1use crate::capabilities::Capabilities;
2use crate::config::Config;
3use crate::diagrams::registry::DiagramRegistry;
4use crate::utils::image_converter;
5use anyhow::{Context, Result};
6use num_cpus;
7use sha2::{Digest, Sha256};
8use std::path::PathBuf;
9use std::sync::Arc;
10use tokio::fs;
11use tokio::io::AsyncWriteExt;
12use tokio::sync::Semaphore;
13use walkdir::WalkDir;
14
15fn resolve_base_format<'a>(format: &'a str, type_: &str) -> (&'a str, bool) {
17 let is_webp = format.eq_ignore_ascii_case("webp");
18 if is_webp {
19 if type_.eq_ignore_ascii_case("ditaa") {
20 ("png", true)
21 } else {
22 ("svg", true)
23 }
24 } else {
25 (format, false)
26 }
27}
28
29async fn generate_diagram(
32 source: &str,
33 type_: &str,
34 format: &str,
35 config: &Config,
36 registry: &DiagramRegistry,
37 cache_dir: &Option<PathBuf>,
38) -> Result<Vec<u8>> {
39 let provider = registry.get(type_).context(format!(
40 "Diagram type '{}' not supported or tool not found",
41 type_
42 ))?;
43
44 if source.len() > config.server.max_input_size {
46 anyhow::bail!(
47 "Input too large ({} bytes). Maximum allowed: {} bytes. Configure via server.max_input_size in kroki.toml.",
48 source.len(),
49 config.server.max_input_size
50 );
51 }
52
53 let mut fonts = config.all_fonts();
55 fonts.sort();
56 fonts.dedup();
57
58 let mut plugin_signatures: Vec<String> = config
59 .plugins
60 .iter()
61 .map(|plugin| {
62 let args = plugin.args.join(",");
63 let formats = plugin.formats.join(",");
64 format!(
65 "{}|{}|{}|{}|{}|{}",
66 plugin.name,
67 plugin.command,
68 args,
69 formats,
70 plugin.stdin,
71 plugin.timeout_ms.unwrap_or(0)
72 )
73 })
74 .collect();
75 plugin_signatures.sort();
76
77 let mut hasher = Sha256::new();
78 hasher.update(type_);
79 hasher.update(format);
80 hasher.update(source);
81 hasher.update(b"fonts:");
82 for font in &fonts {
83 hasher.update(font.as_bytes());
84 hasher.update([0]);
85 }
86 hasher.update(b"plugins:");
87 for signature in &plugin_signatures {
88 hasher.update(signature.as_bytes());
89 hasher.update([0]);
90 }
91 let hash = hex::encode(hasher.finalize());
92
93 if let Some(cache_path) = cache_dir {
95 if !cache_path.exists() {
96 fs::create_dir_all(cache_path).await.ok();
97 }
98 let cached_file = cache_path.join(format!("{}.{}", hash, format));
99 if cached_file.exists() {
100 if let Ok(content) = fs::read(&cached_file).await {
101 tracing::info!("Cache hit! Served from {}", cached_file.display());
102 return Ok(content);
103 }
104 }
105 }
106
107 provider
109 .validate(source)
110 .context("Source validation failed")?;
111
112 let (base_format, is_webp) = resolve_base_format(format, type_);
114 let mut output_bytes = provider
115 .generate(source, base_format)
116 .await
117 .context("Diagram generation failed")?;
118
119 if is_webp {
121 let fonts = config.all_fonts();
122 output_bytes = if base_format == "png" {
123 image_converter::png_to_webp(&output_bytes, image_converter::WebpQuality::Lossless)
124 .await
125 .context("Failed to convert PNG to WebP")?
126 } else {
127 image_converter::svg_to_webp(
128 &output_bytes,
129 image_converter::WebpQuality::Lossless,
130 &fonts,
131 cache_dir.as_deref(),
132 )
133 .await
134 .context("Failed to convert SVG to WebP")?
135 };
136 }
137
138 if let Some(cache_path) = cache_dir {
140 let cached_file = cache_path.join(format!("{}.{}", hash, format));
141 if let Err(e) = fs::write(&cached_file, &output_bytes).await {
142 tracing::warn!("Failed to write to cache: {}", e);
143 } else {
144 tracing::info!("Saved to cache: {}", cached_file.display());
145 }
146 }
147
148 Ok(output_bytes)
149}
150
151pub async fn convert(
152 type_: String,
153 format: String,
154 input: PathBuf,
155 config: Config,
156 cache_dir: Option<PathBuf>,
157) -> Result<()> {
158 let capabilities = Capabilities::discover(&config);
159 let browser_manager = match crate::browser::BrowserManager::start(
160 config.browser.pool_size,
161 config.browser.context_ttl_requests,
162 )
163 .await
164 {
165 Ok(m) => Some(Arc::new(m)),
166 Err(e) => {
167 tracing::warn!("Browser worker failed to start: {}", e);
168 None
169 }
170 };
171 let registry = DiagramRegistry::new(&capabilities, &config, browser_manager);
172 let cache_dir = Config::resolve_cache_dir(cache_dir);
173
174 let source = fs::read_to_string(&input)
175 .await
176 .context(format!("Failed to read input file '{}'", input.display()))?;
177
178 let output_bytes =
179 generate_diagram(&source, &type_, &format, &config, ®istry, &cache_dir).await?;
180
181 let mut stdout = tokio::io::stdout();
182 stdout.write_all(&output_bytes).await?;
183 stdout.flush().await?;
184
185 Ok(())
186}
187
188pub async fn batch(
189 format: String,
190 input_dir: PathBuf,
191 type_override: Option<String>,
192 out_dir: Option<PathBuf>,
193 config: Config,
194 cache_dir: Option<PathBuf>,
195) -> Result<()> {
196 if !input_dir.is_dir() {
197 return Err(anyhow::anyhow!("Input must be a directory"));
198 }
199
200 let files: Vec<PathBuf> = WalkDir::new(&input_dir)
201 .into_iter()
202 .filter_map(|e| e.ok())
203 .filter(|e| e.file_type().is_file())
204 .map(|e| e.path().to_owned())
205 .collect();
206
207 tracing::info!("Found {} files in {}", files.len(), input_dir.display());
208
209 let capabilities = Capabilities::discover(&config);
210 let browser_manager = match crate::browser::BrowserManager::start(
211 config.browser.pool_size,
212 config.browser.context_ttl_requests,
213 )
214 .await
215 {
216 Ok(m) => Some(Arc::new(m)),
217 Err(e) => {
218 tracing::warn!("Browser worker failed to start: {}", e);
219 None
220 }
221 };
222 let registry = Arc::new(DiagramRegistry::new(
223 &capabilities,
224 &config,
225 browser_manager,
226 ));
227 let config = Arc::new(config);
228 let cache_dir = Arc::new(Config::resolve_cache_dir(cache_dir));
229 let format = Arc::new(format);
230 let type_override = Arc::new(type_override);
231 let out_dir = Arc::new(out_dir);
232
233 let mut tasks = Vec::new();
234 let failure_count = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
235
236 let concurrency = std::cmp::max(1, num_cpus::get());
237 let semaphore = Arc::new(Semaphore::new(concurrency));
238
239 for file_path in files {
240 let extension = file_path
241 .extension()
242 .and_then(|e| e.to_str())
243 .unwrap_or("")
244 .to_lowercase();
245
246 let type_ = if let Some(t) = type_override.as_ref() {
247 Some(t.clone())
248 } else {
249 match extension.as_str() {
250 "d2" => Some("d2".to_string()),
251 "dot" | "gv" => Some("graphviz".to_string()),
252 "mmd" | "mermaid" => Some("mermaid".to_string()),
253 "excalidraw" => Some("excalidraw".to_string()),
254 "bpmn" => Some("bpmn".to_string()),
255 "vega" => Some("vega".to_string()),
256 "vl" => Some("vegalite".to_string()),
257 _ => {
258 if file_path.to_string_lossy().ends_with(".vl.json") {
259 Some("vegalite".to_string())
260 } else {
261 None
262 }
263 }
264 }
265 };
266
267 if let Some(t) = type_ {
268 let config = config.clone();
269 let registry = registry.clone();
270 let cache_dir = cache_dir.clone();
271 let format = format.clone();
272 let out_dir = out_dir.clone();
273 let input_dir = input_dir.clone();
274 let failure_count = failure_count.clone();
275
276 let semaphore = semaphore.clone();
277 tasks.push(tokio::spawn(async move {
278 let _permit = semaphore.acquire_owned().await.unwrap();
279 let relative_path = file_path.strip_prefix(&input_dir).unwrap_or(&file_path);
280 let mut output_path = if let Some(out) = out_dir.as_ref() {
281 out.join(relative_path)
282 } else {
283 file_path.clone()
284 };
285 output_path.set_extension(format.as_str());
286
287 if let Some(parent) = output_path.parent() {
288 fs::create_dir_all(parent).await.ok();
289 }
290
291 let source = match fs::read_to_string(&file_path).await {
292 Ok(s) => s,
293 Err(e) => {
294 tracing::error!("Failed to read {}: {}", file_path.display(), e);
295 failure_count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
296 return;
297 }
298 };
299
300 match generate_diagram(&source, &t, &format, &config, ®istry, &cache_dir).await {
301 Ok(bytes) => {
302 if let Err(e) = fs::write(&output_path, &bytes).await {
303 tracing::error!("Failed to write {}: {}", output_path.display(), e);
304 failure_count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
305 } else {
306 tracing::info!(
307 "Converted: {} -> {}",
308 file_path.display(),
309 output_path.display()
310 );
311 }
312 }
313 Err(e) => {
314 tracing::error!("Failed to convert {}: {}", file_path.display(), e);
315 failure_count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
316 }
317 }
318 }));
319 }
320 }
321
322 for task in tasks {
323 task.await?;
324 }
325
326 let failures = failure_count.load(std::sync::atomic::Ordering::Relaxed);
327 if failures > 0 {
328 anyhow::bail!("{} file(s) failed to convert", failures);
329 }
330
331 Ok(())
332}