kroki_rs/server/middleware/
circuit_breaker.rs

1//! Per-provider circuit breaker for the Kroki-rs server.
2//!
3//! Implements the Closed → Open → Half-Open state machine pattern.
4//! Each diagram provider type gets its own independent circuit breaker.
5//! When the circuit is open, requests fail immediately with 503 without
6//! invoking the actual provider.
7
8use crate::config::CircuitBreakerConfig;
9use dashmap::DashMap;
10use std::sync::Arc;
11use std::time::Instant;
12
13/// Circuit breaker states.
14#[derive(Debug, Clone, Copy, PartialEq)]
15pub enum CircuitState {
16    /// Normal operation — requests flow through.
17    Closed,
18    /// Failures exceeded threshold — requests rejected immediately.
19    Open,
20    /// Testing recovery — one request allowed through.
21    HalfOpen,
22}
23
24/// Internal state for a single provider's circuit breaker.
25struct ProviderCircuit {
26    state: CircuitState,
27    consecutive_failures: u32,
28    last_failure_time: Option<Instant>,
29    failure_threshold: u32,
30    reset_timeout_secs: u64,
31}
32
33impl ProviderCircuit {
34    fn new(config: &CircuitBreakerConfig) -> Self {
35        Self {
36            state: CircuitState::Closed,
37            consecutive_failures: 0,
38            last_failure_time: None,
39            failure_threshold: config.failure_threshold,
40            reset_timeout_secs: config.reset_timeout_secs,
41        }
42    }
43
44    /// Checks if a request should be allowed through.
45    fn should_allow(&mut self) -> bool {
46        match self.state {
47            CircuitState::Closed => true,
48            CircuitState::Open => {
49                // Check if reset timeout has elapsed → transition to HalfOpen
50                if let Some(last_fail) = self.last_failure_time {
51                    if last_fail.elapsed().as_secs() >= self.reset_timeout_secs {
52                        self.state = CircuitState::HalfOpen;
53                        tracing::info!("Circuit breaker transitioning to HalfOpen");
54                        true // Allow one test request
55                    } else {
56                        false
57                    }
58                } else {
59                    false
60                }
61            }
62            CircuitState::HalfOpen => {
63                // Only one request is allowed in HalfOpen; subsequent ones are rejected
64                // until the first one completes.
65                false
66            }
67        }
68    }
69
70    /// Records a successful request. Resets the circuit to Closed.
71    fn record_success(&mut self) {
72        if self.state == CircuitState::HalfOpen {
73            tracing::info!("Circuit breaker closing after successful HalfOpen test");
74        }
75        self.state = CircuitState::Closed;
76        self.consecutive_failures = 0;
77        self.last_failure_time = None;
78    }
79
80    /// Records a failed request. May trip the circuit to Open.
81    fn record_failure(&mut self) {
82        self.consecutive_failures += 1;
83        self.last_failure_time = Some(Instant::now());
84
85        if self.consecutive_failures >= self.failure_threshold {
86            self.state = CircuitState::Open;
87            tracing::warn!(
88                "Circuit breaker opened after {} consecutive failures",
89                self.consecutive_failures
90            );
91        }
92
93        // If we were in HalfOpen and the test failed, go back to Open
94        if self.state == CircuitState::HalfOpen {
95            self.state = CircuitState::Open;
96            tracing::warn!("Circuit breaker re-opened after HalfOpen test failure");
97        }
98    }
99}
100
101/// Manages circuit breakers for all diagram providers.
102#[derive(Clone)]
103pub struct CircuitBreakerManager {
104    circuits: Arc<DashMap<String, ProviderCircuit>>,
105    config: CircuitBreakerConfig,
106}
107
108impl CircuitBreakerManager {
109    /// Creates a new circuit breaker manager.
110    pub fn new(config: &CircuitBreakerConfig) -> Self {
111        Self {
112            circuits: Arc::new(DashMap::new()),
113            config: config.clone(),
114        }
115    }
116
117    /// Checks if the circuit for the given provider allows requests.
118    /// Returns `true` if the request should proceed, `false` if it should be rejected.
119    pub fn should_allow(&self, provider: &str) -> bool {
120        let mut entry = self
121            .circuits
122            .entry(provider.to_string())
123            .or_insert_with(|| ProviderCircuit::new(&self.config));
124        entry.should_allow()
125    }
126
127    /// Records a successful request for the given provider.
128    pub fn record_success(&self, provider: &str) {
129        if let Some(mut entry) = self.circuits.get_mut(provider) {
130            entry.record_success();
131        }
132    }
133
134    /// Records a failed request for the given provider.
135    pub fn record_failure(&self, provider: &str) {
136        let mut entry = self
137            .circuits
138            .entry(provider.to_string())
139            .or_insert_with(|| ProviderCircuit::new(&self.config));
140        entry.record_failure();
141    }
142
143    /// Returns the current state of the circuit for a given provider.
144    pub fn get_state(&self, provider: &str) -> CircuitState {
145        self.circuits
146            .get(provider)
147            .map(|entry| entry.state)
148            .unwrap_or(CircuitState::Closed)
149    }
150
151    /// Returns the states of all known circuits.
152    pub fn get_all_states(&self) -> Vec<(String, CircuitState)> {
153        self.circuits
154            .iter()
155            .map(|entry| (entry.key().clone(), entry.state))
156            .collect()
157    }
158}
159
160#[cfg(test)]
161mod tests {
162    use super::*;
163
164    fn test_config() -> CircuitBreakerConfig {
165        CircuitBreakerConfig {
166            enabled: true,
167            failure_threshold: 3,
168            reset_timeout_secs: 1,
169        }
170    }
171
172    #[test]
173    fn test_circuit_starts_closed() {
174        let mgr = CircuitBreakerManager::new(&test_config());
175        assert!(mgr.should_allow("mermaid"));
176        assert_eq!(mgr.get_state("mermaid"), CircuitState::Closed);
177    }
178
179    #[test]
180    fn test_circuit_opens_after_threshold() {
181        let mgr = CircuitBreakerManager::new(&test_config());
182        // 3 consecutive failures should open the circuit
183        for _ in 0..3 {
184            mgr.record_failure("mermaid");
185        }
186        assert_eq!(mgr.get_state("mermaid"), CircuitState::Open);
187        assert!(!mgr.should_allow("mermaid"));
188    }
189
190    #[test]
191    fn test_success_resets_failure_count() {
192        let mgr = CircuitBreakerManager::new(&test_config());
193        mgr.record_failure("mermaid");
194        mgr.record_failure("mermaid");
195        mgr.record_success("mermaid"); // reset
196        mgr.record_failure("mermaid"); // only 1 failure now
197        assert_eq!(mgr.get_state("mermaid"), CircuitState::Closed);
198    }
199
200    #[test]
201    fn test_independent_providers() {
202        let mgr = CircuitBreakerManager::new(&test_config());
203        for _ in 0..3 {
204            mgr.record_failure("mermaid");
205        }
206        assert_eq!(mgr.get_state("mermaid"), CircuitState::Open);
207        assert_eq!(mgr.get_state("graphviz"), CircuitState::Closed);
208        assert!(mgr.should_allow("graphviz"));
209    }
210
211    #[test]
212    fn test_half_open_after_timeout() {
213        let config = CircuitBreakerConfig {
214            enabled: true,
215            failure_threshold: 1,
216            reset_timeout_secs: 0, // instant timeout for testing
217        };
218        let mgr = CircuitBreakerManager::new(&config);
219        mgr.record_failure("mermaid");
220        assert_eq!(mgr.get_state("mermaid"), CircuitState::Open);
221
222        // With reset_timeout_secs = 0, should transition to HalfOpen immediately
223        assert!(mgr.should_allow("mermaid"));
224        assert_eq!(mgr.get_state("mermaid"), CircuitState::HalfOpen);
225    }
226
227    #[test]
228    fn test_half_open_success_closes() {
229        let config = CircuitBreakerConfig {
230            enabled: true,
231            failure_threshold: 1,
232            reset_timeout_secs: 0,
233        };
234        let mgr = CircuitBreakerManager::new(&config);
235        mgr.record_failure("mermaid");
236        mgr.should_allow("mermaid"); // transitions to HalfOpen
237        mgr.record_success("mermaid");
238        assert_eq!(mgr.get_state("mermaid"), CircuitState::Closed);
239    }
240
241    #[test]
242    fn test_get_all_states() {
243        let mgr = CircuitBreakerManager::new(&test_config());
244        mgr.should_allow("mermaid");
245        mgr.should_allow("graphviz");
246        let states = mgr.get_all_states();
247        assert_eq!(states.len(), 2);
248    }
249}