kroki_rs/server/middleware/
auth.rs

1//! API Key authentication middleware for the Kroki-rs server.
2//!
3//! When `server.auth.enabled = true`, extracts any API key or admin credentials
4//! and validates them.
5//! When disabled (the default), all requests pass through — enabling fast local development.
6
7use crate::config::AuthConfig;
8use axum::{
9    body::Body,
10    extract::Request,
11    http::StatusCode,
12    middleware::Next,
13    response::{IntoResponse, Response},
14};
15
16/// Axum middleware that enforces API key authentication.
17///
18/// Skipped entirely when `auth.enabled = false` (dev mode).
19/// Returns 401 with a JSON body if the key is missing or invalid.
20pub async fn auth_middleware(
21    state: axum::extract::State<crate::server::AppState>,
22    request: Request<Body>,
23    next: Next,
24) -> Response {
25    let auth_config = &state.config.server.auth;
26
27    if !auth_config.enabled {
28        return next.run(request).await;
29    }
30
31    let header_name = &auth_config.header_name;
32    let api_key = request
33        .headers()
34        .get(header_name)
35        .and_then(|v| v.to_str().ok());
36
37    match api_key {
38        Some(key) => {
39            if auth_config.api_keys.iter().any(|entry| entry.key == key) {
40                next.run(request).await
41            } else {
42                tracing::warn!("Invalid API key presented");
43                (
44                    StatusCode::UNAUTHORIZED,
45                    serde_json::json!({
46                        "error": "unauthorized",
47                        "message": "Invalid API key"
48                    })
49                    .to_string(),
50                )
51                    .into_response()
52            }
53        }
54        None => {
55            tracing::warn!("Missing API key in header '{}'", header_name);
56            (
57                StatusCode::UNAUTHORIZED,
58                serde_json::json!({
59                    "error": "unauthorized",
60                    "message": format!("Missing API key. Provide it via the '{}' header.", header_name)
61                })
62                .to_string(),
63            )
64                .into_response()
65        }
66    }
67}
68
69/// Axum middleware that enforces admin authentication via Basic Auth.
70///
71/// Authentication is bypasses if:
72/// 1. `auth.enabled = false` (dev mode)
73/// 2. `auth.admin_password_hash` is not configured
74///
75/// Otherwise, expects "Authorization: Basic `<base64>`" header.
76pub async fn admin_auth_middleware(
77    state: axum::extract::State<crate::server::AppState>,
78    request: Request<Body>,
79    next: Next,
80) -> Response {
81    let auth_config = &state.config.server.auth;
82
83    // Bypass if disabled or no password hash set
84    if !auth_config.enabled || auth_config.admin_password_hash.is_none() {
85        return next.run(request).await;
86    }
87
88    let auth_header = request
89        .headers()
90        .get(axum::http::header::AUTHORIZATION)
91        .and_then(|v| v.to_str().ok());
92
93    let authenticated = if let Some(header) = auth_header {
94        if let Some(encoded) = header.strip_prefix("Basic ") {
95            if let Ok(decoded) = base64::Engine::decode(&base64::prelude::BASE64_STANDARD, encoded)
96            {
97                if let Ok(credentials) = String::from_utf8(decoded) {
98                    if let Some((_user, password)) = credentials.split_once(':') {
99                        if let Some(hash) = &auth_config.admin_password_hash {
100                            bcrypt::verify(password, hash).unwrap_or(false)
101                        } else {
102                            false
103                        }
104                    } else {
105                        false
106                    }
107                } else {
108                    false
109                }
110            } else {
111                false
112            }
113        } else {
114            false
115        }
116    } else {
117        false
118    };
119
120    if authenticated {
121        next.run(request).await
122    } else {
123        tracing::warn!("Admin authentication failed");
124        (
125            StatusCode::UNAUTHORIZED,
126            [(
127                axum::http::header::WWW_AUTHENTICATE,
128                "Basic realm=\"Kroki Admin\"",
129            )],
130            serde_json::json!({
131                "error": "unauthorized",
132                "message": "Admin authentication required"
133            })
134            .to_string(),
135        )
136            .into_response()
137    }
138}
139
140/// Looks up the `ApiKeyEntry` for a given key string.
141/// Returns `None` if the key is not found or auth is disabled.
142pub fn find_api_key_entry<'a>(
143    auth_config: &'a AuthConfig,
144    key: &str,
145) -> Option<&'a crate::config::ApiKeyEntry> {
146    auth_config.api_keys.iter().find(|entry| entry.key == key)
147}
148
149#[cfg(test)]
150mod tests {
151    use super::*;
152    use crate::config::{ApiKeyEntry, AuthConfig};
153
154    #[test]
155    fn test_find_api_key_entry_found() {
156        let config = AuthConfig {
157            enabled: true,
158            api_keys: vec![ApiKeyEntry {
159                key: "test-key".to_string(),
160                label: "test".to_string(),
161                rate_limit: Some(10),
162            }],
163            header_name: "x-api-key".to_string(),
164            admin_password_hash: None,
165        };
166        let entry = find_api_key_entry(&config, "test-key");
167        assert!(entry.is_some());
168        assert_eq!(entry.unwrap().label, "test");
169    }
170
171    #[test]
172    fn test_find_api_key_entry_not_found() {
173        let config = AuthConfig {
174            enabled: true,
175            api_keys: vec![ApiKeyEntry {
176                key: "test-key".to_string(),
177                label: "test".to_string(),
178                rate_limit: None,
179            }],
180            header_name: "x-api-key".to_string(),
181            admin_password_hash: None,
182        };
183        assert!(find_api_key_entry(&config, "wrong-key").is_none());
184    }
185
186    #[test]
187    fn test_find_api_key_entry_empty_keys() {
188        let config = AuthConfig::default();
189        assert!(find_api_key_entry(&config, "any-key").is_none());
190    }
191}