kroki_rs/cli/
mod.rs

1use crate::capabilities::Capabilities;
2use crate::config::Config;
3use crate::diagrams::registry::DiagramRegistry;
4use crate::utils::image_converter;
5use anyhow::{Context, Result};
6use num_cpus;
7use sha2::{Digest, Sha256};
8use std::path::PathBuf;
9use std::sync::Arc;
10use tokio::fs;
11use tokio::io::AsyncWriteExt;
12use tokio::sync::Semaphore;
13use walkdir::WalkDir;
14
15/// Resolves WebP format to the appropriate base format for generation.
16fn resolve_base_format<'a>(format: &'a str, type_: &str) -> (&'a str, bool) {
17    let is_webp = format.eq_ignore_ascii_case("webp");
18    if is_webp {
19        if type_.eq_ignore_ascii_case("ditaa") {
20            ("png", true)
21        } else {
22            ("svg", true)
23        }
24    } else {
25        (format, false)
26    }
27}
28
29/// Core generation pipeline shared by convert and batch.
30/// Returns the final output bytes (including optional WebP conversion).
31async fn generate_diagram(
32    source: &str,
33    type_: &str,
34    format: &str,
35    config: &Config,
36    registry: &DiagramRegistry,
37    cache_dir: &Option<PathBuf>,
38) -> Result<Vec<u8>> {
39    let provider = registry.get(type_).context(format!(
40        "Diagram type '{}' not supported or tool not found",
41        type_
42    ))?;
43
44    // Validate input size
45    if source.len() > config.server.max_input_size {
46        anyhow::bail!(
47            "Input too large ({} bytes). Maximum allowed: {} bytes. Configure via server.max_input_size in kroki.toml.",
48            source.len(),
49            config.server.max_input_size
50        );
51    }
52
53    // Caching: compute hash (includes fonts & plugin configuration to avoid stale cache)
54    let mut fonts = config.all_fonts();
55    fonts.sort();
56    fonts.dedup();
57
58    let mut plugin_signatures: Vec<String> = config
59        .plugins
60        .iter()
61        .map(|plugin| {
62            let args = plugin.args.join(",");
63            let formats = plugin.formats.join(",");
64            format!(
65                "{}|{}|{}|{}|{}|{}",
66                plugin.name,
67                plugin.command,
68                args,
69                formats,
70                plugin.stdin,
71                plugin.timeout_ms.unwrap_or(0)
72            )
73        })
74        .collect();
75    plugin_signatures.sort();
76
77    let mut hasher = Sha256::new();
78    hasher.update(type_);
79    hasher.update(format);
80    hasher.update(source);
81    hasher.update(b"fonts:");
82    for font in &fonts {
83        hasher.update(font.as_bytes());
84        hasher.update([0]);
85    }
86    hasher.update(b"plugins:");
87    for signature in &plugin_signatures {
88        hasher.update(signature.as_bytes());
89        hasher.update([0]);
90    }
91    let hash = hex::encode(hasher.finalize());
92
93    // Check cache
94    if let Some(cache_path) = cache_dir {
95        if !cache_path.exists() {
96            fs::create_dir_all(cache_path).await.ok();
97        }
98        let cached_file = cache_path.join(format!("{}.{}", hash, format));
99        if cached_file.exists() {
100            if let Ok(content) = fs::read(&cached_file).await {
101                tracing::info!("Cache hit! Served from {}", cached_file.display());
102                return Ok(content);
103            }
104        }
105    }
106
107    // Validate
108    provider
109        .validate(source)
110        .context("Source validation failed")?;
111
112    // Generate with format resolution
113    let (base_format, is_webp) = resolve_base_format(format, type_);
114    let mut output_bytes = provider
115        .generate(source, base_format)
116        .await
117        .context("Diagram generation failed")?;
118
119    // WebP post-processing
120    if is_webp {
121        let fonts = config.all_fonts();
122        output_bytes = if base_format == "png" {
123            image_converter::png_to_webp(&output_bytes, image_converter::WebpQuality::Lossless)
124                .await
125                .context("Failed to convert PNG to WebP")?
126        } else {
127            image_converter::svg_to_webp(
128                &output_bytes,
129                image_converter::WebpQuality::Lossless,
130                &fonts,
131                cache_dir.as_deref(),
132            )
133            .await
134            .context("Failed to convert SVG to WebP")?
135        };
136    }
137
138    // Write to cache
139    if let Some(cache_path) = cache_dir {
140        let cached_file = cache_path.join(format!("{}.{}", hash, format));
141        if let Err(e) = fs::write(&cached_file, &output_bytes).await {
142            tracing::warn!("Failed to write to cache: {}", e);
143        } else {
144            tracing::info!("Saved to cache: {}", cached_file.display());
145        }
146    }
147
148    Ok(output_bytes)
149}
150
151pub async fn convert(
152    type_: String,
153    format: String,
154    input: PathBuf,
155    config: Config,
156    cache_dir: Option<PathBuf>,
157) -> Result<()> {
158    let capabilities = Capabilities::discover(&config);
159    let browser_manager = match crate::browser::BrowserManager::start(
160        config.browser.pool_size,
161        config.browser.context_ttl_requests,
162    )
163    .await
164    {
165        Ok(m) => Some(Arc::new(m)),
166        Err(e) => {
167            tracing::warn!("Browser worker failed to start: {}", e);
168            None
169        }
170    };
171    let registry = DiagramRegistry::new(&capabilities, &config, browser_manager);
172    let cache_dir = Config::resolve_cache_dir(cache_dir);
173
174    let source = fs::read_to_string(&input)
175        .await
176        .context(format!("Failed to read input file '{}'", input.display()))?;
177
178    let output_bytes =
179        generate_diagram(&source, &type_, &format, &config, &registry, &cache_dir).await?;
180
181    let mut stdout = tokio::io::stdout();
182    stdout.write_all(&output_bytes).await?;
183    stdout.flush().await?;
184
185    Ok(())
186}
187
188pub async fn batch(
189    format: String,
190    input_dir: PathBuf,
191    type_override: Option<String>,
192    out_dir: Option<PathBuf>,
193    config: Config,
194    cache_dir: Option<PathBuf>,
195) -> Result<()> {
196    if !input_dir.is_dir() {
197        return Err(anyhow::anyhow!("Input must be a directory"));
198    }
199
200    let files: Vec<PathBuf> = WalkDir::new(&input_dir)
201        .into_iter()
202        .filter_map(|e| e.ok())
203        .filter(|e| e.file_type().is_file())
204        .map(|e| e.path().to_owned())
205        .collect();
206
207    tracing::info!("Found {} files in {}", files.len(), input_dir.display());
208
209    let capabilities = Capabilities::discover(&config);
210    let browser_manager = match crate::browser::BrowserManager::start(
211        config.browser.pool_size,
212        config.browser.context_ttl_requests,
213    )
214    .await
215    {
216        Ok(m) => Some(Arc::new(m)),
217        Err(e) => {
218            tracing::warn!("Browser worker failed to start: {}", e);
219            None
220        }
221    };
222    let registry = Arc::new(DiagramRegistry::new(
223        &capabilities,
224        &config,
225        browser_manager,
226    ));
227    let config = Arc::new(config);
228    let cache_dir = Arc::new(Config::resolve_cache_dir(cache_dir));
229    let format = Arc::new(format);
230    let type_override = Arc::new(type_override);
231    let out_dir = Arc::new(out_dir);
232
233    let mut tasks = Vec::new();
234    let failure_count = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
235
236    let concurrency = std::cmp::max(1, num_cpus::get());
237    let semaphore = Arc::new(Semaphore::new(concurrency));
238
239    for file_path in files {
240        let extension = file_path
241            .extension()
242            .and_then(|e| e.to_str())
243            .unwrap_or("")
244            .to_lowercase();
245
246        let type_ = if let Some(t) = type_override.as_ref() {
247            Some(t.clone())
248        } else {
249            match extension.as_str() {
250                "d2" => Some("d2".to_string()),
251                "dot" | "gv" => Some("graphviz".to_string()),
252                "mmd" | "mermaid" => Some("mermaid".to_string()),
253                "excalidraw" => Some("excalidraw".to_string()),
254                "bpmn" => Some("bpmn".to_string()),
255                "vega" => Some("vega".to_string()),
256                "vl" => Some("vegalite".to_string()),
257                _ => {
258                    if file_path.to_string_lossy().ends_with(".vl.json") {
259                        Some("vegalite".to_string())
260                    } else {
261                        None
262                    }
263                }
264            }
265        };
266
267        if let Some(t) = type_ {
268            let config = config.clone();
269            let registry = registry.clone();
270            let cache_dir = cache_dir.clone();
271            let format = format.clone();
272            let out_dir = out_dir.clone();
273            let input_dir = input_dir.clone();
274            let failure_count = failure_count.clone();
275
276            let semaphore = semaphore.clone();
277            tasks.push(tokio::spawn(async move {
278                let _permit = semaphore.acquire_owned().await.unwrap();
279                let relative_path = file_path.strip_prefix(&input_dir).unwrap_or(&file_path);
280                let mut output_path = if let Some(out) = out_dir.as_ref() {
281                    out.join(relative_path)
282                } else {
283                    file_path.clone()
284                };
285                output_path.set_extension(format.as_str());
286
287                if let Some(parent) = output_path.parent() {
288                    fs::create_dir_all(parent).await.ok();
289                }
290
291                let source = match fs::read_to_string(&file_path).await {
292                    Ok(s) => s,
293                    Err(e) => {
294                        tracing::error!("Failed to read {}: {}", file_path.display(), e);
295                        failure_count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
296                        return;
297                    }
298                };
299
300                match generate_diagram(&source, &t, &format, &config, &registry, &cache_dir).await {
301                    Ok(bytes) => {
302                        if let Err(e) = fs::write(&output_path, &bytes).await {
303                            tracing::error!("Failed to write {}: {}", output_path.display(), e);
304                            failure_count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
305                        } else {
306                            tracing::info!(
307                                "Converted: {} -> {}",
308                                file_path.display(),
309                                output_path.display()
310                            );
311                        }
312                    }
313                    Err(e) => {
314                        tracing::error!("Failed to convert {}: {}", file_path.display(), e);
315                        failure_count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
316                    }
317                }
318            }));
319        }
320    }
321
322    for task in tasks {
323        task.await?;
324    }
325
326    let failures = failure_count.load(std::sync::atomic::Ordering::Relaxed);
327    if failures > 0 {
328        anyhow::bail!("{} file(s) failed to convert", failures);
329    }
330
331    Ok(())
332}