diff --git a/AGENTS.md b/AGENTS.md index 6d4439814..ea2ccacc2 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -109,13 +109,13 @@ management. - `make typecheck` - `cargo check --all --benches --tests --examples` - `cargo check --all --benches --tests --examples` - `--no-default-features --features libsql` + `--no-default-features --features libsql-test-helpers` - `cargo check --all --benches --tests --examples --all-features` - `cargo check --manifest-path tools-src/github/Cargo.toml --tests` - `make lint` - `cargo clippy --all --benches --tests --examples -- -D warnings` - `cargo clippy --all --benches --tests --examples` - `--no-default-features --features libsql -- -D warnings` + `--no-default-features --features libsql-test-helpers -- -D warnings` - `cargo clippy --all --benches --tests --examples --all-features -- -D warnings` - `cargo clippy --manifest-path tools-src/github/Cargo.toml --tests -- -D warnings` - `make test` diff --git a/Cargo.lock b/Cargo.lock index 4bab0ec4e..423ecaebd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -325,9 +325,9 @@ dependencies = [ [[package]] name = "async-signal" -version = "0.2.13" +version = "0.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43c070bbf59cd3570b6b2dd54cd772527c7c3620fce8be898406dd3ed6adc64c" +checksum = "52b5aaafa020cf5053a01f2a60e8ff5dccf550f0f77ec54a4e47285ac2bab485" dependencies = [ "async-io", "async-lock", @@ -666,11 +666,11 @@ dependencies = [ "hyper 0.14.32", "hyper 1.9.0", "hyper-rustls 0.24.2", - "hyper-rustls 0.27.7", + "hyper-rustls 0.27.9", "hyper-util", "pin-project-lite", "rustls 0.21.12", - "rustls 0.23.37", + "rustls 0.23.38", "rustls-native-certs 0.8.3", "rustls-pki-types", "tokio", @@ -828,9 +828,9 @@ dependencies = [ [[package]] name = "axum" -version = "0.8.8" +version = "0.8.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b52af3cb4058c895d37317bb27508dccc8e5f2d39454016b297bf4a400597b8" +checksum = "31b698c5f9a010f6573133b09e0de5408834d0c82f8d7475a89fc1867a71cd90" dependencies = [ "axum-core 0.5.6", "base64 0.22.1", @@ -855,7 +855,7 @@ dependencies = [ "sha1", "sync_wrapper 1.0.2", "tokio", - "tokio-tungstenite 0.28.0", + "tokio-tungstenite 0.29.0", "tower 0.5.3", "tower-layer", "tower-service", @@ -941,7 +941,7 @@ version = "0.66.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f2b84e06fc203107bfbad243f4aba2af864eb7db3b1cf46ea0a023b0b433d2a7" dependencies = [ - "bitflags 2.11.0", + "bitflags 2.11.1", "cexpr", "clang-sys", "lazy_static", @@ -981,9 +981,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.11.0" +version = "2.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "843867be96c8daad0d758b57df9392b6d8d271134fce549de6ce169ff98a92af" +checksum = "c4512299f36f043ab09a583e57bceb5a5aab7a73db1805848e8fef3c9e8c78b3" [[package]] name = "bitvec" @@ -1068,12 +1068,12 @@ dependencies = [ "http-body-util", "hyper 1.9.0", "hyper-named-pipe", - "hyper-rustls 0.27.7", + "hyper-rustls 0.27.9", "hyper-util", "hyperlocal", "log", "pin-project-lite", - "rustls 0.23.37", + "rustls 0.23.38", "rustls-native-certs 0.8.3", "rustls-pemfile", "rustls-pki-types", @@ -1270,9 +1270,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.59" +version = "1.2.60" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b7a4d3ec6524d28a329fc53654bbadc9bdd7b0431f5d65f1a56ffb28a1ee5283" +checksum = "43c5703da9466b66a946814e1adf53ea2c90f10063b86290cc9eb67ce3478a20" dependencies = [ "find-msvc-tools", "jobserver", @@ -1309,7 +1309,7 @@ checksum = "6f8d983286843e49675a4b7a2d174efe136dc93a18d69130dd18198a6c167601" dependencies = [ "cfg-if", "cpufeatures 0.3.0", - "rand_core 0.10.0", + "rand_core 0.10.1", ] [[package]] @@ -1381,9 +1381,9 @@ dependencies = [ [[package]] name = "clap_complete" -version = "4.6.0" +version = "4.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19c9f1dde76b736e3681f28cec9d5a61299cbaae0fce80a68e43724ad56031eb" +checksum = "3ff7a1dccbdd8b078c2bdebff47e404615151534d5043da397ec50286816f9cb" dependencies = [ "clap", ] @@ -1824,7 +1824,7 @@ version = "0.28.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "829d955a0bb380ef178a640b91779e3987da38c9aea133b20614cfed8cdea9c6" dependencies = [ - "bitflags 2.11.0", + "bitflags 2.11.1", "crossterm_winapi", "mio", "parking_lot", @@ -1840,7 +1840,7 @@ version = "0.29.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d8b9f2e4c67f833b660cdb0a3523065869fb35570177239812ed4c905aeff87b" dependencies = [ - "bitflags 2.11.0", + "bitflags 2.11.1", "crossterm_winapi", "derive_more", "document-features", @@ -2039,6 +2039,17 @@ dependencies = [ "uuid", ] +[[package]] +name = "delegate" +version = "0.13.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "780eb241654bf097afb00fc5f054a09b687dad862e485fdcf8399bb056565370" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + [[package]] name = "der" version = "0.7.10" @@ -2456,9 +2467,9 @@ checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a" [[package]] name = "fastrand" -version = "2.4.0" +version = "2.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a043dc74da1e37d6afe657061213aa6f425f855399a11d3463c6ecccc4dfda1f" +checksum = "9f1f227452a390804cdb637b74a86990f2a7d7ba4b7d5693aac9b4dd6defd8d6" [[package]] name = "fd-lock" @@ -2737,7 +2748,7 @@ version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "27d12c0aed7f1e24276a241aadc4cb8ea9f83000f34bc062b7cc2d51e3b0fabd" dependencies = [ - "bitflags 2.11.0", + "bitflags 2.11.1", "debugid", "fxhash", "serde", @@ -2809,7 +2820,7 @@ dependencies = [ "cfg-if", "libc", "r-efi 6.0.0", - "rand_core 0.10.0", + "rand_core 0.10.1", "wasip2", "wasip3", ] @@ -2831,7 +2842,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" dependencies = [ "fallible-iterator 0.3.0", - "indexmap 2.13.1", + "indexmap 2.14.0", "stable_deref_trait", ] @@ -2853,7 +2864,7 @@ dependencies = [ "futures-sink", "futures-util", "http 0.2.12", - "indexmap 2.13.1", + "indexmap 2.14.0", "slab", "tokio", "tokio-util", @@ -2872,7 +2883,7 @@ dependencies = [ "futures-core", "futures-sink", "http 1.4.0", - "indexmap 2.13.1", + "indexmap 2.14.0", "slab", "tokio", "tokio-util", @@ -2920,6 +2931,12 @@ dependencies = [ "foldhash 0.2.0", ] +[[package]] +name = "hashbrown" +version = "0.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4f467dd6dccf739c208452f8014c75c18bb8301b050ad1cfb27153803edb0f51" + [[package]] name = "hashlink" version = "0.8.4" @@ -3228,16 +3245,15 @@ dependencies = [ [[package]] name = "hyper-rustls" -version = "0.27.7" +version = "0.27.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58" +checksum = "33ca68d021ef39cf6463ab54c1d0f5daf03377b70561305bb89a8f83aab66e0f" dependencies = [ "http 1.4.0", "hyper 1.9.0", "hyper-util", - "rustls 0.23.37", + "rustls 0.23.38", "rustls-native-certs 0.8.3", - "rustls-pki-types", "tokio", "tokio-rustls 0.26.4", "tower-service", @@ -3463,12 +3479,12 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.13.1" +version = "2.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "45a8a2b9cb3e0b0c1803dbb0758ffac5de2f425b23c28f518faabd9d805342ff" +checksum = "d466e9454f08e4a911e14806c24e16fba1b4c121d1ea474396f396069cf949d9" dependencies = [ "equivalent", - "hashbrown 0.16.1", + "hashbrown 0.17.0", "serde", "serde_core", ] @@ -3491,6 +3507,7 @@ checksum = "7b4a6248eb93a4401ed2f37dfe8ea592d3cf05b7cf4f8efa867b6895af7e094e" dependencies = [ "console", "once_cell", + "serde", "similar", "tempfile", ] @@ -3537,7 +3554,7 @@ dependencies = [ "aws-config", "aws-sdk-bedrockruntime", "aws-smithy-types", - "axum 0.8.8", + "axum 0.8.9", "base64 0.22.1", "blake3", "bollard", @@ -3550,6 +3567,7 @@ dependencies = [ "cron", "crossterm 0.28.1", "deadpool-postgres", + "delegate", "dirs 6.0.0", "dotenvy", "ed25519-dalek", @@ -3589,9 +3607,10 @@ dependencies = [ "rusqlite", "rust_decimal", "rust_decimal_macros", - "rustls 0.23.37", + "rustls 0.23.38", "rustls-native-certs 0.8.3", "rustyline", + "scopeguard", "secrecy", "secret-service", "security-framework 3.7.0", @@ -3703,9 +3722,9 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.94" +version = "0.3.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e04e2ef80ce82e13552136fabeef8a5ed1f985a96805761cbb9a2c34e7664d9" +checksum = "2964e92d1d9dc3364cae4d718d93f227e3abb088e747d92e0395bfdedf1c12ca" dependencies = [ "cfg-if", "futures-util", @@ -3730,11 +3749,11 @@ version = "0.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b73885c6a3cefdf7a1db0327cefbe4b9b72cac94cae4b19ede4fa492d8af02a0" dependencies = [ - "bitflags 2.11.0", + "bitflags 2.11.1", "crc", "cssparser", "html5ever 0.38.0", - "indexmap 2.13.1", + "indexmap 2.14.0", "precomputed-hash", "selectors 0.35.0", ] @@ -3788,9 +3807,9 @@ checksum = "09edd9e8b54e49e587e4f6295a7d29c3ea94d469cb40ab8ca70b288248a81db2" [[package]] name = "libc" -version = "0.2.184" +version = "0.2.185" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48f5d2a454e16a5ea0f4ced81bd44e4cfc7bd3a507b61887c99fd3538b28e4af" +checksum = "52ff2c0fe9bc6cb6b14a0592c2ff4fa9ceb83eea9db979b0487cd054946a2b8f" [[package]] name = "libloading" @@ -3810,14 +3829,14 @@ checksum = "b6d2cec3eae94f9f509c767b45932f1ada8350c4bdb85af2fcab4a3c14807981" [[package]] name = "libredox" -version = "0.1.15" +version = "0.1.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ddbf48fd451246b1f8c2610bd3b4ac0cc6e149d89832867093ab69a17194f08" +checksum = "e02f3bb43d335493c96bf3fd3a321600bf6bd07ed34bc64118e9293bdffea46c" dependencies = [ - "bitflags 2.11.0", + "bitflags 2.11.1", "libc", "plain", - "redox_syscall 0.7.3", + "redox_syscall 0.7.4", ] [[package]] @@ -3831,7 +3850,7 @@ dependencies = [ "async-trait", "base64 0.21.7", "bincode", - "bitflags 2.11.0", + "bitflags 2.11.1", "bytes", "fallible-iterator 0.3.0", "futures", @@ -3886,7 +3905,7 @@ version = "0.33.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ae65c66088dcd309abbd5617ae046abac2a2ee0a7fdada5127353bd68e0a27ea" dependencies = [ - "bitflags 2.11.0", + "bitflags 2.11.1", "fallible-iterator 0.2.0", "fallible-streaming-iterator", "hashlink 0.8.4", @@ -3900,10 +3919,10 @@ version = "0.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "15a90128c708356af8f7d767c9ac2946692c9112b4f74f07b99a01a60680e413" dependencies = [ - "bitflags 2.11.0", + "bitflags 2.11.1", "cc", "fallible-iterator 0.3.0", - "indexmap 2.13.1", + "indexmap 2.14.0", "log", "memchr", "phf 0.11.3", @@ -4019,7 +4038,7 @@ checksum = "c5c8ecfc6c72051981c0459f75ccc585e7ff67c70829560cda8e647882a9abff" dependencies = [ "encoding_rs", "flate2", - "indexmap 2.13.1", + "indexmap 2.14.0", "itoa", "log", "md-5 0.10.6", @@ -4031,9 +4050,9 @@ dependencies = [ [[package]] name = "lru" -version = "0.16.3" +version = "0.16.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1dc47f592c06f33f8e3aea9591776ec7c9f9e4124778ff8a3c3b87159f7e593" +checksum = "7f66e8d5d03f609abc3a39e6f08e4164ebf1447a732906d39eb9b99b7919ef39" dependencies = [ "hashbrown 0.16.1", ] @@ -4289,7 +4308,7 @@ version = "0.29.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "71e2746dc3a24dd78b3cfcb7be93368c6de9963d30f43a6a73998a9cf4b17b46" dependencies = [ - "bitflags 2.11.0", + "bitflags 2.11.1", "cfg-if", "cfg_aliases", "libc", @@ -4302,7 +4321,7 @@ version = "0.30.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "74523f3a35e05aba87a1d978330aef40f67b0304ac79c1c00b294c9830543db6" dependencies = [ - "bitflags 2.11.0", + "bitflags 2.11.1", "cfg-if", "cfg_aliases", "libc", @@ -4422,7 +4441,7 @@ version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2a180dd8642fa45cdb7dd721cd4c11b1cadd4929ce112ebd8b9f5803cc79d536" dependencies = [ - "bitflags 2.11.0", + "bitflags 2.11.1", ] [[package]] @@ -4442,7 +4461,7 @@ checksum = "62948e14d923ea95ea2c7c86c71013138b66525b86bdc08d2dcc262bdb497b87" dependencies = [ "crc32fast", "hashbrown 0.15.5", - "indexmap 2.13.1", + "indexmap 2.14.0", "memchr", ] @@ -4486,11 +4505,11 @@ dependencies = [ [[package]] name = "openssl" -version = "0.10.76" +version = "0.10.77" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "951c002c75e16ea2c65b8c7e4d3d51d5530d8dfa7d060b4776828c88cfb18ecf" +checksum = "bfe4646e360ec77dff7dde40ed3d6c5fee52d156ef4a62f53973d38294dad87f" dependencies = [ - "bitflags 2.11.0", + "bitflags 2.11.1", "cfg-if", "foreign-types", "libc", @@ -4524,9 +4543,9 @@ checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe" [[package]] name = "openssl-sys" -version = "0.9.112" +version = "0.9.113" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57d55af3b3e226502be1526dfdba67ab0e9c96fc293004e79576b2b9edb0dbdb" +checksum = "ad2f2c0eba47118757e4c6d2bff2838f3e0523380021356e7875e858372ce644" dependencies = [ "cc", "libc", @@ -4876,9 +4895,9 @@ dependencies = [ [[package]] name = "pkg-config" -version = "0.3.32" +version = "0.3.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" +checksum = "19f132c84eca552bf34cab8ec81f1c1dcc229b811638f9d283dceabe58c5569e" [[package]] name = "plain" @@ -4943,7 +4962,7 @@ dependencies = [ "hmac 0.13.0", "md-5 0.11.0", "memchr", - "rand 0.10.0", + "rand 0.10.1", "sha2 0.11.0", "stringprep", ] @@ -5051,7 +5070,7 @@ version = "3.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e67ba7e9b2b56446f1d419b1d807906278ffa1a658a8a5d8a39dcb1f5a78614f" dependencies = [ - "toml_edit 0.25.10+spec-1.1.0", + "toml_edit 0.25.11+spec-1.1.0", ] [[package]] @@ -5071,9 +5090,9 @@ checksum = "4b45fcc2344c680f5025fe57779faef368840d0bd1f42f216291f0dc4ace4744" dependencies = [ "bit-set", "bit-vec", - "bitflags 2.11.0", + "bitflags 2.11.1", "num-traits", - "rand 0.9.2", + "rand 0.9.4", "rand_chacha 0.9.0", "rand_xorshift", "regex-syntax", @@ -5164,7 +5183,7 @@ dependencies = [ "quinn-proto", "quinn-udp", "rustc-hash 2.1.2", - "rustls 0.23.37", + "rustls 0.23.38", "socket2 0.6.3", "thiserror 2.0.18", "tokio", @@ -5181,10 +5200,10 @@ dependencies = [ "bytes", "getrandom 0.3.4", "lru-slab", - "rand 0.9.2", + "rand 0.9.4", "ring", "rustc-hash 2.1.2", - "rustls 0.23.37", + "rustls 0.23.38", "rustls-pki-types", "slab", "thiserror 2.0.18", @@ -5257,9 +5276,9 @@ dependencies = [ [[package]] name = "rand" -version = "0.9.2" +version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" +checksum = "44c5af06bb1b7d3216d91932aed5265164bf384dc89cd6ba05cf59a35f5f76ea" dependencies = [ "rand_chacha 0.9.0", "rand_core 0.9.5", @@ -5267,13 +5286,13 @@ dependencies = [ [[package]] name = "rand" -version = "0.10.0" +version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc266eb313df6c5c09c1c7b1fbe2510961e5bcd3add930c1e31f7ed9da0feff8" +checksum = "d2e8e8bcc7961af1fdac401278c6a831614941f6164ee3bf4ce61b7edb162207" dependencies = [ "chacha20", "getrandom 0.4.2", - "rand_core 0.10.0", + "rand_core 0.10.1", ] [[package]] @@ -5316,9 +5335,9 @@ dependencies = [ [[package]] name = "rand_core" -version = "0.10.0" +version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c8d0fd677905edcbeedbf2edb6494d676f0e98d54d5cf9bda0b061cb8fb8aba" +checksum = "63b8176103e19a2643978565ca18b50549f6101881c443590420e4dc998a3c69" [[package]] name = "rand_xorshift" @@ -5337,9 +5356,9 @@ checksum = "973443cf09a9c8656b574a866ab68dfa19f0867d0340648c7d2f6a71b8a8ea68" [[package]] name = "rayon" -version = "1.11.0" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "368f01d005bf8fd9b1206fb6fa653e6c4a81ceb1466406b81792d87c5677a58f" +checksum = "fb39b166781f92d482534ef4b4b1b2568f42613b53e5b6c160e24cfbfa30926d" dependencies = [ "either", "rayon-core", @@ -5357,11 +5376,11 @@ dependencies = [ [[package]] name = "readabilityrs" -version = "0.1.2" +version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3eb174b0af6c181a87d68b42800806657bfbdf88b566f819aaadb9d2a7b7699d" +checksum = "d90c6e1dad698d9f3c80a8d91bc0efc8c2397cb5ca4bbffb9c0fb88a9a66e6b8" dependencies = [ - "bitflags 2.11.0", + "bitflags 2.11.1", "kuchikikiki", "once_cell", "regex", @@ -5388,16 +5407,16 @@ version = "0.5.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d" dependencies = [ - "bitflags 2.11.0", + "bitflags 2.11.1", ] [[package]] name = "redox_syscall" -version = "0.7.3" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ce70a74e890531977d37e532c34d45e9055d2409ed08ddba14529471ed0be16" +checksum = "f450ad9c3b1da563fb6948a8e0fb0fb9269711c9c73d9ea1de5058c79c8d643a" dependencies = [ - "bitflags 2.11.0", + "bitflags 2.11.1", ] [[package]] @@ -5567,7 +5586,7 @@ dependencies = [ "http-body 1.0.1", "http-body-util", "hyper 1.9.0", - "hyper-rustls 0.27.7", + "hyper-rustls 0.27.9", "hyper-tls", "hyper-util", "js-sys", @@ -5578,7 +5597,7 @@ dependencies = [ "percent-encoding", "pin-project-lite", "quinn", - "rustls 0.23.37", + "rustls 0.23.38", "rustls-native-certs 0.8.3", "rustls-pki-types", "serde", @@ -5709,7 +5728,7 @@ version = "0.32.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7753b721174eb8ff87a9a0e799e2d7bc3749323e773db92e0984debb00019d6e" dependencies = [ - "bitflags 2.11.0", + "bitflags 2.11.1", "fallible-iterator 0.3.0", "fallible-streaming-iterator", "hashlink 0.9.1", @@ -5778,7 +5797,7 @@ version = "0.38.44" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154" dependencies = [ - "bitflags 2.11.0", + "bitflags 2.11.1", "errno", "libc", "linux-raw-sys 0.4.15", @@ -5791,7 +5810,7 @@ version = "1.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6fe4565b9518b83ef4f91bb47ce29620ca828bd32cb7e408f0062e9930ba190" dependencies = [ - "bitflags 2.11.0", + "bitflags 2.11.1", "errno", "libc", "linux-raw-sys 0.12.1", @@ -5836,15 +5855,15 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.37" +version = "0.23.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "758025cb5fccfd3bc2fd74708fd4682be41d99e5dff73c377c0646c6012c73a4" +checksum = "69f9466fb2c14ea04357e91413efb882e2a6d4a406e625449bc0a5d360d53a21" dependencies = [ "aws-lc-rs", "once_cell", "ring", "rustls-pki-types", - "rustls-webpki 0.103.10", + "rustls-webpki 0.103.12", "subtle", "zeroize", ] @@ -5916,9 +5935,9 @@ dependencies = [ [[package]] name = "rustls-webpki" -version = "0.103.10" +version = "0.103.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df33b2b81ac578cabaf06b89b0631153a3f416b0a886e8a7a1707fb51abbd1ef" +checksum = "8279bb85272c9f10811ae6a6c547ff594d6a7f3c6c6b02ee9726d1d0dcfcdd06" dependencies = [ "aws-lc-rs", "ring", @@ -5950,7 +5969,7 @@ version = "17.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e902948a25149d50edc1a8e0141aad50f54e22ba83ff988cf8f7c9ef07f50564" dependencies = [ - "bitflags 2.11.0", + "bitflags 2.11.1", "cfg-if", "clipboard-win", "fd-lock", @@ -6111,7 +6130,7 @@ version = "2.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" dependencies = [ - "bitflags 2.11.0", + "bitflags 2.11.1", "core-foundation 0.9.4", "core-foundation-sys", "libc", @@ -6124,7 +6143,7 @@ version = "3.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b7f4bc775c73d9a02cde8bf7b2ec4c9d12743edf609006c7facc23998404cd1d" dependencies = [ - "bitflags 2.11.0", + "bitflags 2.11.1", "core-foundation 0.10.1", "core-foundation-sys", "libc", @@ -6147,7 +6166,7 @@ version = "0.33.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "feef350c36147532e1b79ea5c1f3791373e61cbd9a6a2615413b3807bb164fb7" dependencies = [ - "bitflags 2.11.0", + "bitflags 2.11.1", "cssparser", "derive_more", "log", @@ -6166,7 +6185,7 @@ version = "0.35.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "93fdfed56cd634f04fe8b9ddf947ae3dc493483e819593d2ba17df9ad05db8b2" dependencies = [ - "bitflags 2.11.0", + "bitflags 2.11.1", "cssparser", "derive_more", "log", @@ -6305,7 +6324,7 @@ dependencies = [ "chrono", "hex", "indexmap 1.9.3", - "indexmap 2.13.1", + "indexmap 2.14.0", "schemars 0.9.0", "schemars 1.2.1", "serde_core", @@ -6332,7 +6351,7 @@ version = "0.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "59e2dd588bf1597a252c3b920e0143eb99b0f76e4e082f4c92ce34fbc9e71ddd" dependencies = [ - "indexmap 2.13.1", + "indexmap 2.14.0", "itoa", "libyml", "memchr", @@ -6670,7 +6689,7 @@ version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a13f3d0daba03132c0aa9767f98351b3488edc2c100cda2d2ec2b04f3d8d3c8b" dependencies = [ - "bitflags 2.11.0", + "bitflags 2.11.1", "core-foundation 0.9.4", "system-configuration-sys", ] @@ -6691,7 +6710,7 @@ version = "0.27.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cc4592f674ce18521c2a81483873a49596655b179f71c5e05d10c1fe66c78745" dependencies = [ - "bitflags 2.11.0", + "bitflags 2.11.1", "cap-fs-ext", "cap-std", "fd-lock", @@ -6970,9 +6989,9 @@ dependencies = [ [[package]] name = "tokio" -version = "1.51.0" +version = "1.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2bd1c4c0fc4a7ab90fc15ef6daaa3ec3b893f004f915f2392557ed23237820cd" +checksum = "a91135f59b1cbf38c91e73cf3386fca9bb77915c45ce2771460c9d92f0f3d776" dependencies = [ "bytes", "libc", @@ -7036,7 +7055,7 @@ dependencies = [ "pin-project-lite", "postgres-protocol", "postgres-types", - "rand 0.10.0", + "rand 0.10.1", "socket2 0.6.3", "tokio", "tokio-util", @@ -7051,7 +7070,7 @@ checksum = "27d684bad428a0f2481f42241f821db42c54e2dc81d8c00db8536c506b0a0144" dependencies = [ "const-oid 0.9.6", "ring", - "rustls 0.23.37", + "rustls 0.23.38", "tokio", "tokio-postgres", "tokio-rustls 0.26.4", @@ -7085,7 +7104,7 @@ version = "0.26.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1729aa945f29d91ba541258c8df89027d5792d85a8841fb65e8bf0f4ede4ef61" dependencies = [ - "rustls 0.23.37", + "rustls 0.23.38", "tokio", ] @@ -7141,14 +7160,14 @@ dependencies = [ [[package]] name = "tokio-tungstenite" -version = "0.28.0" +version = "0.29.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d25a406cddcc431a75d3d9afc6a7c0f7428d4891dd973e4d54c56b46127bf857" +checksum = "8f72a05e828585856dacd553fba484c242c46e391fb0e58917c942ee9202915c" dependencies = [ "futures-util", "log", "tokio", - "tungstenite 0.28.0", + "tungstenite 0.29.0", ] [[package]] @@ -7182,7 +7201,7 @@ version = "1.1.2+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "81f3d15e84cbcd896376e6730314d59fb5a87f31e4b038454184435cd57defee" dependencies = [ - "indexmap 2.13.1", + "indexmap 2.14.0", "serde_core", "serde_spanned 1.1.1", "toml_datetime 1.1.1+spec-1.1.0", @@ -7215,7 +7234,7 @@ version = "0.22.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "41fe8c660ae4257887cf66394862d21dbca4a6ddd26f04a3560410406a2f819a" dependencies = [ - "indexmap 2.13.1", + "indexmap 2.14.0", "serde", "serde_spanned 0.6.9", "toml_datetime 0.6.11", @@ -7225,11 +7244,11 @@ dependencies = [ [[package]] name = "toml_edit" -version = "0.25.10+spec-1.1.0" +version = "0.25.11+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a82418ca169e235e6c399a84e395ab6debeb3bc90edc959bf0f48647c6a32d1b" +checksum = "0b59c4d22ed448339746c59b905d24568fcbb3ab65a500494f7b8c3e97739f2b" dependencies = [ - "indexmap 2.13.1", + "indexmap 2.14.0", "toml_datetime 1.1.1+spec-1.1.0", "toml_parser", "winnow 1.0.1", @@ -7345,7 +7364,7 @@ version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "61c5bb1d698276a2443e5ecfabc1008bf15a36c12e6a7176e7bf089ea9131140" dependencies = [ - "bitflags 2.11.0", + "bitflags 2.11.1", "bytes", "futures-core", "futures-util", @@ -7365,7 +7384,7 @@ version = "0.6.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d4e6559d53cc268e5031cd8429d05415bc4cb4aefc4aa5d6cc35fbf5b924a1f8" dependencies = [ - "bitflags 2.11.0", + "bitflags 2.11.1", "bytes", "futures-util", "http 1.4.0", @@ -7530,7 +7549,7 @@ dependencies = [ "http 1.4.0", "httparse", "log", - "rand 0.9.2", + "rand 0.9.4", "sha1", "thiserror 2.0.18", "utf-8", @@ -7538,19 +7557,18 @@ dependencies = [ [[package]] name = "tungstenite" -version = "0.28.0" +version = "0.29.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8628dcc84e5a09eb3d8423d6cb682965dea9133204e8fb3efee74c2a0c259442" +checksum = "6c01152af293afb9c7c2a57e4b559c5620b421f6d133261c60dd2d0cdb38e6b8" dependencies = [ "bytes", "data-encoding", "http 1.4.0", "httparse", "log", - "rand 0.9.2", + "rand 0.9.4", "sha1", "thiserror 2.0.18", - "utf-8", ] [[package]] @@ -7831,9 +7849,9 @@ dependencies = [ [[package]] name = "wasm-bindgen" -version = "0.2.117" +version = "0.2.118" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0551fc1bb415591e3372d0bc4780db7e587d84e2a7e79da121051c5c4b89d0b0" +checksum = "0bf938a0bacb0469e83c1e148908bd7d5a6010354cf4fb73279b7447422e3a89" dependencies = [ "cfg-if", "once_cell", @@ -7845,9 +7863,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.67" +version = "0.4.68" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03623de6905b7206edd0a75f69f747f134b7f0a2323392d664448bf2d3c5d87e" +checksum = "f371d383f2fb139252e0bfac3b81b265689bf45b6874af544ffa4c975ac1ebf8" dependencies = [ "js-sys", "wasm-bindgen", @@ -7855,9 +7873,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.117" +version = "0.2.118" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7fbdf9a35adf44786aecd5ff89b4563a90325f9da0923236f6104e603c7e86be" +checksum = "eeff24f84126c0ec2db7a449f0c2ec963c6a49efe0698c4242929da037ca28ed" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -7865,9 +7883,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.117" +version = "0.2.118" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dca9693ef2bab6d4e6707234500350d8dad079eb508dca05530c85dc3a529ff2" +checksum = "9d08065faf983b2b80a79fd87d8254c409281cf7de75fc4b773019824196c904" dependencies = [ "bumpalo", "proc-macro2", @@ -7878,9 +7896,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.117" +version = "0.2.118" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "39129a682a6d2d841b6c429d0c51e5cb0ed1a03829d8b3d1e69a011e62cb3d3b" +checksum = "5fd04d9e306f1907bd13c6361b5c6bfc7b3b3c095ed3f8a9246390f8dbdee129" dependencies = [ "unicode-ident", ] @@ -7922,7 +7940,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bb0e353e6a2fbdc176932bbaab493762eb1255a7900fe0fea1a2f96c296cc909" dependencies = [ "anyhow", - "indexmap 2.13.1", + "indexmap 2.14.0", "wasm-encoder 0.244.0", "wasmparser 0.244.0", ] @@ -7947,9 +7965,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8d07b6a3b550fefa1a914b6d54fc175dd11c3392da11eee604e6ffc759805d25" dependencies = [ "ahash 0.8.12", - "bitflags 2.11.0", + "bitflags 2.11.1", "hashbrown 0.14.5", - "indexmap 2.13.1", + "indexmap 2.14.0", "semver", "serde", ] @@ -7960,9 +7978,9 @@ version = "0.221.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d06bfa36ab3ac2be0dee563380147a5b81ba10dd8885d7fbbc9eb574be67d185" dependencies = [ - "bitflags 2.11.0", + "bitflags 2.11.1", "hashbrown 0.15.5", - "indexmap 2.13.1", + "indexmap 2.14.0", "semver", "serde", ] @@ -7973,9 +7991,9 @@ version = "0.244.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "47b807c72e1bac69382b3a6fb3dbe8ea4c0ed87ff5629b8685ae6b9a611028fe" dependencies = [ - "bitflags 2.11.0", + "bitflags 2.11.1", "hashbrown 0.15.5", - "indexmap 2.13.1", + "indexmap 2.14.0", "semver", ] @@ -7985,8 +8003,8 @@ version = "0.246.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "71cde4757396defafd25417cfb36aa3161027d06d865b0c24baaae229aac005d" dependencies = [ - "bitflags 2.11.0", - "indexmap 2.13.1", + "bitflags 2.11.1", + "indexmap 2.14.0", "semver", ] @@ -8010,7 +8028,7 @@ dependencies = [ "addr2line", "anyhow", "async-trait", - "bitflags 2.11.0", + "bitflags 2.11.1", "bumpalo", "cc", "cfg-if", @@ -8018,7 +8036,7 @@ dependencies = [ "fxprof-processed-profile", "gimli", "hashbrown 0.14.5", - "indexmap 2.13.1", + "indexmap 2.14.0", "ittapi", "libc", "libm", @@ -8144,7 +8162,7 @@ dependencies = [ "cranelift-bitset", "cranelift-entity", "gimli", - "indexmap 2.13.1", + "indexmap 2.14.0", "log", "object 0.36.7", "postcard", @@ -8223,7 +8241,7 @@ checksum = "1a8e04b9a4c68ad018b330a4f4914b82b01dc3582d715ce21a93564c7f26b19f" dependencies = [ "anyhow", "async-trait", - "bitflags 2.11.0", + "bitflags 2.11.1", "bytes", "cap-fs-ext", "cap-net-ext", @@ -8270,7 +8288,7 @@ checksum = "5f38f7a5eb2f06f53fe943e7fb8bf4197f7cf279f1bc52c0ce56e9d3ffd750a4" dependencies = [ "anyhow", "heck", - "indexmap 2.13.1", + "indexmap 2.14.0", "wit-parser 0.221.3", ] @@ -8307,9 +8325,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.94" +version = "0.3.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd70027e39b12f0849461e08ffc50b9cd7688d942c1c8e3c7b22273236b4dd0a" +checksum = "4f2dfbb17949fa2088e5d39408c48368947b86f7834484e87b73de55bc14d97d" dependencies = [ "js-sys", "wasm-bindgen", @@ -8394,7 +8412,7 @@ checksum = "3b23e3dc273d1e35cab9f38a5f76487aeeedcfa6a3fb594e209ee7b6f8b41dcc" dependencies = [ "anyhow", "async-trait", - "bitflags 2.11.0", + "bitflags 2.11.1", "thiserror 1.0.69", "tracing", "wasmtime", @@ -8801,7 +8819,7 @@ version = "0.36.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f3fd376f71958b862e7afb20cfe5a22830e1963462f3a17f49d82a6c1d1f42d" dependencies = [ - "bitflags 2.11.0", + "bitflags 2.11.1", "windows-sys 0.59.0", ] @@ -8833,7 +8851,7 @@ checksum = "b7c566e0f4b284dd6561c786d9cb0142da491f46a9fbed79ea69cdad5db17f21" dependencies = [ "anyhow", "heck", - "indexmap 2.13.1", + "indexmap 2.14.0", "prettyplease", "syn 2.0.117", "wasm-metadata", @@ -8863,8 +8881,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9d66ea20e9553b30172b5e831994e35fbde2d165325bec84fc43dbf6f4eb9cb2" dependencies = [ "anyhow", - "bitflags 2.11.0", - "indexmap 2.13.1", + "bitflags 2.11.1", + "indexmap 2.14.0", "log", "serde", "serde_derive", @@ -8883,7 +8901,7 @@ checksum = "896112579ed56b4a538b07a3d16e562d101ff6265c46b515ce0c701eef16b2ac" dependencies = [ "anyhow", "id-arena", - "indexmap 2.13.1", + "indexmap 2.14.0", "log", "semver", "serde", @@ -8901,7 +8919,7 @@ checksum = "ecc8ac4bc1dc3381b7f59c34f00b67e18f910c2c0f50015669dde7def656a736" dependencies = [ "anyhow", "id-arena", - "indexmap 2.13.1", + "indexmap 2.14.0", "log", "semver", "serde", @@ -9194,7 +9212,7 @@ dependencies = [ "crossbeam-utils", "displaydoc", "flate2", - "indexmap 2.13.1", + "indexmap 2.14.0", "memchr", "thiserror 2.0.18", "zopfli", diff --git a/Cargo.toml b/Cargo.toml index e6d532e5b..0c735909c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -199,13 +199,15 @@ tracing-test = "0.2" tokio-tungstenite = "0.26" testcontainers-modules = { version = "0.11", features = ["postgres"] } pretty_assertions = "1" -insta = "1.46.3" +insta = { version = "1.46.3", features = ["json"] } rstest = "0.26.1" -proptest = "1.6.0" tempfile = "3" mockall = "0.13" trybuild = "1" +proptest = "1.6.0" +delegate = "0.13" gag = "1.0.0" +scopeguard = "1.2.0" [features] default = ["postgres", "libsql", "html-to-markdown", "docker"] @@ -221,7 +223,8 @@ postgres = [ "dep:pgvector", "rust_decimal/db-tokio-postgres", ] -libsql = ["dep:libsql", "test-helpers"] +libsql = ["dep:libsql"] +libsql-test-helpers = ["libsql", "test-helpers"] integration = [] html-to-markdown = ["dep:html-to-markdown-rs", "dep:readabilityrs"] bedrock = ["dep:aws-config", "dep:aws-sdk-bedrockruntime", "dep:aws-smithy-types"] diff --git a/Makefile b/Makefile index d2984079d..36a843785 100644 --- a/Makefile +++ b/Makefile @@ -28,13 +28,13 @@ check-fmt: typecheck: $(CARGO) check --all --benches --tests --examples $(TEST_FEATURES) - $(CARGO) check --all --benches --tests --examples --no-default-features --features libsql,test-helpers + $(CARGO) check --all --benches --tests --examples --no-default-features --features libsql-test-helpers $(CARGO) check --all --benches --tests --examples --all-features $(TEST_FEATURES) $(CARGO) check --manifest-path $(GITHUB_TOOL_MANIFEST) --tests lint: $(CARGO) clippy --all --benches --tests --examples $(TEST_FEATURES) -- -D warnings - $(CARGO) clippy --all --benches --tests --examples --no-default-features --features libsql,test-helpers -- -D warnings + $(CARGO) clippy --all --benches --tests --examples --no-default-features --features libsql-test-helpers -- -D warnings $(CARGO) clippy --all --benches --tests --examples --all-features $(TEST_FEATURES) -- -D warnings $(CARGO) clippy --manifest-path $(GITHUB_TOOL_MANIFEST) --tests -- -D warnings @@ -51,15 +51,15 @@ test-cargo: test-matrix: $(MAKE) build-github-tool-wasm $(NEXTEST) run --workspace $(TEST_FEATURES) --profile $(NEXTEST_PROFILE) - $(NEXTEST) run --workspace --no-default-features --features libsql,test-helpers --profile $(NEXTEST_PROFILE) - $(NEXTEST) run --workspace --features postgres,libsql,html-to-markdown,test-helpers --profile $(NEXTEST_PROFILE) + $(NEXTEST) run --workspace --no-default-features --features libsql-test-helpers --profile $(NEXTEST_PROFILE) + $(NEXTEST) run --workspace --features postgres,libsql-test-helpers,html-to-markdown --profile $(NEXTEST_PROFILE) $(CARGO) test --manifest-path $(GITHUB_TOOL_MANIFEST) -- --nocapture test-matrix-cargo: $(MAKE) build-github-tool-wasm $(CARGO) test $(TEST_FEATURES) -- --nocapture - $(CARGO) test --no-default-features --features libsql,test-helpers -- --nocapture - $(CARGO) test --features postgres,libsql,html-to-markdown,test-helpers -- --nocapture + $(CARGO) test --no-default-features --features libsql-test-helpers -- --nocapture + $(CARGO) test --features postgres,libsql-test-helpers,html-to-markdown -- --nocapture $(CARGO) test --manifest-path $(GITHUB_TOOL_MANIFEST) -- --nocapture clean: diff --git a/docs/developers-guide.md b/docs/developers-guide.md index 3e36afeac..12bad1beb 100644 --- a/docs/developers-guide.md +++ b/docs/developers-guide.md @@ -331,6 +331,71 @@ export DATABASE_URL=postgres://localhost/ironclaw Adjust the connection string if the local PostgreSQL instance requires a different host, user, or password. +### Atomic terminal job persistence + +Use `Database::persist_terminal_result_and_status(...)` with +`TerminalJobPersistence` whenever a code path must persist a terminal +`agent_jobs` status and its matching `job_events` row as one unit. This is the +required path for worker completion, failure, and stuck transitions where +split writes could leave the job row and event history out of sync. + +Prefer the atomic path instead of separate status and event writes when all of +the following are true: + +- the status transition is terminal (`completed`, `failed`, or `stuck`) +- the event payload is the canonical terminal result that history readers and + SSE consumers expect +- the caller must roll back the terminal transition if either write fails + +The contract is: + +- the `agent_jobs` update and the `job_events` insert succeed together or are + both rolled back +- the API returns an error when the job does not exist, the job is not a + direct agent job, or the backend cannot complete the transaction +- callers remain responsible for restoring any in-memory state if the atomic + write fails after the local state machine has already advanced + +Backend expectations: + +- PostgreSQL executes both writes inside one database transaction and rolls + back both records on any failure +- libSQL follows the same all-or-nothing contract for the writes it owns, but + callers should still treat transport or replication failures as failed + writes and retry or roll back their in-memory state accordingly +- `NullDatabase` accepts the call for tests and does not persist anything + +Common failure modes include missing jobs, non-direct jobs, constraint +violations, serialization errors, and pool or transport failures. Callers +should surface the error, avoid assuming the terminal state was stored, and +delegate retry or compensation to the workflow that owns the job. + +Example: + +```rust +store + .persist_terminal_result_and_status(TerminalJobPersistence { + job_id, + status: JobState::Completed, + failure_reason: None, + event_type: SandboxEventType::from("result"), + event_data: &serde_json::json!({ + "status": "completed", + "success": true, + "message": "Job completed successfully", + }), + }) + .await?; +``` + +Migration guidance: + +- replace paired terminal `update_job_status(...)` and `save_job_event(...)` + calls with `persist_terminal_result_and_status(...)` +- keep non-terminal progress updates on the older separate APIs +- add rollback regression coverage for both supported backends before + releasing new terminal transitions + ## End-to-end (E2E) prerequisites For browser-based tests: @@ -517,6 +582,40 @@ reload sequence: The manager is created via `create_hot_reload_manager()` which wires together the default implementations based on available stores. +### Webhook server lifecycle / listener-based API + +`WebhookServer::start_with_listener()` and +`WebhookServer::restart_with_listener()` are the listener-oriented variants of +the older bind-by-address lifecycle. They accept a pre-bound +`tokio::net::TcpListener`, which means the caller owns listener acquisition and +bind failure timing before handing the socket to the webhook server. + +The contract differs from `start()` and `restart_with_addr()` in three +important ways: + +- the caller passes an already-bound listener instead of asking + `WebhookServer` to bind one internally; +- `config.addr` is updated from `listener.local_addr()` so the stored runtime + address reflects the real bound socket; and +- the server still merges any queued route fragments into one router on first + start and saves that router in `merged_router` for later listener restarts. + +Use the listener-based API for hot-reload and integration-test flows that need +OS-selected ports, externally managed socket setup, or socket handoff between +components. In both methods, route ownership remains with the server once the +listener has been accepted; callers should finish route registration before the +first start, just as they would with `start()`. + +Migration notes for maintainers: + +- pre-bind the listener and pass ownership into the method; +- expect the methods to remain async because the serving task is still spawned, + and graceful shutdown wiring still happens inside `WebhookServer`; +- handle bind and startup failures through `ChannelError::StartupFailed`, which + now covers listener-derived startup errors as well as internal bind errors; +- prefer `restart_with_listener()` in reload paths when the caller needs to + validate a replacement listener before the old one is torn down. + ### Extension guidance Adding a new config source: diff --git a/docs/testing-abstractions.md b/docs/testing-abstractions.md new file mode 100644 index 000000000..d7f43e430 --- /dev/null +++ b/docs/testing-abstractions.md @@ -0,0 +1,171 @@ +# Testing abstractions guide + +This document describes the crate-wide testing abstractions available in the +`ironclaw::testing` module and when to use each one. + +Note: `ironclaw::testing` and all of its re-exports are test-only surfaces. +They are compiled only when `#[cfg(test)]` is active, so these symbols are +unavailable in non-test builds and will fail with unresolved import or +visibility errors if used from production code or library consumers. Use the +`ironclaw::testing` module and its re-exports only from tests or +`#[cfg(test)]`-gated helper crates. + +## Overview + +The testing module provides several complementary abstractions for different +testing scenarios: + +Table: Testing abstractions and recommended use cases + +| Abstraction | Purpose | Use when | +| ----------- | ------- | -------- | +| `TestHarnessBuilder` | Full integration testing with real database | Testing actual persistence with a real database | +| `CapturingStore` | Unit testing without database | Verifying interactions without a real database | +| `NullDatabase` | Baseline test double | Creating baseline test doubles or custom mocks | + +## Test harness builder (`TestHarnessBuilder`) + +Located in: `crate::testing::TestHarnessBuilder` + +The `TestHarnessBuilder` constructs a fully-wired `AgentDeps` with a real +libSQL-backed database (when the `libsql` feature is enabled). This is the +correct choice for integration-style tests that need to verify actual +persistence behaviour. + +```rust +use ironclaw::testing::TestHarnessBuilder; + +#[tokio::test] +async fn test_something() { + let harness = TestHarnessBuilder::new().build().await; + // use harness.deps, harness.db, etc. +} +``` + +**When to use:** Choose `TestHarnessBuilder` to verify actual database +persistence or to test components that require a real `Database` trait +implementation. + +**Do not mix with:** `CapturingStore`. The harness uses its own database +internally; mixing it with `CapturingStore` will cause confusing behaviour. + +## Capturing store (`CapturingStore`) + +Located in: `crate::testing::CapturingStore` + +`CapturingStore` is a decorator wrapper around `NullDatabase` that records all +status updates and events for later inspection. It implements the `Database` +trait and can be used anywhere a database is required. + +```rust +use std::sync::Arc; + +use ironclaw::testing::CapturingStore; + +#[tokio::test] +async fn captures_calls() { + let store = Arc::new(CapturingStore::new()); + // Pass Arc::clone(&store) to components that need a Database + // ... exercise the system under test ... + + // Later, inspect captured calls: + let _status = store.calls().last_status.lock().await.clone(); +} +``` + +**Related types:** + +- `StatusCall` / `StatusCallWithId` — Captured status update calls +- `EventCall` / `EventCallWithId` — Captured event calls with full history + +**When to use:** Choose `CapturingStore` for unit tests that must not hit a +real database but need to verify that persistence calls were made correctly. + +**Do not mix with:** The full `TestHarnessBuilder`. Use `CapturingStore` with +manually-constructed components, not the full harness. + +## Null database (`NullDatabase`) + +Located in: `crate::testing::NullDatabase` + +`NullDatabase` is a no-op database implementation that mostly returns empty +defaults (`Ok(None)`, `Ok(vec![])`, and similar) and serves as a baseline for +test doubles that need to override only specific methods. There are important +exceptions: `NullWorkspaceStore` document reads return +`NullDatabase::doc_not_found(...)`, which constructs the concrete +`WorkspaceError::DocumentNotFound` variant, and chunk insertion synthesizes +stable Universally Unique Identifiers (UUIDs) instead of returning a trivial +default. + +```rust +use ironclaw::testing::NullDatabase; + +fn example() { + let db = NullDatabase::new(); + // Most operations return empty defaults, but workspace reads return + // NullDatabase::doc_not_found(...) / WorkspaceError::DocumentNotFound, + // and insert_chunk synthesizes IDs. + let _ = db; +} +``` + +**When to use:** Use `NullDatabase` as a base for custom mocks that require +fine-grained control over specific database operations. + +## Worker harness + +Located in: `crate::testing::worker_harness` + +The worker harness provides helpers for constructing `Worker` instances in +tests, including: + +- `make_worker()` — Build a Worker with the given tools +- `make_worker_with_capturing_store()` — Build a Worker with a CapturingStore +- `TerminalMethod` — Helper enum for driving terminal state transitions + +```rust +#[tokio::test] +async fn test_terminal_completed() -> anyhow::Result<()> { + use ironclaw::testing::worker_harness::{make_worker, TerminalMethod}; + + let worker = make_worker(vec![]).await?; + TerminalMethod::Completed.apply_transition(&worker).await?; + Ok(()) +} +``` + +**When to use:** Use the worker harness when testing `Worker` behaviour +specifically. + +## Choosing the right abstraction + +This flowchart guides maintainers to the right testing abstraction by first +checking whether the test needs real persistence, then whether it only needs +to inspect captured calls, and finally whether it needs a bespoke mock. + +```mermaid +flowchart TD + start[Choose a testing abstraction] + persist{Need to test persistence?} + calls{Need to verify calls?} + mock{Writing a custom mock?} + harness[TestHarnessBuilder] + capturing[CapturingStore] + null_db[NullDatabase] + + start --> persist + persist -- Yes --> harness + persist -- No --> calls + calls -- Yes --> capturing + calls -- No --> mock + mock -- Yes --> null_db + mock -- No --> null_db +``` + +Figure: Choosing the right testing abstraction + +## Additional resources + +- `crate::testing::TestHarnessBuilder` — Full harness builder +- `crate::testing::null_db::{NullDatabase, CapturingStore, EventCall, + StatusCall}` — Database test doubles diff --git a/docs/webhook-server-design.md b/docs/webhook-server-design.md index f696c3bb5..a4112a134 100644 --- a/docs/webhook-server-design.md +++ b/docs/webhook-server-design.md @@ -142,7 +142,35 @@ This behaviour is directly exercised by the current tests in restart leaves the old listener serving traffic and restores the previous address in server state. -## 6. Relationship to hot reload +## 6. Listener-based lifecycle API + +The listener-based lifecycle methods, +`start_with_listener()` and `restart_with_listener()`, extend the original +address-driven API without changing the server's route-ownership model. + +They exist for two concrete call patterns: + +- hot-reload flows that want to validate a replacement listener before the old + one is shut down, and +- integration tests that need OS-selected ports or pre-bound sockets. + +The contract is: + +- the caller pre-binds a `tokio::net::TcpListener` and transfers ownership + into `WebhookServer`; +- the server updates `config.addr` from `listener.local_addr()` so subsequent + status and restart logic sees the real active bind address; +- first start still merges queued routes and stores the result in + `merged_router`; and +- subsequent listener-based restarts reuse `merged_router` rather than asking + channels to rebuild their route fragments. + +That makes the listener-based API an internal lifecycle extension, not a new +route-registration model. Callers should still finish route setup before the +first start and should still expect async startup and +`ChannelError::StartupFailed` on listener or server boot failures. + +## 7. Relationship to hot reload The webhook server and the SIGHUP handler in `src/main.rs` have different responsibilities and should be understood separately. @@ -175,7 +203,7 @@ That distinction matters when debugging incidents: - if the question is “why did the runtime try to restart at all?”, the answer lives in `main.rs`. -## 7. Current trade-offs +## 8. Current trade-offs The present design is pragmatic, but it comes with trade-offs. @@ -192,7 +220,7 @@ The present design is pragmatic, but it comes with trade-offs. None of those trade-offs are inherently wrong for the current system. They are simply the shape maintainers need to preserve or revise deliberately. -## 8. Maintainer guidance +## 9. Maintainer guidance When changing webhook behaviour, treat these as the current invariants: @@ -210,7 +238,11 @@ for: - failed rebind with old-listener rollback; and - clean shutdown after the server has been restarted. -## 9. References +Internal API note: reload paths that can pre-bind a replacement socket should +prefer `restart_with_listener()` over `restart_with_addr()` so bind failures +surface before the old listener is torn down. + +## 10. References - `src/channels/webhook_server.rs` - `src/main.rs` diff --git a/proptest-regressions/agent/dispatcher/delegate/llm_hooks.txt b/proptest-regressions/agent/dispatcher/delegate/llm_hooks.txt new file mode 100644 index 000000000..4112f9c23 --- /dev/null +++ b/proptest-regressions/agent/dispatcher/delegate/llm_hooks.txt @@ -0,0 +1,7 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. +cc dbc84d471c9b38875af1eab10d8a1fa3a9048d0b6f6b31faaeb0e015472ed381 # shrinks to messages = [ChatMessage { role: Assistant, content: "", content_parts: [], tool_call_id: None, name: None, tool_calls: None }, ChatMessage { role: System, content: "", content_parts: [], tool_call_id: None, name: None, tool_calls: None }] diff --git a/src/agent/dispatcher/delegate/llm_hooks.rs b/src/agent/dispatcher/delegate/llm_hooks.rs new file mode 100644 index 000000000..c86efc131 --- /dev/null +++ b/src/agent/dispatcher/delegate/llm_hooks.rs @@ -0,0 +1,451 @@ +//! LLM hook implementations for the chat delegate. +//! +//! Contains the LLM call hooks (check_signals, before_llm_call, call_llm, +//! handle_text_response) and helper functions for message compaction and +//! response sanitization. + +use crate::agent::agentic_loop::{LoopOutcome, LoopSignal, TextAction}; +use crate::agent::cost_guard::CostGuard; +use crate::agent::dispatcher::delegate::ChatDelegate; +use crate::agent::session::ThreadState; +use crate::channels::StatusUpdate; +use crate::error::Error; +use crate::history::LlmCallRecord; +use crate::llm::{ChatMessage, Reasoning, ReasoningContext}; + +/// Check if the agent loop should stop due to external signals. +pub(crate) async fn check_signals(delegate: &ChatDelegate<'_>) -> LoopSignal { + let sess = delegate.session.lock().await; + if let Some(thread) = sess.threads.get(&delegate.thread_id) + && thread.state == ThreadState::Interrupted + { + return LoopSignal::Stop; + } + LoopSignal::Continue +} + +/// Prepare context before calling the LLM. +pub(crate) async fn before_llm_call( + delegate: &ChatDelegate<'_>, + reason_ctx: &mut ReasoningContext, + iteration: usize, +) -> Option { + // Inject a nudge message when approaching the iteration limit so the + // LLM is aware it should produce a final answer on the next turn. + if iteration == delegate.nudge_at { + reason_ctx.messages.push(ChatMessage::system( + "You are approaching the tool call limit. \ + Provide your best final answer on the next response \ + using the information you have gathered so far. \ + Do not call any more tools.", + )); + } + + let force_text = iteration >= delegate.force_text_at; + + // Refresh tool definitions each iteration so newly built tools become visible + let tool_defs = delegate.agent.tools().tool_definitions().await; + + // Apply trust-based tool attenuation if skills are active. + let tool_defs = if !delegate.active_skills.is_empty() { + let result = crate::skills::attenuate_tools(&tool_defs, &delegate.active_skills); + tracing::debug!( + min_trust = %result.min_trust, + tools_available = result.tools.len(), + tools_removed = result.removed_tools.len(), + removed = ?result.removed_tools, + explanation = %result.explanation, + "Tool attenuation applied" + ); + result.tools + } else { + tool_defs + }; + + // Update context for this iteration + reason_ctx.available_tools = if force_text { Vec::new() } else { tool_defs }; + reason_ctx.system_prompt = if force_text { + Some(delegate.cached_prompt_no_tools.clone()) + } else { + None + }; + reason_ctx.force_text = force_text; + + if force_text { + tracing::info!( + iteration, + "Forcing text-only response (iteration limit reached)" + ); + } + + let _ = delegate + .agent + .channels + .send_status( + &delegate.message.channel, + StatusUpdate::Thinking("Calling LLM...".into()), + &delegate.message.metadata, + ) + .await; + + None +} + +/// Call the LLM and handle context-length-exceeded errors. +pub(crate) async fn call_llm( + delegate: &ChatDelegate<'_>, + reasoning: &Reasoning, + reason_ctx: &mut ReasoningContext, + iteration: usize, +) -> Result { + check_cost_guardrail(delegate.agent.cost_guard()).await?; + let output = invoke_with_retry(delegate, reasoning, reason_ctx, iteration).await?; + record_and_log_cost(delegate, &output).await; + Ok(output) +} + +async fn check_cost_guardrail(cost_guard: &CostGuard) -> Result<(), Error> { + if let Err(limit) = cost_guard.check_allowed().await { + return Err(crate::error::LlmError::InvalidResponse { + provider: "agent".to_string(), + reason: limit.to_string(), + } + .into()); + } + Ok(()) +} + +async fn invoke_with_retry( + delegate: &ChatDelegate<'_>, + reasoning: &Reasoning, + reason_ctx: &mut ReasoningContext, + iteration: usize, +) -> Result { + match reasoning.respond_with_tools(reason_ctx).await { + Ok(output) => Ok(output), + Err(crate::error::LlmError::ContextLengthExceeded { used, limit }) => { + tracing::warn!( + used, + limit, + iteration, + "Context length exceeded, compacting messages and retrying" + ); + record_partial_llm_call(delegate, u32::try_from(used).unwrap_or(u32::MAX)).await; + reason_ctx.messages = compact_messages_for_retry(&reason_ctx.messages); + if reason_ctx.force_text { + reason_ctx.available_tools.clear(); + } + check_cost_guardrail(delegate.agent.cost_guard()).await?; + match reasoning.respond_with_tools(reason_ctx).await { + Ok(output) => Ok(output), + Err(retry_err) => { + if let crate::error::LlmError::ContextLengthExceeded { + used: retry_used, .. + } = retry_err + { + record_partial_llm_call( + delegate, + u32::try_from(retry_used).unwrap_or(u32::MAX), + ) + .await; + } + tracing::error!( + original_used = used, + original_limit = limit, + retry_error = %retry_err, + "Retry after auto-compaction also failed" + ); + Err(crate::error::Error::from(retry_err)) + } + } + } + Err(e) => Err(e.into()), + } +} + +async fn record_and_log_cost(delegate: &ChatDelegate<'_>, output: &crate::llm::RespondOutput) { + let model_name = delegate.agent.llm().active_model_name(); + let read_discount = delegate.agent.llm().cache_read_discount(); + let write_multiplier = delegate.agent.llm().cache_write_multiplier(); + let call_cost = delegate + .agent + .cost_guard() + .record_llm_call( + &model_name, + output.usage.input_tokens, + output.usage.output_tokens, + output.usage.cache_read_input_tokens, + output.usage.cache_creation_input_tokens, + read_discount, + write_multiplier, + Some(delegate.agent.llm().cost_per_token()), + ) + .await; + tracing::debug!( + "LLM call used {} input + {} output tokens (${:.6})", + output.usage.input_tokens, + output.usage.output_tokens, + call_cost, + ); +} + +async fn record_partial_llm_call(delegate: &ChatDelegate<'_>, used: u32) { + let model_name = delegate.agent.llm().active_model_name(); + let read_discount = delegate.agent.llm().cache_read_discount(); + let write_multiplier = delegate.agent.llm().cache_write_multiplier(); + let call_cost = delegate + .agent + .cost_guard() + .record_llm_call( + &model_name, + used, + 0, + 0, + 0, + read_discount, + write_multiplier, + Some(delegate.agent.llm().cost_per_token()), + ) + .await; + + let Some(store) = delegate.agent.store() else { + return; + }; + + let purpose = + "context_length_exceeded:auto_compaction_retry (partial/estimated input tokens only)"; + let record = LlmCallRecord { + job_id: Some(delegate.job_ctx.job_id), + conversation_id: delegate.job_ctx.conversation_id, + provider: "agent", + model: &model_name, + input_tokens: used, + output_tokens: 0, + cost: call_cost, + purpose: Some(purpose), + }; + + if let Err(error) = store.record_llm_call(&record).await { + tracing::warn!(%error, "Failed to persist partial LLM call audit entry"); + } +} + +/// Handle a text response from the LLM. +pub(crate) async fn handle_text_response(_delegate: &ChatDelegate<'_>, text: &str) -> TextAction { + // Strip internal "[Called tool ...]" text that can leak when + // provider flattening (e.g. NEAR AI) converts tool_calls to + // plain text and the LLM echoes it back. + let sanitized = strip_internal_tool_call_text(text); + TextAction::Return(LoopOutcome::Response(sanitized)) +} + +/// Collect all System messages from the slice. +fn collect_system_messages(messages: &[ChatMessage]) -> Vec { + use crate::llm::Role; + messages + .iter() + .filter(|m| m.role == Role::System) + .cloned() + .collect() +} + +/// Compact messages when a User message is present. +fn compact_around_user_message(messages: &[ChatMessage], user_idx: usize) -> Vec { + let mut compacted = collect_system_messages(&messages[..user_idx]); + + if user_idx > 0 { + compacted.push(ChatMessage::system( + "[Note: Earlier conversation history was automatically compacted \ + to fit within the context window. The most recent exchange is preserved below.]", + )); + } + + compacted.extend_from_slice(&messages[user_idx..]); + compacted +} + +/// Compact messages when no User message exists (edge case). +fn compact_without_user_message(messages: &[ChatMessage]) -> Vec { + use crate::llm::Role; + let non_system_indices: Vec<_> = messages + .iter() + .enumerate() + .filter_map(|(idx, message)| (message.role != Role::System).then_some(idx)) + .collect(); + let keep = if non_system_indices.len() >= 2 { 2 } else { 1 }; + let retained_non_system: std::collections::HashSet<_> = + non_system_indices.into_iter().rev().take(keep).collect(); + messages + .iter() + .enumerate() + .filter(|(idx, message)| message.role == Role::System || retained_non_system.contains(idx)) + .map(|(_, message)| message.clone()) + .collect() +} + +/// Compact messages for retry after a context-length-exceeded error. +/// +/// Keeps all `System` messages (which carry the system prompt and instructions), +/// finds the last `User` message, and retains it plus every subsequent message +/// (the current turn's assistant tool calls and tool results). A short note is +/// inserted so the LLM knows earlier history was dropped. +pub(crate) fn compact_messages_for_retry(messages: &[ChatMessage]) -> Vec { + use crate::llm::Role; + match messages.iter().rposition(|m| m.role == Role::User) { + Some(idx) => compact_around_user_message(messages, idx), + None => compact_without_user_message(messages), + } +} + +/// Strip internal `[Called tool ...]` and `[Tool ... returned: ...]` markers +/// from a response string. These markers are inserted by provider-level message +/// flattening (e.g. NEAR AI) and can leak into the user-visible response when +/// the LLM echoes them back. +pub(crate) fn strip_internal_tool_call_text(text: &str) -> String { + // Remove lines that are purely internal tool-call markers. + // Pattern: lines matching `[Called tool (...)]` or `[Tool returned: ...]` + let result = text + .lines() + .filter(|line| { + let trimmed = line.trim(); + !((trimmed.starts_with("[Called tool ") && trimmed.ends_with(']')) + || (trimmed.starts_with("[Tool ") + && trimmed.contains(" returned:") + && trimmed.ends_with(']'))) + }) + .fold(String::new(), |mut acc, s| { + if !acc.is_empty() { + acc.push('\n'); + } + acc.push_str(s); + acc + }); + + let result = result.trim(); + if result.is_empty() { + "I wasn't able to complete that request. Could you try rephrasing or providing more details?".to_string() + } else { + result.to_string() + } +} + +#[cfg(test)] +mod tests { + use proptest::prelude::*; + + use super::*; + use crate::llm::Role; + + const COMPACTION_NOTE: &str = concat!( + "[Note: Earlier conversation history was automatically compacted ", + "to fit within the context window. The most recent exchange is preserved below.]" + ); + + fn message(role: Role, content: String) -> ChatMessage { + ChatMessage { + role, + content, + content_parts: Vec::new(), + tool_call_id: None, + name: None, + tool_calls: None, + } + } + + fn message_fingerprint(message: &ChatMessage) -> (Role, &str) { + (message.role, message.content.as_str()) + } + + fn generated_message_strategy() -> impl Strategy { + ( + prop_oneof![ + Just(Role::System), + Just(Role::User), + Just(Role::Assistant), + Just(Role::Tool), + ], + any::(), + ) + .prop_map(|(role, content)| message(role, content)) + } + + proptest! { + #[test] + fn compact_messages_for_retry_preserves_compaction_invariants( + messages in prop::collection::vec(generated_message_strategy(), 0..32) + ) { + let compacted = compact_messages_for_retry(&messages); + let compacted_without_note: Vec<_> = compacted + .iter() + .filter(|message| message.role != Role::System || message.content != COMPACTION_NOTE) + .collect(); + + let mut next_idx = 0usize; + for compacted_message in &compacted_without_note { + let fingerprint = message_fingerprint(compacted_message); + let matched_idx = messages[next_idx..] + .iter() + .position(|original| message_fingerprint(original) == fingerprint) + .map(|offset| next_idx + offset); + prop_assert!( + matched_idx.is_some(), + "compacted message {:?} should appear in original input after index {}", + fingerprint, + next_idx + ); + next_idx = matched_idx.expect("position checked above") + 1; + } + + if let Some(user_idx) = messages.iter().rposition(|message| message.role == Role::User) { + let expected_suffix: Vec<_> = messages[user_idx..] + .iter() + .map(message_fingerprint) + .collect(); + let actual_suffix: Vec<_> = compacted_without_note + .iter() + .rev() + .take(expected_suffix.len()) + .copied() + .collect::>() + .into_iter() + .rev() + .map(message_fingerprint) + .collect(); + prop_assert_eq!(actual_suffix, expected_suffix); + } + + for system_message in messages.iter().filter(|message| message.role == Role::System) { + let original_count = messages + .iter() + .filter(|message| message_fingerprint(message) == message_fingerprint(system_message)) + .count(); + let compacted_count = compacted + .iter() + .filter(|message| message_fingerprint(message) == message_fingerprint(system_message)) + .count(); + prop_assert!( + compacted_count >= original_count, + "expected all system messages to remain present: {:?}", + message_fingerprint(system_message) + ); + } + + let note_count = compacted + .iter() + .filter(|message| message.role == Role::System && message.content == COMPACTION_NOTE) + .count(); + let truncation_occurred = messages + .iter() + .rposition(|message| message.role == Role::User) + .is_some_and(|user_idx| user_idx > 0); + + prop_assert!(note_count <= 1, "compaction note inserted more than once"); + if note_count == 1 { + prop_assert!( + truncation_occurred, + "compaction note should only appear when history before the preserved suffix was truncated" + ); + } + } + } +} diff --git a/src/agent/dispatcher/delegate/mod.rs b/src/agent/dispatcher/delegate/mod.rs new file mode 100644 index 000000000..f8da085f9 --- /dev/null +++ b/src/agent/dispatcher/delegate/mod.rs @@ -0,0 +1,91 @@ +//! Chat delegate implementation for the agentic loop. +//! +//! Contains the `ChatDelegate` struct and its implementation of `NativeLoopDelegate`, +//! which customizes the shared agentic loop for interactive chat sessions. +//! +//! This module is split into child submodules by responsibility: +//! - `llm_hooks`: LLM call hooks and helper functions +//! - `tool_exec`: Tool execution logic and helpers + +mod llm_hooks; +mod tool_exec; + +use std::sync::Arc; + +use tokio::sync::Mutex; +use uuid::Uuid; + +use crate::agent::Agent; +use crate::agent::agentic_loop::{LoopOutcome, LoopSignal, NativeLoopDelegate, TextAction}; +use crate::agent::session::Session; +use crate::channels::IncomingMessage; +use crate::context::JobContext; +use crate::error::Error; +use crate::llm::{Reasoning, ReasoningContext}; + +// Re-export items used by other modules in the crate. +#[cfg(test)] +pub(crate) use llm_hooks::{compact_messages_for_retry, strip_internal_tool_call_text}; +pub(crate) use tool_exec::{ + ToolCallSpec, check_auth_required, execute_chat_tool_standalone, parse_auth_result, +}; + +/// Delegate for the chat (dispatcher) context. +/// +/// Implements `LoopDelegate` to customize the shared agentic loop for +/// interactive chat sessions with the full 3-phase tool execution +/// (preflight → parallel exec → post-flight), approval flow, hooks, +/// auth intercept, and cost tracking. +pub(super) struct ChatDelegate<'a> { + pub(super) agent: &'a Agent, + pub(super) session: Arc>, + pub(super) thread_id: Uuid, + pub(super) message: &'a IncomingMessage, + pub(super) job_ctx: JobContext, + pub(super) active_skills: Vec, + pub(super) cached_prompt: String, + pub(super) cached_prompt_no_tools: String, + pub(super) nudge_at: usize, + pub(super) force_text_at: usize, + pub(super) user_tz: chrono_tz::Tz, +} + +impl<'a> NativeLoopDelegate for ChatDelegate<'a> { + async fn check_signals(&self) -> LoopSignal { + llm_hooks::check_signals(self).await + } + + async fn before_llm_call( + &self, + reason_ctx: &mut ReasoningContext, + iteration: usize, + ) -> Option { + llm_hooks::before_llm_call(self, reason_ctx, iteration).await + } + + async fn call_llm( + &self, + reasoning: &Reasoning, + reason_ctx: &mut ReasoningContext, + iteration: usize, + ) -> Result { + llm_hooks::call_llm(self, reasoning, reason_ctx, iteration).await + } + + async fn handle_text_response( + &self, + text: &str, + _reason_ctx: &mut ReasoningContext, + ) -> TextAction { + llm_hooks::handle_text_response(self, text).await + } + + async fn execute_tool_calls( + &self, + tool_calls: Vec, + content: Option, + reason_ctx: &mut ReasoningContext, + ) -> Result, Error> { + tool_exec::execute_tool_calls(self, tool_calls, content, reason_ctx).await + } +} diff --git a/src/agent/dispatcher/delegate/tool_exec/execution.rs b/src/agent/dispatcher/delegate/tool_exec/execution.rs new file mode 100644 index 000000000..4d721fbb4 --- /dev/null +++ b/src/agent/dispatcher/delegate/tool_exec/execution.rs @@ -0,0 +1,231 @@ +//! Execution stage for chat tool execution. +//! +//! Runs the preflight-approved tool calls, batches them where safe, and +//! captures raw results for the later postflight phase to interpret. + +use tokio::task::JoinSet; + +use crate::agent::dispatcher::delegate::ChatDelegate; +use crate::channels::StatusUpdate; +use crate::context::JobContext; +use crate::error::Error; +use crate::safety::SafetyLayer; +use crate::tools::ToolRegistry; + +use super::postflight::check_auth_required; + +/// Allocate the exec-results buffer and dispatch Phase 2 tool execution. +pub(super) async fn run_phase2( + delegate: &ChatDelegate<'_>, + preflight_len: usize, + runnable: &[(usize, crate::llm::ToolCall)], +) -> Vec>> { + let mut exec_results: Vec>> = + (0..preflight_len).map(|_| None).collect(); + let mut start = 0; + while start < runnable.len() { + if is_auth_barrier_tool(&runnable[start].1.name) { + let batch = &runnable[start..=start]; + run_tool_batch_inline(delegate, batch, &mut exec_results).await; + if let Some(result) = &exec_results[runnable[start].0] + && check_auth_required(&runnable[start].1.name, result).is_some() + { + break; + } + start += 1; + continue; + } + + let mut end = start; + while end < runnable.len() && !is_auth_barrier_tool(&runnable[end].1.name) { + end += 1; + } + + let batch = &runnable[start..end]; + if batch.len() <= 1 { + run_tool_batch_inline(delegate, batch, &mut exec_results).await; + } else { + run_tool_batch_parallel(delegate, batch, &mut exec_results).await; + } + start = end; + } + exec_results +} + +pub(super) fn is_auth_barrier_tool(tool_name: &str) -> bool { + matches!(tool_name, "tool_auth" | "tool_activate") +} + +/// Run a batch of tools inline (sequential execution for small batches). +pub(super) async fn run_tool_batch_inline( + delegate: &ChatDelegate<'_>, + runnable: &[(usize, crate::llm::ToolCall)], + exec_results: &mut [Option>], +) { + for (pf_idx, tc) in runnable { + let result = execute_one_tool(delegate, tc).await; + exec_results[*pf_idx] = Some(result); + } +} + +/// Run a batch of tools in parallel (for large batches). +pub(super) async fn run_tool_batch_parallel( + delegate: &ChatDelegate<'_>, + runnable: &[(usize, crate::llm::ToolCall)], + exec_results: &mut [Option>], +) { + let mut join_set = JoinSet::new(); + + for (pf_idx, tc) in runnable { + let pf_idx = *pf_idx; + let tools = delegate.agent.tools().clone(); + let safety = delegate.agent.safety().clone(); + let channels = delegate.agent.channels.clone(); + let job_ctx = delegate.job_ctx.clone(); + let tc = tc.clone(); + let channel = delegate.message.channel.clone(); + let metadata = delegate.message.metadata.clone(); + + join_set.spawn(async move { + let _ = channels + .send_status( + &channel, + StatusUpdate::ToolStarted { + name: tc.name.clone(), + }, + &metadata, + ) + .await; + + let result = execute_chat_tool_standalone( + &tools, + &safety, + &ToolCallSpec { + name: &tc.name, + params: &tc.arguments, + }, + &job_ctx, + ) + .await; + + let par_tool = tools.get(&tc.name).await; + let _ = channels + .send_status( + &channel, + StatusUpdate::tool_completed( + tc.name.clone(), + &result, + &tc.arguments, + par_tool.as_deref(), + ), + &metadata, + ) + .await; + + (pf_idx, result) + }); + } + + while let Some(join_result) = join_set.join_next().await { + match join_result { + Ok((pf_idx, result)) => { + exec_results[pf_idx] = Some(result); + } + Err(e) => { + if e.is_panic() { + tracing::error!("Chat tool execution task panicked: {}", e); + } else { + tracing::error!("Chat tool execution task cancelled: {}", e); + } + } + } + } + + for (pf_idx, tc) in runnable.iter() { + if exec_results[*pf_idx].is_none() { + tracing::error!( + tool = %tc.name, + "Filling failed task slot with error" + ); + exec_results[*pf_idx] = Some(Err(crate::error::ToolError::ExecutionFailed { + name: tc.name.clone(), + reason: "Task failed during execution".to_string(), + } + .into())); + } + } +} + +/// Execute a single tool inline (for small batches). +pub(super) async fn execute_one_tool( + delegate: &ChatDelegate<'_>, + tc: &crate::llm::ToolCall, +) -> Result { + send_tool_started(delegate, &tc.name).await; + let result = delegate + .agent + .execute_chat_tool(&tc.name, &tc.arguments, &delegate.job_ctx) + .await; + send_tool_completed(delegate, &tc.name, &result, &tc.arguments).await; + result +} + +/// Send ToolStarted status update. +async fn send_tool_started(delegate: &ChatDelegate<'_>, tool_name: &str) { + let _ = delegate + .agent + .channels + .send_status( + &delegate.message.channel, + StatusUpdate::ToolStarted { + name: tool_name.to_string(), + }, + &delegate.message.metadata, + ) + .await; +} + +/// Send tool_completed status update. +async fn send_tool_completed( + delegate: &ChatDelegate<'_>, + tool_name: &str, + result: &Result, + arguments: &serde_json::Value, +) { + let disp_tool = delegate.agent.tools().get(tool_name).await; + let _ = delegate + .agent + .channels + .send_status( + &delegate.message.channel, + StatusUpdate::tool_completed( + tool_name.to_string(), + result, + arguments, + disp_tool.as_deref(), + ), + &delegate.message.metadata, + ) + .await; +} + +/// Specification for a tool call to be executed. +pub(crate) struct ToolCallSpec<'a> { + pub(crate) name: &'a str, + pub(crate) params: &'a serde_json::Value, +} + +/// Execute a chat tool without requiring `&Agent`. +/// +/// This standalone function enables parallel invocation from spawned JoinSet +/// tasks, which cannot borrow `&self`. Delegates to the shared +/// `execute_tool_with_safety` pipeline. +pub(crate) async fn execute_chat_tool_standalone( + tools: &ToolRegistry, + safety: &SafetyLayer, + spec: &ToolCallSpec<'_>, + job_ctx: &JobContext, +) -> Result { + crate::tools::execute::execute_tool_with_safety(tools, safety, spec.name, spec.params, job_ctx) + .await +} diff --git a/src/agent/dispatcher/delegate/tool_exec/mod.rs b/src/agent/dispatcher/delegate/tool_exec/mod.rs new file mode 100644 index 000000000..72f8cdc6d --- /dev/null +++ b/src/agent/dispatcher/delegate/tool_exec/mod.rs @@ -0,0 +1,113 @@ +//! Tool execution logic for the chat delegate. +//! +//! Splits the 3-phase tool execution pipeline into cohesive submodules: +//! preflight, execution, postflight, and recording. + +pub mod execution; +pub mod postflight; +pub mod preflight; +pub mod recording; + +use uuid::Uuid; + +use crate::agent::dispatcher::delegate::ChatDelegate; +use crate::agent::session::PendingApproval; +use crate::channels::StatusUpdate; +use crate::error::Error; +use crate::llm::{ChatMessage, ReasoningContext}; + +pub(crate) use execution::ToolCallSpec; +pub(crate) use execution::execute_chat_tool_standalone; +pub(crate) use postflight::{check_auth_required, parse_auth_result}; + +fn build_pending_approval( + delegate: &ChatDelegate<'_>, + candidate: preflight::ApprovalCandidate, + tool_calls: &[crate::llm::ToolCall], + reason_ctx: &ReasoningContext, +) -> PendingApproval { + let display_params = crate::tools::redact_params( + &candidate.tool_call.arguments, + candidate.tool.sensitive_params(), + ); + PendingApproval { + request_id: Uuid::new_v4(), + tool_name: candidate.tool_call.name.clone(), + parameters: candidate.tool_call.arguments.clone(), + display_parameters: display_params, + description: candidate.tool.description().to_string(), + tool_call_id: candidate.tool_call.id.clone(), + context_messages: reason_ctx.messages.clone(), + deferred_tool_calls: tool_calls[candidate.idx + 1..].to_vec(), + user_timezone: Some(delegate.user_tz.name().to_string()), + } +} + +fn finalized_tool_calls( + original_tool_calls: &[crate::llm::ToolCall], + preflight: &[(crate::llm::ToolCall, preflight::PreflightOutcome)], + approval_needed: Option<&preflight::ApprovalCandidate>, +) -> Vec { + let mut finalized = preflight + .iter() + .map(|(tc, _)| tc.clone()) + .collect::>(); + if let Some(candidate) = approval_needed { + finalized.push(candidate.tool_call.clone()); + finalized.extend_from_slice(&original_tool_calls[candidate.idx + 1..]); + } + finalized +} + +/// Execute tool calls with 3-phase pipeline (preflight → execution → post-flight). +pub(crate) async fn execute_tool_calls( + delegate: &ChatDelegate<'_>, + tool_calls: Vec, + content: Option, + reason_ctx: &mut ReasoningContext, +) -> Result, Error> { + use crate::agent::agentic_loop::LoopOutcome; + + let (batch, approval_needed) = preflight::group_tool_calls(delegate, &tool_calls).await?; + let preflight::ToolBatch { + preflight, + runnable, + } = batch; + let finalized_tool_calls = + finalized_tool_calls(&tool_calls, &preflight, approval_needed.as_ref()); + + reason_ctx + .messages + .push(ChatMessage::assistant_with_tool_calls( + content, + finalized_tool_calls.clone(), + )); + + let _ = delegate + .agent + .channels + .send_status( + &delegate.message.channel, + StatusUpdate::Thinking(format!("Executing {} tool(s)...", tool_calls.len())), + &delegate.message.metadata, + ) + .await; + + recording::record_redacted_tool_calls(delegate, &finalized_tool_calls).await; + + let mut exec_results = execution::run_phase2(delegate, preflight.len(), &runnable).await; + let deferred_auth = + postflight::run_postflight(delegate, preflight, &mut exec_results, reason_ctx).await; + + if let Some(instructions) = deferred_auth { + return Ok(Some(LoopOutcome::Response(instructions))); + } + + if let Some(candidate) = approval_needed { + let pending = + build_pending_approval(delegate, candidate, &finalized_tool_calls, reason_ctx); + return Ok(Some(LoopOutcome::NeedApproval(Box::new(pending)))); + } + + Ok(None) +} diff --git a/src/agent/dispatcher/delegate/tool_exec/postflight.rs b/src/agent/dispatcher/delegate/tool_exec/postflight.rs new file mode 100644 index 000000000..661badfa9 --- /dev/null +++ b/src/agent/dispatcher/delegate/tool_exec/postflight.rs @@ -0,0 +1,373 @@ +//! Postflight stage for chat tool execution. +//! +//! Interprets tool results, emits auth and image side effects, and folds each +//! indexed outcome back into both thread history and the reasoning context. + +use crate::agent::dispatcher::delegate::ChatDelegate; +use crate::channels::StatusUpdate; +use crate::error::Error; +use crate::llm::{ChatMessage, ReasoningContext}; + +use super::execution::is_auth_barrier_tool; +use super::recording::record_tool_outcome; + +/// Parsed auth result fields for emitting StatusUpdate::AuthRequired. +pub(crate) struct ParsedAuthData { + pub(crate) auth_url: Option, + pub(crate) setup_url: Option, +} + +/// Parsed auth result fields for emitting StatusUpdate::AuthRequired. +pub(crate) struct AuthBarrierData { + pub(crate) extension_name: String, + pub(crate) instructions: String, + pub(crate) auth_url: Option, + pub(crate) setup_url: Option, +} + +pub(super) struct ToolCtx<'a> { + pub(super) pf_idx: usize, + pub(super) tc: &'a crate::llm::ToolCall, +} + +/// Parse auth-barrier details from a tool_auth/tool_activate result. +pub(crate) fn parse_auth_barrier( + tool_name: &str, + result: &Result, +) -> Option { + if !is_auth_barrier_tool(tool_name) { + return None; + } + let output = result.as_ref().ok()?; + let parsed: serde_json::Value = serde_json::from_str(output).ok()?; + if parsed.get("awaiting_token") != Some(&serde_json::Value::Bool(true)) { + return None; + } + let extension_name = parsed.get("name")?.as_str()?.to_string(); + let instructions = parsed + .get("instructions") + .and_then(|v| v.as_str()) + .unwrap_or("Please provide your API token/key.") + .to_string(); + let auth_url = parsed + .get("auth_url") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + let setup_url = parsed + .get("setup_url") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + Some(AuthBarrierData { + extension_name, + instructions, + auth_url, + setup_url, + }) +} + +pub(crate) fn parse_auth_result(tool_name: &str, result: &Result) -> ParsedAuthData { + let auth_barrier = parse_auth_barrier(tool_name, result); + ParsedAuthData { + auth_url: auth_barrier.as_ref().and_then(|data| data.auth_url.clone()), + setup_url: auth_barrier.and_then(|data| data.setup_url), + } +} + +pub(crate) fn check_auth_required( + tool_name: &str, + result: &Result, +) -> Option<(String, String)> { + let auth_barrier = parse_auth_barrier(tool_name, result)?; + Some((auth_barrier.extension_name, auth_barrier.instructions)) +} + +async fn handle_auth_barrier( + delegate: &ChatDelegate<'_>, + tc: &crate::llm::ToolCall, + tool_result: &Result, +) -> Option { + let auth_barrier = parse_auth_barrier(&tc.name, tool_result)?; + { + let mut sess = delegate.session.lock().await; + if let Some(thread) = sess.threads.get_mut(&delegate.thread_id) { + thread.enter_auth_mode(auth_barrier.extension_name.clone()); + } + } + let _ = delegate + .agent + .channels + .send_status( + &delegate.message.channel, + StatusUpdate::AuthRequired { + extension_name: auth_barrier.extension_name, + instructions: Some(auth_barrier.instructions.clone()), + auth_url: auth_barrier.auth_url, + setup_url: auth_barrier.setup_url, + }, + &delegate.message.metadata, + ) + .await; + Some(auth_barrier.instructions) +} + +/// Phase 3: iterate preflight outcomes in original order, dispatching each +/// to `handle_rejected_tool` or `process_runnable_tool`. +/// Returns the first deferred-auth instruction string, if any. +pub(super) async fn run_postflight( + delegate: &ChatDelegate<'_>, + preflight: Vec<(crate::llm::ToolCall, super::preflight::PreflightOutcome)>, + exec_results: &mut [Option>], + reason_ctx: &mut ReasoningContext, +) -> Option { + let mut deferred_auth: Option = None; + for (pf_idx, (tc, outcome)) in preflight.into_iter().enumerate() { + match outcome { + super::preflight::PreflightOutcome::Rejected(error_msg) => { + handle_rejected_tool( + delegate, + ToolCtx { pf_idx, tc: &tc }, + &error_msg, + reason_ctx, + ) + .await; + } + super::preflight::PreflightOutcome::Runnable => { + let tool_result = exec_results + .get_mut(pf_idx) + .and_then(Option::take) + .unwrap_or_else(|| { + Err(crate::error::ToolError::ExecutionFailed { + name: tc.name.clone(), + reason: "No result available".to_string(), + } + .into()) + }); + if let Some(instructions) = + process_runnable_tool(delegate, pf_idx, &tc, tool_result, reason_ctx).await + { + deferred_auth = Some(instructions); + break; + } + } + } + } + deferred_auth +} + +/// Handle rejected tool call outcome. +pub(super) async fn handle_rejected_tool( + delegate: &ChatDelegate<'_>, + tool: ToolCtx<'_>, + error_msg: &str, + reason_ctx: &mut ReasoningContext, +) { + record_tool_outcome(delegate, tool.pf_idx, error_msg, true).await; + reason_ctx.messages.push(ChatMessage::tool_result( + &tool.tc.id, + &tool.tc.name, + error_msg.to_string(), + )); +} + +/// Process post-flight for a single runnable tool. +pub(super) async fn process_runnable_tool( + delegate: &ChatDelegate<'_>, + pf_idx: usize, + tc: &crate::llm::ToolCall, + tool_result: Result, + reason_ctx: &mut ReasoningContext, +) -> Option { + use crate::agent::dispatcher::{PREVIEW_MAX_CHARS, truncate_for_preview}; + + let is_tool_error = tool_result.is_err(); + + let output = match &tool_result { + Ok(output) => output, + Err(e) => { + let error_msg = format!("Tool '{}' failed: {}", tc.name, e); + let (preview_text, wrapped_text) = sanitize_output(delegate, &tc.name, &error_msg); + fold_into_context( + delegate, + ToolCtx { pf_idx, tc }, + ToolOutcome { + result_content: wrapped_text, + is_tool_error: true, + }, + reason_ctx, + ) + .await; + let preview = truncate_for_preview(&preview_text, PREVIEW_MAX_CHARS); + if !preview.is_empty() { + let _ = delegate + .agent + .channels + .send_status( + &delegate.message.channel, + StatusUpdate::ToolResult { + name: tc.name.clone(), + preview, + }, + &delegate.message.metadata, + ) + .await; + } + return None; + } + }; + + let is_image_sentinel = maybe_emit_image_sentinel(delegate, &tc.name, output).await; + let image_sentinel_summary = image_sentinel_summary(output); + + let (result_content, preview) = if is_image_sentinel { + let summary = image_sentinel_summary.unwrap_or_else(|| "[Image generated]".to_string()); + let (preview_text, wrapped_text) = sanitize_output(delegate, &tc.name, &summary); + let preview = truncate_for_preview(&preview_text, PREVIEW_MAX_CHARS); + (wrapped_text, preview) + } else { + let (preview_text, wrapped_text) = sanitize_output(delegate, &tc.name, output); + let preview = truncate_for_preview(&preview_text, PREVIEW_MAX_CHARS); + (wrapped_text, preview) + }; + + if !is_image_sentinel && !preview.is_empty() { + let _ = delegate + .agent + .channels + .send_status( + &delegate.message.channel, + StatusUpdate::ToolResult { + name: tc.name.clone(), + preview, + }, + &delegate.message.metadata, + ) + .await; + } + + let auth_instructions = handle_auth_barrier(delegate, tc, &tool_result).await; + + // Stash raw `output` by `tc.id` for auditing/debugging while the LLM sees a separately sanitised form. + delegate + .job_ctx + .tool_output_stash + .write() + .await + .insert(tc.id.clone(), output.clone()); + + fold_into_context( + delegate, + ToolCtx { pf_idx, tc }, + ToolOutcome { + result_content, + is_tool_error, + }, + reason_ctx, + ) + .await; + + auth_instructions +} + +/// Emit image sentinel status update if applicable. +async fn maybe_emit_image_sentinel( + delegate: &ChatDelegate<'_>, + tool_name: &str, + output: &str, +) -> bool { + if !matches!(tool_name, "image_generate" | "image_edit") { + return false; + } + + if let Ok(sentinel) = serde_json::from_str::(output) + && sentinel.get("type").and_then(|v| v.as_str()) == Some("image_generated") + { + let data_url = sentinel + .get("data") + .and_then(|v| v.as_str()) + .unwrap_or_default() + .to_string(); + let path = sentinel + .get("path") + .and_then(|v| v.as_str()) + .map(String::from); + if data_url.is_empty() { + tracing::warn!("Image generation sentinel has empty data URL, skipping broadcast"); + } else { + let _ = delegate + .agent + .channels + .send_status( + &delegate.message.channel, + StatusUpdate::ImageGenerated { data_url, path }, + &delegate.message.metadata, + ) + .await; + } + return true; + } + false +} + +fn image_sentinel_summary(output: &str) -> Option { + let sentinel = serde_json::from_str::(output).ok()?; + if sentinel.get("type").and_then(|value| value.as_str()) != Some("image_generated") { + return None; + } + + let mut parts = vec!["[Image generated]".to_string()]; + if let Some(media_type) = sentinel.get("media_type").and_then(|value| value.as_str()) { + parts.push(format!("type={media_type}")); + } + if let Some(size) = sentinel.get("size").and_then(|value| value.as_str()) { + parts.push(format!("size={size}")); + } + if let Some(path) = sentinel.get("path").and_then(|value| value.as_str()) { + parts.push(format!("path={path}")); + } else if let Some(source_path) = sentinel.get("source_path").and_then(|value| value.as_str()) { + parts.push(format!("source={source_path}")); + } + Some(parts.join(" ")) +} + +/// Sanitize tool output and return both preview text (raw sanitized) and wrapped text (for LLM). +fn sanitize_output(delegate: &ChatDelegate<'_>, tool_name: &str, output: &str) -> (String, String) { + let sanitized = delegate + .agent + .safety() + .sanitize_tool_output(tool_name, output); + let preview_text = sanitized.content.clone(); + let wrapped_text = + delegate + .agent + .safety() + .wrap_for_llm(tool_name, &sanitized.content, sanitized.was_modified); + (preview_text, wrapped_text) +} + +/// Outcome of a tool execution for folding into context. +pub(super) struct ToolOutcome { + pub(super) result_content: String, + pub(super) is_tool_error: bool, +} + +/// Fold tool result into context messages. +pub(super) async fn fold_into_context( + delegate: &ChatDelegate<'_>, + tool: ToolCtx<'_>, + outcome: ToolOutcome, + reason_ctx: &mut ReasoningContext, +) { + record_tool_outcome( + delegate, + tool.pf_idx, + &outcome.result_content, + outcome.is_tool_error, + ) + .await; + + reason_ctx.messages.push(ChatMessage::tool_result( + &tool.tc.id, + &tool.tc.name, + outcome.result_content, + )); +} diff --git a/src/agent/dispatcher/delegate/tool_exec/preflight.rs b/src/agent/dispatcher/delegate/tool_exec/preflight.rs new file mode 100644 index 000000000..93c2860bd --- /dev/null +++ b/src/agent/dispatcher/delegate/tool_exec/preflight.rs @@ -0,0 +1,213 @@ +//! Preflight stage for chat tool execution. +//! +//! Applies hooks, validates tool calls, and classifies each call as runnable, +//! rejected, or requiring explicit user approval before execution. + +use std::sync::Arc; + +use crate::agent::dispatcher::delegate::ChatDelegate; +use crate::error::Error; +use crate::tools::redact_params; + +/// Outcome of preflight check for a single tool call. +pub(crate) enum PreflightOutcome { + /// Tool call was rejected by a hook. + Rejected(String), + /// Tool call is runnable. + Runnable, +} + +/// Result of grouping tool calls into batches. +pub(crate) struct ToolBatch { + /// Preflight outcomes for each tool call. + pub(super) preflight: Vec<(crate::llm::ToolCall, PreflightOutcome)>, + /// Indices of runnable tools (pointing into preflight). + pub(super) runnable: Vec<(usize, crate::llm::ToolCall)>, +} + +/// A tool call that requires user approval, together with its index in the +/// original call sequence (used to build the deferred-call slice). +pub(super) struct ApprovalCandidate { + pub idx: usize, + pub tool_call: crate::llm::ToolCall, + pub tool: Arc, +} + +/// Restore original values for sensitive fields into a mutable JSON object. +/// +/// After a hook modifies tool parameters, any sensitive key that was +/// redacted before the hook must be put back from the original call to +/// prevent secret loss. +fn restore_sensitive_fields( + obj: &mut serde_json::Map, + original_args: &serde_json::Value, + sensitive: &[&str], +) { + for key in sensitive { + if let Some(orig_val) = original_args.get(*key) { + obj.insert((*key).to_string(), orig_val.clone()); + } + } +} + +/// Apply hook parameter modification to a tool call. +fn apply_hook_param_modification( + tc: &mut crate::llm::ToolCall, + original_tc: &crate::llm::ToolCall, + sensitive: &[&str], + new_params: &str, +) { + match serde_json::from_str::(new_params) { + Ok(mut parsed) => { + if let Some(obj) = parsed.as_object_mut() { + restore_sensitive_fields(obj, &original_tc.arguments, sensitive); + } + tc.arguments = parsed; + } + Err(e) => { + tracing::warn!( + tool = %tc.name, + "Hook returned non-JSON modification for ToolCall, ignoring: {}", + e + ); + } + } +} + +/// Apply the BeforeToolCall hook and return rejection message if any. +pub(super) async fn apply_before_tool_call_hook( + delegate: &ChatDelegate<'_>, + original_tc: &crate::llm::ToolCall, + tc: &mut crate::llm::ToolCall, + sensitive: &[&str], +) -> Option { + let hook_params = redact_params(&tc.arguments, sensitive); + let event = crate::hooks::HookEvent::ToolCall { + tool_name: tc.name.clone(), + parameters: hook_params, + user_id: delegate.message.user_id.clone(), + context: "chat".to_string(), + }; + match delegate.agent.hooks().run(&event).await { + Err(crate::hooks::HookError::Rejected { reason }) => { + Some(format!("Tool call rejected by hook: {}", reason)) + } + Err(err) => Some(format!("Tool call blocked by hook policy: {}", err)), + Ok(crate::hooks::HookOutcome::Continue { + modified: Some(new_params), + }) => { + apply_hook_param_modification(tc, original_tc, sensitive, &new_params); + None + } + _ => None, + } +} + +/// Check if a tool requires approval based on its configuration and auto-approve settings. +async fn tool_requires_approval( + delegate: &ChatDelegate<'_>, + tool: &Arc, + tc: &crate::llm::ToolCall, +) -> bool { + use crate::tools::ApprovalRequirement; + match tool.requires_approval(&tc.arguments) { + ApprovalRequirement::Never => false, + ApprovalRequirement::Always => true, + ApprovalRequirement::UnlessAutoApproved => { + let sess = delegate.session.lock().await; + !sess.is_tool_auto_approved(&tc.name) + } + } +} + +async fn approval_required_tool( + delegate: &ChatDelegate<'_>, + tool_opt: Option>, + tc: &crate::llm::ToolCall, +) -> Option> { + if delegate.agent.config.auto_approve_tools { + return None; + } + let tool = tool_opt?; + if tool_requires_approval(delegate, &tool, tc).await { + Some(tool) + } else { + None + } +} + +/// The outcome of pre-flight classification for a single tool call. +enum ToolCallOutcome { + /// The before-hook rejected this call with a message. + Rejected(String), + /// The call requires user approval before it may run. + NeedsApproval(ApprovalCandidate), + /// The call is cleared to run immediately. + Runnable, +} + +async fn classify_tool_call( + delegate: &ChatDelegate<'_>, + idx: usize, + original_tc: &crate::llm::ToolCall, + tc: &mut crate::llm::ToolCall, +) -> ToolCallOutcome { + let tool_opt = delegate.agent.tools().get(&tc.name).await; + let sensitive = tool_opt + .as_ref() + .map(|t| t.sensitive_params()) + .unwrap_or(&[]); + + if let Some(rejection_msg) = + apply_before_tool_call_hook(delegate, original_tc, tc, sensitive).await + { + return ToolCallOutcome::Rejected(rejection_msg); + } + + if let Some(tool) = approval_required_tool(delegate, tool_opt, tc).await { + return ToolCallOutcome::NeedsApproval(ApprovalCandidate { + idx, + tool_call: tc.clone(), + tool, + }); + } + + ToolCallOutcome::Runnable +} + +/// Group tool calls into preflight outcomes and runnable batch. +pub(super) async fn group_tool_calls( + delegate: &ChatDelegate<'_>, + tool_calls: &[crate::llm::ToolCall], +) -> Result<(ToolBatch, Option), Error> { + let mut preflight: Vec<(crate::llm::ToolCall, PreflightOutcome)> = Vec::new(); + let mut runnable: Vec<(usize, crate::llm::ToolCall)> = Vec::new(); + let mut approval_needed = None; + + for (idx, original_tc) in tool_calls.iter().enumerate() { + let mut tc = original_tc.clone(); + + match classify_tool_call(delegate, idx, original_tc, &mut tc).await { + ToolCallOutcome::Rejected(msg) => { + preflight.push((tc, PreflightOutcome::Rejected(msg))); + } + ToolCallOutcome::NeedsApproval(candidate) => { + approval_needed = Some(candidate); + break; + } + ToolCallOutcome::Runnable => { + let pf_idx = preflight.len(); + preflight.push((tc.clone(), PreflightOutcome::Runnable)); + runnable.push((pf_idx, tc)); + } + } + } + + Ok(( + ToolBatch { + preflight, + runnable, + }, + approval_needed, + )) +} diff --git a/src/agent/dispatcher/delegate/tool_exec/recording.rs b/src/agent/dispatcher/delegate/tool_exec/recording.rs new file mode 100644 index 000000000..fc9eb2718 --- /dev/null +++ b/src/agent/dispatcher/delegate/tool_exec/recording.rs @@ -0,0 +1,74 @@ +//! Recording helpers for chat tool execution. +//! +//! Persists redacted tool calls and writes indexed outcomes back to the +//! current turn so later results stay aligned with the originating call. + +use crate::agent::dispatcher::delegate::ChatDelegate; + +/// Compute the safe (redacted) argument map for a single tool call. +async fn redact_single_tool_call( + agent: &crate::agent::Agent, + tc: &crate::llm::ToolCall, +) -> serde_json::Value { + if let Some(tool) = agent.tools().get(&tc.name).await { + crate::tools::redact_params(&tc.arguments, tool.sensitive_params()) + } else { + tc.arguments.clone() + } +} + +/// Record redacted tool-call args into the current turn of the session thread. +pub(super) async fn write_tool_calls_to_thread( + delegate: &ChatDelegate<'_>, + tool_calls: &[crate::llm::ToolCall], + redacted_args: Vec, +) { + let mut sess = delegate.session.lock().await; + let Some(thread) = sess.threads.get_mut(&delegate.thread_id) else { + return; + }; + let Some(turn) = thread.last_turn_mut() else { + return; + }; + for (tc, safe_args) in tool_calls.iter().zip(redacted_args) { + turn.record_tool_call(&tc.name, safe_args); + } +} + +/// Record tool calls in the session thread with sensitive params redacted. +pub(super) async fn record_redacted_tool_calls( + delegate: &ChatDelegate<'_>, + tool_calls: &[crate::llm::ToolCall], +) { + let mut redacted_args: Vec = Vec::with_capacity(tool_calls.len()); + for tc in tool_calls { + redacted_args.push(redact_single_tool_call(delegate.agent, tc).await); + } + write_tool_calls_to_thread(delegate, tool_calls, redacted_args).await; +} + +/// Record tool outcome in the thread. +pub(super) async fn record_tool_outcome( + delegate: &ChatDelegate<'_>, + tool_call_idx: usize, + result_content: &str, + is_tool_error: bool, +) { + let mut sess = delegate.session.lock().await; + if let Some(thread) = sess.threads.get_mut(&delegate.thread_id) + && let Some(turn) = thread.last_turn_mut() + { + let record_result = if is_tool_error { + turn.record_tool_error_at(tool_call_idx, result_content.to_string()) + } else { + turn.record_tool_result_content_at(tool_call_idx, result_content) + }; + if let Err(error) = record_result { + tracing::warn!( + tool_call_idx, + %error, + "Failed to record tool outcome in session turn" + ); + } + } +} diff --git a/src/agent/dispatcher.rs b/src/agent/dispatcher/mod.rs similarity index 62% rename from src/agent/dispatcher.rs rename to src/agent/dispatcher/mod.rs index 08b890c98..f2c0a768b 100644 --- a/src/agent/dispatcher.rs +++ b/src/agent/dispatcher/mod.rs @@ -2,35 +2,34 @@ //! //! Extracted from `agent_loop.rs` to keep the core agentic tool execution //! loop (LLM call -> tool calls -> repeat) in its own focused module. +//! +//! This module is organized into submodules by responsibility: +//! - `preflight`: Tool call preflight checks and batching +//! - `execution`: Tool execution (inline and parallel) +//! - `postflight`: Post-execution processing and context folding +//! - `delegate`: Chat delegate implementation of NativeLoopDelegate + +mod delegate; use std::sync::Arc; use tokio::sync::Mutex; -use tokio::task::JoinSet; use uuid::Uuid; use crate::agent::Agent; -use crate::agent::session::{PendingApproval, Session, ThreadState}; -use crate::channels::{IncomingMessage, StatusUpdate}; +use crate::agent::session::{PendingApproval, Session}; +use crate::channels::IncomingMessage; use crate::context::JobContext; use crate::error::Error; -use crate::agent::agentic_loop::{ - AgenticLoopConfig, LoopOutcome, LoopSignal, NativeLoopDelegate, TextAction, -}; +use crate::agent::agentic_loop::{AgenticLoopConfig, LoopOutcome}; use crate::llm::{ChatMessage, Reasoning, ReasoningContext}; -use crate::tools::redact_params; pub(crate) const PREVIEW_MAX_CHARS: usize = 1024; - -/// Check if a string is valid JSON (object or array). -fn is_valid_json(s: &str) -> bool { - let t = s.trim(); - if !(t.starts_with('{') || t.starts_with('[')) { - return false; - } - serde_json::from_str::(t).is_ok() -} +// Re-export items used by other modules from the delegate submodule +pub(crate) use delegate::{ + ToolCallSpec, check_auth_required, execute_chat_tool_standalone, parse_auth_result, +}; /// Collapse a tool output string into a single-line preview for display. pub(crate) fn truncate_for_preview(output: &str, max_chars: usize) -> String { @@ -228,7 +227,7 @@ impl Agent { let force_text_at = max_tool_iterations; let nudge_at = max_tool_iterations.saturating_sub(1); - let delegate = ChatDelegate { + let delegate = delegate::ChatDelegate { agent: self, session: session.clone(), thread_id, @@ -295,996 +294,16 @@ impl Agent { execute_chat_tool_standalone( self.tools(), self.safety(), - &ChatToolRequest { tool_name, params }, + &ToolCallSpec { + name: tool_name, + params, + }, job_ctx, ) .await } } -/// Delegate for the chat (dispatcher) context. -/// -/// Implements `LoopDelegate` to customize the shared agentic loop for -/// interactive chat sessions with the full 3-phase tool execution -/// (preflight → parallel exec → post-flight), approval flow, hooks, -/// auth intercept, and cost tracking. -struct ChatDelegate<'a> { - agent: &'a Agent, - session: Arc>, - thread_id: Uuid, - message: &'a IncomingMessage, - job_ctx: JobContext, - active_skills: Vec, - cached_prompt: String, - cached_prompt_no_tools: String, - nudge_at: usize, - force_text_at: usize, - user_tz: chrono_tz::Tz, -} - -/// Execution context for tool calls. -#[expect(dead_code, reason = "scaffolding for future tool-exec refactor")] -struct ExecCtx<'a> { - tools: &'a Arc, - safety: &'a Arc, - channels: &'a Arc, - channel: &'a str, - user_id: &'a str, - metadata: &'a serde_json::Value, - preview_limit: usize, -} - -impl<'a> ExecCtx<'a> { - #[expect(dead_code, reason = "scaffolding for future tool-exec refactor")] - fn new( - tools: &'a Arc, - safety: &'a Arc, - channels: &'a Arc, - channel: &'a str, - user_id: &'a str, - metadata: &'a serde_json::Value, - preview_limit: usize, - ) -> Self { - Self { - tools, - safety, - channels, - channel, - user_id, - metadata, - preview_limit, - } - } -} - -/// Outcome of preflight check for a single tool call. -enum PreflightOutcome { - Rejected(String), - Runnable, -} - -/// Result of grouping tool calls into batches. -struct ToolBatch { - preflight: Vec<(crate::llm::ToolCall, PreflightOutcome)>, - runnable: Vec<(usize, crate::llm::ToolCall)>, -} - -impl<'a> ChatDelegate<'a> { - /// Group tool calls into preflight outcomes and runnable batch. - async fn group_tool_calls( - &self, - tool_calls: &[crate::llm::ToolCall], - ) -> Result< - ( - ToolBatch, - Option<(usize, crate::llm::ToolCall, Arc)>, - ), - Error, - > { - let mut preflight: Vec<(crate::llm::ToolCall, PreflightOutcome)> = Vec::new(); - let mut runnable: Vec<(usize, crate::llm::ToolCall)> = Vec::new(); - let mut approval_needed: Option<( - usize, - crate::llm::ToolCall, - Arc, - )> = None; - - for (idx, original_tc) in tool_calls.iter().enumerate() { - let mut tc = original_tc.clone(); - - let tool_opt = self.agent.tools().get(&tc.name).await; - let sensitive = tool_opt - .as_ref() - .map(|t| t.sensitive_params()) - .unwrap_or(&[]); - - // Hook: BeforeToolCall - let hook_params = redact_params(&tc.arguments, sensitive); - let event = crate::hooks::HookEvent::ToolCall { - tool_name: tc.name.clone(), - parameters: hook_params, - user_id: self.message.user_id.clone(), - context: "chat".to_string(), - }; - match self.agent.hooks().run(&event).await { - Err(crate::hooks::HookError::Rejected { reason }) => { - preflight.push(( - tc, - PreflightOutcome::Rejected(format!( - "Tool call rejected by hook: {}", - reason - )), - )); - continue; - } - Err(err) => { - preflight.push(( - tc, - PreflightOutcome::Rejected(format!( - "Tool call blocked by hook policy: {}", - err - )), - )); - continue; - } - Ok(crate::hooks::HookOutcome::Continue { - modified: Some(new_params), - }) => match serde_json::from_str::(&new_params) { - Ok(mut parsed) => { - if let Some(obj) = parsed.as_object_mut() { - for key in sensitive { - if let Some(orig_val) = original_tc.arguments.get(*key) { - obj.insert((*key).to_string(), orig_val.clone()); - } - } - } - tc.arguments = parsed; - } - Err(e) => { - tracing::warn!( - tool = %tc.name, - "Hook returned non-JSON modification for ToolCall, ignoring: {}", - e - ); - } - }, - _ => {} - } - - // Check if tool requires approval - if !self.agent.config.auto_approve_tools - && let Some(tool) = tool_opt - { - use crate::tools::ApprovalRequirement; - let needs_approval = match tool.requires_approval(&tc.arguments) { - ApprovalRequirement::Never => false, - ApprovalRequirement::UnlessAutoApproved => { - let sess = self.session.lock().await; - !sess.is_tool_auto_approved(&tc.name) - } - ApprovalRequirement::Always => true, - }; - - if needs_approval { - approval_needed = Some((idx, tc, tool)); - break; - } - } - - let preflight_idx = preflight.len(); - preflight.push((tc.clone(), PreflightOutcome::Runnable)); - runnable.push((preflight_idx, tc)); - } - - Ok(( - ToolBatch { - preflight, - runnable, - }, - approval_needed, - )) - } - - /// Send ToolStarted status update. - async fn send_tool_started(&self, tool_name: &str) { - let _ = self - .agent - .channels - .send_status( - &self.message.channel, - StatusUpdate::ToolStarted { - name: tool_name.to_string(), - }, - &self.message.metadata, - ) - .await; - } - - /// Send tool_completed status update. - async fn send_tool_completed( - &self, - tool_name: &str, - result: &Result, - arguments: &serde_json::Value, - ) { - let disp_tool = self.agent.tools().get(tool_name).await; - let _ = self - .agent - .channels - .send_status( - &self.message.channel, - StatusUpdate::tool_completed( - tool_name.to_string(), - result, - arguments, - disp_tool.as_deref(), - ), - &self.message.metadata, - ) - .await; - } - - /// Execute a single tool inline (for small batches). - async fn execute_one_tool(&self, tc: &crate::llm::ToolCall) -> Result { - self.send_tool_started(&tc.name).await; - let result = self - .agent - .execute_chat_tool(&tc.name, &tc.arguments, &self.job_ctx) - .await; - self.send_tool_completed(&tc.name, &result, &tc.arguments) - .await; - result - } - - /// Sanitize tool output and return both preview text (raw sanitized) and wrapped text (for LLM). - fn sanitize_output(&self, tool_name: &str, output: &str) -> (String, String) { - let sanitized = self.agent.safety().sanitize_tool_output(tool_name, output); - let preview_text = sanitized.content.clone(); - let wrapped_text = - self.agent - .safety() - .wrap_for_llm(tool_name, &sanitized.content, sanitized.was_modified); - (preview_text, wrapped_text) - } - - /// Record tool outcome in the thread. - async fn record_tool_outcome( - &self, - _tool_name: &str, - result_content: &str, - is_tool_error: bool, - ) { - let mut sess = self.session.lock().await; - if let Some(thread) = sess.threads.get_mut(&self.thread_id) - && let Some(turn) = thread.last_turn_mut() - { - if is_tool_error { - turn.record_tool_error(result_content.to_string()); - } else { - turn.record_tool_result(serde_json::json!(result_content)); - } - } - } - - /// Emit image sentinel status update if applicable. - async fn maybe_emit_image_sentinel(&self, tool_name: &str, output: &str) -> bool { - if !matches!(tool_name, "image_generate" | "image_edit") { - return false; - } - - if let Ok(sentinel) = serde_json::from_str::(output) - && sentinel.get("type").and_then(|v| v.as_str()) == Some("image_generated") - { - let data_url = sentinel - .get("data") - .and_then(|v| v.as_str()) - .unwrap_or_default() - .to_string(); - let path = sentinel - .get("path") - .and_then(|v| v.as_str()) - .map(String::from); - if data_url.is_empty() { - tracing::warn!("Image generation sentinel has empty data URL, skipping broadcast"); - } else { - let _ = self - .agent - .channels - .send_status( - &self.message.channel, - StatusUpdate::ImageGenerated { data_url, path }, - &self.message.metadata, - ) - .await; - } - return true; - } - false - } - - /// Fold tool result into context messages. - async fn fold_into_context( - &self, - tc: &crate::llm::ToolCall, - result_content: String, - is_tool_error: bool, - reason_ctx: &mut ReasoningContext, - ) { - // Record sanitized result in thread - self.record_tool_outcome(&tc.name, &result_content, is_tool_error) - .await; - - reason_ctx - .messages - .push(ChatMessage::tool_result(&tc.id, &tc.name, result_content)); - } - - /// Run a batch of tools inline (sequential execution for small batches). - async fn run_tool_batch_inline( - &self, - runnable: &[(usize, crate::llm::ToolCall)], - exec_results: &mut [Option>], - ) { - for (pf_idx, tc) in runnable { - let result = self.execute_one_tool(tc).await; - exec_results[*pf_idx] = Some(result); - } - } - - /// Run a batch of tools in parallel (for large batches). - async fn run_tool_batch_parallel( - &self, - runnable: &[(usize, crate::llm::ToolCall)], - exec_results: &mut [Option>], - ) { - let mut join_set = JoinSet::new(); - - for (pf_idx, tc) in runnable { - let pf_idx = *pf_idx; - let tools = self.agent.tools().clone(); - let safety = self.agent.safety().clone(); - let channels = self.agent.channels.clone(); - let job_ctx = self.job_ctx.clone(); - let tc = tc.clone(); - let channel = self.message.channel.clone(); - let metadata = self.message.metadata.clone(); - - join_set.spawn(async move { - let _ = channels - .send_status( - &channel, - StatusUpdate::ToolStarted { - name: tc.name.clone(), - }, - &metadata, - ) - .await; - - let result = execute_chat_tool_standalone( - &tools, - &safety, - &ChatToolRequest { - tool_name: &tc.name, - params: &tc.arguments, - }, - &job_ctx, - ) - .await; - - let par_tool = tools.get(&tc.name).await; - let _ = channels - .send_status( - &channel, - StatusUpdate::tool_completed( - tc.name.clone(), - &result, - &tc.arguments, - par_tool.as_deref(), - ), - &metadata, - ) - .await; - - (pf_idx, result) - }); - } - - while let Some(join_result) = join_set.join_next().await { - match join_result { - Ok((pf_idx, result)) => { - exec_results[pf_idx] = Some(result); - } - Err(e) => { - if e.is_panic() { - tracing::error!("Chat tool execution task panicked: {}", e); - } else { - tracing::error!("Chat tool execution task cancelled: {}", e); - } - } - } - } - - // Fill panicked slots with error results - for (pf_idx, tc) in runnable.iter() { - if exec_results[*pf_idx].is_none() { - tracing::error!( - tool = %tc.name, - "Filling failed task slot with error" - ); - exec_results[*pf_idx] = Some(Err(crate::error::ToolError::ExecutionFailed { - name: tc.name.clone(), - reason: "Task failed during execution".to_string(), - } - .into())); - } - } - } - - /// Handle rejected tool call outcome. - async fn handle_rejected_tool( - &self, - tc: &crate::llm::ToolCall, - error_msg: &str, - reason_ctx: &mut ReasoningContext, - ) { - { - let mut sess = self.session.lock().await; - if let Some(thread) = sess.threads.get_mut(&self.thread_id) - && let Some(turn) = thread.last_turn_mut() - { - turn.record_tool_error(error_msg.to_string()); - } - } - reason_ctx.messages.push(ChatMessage::tool_result( - &tc.id, - &tc.name, - error_msg.to_string(), - )); - } - - /// Process post-flight for a single runnable tool. - async fn process_runnable_tool( - &self, - tc: &crate::llm::ToolCall, - tool_result: Result, - reason_ctx: &mut ReasoningContext, - ) -> Option { - let is_tool_error = tool_result.is_err(); - - // Handle error case early - let output = match &tool_result { - Ok(output) => output, - Err(e) => { - let error_msg = format!("Tool '{}' failed: {}", tc.name, e); - self.fold_into_context(tc, error_msg, true, reason_ctx) - .await; - return None; - } - }; - - // Detect image generation sentinel - let is_image_sentinel = self.maybe_emit_image_sentinel(&tc.name, output).await; - - // Determine result content and preview based on whether output is valid JSON - let (result_content, preview) = if is_valid_json(output) { - // For JSON-producing tools, persist raw JSON without wrapping - let preview = truncate_for_preview(output, PREVIEW_MAX_CHARS); - (output.clone(), preview) - } else { - // Sanitize tool output first (before sending preview or using in context) - // preview_text is raw sanitized for preview, wrapped_text is for LLM context - let (preview_text, wrapped_text) = self.sanitize_output(&tc.name, output); - let preview = truncate_for_preview(&preview_text, PREVIEW_MAX_CHARS); - (wrapped_text, preview) - }; - - // Send ToolResult preview - if !is_image_sentinel && !preview.is_empty() { - let _ = self - .agent - .channels - .send_status( - &self.message.channel, - StatusUpdate::ToolResult { - name: tc.name.clone(), - preview, - }, - &self.message.metadata, - ) - .await; - } - - // Check for auth awaiting (use original tool_result for auth detection) - let auth_instructions = - if let Some((ext_name, instructions)) = check_auth_required(&tc.name, &tool_result) { - let auth_data = parse_auth_result(&tool_result); - { - let mut sess = self.session.lock().await; - if let Some(thread) = sess.threads.get_mut(&self.thread_id) { - thread.enter_auth_mode(ext_name.clone()); - } - } - let _ = self - .agent - .channels - .send_status( - &self.message.channel, - StatusUpdate::AuthRequired { - extension_name: ext_name, - instructions: Some(instructions.clone()), - auth_url: auth_data.auth_url, - setup_url: auth_data.setup_url, - }, - &self.message.metadata, - ) - .await; - Some(instructions) - } else { - None - }; - - // Stash full output so subsequent tools can reference it - self.job_ctx - .tool_output_stash - .write() - .await - .insert(tc.id.clone(), output.clone()); - - // Fold result into context - self.fold_into_context(tc, result_content, is_tool_error, reason_ctx) - .await; - - auth_instructions - } -} - -impl<'a> NativeLoopDelegate for ChatDelegate<'a> { - async fn check_signals(&self) -> LoopSignal { - let sess = self.session.lock().await; - if let Some(thread) = sess.threads.get(&self.thread_id) - && thread.state == ThreadState::Interrupted - { - return LoopSignal::Stop; - } - LoopSignal::Continue - } - - async fn before_llm_call( - &self, - reason_ctx: &mut ReasoningContext, - iteration: usize, - ) -> Option { - // Inject a nudge message when approaching the iteration limit so the - // LLM is aware it should produce a final answer on the next turn. - if iteration == self.nudge_at { - reason_ctx.messages.push(ChatMessage::system( - "You are approaching the tool call limit. \ - Provide your best final answer on the next response \ - using the information you have gathered so far. \ - Do not call any more tools.", - )); - } - - let force_text = iteration >= self.force_text_at; - - // Refresh tool definitions each iteration so newly built tools become visible - let tool_defs = self.agent.tools().tool_definitions().await; - - // Apply trust-based tool attenuation if skills are active. - let tool_defs = if !self.active_skills.is_empty() { - let result = crate::skills::attenuate_tools(&tool_defs, &self.active_skills); - tracing::debug!( - min_trust = %result.min_trust, - tools_available = result.tools.len(), - tools_removed = result.removed_tools.len(), - removed = ?result.removed_tools, - explanation = %result.explanation, - "Tool attenuation applied" - ); - result.tools - } else { - tool_defs - }; - - // Update context for this iteration - reason_ctx.available_tools = tool_defs; - reason_ctx.system_prompt = Some(if force_text { - self.cached_prompt_no_tools.clone() - } else { - self.cached_prompt.clone() - }); - reason_ctx.force_text = force_text; - - if force_text { - tracing::info!( - iteration, - "Forcing text-only response (iteration limit reached)" - ); - } - - let _ = self - .agent - .channels - .send_status( - &self.message.channel, - StatusUpdate::Thinking("Calling LLM...".into()), - &self.message.metadata, - ) - .await; - - None - } - - async fn call_llm( - &self, - reasoning: &Reasoning, - reason_ctx: &mut ReasoningContext, - iteration: usize, - ) -> Result { - // Enforce cost guardrails before the LLM call - if let Err(limit) = self.agent.cost_guard().check_allowed().await { - return Err(crate::error::LlmError::InvalidResponse { - provider: "agent".to_string(), - reason: limit.to_string(), - } - .into()); - } - - let output = match reasoning.respond_with_tools(reason_ctx).await { - Ok(output) => output, - Err(crate::error::LlmError::ContextLengthExceeded { used, limit }) => { - tracing::warn!( - used, - limit, - iteration, - "Context length exceeded, compacting messages and retrying" - ); - - // Compact messages in place and retry - reason_ctx.messages = compact_messages_for_retry(&reason_ctx.messages); - - // When force_text, clear tools to further reduce token count - if reason_ctx.force_text { - reason_ctx.available_tools.clear(); - } - - reasoning - .respond_with_tools(reason_ctx) - .await - .map_err(|retry_err| { - tracing::error!( - original_used = used, - original_limit = limit, - retry_error = %retry_err, - "Retry after auto-compaction also failed" - ); - crate::error::Error::from(retry_err) - })? - } - Err(e) => return Err(e.into()), - }; - - // Record cost and track token usage - let model_name = self.agent.llm().active_model_name(); - let read_discount = self.agent.llm().cache_read_discount(); - let write_multiplier = self.agent.llm().cache_write_multiplier(); - let call_cost = self - .agent - .cost_guard() - .record_llm_call( - &model_name, - output.usage.input_tokens, - output.usage.output_tokens, - output.usage.cache_read_input_tokens, - output.usage.cache_creation_input_tokens, - read_discount, - write_multiplier, - Some(self.agent.llm().cost_per_token()), - ) - .await; - tracing::debug!( - "LLM call used {} input + {} output tokens (${:.6})", - output.usage.input_tokens, - output.usage.output_tokens, - call_cost, - ); - - Ok(output) - } - - async fn handle_text_response( - &self, - text: &str, - _reason_ctx: &mut ReasoningContext, - ) -> TextAction { - // Strip internal "[Called tool ...]" text that can leak when - // provider flattening (e.g. NEAR AI) converts tool_calls to - // plain text and the LLM echoes it back. - let sanitized = strip_internal_tool_call_text(text); - TextAction::Return(LoopOutcome::Response(sanitized)) - } - - async fn execute_tool_calls( - &self, - tool_calls: Vec, - content: Option, - reason_ctx: &mut ReasoningContext, - ) -> Result, Error> { - // Add the assistant message with tool_calls to context. - // OpenAI protocol requires this before tool-result messages. - reason_ctx - .messages - .push(ChatMessage::assistant_with_tool_calls( - content, - tool_calls.clone(), - )); - - // Execute tools and add results to context - let _ = self - .agent - .channels - .send_status( - &self.message.channel, - StatusUpdate::Thinking(format!("Executing {} tool(s)...", tool_calls.len())), - &self.message.metadata, - ) - .await; - - // Record tool calls in the thread with sensitive params redacted. - { - let mut redacted_args: Vec = Vec::with_capacity(tool_calls.len()); - for tc in &tool_calls { - let safe = if let Some(tool) = self.agent.tools().get(&tc.name).await { - redact_params(&tc.arguments, tool.sensitive_params()) - } else { - tc.arguments.clone() - }; - redacted_args.push(safe); - } - let mut sess = self.session.lock().await; - if let Some(thread) = sess.threads.get_mut(&self.thread_id) - && let Some(turn) = thread.last_turn_mut() - { - for (tc, safe_args) in tool_calls.iter().zip(redacted_args) { - turn.record_tool_call(&tc.name, safe_args); - } - } - } - - // === Phase 1: Preflight (sequential) === - let (batch, approval_needed) = self.group_tool_calls(&tool_calls).await?; - let ToolBatch { - preflight, - runnable, - } = batch; - - // === Phase 2: Parallel execution === - let mut exec_results: Vec>> = - (0..preflight.len()).map(|_| None).collect(); - - if runnable.len() <= 1 { - self.run_tool_batch_inline(&runnable, &mut exec_results) - .await; - } else { - self.run_tool_batch_parallel(&runnable, &mut exec_results) - .await; - } - - // === Phase 3: Post-flight (sequential, in original order) === - let mut deferred_auth: Option = None; - - for (pf_idx, (tc, outcome)) in preflight.into_iter().enumerate() { - match outcome { - PreflightOutcome::Rejected(error_msg) => { - self.handle_rejected_tool(&tc, &error_msg, reason_ctx).await; - } - PreflightOutcome::Runnable => { - let tool_result = exec_results[pf_idx].take().unwrap_or_else(|| { - Err(crate::error::ToolError::ExecutionFailed { - name: tc.name.clone(), - reason: "No result available".to_string(), - } - .into()) - }); - - if let Some(instructions) = self - .process_runnable_tool(&tc, tool_result, reason_ctx) - .await - { - deferred_auth = Some(instructions); - } - } - } - } - - // Return auth response after all results are recorded - if let Some(instructions) = deferred_auth { - return Ok(Some(LoopOutcome::Response(instructions))); - } - - // Handle approval if a tool needed it - if let Some((approval_idx, tc, tool)) = approval_needed { - let display_params = redact_params(&tc.arguments, tool.sensitive_params()); - let pending = PendingApproval { - request_id: Uuid::new_v4(), - tool_name: tc.name.clone(), - parameters: tc.arguments.clone(), - display_parameters: display_params, - description: tool.description().to_string(), - tool_call_id: tc.id.clone(), - context_messages: reason_ctx.messages.clone(), - deferred_tool_calls: tool_calls[approval_idx + 1..].to_vec(), - user_timezone: Some(self.user_tz.name().to_string()), - }; - - return Ok(Some(LoopOutcome::NeedApproval(Box::new(pending)))); - } - - Ok(None) - } -} - -/// Describes a single tool invocation passed to `execute_chat_tool_standalone`. -pub(crate) struct ChatToolRequest<'a> { - pub(crate) tool_name: &'a str, - pub(crate) params: &'a serde_json::Value, -} -/// Execute a chat tool without requiring `&Agent`. -/// -/// This standalone function enables parallel invocation from spawned JoinSet -/// tasks, which cannot borrow `&self`. Delegates to the shared -/// `execute_tool_with_safety` pipeline. -pub(super) async fn execute_chat_tool_standalone( - tools: &crate::tools::ToolRegistry, - safety: &crate::safety::SafetyLayer, - request: &ChatToolRequest<'_>, - job_ctx: &crate::context::JobContext, -) -> Result { - crate::tools::execute::execute_tool_with_safety( - tools, - safety, - request.tool_name, - request.params, - job_ctx, - ) - .await -} - -/// Parsed auth result fields for emitting StatusUpdate::AuthRequired. -pub(super) struct ParsedAuthData { - pub(super) auth_url: Option, - pub(super) setup_url: Option, -} - -/// Extract auth_url and setup_url from a tool_auth result JSON string. -pub(super) fn parse_auth_result(result: &Result) -> ParsedAuthData { - let parsed = result - .as_ref() - .ok() - .and_then(|s| serde_json::from_str::(s).ok()); - ParsedAuthData { - auth_url: parsed - .as_ref() - .and_then(|v| v.get("auth_url")) - .and_then(|v| v.as_str()) - .map(|s| s.to_string()), - setup_url: parsed - .as_ref() - .and_then(|v| v.get("setup_url")) - .and_then(|v| v.as_str()) - .map(|s| s.to_string()), - } -} - -/// Check if a tool_auth result indicates the extension is awaiting a token. -/// -/// Returns `Some((extension_name, instructions))` if the tool result contains -/// `awaiting_token: true`, meaning the thread should enter auth mode. -pub(super) fn check_auth_required( - tool_name: &str, - result: &Result, -) -> Option<(String, String)> { - if tool_name != "tool_auth" && tool_name != "tool_activate" { - return None; - } - let output = result.as_ref().ok()?; - let parsed: serde_json::Value = serde_json::from_str(output).ok()?; - if parsed.get("awaiting_token") != Some(&serde_json::Value::Bool(true)) { - return None; - } - let name = parsed.get("name")?.as_str()?.to_string(); - let instructions = parsed - .get("instructions") - .and_then(|v| v.as_str()) - .unwrap_or("Please provide your API token/key.") - .to_string(); - Some((name, instructions)) -} - -/// Compact messages for retry after a context-length-exceeded error. -/// -/// Keeps all `System` messages (which carry the system prompt and instructions), -/// finds the last `User` message, and retains it plus every subsequent message -/// (the current turn's assistant tool calls and tool results). A short note is -/// inserted so the LLM knows earlier history was dropped. -fn compact_messages_for_retry(messages: &[ChatMessage]) -> Vec { - use crate::llm::Role; - - let mut compacted = Vec::new(); - - // Find the last User message index - let last_user_idx = messages.iter().rposition(|m| m.role == Role::User); - - if let Some(idx) = last_user_idx { - // Keep System messages that appear BEFORE the last User message. - // System messages after that point (e.g. nudges) are included in the - // slice extension below, avoiding duplication. - for msg in &messages[..idx] { - if msg.role == Role::System { - compacted.push(msg.clone()); - } - } - - // Only add a compaction note if there was earlier history that is being dropped - if idx > 0 { - compacted.push(ChatMessage::system( - "[Note: Earlier conversation history was automatically compacted \ - to fit within the context window. The most recent exchange is preserved below.]", - )); - } - - // Keep the last User message and everything after it - compacted.extend_from_slice(&messages[idx..]); - } else { - // No user messages found (shouldn't happen normally); keep everything, - // with system messages first to preserve prompt ordering. - for msg in messages { - if msg.role == Role::System { - compacted.push(msg.clone()); - } - } - for msg in messages { - if msg.role != Role::System { - compacted.push(msg.clone()); - } - } - } - - compacted -} - -/// Strip internal `[Called tool ...]` and `[Tool ... returned: ...]` markers -/// from a response string. These markers are inserted by provider-level message -/// flattening (e.g. NEAR AI) and can leak into the user-visible response when -/// the LLM echoes them back. -fn strip_internal_tool_call_text(text: &str) -> String { - // Remove lines that are purely internal tool-call markers. - // Pattern: lines matching `[Called tool (...)]` or `[Tool returned: ...]` - let result = text - .lines() - .filter(|line| { - let trimmed = line.trim(); - !((trimmed.starts_with("[Called tool ") && trimmed.ends_with(']')) - || (trimmed.starts_with("[Tool ") - && trimmed.contains(" returned:") - && trimmed.ends_with(']'))) - }) - .fold(String::new(), |mut acc, s| { - if !acc.is_empty() { - acc.push('\n'); - } - acc.push_str(s); - acc - }); - - let result = result.trim(); - if result.is_empty() { - "I wasn't able to complete that request. Could you try rephrasing or providing more details?".to_string() - } else { - result.to_string() - } -} - #[cfg(test)] mod tests { use std::path::PathBuf; @@ -1711,8 +730,8 @@ mod tests { let result = super::execute_chat_tool_standalone( ®istry, &safety, - &super::ChatToolRequest { - tool_name: "echo", + &super::ToolCallSpec { + name: "echo", params: &serde_json::json!({"message": "hello"}), }, &job_ctx, @@ -1741,8 +760,8 @@ mod tests { let result = super::execute_chat_tool_standalone( ®istry, &safety, - &super::ChatToolRequest { - tool_name: "nonexistent", + &super::ToolCallSpec { + name: "nonexistent", params: &serde_json::json!({}), }, &job_ctx, @@ -1754,7 +773,7 @@ mod tests { // ---- compact_messages_for_retry tests ---- - use super::compact_messages_for_retry; + use super::delegate::{compact_messages_for_retry, strip_internal_tool_call_text}; use crate::llm::{ChatMessage, Role}; #[test] @@ -2117,7 +1136,7 @@ mod tests { if iteration >= force_text_at { hit_force_text = true; } - if iteration > max_iter + 1 { + if iteration >= hard_ceiling { hit_ceiling = true; } } @@ -2127,8 +1146,14 @@ mod tests { ); // The ceiling should only fire if force_text somehow didn't break assert!( - hit_ceiling || hard_ceiling <= max_iter + 1, - "ceiling logic inconsistent for max_iter={max_iter}" + hard_ceiling == max_iter + 1, + "hard_ceiling ({hard_ceiling}) must equal max_iter + 1 ({})", + max_iter + 1 + ); + assert!( + !hit_force_text || hit_ceiling, + "force_text_at ({force_text_at}) and hard_ceiling ({hard_ceiling}) diverged: \ + hit_force_text={hit_force_text}, hit_ceiling={hit_ceiling}" ); } } @@ -2408,28 +1433,28 @@ mod tests { #[test] fn test_strip_internal_tool_call_text_removes_markers() { let input = "[Called tool search({\"query\": \"test\"})]\nHere is the answer."; - let result = super::strip_internal_tool_call_text(input); + let result = strip_internal_tool_call_text(input); assert_eq!(result, "Here is the answer."); } #[test] fn test_strip_internal_tool_call_text_removes_returned_markers() { let input = "[Tool search returned: some result]\nSummary of findings."; - let result = super::strip_internal_tool_call_text(input); + let result = strip_internal_tool_call_text(input); assert_eq!(result, "Summary of findings."); } #[test] fn test_strip_internal_tool_call_text_all_markers_yields_fallback() { let input = "[Called tool search({\"query\": \"test\"})]\n[Tool search returned: error]"; - let result = super::strip_internal_tool_call_text(input); + let result = strip_internal_tool_call_text(input); assert!(result.contains("wasn't able to complete")); } #[test] fn test_strip_internal_tool_call_text_preserves_normal_text() { let input = "This is a normal response with [brackets] inside."; - let result = super::strip_internal_tool_call_text(input); + let result = strip_internal_tool_call_text(input); assert_eq!(result, input); } diff --git a/src/agent/self_repair/default.rs b/src/agent/self_repair/default.rs index a529f83dc..e27ea50ce 100644 --- a/src/agent/self_repair/default.rs +++ b/src/agent/self_repair/default.rs @@ -168,6 +168,11 @@ impl NativeSelfRepair for DefaultSelfRepair { message: format!("Job {} already recovered", job.job_id), }) } + Ok(Err(JobRecoveryError::InvariantViolation(reason))) => Err(RepairError::Failed { + target_type: "job".to_string(), + target_id: job.job_id, + reason, + }), Err(e) => Err(RepairError::Failed { target_type: "job".to_string(), target_id: job.job_id, diff --git a/src/agent/session.rs b/src/agent/session.rs index 8e7d956b5..57e5218ad 100644 --- a/src/agent/session.rs +++ b/src/agent/session.rs @@ -14,6 +14,7 @@ use std::collections::{HashMap, HashSet}; use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; +use thiserror::Error; use uuid::Uuid; use crate::channels::web::util::truncate_preview; @@ -41,6 +42,13 @@ pub struct Session { pub auto_approved_tools: HashSet, } +/// Errors for indexed tool-call mutations on a turn. +#[derive(Debug, Error, PartialEq, Eq)] +pub enum ToolCallIndexError { + #[error("tool call index {idx} out of bounds (len={len})")] + OutOfBounds { idx: usize, len: usize }, +} + impl Session { /// Create a new session. pub fn new(user_id: impl Into) -> Self { @@ -201,11 +209,10 @@ pub struct Thread { } impl Thread { - /// Create a new thread. - pub fn new(session_id: Uuid) -> Self { + fn init(session_id: Uuid, thread_id: Uuid) -> Self { let now = Utc::now(); Self { - id: Uuid::new_v4(), + id: thread_id, session_id, state: ThreadState::Idle, turns: Vec::new(), @@ -218,21 +225,15 @@ impl Thread { } } + /// Create a new thread. + pub fn new(session_id: Uuid) -> Self { + let thread_id = Uuid::new_v4(); + Self::init(session_id, thread_id) + } + /// Create a thread with a specific ID (for DB hydration). pub fn with_id(id: Uuid, session_id: Uuid) -> Self { - let now = Utc::now(); - Self { - id, - session_id, - state: ThreadState::Idle, - turns: Vec::new(), - created_at: now, - updated_at: now, - metadata: serde_json::Value::Null, - pending_approval: None, - pending_auth: None, - in_flight_auth: false, - } + Self::init(session_id, id) } /// Get the current turn number (1-indexed for display). @@ -518,6 +519,22 @@ pub struct Turn { } impl Turn { + fn set_tool_outcome_at( + &mut self, + idx: usize, + result: Option, + error: Option, + ) -> Result<(), ToolCallIndexError> { + let len = self.tool_calls.len(); + let tool_call = self + .tool_calls + .get_mut(idx) + .ok_or(ToolCallIndexError::OutOfBounds { idx, len })?; + tool_call.result = result; + tool_call.error = error; + Ok(()) + } + /// Create a new turn. pub fn new(turn_number: usize, user_input: impl Into) -> Self { Self { @@ -574,12 +591,55 @@ impl Turn { } } + /// Record tool call result for a specific tool-call slot. + pub fn record_tool_result_at( + &mut self, + idx: usize, + result: serde_json::Value, + ) -> Result<(), ToolCallIndexError> { + self.set_tool_outcome_at(idx, Some(result), None) + } + + fn parse_tool_result(result_content: &str) -> serde_json::Value { + let trimmed = result_content.trim_start(); + if matches!(trimmed.as_bytes().first(), Some(b'{' | b'[')) { + serde_json::from_str(result_content) + .unwrap_or_else(|_| serde_json::Value::String(result_content.to_string())) + } else { + serde_json::Value::String(result_content.to_string()) + } + } + + /// Record tool call result, parsing structured JSON where possible. + pub fn record_tool_result_content(&mut self, result_content: &str) { + self.record_tool_result(Self::parse_tool_result(result_content)); + } + + /// Record tool call result for a specific slot, parsing structured JSON + /// where possible. + pub fn record_tool_result_content_at( + &mut self, + idx: usize, + result_content: &str, + ) -> Result<(), ToolCallIndexError> { + self.record_tool_result_at(idx, Self::parse_tool_result(result_content)) + } + /// Record tool call error. pub fn record_tool_error(&mut self, error: impl Into) { if let Some(call) = self.tool_calls.last_mut() { call.error = Some(error.into()); } } + + /// Record tool call error for a specific tool-call slot. + pub fn record_tool_error_at( + &mut self, + idx: usize, + error: impl Into, + ) -> Result<(), ToolCallIndexError> { + self.set_tool_outcome_at(idx, None, Some(error.into())) + } } /// Record of a tool call made during a turn. @@ -599,6 +659,8 @@ pub struct TurnToolCall { mod tests { use super::*; + mod record_tool_result_content; + #[test] fn test_session_creation() { let mut session = Session::new("user-123"); @@ -742,6 +804,24 @@ mod tests { assert!(restored.pending_auth.is_none()); } + #[test] + fn test_in_flight_auth_is_transient_across_serde() { + let mut thread = Thread::new(Uuid::new_v4()); + thread.in_flight_auth = true; + + let json = serde_json::to_string(&thread).expect("thread should serialise"); + assert!( + !json.contains("in_flight_auth"), + "in_flight_auth must be omitted from serialised JSON" + ); + + let restored: Thread = serde_json::from_str(&json).expect("thread should deserialise"); + assert!( + !restored.in_flight_auth, + "in_flight_auth must default to false after deserialisation" + ); + } + #[test] fn test_thread_with_id() { let specific_id = Uuid::new_v4(); diff --git a/src/agent/session/tests/record_tool_result_content.rs b/src/agent/session/tests/record_tool_result_content.rs new file mode 100644 index 000000000..e80a1e8e9 --- /dev/null +++ b/src/agent/session/tests/record_tool_result_content.rs @@ -0,0 +1,50 @@ +//! Tests for `Turn::record_tool_result_content` JSON-aware parsing behaviour. + +use rstest::rstest; + +use super::*; + +#[rstest] +#[case( + r#"{"ok":true,"items":[1,2]}"#, + serde_json::json!({"ok": true, "items": [1, 2]}) +)] +#[case("plain text", serde_json::Value::String("plain text".to_string()))] +#[case("[1,2,3]", serde_json::json!([1, 2, 3]))] +#[case("{bad", serde_json::Value::String("{bad".to_string()))] +#[case("[bad", serde_json::Value::String("[bad".to_string()))] +#[case(" {\"ok\":true}", serde_json::json!({"ok": true}))] +fn record_tool_result_content_cases( + #[case] raw_content: &str, + #[case] expected: serde_json::Value, +) { + let mut turn = Turn::new(1, "input"); + turn.record_tool_call("json", serde_json::json!({})); + turn.record_tool_result_content(raw_content); + + assert_eq!(turn.tool_calls[0].result, Some(expected)); +} + +#[test] +fn record_tool_result_at_returns_out_of_bounds_error() { + let mut turn = Turn::new(1, "input"); + turn.record_tool_call("json", serde_json::json!({})); + + let error = turn + .record_tool_result_at(1, serde_json::json!({"ok": true})) + .expect_err("out-of-bounds result write should fail"); + + assert_eq!(error, ToolCallIndexError::OutOfBounds { idx: 1, len: 1 }); +} + +#[test] +fn record_tool_error_at_returns_out_of_bounds_error() { + let mut turn = Turn::new(1, "input"); + turn.record_tool_call("json", serde_json::json!({})); + + let error = turn + .record_tool_error_at(1, "boom") + .expect_err("out-of-bounds error write should fail"); + + assert_eq!(error, ToolCallIndexError::OutOfBounds { idx: 1, len: 1 }); +} diff --git a/src/agent/thread_ops.rs b/src/agent/thread_ops.rs index 78bfebd44..2a9262637 100644 --- a/src/agent/thread_ops.rs +++ b/src/agent/thread_ops.rs @@ -2,12 +2,34 @@ //! //! Extracted from `agent_loop.rs` to isolate thread management (user input //! processing, undo/redo, approval, auth, persistence) from the core loop. +//! +//! This module is organized into submodules by responsibility: +//! - `approval`: Tool approval handling +//! - `control`: Thread control commands (undo, redo, interrupt, compact, clear, new, switch, resume) +//! - `dispatch`: Submission dispatch and hook adapters +//! - `document_store`: Document storage for extracted content +//! - `hydration`: Thread hydration from database +//! - `message_rebuild`: Message reconstruction from DB records +//! - `persistence`: Database persistence for messages and tool calls +//! - `turn_compaction_checkpointing`: Pre-turn compaction and undo checkpoints +//! - `turn_execution`: User turn execution and agentic loop orchestration +//! - `turn_preparation`: Thread-state checks, safety validation, and turn setup +//! - `turn_result_finalisation`: Loop-result handling and response persistence pub(crate) mod approval; +mod control; mod dispatch; mod document_store; +mod hydration; mod message_rebuild; mod persistence; +mod turn_compaction_checkpointing; +mod turn_execution; +mod turn_preparation; +mod turn_result_finalisation; + +pub(super) use persistence::TurnPersistContext; +pub(super) use turn_preparation::{PrepareTurnResult, UserTurnRequest}; use std::sync::Arc; @@ -15,19 +37,13 @@ use tokio::sync::Mutex; use uuid::Uuid; use crate::agent::Agent; -use crate::agent::compaction::ContextCompactor; -use crate::agent::dispatcher::AgenticLoopResult; -use crate::agent::session::{Session, ThreadState}; -use crate::agent::submission::{Submission, SubmissionParser, SubmissionResult}; -use crate::channels::web::util::truncate_preview; -use crate::channels::{IncomingMessage, StatusUpdate}; +use crate::agent::session::Session; +use crate::agent::submission::{Submission, SubmissionParser}; +use crate::channels::IncomingMessage; use crate::error::Error; -use crate::llm::ChatMessage; use dispatch::DispatchCtx; use document_store::store_extracted_documents as store_extracted_documents_impl; -use message_rebuild::rebuild_chat_messages_from_db; -use persistence::gateway_conversation_params; pub(super) async fn store_extracted_documents( workspace: &Arc, @@ -37,41 +53,6 @@ pub(super) async fn store_extracted_documents( } impl Agent { - async fn hydrate_and_resolve_session_thread( - &self, - message: &IncomingMessage, - ) -> (Arc>, Uuid) { - // Hydrate thread from DB if it's a historical thread not in memory - if let Some(ref external_thread_id) = message.thread_id { - tracing::trace!( - message_id = %message.id, - thread_id = %external_thread_id, - "Hydrating thread from DB" - ); - self.maybe_hydrate_thread(message, external_thread_id).await; - } - - tracing::debug!( - message_id = %message.id, - "Resolving session and thread" - ); - let (session, thread_id) = self - .session_manager - .resolve_thread( - &message.user_id, - &message.channel, - message.thread_id.as_deref(), - ) - .await; - tracing::debug!( - message_id = %message.id, - thread_id = %thread_id, - "Resolved session and thread" - ); - - (session, thread_id) - } - async fn check_auth_mode_intercept( &self, message: &IncomingMessage, @@ -118,11 +99,11 @@ impl Agent { } _ => { // Any control submission (interrupt, undo, etc.) cancels auth mode. - // Clear the in_flight_auth marker; pending_auth is cleared separately - // by the control handler path. + // Clear both auth markers so the next user turn is not intercepted. let mut sess = session.lock().await; if let Some(thread) = sess.threads.get_mut(&thread_id) { thread.in_flight_auth = false; + thread.pending_auth = None; } // Fall through to normal handling } @@ -163,7 +144,7 @@ impl Agent { // Parse submission type first let submission = SubmissionParser::parse(&message.content); - let (session, thread_id) = self.hydrate_and_resolve_session_thread(message).await; + let (session, thread_id) = self.hydrate_and_resolve_session_thread(message).await?; if let Some(result) = self .check_auth_mode_intercept(message, &submission, session.clone(), thread_id) @@ -213,796 +194,4 @@ impl Agent { let result = self.dispatch_submission(ctx, submission).await?; self.map_submission_result(message, result).await } - - /// Hydrate a historical thread from DB into memory if not already present. - /// - /// Called before `resolve_thread` so that the session manager finds the - /// thread on lookup instead of creating a new one. - /// - /// Creates an in-memory thread with the exact UUID the frontend sent, - /// even when the conversation has zero messages (e.g. a brand-new - /// assistant thread). Without this, `resolve_thread` would mint a - /// fresh UUID and all messages would land in the wrong conversation. - pub(super) async fn maybe_hydrate_thread( - &self, - message: &IncomingMessage, - external_thread_id: &str, - ) { - // Only hydrate UUID-shaped thread IDs (web gateway uses UUIDs) - let thread_uuid = match Uuid::parse_str(external_thread_id) { - Ok(id) => id, - Err(_) => return, - }; - - // Check if already in memory - let session = self - .session_manager - .get_or_create_session(&message.user_id) - .await; - { - let sess = session.lock().await; - if sess.threads.contains_key(&thread_uuid) { - return; - } - } - - // Load history from DB (may be empty for a newly created thread). - let mut chat_messages: Vec = Vec::new(); - let msg_count; - - if let Some(store) = self.store() { - let db_messages = store - .list_conversation_messages(thread_uuid) - .await - .unwrap_or_default(); - msg_count = db_messages.len(); - chat_messages = rebuild_chat_messages_from_db(&db_messages, self.safety()); - } else { - msg_count = 0; - } - - // Create thread with the historical ID and restore messages - let session_id = { - let sess = session.lock().await; - sess.id - }; - - let mut thread = crate::agent::session::Thread::with_id(thread_uuid, session_id); - if !chat_messages.is_empty() { - thread.restore_from_messages(chat_messages); - } - - // Insert into session and register with session manager - { - let mut sess = session.lock().await; - sess.threads.insert(thread_uuid, thread); - sess.active_thread = Some(thread_uuid); - sess.last_active_at = chrono::Utc::now(); - } - - self.session_manager - .register_thread( - &message.user_id, - &message.channel, - thread_uuid, - Arc::clone(&session), - ) - .await; - - tracing::debug!( - "Hydrated thread {} from DB ({} messages)", - thread_uuid, - msg_count - ); - } - - pub(super) async fn process_user_input( - &self, - message: &IncomingMessage, - session: Arc>, - thread_id: Uuid, - content: &str, - ) -> Result { - tracing::debug!( - message_id = %message.id, - thread_id = %thread_id, - content_len = content.len(), - "Processing user input" - ); - - // First check thread state without holding lock during I/O - let thread_state = { - let sess = session.lock().await; - let thread = sess - .threads - .get(&thread_id) - .ok_or_else(|| Error::from(crate::error::JobError::NotFound { id: thread_id }))?; - thread.state - }; - - tracing::debug!( - message_id = %message.id, - thread_id = %thread_id, - thread_state = ?thread_state, - "Checked thread state" - ); - - // Check thread state - match thread_state { - ThreadState::Processing => { - tracing::warn!( - message_id = %message.id, - thread_id = %thread_id, - "Thread is processing, rejecting new input" - ); - return Ok(SubmissionResult::error( - "Turn in progress. Use /interrupt to cancel.", - )); - } - ThreadState::AwaitingApproval => { - tracing::warn!( - message_id = %message.id, - thread_id = %thread_id, - "Thread awaiting approval, rejecting new input" - ); - return Ok(SubmissionResult::error( - "Waiting for approval. Use /interrupt to cancel.", - )); - } - ThreadState::Completed => { - tracing::warn!( - message_id = %message.id, - thread_id = %thread_id, - "Thread completed, rejecting new input" - ); - return Ok(SubmissionResult::error( - "Thread completed. Use /thread new.", - )); - } - ThreadState::Idle | ThreadState::Interrupted => { - // Can proceed - } - } - - // Safety validation for user input - let validation = self.safety().validate_input(content); - if !validation.is_valid { - let details = validation - .errors - .iter() - .map(|e| format!("{}: {}", e.field, e.message)) - .collect::>() - .join("; "); - return Ok(SubmissionResult::error(format!( - "Input rejected by safety validation: {}", - details - ))); - } - - let violations = self.safety().check_policy(content); - if violations - .iter() - .any(|rule| rule.action == crate::safety::PolicyAction::Block) - { - return Ok(SubmissionResult::error("Input rejected by safety policy.")); - } - - // Scan inbound messages for secrets (API keys, tokens). - // Catching them here prevents the LLM from echoing them back, which - // would trigger the outbound leak detector and create error loops. - if let Some(warning) = self.safety().scan_inbound_for_secrets(content) { - tracing::warn!( - user = %message.user_id, - channel = %message.channel, - "Inbound message blocked: contains leaked secret" - ); - return Ok(SubmissionResult::error(warning)); - } - - // Handle explicit commands (starting with /) directly - // Everything else goes through the normal agentic loop with tools - let temp_message = IncomingMessage { - content: content.to_string(), - ..message.clone() - }; - - if let Some(intent) = self.router.route_command(&temp_message) { - // Explicit command like /status, /job, /list - handle directly - return self.handle_job_or_command(intent, message).await; - } - - // Natural language goes through the agentic loop - // Job tools (create_job, list_jobs, etc.) are in the tool registry - - // Auto-compact if needed BEFORE adding new turn - { - let mut sess = session.lock().await; - let thread = sess - .threads - .get_mut(&thread_id) - .ok_or_else(|| Error::from(crate::error::JobError::NotFound { id: thread_id }))?; - - let messages = thread.messages(); - if let Some(strategy) = self.context_monitor.suggest_compaction(&messages) { - let pct = self.context_monitor.usage_percent(&messages); - tracing::info!("Context at {:.1}% capacity, auto-compacting", pct); - - // Notify the user that compaction is happening - let _ = self - .channels - .send_status( - &message.channel, - StatusUpdate::Status(format!( - "Context at {:.0}% capacity, compacting...", - pct - )), - &message.metadata, - ) - .await; - - let compactor = ContextCompactor::new(self.llm().clone()); - if let Err(e) = compactor - .compact(thread, strategy, self.workspace().map(|w| w.as_ref())) - .await - { - tracing::warn!("Auto-compaction failed: {}", e); - } - } - } - - // Create checkpoint before turn - let undo_mgr = self.session_manager.get_undo_manager(thread_id).await; - { - let sess = session.lock().await; - let thread = sess - .threads - .get(&thread_id) - .ok_or_else(|| Error::from(crate::error::JobError::NotFound { id: thread_id }))?; - - let mut mgr = undo_mgr.lock().await; - mgr.checkpoint( - thread.turn_number(), - thread.messages(), - format!("Before turn {}", thread.turn_number()), - ); - } - - // Augment content with attachment context (transcripts, metadata, images) - let augmented = - crate::agent::attachments::augment_with_attachments(content, &message.attachments); - let (effective_content, image_parts) = match &augmented { - Some(result) => (result.text.as_str(), result.image_parts.clone()), - None => (content, Vec::new()), - }; - - // Start the turn and get messages - let turn_messages = { - let mut sess = session.lock().await; - let thread = sess - .threads - .get_mut(&thread_id) - .ok_or_else(|| Error::from(crate::error::JobError::NotFound { id: thread_id }))?; - let turn = thread.start_turn(effective_content); - turn.image_content_parts = image_parts; - thread.messages() - }; - - // Persist user message to DB immediately so it survives crashes - tracing::debug!( - message_id = %message.id, - thread_id = %thread_id, - "Persisting user message to DB" - ); - self.persist_user_message(thread_id, &message.user_id, effective_content) - .await; - - tracing::debug!( - message_id = %message.id, - thread_id = %thread_id, - "User message persisted, starting agentic loop" - ); - - // Send thinking status - let _ = self - .channels - .send_status( - &message.channel, - StatusUpdate::Thinking("Processing...".into()), - &message.metadata, - ) - .await; - - // Run the agentic tool execution loop - let result = self - .run_agentic_loop(message, session.clone(), thread_id, turn_messages) - .await; - - // Re-acquire lock and check if interrupted - let interrupted = { - let mut sess = session.lock().await; - let thread = sess - .threads - .get_mut(&thread_id) - .ok_or_else(|| Error::from(crate::error::JobError::NotFound { id: thread_id }))?; - thread.state == ThreadState::Interrupted - }; - if interrupted { - let _ = self - .channels - .send_status( - &message.channel, - StatusUpdate::Status("Interrupted".into()), - &message.metadata, - ) - .await; - return Ok(SubmissionResult::Interrupted); - } - - // Re-acquire lock for processing result - let mut sess = session.lock().await; - let thread = sess - .threads - .get_mut(&thread_id) - .ok_or_else(|| Error::from(crate::error::JobError::NotFound { id: thread_id }))?; - - // Complete, fail, or request approval - match result { - Ok(AgenticLoopResult::Response(response)) => { - // Drop the session lock before running the response transform hook - drop(sess); - - // Hook: TransformResponse — allow hooks to modify or reject the final response - let response = { - let event = crate::hooks::HookEvent::ResponseTransform { - user_id: message.user_id.clone(), - thread_id: thread_id.to_string(), - response: response.clone(), - }; - match self.hooks().run(&event).await { - Err(crate::hooks::HookError::Rejected { reason }) => { - format!("[Response filtered: {}]", reason) - } - Ok(crate::hooks::HookOutcome::Reject { reason }) => { - format!("[Response filtered: {}]", reason) - } - Err(err) => { - tracing::warn!("TransformResponse hook failed open: {}", err); - response - } - Ok(crate::hooks::HookOutcome::Continue { - modified: Some(new_response), - }) => new_response, - _ => response, // fail-open: use original - } - }; - - // Re-acquire lock to complete turn and snapshot data - let completion = { - let mut sess = session.lock().await; - let thread = sess.threads.get_mut(&thread_id).ok_or_else(|| { - Error::from(crate::error::JobError::NotFound { id: thread_id }) - })?; - if thread.state == ThreadState::Interrupted { - None - } else { - thread.complete_turn(&response); - Some( - thread - .turns - .last() - .map(|t| (t.turn_number, t.tool_calls.clone())) - .unwrap_or_default(), - ) - } - }; - let Some((turn_number, tool_calls)) = completion else { - let _ = self - .channels - .send_status( - &message.channel, - StatusUpdate::Status("Interrupted".into()), - &message.metadata, - ) - .await; - return Ok(SubmissionResult::Interrupted); - }; - // Lock is dropped here at end of block - - let _ = self - .channels - .send_status( - &message.channel, - StatusUpdate::Status("Done".into()), - &message.metadata, - ) - .await; - - // Persist tool calls then assistant response (user message already persisted at turn start) - self.persist_tool_calls(thread_id, &message.user_id, turn_number, &tool_calls) - .await; - self.persist_assistant_response(thread_id, &message.user_id, &response) - .await; - - Ok(SubmissionResult::response(response)) - } - Ok(AgenticLoopResult::NeedApproval { pending }) => { - // Store pending approval in thread and update state - let request_id = pending.request_id; - let tool_name = pending.tool_name.clone(); - let description = pending.description.clone(); - let parameters = pending.display_parameters.clone(); - thread.await_approval(pending); - // Drop the session lock before async operations - drop(sess); - - let _ = self - .channels - .send_status( - &message.channel, - StatusUpdate::Status("Awaiting approval".into()), - &message.metadata, - ) - .await; - Ok(SubmissionResult::NeedApproval { - request_id, - tool_name, - description, - parameters, - }) - } - Err(e) => { - thread.fail_turn(e.to_string()); - // User message already persisted at turn start; nothing else to save - Ok(SubmissionResult::error(e.to_string())) - } - } - } - - /// Persist the user message to the DB at turn start (before the agentic loop). - /// - /// This ensures the user message is durable even if the process crashes - /// mid-response. Call this right after `thread.start_turn()`. - pub(super) async fn persist_user_message( - &self, - thread_id: Uuid, - user_id: &str, - user_input: &str, - ) { - let store = match self.store() { - Some(s) => Arc::clone(s), - None => return, - }; - - if let Err(e) = store - .ensure_conversation(gateway_conversation_params(thread_id, user_id)) - .await - { - tracing::warn!("Failed to ensure conversation {}: {}", thread_id, e); - return; - } - - if let Err(e) = store - .add_conversation_message(thread_id, "user", user_input) - .await - { - tracing::warn!("Failed to persist user message: {}", e); - } - } - - /// Persist the assistant response to the DB after the agentic loop completes. - /// - /// Re-ensures the conversation row exists so that assistant responses are - /// still persisted even if `persist_user_message` failed transiently at - /// turn start (e.g. a brief DB blip that resolved before response time). - pub(super) async fn persist_assistant_response( - &self, - thread_id: Uuid, - user_id: &str, - response: &str, - ) { - let store = match self.store() { - Some(s) => Arc::clone(s), - None => return, - }; - - if let Err(e) = store - .ensure_conversation(gateway_conversation_params(thread_id, user_id)) - .await - { - tracing::warn!("Failed to ensure conversation {}: {}", thread_id, e); - return; - } - - if let Err(e) = store - .add_conversation_message(thread_id, "assistant", response) - .await - { - tracing::warn!("Failed to persist assistant message: {}", e); - } - } - - /// Persist tool call summaries to the DB as a `role="tool_calls"` message. - /// - /// Stored between the user and assistant messages so that - /// `build_turns_from_db_messages` can reconstruct the tool call history. - /// Content is a JSON array of tool call summaries. - pub(super) async fn persist_tool_calls( - &self, - thread_id: Uuid, - user_id: &str, - turn_number: usize, - tool_calls: &[crate::agent::session::TurnToolCall], - ) { - if tool_calls.is_empty() { - return; - } - - let store = match self.store() { - Some(s) => Arc::clone(s), - None => return, - }; - - let summaries: Vec = tool_calls - .iter() - .enumerate() - .map(|(i, tc)| { - let mut obj = serde_json::json!({ - "name": tc.name, - "call_id": format!("turn{}_{}", turn_number, i), - }); - if let Some(ref result) = tc.result { - let preview = match result { - serde_json::Value::String(s) => truncate_preview(s, 500), - other => truncate_preview(&other.to_string(), 500), - }; - obj["result_preview"] = serde_json::Value::String(preview); - // Store full result (truncated to ~1000 chars) for LLM context rebuild - let full_result = match result { - serde_json::Value::String(s) => truncate_preview(s, 1000), - other => truncate_preview(&other.to_string(), 1000), - }; - obj["result"] = serde_json::Value::String(full_result); - } - if let Some(ref error) = tc.error { - obj["error"] = serde_json::Value::String(truncate_preview(error, 200)); - } - obj - }) - .collect(); - - let content = match serde_json::to_string(&summaries) { - Ok(c) => c, - Err(e) => { - tracing::warn!("Failed to serialize tool calls: {}", e); - return; - } - }; - - if let Err(e) = store - .ensure_conversation(gateway_conversation_params(thread_id, user_id)) - .await - { - tracing::warn!("Failed to ensure conversation {}: {}", thread_id, e); - return; - } - - if let Err(e) = store - .add_conversation_message(thread_id, "tool_calls", &content) - .await - { - tracing::warn!("Failed to persist tool calls: {}", e); - } - } - - pub(super) async fn process_undo( - &self, - session: Arc>, - thread_id: Uuid, - ) -> Result { - let undo_mgr = self.session_manager.get_undo_manager(thread_id).await; - let mut mgr = undo_mgr.lock().await; - - if !mgr.can_undo() { - return Ok(SubmissionResult::ok_with_message("Nothing to undo.")); - } - - let mut sess = session.lock().await; - let thread = sess - .threads - .get_mut(&thread_id) - .ok_or_else(|| Error::from(crate::error::JobError::NotFound { id: thread_id }))?; - - // Save current state to redo, get previous checkpoint - let current_messages = thread.messages(); - let current_turn = thread.turn_number(); - - if let Some(checkpoint) = mgr.undo(current_turn, current_messages) { - // Extract values before consuming the reference - let turn_number = checkpoint.turn_number; - let messages = checkpoint.messages.clone(); - let undo_count = mgr.undo_count(); - // Restore thread from checkpoint - thread.restore_from_messages(messages); - Ok(SubmissionResult::ok_with_message(format!( - "Undone to turn {}. {} undo(s) remaining.", - turn_number, undo_count - ))) - } else { - Ok(SubmissionResult::error("Undo failed.")) - } - } - - pub(super) async fn process_redo( - &self, - session: Arc>, - thread_id: Uuid, - ) -> Result { - let undo_mgr = self.session_manager.get_undo_manager(thread_id).await; - let mut mgr = undo_mgr.lock().await; - - if !mgr.can_redo() { - return Ok(SubmissionResult::ok_with_message("Nothing to redo.")); - } - - let mut sess = session.lock().await; - let thread = sess - .threads - .get_mut(&thread_id) - .ok_or_else(|| Error::from(crate::error::JobError::NotFound { id: thread_id }))?; - - let current_messages = thread.messages(); - let current_turn = thread.turn_number(); - - if let Some(checkpoint) = mgr.redo(current_turn, current_messages) { - thread.restore_from_messages(checkpoint.messages); - Ok(SubmissionResult::ok_with_message(format!( - "Redone to turn {}.", - checkpoint.turn_number - ))) - } else { - Ok(SubmissionResult::error("Redo failed.")) - } - } - - pub(super) async fn process_interrupt( - &self, - session: Arc>, - thread_id: Uuid, - ) -> Result { - let mut sess = session.lock().await; - let thread = sess - .threads - .get_mut(&thread_id) - .ok_or_else(|| Error::from(crate::error::JobError::NotFound { id: thread_id }))?; - - match thread.state { - ThreadState::Processing | ThreadState::AwaitingApproval => { - thread.interrupt(); - Ok(SubmissionResult::ok_with_message("Interrupted.")) - } - _ => Ok(SubmissionResult::ok_with_message("Nothing to interrupt.")), - } - } - - pub(super) async fn process_compact( - &self, - session: Arc>, - thread_id: Uuid, - ) -> Result { - let mut sess = session.lock().await; - let thread = sess - .threads - .get_mut(&thread_id) - .ok_or_else(|| Error::from(crate::error::JobError::NotFound { id: thread_id }))?; - - let messages = thread.messages(); - let usage = self.context_monitor.usage_percent(&messages); - let strategy = self - .context_monitor - .suggest_compaction(&messages) - .unwrap_or( - crate::agent::context_monitor::CompactionStrategy::Summarize { keep_recent: 5 }, - ); - - let compactor = ContextCompactor::new(self.llm().clone()); - match compactor - .compact(thread, strategy, self.workspace().map(|w| w.as_ref())) - .await - { - Ok(result) => { - let mut msg = format!( - "Compacted: {} turns removed, {} → {} tokens (was {:.1}% full)", - result.turns_removed, result.tokens_before, result.tokens_after, usage - ); - if result.summary_written { - msg.push_str(", summary saved to workspace"); - } - Ok(SubmissionResult::ok_with_message(msg)) - } - Err(e) => Ok(SubmissionResult::error(format!("Compaction failed: {}", e))), - } - } - - pub(super) async fn process_clear( - &self, - session: Arc>, - thread_id: Uuid, - ) -> Result { - let mut sess = session.lock().await; - let thread = sess - .threads - .get_mut(&thread_id) - .ok_or_else(|| Error::from(crate::error::JobError::NotFound { id: thread_id }))?; - thread.turns.clear(); - thread.state = ThreadState::Idle; - - // Clear undo history too - let undo_mgr = self.session_manager.get_undo_manager(thread_id).await; - undo_mgr.lock().await.clear(); - - Ok(SubmissionResult::ok_with_message("Thread cleared.")) - } - - pub(super) async fn process_new_thread( - &self, - message: &IncomingMessage, - ) -> Result { - let session = self - .session_manager - .get_or_create_session(&message.user_id) - .await; - let mut sess = session.lock().await; - let thread = sess.create_thread(); - let thread_id = thread.id; - Ok(SubmissionResult::ok_with_message(format!( - "New thread: {}", - thread_id - ))) - } - - pub(super) async fn process_switch_thread( - &self, - message: &IncomingMessage, - target_thread_id: Uuid, - ) -> Result { - let session = self - .session_manager - .get_or_create_session(&message.user_id) - .await; - let mut sess = session.lock().await; - - if sess.switch_thread(target_thread_id) { - Ok(SubmissionResult::ok_with_message(format!( - "Switched to thread {}", - target_thread_id - ))) - } else { - Ok(SubmissionResult::error("Thread not found.")) - } - } - - pub(super) async fn process_resume( - &self, - session: Arc>, - thread_id: Uuid, - checkpoint_id: Uuid, - ) -> Result { - let undo_mgr = self.session_manager.get_undo_manager(thread_id).await; - let mut mgr = undo_mgr.lock().await; - - if let Some(checkpoint) = mgr.restore(checkpoint_id) { - let mut sess = session.lock().await; - let thread = sess - .threads - .get_mut(&thread_id) - .ok_or_else(|| Error::from(crate::error::JobError::NotFound { id: thread_id }))?; - thread.restore_from_messages(checkpoint.messages); - Ok(SubmissionResult::ok_with_message(format!( - "Resumed from checkpoint: {}", - checkpoint.description - ))) - } else { - Ok(SubmissionResult::error("Checkpoint not found.")) - } - } } diff --git a/src/agent/thread_ops/approval.rs b/src/agent/thread_ops/approval.rs index 6f44638e4..21f863334 100644 --- a/src/agent/thread_ops/approval.rs +++ b/src/agent/thread_ops/approval.rs @@ -40,11 +40,12 @@ use uuid::Uuid; use crate::agent::Agent; use crate::agent::dispatcher::{ - AgenticLoopResult, ChatToolRequest, check_auth_required, execute_chat_tool_standalone, + AgenticLoopResult, ToolCallSpec, check_auth_required, execute_chat_tool_standalone, parse_auth_result, }; use crate::agent::session::{PendingApproval, Session, ThreadState}; use crate::agent::submission::SubmissionResult; +use crate::agent::thread_ops::TurnPersistContext; use crate::channels::{IncomingMessage, StatusUpdate}; use crate::context::JobContext; use crate::error::Error; @@ -166,6 +167,8 @@ struct AuthInterceptParams<'a> { env: &'a MsgEnv, /// Tool execution result (used to extract auth URLs). tool_result: &'a Result, + /// Tool name for auth-barrier result parsing. + tool_name: &'a str, /// Extension name requiring authentication. ext_name: String, /// Instructions to display to the user. @@ -363,7 +366,7 @@ impl Agent { if is_tool_error { turn.record_tool_error(result_content.clone()); } else { - turn.record_tool_result(serde_json::json!(result_content)); + turn.record_tool_result_content(&result_content); } } } @@ -385,6 +388,7 @@ impl Agent { thread_id: scope.thread_id, env: &scope.env, tool_result, + tool_name: &pending.tool_name, ext_name, instructions: instructions.clone(), pending: Some(pending.clone()), @@ -550,8 +554,8 @@ impl Agent { let result = execute_chat_tool_standalone( &tools, &safety, - &ChatToolRequest { - tool_name: &tc.name, + &ToolCallSpec { + name: &tc.name, params: &tc.arguments, }, &job_ctx, @@ -643,7 +647,7 @@ impl Agent { if is_deferred_error { turn.record_tool_error(deferred_content.clone()); } else { - turn.record_tool_result(serde_json::json!(deferred_content)); + turn.record_tool_result_content(&deferred_content); } } } @@ -673,6 +677,7 @@ impl Agent { thread_id: scope.thread_id, env: &scope.env, tool_result: &deferred_result, + tool_name: &tc.name, ext_name, instructions: instructions.clone(), pending: Some(fresh_pending), @@ -773,13 +778,12 @@ impl Agent { }; // User message already persisted at turn start; save tool calls then assistant response - self.persist_tool_calls( - scope.thread_id, - &scope.env.user_id, + let persist_ctx = TurnPersistContext { + thread_id: scope.thread_id, + user_id: &scope.env.user_id, turn_number, - &tool_calls, - ) - .await; + }; + self.persist_tool_calls(&persist_ctx, &tool_calls).await; self.persist_assistant_response(scope.thread_id, &scope.env.user_id, response) .await; let _ = self @@ -1105,7 +1109,7 @@ impl Agent { /// to preserve deferred tool calls and context messages, completes + persists /// the turn, and sends the AuthRequired status to the channel. async fn handle_auth_intercept(&self, params: AuthInterceptParams<'_>) { - let auth_data = parse_auth_result(params.tool_result); + let auth_data = parse_auth_result(params.tool_name, params.tool_result); { let mut sess = params.session.lock().await; if let Some(thread) = sess.threads.get_mut(¶ms.thread_id) { @@ -1259,6 +1263,7 @@ impl Agent { let mut sess = scope.session.lock().await; if let Some(thread) = sess.threads.get_mut(&scope.thread_id) { thread.pending_auth = None; + thread.clear_pending_approval(); } } tracing::info!( diff --git a/src/agent/thread_ops/control.rs b/src/agent/thread_ops/control.rs new file mode 100644 index 000000000..ad494b09b --- /dev/null +++ b/src/agent/thread_ops/control.rs @@ -0,0 +1,305 @@ +//! Thread control command handlers. +//! +//! Contains handlers for thread lifecycle state transitions: +//! - Undo/redo operations +//! - Interrupt processing +//! - Context compaction +//! - Thread clearing +//! - New thread creation +//! - Thread switching +//! - Resume from checkpoint + +use std::sync::Arc; + +use chrono::Utc; +use tokio::sync::Mutex; +use uuid::Uuid; + +use crate::agent::Agent; +use crate::agent::compaction::ContextCompactor; +use crate::agent::session::{Session, ThreadState}; +use crate::agent::submission::SubmissionResult; +use crate::agent::undo::{Checkpoint, UndoManager}; +use crate::error::Error; +use crate::llm::ChatMessage; + +#[derive(Clone, Copy)] +enum RewindOp { + Undo, + Redo, +} + +impl Agent { + fn availability_message(mgr: &UndoManager, op: RewindOp) -> Option<&'static str> { + match op { + RewindOp::Undo if !mgr.can_undo() => Some("Nothing to undo."), + RewindOp::Redo if !mgr.can_redo() => Some("Nothing to redo."), + _ => None, + } + } + + fn failure_msg(op: RewindOp) -> &'static str { + match op { + RewindOp::Undo => "Undo failed.", + RewindOp::Redo => "Redo failed.", + } + } + + fn success_msg(op: RewindOp, turn: usize, undo_count: usize) -> String { + match op { + RewindOp::Undo => format!("Undone to turn {turn}.\n{undo_count} undo(s) remaining."), + RewindOp::Redo => format!("Redone to turn {turn}."), + } + } + + fn perform_rewind( + mgr: &mut UndoManager, + op: RewindOp, + current_turn: usize, + current_messages: Vec, + ) -> Option { + match op { + RewindOp::Undo => mgr.undo(current_turn, current_messages), + RewindOp::Redo => mgr.redo(current_turn, current_messages), + } + } + + async fn restore_thread_from_checkpoint( + session: &Arc>, + thread_id: Uuid, + messages: Vec, + ) -> Result<(), Error> { + let mut sess = session.lock().await; + let thread = sess + .threads + .get_mut(&thread_id) + .ok_or_else(|| Error::from(crate::error::JobError::NotFound { id: thread_id }))?; + thread.restore_from_messages(messages); + thread.updated_at = Utc::now(); + Ok(()) + } + + async fn process_rewind( + &self, + session: Arc>, + thread_id: Uuid, + op: RewindOp, + ) -> Result { + let undo_mgr = self.session_manager.get_undo_manager(thread_id).await; + let mut mgr = undo_mgr.lock().await; + + if let Some(msg) = Self::availability_message(&mgr, op) { + return Ok(SubmissionResult::ok_with_message(msg.to_string())); + } + + let (turn, messages) = { + let sess = session.lock().await; + let thread = sess + .threads + .get(&thread_id) + .ok_or_else(|| Error::from(crate::error::JobError::NotFound { id: thread_id }))?; + (thread.turn_number(), thread.messages()) + }; + + let Some(cp) = Self::perform_rewind(&mut mgr, op, turn, messages) else { + return Ok(SubmissionResult::error(Self::failure_msg(op))); + }; + + let msg = Self::success_msg(op, cp.turn_number, mgr.undo_count()); + Self::restore_thread_from_checkpoint(&session, thread_id, cp.messages).await?; + Ok(SubmissionResult::ok_with_message(msg)) + } + + pub(super) async fn process_undo( + &self, + session: Arc>, + thread_id: Uuid, + ) -> Result { + self.process_rewind(session, thread_id, RewindOp::Undo) + .await + } + + pub(super) async fn process_redo( + &self, + session: Arc>, + thread_id: Uuid, + ) -> Result { + self.process_rewind(session, thread_id, RewindOp::Redo) + .await + } + + pub(super) async fn process_interrupt( + &self, + session: Arc>, + thread_id: Uuid, + ) -> Result { + let mut sess = session.lock().await; + let thread = sess + .threads + .get_mut(&thread_id) + .ok_or_else(|| Error::from(crate::error::JobError::NotFound { id: thread_id }))?; + + match thread.state { + ThreadState::Processing | ThreadState::AwaitingApproval => { + thread.interrupt(); + Ok(SubmissionResult::ok_with_message("Interrupted.")) + } + _ => Ok(SubmissionResult::ok_with_message("Nothing to interrupt.")), + } + } + + pub(super) async fn process_compact( + &self, + session: Arc>, + thread_id: Uuid, + ) -> Result { + let (mut thread_snapshot, usage, strategy) = { + let sess = session.lock().await; + let thread = sess + .threads + .get(&thread_id) + .ok_or_else(|| Error::from(crate::error::JobError::NotFound { id: thread_id }))?; + + let messages = thread.messages(); + let usage = self.context_monitor.usage_percent(&messages); + let strategy = self + .context_monitor + .suggest_compaction(&messages) + .unwrap_or( + crate::agent::context_monitor::CompactionStrategy::Summarize { keep_recent: 5 }, + ); + + (thread.clone(), usage, strategy) + }; + + let original_updated_at = thread_snapshot.updated_at; + let original_turns_len = thread_snapshot.turns.len(); + let compactor = ContextCompactor::new(self.llm().clone()); + match compactor + .compact( + &mut thread_snapshot, + strategy, + self.workspace().map(|w| w.as_ref()), + ) + .await + { + Ok(result) => { + let mut sess = session.lock().await; + let thread = sess.threads.get_mut(&thread_id).ok_or_else(|| { + Error::from(crate::error::JobError::NotFound { id: thread_id }) + })?; + if thread.updated_at != original_updated_at + || thread.turns.len() != original_turns_len + { + return Ok(SubmissionResult::error( + "Thread changed while compaction was running. Please retry.", + )); + } + thread.turns = thread_snapshot.turns; + thread.updated_at = Utc::now(); + + let mut msg = format!( + "Compacted: {} turns removed, {} → {} tokens (was {:.1}% full)", + result.turns_removed, result.tokens_before, result.tokens_after, usage + ); + if result.summary_written { + msg.push_str(", summary saved to workspace"); + } + Ok(SubmissionResult::ok_with_message(msg)) + } + Err(e) => Ok(SubmissionResult::error(format!("Compaction failed: {}", e))), + } + } + + pub(super) async fn process_clear( + &self, + session: Arc>, + thread_id: Uuid, + ) -> Result { + let undo_mgr = self.session_manager.get_undo_manager(thread_id).await; + undo_mgr.lock().await.clear(); + + let mut sess = session.lock().await; + let thread = sess + .threads + .get_mut(&thread_id) + .ok_or_else(|| Error::from(crate::error::JobError::NotFound { id: thread_id }))?; + thread.turns.clear(); + thread.state = ThreadState::Idle; + thread.updated_at = Utc::now(); + + Ok(SubmissionResult::ok_with_message("Thread cleared.")) + } + + pub(super) async fn process_new_thread( + &self, + message: &crate::channels::IncomingMessage, + ) -> Result { + let session = self + .session_manager + .get_or_create_session(&message.user_id) + .await; + let mut sess = session.lock().await; + let thread = sess.create_thread(); + let thread_id = thread.id; + Ok(SubmissionResult::ok_with_message(format!( + "New thread: {}", + thread_id + ))) + } + + pub(super) async fn process_switch_thread( + &self, + message: &crate::channels::IncomingMessage, + target_thread_id: Uuid, + ) -> Result { + let session = self + .session_manager + .get_or_create_session(&message.user_id) + .await; + let mut sess = session.lock().await; + + if sess.switch_thread(target_thread_id) { + Ok(SubmissionResult::ok_with_message(format!( + "Switched to thread {}", + target_thread_id + ))) + } else { + Ok(SubmissionResult::error("Thread not found.")) + } + } + + pub(super) async fn process_resume( + &self, + session: Arc>, + thread_id: Uuid, + checkpoint_id: Uuid, + ) -> Result { + { + let sess = session.lock().await; + let _thread = sess + .threads + .get(&thread_id) + .ok_or_else(|| Error::from(crate::error::JobError::NotFound { id: thread_id }))?; + } + + let undo_mgr = self.session_manager.get_undo_manager(thread_id).await; + let mut mgr = undo_mgr.lock().await; + + if let Some(checkpoint) = mgr.restore(checkpoint_id) { + let mut sess = session.lock().await; + let thread = sess + .threads + .get_mut(&thread_id) + .ok_or_else(|| Error::from(crate::error::JobError::NotFound { id: thread_id }))?; + thread.restore_from_messages(checkpoint.messages); + thread.updated_at = Utc::now(); + Ok(SubmissionResult::ok_with_message(format!( + "Resumed from checkpoint: {}", + checkpoint.description + ))) + } else { + Ok(SubmissionResult::error("Checkpoint not found.")) + } + } +} diff --git a/src/agent/thread_ops/dispatch.rs b/src/agent/thread_ops/dispatch.rs index 7fe398f04..b9cece87a 100644 --- a/src/agent/thread_ops/dispatch.rs +++ b/src/agent/thread_ops/dispatch.rs @@ -8,6 +8,7 @@ use uuid::Uuid; use crate::agent::Agent; use crate::agent::session::Session; use crate::agent::submission::{Submission, SubmissionParser, SubmissionResult}; +use crate::agent::thread_ops::UserTurnRequest; use crate::agent::thread_ops::approval::{ApprovalParams, TurnScope}; use crate::channels::{IncomingMessage, StatusUpdate}; use crate::error::Error; @@ -111,16 +112,41 @@ impl Agent { } } + /// Route an approval decision to `process_approval`. + /// + /// Called by both `ExecApproval` (which carries an explicit `request_id`) and + /// `ApprovalResponse` (which relies on the session's pending approval slot). + async fn dispatch_approval( + &self, + ctx: &DispatchCtx, + params: ApprovalParams, + ) -> Result { + let scope = TurnScope::new(ctx.session.clone(), ctx.thread_id, &ctx.message); + self.process_approval(scope, params).await + } + + /// Build a [`UserTurnRequest`] from the dispatch context and delegate to + /// [`Agent::process_user_input`]. + async fn dispatch_user_input( + &self, + ctx: DispatchCtx, + content: String, + ) -> Result { + let req = UserTurnRequest { + session: ctx.session, + thread_id: ctx.thread_id, + content, + }; + self.process_user_input(&ctx.message, req).await + } + pub(super) async fn dispatch_submission( &self, ctx: DispatchCtx, submission: Submission, ) -> Result { match submission { - Submission::UserInput { content } => { - self.process_user_input(&ctx.message, ctx.session, ctx.thread_id, &content) - .await - } + Submission::UserInput { content } => self.dispatch_user_input(ctx, content).await, Submission::SystemCommand { command, args } => { tracing::debug!( "[agent_loop] SystemCommand: command={}, channel={}", @@ -159,22 +185,26 @@ impl Agent { approved, always, } => { - let scope = TurnScope::new(ctx.session.clone(), ctx.thread_id, &ctx.message); - let params = ApprovalParams { - request_id: Some(request_id), - approved, - always, - }; - self.process_approval(scope, params).await + self.dispatch_approval( + &ctx, + ApprovalParams { + request_id: Some(request_id), + approved, + always, + }, + ) + .await } Submission::ApprovalResponse { approved, always } => { - let scope = TurnScope::new(ctx.session.clone(), ctx.thread_id, &ctx.message); - let params = ApprovalParams { - request_id: None, - approved, - always, - }; - self.process_approval(scope, params).await + self.dispatch_approval( + &ctx, + ApprovalParams { + request_id: None, + approved, + always, + }, + ) + .await } } } diff --git a/src/agent/thread_ops/hydration.rs b/src/agent/thread_ops/hydration.rs new file mode 100644 index 000000000..921e7b568 --- /dev/null +++ b/src/agent/thread_ops/hydration.rs @@ -0,0 +1,141 @@ +//! Thread hydration from database. +//! +//! Handles loading historical threads from the database into memory, +//! including message reconstruction and session registration. + +use std::sync::Arc; + +use tokio::sync::Mutex; +use uuid::Uuid; + +use crate::agent::Agent; +use crate::agent::session::Session; +use crate::agent::thread_ops::message_rebuild::rebuild_chat_messages_from_db; +use crate::channels::IncomingMessage; +use crate::llm::ChatMessage; + +impl Agent { + /// Hydrate and resolve session/thread for an incoming message. + /// + /// This is the main entry point for message handling. It hydrates the thread + /// from the database if needed, then resolves the session and thread IDs. + pub(super) async fn hydrate_and_resolve_session_thread( + &self, + message: &IncomingMessage, + ) -> Result<(Arc>, Uuid), crate::error::Error> { + // Hydrate thread from DB if it's a historical thread not in memory + if let Some(ref external_thread_id) = message.thread_id { + tracing::trace!( + message_id = %message.id, + thread_id = %external_thread_id, + "Hydrating thread from DB" + ); + self.maybe_hydrate_thread(message, external_thread_id) + .await?; + } + + tracing::debug!( + message_id = %message.id, + "Resolving session and thread" + ); + let (session, thread_id) = self + .session_manager + .resolve_thread( + &message.user_id, + &message.channel, + message.thread_id.as_deref(), + ) + .await; + tracing::debug!( + message_id = %message.id, + thread_id = %thread_id, + "Resolved session and thread" + ); + + Ok((session, thread_id)) + } + + /// Hydrate a historical thread from DB into memory if not already present. + /// + /// Called before `resolve_thread` so that the session manager finds the + /// thread on lookup instead of creating a new one. + /// + /// Creates an in-memory thread with the exact UUID the frontend sent, + /// even when the conversation has zero messages (e.g. a brand-new + /// assistant thread). Without this, `resolve_thread` would mint a + /// fresh UUID and all messages would land in the wrong conversation. + pub(super) async fn maybe_hydrate_thread( + &self, + message: &IncomingMessage, + external_thread_id: &str, + ) -> Result<(), crate::error::Error> { + // Only hydrate UUID-shaped thread IDs (web gateway uses UUIDs) + let thread_uuid = match Uuid::parse_str(external_thread_id) { + Ok(id) => id, + Err(_) => return Ok(()), + }; + + // Check if already in memory + let session = self + .session_manager + .get_or_create_session(&message.user_id) + .await; + { + let sess = session.lock().await; + if sess.threads.contains_key(&thread_uuid) { + return Ok(()); + } + } + + // Load history from DB (may be empty for a newly created thread). + let mut chat_messages: Vec = Vec::new(); + let msg_count; + + if let Some(store) = self.store() { + let db_messages = store.list_conversation_messages(thread_uuid).await?; + msg_count = db_messages.len(); + chat_messages = rebuild_chat_messages_from_db(&db_messages, self.safety()); + } else { + msg_count = 0; + } + + // Create thread with the historical ID and restore messages + let session_id = { + let sess = session.lock().await; + sess.id + }; + + let mut thread = crate::agent::session::Thread::with_id(thread_uuid, session_id); + if !chat_messages.is_empty() { + thread.restore_from_messages(chat_messages); + } + + // Insert into session and register with session manager + { + let mut sess = session.lock().await; + if sess.threads.contains_key(&thread_uuid) { + return Ok(()); + } + sess.threads.insert(thread_uuid, thread); + sess.active_thread = Some(thread_uuid); + sess.last_active_at = chrono::Utc::now(); + } + + self.session_manager + .register_thread( + &message.user_id, + &message.channel, + thread_uuid, + Arc::clone(&session), + ) + .await; + + tracing::debug!( + "Hydrated thread {} from DB ({} messages)", + thread_uuid, + msg_count + ); + + Ok(()) + } +} diff --git a/src/agent/thread_ops/persistence.rs b/src/agent/thread_ops/persistence.rs index eb3ea0739..370642244 100644 --- a/src/agent/thread_ops/persistence.rs +++ b/src/agent/thread_ops/persistence.rs @@ -2,9 +2,55 @@ //! //! Contains utilities for building database parameters and managing conversation persistence. -use crate::db::EnsureConversationParams; +use std::sync::Arc; + use uuid::Uuid; +use crate::agent::Agent; +use crate::channels::web::util::truncate_preview; +use crate::db::EnsureConversationParams; + +/// Context for persisting turn-related data. +/// +/// Groups thread_id, user_id, and turn_number to reduce the argument count +/// of persistence functions (addresses CodeScene "Excess Number of Function Arguments"). +#[derive(Clone)] +pub(crate) struct TurnPersistContext<'a> { + pub thread_id: Uuid, + pub user_id: &'a str, + pub turn_number: usize, +} + +/// Convert a JSON value to a preview string with the given character limit. +fn value_to_preview(v: &serde_json::Value, limit: usize) -> String { + match v { + serde_json::Value::String(s) => truncate_preview(s, limit), + other => truncate_preview(&other.to_string(), limit), + } +} + +/// Summarize a single tool call into a JSON object. +fn summarize_tool_call( + turn_number: usize, + i: usize, + tc: &crate::agent::session::TurnToolCall, +) -> serde_json::Value { + let mut obj = serde_json::json!({ + "name": tc.name, + "call_id": format!("turn{}_{}", turn_number, i), + "parameters": serde_json::to_value(&tc.parameters) + .unwrap_or_else(|_| serde_json::json!({})), + }); + if let Some(ref result) = tc.result { + obj["result_preview"] = serde_json::Value::String(value_to_preview(result, 500)); + obj["result"] = result.clone(); + } + if let Some(ref error) = tc.error { + obj["error"] = serde_json::Value::String(error.clone()); + } + obj +} + /// Helper to build EnsureConversationParams for gateway conversations. /// /// Gateway conversations use channel="gateway", id=thread_id, and thread_id=None. @@ -19,3 +65,117 @@ pub(super) fn gateway_conversation_params( thread_id: None, } } + +impl Agent { + /// Persist the user message to the DB at turn start (before the agentic loop). + /// + /// This ensures the user message is durable even if the process crashes + /// mid-response. Call this right after `thread.start_turn()`. + pub(super) async fn persist_user_message( + &self, + thread_id: Uuid, + user_id: &str, + user_input: &str, + ) { + let store = match self.store() { + Some(s) => Arc::clone(s), + None => return, + }; + + if let Err(e) = store + .ensure_conversation(gateway_conversation_params(thread_id, user_id)) + .await + { + tracing::warn!("Failed to ensure conversation {}: {}", thread_id, e); + return; + } + + if let Err(e) = store + .add_conversation_message(thread_id, "user", user_input) + .await + { + tracing::warn!("Failed to persist user message: {}", e); + } + } + + /// Persist the assistant response to the DB after the agentic loop completes. + /// + /// Re-ensures the conversation row exists so that assistant responses are + /// still persisted even if `persist_user_message` failed transiently at + /// turn start (e.g. a brief DB blip that resolved before response time). + pub(super) async fn persist_assistant_response( + &self, + thread_id: Uuid, + user_id: &str, + response: &str, + ) { + let store = match self.store() { + Some(s) => Arc::clone(s), + None => return, + }; + + if let Err(e) = store + .ensure_conversation(gateway_conversation_params(thread_id, user_id)) + .await + { + tracing::warn!("Failed to ensure conversation {}: {}", thread_id, e); + return; + } + + if let Err(e) = store + .add_conversation_message(thread_id, "assistant", response) + .await + { + tracing::warn!("Failed to persist assistant message: {}", e); + } + } + + /// Persist tool call summaries to the DB as a `role="tool_calls"` message. + /// + /// Stored between the user and assistant messages so that + /// `build_turns_from_db_messages` can reconstruct the tool call history. + /// Content is a JSON array of tool call summaries. + pub(super) async fn persist_tool_calls( + &self, + ctx: &TurnPersistContext<'_>, + tool_calls: &[crate::agent::session::TurnToolCall], + ) { + if tool_calls.is_empty() { + return; + } + + let store = match self.store() { + Some(s) => Arc::clone(s), + None => return, + }; + + let summaries: Vec = tool_calls + .iter() + .enumerate() + .map(|(i, tc)| summarize_tool_call(ctx.turn_number, i, tc)) + .collect(); + + let content = match serde_json::to_string(&summaries) { + Ok(c) => c, + Err(e) => { + tracing::warn!("Failed to serialize tool calls: {}", e); + return; + } + }; + + if let Err(e) = store + .ensure_conversation(gateway_conversation_params(ctx.thread_id, ctx.user_id)) + .await + { + tracing::warn!("Failed to ensure conversation {}: {}", ctx.thread_id, e); + return; + } + + if let Err(e) = store + .add_conversation_message(ctx.thread_id, "tool_calls", &content) + .await + { + tracing::warn!("Failed to persist tool calls: {}", e); + } + } +} diff --git a/src/agent/thread_ops/turn_compaction_checkpointing.rs b/src/agent/thread_ops/turn_compaction_checkpointing.rs new file mode 100644 index 000000000..fb8c17ea0 --- /dev/null +++ b/src/agent/thread_ops/turn_compaction_checkpointing.rs @@ -0,0 +1,121 @@ +//! Context compaction and checkpoint helpers for user turns. + +use std::sync::Arc; + +use tokio::sync::Mutex; +use uuid::Uuid; + +use crate::agent::Agent; +use crate::agent::compaction::ContextCompactor; +use crate::agent::session::{Session, Thread}; +use crate::channels::{IncomingMessage, StatusUpdate}; +use crate::error::Error; + +impl Agent { + async fn notify_compaction_status(&self, message: &IncomingMessage, pct: f32) { + let _ = self + .channels + .send_status( + &message.channel, + StatusUpdate::Status(format!("Context at {:.0}% capacity, compacting...", pct)), + &message.metadata, + ) + .await; + } + + async fn try_compact_snapshot( + &self, + snapshot: &mut Thread, + strategy: crate::agent::context_monitor::CompactionStrategy, + ) -> bool { + let compactor = ContextCompactor::new(self.llm().clone()); + match compactor + .compact(snapshot, strategy, self.workspace().map(|w| w.as_ref())) + .await + { + Ok(_) => true, + Err(e) => { + tracing::warn!("Auto-compaction failed: {}", e); + false + } + } + } + + async fn apply_compaction_if_fresh( + &self, + session: &Arc>, + thread_id: Uuid, + snapshot: Thread, + ) { + let mut sess = session.lock().await; + if let Some(thread) = sess.threads.get_mut(&thread_id) { + if thread.updated_at == snapshot.updated_at + && thread.turns.len() == snapshot.turns.len() + { + *thread = snapshot; + } else { + tracing::warn!( + thread_id = %thread_id, + "Skipped applying stale auto-compaction result" + ); + } + } + } + + /// Auto-compact context if needed before adding new turn. + pub(super) async fn maybe_compact_context( + &self, + message: &IncomingMessage, + session: &Arc>, + thread_id: Uuid, + ) -> Result<(), Error> { + let mut thread_snapshot = { + let sess = session.lock().await; + sess.threads + .get(&thread_id) + .cloned() + .ok_or_else(|| Error::from(crate::error::JobError::NotFound { id: thread_id }))? + }; + + let messages = thread_snapshot.messages(); + let Some(strategy) = self.context_monitor.suggest_compaction(&messages) else { + return Ok(()); + }; + + let pct = self.context_monitor.usage_percent(&messages); + tracing::info!("Context at {:.1}% capacity, auto-compacting", pct); + self.notify_compaction_status(message, pct as f32).await; + + if !self + .try_compact_snapshot(&mut thread_snapshot, strategy) + .await + { + return Ok(()); + } + + self.apply_compaction_if_fresh(session, thread_id, thread_snapshot) + .await; + Ok(()) + } + + /// Create checkpoint before turn. + pub(super) async fn checkpoint_before_turn( + &self, + session: &Arc>, + thread_id: Uuid, + ) -> Result<(), Error> { + let undo_mgr = self.session_manager.get_undo_manager(thread_id).await; + let (turn_number, messages) = { + let sess = session.lock().await; + let thread = sess + .threads + .get(&thread_id) + .ok_or_else(|| Error::from(crate::error::JobError::NotFound { id: thread_id }))?; + (thread.turn_number(), thread.messages()) + }; + + let mut mgr = undo_mgr.lock().await; + mgr.checkpoint(turn_number, messages, format!("Before turn {turn_number}")); + Ok(()) + } +} diff --git a/src/agent/thread_ops/turn_execution.rs b/src/agent/thread_ops/turn_execution.rs new file mode 100644 index 000000000..b170f9a7f --- /dev/null +++ b/src/agent/thread_ops/turn_execution.rs @@ -0,0 +1,80 @@ +//! User turn execution and agentic loop orchestration. +//! +//! Keeps the top-level phase ordering in one place while sibling modules own +//! turn preparation, context compaction/checkpointing, and result +//! finalisation. + +use crate::agent::Agent; +use crate::agent::submission::SubmissionResult; +use crate::agent::thread_ops::{PrepareTurnResult, UserTurnRequest}; +use crate::channels::{IncomingMessage, StatusUpdate}; +use crate::error::Error; + +impl Agent { + pub(super) async fn process_user_input( + &self, + message: &IncomingMessage, + req: UserTurnRequest, + ) -> Result { + tracing::debug!( + message_id = %message.id, + thread_id = %req.thread_id, + content_len = req.content.len(), + "Processing user input" + ); + + // Phase 1: Check thread state + if let Some(result) = self + .check_thread_state(message, &req.session, req.thread_id) + .await? + { + return Ok(result); + } + + // Phase 2: Safety validation + if let Some(result) = self.validate_safety(message, &req.content) { + return Ok(result); + } + + // Phase 3: Route explicit commands + let temp_message = IncomingMessage { + content: req.content.to_string(), + ..message.clone() + }; + if let Some(intent) = self.router.route_command(&temp_message) { + return self.handle_job_or_command(intent, message).await; + } + + // Phase 4: Auto-compact context if needed + self.maybe_compact_context(message, &req.session, req.thread_id) + .await?; + + // Phase 5: Create checkpoint + self.checkpoint_before_turn(&req.session, req.thread_id) + .await?; + + // Phase 6: Prepare turn + let turn_messages = match self.prepare_turn(message, &req).await? { + PrepareTurnResult::Prepared { turn_messages } => turn_messages, + PrepareTurnResult::Rejected(result) => return Ok(result), + }; + + // Phase 7: Send thinking status and run agentic loop + let _ = self + .channels + .send_status( + &message.channel, + StatusUpdate::Thinking("Processing...".into()), + &message.metadata, + ) + .await; + + let result = self + .run_agentic_loop(message, req.session.clone(), req.thread_id, turn_messages) + .await; + + // Phase 8: Handle loop result + self.handle_loop_result(message, &req.session, req.thread_id, result) + .await + } +} diff --git a/src/agent/thread_ops/turn_preparation.rs b/src/agent/thread_ops/turn_preparation.rs new file mode 100644 index 000000000..2d74bc327 --- /dev/null +++ b/src/agent/thread_ops/turn_preparation.rs @@ -0,0 +1,171 @@ +//! Turn preparation helpers for interactive user input. + +use std::sync::Arc; + +use tokio::sync::Mutex; +use uuid::Uuid; + +use crate::agent::Agent; +use crate::agent::session::{Session, ThreadState}; +use crate::agent::submission::SubmissionResult; +use crate::channels::IncomingMessage; +use crate::error::Error; + +/// Request parameters for processing a user turn. +/// +/// Groups the session, thread ID, and content to reduce the argument count +/// of `process_user_input` (addresses CodeScene "Excess Number of Function Arguments"). +#[derive(Clone)] +pub(crate) struct UserTurnRequest { + pub session: Arc>, + pub thread_id: Uuid, + pub content: String, +} + +pub(crate) enum PrepareTurnResult { + Prepared { + turn_messages: Vec, + }, + Rejected(SubmissionResult), +} + +impl Agent { + fn thread_state_submission_result(&self, state: ThreadState) -> Option { + match state { + ThreadState::Processing => Some(SubmissionResult::error( + "Turn in progress. Use /interrupt to cancel.", + )), + ThreadState::AwaitingApproval => Some(SubmissionResult::error( + "Waiting for approval. Use /interrupt to cancel.", + )), + ThreadState::Completed => Some(SubmissionResult::error( + "Thread completed. Use /thread new.", + )), + ThreadState::Idle | ThreadState::Interrupted => None, + } + } + + /// Check thread state and return error if not in a processable state. + pub(super) async fn check_thread_state( + &self, + message: &IncomingMessage, + session: &Arc>, + thread_id: Uuid, + ) -> Result, Error> { + let thread_state = { + let sess = session.lock().await; + let thread = sess + .threads + .get(&thread_id) + .ok_or_else(|| Error::from(crate::error::JobError::NotFound { id: thread_id }))?; + thread.state + }; + + tracing::debug!( + message_id = %message.id, + thread_id = %thread_id, + thread_state = ?thread_state, + "Checked thread state" + ); + + if let Some(result) = self.thread_state_submission_result(thread_state) { + tracing::warn!( + message_id = %message.id, + thread_id = %thread_id, + thread_state = ?thread_state, + "Thread state blocks new input" + ); + Ok(Some(result)) + } else { + Ok(None) + } + } + + /// Validate safety for user input. + pub(super) fn validate_safety( + &self, + message: &IncomingMessage, + content: &str, + ) -> Option { + let validation = self.safety().validate_input(content); + if !validation.is_valid { + let details = validation + .errors + .iter() + .map(|e| format!("{}: {}", e.field, e.message)) + .collect::>() + .join("; "); + return Some(SubmissionResult::error(format!( + "Input rejected by safety validation: {}", + details + ))); + } + + let violations = self.safety().check_policy(content); + if violations + .iter() + .any(|rule| rule.action == crate::safety::PolicyAction::Block) + { + return Some(SubmissionResult::error("Input rejected by safety policy.")); + } + + if let Some(warning) = self.safety().scan_inbound_for_secrets(content) { + tracing::warn!( + user = %message.user_id, + channel = %message.channel, + "Inbound message blocked: contains leaked secret" + ); + return Some(SubmissionResult::error(warning)); + } + + None + } + + /// Prepare turn by augmenting content and starting the turn. + pub(super) async fn prepare_turn( + &self, + message: &IncomingMessage, + req: &UserTurnRequest, + ) -> Result { + let content = req.content.as_str(); + let augmented = + crate::agent::attachments::augment_with_attachments(content, &message.attachments); + let (effective_content, image_parts) = match &augmented { + Some(result) => (result.text.as_str(), result.image_parts.clone()), + None => (content, Vec::new()), + }; + + if let Some(result) = self.validate_safety(message, effective_content) { + return Ok(PrepareTurnResult::Rejected(result)); + } + + let turn_messages = { + let mut sess = req.session.lock().await; + let thread = sess.threads.get_mut(&req.thread_id).ok_or_else(|| { + Error::from(crate::error::JobError::NotFound { id: req.thread_id }) + })?; + if let Some(result) = self.thread_state_submission_result(thread.state) { + return Ok(PrepareTurnResult::Rejected(result)); + } + let turn = thread.start_turn(effective_content); + turn.image_content_parts = image_parts; + thread.messages() + }; + + tracing::debug!( + message_id = %message.id, + thread_id = %req.thread_id, + "Persisting user message to DB" + ); + self.persist_user_message(req.thread_id, &message.user_id, effective_content) + .await; + + tracing::debug!( + message_id = %message.id, + thread_id = %req.thread_id, + "User message persisted, starting agentic loop" + ); + + Ok(PrepareTurnResult::Prepared { turn_messages }) + } +} diff --git a/src/agent/thread_ops/turn_result_finalisation.rs b/src/agent/thread_ops/turn_result_finalisation.rs new file mode 100644 index 000000000..6ac3119df --- /dev/null +++ b/src/agent/thread_ops/turn_result_finalisation.rs @@ -0,0 +1,166 @@ +//! Result finalisation helpers for completed user turns. + +use std::sync::Arc; + +use tokio::sync::Mutex; +use uuid::Uuid; + +use crate::agent::Agent; +use crate::agent::dispatcher::AgenticLoopResult; +use crate::agent::session::{Session, ThreadState}; +use crate::agent::submission::SubmissionResult; +use crate::agent::thread_ops::TurnPersistContext; +use crate::channels::{IncomingMessage, StatusUpdate}; +use crate::error::Error; + +impl Agent { + /// Apply response transform hook. + async fn apply_response_transform_hook( + &self, + message: &IncomingMessage, + thread_id: Uuid, + response: String, + ) -> String { + let event = crate::hooks::HookEvent::ResponseTransform { + user_id: message.user_id.clone(), + thread_id: thread_id.to_string(), + response: response.clone(), + }; + match self.hooks().run(&event).await { + Err(crate::hooks::HookError::Rejected { reason }) => { + format!("[Response filtered: {}]", reason) + } + Ok(crate::hooks::HookOutcome::Reject { reason }) => { + format!("[Response filtered: {}]", reason) + } + Err(err) => { + tracing::warn!("TransformResponse hook failed open: {}", err); + response + } + Ok(crate::hooks::HookOutcome::Continue { + modified: Some(new_response), + }) => new_response, + _ => response, + } + } + + /// Handle the result from the agentic loop. + pub(super) async fn handle_loop_result( + &self, + message: &IncomingMessage, + session: &Arc>, + thread_id: Uuid, + result: Result, + ) -> Result { + let interrupted = { + let mut sess = session.lock().await; + let thread = sess + .threads + .get_mut(&thread_id) + .ok_or_else(|| Error::from(crate::error::JobError::NotFound { id: thread_id }))?; + thread.state == ThreadState::Interrupted + }; + + if interrupted { + let _ = self + .channels + .send_status( + &message.channel, + StatusUpdate::Status("Interrupted".into()), + &message.metadata, + ) + .await; + return Ok(SubmissionResult::Interrupted); + } + + let mut sess = session.lock().await; + let thread = sess + .threads + .get_mut(&thread_id) + .ok_or_else(|| Error::from(crate::error::JobError::NotFound { id: thread_id }))?; + + match result { + Ok(AgenticLoopResult::Response(response)) => { + drop(sess); + let response = self + .apply_response_transform_hook(message, thread_id, response) + .await; + + let completion = { + let mut sess = session.lock().await; + let thread = sess.threads.get_mut(&thread_id).ok_or_else(|| { + Error::from(crate::error::JobError::NotFound { id: thread_id }) + })?; + if thread.state == ThreadState::Interrupted { + None + } else { + thread.complete_turn(&response); + thread + .turns + .last() + .map(|t| (t.turn_number, t.tool_calls.clone())) + } + }; + + let Some((turn_number, tool_calls)) = completion else { + let _ = self + .channels + .send_status( + &message.channel, + StatusUpdate::Status("Interrupted".into()), + &message.metadata, + ) + .await; + return Ok(SubmissionResult::Interrupted); + }; + + let _ = self + .channels + .send_status( + &message.channel, + StatusUpdate::Status("Done".into()), + &message.metadata, + ) + .await; + + let persist_ctx = TurnPersistContext { + thread_id, + user_id: &message.user_id, + turn_number, + }; + self.persist_tool_calls(&persist_ctx, &tool_calls).await; + self.persist_assistant_response(thread_id, &message.user_id, &response) + .await; + + Ok(SubmissionResult::response(response)) + } + Ok(AgenticLoopResult::NeedApproval { pending }) => { + let request_id = pending.request_id; + let tool_name = pending.tool_name.clone(); + let description = pending.description.clone(); + let parameters = pending.display_parameters.clone(); + thread.await_approval(pending); + drop(sess); + + let _ = self + .channels + .send_status( + &message.channel, + StatusUpdate::Status("Awaiting approval".into()), + &message.metadata, + ) + .await; + Ok(SubmissionResult::NeedApproval { + request_id, + tool_name, + description, + parameters, + }) + } + Err(e) => { + thread.fail_turn(e.to_string()); + Ok(SubmissionResult::error(e.to_string())) + } + } + } +} diff --git a/src/channels/web/server/tests/memory.rs b/src/channels/web/server/tests/memory.rs index a25dfa8ce..13c555c14 100644 --- a/src/channels/web/server/tests/memory.rs +++ b/src/channels/web/server/tests/memory.rs @@ -1,33 +1,33 @@ //! Tests for the web gateway memory search and read routes. -#[cfg(feature = "libsql")] +#[cfg(all(feature = "libsql", feature = "test-helpers"))] use axum::{Router, body::Body, routing::get, routing::post}; -#[cfg(feature = "libsql")] +#[cfg(all(feature = "libsql", feature = "test-helpers"))] use rstest::{fixture, rstest}; -#[cfg(feature = "libsql")] +#[cfg(all(feature = "libsql", feature = "test-helpers"))] use tempfile::TempDir; -#[cfg(feature = "libsql")] +#[cfg(all(feature = "libsql", feature = "test-helpers"))] use tower::ServiceExt; -#[cfg(feature = "libsql")] +#[cfg(all(feature = "libsql", feature = "test-helpers"))] use super::fixtures::{TestGatewayStateFactory, test_gateway_state}; -#[cfg(feature = "libsql")] +#[cfg(all(feature = "libsql", feature = "test-helpers"))] use crate::channels::web::handlers::memory::{ memory_read_handler, memory_search_handler, memory_tree_handler, }; -#[cfg(feature = "libsql")] +#[cfg(all(feature = "libsql", feature = "test-helpers"))] use crate::workspace::Workspace; -#[cfg(feature = "libsql")] +#[cfg(all(feature = "libsql", feature = "test-helpers"))] use axum::http::StatusCode; -#[cfg(feature = "libsql")] +#[cfg(all(feature = "libsql", feature = "test-helpers"))] type TestWorkspaceFixture = (std::sync::Arc, TempDir); -#[cfg(feature = "libsql")] +#[cfg(all(feature = "libsql", feature = "test-helpers"))] #[derive(Clone, Copy, Debug, Default)] struct TestWorkspaceFactory; -#[cfg(feature = "libsql")] +#[cfg(all(feature = "libsql", feature = "test-helpers"))] impl TestWorkspaceFactory { async fn build(self) -> TestWorkspaceFixture { let (db, temp_dir) = crate::testing::test_db().await; @@ -38,13 +38,13 @@ impl TestWorkspaceFactory { } } -#[cfg(feature = "libsql")] +#[cfg(all(feature = "libsql", feature = "test-helpers"))] #[fixture] fn test_workspace() -> TestWorkspaceFactory { TestWorkspaceFactory } -#[cfg(feature = "libsql")] +#[cfg(all(feature = "libsql", feature = "test-helpers"))] #[rstest] #[tokio::test] async fn test_memory_search_results_round_trip_via_read_path( @@ -109,7 +109,7 @@ async fn test_memory_search_results_round_trip_via_read_path( assert_eq!(read_json["content"], "alpha needle beta"); } -#[cfg(feature = "libsql")] +#[cfg(all(feature = "libsql", feature = "test-helpers"))] #[rstest] #[tokio::test] async fn test_memory_tree_honours_depth_query( diff --git a/src/channels/webhook_server.rs b/src/channels/webhook_server.rs index e8ea89e2a..2a43f9ff4 100644 --- a/src/channels/webhook_server.rs +++ b/src/channels/webhook_server.rs @@ -58,6 +58,28 @@ impl WebhookServer { self.bind_and_spawn(app).await } + /// Bind using an already-bound [`tokio::net::TcpListener`], merge all route + /// fragments, and spawn the server. The listener's local address is stored + /// in `config.addr` so `current_addr()` stays accurate. + pub async fn start_with_listener( + &mut self, + listener: tokio::net::TcpListener, + ) -> Result<(), ChannelError> { + let addr = listener + .local_addr() + .map_err(|e| ChannelError::StartupFailed { + name: "webhook_server".to_string(), + reason: format!("local_addr failed: {e}"), + })?; + self.config.addr = addr; + let mut app = Router::new(); + for fragment in self.routes.drain(..) { + app = app.merge(fragment); + } + self.merged_router = Some(app.clone()); + self.spawn_on_listener(listener, app).await + } + /// Bind a listener to the configured address and spawn the server task. /// Private helper used by both start() and restart_with_addr(). async fn bind_and_spawn(&mut self, app: Router) -> Result<(), ChannelError> { @@ -65,14 +87,28 @@ impl WebhookServer { .await .map_err(|e| ChannelError::StartupFailed { name: "webhook_server".to_string(), - reason: format!("Failed to bind to {}: {}", self.config.addr, e), + reason: format!("Failed to bind to {}: {e}", self.config.addr), + })?; + let addr = listener + .local_addr() + .map_err(|e| ChannelError::StartupFailed { + name: "webhook_server".to_string(), + reason: format!("local_addr failed: {e}"), })?; + self.config.addr = addr; + self.spawn_on_listener(listener, app).await + } + /// Spawn the server on an already-bound listener. + /// Private helper that contains the common shutdown-channel and task-spawn logic. + async fn spawn_on_listener( + &mut self, + listener: tokio::net::TcpListener, + app: Router, + ) -> Result<(), ChannelError> { tracing::info!("Webhook server listening on {}", self.config.addr); - let (shutdown_tx, shutdown_rx) = oneshot::channel(); self.shutdown_tx = Some(shutdown_tx); - let handle = tokio::spawn(async move { if let Err(e) = axum::serve(listener, app) .with_graceful_shutdown(async { @@ -81,39 +117,29 @@ impl WebhookServer { }) .await { - tracing::error!("Webhook server error: {}", e); + tracing::error!("Webhook server error: {e}"); } }); - self.handle = Some(handle); Ok(()) } - /// Gracefully shut down the current listener and rebind to a new address. - /// The merged router from the original `start()` call is reused. - /// - /// If binding to the new address fails, the old listener remains active and - /// state is restored. This prevents a denial-of-service if the new address - /// is invalid or already in use. - pub async fn restart_with_addr(&mut self, new_addr: SocketAddr) -> Result<(), ChannelError> { - let app = self - .merged_router - .clone() - .ok_or_else(|| ChannelError::StartupFailed { - name: "webhook_server".to_string(), - reason: "restart_with_addr called before start()".to_string(), - })?; - - // Save old state for rollback if new bind fails + /// Shared restart kernel. Saves current listener state, spawns the server on + /// `listener` bound at `new_addr`, shuts down the old server on success, or + /// restores the previous state on failure. + async fn swap_listener( + &mut self, + new_addr: SocketAddr, + listener: tokio::net::TcpListener, + app: Router, + ) -> Result<(), ChannelError> { let old_addr = self.config.addr; let old_shutdown_tx = self.shutdown_tx.take(); let old_handle = self.handle.take(); - // Update config to new address and try to bind self.config.addr = new_addr; - match self.bind_and_spawn(app).await { + match self.spawn_on_listener(listener, app).await { Ok(()) => { - // New listener is running, gracefully shut down the old one if let Some(tx) = old_shutdown_tx { let _ = tx.send(()); } @@ -123,7 +149,6 @@ impl WebhookServer { Ok(()) } Err(e) => { - // Restore old state; old listener remains active self.config.addr = old_addr; self.shutdown_tx = old_shutdown_tx; self.handle = old_handle; @@ -132,6 +157,65 @@ impl WebhookServer { } } + /// Gracefully shut down the current listener and rebind to a new address. + /// The merged router from the original `start()` call is reused. + /// + /// If binding to the new address fails, the old listener remains active and + /// state is restored. This prevents a denial-of-service if the new address + /// is invalid or already in use. + pub async fn restart_with_addr(&mut self, new_addr: SocketAddr) -> Result<(), ChannelError> { + let app = self + .merged_router + .clone() + .ok_or_else(|| ChannelError::StartupFailed { + name: "webhook_server".to_string(), + reason: "restart_with_addr called before start()".to_string(), + })?; + + let listener = tokio::net::TcpListener::bind(new_addr).await.map_err(|e| { + ChannelError::StartupFailed { + name: "webhook_server".to_string(), + reason: format!("Failed to bind to {new_addr}: {e}"), + } + })?; + let addr = listener + .local_addr() + .map_err(|e| ChannelError::StartupFailed { + name: "webhook_server".to_string(), + reason: format!("local_addr failed: {e}"), + })?; + + self.swap_listener(addr, listener, app).await + } + + /// Shut down the running server and restart it on the already-bound + /// `listener`, inheriting all previously added routes from + /// `self.merged_router`. + pub async fn restart_with_listener( + &mut self, + listener: tokio::net::TcpListener, + ) -> Result<(), ChannelError> { + let app = self + .merged_router + .clone() + .ok_or_else(|| ChannelError::StartupFailed { + name: "webhook_server".to_string(), + reason: "restart_with_listener called before start()".to_string(), + })?; + + // Extract address from the provided listener before mutating self, + // so that old_addr, old_shutdown_tx and old_handle remain intact + // until we know local_addr() succeeds. + let addr = listener + .local_addr() + .map_err(|e| ChannelError::StartupFailed { + name: "webhook_server".to_string(), + reason: format!("local_addr failed: {e}"), + })?; + + self.swap_listener(addr, listener, app).await + } + /// Return the current bind address. pub fn current_addr(&self) -> SocketAddr { self.config.addr @@ -155,169 +239,3 @@ impl WebhookServer { } } } - -#[cfg(test)] -mod tests { - use std::net::TcpListener as StdTcpListener; - - use axum::Json; - use rstest::{fixture, rstest}; - use serde_json::json; - - use super::*; - - /// A started webhook server with a `/health` route and a pre-built client. - struct StartedWebhookServer { - server: WebhookServer, - addr: SocketAddr, - client: reqwest::Client, - } - - /// Finds an available port, creates a [`WebhookServer`] with a `/health` - /// route, starts the server, and returns the address and a client. - #[fixture] - async fn started_webhook_server() - -> Result> { - let port = { - let listener = StdTcpListener::bind("127.0.0.1:0")?; - listener.local_addr()?.port() - }; - let addr: SocketAddr = format!("127.0.0.1:{}", port).parse()?; - let mut server = WebhookServer::new(WebhookServerConfig { addr }); - server.add_routes(Router::new().route( - "/health", - axum::routing::get(|| async { Json(json!({"status": "ok"})) }), - )); - server.start().await?; - Ok(StartedWebhookServer { - server, - addr, - client: reqwest::Client::new(), - }) - } - - #[rstest] - #[tokio::test] - async fn test_restart_with_addr_rebinds_listener( - #[future] started_webhook_server: Result< - StartedWebhookServer, - Box, - >, - ) -> Result<(), Box> { - let StartedWebhookServer { - mut server, - addr: addr1, - client, - } = started_webhook_server.await?; - - assert_eq!( - server.current_addr(), - addr1, - "Server should be bound to initial address" - ); - - let response = client - .get(format!("http://{}/health", addr1)) - .send() - .await?; - assert_eq!( - response.status(), - 200, - "First server should respond to health check" - ); - - // Find a second available port and restart - let port2 = { - let listener = StdTcpListener::bind("127.0.0.1:0")?; - listener.local_addr()?.port() - }; - let addr2: SocketAddr = format!("127.0.0.1:{}", port2).parse()?; - - server.restart_with_addr(addr2).await?; - - assert_eq!( - server.current_addr(), - addr2, - "Server address should be updated after restart" - ); - assert_ne!( - addr1, addr2, - "Address should change after restart_with_addr" - ); - - let response = client - .get(format!("http://{}/health", addr2)) - .send() - .await?; - assert_eq!( - response.status(), - 200, - "Restarted server should respond to health check on new address" - ); - - let old_result = tokio::time::timeout( - std::time::Duration::from_millis(200), - client.get(format!("http://{}/health", addr1)).send(), - ) - .await; - assert!( - old_result.is_err() || old_result.ok().and_then(|r| r.ok()).is_none(), - "Old address should not respond after server restarts" - ); - - server.shutdown().await; - Ok(()) - } - - #[rstest] - #[tokio::test] - async fn test_restart_with_addr_rollback_on_bind_failure( - #[future] started_webhook_server: Result< - StartedWebhookServer, - Box, - >, - ) -> Result<(), Box> { - let StartedWebhookServer { - mut server, - addr: addr1, - client, - } = started_webhook_server.await?; - - let response = client - .get(format!("http://{}/health", addr1)) - .send() - .await?; - assert_eq!(response.status(), 200, "Server should be listening"); - - // Occupy a second port so the restart bind fails deterministically. - let occupied_listener = StdTcpListener::bind("127.0.0.1:0")?; - let conflict_addr = occupied_listener.local_addr()?; - - let result = server.restart_with_addr(conflict_addr).await; - assert!( - result.is_err(), - "Restart with already-bound address should fail" - ); - - drop(occupied_listener); - - let response = client - .get(format!("http://{}/health", addr1)) - .send() - .await?; - assert_eq!( - response.status(), - 200, - "Old listener should still be running after failed restart" - ); - - assert_eq!( - server.current_addr(), - addr1, - "Server address should be restored after failed restart" - ); - - server.shutdown().await; - Ok(()) - } -} diff --git a/src/context/manager.rs b/src/context/manager.rs index aef15435b..5ef8273f9 100644 --- a/src/context/manager.rs +++ b/src/context/manager.rs @@ -612,6 +612,39 @@ mod tests { assert_eq!(stuck[0], id2); } + #[tokio::test] + async fn find_stuck_contexts_returns_only_stuck_contexts() { + let manager = ContextManager::new(10); + let stuck_id = manager.create_job("stuck", "desc").await.unwrap(); + let active_id = manager.create_job("active", "desc").await.unwrap(); + + manager + .update_context(stuck_id, |ctx| { + ctx.transition_to(crate::context::JobState::InProgress, None) + }) + .await + .unwrap() + .unwrap(); + manager + .update_context(stuck_id, |ctx| ctx.mark_stuck("timeout")) + .await + .unwrap() + .unwrap(); + + manager + .update_context(active_id, |ctx| { + ctx.transition_to(crate::context::JobState::InProgress, None) + }) + .await + .unwrap() + .unwrap(); + + let stuck_contexts = manager.find_stuck_contexts().await; + + assert_eq!(stuck_contexts.len(), 1); + assert_eq!(stuck_contexts[0].job_id, stuck_id); + } + #[tokio::test] async fn active_count_tracks_non_terminal_jobs() { let manager = ContextManager::new(10); diff --git a/src/context/rollback_tests.rs b/src/context/rollback_tests.rs new file mode 100644 index 000000000..7257a8d2c --- /dev/null +++ b/src/context/rollback_tests.rs @@ -0,0 +1,208 @@ +//! Rollback-specific tests for `JobContext::set_state_rollback`. + +use super::*; + +fn all_job_states() -> [JobState; 8] { + [ + JobState::Pending, + JobState::InProgress, + JobState::Completed, + JobState::Submitted, + JobState::Accepted, + JobState::Failed, + JobState::Stuck, + JobState::Cancelled, + ] +} + +fn completion_timestamp_for(transitions: &[StateTransition]) -> Option> { + transitions + .iter() + .rev() + .find(|transition| { + matches!( + transition.to, + JobState::Completed | JobState::Accepted | JobState::Failed | JobState::Cancelled + ) + }) + .map(|transition| transition.timestamp) +} + +fn rollback_tracked_as_completed(state: JobState) -> bool { + matches!( + state, + JobState::Completed | JobState::Accepted | JobState::Failed | JobState::Cancelled + ) +} + +fn transition_snapshot( + transitions: &[StateTransition], +) -> Vec<(JobState, JobState, DateTime, Option)> { + transitions + .iter() + .map(|transition| { + ( + transition.from, + transition.to, + transition.timestamp, + transition.reason.clone(), + ) + }) + .collect() +} + +#[test] +fn test_set_state_rollback_ignores_mismatched_transition_history() { + let mut ctx = JobContext::new("Test", "Rollback mismatch test"); + ctx.transition_to(JobState::InProgress, None) + .expect("failed to transition to InProgress"); + ctx.transition_to(JobState::Completed, Some("Done".to_string())) + .expect("failed to transition to Completed"); + + let expected_state = ctx.state; + let expected_completed_at = ctx.completed_at; + let expected_transition_len = ctx.transitions.len(); + let expected_last_transition = ctx + .transitions + .last() + .map(|transition| (transition.from, transition.to, transition.reason.clone())); + + ctx.set_state_rollback(JobState::Pending); + + assert_eq!( + ctx.state, expected_state, + "rollback should not change state when the latest transition does not match" + ); + assert_eq!( + ctx.completed_at, expected_completed_at, + "rollback should not change completed_at when the latest transition does not match" + ); + assert_eq!( + ctx.transitions.len(), + expected_transition_len, + "rollback should not change transition count when the latest transition does not match" + ); + assert_eq!( + ctx.transitions.last().map(|transition| ( + transition.from, + transition.to, + transition.reason.clone() + )), + expected_last_transition, + "rollback should not change the latest transition when the latest transition does not match" + ); +} + +#[test] +fn test_set_state_rollback_applies_across_bounded_state_pairs() { + let base = Utc::now(); + + for (previous_idx, previous) in all_job_states().into_iter().enumerate() { + for (current_idx, current) in all_job_states().into_iter().enumerate() { + let mut ctx = JobContext::new("Test", "Rollback property test"); + let earlier_timestamp = + base + chrono::Duration::seconds((previous_idx * 10 + current_idx) as i64); + let rollback_timestamp = earlier_timestamp + chrono::Duration::seconds(1); + + ctx.transitions.push(StateTransition { + from: JobState::Pending, + to: JobState::Completed, + timestamp: earlier_timestamp, + reason: Some("earlier terminal".to_string()), + }); + ctx.transitions.push(StateTransition { + from: previous, + to: current, + timestamp: rollback_timestamp, + reason: Some("rollback edge".to_string()), + }); + ctx.state = current; + ctx.completed_at = Some(rollback_timestamp); + + let before_len = ctx.transitions.len(); + assert!( + ctx.last_transition_matches_rollback(previous), + "expected rollback edge to match for previous={previous:?}, current={current:?}" + ); + + ctx.set_state_rollback(previous); + + assert_eq!( + ctx.state, previous, + "rollback should restore previous state for previous={previous:?}, current={current:?}" + ); + assert_eq!( + ctx.transitions.len(), + before_len - 1, + "rollback should remove the latest transition for previous={previous:?}, current={current:?}" + ); + assert_eq!( + ctx.completed_at, + if rollback_tracked_as_completed(previous) { + completion_timestamp_for(&ctx.transitions) + } else { + None + }, + "rollback should recompute completed_at from remaining transitions for previous={previous:?}, current={current:?}" + ); + } + } +} + +#[test] +fn test_set_state_rollback_skips_mismatched_edges_across_bounded_state_pairs() { + let base = Utc::now(); + + for (previous_idx, previous) in all_job_states().into_iter().enumerate() { + for (current_idx, current) in all_job_states().into_iter().enumerate() { + let mut ctx = JobContext::new("Test", "Rollback mismatch property test"); + let earlier_timestamp = + base + chrono::Duration::seconds((previous_idx * 10 + current_idx) as i64); + let latest_timestamp = earlier_timestamp + chrono::Duration::seconds(1); + let mismatched_from = all_job_states() + .into_iter() + .find(|candidate| *candidate != previous) + .expect("expected at least one distinct JobState"); + + ctx.transitions.push(StateTransition { + from: JobState::Pending, + to: JobState::Accepted, + timestamp: earlier_timestamp, + reason: Some("earlier terminal".to_string()), + }); + ctx.transitions.push(StateTransition { + from: mismatched_from, + to: current, + timestamp: latest_timestamp, + reason: Some("mismatched rollback edge".to_string()), + }); + ctx.state = current; + ctx.completed_at = Some(latest_timestamp); + + let expected_state = ctx.state; + let expected_completed_at = ctx.completed_at; + let expected_transitions = transition_snapshot(&ctx.transitions); + + assert!( + !ctx.last_transition_matches_rollback(previous), + "expected rollback edge mismatch for previous={previous:?}, current={current:?}" + ); + + ctx.set_state_rollback(previous); + + assert_eq!( + ctx.state, expected_state, + "rollback should not change state when the edge mismatches for previous={previous:?}, current={current:?}" + ); + assert_eq!( + ctx.completed_at, expected_completed_at, + "rollback should not change completed_at when the edge mismatches for previous={previous:?}, current={current:?}" + ); + assert_eq!( + transition_snapshot(&ctx.transitions), + expected_transitions, + "rollback should not change transitions when the edge mismatches for previous={previous:?}, current={current:?}" + ); + } + } +} diff --git a/src/context/state.rs b/src/context/state.rs index 137bedc28..0a575ceb9 100644 --- a/src/context/state.rs +++ b/src/context/state.rs @@ -18,6 +18,9 @@ pub enum JobRecoveryError { /// Job is not in the Stuck state and cannot be recovered. #[error("Job is not stuck")] NotStuck, + /// An unexpected state-machine invariant was violated during recovery. + #[error("Recovery invariant violated: {0}")] + InvariantViolation(String), } /// State of a job. @@ -287,6 +290,47 @@ impl JobContext { Ok(()) } + /// Check whether the newest recorded transition matches a rollback from + /// `previous` back to the current in-memory state. + fn last_transition_matches_rollback(&self, previous: JobState) -> bool { + self.transitions + .last() + .is_some_and(|t| t.from == previous && t.to == self.state) + } + + /// Directly set the state without transition validation. + /// + /// Intended for rollback paths where the in-memory context must be + /// restored to a previous state after a persistence failure, bypassing + /// [`Self::transition_to`] validation. + pub(crate) fn set_state_rollback(&mut self, previous: JobState) { + if !self.last_transition_matches_rollback(previous) { + return; + } + self.transitions.pop(); + self.state = previous; + self.completed_at = if matches!( + self.state, + JobState::Completed | JobState::Accepted | JobState::Failed | JobState::Cancelled + ) { + self.transitions + .iter() + .rev() + .find(|t| { + matches!( + t.to, + JobState::Completed + | JobState::Accepted + | JobState::Failed + | JobState::Cancelled + ) + }) + .map(|t| t.timestamp) + } else { + None + }; + } + /// Add to the actual cost. pub fn add_cost(&mut self, cost: Decimal) { self.actual_cost += cost; @@ -345,7 +389,7 @@ impl JobContext { } self.repair_attempts += 1; self.transition_to(JobState::InProgress, Some("Recovery attempt".to_string())) - .map_err(|e| panic!("Failed to transition from Stuck to InProgress: {}", e)) + .map_err(JobRecoveryError::InvariantViolation) } } @@ -358,3 +402,7 @@ impl Default for JobContext { #[cfg(test)] #[path = "state_tests.rs"] mod tests; + +#[cfg(test)] +#[path = "rollback_tests.rs"] +mod rollback_tests; diff --git a/src/db/forwarders.rs b/src/db/forwarders.rs index e448ddc35..5f0259ac3 100644 --- a/src/db/forwarders.rs +++ b/src/db/forwarders.rs @@ -211,6 +211,7 @@ impl_db_forwarders! { dyn = Database, native = NativeDatabase, methods = { + fn persist_terminal_result_and_status(params: TerminalJobPersistence<'a>) -> Result<(), DatabaseError>; fn run_migrations() -> Result<(), DatabaseError>; } } diff --git a/src/db/libsql/jobs.rs b/src/db/libsql/jobs.rs index 61e2a109d..d28b88390 100644 --- a/src/db/libsql/jobs.rs +++ b/src/db/libsql/jobs.rs @@ -14,7 +14,9 @@ use super::{ opt_text, opt_text_owned, }; use crate::context::{ActionRecord, JobContext, JobState}; -use crate::db::{EstimationActualsParams, EstimationSnapshotParams, NativeJobStore}; +use crate::db::{ + EstimationActualsParams, EstimationSnapshotParams, NativeJobStore, TerminalJobPersistence, +}; use crate::error::DatabaseError; use crate::history::{AgentJobRecord, AgentJobSummary, LlmCallRecord}; @@ -116,6 +118,54 @@ impl LibSqlBackend { .map_err(|e| DatabaseError::Query(e.to_string()))?; Ok(()) } + + pub(crate) async fn persist_terminal_result_and_status( + &self, + params: TerminalJobPersistence<'_>, + ) -> Result<(), DatabaseError> { + let TerminalJobPersistence { + job_id, + status, + failure_reason, + event_type, + event_data, + } = params; + let conn = self.connect().await?; + let tx = conn + .transaction() + .await + .map_err(|e| DatabaseError::Query(e.to_string()))?; + tx.execute( + "INSERT INTO job_events (job_id, event_type, data) VALUES (?1, ?2, ?3)", + params![ + job_id.to_string(), + event_type.as_str().to_string(), + event_data.to_string() + ], + ) + .await + .map_err(|e| DatabaseError::Query(e.to_string()))?; + let rows_affected = tx + .execute( + "UPDATE agent_jobs SET status = ?2, failure_reason = ?3 WHERE id = ?1 AND source = 'direct'", + params![job_id.to_string(), status.to_string(), opt_text(failure_reason)], + ) + .await + .map_err(|e| DatabaseError::Query(e.to_string()))?; + if rows_affected == 0 { + tx.rollback() + .await + .map_err(|e| DatabaseError::Query(e.to_string()))?; + return Err(DatabaseError::NotFound { + entity: "agent_job".to_string(), + id: job_id.to_string(), + }); + } + tx.commit() + .await + .map_err(|e| DatabaseError::Query(e.to_string()))?; + Ok(()) + } } impl NativeJobStore for LibSqlBackend { @@ -321,3 +371,105 @@ impl NativeJobStore for LibSqlBackend { jobs_history::update_estimation_actuals(self, params).await } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::db::NativeDatabase; + use crate::db::SandboxEventType; + use chrono::Utc; + use serde_json::json; + + async fn count_job_events(backend: &LibSqlBackend, job_id: Uuid) -> i64 { + let conn = backend.connect().await.expect("connection should succeed"); + let mut rows = conn + .query( + "SELECT COUNT(*) FROM job_events WHERE job_id = ?1", + params![job_id.to_string()], + ) + .await + .expect("count query should succeed"); + let row = rows + .next() + .await + .expect("count row should load") + .expect("count row should exist"); + row.get::(0).expect("count column should decode") + } + + async fn seed_non_direct_job(backend: &LibSqlBackend, job_id: Uuid) { + let conn = backend.connect().await.expect("connection should succeed"); + conn.execute( + r#" + INSERT INTO agent_jobs ( + id, title, description, status, source, user_id, project_dir, created_at + ) VALUES (?1, ?2, ?3, ?4, 'sandbox', ?5, ?6, ?7) + "#, + params![ + job_id.to_string(), + "Sandbox test job", + "{}", + "creating", + "test-user", + "/tmp/test-project", + Utc::now().to_rfc3339(), + ], + ) + .await + .expect("sandbox job should seed"); + } + + #[tokio::test] + async fn persist_terminal_result_and_status_rejects_unknown_job_ids() { + let dir = tempfile::tempdir().expect("tempdir should succeed"); + let db_path = dir.path().join("jobs.sqlite"); + let backend = LibSqlBackend::new_local(&db_path) + .await + .expect("new_local should succeed"); + backend + .run_migrations() + .await + .expect("migrations should succeed"); + + let job_id = Uuid::new_v4(); + let result = backend + .persist_terminal_result_and_status(TerminalJobPersistence { + job_id, + status: JobState::Completed, + failure_reason: None, + event_type: SandboxEventType::from("result"), + event_data: &json!({"status": "completed"}), + }) + .await; + + assert!(result.is_err(), "unknown job ID should fail"); + assert_eq!( + count_job_events(&backend, job_id).await, + 0, + "unknown job ID should not leave a terminal event behind" + ); + + let sandbox_job_id = Uuid::new_v4(); + seed_non_direct_job(&backend, sandbox_job_id).await; + + let sandbox_result = backend + .persist_terminal_result_and_status(TerminalJobPersistence { + job_id: sandbox_job_id, + status: JobState::Completed, + failure_reason: None, + event_type: SandboxEventType::from("result"), + event_data: &json!({"status": "completed"}), + }) + .await; + + assert!( + sandbox_result.is_err(), + "non-direct job ID should fail terminal persistence" + ); + assert_eq!( + count_job_events(&backend, sandbox_job_id).await, + 0, + "non-direct job ID should not leave a terminal event behind" + ); + } +} diff --git a/src/db/libsql/mod.rs b/src/db/libsql/mod.rs index 4055a355a..fd1446dda 100644 --- a/src/db/libsql/mod.rs +++ b/src/db/libsql/mod.rs @@ -22,6 +22,7 @@ use std::sync::Arc; use crate::db::NativeDatabase; use crate::error::DatabaseError; use libsql::{Connection, Database as LibSqlDatabase}; +use tokio::fs; use crate::db::libsql_migrations; pub(crate) use helpers::{ @@ -40,20 +41,24 @@ pub struct LibSqlBackend { } impl LibSqlBackend { - /// Create a new local embedded database. - pub async fn new_local(path: &Path) -> Result { - // Ensure parent directory exists + /// Ensure the parent directory of `path` exists, creating it and all + /// ancestors if necessary. + async fn ensure_parent_dir(path: &Path) -> Result<(), DatabaseError> { if let Some(parent) = path.parent() { - std::fs::create_dir_all(parent).map_err(|e| { - DatabaseError::Pool(format!("Failed to create database directory: {}", e)) + fs::create_dir_all(parent).await.map_err(|e| { + DatabaseError::Pool(format!("Failed to create database directory: {e}")) })?; } + Ok(()) + } + /// Create a new local embedded database. + pub async fn new_local(path: &Path) -> Result { + Self::ensure_parent_dir(path).await?; let db = libsql::Builder::new_local(path) .build() .await - .map_err(|e| DatabaseError::Pool(format!("Failed to open libSQL database: {}", e)))?; - + .map_err(|e| DatabaseError::Pool(format!("Failed to open libSQL database: {e}")))?; Ok(Self { db: Arc::new(db) }) } @@ -75,17 +80,11 @@ impl LibSqlBackend { url: &str, auth_token: &str, ) -> Result { - if let Some(parent) = path.parent() { - std::fs::create_dir_all(parent).map_err(|e| { - DatabaseError::Pool(format!("Failed to create database directory: {}", e)) - })?; - } - + Self::ensure_parent_dir(path).await?; let db = libsql::Builder::new_remote_replica(path, url.to_string(), auth_token.to_string()) .build() .await - .map_err(|e| DatabaseError::Pool(format!("Failed to open remote replica: {}", e)))?; - + .map_err(|e| DatabaseError::Pool(format!("Failed to open remote replica: {e}")))?; Ok(Self { db: Arc::new(db) }) } @@ -137,6 +136,13 @@ impl LibSqlBackend { } impl NativeDatabase for LibSqlBackend { + async fn persist_terminal_result_and_status( + &self, + params: crate::db::TerminalJobPersistence<'_>, + ) -> Result<(), DatabaseError> { + LibSqlBackend::persist_terminal_result_and_status(self, params).await + } + async fn run_migrations(&self) -> Result<(), DatabaseError> { let conn = self.connect().await?; // WAL mode persists in the database file: all future connections benefit. diff --git a/src/db/mod.rs b/src/db/mod.rs index 79168079d..d1933761c 100644 --- a/src/db/mod.rs +++ b/src/db/mod.rs @@ -32,8 +32,8 @@ mod traits; pub use traits::{ ConversationStore, Database, JobStore, NativeConversationStore, NativeDatabase, NativeJobStore, NativeRoutineStore, NativeSandboxStore, NativeSettingsStore, NativeToolFailureStore, - NativeWorkspaceStore, RoutineStore, SandboxStore, SettingsStore, ToolFailureStore, - WorkspaceStore, + NativeWorkspaceStore, RoutineStore, SandboxStore, SettingsStore, TerminalJobPersistence, + ToolFailureStore, WorkspaceStore, }; mod types; diff --git a/src/db/postgres/mod.rs b/src/db/postgres/mod.rs index e307faca2..1ebe78ee2 100644 --- a/src/db/postgres/mod.rs +++ b/src/db/postgres/mod.rs @@ -56,6 +56,13 @@ impl PgBackend { } impl NativeDatabase for PgBackend { + async fn persist_terminal_result_and_status( + &self, + params: crate::db::TerminalJobPersistence<'_>, + ) -> Result<(), DatabaseError> { + self.store.persist_terminal_result_and_status(params).await + } + async fn run_migrations(&self) -> Result<(), DatabaseError> { self.store.run_migrations().await } diff --git a/src/db/traits/database.rs b/src/db/traits/database.rs index 784c3eb4e..61539ea98 100644 --- a/src/db/traits/database.rs +++ b/src/db/traits/database.rs @@ -5,6 +5,10 @@ use core::future::Future; +use uuid::Uuid; + +use crate::context::JobState; +use crate::db::SandboxEventType; use crate::db::params::DbFuture; use crate::error::DatabaseError; @@ -31,6 +35,12 @@ pub trait Database: + Send + Sync { + /// Parameters for atomically persisting a terminal job event and status. + fn persist_terminal_result_and_status<'a>( + &'a self, + params: TerminalJobPersistence<'a>, + ) -> DbFuture<'a, Result<(), DatabaseError>>; + /// Apply all pending schema migrations before the backend is used. /// /// Implementations must be idempotent, so callers may safely invoke this @@ -57,6 +67,12 @@ pub trait NativeDatabase: + Send + Sync { + /// Native async form of [`Database::persist_terminal_result_and_status`]. + fn persist_terminal_result_and_status<'a>( + &'a self, + params: TerminalJobPersistence<'a>, + ) -> impl Future> + Send + 'a; + /// Apply all pending schema migrations before the backend is used. /// /// Implementations must be idempotent, so callers may safely invoke this @@ -71,3 +87,17 @@ pub trait NativeDatabase: /// call sites run this once immediately after backend construction. fn run_migrations<'a>(&'a self) -> impl Future> + Send + 'a; } + +/// Parameters for atomically persisting a terminal event and terminal status. +pub struct TerminalJobPersistence<'a> { + /// Direct agent job UUID being updated. + pub job_id: Uuid, + /// Terminal job status to persist. + pub status: JobState, + /// Optional failure or completion reason to persist on the job row. + pub failure_reason: Option<&'a str>, + /// Event type written to `job_events`. + pub event_type: SandboxEventType, + /// Structured event payload written alongside the status transition. + pub event_data: &'a serde_json::Value, +} diff --git a/src/db/traits/mod.rs b/src/db/traits/mod.rs index c3abcb30d..0236abc55 100644 --- a/src/db/traits/mod.rs +++ b/src/db/traits/mod.rs @@ -14,7 +14,7 @@ pub mod tool_failure; pub mod workspace; pub use conversation::{ConversationStore, NativeConversationStore}; -pub use database::{Database, NativeDatabase}; +pub use database::{Database, NativeDatabase, TerminalJobPersistence}; pub use job::{JobStore, NativeJobStore}; pub use routine::{NativeRoutineStore, RoutineStore}; pub use sandbox::{NativeSandboxStore, SandboxStore}; diff --git a/src/history/store/jobs.rs b/src/history/store/jobs.rs index 89cb40c6a..5ec87b81b 100644 --- a/src/history/store/jobs.rs +++ b/src/history/store/jobs.rs @@ -14,6 +14,8 @@ use super::Store; #[cfg(feature = "postgres")] use crate::context::{JobContext, JobState}; #[cfg(feature = "postgres")] +use crate::db::TerminalJobPersistence; +#[cfg(feature = "postgres")] use crate::error::DatabaseError; #[cfg(feature = "postgres")] @@ -168,6 +170,47 @@ impl Store { Ok(()) } + /// Persist a terminal result event and terminal status in one transaction. + pub async fn persist_terminal_result_and_status( + &self, + params: TerminalJobPersistence<'_>, + ) -> Result<(), DatabaseError> { + let TerminalJobPersistence { + job_id, + status, + failure_reason, + event_type, + event_data, + } = params; + let mut conn = self.conn().await?; + let tx = conn.transaction().await?; + let status_str = status.to_string(); + + tx.execute( + r#" + INSERT INTO job_events (job_id, event_type, data) + VALUES ($1, $2, $3) + "#, + &[&job_id, &event_type.as_str(), event_data], + ) + .await?; + let rows_affected = tx + .execute( + "UPDATE agent_jobs SET status = $2, failure_reason = $3 WHERE id = $1 AND source = 'direct'", + &[&job_id, &status_str, &failure_reason], + ) + .await?; + if rows_affected != 1 { + tx.rollback().await?; + return Err(DatabaseError::NotFound { + entity: "agent_job".to_string(), + id: job_id.to_string(), + }); + } + tx.commit().await?; + Ok(()) + } + /// Mark job as stuck. pub async fn mark_job_stuck(&self, id: Uuid) -> Result<(), DatabaseError> { let conn = self.conn().await?; @@ -270,9 +313,46 @@ mod tests { #[cfg(feature = "postgres")] use crate::context::StateTransition; #[cfg(feature = "postgres")] + use crate::db::TerminalJobPersistence; + #[cfg(feature = "postgres")] + use crate::db::postgres::PgBackend; + #[cfg(feature = "postgres")] use crate::testing::postgres::try_test_pg_db; #[cfg(feature = "postgres")] use rstest::rstest; + #[cfg(feature = "postgres")] + use serde_json::json; + + #[cfg(feature = "postgres")] + enum RollbackScenario { + UnknownJob, + NonDirectJob, + } + + #[cfg(feature = "postgres")] + async fn prepare_job_for_rollback( + backend: &PgBackend, + store: &Store, + scenario: RollbackScenario, + ) -> Result<(Uuid, Option), Box> { + match scenario { + RollbackScenario::UnknownJob => Ok((Uuid::new_v4(), None)), + RollbackScenario::NonDirectJob => { + let ctx = JobContext::with_user("test-user", "sandbox-like job", "rollback check"); + let job_id = ctx.job_id; + store.save_job(&ctx).await?; + + let conn = backend.pool().get().await?; + conn.execute( + "UPDATE agent_jobs SET source = 'sandbox' WHERE id = $1", + &[&job_id], + ) + .await?; + + Ok((job_id, Some(ctx))) + } + } + } /// Regression test: save_job must persist user-owned and context fields. /// Requires a running PostgreSQL instance (integration tier). @@ -336,4 +416,45 @@ mod tests { assert_eq!(summary.failed, 12); assert_eq!(summary.stuck, 5); } + + #[cfg(feature = "postgres")] + #[rstest] + #[case(RollbackScenario::UnknownJob)] + #[case(RollbackScenario::NonDirectJob)] + #[tokio::test] + async fn persist_terminal_result_and_status_rolls_back_on_invalid_job( + #[case] scenario: RollbackScenario, + ) -> Result<(), Box> { + let Some(backend) = try_test_pg_db().await? else { + return Ok(()); + }; + let store = Store::from_pool(backend.pool()); + let (job_id, saved_ctx) = prepare_job_for_rollback(&backend, &store, scenario).await?; + + let result = store + .persist_terminal_result_and_status(TerminalJobPersistence { + job_id, + status: JobState::Failed, + failure_reason: Some("terminal rollback regression"), + event_type: crate::db::SandboxEventType::from("result"), + event_data: &json!({"status": "failed"}), + }) + .await; + assert!(result.is_err(), "invalid terminal job write should fail"); + + let conn = backend.pool().get().await?; + let count: i64 = conn + .query_one( + "SELECT COUNT(*) FROM job_events WHERE job_id = $1", + &[&job_id], + ) + .await? + .get(0); + assert_eq!(count, 0, "rollback should remove inserted job_events rows"); + if let Some(ctx) = saved_ctx { + conn.execute("DELETE FROM agent_jobs WHERE id = $1", &[&ctx.job_id]) + .await?; + } + Ok(()) + } } diff --git a/src/llm/provider.rs b/src/llm/provider.rs index 3f6ad929a..afff6b4d0 100644 --- a/src/llm/provider.rs +++ b/src/llm/provider.rs @@ -195,7 +195,7 @@ impl CompletionRequest { } /// Response from a chat completion. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Default)] pub struct CompletionResponse { pub content: String, pub input_tokens: u32, @@ -210,8 +210,9 @@ pub struct CompletionResponse { } /// Why the completion finished. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] pub enum FinishReason { + #[default] Stop, Length, ToolUse, @@ -299,7 +300,7 @@ impl ToolCompletionRequest { } /// Response from a completion with potential tool calls. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Default)] pub struct ToolCompletionResponse { /// Text content (may be empty if tool calls are present). pub content: Option, @@ -661,6 +662,7 @@ pub fn strip_unsupported_tool_params( #[cfg(test)] mod tests { use super::*; + mod default_contracts; #[test] fn test_sanitize_preserves_valid_pairs() { diff --git a/src/llm/provider/tests/default_contracts.rs b/src/llm/provider/tests/default_contracts.rs new file mode 100644 index 000000000..35be97d6f --- /dev/null +++ b/src/llm/provider/tests/default_contracts.rs @@ -0,0 +1,56 @@ +//! Verifies `Default` implementations for LLM response types used by the provider. + +use super::*; + +macro_rules! assert_llm_defaults { + ($resp:expr, content_ok = $content_ok:expr, tool_calls_len = $tc_len:expr) => {{ + let r = &$resp; + assert!( + $content_ok + && $tc_len == 0 + && r.input_tokens == 0 + && r.output_tokens == 0 + && r.finish_reason == FinishReason::Stop + && r.cache_read_input_tokens == 0 + && r.cache_creation_input_tokens == 0, + "default {} mismatch: content_ok={}, tool_calls_len={}, in={}, out={}, finish_reason={:?}, cache_read={}, cache_create={}", + std::any::type_name_of_val(r), + $content_ok, + $tc_len, + r.input_tokens, + r.output_tokens, + r.finish_reason, + r.cache_read_input_tokens, + r.cache_creation_input_tokens + ); + }}; +} + +fn assert_finish_reason_is_stop(fr: FinishReason) { + assert!( + fr == FinishReason::Stop, + "FinishReason::default() should be Stop, got: {:?}", + fr + ); +} + +#[test] +fn default_finish_reason_is_stop() { + assert_finish_reason_is_stop(FinishReason::default()); +} + +#[test] +fn default_completion_response_matches_contract() { + let r = CompletionResponse::default(); + assert_llm_defaults!(r, content_ok = r.content.is_empty(), tool_calls_len = 0); +} + +#[test] +fn default_tool_completion_response_matches_contract() { + let r = ToolCompletionResponse::default(); + assert_llm_defaults!( + r, + content_ok = r.content.is_none(), + tool_calls_len = r.tool_calls.len() + ); +} diff --git a/src/testing/mod.rs b/src/testing/mod.rs index 09ad0951a..0a2f8ba66 100644 --- a/src/testing/mod.rs +++ b/src/testing/mod.rs @@ -1,10 +1,17 @@ -//! Test harness for constructing `AgentDeps` with sensible defaults. +//! Test harnesses, doubles, and helpers for crate-level tests. //! -//! Provides: -//! - [`StubLlm`]: A configurable LLM provider that returns a fixed response -//! - [`StubChannel`]: A configurable channel stub with message injection and response capture -//! - [`TestHarnessBuilder`]: Builder for wiring `AgentDeps` with defaults -//! - [`TestHarness`]: The assembled components ready for use in tests +//! The public surface here supports both full integration-style tests and +//! targeted unit tests. Use [`TestHarnessBuilder`] and [`TestHarness`] when a +//! test needs fully wired `AgentDeps` with sensible defaults, [`null_db`] when +//! the test needs null persistence or captured persistence calls, and +//! [`worker_harness`] when the focus is `Worker` setup and terminal-state +//! behaviour. +//! +//! The [`null_db`] exports cover both null persistence and call verification: +//! [`NullDatabase`] is the baseline no-op database, [`CapturingStore`] records +//! persistence interactions, and [`Calls`], [`EventCall`], +//! [`EventCallWithId`], [`StatusCall`], and [`StatusCallWithId`] expose the +//! captured status and event payloads for assertions. //! //! # Usage //! @@ -26,12 +33,22 @@ pub mod postgres; mod settings_tests; pub mod test_utils; +#[cfg(test)] +pub mod null_db; +#[cfg(test)] +pub use null_db::{ + Calls, CapturingStore, EventCall, EventCallWithId, NullDatabase, StatusCall, StatusCallWithId, +}; +#[cfg(test)] +pub mod worker_harness; + +use anyhow::Result; use std::sync::Arc; use std::sync::Mutex; use std::sync::atomic::{AtomicBool, AtomicU32, Ordering}; -use anyhow::Result; use rust_decimal::Decimal; +#[cfg(all(feature = "libsql", feature = "test-helpers"))] use tempfile::TempDir; use tokio::sync::mpsc; @@ -42,7 +59,7 @@ use crate::channels::{ use crate::db::Database; use crate::error::{ChannelError, LlmError}; -#[cfg(test)] +#[cfg(all(test, feature = "libsql", feature = "test-helpers"))] use crate::db::{ EnsureConversationParams, EstimationActualsParams, EstimationSnapshotParams, RoutineRunCompletion, RoutineRuntimeUpdate, SandboxJobStatusUpdate, SettingKey, UserId, @@ -60,7 +77,7 @@ use crate::tools::wasm::{Capabilities, WasmToolWrapper}; /// /// Returns the database and a `TempDir` guard — the database file is /// deleted when the guard is dropped. -#[cfg(feature = "libsql")] +#[cfg(all(feature = "libsql", feature = "test-helpers"))] pub async fn test_db() -> (Arc, TempDir) { use crate::db::libsql::LibSqlBackend; use tempfile::tempdir; @@ -374,7 +391,7 @@ pub struct TestHarness { pub channel: Option<(mpsc::Sender, ChannelManager)>, /// Temp directory guard — keeps the test database alive. Dropped /// automatically when the harness goes out of scope. - #[cfg(feature = "libsql")] + #[cfg(all(feature = "libsql", feature = "test-helpers"))] _temp_dir: TempDir, } @@ -433,7 +450,7 @@ impl TestHarnessBuilder { } /// Build the harness with defaults applied. - #[cfg(feature = "libsql")] + #[cfg(all(feature = "libsql", feature = "test-helpers"))] pub async fn build(self) -> TestHarness { use crate::agent::cost_guard::{CostGuard, CostGuardConfig}; use crate::config::{SafetyConfig, SkillsConfig}; @@ -516,7 +533,7 @@ impl Default for TestHarnessBuilder { mod tests { use super::*; - #[cfg(feature = "libsql")] + #[cfg(all(feature = "libsql", feature = "test-helpers"))] #[tokio::test] async fn test_harness_builds_with_defaults() { let harness = TestHarnessBuilder::new().build().await; @@ -524,7 +541,7 @@ mod tests { assert_eq!(harness.deps.llm.model_name(), "stub-model"); } - #[cfg(feature = "libsql")] + #[cfg(all(feature = "libsql", feature = "test-helpers"))] #[tokio::test] async fn test_harness_custom_llm() { let custom_llm = Arc::new(StubLlm::new("custom response").with_model_name("my-model")); @@ -532,7 +549,7 @@ mod tests { assert_eq!(harness.deps.llm.model_name(), "my-model"); } - #[cfg(feature = "libsql")] + #[cfg(all(feature = "libsql", feature = "test-helpers"))] #[tokio::test] async fn test_harness_db_works() { let harness = TestHarnessBuilder::new().build().await; @@ -547,7 +564,7 @@ mod tests { // === QA Plan P1 - 2.2: Turn persistence round-trip tests === - #[cfg(feature = "libsql")] + #[cfg(all(feature = "libsql", feature = "test-helpers"))] #[tokio::test] async fn test_conversation_message_round_trip() { let harness = TestHarnessBuilder::new().build().await; @@ -594,7 +611,7 @@ mod tests { assert!(msgs[1].created_at <= msgs[2].created_at); } - #[cfg(feature = "libsql")] + #[cfg(all(feature = "libsql", feature = "test-helpers"))] #[tokio::test] async fn test_conversation_metadata_persistence() { let harness = TestHarnessBuilder::new().build().await; @@ -646,7 +663,7 @@ mod tests { assert_eq!(meta["model"], "gpt-4"); } - #[cfg(feature = "libsql")] + #[cfg(all(feature = "libsql", feature = "test-helpers"))] #[tokio::test] async fn test_conversation_belongs_to_user() { let harness = TestHarnessBuilder::new().build().await; @@ -672,7 +689,7 @@ mod tests { ); } - #[cfg(feature = "libsql")] + #[cfg(all(feature = "libsql", feature = "test-helpers"))] #[tokio::test] async fn test_ensure_conversation_idempotent() { let harness = TestHarnessBuilder::new().build().await; @@ -716,7 +733,7 @@ mod tests { assert_eq!(msgs[0].content, "test message"); } - #[cfg(feature = "libsql")] + #[cfg(all(feature = "libsql", feature = "test-helpers"))] #[tokio::test] async fn test_paginated_messages() { let harness = TestHarnessBuilder::new().build().await; @@ -758,7 +775,7 @@ mod tests { } } - #[cfg(feature = "libsql")] + #[cfg(all(feature = "libsql", feature = "test-helpers"))] #[tokio::test] async fn test_conversations_with_preview() { let harness = TestHarnessBuilder::new().build().await; @@ -794,7 +811,7 @@ mod tests { } } - #[cfg(feature = "libsql")] + #[cfg(all(feature = "libsql", feature = "test-helpers"))] #[tokio::test] async fn test_job_action_persistence() { use crate::context::{ActionRecord, JobContext, JobState}; @@ -906,6 +923,7 @@ mod tests { assert!(channel.health_check().await.is_err()); } + #[cfg(all(feature = "libsql", feature = "test-helpers"))] #[tokio::test] async fn test_harness_with_channel() { let harness = TestHarnessBuilder::new().with_stub_channel().build().await; @@ -924,7 +942,7 @@ mod tests { assert!(names.contains(&"stub".to_string())); } - #[cfg(feature = "libsql")] + #[cfg(all(feature = "libsql", feature = "test-helpers"))] #[tokio::test] async fn test_tool_failure_tracking() { let harness = TestHarnessBuilder::new().build().await; @@ -953,7 +971,7 @@ mod tests { .expect("mark repaired"); } - #[cfg(feature = "libsql")] + #[cfg(all(feature = "libsql", feature = "test-helpers"))] async fn create_routine_fixture(db: &Arc) -> uuid::Uuid { use crate::agent::routine::{ NotifyConfig, Routine, RoutineAction, RoutineGuardrails, Trigger, @@ -1040,7 +1058,7 @@ mod tests { routine_id } - #[cfg(feature = "libsql")] + #[cfg(all(feature = "libsql", feature = "test-helpers"))] async fn start_routine_run(db: &Arc, routine_id: uuid::Uuid) -> uuid::Uuid { use crate::agent::routine::{RoutineRun, RunStatus}; @@ -1071,7 +1089,7 @@ mod tests { run_id } - #[cfg(feature = "libsql")] + #[cfg(all(feature = "libsql", feature = "test-helpers"))] async fn complete_routine_run_ok(db: &Arc, run_id: uuid::Uuid) { use crate::agent::routine::RunStatus; @@ -1085,7 +1103,7 @@ mod tests { .expect("complete run"); } - #[cfg(feature = "libsql")] + #[cfg(all(feature = "libsql", feature = "test-helpers"))] async fn assert_history_len(db: &Arc, routine_id: uuid::Uuid, expected: usize) { use crate::agent::routine::RunStatus; @@ -1099,7 +1117,7 @@ mod tests { } } - #[cfg(feature = "libsql")] + #[cfg(all(feature = "libsql", feature = "test-helpers"))] async fn delete_routine_and_assert_absent(db: &Arc, routine_id: uuid::Uuid) { let deleted = db.delete_routine(routine_id).await.expect("delete"); assert!(deleted); @@ -1109,7 +1127,7 @@ mod tests { assert!(!deleted); } - #[cfg(feature = "libsql")] + #[cfg(all(feature = "libsql", feature = "test-helpers"))] #[tokio::test] async fn test_routine_crud() { let harness = TestHarnessBuilder::new().build().await; @@ -1122,7 +1140,7 @@ mod tests { delete_routine_and_assert_absent(db, routine_id).await; } - #[cfg(feature = "libsql")] + #[cfg(all(feature = "libsql", feature = "test-helpers"))] #[tokio::test] async fn test_routine_runtime_update() { use crate::agent::routine::{ @@ -1193,7 +1211,7 @@ mod tests { db.delete_routine(routine_id).await.expect("delete"); } - #[cfg(feature = "libsql")] + #[cfg(all(feature = "libsql", feature = "test-helpers"))] #[tokio::test] async fn test_llm_call_recording() { use crate::history::LlmCallRecord; @@ -1216,7 +1234,7 @@ mod tests { assert!(!call_id.is_nil()); } - #[cfg(feature = "libsql")] + #[cfg(all(feature = "libsql", feature = "test-helpers"))] #[tokio::test] async fn test_sandbox_job_lifecycle() { use crate::history::SandboxJobRecord; @@ -1311,7 +1329,7 @@ mod tests { assert!(!not_belongs); } - #[cfg(feature = "libsql")] + #[cfg(all(feature = "libsql", feature = "test-helpers"))] #[tokio::test] async fn test_sandbox_job_mode() { use crate::history::SandboxJobRecord; @@ -1352,7 +1370,7 @@ mod tests { assert_eq!(mode, crate::db::SandboxMode::ClaudeCode); } - #[cfg(feature = "libsql")] + #[cfg(all(feature = "libsql", feature = "test-helpers"))] #[tokio::test] async fn test_job_events() { use crate::history::SandboxJobRecord; @@ -1425,7 +1443,7 @@ mod tests { assert_eq!(events[0].event_type, "tool_call"); } - #[cfg(feature = "libsql")] + #[cfg(all(feature = "libsql", feature = "test-helpers"))] #[tokio::test] async fn test_estimation_snapshot_round_trip() { let harness = TestHarnessBuilder::new().build().await; diff --git a/src/testing/null_db/capturing_store/delegation.rs b/src/testing/null_db/capturing_store/delegation.rs new file mode 100644 index 000000000..027011c00 --- /dev/null +++ b/src/testing/null_db/capturing_store/delegation.rs @@ -0,0 +1,396 @@ +//! Delegate implementations for CapturingStore. +//! +//! This module contains all the `delegate!` macro invocations that forward +//! trait implementations to the inner NullDatabase. The CapturingStore +//! overrides `persist_terminal_result_and_status`, `update_job_status`, and +//! `save_job_event` to capture calls; all other methods are delegated +//! unchanged through the `delegate!` macro invocations below. + +use delegate::delegate; +use uuid::Uuid; + +use crate::agent::{Routine, routine::RoutineRun}; +use crate::context::JobState; +use crate::db::{ + EnsureConversationParams, EstimationActualsParams, EstimationSnapshotParams, + HybridSearchParams, InsertChunkParams, SandboxEventType, SandboxJobStatusUpdate, SandboxMode, + SettingKey, UserId, +}; +use crate::error::{DatabaseError, WorkspaceError}; +use crate::history::{ + AgentJobRecord, AgentJobSummary, ConversationMessage, ConversationSummary, JobEventRecord, + LlmCallRecord, SandboxJobRecord, SandboxJobSummary, SettingRow, +}; +use crate::workspace::{MemoryChunk, MemoryDocument, SearchResult, WorkspaceEntry}; + +use super::CapturingStore; + +impl crate::db::NativeDatabase for CapturingStore { + async fn persist_terminal_result_and_status( + &self, + params: crate::db::TerminalJobPersistence<'_>, + ) -> Result<(), DatabaseError> { + self.calls + .record_event(params.job_id, params.event_type, params.event_data) + .await; + self.calls + .record_status(params.job_id, params.status, params.failure_reason) + .await; + Ok(()) + } + + delegate! { + to self.inner { + async fn run_migrations(&self) -> Result<(), DatabaseError>; + } + } +} + +impl crate::db::NativeJobStore for CapturingStore { + delegate! { + to self.inner { + async fn save_job(&self, ctx: &crate::context::JobContext) -> Result<(), DatabaseError>; + async fn get_job( + &self, + id: Uuid + ) -> Result, DatabaseError>; + async fn mark_job_stuck(&self, id: Uuid) -> Result<(), DatabaseError>; + async fn get_stuck_jobs(&self) -> Result, DatabaseError>; + async fn list_agent_jobs(&self) -> Result, DatabaseError>; + async fn agent_job_summary(&self) -> Result; + async fn get_agent_job_failure_reason( + &self, + id: Uuid + ) -> Result, DatabaseError>; + async fn save_action( + &self, + job_id: Uuid, + action: &crate::context::ActionRecord + ) -> Result<(), DatabaseError>; + async fn get_job_actions( + &self, + job_id: Uuid + ) -> Result, DatabaseError>; + async fn record_llm_call( + &self, + record: &LlmCallRecord<'_> + ) -> Result; + async fn save_estimation_snapshot( + &self, + params: EstimationSnapshotParams<'_> + ) -> Result; + async fn update_estimation_actuals( + &self, + params: EstimationActualsParams + ) -> Result<(), DatabaseError>; + } + } + + async fn update_job_status( + &self, + id: Uuid, + status: JobState, + failure_reason: Option<&str>, + ) -> Result<(), DatabaseError> { + self.calls.record_status(id, status, failure_reason).await; + Ok(()) + } +} + +impl crate::db::NativeSandboxStore for CapturingStore { + delegate! { + to self.inner { + async fn save_sandbox_job(&self, job: &SandboxJobRecord) -> Result<(), DatabaseError>; + async fn get_sandbox_job( + &self, + id: Uuid + ) -> Result, DatabaseError>; + async fn list_sandbox_jobs(&self) -> Result, DatabaseError>; + async fn update_sandbox_job_status( + &self, + params: SandboxJobStatusUpdate<'_> + ) -> Result<(), DatabaseError>; + async fn cleanup_stale_sandbox_jobs(&self) -> Result; + async fn sandbox_job_summary(&self) -> Result; + async fn list_sandbox_jobs_for_user( + &self, + user_id: UserId + ) -> Result, DatabaseError>; + async fn sandbox_job_summary_for_user( + &self, + user_id: UserId + ) -> Result; + async fn sandbox_job_belongs_to_user( + &self, + job_id: Uuid, + user_id: UserId + ) -> Result; + async fn update_sandbox_job_mode( + &self, + id: Uuid, + mode: SandboxMode + ) -> Result<(), DatabaseError>; + async fn get_sandbox_job_mode( + &self, + id: Uuid + ) -> Result, DatabaseError>; + async fn list_job_events( + &self, + job_id: Uuid, + before_id: Option, + limit: Option + ) -> Result, DatabaseError>; + } + } + + async fn save_job_event( + &self, + job_id: Uuid, + event_type: SandboxEventType, + data: &serde_json::Value, + ) -> Result<(), DatabaseError> { + self.calls.record_event(job_id, event_type, data).await; + Ok(()) + } +} + +impl crate::db::NativeConversationStore for CapturingStore { + delegate! { + to self.inner { + async fn create_conversation( + &self, + channel: &str, + user_id: &str, + thread_id: Option<&str> + ) -> Result; + async fn touch_conversation(&self, id: Uuid) -> Result<(), DatabaseError>; + async fn add_conversation_message( + &self, + conversation_id: Uuid, + role: &str, + content: &str + ) -> Result; + async fn ensure_conversation( + &self, + params: EnsureConversationParams<'_> + ) -> Result<(), DatabaseError>; + async fn list_conversations_with_preview( + &self, + user_id: &str, + channel: &str, + limit: usize + ) -> Result, DatabaseError>; + async fn list_conversations_all_channels( + &self, + user_id: &str, + limit: usize + ) -> Result, DatabaseError>; + async fn get_or_create_routine_conversation( + &self, + routine_id: Uuid, + routine_name: &str, + user_id: &str + ) -> Result; + async fn get_or_create_heartbeat_conversation( + &self, + user_id: &str + ) -> Result; + async fn get_or_create_assistant_conversation( + &self, + user_id: &str, + channel: &str + ) -> Result; + async fn create_conversation_with_metadata( + &self, + channel: &str, + user_id: &str, + metadata: &serde_json::Value + ) -> Result; + async fn update_conversation_metadata_field( + &self, + id: Uuid, + key: &str, + value: &serde_json::Value + ) -> Result<(), DatabaseError>; + async fn get_conversation_metadata( + &self, + id: Uuid + ) -> Result, DatabaseError>; + async fn list_conversation_messages( + &self, + conversation_id: Uuid + ) -> Result, DatabaseError>; + async fn list_conversation_messages_paginated( + &self, + conversation_id: Uuid, + before: Option<(chrono::DateTime, Uuid)>, + limit: usize + ) -> Result<(Vec, bool), DatabaseError>; + async fn conversation_belongs_to_user( + &self, + conversation_id: Uuid, + user_id: &str + ) -> Result; + } + } +} + +impl crate::db::NativeRoutineStore for CapturingStore { + delegate! { + to self.inner { + async fn create_routine(&self, routine: &Routine) -> Result<(), DatabaseError>; + async fn get_routine(&self, id: Uuid) -> Result, DatabaseError>; + async fn get_routine_by_name( + &self, + user_id: &str, + name: &str + ) -> Result, DatabaseError>; + async fn list_routines(&self, user_id: &str) -> Result, DatabaseError>; + async fn list_all_routines(&self) -> Result, DatabaseError>; + async fn update_routine(&self, routine: &Routine) -> Result<(), DatabaseError>; + async fn delete_routine(&self, id: Uuid) -> Result; + async fn update_routine_runtime( + &self, + update: crate::db::RoutineRuntimeUpdate<'_> + ) -> Result<(), DatabaseError>; + async fn create_routine_run(&self, run: &RoutineRun) -> Result<(), DatabaseError>; + async fn list_routine_runs( + &self, + routine_id: Uuid, + limit: i64 + ) -> Result, DatabaseError>; + async fn complete_routine_run( + &self, + completion: crate::db::RoutineRunCompletion<'_> + ) -> Result<(), DatabaseError>; + async fn list_event_routines(&self) -> Result, DatabaseError>; + async fn list_due_cron_routines(&self) -> Result, DatabaseError>; + async fn count_running_routine_runs(&self, routine_id: Uuid) -> Result; + async fn link_routine_run_to_job( + &self, + run_id: Uuid, + job_id: Uuid + ) -> Result<(), DatabaseError>; + } + } +} + +impl crate::db::NativeToolFailureStore for CapturingStore { + delegate! { + to self.inner { + async fn record_tool_failure( + &self, + tool_name: &str, + error: &str + ) -> Result<(), DatabaseError>; + async fn get_broken_tools( + &self, + threshold: i32 + ) -> Result, DatabaseError>; + async fn mark_tool_repaired(&self, tool_name: &str) -> Result<(), DatabaseError>; + async fn increment_repair_attempts(&self, tool_name: &str) -> Result<(), DatabaseError>; + } + } +} + +impl crate::db::NativeSettingsStore for CapturingStore { + delegate! { + to self.inner { + async fn get_setting( + &self, + user_id: UserId, + key: SettingKey + ) -> Result, DatabaseError>; + async fn get_setting_full( + &self, + user_id: UserId, + key: SettingKey + ) -> Result, DatabaseError>; + async fn delete_setting( + &self, + user_id: UserId, + key: SettingKey + ) -> Result; + async fn list_settings( + &self, + user_id: UserId + ) -> Result, DatabaseError>; + async fn set_setting( + &self, + user_id: UserId, + key: SettingKey, + value: &serde_json::Value + ) -> Result<(), DatabaseError>; + async fn get_all_settings( + &self, + user_id: UserId + ) -> Result, DatabaseError>; + async fn set_all_settings( + &self, + user_id: UserId, + settings: &std::collections::HashMap + ) -> Result<(), DatabaseError>; + async fn has_settings(&self, user_id: UserId) -> Result; + } + } +} + +impl crate::db::NativeWorkspaceStore for CapturingStore { + delegate! { + to self.inner { + async fn get_document_by_path( + &self, + user_id: &str, + agent_id: Option, + path: &str + ) -> Result; + async fn get_document_by_id(&self, id: Uuid) -> Result; + async fn get_or_create_document_by_path( + &self, + user_id: &str, + agent_id: Option, + path: &str + ) -> Result; + async fn update_document(&self, id: Uuid, content: &str) -> Result<(), WorkspaceError>; + async fn delete_document_by_path( + &self, + user_id: &str, + agent_id: Option, + path: &str + ) -> Result<(), WorkspaceError>; + async fn list_directory( + &self, + user_id: &str, + agent_id: Option, + directory: &str + ) -> Result, WorkspaceError>; + async fn list_all_paths( + &self, + user_id: &str, + agent_id: Option + ) -> Result, WorkspaceError>; + async fn list_documents( + &self, + user_id: &str, + agent_id: Option + ) -> Result, WorkspaceError>; + async fn delete_chunks(&self, document_id: Uuid) -> Result<(), WorkspaceError>; + async fn insert_chunk(&self, params: InsertChunkParams<'_>) -> Result; + async fn update_chunk_embedding( + &self, + chunk_id: Uuid, + embedding: &[f32] + ) -> Result<(), WorkspaceError>; + async fn get_chunks_without_embeddings( + &self, + user_id: &str, + agent_id: Option, + limit: usize + ) -> Result, WorkspaceError>; + async fn hybrid_search( + &self, + params: HybridSearchParams<'_> + ) -> Result, WorkspaceError>; + } + } +} diff --git a/src/testing/null_db/capturing_store/mod.rs b/src/testing/null_db/capturing_store/mod.rs new file mode 100644 index 000000000..bc911c63a --- /dev/null +++ b/src/testing/null_db/capturing_store/mod.rs @@ -0,0 +1,168 @@ +//! Capturing database wrapper for tests. +//! +//! Provides a [`CapturingStore`] that wraps [`NullDatabase`] and captures +//! specific method calls for test assertions. +//! +//! Captured calls include job IDs via [`StatusCallWithId`] and [`EventCallWithId`] +//! in the `status_history` and `event_history` collections, while [`StatusCall`] +//! and [`EventCall`] provide the simpler view without IDs in `last_status` and +//! `last_event`. + +use std::sync::Arc; + +use tokio::sync::Mutex; +use uuid::Uuid; + +use crate::context::JobState; +use crate::db::SandboxEventType; + +use super::NullDatabase; + +mod delegation; + +/// Captured status update call. +#[derive(Debug, Clone)] +pub struct StatusCall { + /// The job status that was recorded. + pub status: JobState, + /// Optional failure reason associated with the status. + pub reason: Option, +} + +/// Captured status update call with job ID. +#[derive(Debug, Clone)] +pub struct StatusCallWithId { + /// The job ID associated with this status update. + pub job_id: Uuid, + /// The job status that was recorded. + pub status: JobState, + /// Optional failure reason associated with the status. + pub reason: Option, +} + +/// Captured job event call. +#[derive(Debug, Clone)] +pub struct EventCall { + /// The event type string (e.g., "result"). + pub event_type: String, + /// The JSON data payload associated with the event. + pub data: serde_json::Value, +} + +/// Captured job event call with job ID. +#[derive(Debug, Clone)] +pub struct EventCallWithId { + /// The job ID associated with this event. + pub job_id: Uuid, + /// The event type string (e.g., "result"). + pub event_type: String, + /// The JSON data payload associated with the event. + pub data: serde_json::Value, +} + +/// Thread-safe storage for captured calls. +#[derive(Debug, Default)] +pub struct Calls { + /// The last status update call captured, if any. + pub last_status: Mutex>, + /// The last event call captured, if any. + pub last_event: Mutex>, + /// Full history of all status calls with job IDs. + pub status_history: Mutex>, + /// Full history of all event calls with job IDs. + pub event_history: Mutex>, +} + +impl Calls { + /// Create a new empty Calls container. + pub fn new() -> Self { + Self::default() + } + + /// Record a status update call. + /// + /// The call is stored in both `last_status` (overwriting previous) + /// and appended to `status_history` with the job ID for tests that need + /// to verify call counts or per-job tracking. + pub async fn record_status(&self, job_id: Uuid, status: JobState, reason: Option<&str>) { + let last_call = StatusCall { + status, + reason: reason.map(ToOwned::to_owned), + }; + let history_call = StatusCallWithId { + job_id, + status, + reason: reason.map(ToOwned::to_owned), + }; + *self.last_status.lock().await = Some(last_call); + self.status_history.lock().await.push(history_call); + } + + /// Record an event call. + /// + /// The call is stored in both `last_event` (overwriting previous) + /// and appended to `event_history` with the job ID for tests that need + /// to verify call counts or per-job tracking. + pub async fn record_event( + &self, + job_id: Uuid, + event_type: SandboxEventType, + data: &serde_json::Value, + ) { + let last_call = EventCall { + event_type: event_type.as_str().to_string(), + data: data.clone(), + }; + let history_call = EventCallWithId { + job_id, + event_type: event_type.as_str().to_string(), + data: data.clone(), + }; + *self.last_event.lock().await = Some(last_call); + self.event_history.lock().await.push(history_call); + } + + /// Clear all captured call history. + pub async fn clear(&self) { + *self.last_status.lock().await = None; + *self.last_event.lock().await = None; + self.status_history.lock().await.clear(); + self.event_history.lock().await.clear(); + } +} + +/// A database wrapper that captures calls to specific methods for testing. +/// +/// Delegates all other methods to the inner [`NullDatabase`]. +/// +/// The `last_status` and `last_event` fields store the most recent call +/// (without job ID), while `status_history` and `event_history` maintain +/// full call sequences with job IDs via [`StatusCallWithId`] and +/// [`EventCallWithId`]. This supports tests that need to verify call counts +/// (e.g., duplicate transition rejection) or per-job tracking. +#[derive(Debug)] +pub struct CapturingStore { + pub(crate) inner: NullDatabase, + calls: Arc, +} + +impl CapturingStore { + /// Create a new capturing store with an inner NullDatabase. + pub fn new() -> Self { + Self { + inner: NullDatabase::new(), + calls: Arc::new(Calls::new()), + } + } + + /// Access the captured calls for assertions. + pub fn calls(&self) -> &Arc { + &self.calls + } +} + +impl Default for CapturingStore { + fn default() -> Self { + Self::new() + } +} diff --git a/src/testing/null_db/mod.rs b/src/testing/null_db/mod.rs new file mode 100644 index 000000000..e0492a010 --- /dev/null +++ b/src/testing/null_db/mod.rs @@ -0,0 +1,19 @@ +//! Test-only database doubles and captured-call helpers. +//! +//! [`NullDatabase`] provides null defaults across the `Native*Store` traits for +//! bespoke mocks, while [`CapturingStore`] wraps that baseline with captured +//! [`Calls`], [`EventCall`], [`EventCallWithId`], [`StatusCall`], and +//! [`StatusCallWithId`] records for persistence assertions. +//! +//! Choose the right testing abstraction for the job: use +//! [`crate::testing::TestHarnessBuilder`] for persistence testing with a real +//! database, [`CapturingStore`] for verifying calls without durable storage, +//! or [`NullDatabase`] when a test needs a custom mock with null behaviour. + +mod capturing_store; +mod null_database; + +pub use capturing_store::{ + Calls, CapturingStore, EventCall, EventCallWithId, StatusCall, StatusCallWithId, +}; +pub use null_database::NullDatabase; diff --git a/src/testing/null_db/null_database.rs b/src/testing/null_db/null_database.rs new file mode 100644 index 000000000..673a82dc2 --- /dev/null +++ b/src/testing/null_db/null_database.rs @@ -0,0 +1,174 @@ +//! Null database implementation for tests. +//! +//! Most methods return empty defaults (`Ok(None)`, `Ok(vec![])`, etc.), but +//! some return [`WorkspaceError::DocumentNotFound`] for missing documents. +//! UUIDs are generated deterministically via an internal counter (see +//! [`next_synthetic_uuid`](NullDatabase::next_synthetic_uuid)) and cache +//! entries are stable per-key, ensuring reproducible test results. +//! Use this as a baseline for test doubles that need to override only +//! specific methods while delegating the rest to null behavior. + +use std::collections::HashMap; +use std::hash::Hash; +use std::sync::Mutex; + +use crate::error::DatabaseError; +use crate::error::WorkspaceError; + +mod conversation_store; +mod job_store; +mod routine_store; +mod sandbox_store; +mod settings_store; +mod tool_failure_store; +mod workspace_store; + +/// Key for the routine conversation cache. +/// +/// Only includes routine_id and user_id to ensure singleton semantics +/// (changing the routine name should not create a new conversation). +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub(super) struct RoutineConvKey { + pub routine_id: uuid::Uuid, + pub user_id: String, +} + +/// Key for the assistant conversation cache. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub(super) struct AssistantConvKey { + pub user_id: String, + pub channel: String, +} + +/// A no-op database implementation for testing. +/// +/// Most methods return empty defaults (`Ok(None)`, `Ok(vec![])`, etc.), but +/// some return [`WorkspaceError::DocumentNotFound`] for missing documents. +/// UUIDs are generated deterministically via an internal counter (see +/// [`next_synthetic_uuid`](NullDatabase::next_synthetic_uuid)) and cache +/// entries are stable per-key, ensuring reproducible test results. +/// Use this as a baseline for test doubles that need to override only +/// specific methods while delegating the rest to null behavior. +#[derive(Debug, Default)] +pub struct NullDatabase { + /// Stable UUIDs for routine conversations, keyed by (routine_id, user_id). + pub(super) routine_conv_cache: Mutex>, + /// Stable UUIDs for heartbeat conversations, keyed by user_id. + pub(super) heartbeat_conv_cache: Mutex>, + /// Stable UUIDs for assistant conversations, keyed by (user_id, channel). + pub(super) assistant_conv_cache: Mutex>, + /// Counter for deterministic synthetic UUIDs. + pub(super) uuid_counter: Mutex, +} + +impl NullDatabase { + /// Create a new null database instance. + pub fn new() -> Self { + Self::default() + } + + /// Helper for document-not-found errors in workspace operations. + pub(super) fn doc_not_found(doc_type: &str) -> WorkspaceError { + WorkspaceError::DocumentNotFound { + doc_type: doc_type.to_string(), + user_id: "test".to_string(), + } + } + + /// Generate a deterministic synthetic UUID based on an internal counter. + /// + /// Each call increments the counter and returns a UUID with the counter + /// value embedded in the UUID bytes. This provides reproducible IDs + /// for tests that need stable values across multiple calls. + pub(super) fn next_synthetic_uuid(&self) -> Result { + let mut counter = self + .uuid_counter + .lock() + .map_err(|_| DatabaseError::Query("lock poisoned".to_string()))?; + *counter += 1; + // Embed counter in UUID bytes for deterministic generation + let bytes = counter.to_be_bytes(); + let mut uuid_bytes = [0u8; 16]; + uuid_bytes[0..16].copy_from_slice(&[ + bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7], + bytes[8], bytes[9], bytes[10], bytes[11], bytes[12], bytes[13], bytes[14], bytes[15], + ]); + Ok(uuid::Uuid::from_bytes(uuid_bytes)) + } + + /// Lock `cache` and return the UUID already stored under `key`, + /// inserting a fresh synthetic UUID if the entry is absent. + pub(super) fn get_or_create_in_cache( + &self, + cache: &Mutex>, + key: K, + ) -> Result { + let mut map = cache + .lock() + .map_err(|_| DatabaseError::Query("lock poisoned".to_string()))?; + if let Some(id) = map.get(&key) { + return Ok(*id); + } + let id = self.next_synthetic_uuid()?; + map.insert(key, id); + Ok(id) + } +} + +impl crate::db::NativeDatabase for NullDatabase { + async fn persist_terminal_result_and_status( + &self, + _params: crate::db::TerminalJobPersistence<'_>, + ) -> Result<(), crate::error::DatabaseError> { + Ok(()) + } + + async fn run_migrations(&self) -> Result<(), crate::error::DatabaseError> { + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn synthetic_uuid_sequence_is_unique_across_many_calls() { + let db = NullDatabase::new(); + let mut seen = std::collections::HashSet::new(); + + for _ in 0..100 { + let id = db + .next_synthetic_uuid() + .expect("synthetic UUID generation should not fail"); + assert!(seen.insert(id), "duplicate synthetic UUID: {id}"); + } + } + + #[test] + fn cached_ids_are_stable_per_key_and_distinct_across_keys() { + let db = NullDatabase::new(); + let cache = Mutex::new(HashMap::new()); + let keys = (0..10).map(|idx| format!("key-{idx}")).collect::>(); + let mut expected = HashMap::new(); + + for _ in 0..5 { + for key in &keys { + let id = db + .get_or_create_in_cache(&cache, key.clone()) + .expect("cache access should not fail"); + if let Some(existing) = expected.get(key) { + assert_eq!(*existing, id, "cache entry for {key} changed"); + } else { + expected.insert(key.clone(), id); + } + } + } + + let unique = expected + .values() + .copied() + .collect::>(); + assert_eq!(unique.len(), keys.len(), "different keys shared a UUID"); + } +} diff --git a/src/testing/null_db/null_database/conversation_store.rs b/src/testing/null_db/null_database/conversation_store.rs new file mode 100644 index 000000000..876ed0fda --- /dev/null +++ b/src/testing/null_db/null_database/conversation_store.rs @@ -0,0 +1,256 @@ +//! Null implementation of NativeConversationStore for NullDatabase. + +use chrono::{DateTime, Utc}; +use uuid::Uuid; + +use crate::db::EnsureConversationParams; +use crate::error::DatabaseError; +use crate::history::{ConversationMessage, ConversationSummary}; +use crate::testing::null_db::null_database::{AssistantConvKey, RoutineConvKey}; + +use super::NullDatabase; + +impl crate::db::NativeConversationStore for NullDatabase { + async fn create_conversation( + &self, + _channel: &str, + _user_id: &str, + _thread_id: Option<&str>, + ) -> Result { + self.next_synthetic_uuid() + } + + async fn touch_conversation(&self, _id: Uuid) -> Result<(), DatabaseError> { + Ok(()) + } + + async fn add_conversation_message( + &self, + _conversation_id: Uuid, + _role: &str, + _content: &str, + ) -> Result { + self.next_synthetic_uuid() + } + + async fn ensure_conversation( + &self, + _params: EnsureConversationParams<'_>, + ) -> Result<(), DatabaseError> { + Ok(()) + } + + async fn list_conversations_with_preview( + &self, + _user_id: &str, + _channel: &str, + _limit: usize, + ) -> Result, DatabaseError> { + Ok(vec![]) + } + + async fn list_conversations_all_channels( + &self, + _user_id: &str, + _limit: usize, + ) -> Result, DatabaseError> { + Ok(vec![]) + } + + async fn get_or_create_routine_conversation( + &self, + routine_id: Uuid, + _routine_name: &str, + user_id: &str, + ) -> Result { + let key = RoutineConvKey { + routine_id, + user_id: user_id.to_string(), + }; + self.get_or_create_in_cache(&self.routine_conv_cache, key) + } + + async fn get_or_create_heartbeat_conversation( + &self, + user_id: &str, + ) -> Result { + self.get_or_create_in_cache(&self.heartbeat_conv_cache, user_id.to_string()) + } + + async fn get_or_create_assistant_conversation( + &self, + user_id: &str, + channel: &str, + ) -> Result { + let key = AssistantConvKey { + user_id: user_id.to_string(), + channel: channel.to_string(), + }; + self.get_or_create_in_cache(&self.assistant_conv_cache, key) + } + + async fn create_conversation_with_metadata( + &self, + _channel: &str, + _user_id: &str, + _metadata: &serde_json::Value, + ) -> Result { + self.next_synthetic_uuid() + } + + async fn update_conversation_metadata_field( + &self, + _id: Uuid, + _key: &str, + _value: &serde_json::Value, + ) -> Result<(), DatabaseError> { + Ok(()) + } + + async fn get_conversation_metadata( + &self, + _id: Uuid, + ) -> Result, DatabaseError> { + Ok(None) + } + + async fn list_conversation_messages( + &self, + _conversation_id: Uuid, + ) -> Result, DatabaseError> { + Ok(vec![]) + } + + async fn list_conversation_messages_paginated( + &self, + _conversation_id: Uuid, + _before: Option<(DateTime, Uuid)>, + _limit: usize, + ) -> Result<(Vec, bool), DatabaseError> { + Ok((vec![], false)) + } + + async fn conversation_belongs_to_user( + &self, + _conversation_id: Uuid, + _user_id: &str, + ) -> Result { + Ok(false) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::db::NativeConversationStore; + + #[tokio::test] + async fn test_get_or_create_routine_conversation_returns_stable_uuid() { + let db = NullDatabase::new(); + let routine_id = Uuid::new_v4(); + + let uuid1 = db + .get_or_create_routine_conversation(routine_id, "test_routine", "user1") + .await + .expect( + "first get_or_create_routine_conversation for test_routine user1 should succeed", + ); + let uuid2 = db + .get_or_create_routine_conversation(routine_id, "test_routine", "user1") + .await + .expect( + "second get_or_create_routine_conversation for test_routine user1 should succeed", + ); + + assert_eq!(uuid1, uuid2, "Same inputs should return same UUID"); + + // Different routine_name but same routine_id should return same UUID (singleton semantics) + let uuid3 = db + .get_or_create_routine_conversation(routine_id, "different_routine", "user1") + .await + .expect("get_or_create_routine_conversation with different routine_name for user1 should succeed"); + assert_eq!( + uuid1, uuid3, + "Same routine_id should return same UUID regardless of routine_name" + ); + + let uuid4 = db + .get_or_create_routine_conversation(Uuid::new_v4(), "test_routine", "user1") + .await + .expect("get_or_create_routine_conversation with different routine_id for user1 should succeed"); + assert_ne!( + uuid1, uuid4, + "Different routine_id should return different UUID" + ); + + let uuid5 = db + .get_or_create_routine_conversation(routine_id, "test_routine", "user2") + .await + .expect("get_or_create_routine_conversation for user2 should succeed"); + assert_ne!( + uuid1, uuid5, + "Different user_id should return different UUID" + ); + } + + #[tokio::test] + async fn test_get_or_create_heartbeat_conversation_returns_stable_uuid() { + let db = NullDatabase::new(); + + let uuid1 = db + .get_or_create_heartbeat_conversation("user1") + .await + .expect("first get_or_create_heartbeat_conversation for user1 should succeed"); + let uuid2 = db + .get_or_create_heartbeat_conversation("user1") + .await + .expect("second get_or_create_heartbeat_conversation for user1 should succeed"); + + assert_eq!(uuid1, uuid2, "Same user_id should return same UUID"); + + // Different user should return different UUID + let uuid3 = db + .get_or_create_heartbeat_conversation("user2") + .await + .expect("get_or_create_heartbeat_conversation for user2 should succeed"); + assert_ne!( + uuid1, uuid3, + "Different user_id should return different UUID" + ); + } + + #[tokio::test] + async fn test_get_or_create_assistant_conversation_returns_stable_uuid() { + let db = NullDatabase::new(); + + let uuid1 = db + .get_or_create_assistant_conversation("user1", "slack") + .await + .expect("first get_or_create_assistant_conversation for user1 slack should succeed"); + let uuid2 = db + .get_or_create_assistant_conversation("user1", "slack") + .await + .expect("second get_or_create_assistant_conversation for user1 slack should succeed"); + + assert_eq!(uuid1, uuid2, "Same inputs should return same UUID"); + + // Different inputs should return different UUIDs + let uuid3 = db + .get_or_create_assistant_conversation("user2", "slack") + .await + .expect("get_or_create_assistant_conversation for user2 slack should succeed"); + assert_ne!( + uuid1, uuid3, + "Different user_id should return different UUID" + ); + + let uuid4 = db + .get_or_create_assistant_conversation("user1", "discord") + .await + .expect("get_or_create_assistant_conversation for user1 discord should succeed"); + assert_ne!( + uuid1, uuid4, + "Different channel should return different UUID" + ); + } +} diff --git a/src/testing/null_db/null_database/job_store.rs b/src/testing/null_db/null_database/job_store.rs new file mode 100644 index 000000000..aa2ada46a --- /dev/null +++ b/src/testing/null_db/null_database/job_store.rs @@ -0,0 +1,156 @@ +//! Null implementation of NativeJobStore for NullDatabase. + +use uuid::Uuid; + +use crate::context::{ActionRecord, JobContext}; +use crate::db::{EstimationActualsParams, EstimationSnapshotParams}; +use crate::error::DatabaseError; +use crate::history::{AgentJobRecord, AgentJobSummary, LlmCallRecord}; + +use super::NullDatabase; + +impl crate::db::NativeJobStore for NullDatabase { + async fn save_job(&self, _ctx: &JobContext) -> Result<(), DatabaseError> { + Ok(()) + } + + async fn get_job(&self, _id: Uuid) -> Result, DatabaseError> { + Ok(None) + } + + async fn update_job_status( + &self, + _id: Uuid, + _status: crate::context::JobState, + _failure_reason: Option<&str>, + ) -> Result<(), DatabaseError> { + Ok(()) + } + + async fn mark_job_stuck(&self, _id: Uuid) -> Result<(), DatabaseError> { + Ok(()) + } + + async fn get_stuck_jobs(&self) -> Result, DatabaseError> { + Ok(vec![]) + } + + async fn list_agent_jobs(&self) -> Result, DatabaseError> { + Ok(vec![]) + } + + async fn agent_job_summary(&self) -> Result { + Ok(AgentJobSummary::default()) + } + + async fn get_agent_job_failure_reason( + &self, + _id: Uuid, + ) -> Result, DatabaseError> { + Ok(None) + } + + async fn save_action( + &self, + _job_id: Uuid, + _action: &ActionRecord, + ) -> Result<(), DatabaseError> { + Ok(()) + } + + async fn get_job_actions(&self, _job_id: Uuid) -> Result, DatabaseError> { + Ok(vec![]) + } + + async fn record_llm_call(&self, _record: &LlmCallRecord<'_>) -> Result { + self.next_synthetic_uuid() + } + + async fn save_estimation_snapshot( + &self, + _params: EstimationSnapshotParams<'_>, + ) -> Result { + self.next_synthetic_uuid() + } + + async fn update_estimation_actuals( + &self, + _params: EstimationActualsParams, + ) -> Result<(), DatabaseError> { + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::db::NativeJobStore; + use crate::history::LlmCallRecord; + + #[test] + fn test_synthetic_uuid_is_deterministic() { + let db = NullDatabase::new(); + + let uuid1 = db + .next_synthetic_uuid() + .expect("first synthetic UUID generation should succeed"); + let uuid2 = db + .next_synthetic_uuid() + .expect("second synthetic UUID generation should succeed"); + let uuid3 = db + .next_synthetic_uuid() + .expect("third synthetic UUID generation should succeed"); + + // UUIDs should be sequential and unique + assert_ne!(uuid1, uuid2); + assert_ne!(uuid2, uuid3); + assert_ne!(uuid1, uuid3); + + // Each call should increment the counter + let bytes1 = uuid1.as_bytes(); + let bytes2 = uuid2.as_bytes(); + let bytes3 = uuid3.as_bytes(); + + // Convert first 8 bytes back to u128 (big endian) + let n1 = u128::from_be_bytes(*bytes1); + let n2 = u128::from_be_bytes(*bytes2); + let n3 = u128::from_be_bytes(*bytes3); + + assert_eq!(n1 + 1, n2, "Second UUID should be one greater than first"); + assert_eq!(n2 + 1, n3, "Third UUID should be one greater than second"); + } + + #[tokio::test] + async fn test_record_llm_call_returns_deterministic_uuids() { + use rust_decimal::Decimal; + + let db = NullDatabase::new(); + + let record = LlmCallRecord { + job_id: Some(Uuid::nil()), + conversation_id: None, + provider: "test_provider", + model: "test", + input_tokens: 10, + output_tokens: 20, + cost: Decimal::ZERO, + purpose: Some("test"), + }; + + let uuid1 = db + .record_llm_call(&record) + .await + .expect("record_llm_call failed for uuid1"); + let uuid2 = db + .record_llm_call(&record) + .await + .expect("record_llm_call failed for uuid2"); + + assert_ne!(uuid1, uuid2, "Each call should return a new UUID"); + + // Verify they are sequential + let n1 = u128::from_be_bytes(*uuid1.as_bytes()); + let n2 = u128::from_be_bytes(*uuid2.as_bytes()); + assert_eq!(n1 + 1, n2, "UUIDs should be sequential"); + } +} diff --git a/src/testing/null_db/null_database/routine_store.rs b/src/testing/null_db/null_database/routine_store.rs new file mode 100644 index 000000000..b561a023d --- /dev/null +++ b/src/testing/null_db/null_database/routine_store.rs @@ -0,0 +1,89 @@ +//! Null implementation of NativeRoutineStore for NullDatabase. + +use uuid::Uuid; + +use crate::agent::{Routine, routine::RoutineRun}; +use crate::db::RoutineRuntimeUpdate; +use crate::error::DatabaseError; + +use super::NullDatabase; + +impl crate::db::NativeRoutineStore for NullDatabase { + async fn create_routine(&self, _routine: &Routine) -> Result<(), DatabaseError> { + Ok(()) + } + + async fn get_routine(&self, _id: Uuid) -> Result, DatabaseError> { + Ok(None) + } + + async fn get_routine_by_name( + &self, + _user_id: &str, + _name: &str, + ) -> Result, DatabaseError> { + Ok(None) + } + + async fn list_routines(&self, _user_id: &str) -> Result, DatabaseError> { + Ok(vec![]) + } + + async fn list_all_routines(&self) -> Result, DatabaseError> { + Ok(vec![]) + } + + async fn update_routine(&self, _routine: &Routine) -> Result<(), DatabaseError> { + Ok(()) + } + + async fn delete_routine(&self, _id: Uuid) -> Result { + Ok(false) + } + + async fn update_routine_runtime( + &self, + _update: RoutineRuntimeUpdate<'_>, + ) -> Result<(), DatabaseError> { + Ok(()) + } + + async fn create_routine_run(&self, _run: &RoutineRun) -> Result<(), DatabaseError> { + Ok(()) + } + + async fn list_routine_runs( + &self, + _routine_id: Uuid, + _limit: i64, + ) -> Result, DatabaseError> { + Ok(vec![]) + } + + async fn complete_routine_run( + &self, + _completion: crate::db::RoutineRunCompletion<'_>, + ) -> Result<(), DatabaseError> { + Ok(()) + } + + async fn list_event_routines(&self) -> Result, DatabaseError> { + Ok(vec![]) + } + + async fn list_due_cron_routines(&self) -> Result, DatabaseError> { + Ok(vec![]) + } + + async fn count_running_routine_runs(&self, _routine_id: Uuid) -> Result { + Ok(0) + } + + async fn link_routine_run_to_job( + &self, + _run_id: Uuid, + _job_id: Uuid, + ) -> Result<(), DatabaseError> { + Ok(()) + } +} diff --git a/src/testing/null_db/null_database/sandbox_store.rs b/src/testing/null_db/null_database/sandbox_store.rs new file mode 100644 index 000000000..cc554397b --- /dev/null +++ b/src/testing/null_db/null_database/sandbox_store.rs @@ -0,0 +1,90 @@ +//! Null implementation of NativeSandboxStore for NullDatabase. + +use uuid::Uuid; + +use crate::db::{SandboxEventType, SandboxJobStatusUpdate, SandboxMode, UserId}; +use crate::error::DatabaseError; +use crate::history::{JobEventRecord, SandboxJobRecord, SandboxJobSummary}; + +use super::NullDatabase; + +impl crate::db::NativeSandboxStore for NullDatabase { + async fn save_sandbox_job(&self, _job: &SandboxJobRecord) -> Result<(), DatabaseError> { + Ok(()) + } + + async fn get_sandbox_job(&self, _id: Uuid) -> Result, DatabaseError> { + Ok(None) + } + + async fn list_sandbox_jobs(&self) -> Result, DatabaseError> { + Ok(vec![]) + } + + async fn update_sandbox_job_status( + &self, + _params: SandboxJobStatusUpdate<'_>, + ) -> Result<(), DatabaseError> { + Ok(()) + } + + async fn cleanup_stale_sandbox_jobs(&self) -> Result { + Ok(0) + } + + async fn sandbox_job_summary(&self) -> Result { + Ok(SandboxJobSummary::default()) + } + + async fn list_sandbox_jobs_for_user( + &self, + _user_id: UserId, + ) -> Result, DatabaseError> { + Ok(vec![]) + } + + async fn sandbox_job_summary_for_user( + &self, + _user_id: UserId, + ) -> Result { + Ok(SandboxJobSummary::default()) + } + + async fn sandbox_job_belongs_to_user( + &self, + _job_id: Uuid, + _user_id: UserId, + ) -> Result { + Ok(false) + } + + async fn update_sandbox_job_mode( + &self, + _id: Uuid, + _mode: SandboxMode, + ) -> Result<(), DatabaseError> { + Ok(()) + } + + async fn get_sandbox_job_mode(&self, _id: Uuid) -> Result, DatabaseError> { + Ok(None) + } + + async fn save_job_event( + &self, + _job_id: Uuid, + _event_type: SandboxEventType, + _data: &serde_json::Value, + ) -> Result<(), DatabaseError> { + Ok(()) + } + + async fn list_job_events( + &self, + _job_id: Uuid, + _before_id: Option, + _limit: Option, + ) -> Result, DatabaseError> { + Ok(vec![]) + } +} diff --git a/src/testing/null_db/null_database/settings_store.rs b/src/testing/null_db/null_database/settings_store.rs new file mode 100644 index 000000000..f89c1bb6d --- /dev/null +++ b/src/testing/null_db/null_database/settings_store.rs @@ -0,0 +1,67 @@ +//! Null implementation of NativeSettingsStore for NullDatabase. + +use std::collections::HashMap; + +use crate::db::{SettingKey, UserId}; +use crate::error::DatabaseError; +use crate::history::SettingRow; + +use super::NullDatabase; + +impl crate::db::NativeSettingsStore for NullDatabase { + async fn get_setting( + &self, + _user_id: UserId, + _key: SettingKey, + ) -> Result, DatabaseError> { + Ok(None) + } + + async fn get_setting_full( + &self, + _user_id: UserId, + _key: SettingKey, + ) -> Result, DatabaseError> { + Ok(None) + } + + async fn delete_setting( + &self, + _user_id: UserId, + _key: SettingKey, + ) -> Result { + Ok(false) + } + + async fn list_settings(&self, _user_id: UserId) -> Result, DatabaseError> { + Ok(vec![]) + } + + async fn set_setting( + &self, + _user_id: UserId, + _key: SettingKey, + _value: &serde_json::Value, + ) -> Result<(), DatabaseError> { + Ok(()) + } + + async fn get_all_settings( + &self, + _user_id: UserId, + ) -> Result, DatabaseError> { + Ok(HashMap::new()) + } + + async fn set_all_settings( + &self, + _user_id: UserId, + _settings: &HashMap, + ) -> Result<(), DatabaseError> { + Ok(()) + } + + async fn has_settings(&self, _user_id: UserId) -> Result { + Ok(false) + } +} diff --git a/src/testing/null_db/null_database/tool_failure_store.rs b/src/testing/null_db/null_database/tool_failure_store.rs new file mode 100644 index 000000000..d17a95ddd --- /dev/null +++ b/src/testing/null_db/null_database/tool_failure_store.rs @@ -0,0 +1,28 @@ +//! Null implementation of NativeToolFailureStore for NullDatabase. + +use crate::agent::BrokenTool; +use crate::error::DatabaseError; + +use super::NullDatabase; + +impl crate::db::NativeToolFailureStore for NullDatabase { + async fn record_tool_failure( + &self, + _tool_name: &str, + _error: &str, + ) -> Result<(), DatabaseError> { + Ok(()) + } + + async fn get_broken_tools(&self, _threshold: i32) -> Result, DatabaseError> { + Ok(vec![]) + } + + async fn mark_tool_repaired(&self, _tool_name: &str) -> Result<(), DatabaseError> { + Ok(()) + } + + async fn increment_repair_attempts(&self, _tool_name: &str) -> Result<(), DatabaseError> { + Ok(()) + } +} diff --git a/src/testing/null_db/null_database/workspace_store.rs b/src/testing/null_db/null_database/workspace_store.rs new file mode 100644 index 000000000..ec12472ad --- /dev/null +++ b/src/testing/null_db/null_database/workspace_store.rs @@ -0,0 +1,112 @@ +//! Null implementation of NativeWorkspaceStore for NullDatabase. + +use uuid::Uuid; + +use crate::db::{HybridSearchParams, InsertChunkParams}; +use crate::error::WorkspaceError; +use crate::workspace::{ + MemoryChunk as WorkspaceMemoryChunk, MemoryDocument as WorkspaceMemoryDocument, + SearchResult as WorkspaceSearchResult, WorkspaceEntry as WorkspaceWorkspaceEntry, +}; + +use super::NullDatabase; + +impl crate::db::NativeWorkspaceStore for NullDatabase { + async fn get_document_by_path( + &self, + _user_id: &str, + _agent_id: Option, + _path: &str, + ) -> Result { + Err(NullDatabase::doc_not_found("file")) + } + + async fn get_document_by_id( + &self, + _id: Uuid, + ) -> Result { + Err(NullDatabase::doc_not_found("id")) + } + + async fn get_or_create_document_by_path( + &self, + _user_id: &str, + _agent_id: Option, + _path: &str, + ) -> Result { + Err(NullDatabase::doc_not_found("file")) + } + + async fn update_document(&self, _id: Uuid, _content: &str) -> Result<(), WorkspaceError> { + Ok(()) + } + + async fn delete_document_by_path( + &self, + _user_id: &str, + _agent_id: Option, + _path: &str, + ) -> Result<(), WorkspaceError> { + Ok(()) + } + + async fn list_directory( + &self, + _user_id: &str, + _agent_id: Option, + _directory: &str, + ) -> Result, WorkspaceError> { + Ok(vec![]) + } + + async fn list_all_paths( + &self, + _user_id: &str, + _agent_id: Option, + ) -> Result, WorkspaceError> { + Ok(vec![]) + } + + async fn list_documents( + &self, + _user_id: &str, + _agent_id: Option, + ) -> Result, WorkspaceError> { + Ok(vec![]) + } + + async fn delete_chunks(&self, _document_id: Uuid) -> Result<(), WorkspaceError> { + Ok(()) + } + + async fn insert_chunk(&self, _params: InsertChunkParams<'_>) -> Result { + self.next_synthetic_uuid() + .map_err(|err| WorkspaceError::IoError { + reason: err.to_string(), + }) + } + + async fn update_chunk_embedding( + &self, + _chunk_id: Uuid, + _embedding: &[f32], + ) -> Result<(), WorkspaceError> { + Ok(()) + } + + async fn get_chunks_without_embeddings( + &self, + _user_id: &str, + _agent_id: Option, + _limit: usize, + ) -> Result, WorkspaceError> { + Ok(vec![]) + } + + async fn hybrid_search( + &self, + _params: HybridSearchParams<'_>, + ) -> Result, WorkspaceError> { + Ok(vec![]) + } +} diff --git a/src/testing/settings_tests.rs b/src/testing/settings_tests.rs index b450c0554..effd3b58d 100644 --- a/src/testing/settings_tests.rs +++ b/src/testing/settings_tests.rs @@ -1,9 +1,11 @@ //! libSQL settings-store regression tests for the shared test harness. +#[cfg(all(feature = "libsql", feature = "test-helpers"))] use super::*; +#[cfg(all(feature = "libsql", feature = "test-helpers"))] use rstest::rstest; -#[cfg(feature = "libsql")] +#[cfg(all(feature = "libsql", feature = "test-helpers"))] #[tokio::test] async fn test_settings_crud() { let harness = TestHarnessBuilder::new().build().await; @@ -66,7 +68,7 @@ async fn test_settings_crud() { assert!(!deleted); } -#[cfg(feature = "libsql")] +#[cfg(all(feature = "libsql", feature = "test-helpers"))] async fn run_settings_crud_flow(db: &Arc, user_id: UserId, key: SettingKey) { let initial_value = serde_json::json!("dark"); let updated_value = serde_json::json!("light"); @@ -114,7 +116,7 @@ async fn run_settings_crud_flow(db: &Arc, user_id: UserId, key: Se ); } -#[cfg(feature = "libsql")] +#[cfg(all(feature = "libsql", feature = "test-helpers"))] #[rstest] #[case(false)] #[case(true)] @@ -135,7 +137,7 @@ async fn test_settings_crud_variants(#[case] use_owned_strings: bool) { run_settings_crud_flow(db, user_id, key).await; } -#[cfg(feature = "libsql")] +#[cfg(all(feature = "libsql", feature = "test-helpers"))] #[tokio::test] async fn test_settings_bulk_operations() { let harness = TestHarnessBuilder::new().build().await; diff --git a/src/testing/snapshots/ironclaw__testing__worker_harness__terminal_persistence_result_completed.snap b/src/testing/snapshots/ironclaw__testing__worker_harness__terminal_persistence_result_completed.snap new file mode 100644 index 000000000..e11494bf6 --- /dev/null +++ b/src/testing/snapshots/ironclaw__testing__worker_harness__terminal_persistence_result_completed.snap @@ -0,0 +1,10 @@ +--- +source: src/worker/job.rs +assertion_line: 2077 +expression: "&event_call.data" +--- +{ + "message": "Job completed successfully", + "status": "completed", + "success": true +} diff --git a/src/testing/snapshots/ironclaw__testing__worker_harness__terminal_persistence_result_failed.snap b/src/testing/snapshots/ironclaw__testing__worker_harness__terminal_persistence_result_failed.snap new file mode 100644 index 000000000..732d3334d --- /dev/null +++ b/src/testing/snapshots/ironclaw__testing__worker_harness__terminal_persistence_result_failed.snap @@ -0,0 +1,10 @@ +--- +source: src/worker/job.rs +assertion_line: 2077 +expression: "&event_call.data" +--- +{ + "message": "Execution failed: budget exceeded", + "status": "failed", + "success": false +} diff --git a/src/testing/snapshots/ironclaw__testing__worker_harness__terminal_persistence_result_stuck.snap b/src/testing/snapshots/ironclaw__testing__worker_harness__terminal_persistence_result_stuck.snap new file mode 100644 index 000000000..f7d3916b8 --- /dev/null +++ b/src/testing/snapshots/ironclaw__testing__worker_harness__terminal_persistence_result_stuck.snap @@ -0,0 +1,10 @@ +--- +source: src/worker/job.rs +assertion_line: 2077 +expression: "&event_call.data" +--- +{ + "message": "Job stuck: timeout", + "status": "stuck", + "success": false +} diff --git a/src/testing/worker_harness.rs b/src/testing/worker_harness.rs new file mode 100644 index 000000000..bc5c444ab --- /dev/null +++ b/src/testing/worker_harness.rs @@ -0,0 +1,339 @@ +//! Worker test harness for job module tests. +//! +//! Provides helpers for building workers with various configurations for testing. + +use std::sync::Arc; +use std::time::Duration; + +use anyhow::Context as _; + +use crate::config::SafetyConfig; +use crate::context::{ContextManager, JobState}; +use crate::db::Database; +use crate::hooks::HookRegistry; +use crate::llm::{ + CompletionRequest, CompletionResponse, NativeLlmProvider, ToolCompletionRequest, + ToolCompletionResponse, +}; +use crate::safety::SafetyLayer; +use crate::testing::null_db::{CapturingStore, EventCall, StatusCall}; +use crate::tools::{ApprovalContext, Tool, ToolRegistry}; +use crate::worker::Worker; +use crate::worker::job::WorkerDeps; + +/// Stub LLM provider (never called in worker unit tests). +pub struct StubLlm; + +impl NativeLlmProvider for StubLlm { + fn model_name(&self) -> &str { + "stub" + } + fn cost_per_token(&self) -> (rust_decimal::Decimal, rust_decimal::Decimal) { + (rust_decimal::Decimal::ZERO, rust_decimal::Decimal::ZERO) + } + async fn complete( + &self, + _req: CompletionRequest, + ) -> Result { + // Return a deterministic stub response instead of panicking. + // This allows tests that construct a Worker to run without + // hitting unimplemented! if the LLM path is accidentally exercised. + Ok(CompletionResponse { + content: "stub response".to_string(), + input_tokens: 0, + output_tokens: 0, + finish_reason: crate::llm::FinishReason::Stop, + cache_read_input_tokens: 0, + cache_creation_input_tokens: 0, + }) + } + async fn complete_with_tools( + &self, + _req: ToolCompletionRequest, + ) -> Result { + // Return a deterministic stub response instead of panicking. + Ok(ToolCompletionResponse { + content: None, + tool_calls: vec![], + input_tokens: 0, + output_tokens: 0, + finish_reason: crate::llm::FinishReason::Stop, + cache_read_input_tokens: 0, + cache_creation_input_tokens: 0, + }) + } +} + +/// Build a ToolRegistry containing the given tools. +pub async fn build_registry(tools: Vec>) -> ToolRegistry { + let registry = ToolRegistry::new(); + for tool in tools { + registry.register(tool).await; + } + registry +} + +/// Build WorkerDeps with the given components. +pub fn base_deps( + cm: Arc, + tools: Arc, + store: Option>, + approval_context: Option, +) -> WorkerDeps { + WorkerDeps { + context_manager: cm, + llm: Arc::new(StubLlm), + safety: Arc::new(SafetyLayer::new(&SafetyConfig { + max_output_length: 100_000, + injection_check_enabled: false, + })), + tools, + store, + hooks: Arc::new(HookRegistry::new()), + timeout: Duration::from_secs(30), + use_planning: false, + sse_tx: None, + approval_context, + http_interceptor: None, + } +} + +/// Build a Worker wired to a ToolRegistry containing the given tools. +pub async fn make_worker(tools: Vec>) -> anyhow::Result { + let registry = Arc::new(build_registry(tools).await); + let cm = Arc::new(ContextManager::new(5)); + let job_id = cm + .create_job("test", "test job") + .await + .context("make_worker: create_job failed")?; + let deps = base_deps(cm, registry, None, None); + + Ok(Worker::new(job_id, deps)) +} + +/// Build a Worker with a real database store (libsql feature required). +#[cfg(all(feature = "libsql", feature = "test-helpers"))] +pub async fn make_worker_with_store( + tools: Vec>, +) -> anyhow::Result<(Worker, Arc, tempfile::TempDir)> { + use crate::db::libsql::LibSqlBackend; + use tempfile::tempdir; + + let registry = Arc::new(build_registry(tools).await); + let cm = Arc::new(ContextManager::new(5)); + let job_id = cm + .create_job("test", "test job") + .await + .context("make_worker_with_store: create_job failed")?; + let dir = tempdir()?; + let path = dir.path().join("worker-test.db"); + let backend = LibSqlBackend::new_local(&path) + .await + .context("make_worker_with_store: LibSqlBackend::new_local failed")?; + backend + .run_migrations() + .await + .context("make_worker_with_store: run_migrations failed")?; + let store: Arc = Arc::new(backend); + let ctx = cm + .get_context(job_id) + .await + .context("make_worker_with_store: get_context failed")?; + store + .save_job(&ctx) + .await + .context("make_worker_with_store: save_job failed")?; + let deps = base_deps(cm, registry, Some(store.clone()), None); + + Ok((Worker::new(job_id, deps), store, dir)) +} + +/// Build a Worker with a capturing store for characterisation tests. +pub async fn make_worker_with_capturing_store( + tools: Vec>, +) -> anyhow::Result<(Worker, Arc)> { + let registry = Arc::new(build_registry(tools).await); + let cm = Arc::new(ContextManager::new(5)); + let job_id = cm + .create_job("test", "test job") + .await + .context("make_worker_with_capturing_store: create_job failed")?; + + let store = Arc::new(CapturingStore::new()); + let store_dyn: Arc = store.clone(); + let deps = base_deps(cm, registry, Some(store_dyn), None); + + Ok((Worker::new(job_id, deps), store)) +} + +/// Transition a worker's job to InProgress state. +pub async fn transition_to_in_progress(worker: &Worker) -> anyhow::Result<()> { + use crate::context::JobContext; + worker + .context_manager() + .update_context(worker.job_id, |ctx: &mut JobContext| { + ctx.transition_to(JobState::InProgress, None) + }) + .await + .context("transition_to_in_progress: update_context failed")? + .map_err(|s| anyhow::anyhow!("context transition failed: {s}"))?; + Ok(()) +} + +/// Bundles the expected terminal-state outcome for persistence assertions. +pub struct TerminalPersistenceExpectation<'a> { + pub state: JobState, + pub status_str: &'a str, + pub success: bool, + pub message: Option, + pub reason: Option<&'a str>, +} + +fn terminal_event_message( + expected_state: JobState, + expected_reason: Option<&str>, +) -> Option { + match (expected_state, expected_reason) { + (JobState::Completed, _) => Some("Job completed successfully".to_string()), + (JobState::Failed, Some(reason)) => Some(format!("Execution failed: {reason}")), + (JobState::Stuck, Some(reason)) => Some(format!("Job stuck: {reason}")), + _ => None, + } +} + +/// Check captured persistence calls against expected values. +pub fn check_terminal_persistence_calls( + status_call: &StatusCall, + event_call: &EventCall, + expected: &TerminalPersistenceExpectation<'_>, +) { + assert_eq!(status_call.status, expected.state); + if let Some(reason) = expected.reason { + assert_eq!(status_call.reason.as_deref(), Some(reason)); + } else { + assert!( + status_call.reason.is_none(), + "Expected no failure reason, but got {:?}", + status_call.reason + ); + } + assert_eq!(event_call.event_type, "result"); + assert_eq!(event_call.data["status"], expected.status_str); + assert_eq!(event_call.data["success"], expected.success); + if let Some(message) = &expected.message { + assert_eq!(event_call.data["message"], message.as_str()); + } else { + assert!( + event_call.data["message"].is_null(), + "Expected no event message, but got {:?}", + event_call.data["message"] + ); + } +} + +/// Assert terminal persistence state matches expected values. +pub async fn assert_terminal_persistence( + store: &CapturingStore, + expected_state: JobState, + expected_status_str: &str, + expected_reason: Option<&str>, +) { + let status_call = store + .calls() + .last_status + .lock() + .await + .clone() + .expect("expected a status update"); + let event_call = store + .calls() + .last_event + .lock() + .await + .clone() + .expect("expected a job event"); + check_terminal_persistence_calls( + &status_call, + &event_call, + &TerminalPersistenceExpectation { + state: expected_state, + status_str: expected_status_str, + success: expected_state == JobState::Completed, + message: terminal_event_message(expected_state, expected_reason), + reason: expected_reason, + }, + ); +} + +/// Assert terminal persistence state with snapshot testing. +pub async fn assert_terminal_persistence_with_snapshot( + store: &CapturingStore, + expected_state: JobState, + expected_status_str: &str, + expected_reason: Option<&str>, +) { + let status_call = store + .calls() + .last_status + .lock() + .await + .clone() + .expect("expected a status update"); + let event_call = store + .calls() + .last_event + .lock() + .await + .clone() + .expect("expected a job event"); + check_terminal_persistence_calls( + &status_call, + &event_call, + &TerminalPersistenceExpectation { + state: expected_state, + status_str: expected_status_str, + success: expected_state == JobState::Completed, + message: terminal_event_message(expected_state, expected_reason), + reason: expected_reason, + }, + ); + insta::assert_json_snapshot!( + format!("terminal_persistence_result_{}", expected_status_str), + &event_call.data + ); +} + +/// Methods for driving terminal state transitions in tests. +#[derive(Debug, Clone, Copy)] +pub enum TerminalMethod { + Completed, + Failed(&'static str), + Stuck(&'static str), +} + +impl TerminalMethod { + /// Apply this terminal transition to a worker. + pub async fn apply_transition(&self, worker: &Worker) -> anyhow::Result<()> { + match self { + TerminalMethod::Completed => { + worker + .mark_completed() + .await + .context("apply_transition: mark_completed failed")?; + } + TerminalMethod::Failed(reason) => { + worker + .mark_failed(reason) + .await + .context("apply_transition: mark_failed failed")?; + } + TerminalMethod::Stuck(reason) => { + worker + .mark_stuck(reason) + .await + .context("apply_transition: mark_stuck failed")?; + } + } + Ok(()) + } +} diff --git a/src/worker/api/types.rs b/src/worker/api/types.rs index 3a2bd9eb8..dca0177d6 100644 --- a/src/worker/api/types.rs +++ b/src/worker/api/types.rs @@ -211,7 +211,7 @@ pub struct JobEventPayload { } /// Response from the prompt polling endpoint. -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] pub struct PromptResponse { pub content: String, #[serde(default)] diff --git a/src/worker/job.rs b/src/worker/job.rs index f97893a34..40ff00087 100644 --- a/src/worker/job.rs +++ b/src/worker/job.rs @@ -19,7 +19,7 @@ use crate::agent::scheduler::WorkerMessage; use crate::agent::task::TaskOutput; use crate::channels::web::types::SseEvent; use crate::context::{ContextManager, JobState}; -use crate::db::Database; +use crate::db::{Database, TerminalJobPersistence}; use crate::error::Error; use crate::hooks::HookRegistry; use crate::llm::{ @@ -56,8 +56,16 @@ pub struct WorkerDeps { } /// Worker that executes a single job. +/// +/// The scheduler and worker-oriented unit tests own this type. It coordinates +/// in-memory job state, tool execution, and terminal persistence for one job. pub struct Worker { - job_id: Uuid, + /// Stable job identifier exposed to internal callers and unit tests. + /// + /// Callers use this to correlate scheduler state, context-manager lookups, + /// and persistence assertions. Reading this field has no side effects and + /// does not itself make any state durable. + pub(crate) job_id: Uuid, deps: WorkerDeps, } @@ -79,7 +87,13 @@ impl Worker { } // Convenience accessors to avoid deps.field everywhere - fn context_manager(&self) -> &Arc { + /// Return the shared context manager for this worker's job. + /// + /// Internal crates and unit tests use this accessor to inspect or prepare + /// the in-memory job state before driving the worker. This is a pure + /// accessor: it does not persist anything and requires no rollback by the + /// caller. + pub(crate) fn context_manager(&self) -> &Arc { &self.deps.context_manager } @@ -108,21 +122,6 @@ impl Worker { self.deps.use_planning } - /// Persist a terminal job status before returning to the caller. - async fn persist_status(&self, status: JobState, reason: Option) -> Result<(), Error> { - if let Some(store) = self.store() { - let job_id = self.job_id; - store - .update_job_status(job_id, status, reason.as_deref()) - .await - .map_err(|e| crate::error::JobError::PersistenceError { - id: job_id, - reason: e.to_string(), - })?; - } - Ok(()) - } - /// Fire-and-forget persistence and SSE broadcast for non-terminal job /// events only. /// @@ -150,16 +149,24 @@ impl Worker { self.broadcast_event(event_type, &data); } - /// Persist a terminal result event before returning to the caller. - async fn log_terminal_result_event( + /// Persist the terminal event and terminal status in one durable write. + async fn persist_terminal_result_and_status( &self, + status: JobState, + failure_reason: Option<&str>, event_type: &str, - data: serde_json::Value, + data: &serde_json::Value, ) -> Result<(), Error> { let job_id = self.job_id; if let Some(store) = self.store() { store - .save_job_event(job_id, crate::db::SandboxEventType::from(event_type), &data) + .persist_terminal_result_and_status(TerminalJobPersistence { + job_id, + status, + failure_reason, + event_type: crate::db::SandboxEventType::from(event_type), + event_data: data, + }) .await .map_err(|e| crate::error::JobError::PersistenceError { id: job_id, @@ -167,7 +174,7 @@ impl Worker { })?; } - self.broadcast_event(event_type, &data); + self.broadcast_event(event_type, data); Ok(()) } @@ -242,6 +249,38 @@ impl Worker { } } + async fn transition_terminal_state(&self, transition: F) -> Result + where + F: FnOnce(&mut crate::context::JobContext) -> Result<(), String>, + { + let previous = self + .context_manager() + .update_context(self.job_id, |ctx| { + let previous = ctx.state; + let result = if matches!( + previous, + JobState::Completed | JobState::Failed | JobState::Stuck | JobState::Cancelled + ) { + Err(format!( + "Cannot transition from terminal worker state {}", + previous + )) + } else { + transition(ctx) + }; + (previous, result) + }) + .await?; + + let (previous_state, transition_result) = previous; + transition_result.map_err(|reason| crate::error::JobError::ContextError { + id: self.job_id, + reason, + })?; + + Ok(previous_state) + } + /// Run the worker until the job is complete or stopped. pub async fn run(self, mut rx: mpsc::Receiver) -> Result<(), Error> { tracing::info!("Worker starting for job {}", self.job_id); @@ -998,83 +1037,130 @@ Report when the job is complete or if you encounter issues you cannot resolve."# Self::execute_tool_inner(&self.deps, self.job_id, tool_name, params).await } - async fn mark_completed(&self) -> Result<(), Error> { - self.context_manager() - .update_context(self.job_id, |ctx| { + /// Mark the job completed and durably persist that terminal outcome. + /// + /// Internal scheduler paths and worker unit tests call this once the job's + /// successful result is known. The method first moves the in-memory + /// [`JobContext`] to `Completed`, then attempts an atomic terminal + /// persistence write for the result event and job status. If persistence + /// fails, it performs a best-effort rollback to the previous in-memory + /// state before returning the error; callers do not need to issue an extra + /// rollback step, but they should treat the terminal outcome as not + /// durable. + pub(crate) async fn mark_completed(&self) -> Result<(), Error> { + self.apply_terminal_transition( + JobState::Completed, + Some("Job completed successfully"), + "completed", + "Job completed successfully".to_string(), + |ctx| { ctx.transition_to( JobState::Completed, Some("Job completed successfully".to_string()), ) - }) - .await? - .map_err(|s| crate::error::JobError::ContextError { - id: self.job_id, - reason: s, - })?; - - self.log_terminal_result_event( - "result", - serde_json::json!({ - "status": "completed", - "success": true, - "message": "Job completed successfully", - }), - ) - .await?; - self.persist_status( - JobState::Completed, - Some("Job completed successfully".to_string()), + }, + "mark_completed", ) - .await?; - Ok(()) + .await } - async fn mark_failed(&self, reason: &str) -> Result<(), Error> { - self.context_manager() - .update_context(self.job_id, |ctx| { - ctx.transition_to(JobState::Failed, Some(reason.to_string())) - }) - .await? - .map_err(|s| crate::error::JobError::ContextError { - id: self.job_id, - reason: s, - })?; + /// Roll back the context to the previous state on persistence failure. + async fn rollback_context(&self, previous: Option, operation: &str) { + if let Some(state) = previous { + match self + .context_manager() + .update_context(self.job_id, |ctx| { + ctx.set_state_rollback(state); + }) + .await + { + Ok(()) => { + tracing::error!( + job_id = %self.job_id, + operation, + "Rolled back context state after persistence failure" + ); + } + Err(e) => { + tracing::error!( + job_id = %self.job_id, + operation, + error = %e, + "Failed to roll back context state after persistence failure" + ); + } + } + } + } - self.log_terminal_result_event( - "result", - serde_json::json!({ - "status": "failed", - "success": false, - "message": format!("Execution failed: {}", reason), - }), - ) - .await?; - self.persist_status(JobState::Failed, Some(reason.to_string())) - .await?; + async fn apply_terminal_transition( + &self, + status: JobState, + reason: Option<&str>, + status_str: &str, + message: String, + transition: F, + op_name: &'static str, + ) -> Result<(), Error> + where + F: FnOnce(&mut crate::context::JobContext) -> Result<(), String>, + { + let previous = self.transition_terminal_state(transition).await?; + let event = serde_json::json!({ + "status": status_str, + "success": matches!(status, JobState::Completed), + "message": message, + }); + if let Err(e) = self + .persist_terminal_result_and_status(status, reason, "result", &event) + .await + { + self.rollback_context(Some(previous), op_name).await; + return Err(e); + } Ok(()) } - async fn mark_stuck(&self, reason: &str) -> Result<(), Error> { - self.context_manager() - .update_context(self.job_id, |ctx| ctx.mark_stuck(reason)) - .await? - .map_err(|s| crate::error::JobError::ContextError { - id: self.job_id, - reason: s, - })?; + /// Mark the job failed and durably persist the terminal failure. + /// + /// Internal scheduler paths and unit tests call this when execution has + /// reached a terminal error. The method updates the in-memory + /// [`JobContext`] to `Failed`, then attempts one atomic persistence write + /// for the terminal event and status. If that write fails, it best-effort + /// rolls the context back to the previous state before returning the + /// persistence error; callers should not perform additional rollback, but + /// must treat the failure as non-durable. + pub(crate) async fn mark_failed(&self, reason: &str) -> Result<(), Error> { + self.apply_terminal_transition( + JobState::Failed, + Some(reason), + "failed", + format!("Execution failed: {}", reason), + |ctx| ctx.transition_to(JobState::Failed, Some(reason.to_string())), + "mark_failed", + ) + .await + } - self.log_terminal_result_event( - "result", - serde_json::json!({ - "status": "stuck", - "success": false, - "message": format!("Job stuck: {}", reason), - }), + /// Mark the job stuck and durably persist the terminal stuck result. + /// + /// Internal scheduler timeout handling and unit tests call this when the + /// worker cannot make further progress. The method transitions the + /// in-memory [`JobContext`] to `Stuck`, then attempts one atomic terminal + /// persistence write for the result event and status. If persistence + /// fails, it best-effort rolls the context back to the prior state before + /// returning the error; callers do not need to clean up the context + /// themselves, but the stuck outcome should be treated as non-durable. + pub(crate) async fn mark_stuck(&self, reason: &str) -> Result<(), Error> { + self.apply_terminal_transition( + JobState::Stuck, + Some(reason), + "stuck", + format!("Job stuck: {}", reason), + |ctx| ctx.mark_stuck(reason), + "mark_stuck", ) - .await?; - self.persist_status(JobState::Stuck, Some(reason.to_string())) - .await?; - Ok(()) + .await } } @@ -1426,13 +1512,10 @@ mod tests { use std::sync::atomic::{AtomicUsize, Ordering}; use super::*; - use crate::config::SafetyConfig; use crate::context::JobContext; use crate::llm::ToolSelection; - use crate::llm::{ - CompletionRequest, CompletionResponse, ToolCompletionRequest, ToolCompletionResponse, - }; - use crate::safety::SafetyLayer; + use crate::testing::CapturingStore; + use crate::testing::worker_harness::*; use crate::tools::{NativeTool, Tool, ToolError as ToolExecError, ToolOutput}; /// A test tool that sleeps for a configurable duration before returning. @@ -1473,110 +1556,6 @@ mod tests { } } - /// Stub LLM provider (never called in these tests). - struct StubLlm; - - impl crate::llm::NativeLlmProvider for StubLlm { - fn model_name(&self) -> &str { - "stub" - } - fn cost_per_token(&self) -> (rust_decimal::Decimal, rust_decimal::Decimal) { - (rust_decimal::Decimal::ZERO, rust_decimal::Decimal::ZERO) - } - async fn complete( - &self, - _req: CompletionRequest, - ) -> Result { - unimplemented!("stub") - } - async fn complete_with_tools( - &self, - _req: ToolCompletionRequest, - ) -> Result { - unimplemented!("stub") - } - } - - /// Build a Worker wired to a ToolRegistry containing the given tools. - async fn make_worker(tools: Vec>) -> Worker { - let registry = ToolRegistry::new(); - for t in tools { - registry.register(t).await; - } - - let cm = Arc::new(crate::context::ContextManager::new(5)); - let job_id = cm.create_job("test", "test job").await.unwrap(); - - let deps = WorkerDeps { - context_manager: cm, - llm: Arc::new(StubLlm), - safety: Arc::new(SafetyLayer::new(&SafetyConfig { - max_output_length: 100_000, - injection_check_enabled: false, - })), - tools: Arc::new(registry), - store: None, - hooks: Arc::new(crate::hooks::HookRegistry::new()), - timeout: Duration::from_secs(30), - use_planning: false, - sse_tx: None, - approval_context: None, - http_interceptor: None, - }; - - Worker::new(job_id, deps) - } - - #[cfg(feature = "libsql")] - async fn make_worker_with_store( - tools: Vec>, - ) -> (Worker, Arc, tempfile::TempDir) { - use crate::db::libsql::LibSqlBackend; - use tempfile::tempdir; - - let registry = ToolRegistry::new(); - for t in tools { - registry.register(t).await; - } - - let cm = Arc::new(crate::context::ContextManager::new(5)); - let job_id = cm - .create_job("test", "test job") - .await - .expect("failed to create job"); - let dir = tempdir().expect("failed to create tempdir"); - let path = dir.path().join("worker-test.db"); - let backend = LibSqlBackend::new_local(&path) - .await - .expect("failed to open libsql backend"); - backend - .run_migrations() - .await - .expect("failed to run migrations"); - let store: Arc = Arc::new(backend); - let ctx = cm.get_context(job_id).await.expect("failed to get context"); - store.save_job(&ctx).await.expect("failed to save job"); - - let deps = WorkerDeps { - context_manager: cm, - llm: Arc::new(StubLlm), - safety: Arc::new(SafetyLayer::new(&SafetyConfig { - max_output_length: 100_000, - injection_check_enabled: false, - })), - tools: Arc::new(registry), - store: Some(store.clone()), - hooks: Arc::new(crate::hooks::HookRegistry::new()), - timeout: Duration::from_secs(30), - use_planning: false, - sse_tx: None, - approval_context: None, - http_interceptor: None, - }; - - (Worker::new(job_id, deps), store, dir) - } - #[test] fn test_tool_selection_preserves_call_id() { let selection = ToolSelection { @@ -1598,7 +1577,7 @@ mod tests { // See: test_completion_signals, test_completion_negative, etc. #[tokio::test] - async fn test_parallel_speedup() { + async fn test_parallel_speedup() -> Result<(), Box> { let current_active = Arc::new(AtomicUsize::new(0)); let max_active = Arc::new(AtomicUsize::new(0)); let tools: Vec> = (0..3) @@ -1612,7 +1591,7 @@ mod tests { }) .collect(); - let worker = make_worker(tools).await; + let worker = make_worker(tools).await?; let selections: Vec = (0..3) .map(|i| ToolSelection { @@ -1635,69 +1614,77 @@ mod tests { "Expected parallel tool execution to overlap, but max concurrency was {}", max_active.load(Ordering::SeqCst) ); + Ok(()) + } + + fn slow_tool( + name: &str, + delay_ms: u64, + current: &Arc, + max: &Arc, + ) -> Arc { + Arc::new(SlowTool { + tool_name: name.into(), + delay: Duration::from_millis(delay_ms), + current_active: Arc::clone(current), + max_active: Arc::clone(max), + }) + } + + fn tool_selection(name: &str, call_id: &str) -> ToolSelection { + ToolSelection { + tool_name: name.into(), + parameters: serde_json::json!({}), + reasoning: String::new(), + alternatives: vec![], + tool_call_id: call_id.into(), + } } #[tokio::test] - async fn test_result_ordering_preserved() { + async fn test_result_ordering_preserved() -> Result<(), Box> + { let current_active = Arc::new(AtomicUsize::new(0)); let max_active = Arc::new(AtomicUsize::new(0)); + let tools: Vec> = vec![ - Arc::new(SlowTool { - tool_name: "tool_a".into(), - delay: Duration::from_millis(300), - current_active: Arc::clone(¤t_active), - max_active: Arc::clone(&max_active), - }), - Arc::new(SlowTool { - tool_name: "tool_b".into(), - delay: Duration::from_millis(100), - current_active: Arc::clone(¤t_active), - max_active: Arc::clone(&max_active), - }), - Arc::new(SlowTool { - tool_name: "tool_c".into(), - delay: Duration::from_millis(200), - current_active: Arc::clone(¤t_active), - max_active: Arc::clone(&max_active), - }), + slow_tool("tool_a", 300, ¤t_active, &max_active), + slow_tool("tool_b", 100, ¤t_active, &max_active), + slow_tool("tool_c", 200, ¤t_active, &max_active), ]; - let worker = make_worker(tools).await; + let worker = make_worker(tools).await?; let selections = vec![ - ToolSelection { - tool_name: "tool_a".into(), - parameters: serde_json::json!({}), - reasoning: String::new(), - alternatives: vec![], - tool_call_id: "call_a".into(), - }, - ToolSelection { - tool_name: "tool_b".into(), - parameters: serde_json::json!({}), - reasoning: String::new(), - alternatives: vec![], - tool_call_id: "call_b".into(), - }, - ToolSelection { - tool_name: "tool_c".into(), - parameters: serde_json::json!({}), - reasoning: String::new(), - alternatives: vec![], - tool_call_id: "call_c".into(), - }, + tool_selection("tool_a", "call_a"), + tool_selection("tool_b", "call_b"), + tool_selection("tool_c", "call_c"), ]; let results = worker.execute_tools_parallel(&selections).await; - assert!(results[0].result.as_ref().unwrap().contains("done_tool_a")); - assert!(results[1].result.as_ref().unwrap().contains("done_tool_b")); - assert!(results[2].result.as_ref().unwrap().contains("done_tool_c")); + for (i, (result, expected)) in results + .iter() + .zip(["done_tool_a", "done_tool_b", "done_tool_c"]) + .enumerate() + { + let result_str = result + .result + .as_ref() + .unwrap_or_else(|_| panic!("tool {i} should return a captured result")) + .clone(); + assert!( + result_str.contains(expected), + "result[{i}] should contain '{expected}'", + ); + } + Ok(()) } #[tokio::test] - async fn test_missing_tool_produces_error_not_panic() { - let worker = make_worker(vec![]).await; + async fn test_missing_tool_produces_error_not_panic() + -> Result<(), Box> { + let worker = make_worker(vec![]).await?; let selections = vec![ToolSelection { tool_name: "nonexistent_tool".into(), @@ -1713,11 +1700,13 @@ mod tests { results[0].result.is_err(), "Missing tool should produce an error, not a panic" ); + Ok(()) } #[tokio::test] - async fn test_mark_completed_twice_returns_error() { - let worker = make_worker(vec![]).await; + async fn test_mark_completed_twice_returns_error() + -> Result<(), Box> { + let worker = make_worker(vec![]).await?; worker .context_manager() @@ -1725,16 +1714,19 @@ mod tests { ctx.transition_to(JobState::InProgress, None) }) .await - .unwrap() - .unwrap(); + .expect("failed to update context before completion test") + .expect("failed to transition job to in-progress before completion test"); - worker.mark_completed().await.unwrap(); + worker + .mark_completed() + .await + .expect("failed to mark job completed in duplicate-completion test"); let ctx = worker .context_manager() .get_context(worker.job_id) .await - .unwrap(); + .expect("failed to reload job context after first completion"); assert_eq!(ctx.state, JobState::Completed); let result = worker.mark_completed().await; @@ -1742,12 +1734,14 @@ mod tests { result.is_err(), "Completed → Completed transition should be rejected by state machine" ); + Ok(()) } - #[cfg(feature = "libsql")] + #[cfg(all(feature = "libsql", feature = "test-helpers"))] #[tokio::test] - async fn test_mark_completed_persists_result_before_returning() { - let (worker, store, _dir) = make_worker_with_store(vec![]).await; + async fn test_mark_completed_persists_result_before_returning() + -> Result<(), Box> { + let (worker, store, _dir) = make_worker_with_store(vec![]).await?; worker .context_manager() @@ -1777,39 +1771,80 @@ mod tests { assert_eq!(events.len(), 1); assert_eq!(events[0].event_type, "result"); assert_eq!(events[0].data["status"], "completed"); + Ok(()) + } + + #[cfg(feature = "libsql")] + async fn make_worker_with_unpersisted_store( + tools: Vec>, + ) -> anyhow::Result<(Worker, tempfile::TempDir)> { + use crate::db::libsql::LibSqlBackend; + use tempfile::tempdir; + + let registry = Arc::new(build_registry(tools).await); + let cm = Arc::new(ContextManager::new(5)); + let job_id = cm.create_job("test", "test job").await?; + let dir = tempdir()?; + let path = dir.path().join("worker-test.db"); + let backend = LibSqlBackend::new_local(&path).await?; + backend.run_migrations().await?; + let store: Arc = Arc::new(backend); + let deps = base_deps(cm, registry, Some(store), None); + + Ok((Worker::new(job_id, deps), dir)) + } + + #[cfg(feature = "libsql")] + async fn assert_terminal_persistence_failure_rolls_back( + transition: TerminalMethod, + ) -> Result<(), Box> { + let (worker, _dir) = make_worker_with_unpersisted_store(vec![]).await?; + transition_to_in_progress(&worker).await?; + + let result = transition.apply_transition(&worker).await; + assert!(result.is_err(), "terminal persistence should fail"); + + let ctx = worker.context_manager().get_context(worker.job_id).await?; + assert_eq!( + ctx.state, + JobState::InProgress, + "persistence failure should roll context back to InProgress" + ); + Ok(()) + } + + #[cfg(feature = "libsql")] + #[tokio::test] + async fn test_mark_completed_rolls_back_context_when_persistence_fails() + -> Result<(), Box> { + assert_terminal_persistence_failure_rolls_back(TerminalMethod::Completed).await + } + + #[cfg(feature = "libsql")] + #[tokio::test] + async fn test_mark_failed_rolls_back_context_when_persistence_fails() + -> Result<(), Box> { + assert_terminal_persistence_failure_rolls_back(TerminalMethod::Failed("test failure")).await + } + + #[cfg(feature = "libsql")] + #[tokio::test] + async fn test_mark_stuck_rolls_back_context_when_persistence_fails() + -> Result<(), Box> { + assert_terminal_persistence_failure_rolls_back(TerminalMethod::Stuck("test stuck")).await } /// Build a Worker with the given approval context. async fn make_worker_with_approval( tools: Vec>, approval_context: Option, - ) -> Worker { - let registry = ToolRegistry::new(); - for t in tools { - registry.register(t).await; - } - + ) -> Result> { + let registry = Arc::new(build_registry(tools).await); let cm = Arc::new(crate::context::ContextManager::new(5)); - let job_id = cm.create_job("test", "test job").await.unwrap(); - - let deps = WorkerDeps { - context_manager: cm, - llm: Arc::new(StubLlm), - safety: Arc::new(SafetyLayer::new(&SafetyConfig { - max_output_length: 100_000, - injection_check_enabled: false, - })), - tools: Arc::new(registry), - store: None, - hooks: Arc::new(crate::hooks::HookRegistry::new()), - timeout: Duration::from_secs(30), - use_planning: false, - sse_tx: None, - approval_context, - http_interceptor: None, - }; + let job_id = cm.create_job("test", "test job").await?; + let deps = base_deps(cm, registry, None, approval_context); - Worker::new(job_id, deps) + Ok(Worker::new(job_id, deps)) } /// A tool that requires approval (UnlessAutoApproved). @@ -1881,8 +1916,9 @@ mod tests { } #[tokio::test] - async fn test_approval_context_unblocks_unless_auto_approved() { - let worker_blocked = make_worker_with_approval(vec![Arc::new(ApprovalTool)], None).await; + async fn test_approval_context_unblocks_unless_auto_approved() + -> Result<(), Box> { + let worker_blocked = make_worker_with_approval(vec![Arc::new(ApprovalTool)], None).await?; let result = worker_blocked .execute_tool("needs_approval", &serde_json::json!({})) .await; @@ -1895,20 +1931,22 @@ mod tests { vec![Arc::new(ApprovalTool)], Some(crate::tools::ApprovalContext::autonomous()), ) - .await; + .await?; let result = worker_allowed .execute_tool("needs_approval", &serde_json::json!({})) .await; assert!(result.is_ok(), "Should be allowed with autonomous context"); + Ok(()) } #[tokio::test] - async fn test_approval_context_blocks_always_unless_permitted() { + async fn test_approval_context_blocks_always_unless_permitted() + -> Result<(), Box> { let worker_blocked = make_worker_with_approval( vec![Arc::new(AlwaysApprovalTool)], Some(crate::tools::ApprovalContext::autonomous()), ) - .await; + .await?; let result = worker_blocked .execute_tool("always_approval", &serde_json::json!({})) .await; @@ -1923,7 +1961,7 @@ mod tests { "always_approval".to_string(), ])), ) - .await; + .await?; let result = worker_allowed .execute_tool("always_approval", &serde_json::json!({})) .await; @@ -1931,11 +1969,13 @@ mod tests { result.is_ok(), "Always tool should be allowed with permission" ); + Ok(()) } #[tokio::test] - async fn test_token_budget_exceeded_fails_job() { - let worker = make_worker(vec![]).await; + async fn test_token_budget_exceeded_fails_job() + -> Result<(), Box> { + let worker = make_worker(vec![]).await?; // Transition to InProgress (required for mark_failed) worker @@ -1944,8 +1984,8 @@ mod tests { ctx.transition_to(JobState::InProgress, None) }) .await - .unwrap() - .unwrap(); + .expect("failed to update context before token-budget failure test") + .expect("failed to transition job to in-progress before token-budget failure test"); // Set a token budget worker @@ -1954,14 +1994,14 @@ mod tests { ctx.max_tokens = 100; }) .await - .unwrap(); + .expect("failed to set max token budget for token-budget failure test"); // Simulate adding tokens that exceed the budget let budget_result = worker .context_manager() .update_context(worker.job_id, |ctx| ctx.add_tokens(200)) .await - .unwrap(); + .expect("failed to apply token usage for token-budget failure test"); assert!( budget_result.is_err(), @@ -1972,18 +2012,20 @@ mod tests { worker .mark_failed(&budget_result.unwrap_err()) .await - .unwrap(); + .expect("failed to mark job failed after token budget exceeded"); let ctx = worker .context_manager() .get_context(worker.job_id) .await - .unwrap(); + .expect("failed to reload job context after token-budget failure"); assert_eq!(ctx.state, JobState::Failed); + Ok(()) } #[tokio::test] - async fn test_iteration_cap_marks_failed_not_stuck() { - let worker = make_worker(vec![]).await; + async fn test_iteration_cap_marks_failed_not_stuck() + -> Result<(), Box> { + let worker = make_worker(vec![]).await?; // Transition to InProgress (required for mark_failed) worker @@ -1992,24 +2034,249 @@ mod tests { ctx.transition_to(JobState::InProgress, None) }) .await - .unwrap() - .unwrap(); + .expect("failed to update context before iteration-cap failure test") + .expect("failed to transition job to in-progress before iteration-cap failure test"); // Simulate what the execution loop does when max_iterations is exceeded worker .mark_failed("Maximum iterations exceeded: job hit the iteration cap") .await - .unwrap(); + .expect("failed to mark job failed after hitting the iteration cap"); let ctx = worker .context_manager() .get_context(worker.job_id) .await - .unwrap(); + .expect("failed to reload job context after iteration-cap failure"); assert_eq!( ctx.state, JobState::Failed, "Iteration cap should transition to Failed, not Stuck" ); + Ok(()) + } + + // ----------------------------------------------------------------------- + // Terminal job-state persistence characterisation tests + // ----------------------------------------------------------------------- + + #[rstest::rstest] + #[case::completed( + TerminalTestCase { + method: TerminalMethod::Completed, + expected_state: JobState::Completed, + expected_status: "completed", + expected_reason: Some("Job completed successfully"), + } + )] + #[case::failed( + TerminalTestCase { + method: TerminalMethod::Failed("budget exceeded"), + expected_state: JobState::Failed, + expected_status: "failed", + expected_reason: Some("budget exceeded"), + } + )] + #[case::stuck( + TerminalTestCase { + method: TerminalMethod::Stuck("timeout"), + expected_state: JobState::Stuck, + expected_status: "stuck", + expected_reason: Some("timeout"), + } + )] + #[tokio::test] + async fn test_terminal_state_characterises_persistence( + #[case] case: TerminalTestCase, + ) -> Result<(), Box> { + let (worker, store) = make_worker_with_capturing_store(vec![]).await?; + + // Transition to InProgress first + transition_to_in_progress(&worker).await?; + + // Execute the terminal state transition + case.method.apply_transition(&worker).await?; + + // Verify state in ContextManager + let ctx = worker + .context_manager() + .get_context(worker.job_id) + .await + .expect("failed to get context after terminal transition"); + assert_eq!(ctx.state, case.expected_state); + + assert_terminal_persistence_with_snapshot( + &store, + case.expected_state, + case.expected_status, + case.expected_reason, + ) + .await; + Ok(()) + } + + /// Test case structure for parameterised terminal state tests. + struct TerminalTestCase { + method: TerminalMethod, + expected_state: JobState, + expected_status: &'static str, + expected_reason: Option<&'static str>, + } + + async fn get_call_counts(store: &CapturingStore) -> (usize, usize) { + let calls = store.calls(); + let status_count = calls.status_history.lock().await.len(); + let event_count = calls.event_history.lock().await.len(); + (status_count, event_count) + } + + async fn assert_rejected_does_not_persist( + worker: &Worker, + store: &CapturingStore, + rejected: TerminalMethod, + expected_state: JobState, + before: (usize, usize), + ) { + let result = match rejected { + TerminalMethod::Completed => worker.mark_completed().await, + TerminalMethod::Failed(reason) => worker.mark_failed(reason).await, + TerminalMethod::Stuck(reason) => worker.mark_stuck(reason).await, + }; + assert!( + result.is_err(), + "Terminal transition {:?} after {:?} should be rejected", + rejected, + expected_state + ); + + let after = get_call_counts(store).await; + assert_eq!( + after.0, before.0, + "Rejected transition {:?} after {:?} should not persist status", + rejected, expected_state + ); + assert_eq!( + after.1, before.1, + "Rejected transition {:?} after {:?} should not persist event", + rejected, expected_state + ); + } + + async fn run_single_terminal_case( + method: TerminalMethod, + expected_state: JobState, + expected_status: &str, + expected_reason: Option<&str>, + ) -> anyhow::Result<()> { + let (worker, store) = make_worker_with_capturing_store(vec![]).await?; + transition_to_in_progress(&worker).await?; + + method.apply_transition(&worker).await?; + + let ctx = worker.context_manager().get_context(worker.job_id).await?; + assert_eq!( + ctx.state, expected_state, + "State should match expected terminal state" + ); + + assert_terminal_persistence(&store, expected_state, expected_status, expected_reason).await; + let before = get_call_counts(&store).await; + + for rejected in [ + TerminalMethod::Completed, + TerminalMethod::Failed("cross-terminal failure"), + TerminalMethod::Stuck("cross-terminal stuck"), + ] { + assert_rejected_does_not_persist(&worker, &store, rejected, expected_state, before) + .await; + } + + Ok(()) + } + + #[tokio::test] + async fn test_double_completed_transition_rejected() + -> Result<(), Box> { + let (worker, store) = make_worker_with_capturing_store(vec![]).await?; + + // Transition to InProgress first + transition_to_in_progress(&worker).await?; + + // First call succeeds + worker + .mark_completed() + .await + .expect("first mark_completed should succeed"); + + // Record call counts before attempting duplicate transition + let status_count_before = store.calls().status_history.lock().await.len(); + let event_count_before = store.calls().event_history.lock().await.len(); + + // Second call should fail + let result = worker.mark_completed().await; + assert!( + result.is_err(), + "Double transition to Completed should be rejected" + ); + + // Verify no new persistence calls were made on rejected transition + let status_count_after = store.calls().status_history.lock().await.len(); + let event_count_after = store.calls().event_history.lock().await.len(); + assert_eq!( + status_count_after, status_count_before, + "Rejected transition should not persist status" + ); + assert_eq!( + event_count_after, event_count_before, + "Rejected transition should not persist event" + ); + + assert_terminal_persistence_with_snapshot( + &store, + JobState::Completed, + "completed", + Some("Job completed successfully"), + ) + .await; + Ok(()) + } + + /// Terminal transition rejection test for duplicate state changes. + /// + /// Verifies that after transitioning to a terminal state (Completed, + /// Failed, or Stuck), subsequent attempts to transition to any terminal + /// state are rejected and persistence calls remain unchanged. + /// + /// This is a curated test covering the three terminal states; it does + /// not generate arbitrary sequences or property-based inputs. + #[tokio::test] + async fn test_terminal_transition_rejects_duplicates() + -> Result<(), Box> { + let test_cases = [ + ( + TerminalMethod::Completed, + JobState::Completed, + "completed", + Some("Job completed successfully"), + ), + ( + TerminalMethod::Failed("test failure"), + JobState::Failed, + "failed", + Some("test failure"), + ), + ( + TerminalMethod::Stuck("test stuck"), + JobState::Stuck, + "stuck", + Some("test stuck"), + ), + ]; + + for (method, expected_state, expected_status, expected_reason) in test_cases { + run_single_terminal_case(method, expected_state, expected_status, expected_reason) + .await?; + } + Ok(()) } } diff --git a/tests/e2e_traces/builtin_tool_coverage/job.rs b/tests/e2e_traces/builtin_tool_coverage/job.rs index f749416cc..edbebc346 100644 --- a/tests/e2e_traces/builtin_tool_coverage/job.rs +++ b/tests/e2e_traces/builtin_tool_coverage/job.rs @@ -17,6 +17,7 @@ async fn job_create_status() -> anyhow::Result<()> { RigConfig::default(), ) .await?; + let rig = scopeguard::guard(rig, |rig| rig.shutdown()); // Both tools should have succeeded. let completed = rig.tool_calls_completed(); @@ -35,32 +36,42 @@ async fn job_create_status() -> anyhow::Result<()> { .iter() .find(|(n, _)| n == "create_job") .expect("create_job result missing"); + let parsed_create = serde_json::from_str::(&create_result.1) + .expect("create_job result should be valid JSON"); assert!( - create_result.1.contains("job_id"), - "create_job should return a job_id: {:?}", - create_result.1 + parsed_create + .get("job_id") + .and_then(serde_json::Value::as_str) + .is_some_and(|job_id| !job_id.is_empty()), + "create_job should return a non-empty job_id: {parsed_create:?}" ); - assert!( - create_result.1.contains("in_progress"), - "create_job should dispatch through the scheduler, not stay pending: {:?}", - create_result.1 + assert_eq!( + parsed_create + .get("status") + .and_then(serde_json::Value::as_str), + Some("in_progress"), + "create_job should dispatch through the scheduler, not stay pending: {parsed_create:?}" ); assert!( - !create_result.1.contains("scheduler unavailable"), - "create_job should not fall back to the unscheduled path: {:?}", - create_result.1 + !parsed_create + .get("error") + .and_then(serde_json::Value::as_str) + .is_some_and(|error| error.contains("scheduler unavailable")), + "create_job should not fall back to the unscheduled path: {parsed_create:?}" ); let status_result = results .iter() .find(|(n, _)| n == "job_status") .expect("job_status result missing"); - assert!( - status_result.1.contains("Test analysis job"), - "job_status should return the job title: {:?}", - status_result.1 + let parsed_status = serde_json::from_str::(&status_result.1) + .expect("job_status result should be valid JSON"); + assert_eq!( + parsed_status + .get("title") + .and_then(serde_json::Value::as_str), + Some("Test analysis job"), + "job_status should return the job title: {parsed_status:?}" ); - - rig.shutdown(); Ok(()) } @@ -79,6 +90,7 @@ async fn job_list_cancel() -> anyhow::Result<()> { RigConfig::default(), ) .await?; + let rig = scopeguard::guard(rig, |rig| rig.shutdown()); // All three tools should have succeeded. let completed = rig.tool_calls_completed(); @@ -95,6 +107,33 @@ async fn job_list_cancel() -> anyhow::Result<()> { "cancel_job should succeed: {completed:?}" ); - rig.shutdown(); + let results = rig.tool_results(); + let create_result = results + .iter() + .find(|(n, _)| n == "create_job") + .expect("create_job result missing"); + assert!( + create_result.1.contains("job_id"), + "create_job should return a job_id: {:?}", + create_result.1 + ); + let list_result = results + .iter() + .find(|(n, _)| n == "list_jobs") + .expect("list_jobs result missing"); + assert!( + !list_result.1.is_empty() && list_result.1.contains("job_id"), + "list_jobs should return at least one job entry: {:?}", + list_result.1 + ); + let cancel_result = results + .iter() + .find(|(n, _)| n == "cancel_job") + .expect("cancel_job result missing"); + assert!( + cancel_result.1.contains("cancel") || cancel_result.1.contains("cancelled"), + "cancel_job should report a cancelled outcome: {:?}", + cancel_result.1 + ); Ok(()) } diff --git a/tests/e2e_traces/builtin_tool_coverage/routine.rs b/tests/e2e_traces/builtin_tool_coverage/routine.rs index 550e1401d..33c614494 100644 --- a/tests/e2e_traces/builtin_tool_coverage/routine.rs +++ b/tests/e2e_traces/builtin_tool_coverage/routine.rs @@ -80,19 +80,9 @@ async fn routine_system_event_emit() -> anyhow::Result<()> { .iter() .find(|(n, _)| n == "event_emit") .expect("event_emit result missing"); - assert!( - emit_result.1.contains("fired_routines"), - "event_emit should report fired routine count: {:?}", - emit_result.1 - ); - // Verify at least one routine actually fired (not just that the key exists). let emit_json: serde_json::Value = serde_json::from_str(&emit_result.1).expect("event_emit result should be valid JSON"); - assert!( - emit_json["fired_routines"].as_u64().unwrap_or(0) > 0, - "event_emit should have fired at least one routine: {:?}", - emit_result.1 - ); + insta::assert_json_snapshot!("routine_system_event_emit_payload", emit_json); rig.shutdown(); Ok(()) @@ -135,19 +125,15 @@ async fn skill_install_routine_webhook_sim() -> anyhow::Result<()> { .expect("event_emit result missing"); let emit_payload: serde_json::Value = serde_json::from_str(&emit_result.1).expect("event_emit result should be valid JSON"); - let fired_routines = emit_payload - .get("fired_routines") - .and_then(serde_json::Value::as_u64) - .expect("event_emit result should include integer fired_routines"); - assert!( - fired_routines > 0, - "event_emit should report fired routines > 0: {emit_payload:?}" - ); + insta::assert_json_snapshot!("skill_install_emit_payload", emit_payload); - let _history_result = results + let history_result = results .iter() .find(|(n, _)| n == "routine_history") .expect("routine_history result missing"); + let history_json: serde_json::Value = serde_json::from_str(&history_result.1) + .expect("routine_history result should be valid JSON"); + insta::assert_json_snapshot!("skill_install_routine_history_payload", history_json); rig.shutdown(); Ok(()) diff --git a/tests/e2e_traces/builtin_tool_coverage/snapshots/e2e_traces__builtin_tool_coverage__routine__routine_system_event_emit_payload.snap b/tests/e2e_traces/builtin_tool_coverage/snapshots/e2e_traces__builtin_tool_coverage__routine__routine_system_event_emit_payload.snap new file mode 100644 index 000000000..fef4e9266 --- /dev/null +++ b/tests/e2e_traces/builtin_tool_coverage/snapshots/e2e_traces__builtin_tool_coverage__routine__routine_system_event_emit_payload.snap @@ -0,0 +1,10 @@ +--- +source: tests/e2e_traces/builtin_tool_coverage/routine.rs +expression: emit_json +--- +{ + "event_source": "github", + "event_type": "issue.opened", + "fired_routines": 1, + "user_id": "test-user" +} diff --git a/tests/e2e_traces/builtin_tool_coverage/snapshots/e2e_traces__builtin_tool_coverage__routine__skill_install_emit_payload.snap b/tests/e2e_traces/builtin_tool_coverage/snapshots/e2e_traces__builtin_tool_coverage__routine__skill_install_emit_payload.snap new file mode 100644 index 000000000..58429f2f0 --- /dev/null +++ b/tests/e2e_traces/builtin_tool_coverage/snapshots/e2e_traces__builtin_tool_coverage__routine__skill_install_emit_payload.snap @@ -0,0 +1,10 @@ +--- +source: tests/e2e_traces/builtin_tool_coverage/routine.rs +expression: emit_payload +--- +{ + "event_source": "github", + "event_type": "issue.opened", + "fired_routines": 1, + "user_id": "test-user" +} diff --git a/tests/e2e_traces/builtin_tool_coverage/snapshots/e2e_traces__builtin_tool_coverage__routine__skill_install_routine_history_payload.snap b/tests/e2e_traces/builtin_tool_coverage/snapshots/e2e_traces__builtin_tool_coverage__routine__skill_install_routine_history_payload.snap new file mode 100644 index 000000000..27792f4c0 --- /dev/null +++ b/tests/e2e_traces/builtin_tool_coverage/snapshots/e2e_traces__builtin_tool_coverage__routine__skill_install_routine_history_payload.snap @@ -0,0 +1,9 @@ +--- +source: tests/e2e_traces/builtin_tool_coverage/routine.rs +expression: history_json +--- +{ + "routine": "wf-webhook-sim-trace", + "runs": [], + "total_runs": 0 +} diff --git a/tests/e2e_traces/builtin_tool_coverage/snapshots/e2e_traces__builtin_tool_coverage__time__time_diff_result.snap b/tests/e2e_traces/builtin_tool_coverage/snapshots/e2e_traces__builtin_tool_coverage__time__time_diff_result.snap new file mode 100644 index 000000000..d61a1971c --- /dev/null +++ b/tests/e2e_traces/builtin_tool_coverage/snapshots/e2e_traces__builtin_tool_coverage__time__time_diff_result.snap @@ -0,0 +1,5 @@ +--- +source: tests/e2e_traces/builtin_tool_coverage/time.rs +expression: "time_results[1].1" +--- +{ "days": 1, "hours": 28, "minutes": 1695, "seconds": 101700 } diff --git a/tests/e2e_traces/builtin_tool_coverage/snapshots/e2e_traces__builtin_tool_coverage__time__time_parse_and_diff_response.snap b/tests/e2e_traces/builtin_tool_coverage/snapshots/e2e_traces__builtin_tool_coverage__time__time_parse_and_diff_response.snap new file mode 100644 index 000000000..3a47721d8 --- /dev/null +++ b/tests/e2e_traces/builtin_tool_coverage/snapshots/e2e_traces__builtin_tool_coverage__time__time_parse_and_diff_response.snap @@ -0,0 +1,5 @@ +--- +source: tests/e2e_traces/builtin_tool_coverage/time.rs +expression: "responses[0].content" +--- +The timestamp 2024-01-15T10:30:00Z was parsed successfully. The difference between the two timestamps is 1 day, 4 hours, and 15 minutes (101700 seconds). diff --git a/tests/e2e_traces/builtin_tool_coverage/snapshots/e2e_traces__builtin_tool_coverage__time__time_parse_invalid_response.snap b/tests/e2e_traces/builtin_tool_coverage/snapshots/e2e_traces__builtin_tool_coverage__time__time_parse_invalid_response.snap new file mode 100644 index 000000000..eeee43a91 --- /dev/null +++ b/tests/e2e_traces/builtin_tool_coverage/snapshots/e2e_traces__builtin_tool_coverage__time__time_parse_invalid_response.snap @@ -0,0 +1,5 @@ +--- +source: tests/e2e_traces/builtin_tool_coverage/time.rs +expression: "responses[0].content" +--- +The timestamp 'not-a-valid-timestamp' could not be parsed. Please provide a valid ISO 8601 timestamp like '2024-01-15T10:30:00Z'. diff --git a/tests/e2e_traces/builtin_tool_coverage/snapshots/e2e_traces__builtin_tool_coverage__time__time_parse_invalid_result.snap b/tests/e2e_traces/builtin_tool_coverage/snapshots/e2e_traces__builtin_tool_coverage__time__time_parse_invalid_result.snap new file mode 100644 index 000000000..e633b72bc --- /dev/null +++ b/tests/e2e_traces/builtin_tool_coverage/snapshots/e2e_traces__builtin_tool_coverage__time__time_parse_invalid_result.snap @@ -0,0 +1,5 @@ +--- +source: tests/e2e_traces/builtin_tool_coverage/time.rs +expression: "time_result_previews[0]" +--- +Tool 'time' failed: Tool error: Tool time execution failed: Invalid parameters: invalid timestamp 'not-a-valid-timestamp': expected RFC 3339 or a naive timestamp with timezone/from_timezone diff --git a/tests/e2e_traces/builtin_tool_coverage/snapshots/e2e_traces__builtin_tool_coverage__time__time_parse_result.snap b/tests/e2e_traces/builtin_tool_coverage/snapshots/e2e_traces__builtin_tool_coverage__time__time_parse_result.snap new file mode 100644 index 000000000..6f696b46d --- /dev/null +++ b/tests/e2e_traces/builtin_tool_coverage/snapshots/e2e_traces__builtin_tool_coverage__time__time_parse_result.snap @@ -0,0 +1,5 @@ +--- +source: tests/e2e_traces/builtin_tool_coverage/time.rs +expression: "time_results[0].1" +--- +{ "iso": "2024-01-15T10:30:00+00:00", "unix": 1705314600, "unix_millis": 1705314600000 } diff --git a/tests/e2e_traces/builtin_tool_coverage/time.rs b/tests/e2e_traces/builtin_tool_coverage/time.rs index 9c0570d06..7a6058d9d 100644 --- a/tests/e2e_traces/builtin_tool_coverage/time.rs +++ b/tests/e2e_traces/builtin_tool_coverage/time.rs @@ -8,7 +8,7 @@ async fn time_parse_and_diff() -> anyhow::Result<()> { env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/llm_traces/tools/time_parse_diff.json" ); - let (rig, _trace, _responses) = run_trace_test( + let (rig, _trace, responses) = run_trace_test( fixture_path, "Parse a time and compute a diff", RigConfig { @@ -19,16 +19,27 @@ async fn time_parse_and_diff() -> anyhow::Result<()> { ) .await?; - // Time tool should have been called twice (parse + diff). - let started = rig.tool_calls_started(); - let time_count = started.iter().filter(|n| n.as_str() == "time").count(); - assert!( - time_count >= 2, - "Expected >= 2 time tool calls, got {time_count}" - ); - + let result: anyhow::Result<()> = { + // Time tool should have been called twice (parse + diff). + let started = rig.tool_calls_started(); + let time_count = started.iter().filter(|n| n.as_str() == "time").count(); + assert!( + time_count >= 2, + "Expected >= 2 time tool calls, got {time_count}" + ); + let time_results: Vec<_> = rig + .tool_results() + .into_iter() + .filter(|(name, _)| name == "time") + .collect(); + assert_eq!(time_results.len(), 2, "expected exactly 2 time results"); + insta::assert_snapshot!("time_parse_result", time_results[0].1); + insta::assert_snapshot!("time_diff_result", time_results[1].1); + insta::assert_snapshot!("time_parse_and_diff_response", responses[0].content); + Ok(()) + }; rig.shutdown(); - Ok(()) + result } #[tokio::test] @@ -37,7 +48,7 @@ async fn time_parse_invalid() -> anyhow::Result<()> { env!("CARGO_MANIFEST_DIR"), "/tests/fixtures/llm_traces/tools/time_parse_invalid.json" ); - let (rig, _trace, _responses) = run_trace_test( + let (rig, _trace, responses) = run_trace_test( fixture_path, "Parse an invalid timestamp", RigConfig { @@ -48,18 +59,28 @@ async fn time_parse_invalid() -> anyhow::Result<()> { ) .await?; - // The time tool call should have failed (invalid timestamp). - let completed = rig.tool_calls_completed(); - let time_results: Vec<_> = completed - .iter() - .filter(|(name, _)| name == "time") - .collect(); - assert!(!time_results.is_empty(), "Expected time tool to be called"); - assert!( - time_results.iter().any(|(_, ok)| !ok), - "Expected at least one failed time call: {time_results:?}" - ); - + let result: anyhow::Result<()> = { + // The time tool call should have failed (invalid timestamp). + let completed = rig.tool_calls_completed(); + let time_results: Vec<_> = completed + .iter() + .filter(|(name, _)| name == "time") + .collect(); + assert!(!time_results.is_empty(), "Expected time tool to be called"); + assert!( + time_results.iter().any(|(_, ok)| !ok), + "Expected at least one failed time call: {time_results:?}" + ); + let time_result_previews: Vec<_> = rig + .tool_results() + .into_iter() + .filter(|(name, _)| name == "time") + .map(|(_, preview)| preview) + .collect(); + insta::assert_snapshot!("time_parse_invalid_result", time_result_previews[0]); + insta::assert_snapshot!("time_parse_invalid_response", responses[0].content); + Ok(()) + }; rig.shutdown(); - Ok(()) + result } diff --git a/tests/e2e_traces/heartbeat.rs b/tests/e2e_traces/heartbeat.rs index 89d27ebd9..ed8f837e4 100644 --- a/tests/e2e_traces/heartbeat.rs +++ b/tests/e2e_traces/heartbeat.rs @@ -66,7 +66,10 @@ async fn heartbeat_findings() { } // No notification since we called check_heartbeat directly (not run). - let _ = rx.try_recv(); + assert!( + rx.try_recv().is_err(), + "Expected no notification to be sent when calling check_heartbeat() directly" + ); } #[tokio::test] diff --git a/tests/e2e_traces/routine_cooldown.rs b/tests/e2e_traces/routine_cooldown.rs index af72355e9..80db0305b 100644 --- a/tests/e2e_traces/routine_cooldown.rs +++ b/tests/e2e_traces/routine_cooldown.rs @@ -5,10 +5,7 @@ use std::time::Duration; -use chrono::Utc; - use ironclaw::agent::routine::Trigger; -use ironclaw::db::RoutineRuntimeUpdate; use crate::support::routines::engine_sync::{wait_for_idle, wait_for_persisted_run}; use crate::support::routines::{ @@ -17,7 +14,7 @@ use crate::support::routines::{ use crate::support::trace_llm::{LlmTrace, TraceResponse, TraceStep}; #[tokio::test] -async fn routine_cooldown() { +async fn routine_cooldown() -> anyhow::Result<()> { let (db, _tmp) = create_test_db().await.expect("create_test_db"); let ws = create_workspace(&db); @@ -56,29 +53,31 @@ async fn routine_cooldown() { assert!(fired1 >= 1, "First fire should work"); // Wait for routine execution to complete using deterministic synchronization, - // then verify the routine run was recorded before updating last_run_at. - wait_for_idle(&engine, Duration::from_secs(5)).await; + // then verify the routine run was recorded in the database. + wait_for_idle(&engine, Duration::from_secs(5)).await?; // Wait for routine run to be durably persisted in the database. // Snapshot run count before firing (zero for a freshly-created routine). - wait_for_persisted_run(&db, routine.id, 0, Duration::from_secs(5)).await; - - // Update the routine's last_run_at to now (simulating it just ran). - db.update_routine_runtime(RoutineRuntimeUpdate { - id: routine.id, - last_run_at: Utc::now(), - next_fire_at: None, - run_count: 1, - consecutive_failures: 0, - state: &serde_json::json!({}), - }) - .await - .expect("update_routine_runtime"); + wait_for_persisted_run(&db, routine.id, 0, Duration::from_secs(5)).await?; - // Refresh cache to pick up updated last_run_at. + let persisted = db + .get_routine(routine.id) + .await + .expect("get_routine") + .expect("routine present"); + assert!( + persisted.last_run_at.is_some(), + "Expected engine to persist last_run_at" + ); + assert!( + persisted.run_count >= 1, + "Expected engine to persist run_count" + ); engine.refresh_event_cache().await; // Second fire should be blocked by cooldown. let fired2 = engine.check_event_triggers(&msg).await; assert_eq!(fired2, 0, "Second fire should be blocked by cooldown"); + + Ok(()) } diff --git a/tests/e2e_traces/routine_cron.rs b/tests/e2e_traces/routine_cron.rs index 3d792da98..3b2bca9a8 100644 --- a/tests/e2e_traces/routine_cron.rs +++ b/tests/e2e_traces/routine_cron.rs @@ -16,7 +16,7 @@ use crate::support::routines::{ use crate::support::trace_llm::{LlmTrace, TraceResponse, TraceStep}; #[tokio::test] -async fn cron_routine_fires() { +async fn cron_routine_fires() -> anyhow::Result<()> { let (db, _tmp) = create_test_db().await.expect("create_test_db"); let ws = create_workspace(&db); @@ -53,13 +53,15 @@ async fn cron_routine_fires() { // Wait for routine execution to complete using deterministic synchronization, // then verify the routine run was recorded. - wait_for_idle(&engine, Duration::from_secs(5)).await; + wait_for_idle(&engine, Duration::from_secs(5)).await?; // Wait for routine run to be durably persisted in the database. // Snapshot run count before firing (zero for a freshly-created routine). - wait_for_persisted_run(&db, routine.id, 0, Duration::from_secs(5)).await; + wait_for_persisted_run(&db, routine.id, 0, Duration::from_secs(5)).await?; // Notification may or may not be sent depending on config; // just verify no panic occurred. Drain the channel. let _ = notify_rx.try_recv(); + + Ok(()) } diff --git a/tests/e2e_traces/routine_event.rs b/tests/e2e_traces/routine_event.rs index 5d804c552..f7a517b11 100644 --- a/tests/e2e_traces/routine_event.rs +++ b/tests/e2e_traces/routine_event.rs @@ -14,7 +14,7 @@ use crate::support::routines::{ use crate::support::trace_llm::{LlmTrace, TraceResponse, TraceStep}; #[tokio::test] -async fn event_trigger_matches() { +async fn event_trigger_matches() -> anyhow::Result<()> { let (db, _tmp) = create_test_db().await.expect("create_test_db"); let ws = create_workspace(&db); @@ -57,14 +57,16 @@ async fn event_trigger_matches() { // Wait for routine execution to complete using deterministic synchronization, // then verify the routine run was recorded. - wait_for_idle(&engine, Duration::from_secs(5)).await; + wait_for_idle(&engine, Duration::from_secs(5)).await?; // Wait for routine run to be durably persisted in the database. // Snapshot run count before firing (zero for a freshly-created routine). - wait_for_persisted_run(&db, routine.id, 0, Duration::from_secs(5)).await; + wait_for_persisted_run(&db, routine.id, 0, Duration::from_secs(5)).await?; // Negative match: message that doesn't match. let non_matching_msg = make_test_incoming_message("check the staging environment"); let fired_neg = engine.check_event_triggers(&non_matching_msg).await; assert_eq!(fired_neg, 0, "Expected 0 routines fired on non-match"); + + Ok(()) } diff --git a/tests/e2e_traces/routine_system_event.rs b/tests/e2e_traces/routine_system_event.rs index 39f0e6ca9..7c6226439 100644 --- a/tests/e2e_traces/routine_system_event.rs +++ b/tests/e2e_traces/routine_system_event.rs @@ -13,7 +13,7 @@ use crate::support::routines::{ use crate::support::trace_llm::{LlmTrace, TraceResponse, TraceStep}; #[tokio::test] -async fn system_event_trigger_matches_and_filters() { +async fn system_event_trigger_matches_and_filters() -> anyhow::Result<()> { let (db, _tmp) = create_test_db().await.expect("create_test_db"); let ws = create_workspace(&db); let trace = LlmTrace::single_turn( @@ -30,7 +30,7 @@ async fn system_event_trigger_matches_and_filters() { }], ); let (engine, _notify_rx) = make_minimal_engine(trace, db.clone(), ws); - let routine = register_github_issue_routine(&db, &engine).await; + let routine = register_github_issue_routine(&db, &engine).await?; // Matching event should fire and be recorded in run history. assert_system_event_count( @@ -47,11 +47,11 @@ async fn system_event_trigger_matches_and_filters() { // Wait for routine execution to complete using deterministic synchronization, // then verify the routine run was recorded. - wait_for_idle(&engine, Duration::from_secs(5)).await; + wait_for_idle(&engine, Duration::from_secs(5)).await?; // Wait for routine run to be durably persisted in the database. // Snapshot run count before firing (zero for a freshly-created routine). - wait_for_persisted_run(&db, routine.id, 0, Duration::from_secs(5)).await; + wait_for_persisted_run(&db, routine.id, 0, Duration::from_secs(5)).await?; // Table-driven checks for non-matching and case-insensitive scenarios. #[rustfmt::skip] @@ -68,4 +68,6 @@ async fn system_event_trigger_matches_and_filters() { for (spec, expected, msg) in scenarios { assert_system_event_count(&engine, spec, expected, msg).await; } + + Ok(()) } diff --git a/tests/infrastructure.rs b/tests/infrastructure.rs index 176023530..538406343 100644 --- a/tests/infrastructure.rs +++ b/tests/infrastructure.rs @@ -1,6 +1,8 @@ //! Infrastructure integration tests covering heartbeat, pairing, provider //! chaos, SIGHUP reload, and workspace functionality. +mod support; + #[path = "infrastructure/heartbeat.rs"] mod heartbeat; #[path = "infrastructure/pairing.rs"] diff --git a/tests/infrastructure/sighup_reload.rs b/tests/infrastructure/sighup_reload.rs index 3e009ade1..b93716362 100644 --- a/tests/infrastructure/sighup_reload.rs +++ b/tests/infrastructure/sighup_reload.rs @@ -1,170 +1,210 @@ -//! Integration test for SIGHUP hot-reload of HTTP webhook configuration. +//! Integration tests for SIGHUP hot-reload of HTTP webhook configuration. //! -//! This test verifies that: -//! 1. SIGHUP triggers config reload from DB/environment -//! 2. Address changes cause listener restart -//! 3. Secret changes take effect immediately (zero-downtime) -//! 4. Old listener is shut down after successful restart - -#![cfg(unix)] +//! Exercises the reload path end-to-end by driving `WebhookServer` and +//! `HttpChannel` directly — no live binary spawning. +use std::net::{SocketAddr, TcpListener as StdTcpListener}; use std::time::Duration; +use axum::http::StatusCode; +use reqwest::Client; +use secrecy::SecretString; +use serde_json::json; + +use ironclaw::channels::{HttpChannel, NativeChannel, WebhookServer, WebhookServerConfig}; +use ironclaw::config::HttpConfig; +use rstest::{fixture, rstest}; + +use crate::support::webhook_helpers; + +/// Build a minimal health-check server using the given already-bound listener. +/// Returns the started server and the bound address. +async fn health_server( + listener: tokio::net::TcpListener, +) -> Result<(WebhookServer, SocketAddr), Box> { + let addr = listener.local_addr()?; + let config = WebhookServerConfig { addr }; + let mut server = WebhookServer::new(config); + server.add_routes(webhook_helpers::health_routes()); + server.start_with_listener(listener).await?; + Ok((server, addr)) +} + +/// POST a webhook payload and return the HTTP status. +async fn post_webhook( + client: &Client, + addr: SocketAddr, + secret: &str, +) -> Result { + Ok(client + .post(format!("http://{}/webhook", addr)) + .json(&json!({"content": "hello", "secret": secret})) + .send() + .await? + .status()) +} + +#[fixture] +fn http_client() -> Result { + webhook_helpers::test_http_client() +} + +#[rstest] #[tokio::test] -#[ignore] // Requires full ironclaw binary and database setup -async fn test_sighup_config_reload_address_change() { - // This is a placeholder integration test structure. - // It demonstrates the test approach and can be run against a live ironclaw instance. - // - // To run this test manually: - // 1. Start ironclaw with HTTP_PORT=19000 HTTP_WEBHOOK_SECRET=initial-secret - // 2. Run: cargo test --test sighup_reload_integration -- --ignored --nocapture - // - // The test will: - // - Verify initial webhook responds on port 19000 with "initial-secret" - // - Update environment/DB to use port 19001 and "new-secret" - // - Send SIGHUP to ironclaw - // - Verify old port 19000 stops responding - // - Verify new port 19001 responds with "new-secret" - - let initial_port = 19000u16; - let _new_port = 19001u16; - let initial_secret = "initial-secret"; - let _new_secret = "new-secret"; - - let client = reqwest::Client::builder() - .timeout(Duration::from_secs(2)) - .build() - .expect("Failed to build HTTP client"); - - // Verify initial webhook is listening - let initial_addr = format!("http://127.0.0.1:{}/webhook", initial_port); - let response = client - .post(&initial_addr) - .json(&serde_json::json!({ - "content": "test", - "secret": initial_secret - })) +async fn test_sighup_config_reload_address_change( + http_client: Result, +) -> Result<(), Box> { + let http_client = http_client?; + let listener1 = tokio::net::TcpListener::bind("127.0.0.1:0").await?; + let (mut server, addr1) = health_server(listener1).await?; + + // Confirm first address responds. + let resp = http_client + .get(format!("http://{}/health", addr1)) + .send() + .await + .expect("health check"); + assert_eq!(resp.status(), StatusCode::OK); + + // Restart on a second ephemeral port. + let listener2 = tokio::net::TcpListener::bind("127.0.0.1:0").await?; + let addr2 = listener2.local_addr()?; + server + .restart_with_listener(listener2) + .await + .expect("restart"); + + // New address should respond. + let resp = http_client + .get(format!("http://{}/health", addr2)) .send() + .await + .expect("health check on new address"); + assert_eq!(resp.status(), StatusCode::OK, "new address should respond"); + + // Old address should refuse connections. + let old_result = tokio::time::timeout( + Duration::from_millis(200), + http_client.get(format!("http://{}/health", addr1)).send(), + ) + .await; + + match old_result { + // Timeout expired — the old address no longer accepts connections. + Err(_) => {} + // Request reached the client stack but the old listener was gone. + Ok(Err(_)) => {} + Ok(Ok(resp)) => { + panic!( + "old address should not respond after restart, got status {}", + resp.status() + ); + } + } + + server.shutdown().await; + Ok(()) +} + +#[rstest] +#[tokio::test] +async fn test_sighup_secret_update_zero_downtime( + http_client: Result, +) -> Result<(), Box> { + let http_client = http_client?; + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?; + let addr = listener.local_addr()?; + + let channel = HttpChannel::new(HttpConfig { + host: "127.0.0.1".to_string(), + port: addr.port(), + webhook_secret: Some(SecretString::from("old-secret".to_string())), + user_id: "test-user".to_string(), + }); + + // Start the channel so the internal sender is populated. + // `_stream` is intentionally kept to hold the returned `MessageStream` alive, + // ensuring the `HttpChannel`'s internal sender/registration is not dropped + // and the channel lifecycle remains active for the duration of the test. + let _stream = channel.start().await.expect("start channel"); + + let mut server = WebhookServer::new(WebhookServerConfig { addr }); + server.add_routes(channel.routes()); + server.start_with_listener(listener).await?; + + // Old secret should be accepted. + let status = post_webhook(&http_client, addr, "old-secret").await?; + assert_eq!(status, StatusCode::OK, "old secret should work initially"); + + // Hot-swap secret via the public API. + channel + .update_secret(Some(SecretString::from("new-secret".to_string()))) .await; - assert!( - response.is_ok(), - "Initial webhook should be listening on port {}", - initial_port - ); + // Old secret should now be rejected. + let status = post_webhook(&http_client, addr, "old-secret").await?; assert_eq!( - response.unwrap().status(), - 200, - "Request with correct secret should succeed" + status, + StatusCode::UNAUTHORIZED, + "old secret should fail after swap" ); - // In a real test, we would: - // 1. Update the database or environment variables for the new config - // 2. Send SIGHUP to the ironclaw process - // 3. Wait for reload to complete - // 4. Verify new listener is active and old one is inactive - // 5. Verify secret change took effect + // New secret should be accepted. + let status = post_webhook(&http_client, addr, "new-secret").await?; + assert_eq!(status, StatusCode::OK, "new secret should work after swap"); - println!("SIGHUP reload test structure is in place."); - println!("This test requires a running ironclaw instance to verify actual behavior."); + server.shutdown().await; + Ok(()) } +#[rstest] #[tokio::test] -#[ignore] // Requires full ironclaw binary -async fn test_sighup_secret_update_zero_downtime() { - // Test that secret changes take effect immediately without restarting the listener. - // - // Setup: - // - Start ironclaw with HTTP_PORT=19002 HTTP_WEBHOOK_SECRET=original-secret - // - // Test flow: - // 1. Make request with "original-secret" → 200 OK - // 2. Update DB secret to "updated-secret" - // 3. Send SIGHUP - // 4. Make request with "original-secret" → 401 Unauthorized - // 5. Make request with "updated-secret" → 200 OK - // 6. Verify listener is still on same port (no restart) - - let port = 19002u16; - let original_secret = "original-secret"; - let _updated_secret = "updated-secret"; - - let client = reqwest::Client::builder() - .timeout(Duration::from_secs(2)) - .build() - .expect("Failed to build HTTP client"); - - let webhook_url = format!("http://127.0.0.1:{}/webhook", port); - - // Verify original secret works - let response = client - .post(&webhook_url) - .json(&serde_json::json!({ - "content": "test", - "secret": original_secret - })) +async fn test_sighup_rollback_on_address_bind_failure( + http_client: Result, +) -> Result<(), Box> { + let http_client = http_client?; + let listener1 = tokio::net::TcpListener::bind("127.0.0.1:0").await?; + let (mut server, addr1) = health_server(listener1).await?; + + // Confirm initial address works. + let resp = http_client + .get(format!("http://{}/health", addr1)) .send() - .await; - - assert!( - response.is_ok(), - "Initial request with correct secret should succeed" + .await + .expect("health check"); + assert_eq!( + resp.status(), + StatusCode::OK, + "initial address should respond" ); - assert_eq!(response.unwrap().status(), 200); - // After SIGHUP with updated secret: - // - Original secret should fail - // - Updated secret should succeed - // (This is verified by the hot-swap unit test; integration test - // structure is in place for end-to-end verification) + // Occupy a second ephemeral port so bind deterministically fails. + let occupied = StdTcpListener::bind("127.0.0.1:0").expect("bind conflict port"); + let conflict_addr = occupied.local_addr().expect("conflict local_addr"); - println!("Zero-downtime secret update test structure is in place."); -} + let result = server.restart_with_addr(conflict_addr).await; + assert!(result.is_err(), "restart to occupied port should fail"); -#[tokio::test] -#[ignore] // Requires manual setup -async fn test_sighup_rollback_on_address_bind_failure() { - // Test that if restart_with_addr fails, the old listener remains active - // and state is restored. - // - // Setup: - // - Start ironclaw with HTTP_PORT=19003 HTTP_WEBHOOK_SECRET=test-secret - // - // Test flow: - // 1. Make request to port 19003 → 200 OK - // 2. Update DB to use invalid address (e.g., port 1, which requires root) - // 3. Send SIGHUP - // 4. Verify old listener on port 19003 is still responding - // 5. Verify state was restored (config still shows port 19003) - - let original_port = 19003u16; - let secret = "test-secret"; - - let client = reqwest::Client::builder() - .timeout(Duration::from_secs(2)) - .build() - .expect("Failed to build HTTP client"); - - let webhook_url = format!("http://127.0.0.1:{}/webhook", original_port); - - // Verify original listener is working - let response = client - .post(&webhook_url) - .json(&serde_json::json!({ - "content": "test", - "secret": secret - })) - .send() - .await; + drop(occupied); - assert!(response.is_ok(), "Original listener should be responding"); - assert_eq!(response.unwrap().status(), 200); + // Original listener must still respond. + let resp = http_client + .get(format!("http://{}/health", addr1)) + .send() + .await + .expect("health check after failed restart"); + assert_eq!( + resp.status(), + StatusCode::OK, + "original address should still respond after failed restart" + ); - // After SIGHUP with invalid address: - // - Original listener should still respond - // - No downtime should have occurred - // (Verified by webhook_server unit test; integration structure in place) + assert_eq!( + server.current_addr(), + addr1, + "server address should be restored after failed restart" + ); - println!("SIGHUP rollback test structure is in place."); + server.shutdown().await; + Ok(()) } diff --git a/tests/snapshots/worker_orchestrator_json_shapes__completion_report.snap b/tests/snapshots/worker_orchestrator_json_shapes__completion_report.snap new file mode 100644 index 000000000..79d8da90d --- /dev/null +++ b/tests/snapshots/worker_orchestrator_json_shapes__completion_report.snap @@ -0,0 +1,10 @@ +--- +source: tests/worker_orchestrator_json_shapes.rs +assertion_line: 46 +expression: "&original" +--- +{ + "success": true, + "message": "done", + "iterations": 10 +} diff --git a/tests/snapshots/worker_orchestrator_json_shapes__credential_response.snap b/tests/snapshots/worker_orchestrator_json_shapes__credential_response.snap new file mode 100644 index 000000000..4c4aeb202 --- /dev/null +++ b/tests/snapshots/worker_orchestrator_json_shapes__credential_response.snap @@ -0,0 +1,9 @@ +--- +source: tests/worker_orchestrator_json_shapes.rs +assertion_line: 142 +expression: "&parsed" +--- +{ + "env_var": "API_KEY", + "value": "secret123" +} diff --git a/tests/snapshots/worker_orchestrator_json_shapes__job_description.snap b/tests/snapshots/worker_orchestrator_json_shapes__job_description.snap new file mode 100644 index 000000000..af7844ab8 --- /dev/null +++ b/tests/snapshots/worker_orchestrator_json_shapes__job_description.snap @@ -0,0 +1,10 @@ +--- +source: tests/worker_orchestrator_json_shapes.rs +assertion_line: 111 +expression: "&parsed" +--- +{ + "title": "Test Job", + "description": "Do something", + "project_dir": "/tmp/project" +} diff --git a/tests/snapshots/worker_orchestrator_json_shapes__job_event_payload.snap b/tests/snapshots/worker_orchestrator_json_shapes__job_event_payload.snap new file mode 100644 index 000000000..be9ca9322 --- /dev/null +++ b/tests/snapshots/worker_orchestrator_json_shapes__job_event_payload.snap @@ -0,0 +1,11 @@ +--- +source: tests/worker_orchestrator_json_shapes.rs +assertion_line: 32 +expression: "&original" +--- +{ + "event_type": "tool_use", + "data": { + "tool": "bash" + } +} diff --git a/tests/snapshots/worker_orchestrator_json_shapes__prompt_response.snap b/tests/snapshots/worker_orchestrator_json_shapes__prompt_response.snap new file mode 100644 index 000000000..7d54d3967 --- /dev/null +++ b/tests/snapshots/worker_orchestrator_json_shapes__prompt_response.snap @@ -0,0 +1,9 @@ +--- +source: tests/worker_orchestrator_json_shapes.rs +assertion_line: 151 +expression: "&parsed" +--- +{ + "content": "Continue?", + "done": false +} diff --git a/tests/snapshots/worker_orchestrator_json_shapes__proxy_completion_response.snap b/tests/snapshots/worker_orchestrator_json_shapes__proxy_completion_response.snap new file mode 100644 index 000000000..31f7e23c6 --- /dev/null +++ b/tests/snapshots/worker_orchestrator_json_shapes__proxy_completion_response.snap @@ -0,0 +1,13 @@ +--- +source: tests/worker_orchestrator_json_shapes.rs +assertion_line: 92 +expression: "&parsed" +--- +{ + "content": "Hello", + "input_tokens": 100, + "output_tokens": 50, + "finish_reason": "stop", + "cache_read_input_tokens": 10, + "cache_creation_input_tokens": 5 +} diff --git a/tests/snapshots/worker_orchestrator_json_shapes__proxy_tool_completion_request.snap b/tests/snapshots/worker_orchestrator_json_shapes__proxy_tool_completion_request.snap new file mode 100644 index 000000000..30ae2e15a --- /dev/null +++ b/tests/snapshots/worker_orchestrator_json_shapes__proxy_tool_completion_request.snap @@ -0,0 +1,18 @@ +--- +source: tests/worker_orchestrator_json_shapes.rs +assertion_line: 76 +expression: "&original" +--- +{ + "messages": [ + { + "role": "user", + "content": "hello" + } + ], + "tools": [], + "model": null, + "max_tokens": null, + "temperature": null, + "tool_choice": "auto" +} diff --git a/tests/snapshots/worker_orchestrator_json_shapes__remote_tool_catalog_response.snap b/tests/snapshots/worker_orchestrator_json_shapes__remote_tool_catalog_response.snap new file mode 100644 index 000000000..d3ad0bccd --- /dev/null +++ b/tests/snapshots/worker_orchestrator_json_shapes__remote_tool_catalog_response.snap @@ -0,0 +1,20 @@ +--- +source: tests/worker_orchestrator_json_shapes.rs +assertion_line: 130 +expression: "&parsed" +--- +{ + "tools": [ + { + "name": "t", + "description": "d", + "parameters": { + "type": "object" + } + } + ], + "toolset_instructions": [ + "Use bash carefully" + ], + "catalog_version": 7 +} diff --git a/tests/snapshots/worker_orchestrator_json_shapes__remote_tool_execution_request.snap b/tests/snapshots/worker_orchestrator_json_shapes__remote_tool_execution_request.snap new file mode 100644 index 000000000..41fabf6f5 --- /dev/null +++ b/tests/snapshots/worker_orchestrator_json_shapes__remote_tool_execution_request.snap @@ -0,0 +1,11 @@ +--- +source: tests/worker_orchestrator_json_shapes.rs +assertion_line: 60 +expression: "&original" +--- +{ + "tool_name": "my_tool", + "params": { + "key": "value" + } +} diff --git a/tests/snapshots/worker_orchestrator_json_shapes__status_update.snap b/tests/snapshots/worker_orchestrator_json_shapes__status_update.snap new file mode 100644 index 000000000..b09fd180e --- /dev/null +++ b/tests/snapshots/worker_orchestrator_json_shapes__status_update.snap @@ -0,0 +1,10 @@ +--- +source: tests/worker_orchestrator_json_shapes.rs +assertion_line: 18 +expression: "&original" +--- +{ + "state": "in_progress", + "message": "working", + "iteration": 42 +} diff --git a/tests/support/mod.rs b/tests/support/mod.rs index 56e42227a..cf8997a7f 100644 --- a/tests/support/mod.rs +++ b/tests/support/mod.rs @@ -16,6 +16,7 @@ pub mod test_rig; pub mod trace_llm; mod trace_provider; pub mod trace_types; +pub mod webhook_helpers; #[cfg(feature = "libsql")] #[expect( @@ -52,6 +53,16 @@ type AsyncTraceMetrics<'a> = type AsyncTraceRun<'a> = std::pin::Pin< Box>> + 'a>, >; +type AsyncStartedWebhookServer = std::pin::Pin< + Box< + dyn std::future::Future< + Output = Result< + webhook_helpers::StartedWebhookServer, + Box, + >, + > + Send, + >, +>; #[cfg(feature = "libsql")] type AsyncBuildRig = std::pin::Pin>>>; @@ -349,8 +360,26 @@ fn test_rig_symbol_refs() { touch_test_rig_symbols(); } +fn webhook_helpers_symbol_refs() { + const _: fn() -> axum::Router = webhook_helpers::health_routes; + const _: fn() -> Result = webhook_helpers::test_http_client; + const _: fn() -> std::mem::MaybeUninit = + std::mem::MaybeUninit::::uninit; + const _: fn(&webhook_helpers::StartedWebhookServer) = _touch_started_webhook_server_fields; + const _: fn() -> AsyncStartedWebhookServer = _start_health_server_sig; +} + +fn _start_health_server_sig() -> AsyncStartedWebhookServer { + Box::pin(webhook_helpers::start_health_server()) +} + +fn _touch_started_webhook_server_fields(server: &webhook_helpers::StartedWebhookServer) { + let _ = (&server.server, &server.addr, &server.client); +} + const _: fn() = trace_support_symbol_refs; const _: fn() = test_rig_symbol_refs; +const _: fn() = webhook_helpers_symbol_refs; // ============================================================================= // Routines module compile-time assertions @@ -361,33 +390,40 @@ const _: fn() = routines_symbol_refs; #[cfg(feature = "libsql")] fn routines_symbol_refs() { - // Compile-time type assertions for routines module helpers. - // These ensure the public API signatures remain stable. - const _: fn( - &std::sync::Arc, - ) -> std::sync::Arc = routines::create_workspace; - const _: fn( - &str, - ironclaw::agent::routine::Trigger, - &str, - ) -> ironclaw::agent::routine::Routine = routines::make_routine; - const _: fn(&str) -> ironclaw::channels::IncomingMessage = routines::make_test_incoming_message; - #[allow(clippy::type_complexity)] - const _: fn( - trace_llm::LlmTrace, - std::sync::Arc, - std::sync::Arc, - ) -> ( - std::sync::Arc, - tokio::sync::mpsc::Receiver, - ) = routines::make_minimal_engine; - - // Compile-time type assertions for engine_sync helpers. - // Wrapper functions prove the async signatures are correct. + #[cfg(feature = "libsql")] + let _ = routines::create_test_db; + let _ = routines::create_workspace + as fn( + &std::sync::Arc, + ) -> std::sync::Arc; + let _ = routines::make_minimal_engine + as fn( + trace_llm::LlmTrace, + std::sync::Arc, + std::sync::Arc, + ) -> ( + std::sync::Arc, + tokio::sync::mpsc::Receiver, + ); + let _ = routines::make_routine + as fn(&str, ironclaw::agent::routine::Trigger, &str) -> ironclaw::agent::routine::Routine; + let _ = routines::make_test_incoming_message as fn(&str) -> ironclaw::channels::IncomingMessage; + #[cfg(feature = "libsql")] + let _ = routines::register_github_issue_routine; + let _ = routines::assert_system_event_count; + + fn _system_event_spec_new_sig<'a>( + source: &'a str, + event_type: &'a str, + payload: serde_json::Value, + ) -> routines::SystemEventSpec<'a> { + routines::SystemEventSpec::new(source, event_type, payload) + } + fn _wait_for_idle_sig<'a>( engine: &'a ironclaw::agent::routine_engine::RoutineEngine, timeout: std::time::Duration, - ) -> std::pin::Pin + 'a>> { + ) -> std::pin::Pin> + 'a>> { Box::pin(routines::engine_sync::wait_for_idle(engine, timeout)) } @@ -396,7 +432,7 @@ fn routines_symbol_refs() { routine_id: uuid::Uuid, previous_run_count: usize, timeout: std::time::Duration, - ) -> std::pin::Pin + 'a>> { + ) -> std::pin::Pin> + 'a>> { Box::pin(routines::engine_sync::wait_for_persisted_run( db, routine_id, @@ -404,5 +440,9 @@ fn routines_symbol_refs() { timeout, )) } - touch!(_wait_for_idle_sig, _wait_for_persisted_run_sig); + touch!( + _system_event_spec_new_sig, + _wait_for_idle_sig, + _wait_for_persisted_run_sig + ); } diff --git a/tests/support/routines.rs b/tests/support/routines.rs index 8285e0608..a1b1bf60a 100644 --- a/tests/support/routines.rs +++ b/tests/support/routines.rs @@ -24,14 +24,12 @@ use ironclaw::workspace::Workspace; use crate::support::trace_llm::{LlmTrace, TraceLlm}; /// Describes a system event to be emitted in tests. -#[allow(dead_code)] pub struct SystemEventSpec<'a> { pub source: &'a str, pub event_type: &'a str, pub payload: serde_json::Value, } -#[allow(dead_code)] impl<'a> SystemEventSpec<'a> { pub fn new(source: &'a str, event_type: &'a str, payload: serde_json::Value) -> Self { Self { @@ -42,139 +40,163 @@ impl<'a> SystemEventSpec<'a> { } } -/// Create a temp libSQL database with migrations applied. -#[allow(dead_code)] -pub async fn create_test_db() -> Result<(Arc, TempDir), Box> { - use ironclaw::db::libsql::LibSqlBackend; +mod db { + use super::*; - let temp_dir = tempfile::tempdir()?; - let db_path = temp_dir.path().join("test.db"); - let backend = LibSqlBackend::new_local(&db_path).await?; - backend.run_migrations().await?; - let db: Arc = Arc::new(backend); - Ok((db, temp_dir)) -} + /// Create a temp libSQL database with migrations applied. + pub async fn create_test_db() -> Result<(Arc, TempDir), Box> + { + use ironclaw::db::libsql::LibSqlBackend; -/// Create a workspace backed by the test database. -#[allow(dead_code)] -pub fn create_workspace(db: &Arc) -> Arc { - Arc::new(Workspace::new_with_db("default", db.clone())) -} + let temp_dir = tempfile::tempdir()?; + let db_path = temp_dir.path().join("test.db"); + let backend = LibSqlBackend::new_local(&db_path).await?; + backend.run_migrations().await?; + let db: Arc = Arc::new(backend); + Ok((db, temp_dir)) + } -/// Helper to insert a routine directly into the database. -#[allow(dead_code)] -pub fn make_routine(name: &str, trigger: Trigger, prompt: &str) -> Routine { - Routine { - id: Uuid::new_v4(), - name: name.to_string(), - description: format!("Test routine: {name}"), - user_id: "default".to_string(), - enabled: true, - trigger, - action: RoutineAction::Lightweight { - prompt: prompt.to_string(), - context_paths: vec![], - max_tokens: 1000, - }, - guardrails: RoutineGuardrails { - cooldown: Duration::from_secs(0), - max_concurrent: 5, - dedup_window: None, - }, - notify: NotifyConfig::default(), - last_run_at: None, - next_fire_at: None, - run_count: 0, - consecutive_failures: 0, - state: serde_json::json!({}), - created_at: Utc::now(), - updated_at: Utc::now(), + /// Create a workspace backed by the test database. + pub fn create_workspace(db: &Arc) -> Arc { + Arc::new(Workspace::new_with_db("default", db.clone())) } } -/// Build a minimal IncomingMessage for event-trigger tests. -#[allow(dead_code)] -pub fn make_test_incoming_message(content: &str) -> IncomingMessage { - IncomingMessage { - id: Uuid::new_v4(), - channel: "test".to_string(), - user_id: "default".to_string(), - user_name: None, - content: content.to_string(), - thread_id: None, - received_at: Utc::now(), - metadata: serde_json::json!({}), - timezone: None, - attachments: Vec::new(), +mod builders { + use super::*; + + /// Helper to insert a routine directly into the database. + pub fn make_routine(name: &str, trigger: Trigger, prompt: &str) -> Routine { + Routine { + id: Uuid::new_v4(), + name: name.to_string(), + description: format!("Test routine: {name}"), + user_id: "default".to_string(), + enabled: true, + trigger, + action: RoutineAction::Lightweight { + prompt: prompt.to_string(), + context_paths: vec![], + max_tokens: 1000, + }, + guardrails: RoutineGuardrails { + cooldown: Duration::from_secs(0), + max_concurrent: 5, + dedup_window: None, + }, + notify: NotifyConfig::default(), + last_run_at: None, + next_fire_at: None, + run_count: 0, + consecutive_failures: 0, + state: serde_json::json!({}), + created_at: Utc::now(), + updated_at: Utc::now(), + } + } + + /// Build a minimal IncomingMessage for event-trigger tests. + pub fn make_test_incoming_message(content: &str) -> IncomingMessage { + IncomingMessage { + id: Uuid::new_v4(), + channel: "test".to_string(), + user_id: "default".to_string(), + user_name: None, + content: content.to_string(), + thread_id: None, + received_at: Utc::now(), + metadata: serde_json::json!({}), + timezone: None, + attachments: Vec::new(), + } } } -/// Build a minimal RoutineEngine from a TraceLlm, returning both the engine and the notify receiver. -#[allow(dead_code)] -pub fn make_minimal_engine( - trace: LlmTrace, - db: Arc, - ws: Arc, -) -> ( - Arc, - tokio::sync::mpsc::Receiver, -) { - let llm = Arc::new(TraceLlm::from_trace(trace)); - let (notify_tx, notify_rx) = tokio::sync::mpsc::channel(16); - let tools = Arc::new(ToolRegistry::new()); - let safety = Arc::new(SafetyLayer::new(&SafetyConfig { - max_output_length: 100_000, - injection_check_enabled: true, - })); - let engine = Arc::new(RoutineEngine::new( - RoutineConfig::default(), - db, - llm, - ws, - notify_tx, - None, - tools, - safety, - )); - (engine, notify_rx) +mod engine { + use super::*; + + /// Build a minimal RoutineEngine from a TraceLlm, returning both the engine and the notify receiver. + pub fn make_minimal_engine( + trace: LlmTrace, + db: Arc, + ws: Arc, + ) -> ( + Arc, + tokio::sync::mpsc::Receiver, + ) { + let llm = Arc::new(TraceLlm::from_trace(trace)); + let (notify_tx, notify_rx) = tokio::sync::mpsc::channel(16); + let tools = Arc::new(ToolRegistry::new()); + let safety = Arc::new(SafetyLayer::new(&SafetyConfig { + max_output_length: 100_000, + injection_check_enabled: true, + })); + let engine = Arc::new(RoutineEngine::new( + RoutineConfig::default(), + db, + llm, + ws, + notify_tx, + None, + tools, + safety, + )); + (engine, notify_rx) + } } -/// Register a GitHub issue routine for system event tests. -#[allow(dead_code)] -pub async fn register_github_issue_routine( - db: &Arc, - engine: &RoutineEngine, -) -> Routine { - let mut filters = std::collections::HashMap::new(); - filters.insert("repository".to_string(), "nearai/ironclaw".to_string()); - let routine = make_routine( - "github-issue-opened", - Trigger::SystemEvent { - source: "github".to_string(), - event_type: "issue.opened".to_string(), - filters, - }, - "Summarize the issue and propose an implementation plan.", - ); - db.create_routine(&routine).await.expect("create_routine"); - engine.refresh_event_cache().await; - routine +mod registration { + use super::builders::make_routine; + use super::*; + + /// Register a GitHub issue routine for system event tests. + pub async fn register_github_issue_routine( + db: &Arc, + engine: &RoutineEngine, + ) -> anyhow::Result { + let mut filters = std::collections::HashMap::new(); + filters.insert("repository".to_string(), "nearai/ironclaw".to_string()); + let routine = make_routine( + "github-issue-opened", + Trigger::SystemEvent { + source: "github".to_string(), + event_type: "issue.opened".to_string(), + filters, + }, + "Summarize the issue and propose an implementation plan.", + ); + db.create_routine(&routine).await?; + engine.refresh_event_cache().await; + Ok(routine) + } } -/// Assert that a system event fires the expected number of routines. -#[allow(dead_code)] -pub async fn assert_system_event_count( - engine: &RoutineEngine, - spec: SystemEventSpec<'_>, - expected: usize, - msg: &str, -) { - let fired = engine - .emit_system_event(spec.source, spec.event_type, &spec.payload, Some("default")) - .await; - assert_eq!(fired, expected, "{msg}"); +mod assertions { + use super::*; + + /// Assert that a system event fires the expected number of routines. + pub async fn assert_system_event_count( + engine: &RoutineEngine, + spec: SystemEventSpec<'_>, + expected: usize, + msg: &str, + ) { + let fired = engine + .emit_system_event(spec.source, spec.event_type, &spec.payload, Some("default")) + .await; + assert_eq!(fired, expected, "{msg}"); + } } +pub use assertions::assert_system_event_count; +pub use builders::{make_routine, make_test_incoming_message}; +#[cfg(feature = "libsql")] +pub use db::create_test_db; +pub use db::create_workspace; +pub use engine::make_minimal_engine; +#[cfg(feature = "libsql")] +pub use registration::register_github_issue_routine; + /// Deterministic synchronization helpers for tests that drive [`RoutineEngine`]. /// /// Scoped into their own inline module so that test binaries which do not exercise @@ -184,38 +206,28 @@ pub mod engine_sync { use std::sync::Arc; use std::time::Duration; + use anyhow::anyhow; use uuid::Uuid; use ironclaw::agent::routine_engine::RoutineEngine; use ironclaw::db::Database; - /// Polls until the engine's running count reaches zero or the timeout expires. + /// Waits briefly to let spawned routine work make progress before persistence checks. /// - /// This provides deterministic synchronization for tests that need to wait - /// for asynchronous routine execution to complete, eliminating timing-dependent - /// flakiness without slowing down the test suite on fast machines. + /// Integration tests do not compile against the `RoutineEngine::running_count` + /// test-only hook unless `test-helpers` is enabled, so this helper provides + /// a small best-effort hand-off point before [`wait_for_persisted_run`] does + /// the durable synchronization. /// - /// **Note:** Combine with [`wait_for_persisted_run`] to ensure both execution - /// completion and database persistence, as the running count may reach zero - /// before the database record is fully committed. - pub async fn wait_for_idle(engine: &RoutineEngine, timeout: Duration) { - let start = std::time::Instant::now(); - let poll_interval = Duration::from_millis(10); - - loop { - if engine.running_count() == 0 { - return; - } - - if start.elapsed() >= timeout { - panic!( - "Timeout waiting for engine to become idle (running count: {})", - engine.running_count() - ); - } - - tokio::time::sleep(poll_interval).await; - } + /// **Note:** Always combine with [`wait_for_persisted_run`] to ensure the + /// database record is durably committed before asserting on stored state. + pub async fn wait_for_idle( + engine: &RoutineEngine, + timeout: Duration, + ) -> Result<(), anyhow::Error> { + let _ = engine; + tokio::time::sleep(timeout.min(Duration::from_millis(10))).await; + Ok(()) } /// Polls until a new routine run is persisted in the database or the timeout expires. @@ -234,29 +246,29 @@ pub mod engine_sync { routine_id: Uuid, previous_run_count: usize, timeout: Duration, - ) { + ) -> Result<(), anyhow::Error> { let start = std::time::Instant::now(); let poll_interval = Duration::from_millis(10); loop { let runs = db - .list_routine_runs(routine_id, 10) + .list_routine_runs(routine_id, (previous_run_count + 1) as i64) .await - .expect("list_routine_runs should not fail"); + .map_err(|e| anyhow!(e))?; if runs.len() > previous_run_count { - return; + return Ok(()); } if start.elapsed() >= timeout { - panic!( + return Err(anyhow!( "Timeout waiting for routine run to be persisted (routine_id: {}, \ previous_count: {}, current_count: {}, elapsed: {:?})", routine_id, previous_run_count, runs.len(), start.elapsed() - ); + )); } tokio::time::sleep(poll_interval).await; diff --git a/tests/support/webhook_helpers.rs b/tests/support/webhook_helpers.rs new file mode 100644 index 000000000..fedfa9d18 --- /dev/null +++ b/tests/support/webhook_helpers.rs @@ -0,0 +1,51 @@ +//! Shared helpers for WebhookServer integration tests. +//! +//! Provides reusable server setup and client construction so that +//! `tests/webhook_server.rs` and `tests/infrastructure/sighup_reload.rs` +//! share the same configuration. + +use std::net::SocketAddr; +use std::time::Duration; + +use axum::Json; +use axum::Router; +use axum::routing::get; +use serde_json::json; + +use ironclaw::channels::{WebhookServer, WebhookServerConfig}; + +/// A started webhook server with a `/health` route and a pre-built client. +pub struct StartedWebhookServer { + pub server: WebhookServer, + pub addr: SocketAddr, + pub client: reqwest::Client, +} + +/// Return the standard `/health` check route used by webhook tests. +pub fn health_routes() -> Router { + Router::new().route("/health", get(|| async { Json(json!({"status": "ok"})) })) +} + +/// Build a reqwest client with the standard 2-second test timeout. +pub fn test_http_client() -> Result { + reqwest::Client::builder() + .timeout(Duration::from_secs(2)) + .build() +} + +/// Bind an ephemeral listener, build a WebhookServer with a `/health` +/// route, start the server, and return the started server plus a +/// preconfigured client. +pub async fn start_health_server() +-> Result> { + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?; + let addr = listener.local_addr()?; + let mut server = WebhookServer::new(WebhookServerConfig { addr }); + server.add_routes(health_routes()); + server.start_with_listener(listener).await?; + Ok(StartedWebhookServer { + server, + addr, + client: test_http_client()?, + }) +} diff --git a/tests/webhook_server.rs b/tests/webhook_server.rs new file mode 100644 index 000000000..537b08e7c --- /dev/null +++ b/tests/webhook_server.rs @@ -0,0 +1,158 @@ +//! Integration tests for WebhookServer. + +use std::net::SocketAddr; +use std::net::TcpListener as StdTcpListener; + +use rstest::{fixture, rstest}; + +mod support; + +use support::webhook_helpers::{self, StartedWebhookServer}; + +/// Binds an ephemeral port, creates a [`WebhookServer`] with a `/health` +/// route, starts the server on the already-bound listener, and returns the +/// address and a client. +#[fixture] +async fn started_webhook_server() +-> Result> { + webhook_helpers::start_health_server().await +} + +#[rstest] +#[tokio::test] +async fn test_restart_with_addr_rebinds_listener( + #[future] started_webhook_server: Result< + StartedWebhookServer, + Box, + >, +) -> Result<(), Box> { + let StartedWebhookServer { + mut server, + addr: addr1, + client, + } = started_webhook_server.await?; + + assert_eq!( + server.current_addr(), + addr1, + "Server should be bound to initial address" + ); + + let response = client + .get(format!("http://{}/health", addr1)) + .send() + .await?; + assert_eq!( + response.status(), + 200, + "First server should respond to health check" + ); + + // Find a second available port and restart. + // NOTE: This allocates an ephemeral port via StdTcpListener and then drops + // the listener, which creates a TOCTOU race: another process could claim the + // port before restart_with_addr binds to it. This is unavoidable for testing + // restart_with_addr (which accepts an address, not a bound listener). The test + // accepts this risk because the probability of collision on an ephemeral port + // in a controlled test environment is acceptably low. + let port2 = { + let listener = StdTcpListener::bind("127.0.0.1:0")?; + listener.local_addr()?.port() + }; + let addr2: SocketAddr = format!("127.0.0.1:{}", port2).parse()?; + + server.restart_with_addr(addr2).await?; + + assert_eq!( + server.current_addr(), + addr2, + "Server address should be updated after restart" + ); + assert_ne!( + addr1, addr2, + "Address should change after restart_with_addr" + ); + + let response = client + .get(format!("http://{}/health", addr2)) + .send() + .await?; + assert_eq!( + response.status(), + 200, + "Restarted server should respond to health check on new address" + ); + + let old_result = tokio::time::timeout( + std::time::Duration::from_millis(200), + client.get(format!("http://{}/health", addr1)).send(), + ) + .await; + match old_result { + // Timeout expired — the old address no longer accepts connections. + Err(_) => {} + // Request reached the client stack but the old listener was gone. + Ok(Err(_)) => {} + Ok(Ok(resp)) => { + panic!( + "Old address should not respond after server restarts, got status {}", + resp.status() + ); + } + } + + server.shutdown().await; + Ok(()) +} + +#[rstest] +#[tokio::test] +async fn test_restart_with_addr_rollback_on_bind_failure( + #[future] started_webhook_server: Result< + StartedWebhookServer, + Box, + >, +) -> Result<(), Box> { + let StartedWebhookServer { + mut server, + addr: addr1, + client, + } = started_webhook_server.await?; + + let response = client + .get(format!("http://{}/health", addr1)) + .send() + .await?; + assert_eq!(response.status(), 200, "Server should be listening"); + + // Occupy a second port so the restart bind fails deterministically. + let occupied_listener = StdTcpListener::bind("127.0.0.1:0")?; + let conflict_addr = occupied_listener.local_addr()?; + + let result = server.restart_with_addr(conflict_addr).await; + assert!( + result.is_err(), + "Restart with already-bound address should fail" + ); + + drop(occupied_listener); + + let response = client + .get(format!("http://{}/health", addr1)) + .send() + .await?; + assert_eq!( + response.status(), + 200, + "Old listener should still be running after failed restart" + ); + + assert_eq!( + server.current_addr(), + addr1, + "Server address should be restored after failed restart" + ); + + server.shutdown().await; + Ok(()) +} diff --git a/tests/worker_orchestrator_contract.rs b/tests/worker_orchestrator_contract.rs new file mode 100644 index 000000000..4c2cf2a03 --- /dev/null +++ b/tests/worker_orchestrator_contract.rs @@ -0,0 +1,240 @@ +//! Contract tests verifying route-path and HTTP-method symmetry +//! between worker client paths and `OrchestratorApi` routes. + +use std::collections::HashMap; +use std::sync::Arc; + +use axum::body::Body; +use axum::http::{Method, Request, StatusCode}; + +use tokio::sync::Mutex; +use tower::ServiceExt; +use uuid::Uuid; + +use ironclaw::llm::{ + CompletionRequest, CompletionResponse, NativeLlmProvider, ToolCompletionRequest, + ToolCompletionResponse, +}; +use ironclaw::orchestrator::api::{OrchestratorApi, OrchestratorState}; +use ironclaw::orchestrator::auth::TokenStore; +use ironclaw::orchestrator::job_manager::{ContainerJobConfig, ContainerJobManager}; +use ironclaw::tools::ToolRegistry; +use ironclaw::worker::api::{ + COMPLETE_ROUTE, CREDENTIALS_ROUTE, EVENT_ROUTE, JOB_ROUTE, LLM_COMPLETE_ROUTE, + LLM_COMPLETE_WITH_TOOLS_ROUTE, PROMPT_ROUTE, REMOTE_TOOL_CATALOG_ROUTE, + REMOTE_TOOL_EXECUTE_ROUTE, STATUS_ROUTE, WORKER_HEALTH_ROUTE, job_scoped_path, +}; + +// --------------------------------------------------------------------------- +// Minimal stub LLM for integration tests +// --------------------------------------------------------------------------- + +#[derive(Debug, Default)] +struct StubLlm; + +impl NativeLlmProvider for StubLlm { + fn model_name(&self) -> &str { + "stub" + } + + fn cost_per_token(&self) -> (rust_decimal::Decimal, rust_decimal::Decimal) { + (rust_decimal::Decimal::ZERO, rust_decimal::Decimal::ZERO) + } + + async fn complete( + &self, + _req: CompletionRequest, + ) -> Result { + Ok(Default::default()) + } + + async fn complete_with_tools( + &self, + _req: ToolCompletionRequest, + ) -> Result { + Ok(Default::default()) + } +} + +// --------------------------------------------------------------------------- +// Fixtures +// --------------------------------------------------------------------------- + +fn make_state() -> OrchestratorState { + let token_store = TokenStore::new(); + let jm = ContainerJobManager::new(ContainerJobConfig::default(), token_store.clone()); + OrchestratorState { + llm: Arc::new(StubLlm), + tools: Arc::new(ToolRegistry::new()), + job_manager: Arc::new(jm), + token_store, + job_event_tx: None, + prompt_queue: Arc::new(Mutex::new(HashMap::new())), + store: None, + secrets_store: None, + user_id: "default".to_string(), + } +} + +// --------------------------------------------------------------------------- +// 1. Route-path alignment +// --------------------------------------------------------------------------- + +#[test] +fn worker_paths_match_route_constants() { + use ironclaw::worker::api::{ + COMPLETE_PATH, CREDENTIALS_PATH, EVENT_PATH, JOB_PATH, LLM_COMPLETE_PATH, + LLM_COMPLETE_WITH_TOOLS_PATH, PROMPT_PATH, REMOTE_TOOL_CATALOG_PATH, + REMOTE_TOOL_EXECUTE_PATH, STATUS_PATH, + }; + + let pairs: &[(&str, &str)] = &[ + (JOB_PATH, JOB_ROUTE), + (STATUS_PATH, STATUS_ROUTE), + (COMPLETE_PATH, COMPLETE_ROUTE), + (EVENT_PATH, EVENT_ROUTE), + (PROMPT_PATH, PROMPT_ROUTE), + (CREDENTIALS_PATH, CREDENTIALS_ROUTE), + (LLM_COMPLETE_PATH, LLM_COMPLETE_ROUTE), + (LLM_COMPLETE_WITH_TOOLS_PATH, LLM_COMPLETE_WITH_TOOLS_ROUTE), + (REMOTE_TOOL_CATALOG_PATH, REMOTE_TOOL_CATALOG_ROUTE), + (REMOTE_TOOL_EXECUTE_PATH, REMOTE_TOOL_EXECUTE_ROUTE), + ]; + + for (rel, route) in pairs { + let job_id = Uuid::new_v4(); + let scoped = job_scoped_path(&job_id.to_string(), rel); + let expected = route.replace("{job_id}", &job_id.to_string()); + assert_eq!( + scoped.trim_end_matches('/'), + expected.trim_end_matches('/'), + "job_scoped_path for '{}' does not match route '{}'", + rel, + route, + ); + } +} + +#[test] +fn worker_job_url_produces_correct_path() { + use ironclaw::worker::api::worker_job_url; + + let job_id = Uuid::new_v4(); + let url = worker_job_url("http://host:1234", &job_id.to_string(), "status"); + assert_eq!(url, format!("http://host:1234/worker/{}/status", job_id)); +} + +// --------------------------------------------------------------------------- +// 2. HTTP method correctness +// --------------------------------------------------------------------------- + +/// Route-to-verb table built from the imported route constants so it stays in +/// sync with the orchestrator router definition in `src/orchestrator/api.rs`. +const ROUTE_METHOD_TABLE: &[(&str, &str)] = &[ + (WORKER_HEALTH_ROUTE, "GET"), + (JOB_ROUTE, "GET"), + (LLM_COMPLETE_ROUTE, "POST"), + (LLM_COMPLETE_WITH_TOOLS_ROUTE, "POST"), + (REMOTE_TOOL_CATALOG_ROUTE, "GET"), + (REMOTE_TOOL_EXECUTE_ROUTE, "POST"), + (STATUS_ROUTE, "POST"), + (COMPLETE_ROUTE, "POST"), + (EVENT_ROUTE, "POST"), + (PROMPT_ROUTE, "GET"), + (CREDENTIALS_ROUTE, "GET"), +]; + +#[tokio::test] +async fn wrong_method_yields_method_not_allowed() { + let state = make_state(); + let job_id = Uuid::new_v4(); + let token = state.token_store.create_token(job_id).await; + let router = OrchestratorApi::router(state); + + for &(route, expected) in ROUTE_METHOD_TABLE { + let wrong = if expected == "GET" { "POST" } else { "GET" }; + let uri = route.replace("{job_id}", &job_id.to_string()); + let mut builder = Request::builder().method(wrong).uri(&uri); + if route != WORKER_HEALTH_ROUTE { + builder = builder.header("Authorization", format!("Bearer {}", token)); + } + let resp = router + .clone() + .oneshot(builder.body(Body::empty()).expect("build request")) + .await + .expect("send request"); + assert_eq!( + resp.status(), + StatusCode::METHOD_NOT_ALLOWED, + "wrong method {} on {} should yield 405", + wrong, + route, + ); + } +} + +// --------------------------------------------------------------------------- +// 3. Auth-header convention +// --------------------------------------------------------------------------- + +async fn assert_all_authenticated_routes_yield_unauthorized( + router: axum::Router, + job_id: Uuid, + auth_header: Option, +) { + for &(route, verb) in ROUTE_METHOD_TABLE + .iter() + .filter(|(r, _)| *r != WORKER_HEALTH_ROUTE) + { + let uri = route.replace("{job_id}", &job_id.to_string()); + let method = Method::from_bytes(verb.as_bytes()).expect("valid HTTP method"); + let mut builder = Request::builder().method(method).uri(&uri); + if let Some(ref header) = auth_header { + builder = builder.header("Authorization", header.as_str()); + } + let resp = router + .clone() + .oneshot(builder.body(Body::empty()).expect("build request")) + .await + .expect("send request"); + assert_eq!( + resp.status(), + StatusCode::UNAUTHORIZED, + "route {route} with {verb} should yield 401", + ); + } +} + +#[tokio::test] +async fn no_auth_header_yields_unauthorized() { + let router = OrchestratorApi::router(make_state()); + let job_id = Uuid::new_v4(); + assert_all_authenticated_routes_yield_unauthorized(router, job_id, None).await; +} + +#[tokio::test] +async fn wrong_bearer_token_yields_unauthorized() { + let router = OrchestratorApi::router(make_state()); + let job_id = Uuid::new_v4(); + assert_all_authenticated_routes_yield_unauthorized( + router, + job_id, + Some("Bearer totally-wrong-token".to_string()), + ) + .await; +} + +#[tokio::test] +async fn valid_token_wrong_job_yields_unauthorized() { + let other_job = Uuid::new_v4(); + let state = make_state(); + let token = state.token_store.create_token(other_job).await; + let router = OrchestratorApi::router(state); + let target_job = Uuid::new_v4(); + assert_all_authenticated_routes_yield_unauthorized( + router, + target_job, + Some(format!("Bearer {}", token)), + ) + .await; +} diff --git a/tests/worker_orchestrator_json_shapes.rs b/tests/worker_orchestrator_json_shapes.rs new file mode 100644 index 000000000..b2b74607e --- /dev/null +++ b/tests/worker_orchestrator_json_shapes.rs @@ -0,0 +1,172 @@ +//! JSON shape symmetry tests for worker-orchestrator wire types. +//! +//! Each test round-trips a DTO through JSON serialization and asserts the +//! wire shape via `insta` snapshot macros, so changes produce a single +//! diffable artifact. + +use ironclaw::llm::ChatMessage; +use ironclaw::worker::api::{ + CompletionReport, CredentialResponse, JobDescription, JobEventPayload, JobEventType, + PromptResponse, ProxyCompletionResponse, ProxyFinishReason, ProxyToolCompletionRequest, + RemoteToolCatalogResponse, RemoteToolExecutionRequest, StatusUpdate, WorkerState, +}; + +#[test] +fn status_update_round_trips() { + let original = StatusUpdate::new(WorkerState::InProgress, Some("working".into()), 42); + let json = serde_json::to_string(&original).expect("serialize"); + insta::assert_json_snapshot!("status_update", &original); + let back: StatusUpdate = serde_json::from_str(&json).expect("deserialize"); + assert_eq!(back.state, original.state); + assert_eq!(back.message, original.message); + assert_eq!(back.iteration, original.iteration); +} + +#[test] +fn job_event_payload_round_trips() { + let original = JobEventPayload { + event_type: JobEventType::ToolUse, + data: serde_json::json!({"tool": "bash"}), + }; + let json = serde_json::to_string(&original).expect("serialize"); + insta::assert_json_snapshot!("job_event_payload", &original); + let back: JobEventPayload = serde_json::from_str(&json).expect("deserialize"); + assert_eq!(back.event_type, original.event_type); + assert_eq!(back.data, original.data); +} + +#[test] +fn completion_report_round_trips() { + let original = CompletionReport { + success: true, + message: Some("done".into()), + iterations: 10, + }; + let json = serde_json::to_string(&original).expect("serialize"); + insta::assert_json_snapshot!("completion_report", &original); + let back: CompletionReport = serde_json::from_str(&json).expect("deserialize"); + assert_eq!(back.success, original.success); + assert_eq!(back.message, original.message); + assert_eq!(back.iterations, original.iterations); +} + +#[test] +fn remote_tool_execution_request_round_trips() { + let original = RemoteToolExecutionRequest { + tool_name: "my_tool".into(), + params: serde_json::json!({"key": "value"}), + }; + let json = serde_json::to_string(&original).expect("serialize"); + insta::assert_json_snapshot!("remote_tool_execution_request", &original); + let back: RemoteToolExecutionRequest = serde_json::from_str(&json).expect("deserialize"); + assert_eq!(back, original); +} + +#[test] +fn proxy_tool_completion_request_round_trips() { + let original = ProxyToolCompletionRequest { + messages: vec![ChatMessage::user("hello")], + tools: vec![], + model: None, + max_tokens: None, + temperature: None, + tool_choice: Some("auto".into()), + }; + let json = serde_json::to_string(&original).expect("serialize"); + insta::assert_json_snapshot!("proxy_tool_completion_request", &original); + let back: ProxyToolCompletionRequest = serde_json::from_str(&json).expect("deserialize"); + assert_eq!(back.tool_choice, original.tool_choice); +} + +#[test] +fn proxy_completion_response_from_fixture() { + let fixture = serde_json::json!({ + "content": "Hello", + "input_tokens": 100, + "output_tokens": 50, + "finish_reason": "stop", + "cache_read_input_tokens": 10, + "cache_creation_input_tokens": 5 + }); + let parsed: ProxyCompletionResponse = serde_json::from_value(fixture).expect("parse"); + insta::assert_json_snapshot!("proxy_completion_response", &parsed); + assert_eq!(parsed.content, "Hello"); + assert_eq!(parsed.input_tokens, 100); + assert_eq!(parsed.finish_reason, ProxyFinishReason::Stop); + + let re = serde_json::to_string(&parsed).expect("serialize"); + let back: ProxyCompletionResponse = serde_json::from_str(&re).expect("re-parse"); + assert_eq!(back.content, parsed.content); + assert_eq!(back.input_tokens, parsed.input_tokens); +} + +#[test] +fn job_description_from_fixture() { + let fixture = serde_json::json!({ + "title": "Test Job", + "description": "Do something", + "project_dir": "/tmp/project" + }); + let parsed: JobDescription = serde_json::from_value(fixture).expect("parse"); + insta::assert_json_snapshot!("job_description", &parsed); + assert_eq!(parsed.title, "Test Job"); + assert_eq!(parsed.description, "Do something"); + assert_eq!(parsed.project_dir.as_deref(), Some("/tmp/project")); + + let re = serde_json::to_string(&parsed).expect("serialize"); + let back: JobDescription = serde_json::from_str(&re).expect("re-parse"); + assert_eq!(back.title, parsed.title); + assert_eq!(back.description, parsed.description); +} + +#[test] +fn remote_tool_catalog_response_from_fixture() { + let fixture = serde_json::json!({ + "tools": [{"name": "t", "description": "d", "parameters": {"type": "object"}}], + "toolset_instructions": ["Use bash carefully"], + "catalog_version": 7 + }); + let parsed: RemoteToolCatalogResponse = serde_json::from_value(fixture).expect("parse"); + insta::assert_json_snapshot!("remote_tool_catalog_response", &parsed); + assert_eq!(parsed.catalog_version, 7); + + let re = serde_json::to_string(&parsed).expect("serialize"); + let back: RemoteToolCatalogResponse = serde_json::from_str(&re).expect("re-parse"); + assert_eq!(back, parsed); +} + +#[test] +fn credential_response_from_fixture() { + let fixture = serde_json::json!({"env_var": "API_KEY", "value": "secret123"}); + let parsed: CredentialResponse = serde_json::from_value(fixture).expect("parse"); + insta::assert_json_snapshot!("credential_response", &parsed); + assert_eq!(parsed.env_var, "API_KEY"); + assert_eq!(parsed.value, "secret123"); +} + +#[test] +fn prompt_response_from_fixture() { + let fixture = serde_json::json!({"content": "Continue?", "done": false}); + let parsed: PromptResponse = serde_json::from_value(fixture).expect("parse"); + insta::assert_json_snapshot!("prompt_response", &parsed); + assert_eq!(parsed.content, "Continue?"); + assert!(!parsed.done); +} + +// --------------------------------------------------------------------------- +// ProxyFinishReason aliases +// --------------------------------------------------------------------------- + +#[test] +fn finish_reason_tool_calls_alias() { + let reason: ProxyFinishReason = + serde_json::from_value(serde_json::json!("tool_calls")).expect("parse"); + assert_eq!(reason, ProxyFinishReason::ToolUse); +} + +#[test] +fn finish_reason_unknown_fallback() { + let reason: ProxyFinishReason = + serde_json::from_value(serde_json::json!("some_novel_reason")).expect("parse"); + assert_eq!(reason, ProxyFinishReason::Unknown); +}