kroki_rs/server/middleware/
rate_limit.rs1use crate::config::RateLimitConfig;
7use axum::{
8 body::Body,
9 extract::{connect_info::ConnectInfo, Request, State},
10 http::{HeaderValue, StatusCode},
11 middleware::Next,
12 response::{IntoResponse, Response},
13};
14use dashmap::DashMap;
15use std::net::{IpAddr, SocketAddr};
16use std::sync::Arc;
17use std::time::Instant;
18
19struct TokenBucket {
21 tokens: f64,
22 last_refill: Instant,
23 max_tokens: f64,
24 refill_rate: f64, }
26
27impl TokenBucket {
28 fn new(burst_size: u32, refill_rate: u32) -> Self {
29 Self {
30 tokens: burst_size as f64,
31 last_refill: Instant::now(),
32 max_tokens: burst_size as f64,
33 refill_rate: refill_rate as f64,
34 }
35 }
36
37 fn try_consume(&mut self) -> bool {
39 let now = Instant::now();
40 let elapsed = now.duration_since(self.last_refill).as_secs_f64();
41 self.tokens = (self.tokens + elapsed * self.refill_rate).min(self.max_tokens);
42 self.last_refill = now;
43
44 if self.tokens >= 1.0 {
45 self.tokens -= 1.0;
46 true
47 } else {
48 false
49 }
50 }
51
52 fn retry_after(&self) -> u64 {
54 if self.refill_rate <= 0.0 {
55 return 60;
56 }
57 let deficit = 1.0 - self.tokens;
58 (deficit / self.refill_rate).ceil() as u64
59 }
60}
61
62#[derive(Clone)]
64pub struct RateLimiter {
65 buckets: Arc<DashMap<IpAddr, TokenBucket>>,
66 config: RateLimitConfig,
67}
68
69impl RateLimiter {
70 pub fn new(config: &RateLimitConfig) -> Self {
72 Self {
73 buckets: Arc::new(DashMap::new()),
74 config: config.clone(),
75 }
76 }
77
78 pub fn check(&self, ip: IpAddr) -> Result<(), u64> {
81 let mut entry = self.buckets.entry(ip).or_insert_with(|| {
82 TokenBucket::new(self.config.burst_size, self.config.requests_per_second)
83 });
84
85 if entry.try_consume() {
86 Ok(())
87 } else {
88 Err(entry.retry_after())
89 }
90 }
91}
92
93pub async fn rate_limit_middleware(
98 State(state): State<crate::server::AppState>,
99 request: Request<Body>,
100 next: Next,
101) -> Response {
102 if !state.config.server.rate_limit.enabled {
103 return next.run(request).await;
104 }
105
106 let fallback_ip = request
107 .extensions()
108 .get::<ConnectInfo<SocketAddr>>()
109 .map(|info| info.0.ip())
110 .unwrap_or_else(|| "127.0.0.1".parse().unwrap());
111
112 let ip = extract_client_ip(&request, fallback_ip);
113
114 if let Some(ref limiter) = state.rate_limiter {
115 match limiter.check(ip) {
116 Ok(()) => next.run(request).await,
117 Err(retry_after) => {
118 tracing::warn!("Rate limit exceeded for IP: {}", ip);
119 let mut response = (
120 StatusCode::TOO_MANY_REQUESTS,
121 serde_json::json!({
122 "error": "rate_limit_exceeded",
123 "message": "Too many requests. Please retry later.",
124 "retry_after_seconds": retry_after
125 })
126 .to_string(),
127 )
128 .into_response();
129
130 response.headers_mut().insert(
131 "retry-after",
132 HeaderValue::from_str(&retry_after.to_string())
133 .unwrap_or_else(|_| HeaderValue::from_static("60")),
134 );
135
136 response
137 }
138 }
139 } else {
140 next.run(request).await
141 }
142}
143
144fn extract_client_ip(request: &Request<Body>, fallback: IpAddr) -> IpAddr {
148 if let Some(forwarded) = request.headers().get("x-forwarded-for") {
150 if let Ok(forwarded_str) = forwarded.to_str() {
151 if let Some(first_ip) = forwarded_str.split(',').next() {
152 if let Ok(ip) = first_ip.trim().parse::<IpAddr>() {
153 return ip;
154 }
155 }
156 }
157 }
158
159 if let Some(real_ip) = request.headers().get("x-real-ip") {
161 if let Ok(ip_str) = real_ip.to_str() {
162 if let Ok(ip) = ip_str.parse::<IpAddr>() {
163 return ip;
164 }
165 }
166 }
167
168 fallback
169}
170
171#[cfg(test)]
172mod tests {
173 use super::*;
174
175 #[test]
176 fn test_token_bucket_allows_burst() {
177 let mut bucket = TokenBucket::new(5, 1);
178 for _ in 0..5 {
179 assert!(bucket.try_consume());
180 }
181 assert!(!bucket.try_consume());
183 }
184
185 #[test]
186 fn test_token_bucket_refills() {
187 let mut bucket = TokenBucket::new(1, 1000); assert!(bucket.try_consume());
189 assert!(!bucket.try_consume());
190 bucket.last_refill = Instant::now() - std::time::Duration::from_millis(10);
192 assert!(bucket.try_consume());
193 }
194
195 #[test]
196 fn test_rate_limiter_per_ip() {
197 let config = RateLimitConfig {
198 enabled: true,
199 requests_per_second: 1,
200 burst_size: 2,
201 };
202 let limiter = RateLimiter::new(&config);
203 let ip1: IpAddr = "192.168.1.1".parse().unwrap();
204 let ip2: IpAddr = "192.168.1.2".parse().unwrap();
205
206 assert!(limiter.check(ip1).is_ok());
208 assert!(limiter.check(ip1).is_ok());
209 assert!(limiter.check(ip1).is_err()); assert!(limiter.check(ip2).is_ok()); }
213}