kroki_rs/server/middleware/
rate_limit.rs

1//! Token-bucket rate limiting middleware for the Kroki-rs server.
2//!
3//! Implements per-IP rate limiting using a concurrent `DashMap` for O(1) lookups.
4//! When `server.rate_limit.enabled = false` (default), all requests pass through.
5
6use crate::config::RateLimitConfig;
7use axum::{
8    body::Body,
9    extract::{connect_info::ConnectInfo, Request, State},
10    http::{HeaderValue, StatusCode},
11    middleware::Next,
12    response::{IntoResponse, Response},
13};
14use dashmap::DashMap;
15use std::net::{IpAddr, SocketAddr};
16use std::sync::Arc;
17use std::time::Instant;
18
19/// A token bucket for a single client IP.
20struct TokenBucket {
21    tokens: f64,
22    last_refill: Instant,
23    max_tokens: f64,
24    refill_rate: f64, // tokens per second
25}
26
27impl TokenBucket {
28    fn new(burst_size: u32, refill_rate: u32) -> Self {
29        Self {
30            tokens: burst_size as f64,
31            last_refill: Instant::now(),
32            max_tokens: burst_size as f64,
33            refill_rate: refill_rate as f64,
34        }
35    }
36
37    /// Attempts to consume a token. Returns true if allowed, false if rate limited.
38    fn try_consume(&mut self) -> bool {
39        let now = Instant::now();
40        let elapsed = now.duration_since(self.last_refill).as_secs_f64();
41        self.tokens = (self.tokens + elapsed * self.refill_rate).min(self.max_tokens);
42        self.last_refill = now;
43
44        if self.tokens >= 1.0 {
45            self.tokens -= 1.0;
46            true
47        } else {
48            false
49        }
50    }
51
52    /// Returns the estimated seconds until a token is available.
53    fn retry_after(&self) -> u64 {
54        if self.refill_rate <= 0.0 {
55            return 60;
56        }
57        let deficit = 1.0 - self.tokens;
58        (deficit / self.refill_rate).ceil() as u64
59    }
60}
61
62/// Shared rate limiter state keyed by client IP.
63#[derive(Clone)]
64pub struct RateLimiter {
65    buckets: Arc<DashMap<IpAddr, TokenBucket>>,
66    config: RateLimitConfig,
67}
68
69impl RateLimiter {
70    /// Creates a new rate limiter from configuration.
71    pub fn new(config: &RateLimitConfig) -> Self {
72        Self {
73            buckets: Arc::new(DashMap::new()),
74            config: config.clone(),
75        }
76    }
77
78    /// Attempts to allow a request from the given IP.
79    /// Returns `Ok(())` if allowed, or `Err(retry_after_secs)` if rate limited.
80    pub fn check(&self, ip: IpAddr) -> Result<(), u64> {
81        let mut entry = self.buckets.entry(ip).or_insert_with(|| {
82            TokenBucket::new(self.config.burst_size, self.config.requests_per_second)
83        });
84
85        if entry.try_consume() {
86            Ok(())
87        } else {
88            Err(entry.retry_after())
89        }
90    }
91}
92
93/// Axum middleware that enforces rate limits.
94///
95/// Skipped entirely when `rate_limit.enabled = false` (dev mode).
96/// Returns 429 with `Retry-After` header when the limit is exceeded.
97pub async fn rate_limit_middleware(
98    State(state): State<crate::server::AppState>,
99    request: Request<Body>,
100    next: Next,
101) -> Response {
102    if !state.config.server.rate_limit.enabled {
103        return next.run(request).await;
104    }
105
106    let fallback_ip = request
107        .extensions()
108        .get::<ConnectInfo<SocketAddr>>()
109        .map(|info| info.0.ip())
110        .unwrap_or_else(|| "127.0.0.1".parse().unwrap());
111
112    let ip = extract_client_ip(&request, fallback_ip);
113
114    if let Some(ref limiter) = state.rate_limiter {
115        match limiter.check(ip) {
116            Ok(()) => next.run(request).await,
117            Err(retry_after) => {
118                tracing::warn!("Rate limit exceeded for IP: {}", ip);
119                let mut response = (
120                    StatusCode::TOO_MANY_REQUESTS,
121                    serde_json::json!({
122                        "error": "rate_limit_exceeded",
123                        "message": "Too many requests. Please retry later.",
124                        "retry_after_seconds": retry_after
125                    })
126                    .to_string(),
127                )
128                    .into_response();
129
130                response.headers_mut().insert(
131                    "retry-after",
132                    HeaderValue::from_str(&retry_after.to_string())
133                        .unwrap_or_else(|_| HeaderValue::from_static("60")),
134                );
135
136                response
137            }
138        }
139    } else {
140        next.run(request).await
141    }
142}
143
144/// Extracts the client IP from the request.
145/// Checks `X-Forwarded-For` first (for reverse proxy setups), then falls back to
146/// `X-Real-IP`, and finally defaults to `127.0.0.1`.
147fn extract_client_ip(request: &Request<Body>, fallback: IpAddr) -> IpAddr {
148    // Try X-Forwarded-For (first IP in the chain)
149    if let Some(forwarded) = request.headers().get("x-forwarded-for") {
150        if let Ok(forwarded_str) = forwarded.to_str() {
151            if let Some(first_ip) = forwarded_str.split(',').next() {
152                if let Ok(ip) = first_ip.trim().parse::<IpAddr>() {
153                    return ip;
154                }
155            }
156        }
157    }
158
159    // Try X-Real-IP
160    if let Some(real_ip) = request.headers().get("x-real-ip") {
161        if let Ok(ip_str) = real_ip.to_str() {
162            if let Ok(ip) = ip_str.parse::<IpAddr>() {
163                return ip;
164            }
165        }
166    }
167
168    fallback
169}
170
171#[cfg(test)]
172mod tests {
173    use super::*;
174
175    #[test]
176    fn test_token_bucket_allows_burst() {
177        let mut bucket = TokenBucket::new(5, 1);
178        for _ in 0..5 {
179            assert!(bucket.try_consume());
180        }
181        // 6th request should be denied
182        assert!(!bucket.try_consume());
183    }
184
185    #[test]
186    fn test_token_bucket_refills() {
187        let mut bucket = TokenBucket::new(1, 1000); // very fast refill
188        assert!(bucket.try_consume());
189        assert!(!bucket.try_consume());
190        // Simulate time passing
191        bucket.last_refill = Instant::now() - std::time::Duration::from_millis(10);
192        assert!(bucket.try_consume());
193    }
194
195    #[test]
196    fn test_rate_limiter_per_ip() {
197        let config = RateLimitConfig {
198            enabled: true,
199            requests_per_second: 1,
200            burst_size: 2,
201        };
202        let limiter = RateLimiter::new(&config);
203        let ip1: IpAddr = "192.168.1.1".parse().unwrap();
204        let ip2: IpAddr = "192.168.1.2".parse().unwrap();
205
206        // Each IP gets its own bucket
207        assert!(limiter.check(ip1).is_ok());
208        assert!(limiter.check(ip1).is_ok());
209        assert!(limiter.check(ip1).is_err()); // exhausted for ip1
210
211        assert!(limiter.check(ip2).is_ok()); // ip2 is independent
212    }
213}