kroki_rs/server/middleware/
circuit_breaker.rs1use crate::config::CircuitBreakerConfig;
9use dashmap::DashMap;
10use std::sync::Arc;
11use std::time::Instant;
12
13#[derive(Debug, Clone, Copy, PartialEq)]
15pub enum CircuitState {
16 Closed,
18 Open,
20 HalfOpen,
22}
23
24struct ProviderCircuit {
26 state: CircuitState,
27 consecutive_failures: u32,
28 last_failure_time: Option<Instant>,
29 failure_threshold: u32,
30 reset_timeout_secs: u64,
31}
32
33impl ProviderCircuit {
34 fn new(config: &CircuitBreakerConfig) -> Self {
35 Self {
36 state: CircuitState::Closed,
37 consecutive_failures: 0,
38 last_failure_time: None,
39 failure_threshold: config.failure_threshold,
40 reset_timeout_secs: config.reset_timeout_secs,
41 }
42 }
43
44 fn should_allow(&mut self) -> bool {
46 match self.state {
47 CircuitState::Closed => true,
48 CircuitState::Open => {
49 if let Some(last_fail) = self.last_failure_time {
51 if last_fail.elapsed().as_secs() >= self.reset_timeout_secs {
52 self.state = CircuitState::HalfOpen;
53 tracing::info!("Circuit breaker transitioning to HalfOpen");
54 true } else {
56 false
57 }
58 } else {
59 false
60 }
61 }
62 CircuitState::HalfOpen => {
63 false
66 }
67 }
68 }
69
70 fn record_success(&mut self) {
72 if self.state == CircuitState::HalfOpen {
73 tracing::info!("Circuit breaker closing after successful HalfOpen test");
74 }
75 self.state = CircuitState::Closed;
76 self.consecutive_failures = 0;
77 self.last_failure_time = None;
78 }
79
80 fn record_failure(&mut self) {
82 self.consecutive_failures += 1;
83 self.last_failure_time = Some(Instant::now());
84
85 if self.consecutive_failures >= self.failure_threshold {
86 self.state = CircuitState::Open;
87 tracing::warn!(
88 "Circuit breaker opened after {} consecutive failures",
89 self.consecutive_failures
90 );
91 }
92
93 if self.state == CircuitState::HalfOpen {
95 self.state = CircuitState::Open;
96 tracing::warn!("Circuit breaker re-opened after HalfOpen test failure");
97 }
98 }
99}
100
101#[derive(Clone)]
103pub struct CircuitBreakerManager {
104 circuits: Arc<DashMap<String, ProviderCircuit>>,
105 config: CircuitBreakerConfig,
106}
107
108impl CircuitBreakerManager {
109 pub fn new(config: &CircuitBreakerConfig) -> Self {
111 Self {
112 circuits: Arc::new(DashMap::new()),
113 config: config.clone(),
114 }
115 }
116
117 pub fn should_allow(&self, provider: &str) -> bool {
120 let mut entry = self
121 .circuits
122 .entry(provider.to_string())
123 .or_insert_with(|| ProviderCircuit::new(&self.config));
124 entry.should_allow()
125 }
126
127 pub fn record_success(&self, provider: &str) {
129 if let Some(mut entry) = self.circuits.get_mut(provider) {
130 entry.record_success();
131 }
132 }
133
134 pub fn record_failure(&self, provider: &str) {
136 let mut entry = self
137 .circuits
138 .entry(provider.to_string())
139 .or_insert_with(|| ProviderCircuit::new(&self.config));
140 entry.record_failure();
141 }
142
143 pub fn get_state(&self, provider: &str) -> CircuitState {
145 self.circuits
146 .get(provider)
147 .map(|entry| entry.state)
148 .unwrap_or(CircuitState::Closed)
149 }
150
151 pub fn get_all_states(&self) -> Vec<(String, CircuitState)> {
153 self.circuits
154 .iter()
155 .map(|entry| (entry.key().clone(), entry.state))
156 .collect()
157 }
158}
159
160#[cfg(test)]
161mod tests {
162 use super::*;
163
164 fn test_config() -> CircuitBreakerConfig {
165 CircuitBreakerConfig {
166 enabled: true,
167 failure_threshold: 3,
168 reset_timeout_secs: 1,
169 }
170 }
171
172 #[test]
173 fn test_circuit_starts_closed() {
174 let mgr = CircuitBreakerManager::new(&test_config());
175 assert!(mgr.should_allow("mermaid"));
176 assert_eq!(mgr.get_state("mermaid"), CircuitState::Closed);
177 }
178
179 #[test]
180 fn test_circuit_opens_after_threshold() {
181 let mgr = CircuitBreakerManager::new(&test_config());
182 for _ in 0..3 {
184 mgr.record_failure("mermaid");
185 }
186 assert_eq!(mgr.get_state("mermaid"), CircuitState::Open);
187 assert!(!mgr.should_allow("mermaid"));
188 }
189
190 #[test]
191 fn test_success_resets_failure_count() {
192 let mgr = CircuitBreakerManager::new(&test_config());
193 mgr.record_failure("mermaid");
194 mgr.record_failure("mermaid");
195 mgr.record_success("mermaid"); mgr.record_failure("mermaid"); assert_eq!(mgr.get_state("mermaid"), CircuitState::Closed);
198 }
199
200 #[test]
201 fn test_independent_providers() {
202 let mgr = CircuitBreakerManager::new(&test_config());
203 for _ in 0..3 {
204 mgr.record_failure("mermaid");
205 }
206 assert_eq!(mgr.get_state("mermaid"), CircuitState::Open);
207 assert_eq!(mgr.get_state("graphviz"), CircuitState::Closed);
208 assert!(mgr.should_allow("graphviz"));
209 }
210
211 #[test]
212 fn test_half_open_after_timeout() {
213 let config = CircuitBreakerConfig {
214 enabled: true,
215 failure_threshold: 1,
216 reset_timeout_secs: 0, };
218 let mgr = CircuitBreakerManager::new(&config);
219 mgr.record_failure("mermaid");
220 assert_eq!(mgr.get_state("mermaid"), CircuitState::Open);
221
222 assert!(mgr.should_allow("mermaid"));
224 assert_eq!(mgr.get_state("mermaid"), CircuitState::HalfOpen);
225 }
226
227 #[test]
228 fn test_half_open_success_closes() {
229 let config = CircuitBreakerConfig {
230 enabled: true,
231 failure_threshold: 1,
232 reset_timeout_secs: 0,
233 };
234 let mgr = CircuitBreakerManager::new(&config);
235 mgr.record_failure("mermaid");
236 mgr.should_allow("mermaid"); mgr.record_success("mermaid");
238 assert_eq!(mgr.get_state("mermaid"), CircuitState::Closed);
239 }
240
241 #[test]
242 fn test_get_all_states() {
243 let mgr = CircuitBreakerManager::new(&test_config());
244 mgr.should_allow("mermaid");
245 mgr.should_allow("graphviz");
246 let states = mgr.get_all_states();
247 assert_eq!(states.len(), 2);
248 }
249}