kroki_rs/diagrams/
mod.rs

1pub mod providers {
2    pub mod bpmn;
3    pub mod cmd;
4    pub mod d2;
5    pub mod ditaa;
6    pub mod excalidraw;
7    pub mod mermaid;
8    pub mod plugin;
9    pub mod vega;
10    pub mod wavedrom;
11}
12pub mod error;
13pub mod registry;
14
15pub use error::{DiagramError, DiagramResult};
16
17use async_trait::async_trait;
18
19/// Computes an adaptive timeout based on input size.
20/// Base: 3000ms. Adds 1000ms per 10KB of payload. Max default: 10000ms.
21pub fn adaptive_timeout(source_len: usize) -> u64 {
22    let base = 3000;
23    let scaling = (source_len as u64 / 10240) * 1000;
24    std::cmp::min(base + scaling, 10000)
25}
26
27/// A macro to drastically reduce boilerplate structurally identical providers.
28#[macro_export]
29macro_rules! define_provider {
30    ($name:ident) => {
31        pub struct $name {
32            pub bin_path: std::path::PathBuf,
33            pub timeout_ms: Option<u64>,
34        }
35
36        impl $name {
37            pub fn new(bin_path: std::path::PathBuf, timeout_ms: Option<u64>) -> Self {
38                Self {
39                    bin_path,
40                    timeout_ms,
41                }
42            }
43        }
44    };
45}
46pub(crate) use define_provider;
47
48/// Safely executes a child process, automatically managing timeouts, input piping, and memory cleanup.
49///
50/// # Arguments
51/// * `tool_name` - Human-readable name of the tool (e.g., "mmdc", "dot") for error messages.
52/// * `cmd` - The tokio Command to execute.
53/// * `source` - Optional bytes to pipe to stdin.
54/// * `timeout_ms` - Optional explicit timeout; falls back to adaptive_timeout.
55/// * `source_len` - Length of the source input (used for adaptive timeout and error context).
56pub async fn run_process_with_timeout(
57    tool_name: &str,
58    mut cmd: tokio::process::Command,
59    source: Option<&[u8]>,
60    timeout_ms: Option<u64>,
61    source_len: usize,
62) -> DiagramResult<std::process::Output> {
63    use tokio::io::AsyncWriteExt;
64
65    cmd.kill_on_drop(true);
66    let mut child = cmd.spawn().map_err(|e| {
67        if e.kind() == std::io::ErrorKind::NotFound {
68            DiagramError::ToolNotFound(tool_name.to_string())
69        } else {
70            DiagramError::ProcessFailed(format!("Failed to spawn '{}': {}", tool_name, e))
71        }
72    })?;
73
74    if let (Some(mut stdin), Some(src)) = (child.stdin.take(), source) {
75        stdin.write_all(src).await.map_err(|e| {
76            DiagramError::ProcessFailed(format!(
77                "Failed to write to stdin of '{}': {}",
78                tool_name, e
79            ))
80        })?;
81    }
82
83    let actual_timeout = std::cmp::min(
84        timeout_ms.unwrap_or_else(|| adaptive_timeout(source_len)),
85        20000,
86    );
87    let output_future = child.wait_with_output();
88
89    match tokio::time::timeout(
90        std::time::Duration::from_millis(actual_timeout),
91        output_future,
92    )
93    .await
94    {
95        Ok(Ok(out)) => Ok(out),
96        Ok(Err(e)) => Err(DiagramError::ProcessFailed(format!(
97            "'{}' IO error (input: {} bytes): {}",
98            tool_name, source_len, e
99        ))),
100        Err(_) => Err(DiagramError::ExecutionTimeout {
101            tool: tool_name.to_string(),
102            timeout_ms: actual_timeout,
103            bytes: source_len,
104        }),
105    }
106}
107
108/// A trait for diagram generation providers.
109///
110/// Each provider implementation is responsible for a specific diagram type
111/// (e.g., Mermaid, Graphviz).
112#[async_trait]
113pub trait DiagramProvider {
114    /// Validates the diagram source text.
115    ///
116    /// Returns `Ok(())` if the source is valid, or an error otherwise.
117    fn validate(&self, source: &str) -> DiagramResult<()>;
118
119    /// Generates a diagram image from the source text.
120    ///
121    /// # Arguments
122    /// * `source` - The diagram description text.
123    /// * `format` - The desired output format (e.g., "svg", "png").
124    ///
125    /// Returns a `Vec<u8>` containing the image data.
126    async fn generate(&self, source: &str, format: &str) -> DiagramResult<Vec<u8>>;
127}
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132    use tokio::process::Command;
133
134    #[tokio::test]
135    async fn test_run_process_with_timeout() {
136        let mut cmd = Command::new("sleep");
137        cmd.arg("2"); // Sleep for 2 seconds
138
139        // Set timeout to 100ms, which is much shorter than 2 seconds
140        let result = run_process_with_timeout("sleep", cmd, None, Some(100), 0).await;
141
142        // Ensure the function returns an error and it's specifically a timeout
143        assert!(result.is_err());
144        let err = result.unwrap_err();
145        match err {
146            DiagramError::ExecutionTimeout {
147                tool,
148                timeout_ms,
149                bytes,
150            } => {
151                assert_eq!(tool, "sleep");
152                assert_eq!(timeout_ms, 100);
153                assert_eq!(bytes, 0);
154            }
155            _ => panic!("Expected ExecutionTimeout error"),
156        }
157    }
158}