kroki_rs/config/
mod.rs

1use serde::Deserialize;
2use std::env;
3use std::fs;
4use std::path::PathBuf;
5
6/// Main configuration for Kroki-rs.
7#[derive(Debug, Deserialize, Clone, Default)]
8pub struct Config {
9    #[serde(default)]
10    pub server: ServerConfig,
11    #[serde(default)]
12    pub browser: BrowserConfig,
13    #[serde(default)]
14    pub artifacts: ArtifactsConfig,
15    #[serde(default)]
16    pub plugins: Vec<PluginConfig>,
17    // Tool-specific configurations (TD-03)
18    #[serde(default)]
19    pub graphviz: ToolConfig,
20    #[serde(default)]
21    pub mermaid: ToolConfig,
22    #[serde(default)]
23    pub vega: ToolConfig,
24    #[serde(default)]
25    pub vegalite: ToolConfig,
26    #[serde(default)]
27    pub wavedrom: ToolConfig,
28    #[serde(default)]
29    pub bpmn: ToolConfig,
30    #[serde(default)]
31    pub d2: ToolConfig,
32    #[serde(default)]
33    pub ditaa: ToolConfig,
34    #[serde(default)]
35    pub excalidraw: ToolConfig,
36}
37
38#[derive(Debug, Deserialize, Clone, Default)]
39pub struct PluginConfig {
40    pub name: String,
41    pub command: String,
42    #[serde(default)]
43    pub args: Vec<String>,
44    #[serde(default = "default_true")]
45    pub stdin: bool,
46    #[serde(default)]
47    pub formats: Vec<String>,
48    pub timeout_ms: Option<u64>,
49}
50
51#[derive(Debug, Deserialize, Clone)]
52pub struct ServerConfig {
53    #[serde(default = "default_port")]
54    pub port: u16,
55    #[serde(default = "default_admin_port")]
56    pub admin_port: u16,
57    #[serde(default = "default_host")]
58    pub host: String,
59    #[serde(default = "default_log_level")]
60    pub log_level: String,
61    #[serde(default = "default_timeout")]
62    pub timeout_ms: u64,
63    /// Maximum allowed input size in bytes (default: 1MB).
64    #[serde(default = "default_max_input_size")]
65    pub max_input_size: usize,
66    /// Maximum allowed output size in bytes (default: 50MB).
67    #[serde(default = "default_max_output_size")]
68    pub max_output_size: usize,
69    /// Authentication configuration (disabled by default for dev mode).
70    #[serde(default)]
71    pub auth: AuthConfig,
72    /// Rate limiting configuration (disabled by default for dev mode).
73    #[serde(default)]
74    pub rate_limit: RateLimitConfig,
75    /// Circuit breaker configuration (disabled by default).
76    #[serde(default)]
77    pub circuit_breaker: CircuitBreakerConfig,
78    /// Metrics configuration (enabled by default).
79    #[serde(default)]
80    pub metrics: MetricsConfig,
81    /// Telemetry/OTel configuration (disabled by default).
82    #[serde(default)]
83    pub telemetry: TelemetryConfig,
84}
85
86impl Default for ServerConfig {
87    fn default() -> Self {
88        Self {
89            port: 8000,
90            admin_port: 8081,
91            host: "localhost".to_string(),
92            log_level: "info".to_string(),
93            timeout_ms: 5000,
94            max_input_size: 1_048_576,   // 1MB
95            max_output_size: 52_428_800, // 50MB
96            auth: AuthConfig::default(),
97            rate_limit: RateLimitConfig::default(),
98            circuit_breaker: CircuitBreakerConfig::default(),
99            metrics: MetricsConfig::default(),
100            telemetry: TelemetryConfig::default(),
101        }
102    }
103}
104
105#[derive(Debug, Deserialize, Clone, Default)]
106pub struct ToolConfig {
107    pub bin_path: Option<String>,
108    pub timeout_ms: Option<u64>,
109    pub config_path: Option<String>,
110    #[serde(default)]
111    pub fonts: Vec<String>,
112}
113
114impl ToolConfig {
115    pub fn apply_env_overrides(&mut self, prefix: &str) {
116        let prefix_upper = prefix.to_uppercase();
117        if let Ok(v) = env::var(format!("KROKI_{}_BIN", prefix_upper)) {
118            self.bin_path = Some(v);
119        }
120        if let Ok(v) = env::var(format!("KROKI_{}_TIMEOUT", prefix_upper)) {
121            if let Ok(t) = v.parse() {
122                self.timeout_ms = Some(t);
123            }
124        }
125        if let Ok(v) = env::var(format!("KROKI_{}_CONFIG", prefix_upper)) {
126            self.config_path = Some(v);
127        }
128    }
129}
130
131#[derive(Debug, Deserialize, Clone)]
132pub struct ApiKeyEntry {
133    pub key: String,
134    pub label: String,
135    /// Optional per-key rate limit (requests per second).
136    pub rate_limit: Option<u32>,
137}
138
139/// Authentication configuration.
140/// When `enabled = false` (default), all auth is bypassed (dev mode).
141#[derive(Debug, Deserialize, Clone)]
142pub struct AuthConfig {
143    #[serde(default)]
144    pub enabled: bool,
145    /// List of valid API keys with optional per-key rate limits.
146    #[serde(default)]
147    pub api_keys: Vec<ApiKeyEntry>,
148    /// HTTP header name for API key extraction.
149    #[serde(default = "default_auth_header")]
150    pub header_name: String,
151    /// Bcrpyt hash of the admin password.
152    pub admin_password_hash: Option<String>,
153}
154
155impl Default for AuthConfig {
156    fn default() -> Self {
157        Self {
158            enabled: false,
159            api_keys: Vec::new(),
160            header_name: "x-api-key".to_string(),
161            admin_password_hash: None,
162        }
163    }
164}
165
166/// Rate limiting configuration using token-bucket algorithm.
167/// When `enabled = false` (default), no rate limiting is applied.
168#[derive(Debug, Deserialize, Clone)]
169pub struct RateLimitConfig {
170    #[serde(default)]
171    pub enabled: bool,
172    /// Maximum sustained requests per second.
173    #[serde(default = "default_rps")]
174    pub requests_per_second: u32,
175    /// Maximum burst size above the sustained rate.
176    #[serde(default = "default_burst")]
177    pub burst_size: u32,
178}
179
180impl Default for RateLimitConfig {
181    fn default() -> Self {
182        Self {
183            enabled: false,
184            requests_per_second: 10,
185            burst_size: 50,
186        }
187    }
188}
189
190/// Circuit breaker configuration for per-provider failure isolation.
191/// When `enabled = false` (default), circuit breaker is not applied.
192#[derive(Debug, Deserialize, Clone)]
193pub struct CircuitBreakerConfig {
194    #[serde(default)]
195    pub enabled: bool,
196    /// Number of consecutive failures before the circuit opens.
197    #[serde(default = "default_failure_threshold")]
198    pub failure_threshold: u32,
199    /// Seconds to wait before transitioning from Open to Half-Open.
200    #[serde(default = "default_reset_timeout")]
201    pub reset_timeout_secs: u64,
202}
203
204impl Default for CircuitBreakerConfig {
205    fn default() -> Self {
206        Self {
207            enabled: false,
208            failure_threshold: 5,
209            reset_timeout_secs: 30,
210        }
211    }
212}
213
214/// Prometheus metrics configuration.
215#[derive(Debug, Deserialize, Clone)]
216pub struct MetricsConfig {
217    #[serde(default = "default_true")]
218    pub enabled: bool,
219    /// Whether to expose a /metrics endpoint on the admin server.
220    #[serde(default = "default_false")]
221    pub export_endpoint: bool,
222}
223
224impl Default for MetricsConfig {
225    fn default() -> Self {
226        Self {
227            enabled: true,
228            export_endpoint: true,
229        }
230    }
231}
232
233/// Telemetry (OpenTelemetry) configuration.
234#[derive(Debug, Deserialize, Clone, Default)]
235pub struct TelemetryConfig {
236    #[serde(default)]
237    pub enabled: bool,
238    /// OTLP exporter endpoint.
239    #[serde(default)]
240    pub otlp_endpoint: Option<String>,
241}
242
243#[derive(Debug, Deserialize, Clone)]
244pub struct BrowserConfig {
245    #[serde(default = "default_pool_size")]
246    pub pool_size: usize,
247    /// Number of requests after which a browser context is recreated to prevent memory leaks.
248    #[serde(default = "default_context_ttl")]
249    pub context_ttl_requests: usize,
250}
251
252impl Default for BrowserConfig {
253    fn default() -> Self {
254        Self {
255            pool_size: 4,
256            context_ttl_requests: 100,
257        }
258    }
259}
260
261#[derive(Debug, Deserialize, Clone, Default)]
262pub struct ArtifactsConfig {
263    pub cache_dir: Option<PathBuf>,
264}
265
266/// Supported output formats.
267pub const SUPPORTED_FORMATS: &[&str] = &["svg", "png", "pdf", "webp", "txt"];
268
269impl Config {
270    /// Loads the configuration from a path, environment variable, or default file.
271    pub fn load(path: Option<PathBuf>) -> anyhow::Result<Self> {
272        let path = if let Some(p) = path {
273            Some(p)
274        } else if let Ok(p) = env::var("KROKI_CONFIG") {
275            Some(PathBuf::from(p))
276        } else if fs::metadata("kroki.toml").is_ok() {
277            Some(PathBuf::from("kroki.toml"))
278        } else {
279            None
280        };
281
282        let mut config = if let Some(p) = path {
283            let content = fs::read_to_string(p)?; // TODO handle toml error nicely?
284            toml::from_str(&content)?
285        } else {
286            Config::default()
287        };
288
289        config.apply_env_overrides();
290
291        Ok(config)
292    }
293
294    fn apply_env_overrides(&mut self) {
295        if let Ok(v) = env::var("KROKI_PORT") {
296            if let Ok(p) = v.parse() {
297                self.server.port = p;
298            }
299        }
300        if let Ok(v) = env::var("KROKI_ADMIN_PORT") {
301            if let Ok(p) = v.parse() {
302                self.server.admin_port = p;
303            }
304        }
305        if let Ok(v) = env::var("KROKI_LOG_LEVEL") {
306            self.server.log_level = v;
307        }
308        if let Ok(v) = env::var("KROKI_HOST") {
309            self.server.host = v;
310        }
311        if let Ok(password) = env::var("KROKI_ADMIN_PASSWORD") {
312            if let Ok(hash) = bcrypt::hash(password, bcrypt::DEFAULT_COST) {
313                self.server.auth.admin_password_hash = Some(hash);
314            }
315        }
316        if let Ok(v) = env::var("KROKI_TIMEOUT") {
317            if let Ok(t) = v.parse() {
318                self.server.timeout_ms = t;
319            }
320        }
321        if let Ok(v) = env::var("KROKI_MAX_INPUT_SIZE") {
322            if let Ok(s) = v.parse() {
323                self.server.max_input_size = s;
324            }
325        }
326        if let Ok(v) = env::var("KROKI_MAX_OUTPUT_SIZE") {
327            if let Ok(s) = v.parse() {
328                self.server.max_output_size = s;
329            }
330        }
331
332        // Apply overrides for all tools
333        self.graphviz.apply_env_overrides("graphviz");
334        self.mermaid.apply_env_overrides("mermaid");
335        self.vega.apply_env_overrides("vega");
336        self.vegalite.apply_env_overrides("vegalite");
337        self.wavedrom.apply_env_overrides("wavedrom");
338        self.bpmn.apply_env_overrides("bpmn");
339        self.d2.apply_env_overrides("d2");
340        self.ditaa.apply_env_overrides("ditaa");
341        self.excalidraw.apply_env_overrides("excalidraw");
342    }
343
344    /// Resolves the cache directory, creating it if it doesn't exist.
345    pub fn resolve_cache_dir(custom_path: Option<PathBuf>) -> Option<PathBuf> {
346        let path = custom_path.or_else(|| {
347            dirs::cache_dir().map(|mut p| {
348                p.push("kroki-rs");
349                p
350            })
351        });
352
353        if let Some(ref p) = path {
354            let _ = fs::create_dir_all(p);
355        }
356        path
357    }
358
359    /// Aggregates font information from all configured tools.
360    pub fn all_fonts(&self) -> Vec<String> {
361        let mut fonts = Vec::new();
362        fonts.extend(self.graphviz.fonts.clone());
363        fonts.extend(self.mermaid.fonts.clone());
364        fonts.extend(self.vega.fonts.clone());
365        fonts.extend(self.vegalite.fonts.clone());
366        fonts.extend(self.wavedrom.fonts.clone());
367        fonts.extend(self.bpmn.fonts.clone());
368        fonts.extend(self.d2.fonts.clone());
369        fonts.extend(self.ditaa.fonts.clone());
370        fonts.extend(self.excalidraw.fonts.clone());
371        fonts
372    }
373}
374
375// Default helper functions for serde
376fn default_port() -> u16 {
377    8000
378}
379fn default_admin_port() -> u16 {
380    8081
381}
382fn default_host() -> String {
383    "localhost".to_string()
384}
385fn default_log_level() -> String {
386    "info".to_string()
387}
388fn default_timeout() -> u64 {
389    5000
390}
391fn default_max_input_size() -> usize {
392    1_048_576
393}
394fn default_max_output_size() -> usize {
395    52_428_800
396}
397fn default_auth_header() -> String {
398    "x-api-key".to_string()
399}
400fn default_rps() -> u32 {
401    10
402}
403fn default_burst() -> u32 {
404    50
405}
406fn default_failure_threshold() -> u32 {
407    5
408}
409fn default_reset_timeout() -> u64 {
410    30
411}
412fn default_pool_size() -> usize {
413    4
414}
415fn default_context_ttl() -> usize {
416    100
417}
418fn default_true() -> bool {
419    true
420}
421fn default_false() -> bool {
422    false
423}