From 7700a8fe223d80ca3b942da8aecbe56188a7d7a3 Mon Sep 17 00:00:00 2001 From: Kacper Sawicki Date: Thu, 11 Dec 2025 13:50:43 +0000 Subject: [PATCH 01/26] feat: add circuit breaker for upstream provider overload protection Implement per-provider circuit breakers that detect upstream rate limiting (429/503/529 status codes) and temporarily stop sending requests when providers are overloaded. Key features: - Per-provider circuit breakers (Anthropic, OpenAI) - Configurable failure threshold, time window, and cooldown period - Half-open state allows gradual recovery testing - Prometheus metrics for monitoring (state gauge, trips counter, rejects counter) - Thread-safe implementation with proper state machine transitions - Disabled by default for backward compatibility Circuit breaker states: - Closed: normal operation, tracking failures within sliding window - Open: all requests rejected with 503, waiting for cooldown - Half-Open: limited requests allowed to test if upstream recovered Status codes that trigger circuit breaker: - 429 Too Many Requests - 503 Service Unavailable - 529 Anthropic Overloaded Relates to: https://github.com/coder/internal/issues/1153 --- bridge.go | 45 ++++- circuit_breaker.go | 349 ++++++++++++++++++++++++++++++++++++ circuit_breaker_test.go | 382 ++++++++++++++++++++++++++++++++++++++++ interception.go | 52 +++++- metrics.go | 26 +++ 5 files changed, 847 insertions(+), 7 deletions(-) create mode 100644 circuit_breaker.go create mode 100644 circuit_breaker_test.go diff --git a/bridge.go b/bridge.go index 9f2c424..16872a0 100644 --- a/bridge.go +++ b/bridge.go @@ -30,6 +30,10 @@ type RequestBridge struct { mcpProxy mcp.ServerProxier + // circuitBreakers manages circuit breakers for upstream providers. + // When enabled, it protects against cascading failures from upstream rate limits. + circuitBreakers *CircuitBreakerManager + inflightReqs atomic.Int32 inflightWG sync.WaitGroup // For graceful shutdown. @@ -49,12 +53,34 @@ var _ http.Handler = &RequestBridge{} // // mcpProxy will be closed when the [RequestBridge] is closed. func NewRequestBridge(ctx context.Context, providers []Provider, recorder Recorder, mcpProxy mcp.ServerProxier, logger slog.Logger, metrics *Metrics, tracer trace.Tracer) (*RequestBridge, error) { + return NewRequestBridgeWithCircuitBreaker(ctx, providers, recorder, mcpProxy, logger, metrics, tracer, DefaultCircuitBreakerConfig()) +} + +// NewRequestBridgeWithCircuitBreaker creates a new *[RequestBridge] with custom circuit breaker configuration. +// See [NewRequestBridge] for more details. +func NewRequestBridgeWithCircuitBreaker(ctx context.Context, providers []Provider, recorder Recorder, mcpProxy mcp.ServerProxier, logger slog.Logger, metrics *Metrics, tracer trace.Tracer, cbConfig CircuitBreakerConfig) (*RequestBridge, error) { mux := http.NewServeMux() + // Create circuit breaker manager + cbManager := NewCircuitBreakerManager(cbConfig) + + // Set up metrics callback if metrics are provided + if metrics != nil { + cbManager.SetStateChangeCallback(func(provider string, from, to CircuitState) { + metrics.CircuitBreakerState.WithLabelValues(provider).Set(float64(to)) + if to == CircuitOpen { + metrics.CircuitBreakerTrips.WithLabelValues(provider).Inc() + } + }) + } + for _, provider := range providers { + // Pre-create circuit breaker for this provider + cbManager.GetOrCreate(provider.Name()) + // Add the known provider-specific routes which are bridged (i.e. intercepted and augmented). for _, path := range provider.BridgedRoutes() { - mux.HandleFunc(path, newInterceptionProcessor(provider, recorder, mcpProxy, logger, metrics, tracer)) + mux.HandleFunc(path, newInterceptionProcessor(provider, recorder, mcpProxy, logger, metrics, tracer, cbManager)) } // Any requests which passthrough to this will be reverse-proxied to the upstream. @@ -77,11 +103,12 @@ func NewRequestBridge(ctx context.Context, providers []Provider, recorder Record inflightCtx, cancel := context.WithCancel(context.Background()) return &RequestBridge{ - mux: mux, - logger: logger, - mcpProxy: mcpProxy, - inflightCtx: inflightCtx, - inflightCancel: cancel, + mux: mux, + logger: logger, + mcpProxy: mcpProxy, + circuitBreakers: cbManager, + inflightCtx: inflightCtx, + inflightCancel: cancel, closed: make(chan struct{}, 1), }, nil @@ -153,6 +180,12 @@ func (b *RequestBridge) InflightRequests() int32 { return b.inflightReqs.Load() } +// CircuitBreakers returns the circuit breaker manager for this bridge. +// This can be used to query circuit breaker states or configure callbacks. +func (b *RequestBridge) CircuitBreakers() *CircuitBreakerManager { + return b.circuitBreakers +} + // mergeContexts merges two contexts together, so that if either is cancelled // the returned context is cancelled. The context values will only be used from // the first context. diff --git a/circuit_breaker.go b/circuit_breaker.go new file mode 100644 index 0000000..48800ee --- /dev/null +++ b/circuit_breaker.go @@ -0,0 +1,349 @@ +package aibridge + +import ( + "net/http" + "sync" + "time" +) + +// CircuitState represents the current state of a circuit breaker. +type CircuitState int + +const ( + // CircuitClosed is the normal state - all requests pass through. + CircuitClosed CircuitState = iota + // CircuitOpen is the tripped state - requests are rejected immediately. + CircuitOpen + // CircuitHalfOpen is the testing state - limited requests pass through. + CircuitHalfOpen +) + +func (s CircuitState) String() string { + switch s { + case CircuitClosed: + return "closed" + case CircuitOpen: + return "open" + case CircuitHalfOpen: + return "half-open" + default: + return "unknown" + } +} + +// CircuitBreakerConfig holds configuration for a circuit breaker. +type CircuitBreakerConfig struct { + // Enabled controls whether the circuit breaker is active. + // If false, all requests pass through regardless of failures. + Enabled bool + // FailureThreshold is the number of failures within the window that triggers the circuit to open. + FailureThreshold int64 + // Window is the time window for counting failures. + Window time.Duration + // Cooldown is how long the circuit stays open before transitioning to half-open. + Cooldown time.Duration + // HalfOpenMaxRequests is the maximum number of requests allowed in half-open state + // before deciding whether to close or re-open the circuit. + HalfOpenMaxRequests int64 +} + +// DefaultCircuitBreakerConfig returns sensible defaults for circuit breaker configuration. +func DefaultCircuitBreakerConfig() CircuitBreakerConfig { + return CircuitBreakerConfig{ + Enabled: false, // Disabled by default for backward compatibility + FailureThreshold: 5, + Window: 10 * time.Second, + Cooldown: 30 * time.Second, + HalfOpenMaxRequests: 3, + } +} + +// CircuitBreaker implements the circuit breaker pattern to protect against +// upstream service failures. It tracks failures from upstream providers +// (like rate limit errors) and temporarily blocks requests when the +// failure threshold is exceeded. +type CircuitBreaker struct { + mu sync.RWMutex + + // Current state + state CircuitState + failures int64 // Failure count in current window + windowStart time.Time // Start of current failure counting window + openedAt time.Time // When circuit transitioned to open + + // Half-open state tracking + halfOpenSuccesses int64 + halfOpenFailures int64 + + // Configuration + config CircuitBreakerConfig + + // Provider name for logging/metrics + provider string + + // Optional metrics callback + onStateChange func(provider string, from, to CircuitState) +} + +// NewCircuitBreaker creates a new circuit breaker for the given provider. +func NewCircuitBreaker(provider string, config CircuitBreakerConfig) *CircuitBreaker { + return &CircuitBreaker{ + state: CircuitClosed, + windowStart: time.Now(), + config: config, + provider: provider, + } +} + +// SetStateChangeCallback sets a callback that is invoked when the circuit state changes. +// This is useful for metrics and logging. +func (cb *CircuitBreaker) SetStateChangeCallback(fn func(provider string, from, to CircuitState)) { + cb.mu.Lock() + defer cb.mu.Unlock() + cb.onStateChange = fn +} + +// Allow checks if a request should be allowed through. +// Returns true if the request can proceed, false if it should be rejected. +func (cb *CircuitBreaker) Allow() bool { + if !cb.config.Enabled { + return true + } + + cb.mu.Lock() + defer cb.mu.Unlock() + + now := time.Now() + + switch cb.state { + case CircuitClosed: + return true + + case CircuitOpen: + // Check if cooldown period has elapsed + if now.Sub(cb.openedAt) >= cb.config.Cooldown { + cb.transitionTo(CircuitHalfOpen) + return true + } + return false + + case CircuitHalfOpen: + // Allow limited requests in half-open state + totalHalfOpenRequests := cb.halfOpenSuccesses + cb.halfOpenFailures + return totalHalfOpenRequests < cb.config.HalfOpenMaxRequests + } + + return true +} + +// RecordSuccess records a successful request. +// This is called after a request completes successfully. +func (cb *CircuitBreaker) RecordSuccess() { + if !cb.config.Enabled { + return + } + + cb.mu.Lock() + defer cb.mu.Unlock() + + switch cb.state { + case CircuitHalfOpen: + cb.halfOpenSuccesses++ + // If we've had enough successes in half-open, close the circuit + if cb.halfOpenSuccesses >= cb.config.HalfOpenMaxRequests { + cb.transitionTo(CircuitClosed) + } + case CircuitClosed: + // Reset failure count on success (sliding window behavior) + // This helps prevent false positives from old failures + cb.maybeResetWindow() + } +} + +// RecordFailure records a failed request. +// statusCode is the HTTP status code from the upstream response. +// Returns true if this failure caused the circuit to trip open. +func (cb *CircuitBreaker) RecordFailure(statusCode int) bool { + if !cb.config.Enabled { + return false + } + + // Only count specific error codes as circuit-breaker failures + if !isCircuitBreakerFailure(statusCode) { + return false + } + + cb.mu.Lock() + defer cb.mu.Unlock() + + switch cb.state { + case CircuitClosed: + cb.maybeResetWindow() + cb.failures++ + if cb.failures >= cb.config.FailureThreshold { + cb.transitionTo(CircuitOpen) + return true + } + + case CircuitHalfOpen: + cb.halfOpenFailures++ + // Any failure in half-open state re-opens the circuit + cb.transitionTo(CircuitOpen) + return true + } + + return false +} + +// State returns the current state of the circuit breaker. +func (cb *CircuitBreaker) State() CircuitState { + cb.mu.RLock() + defer cb.mu.RUnlock() + return cb.state +} + +// Provider returns the provider name this circuit breaker is for. +func (cb *CircuitBreaker) Provider() string { + return cb.provider +} + +// Failures returns the current failure count. +func (cb *CircuitBreaker) Failures() int64 { + cb.mu.RLock() + defer cb.mu.RUnlock() + return cb.failures +} + +// transitionTo changes the circuit state. Must be called with lock held. +func (cb *CircuitBreaker) transitionTo(newState CircuitState) { + oldState := cb.state + if oldState == newState { + return + } + + cb.state = newState + now := time.Now() + + switch newState { + case CircuitOpen: + cb.openedAt = now + case CircuitHalfOpen: + cb.halfOpenSuccesses = 0 + cb.halfOpenFailures = 0 + case CircuitClosed: + cb.failures = 0 + cb.windowStart = now + } + + if cb.onStateChange != nil { + // Call callback without holding lock to avoid deadlocks + callback := cb.onStateChange + go callback(cb.provider, oldState, newState) + } +} + +// maybeResetWindow resets the failure count if the window has elapsed. +// Must be called with lock held. +func (cb *CircuitBreaker) maybeResetWindow() { + now := time.Now() + if now.Sub(cb.windowStart) >= cb.config.Window { + cb.failures = 0 + cb.windowStart = now + } +} + +// isCircuitBreakerFailure returns true if the given HTTP status code +// should count as a failure for circuit breaker purposes. +// We specifically track rate limiting and overload errors from upstream. +func isCircuitBreakerFailure(statusCode int) bool { + switch statusCode { + case http.StatusTooManyRequests: // 429 - Rate limited + return true + case http.StatusServiceUnavailable: // 503 - Service unavailable + return true + case 529: // Anthropic-specific "Overloaded" error + return true + default: + return false + } +} + +// CircuitBreakerManager manages circuit breakers for multiple providers. +type CircuitBreakerManager struct { + mu sync.RWMutex + breakers map[string]*CircuitBreaker + config CircuitBreakerConfig + + // Metrics callbacks + onStateChange func(provider string, from, to CircuitState) +} + +// NewCircuitBreakerManager creates a new manager with the given configuration. +func NewCircuitBreakerManager(config CircuitBreakerConfig) *CircuitBreakerManager { + return &CircuitBreakerManager{ + breakers: make(map[string]*CircuitBreaker), + config: config, + } +} + +// SetStateChangeCallback sets the callback for state changes on all circuit breakers. +func (m *CircuitBreakerManager) SetStateChangeCallback(fn func(provider string, from, to CircuitState)) { + m.mu.Lock() + defer m.mu.Unlock() + m.onStateChange = fn + + // Update existing breakers + for _, cb := range m.breakers { + cb.SetStateChangeCallback(fn) + } +} + +// GetOrCreate returns the circuit breaker for the given provider, +// creating one if it doesn't exist. +func (m *CircuitBreakerManager) GetOrCreate(provider string) *CircuitBreaker { + m.mu.RLock() + if cb, ok := m.breakers[provider]; ok { + m.mu.RUnlock() + return cb + } + m.mu.RUnlock() + + m.mu.Lock() + defer m.mu.Unlock() + + // Double-check after acquiring write lock + if cb, ok := m.breakers[provider]; ok { + return cb + } + + cb := NewCircuitBreaker(provider, m.config) + if m.onStateChange != nil { + cb.SetStateChangeCallback(m.onStateChange) + } + m.breakers[provider] = cb + return cb +} + +// Get returns the circuit breaker for the given provider, or nil if not found. +func (m *CircuitBreakerManager) Get(provider string) *CircuitBreaker { + m.mu.RLock() + defer m.mu.RUnlock() + return m.breakers[provider] +} + +// AllStates returns the current state of all circuit breakers. +func (m *CircuitBreakerManager) AllStates() map[string]CircuitState { + m.mu.RLock() + defer m.mu.RUnlock() + + states := make(map[string]CircuitState, len(m.breakers)) + for provider, cb := range m.breakers { + states[provider] = cb.State() + } + return states +} + +// Config returns the configuration used by this manager. +func (m *CircuitBreakerManager) Config() CircuitBreakerConfig { + return m.config +} diff --git a/circuit_breaker_test.go b/circuit_breaker_test.go new file mode 100644 index 0000000..4d52a15 --- /dev/null +++ b/circuit_breaker_test.go @@ -0,0 +1,382 @@ +package aibridge + +import ( + "net/http" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCircuitBreaker_DefaultConfig(t *testing.T) { + t.Parallel() + + cfg := DefaultCircuitBreakerConfig() + assert.False(t, cfg.Enabled, "should be disabled by default") + assert.Equal(t, int64(5), cfg.FailureThreshold) + assert.Equal(t, 10*time.Second, cfg.Window) + assert.Equal(t, 30*time.Second, cfg.Cooldown) + assert.Equal(t, int64(3), cfg.HalfOpenMaxRequests) +} + +func TestCircuitBreaker_DisabledByDefault(t *testing.T) { + t.Parallel() + + cb := NewCircuitBreaker("test", DefaultCircuitBreakerConfig()) + + // Should always allow when disabled + assert.True(t, cb.Allow()) + + // Recording failures should not affect state when disabled + for i := 0; i < 100; i++ { + cb.RecordFailure(http.StatusTooManyRequests) + } + assert.True(t, cb.Allow()) + assert.Equal(t, CircuitClosed, cb.State()) +} + +func TestCircuitBreaker_StateTransitions(t *testing.T) { + t.Parallel() + + cfg := CircuitBreakerConfig{ + Enabled: true, + FailureThreshold: 3, + Window: time.Minute, // Long window so it doesn't reset during test + Cooldown: 50 * time.Millisecond, + HalfOpenMaxRequests: 2, + } + cb := NewCircuitBreaker("test", cfg) + + // Start in closed state + assert.Equal(t, CircuitClosed, cb.State()) + assert.True(t, cb.Allow()) + + // Record failures below threshold + cb.RecordFailure(http.StatusTooManyRequests) + cb.RecordFailure(http.StatusTooManyRequests) + assert.Equal(t, CircuitClosed, cb.State()) + assert.True(t, cb.Allow()) + + // Third failure should trip the circuit + tripped := cb.RecordFailure(http.StatusTooManyRequests) + assert.True(t, tripped) + assert.Equal(t, CircuitOpen, cb.State()) + assert.False(t, cb.Allow()) + + // Wait for cooldown + time.Sleep(60 * time.Millisecond) + + // Should transition to half-open and allow request + assert.True(t, cb.Allow()) + assert.Equal(t, CircuitHalfOpen, cb.State()) + + // Success in half-open should eventually close + cb.RecordSuccess() + cb.RecordSuccess() + assert.Equal(t, CircuitClosed, cb.State()) + assert.True(t, cb.Allow()) +} + +func TestCircuitBreaker_HalfOpenFailure(t *testing.T) { + t.Parallel() + + cfg := CircuitBreakerConfig{ + Enabled: true, + FailureThreshold: 2, + Window: time.Minute, + Cooldown: 50 * time.Millisecond, + HalfOpenMaxRequests: 3, + } + cb := NewCircuitBreaker("test", cfg) + + // Trip the circuit + cb.RecordFailure(http.StatusTooManyRequests) + cb.RecordFailure(http.StatusTooManyRequests) + assert.Equal(t, CircuitOpen, cb.State()) + + // Wait for cooldown + time.Sleep(60 * time.Millisecond) + + // Transition to half-open + assert.True(t, cb.Allow()) + assert.Equal(t, CircuitHalfOpen, cb.State()) + + // Failure in half-open should re-open circuit + tripped := cb.RecordFailure(http.StatusServiceUnavailable) + assert.True(t, tripped) + assert.Equal(t, CircuitOpen, cb.State()) + assert.False(t, cb.Allow()) +} + +func TestCircuitBreaker_OnlyCountsRelevantStatusCodes(t *testing.T) { + t.Parallel() + + cfg := CircuitBreakerConfig{ + Enabled: true, + FailureThreshold: 2, + Window: time.Minute, + Cooldown: time.Minute, + HalfOpenMaxRequests: 2, + } + cb := NewCircuitBreaker("test", cfg) + + // Non-circuit-breaker status codes should not count + cb.RecordFailure(http.StatusBadRequest) // 400 + cb.RecordFailure(http.StatusUnauthorized) // 401 + cb.RecordFailure(http.StatusInternalServerError) // 500 + cb.RecordFailure(http.StatusBadGateway) // 502 + assert.Equal(t, CircuitClosed, cb.State()) + assert.Equal(t, int64(0), cb.Failures()) + + // These should count + cb.RecordFailure(http.StatusTooManyRequests) // 429 + assert.Equal(t, int64(1), cb.Failures()) + + cb.RecordFailure(http.StatusServiceUnavailable) // 503 + assert.Equal(t, CircuitOpen, cb.State()) +} + +func TestCircuitBreaker_Anthropic529(t *testing.T) { + t.Parallel() + + cfg := CircuitBreakerConfig{ + Enabled: true, + FailureThreshold: 1, + Window: time.Minute, + Cooldown: time.Minute, + HalfOpenMaxRequests: 1, + } + cb := NewCircuitBreaker("anthropic", cfg) + + // Anthropic-specific 529 "Overloaded" should trip the circuit + tripped := cb.RecordFailure(529) + assert.True(t, tripped) + assert.Equal(t, CircuitOpen, cb.State()) +} + +func TestCircuitBreaker_WindowReset(t *testing.T) { + t.Parallel() + + cfg := CircuitBreakerConfig{ + Enabled: true, + FailureThreshold: 3, + Window: 50 * time.Millisecond, // Short window + Cooldown: time.Minute, + HalfOpenMaxRequests: 2, + } + cb := NewCircuitBreaker("test", cfg) + + // Record failures + cb.RecordFailure(http.StatusTooManyRequests) + cb.RecordFailure(http.StatusTooManyRequests) + assert.Equal(t, int64(2), cb.Failures()) + + // Wait for window to expire + time.Sleep(60 * time.Millisecond) + + // Next failure should reset counter (due to window expiry) + cb.RecordFailure(http.StatusTooManyRequests) + assert.Equal(t, int64(1), cb.Failures()) + assert.Equal(t, CircuitClosed, cb.State()) +} + +func TestCircuitBreaker_ConcurrentAccess(t *testing.T) { + t.Parallel() + + cfg := CircuitBreakerConfig{ + Enabled: true, + FailureThreshold: 100, + Window: time.Minute, + Cooldown: time.Minute, + HalfOpenMaxRequests: 10, + } + cb := NewCircuitBreaker("test", cfg) + + var wg sync.WaitGroup + numGoroutines := 50 + opsPerGoroutine := 100 + + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < opsPerGoroutine; j++ { + cb.Allow() + cb.RecordSuccess() + cb.RecordFailure(http.StatusTooManyRequests) + cb.State() + cb.Failures() + } + }() + } + + wg.Wait() + // Should not panic or deadlock +} + +func TestCircuitBreaker_StateChangeCallback(t *testing.T) { + t.Parallel() + + cfg := CircuitBreakerConfig{ + Enabled: true, + FailureThreshold: 2, + Window: time.Minute, + Cooldown: 50 * time.Millisecond, + HalfOpenMaxRequests: 1, + } + cb := NewCircuitBreaker("test", cfg) + + var mu sync.Mutex + var transitions []struct { + from, to CircuitState + } + + cb.SetStateChangeCallback(func(provider string, from, to CircuitState) { + mu.Lock() + defer mu.Unlock() + transitions = append(transitions, struct{ from, to CircuitState }{from, to}) + }) + + // Trip the circuit + cb.RecordFailure(http.StatusTooManyRequests) + cb.RecordFailure(http.StatusTooManyRequests) + + // Wait for callback + time.Sleep(10 * time.Millisecond) + + // Wait for cooldown and trigger half-open + time.Sleep(60 * time.Millisecond) + cb.Allow() + + // Wait for callback + time.Sleep(10 * time.Millisecond) + + // Success to close + cb.RecordSuccess() + + // Wait for callback + time.Sleep(10 * time.Millisecond) + + mu.Lock() + defer mu.Unlock() + require.Len(t, transitions, 3) + assert.Equal(t, CircuitClosed, transitions[0].from) + assert.Equal(t, CircuitOpen, transitions[0].to) + assert.Equal(t, CircuitOpen, transitions[1].from) + assert.Equal(t, CircuitHalfOpen, transitions[1].to) + assert.Equal(t, CircuitHalfOpen, transitions[2].from) + assert.Equal(t, CircuitClosed, transitions[2].to) +} + +func TestCircuitBreakerManager_GetOrCreate(t *testing.T) { + t.Parallel() + + cfg := CircuitBreakerConfig{ + Enabled: true, + FailureThreshold: 5, + Window: time.Minute, + Cooldown: time.Minute, + } + manager := NewCircuitBreakerManager(cfg) + + // First call should create + cb1 := manager.GetOrCreate("anthropic") + require.NotNil(t, cb1) + assert.Equal(t, "anthropic", cb1.Provider()) + + // Second call should return same instance + cb2 := manager.GetOrCreate("anthropic") + assert.Same(t, cb1, cb2) + + // Different provider gets different instance + cb3 := manager.GetOrCreate("openai") + require.NotNil(t, cb3) + assert.NotSame(t, cb1, cb3) + assert.Equal(t, "openai", cb3.Provider()) +} + +func TestCircuitBreakerManager_AllStates(t *testing.T) { + t.Parallel() + + cfg := CircuitBreakerConfig{ + Enabled: true, + FailureThreshold: 1, + Window: time.Minute, + Cooldown: time.Minute, + } + manager := NewCircuitBreakerManager(cfg) + + manager.GetOrCreate("anthropic") + manager.GetOrCreate("openai") + + // Trip one circuit + manager.Get("anthropic").RecordFailure(http.StatusTooManyRequests) + + states := manager.AllStates() + assert.Equal(t, CircuitOpen, states["anthropic"]) + assert.Equal(t, CircuitClosed, states["openai"]) +} + +func TestCircuitBreakerManager_ConcurrentGetOrCreate(t *testing.T) { + t.Parallel() + + cfg := DefaultCircuitBreakerConfig() + cfg.Enabled = true + manager := NewCircuitBreakerManager(cfg) + + var wg sync.WaitGroup + var results [100]*CircuitBreaker + + for i := 0; i < 100; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + results[idx] = manager.GetOrCreate("test-provider") + }(i) + } + + wg.Wait() + + // All should be the same instance + first := results[0] + for i := 1; i < 100; i++ { + assert.Same(t, first, results[i]) + } +} + +func TestIsCircuitBreakerFailure(t *testing.T) { + t.Parallel() + + tests := []struct { + statusCode int + isFailure bool + }{ + {http.StatusOK, false}, + {http.StatusBadRequest, false}, + {http.StatusUnauthorized, false}, + {http.StatusForbidden, false}, + {http.StatusNotFound, false}, + {http.StatusTooManyRequests, true}, // 429 + {http.StatusInternalServerError, false}, + {http.StatusBadGateway, false}, + {http.StatusServiceUnavailable, true}, // 503 + {529, true}, // Anthropic Overloaded + } + + for _, tt := range tests { + t.Run(http.StatusText(tt.statusCode), func(t *testing.T) { + assert.Equal(t, tt.isFailure, isCircuitBreakerFailure(tt.statusCode)) + }) + } +} + +func TestCircuitState_String(t *testing.T) { + t.Parallel() + + assert.Equal(t, "closed", CircuitClosed.String()) + assert.Equal(t, "open", CircuitOpen.String()) + assert.Equal(t, "half-open", CircuitHalfOpen.String()) + assert.Equal(t, "unknown", CircuitState(99).String()) +} diff --git a/interception.go b/interception.go index 46ec7bd..d1d8444 100644 --- a/interception.go +++ b/interception.go @@ -40,11 +40,26 @@ const recordingTimeout = time.Second * 5 // newInterceptionProcessor returns an [http.HandlerFunc] which is capable of creating a new interceptor and processing a given request // using [Provider] p, recording all usage events using [Recorder] recorder. -func newInterceptionProcessor(p Provider, recorder Recorder, mcpProxy mcp.ServerProxier, logger slog.Logger, metrics *Metrics, tracer trace.Tracer) http.HandlerFunc { +func newInterceptionProcessor(p Provider, recorder Recorder, mcpProxy mcp.ServerProxier, logger slog.Logger, metrics *Metrics, tracer trace.Tracer, cbManager *CircuitBreakerManager) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx, span := tracer.Start(r.Context(), "Intercept") defer span.End() + // Check circuit breaker before proceeding + cb := cbManager.GetOrCreate(p.Name()) + if !cb.Allow() { + span.SetStatus(codes.Error, "circuit breaker open") + logger.Debug(ctx, "request rejected by circuit breaker", + slog.F("provider", p.Name()), + slog.F("circuit_state", cb.State().String()), + ) + if metrics != nil { + metrics.CircuitBreakerRejects.WithLabelValues(p.Name()).Inc() + } + http.Error(w, fmt.Sprintf("%s is currently unavailable due to upstream rate limiting. Please try again later.", p.Name()), http.StatusServiceUnavailable) + return + } + interceptor, err := p.CreateInterceptor(w, r.WithContext(ctx), tracer) if err != nil { span.SetStatus(codes.Error, fmt.Sprintf("failed to create interceptor: %v", err)) @@ -116,11 +131,24 @@ func newInterceptionProcessor(p Provider, recorder Recorder, mcpProxy mcp.Server } span.SetStatus(codes.Error, fmt.Sprintf("interception failed: %v", err)) log.Warn(ctx, "interception failed", slog.Error(err)) + + // Record failure for circuit breaker - extract status code if available + if statusCode := extractStatusCodeFromError(err); statusCode > 0 { + if cb.RecordFailure(statusCode) { + log.Warn(ctx, "circuit breaker tripped", + slog.F("provider", p.Name()), + slog.F("status_code", statusCode), + ) + } + } } else { if metrics != nil { metrics.InterceptionCount.WithLabelValues(p.Name(), interceptor.Model(), InterceptionCountStatusCompleted, route, r.Method, actor.id).Add(1) } log.Debug(ctx, "interception ended") + + // Record success for circuit breaker + cb.RecordSuccess() } asyncRecorder.RecordInterceptionEnded(ctx, &InterceptionRecordEnded{ID: interceptor.ID().String()}) @@ -128,3 +156,25 @@ func newInterceptionProcessor(p Provider, recorder Recorder, mcpProxy mcp.Server asyncRecorder.Wait() } } + +// extractStatusCodeFromError attempts to extract an HTTP status code from an error. +// This is used for circuit breaker failure tracking. +func extractStatusCodeFromError(err error) int { + if err == nil { + return 0 + } + + // Check for Anthropic error response + var antErr *AnthropicErrorResponse + if errors.As(err, &antErr) && antErr != nil { + return antErr.StatusCode + } + + // Check for OpenAI error response + var oaiErr *OpenAIErrorResponse + if errors.As(err, &oaiErr) && oaiErr != nil { + return oaiErr.StatusCode + } + + return 0 +} diff --git a/metrics.go b/metrics.go index 32d5a78..565029a 100644 --- a/metrics.go +++ b/metrics.go @@ -28,6 +28,11 @@ type Metrics struct { // Tool-related metrics. InjectedToolUseCount *prometheus.CounterVec NonInjectedToolUseCount *prometheus.CounterVec + + // Circuit breaker metrics. + CircuitBreakerState *prometheus.GaugeVec // Current state (0=closed, 1=open, 2=half-open) + CircuitBreakerTrips *prometheus.CounterVec // Total times circuit opened + CircuitBreakerRejects *prometheus.CounterVec // Requests rejected due to open circuit } // NewMetrics creates AND registers metrics. It will panic if a collector has already been registered. @@ -102,5 +107,26 @@ func NewMetrics(reg prometheus.Registerer) *Metrics { Name: "total", Help: "The number of times an AI model selected a tool to be invoked by the client.", }, append(baseLabels, "name")), + + // Circuit breaker metrics. + + // Pessimistic cardinality: 2 providers = up to 2. + CircuitBreakerState: promauto.With(reg).NewGaugeVec(prometheus.GaugeOpts{ + Subsystem: "circuit_breaker", + Name: "state", + Help: "Current state of the circuit breaker (0=closed, 1=open, 2=half-open).", + }, []string{"provider"}), + // Pessimistic cardinality: 2 providers = up to 2. + CircuitBreakerTrips: promauto.With(reg).NewCounterVec(prometheus.CounterOpts{ + Subsystem: "circuit_breaker", + Name: "trips_total", + Help: "Total number of times the circuit breaker has tripped open.", + }, []string{"provider"}), + // Pessimistic cardinality: 2 providers = up to 2. + CircuitBreakerRejects: promauto.With(reg).NewCounterVec(prometheus.CounterOpts{ + Subsystem: "circuit_breaker", + Name: "rejects_total", + Help: "Total number of requests rejected due to open circuit breaker.", + }, []string{"provider"}), } } From aad288c4cacf7c854dd757fc315153c5485347c6 Mon Sep 17 00:00:00 2001 From: Kacper Sawicki Date: Fri, 12 Dec 2025 13:25:59 +0000 Subject: [PATCH 02/26] chore: apply make fmt --- circuit_breaker.go | 10 +++++----- circuit_breaker_test.go | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/circuit_breaker.go b/circuit_breaker.go index 48800ee..d863cad 100644 --- a/circuit_breaker.go +++ b/circuit_breaker.go @@ -66,11 +66,11 @@ type CircuitBreaker struct { mu sync.RWMutex // Current state - state CircuitState - failures int64 // Failure count in current window + state CircuitState + failures int64 // Failure count in current window windowStart time.Time // Start of current failure counting window - openedAt time.Time // When circuit transitioned to open - + openedAt time.Time // When circuit transitioned to open + // Half-open state tracking halfOpenSuccesses int64 halfOpenFailures int64 @@ -291,7 +291,7 @@ func (m *CircuitBreakerManager) SetStateChangeCallback(fn func(provider string, m.mu.Lock() defer m.mu.Unlock() m.onStateChange = fn - + // Update existing breakers for _, cb := range m.breakers { cb.SetStateChangeCallback(fn) diff --git a/circuit_breaker_test.go b/circuit_breaker_test.go index 4d52a15..3f84ba6 100644 --- a/circuit_breaker_test.go +++ b/circuit_breaker_test.go @@ -131,9 +131,9 @@ func TestCircuitBreaker_OnlyCountsRelevantStatusCodes(t *testing.T) { assert.Equal(t, int64(0), cb.Failures()) // These should count - cb.RecordFailure(http.StatusTooManyRequests) // 429 + cb.RecordFailure(http.StatusTooManyRequests) // 429 assert.Equal(t, int64(1), cb.Failures()) - + cb.RecordFailure(http.StatusServiceUnavailable) // 503 assert.Equal(t, CircuitOpen, cb.State()) } From 47253f193976b840bb6d2c85fdcfa4c1951a58d1 Mon Sep 17 00:00:00 2001 From: Kacper Sawicki Date: Tue, 16 Dec 2025 10:27:35 +0000 Subject: [PATCH 03/26] refactor: use sony/gobreaker for circuit breakers with per-endpoint isolation - Replace custom circuit breaker implementation with sony/gobreaker - Change from per-provider to per-endpoint circuit breakers (e.g., OpenAI chat completions failing won't block responses API) - Simplify API: CircuitBreakers manages all breakers internally - Update metrics to include endpoint label - Simplify tests to focus on key behaviors Based on PR review feedback suggesting use of established library and per-endpoint granularity for better fault isolation. --- bridge.go | 30 ++-- circuit_breaker.go | 348 ++++++++++------------------------------ circuit_breaker_test.go | 266 ++++++++---------------------- go.mod | 2 + go.sum | 2 + interception.go | 20 +-- metrics.go | 12 +- 7 files changed, 186 insertions(+), 494 deletions(-) diff --git a/bridge.go b/bridge.go index 16872a0..3147486 100644 --- a/bridge.go +++ b/bridge.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net/http" + "strings" "sync" "sync/atomic" @@ -32,7 +33,7 @@ type RequestBridge struct { // circuitBreakers manages circuit breakers for upstream providers. // When enabled, it protects against cascading failures from upstream rate limits. - circuitBreakers *CircuitBreakerManager + circuitBreakers *CircuitBreakers inflightReqs atomic.Int32 inflightWG sync.WaitGroup // For graceful shutdown. @@ -61,26 +62,24 @@ func NewRequestBridge(ctx context.Context, providers []Provider, recorder Record func NewRequestBridgeWithCircuitBreaker(ctx context.Context, providers []Provider, recorder Recorder, mcpProxy mcp.ServerProxier, logger slog.Logger, metrics *Metrics, tracer trace.Tracer, cbConfig CircuitBreakerConfig) (*RequestBridge, error) { mux := http.NewServeMux() - // Create circuit breaker manager - cbManager := NewCircuitBreakerManager(cbConfig) - - // Set up metrics callback if metrics are provided + // Create circuit breakers with metrics callback + var onChange func(name string, from, to CircuitState) if metrics != nil { - cbManager.SetStateChangeCallback(func(provider string, from, to CircuitState) { - metrics.CircuitBreakerState.WithLabelValues(provider).Set(float64(to)) + onChange = func(name string, from, to CircuitState) { + provider, endpoint, _ := strings.Cut(name, ":") + metrics.CircuitBreakerState.WithLabelValues(provider, endpoint).Set(float64(to)) if to == CircuitOpen { - metrics.CircuitBreakerTrips.WithLabelValues(provider).Inc() + metrics.CircuitBreakerTrips.WithLabelValues(provider, endpoint).Inc() } - }) + } } + cbs := NewCircuitBreakers(cbConfig, onChange) for _, provider := range providers { - // Pre-create circuit breaker for this provider - cbManager.GetOrCreate(provider.Name()) // Add the known provider-specific routes which are bridged (i.e. intercepted and augmented). for _, path := range provider.BridgedRoutes() { - mux.HandleFunc(path, newInterceptionProcessor(provider, recorder, mcpProxy, logger, metrics, tracer, cbManager)) + mux.HandleFunc(path, newInterceptionProcessor(provider, recorder, mcpProxy, logger, metrics, tracer, cbs)) } // Any requests which passthrough to this will be reverse-proxied to the upstream. @@ -106,7 +105,7 @@ func NewRequestBridgeWithCircuitBreaker(ctx context.Context, providers []Provide mux: mux, logger: logger, mcpProxy: mcpProxy, - circuitBreakers: cbManager, + circuitBreakers: cbs, inflightCtx: inflightCtx, inflightCancel: cancel, @@ -180,9 +179,8 @@ func (b *RequestBridge) InflightRequests() int32 { return b.inflightReqs.Load() } -// CircuitBreakers returns the circuit breaker manager for this bridge. -// This can be used to query circuit breaker states or configure callbacks. -func (b *RequestBridge) CircuitBreakers() *CircuitBreakerManager { +// CircuitBreakers returns the circuit breakers for this bridge. +func (b *RequestBridge) CircuitBreakers() *CircuitBreakers { return b.circuitBreakers } diff --git a/circuit_breaker.go b/circuit_breaker.go index d863cad..d27f0cd 100644 --- a/circuit_breaker.go +++ b/circuit_breaker.go @@ -1,9 +1,12 @@ package aibridge import ( + "fmt" "net/http" "sync" "time" + + "github.com/sony/gobreaker/v2" ) // CircuitState represents the current state of a circuit breaker. @@ -31,19 +34,31 @@ func (s CircuitState) String() string { } } -// CircuitBreakerConfig holds configuration for a circuit breaker. +// toCircuitState converts gobreaker.State to our CircuitState. +func toCircuitState(s gobreaker.State) CircuitState { + switch s { + case gobreaker.StateClosed: + return CircuitClosed + case gobreaker.StateOpen: + return CircuitOpen + case gobreaker.StateHalfOpen: + return CircuitHalfOpen + default: + return CircuitClosed + } +} + +// CircuitBreakerConfig holds configuration for circuit breakers. type CircuitBreakerConfig struct { - // Enabled controls whether the circuit breaker is active. - // If false, all requests pass through regardless of failures. + // Enabled controls whether circuit breakers are active. Enabled bool - // FailureThreshold is the number of failures within the window that triggers the circuit to open. + // FailureThreshold is the number of consecutive failures that triggers the circuit to open. FailureThreshold int64 // Window is the time window for counting failures. Window time.Duration // Cooldown is how long the circuit stays open before transitioning to half-open. Cooldown time.Duration - // HalfOpenMaxRequests is the maximum number of requests allowed in half-open state - // before deciding whether to close or re-open the circuit. + // HalfOpenMaxRequests is the maximum number of requests allowed in half-open state. HalfOpenMaxRequests int64 } @@ -58,292 +73,97 @@ func DefaultCircuitBreakerConfig() CircuitBreakerConfig { } } -// CircuitBreaker implements the circuit breaker pattern to protect against -// upstream service failures. It tracks failures from upstream providers -// (like rate limit errors) and temporarily blocks requests when the -// failure threshold is exceeded. -type CircuitBreaker struct { - mu sync.RWMutex - - // Current state - state CircuitState - failures int64 // Failure count in current window - windowStart time.Time // Start of current failure counting window - openedAt time.Time // When circuit transitioned to open - - // Half-open state tracking - halfOpenSuccesses int64 - halfOpenFailures int64 - - // Configuration - config CircuitBreakerConfig - - // Provider name for logging/metrics - provider string - - // Optional metrics callback - onStateChange func(provider string, from, to CircuitState) -} - -// NewCircuitBreaker creates a new circuit breaker for the given provider. -func NewCircuitBreaker(provider string, config CircuitBreakerConfig) *CircuitBreaker { - return &CircuitBreaker{ - state: CircuitClosed, - windowStart: time.Now(), - config: config, - provider: provider, - } -} - -// SetStateChangeCallback sets a callback that is invoked when the circuit state changes. -// This is useful for metrics and logging. -func (cb *CircuitBreaker) SetStateChangeCallback(fn func(provider string, from, to CircuitState)) { - cb.mu.Lock() - defer cb.mu.Unlock() - cb.onStateChange = fn -} - -// Allow checks if a request should be allowed through. -// Returns true if the request can proceed, false if it should be rejected. -func (cb *CircuitBreaker) Allow() bool { - if !cb.config.Enabled { - return true - } - - cb.mu.Lock() - defer cb.mu.Unlock() - - now := time.Now() - - switch cb.state { - case CircuitClosed: - return true - - case CircuitOpen: - // Check if cooldown period has elapsed - if now.Sub(cb.openedAt) >= cb.config.Cooldown { - cb.transitionTo(CircuitHalfOpen) - return true - } - return false - - case CircuitHalfOpen: - // Allow limited requests in half-open state - totalHalfOpenRequests := cb.halfOpenSuccesses + cb.halfOpenFailures - return totalHalfOpenRequests < cb.config.HalfOpenMaxRequests - } - - return true -} - -// RecordSuccess records a successful request. -// This is called after a request completes successfully. -func (cb *CircuitBreaker) RecordSuccess() { - if !cb.config.Enabled { - return - } - - cb.mu.Lock() - defer cb.mu.Unlock() - - switch cb.state { - case CircuitHalfOpen: - cb.halfOpenSuccesses++ - // If we've had enough successes in half-open, close the circuit - if cb.halfOpenSuccesses >= cb.config.HalfOpenMaxRequests { - cb.transitionTo(CircuitClosed) - } - case CircuitClosed: - // Reset failure count on success (sliding window behavior) - // This helps prevent false positives from old failures - cb.maybeResetWindow() - } -} - -// RecordFailure records a failed request. -// statusCode is the HTTP status code from the upstream response. -// Returns true if this failure caused the circuit to trip open. -func (cb *CircuitBreaker) RecordFailure(statusCode int) bool { - if !cb.config.Enabled { - return false - } - - // Only count specific error codes as circuit-breaker failures - if !isCircuitBreakerFailure(statusCode) { - return false - } - - cb.mu.Lock() - defer cb.mu.Unlock() - - switch cb.state { - case CircuitClosed: - cb.maybeResetWindow() - cb.failures++ - if cb.failures >= cb.config.FailureThreshold { - cb.transitionTo(CircuitOpen) - return true - } - - case CircuitHalfOpen: - cb.halfOpenFailures++ - // Any failure in half-open state re-opens the circuit - cb.transitionTo(CircuitOpen) - return true - } - - return false -} - -// State returns the current state of the circuit breaker. -func (cb *CircuitBreaker) State() CircuitState { - cb.mu.RLock() - defer cb.mu.RUnlock() - return cb.state -} - -// Provider returns the provider name this circuit breaker is for. -func (cb *CircuitBreaker) Provider() string { - return cb.provider -} - -// Failures returns the current failure count. -func (cb *CircuitBreaker) Failures() int64 { - cb.mu.RLock() - defer cb.mu.RUnlock() - return cb.failures -} - -// transitionTo changes the circuit state. Must be called with lock held. -func (cb *CircuitBreaker) transitionTo(newState CircuitState) { - oldState := cb.state - if oldState == newState { - return - } - - cb.state = newState - now := time.Now() - - switch newState { - case CircuitOpen: - cb.openedAt = now - case CircuitHalfOpen: - cb.halfOpenSuccesses = 0 - cb.halfOpenFailures = 0 - case CircuitClosed: - cb.failures = 0 - cb.windowStart = now - } - - if cb.onStateChange != nil { - // Call callback without holding lock to avoid deadlocks - callback := cb.onStateChange - go callback(cb.provider, oldState, newState) - } -} - -// maybeResetWindow resets the failure count if the window has elapsed. -// Must be called with lock held. -func (cb *CircuitBreaker) maybeResetWindow() { - now := time.Now() - if now.Sub(cb.windowStart) >= cb.config.Window { - cb.failures = 0 - cb.windowStart = now - } -} - // isCircuitBreakerFailure returns true if the given HTTP status code // should count as a failure for circuit breaker purposes. -// We specifically track rate limiting and overload errors from upstream. func isCircuitBreakerFailure(statusCode int) bool { switch statusCode { - case http.StatusTooManyRequests: // 429 - Rate limited - return true - case http.StatusServiceUnavailable: // 503 - Service unavailable - return true - case 529: // Anthropic-specific "Overloaded" error + case http.StatusTooManyRequests, // 429 + http.StatusServiceUnavailable, // 503 + 529: // Anthropic "Overloaded" return true default: return false } } -// CircuitBreakerManager manages circuit breakers for multiple providers. -type CircuitBreakerManager struct { - mu sync.RWMutex - breakers map[string]*CircuitBreaker +// CircuitBreakers manages per-endpoint circuit breakers using sony/gobreaker. +// Circuit breakers are keyed by "provider:endpoint" for per-endpoint isolation. +type CircuitBreakers struct { + breakers sync.Map // map[string]*gobreaker.CircuitBreaker[any] config CircuitBreakerConfig - - // Metrics callbacks - onStateChange func(provider string, from, to CircuitState) + onChange func(name string, from, to CircuitState) } -// NewCircuitBreakerManager creates a new manager with the given configuration. -func NewCircuitBreakerManager(config CircuitBreakerConfig) *CircuitBreakerManager { - return &CircuitBreakerManager{ - breakers: make(map[string]*CircuitBreaker), +// NewCircuitBreakers creates a new circuit breaker manager. +func NewCircuitBreakers(config CircuitBreakerConfig, onChange func(name string, from, to CircuitState)) *CircuitBreakers { + return &CircuitBreakers{ config: config, + onChange: onChange, } } -// SetStateChangeCallback sets the callback for state changes on all circuit breakers. -func (m *CircuitBreakerManager) SetStateChangeCallback(fn func(provider string, from, to CircuitState)) { - m.mu.Lock() - defer m.mu.Unlock() - m.onStateChange = fn - - // Update existing breakers - for _, cb := range m.breakers { - cb.SetStateChangeCallback(fn) +// Allow checks if a request to provider/endpoint should be allowed. +func (c *CircuitBreakers) Allow(provider, endpoint string) bool { + if !c.config.Enabled { + return true } + cb := c.getOrCreate(provider, endpoint) + return cb.State() != gobreaker.StateOpen } -// GetOrCreate returns the circuit breaker for the given provider, -// creating one if it doesn't exist. -func (m *CircuitBreakerManager) GetOrCreate(provider string) *CircuitBreaker { - m.mu.RLock() - if cb, ok := m.breakers[provider]; ok { - m.mu.RUnlock() - return cb - } - m.mu.RUnlock() - - m.mu.Lock() - defer m.mu.Unlock() - - // Double-check after acquiring write lock - if cb, ok := m.breakers[provider]; ok { - return cb +// RecordSuccess records a successful request. +func (c *CircuitBreakers) RecordSuccess(provider, endpoint string) { + if !c.config.Enabled { + return } + cb := c.getOrCreate(provider, endpoint) + _, _ = cb.Execute(func() (any, error) { return nil, nil }) +} - cb := NewCircuitBreaker(provider, m.config) - if m.onStateChange != nil { - cb.SetStateChangeCallback(m.onStateChange) +// RecordFailure records a failed request. Returns true if this caused the circuit to open. +func (c *CircuitBreakers) RecordFailure(provider, endpoint string, statusCode int) bool { + if !c.config.Enabled || !isCircuitBreakerFailure(statusCode) { + return false } - m.breakers[provider] = cb - return cb + cb := c.getOrCreate(provider, endpoint) + before := cb.State() + _, _ = cb.Execute(func() (any, error) { + return nil, fmt.Errorf("upstream error: %d", statusCode) + }) + return before != gobreaker.StateOpen && cb.State() == gobreaker.StateOpen } -// Get returns the circuit breaker for the given provider, or nil if not found. -func (m *CircuitBreakerManager) Get(provider string) *CircuitBreaker { - m.mu.RLock() - defer m.mu.RUnlock() - return m.breakers[provider] +// State returns the current state for a provider/endpoint. +func (c *CircuitBreakers) State(provider, endpoint string) CircuitState { + if !c.config.Enabled { + return CircuitClosed + } + cb := c.getOrCreate(provider, endpoint) + return toCircuitState(cb.State()) } -// AllStates returns the current state of all circuit breakers. -func (m *CircuitBreakerManager) AllStates() map[string]CircuitState { - m.mu.RLock() - defer m.mu.RUnlock() +func (c *CircuitBreakers) getOrCreate(provider, endpoint string) *gobreaker.CircuitBreaker[any] { + key := provider + ":" + endpoint + if v, ok := c.breakers.Load(key); ok { + return v.(*gobreaker.CircuitBreaker[any]) + } - states := make(map[string]CircuitState, len(m.breakers)) - for provider, cb := range m.breakers { - states[provider] = cb.State() + settings := gobreaker.Settings{ + Name: key, + MaxRequests: uint32(c.config.HalfOpenMaxRequests), + Interval: c.config.Window, + Timeout: c.config.Cooldown, + ReadyToTrip: func(counts gobreaker.Counts) bool { + return counts.ConsecutiveFailures >= uint32(c.config.FailureThreshold) + }, + OnStateChange: func(name string, from, to gobreaker.State) { + if c.onChange != nil { + c.onChange(name, toCircuitState(from), toCircuitState(to)) + } + }, } - return states -} -// Config returns the configuration used by this manager. -func (m *CircuitBreakerManager) Config() CircuitBreakerConfig { - return m.config + cb := gobreaker.NewCircuitBreaker[any](settings) + actual, _ := c.breakers.LoadOrStore(key, cb) + return actual.(*gobreaker.CircuitBreaker[any]) } diff --git a/circuit_breaker_test.go b/circuit_breaker_test.go index 3f84ba6..0d8deb9 100644 --- a/circuit_breaker_test.go +++ b/circuit_breaker_test.go @@ -21,96 +21,86 @@ func TestCircuitBreaker_DefaultConfig(t *testing.T) { assert.Equal(t, int64(3), cfg.HalfOpenMaxRequests) } -func TestCircuitBreaker_DisabledByDefault(t *testing.T) { +func TestCircuitBreakers_DisabledByDefault(t *testing.T) { t.Parallel() - cb := NewCircuitBreaker("test", DefaultCircuitBreakerConfig()) + cbs := NewCircuitBreakers(DefaultCircuitBreakerConfig(), nil) // Should always allow when disabled - assert.True(t, cb.Allow()) + assert.True(t, cbs.Allow("anthropic", "/v1/messages")) // Recording failures should not affect state when disabled for i := 0; i < 100; i++ { - cb.RecordFailure(http.StatusTooManyRequests) + cbs.RecordFailure("anthropic", "/v1/messages", http.StatusTooManyRequests) } - assert.True(t, cb.Allow()) - assert.Equal(t, CircuitClosed, cb.State()) + assert.True(t, cbs.Allow("anthropic", "/v1/messages")) + assert.Equal(t, CircuitClosed, cbs.State("anthropic", "/v1/messages")) } -func TestCircuitBreaker_StateTransitions(t *testing.T) { +func TestCircuitBreakers_StateTransitions(t *testing.T) { t.Parallel() cfg := CircuitBreakerConfig{ Enabled: true, FailureThreshold: 3, - Window: time.Minute, // Long window so it doesn't reset during test + Window: time.Minute, Cooldown: 50 * time.Millisecond, HalfOpenMaxRequests: 2, } - cb := NewCircuitBreaker("test", cfg) + cbs := NewCircuitBreakers(cfg, nil) // Start in closed state - assert.Equal(t, CircuitClosed, cb.State()) - assert.True(t, cb.Allow()) + assert.Equal(t, CircuitClosed, cbs.State("test", "/api")) + assert.True(t, cbs.Allow("test", "/api")) // Record failures below threshold - cb.RecordFailure(http.StatusTooManyRequests) - cb.RecordFailure(http.StatusTooManyRequests) - assert.Equal(t, CircuitClosed, cb.State()) - assert.True(t, cb.Allow()) + cbs.RecordFailure("test", "/api", http.StatusTooManyRequests) + cbs.RecordFailure("test", "/api", http.StatusTooManyRequests) + assert.Equal(t, CircuitClosed, cbs.State("test", "/api")) // Third failure should trip the circuit - tripped := cb.RecordFailure(http.StatusTooManyRequests) + tripped := cbs.RecordFailure("test", "/api", http.StatusTooManyRequests) assert.True(t, tripped) - assert.Equal(t, CircuitOpen, cb.State()) - assert.False(t, cb.Allow()) + assert.Equal(t, CircuitOpen, cbs.State("test", "/api")) + assert.False(t, cbs.Allow("test", "/api")) // Wait for cooldown time.Sleep(60 * time.Millisecond) // Should transition to half-open and allow request - assert.True(t, cb.Allow()) - assert.Equal(t, CircuitHalfOpen, cb.State()) + assert.True(t, cbs.Allow("test", "/api")) + assert.Equal(t, CircuitHalfOpen, cbs.State("test", "/api")) // Success in half-open should eventually close - cb.RecordSuccess() - cb.RecordSuccess() - assert.Equal(t, CircuitClosed, cb.State()) - assert.True(t, cb.Allow()) + cbs.RecordSuccess("test", "/api") + cbs.RecordSuccess("test", "/api") + assert.Equal(t, CircuitClosed, cbs.State("test", "/api")) } -func TestCircuitBreaker_HalfOpenFailure(t *testing.T) { +func TestCircuitBreakers_PerEndpointIsolation(t *testing.T) { t.Parallel() cfg := CircuitBreakerConfig{ Enabled: true, - FailureThreshold: 2, + FailureThreshold: 1, Window: time.Minute, - Cooldown: 50 * time.Millisecond, - HalfOpenMaxRequests: 3, + Cooldown: time.Minute, + HalfOpenMaxRequests: 1, } - cb := NewCircuitBreaker("test", cfg) - - // Trip the circuit - cb.RecordFailure(http.StatusTooManyRequests) - cb.RecordFailure(http.StatusTooManyRequests) - assert.Equal(t, CircuitOpen, cb.State()) + cbs := NewCircuitBreakers(cfg, nil) - // Wait for cooldown - time.Sleep(60 * time.Millisecond) + // Trip circuit for one endpoint + cbs.RecordFailure("openai", "/v1/chat/completions", http.StatusTooManyRequests) + assert.Equal(t, CircuitOpen, cbs.State("openai", "/v1/chat/completions")) - // Transition to half-open - assert.True(t, cb.Allow()) - assert.Equal(t, CircuitHalfOpen, cb.State()) - - // Failure in half-open should re-open circuit - tripped := cb.RecordFailure(http.StatusServiceUnavailable) - assert.True(t, tripped) - assert.Equal(t, CircuitOpen, cb.State()) - assert.False(t, cb.Allow()) + // Other endpoints should still be closed + assert.Equal(t, CircuitClosed, cbs.State("openai", "/v1/responses")) + assert.Equal(t, CircuitClosed, cbs.State("anthropic", "/v1/messages")) + assert.True(t, cbs.Allow("openai", "/v1/responses")) + assert.True(t, cbs.Allow("anthropic", "/v1/messages")) } -func TestCircuitBreaker_OnlyCountsRelevantStatusCodes(t *testing.T) { +func TestCircuitBreakers_OnlyCountsRelevantStatusCodes(t *testing.T) { t.Parallel() cfg := CircuitBreakerConfig{ @@ -120,25 +110,22 @@ func TestCircuitBreaker_OnlyCountsRelevantStatusCodes(t *testing.T) { Cooldown: time.Minute, HalfOpenMaxRequests: 2, } - cb := NewCircuitBreaker("test", cfg) + cbs := NewCircuitBreakers(cfg, nil) // Non-circuit-breaker status codes should not count - cb.RecordFailure(http.StatusBadRequest) // 400 - cb.RecordFailure(http.StatusUnauthorized) // 401 - cb.RecordFailure(http.StatusInternalServerError) // 500 - cb.RecordFailure(http.StatusBadGateway) // 502 - assert.Equal(t, CircuitClosed, cb.State()) - assert.Equal(t, int64(0), cb.Failures()) + cbs.RecordFailure("test", "/api", http.StatusBadRequest) // 400 + cbs.RecordFailure("test", "/api", http.StatusUnauthorized) // 401 + cbs.RecordFailure("test", "/api", http.StatusInternalServerError) // 500 + cbs.RecordFailure("test", "/api", http.StatusBadGateway) // 502 + assert.Equal(t, CircuitClosed, cbs.State("test", "/api")) // These should count - cb.RecordFailure(http.StatusTooManyRequests) // 429 - assert.Equal(t, int64(1), cb.Failures()) - - cb.RecordFailure(http.StatusServiceUnavailable) // 503 - assert.Equal(t, CircuitOpen, cb.State()) + cbs.RecordFailure("test", "/api", http.StatusTooManyRequests) // 429 + cbs.RecordFailure("test", "/api", http.StatusServiceUnavailable) // 503 + assert.Equal(t, CircuitOpen, cbs.State("test", "/api")) } -func TestCircuitBreaker_Anthropic529(t *testing.T) { +func TestCircuitBreakers_Anthropic529(t *testing.T) { t.Parallel() cfg := CircuitBreakerConfig{ @@ -148,75 +135,44 @@ func TestCircuitBreaker_Anthropic529(t *testing.T) { Cooldown: time.Minute, HalfOpenMaxRequests: 1, } - cb := NewCircuitBreaker("anthropic", cfg) + cbs := NewCircuitBreakers(cfg, nil) // Anthropic-specific 529 "Overloaded" should trip the circuit - tripped := cb.RecordFailure(529) + tripped := cbs.RecordFailure("anthropic", "/v1/messages", 529) assert.True(t, tripped) - assert.Equal(t, CircuitOpen, cb.State()) -} - -func TestCircuitBreaker_WindowReset(t *testing.T) { - t.Parallel() - - cfg := CircuitBreakerConfig{ - Enabled: true, - FailureThreshold: 3, - Window: 50 * time.Millisecond, // Short window - Cooldown: time.Minute, - HalfOpenMaxRequests: 2, - } - cb := NewCircuitBreaker("test", cfg) - - // Record failures - cb.RecordFailure(http.StatusTooManyRequests) - cb.RecordFailure(http.StatusTooManyRequests) - assert.Equal(t, int64(2), cb.Failures()) - - // Wait for window to expire - time.Sleep(60 * time.Millisecond) - - // Next failure should reset counter (due to window expiry) - cb.RecordFailure(http.StatusTooManyRequests) - assert.Equal(t, int64(1), cb.Failures()) - assert.Equal(t, CircuitClosed, cb.State()) + assert.Equal(t, CircuitOpen, cbs.State("anthropic", "/v1/messages")) } -func TestCircuitBreaker_ConcurrentAccess(t *testing.T) { +func TestCircuitBreakers_ConcurrentAccess(t *testing.T) { t.Parallel() cfg := CircuitBreakerConfig{ Enabled: true, - FailureThreshold: 100, + FailureThreshold: 1000, Window: time.Minute, Cooldown: time.Minute, HalfOpenMaxRequests: 10, } - cb := NewCircuitBreaker("test", cfg) + cbs := NewCircuitBreakers(cfg, nil) var wg sync.WaitGroup - numGoroutines := 50 - opsPerGoroutine := 100 - - for i := 0; i < numGoroutines; i++ { + for i := 0; i < 50; i++ { wg.Add(1) go func() { defer wg.Done() - for j := 0; j < opsPerGoroutine; j++ { - cb.Allow() - cb.RecordSuccess() - cb.RecordFailure(http.StatusTooManyRequests) - cb.State() - cb.Failures() + for j := 0; j < 100; j++ { + cbs.Allow("test", "/api") + cbs.RecordSuccess("test", "/api") + cbs.RecordFailure("test", "/api", http.StatusTooManyRequests) + cbs.State("test", "/api") } }() } - wg.Wait() // Should not panic or deadlock } -func TestCircuitBreaker_StateChangeCallback(t *testing.T) { +func TestCircuitBreakers_StateChangeCallback(t *testing.T) { t.Parallel() cfg := CircuitBreakerConfig{ @@ -226,38 +182,29 @@ func TestCircuitBreaker_StateChangeCallback(t *testing.T) { Cooldown: 50 * time.Millisecond, HalfOpenMaxRequests: 1, } - cb := NewCircuitBreaker("test", cfg) var mu sync.Mutex - var transitions []struct { - from, to CircuitState - } + var transitions []struct{ from, to CircuitState } - cb.SetStateChangeCallback(func(provider string, from, to CircuitState) { + cbs := NewCircuitBreakers(cfg, func(name string, from, to CircuitState) { mu.Lock() defer mu.Unlock() transitions = append(transitions, struct{ from, to CircuitState }{from, to}) }) // Trip the circuit - cb.RecordFailure(http.StatusTooManyRequests) - cb.RecordFailure(http.StatusTooManyRequests) - - // Wait for callback - time.Sleep(10 * time.Millisecond) + cbs.RecordFailure("test", "/api", http.StatusTooManyRequests) + cbs.RecordFailure("test", "/api", http.StatusTooManyRequests) // Wait for cooldown and trigger half-open time.Sleep(60 * time.Millisecond) - cb.Allow() - - // Wait for callback - time.Sleep(10 * time.Millisecond) + cbs.Allow("test", "/api") // Success to close - cb.RecordSuccess() + cbs.RecordSuccess("test", "/api") - // Wait for callback - time.Sleep(10 * time.Millisecond) + // Wait for callbacks + time.Sleep(20 * time.Millisecond) mu.Lock() defer mu.Unlock() @@ -270,82 +217,6 @@ func TestCircuitBreaker_StateChangeCallback(t *testing.T) { assert.Equal(t, CircuitClosed, transitions[2].to) } -func TestCircuitBreakerManager_GetOrCreate(t *testing.T) { - t.Parallel() - - cfg := CircuitBreakerConfig{ - Enabled: true, - FailureThreshold: 5, - Window: time.Minute, - Cooldown: time.Minute, - } - manager := NewCircuitBreakerManager(cfg) - - // First call should create - cb1 := manager.GetOrCreate("anthropic") - require.NotNil(t, cb1) - assert.Equal(t, "anthropic", cb1.Provider()) - - // Second call should return same instance - cb2 := manager.GetOrCreate("anthropic") - assert.Same(t, cb1, cb2) - - // Different provider gets different instance - cb3 := manager.GetOrCreate("openai") - require.NotNil(t, cb3) - assert.NotSame(t, cb1, cb3) - assert.Equal(t, "openai", cb3.Provider()) -} - -func TestCircuitBreakerManager_AllStates(t *testing.T) { - t.Parallel() - - cfg := CircuitBreakerConfig{ - Enabled: true, - FailureThreshold: 1, - Window: time.Minute, - Cooldown: time.Minute, - } - manager := NewCircuitBreakerManager(cfg) - - manager.GetOrCreate("anthropic") - manager.GetOrCreate("openai") - - // Trip one circuit - manager.Get("anthropic").RecordFailure(http.StatusTooManyRequests) - - states := manager.AllStates() - assert.Equal(t, CircuitOpen, states["anthropic"]) - assert.Equal(t, CircuitClosed, states["openai"]) -} - -func TestCircuitBreakerManager_ConcurrentGetOrCreate(t *testing.T) { - t.Parallel() - - cfg := DefaultCircuitBreakerConfig() - cfg.Enabled = true - manager := NewCircuitBreakerManager(cfg) - - var wg sync.WaitGroup - var results [100]*CircuitBreaker - - for i := 0; i < 100; i++ { - wg.Add(1) - go func(idx int) { - defer wg.Done() - results[idx] = manager.GetOrCreate("test-provider") - }(i) - } - - wg.Wait() - - // All should be the same instance - first := results[0] - for i := 1; i < 100; i++ { - assert.Same(t, first, results[i]) - } -} - func TestIsCircuitBreakerFailure(t *testing.T) { t.Parallel() @@ -356,11 +227,8 @@ func TestIsCircuitBreakerFailure(t *testing.T) { {http.StatusOK, false}, {http.StatusBadRequest, false}, {http.StatusUnauthorized, false}, - {http.StatusForbidden, false}, - {http.StatusNotFound, false}, {http.StatusTooManyRequests, true}, // 429 {http.StatusInternalServerError, false}, - {http.StatusBadGateway, false}, {http.StatusServiceUnavailable, true}, // 503 {529, true}, // Anthropic Overloaded } diff --git a/go.mod b/go.mod index 9a62089..4715f99 100644 --- a/go.mod +++ b/go.mod @@ -33,6 +33,8 @@ require ( go.opentelemetry.io/otel/trace v1.38.0 ) +require github.com/sony/gobreaker/v2 v2.3.0 + require ( github.com/aws/aws-sdk-go-v2 v1.30.3 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.3 // indirect diff --git a/go.sum b/go.sum index 385345d..fff1ee3 100644 --- a/go.sum +++ b/go.sum @@ -110,6 +110,8 @@ github.com/rivo/uniseg v0.4.4 h1:8TfxU8dW6PdqD27gjM8MVNuicgxIjxpm4K7x4jp8sis= github.com/rivo/uniseg v0.4.4/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= +github.com/sony/gobreaker/v2 v2.3.0 h1:7VYxZ69QXRQ2Q4eEawHn6eU4FiuwovzJwsUMA03Lu4I= +github.com/sony/gobreaker/v2 v2.3.0/go.mod h1:pTyFJgcZ3h2tdQVLZZruK2C0eoFL1fb/G83wK1ZQl+s= github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y= github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= diff --git a/interception.go b/interception.go index d1d8444..f1d7472 100644 --- a/interception.go +++ b/interception.go @@ -40,23 +40,26 @@ const recordingTimeout = time.Second * 5 // newInterceptionProcessor returns an [http.HandlerFunc] which is capable of creating a new interceptor and processing a given request // using [Provider] p, recording all usage events using [Recorder] recorder. -func newInterceptionProcessor(p Provider, recorder Recorder, mcpProxy mcp.ServerProxier, logger slog.Logger, metrics *Metrics, tracer trace.Tracer, cbManager *CircuitBreakerManager) http.HandlerFunc { +func newInterceptionProcessor(p Provider, recorder Recorder, mcpProxy mcp.ServerProxier, logger slog.Logger, metrics *Metrics, tracer trace.Tracer, cbs *CircuitBreakers) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx, span := tracer.Start(r.Context(), "Intercept") defer span.End() + // Extract endpoint (route) for per-endpoint circuit breaker + route := strings.TrimPrefix(r.URL.Path, fmt.Sprintf("/%s", p.Name())) + // Check circuit breaker before proceeding - cb := cbManager.GetOrCreate(p.Name()) - if !cb.Allow() { + if !cbs.Allow(p.Name(), route) { span.SetStatus(codes.Error, "circuit breaker open") logger.Debug(ctx, "request rejected by circuit breaker", slog.F("provider", p.Name()), - slog.F("circuit_state", cb.State().String()), + slog.F("endpoint", route), + slog.F("circuit_state", cbs.State(p.Name(), route).String()), ) if metrics != nil { - metrics.CircuitBreakerRejects.WithLabelValues(p.Name()).Inc() + metrics.CircuitBreakerRejects.WithLabelValues(p.Name(), route).Inc() } - http.Error(w, fmt.Sprintf("%s is currently unavailable due to upstream rate limiting. Please try again later.", p.Name()), http.StatusServiceUnavailable) + http.Error(w, fmt.Sprintf("%s %s is currently unavailable due to upstream rate limiting. Please try again later.", p.Name(), route), http.StatusServiceUnavailable) return } @@ -108,7 +111,6 @@ func newInterceptionProcessor(p Provider, recorder Recorder, mcpProxy mcp.Server return } - route := strings.TrimPrefix(r.URL.Path, fmt.Sprintf("/%s", p.Name())) log := logger.With( slog.F("route", route), slog.F("provider", p.Name()), @@ -134,7 +136,7 @@ func newInterceptionProcessor(p Provider, recorder Recorder, mcpProxy mcp.Server // Record failure for circuit breaker - extract status code if available if statusCode := extractStatusCodeFromError(err); statusCode > 0 { - if cb.RecordFailure(statusCode) { + if cbs.RecordFailure(p.Name(), route, statusCode) { log.Warn(ctx, "circuit breaker tripped", slog.F("provider", p.Name()), slog.F("status_code", statusCode), @@ -148,7 +150,7 @@ func newInterceptionProcessor(p Provider, recorder Recorder, mcpProxy mcp.Server log.Debug(ctx, "interception ended") // Record success for circuit breaker - cb.RecordSuccess() + cbs.RecordSuccess(p.Name(), route) } asyncRecorder.RecordInterceptionEnded(ctx, &InterceptionRecordEnded{ID: interceptor.ID().String()}) diff --git a/metrics.go b/metrics.go index 565029a..f744d10 100644 --- a/metrics.go +++ b/metrics.go @@ -110,23 +110,23 @@ func NewMetrics(reg prometheus.Registerer) *Metrics { // Circuit breaker metrics. - // Pessimistic cardinality: 2 providers = up to 2. + // Pessimistic cardinality: 2 providers, 5 endpoints = up to 10. CircuitBreakerState: promauto.With(reg).NewGaugeVec(prometheus.GaugeOpts{ Subsystem: "circuit_breaker", Name: "state", Help: "Current state of the circuit breaker (0=closed, 1=open, 2=half-open).", - }, []string{"provider"}), - // Pessimistic cardinality: 2 providers = up to 2. + }, []string{"provider", "endpoint"}), + // Pessimistic cardinality: 2 providers, 5 endpoints = up to 10. CircuitBreakerTrips: promauto.With(reg).NewCounterVec(prometheus.CounterOpts{ Subsystem: "circuit_breaker", Name: "trips_total", Help: "Total number of times the circuit breaker has tripped open.", - }, []string{"provider"}), - // Pessimistic cardinality: 2 providers = up to 2. + }, []string{"provider", "endpoint"}), + // Pessimistic cardinality: 2 providers, 5 endpoints = up to 10. CircuitBreakerRejects: promauto.With(reg).NewCounterVec(prometheus.CounterOpts{ Subsystem: "circuit_breaker", Name: "rejects_total", Help: "Total number of requests rejected due to open circuit breaker.", - }, []string{"provider"}), + }, []string{"provider", "endpoint"}), } } From 8cf2d18baddab1342b417179f4440c2a2f8268c7 Mon Sep 17 00:00:00 2001 From: Kacper Sawicki Date: Tue, 16 Dec 2025 10:30:03 +0000 Subject: [PATCH 04/26] refactor: align CircuitBreakerConfig fields with gobreaker.Settings Rename fields to match gobreaker naming convention: - Window -> Interval - Cooldown -> Timeout - HalfOpenMaxRequests -> MaxRequests - FailureThreshold type int64 -> uint32 --- circuit_breaker.go | 33 ++++++++++---------- circuit_breaker_test.go | 68 ++++++++++++++++++++--------------------- 2 files changed, 51 insertions(+), 50 deletions(-) diff --git a/circuit_breaker.go b/circuit_breaker.go index d27f0cd..5ecf674 100644 --- a/circuit_breaker.go +++ b/circuit_breaker.go @@ -49,27 +49,28 @@ func toCircuitState(s gobreaker.State) CircuitState { } // CircuitBreakerConfig holds configuration for circuit breakers. +// Fields match gobreaker.Settings for clarity. type CircuitBreakerConfig struct { // Enabled controls whether circuit breakers are active. Enabled bool + // MaxRequests is the maximum number of requests allowed in half-open state. + MaxRequests uint32 + // Interval is the cyclic period of the closed state for clearing internal counts. + Interval time.Duration + // Timeout is how long the circuit stays open before transitioning to half-open. + Timeout time.Duration // FailureThreshold is the number of consecutive failures that triggers the circuit to open. - FailureThreshold int64 - // Window is the time window for counting failures. - Window time.Duration - // Cooldown is how long the circuit stays open before transitioning to half-open. - Cooldown time.Duration - // HalfOpenMaxRequests is the maximum number of requests allowed in half-open state. - HalfOpenMaxRequests int64 + FailureThreshold uint32 } // DefaultCircuitBreakerConfig returns sensible defaults for circuit breaker configuration. func DefaultCircuitBreakerConfig() CircuitBreakerConfig { return CircuitBreakerConfig{ - Enabled: false, // Disabled by default for backward compatibility - FailureThreshold: 5, - Window: 10 * time.Second, - Cooldown: 30 * time.Second, - HalfOpenMaxRequests: 3, + Enabled: false, // Disabled by default for backward compatibility + FailureThreshold: 5, + Interval: 10 * time.Second, + Timeout: 30 * time.Second, + MaxRequests: 3, } } @@ -150,11 +151,11 @@ func (c *CircuitBreakers) getOrCreate(provider, endpoint string) *gobreaker.Circ settings := gobreaker.Settings{ Name: key, - MaxRequests: uint32(c.config.HalfOpenMaxRequests), - Interval: c.config.Window, - Timeout: c.config.Cooldown, + MaxRequests: c.config.MaxRequests, + Interval: c.config.Interval, + Timeout: c.config.Timeout, ReadyToTrip: func(counts gobreaker.Counts) bool { - return counts.ConsecutiveFailures >= uint32(c.config.FailureThreshold) + return counts.ConsecutiveFailures >= c.config.FailureThreshold }, OnStateChange: func(name string, from, to gobreaker.State) { if c.onChange != nil { diff --git a/circuit_breaker_test.go b/circuit_breaker_test.go index 0d8deb9..c7e2c82 100644 --- a/circuit_breaker_test.go +++ b/circuit_breaker_test.go @@ -15,10 +15,10 @@ func TestCircuitBreaker_DefaultConfig(t *testing.T) { cfg := DefaultCircuitBreakerConfig() assert.False(t, cfg.Enabled, "should be disabled by default") - assert.Equal(t, int64(5), cfg.FailureThreshold) - assert.Equal(t, 10*time.Second, cfg.Window) - assert.Equal(t, 30*time.Second, cfg.Cooldown) - assert.Equal(t, int64(3), cfg.HalfOpenMaxRequests) + assert.Equal(t, uint32(5), cfg.FailureThreshold) + assert.Equal(t, 10*time.Second, cfg.Interval) + assert.Equal(t, 30*time.Second, cfg.Timeout) + assert.Equal(t, uint32(3), cfg.MaxRequests) } func TestCircuitBreakers_DisabledByDefault(t *testing.T) { @@ -41,11 +41,11 @@ func TestCircuitBreakers_StateTransitions(t *testing.T) { t.Parallel() cfg := CircuitBreakerConfig{ - Enabled: true, - FailureThreshold: 3, - Window: time.Minute, - Cooldown: 50 * time.Millisecond, - HalfOpenMaxRequests: 2, + Enabled: true, + FailureThreshold: 3, + Interval: time.Minute, + Timeout: 50 * time.Millisecond, + MaxRequests: 2, } cbs := NewCircuitBreakers(cfg, nil) @@ -81,11 +81,11 @@ func TestCircuitBreakers_PerEndpointIsolation(t *testing.T) { t.Parallel() cfg := CircuitBreakerConfig{ - Enabled: true, - FailureThreshold: 1, - Window: time.Minute, - Cooldown: time.Minute, - HalfOpenMaxRequests: 1, + Enabled: true, + FailureThreshold: 1, + Interval: time.Minute, + Timeout: time.Minute, + MaxRequests: 1, } cbs := NewCircuitBreakers(cfg, nil) @@ -104,11 +104,11 @@ func TestCircuitBreakers_OnlyCountsRelevantStatusCodes(t *testing.T) { t.Parallel() cfg := CircuitBreakerConfig{ - Enabled: true, - FailureThreshold: 2, - Window: time.Minute, - Cooldown: time.Minute, - HalfOpenMaxRequests: 2, + Enabled: true, + FailureThreshold: 2, + Interval: time.Minute, + Timeout: time.Minute, + MaxRequests: 2, } cbs := NewCircuitBreakers(cfg, nil) @@ -129,11 +129,11 @@ func TestCircuitBreakers_Anthropic529(t *testing.T) { t.Parallel() cfg := CircuitBreakerConfig{ - Enabled: true, - FailureThreshold: 1, - Window: time.Minute, - Cooldown: time.Minute, - HalfOpenMaxRequests: 1, + Enabled: true, + FailureThreshold: 1, + Interval: time.Minute, + Timeout: time.Minute, + MaxRequests: 1, } cbs := NewCircuitBreakers(cfg, nil) @@ -147,11 +147,11 @@ func TestCircuitBreakers_ConcurrentAccess(t *testing.T) { t.Parallel() cfg := CircuitBreakerConfig{ - Enabled: true, - FailureThreshold: 1000, - Window: time.Minute, - Cooldown: time.Minute, - HalfOpenMaxRequests: 10, + Enabled: true, + FailureThreshold: 1000, + Interval: time.Minute, + Timeout: time.Minute, + MaxRequests: 10, } cbs := NewCircuitBreakers(cfg, nil) @@ -176,11 +176,11 @@ func TestCircuitBreakers_StateChangeCallback(t *testing.T) { t.Parallel() cfg := CircuitBreakerConfig{ - Enabled: true, - FailureThreshold: 2, - Window: time.Minute, - Cooldown: 50 * time.Millisecond, - HalfOpenMaxRequests: 1, + Enabled: true, + FailureThreshold: 2, + Interval: time.Minute, + Timeout: 50 * time.Millisecond, + MaxRequests: 1, } var mu sync.Mutex From 8e44145e9bb28a2fb0ceefcbcb7cbe85380e780b Mon Sep 17 00:00:00 2001 From: Kacper Sawicki Date: Tue, 16 Dec 2025 10:32:01 +0000 Subject: [PATCH 05/26] refactor: remove CircuitState, use gobreaker.State directly --- bridge.go | 10 ++++---- circuit_breaker.go | 51 +++++----------------------------------- circuit_breaker_test.go | 52 +++++++++++++++++------------------------ interception.go | 2 +- 4 files changed, 34 insertions(+), 81 deletions(-) diff --git a/bridge.go b/bridge.go index 3147486..4cafa6b 100644 --- a/bridge.go +++ b/bridge.go @@ -10,9 +10,9 @@ import ( "cdr.dev/slog" "github.com/coder/aibridge/mcp" - "go.opentelemetry.io/otel/trace" - "github.com/hashicorp/go-multierror" + "github.com/sony/gobreaker/v2" + "go.opentelemetry.io/otel/trace" ) // RequestBridge is an [http.Handler] which is capable of masquerading as AI providers' APIs; @@ -63,12 +63,12 @@ func NewRequestBridgeWithCircuitBreaker(ctx context.Context, providers []Provide mux := http.NewServeMux() // Create circuit breakers with metrics callback - var onChange func(name string, from, to CircuitState) + var onChange func(name string, from, to gobreaker.State) if metrics != nil { - onChange = func(name string, from, to CircuitState) { + onChange = func(name string, from, to gobreaker.State) { provider, endpoint, _ := strings.Cut(name, ":") metrics.CircuitBreakerState.WithLabelValues(provider, endpoint).Set(float64(to)) - if to == CircuitOpen { + if to == gobreaker.StateOpen { metrics.CircuitBreakerTrips.WithLabelValues(provider, endpoint).Inc() } } diff --git a/circuit_breaker.go b/circuit_breaker.go index 5ecf674..06754bb 100644 --- a/circuit_breaker.go +++ b/circuit_breaker.go @@ -9,45 +9,6 @@ import ( "github.com/sony/gobreaker/v2" ) -// CircuitState represents the current state of a circuit breaker. -type CircuitState int - -const ( - // CircuitClosed is the normal state - all requests pass through. - CircuitClosed CircuitState = iota - // CircuitOpen is the tripped state - requests are rejected immediately. - CircuitOpen - // CircuitHalfOpen is the testing state - limited requests pass through. - CircuitHalfOpen -) - -func (s CircuitState) String() string { - switch s { - case CircuitClosed: - return "closed" - case CircuitOpen: - return "open" - case CircuitHalfOpen: - return "half-open" - default: - return "unknown" - } -} - -// toCircuitState converts gobreaker.State to our CircuitState. -func toCircuitState(s gobreaker.State) CircuitState { - switch s { - case gobreaker.StateClosed: - return CircuitClosed - case gobreaker.StateOpen: - return CircuitOpen - case gobreaker.StateHalfOpen: - return CircuitHalfOpen - default: - return CircuitClosed - } -} - // CircuitBreakerConfig holds configuration for circuit breakers. // Fields match gobreaker.Settings for clarity. type CircuitBreakerConfig struct { @@ -92,11 +53,11 @@ func isCircuitBreakerFailure(statusCode int) bool { type CircuitBreakers struct { breakers sync.Map // map[string]*gobreaker.CircuitBreaker[any] config CircuitBreakerConfig - onChange func(name string, from, to CircuitState) + onChange func(name string, from, to gobreaker.State) } // NewCircuitBreakers creates a new circuit breaker manager. -func NewCircuitBreakers(config CircuitBreakerConfig, onChange func(name string, from, to CircuitState)) *CircuitBreakers { +func NewCircuitBreakers(config CircuitBreakerConfig, onChange func(name string, from, to gobreaker.State)) *CircuitBreakers { return &CircuitBreakers{ config: config, onChange: onChange, @@ -135,12 +96,12 @@ func (c *CircuitBreakers) RecordFailure(provider, endpoint string, statusCode in } // State returns the current state for a provider/endpoint. -func (c *CircuitBreakers) State(provider, endpoint string) CircuitState { +func (c *CircuitBreakers) State(provider, endpoint string) gobreaker.State { if !c.config.Enabled { - return CircuitClosed + return gobreaker.StateClosed } cb := c.getOrCreate(provider, endpoint) - return toCircuitState(cb.State()) + return cb.State() } func (c *CircuitBreakers) getOrCreate(provider, endpoint string) *gobreaker.CircuitBreaker[any] { @@ -159,7 +120,7 @@ func (c *CircuitBreakers) getOrCreate(provider, endpoint string) *gobreaker.Circ }, OnStateChange: func(name string, from, to gobreaker.State) { if c.onChange != nil { - c.onChange(name, toCircuitState(from), toCircuitState(to)) + c.onChange(name, from, to) } }, } diff --git a/circuit_breaker_test.go b/circuit_breaker_test.go index c7e2c82..ca45620 100644 --- a/circuit_breaker_test.go +++ b/circuit_breaker_test.go @@ -6,6 +6,7 @@ import ( "testing" "time" + "github.com/sony/gobreaker/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -34,7 +35,7 @@ func TestCircuitBreakers_DisabledByDefault(t *testing.T) { cbs.RecordFailure("anthropic", "/v1/messages", http.StatusTooManyRequests) } assert.True(t, cbs.Allow("anthropic", "/v1/messages")) - assert.Equal(t, CircuitClosed, cbs.State("anthropic", "/v1/messages")) + assert.Equal(t, gobreaker.StateClosed, cbs.State("anthropic", "/v1/messages")) } func TestCircuitBreakers_StateTransitions(t *testing.T) { @@ -50,18 +51,18 @@ func TestCircuitBreakers_StateTransitions(t *testing.T) { cbs := NewCircuitBreakers(cfg, nil) // Start in closed state - assert.Equal(t, CircuitClosed, cbs.State("test", "/api")) + assert.Equal(t, gobreaker.StateClosed, cbs.State("test", "/api")) assert.True(t, cbs.Allow("test", "/api")) // Record failures below threshold cbs.RecordFailure("test", "/api", http.StatusTooManyRequests) cbs.RecordFailure("test", "/api", http.StatusTooManyRequests) - assert.Equal(t, CircuitClosed, cbs.State("test", "/api")) + assert.Equal(t, gobreaker.StateClosed, cbs.State("test", "/api")) // Third failure should trip the circuit tripped := cbs.RecordFailure("test", "/api", http.StatusTooManyRequests) assert.True(t, tripped) - assert.Equal(t, CircuitOpen, cbs.State("test", "/api")) + assert.Equal(t, gobreaker.StateOpen, cbs.State("test", "/api")) assert.False(t, cbs.Allow("test", "/api")) // Wait for cooldown @@ -69,12 +70,12 @@ func TestCircuitBreakers_StateTransitions(t *testing.T) { // Should transition to half-open and allow request assert.True(t, cbs.Allow("test", "/api")) - assert.Equal(t, CircuitHalfOpen, cbs.State("test", "/api")) + assert.Equal(t, gobreaker.StateHalfOpen, cbs.State("test", "/api")) // Success in half-open should eventually close cbs.RecordSuccess("test", "/api") cbs.RecordSuccess("test", "/api") - assert.Equal(t, CircuitClosed, cbs.State("test", "/api")) + assert.Equal(t, gobreaker.StateClosed, cbs.State("test", "/api")) } func TestCircuitBreakers_PerEndpointIsolation(t *testing.T) { @@ -91,11 +92,11 @@ func TestCircuitBreakers_PerEndpointIsolation(t *testing.T) { // Trip circuit for one endpoint cbs.RecordFailure("openai", "/v1/chat/completions", http.StatusTooManyRequests) - assert.Equal(t, CircuitOpen, cbs.State("openai", "/v1/chat/completions")) + assert.Equal(t, gobreaker.StateOpen, cbs.State("openai", "/v1/chat/completions")) // Other endpoints should still be closed - assert.Equal(t, CircuitClosed, cbs.State("openai", "/v1/responses")) - assert.Equal(t, CircuitClosed, cbs.State("anthropic", "/v1/messages")) + assert.Equal(t, gobreaker.StateClosed, cbs.State("openai", "/v1/responses")) + assert.Equal(t, gobreaker.StateClosed, cbs.State("anthropic", "/v1/messages")) assert.True(t, cbs.Allow("openai", "/v1/responses")) assert.True(t, cbs.Allow("anthropic", "/v1/messages")) } @@ -117,12 +118,12 @@ func TestCircuitBreakers_OnlyCountsRelevantStatusCodes(t *testing.T) { cbs.RecordFailure("test", "/api", http.StatusUnauthorized) // 401 cbs.RecordFailure("test", "/api", http.StatusInternalServerError) // 500 cbs.RecordFailure("test", "/api", http.StatusBadGateway) // 502 - assert.Equal(t, CircuitClosed, cbs.State("test", "/api")) + assert.Equal(t, gobreaker.StateClosed, cbs.State("test", "/api")) // These should count cbs.RecordFailure("test", "/api", http.StatusTooManyRequests) // 429 cbs.RecordFailure("test", "/api", http.StatusServiceUnavailable) // 503 - assert.Equal(t, CircuitOpen, cbs.State("test", "/api")) + assert.Equal(t, gobreaker.StateOpen, cbs.State("test", "/api")) } func TestCircuitBreakers_Anthropic529(t *testing.T) { @@ -140,7 +141,7 @@ func TestCircuitBreakers_Anthropic529(t *testing.T) { // Anthropic-specific 529 "Overloaded" should trip the circuit tripped := cbs.RecordFailure("anthropic", "/v1/messages", 529) assert.True(t, tripped) - assert.Equal(t, CircuitOpen, cbs.State("anthropic", "/v1/messages")) + assert.Equal(t, gobreaker.StateOpen, cbs.State("anthropic", "/v1/messages")) } func TestCircuitBreakers_ConcurrentAccess(t *testing.T) { @@ -184,12 +185,12 @@ func TestCircuitBreakers_StateChangeCallback(t *testing.T) { } var mu sync.Mutex - var transitions []struct{ from, to CircuitState } + var transitions []struct{ from, to gobreaker.State } - cbs := NewCircuitBreakers(cfg, func(name string, from, to CircuitState) { + cbs := NewCircuitBreakers(cfg, func(name string, from, to gobreaker.State) { mu.Lock() defer mu.Unlock() - transitions = append(transitions, struct{ from, to CircuitState }{from, to}) + transitions = append(transitions, struct{ from, to gobreaker.State }{from, to}) }) // Trip the circuit @@ -209,12 +210,12 @@ func TestCircuitBreakers_StateChangeCallback(t *testing.T) { mu.Lock() defer mu.Unlock() require.Len(t, transitions, 3) - assert.Equal(t, CircuitClosed, transitions[0].from) - assert.Equal(t, CircuitOpen, transitions[0].to) - assert.Equal(t, CircuitOpen, transitions[1].from) - assert.Equal(t, CircuitHalfOpen, transitions[1].to) - assert.Equal(t, CircuitHalfOpen, transitions[2].from) - assert.Equal(t, CircuitClosed, transitions[2].to) + assert.Equal(t, gobreaker.StateClosed, transitions[0].from) + assert.Equal(t, gobreaker.StateOpen, transitions[0].to) + assert.Equal(t, gobreaker.StateOpen, transitions[1].from) + assert.Equal(t, gobreaker.StateHalfOpen, transitions[1].to) + assert.Equal(t, gobreaker.StateHalfOpen, transitions[2].from) + assert.Equal(t, gobreaker.StateClosed, transitions[2].to) } func TestIsCircuitBreakerFailure(t *testing.T) { @@ -239,12 +240,3 @@ func TestIsCircuitBreakerFailure(t *testing.T) { }) } } - -func TestCircuitState_String(t *testing.T) { - t.Parallel() - - assert.Equal(t, "closed", CircuitClosed.String()) - assert.Equal(t, "open", CircuitOpen.String()) - assert.Equal(t, "half-open", CircuitHalfOpen.String()) - assert.Equal(t, "unknown", CircuitState(99).String()) -} diff --git a/interception.go b/interception.go index f1d7472..da589be 100644 --- a/interception.go +++ b/interception.go @@ -54,7 +54,7 @@ func newInterceptionProcessor(p Provider, recorder Recorder, mcpProxy mcp.Server logger.Debug(ctx, "request rejected by circuit breaker", slog.F("provider", p.Name()), slog.F("endpoint", route), - slog.F("circuit_state", cbs.State(p.Name(), route).String()), + slog.F("circuit_state", cbs.State(p.Name(), route)), ) if metrics != nil { metrics.CircuitBreakerRejects.WithLabelValues(p.Name(), route).Inc() From 7af3bc1c1fe30b6f6cddb6a0f47cacacbc689bfb Mon Sep 17 00:00:00 2001 From: Kacper Sawicki Date: Wed, 17 Dec 2025 08:51:36 +0000 Subject: [PATCH 06/26] refactor: implement circuit breaker as middleware with per-provider configs Address PR review feedback: 1. Middleware pattern - Circuit breaker is now HTTP middleware that wraps handlers, capturing response status codes directly instead of extracting from provider-specific error types. 2. Per-provider configs - NewCircuitBreakers takes map[string]CircuitBreakerConfig keyed by provider name. Providers not in the map have no circuit breaker. 3. Remove provider overfitting - Deleted extractStatusCodeFromError() which hardcoded AnthropicErrorResponse and OpenAIErrorResponse types. Middleware now uses statusCapturingWriter to inspect actual HTTP response codes. 4. Configurable failure detection - IsFailure func in config allows providers to define custom status codes as failures. Defaults to 429/503/529. 5. Fix gauge values - State gauge now uses 0 (closed), 0.5 (half-open), 1 (open) 6. Integration tests - Replaced unit tests with httptest-based integration tests that verify actual behavior: upstream errors trip circuit, requests get blocked, recovery after timeout, per-endpoint isolation. 7. Error message - Changed from 'upstream rate limiting' to 'circuit breaker is open' --- bridge.go | 19 +- circuit_breaker.go | 166 ++++++++++++------ circuit_breaker_test.go | 372 +++++++++++++++++++++------------------- interception.go | 53 +----- 4 files changed, 325 insertions(+), 285 deletions(-) diff --git a/bridge.go b/bridge.go index 4cafa6b..bef401c 100644 --- a/bridge.go +++ b/bridge.go @@ -54,12 +54,13 @@ var _ http.Handler = &RequestBridge{} // // mcpProxy will be closed when the [RequestBridge] is closed. func NewRequestBridge(ctx context.Context, providers []Provider, recorder Recorder, mcpProxy mcp.ServerProxier, logger slog.Logger, metrics *Metrics, tracer trace.Tracer) (*RequestBridge, error) { - return NewRequestBridgeWithCircuitBreaker(ctx, providers, recorder, mcpProxy, logger, metrics, tracer, DefaultCircuitBreakerConfig()) + return NewRequestBridgeWithCircuitBreaker(ctx, providers, recorder, mcpProxy, logger, metrics, tracer, nil) } -// NewRequestBridgeWithCircuitBreaker creates a new *[RequestBridge] with custom circuit breaker configuration. -// See [NewRequestBridge] for more details. -func NewRequestBridgeWithCircuitBreaker(ctx context.Context, providers []Provider, recorder Recorder, mcpProxy mcp.ServerProxier, logger slog.Logger, metrics *Metrics, tracer trace.Tracer, cbConfig CircuitBreakerConfig) (*RequestBridge, error) { +// NewRequestBridgeWithCircuitBreaker creates a new *[RequestBridge] with per-provider circuit breaker configuration. +// The cbConfigs map is keyed by provider name. Providers not in the map will not have circuit breaker protection. +// Pass nil to disable circuit breakers entirely. +func NewRequestBridgeWithCircuitBreaker(ctx context.Context, providers []Provider, recorder Recorder, mcpProxy mcp.ServerProxier, logger slog.Logger, metrics *Metrics, tracer trace.Tracer, cbConfigs map[string]CircuitBreakerConfig) (*RequestBridge, error) { mux := http.NewServeMux() // Create circuit breakers with metrics callback @@ -67,19 +68,21 @@ func NewRequestBridgeWithCircuitBreaker(ctx context.Context, providers []Provide if metrics != nil { onChange = func(name string, from, to gobreaker.State) { provider, endpoint, _ := strings.Cut(name, ":") - metrics.CircuitBreakerState.WithLabelValues(provider, endpoint).Set(float64(to)) + metrics.CircuitBreakerState.WithLabelValues(provider, endpoint).Set(stateToGaugeValue(to)) if to == gobreaker.StateOpen { metrics.CircuitBreakerTrips.WithLabelValues(provider, endpoint).Inc() } } } - cbs := NewCircuitBreakers(cbConfig, onChange) + cbs := NewCircuitBreakers(cbConfigs, onChange) for _, provider := range providers { - // Add the known provider-specific routes which are bridged (i.e. intercepted and augmented). for _, path := range provider.BridgedRoutes() { - mux.HandleFunc(path, newInterceptionProcessor(provider, recorder, mcpProxy, logger, metrics, tracer, cbs)) + handler := newInterceptionProcessor(provider, recorder, mcpProxy, logger, metrics, tracer) + // Wrap with circuit breaker middleware if configured for this provider + handler = CircuitBreakerMiddleware(cbs, metrics, provider.Name())(handler).ServeHTTP + mux.HandleFunc(path, handler) } // Any requests which passthrough to this will be reverse-proxied to the upstream. diff --git a/circuit_breaker.go b/circuit_breaker.go index 06754bb..9ffd0cc 100644 --- a/circuit_breaker.go +++ b/circuit_breaker.go @@ -3,6 +3,7 @@ package aibridge import ( "fmt" "net/http" + "strings" "sync" "time" @@ -12,8 +13,6 @@ import ( // CircuitBreakerConfig holds configuration for circuit breakers. // Fields match gobreaker.Settings for clarity. type CircuitBreakerConfig struct { - // Enabled controls whether circuit breakers are active. - Enabled bool // MaxRequests is the maximum number of requests allowed in half-open state. MaxRequests uint32 // Interval is the cyclic period of the closed state for clearing internal counts. @@ -22,22 +21,26 @@ type CircuitBreakerConfig struct { Timeout time.Duration // FailureThreshold is the number of consecutive failures that triggers the circuit to open. FailureThreshold uint32 + // IsFailure determines if a status code should count as a failure. + // If nil, defaults to 429, 503, and 529 (Anthropic overloaded). + IsFailure func(statusCode int) bool } // DefaultCircuitBreakerConfig returns sensible defaults for circuit breaker configuration. func DefaultCircuitBreakerConfig() CircuitBreakerConfig { return CircuitBreakerConfig{ - Enabled: false, // Disabled by default for backward compatibility FailureThreshold: 5, Interval: 10 * time.Second, Timeout: 30 * time.Second, MaxRequests: 3, + IsFailure: DefaultIsFailure, } } -// isCircuitBreakerFailure returns true if the given HTTP status code -// should count as a failure for circuit breaker purposes. -func isCircuitBreakerFailure(statusCode int) bool { +// DefaultIsFailure returns true for status codes that typically indicate +// upstream overload: 429 (Too Many Requests), 503 (Service Unavailable), +// and 529 (Anthropic Overloaded). +func DefaultIsFailure(statusCode int) bool { switch statusCode { case http.StatusTooManyRequests, // 429 http.StatusServiceUnavailable, // 503 @@ -52,59 +55,40 @@ func isCircuitBreakerFailure(statusCode int) bool { // Circuit breakers are keyed by "provider:endpoint" for per-endpoint isolation. type CircuitBreakers struct { breakers sync.Map // map[string]*gobreaker.CircuitBreaker[any] - config CircuitBreakerConfig + configs map[string]CircuitBreakerConfig onChange func(name string, from, to gobreaker.State) } -// NewCircuitBreakers creates a new circuit breaker manager. -func NewCircuitBreakers(config CircuitBreakerConfig, onChange func(name string, from, to gobreaker.State)) *CircuitBreakers { +// NewCircuitBreakers creates a new circuit breaker manager with per-provider configs. +// The configs map is keyed by provider name. Providers not in the map will not have +// circuit breaker protection. +func NewCircuitBreakers(configs map[string]CircuitBreakerConfig, onChange func(name string, from, to gobreaker.State)) *CircuitBreakers { return &CircuitBreakers{ - config: config, + configs: configs, onChange: onChange, } } -// Allow checks if a request to provider/endpoint should be allowed. -func (c *CircuitBreakers) Allow(provider, endpoint string) bool { - if !c.config.Enabled { - return true - } - cb := c.getOrCreate(provider, endpoint) - return cb.State() != gobreaker.StateOpen -} - -// RecordSuccess records a successful request. -func (c *CircuitBreakers) RecordSuccess(provider, endpoint string) { - if !c.config.Enabled { - return +// getConfig returns the config for a provider, or nil if not configured. +func (c *CircuitBreakers) getConfig(provider string) *CircuitBreakerConfig { + if c.configs == nil { + return nil } - cb := c.getOrCreate(provider, endpoint) - _, _ = cb.Execute(func() (any, error) { return nil, nil }) -} - -// RecordFailure records a failed request. Returns true if this caused the circuit to open. -func (c *CircuitBreakers) RecordFailure(provider, endpoint string, statusCode int) bool { - if !c.config.Enabled || !isCircuitBreakerFailure(statusCode) { - return false + cfg, ok := c.configs[provider] + if !ok { + return nil } - cb := c.getOrCreate(provider, endpoint) - before := cb.State() - _, _ = cb.Execute(func() (any, error) { - return nil, fmt.Errorf("upstream error: %d", statusCode) - }) - return before != gobreaker.StateOpen && cb.State() == gobreaker.StateOpen + return &cfg } -// State returns the current state for a provider/endpoint. -func (c *CircuitBreakers) State(provider, endpoint string) gobreaker.State { - if !c.config.Enabled { - return gobreaker.StateClosed +// getOrCreate returns the circuit breaker for a provider/endpoint, creating if needed. +// Returns nil if the provider is not configured. +func (c *CircuitBreakers) getOrCreate(provider, endpoint string) *gobreaker.CircuitBreaker[any] { + cfg := c.getConfig(provider) + if cfg == nil { + return nil } - cb := c.getOrCreate(provider, endpoint) - return cb.State() -} -func (c *CircuitBreakers) getOrCreate(provider, endpoint string) *gobreaker.CircuitBreaker[any] { key := provider + ":" + endpoint if v, ok := c.breakers.Load(key); ok { return v.(*gobreaker.CircuitBreaker[any]) @@ -112,11 +96,11 @@ func (c *CircuitBreakers) getOrCreate(provider, endpoint string) *gobreaker.Circ settings := gobreaker.Settings{ Name: key, - MaxRequests: c.config.MaxRequests, - Interval: c.config.Interval, - Timeout: c.config.Timeout, + MaxRequests: cfg.MaxRequests, + Interval: cfg.Interval, + Timeout: cfg.Timeout, ReadyToTrip: func(counts gobreaker.Counts) bool { - return counts.ConsecutiveFailures >= c.config.FailureThreshold + return counts.ConsecutiveFailures >= cfg.FailureThreshold }, OnStateChange: func(name string, from, to gobreaker.State) { if c.onChange != nil { @@ -129,3 +113,87 @@ func (c *CircuitBreakers) getOrCreate(provider, endpoint string) *gobreaker.Circ actual, _ := c.breakers.LoadOrStore(key, cb) return actual.(*gobreaker.CircuitBreaker[any]) } + +// statusCapturingWriter wraps http.ResponseWriter to capture the status code. +type statusCapturingWriter struct { + http.ResponseWriter + statusCode int + headerWritten bool +} + +func (w *statusCapturingWriter) WriteHeader(code int) { + if !w.headerWritten { + w.statusCode = code + w.headerWritten = true + } + w.ResponseWriter.WriteHeader(code) +} + +func (w *statusCapturingWriter) Write(b []byte) (int, error) { + if !w.headerWritten { + w.statusCode = http.StatusOK + w.headerWritten = true + } + return w.ResponseWriter.Write(b) +} + +// CircuitBreakerMiddleware returns middleware that wraps handlers with circuit breaker protection. +// It captures the response status code to determine success/failure without provider-specific logic. +func CircuitBreakerMiddleware(cbs *CircuitBreakers, metrics *Metrics, provider string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + cfg := cbs.getConfig(provider) + if cfg == nil { + // No config for this provider, pass through + return next + } + + isFailure := cfg.IsFailure + if isFailure == nil { + isFailure = DefaultIsFailure + } + + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + endpoint := strings.TrimPrefix(r.URL.Path, fmt.Sprintf("/%s", provider)) + + // Check if circuit is open + cb := cbs.getOrCreate(provider, endpoint) + if cb != nil && cb.State() == gobreaker.StateOpen { + if metrics != nil { + metrics.CircuitBreakerRejects.WithLabelValues(provider, endpoint).Inc() + } + http.Error(w, "circuit breaker is open", http.StatusServiceUnavailable) + return + } + + // Wrap response writer to capture status code + sw := &statusCapturingWriter{ResponseWriter: w, statusCode: http.StatusOK} + next.ServeHTTP(sw, r) + + // Record result + if cb != nil { + if isFailure(sw.statusCode) { + _, _ = cb.Execute(func() (any, error) { + return nil, fmt.Errorf("upstream error: %d", sw.statusCode) + }) + } else { + _, _ = cb.Execute(func() (any, error) { return nil, nil }) + } + } + }) + } +} + +// stateToGaugeValue converts gobreaker.State to a gauge value. +// closed=0, half-open=0.5, open=1 +func stateToGaugeValue(s gobreaker.State) float64 { + switch s { + case gobreaker.StateClosed: + return 0 + case gobreaker.StateHalfOpen: + return 0.5 + case gobreaker.StateOpen: + return 1 + default: + return 0 + } +} diff --git a/circuit_breaker_test.go b/circuit_breaker_test.go index ca45620..dca23e2 100644 --- a/circuit_breaker_test.go +++ b/circuit_breaker_test.go @@ -1,8 +1,10 @@ package aibridge import ( + "io" "net/http" - "sync" + "net/http/httptest" + "sync/atomic" "testing" "time" @@ -11,214 +13,225 @@ import ( "github.com/stretchr/testify/require" ) -func TestCircuitBreaker_DefaultConfig(t *testing.T) { +func TestCircuitBreakerMiddleware_TripsOnUpstreamErrors(t *testing.T) { t.Parallel() - cfg := DefaultCircuitBreakerConfig() - assert.False(t, cfg.Enabled, "should be disabled by default") - assert.Equal(t, uint32(5), cfg.FailureThreshold) - assert.Equal(t, 10*time.Second, cfg.Interval) - assert.Equal(t, 30*time.Second, cfg.Timeout) - assert.Equal(t, uint32(3), cfg.MaxRequests) -} - -func TestCircuitBreakers_DisabledByDefault(t *testing.T) { - t.Parallel() + var upstreamCalls atomic.Int32 - cbs := NewCircuitBreakers(DefaultCircuitBreakerConfig(), nil) - - // Should always allow when disabled - assert.True(t, cbs.Allow("anthropic", "/v1/messages")) + // Mock upstream that returns 429 + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upstreamCalls.Add(1) + w.WriteHeader(http.StatusTooManyRequests) + }) - // Recording failures should not affect state when disabled - for i := 0; i < 100; i++ { - cbs.RecordFailure("anthropic", "/v1/messages", http.StatusTooManyRequests) + // Create circuit breaker with low threshold + cbs := NewCircuitBreakers(map[string]CircuitBreakerConfig{ + "test": { + FailureThreshold: 2, + Interval: time.Minute, + Timeout: 50 * time.Millisecond, + MaxRequests: 1, + }, + }, nil) + + // Wrap upstream with circuit breaker middleware + handler := CircuitBreakerMiddleware(cbs, nil, "test")(upstream) + server := httptest.NewServer(handler) + defer server.Close() + + // First 2 requests hit upstream, get 429 + for i := 0; i < 2; i++ { + resp, err := http.Get(server.URL + "/test/v1/messages") + require.NoError(t, err) + resp.Body.Close() + assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode) } - assert.True(t, cbs.Allow("anthropic", "/v1/messages")) - assert.Equal(t, gobreaker.StateClosed, cbs.State("anthropic", "/v1/messages")) + assert.Equal(t, int32(2), upstreamCalls.Load()) + + // Third request should get 503 "circuit breaker is open" without hitting upstream + resp, err := http.Get(server.URL + "/test/v1/messages") + require.NoError(t, err) + body, _ := io.ReadAll(resp.Body) + resp.Body.Close() + assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode) + assert.Contains(t, string(body), "circuit breaker is open") + assert.Equal(t, int32(2), upstreamCalls.Load()) // No new upstream call + + // Wait for timeout, verify recovery + time.Sleep(60 * time.Millisecond) + + // Next request should hit upstream again (half-open state) + resp, err = http.Get(server.URL + "/test/v1/messages") + require.NoError(t, err) + resp.Body.Close() + assert.Equal(t, int32(3), upstreamCalls.Load()) } -func TestCircuitBreakers_StateTransitions(t *testing.T) { +func TestCircuitBreakerMiddleware_PerEndpointIsolation(t *testing.T) { t.Parallel() - cfg := CircuitBreakerConfig{ - Enabled: true, - FailureThreshold: 3, - Interval: time.Minute, - Timeout: 50 * time.Millisecond, - MaxRequests: 2, - } - cbs := NewCircuitBreakers(cfg, nil) + chatCalls := atomic.Int32{} + responsesCalls := atomic.Int32{} + + // Mock upstream - /chat returns 429, /responses returns 200 + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/test/v1/chat/completions" { + chatCalls.Add(1) + w.WriteHeader(http.StatusTooManyRequests) + } else { + responsesCalls.Add(1) + w.WriteHeader(http.StatusOK) + } + }) + + cbs := NewCircuitBreakers(map[string]CircuitBreakerConfig{ + "test": { + FailureThreshold: 1, + Interval: time.Minute, + Timeout: time.Minute, + MaxRequests: 1, + }, + }, nil) + + handler := CircuitBreakerMiddleware(cbs, nil, "test")(upstream) + server := httptest.NewServer(handler) + defer server.Close() + + // Trip circuit on /chat/completions + resp, err := http.Get(server.URL + "/test/v1/chat/completions") + require.NoError(t, err) + resp.Body.Close() + + // /chat/completions should now be blocked + resp, err = http.Get(server.URL + "/test/v1/chat/completions") + require.NoError(t, err) + resp.Body.Close() + assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode) + assert.Equal(t, int32(1), chatCalls.Load()) // Only 1 call, second was blocked + + // /responses should still work + resp, err = http.Get(server.URL + "/test/v1/responses") + require.NoError(t, err) + resp.Body.Close() + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, int32(1), responsesCalls.Load()) +} - // Start in closed state - assert.Equal(t, gobreaker.StateClosed, cbs.State("test", "/api")) - assert.True(t, cbs.Allow("test", "/api")) +func TestCircuitBreakerMiddleware_NotConfigured(t *testing.T) { + t.Parallel() - // Record failures below threshold - cbs.RecordFailure("test", "/api", http.StatusTooManyRequests) - cbs.RecordFailure("test", "/api", http.StatusTooManyRequests) - assert.Equal(t, gobreaker.StateClosed, cbs.State("test", "/api")) + var upstreamCalls atomic.Int32 - // Third failure should trip the circuit - tripped := cbs.RecordFailure("test", "/api", http.StatusTooManyRequests) - assert.True(t, tripped) - assert.Equal(t, gobreaker.StateOpen, cbs.State("test", "/api")) - assert.False(t, cbs.Allow("test", "/api")) + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upstreamCalls.Add(1) + w.WriteHeader(http.StatusTooManyRequests) + }) - // Wait for cooldown - time.Sleep(60 * time.Millisecond) + // No config for "test" provider + cbs := NewCircuitBreakers(nil, nil) - // Should transition to half-open and allow request - assert.True(t, cbs.Allow("test", "/api")) - assert.Equal(t, gobreaker.StateHalfOpen, cbs.State("test", "/api")) + handler := CircuitBreakerMiddleware(cbs, nil, "test")(upstream) + server := httptest.NewServer(handler) + defer server.Close() - // Success in half-open should eventually close - cbs.RecordSuccess("test", "/api") - cbs.RecordSuccess("test", "/api") - assert.Equal(t, gobreaker.StateClosed, cbs.State("test", "/api")) + // All requests should pass through even with 429s + for i := 0; i < 10; i++ { + resp, err := http.Get(server.URL + "/test/v1/messages") + require.NoError(t, err) + resp.Body.Close() + assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode) + } + assert.Equal(t, int32(10), upstreamCalls.Load()) } -func TestCircuitBreakers_PerEndpointIsolation(t *testing.T) { +func TestCircuitBreakerMiddleware_RecoveryAfterSuccess(t *testing.T) { t.Parallel() - cfg := CircuitBreakerConfig{ - Enabled: true, - FailureThreshold: 1, - Interval: time.Minute, - Timeout: time.Minute, - MaxRequests: 1, - } - cbs := NewCircuitBreakers(cfg, nil) + var returnError atomic.Bool + returnError.Store(true) - // Trip circuit for one endpoint - cbs.RecordFailure("openai", "/v1/chat/completions", http.StatusTooManyRequests) - assert.Equal(t, gobreaker.StateOpen, cbs.State("openai", "/v1/chat/completions")) + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if returnError.Load() { + w.WriteHeader(http.StatusTooManyRequests) + } else { + w.WriteHeader(http.StatusOK) + } + }) - // Other endpoints should still be closed - assert.Equal(t, gobreaker.StateClosed, cbs.State("openai", "/v1/responses")) - assert.Equal(t, gobreaker.StateClosed, cbs.State("anthropic", "/v1/messages")) - assert.True(t, cbs.Allow("openai", "/v1/responses")) - assert.True(t, cbs.Allow("anthropic", "/v1/messages")) -} + cbs := NewCircuitBreakers(map[string]CircuitBreakerConfig{ + "test": { + FailureThreshold: 2, + Interval: time.Minute, + Timeout: 50 * time.Millisecond, + MaxRequests: 1, + }, + }, nil) -func TestCircuitBreakers_OnlyCountsRelevantStatusCodes(t *testing.T) { - t.Parallel() + handler := CircuitBreakerMiddleware(cbs, nil, "test")(upstream) + server := httptest.NewServer(handler) + defer server.Close() - cfg := CircuitBreakerConfig{ - Enabled: true, - FailureThreshold: 2, - Interval: time.Minute, - Timeout: time.Minute, - MaxRequests: 2, + // Trip the circuit + for i := 0; i < 2; i++ { + resp, _ := http.Get(server.URL + "/test/v1/messages") + resp.Body.Close() } - cbs := NewCircuitBreakers(cfg, nil) - - // Non-circuit-breaker status codes should not count - cbs.RecordFailure("test", "/api", http.StatusBadRequest) // 400 - cbs.RecordFailure("test", "/api", http.StatusUnauthorized) // 401 - cbs.RecordFailure("test", "/api", http.StatusInternalServerError) // 500 - cbs.RecordFailure("test", "/api", http.StatusBadGateway) // 502 - assert.Equal(t, gobreaker.StateClosed, cbs.State("test", "/api")) - - // These should count - cbs.RecordFailure("test", "/api", http.StatusTooManyRequests) // 429 - cbs.RecordFailure("test", "/api", http.StatusServiceUnavailable) // 503 - assert.Equal(t, gobreaker.StateOpen, cbs.State("test", "/api")) -} -func TestCircuitBreakers_Anthropic529(t *testing.T) { - t.Parallel() + // Circuit should be open + resp, _ := http.Get(server.URL + "/test/v1/messages") + assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode) + resp.Body.Close() - cfg := CircuitBreakerConfig{ - Enabled: true, - FailureThreshold: 1, - Interval: time.Minute, - Timeout: time.Minute, - MaxRequests: 1, - } - cbs := NewCircuitBreakers(cfg, nil) - - // Anthropic-specific 529 "Overloaded" should trip the circuit - tripped := cbs.RecordFailure("anthropic", "/v1/messages", 529) - assert.True(t, tripped) - assert.Equal(t, gobreaker.StateOpen, cbs.State("anthropic", "/v1/messages")) -} + // Wait for timeout, switch upstream to success + time.Sleep(60 * time.Millisecond) + returnError.Store(false) -func TestCircuitBreakers_ConcurrentAccess(t *testing.T) { - t.Parallel() + // Half-open: one request allowed + resp, _ = http.Get(server.URL + "/test/v1/messages") + assert.Equal(t, http.StatusOK, resp.StatusCode) + resp.Body.Close() - cfg := CircuitBreakerConfig{ - Enabled: true, - FailureThreshold: 1000, - Interval: time.Minute, - Timeout: time.Minute, - MaxRequests: 10, - } - cbs := NewCircuitBreakers(cfg, nil) - - var wg sync.WaitGroup - for i := 0; i < 50; i++ { - wg.Add(1) - go func() { - defer wg.Done() - for j := 0; j < 100; j++ { - cbs.Allow("test", "/api") - cbs.RecordSuccess("test", "/api") - cbs.RecordFailure("test", "/api", http.StatusTooManyRequests) - cbs.State("test", "/api") - } - }() - } - wg.Wait() - // Should not panic or deadlock + // Circuit should be closed now, more requests allowed + resp, _ = http.Get(server.URL + "/test/v1/messages") + assert.Equal(t, http.StatusOK, resp.StatusCode) + resp.Body.Close() } -func TestCircuitBreakers_StateChangeCallback(t *testing.T) { +func TestCircuitBreakerMiddleware_CustomIsFailure(t *testing.T) { t.Parallel() - cfg := CircuitBreakerConfig{ - Enabled: true, - FailureThreshold: 2, - Interval: time.Minute, - Timeout: 50 * time.Millisecond, - MaxRequests: 1, - } - - var mu sync.Mutex - var transitions []struct{ from, to gobreaker.State } - - cbs := NewCircuitBreakers(cfg, func(name string, from, to gobreaker.State) { - mu.Lock() - defer mu.Unlock() - transitions = append(transitions, struct{ from, to gobreaker.State }{from, to}) + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadGateway) // 502 }) - // Trip the circuit - cbs.RecordFailure("test", "/api", http.StatusTooManyRequests) - cbs.RecordFailure("test", "/api", http.StatusTooManyRequests) - - // Wait for cooldown and trigger half-open - time.Sleep(60 * time.Millisecond) - cbs.Allow("test", "/api") - - // Success to close - cbs.RecordSuccess("test", "/api") - - // Wait for callbacks - time.Sleep(20 * time.Millisecond) - - mu.Lock() - defer mu.Unlock() - require.Len(t, transitions, 3) - assert.Equal(t, gobreaker.StateClosed, transitions[0].from) - assert.Equal(t, gobreaker.StateOpen, transitions[0].to) - assert.Equal(t, gobreaker.StateOpen, transitions[1].from) - assert.Equal(t, gobreaker.StateHalfOpen, transitions[1].to) - assert.Equal(t, gobreaker.StateHalfOpen, transitions[2].from) - assert.Equal(t, gobreaker.StateClosed, transitions[2].to) + // Custom IsFailure that treats 502 as failure + cbs := NewCircuitBreakers(map[string]CircuitBreakerConfig{ + "test": { + FailureThreshold: 1, + Interval: time.Minute, + Timeout: time.Minute, + MaxRequests: 1, + IsFailure: func(statusCode int) bool { + return statusCode == http.StatusBadGateway + }, + }, + }, nil) + + handler := CircuitBreakerMiddleware(cbs, nil, "test")(upstream) + server := httptest.NewServer(handler) + defer server.Close() + + // First request returns 502, trips circuit + resp, _ := http.Get(server.URL + "/test/v1/messages") + resp.Body.Close() + + // Second request should be blocked + resp, _ = http.Get(server.URL + "/test/v1/messages") + assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode) + resp.Body.Close() } -func TestIsCircuitBreakerFailure(t *testing.T) { +func TestDefaultIsFailure(t *testing.T) { t.Parallel() tests := []struct { @@ -230,13 +243,20 @@ func TestIsCircuitBreakerFailure(t *testing.T) { {http.StatusUnauthorized, false}, {http.StatusTooManyRequests, true}, // 429 {http.StatusInternalServerError, false}, + {http.StatusBadGateway, false}, {http.StatusServiceUnavailable, true}, // 503 {529, true}, // Anthropic Overloaded } for _, tt := range tests { - t.Run(http.StatusText(tt.statusCode), func(t *testing.T) { - assert.Equal(t, tt.isFailure, isCircuitBreakerFailure(tt.statusCode)) - }) + assert.Equal(t, tt.isFailure, DefaultIsFailure(tt.statusCode), "status code %d", tt.statusCode) } } + +func TestStateToGaugeValue(t *testing.T) { + t.Parallel() + + assert.Equal(t, float64(0), stateToGaugeValue(gobreaker.StateClosed)) + assert.Equal(t, float64(0.5), stateToGaugeValue(gobreaker.StateHalfOpen)) + assert.Equal(t, float64(1), stateToGaugeValue(gobreaker.StateOpen)) +} diff --git a/interception.go b/interception.go index da589be..62201aa 100644 --- a/interception.go +++ b/interception.go @@ -40,29 +40,13 @@ const recordingTimeout = time.Second * 5 // newInterceptionProcessor returns an [http.HandlerFunc] which is capable of creating a new interceptor and processing a given request // using [Provider] p, recording all usage events using [Recorder] recorder. -func newInterceptionProcessor(p Provider, recorder Recorder, mcpProxy mcp.ServerProxier, logger slog.Logger, metrics *Metrics, tracer trace.Tracer, cbs *CircuitBreakers) http.HandlerFunc { +func newInterceptionProcessor(p Provider, recorder Recorder, mcpProxy mcp.ServerProxier, logger slog.Logger, metrics *Metrics, tracer trace.Tracer) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx, span := tracer.Start(r.Context(), "Intercept") defer span.End() - // Extract endpoint (route) for per-endpoint circuit breaker route := strings.TrimPrefix(r.URL.Path, fmt.Sprintf("/%s", p.Name())) - // Check circuit breaker before proceeding - if !cbs.Allow(p.Name(), route) { - span.SetStatus(codes.Error, "circuit breaker open") - logger.Debug(ctx, "request rejected by circuit breaker", - slog.F("provider", p.Name()), - slog.F("endpoint", route), - slog.F("circuit_state", cbs.State(p.Name(), route)), - ) - if metrics != nil { - metrics.CircuitBreakerRejects.WithLabelValues(p.Name(), route).Inc() - } - http.Error(w, fmt.Sprintf("%s %s is currently unavailable due to upstream rate limiting. Please try again later.", p.Name(), route), http.StatusServiceUnavailable) - return - } - interceptor, err := p.CreateInterceptor(w, r.WithContext(ctx), tracer) if err != nil { span.SetStatus(codes.Error, fmt.Sprintf("failed to create interceptor: %v", err)) @@ -133,24 +117,11 @@ func newInterceptionProcessor(p Provider, recorder Recorder, mcpProxy mcp.Server } span.SetStatus(codes.Error, fmt.Sprintf("interception failed: %v", err)) log.Warn(ctx, "interception failed", slog.Error(err)) - - // Record failure for circuit breaker - extract status code if available - if statusCode := extractStatusCodeFromError(err); statusCode > 0 { - if cbs.RecordFailure(p.Name(), route, statusCode) { - log.Warn(ctx, "circuit breaker tripped", - slog.F("provider", p.Name()), - slog.F("status_code", statusCode), - ) - } - } } else { if metrics != nil { metrics.InterceptionCount.WithLabelValues(p.Name(), interceptor.Model(), InterceptionCountStatusCompleted, route, r.Method, actor.id).Add(1) } log.Debug(ctx, "interception ended") - - // Record success for circuit breaker - cbs.RecordSuccess(p.Name(), route) } asyncRecorder.RecordInterceptionEnded(ctx, &InterceptionRecordEnded{ID: interceptor.ID().String()}) @@ -158,25 +129,3 @@ func newInterceptionProcessor(p Provider, recorder Recorder, mcpProxy mcp.Server asyncRecorder.Wait() } } - -// extractStatusCodeFromError attempts to extract an HTTP status code from an error. -// This is used for circuit breaker failure tracking. -func extractStatusCodeFromError(err error) int { - if err == nil { - return 0 - } - - // Check for Anthropic error response - var antErr *AnthropicErrorResponse - if errors.As(err, &antErr) && antErr != nil { - return antErr.StatusCode - } - - // Check for OpenAI error response - var oaiErr *OpenAIErrorResponse - if errors.As(err, &oaiErr) && oaiErr != nil { - return oaiErr.StatusCode - } - - return 0 -} From 521df9b7d233fb997a132b6e3ac9453353fe5f6c Mon Sep 17 00:00:00 2001 From: Kacper Sawicki Date: Wed, 17 Dec 2025 09:17:59 +0000 Subject: [PATCH 07/26] docs: clarify noop behavior when provider not configured --- circuit_breaker.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/circuit_breaker.go b/circuit_breaker.go index 9ffd0cc..9a39538 100644 --- a/circuit_breaker.go +++ b/circuit_breaker.go @@ -139,11 +139,12 @@ func (w *statusCapturingWriter) Write(b []byte) (int, error) { // CircuitBreakerMiddleware returns middleware that wraps handlers with circuit breaker protection. // It captures the response status code to determine success/failure without provider-specific logic. +// If the provider is not configured, returns a noop middleware (passes through without any circuit breaker). func CircuitBreakerMiddleware(cbs *CircuitBreakers, metrics *Metrics, provider string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { cfg := cbs.getConfig(provider) if cfg == nil { - // No config for this provider, pass through + // Noop: no config for this provider, pass through without circuit breaker return next } From c85b8365d21956a145c7a96227001afa6d837873 Mon Sep 17 00:00:00 2001 From: Kacper Sawicki Date: Wed, 17 Dec 2025 10:15:01 +0000 Subject: [PATCH 08/26] Update go.mod --- go.mod | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/go.mod b/go.mod index 4715f99..329c7bf 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/hashicorp/go-multierror v1.1.1 github.com/mark3labs/mcp-go v0.38.0 github.com/prometheus/client_golang v1.23.2 + github.com/sony/gobreaker/v2 v2.3.0 github.com/stretchr/testify v1.11.1 github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 @@ -33,8 +34,6 @@ require ( go.opentelemetry.io/otel/trace v1.38.0 ) -require github.com/sony/gobreaker/v2 v2.3.0 - require ( github.com/aws/aws-sdk-go-v2 v1.30.3 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.3 // indirect From e4469544c223b84652bdf9b400af710940fa1fff Mon Sep 17 00:00:00 2001 From: Kacper Sawicki Date: Wed, 17 Dec 2025 10:17:54 +0000 Subject: [PATCH 09/26] fix: update metrics help text to reflect 0/0.5/1 gauge values --- metrics.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/metrics.go b/metrics.go index f744d10..e7fbb5c 100644 --- a/metrics.go +++ b/metrics.go @@ -30,7 +30,7 @@ type Metrics struct { NonInjectedToolUseCount *prometheus.CounterVec // Circuit breaker metrics. - CircuitBreakerState *prometheus.GaugeVec // Current state (0=closed, 1=open, 2=half-open) + CircuitBreakerState *prometheus.GaugeVec // Current state (0=closed, 0.5=half-open, 1=open) CircuitBreakerTrips *prometheus.CounterVec // Total times circuit opened CircuitBreakerRejects *prometheus.CounterVec // Requests rejected due to open circuit } @@ -114,7 +114,7 @@ func NewMetrics(reg prometheus.Registerer) *Metrics { CircuitBreakerState: promauto.With(reg).NewGaugeVec(prometheus.GaugeOpts{ Subsystem: "circuit_breaker", Name: "state", - Help: "Current state of the circuit breaker (0=closed, 1=open, 2=half-open).", + Help: "Current state of the circuit breaker (0=closed, 0.5=half-open, 1=open).", }, []string{"provider", "endpoint"}), // Pessimistic cardinality: 2 providers, 5 endpoints = up to 10. CircuitBreakerTrips: promauto.With(reg).NewCounterVec(prometheus.CounterOpts{ From 1d2315edfa9004d3faf8daf44c9eaf4c40b4c385 Mon Sep 17 00:00:00 2001 From: Kacper Sawicki Date: Wed, 17 Dec 2025 10:20:58 +0000 Subject: [PATCH 10/26] refactor: add CircuitBreaker interface with NoopCircuitBreaker - Add CircuitBreaker interface with Allow(), RecordSuccess(), RecordFailure() - Add NoopCircuitBreaker struct for providers without circuit breaker config - Add gobreakerCircuitBreaker wrapping sony/gobreaker implementation - CircuitBreakers.Get() returns NoopCircuitBreaker when provider not configured - Add http.Flusher support to statusCapturingWriter for SSE streaming - Add Unwrap() for ResponseWriter interface detection --- circuit_breaker.go | 111 +++++++++++++++++++++++++++++++-------------- 1 file changed, 78 insertions(+), 33 deletions(-) diff --git a/circuit_breaker.go b/circuit_breaker.go index 9a39538..5cbf8eb 100644 --- a/circuit_breaker.go +++ b/circuit_breaker.go @@ -51,17 +51,57 @@ func DefaultIsFailure(statusCode int) bool { } } +// CircuitBreaker defines the interface for circuit breaker implementations. +type CircuitBreaker interface { + // Allow returns true if the request should be allowed through. + Allow() bool + // RecordSuccess records a successful request. + RecordSuccess() + // RecordFailure records a failed request with the given status code. + RecordFailure(statusCode int) +} + +// NoopCircuitBreaker is a circuit breaker that always allows requests through. +// Used when circuit breaker is not configured for a provider. +type NoopCircuitBreaker struct{} + +func (NoopCircuitBreaker) Allow() bool { return true } +func (NoopCircuitBreaker) RecordSuccess() {} +func (NoopCircuitBreaker) RecordFailure(statusCode int) {} + +// gobreakerCircuitBreaker wraps sony/gobreaker to implement CircuitBreaker. +type gobreakerCircuitBreaker struct { + cb *gobreaker.CircuitBreaker[any] + isFailure func(statusCode int) bool +} + +func (g *gobreakerCircuitBreaker) Allow() bool { + return g.cb.State() != gobreaker.StateOpen +} + +func (g *gobreakerCircuitBreaker) RecordSuccess() { + _, _ = g.cb.Execute(func() (any, error) { return nil, nil }) +} + +func (g *gobreakerCircuitBreaker) RecordFailure(statusCode int) { + if g.isFailure(statusCode) { + _, _ = g.cb.Execute(func() (any, error) { + return nil, fmt.Errorf("upstream error: %d", statusCode) + }) + } +} + // CircuitBreakers manages per-endpoint circuit breakers using sony/gobreaker. // Circuit breakers are keyed by "provider:endpoint" for per-endpoint isolation. type CircuitBreakers struct { - breakers sync.Map // map[string]*gobreaker.CircuitBreaker[any] + breakers sync.Map // map[string]CircuitBreaker configs map[string]CircuitBreakerConfig onChange func(name string, from, to gobreaker.State) } // NewCircuitBreakers creates a new circuit breaker manager with per-provider configs. -// The configs map is keyed by provider name. Providers not in the map will not have -// circuit breaker protection. +// The configs map is keyed by provider name. Providers not in the map will use +// NoopCircuitBreaker (always allows requests). func NewCircuitBreakers(configs map[string]CircuitBreakerConfig, onChange func(name string, from, to gobreaker.State)) *CircuitBreakers { return &CircuitBreakers{ configs: configs, @@ -81,17 +121,22 @@ func (c *CircuitBreakers) getConfig(provider string) *CircuitBreakerConfig { return &cfg } -// getOrCreate returns the circuit breaker for a provider/endpoint, creating if needed. -// Returns nil if the provider is not configured. -func (c *CircuitBreakers) getOrCreate(provider, endpoint string) *gobreaker.CircuitBreaker[any] { +// Get returns the circuit breaker for a provider/endpoint. +// Returns NoopCircuitBreaker if the provider is not configured. +func (c *CircuitBreakers) Get(provider, endpoint string) CircuitBreaker { cfg := c.getConfig(provider) if cfg == nil { - return nil + return NoopCircuitBreaker{} } key := provider + ":" + endpoint if v, ok := c.breakers.Load(key); ok { - return v.(*gobreaker.CircuitBreaker[any]) + return v.(CircuitBreaker) + } + + isFailure := cfg.IsFailure + if isFailure == nil { + isFailure = DefaultIsFailure } settings := gobreaker.Settings{ @@ -109,12 +154,16 @@ func (c *CircuitBreakers) getOrCreate(provider, endpoint string) *gobreaker.Circ }, } - cb := gobreaker.NewCircuitBreaker[any](settings) + cb := &gobreakerCircuitBreaker{ + cb: gobreaker.NewCircuitBreaker[any](settings), + isFailure: isFailure, + } actual, _ := c.breakers.LoadOrStore(key, cb) - return actual.(*gobreaker.CircuitBreaker[any]) + return actual.(CircuitBreaker) } // statusCapturingWriter wraps http.ResponseWriter to capture the status code. +// It also implements http.Flusher to support streaming responses. type statusCapturingWriter struct { http.ResponseWriter statusCode int @@ -137,28 +186,28 @@ func (w *statusCapturingWriter) Write(b []byte) (int, error) { return w.ResponseWriter.Write(b) } +func (w *statusCapturingWriter) Flush() { + if f, ok := w.ResponseWriter.(http.Flusher); ok { + f.Flush() + } +} + +// Unwrap returns the underlying ResponseWriter for interface checks. +func (w *statusCapturingWriter) Unwrap() http.ResponseWriter { + return w.ResponseWriter +} + // CircuitBreakerMiddleware returns middleware that wraps handlers with circuit breaker protection. // It captures the response status code to determine success/failure without provider-specific logic. -// If the provider is not configured, returns a noop middleware (passes through without any circuit breaker). +// If the provider is not configured, uses NoopCircuitBreaker (always allows requests). func CircuitBreakerMiddleware(cbs *CircuitBreakers, metrics *Metrics, provider string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { - cfg := cbs.getConfig(provider) - if cfg == nil { - // Noop: no config for this provider, pass through without circuit breaker - return next - } - - isFailure := cfg.IsFailure - if isFailure == nil { - isFailure = DefaultIsFailure - } - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { endpoint := strings.TrimPrefix(r.URL.Path, fmt.Sprintf("/%s", provider)) + cb := cbs.Get(provider, endpoint) // Check if circuit is open - cb := cbs.getOrCreate(provider, endpoint) - if cb != nil && cb.State() == gobreaker.StateOpen { + if !cb.Allow() { if metrics != nil { metrics.CircuitBreakerRejects.WithLabelValues(provider, endpoint).Inc() } @@ -170,15 +219,11 @@ func CircuitBreakerMiddleware(cbs *CircuitBreakers, metrics *Metrics, provider s sw := &statusCapturingWriter{ResponseWriter: w, statusCode: http.StatusOK} next.ServeHTTP(sw, r) - // Record result - if cb != nil { - if isFailure(sw.statusCode) { - _, _ = cb.Execute(func() (any, error) { - return nil, fmt.Errorf("upstream error: %d", sw.statusCode) - }) - } else { - _, _ = cb.Execute(func() (any, error) { return nil, nil }) - } + // Record result - NoopCircuitBreaker methods are no-ops + if sw.statusCode >= 400 { + cb.RecordFailure(sw.statusCode) + } else { + cb.RecordSuccess() } }) } From 6994f89b3b1bf6aa7dea6a33abc3d9fed5d5fa02 Mon Sep 17 00:00:00 2001 From: Kacper Sawicki Date: Wed, 17 Dec 2025 10:27:29 +0000 Subject: [PATCH 11/26] refactor: use gobreaker Execute for proper half-open rejection handling - Changed CircuitBreaker interface to Execute(fn func() int) (statusCode, rejected) - Use gobreaker.Execute() to properly handle both ErrOpenState and ErrTooManyRequests - NoopCircuitBreaker.Execute simply runs the function and returns not rejected - Simplified middleware by removing separate Allow/Record pattern --- circuit_breaker.go | 72 +++++++++++++++++++++------------------------- 1 file changed, 33 insertions(+), 39 deletions(-) diff --git a/circuit_breaker.go b/circuit_breaker.go index 5cbf8eb..04b86e8 100644 --- a/circuit_breaker.go +++ b/circuit_breaker.go @@ -1,6 +1,7 @@ package aibridge import ( + "errors" "fmt" "net/http" "strings" @@ -53,42 +54,40 @@ func DefaultIsFailure(statusCode int) bool { // CircuitBreaker defines the interface for circuit breaker implementations. type CircuitBreaker interface { - // Allow returns true if the request should be allowed through. - Allow() bool - // RecordSuccess records a successful request. - RecordSuccess() - // RecordFailure records a failed request with the given status code. - RecordFailure(statusCode int) + // Execute runs the given function if the circuit allows it. + // Returns (statusCode, rejected) where rejected is true if the circuit breaker blocked the request. + Execute(fn func() int) (statusCode int, rejected bool) } // NoopCircuitBreaker is a circuit breaker that always allows requests through. // Used when circuit breaker is not configured for a provider. type NoopCircuitBreaker struct{} -func (NoopCircuitBreaker) Allow() bool { return true } -func (NoopCircuitBreaker) RecordSuccess() {} -func (NoopCircuitBreaker) RecordFailure(statusCode int) {} +func (NoopCircuitBreaker) Execute(fn func() int) (int, bool) { + return fn(), false +} // gobreakerCircuitBreaker wraps sony/gobreaker to implement CircuitBreaker. type gobreakerCircuitBreaker struct { - cb *gobreaker.CircuitBreaker[any] + cb *gobreaker.CircuitBreaker[int] isFailure func(statusCode int) bool } -func (g *gobreakerCircuitBreaker) Allow() bool { - return g.cb.State() != gobreaker.StateOpen -} - -func (g *gobreakerCircuitBreaker) RecordSuccess() { - _, _ = g.cb.Execute(func() (any, error) { return nil, nil }) -} - -func (g *gobreakerCircuitBreaker) RecordFailure(statusCode int) { - if g.isFailure(statusCode) { - _, _ = g.cb.Execute(func() (any, error) { - return nil, fmt.Errorf("upstream error: %d", statusCode) - }) +func (g *gobreakerCircuitBreaker) Execute(fn func() int) (int, bool) { + statusCode, err := g.cb.Execute(func() (int, error) { + code := fn() + if g.isFailure(code) { + return code, fmt.Errorf("upstream error: %d", code) + } + return code, nil + }) + if err != nil { + // Check if rejected by circuit breaker (open or half-open with too many requests) + if errors.Is(err, gobreaker.ErrOpenState) || errors.Is(err, gobreaker.ErrTooManyRequests) { + return 0, true + } } + return statusCode, false } // CircuitBreakers manages per-endpoint circuit breakers using sony/gobreaker. @@ -155,7 +154,7 @@ func (c *CircuitBreakers) Get(provider, endpoint string) CircuitBreaker { } cb := &gobreakerCircuitBreaker{ - cb: gobreaker.NewCircuitBreaker[any](settings), + cb: gobreaker.NewCircuitBreaker[int](settings), isFailure: isFailure, } actual, _ := c.breakers.LoadOrStore(key, cb) @@ -203,27 +202,22 @@ func (w *statusCapturingWriter) Unwrap() http.ResponseWriter { func CircuitBreakerMiddleware(cbs *CircuitBreakers, metrics *Metrics, provider string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - endpoint := strings.TrimPrefix(r.URL.Path, fmt.Sprintf("/%s", provider)) + endpoint := strings.TrimPrefix(r.URL.Path, "/"+provider) cb := cbs.Get(provider, endpoint) - // Check if circuit is open - if !cb.Allow() { + // Wrap response writer to capture status code + sw := &statusCapturingWriter{ResponseWriter: w, statusCode: http.StatusOK} + + _, rejected := cb.Execute(func() int { + next.ServeHTTP(sw, r) + return sw.statusCode + }) + + if rejected { if metrics != nil { metrics.CircuitBreakerRejects.WithLabelValues(provider, endpoint).Inc() } http.Error(w, "circuit breaker is open", http.StatusServiceUnavailable) - return - } - - // Wrap response writer to capture status code - sw := &statusCapturingWriter{ResponseWriter: w, statusCode: http.StatusOK} - next.ServeHTTP(sw, r) - - // Record result - NoopCircuitBreaker methods are no-ops - if sw.statusCode >= 400 { - cb.RecordFailure(sw.statusCode) - } else { - cb.RecordSuccess() } }) } From 6a7d57870ff4ddd6c9becaf52b59f9ff40f84343 Mon Sep 17 00:00:00 2001 From: Kacper Sawicki Date: Wed, 17 Dec 2025 10:28:51 +0000 Subject: [PATCH 12/26] refactor: remove unused circuitBreakers field and getter from RequestBridge --- bridge.go | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/bridge.go b/bridge.go index bef401c..6dead52 100644 --- a/bridge.go +++ b/bridge.go @@ -31,10 +31,6 @@ type RequestBridge struct { mcpProxy mcp.ServerProxier - // circuitBreakers manages circuit breakers for upstream providers. - // When enabled, it protects against cascading failures from upstream rate limits. - circuitBreakers *CircuitBreakers - inflightReqs atomic.Int32 inflightWG sync.WaitGroup // For graceful shutdown. @@ -105,11 +101,10 @@ func NewRequestBridgeWithCircuitBreaker(ctx context.Context, providers []Provide inflightCtx, cancel := context.WithCancel(context.Background()) return &RequestBridge{ - mux: mux, - logger: logger, - mcpProxy: mcpProxy, - circuitBreakers: cbs, - inflightCtx: inflightCtx, + mux: mux, + logger: logger, + mcpProxy: mcpProxy, + inflightCtx: inflightCtx, inflightCancel: cancel, closed: make(chan struct{}, 1), @@ -182,11 +177,6 @@ func (b *RequestBridge) InflightRequests() int32 { return b.inflightReqs.Load() } -// CircuitBreakers returns the circuit breakers for this bridge. -func (b *RequestBridge) CircuitBreakers() *CircuitBreakers { - return b.circuitBreakers -} - // mergeContexts merges two contexts together, so that if either is cancelled // the returned context is cancelled. The context values will only be used from // the first context. From b0ff0eb78aae6de243bea373ef532487c754f2ac Mon Sep 17 00:00:00 2001 From: Kacper Sawicki Date: Wed, 17 Dec 2025 10:59:46 +0000 Subject: [PATCH 13/26] use per-provider maps for endpoints --- bridge.go | 6 ++---- circuit_breaker.go | 27 +++++++++++++++++---------- 2 files changed, 19 insertions(+), 14 deletions(-) diff --git a/bridge.go b/bridge.go index 6dead52..a9e412b 100644 --- a/bridge.go +++ b/bridge.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "net/http" - "strings" "sync" "sync/atomic" @@ -60,10 +59,9 @@ func NewRequestBridgeWithCircuitBreaker(ctx context.Context, providers []Provide mux := http.NewServeMux() // Create circuit breakers with metrics callback - var onChange func(name string, from, to gobreaker.State) + var onChange func(provider, endpoint string, from, to gobreaker.State) if metrics != nil { - onChange = func(name string, from, to gobreaker.State) { - provider, endpoint, _ := strings.Cut(name, ":") + onChange = func(provider, endpoint string, from, to gobreaker.State) { metrics.CircuitBreakerState.WithLabelValues(provider, endpoint).Set(stateToGaugeValue(to)) if to == gobreaker.StateOpen { metrics.CircuitBreakerTrips.WithLabelValues(provider, endpoint).Inc() diff --git a/circuit_breaker.go b/circuit_breaker.go index 04b86e8..0f87dd2 100644 --- a/circuit_breaker.go +++ b/circuit_breaker.go @@ -91,17 +91,18 @@ func (g *gobreakerCircuitBreaker) Execute(fn func() int) (int, bool) { } // CircuitBreakers manages per-endpoint circuit breakers using sony/gobreaker. -// Circuit breakers are keyed by "provider:endpoint" for per-endpoint isolation. +// Organized as a per-provider map with endpoint keys. type CircuitBreakers struct { - breakers sync.Map // map[string]CircuitBreaker + // breakers is map[provider]*sync.Map where inner map is endpoint -> CircuitBreaker + breakers sync.Map configs map[string]CircuitBreakerConfig - onChange func(name string, from, to gobreaker.State) + onChange func(provider, endpoint string, from, to gobreaker.State) } // NewCircuitBreakers creates a new circuit breaker manager with per-provider configs. // The configs map is keyed by provider name. Providers not in the map will use // NoopCircuitBreaker (always allows requests). -func NewCircuitBreakers(configs map[string]CircuitBreakerConfig, onChange func(name string, from, to gobreaker.State)) *CircuitBreakers { +func NewCircuitBreakers(configs map[string]CircuitBreakerConfig, onChange func(provider, endpoint string, from, to gobreaker.State)) *CircuitBreakers { return &CircuitBreakers{ configs: configs, onChange: onChange, @@ -120,6 +121,12 @@ func (c *CircuitBreakers) getConfig(provider string) *CircuitBreakerConfig { return &cfg } +// getProviderBreakers returns the endpoint map for a provider, creating it if needed. +func (c *CircuitBreakers) getProviderBreakers(provider string) *sync.Map { + v, _ := c.breakers.LoadOrStore(provider, &sync.Map{}) + return v.(*sync.Map) +} + // Get returns the circuit breaker for a provider/endpoint. // Returns NoopCircuitBreaker if the provider is not configured. func (c *CircuitBreakers) Get(provider, endpoint string) CircuitBreaker { @@ -128,8 +135,8 @@ func (c *CircuitBreakers) Get(provider, endpoint string) CircuitBreaker { return NoopCircuitBreaker{} } - key := provider + ":" + endpoint - if v, ok := c.breakers.Load(key); ok { + providerBreakers := c.getProviderBreakers(provider) + if v, ok := providerBreakers.Load(endpoint); ok { return v.(CircuitBreaker) } @@ -139,16 +146,16 @@ func (c *CircuitBreakers) Get(provider, endpoint string) CircuitBreaker { } settings := gobreaker.Settings{ - Name: key, + Name: provider + ":" + endpoint, MaxRequests: cfg.MaxRequests, Interval: cfg.Interval, Timeout: cfg.Timeout, ReadyToTrip: func(counts gobreaker.Counts) bool { return counts.ConsecutiveFailures >= cfg.FailureThreshold }, - OnStateChange: func(name string, from, to gobreaker.State) { + OnStateChange: func(_ string, from, to gobreaker.State) { if c.onChange != nil { - c.onChange(name, from, to) + c.onChange(provider, endpoint, from, to) } }, } @@ -157,7 +164,7 @@ func (c *CircuitBreakers) Get(provider, endpoint string) CircuitBreaker { cb: gobreaker.NewCircuitBreaker[int](settings), isFailure: isFailure, } - actual, _ := c.breakers.LoadOrStore(key, cb) + actual, _ := providerBreakers.LoadOrStore(endpoint, cb) return actual.(CircuitBreaker) } From bee7a4d87396e0154dfe54f0cf17878d3a738870 Mon Sep 17 00:00:00 2001 From: Kacper Sawicki Date: Wed, 17 Dec 2025 12:38:53 +0000 Subject: [PATCH 14/26] make fmt --- bridge.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridge.go b/bridge.go index a9e412b..8c63d7a 100644 --- a/bridge.go +++ b/bridge.go @@ -103,7 +103,7 @@ func NewRequestBridgeWithCircuitBreaker(ctx context.Context, providers []Provide logger: logger, mcpProxy: mcpProxy, inflightCtx: inflightCtx, - inflightCancel: cancel, + inflightCancel: cancel, closed: make(chan struct{}, 1), }, nil From 98c7b7abafd2f5051f618ec7bcb2fb935a89949c Mon Sep 17 00:00:00 2001 From: Kacper Sawicki Date: Wed, 17 Dec 2025 12:51:25 +0000 Subject: [PATCH 15/26] use mux.Handle for cb middleware --- bridge.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bridge.go b/bridge.go index 8c63d7a..aa4feb7 100644 --- a/bridge.go +++ b/bridge.go @@ -75,8 +75,8 @@ func NewRequestBridgeWithCircuitBreaker(ctx context.Context, providers []Provide for _, path := range provider.BridgedRoutes() { handler := newInterceptionProcessor(provider, recorder, mcpProxy, logger, metrics, tracer) // Wrap with circuit breaker middleware if configured for this provider - handler = CircuitBreakerMiddleware(cbs, metrics, provider.Name())(handler).ServeHTTP - mux.HandleFunc(path, handler) + wrapped := CircuitBreakerMiddleware(cbs, metrics, provider.Name())(handler) + mux.Handle(path, wrapped) } // Any requests which passthrough to this will be reverse-proxied to the upstream. From 773326670fe2512cfae6d4b0c00ca9d87a520582 Mon Sep 17 00:00:00 2001 From: Kacper Sawicki Date: Wed, 17 Dec 2025 13:03:26 +0000 Subject: [PATCH 16/26] Move CircuitBreakerConfig to the Provider struct --- bridge.go | 20 ++++++++++++-------- config.go | 3 ++- provider.go | 3 +++ provider_anthropic.go | 4 ++++ provider_openai.go | 12 +++++++++--- 5 files changed, 30 insertions(+), 12 deletions(-) diff --git a/bridge.go b/bridge.go index aa4feb7..db87c33 100644 --- a/bridge.go +++ b/bridge.go @@ -48,16 +48,20 @@ var _ http.Handler = &RequestBridge{} // A [Recorder] is also required to record prompt, tool, and token use. // // mcpProxy will be closed when the [RequestBridge] is closed. +// +// Circuit breaker configuration is obtained from each provider's CircuitBreakerConfig() method. +// Providers returning nil will not have circuit breaker protection. func NewRequestBridge(ctx context.Context, providers []Provider, recorder Recorder, mcpProxy mcp.ServerProxier, logger slog.Logger, metrics *Metrics, tracer trace.Tracer) (*RequestBridge, error) { - return NewRequestBridgeWithCircuitBreaker(ctx, providers, recorder, mcpProxy, logger, metrics, tracer, nil) -} - -// NewRequestBridgeWithCircuitBreaker creates a new *[RequestBridge] with per-provider circuit breaker configuration. -// The cbConfigs map is keyed by provider name. Providers not in the map will not have circuit breaker protection. -// Pass nil to disable circuit breakers entirely. -func NewRequestBridgeWithCircuitBreaker(ctx context.Context, providers []Provider, recorder Recorder, mcpProxy mcp.ServerProxier, logger slog.Logger, metrics *Metrics, tracer trace.Tracer, cbConfigs map[string]CircuitBreakerConfig) (*RequestBridge, error) { mux := http.NewServeMux() + // Build circuit breaker configs from providers + cbConfigs := make(map[string]CircuitBreakerConfig) + for _, p := range providers { + if cfg := p.CircuitBreakerConfig(); cfg != nil { + cbConfigs[p.Name()] = *cfg + } + } + // Create circuit breakers with metrics callback var onChange func(provider, endpoint string, from, to gobreaker.State) if metrics != nil { @@ -74,7 +78,7 @@ func NewRequestBridgeWithCircuitBreaker(ctx context.Context, providers []Provide // Add the known provider-specific routes which are bridged (i.e. intercepted and augmented). for _, path := range provider.BridgedRoutes() { handler := newInterceptionProcessor(provider, recorder, mcpProxy, logger, metrics, tracer) - // Wrap with circuit breaker middleware if configured for this provider + // Wrap with circuit breaker middleware (uses NoopCircuitBreaker if not configured) wrapped := CircuitBreakerMiddleware(cbs, metrics, provider.Name())(handler) mux.Handle(path, wrapped) } diff --git a/config.go b/config.go index 8dc6f1d..ff1f639 100644 --- a/config.go +++ b/config.go @@ -1,7 +1,8 @@ package aibridge type ProviderConfig struct { - BaseURL, Key string + BaseURL, Key string + CircuitBreaker *CircuitBreakerConfig } type ( diff --git a/provider.go b/provider.go index 20f8f52..a2d1d87 100644 --- a/provider.go +++ b/provider.go @@ -33,4 +33,7 @@ type Provider interface { AuthHeader() string // InjectAuthHeader allows [Provider]s to set its authentication header. InjectAuthHeader(*http.Header) + + // CircuitBreakerConfig returns the circuit breaker configuration for the provider. + CircuitBreakerConfig() *CircuitBreakerConfig } diff --git a/provider_anthropic.go b/provider_anthropic.go index fb5d10b..d07502e 100644 --- a/provider_anthropic.go +++ b/provider_anthropic.go @@ -108,6 +108,10 @@ func (p *AnthropicProvider) InjectAuthHeader(headers *http.Header) { headers.Set(p.AuthHeader(), p.cfg.Key) } +func (p *AnthropicProvider) CircuitBreakerConfig() *CircuitBreakerConfig { + return p.cfg.CircuitBreaker +} + func getAnthropicErrorResponse(err error) *AnthropicErrorResponse { var apierr *anthropic.Error if !errors.As(err, &apierr) { diff --git a/provider_openai.go b/provider_openai.go index 68777e7..65288f6 100644 --- a/provider_openai.go +++ b/provider_openai.go @@ -17,7 +17,8 @@ var _ Provider = &OpenAIProvider{} // OpenAIProvider allows for interactions with the OpenAI API. type OpenAIProvider struct { - baseURL, key string + baseURL, key string + circuitBreaker *CircuitBreakerConfig } const ( @@ -36,8 +37,9 @@ func NewOpenAIProvider(cfg OpenAIConfig) *OpenAIProvider { } return &OpenAIProvider{ - baseURL: cfg.BaseURL, - key: cfg.Key, + baseURL: cfg.BaseURL, + key: cfg.Key, + circuitBreaker: cfg.CircuitBreaker, } } @@ -108,3 +110,7 @@ func (p *OpenAIProvider) InjectAuthHeader(headers *http.Header) { headers.Set(p.AuthHeader(), "Bearer "+p.key) } + +func (p *OpenAIProvider) CircuitBreakerConfig() *CircuitBreakerConfig { + return p.circuitBreaker +} From 7c7c85b3db75f90bfa675afc212ddebfd40aa987 Mon Sep 17 00:00:00 2001 From: Kacper Sawicki Date: Wed, 17 Dec 2025 13:10:26 +0000 Subject: [PATCH 17/26] Update tests --- circuit_breaker_test.go | 35 +++++++++++++++++++++-------------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/circuit_breaker_test.go b/circuit_breaker_test.go index dca23e2..5e9e767 100644 --- a/circuit_breaker_test.go +++ b/circuit_breaker_test.go @@ -57,14 +57,16 @@ func TestCircuitBreakerMiddleware_TripsOnUpstreamErrors(t *testing.T) { assert.Contains(t, string(body), "circuit breaker is open") assert.Equal(t, int32(2), upstreamCalls.Load()) // No new upstream call - // Wait for timeout, verify recovery - time.Sleep(60 * time.Millisecond) - - // Next request should hit upstream again (half-open state) - resp, err = http.Get(server.URL + "/test/v1/messages") - require.NoError(t, err) - resp.Body.Close() - assert.Equal(t, int32(3), upstreamCalls.Load()) + // Wait for timeout, verify recovery (circuit transitions to half-open) + require.Eventually(t, func() bool { + resp, err = http.Get(server.URL + "/test/v1/messages") + if err != nil { + return false + } + resp.Body.Close() + // Request hit upstream again (half-open state allows probe request) + return upstreamCalls.Load() == 3 + }, 5*time.Second, 25*time.Millisecond) } func TestCircuitBreakerMiddleware_PerEndpointIsolation(t *testing.T) { @@ -182,14 +184,19 @@ func TestCircuitBreakerMiddleware_RecoveryAfterSuccess(t *testing.T) { assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode) resp.Body.Close() - // Wait for timeout, switch upstream to success - time.Sleep(60 * time.Millisecond) + // Switch upstream to success before we start polling returnError.Store(false) - // Half-open: one request allowed - resp, _ = http.Get(server.URL + "/test/v1/messages") - assert.Equal(t, http.StatusOK, resp.StatusCode) - resp.Body.Close() + // Wait for timeout (circuit transitions to half-open), then verify recovery + require.Eventually(t, func() bool { + resp, err := http.Get(server.URL + "/test/v1/messages") + if err != nil { + return false + } + defer resp.Body.Close() + // Half-open: request goes through and succeeds + return resp.StatusCode == http.StatusOK + }, 5*time.Second, 25*time.Millisecond) // Circuit should be closed now, more requests allowed resp, _ = http.Get(server.URL + "/test/v1/messages") From 8943ef01f6bf33b1ff682eb383e0f5e1393275c0 Mon Sep 17 00:00:00 2001 From: Kacper Sawicki Date: Wed, 17 Dec 2025 15:25:45 +0000 Subject: [PATCH 18/26] default noop func for onChange --- bridge.go | 2 +- circuit_breaker.go | 4 +--- circuit_breaker_test.go | 8 ++++---- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/bridge.go b/bridge.go index db87c33..00e2654 100644 --- a/bridge.go +++ b/bridge.go @@ -63,7 +63,7 @@ func NewRequestBridge(ctx context.Context, providers []Provider, recorder Record } // Create circuit breakers with metrics callback - var onChange func(provider, endpoint string, from, to gobreaker.State) + onChange := func(provider, endpoint string, from, to gobreaker.State) {} if metrics != nil { onChange = func(provider, endpoint string, from, to gobreaker.State) { metrics.CircuitBreakerState.WithLabelValues(provider, endpoint).Set(stateToGaugeValue(to)) diff --git a/circuit_breaker.go b/circuit_breaker.go index 0f87dd2..5df67f9 100644 --- a/circuit_breaker.go +++ b/circuit_breaker.go @@ -154,9 +154,7 @@ func (c *CircuitBreakers) Get(provider, endpoint string) CircuitBreaker { return counts.ConsecutiveFailures >= cfg.FailureThreshold }, OnStateChange: func(_ string, from, to gobreaker.State) { - if c.onChange != nil { - c.onChange(provider, endpoint, from, to) - } + c.onChange(provider, endpoint, from, to) }, } diff --git a/circuit_breaker_test.go b/circuit_breaker_test.go index 5e9e767..ee80f3b 100644 --- a/circuit_breaker_test.go +++ b/circuit_breaker_test.go @@ -32,7 +32,7 @@ func TestCircuitBreakerMiddleware_TripsOnUpstreamErrors(t *testing.T) { Timeout: 50 * time.Millisecond, MaxRequests: 1, }, - }, nil) + }, func(provider, endpoint string, from, to gobreaker.State) {}) // Wrap upstream with circuit breaker middleware handler := CircuitBreakerMiddleware(cbs, nil, "test")(upstream) @@ -93,7 +93,7 @@ func TestCircuitBreakerMiddleware_PerEndpointIsolation(t *testing.T) { Timeout: time.Minute, MaxRequests: 1, }, - }, nil) + }, func(provider, endpoint string, from, to gobreaker.State) {}) handler := CircuitBreakerMiddleware(cbs, nil, "test")(upstream) server := httptest.NewServer(handler) @@ -167,7 +167,7 @@ func TestCircuitBreakerMiddleware_RecoveryAfterSuccess(t *testing.T) { Timeout: 50 * time.Millisecond, MaxRequests: 1, }, - }, nil) + }, func(provider, endpoint string, from, to gobreaker.State) {}) handler := CircuitBreakerMiddleware(cbs, nil, "test")(upstream) server := httptest.NewServer(handler) @@ -222,7 +222,7 @@ func TestCircuitBreakerMiddleware_CustomIsFailure(t *testing.T) { return statusCode == http.StatusBadGateway }, }, - }, nil) + }, func(provider, endpoint string, from, to gobreaker.State) {}) handler := CircuitBreakerMiddleware(cbs, nil, "test")(upstream) server := httptest.NewServer(handler) From 7d2dcb1017511ececefe33153d52e20bf2d358bb Mon Sep 17 00:00:00 2001 From: Kacper Sawicki Date: Wed, 17 Dec 2025 15:36:09 +0000 Subject: [PATCH 19/26] create CircuitBreakers per Provider instead of a global one and remove gobreakerCircuitBraker along with the interface and noop struct --- bridge.go | 34 ++++----- circuit_breaker.go | 148 ++++++++++++---------------------------- circuit_breaker_test.go | 66 ++++++++---------- 3 files changed, 86 insertions(+), 162 deletions(-) diff --git a/bridge.go b/bridge.go index 00e2654..bb0c9d2 100644 --- a/bridge.go +++ b/bridge.go @@ -54,31 +54,27 @@ var _ http.Handler = &RequestBridge{} func NewRequestBridge(ctx context.Context, providers []Provider, recorder Recorder, mcpProxy mcp.ServerProxier, logger slog.Logger, metrics *Metrics, tracer trace.Tracer) (*RequestBridge, error) { mux := http.NewServeMux() - // Build circuit breaker configs from providers - cbConfigs := make(map[string]CircuitBreakerConfig) - for _, p := range providers { - if cfg := p.CircuitBreakerConfig(); cfg != nil { - cbConfigs[p.Name()] = *cfg - } - } - - // Create circuit breakers with metrics callback - onChange := func(provider, endpoint string, from, to gobreaker.State) {} - if metrics != nil { - onChange = func(provider, endpoint string, from, to gobreaker.State) { - metrics.CircuitBreakerState.WithLabelValues(provider, endpoint).Set(stateToGaugeValue(to)) - if to == gobreaker.StateOpen { - metrics.CircuitBreakerTrips.WithLabelValues(provider, endpoint).Inc() + for _, provider := range providers { + // Create per-provider circuit breaker if configured + var cbs *ProviderCircuitBreakers + if cfg := provider.CircuitBreakerConfig(); cfg != nil { + onChange := func(endpoint string, from, to gobreaker.State) {} + if metrics != nil { + providerName := provider.Name() + onChange = func(endpoint string, from, to gobreaker.State) { + metrics.CircuitBreakerState.WithLabelValues(providerName, endpoint).Set(stateToGaugeValue(to)) + if to == gobreaker.StateOpen { + metrics.CircuitBreakerTrips.WithLabelValues(providerName, endpoint).Inc() + } + } } + cbs = NewProviderCircuitBreakers(provider.Name(), *cfg, onChange) } - } - cbs := NewCircuitBreakers(cbConfigs, onChange) - for _, provider := range providers { // Add the known provider-specific routes which are bridged (i.e. intercepted and augmented). for _, path := range provider.BridgedRoutes() { handler := newInterceptionProcessor(provider, recorder, mcpProxy, logger, metrics, tracer) - // Wrap with circuit breaker middleware (uses NoopCircuitBreaker if not configured) + // Wrap with circuit breaker middleware (nil cbs passes through) wrapped := CircuitBreakerMiddleware(cbs, metrics, provider.Name())(handler) mux.Handle(path, wrapped) } diff --git a/circuit_breaker.go b/circuit_breaker.go index 5df67f9..2adbf18 100644 --- a/circuit_breaker.go +++ b/circuit_breaker.go @@ -52,118 +52,48 @@ func DefaultIsFailure(statusCode int) bool { } } -// CircuitBreaker defines the interface for circuit breaker implementations. -type CircuitBreaker interface { - // Execute runs the given function if the circuit allows it. - // Returns (statusCode, rejected) where rejected is true if the circuit breaker blocked the request. - Execute(fn func() int) (statusCode int, rejected bool) -} - -// NoopCircuitBreaker is a circuit breaker that always allows requests through. -// Used when circuit breaker is not configured for a provider. -type NoopCircuitBreaker struct{} - -func (NoopCircuitBreaker) Execute(fn func() int) (int, bool) { - return fn(), false -} - -// gobreakerCircuitBreaker wraps sony/gobreaker to implement CircuitBreaker. -type gobreakerCircuitBreaker struct { - cb *gobreaker.CircuitBreaker[int] - isFailure func(statusCode int) bool -} - -func (g *gobreakerCircuitBreaker) Execute(fn func() int) (int, bool) { - statusCode, err := g.cb.Execute(func() (int, error) { - code := fn() - if g.isFailure(code) { - return code, fmt.Errorf("upstream error: %d", code) - } - return code, nil - }) - if err != nil { - // Check if rejected by circuit breaker (open or half-open with too many requests) - if errors.Is(err, gobreaker.ErrOpenState) || errors.Is(err, gobreaker.ErrTooManyRequests) { - return 0, true - } - } - return statusCode, false -} - -// CircuitBreakers manages per-endpoint circuit breakers using sony/gobreaker. -// Organized as a per-provider map with endpoint keys. -type CircuitBreakers struct { - // breakers is map[provider]*sync.Map where inner map is endpoint -> CircuitBreaker - breakers sync.Map - configs map[string]CircuitBreakerConfig - onChange func(provider, endpoint string, from, to gobreaker.State) -} - -// NewCircuitBreakers creates a new circuit breaker manager with per-provider configs. -// The configs map is keyed by provider name. Providers not in the map will use -// NoopCircuitBreaker (always allows requests). -func NewCircuitBreakers(configs map[string]CircuitBreakerConfig, onChange func(provider, endpoint string, from, to gobreaker.State)) *CircuitBreakers { - return &CircuitBreakers{ - configs: configs, +// ProviderCircuitBreakers manages per-endpoint circuit breakers for a single provider. +type ProviderCircuitBreakers struct { + provider string + config CircuitBreakerConfig + breakers sync.Map // endpoint -> *gobreaker.CircuitBreaker[struct{}] + onChange func(endpoint string, from, to gobreaker.State) +} + +// NewProviderCircuitBreakers creates circuit breakers for a single provider. +func NewProviderCircuitBreakers(provider string, config CircuitBreakerConfig, onChange func(endpoint string, from, to gobreaker.State)) *ProviderCircuitBreakers { + if config.IsFailure == nil { + config.IsFailure = DefaultIsFailure + } + return &ProviderCircuitBreakers{ + provider: provider, + config: config, onChange: onChange, } } -// getConfig returns the config for a provider, or nil if not configured. -func (c *CircuitBreakers) getConfig(provider string) *CircuitBreakerConfig { - if c.configs == nil { - return nil - } - cfg, ok := c.configs[provider] - if !ok { - return nil - } - return &cfg -} - -// getProviderBreakers returns the endpoint map for a provider, creating it if needed. -func (c *CircuitBreakers) getProviderBreakers(provider string) *sync.Map { - v, _ := c.breakers.LoadOrStore(provider, &sync.Map{}) - return v.(*sync.Map) -} - -// Get returns the circuit breaker for a provider/endpoint. -// Returns NoopCircuitBreaker if the provider is not configured. -func (c *CircuitBreakers) Get(provider, endpoint string) CircuitBreaker { - cfg := c.getConfig(provider) - if cfg == nil { - return NoopCircuitBreaker{} - } - - providerBreakers := c.getProviderBreakers(provider) - if v, ok := providerBreakers.Load(endpoint); ok { - return v.(CircuitBreaker) - } - - isFailure := cfg.IsFailure - if isFailure == nil { - isFailure = DefaultIsFailure +// Get returns the circuit breaker for an endpoint, creating it if needed. +func (p *ProviderCircuitBreakers) Get(endpoint string) *gobreaker.CircuitBreaker[struct{}] { + if v, ok := p.breakers.Load(endpoint); ok { + return v.(*gobreaker.CircuitBreaker[struct{}]) } settings := gobreaker.Settings{ - Name: provider + ":" + endpoint, - MaxRequests: cfg.MaxRequests, - Interval: cfg.Interval, - Timeout: cfg.Timeout, + Name: p.provider + ":" + endpoint, + MaxRequests: p.config.MaxRequests, + Interval: p.config.Interval, + Timeout: p.config.Timeout, ReadyToTrip: func(counts gobreaker.Counts) bool { - return counts.ConsecutiveFailures >= cfg.FailureThreshold + return counts.ConsecutiveFailures >= p.config.FailureThreshold }, OnStateChange: func(_ string, from, to gobreaker.State) { - c.onChange(provider, endpoint, from, to) + p.onChange(endpoint, from, to) }, } - cb := &gobreakerCircuitBreaker{ - cb: gobreaker.NewCircuitBreaker[int](settings), - isFailure: isFailure, - } - actual, _ := providerBreakers.LoadOrStore(endpoint, cb) - return actual.(CircuitBreaker) + cb := gobreaker.NewCircuitBreaker[struct{}](settings) + actual, _ := p.breakers.LoadOrStore(endpoint, cb) + return actual.(*gobreaker.CircuitBreaker[struct{}]) } // statusCapturingWriter wraps http.ResponseWriter to capture the status code. @@ -203,22 +133,30 @@ func (w *statusCapturingWriter) Unwrap() http.ResponseWriter { // CircuitBreakerMiddleware returns middleware that wraps handlers with circuit breaker protection. // It captures the response status code to determine success/failure without provider-specific logic. -// If the provider is not configured, uses NoopCircuitBreaker (always allows requests). -func CircuitBreakerMiddleware(cbs *CircuitBreakers, metrics *Metrics, provider string) func(http.Handler) http.Handler { +// If cbs is nil, requests pass through without circuit breaker protection. +func CircuitBreakerMiddleware(cbs *ProviderCircuitBreakers, metrics *Metrics, provider string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { + // No circuit breaker configured - pass through + if cbs == nil { + return next + } + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { endpoint := strings.TrimPrefix(r.URL.Path, "/"+provider) - cb := cbs.Get(provider, endpoint) + cb := cbs.Get(endpoint) // Wrap response writer to capture status code sw := &statusCapturingWriter{ResponseWriter: w, statusCode: http.StatusOK} - _, rejected := cb.Execute(func() int { + _, err := cb.Execute(func() (struct{}, error) { next.ServeHTTP(sw, r) - return sw.statusCode + if cbs.config.IsFailure(sw.statusCode) { + return struct{}{}, fmt.Errorf("upstream error: %d", sw.statusCode) + } + return struct{}{}, nil }) - if rejected { + if err != nil && (errors.Is(err, gobreaker.ErrOpenState) || errors.Is(err, gobreaker.ErrTooManyRequests)) { if metrics != nil { metrics.CircuitBreakerRejects.WithLabelValues(provider, endpoint).Inc() } diff --git a/circuit_breaker_test.go b/circuit_breaker_test.go index ee80f3b..95a7c78 100644 --- a/circuit_breaker_test.go +++ b/circuit_breaker_test.go @@ -25,14 +25,12 @@ func TestCircuitBreakerMiddleware_TripsOnUpstreamErrors(t *testing.T) { }) // Create circuit breaker with low threshold - cbs := NewCircuitBreakers(map[string]CircuitBreakerConfig{ - "test": { - FailureThreshold: 2, - Interval: time.Minute, - Timeout: 50 * time.Millisecond, - MaxRequests: 1, - }, - }, func(provider, endpoint string, from, to gobreaker.State) {}) + cbs := NewProviderCircuitBreakers("test", CircuitBreakerConfig{ + FailureThreshold: 2, + Interval: time.Minute, + Timeout: 50 * time.Millisecond, + MaxRequests: 1, + }, func(endpoint string, from, to gobreaker.State) {}) // Wrap upstream with circuit breaker middleware handler := CircuitBreakerMiddleware(cbs, nil, "test")(upstream) @@ -86,14 +84,12 @@ func TestCircuitBreakerMiddleware_PerEndpointIsolation(t *testing.T) { } }) - cbs := NewCircuitBreakers(map[string]CircuitBreakerConfig{ - "test": { - FailureThreshold: 1, - Interval: time.Minute, - Timeout: time.Minute, - MaxRequests: 1, - }, - }, func(provider, endpoint string, from, to gobreaker.State) {}) + cbs := NewProviderCircuitBreakers("test", CircuitBreakerConfig{ + FailureThreshold: 1, + Interval: time.Minute, + Timeout: time.Minute, + MaxRequests: 1, + }, func(endpoint string, from, to gobreaker.State) {}) handler := CircuitBreakerMiddleware(cbs, nil, "test")(upstream) server := httptest.NewServer(handler) @@ -129,10 +125,8 @@ func TestCircuitBreakerMiddleware_NotConfigured(t *testing.T) { w.WriteHeader(http.StatusTooManyRequests) }) - // No config for "test" provider - cbs := NewCircuitBreakers(nil, nil) - - handler := CircuitBreakerMiddleware(cbs, nil, "test")(upstream) + // No circuit breaker configured (nil) + handler := CircuitBreakerMiddleware(nil, nil, "test")(upstream) server := httptest.NewServer(handler) defer server.Close() @@ -160,14 +154,12 @@ func TestCircuitBreakerMiddleware_RecoveryAfterSuccess(t *testing.T) { } }) - cbs := NewCircuitBreakers(map[string]CircuitBreakerConfig{ - "test": { - FailureThreshold: 2, - Interval: time.Minute, - Timeout: 50 * time.Millisecond, - MaxRequests: 1, - }, - }, func(provider, endpoint string, from, to gobreaker.State) {}) + cbs := NewProviderCircuitBreakers("test", CircuitBreakerConfig{ + FailureThreshold: 2, + Interval: time.Minute, + Timeout: 50 * time.Millisecond, + MaxRequests: 1, + }, func(endpoint string, from, to gobreaker.State) {}) handler := CircuitBreakerMiddleware(cbs, nil, "test")(upstream) server := httptest.NewServer(handler) @@ -212,17 +204,15 @@ func TestCircuitBreakerMiddleware_CustomIsFailure(t *testing.T) { }) // Custom IsFailure that treats 502 as failure - cbs := NewCircuitBreakers(map[string]CircuitBreakerConfig{ - "test": { - FailureThreshold: 1, - Interval: time.Minute, - Timeout: time.Minute, - MaxRequests: 1, - IsFailure: func(statusCode int) bool { - return statusCode == http.StatusBadGateway - }, + cbs := NewProviderCircuitBreakers("test", CircuitBreakerConfig{ + FailureThreshold: 1, + Interval: time.Minute, + Timeout: time.Minute, + MaxRequests: 1, + IsFailure: func(statusCode int) bool { + return statusCode == http.StatusBadGateway }, - }, func(provider, endpoint string, from, to gobreaker.State) {}) + }, func(endpoint string, from, to gobreaker.State) {}) handler := CircuitBreakerMiddleware(cbs, nil, "test")(upstream) server := httptest.NewServer(handler) From e3438f40dd95733ec1b73b84b42b58c76f0a35b6 Mon Sep 17 00:00:00 2001 From: Kacper Sawicki Date: Wed, 17 Dec 2025 17:27:29 +0100 Subject: [PATCH 20/26] Update bridge.go MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Paweł Banaszewski --- bridge.go | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/bridge.go b/bridge.go index bb0c9d2..d60b2fd 100644 --- a/bridge.go +++ b/bridge.go @@ -57,20 +57,27 @@ func NewRequestBridge(ctx context.Context, providers []Provider, recorder Record for _, provider := range providers { // Create per-provider circuit breaker if configured var cbs *ProviderCircuitBreakers - if cfg := provider.CircuitBreakerConfig(); cfg != nil { - onChange := func(endpoint string, from, to gobreaker.State) {} - if metrics != nil { - providerName := provider.Name() - onChange = func(endpoint string, from, to gobreaker.State) { - metrics.CircuitBreakerState.WithLabelValues(providerName, endpoint).Set(stateToGaugeValue(to)) - if to == gobreaker.StateOpen { - metrics.CircuitBreakerTrips.WithLabelValues(providerName, endpoint).Inc() - } + onChange := func(endpoint string, from, to gobreaker.State) {} + + if cfg := provider.CircuitBreakerConfig(); cfg != nil && metrics != nil { + providerName := provider.Name() + onChange = func(endpoint string, from, to gobreaker.State) { + metrics.CircuitBreakerState.WithLabelValues(providerName, endpoint).Set(stateToGaugeValue(to)) + if to == gobreaker.StateOpen { + metrics.CircuitBreakerTrips.WithLabelValues(providerName, endpoint).Inc() } } - cbs = NewProviderCircuitBreakers(provider.Name(), *cfg, onChange) } + cbs = NewProviderCircuitBreakers(provider.Name(), *cfg, onChange) + + // Add the known provider-specific routes which are bridged (i.e. intercepted and augmented). + for _, path := range provider.BridgedRoutes() { + handler := newInterceptionProcessor(provider, recorder, mcpProxy, logger, metrics, tracer) + // Wrap with circuit breaker middleware (nil cbs passes through) + wrapped := CircuitBreakerMiddleware(cbs, metrics, provider.Name())(handler) + mux.Handle(path, wrapped) + } // Add the known provider-specific routes which are bridged (i.e. intercepted and augmented). for _, path := range provider.BridgedRoutes() { handler := newInterceptionProcessor(provider, recorder, mcpProxy, logger, metrics, tracer) From a32f246a1f81f9642e37e3fd0287b7f5a06c4d42 Mon Sep 17 00:00:00 2001 From: Kacper Sawicki Date: Wed, 17 Dec 2025 16:28:16 +0000 Subject: [PATCH 21/26] fix format --- bridge.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bridge.go b/bridge.go index d60b2fd..84d8af6 100644 --- a/bridge.go +++ b/bridge.go @@ -59,7 +59,7 @@ func NewRequestBridge(ctx context.Context, providers []Provider, recorder Record var cbs *ProviderCircuitBreakers onChange := func(endpoint string, from, to gobreaker.State) {} - if cfg := provider.CircuitBreakerConfig(); cfg != nil && metrics != nil { + if cfg := provider.CircuitBreakerConfig(); cfg != nil && metrics != nil { providerName := provider.Name() onChange = func(endpoint string, from, to gobreaker.State) { metrics.CircuitBreakerState.WithLabelValues(providerName, endpoint).Set(stateToGaugeValue(to)) @@ -70,7 +70,7 @@ func NewRequestBridge(ctx context.Context, providers []Provider, recorder Record } cbs = NewProviderCircuitBreakers(provider.Name(), *cfg, onChange) - + // Add the known provider-specific routes which are bridged (i.e. intercepted and augmented). for _, path := range provider.BridgedRoutes() { handler := newInterceptionProcessor(provider, recorder, mcpProxy, logger, metrics, tracer) From e9290981cf44c5d89507df64bb7060650611320b Mon Sep 17 00:00:00 2001 From: Kacper Sawicki Date: Wed, 17 Dec 2025 17:02:34 +0000 Subject: [PATCH 22/26] Apply review suggestions --- bridge.go | 21 ++++++--------------- circuit_breaker.go | 18 ++++++++++++------ circuit_breaker_test.go | 18 +++++++++--------- 3 files changed, 27 insertions(+), 30 deletions(-) diff --git a/bridge.go b/bridge.go index 84d8af6..0fcb272 100644 --- a/bridge.go +++ b/bridge.go @@ -56,33 +56,24 @@ func NewRequestBridge(ctx context.Context, providers []Provider, recorder Record for _, provider := range providers { // Create per-provider circuit breaker if configured - var cbs *ProviderCircuitBreakers + cfg := provider.CircuitBreakerConfig() onChange := func(endpoint string, from, to gobreaker.State) {} - if cfg := provider.CircuitBreakerConfig(); cfg != nil && metrics != nil { - providerName := provider.Name() + if cfg != nil && metrics != nil { onChange = func(endpoint string, from, to gobreaker.State) { - metrics.CircuitBreakerState.WithLabelValues(providerName, endpoint).Set(stateToGaugeValue(to)) + metrics.CircuitBreakerState.WithLabelValues(provider.Name(), endpoint).Set(stateToGaugeValue(to)) if to == gobreaker.StateOpen { - metrics.CircuitBreakerTrips.WithLabelValues(providerName, endpoint).Inc() + metrics.CircuitBreakerTrips.WithLabelValues(provider.Name(), endpoint).Inc() } } } + cbs := NewProviderCircuitBreakers(provider.Name(), cfg, onChange) - cbs = NewProviderCircuitBreakers(provider.Name(), *cfg, onChange) - - // Add the known provider-specific routes which are bridged (i.e. intercepted and augmented). - for _, path := range provider.BridgedRoutes() { - handler := newInterceptionProcessor(provider, recorder, mcpProxy, logger, metrics, tracer) - // Wrap with circuit breaker middleware (nil cbs passes through) - wrapped := CircuitBreakerMiddleware(cbs, metrics, provider.Name())(handler) - mux.Handle(path, wrapped) - } // Add the known provider-specific routes which are bridged (i.e. intercepted and augmented). for _, path := range provider.BridgedRoutes() { handler := newInterceptionProcessor(provider, recorder, mcpProxy, logger, metrics, tracer) // Wrap with circuit breaker middleware (nil cbs passes through) - wrapped := CircuitBreakerMiddleware(cbs, metrics, provider.Name())(handler) + wrapped := CircuitBreakerMiddleware(cbs, metrics)(handler) mux.Handle(path, wrapped) } diff --git a/circuit_breaker.go b/circuit_breaker.go index 2adbf18..9e1ef20 100644 --- a/circuit_breaker.go +++ b/circuit_breaker.go @@ -61,13 +61,17 @@ type ProviderCircuitBreakers struct { } // NewProviderCircuitBreakers creates circuit breakers for a single provider. -func NewProviderCircuitBreakers(provider string, config CircuitBreakerConfig, onChange func(endpoint string, from, to gobreaker.State)) *ProviderCircuitBreakers { +// Returns nil if config is nil (no circuit breaker protection). +func NewProviderCircuitBreakers(provider string, config *CircuitBreakerConfig, onChange func(endpoint string, from, to gobreaker.State)) *ProviderCircuitBreakers { + if config == nil { + return nil + } if config.IsFailure == nil { config.IsFailure = DefaultIsFailure } return &ProviderCircuitBreakers{ provider: provider, - config: config, + config: *config, onChange: onChange, } } @@ -87,7 +91,9 @@ func (p *ProviderCircuitBreakers) Get(endpoint string) *gobreaker.CircuitBreaker return counts.ConsecutiveFailures >= p.config.FailureThreshold }, OnStateChange: func(_ string, from, to gobreaker.State) { - p.onChange(endpoint, from, to) + if p.onChange != nil { + p.onChange(endpoint, from, to) + } }, } @@ -134,7 +140,7 @@ func (w *statusCapturingWriter) Unwrap() http.ResponseWriter { // CircuitBreakerMiddleware returns middleware that wraps handlers with circuit breaker protection. // It captures the response status code to determine success/failure without provider-specific logic. // If cbs is nil, requests pass through without circuit breaker protection. -func CircuitBreakerMiddleware(cbs *ProviderCircuitBreakers, metrics *Metrics, provider string) func(http.Handler) http.Handler { +func CircuitBreakerMiddleware(cbs *ProviderCircuitBreakers, metrics *Metrics) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { // No circuit breaker configured - pass through if cbs == nil { @@ -142,7 +148,7 @@ func CircuitBreakerMiddleware(cbs *ProviderCircuitBreakers, metrics *Metrics, pr } return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - endpoint := strings.TrimPrefix(r.URL.Path, "/"+provider) + endpoint := strings.TrimPrefix(r.URL.Path, "/"+cbs.provider) cb := cbs.Get(endpoint) // Wrap response writer to capture status code @@ -158,7 +164,7 @@ func CircuitBreakerMiddleware(cbs *ProviderCircuitBreakers, metrics *Metrics, pr if err != nil && (errors.Is(err, gobreaker.ErrOpenState) || errors.Is(err, gobreaker.ErrTooManyRequests)) { if metrics != nil { - metrics.CircuitBreakerRejects.WithLabelValues(provider, endpoint).Inc() + metrics.CircuitBreakerRejects.WithLabelValues(cbs.provider, endpoint).Inc() } http.Error(w, "circuit breaker is open", http.StatusServiceUnavailable) } diff --git a/circuit_breaker_test.go b/circuit_breaker_test.go index 95a7c78..d509838 100644 --- a/circuit_breaker_test.go +++ b/circuit_breaker_test.go @@ -25,7 +25,7 @@ func TestCircuitBreakerMiddleware_TripsOnUpstreamErrors(t *testing.T) { }) // Create circuit breaker with low threshold - cbs := NewProviderCircuitBreakers("test", CircuitBreakerConfig{ + cbs := NewProviderCircuitBreakers("test", &CircuitBreakerConfig{ FailureThreshold: 2, Interval: time.Minute, Timeout: 50 * time.Millisecond, @@ -33,7 +33,7 @@ func TestCircuitBreakerMiddleware_TripsOnUpstreamErrors(t *testing.T) { }, func(endpoint string, from, to gobreaker.State) {}) // Wrap upstream with circuit breaker middleware - handler := CircuitBreakerMiddleware(cbs, nil, "test")(upstream) + handler := CircuitBreakerMiddleware(cbs, nil)(upstream) server := httptest.NewServer(handler) defer server.Close() @@ -84,14 +84,14 @@ func TestCircuitBreakerMiddleware_PerEndpointIsolation(t *testing.T) { } }) - cbs := NewProviderCircuitBreakers("test", CircuitBreakerConfig{ + cbs := NewProviderCircuitBreakers("test", &CircuitBreakerConfig{ FailureThreshold: 1, Interval: time.Minute, Timeout: time.Minute, MaxRequests: 1, }, func(endpoint string, from, to gobreaker.State) {}) - handler := CircuitBreakerMiddleware(cbs, nil, "test")(upstream) + handler := CircuitBreakerMiddleware(cbs, nil)(upstream) server := httptest.NewServer(handler) defer server.Close() @@ -126,7 +126,7 @@ func TestCircuitBreakerMiddleware_NotConfigured(t *testing.T) { }) // No circuit breaker configured (nil) - handler := CircuitBreakerMiddleware(nil, nil, "test")(upstream) + handler := CircuitBreakerMiddleware(nil, nil)(upstream) server := httptest.NewServer(handler) defer server.Close() @@ -154,14 +154,14 @@ func TestCircuitBreakerMiddleware_RecoveryAfterSuccess(t *testing.T) { } }) - cbs := NewProviderCircuitBreakers("test", CircuitBreakerConfig{ + cbs := NewProviderCircuitBreakers("test", &CircuitBreakerConfig{ FailureThreshold: 2, Interval: time.Minute, Timeout: 50 * time.Millisecond, MaxRequests: 1, }, func(endpoint string, from, to gobreaker.State) {}) - handler := CircuitBreakerMiddleware(cbs, nil, "test")(upstream) + handler := CircuitBreakerMiddleware(cbs, nil)(upstream) server := httptest.NewServer(handler) defer server.Close() @@ -204,7 +204,7 @@ func TestCircuitBreakerMiddleware_CustomIsFailure(t *testing.T) { }) // Custom IsFailure that treats 502 as failure - cbs := NewProviderCircuitBreakers("test", CircuitBreakerConfig{ + cbs := NewProviderCircuitBreakers("test", &CircuitBreakerConfig{ FailureThreshold: 1, Interval: time.Minute, Timeout: time.Minute, @@ -214,7 +214,7 @@ func TestCircuitBreakerMiddleware_CustomIsFailure(t *testing.T) { }, }, func(endpoint string, from, to gobreaker.State) {}) - handler := CircuitBreakerMiddleware(cbs, nil, "test")(upstream) + handler := CircuitBreakerMiddleware(cbs, nil)(upstream) server := httptest.NewServer(handler) defer server.Close() From ab08de4d544c0b0296942c0451c17eb691ceaa57 Mon Sep 17 00:00:00 2001 From: Kacper Sawicki Date: Thu, 18 Dec 2025 12:44:17 +0000 Subject: [PATCH 23/26] Apply review suggestions and add proper integration tests --- circuit_breaker.go | 2 +- circuit_breaker_integration_test.go | 208 ++++++++++++++++++++++++++++ circuit_breaker_test.go | 111 --------------- 3 files changed, 209 insertions(+), 112 deletions(-) create mode 100644 circuit_breaker_integration_test.go diff --git a/circuit_breaker.go b/circuit_breaker.go index 9e1ef20..e5a8530 100644 --- a/circuit_breaker.go +++ b/circuit_breaker.go @@ -162,7 +162,7 @@ func CircuitBreakerMiddleware(cbs *ProviderCircuitBreakers, metrics *Metrics) fu return struct{}{}, nil }) - if err != nil && (errors.Is(err, gobreaker.ErrOpenState) || errors.Is(err, gobreaker.ErrTooManyRequests)) { + if errors.Is(err, gobreaker.ErrOpenState) || errors.Is(err, gobreaker.ErrTooManyRequests) { if metrics != nil { metrics.CircuitBreakerRejects.WithLabelValues(cbs.provider, endpoint).Inc() } diff --git a/circuit_breaker_integration_test.go b/circuit_breaker_integration_test.go new file mode 100644 index 0000000..ee63e69 --- /dev/null +++ b/circuit_breaker_integration_test.go @@ -0,0 +1,208 @@ +package aibridge_test + +import ( + "context" + "io" + "net" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + "time" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/aibridge" + "github.com/coder/aibridge/mcp" + "github.com/prometheus/client_golang/prometheus" + promtest "github.com/prometheus/client_golang/prometheus/testutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel" +) + +func TestCircuitBreaker_WithNewRequestBridge(t *testing.T) { + t.Parallel() + + var upstreamCalls atomic.Int32 + + // Mock upstream that returns 429 in Anthropic error format. + // x-should-retry: false is required to disable SDK automatic retries (default MaxRetries=2). + mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upstreamCalls.Add(1) + w.Header().Set("Content-Type", "application/json") + w.Header().Set("x-should-retry", "false") + w.WriteHeader(http.StatusTooManyRequests) + _, _ = w.Write([]byte(`{"type":"error","error":{"type":"rate_limit_error","message":"rate limited"}}`)) + })) + defer mockUpstream.Close() + + metrics := aibridge.NewMetrics(prometheus.NewRegistry()) + + // Create provider with circuit breaker config + provider := aibridge.NewAnthropicProvider(aibridge.AnthropicConfig{ + BaseURL: mockUpstream.URL, + Key: "test-key", + CircuitBreaker: &aibridge.CircuitBreakerConfig{ + FailureThreshold: 2, + Interval: time.Minute, + Timeout: 50 * time.Millisecond, + MaxRequests: 1, + }, + }, nil) + + ctx := t.Context() + tracer := otel.Tracer("forTesting") + logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) + bridge, err := aibridge.NewRequestBridge(ctx, + []aibridge.Provider{provider}, + &mockRecorderClient{}, + mcp.NewServerProxyManager(nil, tracer), + logger, + metrics, + tracer, + ) + require.NoError(t, err) + + mockSrv := httptest.NewUnstartedServer(bridge) + t.Cleanup(mockSrv.Close) + mockSrv.Config.BaseContext = func(_ net.Listener) context.Context { + return aibridge.AsActor(ctx, "test-user-id", nil) + } + mockSrv.Start() + + makeRequest := func() *http.Response { + body := `{"model":"claude-sonnet-4-20250514","max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}` + req, _ := http.NewRequest("POST", mockSrv.URL+"/anthropic/v1/messages", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("x-api-key", "test") + req.Header.Set("anthropic-version", "2023-06-01") + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + _, _ = io.ReadAll(resp.Body) + resp.Body.Close() + return resp + } + + // First 2 requests hit upstream, get 429 + for i := 0; i < 2; i++ { + resp := makeRequest() + assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode) + } + assert.Equal(t, int32(2), upstreamCalls.Load()) + + // Third request should be blocked by circuit breaker + resp := makeRequest() + assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode) + assert.Equal(t, int32(2), upstreamCalls.Load()) // No new upstream call + + // Verify metrics were recorded via NewRequestBridge's onChange callback + trips := promtest.ToFloat64(metrics.CircuitBreakerTrips.WithLabelValues(aibridge.ProviderAnthropic, "/v1/messages")) + assert.Equal(t, 1.0, trips, "CircuitBreakerTrips should be 1") + + state := promtest.ToFloat64(metrics.CircuitBreakerState.WithLabelValues(aibridge.ProviderAnthropic, "/v1/messages")) + assert.Equal(t, 1.0, state, "CircuitBreakerState should be 1 (open)") + + rejects := promtest.ToFloat64(metrics.CircuitBreakerRejects.WithLabelValues(aibridge.ProviderAnthropic, "/v1/messages")) + assert.Equal(t, 1.0, rejects, "CircuitBreakerRejects should be 1") +} + +func TestCircuitBreaker_HalfOpenAndRecovery(t *testing.T) { + t.Parallel() + + var upstreamCalls atomic.Int32 + var returnError atomic.Bool + returnError.Store(true) + + // Mock upstream that can switch between error and success. + // x-should-retry: false is required to disable SDK automatic retries. + mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upstreamCalls.Add(1) + w.Header().Set("Content-Type", "application/json") + w.Header().Set("x-should-retry", "false") + if returnError.Load() { + w.WriteHeader(http.StatusTooManyRequests) + _, _ = w.Write([]byte(`{"type":"error","error":{"type":"rate_limit_error","message":"rate limited"}}`)) + } else { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"id":"msg_123","type":"message","role":"assistant","content":[{"type":"text","text":"hi"}],"model":"claude-sonnet-4-20250514","stop_reason":"end_turn","usage":{"input_tokens":10,"output_tokens":5}}`)) + } + })) + t.Cleanup(mockUpstream.Close) + + metrics := aibridge.NewMetrics(prometheus.NewRegistry()) + + provider := aibridge.NewAnthropicProvider(aibridge.AnthropicConfig{ + BaseURL: mockUpstream.URL, + Key: "test-key", + CircuitBreaker: &aibridge.CircuitBreakerConfig{ + FailureThreshold: 2, + Interval: time.Minute, + Timeout: 50 * time.Millisecond, // Short timeout for faster test + MaxRequests: 1, + }, + }, nil) + + ctx := t.Context() + tracer := otel.Tracer("forTesting") + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) + bridge, err := aibridge.NewRequestBridge(ctx, + []aibridge.Provider{provider}, + &mockRecorderClient{}, + mcp.NewServerProxyManager(nil, tracer), + logger, + metrics, + tracer, + ) + require.NoError(t, err) + + mockSrv := httptest.NewUnstartedServer(bridge) + t.Cleanup(mockSrv.Close) + mockSrv.Config.BaseContext = func(_ net.Listener) context.Context { + return aibridge.AsActor(ctx, "test-user-id", nil) + } + mockSrv.Start() + + makeRequest := func() *http.Response { + body := `{"model":"claude-sonnet-4-20250514","max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}` + req, _ := http.NewRequest("POST", mockSrv.URL+"/anthropic/v1/messages", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("x-api-key", "test") + req.Header.Set("anthropic-version", "2023-06-01") + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + _, _ = io.ReadAll(resp.Body) + resp.Body.Close() + return resp + } + + // Trip the circuit with 2 failures + for i := 0; i < 2; i++ { + resp := makeRequest() + assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode) + } + assert.Equal(t, int32(2), upstreamCalls.Load()) + + // Verify circuit is open + state := promtest.ToFloat64(metrics.CircuitBreakerState.WithLabelValues(aibridge.ProviderAnthropic, "/v1/messages")) + assert.Equal(t, 1.0, state, "CircuitBreakerState should be 1 (open)") + + // Switch upstream to success before recovery + returnError.Store(false) + + // Wait for timeout, then make a request to trigger recovery + time.Sleep(60 * time.Millisecond) + + // This request triggers recovery: half-open -> probe succeeds -> closed + resp := makeRequest() + assert.Equal(t, http.StatusOK, resp.StatusCode) + + // Verify circuit is now closed (state = 0) + state = promtest.ToFloat64(metrics.CircuitBreakerState.WithLabelValues(aibridge.ProviderAnthropic, "/v1/messages")) + assert.Equal(t, 0.0, state, "CircuitBreakerState should be 0 (closed)") + + // Additional requests should succeed + resp = makeRequest() + assert.Equal(t, http.StatusOK, resp.StatusCode) +} diff --git a/circuit_breaker_test.go b/circuit_breaker_test.go index d509838..afa8cfb 100644 --- a/circuit_breaker_test.go +++ b/circuit_breaker_test.go @@ -1,7 +1,6 @@ package aibridge import ( - "io" "net/http" "net/http/httptest" "sync/atomic" @@ -13,60 +12,6 @@ import ( "github.com/stretchr/testify/require" ) -func TestCircuitBreakerMiddleware_TripsOnUpstreamErrors(t *testing.T) { - t.Parallel() - - var upstreamCalls atomic.Int32 - - // Mock upstream that returns 429 - upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - upstreamCalls.Add(1) - w.WriteHeader(http.StatusTooManyRequests) - }) - - // Create circuit breaker with low threshold - cbs := NewProviderCircuitBreakers("test", &CircuitBreakerConfig{ - FailureThreshold: 2, - Interval: time.Minute, - Timeout: 50 * time.Millisecond, - MaxRequests: 1, - }, func(endpoint string, from, to gobreaker.State) {}) - - // Wrap upstream with circuit breaker middleware - handler := CircuitBreakerMiddleware(cbs, nil)(upstream) - server := httptest.NewServer(handler) - defer server.Close() - - // First 2 requests hit upstream, get 429 - for i := 0; i < 2; i++ { - resp, err := http.Get(server.URL + "/test/v1/messages") - require.NoError(t, err) - resp.Body.Close() - assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode) - } - assert.Equal(t, int32(2), upstreamCalls.Load()) - - // Third request should get 503 "circuit breaker is open" without hitting upstream - resp, err := http.Get(server.URL + "/test/v1/messages") - require.NoError(t, err) - body, _ := io.ReadAll(resp.Body) - resp.Body.Close() - assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode) - assert.Contains(t, string(body), "circuit breaker is open") - assert.Equal(t, int32(2), upstreamCalls.Load()) // No new upstream call - - // Wait for timeout, verify recovery (circuit transitions to half-open) - require.Eventually(t, func() bool { - resp, err = http.Get(server.URL + "/test/v1/messages") - if err != nil { - return false - } - resp.Body.Close() - // Request hit upstream again (half-open state allows probe request) - return upstreamCalls.Load() == 3 - }, 5*time.Second, 25*time.Millisecond) -} - func TestCircuitBreakerMiddleware_PerEndpointIsolation(t *testing.T) { t.Parallel() @@ -140,62 +85,6 @@ func TestCircuitBreakerMiddleware_NotConfigured(t *testing.T) { assert.Equal(t, int32(10), upstreamCalls.Load()) } -func TestCircuitBreakerMiddleware_RecoveryAfterSuccess(t *testing.T) { - t.Parallel() - - var returnError atomic.Bool - returnError.Store(true) - - upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if returnError.Load() { - w.WriteHeader(http.StatusTooManyRequests) - } else { - w.WriteHeader(http.StatusOK) - } - }) - - cbs := NewProviderCircuitBreakers("test", &CircuitBreakerConfig{ - FailureThreshold: 2, - Interval: time.Minute, - Timeout: 50 * time.Millisecond, - MaxRequests: 1, - }, func(endpoint string, from, to gobreaker.State) {}) - - handler := CircuitBreakerMiddleware(cbs, nil)(upstream) - server := httptest.NewServer(handler) - defer server.Close() - - // Trip the circuit - for i := 0; i < 2; i++ { - resp, _ := http.Get(server.URL + "/test/v1/messages") - resp.Body.Close() - } - - // Circuit should be open - resp, _ := http.Get(server.URL + "/test/v1/messages") - assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode) - resp.Body.Close() - - // Switch upstream to success before we start polling - returnError.Store(false) - - // Wait for timeout (circuit transitions to half-open), then verify recovery - require.Eventually(t, func() bool { - resp, err := http.Get(server.URL + "/test/v1/messages") - if err != nil { - return false - } - defer resp.Body.Close() - // Half-open: request goes through and succeeds - return resp.StatusCode == http.StatusOK - }, 5*time.Second, 25*time.Millisecond) - - // Circuit should be closed now, more requests allowed - resp, _ = http.Get(server.URL + "/test/v1/messages") - assert.Equal(t, http.StatusOK, resp.StatusCode) - resp.Body.Close() -} - func TestCircuitBreakerMiddleware_CustomIsFailure(t *testing.T) { t.Parallel() From 161db92bd676422b4a17ba7f5bc1215bee2a999c Mon Sep 17 00:00:00 2001 From: Kacper Sawicki Date: Thu, 18 Dec 2025 12:46:58 +0000 Subject: [PATCH 24/26] Add test to check circuit breaker config --- circuit_breaker_test.go | 50 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/circuit_breaker_test.go b/circuit_breaker_test.go index afa8cfb..5ec2455 100644 --- a/circuit_breaker_test.go +++ b/circuit_breaker_test.go @@ -146,3 +146,53 @@ func TestStateToGaugeValue(t *testing.T) { assert.Equal(t, float64(0.5), stateToGaugeValue(gobreaker.StateHalfOpen)) assert.Equal(t, float64(1), stateToGaugeValue(gobreaker.StateOpen)) } + +func TestProviderCircuitBreakerConfig(t *testing.T) { + t.Parallel() + + cfg := &CircuitBreakerConfig{ + FailureThreshold: 5, + Interval: 2 * time.Minute, + Timeout: 30 * time.Second, + MaxRequests: 3, + IsFailure: func(code int) bool { + return code == 429 + }, + } + + t.Run("AnthropicProvider", func(t *testing.T) { + t.Parallel() + provider := NewAnthropicProvider(AnthropicConfig{ + CircuitBreaker: cfg, + }, nil) + + got := provider.CircuitBreakerConfig() + require.NotNil(t, got) + assert.Equal(t, cfg.FailureThreshold, got.FailureThreshold) + assert.Equal(t, cfg.Interval, got.Interval) + assert.Equal(t, cfg.Timeout, got.Timeout) + assert.Equal(t, cfg.MaxRequests, got.MaxRequests) + assert.NotNil(t, got.IsFailure) + }) + + t.Run("OpenAIProvider", func(t *testing.T) { + t.Parallel() + provider := NewOpenAIProvider(OpenAIConfig{ + CircuitBreaker: cfg, + }) + + got := provider.CircuitBreakerConfig() + require.NotNil(t, got) + assert.Equal(t, cfg.FailureThreshold, got.FailureThreshold) + assert.Equal(t, cfg.Interval, got.Interval) + assert.Equal(t, cfg.Timeout, got.Timeout) + assert.Equal(t, cfg.MaxRequests, got.MaxRequests) + assert.NotNil(t, got.IsFailure) + }) + + t.Run("NilConfig", func(t *testing.T) { + t.Parallel() + provider := NewAnthropicProvider(AnthropicConfig{}, nil) + assert.Nil(t, provider.CircuitBreakerConfig()) + }) +} From 33ea4aeb4f86a510f0fe1a228c0d9cacd8f0c3fc Mon Sep 17 00:00:00 2001 From: Kacper Sawicki Date: Thu, 18 Dec 2025 13:27:01 +0000 Subject: [PATCH 25/26] Remove test --- circuit_breaker_test.go | 50 ----------------------------------------- 1 file changed, 50 deletions(-) diff --git a/circuit_breaker_test.go b/circuit_breaker_test.go index 5ec2455..afa8cfb 100644 --- a/circuit_breaker_test.go +++ b/circuit_breaker_test.go @@ -146,53 +146,3 @@ func TestStateToGaugeValue(t *testing.T) { assert.Equal(t, float64(0.5), stateToGaugeValue(gobreaker.StateHalfOpen)) assert.Equal(t, float64(1), stateToGaugeValue(gobreaker.StateOpen)) } - -func TestProviderCircuitBreakerConfig(t *testing.T) { - t.Parallel() - - cfg := &CircuitBreakerConfig{ - FailureThreshold: 5, - Interval: 2 * time.Minute, - Timeout: 30 * time.Second, - MaxRequests: 3, - IsFailure: func(code int) bool { - return code == 429 - }, - } - - t.Run("AnthropicProvider", func(t *testing.T) { - t.Parallel() - provider := NewAnthropicProvider(AnthropicConfig{ - CircuitBreaker: cfg, - }, nil) - - got := provider.CircuitBreakerConfig() - require.NotNil(t, got) - assert.Equal(t, cfg.FailureThreshold, got.FailureThreshold) - assert.Equal(t, cfg.Interval, got.Interval) - assert.Equal(t, cfg.Timeout, got.Timeout) - assert.Equal(t, cfg.MaxRequests, got.MaxRequests) - assert.NotNil(t, got.IsFailure) - }) - - t.Run("OpenAIProvider", func(t *testing.T) { - t.Parallel() - provider := NewOpenAIProvider(OpenAIConfig{ - CircuitBreaker: cfg, - }) - - got := provider.CircuitBreakerConfig() - require.NotNil(t, got) - assert.Equal(t, cfg.FailureThreshold, got.FailureThreshold) - assert.Equal(t, cfg.Interval, got.Interval) - assert.Equal(t, cfg.Timeout, got.Timeout) - assert.Equal(t, cfg.MaxRequests, got.MaxRequests) - assert.NotNil(t, got.IsFailure) - }) - - t.Run("NilConfig", func(t *testing.T) { - t.Parallel() - provider := NewAnthropicProvider(AnthropicConfig{}, nil) - assert.Nil(t, provider.CircuitBreakerConfig()) - }) -} From dbfab23c547e524c663aea95bd83c103c57a9ea1 Mon Sep 17 00:00:00 2001 From: Kacper Sawicki Date: Thu, 18 Dec 2025 14:01:45 +0000 Subject: [PATCH 26/26] Remove TestCircuitBreaker_HalfOpenAndRecovery --- circuit_breaker_integration_test.go | 99 ----------------------------- 1 file changed, 99 deletions(-) diff --git a/circuit_breaker_integration_test.go b/circuit_breaker_integration_test.go index ee63e69..59dab66 100644 --- a/circuit_breaker_integration_test.go +++ b/circuit_breaker_integration_test.go @@ -107,102 +107,3 @@ func TestCircuitBreaker_WithNewRequestBridge(t *testing.T) { rejects := promtest.ToFloat64(metrics.CircuitBreakerRejects.WithLabelValues(aibridge.ProviderAnthropic, "/v1/messages")) assert.Equal(t, 1.0, rejects, "CircuitBreakerRejects should be 1") } - -func TestCircuitBreaker_HalfOpenAndRecovery(t *testing.T) { - t.Parallel() - - var upstreamCalls atomic.Int32 - var returnError atomic.Bool - returnError.Store(true) - - // Mock upstream that can switch between error and success. - // x-should-retry: false is required to disable SDK automatic retries. - mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - upstreamCalls.Add(1) - w.Header().Set("Content-Type", "application/json") - w.Header().Set("x-should-retry", "false") - if returnError.Load() { - w.WriteHeader(http.StatusTooManyRequests) - _, _ = w.Write([]byte(`{"type":"error","error":{"type":"rate_limit_error","message":"rate limited"}}`)) - } else { - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte(`{"id":"msg_123","type":"message","role":"assistant","content":[{"type":"text","text":"hi"}],"model":"claude-sonnet-4-20250514","stop_reason":"end_turn","usage":{"input_tokens":10,"output_tokens":5}}`)) - } - })) - t.Cleanup(mockUpstream.Close) - - metrics := aibridge.NewMetrics(prometheus.NewRegistry()) - - provider := aibridge.NewAnthropicProvider(aibridge.AnthropicConfig{ - BaseURL: mockUpstream.URL, - Key: "test-key", - CircuitBreaker: &aibridge.CircuitBreakerConfig{ - FailureThreshold: 2, - Interval: time.Minute, - Timeout: 50 * time.Millisecond, // Short timeout for faster test - MaxRequests: 1, - }, - }, nil) - - ctx := t.Context() - tracer := otel.Tracer("forTesting") - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) - bridge, err := aibridge.NewRequestBridge(ctx, - []aibridge.Provider{provider}, - &mockRecorderClient{}, - mcp.NewServerProxyManager(nil, tracer), - logger, - metrics, - tracer, - ) - require.NoError(t, err) - - mockSrv := httptest.NewUnstartedServer(bridge) - t.Cleanup(mockSrv.Close) - mockSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return aibridge.AsActor(ctx, "test-user-id", nil) - } - mockSrv.Start() - - makeRequest := func() *http.Response { - body := `{"model":"claude-sonnet-4-20250514","max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}` - req, _ := http.NewRequest("POST", mockSrv.URL+"/anthropic/v1/messages", strings.NewReader(body)) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("x-api-key", "test") - req.Header.Set("anthropic-version", "2023-06-01") - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) - _, _ = io.ReadAll(resp.Body) - resp.Body.Close() - return resp - } - - // Trip the circuit with 2 failures - for i := 0; i < 2; i++ { - resp := makeRequest() - assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode) - } - assert.Equal(t, int32(2), upstreamCalls.Load()) - - // Verify circuit is open - state := promtest.ToFloat64(metrics.CircuitBreakerState.WithLabelValues(aibridge.ProviderAnthropic, "/v1/messages")) - assert.Equal(t, 1.0, state, "CircuitBreakerState should be 1 (open)") - - // Switch upstream to success before recovery - returnError.Store(false) - - // Wait for timeout, then make a request to trigger recovery - time.Sleep(60 * time.Millisecond) - - // This request triggers recovery: half-open -> probe succeeds -> closed - resp := makeRequest() - assert.Equal(t, http.StatusOK, resp.StatusCode) - - // Verify circuit is now closed (state = 0) - state = promtest.ToFloat64(metrics.CircuitBreakerState.WithLabelValues(aibridge.ProviderAnthropic, "/v1/messages")) - assert.Equal(t, 0.0, state, "CircuitBreakerState should be 0 (closed)") - - // Additional requests should succeed - resp = makeRequest() - assert.Equal(t, http.StatusOK, resp.StatusCode) -}