diff --git a/.github/workflows/test-integration-docker.yml b/.github/workflows/test-integration-docker.yml index 84e7349..5fb3b4c 100644 --- a/.github/workflows/test-integration-docker.yml +++ b/.github/workflows/test-integration-docker.yml @@ -1,25 +1,25 @@ name: Integration Tests with Docker Mixnet on: - push: - branches: [ main, master ] pull_request: branches: [ main, master ] jobs: test-integration-docker: runs-on: ubuntu-latest - + timeout-minutes: 90 + steps: - name: Checkout thinclient repository uses: actions/checkout@v4 with: path: thinclient - - name: Checkout katzenpost repository + - name: Checkout katzenpost repository uses: actions/checkout@v4 with: repository: katzenpost/katzenpost + ref: 5d627b4bccc4e12d07d28710fb3af905ba083994 path: katzenpost - name: Set up Docker Buildx @@ -58,17 +58,19 @@ jobs: cd katzenpost/docker && make start wait - name: Brief pause to ensure mixnet is fully ready - run: sleep 5 + run: sleep 30 - name: Run all Python tests (including channel API integration tests) + timeout-minutes: 45 run: | cd thinclient - python -m pytest tests/ -vvv -s --tb=short + python -m pytest tests/ -vvv -s --tb=short --timeout=1200 - name: Run Rust integration tests + timeout-minutes: 45 run: | cd thinclient - cargo test --test '*' -- --nocapture + cargo test --test '*' -- --nocapture --test-threads=1 - name: Stop the mixnet if: always() diff --git a/.gitignore b/.gitignore index 0935fec..d4b7173 100644 --- a/.gitignore +++ b/.gitignore @@ -1,17 +1,2 @@ __pycache__/ - /target - - -# Added by cargo -# -# already existing elements were commented out - -#/target - - -# Added by cargo -# -# already existing elements were commented out - -#/target diff --git a/Cargo.lock b/Cargo.lock index be34c12..1cfd9bc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -76,6 +76,12 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "anyhow" +version = "1.0.102" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" + [[package]] name = "autocfg" version = "1.4.0" @@ -97,6 +103,12 @@ dependencies = [ "windows-targets", ] +[[package]] +name = "base64" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + [[package]] name = "bitflags" version = "2.9.0" @@ -121,24 +133,98 @@ dependencies = [ "generic-array", ] +[[package]] +name = "bumpalo" +version = "3.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d20789868f4b01b2f2caec9f5c4e0213b41e3e5702a50157d699ae31ced2fcb" + [[package]] name = "bytes" version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" +[[package]] +name = "cc" +version = "1.2.56" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aebf35691d1bfb0ac386a69bac2fde4dd276fb618cf8bf4f5318fe285e821bb2" +dependencies = [ + "find-msvc-tools", + "shlex", +] + [[package]] name = "cfg-if" version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "clap" +version = "4.5.60" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2797f34da339ce31042b27d23607e051786132987f595b02ba4f6a6dffb7030a" +dependencies = [ + "clap_builder", + "clap_derive", +] + +[[package]] +name = "clap_builder" +version = "4.5.60" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24a241312cea5059b13574bb9b3861cabf758b879c15190b37b6d6fd63ab6876" +dependencies = [ + "anstream", + "anstyle", + "clap_lex", + "strsim", +] + +[[package]] +name = "clap_derive" +version = "4.5.55" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a92793da1a46a5f2a02a6f4c46c6496b28c43638adea8306fcb0caa1634f24e5" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "clap_lex" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a822ea5bc7590f9d40f1ba12c0dc3c2760f3482c6984db1573ad11031420831" + [[package]] name = "colorchoice" version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990" +[[package]] +name = "crdts" +version = "7.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "387808c885b79055facbd4b2e806a683fe1bc37abc7dfa5fea1974ad2d4137b0" +dependencies = [ + "num", + "quickcheck", + "serde", + "tiny-keccak", +] + +[[package]] +name = "crunchy" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5" + [[package]] name = "crypto-common" version = "0.1.6" @@ -189,6 +275,36 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" +[[package]] +name = "fallible-iterator" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2acce4a10f12dc2fb14a218589d4f1f62ef011b2d0cc4b3cb1bba8e94da14649" + +[[package]] +name = "fallible-streaming-iterator" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a" + +[[package]] +name = "find-msvc-tools" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" + +[[package]] +name = "foldhash" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" + +[[package]] +name = "foldhash" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77ce24cb58228fbb8aa041425bb1050850ac19177686ea6e0f41a70416f56fdb" + [[package]] name = "generic-array" version = "0.14.7" @@ -210,6 +326,20 @@ dependencies = [ "wasi", ] +[[package]] +name = "getrandom" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0de51e6874e94e7bf76d726fc5d13ba782deca734ff60d5bb2fb2607c7406555" +dependencies = [ + "cfg-if", + "libc", + "r-efi", + "rand_core 0.10.0", + "wasip2", + "wasip3", +] + [[package]] name = "gimli" version = "0.31.1" @@ -227,6 +357,33 @@ name = "hashbrown" version = "0.15.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" +dependencies = [ + "foldhash 0.1.5", +] + +[[package]] +name = "hashbrown" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" +dependencies = [ + "foldhash 0.2.0", +] + +[[package]] +name = "hashlink" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea0b22561a9c04a7cb1a302c013e0259cd3b4bb619f145b32f72b8b4bcbed230" +dependencies = [ + "hashbrown 0.16.1", +] + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" [[package]] name = "hex" @@ -234,6 +391,12 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" +[[package]] +name = "id-arena" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d3067d79b975e8844ca9eb072e16b31c3c1c36928edf9c6789548c524d0d954" + [[package]] name = "indexmap" version = "2.9.0" @@ -241,7 +404,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cea70ddb795996207ad57735b50c5982d8844f38ba9ee5f1aedcfb708a2aa11e" dependencies = [ "equivalent", - "hashbrown", + "hashbrown 0.15.2", + "serde", ] [[package]] @@ -280,18 +444,33 @@ dependencies = [ "syn", ] +[[package]] +name = "js-sys" +version = "0.3.91" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b49715b7073f385ba4bc528e5747d02e66cb39c6146efb66b781f131f0fb399c" +dependencies = [ + "once_cell", + "wasm-bindgen", +] + [[package]] name = "katzenpost_thin_client" version = "0.0.11" dependencies = [ + "base64", "blake2", + "clap", + "crdts", "env_logger", "generic-array", "hex", "libc", "log", - "rand", + "rand 0.8.5", + "rusqlite", "serde", + "serde_bytes", "serde_cbor", "serde_json", "tokio", @@ -299,12 +478,29 @@ dependencies = [ "typenum", ] +[[package]] +name = "leb128fmt" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09edd9e8b54e49e587e4f6295a7d29c3ea94d469cb40ab8ca70b288248a81db2" + [[package]] name = "libc" version = "0.2.171" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c19937216e9d3aa9956d9bb8dfc0b0c8beb6058fc4f7a4dc4d850edf86a237d6" +[[package]] +name = "libsqlite3-sys" +version = "0.36.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95b4103cffefa72eb8428cb6b47d6627161e51c2739fc5e3b734584157bc642a" +dependencies = [ + "cc", + "pkg-config", + "vcpkg", +] + [[package]] name = "lock_api" version = "0.4.12" @@ -347,6 +543,82 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "num" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35bd024e8b2ff75562e5f34e7f4905839deb4b22955ef5e73d2fea1b9813cb23" +dependencies = [ + "num-bigint", + "num-complex", + "num-integer", + "num-iter", + "num-rational", + "num-traits", +] + +[[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", + "serde", +] + +[[package]] +name = "num-complex" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +dependencies = [ + "num-traits", + "serde", +] + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-iter" +version = "0.1.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1429034a0490724d0075ebb2bc9e875d6503c3cf69e235a8941aa757d83ef5bf" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-rational" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824" +dependencies = [ + "num-bigint", + "num-integer", + "num-traits", + "serde", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + [[package]] name = "object" version = "0.36.7" @@ -391,6 +663,12 @@ version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" +[[package]] +name = "pkg-config" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" + [[package]] name = "portable-atomic" version = "1.11.0" @@ -415,6 +693,16 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "prettyplease" +version = "0.2.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" +dependencies = [ + "proc-macro2", + "syn", +] + [[package]] name = "proc-macro2" version = "1.0.94" @@ -424,6 +712,17 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "quickcheck" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95c589f335db0f6aaa168a7cd27b1fc6920f5e1470c804f814d9cd6e62a0f70b" +dependencies = [ + "env_logger", + "log", + "rand 0.10.0", +] + [[package]] name = "quote" version = "1.0.40" @@ -433,6 +732,12 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "r-efi" +version = "6.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8dcc9c7d52a811697d2151c701e0d08956f92b0e24136cf4cf27b57a6a0d9bf" + [[package]] name = "rand" version = "0.8.5" @@ -441,7 +746,17 @@ checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" dependencies = [ "libc", "rand_chacha", - "rand_core", + "rand_core 0.6.4", +] + +[[package]] +name = "rand" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc266eb313df6c5c09c1c7b1fbe2510961e5bcd3add930c1e31f7ed9da0feff8" +dependencies = [ + "getrandom 0.4.2", + "rand_core 0.10.0", ] [[package]] @@ -451,7 +766,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" dependencies = [ "ppv-lite86", - "rand_core", + "rand_core 0.6.4", ] [[package]] @@ -460,9 +775,15 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ - "getrandom", + "getrandom 0.2.15", ] +[[package]] +name = "rand_core" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c8d0fd677905edcbeedbf2edb6494d676f0e98d54d5cf9bda0b061cb8fb8aba" + [[package]] name = "redox_syscall" version = "0.5.10" @@ -501,12 +822,43 @@ version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" +[[package]] +name = "rsqlite-vfs" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8a1f2315036ef6b1fbacd1972e8ee7688030b0a2121edfc2a6550febd41574d" +dependencies = [ + "hashbrown 0.16.1", + "thiserror", +] + +[[package]] +name = "rusqlite" +version = "0.38.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1c93dd1c9683b438c392c492109cb702b8090b2bfc8fed6f6e4eb4523f17af3" +dependencies = [ + "bitflags", + "fallible-iterator", + "fallible-streaming-iterator", + "hashlink", + "libsqlite3-sys", + "smallvec", + "sqlite-wasm-rs", +] + [[package]] name = "rustc-demangle" version = "0.1.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" +[[package]] +name = "rustversion" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" + [[package]] name = "ryu" version = "1.0.20" @@ -519,15 +871,32 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "semver" +version = "1.0.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2" + [[package]] name = "serde" -version = "1.0.219" +version = "1.0.228" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" dependencies = [ + "serde_core", "serde_derive", ] +[[package]] +name = "serde_bytes" +version = "0.11.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5d440709e79d88e51ac01c4b72fc6cb7314017bb7da9eeff678aa94c10e3ea8" +dependencies = [ + "serde", + "serde_core", +] + [[package]] name = "serde_cbor" version = "0.11.2" @@ -538,11 +907,20 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", +] + [[package]] name = "serde_derive" -version = "1.0.219" +version = "1.0.228" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" dependencies = [ "proc-macro2", "quote", @@ -570,6 +948,12 @@ dependencies = [ "serde", ] +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + [[package]] name = "signal-hook-registry" version = "1.4.2" @@ -595,6 +979,24 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "sqlite-wasm-rs" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f4206ed3a67690b9c29b77d728f6acc3ce78f16bf846d83c94f76400320181b" +dependencies = [ + "cc", + "js-sys", + "rsqlite-vfs", + "wasm-bindgen", +] + +[[package]] +name = "strsim" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" + [[package]] name = "subtle" version = "2.6.1" @@ -603,15 +1005,44 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "syn" -version = "2.0.100" +version = "2.0.117" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b09a44accad81e1ba1cd74a32461ba89dee89095ba17b32f5d03683b1b1fc2a0" +checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99" dependencies = [ "proc-macro2", "quote", "unicode-ident", ] +[[package]] +name = "thiserror" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4288b5bcbc7920c07a1149a35cf9590a2aa808e0bc1eafaade0b80947865fbc4" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc4ee7f67670e9b64d05fa4253e753e016c6c95ff35b89b7941d6b856dec1d5" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tiny-keccak" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c9d3793400a45f954c52e73d068316d76b6f4e36977e3fcebb13a2721e80237" +dependencies = [ + "crunchy", +] + [[package]] name = "tokio" version = "1.44.1" @@ -694,12 +1125,24 @@ version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512" +[[package]] +name = "unicode-xid" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" + [[package]] name = "utf8parse" version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + [[package]] name = "version_check" version = "0.9.5" @@ -712,6 +1155,103 @@ version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +[[package]] +name = "wasip2" +version = "1.0.2+wasi-0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9517f9239f02c069db75e65f174b3da828fe5f5b945c4dd26bd25d89c03ebcf5" +dependencies = [ + "wit-bindgen", +] + +[[package]] +name = "wasip3" +version = "0.4.0+wasi-0.3.0-rc-2026-01-06" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5428f8bf88ea5ddc08faddef2ac4a67e390b88186c703ce6dbd955e1c145aca5" +dependencies = [ + "wit-bindgen", +] + +[[package]] +name = "wasm-bindgen" +version = "0.2.114" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6532f9a5c1ece3798cb1c2cfdba640b9b3ba884f5db45973a6f442510a87d38e" +dependencies = [ + "cfg-if", + "once_cell", + "rustversion", + "wasm-bindgen-macro", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.114" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18a2d50fcf105fb33bb15f00e7a77b772945a2ee45dcf454961fd843e74c18e6" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.114" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03ce4caeaac547cdf713d280eda22a730824dd11e6b8c3ca9e42247b25c631e3" +dependencies = [ + "bumpalo", + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.114" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75a326b8c223ee17883a4251907455a2431acc2791c98c26279376490c378c16" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "wasm-encoder" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "990065f2fe63003fe337b932cfb5e3b80e0b4d0f5ff650e6985b1048f62c8319" +dependencies = [ + "leb128fmt", + "wasmparser", +] + +[[package]] +name = "wasm-metadata" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb0e353e6a2fbdc176932bbaab493762eb1255a7900fe0fea1a2f96c296cc909" +dependencies = [ + "anyhow", + "indexmap", + "wasm-encoder", + "wasmparser", +] + +[[package]] +name = "wasmparser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47b807c72e1bac69382b3a6fb3dbe8ea4c0ed87ff5629b8685ae6b9a611028fe" +dependencies = [ + "bitflags", + "hashbrown 0.15.2", + "indexmap", + "semver", +] + [[package]] name = "windows-sys" version = "0.52.0" @@ -803,6 +1343,94 @@ dependencies = [ "memchr", ] +[[package]] +name = "wit-bindgen" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5" +dependencies = [ + "wit-bindgen-rust-macro", +] + +[[package]] +name = "wit-bindgen-core" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea61de684c3ea68cb082b7a88508a8b27fcc8b797d738bfc99a82facf1d752dc" +dependencies = [ + "anyhow", + "heck", + "wit-parser", +] + +[[package]] +name = "wit-bindgen-rust" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7c566e0f4b284dd6561c786d9cb0142da491f46a9fbed79ea69cdad5db17f21" +dependencies = [ + "anyhow", + "heck", + "indexmap", + "prettyplease", + "syn", + "wasm-metadata", + "wit-bindgen-core", + "wit-component", +] + +[[package]] +name = "wit-bindgen-rust-macro" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c0f9bfd77e6a48eccf51359e3ae77140a7f50b1e2ebfe62422d8afdaffab17a" +dependencies = [ + "anyhow", + "prettyplease", + "proc-macro2", + "quote", + "syn", + "wit-bindgen-core", + "wit-bindgen-rust", +] + +[[package]] +name = "wit-component" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d66ea20e9553b30172b5e831994e35fbde2d165325bec84fc43dbf6f4eb9cb2" +dependencies = [ + "anyhow", + "bitflags", + "indexmap", + "log", + "serde", + "serde_derive", + "serde_json", + "wasm-encoder", + "wasm-metadata", + "wasmparser", + "wit-parser", +] + +[[package]] +name = "wit-parser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ecc8ac4bc1dc3381b7f59c34f00b67e18f910c2c0f50015669dde7def656a736" +dependencies = [ + "anyhow", + "id-arena", + "indexmap", + "log", + "semver", + "serde", + "serde_derive", + "serde_json", + "unicode-xid", + "wasmparser", +] + [[package]] name = "zerocopy" version = "0.8.24" diff --git a/Cargo.toml b/Cargo.toml index 12ffcb2..318198d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,7 @@ keywords = ["katzenpost", "cryptography", "sphinx", "mixnet"] libc = "0.2.152" rand = "0.8" serde = { version = "1.0", features = ["derive"] } +serde_bytes = "0.11" serde_json = "1.0" serde_cbor = "0.11" blake2 = "0.10" @@ -24,4 +25,14 @@ hex = "0.4" tokio = { version = "1", features = ["full"] } generic-array = "0.14.0" typenum = "1.16" -toml = "0.8" \ No newline at end of file +toml = "0.8" +rusqlite = { version = "0.38.0", features = ["bundled"] } +clap = { version = "4.5", features = ["derive"] } +base64 = "0.22" + +[[bin]] +name = "copycat" +path = "src/bin/copycat.rs" + +[dev-dependencies] +crdts = "7" diff --git a/katzenpost_thinclient/__init__.py b/katzenpost_thinclient/__init__.py index ba3a6a7..8e4b3cc 100644 --- a/katzenpost_thinclient/__init__.py +++ b/katzenpost_thinclient/__init__.py @@ -47,1554 +47,229 @@ async def main(): ``` """ -import socket -import struct -import random -import coloredlogs -import logging -import sys -import io -import os -import asyncio -import cbor2 -import pprintpp -import toml -import hashlib - -from typing import Tuple, Any, Dict, List, Callable - -# Thin Client Error Codes (matching Go implementation) -THIN_CLIENT_SUCCESS = 0 -THIN_CLIENT_ERROR_CONNECTION_LOST = 1 -THIN_CLIENT_ERROR_TIMEOUT = 2 -THIN_CLIENT_ERROR_INVALID_REQUEST = 3 -THIN_CLIENT_ERROR_INTERNAL_ERROR = 4 -THIN_CLIENT_ERROR_MAX_RETRIES = 5 - -THIN_CLIENT_ERROR_INVALID_CHANNEL = 6 -THIN_CLIENT_ERROR_CHANNEL_NOT_FOUND = 7 -THIN_CLIENT_ERROR_PERMISSION_DENIED = 8 -THIN_CLIENT_ERROR_INVALID_PAYLOAD = 9 -THIN_CLIENT_ERROR_SERVICE_UNAVAILABLE = 10 -THIN_CLIENT_ERROR_DUPLICATE_CAPABILITY = 11 -THIN_CLIENT_ERROR_COURIER_CACHE_CORRUPTION = 12 -THIN_CLIENT_PROPAGATION_ERROR = 13 - -def thin_client_error_to_string(error_code: int) -> str: - """Convert a thin client error code to a human-readable string.""" - error_messages = { - THIN_CLIENT_SUCCESS: "Success", - THIN_CLIENT_ERROR_CONNECTION_LOST: "Connection lost", - THIN_CLIENT_ERROR_TIMEOUT: "Timeout", - THIN_CLIENT_ERROR_INVALID_REQUEST: "Invalid request", - THIN_CLIENT_ERROR_INTERNAL_ERROR: "Internal error", - THIN_CLIENT_ERROR_MAX_RETRIES: "Maximum retries exceeded", +# Import core classes and functions +from .core import ( + # Replica error codes (from pigeonhole/errors.go) + REPLICA_SUCCESS, + REPLICA_ERROR_BOX_ID_NOT_FOUND, + REPLICA_ERROR_INVALID_BOX_ID, + REPLICA_ERROR_INVALID_SIGNATURE, + REPLICA_ERROR_DATABASE_FAILURE, + REPLICA_ERROR_INVALID_PAYLOAD, + REPLICA_ERROR_STORAGE_FULL, + REPLICA_ERROR_INTERNAL_ERROR, + REPLICA_ERROR_INVALID_EPOCH, + REPLICA_ERROR_REPLICATION_FAILED, + REPLICA_ERROR_BOX_ALREADY_EXISTS, + # Thin client error codes + THIN_CLIENT_SUCCESS, + THIN_CLIENT_ERROR_CONNECTION_LOST, + THIN_CLIENT_ERROR_TIMEOUT, + THIN_CLIENT_ERROR_INVALID_REQUEST, + THIN_CLIENT_ERROR_INTERNAL_ERROR, + THIN_CLIENT_ERROR_MAX_RETRIES, + THIN_CLIENT_ERROR_INVALID_CHANNEL, + THIN_CLIENT_ERROR_CHANNEL_NOT_FOUND, + THIN_CLIENT_ERROR_PERMISSION_DENIED, + THIN_CLIENT_ERROR_INVALID_PAYLOAD, + THIN_CLIENT_ERROR_SERVICE_UNAVAILABLE, + THIN_CLIENT_ERROR_DUPLICATE_CAPABILITY, + THIN_CLIENT_ERROR_COURIER_CACHE_CORRUPTION, + THIN_CLIENT_PROPAGATION_ERROR, + THIN_CLIENT_ERROR_INVALID_WRITE_CAPABILITY, + THIN_CLIENT_ERROR_INVALID_READ_CAPABILITY, + THIN_CLIENT_ERROR_INVALID_RESUME_WRITE_CHANNEL_REQUEST, + THIN_CLIENT_ERROR_INVALID_RESUME_READ_CHANNEL_REQUEST, + THIN_CLIENT_IMPOSSIBLE_HASH_ERROR, + THIN_CLIENT_IMPOSSIBLE_NEW_WRITE_CAP_ERROR, + THIN_CLIENT_IMPOSSIBLE_NEW_STATEFUL_WRITER_ERROR, + THIN_CLIENT_CAPABILITY_ALREADY_IN_USE, + THIN_CLIENT_ERROR_MKEM_DECRYPTION_FAILED, + THIN_CLIENT_ERROR_BACAP_DECRYPTION_FAILED, + THIN_CLIENT_ERROR_START_RESENDING_CANCELLED, + thin_client_error_to_string, + error_code_to_exception, + # Replica exceptions (matching Go sentinel errors) + ReplicaError, + BoxIDNotFoundError, + InvalidBoxIDError, + InvalidSignatureError, + DatabaseFailureError, + InvalidPayloadError, + StorageFullError, + ReplicaInternalError, + InvalidEpochError, + ReplicationFailedError, + BoxAlreadyExistsError, + # Thin client exceptions + MKEMDecryptionFailedError, + BACAPDecryptionFailedError, + StartResendingCancelledError, + ThinClientOfflineError, + # Constants + SURB_ID_SIZE, + MESSAGE_ID_SIZE, + STREAM_ID_LENGTH, + # Classes + ThinClient, + Config, + ConfigFile, + Geometry, + PigeonholeGeometry, + ServiceDescriptor, + # Functions + find_services, + pretty_print_obj, + blake2_256_sum, +) + +# Import legacy channel API classes and methods +from .legacy import ( + WriteChannelReply, + ReadChannelReply, + create_write_channel, + create_read_channel, + write_channel, + read_channel, + read_channel_with_retry, + _send_channel_query_and_wait_for_message_id, + close_channel, +) + +# Import new pigeonhole API methods and result types +from .pigeonhole import ( + stream_id, + new_keypair, + encrypt_read, + encrypt_write, + start_resending_encrypted_message, + start_resending_encrypted_message_return_box_exists, + start_resending_encrypted_message_no_retry, + cancel_resending_encrypted_message, + next_message_box_index, + start_resending_copy_command, + cancel_resending_copy_command, + create_courier_envelopes_from_payload, + create_courier_envelopes_from_multi_payload, + set_stream_buffer, + tombstone_box, + tombstone_range, + # Result dataclasses + KeypairResult, + EncryptReadResult, + EncryptWriteResult, + CreateEnvelopesResult, +) + + +# Attach legacy channel API methods to ThinClient +ThinClient.create_write_channel = create_write_channel +ThinClient.create_read_channel = create_read_channel +ThinClient.write_channel = write_channel +ThinClient.read_channel = read_channel +ThinClient.read_channel_with_retry = read_channel_with_retry +ThinClient._send_channel_query_and_wait_for_message_id = _send_channel_query_and_wait_for_message_id +ThinClient.close_channel = close_channel + +# Attach new pigeonhole API methods to ThinClient +ThinClient.stream_id = stream_id +ThinClient.new_keypair = new_keypair +ThinClient.encrypt_read = encrypt_read +ThinClient.encrypt_write = encrypt_write +ThinClient.start_resending_encrypted_message = start_resending_encrypted_message +ThinClient.start_resending_encrypted_message_return_box_exists = start_resending_encrypted_message_return_box_exists +ThinClient.start_resending_encrypted_message_no_retry = start_resending_encrypted_message_no_retry +ThinClient.cancel_resending_encrypted_message = cancel_resending_encrypted_message +ThinClient.next_message_box_index = next_message_box_index +ThinClient.start_resending_copy_command = start_resending_copy_command +ThinClient.cancel_resending_copy_command = cancel_resending_copy_command +ThinClient.create_courier_envelopes_from_payload = create_courier_envelopes_from_payload +ThinClient.create_courier_envelopes_from_multi_payload = create_courier_envelopes_from_multi_payload +ThinClient.set_stream_buffer = set_stream_buffer +ThinClient.tombstone_box = tombstone_box +ThinClient.tombstone_range = tombstone_range - THIN_CLIENT_ERROR_INVALID_CHANNEL: "Invalid channel", - THIN_CLIENT_ERROR_CHANNEL_NOT_FOUND: "Channel not found", - THIN_CLIENT_ERROR_PERMISSION_DENIED: "Permission denied", - THIN_CLIENT_ERROR_INVALID_PAYLOAD: "Invalid payload", - THIN_CLIENT_ERROR_SERVICE_UNAVAILABLE: "Service unavailable", - THIN_CLIENT_ERROR_DUPLICATE_CAPABILITY: "Duplicate capability", - THIN_CLIENT_ERROR_COURIER_CACHE_CORRUPTION: "Courier cache corruption", - THIN_CLIENT_PROPAGATION_ERROR: "Propagation error", - } - return error_messages.get(error_code, f"Unknown thin client error code: {error_code}") - -class ThinClientOfflineError(Exception): - pass # Export public API __all__ = [ + # Main classes 'ThinClient', 'ThinClientOfflineError', 'Config', + 'ConfigFile', + 'Geometry', + 'PigeonholeGeometry', 'ServiceDescriptor', + # Legacy channel reply classes 'WriteChannelReply', 'ReadChannelReply', - 'find_services' + # Pigeonhole result dataclasses + 'KeypairResult', + 'EncryptReadResult', + 'EncryptWriteResult', + 'CreateEnvelopesResult', + # Utility functions + 'find_services', + 'pretty_print_obj', + 'blake2_256_sum', + 'thin_client_error_to_string', + 'error_code_to_exception', + # Constants + 'SURB_ID_SIZE', + 'MESSAGE_ID_SIZE', + 'STREAM_ID_LENGTH', + # Replica error codes (from pigeonhole/errors.go) + 'REPLICA_SUCCESS', + 'REPLICA_ERROR_BOX_ID_NOT_FOUND', + 'REPLICA_ERROR_INVALID_BOX_ID', + 'REPLICA_ERROR_INVALID_SIGNATURE', + 'REPLICA_ERROR_DATABASE_FAILURE', + 'REPLICA_ERROR_INVALID_PAYLOAD', + 'REPLICA_ERROR_STORAGE_FULL', + 'REPLICA_ERROR_INTERNAL_ERROR', + 'REPLICA_ERROR_INVALID_EPOCH', + 'REPLICA_ERROR_REPLICATION_FAILED', + 'REPLICA_ERROR_BOX_ALREADY_EXISTS', + # Thin client error codes + 'THIN_CLIENT_SUCCESS', + 'THIN_CLIENT_ERROR_CONNECTION_LOST', + 'THIN_CLIENT_ERROR_TIMEOUT', + 'THIN_CLIENT_ERROR_INVALID_REQUEST', + 'THIN_CLIENT_ERROR_INTERNAL_ERROR', + 'THIN_CLIENT_ERROR_MAX_RETRIES', + 'THIN_CLIENT_ERROR_INVALID_CHANNEL', + 'THIN_CLIENT_ERROR_CHANNEL_NOT_FOUND', + 'THIN_CLIENT_ERROR_PERMISSION_DENIED', + 'THIN_CLIENT_ERROR_INVALID_PAYLOAD', + 'THIN_CLIENT_ERROR_SERVICE_UNAVAILABLE', + 'THIN_CLIENT_ERROR_DUPLICATE_CAPABILITY', + 'THIN_CLIENT_ERROR_COURIER_CACHE_CORRUPTION', + 'THIN_CLIENT_PROPAGATION_ERROR', + 'THIN_CLIENT_ERROR_INVALID_WRITE_CAPABILITY', + 'THIN_CLIENT_ERROR_INVALID_READ_CAPABILITY', + 'THIN_CLIENT_ERROR_INVALID_RESUME_WRITE_CHANNEL_REQUEST', + 'THIN_CLIENT_ERROR_INVALID_RESUME_READ_CHANNEL_REQUEST', + 'THIN_CLIENT_IMPOSSIBLE_HASH_ERROR', + 'THIN_CLIENT_IMPOSSIBLE_NEW_WRITE_CAP_ERROR', + 'THIN_CLIENT_IMPOSSIBLE_NEW_STATEFUL_WRITER_ERROR', + 'THIN_CLIENT_CAPABILITY_ALREADY_IN_USE', + 'THIN_CLIENT_ERROR_MKEM_DECRYPTION_FAILED', + 'THIN_CLIENT_ERROR_BACAP_DECRYPTION_FAILED', + 'THIN_CLIENT_ERROR_START_RESENDING_CANCELLED', + # Replica exceptions (matching Go sentinel errors) + 'ReplicaError', + 'BoxIDNotFoundError', + 'InvalidBoxIDError', + 'InvalidSignatureError', + 'DatabaseFailureError', + 'InvalidPayloadError', + 'StorageFullError', + 'ReplicaInternalError', + 'InvalidEpochError', + 'ReplicationFailedError', + 'BoxAlreadyExistsError', + # Thin client exceptions + 'MKEMDecryptionFailedError', + 'BACAPDecryptionFailedError', + 'StartResendingCancelledError', ] - -# SURB_ID_SIZE is the size in bytes for the -# Katzenpost SURB ID. -SURB_ID_SIZE = 16 - -# MESSAGE_ID_SIZE is the size in bytes for an ID -# which is unique to the sent message. -MESSAGE_ID_SIZE = 16 - - -class WriteChannelReply: - """Reply from WriteChannel operation, matching Rust WriteChannelReply.""" - - def __init__(self, send_message_payload: bytes, current_message_index: bytes, - next_message_index: bytes, envelope_descriptor: bytes, envelope_hash: bytes): - self.send_message_payload = send_message_payload - self.current_message_index = current_message_index - self.next_message_index = next_message_index - self.envelope_hash = envelope_hash - self.envelope_descriptor = envelope_descriptor - - -class ReadChannelReply: - """Reply from ReadChannel operation, matching Rust ReadChannelReply.""" - - def __init__(self, send_message_payload: bytes, current_message_index: bytes, - next_message_index: bytes, reply_index: "int|None", - envelope_descriptor: bytes, envelope_hash: bytes): - self.send_message_payload = send_message_payload - self.current_message_index = current_message_index - self.next_message_index = next_message_index - self.reply_index = reply_index - self.envelope_descriptor = envelope_descriptor - self.envelope_hash = envelope_hash - - -class Geometry: - """ - Geometry describes the geometry of a Sphinx packet. - - NOTE: You must not try to compose a Sphinx Geometry yourself. - It must be programmatically generated by Katzenpost - genconfig or gensphinx CLI utilities. - - We describe all the Sphinx Geometry attributes below, however - the only one you are interested in to faciliate your thin client - message bounds checking is UserForwardPayloadLength, which indicates - the maximum sized message that you can send to a mixnet service in - a single packet. - - Attributes: - PacketLength (int): The total length of a Sphinx packet in bytes. - NrHops (int): The number of hops; determines the header's structure. - HeaderLength (int): The total size of the Sphinx header in bytes. - RoutingInfoLength (int): The length of the routing information portion of the header. - PerHopRoutingInfoLength (int): The length of routing info for a single hop. - SURBLength (int): The length of a Single-Use Reply Block (SURB). - SphinxPlaintextHeaderLength (int): The length of the unencrypted plaintext header. - PayloadTagLength (int): The length of the tag used to authenticate the payload. - ForwardPayloadLength (int): The size of the full payload including padding and tag. - UserForwardPayloadLength (int): The usable portion of the payload intended for the recipient. - NextNodeHopLength (int): Derived from the expected maximum routing info block size. - SPRPKeyMaterialLength (int): The length of the key used for SPRP (Sphinx packet payload encryption). - NIKEName (str): Name of the NIKE scheme (if used). Mutually exclusive with KEMName. - KEMName (str): Name of the KEM scheme (if used). Mutually exclusive with NIKEName. - """ - - def __init__(self, *, PacketLength:int, NrHops:int, HeaderLength:int, RoutingInfoLength:int, PerHopRoutingInfoLength:int, SURBLength:int, SphinxPlaintextHeaderLength:int, PayloadTagLength:int, ForwardPayloadLength:int, UserForwardPayloadLength:int, NextNodeHopLength:int, SPRPKeyMaterialLength:int, NIKEName:str='', KEMName:str='') -> None: - self.PacketLength = PacketLength - self.NrHops = NrHops - self.HeaderLength = HeaderLength - self.RoutingInfoLength = RoutingInfoLength - self.PerHopRoutingInfoLength = PerHopRoutingInfoLength - self.SURBLength = SURBLength - self.SphinxPlaintextHeaderLength = SphinxPlaintextHeaderLength - self.PayloadTagLength = PayloadTagLength - self.ForwardPayloadLength = ForwardPayloadLength - self.UserForwardPayloadLength = UserForwardPayloadLength - self.NextNodeHopLength = NextNodeHopLength - self.SPRPKeyMaterialLength = SPRPKeyMaterialLength - self.NIKEName = NIKEName - self.KEMName = KEMName - - def __str__(self) -> str: - return ( - f"PacketLength: {self.PacketLength}\n" - f"NrHops: {self.NrHops}\n" - f"HeaderLength: {self.HeaderLength}\n" - f"RoutingInfoLength: {self.RoutingInfoLength}\n" - f"PerHopRoutingInfoLength: {self.PerHopRoutingInfoLength}\n" - f"SURBLength: {self.SURBLength}\n" - f"SphinxPlaintextHeaderLength: {self.SphinxPlaintextHeaderLength}\n" - f"PayloadTagLength: {self.PayloadTagLength}\n" - f"ForwardPayloadLength: {self.ForwardPayloadLength}\n" - f"UserForwardPayloadLength: {self.UserForwardPayloadLength}\n" - f"NextNodeHopLength: {self.NextNodeHopLength}\n" - f"SPRPKeyMaterialLength: {self.SPRPKeyMaterialLength}\n" - f"NIKEName: {self.NIKEName}\n" - f"KEMName: {self.KEMName}" - ) - -class ConfigFile: - """ - ConfigFile represents everything loaded from a TOML file: - network, address, and geometry. - """ - def __init__(self, network:str, address:str, geometry:Geometry) -> None: - self.network : str = network - self.address : str = address - self.geometry : Geometry = geometry - - @classmethod - def load(cls, toml_path:str) -> "ConfigFile": - with open(toml_path, 'r') as f: - data = toml.load(f) - network = data.get('Network') - assert isinstance(network, str) - address = data.get('Address') - assert isinstance(address, str) - geometry_data = data.get('SphinxGeometry') - assert isinstance(geometry_data, dict) - geometry : Geometry = Geometry(**geometry_data) - return cls(network, address, geometry) - - def __str__(self) -> str: - return ( - f"Network: {self.network}\n" - f"Address: {self.address}\n" - f"Geometry:\n{self.geometry}" - ) - - -def pretty_print_obj(obj: "Any") -> str: - """ - Pretty-print a Python object using indentation and return the formatted string. - - This function uses `pprintpp` to format complex data structures - (e.g., dictionaries, lists) in a readable, indented format. - - Args: - obj (Any): The object to pretty-print. - - Returns: - str: The pretty-printed representation of the object. - """ - pp = pprintpp.PrettyPrinter(indent=4) - return pp.pformat(obj) - -def blake2_256_sum(data:bytes) -> bytes: - return hashlib.blake2b(data, digest_size=32).digest() - -class ServiceDescriptor: - """ - Describes a mixnet service endpoint retrieved from the PKI document. - - A ServiceDescriptor encapsulates the necessary information for communicating - with a service on the mix network. The service node's identity public key's hash - is used as the destination address along with the service's queue ID. - - Attributes: - recipient_queue_id (bytes): The identifier of the recipient's queue on the mixnet. ("Kaetzchen.endpoint" in the PKI) - mix_descriptor (dict): A CBOR-decoded dictionary describing the mix node, - typically includes the 'IdentityKey' and other metadata. - - Methods: - to_destination(): Returns a tuple of (provider_id_hash, recipient_queue_id), - where the provider ID is a 32-byte BLAKE2b hash of the IdentityKey. - """ - - def __init__(self, recipient_queue_id:bytes, mix_descriptor: "Dict[Any,Any]") -> None: - self.recipient_queue_id = recipient_queue_id - self.mix_descriptor = mix_descriptor - - def to_destination(self) -> "Tuple[bytes,bytes]": - "provider identity key hash and queue id" - provider_id_hash = blake2_256_sum(self.mix_descriptor['IdentityKey']) - return (provider_id_hash, self.recipient_queue_id) - -def find_services(capability:str, doc:"Dict[str,Any]") -> "List[ServiceDescriptor]": - """ - Search the PKI document for services supporting the specified capability. - - This function iterates over all service nodes in the PKI document, - deserializes each CBOR-encoded node, and looks for advertised capabilities. - If a service provides the requested capability, it is returned as a - `ServiceDescriptor`. - - Args: - capability (str): The name of the capability to search for (e.g., "echo"). - doc (dict): The decoded PKI document as a Python dictionary, - which must include a "ServiceNodes" key containing CBOR-encoded descriptors. - - Returns: - List[ServiceDescriptor]: A list of matching service descriptors that advertise the capability. - - Raises: - KeyError: If the 'ServiceNodes' field is missing from the PKI document. - """ - services = [] - for node in doc['ServiceNodes']: - mynode = cbor2.loads(node) - - # Check if the node has services in Kaetzchen field (fixed from omitempty) - if 'Kaetzchen' in mynode: - for cap, details in mynode['Kaetzchen'].items(): - if cap == capability: - service_desc = ServiceDescriptor( - recipient_queue_id=bytes(details['endpoint'], 'utf-8'), # why is this bytes when it's string in PKI? - mix_descriptor=mynode - ) - services.append(service_desc) - return services - - -class Config: - """ - Configuration object for the ThinClient containing connection details and event callbacks. - - The Config class loads network configuration from a TOML file and provides optional - callback functions that are invoked when specific events occur during client operation. - - Attributes: - network (str): Network type ('tcp', 'unix', etc.) - address (str): Network address (host:port for TCP, path for Unix sockets) - geometry (Geometry): Sphinx packet geometry parameters - on_connection_status (callable): Callback for connection status changes - on_new_pki_document (callable): Callback for new PKI documents - on_message_sent (callable): Callback for message transmission confirmations - on_message_reply (callable): Callback for received message replies - - Example: - >>> def handle_reply(event): - ... # Process the received reply - ... payload = event['payload'] - >>> - >>> config = Config("client.toml", on_message_reply=handle_reply) - >>> client = ThinClient(config) - """ - - def __init__(self, filepath:str, - on_connection_status:"Callable|None"=None, - on_new_pki_document:"Callable|None"=None, - on_message_sent:"Callable|None"=None, - on_message_reply:"Callable|None"=None) -> None: - """ - Initialize the Config object. - - Args: - filepath (str): Path to the TOML config file containing network, address, and geometry. - - on_connection_status (callable, optional): Callback invoked when the daemon's connection - status to the mixnet changes. The callback receives a single argument: - - - event (dict): Connection status event with keys: - - 'is_connected' (bool): True if daemon is connected to mixnet, False otherwise - - 'err' (str, optional): Error message if connection failed, empty string if no error - - Example: ``{'is_connected': True, 'err': ''}`` - - on_new_pki_document (callable, optional): Callback invoked when a new PKI document - is received from the mixnet. The callback receives a single argument: - - - event (dict): PKI document event with keys: - - 'payload' (bytes): CBOR-encoded PKI document data stripped of signatures - - Example: ``{'payload': b'\\xa5\\x64Epoch\\x00...'}`` - - on_message_sent (callable, optional): Callback invoked when a message has been - successfully transmitted to the mixnet. The callback receives a single argument: - - - event (dict): Message sent event with keys: - - 'message_id' (bytes): 16-byte unique identifier for the sent message - - 'surbid' (bytes, optional): SURB ID if message was sent with SURB, None otherwise - - 'sent_at' (str): ISO timestamp when message was sent - - 'reply_eta' (float): Expected round-trip time in seconds for reply - - 'err' (str, optional): Error message if sending failed, empty string if successful - - Example: ``{'message_id': b'\\x01\\x02...', 'surbid': b'\\xaa\\xbb...', 'sent_at': '2024-01-01T12:00:00Z', 'reply_eta': 30.5, 'err': ''}`` - - on_message_reply (callable, optional): Callback invoked when a reply is received - for a previously sent message. The callback receives a single argument: - - - event (dict): Message reply event with keys: - - 'message_id' (bytes): 16-byte identifier matching the original message - - 'surbid' (bytes, optional): SURB ID if reply used SURB, None otherwise - - 'payload' (bytes): Reply payload data from the service - - 'reply_index' (int, optional): Index of reply used - - 'error_code' (int): Error code indicating success (0) or specific failure condition - - Example: ``{'message_id': b'\\x01\\x02...', 'surbid': b'\\xaa\\xbb...', 'payload': b'echo response', 'reply_index': 0, 'error_code': 0}`` - - Note: - All callbacks are optional. If not provided, the corresponding events will be ignored. - Callbacks should be lightweight and non-blocking as they are called from the client's - event processing loop. - """ - - cfgfile = ConfigFile.load(filepath) - - self.network = cfgfile.network - self.address = cfgfile.address - self.geometry = cfgfile.geometry - - self.on_connection_status = on_connection_status - self.on_new_pki_document = on_new_pki_document - self.on_message_sent = on_message_sent - self.on_message_reply = on_message_reply - - async def handle_connection_status_event(self, event: asyncio.Event) -> None: - if self.on_connection_status: - return await self.on_connection_status(event) - - async def handle_new_pki_document_event(self, event: asyncio.Event) -> None: - if self.on_new_pki_document: - await self.on_new_pki_document(event) - - async def handle_message_sent_event(self, event: asyncio.Event) -> None: - if self.on_message_sent: - await self.on_message_sent(event) - - async def handle_message_reply_event(self, event: asyncio.Event) -> None: - if self.on_message_reply: - await self.on_message_reply(event) - - -class ThinClient: - """ - A minimal Katzenpost Python thin client for communicating with the local - Katzenpost client daemon over a UNIX or TCP socket. - - The thin client is responsible for: - - Establishing a connection to the client daemon. - - Receiving and parsing PKI documents. - - Sending messages to mixnet services (with or without SURBs). - - Handling replies and events via user-defined callbacks. - - All cryptographic operations are handled by the daemon, not by this client. - """ - - def __init__(self, config:Config) -> None: - """ - Initialize the thin client with the given configuration. - - Args: - config (Config): The configuration object containing socket details and callbacks. - - Raises: - RuntimeError: If the network type is not recognized or config is incomplete. - """ - self.pki_doc : Dict[Any,Any] | None = None - self.config = config - self.reply_received_event = asyncio.Event() - - self._is_connected : bool = False # Track connection state - - # Mutexes to serialize socket send/recv operations: - self._send_lock = asyncio.Lock() - self._recv_lock = asyncio.Lock() - - # Letterbox for each response associated (by query_id) with a request. - self.response_queues : Dict[bytes, asyncio.Queue[Dict[str,Any]]] = {} # (query_id|message_id) -> Queue - self.ack_queues : Dict[bytes, asyncio.Queue[Dict[str,Any]]] = {} # (query_id|message_id) -> Queue - - # Channel query message ID correlation (for send_channel_query_await_reply) - self.pending_channel_message_queries : Dict[bytes, asyncio.Event] = {} # message_id -> Event - self.channel_message_query_responses : Dict[bytes, bytes] = {} # message_id -> payload - - - self.logger = logging.getLogger('thinclient') - self.logger.setLevel(logging.DEBUG) - # Only add handler if none exists to avoid duplicate log messages - # XXX: commented out because it did in fact log twice: - #if not self.logger.handlers: - # handler = logging.StreamHandler(sys.stderr) - # self.logger.addHandler(handler) - - if self.config.network is None: - raise RuntimeError("config.network is None") - - network: str = self.config.network.lower() - self.server_addr : str | Tuple[str,int] - if network.lower().startswith("tcp"): - self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - host, port_str = self.config.address.split(":") - self.server_addr = (host, int(port_str)) - elif network.lower().startswith("unix"): - self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - - if self.config.address.startswith("@"): - # Abstract UNIX socket: leading @ means first byte is null - abstract_name = self.config.address[1:] - self.server_addr = f"\0{abstract_name}" - - # Bind to a unique abstract socket for this client - random_bytes = [random.randint(0, 255) for _ in range(16)] - hex_string = ''.join(format(byte, '02x') for byte in random_bytes) - client_abstract = f"\0katzenpost_python_thin_client_{hex_string}" - self.socket.bind(client_abstract) - else: - # Filesystem UNIX socket - self.server_addr = self.config.address - - self.socket.setblocking(False) - else: - raise RuntimeError(f"Unknown network type: {self.config.network}") - - self.socket.setblocking(False) - - - async def start(self, loop:asyncio.AbstractEventLoop) -> None: - """ - Start the thin client: establish connection to the daemon, read initial events, - and begin the background event loop. - - Args: - loop (asyncio.AbstractEventLoop): The running asyncio event loop. - - Exceptions: - BrokenPipeError - """ - self.logger.debug("connecting to daemon") - server_addr : str | Tuple[str,int] = '' - - if self.config.network.lower().startswith("tcp"): - host, port_str = self.config.address.split(":") - server_addr = (host, int(port_str)) - elif self.config.network.lower().startswith("unix"): - if self.config.address.startswith("@"): - server_addr = '\0' + self.config.address[1:] - else: - server_addr = self.config.address - else: - raise RuntimeError(f"Unknown network type: {self.config.network}") - - await loop.sock_connect(self.socket, server_addr) - - # 1st message is always a status event - response = await self.recv(loop) - assert response is not None - assert response["connection_status_event"] is not None - await self.handle_response(response) - - # 2nd message is always a new pki doc event - #response = await self.recv(loop) - #assert response is not None - #assert response["new_pki_document_event"] is not None, response - #await self.handle_response(response) - - # Start the read loop as a background task - self.logger.debug("starting read loop") - self.task = loop.create_task(self.worker_loop(loop)) - def handle_loop_err(task): - try: - result = task.result() - except Exception: - import traceback - traceback.print_exc() - raise - self.task.add_done_callback(handle_loop_err) - - def get_config(self) -> Config: - """ - Returns the current configuration object. - - Returns: - Config: The client configuration in use. - """ - return self.config - - def is_connected(self) -> bool: - """ - Returns True if the daemon is connected to the mixnet. - - Returns: - bool: True if connected, False if in offline mode. - """ - return self._is_connected - - def stop(self) -> None: - """ - Gracefully shut down the client and close its socket. - """ - self.logger.debug("closing connection to daemon") - self.socket.close() - self.task.cancel() - - async def _send_all(self, data: bytes) -> None: - """ - Send all data using async socket operations with mutex protection. - - This method uses a mutex to prevent race conditions when multiple - coroutines try to send data over the same socket simultaneously. - - Args: - data (bytes): Data to send. - """ - async with self._send_lock: - loop = asyncio.get_running_loop() - await loop.sock_sendall(self.socket, data) - - async def __recv_exactly(self, total:int, loop:asyncio.AbstractEventLoop) -> bytes: - "receive exactly (total) bytes or die trying raising BrokenPipeError" - buf = bytearray(total) - remain = memoryview(buf) - while len(remain): - if not (nread := await loop.sock_recv_into(self.socket, remain)): - raise BrokenPipeError - remain = remain[nread:] - return buf - - async def recv(self, loop:asyncio.AbstractEventLoop) -> "Dict[Any,Any]": - """ - Receive a CBOR-encoded message from the daemon. - - Args: - loop (asyncio.AbstractEventLoop): Event loop to use for socket reads. - - Returns: - dict: Decoded CBOR response from the daemon. - - Raises: - BrokenPipeError: If connection fails - ValueError: If message framing fails. - """ - async with self._recv_lock: - length_prefix = await self.__recv_exactly(4, loop) - message_length = struct.unpack('>I', length_prefix)[0] - raw_data = await self.__recv_exactly(message_length, loop) - try: - response = cbor2.loads(raw_data) - except cbor2.CBORDecodeValueError as e: - self.logger.error(f"{e}") - raise ValueError(f"{e}") - response = {k:v for k,v in response.items() if v} # filter empty KV pairs - if not (set(response.keys()) & {'new_pki_document_event'}): - self.logger.debug(f"Received daemon response: [{len(raw_data)}] {type(response)} {response}") - return response - - async def worker_loop(self, loop:asyncio.events.AbstractEventLoop) -> None: - """ - Background task that listens for events and dispatches them. - """ - self.logger.debug("read loop start") - while True: - try: - response = await self.recv(loop) - except asyncio.CancelledError: - # Handle cancellation of the read loop - self.logger.error(f"worker_loop cancelled") - break - except Exception as e: - self.logger.error(f"Error reading from socket: {e}") - raise - else: - def handle_response_err(task): - try: - result = task.result() - except Exception: - import traceback - traceback.print_exc() - raise - resp = asyncio.create_task(self.handle_response(response)) - resp.add_done_callback(handle_response_err) - - def parse_status(self, event: "Dict[str,Any]") -> None: - """ - Parse a connection status event and update connection state. - """ - self.logger.debug("parse status") - assert event is not None - - self._is_connected = event.get("is_connected", False) - - if self._is_connected: - self.logger.debug("Daemon is connected to mixnet - full functionality available") - else: - self.logger.info("Daemon is not connected to mixnet - entering offline mode") - - self.logger.debug("parse status success") - - def pki_document(self) -> "Dict[str,Any] | None": - """ - Retrieve the latest PKI document received. - - Returns: - dict: Parsed CBOR PKI document. - """ - return self.pki_doc - - def parse_pki_doc(self, event: "Dict[str,Any]") -> None: - """ - Parse and store a new PKI document received from the daemon. - """ - self.logger.debug("parse pki doc") - assert event is not None - assert event["payload"] is not None - raw_pki_doc = cbor2.loads(event["payload"]) - self.pki_doc = raw_pki_doc - self.logger.debug("parse pki doc success") - - def get_services(self, capability:str) -> "List[ServiceDescriptor]": - """ - Look up all services in the PKI that advertise a given capability. - - Args: - capability (str): Capability name (e.g., "echo"). - - Returns: - list[ServiceDescriptor]: Matching services.xsy - - Raises: - Exception: If PKI is missing or no services match. - """ - doc = self.pki_document() - if doc == None: - raise Exception("pki doc is nil") - descriptors = find_services(capability, doc) - if not descriptors: - raise Exception("service not found in pki doc") - return descriptors - - def get_service(self, service_name:str) -> ServiceDescriptor: - """ - Select a random service matching a capability. - - Args: - service_name (str): The capability name (e.g., "echo"). - - Returns: - ServiceDescriptor: One of the matching services. - """ - service_descriptors = self.get_services(service_name) - return random.choice(service_descriptors) - - @staticmethod - def new_message_id() -> bytes: - """ - Generate a new 16-byte message ID for use with ARQ sends. - - Returns: - bytes: Random 16-byte identifier. - """ - return os.urandom(MESSAGE_ID_SIZE) - - def new_surb_id(self) -> bytes: - """ - Generate a new 16-byte SURB ID for reply-capable sends. - - Returns: - bytes: Random 16-byte identifier. - """ - return os.urandom(SURB_ID_SIZE) - - def new_query_id(self) -> bytes: - """ - Generate a new 16-byte query ID for channel API operations. - - Returns: - bytes: Random 16-byte identifier. - """ - return os.urandom(16) - - async def _send_and_wait(self, *, query_id:bytes, request: Dict[str, Any]) -> Dict[str, Any]: - cbor_request = cbor2.dumps(request) - length_prefix = struct.pack('>I', len(cbor_request)) - length_prefixed_request = length_prefix + cbor_request - assert query_id not in self.response_queues - self.response_queues[query_id] = asyncio.Queue(maxsize=1) - request_type = list(request.keys())[0] - try: - await self._send_all(length_prefixed_request) - self.logger.info(f"{request_type} request sent.") - reply = await self.response_queues[query_id].get() - self.logger.info(f"{request_type} response received.") - # TODO error handling, see _wait_for_channel_reply - return reply - except asyncio.CancelledError: - self.logger.info("{request_type} task cancelled.") - raise - finally: - del self.response_queues[query_id] - - async def _wait_for_channel_reply(self, expected_reply_type: str) -> Dict[Any, Any]: - """ - Wait for a channel API reply using response queues (simulating Rust's event sinks). - - Args: - expected_reply_type: The expected reply type (e.g., "create_write_channel_reply"). - - Returns: - Dict: The reply data. - - Raises: - Exception: If the reply contains an error or times out. - """ - # Create a queue for this reply type - queue = asyncio.Queue(maxsize=1) - self.channel_response_queues[expected_reply_type] = queue - - try: - # Wait for the reply with timeout - reply = await asyncio.wait_for(queue.get(), timeout=30.0) - - # Check for errors (matching Rust implementation) - error_code = reply.get("error_code", 0) - if error_code != 0: - raise Exception(f"{expected_reply_type} failed with error code: {error_code}") - - if reply.get("err"): - raise Exception(f"{expected_reply_type} failed: {reply['err']}") - - return reply - - except asyncio.TimeoutError: - raise Exception(f"Timeout waiting for {expected_reply_type}") - finally: - # Clean up - self.channel_response_queues.pop(expected_reply_type, None) - - async def handle_response(self, response: "Dict[str,Any]") -> None: - """ - Dispatch a parsed CBOR response to the appropriate handler or callback. - """ - assert response is not None - - if response.get("connection_status_event") is not None: - self.logger.debug("connection status event") - self.parse_status(response["connection_status_event"]) - await self.config.handle_connection_status_event(response["connection_status_event"]) - return - if response.get("new_pki_document_event") is not None: - self.logger.debug("new pki doc event") - self.parse_pki_doc(response["new_pki_document_event"]) - await self.config.handle_new_pki_document_event(response["new_pki_document_event"]) - return - if response.get("message_sent_event") is not None: - self.logger.debug("message sent event") - await self.config.handle_message_sent_event(response["message_sent_event"]) - return - if response.get("message_reply_event") is not None: - self.logger.debug("message reply event") - reply = response["message_reply_event"] - self.reply_received_event.set() - await self.config.handle_message_reply_event(reply) - return - # Handle channel query events (for send_channel_query_await_reply), this is the ACK from the local clientd (not courier) - if response.get("channel_query_sent_event") is not None: - # channel_query_sent_event': {'message_id': b'\xb7\xd5\xaeG\x8a\xc4\x96\x99|M\x89c\x90\xc3\xd4\x1f', 'sent_at': 1758485828, 'reply_eta': 1179000000, 'error_code': 0}, - self.logger.debug("channel_query_sent_event") - event = response["channel_query_sent_event"] - message_id = event.get("message_id") - if message_id is not None: - # Check for error in sent event - error_code = event.get("error_code", 0) - if error_code != 0: - # Store error for the waiting coroutine - if message_id in self.pending_channel_message_queries: - self.channel_message_query_responses[message_id] = f"Channel query send failed with error code: {error_code}".encode() - self.pending_channel_message_queries[message_id].set() - # Continue waiting for the reply (don't return here) - return - - if query_ack := response.get("channel_query_reply_event", None): - # this is the ACK from the courier - self.logger.debug("channel_query_reply_event") - event = response["channel_query_reply_event"] - message_id = event.get("message_id") - - if message_id is None: - self.logger.error("channel_query_reply_event without message_id") - return - - # TODO wait why are we storing these indefinitely if we don't really care about them?? - if error_code := event.get("error_code", 0): - error_msg = f"Channel query failed with error code: {error_code}".encode() - self.channel_message_query_responses[message_id] = error_msg - else: - # Extract the payload - payload = event.get("payload", b"") - self.channel_message_query_responses[message_id] = payload - - if (queue := self.ack_queues.get(message_id, None)): - self.logger.debug(f"ack_queues: populated with message_id {message_id.hex()}") - asyncio.create_task(queue.put(query_ack)) - else: - self.logger.error(f"channel_query_reply_event for message_id {message_id.hex()}, but there is no listener") - - - # Signal the waiting coroutine - if message_id in self.pending_channel_message_queries: - self.pending_channel_message_queries[message_id].set() - return - - for reply_type, reply in response.items(): - if not reply: - continue - self.logger.debug(f"channel {reply_type} event") - if not reply_type.endswith("_reply") or not (query_id := reply.get("query_id", None)): - self.logger.debug(f"{reply_type} is not a reply, or can't get query_id") - # 'create_read_channel_reply': {'query_id': None, 'channel_id': 0, 'error_code': 21}, - # DEBUG [thinclient] channel_query_reply_event is not a reply, or can't get query_id - # REPLY {'message_id': b'\xfd\xc0\x9d\xcfh\xa3\x88X[\xab\xa8\xd3\x1b\x8b\x15\xd1', 'payload': b'', 'reply_index': None, 'error_code': 0} - # SELF.RESPONSE_QUEUES {} - print("REPLY", reply) - print('SELF.RESPONSE_QUEUES', self.response_queues) - continue - if not (queue := self.response_queues.get(query_id, None)): - self.logger.debug(f"query_id for {reply_type} has no listener") - continue - # avoid blocking recv loop: - asyncio.create_task(queue.put(reply)) - - - - async def send_message_without_reply(self, payload:bytes|str, dest_node:bytes, dest_queue:bytes) -> None: - """ - Send a fire-and-forget message with no SURB or reply handling. - This method requires mixnet connectivity. - - Args: - payload (bytes or str): Message payload. - dest_node (bytes): Destination node identity hash. - dest_queue (bytes): Destination recipient queue ID. - - Raises: - ThinClientOfflineError: If in offline mode (daemon not connected to mixnet). - """ - # Check if we're in offline mode - if not self._is_connected: - raise ThinClientOfflineError("cannot send_message_without_reply in offline mode - daemon not connected to mixnet") - - if not isinstance(payload, bytes): - payload = payload.encode('utf-8') # Encoding the string to bytes - - # Create the SendMessage structure - send_message = { - "id": None, # No ID for fire-and-forget messages - "with_surb": False, - "surbid": None, # No SURB ID for fire-and-forget messages - "destination_id_hash": dest_node, - "recipient_queue_id": dest_queue, - "payload": payload, - } - - # Wrap in the new Request structure - request = { - "send_message": send_message - } - - cbor_request = cbor2.dumps(request) - length_prefix = struct.pack('>I', len(cbor_request)) - length_prefixed_request = length_prefix + cbor_request - try: - await self._send_all(length_prefixed_request) - self.logger.info("Message sent successfully.") - except Exception as e: - self.logger.error(f"Error sending message: {e}") - - async def send_message(self, surb_id:bytes, payload:bytes|str, dest_node:bytes, dest_queue:bytes) -> None: - """ - Send a message using a SURB to allow the recipient to send a reply. - This method requires mixnet connectivity. - - Args: - surb_id (bytes): SURB identifier for reply correlation. - payload (bytes or str): Message payload. - dest_node (bytes): Destination node identity hash. - dest_queue (bytes): Destination recipient queue ID. - - Raises: - ThinClientOfflineError: If in offline mode (daemon not connected to mixnet). - """ - # Check if we're in offline mode - if not self._is_connected: - raise ThinClientOfflineError("cannot send message in offline mode - daemon not connected to mixnet") - - if not isinstance(payload, bytes): - payload = payload.encode('utf-8') # Encoding the string to bytes - - # Create the SendMessage structure - send_message = { - "id": None, # No ID for regular messages - "with_surb": True, - "surbid": surb_id, - "destination_id_hash": dest_node, - "recipient_queue_id": dest_queue, - "payload": payload, - } - - # Wrap in the new Request structure - request = { - "send_message": send_message - } - - cbor_request = cbor2.dumps(request) - length_prefix = struct.pack('>I', len(cbor_request)) - length_prefixed_request = length_prefix + cbor_request - try: - await self._send_all(length_prefixed_request) - self.logger.info("Message sent successfully.") - except Exception as e: - self.logger.error(f"Error sending message: {e}") - - - - async def send_reliable_message(self, message_id:bytes, payload:bytes|str, dest_node:bytes, dest_queue:bytes) -> None: - """ - Send a reliable message using an ARQ mechanism and message ID. - This method requires mixnet connectivity. - - Args: - message_id (bytes): Message ID for reply correlation. - payload (bytes or str): Message payload. - dest_node (bytes): Destination node identity hash. - dest_queue (bytes): Destination recipient queue ID. - - Raises: - ThinClientOfflineError: If in offline mode (daemon not connected to mixnet). - """ - # Check if we're in offline mode - if not self._is_connected: - raise ThinClientOfflineError("cannot send reliable message in offline mode - daemon not connected to mixnet") - - if not isinstance(payload, bytes): - payload = payload.encode('utf-8') # Encoding the string to bytes - - # Create the SendARQMessage structure - send_arq_message = { - "id": message_id, - "with_surb": True, - "surbid": None, # ARQ messages don't use SURB IDs directly - "destination_id_hash": dest_node, - "recipient_queue_id": dest_queue, - "payload": payload, - } - - # Wrap in the new Request structure - request = { - "send_arq_message": send_arq_message - } - - cbor_request = cbor2.dumps(request) - length_prefix = struct.pack('>I', len(cbor_request)) - length_prefixed_request = length_prefix + cbor_request - try: - await self._send_all(length_prefixed_request) - self.logger.info("Message sent successfully.") - except Exception as e: - self.logger.error(f"Error sending message: {e}") - - def pretty_print_pki_doc(self, doc: "Dict[str,Any]") -> None: - """ - Pretty-print a parsed PKI document with fully decoded CBOR nodes. - - Args: - doc (dict): Raw PKI document from the daemon. - """ - assert doc is not None - assert doc['GatewayNodes'] is not None - assert doc['ServiceNodes'] is not None - assert doc['Topology'] is not None - - new_doc = doc - gateway_nodes = [] - service_nodes = [] - topology = [] - - for gateway_cert_blob in doc['GatewayNodes']: - gateway_cert = cbor2.loads(gateway_cert_blob) - gateway_nodes.append(gateway_cert) - - for service_cert_blob in doc['ServiceNodes']: - service_cert = cbor2.loads(service_cert_blob) - service_nodes.append(service_cert) - - for layer in doc['Topology']: - for mix_desc_blob in layer: - mix_cert = cbor2.loads(mix_desc_blob) - topology.append(mix_cert) # flatten, no prob, relax - - new_doc['GatewayNodes'] = gateway_nodes - new_doc['ServiceNodes'] = service_nodes - new_doc['Topology'] = topology - pretty_print_obj(new_doc) - - async def await_message_reply(self) -> None: - """ - Asynchronously block until a reply is received from the daemon. - """ - await self.reply_received_event.wait() - - # Channel API methods - - async def create_write_channel(self) -> "Tuple[int, bytes, bytes]": - """ - Creates a new Pigeonhole write channel for sending messages. - - Returns: - tuple: (channel_id, read_cap, write_cap) where: - - channel_id is the 16-bit channel ID - - read_cap is the read capability for sharing - - write_cap is the write capability for persistence - - Raises: - Exception: If the channel creation fails. - """ - query_id = self.new_query_id() - - request = { - "create_write_channel": { - "query_id": query_id - } - } - - try: - reply = await self._send_and_wait(query_id=query_id, request=request) - except Exception as e: - self.logger.error(f"Error creating write channel: {e}") - raise e - - channel_id = reply["channel_id"] - read_cap = reply["read_cap"] - write_cap = reply["write_cap"] - - return channel_id, read_cap, write_cap - - async def create_read_channel(self, read_cap: bytes) -> int: - """ - Creates a read channel from a read capability. - - Args: - read_cap: The read capability bytes. - - Returns: - int: The channel ID. - - Raises: - Exception: If the read channel creation fails. - """ - query_id = self.new_query_id() - - request = { - "create_read_channel": { - "query_id": query_id, - "read_cap": read_cap - } - } - - try: - reply = await self._send_and_wait(query_id=query_id, request=request) - except Exception as e: - self.logger.error(f"Error creating read channel: {e}") - raise - - # client2/thin/thin_messages.go: ThinClientCapabilityAlreadyInUse uint8 = 21 - - channel_id = reply["channel_id"] - return channel_id - - async def write_channel(self, channel_id: int, payload: "bytes|str") -> WriteChannelReply: - """ - Prepares a message for writing to a Pigeonhole channel. - - Args: - channel_id: The 16-bit channel ID. - payload: The data to write to the channel. - - Returns: - WriteChannelReply: Reply containing send_message_payload and other metadata. - // ThinClientErrorInternalError indicates an internal error occurred within - // the client daemon or thin client that prevented operation completion. - ThinClientErrorInternalError uint8 = 4 - - - Raises: - Exception: If the write preparation fails. - """ - if not isinstance(payload, bytes): - payload = payload.encode('utf-8') - - query_id = self.new_query_id() - - request = { - "write_channel": { - "channel_id": channel_id, - "query_id": query_id, - "payload": payload - } - } - - try: - reply = await self._send_and_wait(query_id=query_id, request=request) - except Exception as e: - self.logger.error(f"Error preparing write to channel: {e}") - raise - - if reply['error_code'] != 0: - # Examples: - # 12:24:32.206 ERRO katzenpost/client2: writeChannel failure: failed to create write request: pki: replica not found - # - This one will probably never succeed? Why is the client using a bad replica? - # - raise Exception(f"write_channel got error from clientd: {reply['error_code']}") - - return WriteChannelReply( - send_message_payload=reply["send_message_payload"], - current_message_index=reply["current_message_index"], - next_message_index=reply["next_message_index"], - envelope_descriptor=reply["envelope_descriptor"], - envelope_hash=reply["envelope_hash"] - ) - - - async def read_channel(self, channel_id: int, message_box_index: "bytes|None" = None, - reply_index: "int|None" = None) -> ReadChannelReply: - """ - Prepares a read query for a Pigeonhole channel. - - Args: - channel_id: The 16-bit channel ID. - message_box_index: Optional message box index for resuming from a specific position. - reply_index: Optional index of the reply to return. - - Returns: - ReadChannelReply: Reply containing send_message_payload and other metadata. - - Raises: - Exception: If the read preparation fails. - """ - query_id = self.new_query_id() - - request_data = { - "channel_id": channel_id, - "query_id": query_id - } - - if message_box_index is not None: - request_data["message_box_index"] = message_box_index - - if reply_index is not None: - request_data["reply_index"] = reply_index - - request = { - "read_channel": request_data - } - - try: - reply = await self._send_and_wait(query_id=query_id, request=request) - except Exception as e: - self.logger.error(f"Error preparing read from channel: {e}") - raise - - return ReadChannelReply( - send_message_payload=reply["send_message_payload"], - current_message_index=reply["current_message_index"], - next_message_index=reply["next_message_index"], - reply_index=reply.get("reply_index"), - envelope_descriptor=reply["envelope_descriptor"], - envelope_hash=reply["envelope_hash"] - ) - - - async def resume_write_channel(self, write_cap: bytes, message_box_index: "bytes|None" = None) -> int: - """ - Resumes a write channel from a previous session. - - Args: - write_cap: The write capability bytes. - message_box_index: Optional message box index for resuming from a specific position. - - Returns: - int: The channel ID. - - Raises: - Exception: If the channel resumption fails. - """ - query_id = self.new_query_id() - - request_data = { - "query_id": query_id, - "write_cap": write_cap - } - - if message_box_index is not None: - request_data["message_box_index"] = message_box_index - - request = { - "resume_write_channel": request_data - } - - try: - reply = await self._send_and_wait(query_id=query_id, request=request) - except Exception as e: - self.logger.error(f"Error resuming write channel: {e}") - raise - return reply["channel_id"] - - - async def resume_read_channel(self, read_cap: bytes, next_message_index: "bytes|None" = None, - reply_index: "int|None" = None) -> int: - """ - Resumes a read channel from a previous session. - - Args: - read_cap: The read capability bytes. - next_message_index: Optional next message index for resuming from a specific position. - reply_index: Optional reply index. - - Returns: - int: The channel ID. - - Raises: - Exception: If the channel resumption fails. - """ - query_id = self.new_query_id() - - request_data = { - "query_id": query_id, - "read_cap": read_cap - } - - if next_message_index is not None: - request_data["next_message_index"] = next_message_index - - if reply_index is not None: - request_data["reply_index"] = reply_index - - request = { - "resume_read_channel": request_data - } - - try: - reply = await self._send_and_wait(query_id=query_id, request=request) - except Exception as e: - self.logger.error(f"Error resuming read channel: {e}") - raise - if not reply["channel_id"]: - self.logger.error(f"Error resuming read channel: no channel_id") - raise Exception("TODO resume_read_channel error", reply) - return reply["channel_id"] - - - async def resume_write_channel_query(self, write_cap: bytes, message_box_index: bytes, - envelope_descriptor: bytes, envelope_hash: bytes) -> int: - """ - Resumes a write channel with a specific query state. - This method provides more granular resumption control than resume_write_channel - by allowing the application to resume from a specific query state, including - the envelope descriptor and hash. This is useful when resuming from a partially - completed write operation that was interrupted during transmission. - - Args: - write_cap: The write capability bytes. - message_box_index: Message box index for resuming from a specific position (WriteChannelReply.current_message_index). - envelope_descriptor: Envelope descriptor from previous query (WriteChannelReply.envelope_descriptor). - envelope_hash: Envelope hash from previous query (WriteChannelReply.envelope_hash). - - Returns: - int: The channel ID. - - Raises: - Exception: If the channel resumption fails. - """ - query_id = self.new_query_id() - - request = { - "resume_write_channel_query": { - "query_id": query_id, - "write_cap": write_cap, - "message_box_index": message_box_index, - "envelope_descriptor": envelope_descriptor, - "envelope_hash": envelope_hash - } - } - - try: - reply = await self._send_and_wait(query_id=query_id, request=request) - except Exception as e: - self.logger.error(f"Error resuming write channel query: {e}") - raise - return reply["channel_id"] - - - async def resume_read_channel_query(self, read_cap: bytes, next_message_index: bytes, - reply_index: "int|None", envelope_descriptor: bytes, - envelope_hash: bytes) -> int: - """ - Resumes a read channel with a specific query state. - This method provides more granular resumption control than resume_read_channel - by allowing the application to resume from a specific query state, including - the envelope descriptor and hash. This is useful when resuming from a partially - completed read operation that was interrupted during transmission. - - Args: - read_cap: The read capability bytes. - next_message_index: Next message index for resuming from a specific position. - reply_index: Optional reply index. - envelope_descriptor: Envelope descriptor from previous query. - envelope_hash: Envelope hash from previous query. - - Returns: - int: The channel ID. - - Raises: - Exception: If the channel resumption fails. - """ - query_id = self.new_query_id() - - request_data = { - "query_id": query_id, - "read_cap": read_cap, - "next_message_index": next_message_index, - "envelope_descriptor": envelope_descriptor, - "envelope_hash": envelope_hash - } - - if reply_index is not None: - request_data["reply_index"] = reply_index - - request = { - "resume_read_channel_query": request_data - } - - try: - reply = await self._send_and_wait(query_id=query_id, request=request) - except Exception as e: - self.logger.error(f"Error resuming read channel query: {e}") - raise - return reply["channel_id"] - - - async def get_courier_destination(self) -> "Tuple[bytes, bytes]": - """ - Gets the courier service destination for channel queries. - This is a convenience method that combines get_service("courier") - and to_destination() to get the destination node and queue for - use with send_channel_query and send_channel_query_await_reply. - - Returns: - tuple: (dest_node, dest_queue) where: - - dest_node is the destination node identity hash - - dest_queue is the destination recipient queue ID - - Raises: - Exception: If the courier service is not found. - """ - courier_service = self.get_service("courier") - dest_node, dest_queue = courier_service.to_destination() - return dest_node, dest_queue - - async def send_channel_query_await_reply(self, channel_id: int, payload: bytes, - dest_node: bytes, dest_queue: bytes, - message_id: bytes, timeout_seconds=30.0) -> bytes: - """ - Sends a channel query and waits for the reply. - This combines send_channel_query with event handling to wait for the response. - - Args: - channel_id: The 16-bit channel ID. - payload: The prepared query payload. - dest_node: Destination node identity hash. - dest_queue: Destination recipient queue ID. - message_id: Message ID for reply correlation. - timeout_seconds: float (seconds to wait), None for indefinite wait - - Returns: - bytes: The received payload from the channel. - - Raises: - ThinClientOfflineError: If in offline mode (daemon not connected to mixnet). - Exception: If the query fails or times out. - """ - # Check if we're in offline mode - if not self._is_connected: - raise ThinClientOfflineError("cannot send_channel_query_await_reply in offline mode - daemon not connected to mixnet") - - # Create an event for this message_id - if message_id not in self.pending_channel_message_queries: - event = asyncio.Event() - self.pending_channel_message_queries[message_id] = event - - try: - # Send the channel query - await self.send_channel_query(channel_id, payload=payload, dest_node=dest_node, dest_queue=dest_queue, message_id=message_id) - - # Wait for the reply with timeout - await asyncio.wait_for(event.wait(), timeout=timeout_seconds) - - # Get the response payload - if message_id not in self.channel_message_query_responses: - raise Exception("No channel query reply received within timeout_seconds") - - response_payload = self.channel_message_query_responses[message_id] - - # Check if it's an error message - if isinstance(response_payload, bytes) and response_payload.startswith(b"Channel query"): - raise Exception(response_payload.decode()) - - return response_payload - - except asyncio.TimeoutError: - raise Exception("Timeout waiting for channel query reply") - finally: - # Clean up - self.pending_channel_message_queries.pop(message_id, None) - self.channel_message_query_responses.pop(message_id, None) - - async def send_channel_query(self, channel_id: int, *, payload: bytes, dest_node: bytes, - dest_queue: bytes, message_id: bytes) -> None: - """ - Sends a prepared channel query to the mixnet without waiting for a reply. - - Args: - channel_id: The 16-bit channel ID. - payload: Channel query payload prepared by write_channel or read_channel. - dest_node: Destination node identity hash. - dest_queue: Destination recipient queue ID. - message_id: Message ID for reply correlation. - - Raises: - ThinClientOfflineError: If in offline mode (daemon not connected to mixnet). - """ - # Check if we're in offline mode - if not self._is_connected: - raise ThinClientOfflineError("cannot send_channel_query while not is_connected() - daemon not connected to mixnet") - - if not isinstance(payload, bytes): - self.logger.error("send_channel_query: type error: payload= must be bytes()") - payload = payload.encode('utf-8') - - # Create the SendChannelQuery structure (matches Rust implementation) - send_channel_query = { - "message_id": message_id, - "channel_id": channel_id, - "destination_id_hash": dest_node, - "recipient_queue_id": dest_queue, - "payload": payload, - } - - # Wrap in the Request structure - request = { - "send_channel_query": send_channel_query - } - - cbor_request = cbor2.dumps(request) - length_prefix = struct.pack('>I', len(cbor_request)) - length_prefixed_request = length_prefix + cbor_request - - try: - await self._send_all(length_prefixed_request) - self.logger.info(f"Channel query sent successfully for channel {channel_id}.") - except Exception as e: - self.logger.error(f"Error sending channel query: {e}") - raise - - async def close_channel(self, channel_id: int) -> None: - """ - Closes a pigeonhole channel and cleans up its resources. - This helps avoid running out of channel IDs by properly releasing them. - This operation is infallible - it sends the close request and returns immediately. - - Args: - channel_id: The 16-bit channel ID to close. - - Raises: - Exception: If the socket send operation fails. - """ - - request = { - "close_channel": { - "channel_id": channel_id - } - } - - cbor_request = cbor2.dumps(request) - length_prefix = struct.pack('>I', len(cbor_request)) - length_prefixed_request = length_prefix + cbor_request - - try: - # CloseChannel is infallible - fire and forget, no reply expected - await self._send_all(length_prefixed_request) - except Exception as e: - self.logger.error(f"Error sending close channel request: {e}") - raise - self.logger.info(f"CloseChannel request sent for channel {channel_id}.") diff --git a/katzenpost_thinclient/core.py b/katzenpost_thinclient/core.py new file mode 100644 index 0000000..7ff43c5 --- /dev/null +++ b/katzenpost_thinclient/core.py @@ -0,0 +1,1393 @@ +# SPDX-FileCopyrightText: Copyright (C) 2024 David Stainton +# SPDX-License-Identifier: AGPL-3.0-only + +""" +Katzenpost Python Thin Client - Core Module +============================================ + +This module provides the core functionality for the Katzenpost thin client, +including the ThinClient class, configuration, and helper utilities. +""" + +import socket +import struct +import random +import coloredlogs +import logging +import sys +import io +import os +import asyncio +import cbor2 +import pprintpp +import toml +import hashlib + +from typing import Tuple, Any, Dict, List, Callable + +# Pigeonhole Replica Error Codes (matching Go pigeonhole/errors.go) +# These are error codes returned by storage replicas, passed through by the daemon +# for the StartResendingEncryptedMessage API. +REPLICA_SUCCESS = 0 +REPLICA_ERROR_BOX_ID_NOT_FOUND = 1 +REPLICA_ERROR_INVALID_BOX_ID = 2 +REPLICA_ERROR_INVALID_SIGNATURE = 3 +REPLICA_ERROR_DATABASE_FAILURE = 4 +REPLICA_ERROR_INVALID_PAYLOAD = 5 +REPLICA_ERROR_STORAGE_FULL = 6 +REPLICA_ERROR_INTERNAL_ERROR = 7 +REPLICA_ERROR_INVALID_EPOCH = 8 +REPLICA_ERROR_REPLICATION_FAILED = 9 +REPLICA_ERROR_BOX_ALREADY_EXISTS = 10 + +# Thin Client Error Codes (matching Go implementation) +# These are error codes for thin client operations (separate from replica errors) +THIN_CLIENT_SUCCESS = 0 +THIN_CLIENT_ERROR_CONNECTION_LOST = 1 +THIN_CLIENT_ERROR_TIMEOUT = 2 +THIN_CLIENT_ERROR_INVALID_REQUEST = 3 +THIN_CLIENT_ERROR_INTERNAL_ERROR = 4 +THIN_CLIENT_ERROR_MAX_RETRIES = 5 +THIN_CLIENT_ERROR_INVALID_CHANNEL = 6 +THIN_CLIENT_ERROR_CHANNEL_NOT_FOUND = 7 +THIN_CLIENT_ERROR_PERMISSION_DENIED = 8 +THIN_CLIENT_ERROR_INVALID_PAYLOAD = 9 +THIN_CLIENT_ERROR_SERVICE_UNAVAILABLE = 10 +THIN_CLIENT_ERROR_DUPLICATE_CAPABILITY = 11 +THIN_CLIENT_ERROR_COURIER_CACHE_CORRUPTION = 12 +THIN_CLIENT_PROPAGATION_ERROR = 13 +THIN_CLIENT_ERROR_INVALID_WRITE_CAPABILITY = 14 +THIN_CLIENT_ERROR_INVALID_READ_CAPABILITY = 15 +THIN_CLIENT_ERROR_INVALID_RESUME_WRITE_CHANNEL_REQUEST = 16 +THIN_CLIENT_ERROR_INVALID_RESUME_READ_CHANNEL_REQUEST = 17 +THIN_CLIENT_IMPOSSIBLE_HASH_ERROR = 18 +THIN_CLIENT_IMPOSSIBLE_NEW_WRITE_CAP_ERROR = 19 +THIN_CLIENT_IMPOSSIBLE_NEW_STATEFUL_WRITER_ERROR = 20 +THIN_CLIENT_CAPABILITY_ALREADY_IN_USE = 21 +THIN_CLIENT_ERROR_MKEM_DECRYPTION_FAILED = 22 +THIN_CLIENT_ERROR_BACAP_DECRYPTION_FAILED = 23 +THIN_CLIENT_ERROR_START_RESENDING_CANCELLED = 24 + +def thin_client_error_to_string(error_code: int) -> str: + """Convert a thin client error code to a human-readable string.""" + error_messages = { + THIN_CLIENT_SUCCESS: "Success", + THIN_CLIENT_ERROR_CONNECTION_LOST: "Connection lost", + THIN_CLIENT_ERROR_TIMEOUT: "Timeout", + THIN_CLIENT_ERROR_INVALID_REQUEST: "Invalid request", + THIN_CLIENT_ERROR_INTERNAL_ERROR: "Internal error", + THIN_CLIENT_ERROR_MAX_RETRIES: "Maximum retries exceeded", + THIN_CLIENT_ERROR_INVALID_CHANNEL: "Invalid channel", + THIN_CLIENT_ERROR_CHANNEL_NOT_FOUND: "Channel not found", + THIN_CLIENT_ERROR_PERMISSION_DENIED: "Permission denied", + THIN_CLIENT_ERROR_INVALID_PAYLOAD: "Invalid payload", + THIN_CLIENT_ERROR_SERVICE_UNAVAILABLE: "Service unavailable", + THIN_CLIENT_ERROR_DUPLICATE_CAPABILITY: "Duplicate capability", + THIN_CLIENT_ERROR_COURIER_CACHE_CORRUPTION: "Courier cache corruption", + THIN_CLIENT_PROPAGATION_ERROR: "Propagation error", + THIN_CLIENT_ERROR_INVALID_WRITE_CAPABILITY: "Invalid write capability", + THIN_CLIENT_ERROR_INVALID_READ_CAPABILITY: "Invalid read capability", + THIN_CLIENT_ERROR_INVALID_RESUME_WRITE_CHANNEL_REQUEST: "Invalid resume write channel request", + THIN_CLIENT_ERROR_INVALID_RESUME_READ_CHANNEL_REQUEST: "Invalid resume read channel request", + THIN_CLIENT_IMPOSSIBLE_HASH_ERROR: "Impossible hash error", + THIN_CLIENT_IMPOSSIBLE_NEW_WRITE_CAP_ERROR: "Failed to create new write capability", + THIN_CLIENT_IMPOSSIBLE_NEW_STATEFUL_WRITER_ERROR: "Failed to create new stateful writer", + THIN_CLIENT_CAPABILITY_ALREADY_IN_USE: "Capability already in use", + THIN_CLIENT_ERROR_MKEM_DECRYPTION_FAILED: "MKEM decryption failed", + THIN_CLIENT_ERROR_BACAP_DECRYPTION_FAILED: "BACAP decryption failed", + THIN_CLIENT_ERROR_START_RESENDING_CANCELLED: "Start resending cancelled", + } + return error_messages.get(error_code, f"Unknown thin client error code: {error_code}") + + +# Pigeonhole Replica Exceptions (matching Go sentinel errors in thin/thin.go) +# These exceptions can be caught using isinstance() for specific error handling, +# similar to how Go uses errors.Is() with sentinel errors. + +class ReplicaError(Exception): + """Base class for all replica errors.""" + pass + +class BoxIDNotFoundError(ReplicaError): + """Box ID not found on the replica. Occurs when reading from a non-existent mailbox.""" + pass + +class InvalidBoxIDError(ReplicaError): + """Invalid box ID format.""" + pass + +class InvalidSignatureError(ReplicaError): + """Signature verification failed.""" + pass + +class DatabaseFailureError(ReplicaError): + """Replica encountered a database error.""" + pass + +class InvalidPayloadError(ReplicaError): + """Payload data is invalid.""" + pass + +class StorageFullError(ReplicaError): + """Replica's storage capacity has been exceeded.""" + pass + +class ReplicaInternalError(ReplicaError): + """Internal error on the replica.""" + pass + +class InvalidEpochError(ReplicaError): + """Epoch is invalid or expired.""" + pass + +class ReplicationFailedError(ReplicaError): + """Replication to other replicas failed.""" + pass + +class BoxAlreadyExistsError(ReplicaError): + """Box already contains data. Pigeonhole writes are immutable.""" + pass + +class MKEMDecryptionFailedError(Exception): + """MKEM envelope decryption failed with all replica keys.""" + pass + +class BACAPDecryptionFailedError(Exception): + """BACAP payload decryption or signature verification failed.""" + pass + +class StartResendingCancelledError(Exception): + """StartResendingEncryptedMessage operation was cancelled.""" + pass + + +def error_code_to_exception(error_code: int) -> Exception: + """ + Maps error codes to exception instances for StartResendingEncryptedMessage. + This matches Go's errorCodeToSentinel function in thin/pigeonhole.go. + + The daemon passes through pigeonhole replica error codes (1-9) for replica-level errors. + For other errors (thin client errors like decryption failures), specific exceptions are raised. + """ + if error_code == REPLICA_SUCCESS: + return None + + # Pigeonhole replica error codes (from pigeonhole/errors.go) + if error_code == REPLICA_ERROR_BOX_ID_NOT_FOUND: # 1 + return BoxIDNotFoundError("box ID not found") + elif error_code == REPLICA_ERROR_INVALID_BOX_ID: # 2 + return InvalidBoxIDError("invalid box ID") + elif error_code == REPLICA_ERROR_INVALID_SIGNATURE: # 3 + return InvalidSignatureError("invalid signature") + elif error_code == REPLICA_ERROR_DATABASE_FAILURE: # 4 + return DatabaseFailureError("database failure") + elif error_code == REPLICA_ERROR_INVALID_PAYLOAD: # 5 + return InvalidPayloadError("invalid payload") + elif error_code == REPLICA_ERROR_STORAGE_FULL: # 6 + return StorageFullError("storage full") + elif error_code == REPLICA_ERROR_INTERNAL_ERROR: # 7 + return ReplicaInternalError("replica internal error") + elif error_code == REPLICA_ERROR_INVALID_EPOCH: # 8 + return InvalidEpochError("invalid epoch") + elif error_code == REPLICA_ERROR_REPLICATION_FAILED: # 9 + return ReplicationFailedError("replication failed") + elif error_code == REPLICA_ERROR_BOX_ALREADY_EXISTS: # 10 + return BoxAlreadyExistsError("box already exists") + + # Thin client decryption error codes + elif error_code == THIN_CLIENT_ERROR_MKEM_DECRYPTION_FAILED: # 22 + return MKEMDecryptionFailedError("MKEM decryption failed") + elif error_code == THIN_CLIENT_ERROR_BACAP_DECRYPTION_FAILED: # 23 + return BACAPDecryptionFailedError("BACAP decryption failed") + + # Thin client operation error codes + elif error_code == THIN_CLIENT_ERROR_START_RESENDING_CANCELLED: # 24 + return StartResendingCancelledError("start resending cancelled") + + # For other error codes, return a generic exception with the error string + else: + return Exception(thin_client_error_to_string(error_code)) + + +class ThinClientOfflineError(Exception): + pass + +# SURB_ID_SIZE is the size in bytes for the +# Katzenpost SURB ID. +SURB_ID_SIZE = 16 + +# MESSAGE_ID_SIZE is the size in bytes for an ID +# which is unique to the sent message. +MESSAGE_ID_SIZE = 16 + +# STREAM_ID_LENGTH is the length of a stream ID in bytes. +# Used for multi-call envelope encoding streams. +STREAM_ID_LENGTH = 16 + + +class Geometry: + """ + Geometry describes the geometry of a Sphinx packet. + + NOTE: You must not try to compose a Sphinx Geometry yourself. + It must be programmatically generated by Katzenpost + genconfig or gensphinx CLI utilities. + + We describe all the Sphinx Geometry attributes below, however + the only one you are interested in to faciliate your thin client + message bounds checking is UserForwardPayloadLength, which indicates + the maximum sized message that you can send to a mixnet service in + a single packet. + + Attributes: + PacketLength (int): The total length of a Sphinx packet in bytes. + NrHops (int): The number of hops; determines the header's structure. + HeaderLength (int): The total size of the Sphinx header in bytes. + RoutingInfoLength (int): The length of the routing information portion of the header. + PerHopRoutingInfoLength (int): The length of routing info for a single hop. + SURBLength (int): The length of a Single-Use Reply Block (SURB). + SphinxPlaintextHeaderLength (int): The length of the unencrypted plaintext header. + PayloadTagLength (int): The length of the tag used to authenticate the payload. + ForwardPayloadLength (int): The size of the full payload including padding and tag. + UserForwardPayloadLength (int): The usable portion of the payload intended for the recipient. + NextNodeHopLength (int): Derived from the expected maximum routing info block size. + SPRPKeyMaterialLength (int): The length of the key used for SPRP (Sphinx packet payload encryption). + NIKEName (str): Name of the NIKE scheme (if used). Mutually exclusive with KEMName. + KEMName (str): Name of the KEM scheme (if used). Mutually exclusive with NIKEName. + """ + + def __init__(self, *, PacketLength:int, NrHops:int, HeaderLength:int, RoutingInfoLength:int, PerHopRoutingInfoLength:int, SURBLength:int, SphinxPlaintextHeaderLength:int, PayloadTagLength:int, ForwardPayloadLength:int, UserForwardPayloadLength:int, NextNodeHopLength:int, SPRPKeyMaterialLength:int, NIKEName:str='', KEMName:str='') -> None: + self.PacketLength = PacketLength + self.NrHops = NrHops + self.HeaderLength = HeaderLength + self.RoutingInfoLength = RoutingInfoLength + self.PerHopRoutingInfoLength = PerHopRoutingInfoLength + self.SURBLength = SURBLength + self.SphinxPlaintextHeaderLength = SphinxPlaintextHeaderLength + self.PayloadTagLength = PayloadTagLength + self.ForwardPayloadLength = ForwardPayloadLength + self.UserForwardPayloadLength = UserForwardPayloadLength + self.NextNodeHopLength = NextNodeHopLength + self.SPRPKeyMaterialLength = SPRPKeyMaterialLength + self.NIKEName = NIKEName + self.KEMName = KEMName + + def __str__(self) -> str: + return ( + f"PacketLength: {self.PacketLength}\n" + f"NrHops: {self.NrHops}\n" + f"HeaderLength: {self.HeaderLength}\n" + f"RoutingInfoLength: {self.RoutingInfoLength}\n" + f"PerHopRoutingInfoLength: {self.PerHopRoutingInfoLength}\n" + f"SURBLength: {self.SURBLength}\n" + f"SphinxPlaintextHeaderLength: {self.SphinxPlaintextHeaderLength}\n" + f"PayloadTagLength: {self.PayloadTagLength}\n" + f"ForwardPayloadLength: {self.ForwardPayloadLength}\n" + f"UserForwardPayloadLength: {self.UserForwardPayloadLength}\n" + f"NextNodeHopLength: {self.NextNodeHopLength}\n" + f"SPRPKeyMaterialLength: {self.SPRPKeyMaterialLength}\n" + f"NIKEName: {self.NIKEName}\n" + f"KEMName: {self.KEMName}" + ) + + +class PigeonholeGeometry: + """ + PigeonholeGeometry describes the geometry of a Pigeonhole envelope. + + This provides mathematically precise geometry calculations for the + Pigeonhole protocol using trunnel's fixed binary format. + + It supports 3 distinct use cases: + 1. Given MaxPlaintextPayloadLength → compute all envelope sizes + 2. Given precomputed Pigeonhole Geometry → derive accommodating Sphinx Geometry + 3. Given Sphinx Geometry constraint → derive optimal Pigeonhole Geometry + + Attributes: + max_plaintext_payload_length (int): The maximum usable plaintext payload size within a Box. + courier_query_read_length (int): The size of a CourierQuery containing a ReplicaRead. + courier_query_write_length (int): The size of a CourierQuery containing a ReplicaWrite. + courier_query_reply_read_length (int): The size of a CourierQueryReply containing a ReplicaReadReply. + courier_query_reply_write_length (int): The size of a CourierQueryReply containing a ReplicaWriteReply. + nike_name (str): The NIKE scheme name used in MKEM for encrypting to multiple storage replicas. + signature_scheme_name (str): The signature scheme used for BACAP (always "Ed25519"). + """ + + # Length prefix for padded payloads + LENGTH_PREFIX_SIZE = 4 + + def __init__( + self, + *, + max_plaintext_payload_length: int, + courier_query_read_length: int = 0, + courier_query_write_length: int = 0, + courier_query_reply_read_length: int = 0, + courier_query_reply_write_length: int = 0, + nike_name: str = "", + signature_scheme_name: str = "Ed25519" + ) -> None: + self.max_plaintext_payload_length = max_plaintext_payload_length + self.courier_query_read_length = courier_query_read_length + self.courier_query_write_length = courier_query_write_length + self.courier_query_reply_read_length = courier_query_reply_read_length + self.courier_query_reply_write_length = courier_query_reply_write_length + self.nike_name = nike_name + self.signature_scheme_name = signature_scheme_name + + def validate(self) -> None: + """ + Validates that the geometry has valid parameters. + + Raises: + ValueError: If the geometry is invalid. + """ + if self.max_plaintext_payload_length <= 0: + raise ValueError("max_plaintext_payload_length must be positive") + if not self.nike_name: + raise ValueError("nike_name must be set") + if self.signature_scheme_name != "Ed25519": + raise ValueError("signature_scheme_name must be 'Ed25519'") + + def padded_payload_length(self) -> int: + """ + Returns the payload size after adding length prefix. + + Returns: + int: The padded payload length (max_plaintext_payload_length + 4). + """ + return self.max_plaintext_payload_length + self.LENGTH_PREFIX_SIZE + + def __str__(self) -> str: + return ( + f"PigeonholeGeometry:\n" + f" max_plaintext_payload_length: {self.max_plaintext_payload_length} bytes\n" + f" courier_query_read_length: {self.courier_query_read_length} bytes\n" + f" courier_query_write_length: {self.courier_query_write_length} bytes\n" + f" courier_query_reply_read_length: {self.courier_query_reply_read_length} bytes\n" + f" courier_query_reply_write_length: {self.courier_query_reply_write_length} bytes\n" + f" nike_name: {self.nike_name}\n" + f" signature_scheme_name: {self.signature_scheme_name}" + ) + + +class ConfigFile: + """ + ConfigFile represents everything loaded from a TOML file: + network, address, and geometry. + """ + def __init__(self, network:str, address:str, geometry:Geometry) -> None: + self.network : str = network + self.address : str = address + self.geometry : Geometry = geometry + + @classmethod + def load(cls, toml_path:str) -> "ConfigFile": + with open(toml_path, 'r') as f: + data = toml.load(f) + network = data.get('Network') + assert isinstance(network, str) + address = data.get('Address') + assert isinstance(address, str) + geometry_data = data.get('SphinxGeometry') + assert isinstance(geometry_data, dict) + geometry : Geometry = Geometry(**geometry_data) + return cls(network, address, geometry) + + def __str__(self) -> str: + return ( + f"Network: {self.network}\n" + f"Address: {self.address}\n" + f"Geometry:\n{self.geometry}" + ) + + +def pretty_print_obj(obj: "Any") -> str: + """ + Pretty-print a Python object using indentation and return the formatted string. + + This function uses `pprintpp` to format complex data structures + (e.g., dictionaries, lists) in a readable, indented format. + + Args: + obj (Any): The object to pretty-print. + + Returns: + str: The pretty-printed representation of the object. + """ + pp = pprintpp.PrettyPrinter(indent=4) + return pp.pformat(obj) + +def blake2_256_sum(data:bytes) -> bytes: + return hashlib.blake2b(data, digest_size=32).digest() + +class ServiceDescriptor: + """ + Describes a mixnet service endpoint retrieved from the PKI document. + + A ServiceDescriptor encapsulates the necessary information for communicating + with a service on the mix network. The service node's identity public key's hash + is used as the destination address along with the service's queue ID. + + Attributes: + recipient_queue_id (bytes): The identifier of the recipient's queue on the mixnet. ("Kaetzchen.endpoint" in the PKI) + mix_descriptor (dict): A CBOR-decoded dictionary describing the mix node, + typically includes the 'IdentityKey' and other metadata. + + Methods: + to_destination(): Returns a tuple of (provider_id_hash, recipient_queue_id), + where the provider ID is a 32-byte BLAKE2b hash of the IdentityKey. + """ + + def __init__(self, recipient_queue_id:bytes, mix_descriptor: "Dict[Any,Any]") -> None: + self.recipient_queue_id = recipient_queue_id + self.mix_descriptor = mix_descriptor + + def to_destination(self) -> "Tuple[bytes,bytes]": + "provider identity key hash and queue id" + provider_id_hash = blake2_256_sum(self.mix_descriptor['IdentityKey']) + return (provider_id_hash, self.recipient_queue_id) + +def find_services(capability:str, doc:"Dict[str,Any]") -> "List[ServiceDescriptor]": + """ + Search the PKI document for services supporting the specified capability. + + This function iterates over all service nodes in the PKI document, + deserializes each CBOR-encoded node, and looks for advertised capabilities. + If a service provides the requested capability, it is returned as a + `ServiceDescriptor`. + + Args: + capability (str): The name of the capability to search for (e.g., "echo"). + doc (dict): The decoded PKI document as a Python dictionary, + which must include a "ServiceNodes" key containing CBOR-encoded descriptors. + + Returns: + List[ServiceDescriptor]: A list of matching service descriptors that advertise the capability. + + Raises: + KeyError: If the 'ServiceNodes' field is missing from the PKI document. + """ + services = [] + for node in doc['ServiceNodes']: + mynode = cbor2.loads(node) + + # Check if the node has services in Kaetzchen field (fixed from omitempty) + if 'Kaetzchen' in mynode: + for cap, details in mynode['Kaetzchen'].items(): + if cap == capability: + service_desc = ServiceDescriptor( + recipient_queue_id=bytes(details['endpoint'], 'utf-8'), # why is this bytes when it's string in PKI? + mix_descriptor=mynode + ) + services.append(service_desc) + return services + + +class Config: + """ + Configuration object for the ThinClient containing connection details and event callbacks. + + The Config class loads network configuration from a TOML file and provides optional + callback functions that are invoked when specific events occur during client operation. + + Attributes: + network (str): Network type ('tcp', 'unix', etc.) + address (str): Network address (host:port for TCP, path for Unix sockets) + geometry (Geometry): Sphinx packet geometry parameters + on_connection_status (callable): Callback for connection status changes + on_new_pki_document (callable): Callback for new PKI documents + on_message_sent (callable): Callback for message transmission confirmations + on_message_reply (callable): Callback for received message replies + + Example: + >>> def handle_reply(event): + ... # Process the received reply + ... payload = event['payload'] + >>> + >>> config = Config("client.toml", on_message_reply=handle_reply) + >>> client = ThinClient(config) + """ + + def __init__(self, filepath:str, + on_connection_status:"Callable|None"=None, + on_new_pki_document:"Callable|None"=None, + on_message_sent:"Callable|None"=None, + on_message_reply:"Callable|None"=None) -> None: + """ + Initialize the Config object. + + Args: + filepath (str): Path to the TOML config file containing network, address, and geometry. + + on_connection_status (callable, optional): Callback invoked when the daemon's connection + status to the mixnet changes. The callback receives a single argument: + + - event (dict): Connection status event with keys: + - 'is_connected' (bool): True if daemon is connected to mixnet, False otherwise + - 'err' (str, optional): Error message if connection failed, empty string if no error + + Example: ``{'is_connected': True, 'err': ''}`` + + on_new_pki_document (callable, optional): Callback invoked when a new PKI document + is received from the mixnet. The callback receives a single argument: + + - event (dict): PKI document event with keys: + - 'payload' (bytes): CBOR-encoded PKI document data stripped of signatures + + Example: ``{'payload': b'\\xa5\\x64Epoch\\x00...'}`` + + on_message_sent (callable, optional): Callback invoked when a message has been + successfully transmitted to the mixnet. The callback receives a single argument: + + - event (dict): Message sent event with keys: + - 'message_id' (bytes): 16-byte unique identifier for the sent message + - 'surbid' (bytes, optional): SURB ID if message was sent with SURB, None otherwise + - 'sent_at' (str): ISO timestamp when message was sent + - 'reply_eta' (float): Expected round-trip time in seconds for reply + - 'err' (str, optional): Error message if sending failed, empty string if successful + + Example: ``{'message_id': b'\\x01\\x02...', 'surbid': b'\\xaa\\xbb...', 'sent_at': '2024-01-01T12:00:00Z', 'reply_eta': 30.5, 'err': ''}`` + + on_message_reply (callable, optional): Callback invoked when a reply is received + for a previously sent message. The callback receives a single argument: + + - event (dict): Message reply event with keys: + - 'message_id' (bytes): 16-byte identifier matching the original message + - 'surbid' (bytes, optional): SURB ID if reply used SURB, None otherwise + - 'payload' (bytes): Reply payload data from the service + - 'reply_index' (int, optional): Index of reply used (relevant for channel reads) + - 'error_code' (int): Error code indicating success (0) or specific failure condition + + Example: ``{'message_id': b'\\x01\\x02...', 'surbid': b'\\xaa\\xbb...', 'payload': b'echo response', 'reply_index': 0, 'error_code': 0}`` + + Note: + All callbacks are optional. If not provided, the corresponding events will be ignored. + Callbacks should be lightweight and non-blocking as they are called from the client's + event processing loop. + """ + + cfgfile = ConfigFile.load(filepath) + + self.network = cfgfile.network + self.address = cfgfile.address + self.geometry = cfgfile.geometry + + self.on_connection_status = on_connection_status + self.on_new_pki_document = on_new_pki_document + self.on_message_sent = on_message_sent + self.on_message_reply = on_message_reply + + async def handle_connection_status_event(self, event: asyncio.Event) -> None: + if self.on_connection_status: + return await self.on_connection_status(event) + + async def handle_new_pki_document_event(self, event: asyncio.Event) -> None: + if self.on_new_pki_document: + await self.on_new_pki_document(event) + + async def handle_message_sent_event(self, event: asyncio.Event) -> None: + if self.on_message_sent: + await self.on_message_sent(event) + + async def handle_message_reply_event(self, event: asyncio.Event) -> None: + if self.on_message_reply: + await self.on_message_reply(event) + + +class ThinClient: + """ + A minimal Katzenpost Python thin client for communicating with the local + Katzenpost client daemon over a UNIX or TCP socket. + + The thin client is responsible for: + - Establishing a connection to the client daemon. + - Receiving and parsing PKI documents. + - Sending messages to mixnet services (with or without SURBs). + - Handling replies and events via user-defined callbacks. + + All cryptographic operations are handled by the daemon, not by this client. + """ + + def __init__(self, config:Config) -> None: + """ + Initialize the thin client with the given configuration. + + Args: + config (Config): The configuration object containing socket details and callbacks. + + Raises: + RuntimeError: If the network type is not recognized or config is incomplete. + """ + self.pki_doc : Dict[Any,Any] | None = None + self.config = config + self.reply_received_event = asyncio.Event() + self.channel_reply_event = asyncio.Event() + self.channel_reply_data : Dict[Any,Any] | None = None + # For handling async read channel responses with message ID correlation + self.pending_read_channels : Dict[bytes,asyncio.Event] = {} # message_id -> asyncio.Event + self.read_channel_responses : Dict[bytes,bytes] = {} # message_id -> payload + self._is_connected : bool = False # Track connection state + self._stopping : bool = False # Track shutdown state to suppress expected errors + + # Mutexes to serialize socket send/recv operations: + self._send_lock = asyncio.Lock() + self._recv_lock = asyncio.Lock() + + # Letterbox for each response associated (by query_id) with a request. + self.response_queues : Dict[bytes, asyncio.Queue[Dict[str,Any]]] = {} # (query_id|message_id) -> Queue + self.ack_queues : Dict[bytes, asyncio.Queue[Dict[str,Any]]] = {} # (query_id|message_id) -> Queue + + # Channel query message ID correlation (for send_channel_query_await_reply) + self.pending_channel_message_queries : Dict[bytes, asyncio.Event] = {} # message_id -> Event + self.channel_message_query_responses : Dict[bytes, bytes] = {} # message_id -> payload + + # For message ID-based reply matching (old channel API) + self._expected_message_id : bytes | None = None + self._received_reply_payload : bytes | None = None + self._reply_received_for_message_id : asyncio.Event | None = None + self.logger = logging.getLogger('thinclient') + self.logger.setLevel(logging.DEBUG) + # Only add handler if none exists to avoid duplicate log messages + # XXX: commented out because it did in fact log twice: + #if not self.logger.handlers: + # handler = logging.StreamHandler(sys.stderr) + # self.logger.addHandler(handler) + + if self.config.network is None: + raise RuntimeError("config.network is None") + + network: str = self.config.network.lower() + self.server_addr : str | Tuple[str,int] + if network.lower().startswith("tcp"): + self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + host, port_str = self.config.address.split(":") + self.server_addr = (host, int(port_str)) + elif network.lower().startswith("unix"): + self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + + if self.config.address.startswith("@"): + # Abstract UNIX socket: leading @ means first byte is null + abstract_name = self.config.address[1:] + self.server_addr = f"\0{abstract_name}" + + # Bind to a unique abstract socket for this client + random_bytes = [random.randint(0, 255) for _ in range(16)] + hex_string = ''.join(format(byte, '02x') for byte in random_bytes) + client_abstract = f"\0katzenpost_python_thin_client_{hex_string}" + self.socket.bind(client_abstract) + else: + # Filesystem UNIX socket + self.server_addr = self.config.address + + self.socket.setblocking(False) + else: + raise RuntimeError(f"Unknown network type: {self.config.network}") + + self.socket.setblocking(False) + + + async def start(self, loop:asyncio.AbstractEventLoop) -> None: + """ + Start the thin client: establish connection to the daemon, read initial events, + and begin the background event loop. + + Args: + loop (asyncio.AbstractEventLoop): The running asyncio event loop. + + Exceptions: + BrokenPipeError + """ + self.logger.debug("connecting to daemon") + server_addr : str | Tuple[str,int] = '' + + if self.config.network.lower().startswith("tcp"): + host, port_str = self.config.address.split(":") + server_addr = (host, int(port_str)) + elif self.config.network.lower().startswith("unix"): + if self.config.address.startswith("@"): + server_addr = '\0' + self.config.address[1:] + else: + server_addr = self.config.address + else: + raise RuntimeError(f"Unknown network type: {self.config.network}") + + await loop.sock_connect(self.socket, server_addr) + + # 1st message is always a status event + response = await self.recv(loop) + assert response is not None + assert response["connection_status_event"] is not None + await self.handle_response(response) + + # 2nd message is always a new pki doc event + #response = await self.recv(loop) + #assert response is not None + #assert response["new_pki_document_event"] is not None, response + #await self.handle_response(response) + + # Start the read loop as a background task + self.logger.debug("starting read loop") + self.task = loop.create_task(self.worker_loop(loop)) + def handle_loop_err(task): + # Check stopping flag first - if we're shutting down, all errors are expected + if self._stopping: + return + try: + result = task.result() + except asyncio.CancelledError: + # Task was cancelled during shutdown - expected behavior + pass + except (BrokenPipeError, ConnectionResetError, OSError) as e: + # Connection errors can occur due to race conditions during shutdown + # Double-check _stopping flag as it may have been set after the exception + if not self._stopping: + self.logger.error(f"Unexpected connection error in worker loop: {e}") + except Exception: + import traceback + traceback.print_exc() + raise + self.task.add_done_callback(handle_loop_err) + + def get_config(self) -> Config: + """ + Returns the current configuration object. + + Returns: + Config: The client configuration in use. + """ + return self.config + + def is_connected(self) -> bool: + """ + Returns True if the daemon is connected to the mixnet. + + Returns: + bool: True if connected, False if in offline mode. + """ + return self._is_connected + + def stop(self) -> None: + """ + Gracefully shut down the client and close its socket. + """ + self.logger.debug("closing connection to daemon") + self._stopping = True # Set flag to suppress expected BrokenPipeError + self.socket.close() + self.task.cancel() + + async def _send_all(self, data: bytes) -> None: + """ + Send all data using async socket operations with mutex protection. + + This method uses a mutex to prevent race conditions when multiple + coroutines try to send data over the same socket simultaneously. + + Args: + data (bytes): Data to send. + """ + async with self._send_lock: + loop = asyncio.get_running_loop() + await loop.sock_sendall(self.socket, data) + + async def __recv_exactly(self, total:int, loop:asyncio.AbstractEventLoop) -> bytes: + "receive exactly (total) bytes or die trying raising BrokenPipeError" + buf = bytearray(total) + remain = memoryview(buf) + while len(remain): + if not (nread := await loop.sock_recv_into(self.socket, remain)): + raise BrokenPipeError + remain = remain[nread:] + return buf + + async def recv(self, loop:asyncio.AbstractEventLoop) -> "Dict[Any,Any]": + """ + Receive a CBOR-encoded message from the daemon. + + Args: + loop (asyncio.AbstractEventLoop): Event loop to use for socket reads. + + Returns: + dict: Decoded CBOR response from the daemon. + + Raises: + BrokenPipeError: If connection fails + ValueError: If message framing fails. + """ + async with self._recv_lock: + length_prefix = await self.__recv_exactly(4, loop) + message_length = struct.unpack('>I', length_prefix)[0] + raw_data = await self.__recv_exactly(message_length, loop) + try: + response = cbor2.loads(raw_data) + except cbor2.CBORDecodeValueError as e: + self.logger.error(f"{e}") + raise ValueError(f"{e}") + response = {k:v for k,v in response.items() if v} # filter empty KV pairs + if not (set(response.keys()) & {'new_pki_document_event'}): + self.logger.debug(f"Received daemon response: [{len(raw_data)}] {type(response)} {response}") + return response + + async def worker_loop(self, loop:asyncio.events.AbstractEventLoop) -> None: + """ + Background task that listens for events and dispatches them. + """ + self.logger.debug("read loop start") + while True: + try: + response = await self.recv(loop) + except asyncio.CancelledError: + # Handle cancellation of the read loop - expected during shutdown + self.logger.debug("worker_loop cancelled") + break + except (BrokenPipeError, ConnectionResetError, OSError) as e: + # Connection errors during shutdown are expected + if self._stopping: + self.logger.debug(f"Connection closed during shutdown: {e}") + break + else: + self.logger.error(f"Unexpected connection error: {e}") + raise + except Exception as e: + self.logger.error(f"Error reading from socket: {e}") + raise + else: + def handle_response_err(task): + try: + result = task.result() + except Exception: + import traceback + traceback.print_exc() + raise + resp = asyncio.create_task(self.handle_response(response)) + resp.add_done_callback(handle_response_err) + + def parse_status(self, event: "Dict[str,Any]") -> None: + """ + Parse a connection status event and update connection state. + """ + self.logger.debug("parse status") + assert event is not None + + self._is_connected = event.get("is_connected", False) + + if self._is_connected: + self.logger.debug("Daemon is connected to mixnet - full functionality available") + else: + self.logger.info("Daemon is not connected to mixnet - entering offline mode (channel operations will work)") + + self.logger.debug("parse status success") + + def pki_document(self) -> "Dict[str,Any] | None": + """ + Retrieve the latest PKI document received. + + Returns: + dict: Parsed CBOR PKI document. + """ + return self.pki_doc + + def parse_pki_doc(self, event: "Dict[str,Any]") -> None: + """ + Parse and store a new PKI document received from the daemon. + """ + self.logger.debug("parse pki doc") + assert event is not None + assert event["payload"] is not None + raw_pki_doc = cbor2.loads(event["payload"]) + self.pki_doc = raw_pki_doc + self.logger.debug("parse pki doc success") + + def get_services(self, capability:str) -> "List[ServiceDescriptor]": + """ + Look up all services in the PKI that advertise a given capability. + + Args: + capability (str): Capability name (e.g., "echo"). + + Returns: + list[ServiceDescriptor]: Matching services.xsy + + Raises: + Exception: If PKI is missing or no services match. + """ + doc = self.pki_document() + if doc == None: + raise Exception("pki doc is nil") + descriptors = find_services(capability, doc) + if not descriptors: + raise Exception("service not found in pki doc") + return descriptors + + def get_service(self, service_name:str) -> ServiceDescriptor: + """ + Select a random service matching a capability. + + Args: + service_name (str): The capability name (e.g., "echo"). + + Returns: + ServiceDescriptor: One of the matching services. + """ + service_descriptors = self.get_services(service_name) + return random.choice(service_descriptors) + + @staticmethod + def new_message_id() -> bytes: + """ + Generate a new 16-byte message ID for use with ARQ sends. + + Returns: + bytes: Random 16-byte identifier. + """ + return os.urandom(MESSAGE_ID_SIZE) + + def new_surb_id(self) -> bytes: + """ + Generate a new 16-byte SURB ID for reply-capable sends. + + Returns: + bytes: Random 16-byte identifier. + """ + return os.urandom(SURB_ID_SIZE) + + def new_query_id(self) -> bytes: + """ + Generate a new 16-byte query ID for channel API operations. + + Returns: + bytes: Random 16-byte identifier. + """ + return os.urandom(16) + + @staticmethod + def new_stream_id() -> bytes: + """ + Generate a new 16-byte stream ID for copy stream operations. + + Stream IDs are used to identify encoder instances for multi-call + envelope encoding streams. All calls for the same stream must use + the same stream ID. + + Returns: + bytes: Random 16-byte stream identifier. + """ + return os.urandom(STREAM_ID_LENGTH) + + async def _send_and_wait(self, *, query_id:bytes, request: Dict[str, Any]) -> Dict[str, Any]: + cbor_request = cbor2.dumps(request) + length_prefix = struct.pack('>I', len(cbor_request)) + length_prefixed_request = length_prefix + cbor_request + assert query_id not in self.response_queues + self.response_queues[query_id] = asyncio.Queue(maxsize=1) + request_type = list(request.keys())[0] + try: + await self._send_all(length_prefixed_request) + self.logger.info(f"{request_type} request sent.") + reply = await self.response_queues[query_id].get() + self.logger.info(f"{request_type} response received.") + # TODO error handling, see _wait_for_channel_reply + return reply + except asyncio.CancelledError: + self.logger.info("{request_type} task cancelled.") + raise + finally: + del self.response_queues[query_id] + + async def handle_response(self, response: "Dict[str,Any]") -> None: + """ + Dispatch a parsed CBOR response to the appropriate handler or callback. + """ + assert response is not None + + if response.get("connection_status_event") is not None: + self.logger.debug("connection status event") + self.parse_status(response["connection_status_event"]) + await self.config.handle_connection_status_event(response["connection_status_event"]) + return + if response.get("new_pki_document_event") is not None: + self.logger.debug("new pki doc event") + self.parse_pki_doc(response["new_pki_document_event"]) + await self.config.handle_new_pki_document_event(response["new_pki_document_event"]) + return + if response.get("message_sent_event") is not None: + self.logger.debug("message sent event") + await self.config.handle_message_sent_event(response["message_sent_event"]) + return + if response.get("message_reply_event") is not None: + self.logger.debug("message reply event") + reply = response["message_reply_event"] + + # Check if this reply matches our expected message ID for old channel operations + if hasattr(self, '_expected_message_id') and self._expected_message_id is not None: + reply_message_id = reply.get("message_id") + if reply_message_id is not None and reply_message_id == self._expected_message_id: + self.logger.debug(f"Received matching MessageReplyEvent for message_id {reply_message_id.hex()[:16]}...") + # Handle error in reply using error_code field + error_code = reply.get("error_code", 0) + self.logger.debug(f"MessageReplyEvent: error_code={error_code}") + if error_code != 0: + error_msg = thin_client_error_to_string(error_code) + self.logger.debug(f"Reply contains error: {error_msg} (error code {error_code})") + self._received_reply_payload = None + else: + payload = reply.get("payload") + if payload is None: + self._received_reply_payload = b"" + else: + self._received_reply_payload = payload + self.logger.debug(f"Reply contains {len(self._received_reply_payload)} bytes of payload") + + # Signal that we received the matching reply + if hasattr(self, '_reply_received_for_message_id'): + self._reply_received_for_message_id.set() + return + else: + if reply_message_id is not None: + self.logger.debug(f"Received MessageReplyEvent with mismatched message_id (expected {self._expected_message_id.hex()[:16]}..., got {reply_message_id.hex()[:16]}...), ignoring") + else: + self.logger.debug("Received MessageReplyEvent with nil message_id, ignoring") + + # Fall back to original behavior for non-channel operations + self.reply_received_event.set() + await self.config.handle_message_reply_event(reply) + return + # Handle channel query events (for send_channel_query_await_reply), this is the ACK from the local clientd (not courier) + if response.get("channel_query_sent_event") is not None: + # channel_query_sent_event': {'message_id': b'\xb7\xd5\xaeG\x8a\xc4\x96\x99|M\x89c\x90\xc3\xd4\x1f', 'sent_at': 1758485828, 'reply_eta': 1179000000, 'error_code': 0}, + self.logger.debug("channel_query_sent_event") + event = response["channel_query_sent_event"] + message_id = event.get("message_id") + if message_id is not None: + # Check for error in sent event + error_code = event.get("error_code", 0) + if error_code != 0: + # Store error for the waiting coroutine + if message_id in self.pending_channel_message_queries: + self.channel_message_query_responses[message_id] = f"Channel query send failed with error code: {error_code}".encode() + self.pending_channel_message_queries[message_id].set() + # Continue waiting for the reply (don't return here) + return + + # Handle old channel API replies + if response.get("create_write_channel_reply") is not None: + self.logger.debug("channel create_write_channel_reply event") + self.channel_reply_data = response + self.channel_reply_event.set() + return + + if response.get("create_read_channel_reply") is not None: + self.logger.debug("channel create_read_channel_reply event") + self.channel_reply_data = response + self.channel_reply_event.set() + return + + if response.get("write_channel_reply") is not None: + self.logger.debug("channel write_channel_reply event") + self.channel_reply_data = response + self.channel_reply_event.set() + return + + if response.get("read_channel_reply") is not None: + self.logger.debug("channel read_channel_reply event") + self.channel_reply_data = response + self.channel_reply_event.set() + return + + if response.get("copy_channel_reply") is not None: + self.logger.debug("channel copy_channel_reply event") + self.channel_reply_data = response + self.channel_reply_event.set() + return + + # Handle newer channel query reply events + if query_ack := response.get("channel_query_reply_event", None): + # this is the ACK from the courier + self.logger.debug("channel_query_reply_event") + event = response["channel_query_reply_event"] + message_id = event.get("message_id") + + if message_id is None: + self.logger.error("channel_query_reply_event without message_id") + return + + # TODO wait why are we storing these indefinitely if we don't really care about them?? + if error_code := event.get("error_code", 0): + error_msg = f"Channel query failed with error code: {error_code}".encode() + self.channel_message_query_responses[message_id] = error_msg + else: + # Extract the payload + payload = event.get("payload", b"") + self.channel_message_query_responses[message_id] = payload + + if (queue := self.ack_queues.get(message_id, None)): + self.logger.debug(f"ack_queues: populated with message_id {message_id.hex()}") + asyncio.create_task(queue.put(query_ack)) + else: + self.logger.error(f"channel_query_reply_event for message_id {message_id.hex()}, but there is no listener") + + + # Signal the waiting coroutine + if message_id in self.pending_channel_message_queries: + self.pending_channel_message_queries[message_id].set() + return + + for reply_type, reply in response.items(): + if not reply: + continue + self.logger.debug(f"channel {reply_type} event") + if not reply_type.endswith("_reply") or not (query_id := reply.get("query_id", None)): + self.logger.debug(f"{reply_type} is not a reply, or can't get query_id") + # 'create_read_channel_reply': {'query_id': None, 'channel_id': 0, 'error_code': 21}, + # DEBUG [thinclient] channel_query_reply_event is not a reply, or can't get query_id + # REPLY {'message_id': b'\xfd\xc0\x9d\xcfh\xa3\x88X[\xab\xa8\xd3\x1b\x8b\x15\xd1', 'payload': b'', 'reply_index': None, 'error_code': 0} + # SELF.RESPONSE_QUEUES {} + print("REPLY", reply) + print('SELF.RESPONSE_QUEUES', self.response_queues) + continue + if not (queue := self.response_queues.get(query_id, None)): + self.logger.debug(f"query_id for {reply_type} has no listener") + continue + # avoid blocking recv loop: + asyncio.create_task(queue.put(reply)) + + + + async def send_message_without_reply(self, payload:bytes|str, dest_node:bytes, dest_queue:bytes) -> None: + """ + Send a fire-and-forget message with no SURB or reply handling. + This method requires mixnet connectivity. + + Args: + payload (bytes or str): Message payload. + dest_node (bytes): Destination node identity hash. + dest_queue (bytes): Destination recipient queue ID. + + Raises: + ThinClientOfflineError: If in offline mode (daemon not connected to mixnet). + """ + # Check if we're in offline mode + if not self._is_connected: + raise ThinClientOfflineError("cannot send_message_without_reply in offline mode - daemon not connected to mixnet") + + if not isinstance(payload, bytes): + payload = payload.encode('utf-8') # Encoding the string to bytes + + # Create the SendMessage structure + send_message = { + "id": None, # No ID for fire-and-forget messages + "with_surb": False, + "surbid": None, # No SURB ID for fire-and-forget messages + "destination_id_hash": dest_node, + "recipient_queue_id": dest_queue, + "payload": payload, + } + + # Wrap in the new Request structure + request = { + "send_message": send_message + } + + cbor_request = cbor2.dumps(request) + length_prefix = struct.pack('>I', len(cbor_request)) + length_prefixed_request = length_prefix + cbor_request + try: + await self._send_all(length_prefixed_request) + self.logger.info("Message sent successfully.") + except Exception as e: + self.logger.error(f"Error sending message: {e}") + + async def send_message(self, surb_id:bytes, payload:bytes|str, dest_node:bytes, dest_queue:bytes) -> None: + """ + Send a message using a SURB to allow the recipient to send a reply. + This method requires mixnet connectivity. + + Args: + surb_id (bytes): SURB identifier for reply correlation. + payload (bytes or str): Message payload. + dest_node (bytes): Destination node identity hash. + dest_queue (bytes): Destination recipient queue ID. + + Raises: + ThinClientOfflineError: If in offline mode (daemon not connected to mixnet). + """ + # Check if we're in offline mode + if not self._is_connected: + raise ThinClientOfflineError("cannot send message in offline mode - daemon not connected to mixnet") + + if not isinstance(payload, bytes): + payload = payload.encode('utf-8') # Encoding the string to bytes + + # Create the SendMessage structure + send_message = { + "id": None, # No ID for regular messages + "with_surb": True, + "surbid": surb_id, + "destination_id_hash": dest_node, + "recipient_queue_id": dest_queue, + "payload": payload, + } + + # Wrap in the new Request structure + request = { + "send_message": send_message + } + + cbor_request = cbor2.dumps(request) + length_prefix = struct.pack('>I', len(cbor_request)) + length_prefixed_request = length_prefix + cbor_request + try: + await self._send_all(length_prefixed_request) + self.logger.info("Message sent successfully.") + except Exception as e: + self.logger.error(f"Error sending message: {e}") + + async def send_channel_query(self, channel_id:int, payload:bytes, dest_node:bytes, dest_queue:bytes, message_id:"bytes|None"=None): + """ + Send a channel query (prepared by write_channel or read_channel) to the mixnet. + This method sets the ChannelID inside the Request for proper channel handling. + This method requires mixnet connectivity. + + Args: + channel_id (int): The 16-bit channel ID. + payload (bytes): Channel query payload prepared by write_channel or read_channel. + dest_node (bytes): Destination node identity hash. + dest_queue (bytes): Destination recipient queue ID. + message_id (bytes, optional): Message ID for reply correlation. If None, generates a new one. + + Returns: + bytes: The message ID used for this query (either provided or generated). + + Raises: + RuntimeError: If in offline mode (daemon not connected to mixnet). + """ + # Check if we're in offline mode + if not self._is_connected: + raise RuntimeError("cannot send channel query in offline mode - daemon not connected to mixnet") + + if not isinstance(payload, bytes): + payload = payload.encode('utf-8') # Encoding the string to bytes + + # Generate message ID if not provided, and SURB ID + if message_id is None: + message_id = self.new_message_id() + self.logger.debug(f"send_channel_query: Generated message_id {message_id.hex()[:16]}...") + else: + self.logger.debug(f"send_channel_query: Using provided message_id {message_id.hex()[:16]}...") + + surb_id = self.new_surb_id() + + # Create the SendMessage structure with ChannelID + + send_message = { + "channel_id": channel_id, # This is the key difference from send_message + "id": message_id, # Use generated message_id for reply correlation + "with_surb": True, + "surbid": surb_id, + "destination_id_hash": dest_node, + "recipient_queue_id": dest_queue, + "payload": payload, + } + + # Wrap in the new Request structure + request = { + "send_message": send_message + } + + cbor_request = cbor2.dumps(request) + length_prefix = struct.pack('>I', len(cbor_request)) + length_prefixed_request = length_prefix + cbor_request + try: + await self._send_all(length_prefixed_request) + self.logger.info(f"Channel query sent successfully for channel {channel_id}.") + return message_id + except Exception as e: + self.logger.error(f"Error sending channel query: {e}") + raise + + async def send_reliable_message(self, message_id:bytes, payload:bytes|str, dest_node:bytes, dest_queue:bytes) -> None: + """ + Send a reliable message using an ARQ mechanism and message ID. + This method requires mixnet connectivity. + + Args: + message_id (bytes): Message ID for reply correlation. + payload (bytes or str): Message payload. + dest_node (bytes): Destination node identity hash. + dest_queue (bytes): Destination recipient queue ID. + + Raises: + ThinClientOfflineError: If in offline mode (daemon not connected to mixnet). + """ + # Check if we're in offline mode + if not self._is_connected: + raise ThinClientOfflineError("cannot send reliable message in offline mode - daemon not connected to mixnet") + + if not isinstance(payload, bytes): + payload = payload.encode('utf-8') # Encoding the string to bytes + + # Create the SendARQMessage structure + send_arq_message = { + "id": message_id, + "with_surb": True, + "surbid": None, # ARQ messages don't use SURB IDs directly + "destination_id_hash": dest_node, + "recipient_queue_id": dest_queue, + "payload": payload, + } + + # Wrap in the new Request structure + request = { + "send_arq_message": send_arq_message + } + + cbor_request = cbor2.dumps(request) + length_prefix = struct.pack('>I', len(cbor_request)) + length_prefixed_request = length_prefix + cbor_request + try: + await self._send_all(length_prefixed_request) + self.logger.info("Message sent successfully.") + except Exception as e: + self.logger.error(f"Error sending message: {e}") + + def pretty_print_pki_doc(self, doc: "Dict[str,Any]") -> None: + """ + Pretty-print a parsed PKI document with fully decoded CBOR nodes. + + Args: + doc (dict): Raw PKI document from the daemon. + """ + assert doc is not None + assert doc['GatewayNodes'] is not None + assert doc['ServiceNodes'] is not None + assert doc['Topology'] is not None + + new_doc = doc + gateway_nodes = [] + service_nodes = [] + topology = [] + + for gateway_cert_blob in doc['GatewayNodes']: + gateway_cert = cbor2.loads(gateway_cert_blob) + gateway_nodes.append(gateway_cert) + + for service_cert_blob in doc['ServiceNodes']: + service_cert = cbor2.loads(service_cert_blob) + service_nodes.append(service_cert) + + for layer in doc['Topology']: + for mix_desc_blob in layer: + mix_cert = cbor2.loads(mix_desc_blob) + topology.append(mix_cert) # flatten, no prob, relax + + new_doc['GatewayNodes'] = gateway_nodes + new_doc['ServiceNodes'] = service_nodes + new_doc['Topology'] = topology + pretty_print_obj(new_doc) + + async def await_message_reply(self) -> None: + """ + Asynchronously block until a reply is received from the daemon. + """ + await self.reply_received_event.wait() + diff --git a/katzenpost_thinclient/legacy.py b/katzenpost_thinclient/legacy.py new file mode 100644 index 0000000..40941f5 --- /dev/null +++ b/katzenpost_thinclient/legacy.py @@ -0,0 +1,455 @@ +# SPDX-FileCopyrightText: Copyright (C) 2024 David Stainton +# SPDX-License-Identifier: AGPL-3.0-only + +""" +Katzenpost Python Thin Client - Legacy Channel API +=================================================== + +This module provides the old channel-based Pigeonhole API methods. +These methods use the channel_id pattern and are maintained for +backward compatibility. +""" + +import asyncio +import struct +import cbor2 + +from typing import Tuple, Any, Dict + +from .core import thin_client_error_to_string + + +class WriteChannelReply: + """Reply from WriteChannel operation, matching Rust WriteChannelReply.""" + + def __init__(self, send_message_payload: bytes, current_message_index: bytes, + next_message_index: bytes, envelope_descriptor: bytes, envelope_hash: bytes): + self.send_message_payload = send_message_payload + self.current_message_index = current_message_index + self.next_message_index = next_message_index + self.envelope_hash = envelope_hash + self.envelope_descriptor = envelope_descriptor + + +class ReadChannelReply: + """Reply from ReadChannel operation, matching Rust ReadChannelReply.""" + + def __init__(self, send_message_payload: bytes, current_message_index: bytes, + next_message_index: bytes, reply_index: "int|None", + envelope_descriptor: bytes, envelope_hash: bytes): + self.send_message_payload = send_message_payload + self.current_message_index = current_message_index + self.next_message_index = next_message_index + self.reply_index = reply_index + self.envelope_descriptor = envelope_descriptor + self.envelope_hash = envelope_hash + + +# Legacy channel API methods - these will be attached to ThinClient class + +async def create_write_channel(self, write_cap: "bytes|None"=None, message_box_index: "bytes|None"=None) -> "Tuple[int,bytes,bytes,bytes]": + """ + Create a new pigeonhole write channel. + + Args: + write_cap: Optional WriteCap for resuming an existing channel. + message_box_index: Optional MessageBoxIndex for resuming from a specific position. + + Returns: + tuple: (channel_id, read_cap, write_cap, next_message_index) where: + - channel_id is 16-bit channel ID + - read_cap is the read capability for sharing + - write_cap is the write capability for persistence + - next_message_index is the current position for crash consistency + + Raises: + Exception: If the channel creation fails. + """ + request_data = {} + + if write_cap is not None: + request_data["write_cap"] = write_cap + + if message_box_index is not None: + request_data["message_box_index"] = message_box_index + + request = { + "create_write_channel": request_data + } + + cbor_request = cbor2.dumps(request) + length_prefix = struct.pack('>I', len(cbor_request)) + length_prefixed_request = length_prefix + cbor_request + + try: + # Clear previous reply data and reset event + self.channel_reply_data = None + self.channel_reply_event.clear() + + await self._send_all(length_prefixed_request) + self.logger.info("CreateWriteChannel request sent successfully.") + + # Wait for CreateWriteChannelReply via the background worker + await self.channel_reply_event.wait() + + if self.channel_reply_data and self.channel_reply_data.get("create_write_channel_reply"): + reply = self.channel_reply_data["create_write_channel_reply"] + error_code = reply.get("error_code", 0) + if error_code != 0: + error_msg = thin_client_error_to_string(error_code) + raise Exception(f"CreateWriteChannel failed: {error_msg} (error code {error_code})") + return reply["channel_id"], reply["read_cap"], reply["write_cap"], reply["next_message_index"] + else: + raise Exception("No create_write_channel_reply received") + + except Exception as e: + self.logger.error(f"Error creating write channel: {e}") + raise + + +async def create_read_channel(self, read_cap: bytes, message_box_index: "bytes|None"=None) -> "Tuple[int,bytes]": + """ + Create a read channel from a read capability. + + Args: + read_cap: The read capability object. + message_box_index: Optional MessageBoxIndex for resuming from a specific position. + + Returns: + tuple: (channel_id, next_message_index) where: + - channel_id is the 16-bit channel ID + - next_message_index is the current position for crash consistency + + Raises: + Exception: If the read channel creation fails. + """ + request_data = { + "read_cap": read_cap + } + + if message_box_index is not None: + request_data["message_box_index"] = message_box_index + + request = { + "create_read_channel": request_data + } + + cbor_request = cbor2.dumps(request) + length_prefix = struct.pack('>I', len(cbor_request)) + length_prefixed_request = length_prefix + cbor_request + + try: + # Clear previous reply data and reset event + self.channel_reply_data = None + self.channel_reply_event.clear() + + await self._send_all(length_prefixed_request) + self.logger.info("CreateReadChannel request sent successfully.") + + # Wait for CreateReadChannelReply via the background worker + await self.channel_reply_event.wait() + + if self.channel_reply_data and self.channel_reply_data.get("create_read_channel_reply"): + reply = self.channel_reply_data["create_read_channel_reply"] + error_code = reply.get("error_code", 0) + if error_code != 0: + error_msg = thin_client_error_to_string(error_code) + raise Exception(f"CreateReadChannel failed: {error_msg} (error code {error_code})") + return reply["channel_id"], reply["next_message_index"] + else: + raise Exception("No create_read_channel_reply received") + + except Exception as e: + self.logger.error(f"Error creating read channel: {e}") + raise + + +async def write_channel(self, channel_id: int, payload: "bytes|str") -> "Tuple[bytes,bytes]": + """ + Prepare a write message for a pigeonhole channel and return the SendMessage payload and next MessageBoxIndex. + The thin client must then call send_message with the returned payload to actually send the message. + + Args: + channel_id (int): The 16-bit channel ID. + payload (bytes or str): The data to write to the channel. + + Returns: + tuple: (send_message_payload, next_message_index) where: + - send_message_payload is the prepared payload for send_message + - next_message_index is the position to use after courier acknowledgment + + Raises: + Exception: If the write preparation fails. + """ + if not isinstance(payload, bytes): + payload = payload.encode('utf-8') + + request = { + "write_channel": { + "channel_id": channel_id, + "payload": payload + } + } + + cbor_request = cbor2.dumps(request) + length_prefix = struct.pack('>I', len(cbor_request)) + length_prefixed_request = length_prefix + cbor_request + + try: + # Clear previous reply data and reset event + self.channel_reply_data = None + self.channel_reply_event.clear() + + await self._send_all(length_prefixed_request) + self.logger.info("WriteChannel prepare request sent successfully.") + + # Wait for WriteChannelReply via the background worker + await self.channel_reply_event.wait() + + if self.channel_reply_data and self.channel_reply_data.get("write_channel_reply"): + reply = self.channel_reply_data["write_channel_reply"] + error_code = reply.get("error_code", 0) + if error_code != 0: + error_msg = thin_client_error_to_string(error_code) + raise Exception(f"WriteChannel failed: {error_msg} (error code {error_code})") + return reply["send_message_payload"], reply["next_message_index"] + else: + raise Exception("No write_channel_reply received") + + except Exception as e: + self.logger.error(f"Error preparing write to channel: {e}") + raise + + +async def read_channel(self, channel_id: int, message_id: "bytes|None"=None, reply_index: "int|None"=None) -> "Tuple[bytes,bytes,int|None]": + """ + Prepare a read query for a pigeonhole channel and return the SendMessage payload, next MessageBoxIndex, and used ReplyIndex. + The thin client must then call send_message with the returned payload to actually send the query. + + Args: + channel_id (int): The 16-bit channel ID. + message_id (bytes, optional): The 16-byte message ID for correlation. If None, generates a new one. + reply_index (int, optional): The index of the reply to return. If None, defaults to 0. + + Returns: + tuple: (send_message_payload, next_message_index, used_reply_index) where: + - send_message_payload is the prepared payload for send_message + - next_message_index is the position to use after successful read + - used_reply_index is the reply index that was used (or None if not specified) + + Raises: + Exception: If the read preparation fails. + """ + if message_id is None: + message_id = self.new_message_id() + + request_data = { + "channel_id": channel_id, + "message_id": message_id + } + + if reply_index is not None: + request_data["reply_index"] = reply_index + + request = { + "read_channel": request_data + } + + cbor_request = cbor2.dumps(request) + length_prefix = struct.pack('>I', len(cbor_request)) + length_prefixed_request = length_prefix + cbor_request + + try: + # Clear previous reply data and reset event + self.channel_reply_data = None + self.channel_reply_event.clear() + + await self._send_all(length_prefixed_request) + self.logger.info(f"ReadChannel request sent for message_id {message_id.hex()[:16]}...") + + # Wait for ReadChannelReply via the background worker + await self.channel_reply_event.wait() + + if self.channel_reply_data and self.channel_reply_data.get("read_channel_reply"): + reply = self.channel_reply_data["read_channel_reply"] + error_code = reply.get("error_code", 0) + if error_code != 0: + error_msg = thin_client_error_to_string(error_code) + raise Exception(f"ReadChannel failed: {error_msg} (error code {error_code})") + + used_reply_index = reply.get("reply_index") + return reply["send_message_payload"], reply["next_message_index"], used_reply_index + else: + raise Exception("No read_channel_reply received") + + except Exception as e: + self.logger.error(f"Error preparing read from channel: {e}") + raise + + +async def read_channel_with_retry(self, channel_id: int, dest_node: bytes, dest_queue: bytes, + max_retries: int = 2) -> bytes: + """ + Send a read query for a pigeonhole channel with automatic reply index retry. + It first tries reply index 0 up to max_retries times, and if that fails, + it tries reply index 1 up to max_retries times. + This method handles the common case where the courier has cached replies at different indices + and accounts for timing issues where messages may not have propagated yet. + This method requires mixnet connectivity and will fail in offline mode. + The method generates its own message ID and matches replies for correct correlation. + + Args: + channel_id (int): The 16-bit channel ID. + dest_node (bytes): Destination node identity hash. + dest_queue (bytes): Destination recipient queue ID. + max_retries (int): Maximum number of attempts per reply index (default: 2). + + Returns: + bytes: The received payload from the channel. + + Raises: + RuntimeError: If in offline mode (daemon not connected to mixnet). + Exception: If all retry attempts fail. + """ + # Check if we're in offline mode + if not self._is_connected: + raise RuntimeError("cannot send channel query in offline mode - daemon not connected to mixnet") + + # Generate a new message ID for this read operation + message_id = self.new_message_id() + self.logger.debug(f"read_channel_with_retry: Generated message_id {message_id.hex()[:16]}...") + + reply_indices = [0, 1] + + for reply_index in reply_indices: + self.logger.debug(f"read_channel_with_retry: Trying reply index {reply_index}") + + # Prepare the read query for this reply index + try: + # read_channel expects int channel_id + payload, _, _ = await self.read_channel(channel_id, message_id, reply_index) + except Exception as e: + self.logger.error(f"Failed to prepare read query with reply index {reply_index}: {e}") + continue + + # Try this reply index up to max_retries times + for attempt in range(1, max_retries + 1): + self.logger.debug(f"read_channel_with_retry: Reply index {reply_index} attempt {attempt}/{max_retries}") + + try: + # Send the channel query and wait for matching reply + result = await self._send_channel_query_and_wait_for_message_id( + channel_id, payload, dest_node, dest_queue, message_id, is_read_operation=True + ) + + # For read operations, we should only consider it successful if we got actual data + if len(result) > 0: + self.logger.debug(f"read_channel_with_retry: Reply index {reply_index} succeeded on attempt {attempt} with {len(result)} bytes") + return result + else: + self.logger.debug(f"read_channel_with_retry: Reply index {reply_index} attempt {attempt} got empty payload, treating as failure") + raise Exception("received empty payload - message not available yet") + + except Exception as e: + self.logger.debug(f"read_channel_with_retry: Reply index {reply_index} attempt {attempt} failed: {e}") + + # If this was the last attempt for this reply index, move to next reply index + if attempt == max_retries: + break + + # Add a delay between retries to allow for message propagation (match Go client) + await asyncio.sleep(5.0) + + # All reply indices and attempts failed + self.logger.debug(f"read_channel_with_retry: All reply indices failed after {max_retries} attempts each") + raise Exception("all reply indices failed after multiple attempts") + + +async def _send_channel_query_and_wait_for_message_id(self, channel_id: int, payload: bytes, + dest_node: bytes, dest_queue: bytes, + expected_message_id: bytes, is_read_operation: bool = True) -> bytes: + """ + Send a channel query and wait for a reply with the specified message ID. + This method matches replies by message ID to ensure correct correlation. + + Args: + channel_id (int): The channel ID for the query + payload (bytes): The prepared query payload + dest_node (bytes): Destination node identity hash + dest_queue (bytes): Destination recipient queue ID + expected_message_id (bytes): The message ID to match replies against + is_read_operation (bool): Whether this is a read operation (affects empty payload handling) + + Returns: + bytes: The received payload + + Raises: + Exception: If the query fails or times out + """ + # Store the expected message ID for reply matching + self._expected_message_id = expected_message_id + self._received_reply_payload = None + self._reply_received_for_message_id = asyncio.Event() + self._reply_received_for_message_id.clear() + + try: + # Send the channel query with the specific expected_message_id + actual_message_id = await self.send_channel_query(channel_id, payload, dest_node, dest_queue, expected_message_id) + + # Verify that the message ID matches what we expected + assert actual_message_id == expected_message_id, f"Message ID mismatch: expected {expected_message_id.hex()}, got {actual_message_id.hex()}" + + # Wait for the matching reply with timeout + await asyncio.wait_for(self._reply_received_for_message_id.wait(), timeout=120.0) + + # Check if we got a valid payload + if self._received_reply_payload is None: + raise Exception("no reply received for message ID") + + # Handle empty payload based on operation type + if len(self._received_reply_payload) == 0: + if is_read_operation: + raise Exception("message not available yet - empty payload") + else: + return b"" # Empty payload is success for write operations + + return self._received_reply_payload + + except asyncio.TimeoutError: + raise Exception("timeout waiting for reply") + finally: + # Clean up + self._expected_message_id = None + self._received_reply_payload = None + + +async def close_channel(self, channel_id: int) -> None: + """ + Close a pigeonhole channel and clean up its resources. + This helps avoid running out of channel IDs by properly releasing them. + This operation is infallible - it sends the close request and returns immediately. + + Args: + channel_id (int): The 16-bit channel ID to close. + + Raises: + Exception: If the socket send operation fails. + """ + request = { + "close_channel": { + "channel_id": channel_id + } + } + + cbor_request = cbor2.dumps(request) + length_prefix = struct.pack('>I', len(cbor_request)) + length_prefixed_request = length_prefix + cbor_request + + try: + # CloseChannel is infallible - fire and forget, no reply expected + await self._send_all(length_prefixed_request) + self.logger.info(f"CloseChannel request sent for channel {channel_id}.") + except Exception as e: + self.logger.error(f"Error sending close channel request: {e}") + raise + diff --git a/katzenpost_thinclient/pigeonhole.py b/katzenpost_thinclient/pigeonhole.py new file mode 100644 index 0000000..9f1cf89 --- /dev/null +++ b/katzenpost_thinclient/pigeonhole.py @@ -0,0 +1,978 @@ +# SPDX-FileCopyrightText: Copyright (C) 2024 David Stainton +# SPDX-License-Identifier: AGPL-3.0-only + +""" +Katzenpost Python Thin Client - New Pigeonhole API +=================================================== + +This module provides the new capability-based Pigeonhole API methods. +These methods use WriteCap/ReadCap keypairs and provide direct +control over the Pigeonhole protocol. +""" + +import os +from dataclasses import dataclass +from typing import Any, Dict, List + +from .core import ( + THIN_CLIENT_SUCCESS, + thin_client_error_to_string, + error_code_to_exception, + PigeonholeGeometry, + STREAM_ID_LENGTH, +) + + +@dataclass +class KeypairResult: + """Result from new_keypair containing the generated capabilities.""" + write_cap: bytes + read_cap: bytes + first_message_index: bytes + + +@dataclass +class EncryptReadResult: + """Result from encrypt_read containing the encrypted read request.""" + message_ciphertext: bytes + next_message_index: bytes + envelope_descriptor: bytes + envelope_hash: bytes + + +@dataclass +class EncryptWriteResult: + """Result from encrypt_write containing the encrypted write request.""" + message_ciphertext: bytes + envelope_descriptor: bytes + envelope_hash: bytes + + +# New Pigeonhole API methods - these will be attached to ThinClient class + + +def stream_id(self) -> bytes: + """ + Generate a new 16-byte stream ID for copy stream operations. + + Stream IDs are used to identify encoder instances for multi-call + envelope encoding streams. All calls for the same stream must use + the same stream ID. + + Returns: + bytes: Random 16-byte stream identifier. + """ + return os.urandom(STREAM_ID_LENGTH) + + +async def new_keypair(self, seed: bytes) -> KeypairResult: + """ + Creates a new keypair for use with the Pigeonhole protocol. + + This method generates a WriteCap and ReadCap from the provided seed using + the BACAP (Blinding-and-Capability) protocol. The WriteCap should be stored + securely for writing messages, while the ReadCap can be shared with others + to allow them to read messages. + + Args: + seed: 32-byte seed used to derive the keypair. + + Returns: + KeypairResult: Contains write_cap, read_cap, and first_message_index. + + Raises: + Exception: If the keypair creation fails. + ValueError: If seed is not exactly 32 bytes. + + Example: + >>> import os + >>> seed = os.urandom(32) + >>> result = await client.new_keypair(seed) + >>> # Share result.read_cap with Bob so he can read messages + >>> # Store result.write_cap for sending messages + """ + if len(seed) != 32: + raise ValueError("seed must be exactly 32 bytes") + + query_id = self.new_query_id() + + request = { + "new_keypair": { + "query_id": query_id, + "seed": seed + } + } + + try: + reply = await self._send_and_wait(query_id=query_id, request=request) + except Exception as e: + self.logger.error(f"Error creating keypair: {e}") + raise + + if reply.get('error_code', 0) != THIN_CLIENT_SUCCESS: + error_msg = thin_client_error_to_string(reply['error_code']) + raise Exception(f"new_keypair failed: {error_msg}") + + return KeypairResult( + write_cap=reply["write_cap"], + read_cap=reply["read_cap"], + first_message_index=reply["first_message_index"] + ) + + +async def encrypt_read(self, read_cap: bytes, message_box_index: bytes) -> EncryptReadResult: + """ + Encrypts a read operation for a given read capability. + + This method prepares an encrypted read request that can be sent to the + courier service to retrieve a message from a pigeonhole box. The returned + ciphertext should be sent via start_resending_encrypted_message. + + Args: + read_cap: Read capability that grants access to the channel. + message_box_index: Starting read position for the channel. + + Returns: + EncryptReadResult: Contains message_ciphertext, next_message_index, + envelope_descriptor, and envelope_hash. + + Raises: + Exception: If the encryption fails. + + Example: + >>> result = await client.encrypt_read(read_cap, message_box_index) + >>> # Send result.message_ciphertext via start_resending_encrypted_message + """ + query_id = self.new_query_id() + + request = { + "encrypt_read": { + "query_id": query_id, + "read_cap": read_cap, + "message_box_index": message_box_index + } + } + + try: + reply = await self._send_and_wait(query_id=query_id, request=request) + except Exception as e: + self.logger.error(f"Error encrypting read: {e}") + raise + + if reply.get('error_code', 0) != THIN_CLIENT_SUCCESS: + error_msg = thin_client_error_to_string(reply['error_code']) + raise Exception(f"encrypt_read failed: {error_msg}") + + return EncryptReadResult( + message_ciphertext=reply["message_ciphertext"], + next_message_index=reply["next_message_index"], + envelope_descriptor=reply["envelope_descriptor"], + envelope_hash=reply["envelope_hash"] + ) + + +async def encrypt_write(self, plaintext: bytes, write_cap: bytes, message_box_index: bytes) -> EncryptWriteResult: + """ + Encrypts a write operation for a given write capability. + + This method prepares an encrypted write request that can be sent to the + courier service to store a message in a pigeonhole box. The returned + ciphertext should be sent via start_resending_encrypted_message. + + Plaintext Size Constraint: + The plaintext must not exceed PigeonholeGeometry.max_plaintext_payload_length + bytes. The daemon internally adds a 4-byte big-endian length prefix before + padding and encryption, so the actual wire format is: + [4-byte length][plaintext][zero padding]. + + If the plaintext exceeds the maximum size, the daemon will return + ThinClientErrorInvalidRequest. + + Args: + plaintext: The plaintext message to encrypt. Must be at most + PigeonholeGeometry.max_plaintext_payload_length bytes. + write_cap: Write capability that grants access to the channel. + message_box_index: The message box index for this write operation. + + Returns: + EncryptWriteResult: Contains message_ciphertext, envelope_descriptor, + and envelope_hash. + + Raises: + Exception: If the encryption fails (including if plaintext is too large). + + Example: + >>> plaintext = b"Hello, Bob!" + >>> result = await client.encrypt_write(plaintext, write_cap, message_box_index) + >>> # Send result.message_ciphertext via start_resending_encrypted_message + """ + query_id = self.new_query_id() + + request = { + "encrypt_write": { + "query_id": query_id, + "plaintext": plaintext, + "write_cap": write_cap, + "message_box_index": message_box_index + } + } + + try: + reply = await self._send_and_wait(query_id=query_id, request=request) + except Exception as e: + self.logger.error(f"Error encrypting write: {e}") + raise + + if reply.get('error_code', 0) != THIN_CLIENT_SUCCESS: + error_msg = thin_client_error_to_string(reply['error_code']) + raise Exception(f"encrypt_write failed: {error_msg}") + + return EncryptWriteResult( + message_ciphertext=reply["message_ciphertext"], + envelope_descriptor=reply["envelope_descriptor"], + envelope_hash=reply["envelope_hash"] + ) + + +async def start_resending_encrypted_message( + self, + read_cap: "bytes|None", + write_cap: "bytes|None", + next_message_index: "bytes|None", + reply_index: "int|None", + envelope_descriptor: bytes, + message_ciphertext: bytes, + envelope_hash: bytes, + no_retry_on_box_id_not_found: bool = False, + no_idempotent_box_already_exists: bool = False +) -> bytes: + """ + Starts resending an encrypted message via ARQ. + + This method initiates automatic repeat request (ARQ) for an encrypted message, + which will be resent periodically until either: + - A reply is received from the courier + - The message is cancelled via cancel_resending_encrypted_message + - The client is shut down + + This is used for both read and write operations in the new Pigeonhole API. + + The daemon implements a finite state machine (FSM) for handling the stop-and-wait ARQ protocol: + - For write operations (write_cap != None, read_cap == None): + The method waits for an ACK from the courier and returns immediately. + - For read operations (read_cap != None, write_cap == None): + The method waits for an ACK from the courier, then the daemon automatically + sends a new SURB to request the payload, and this method waits for the payload. + The daemon performs all decryption (MKEM envelope + BACAP payload) and returns + the fully decrypted plaintext. + + Args: + read_cap: Read capability (can be None for write operations, required for reads). + write_cap: Write capability (can be None for read operations, required for writes). + next_message_index: Next message index for BACAP decryption (required for reads). + reply_index: Index of the reply to use (typically 0 or 1). + envelope_descriptor: Serialized envelope descriptor for MKEM decryption. + message_ciphertext: MKEM-encrypted message to send (from encrypt_read or encrypt_write). + envelope_hash: Hash of the courier envelope. + no_retry_on_box_id_not_found: If True, BoxIDNotFound errors on reads trigger + immediate error instead of automatic retries. By default (False), reads + will retry up to 10 times to handle replication lag. Set to True to get + immediate BoxIDNotFound error without retries. + no_idempotent_box_already_exists: If True, BoxAlreadyExists errors on writes are + returned as errors instead of being treated as idempotent success. + By default (False), BoxAlreadyExists is treated as success (the write + already happened). Set to True to detect whether a write was actually + performed or if the box already existed. + + Returns: + bytes: For read operations, the decrypted plaintext message (at most + PigeonholeGeometry.max_plaintext_payload_length bytes). The length + prefix and padding are automatically removed by the daemon. + For write operations, returns an empty bytes object on success. + + Raises: + BoxIDNotFoundError: If no_retry_on_box_id_not_found=True and the box does not exist. + BoxAlreadyExistsError: If no_idempotent_box_already_exists=True and the box + already contains data. + Exception: If the operation fails. Check error_code for specific errors. + + Example: + >>> plaintext = await client.start_resending_encrypted_message( + ... read_cap, None, next_index, reply_idx, env_desc, ciphertext, env_hash) + >>> print(f"Received: {plaintext}") + """ + query_id = self.new_query_id() + + request = { + "start_resending_encrypted_message": { + "query_id": query_id, + "read_cap": read_cap, + "write_cap": write_cap, + "next_message_index": next_message_index, + "reply_index": reply_index, + "envelope_descriptor": envelope_descriptor, + "message_ciphertext": message_ciphertext, + "envelope_hash": envelope_hash, + "no_retry_on_box_id_not_found": no_retry_on_box_id_not_found, + "no_idempotent_box_already_exists": no_idempotent_box_already_exists + } + } + + try: + reply = await self._send_and_wait(query_id=query_id, request=request) + except Exception as e: + self.logger.error(f"Error starting resending encrypted message: {e}") + raise + + error_code = reply.get('error_code', 0) + if error_code != THIN_CLIENT_SUCCESS: + # Use error_code_to_exception to map error codes to specific exceptions + # This matches Go's errorCodeToSentinel behavior for replica error codes (1-9) + # and thin client error codes (22-24) + exc = error_code_to_exception(error_code) + if exc: + raise exc + # Should not reach here, but fallback just in case + error_msg = thin_client_error_to_string(error_code) + raise Exception(f"start_resending_encrypted_message failed: {error_msg}") + + return reply.get("plaintext", b"") + + +async def start_resending_encrypted_message_return_box_exists( + self, + read_cap: "bytes|None", + write_cap: "bytes|None", + next_message_index: "bytes|None", + reply_index: "int|None", + envelope_descriptor: bytes, + message_ciphertext: bytes, + envelope_hash: bytes +) -> bytes: + """ + Like start_resending_encrypted_message but returns BoxAlreadyExists errors. + + This is a convenience method that calls start_resending_encrypted_message with + no_idempotent_box_already_exists=True. Use this when you want to detect whether + a write was actually performed or if the box already existed. + + Args: + read_cap: Read capability (can be None for write operations, required for reads). + write_cap: Write capability (can be None for read operations, required for writes). + next_message_index: Next message index for BACAP decryption (required for reads). + reply_index: Index of the reply to use (typically 0 or 1). + envelope_descriptor: Serialized envelope descriptor for MKEM decryption. + message_ciphertext: MKEM-encrypted message to send (from encrypt_read or encrypt_write). + envelope_hash: Hash of the courier envelope. + + Returns: + bytes: For read operations, the decrypted plaintext message. + For write operations, returns an empty bytes object on success. + + Raises: + BoxAlreadyExistsError: If the box already contains data. + Exception: If the operation fails. + + Example: + >>> try: + ... await client.start_resending_encrypted_message_return_box_exists( + ... None, write_cap, None, None, env_desc, ciphertext, env_hash) + ... except BoxAlreadyExistsError: + ... print("Box already has data - write was idempotent") + """ + return await self.start_resending_encrypted_message( + read_cap=read_cap, + write_cap=write_cap, + next_message_index=next_message_index, + reply_index=reply_index, + envelope_descriptor=envelope_descriptor, + message_ciphertext=message_ciphertext, + envelope_hash=envelope_hash, + no_idempotent_box_already_exists=True + ) + + +async def start_resending_encrypted_message_no_retry( + self, + read_cap: "bytes|None", + write_cap: "bytes|None", + next_message_index: "bytes|None", + reply_index: "int|None", + envelope_descriptor: bytes, + message_ciphertext: bytes, + envelope_hash: bytes +) -> bytes: + """ + Like start_resending_encrypted_message but disables automatic retries on BoxIDNotFound. + + This is a convenience method that calls start_resending_encrypted_message with + no_retry_on_box_id_not_found=True. Use this when you want immediate error feedback + rather than waiting for potential replication lag to resolve. + + Args: + read_cap: Read capability (can be None for write operations, required for reads). + write_cap: Write capability (can be None for read operations, required for writes). + next_message_index: Next message index for BACAP decryption (required for reads). + reply_index: Index of the reply to use (typically 0 or 1). + envelope_descriptor: Serialized envelope descriptor for MKEM decryption. + message_ciphertext: MKEM-encrypted message to send (from encrypt_read or encrypt_write). + envelope_hash: Hash of the courier envelope. + + Returns: + bytes: For read operations, the decrypted plaintext message. + For write operations, returns an empty bytes object on success. + + Raises: + BoxIDNotFoundError: If the box does not exist (no automatic retries). + Exception: If the operation fails. + + Example: + >>> try: + ... plaintext = await client.start_resending_encrypted_message_no_retry( + ... read_cap, None, next_index, reply_idx, env_desc, ciphertext, env_hash) + ... except BoxIDNotFoundError: + ... print("Box not found - message not yet written") + """ + return await self.start_resending_encrypted_message( + read_cap=read_cap, + write_cap=write_cap, + next_message_index=next_message_index, + reply_index=reply_index, + envelope_descriptor=envelope_descriptor, + message_ciphertext=message_ciphertext, + envelope_hash=envelope_hash, + no_retry_on_box_id_not_found=True + ) + + +async def cancel_resending_encrypted_message(self, envelope_hash: bytes) -> None: + """ + Cancels ARQ resending for an encrypted message. + + This method stops the automatic repeat request (ARQ) for a previously started + encrypted message transmission. This is useful when: + - A reply has been received through another channel + - The operation should be aborted + - The message is no longer needed + + Args: + envelope_hash: Hash of the courier envelope to cancel. + + Raises: + Exception: If the cancellation fails. + + Example: + >>> await client.cancel_resending_encrypted_message(env_hash) + """ + query_id = self.new_query_id() + + request = { + "cancel_resending_encrypted_message": { + "query_id": query_id, + "envelope_hash": envelope_hash + } + } + + try: + reply = await self._send_and_wait(query_id=query_id, request=request) + except Exception as e: + self.logger.error(f"Error cancelling resending encrypted message: {e}") + raise + + if reply.get('error_code', 0) != THIN_CLIENT_SUCCESS: + error_msg = thin_client_error_to_string(reply['error_code']) + raise Exception(f"cancel_resending_encrypted_message failed: {error_msg}") + + +async def next_message_box_index(self, message_box_index: bytes) -> bytes: + """ + Increments a MessageBoxIndex using the BACAP NextIndex method. + + This method is used when sending multiple messages to different mailboxes using + the same WriteCap or ReadCap. It properly advances the cryptographic state by: + - Incrementing the Idx64 counter + - Deriving new encryption and blinding keys using HKDF + - Updating the HKDF state for the next iteration + + The daemon handles the cryptographic operations internally, ensuring correct + BACAP protocol implementation. + + Args: + message_box_index: Current message box index to increment (as bytes). + + Returns: + bytes: The next message box index. + + Raises: + Exception: If the increment operation fails. + + Example: + >>> current_index = first_message_index + >>> next_index = await client.next_message_box_index(current_index) + >>> # Use next_index for the next message + """ + query_id = self.new_query_id() + + request = { + "next_message_box_index": { + "query_id": query_id, + "message_box_index": message_box_index + } + } + + try: + reply = await self._send_and_wait(query_id=query_id, request=request) + except Exception as e: + self.logger.error(f"Error incrementing message box index: {e}") + raise + + if reply.get('error_code', 0) != THIN_CLIENT_SUCCESS: + error_msg = thin_client_error_to_string(reply['error_code']) + raise Exception(f"next_message_box_index failed: {error_msg}") + + return reply.get("next_message_box_index") + + +async def start_resending_copy_command( + self, + write_cap: bytes, + courier_identity_hash: "bytes|None" = None, + courier_queue_id: "bytes|None" = None +) -> None: + """ + Starts resending a copy command to a courier via ARQ. + + This method instructs a courier to read data from a temporary channel + (identified by the write_cap) and write it to the destination channel. + The command is automatically retransmitted until acknowledged. + + If courier_identity_hash and courier_queue_id are both provided, + the copy command is sent to that specific courier. Otherwise, a + random courier is selected. + + Args: + write_cap: Write capability for the temporary channel containing the data. + courier_identity_hash: Optional identity hash of a specific courier to use. + courier_queue_id: Optional queue ID for the specified courier. Must be set + if courier_identity_hash is set. + + Raises: + Exception: If the operation fails. + + Example: + >>> # Send copy command to a random courier + >>> await client.start_resending_copy_command(temp_write_cap) + >>> # Send copy command to a specific courier + >>> await client.start_resending_copy_command( + ... temp_write_cap, courier_identity_hash, courier_queue_id) + """ + query_id = self.new_query_id() + + request_data = { + "query_id": query_id, + "write_cap": write_cap, + } + + if courier_identity_hash is not None: + request_data["courier_identity_hash"] = courier_identity_hash + if courier_queue_id is not None: + request_data["courier_queue_id"] = courier_queue_id + + request = { + "start_resending_copy_command": request_data + } + + try: + reply = await self._send_and_wait(query_id=query_id, request=request) + except Exception as e: + self.logger.error(f"Error starting resending copy command: {e}") + raise + + if reply.get('error_code', 0) != THIN_CLIENT_SUCCESS: + error_msg = thin_client_error_to_string(reply['error_code']) + raise Exception(f"start_resending_copy_command failed: {error_msg}") + + +async def cancel_resending_copy_command(self, write_cap_hash: bytes) -> None: + """ + Cancels ARQ resending for a copy command. + + This method stops the automatic repeat request (ARQ) for a previously started + copy command. Use this when: + - The copy operation should be aborted + - The operation is no longer needed + - You want to clean up pending ARQ operations + + Args: + write_cap_hash: Hash of the WriteCap used in start_resending_copy_command. + + Raises: + Exception: If the cancellation fails. + + Example: + >>> await client.cancel_resending_copy_command(write_cap_hash) + """ + query_id = self.new_query_id() + + request = { + "cancel_resending_copy_command": { + "query_id": query_id, + "write_cap_hash": write_cap_hash + } + } + + try: + reply = await self._send_and_wait(query_id=query_id, request=request) + except Exception as e: + self.logger.error(f"Error cancelling resending copy command: {e}") + raise + + if reply.get('error_code', 0) != THIN_CLIENT_SUCCESS: + error_msg = thin_client_error_to_string(reply['error_code']) + raise Exception(f"cancel_resending_copy_command failed: {error_msg}") + + +async def create_courier_envelopes_from_payload( + self, + query_id: bytes, + stream_id: bytes, + payload: bytes, + dest_write_cap: bytes, + dest_start_index: bytes, + is_last: bool +) -> "CreateEnvelopesResult": + """ + Creates multiple CourierEnvelopes from a payload of any size. + + The payload is automatically chunked and each chunk is wrapped in a + CourierEnvelope. Each returned chunk is a serialized CopyStreamElement + ready to be written to a box. + + Multiple calls can be made with the same stream_id to build up a stream + incrementally. The first call creates a new encoder (first element gets + IsStart=true). The final call should have is_last=True (last element + gets IsFinal=true). + + The buffer in the result contains the current encoder buffer which + you should persist for crash recovery. On restart, use `set_stream_buffer` + to restore the state before continuing the stream. + + Args: + query_id: 16-byte query identifier for correlating requests and replies. + stream_id: 16-byte identifier for the encoder instance. All calls for + the same stream must use the same stream ID. + payload: The data to be encoded into courier envelopes. + dest_write_cap: Write capability for the destination channel. + dest_start_index: Starting index in the destination channel. + is_last: Whether this is the last payload in the sequence. When True, + the final CopyStreamElement will have IsFinal=true and the + encoder instance will be removed. + + Returns: + CreateEnvelopesResult: Contains envelopes and buffer state for crash recovery. + + Raises: + Exception: If the envelope creation fails. + + Example: + >>> query_id = client.new_query_id() + >>> stream_id = client.new_stream_id() + >>> result = await client.create_courier_envelopes_from_payload( + ... query_id, stream_id, payload, dest_write_cap, dest_start_index, is_last=False) + >>> # Persist buffer for crash recovery + >>> save_to_disk(stream_id, result.buffer) + >>> for env in result.envelopes: + ... # Write each envelope to the copy stream + ... pass + """ + + request = { + "create_courier_envelopes_from_payload": { + "query_id": query_id, + "stream_id": stream_id, + "payload": payload, + "dest_write_cap": dest_write_cap, + "dest_start_index": dest_start_index, + "is_last": is_last + } + } + + try: + reply = await self._send_and_wait(query_id=query_id, request=request) + except Exception as e: + self.logger.error(f"Error creating courier envelopes from payload: {e}") + raise + + if reply.get('error_code', 0) != THIN_CLIENT_SUCCESS: + error_msg = thin_client_error_to_string(reply['error_code']) + raise Exception(f"create_courier_envelopes_from_payload failed: {error_msg}") + + return CreateEnvelopesResult( + envelopes=reply.get("envelopes", []), + buffer=reply.get("buffer", b"") + ) + + +async def create_courier_envelopes_from_multi_payload( + self, + stream_id: bytes, + destinations: "List[Dict[str, Any]]", + is_last: bool +) -> "CreateEnvelopesResult": + """ + Creates CourierEnvelopes from multiple payloads going to different destinations. + + This is more space-efficient than calling create_courier_envelopes_from_payload + multiple times because envelopes from different destinations are packed + together in the copy stream without wasting space. + + Multiple calls can be made with the same stream_id to build up a stream + incrementally. The first call creates a new encoder (first element gets + IsStart=true). The final call should have is_last=True (last element + gets IsFinal=true). + + The buffer in the result contains the current encoder buffer which + you should persist for crash recovery. On restart, use `set_stream_buffer` + to restore the state before continuing the stream. + + Args: + stream_id: 16-byte identifier for the encoder instance. All calls for + the same stream must use the same stream ID. + destinations: List of destination payloads, each a dict with: + - "payload": bytes - The data to be written + - "write_cap": bytes - Write capability for destination + - "start_index": bytes - Starting index in destination + is_last: Whether this is the last set of payloads in the sequence. + When True, the final CopyStreamElement will have IsFinal=true + and the encoder instance will be removed. + + Returns: + CreateEnvelopesResult: Contains envelopes and buffer state for crash recovery. + + Raises: + Exception: If the envelope creation fails. + + Example: + >>> stream_id = client.new_stream_id() + >>> destinations = [ + ... {"payload": data1, "write_cap": cap1, "start_index": idx1}, + ... {"payload": data2, "write_cap": cap2, "start_index": idx2}, + ... ] + >>> result = await client.create_courier_envelopes_from_multi_payload( + ... stream_id, destinations, is_last=False) + >>> # Persist buffer for crash recovery + >>> save_to_disk(stream_id, result.buffer) + """ + query_id = self.new_query_id() + + request = { + "create_courier_envelopes_from_multi_payload": { + "query_id": query_id, + "stream_id": stream_id, + "destinations": destinations, + "is_last": is_last + } + } + + try: + reply = await self._send_and_wait(query_id=query_id, request=request) + except Exception as e: + self.logger.error(f"Error creating courier envelopes from payloads: {e}") + raise + + if reply.get('error_code', 0) != THIN_CLIENT_SUCCESS: + error_msg = thin_client_error_to_string(reply['error_code']) + raise Exception(f"create_courier_envelopes_from_multi_payload failed: {error_msg}") + + return CreateEnvelopesResult( + envelopes=reply.get("envelopes", []), + buffer=reply.get("buffer", b"") + ) + + +@dataclass +class CreateEnvelopesResult: + """Result of creating courier envelopes, including envelopes and buffer for crash recovery.""" + envelopes: "List[bytes]" + """The serialized CopyStreamElements to send to the network.""" + buffer: bytes + """The buffered data that hasn't been output yet. Persist this for crash recovery.""" + + +async def set_stream_buffer( + self, + stream_id: bytes, + buffer: bytes +) -> None: + """ + Restores the buffered state for a given stream ID. + + This is useful for crash recovery: after restart, call this method with the + buffer that was returned by `create_courier_envelopes_from_payload` or + `create_courier_envelopes_from_multi_payload` before the crash/shutdown. + + Note: This will create a new encoder if one doesn't exist for this stream_id, + or replace the buffer contents if one already exists. + + Args: + stream_id: 16-byte identifier for the encoder instance. + buffer: The buffered data to restore (from CreateEnvelopesResult.buffer). + + Returns: + None + + Raises: + ValueError: If stream_id is not exactly 16 bytes. + Exception: If the operation fails. + + Example: + >>> # During streaming, save the buffer from each call + >>> result = await client.create_courier_envelopes_from_payload( + ... query_id, stream_id, data, ..., is_last=False) + >>> save_to_disk(stream_id, result.buffer) + >>> + >>> # On restart, restore the stream state + >>> buffer = load_from_disk(stream_id) + >>> await client.set_stream_buffer(stream_id, buffer) + >>> # Now continue streaming from where we left off + >>> await client.create_courier_envelopes_from_payload( + ... query_id, stream_id, more_data, ..., is_last=True) + """ + if len(stream_id) != STREAM_ID_LENGTH: + raise ValueError(f"stream_id must be exactly {STREAM_ID_LENGTH} bytes") + + query_id = self.new_query_id() + + request = { + "set_stream_buffer": { + "query_id": query_id, + "stream_id": stream_id, + "buffer": buffer + } + } + + try: + reply = await self._send_and_wait(query_id=query_id, request=request) + except Exception as e: + self.logger.error(f"Error setting stream buffer: {e}") + raise + + if reply.get('error_code', 0) != THIN_CLIENT_SUCCESS: + error_msg = thin_client_error_to_string(reply['error_code']) + raise Exception(f"set_stream_buffer failed: {error_msg}") + + +async def tombstone_box( + self, + write_cap: bytes, + box_index: bytes +) -> EncryptWriteResult: + """ + Create a tombstone for a single pigeonhole box. + + This method creates a tombstone (empty payload with signature) for deleting + the specified box. The caller must send the returned values via + start_resending_encrypted_message to complete the tombstone operation. + + Args: + write_cap: Write capability for the box. + box_index: Index of the box to tombstone. + + Returns: + EncryptWriteResult: Contains message_ciphertext, envelope_descriptor, + and envelope_hash. + + Raises: + ValueError: If any argument is None. + Exception: If the encrypt operation fails. + + Example: + >>> result = await client.tombstone_box(write_cap, box_index) + >>> await client.start_resending_encrypted_message( + ... None, write_cap, None, None, + ... result.envelope_descriptor, result.message_ciphertext, result.envelope_hash) + """ + if write_cap is None: + raise ValueError("write_cap cannot be None") + if box_index is None: + raise ValueError("box_index cannot be None") + + # Tombstones are created by sending an empty plaintext to encrypt_write + # The daemon will detect this and sign an empty payload instead of encrypting + return await self.encrypt_write(b'', write_cap, box_index) + + +async def tombstone_range( + self, + write_cap: bytes, + start: bytes, + max_count: int +) -> "Dict[str, Any]": + """ + Create tombstones for a range of pigeonhole boxes. + + This method creates tombstones for up to max_count boxes, + starting from the specified box index and advancing through consecutive + indices. The caller must send each envelope via start_resending_encrypted_message + to complete the tombstone operations. + + If an error occurs during the operation, a partial result is returned + containing the envelopes created so far and the next index. + + Args: + write_cap: Write capability for the boxes. + start: Starting MessageBoxIndex. + max_count: Maximum number of boxes to tombstone. + + Returns: + Dict[str, Any]: A dictionary with: + - "envelopes" (List[Dict]): List of envelope dicts, each containing: + - "message_ciphertext": The tombstone payload. + - "envelope_descriptor": The envelope descriptor. + - "envelope_hash": The envelope hash for cancellation. + - "box_index": The box index this envelope is for. + - "next" (bytes): The next MessageBoxIndex after the last processed. + + Raises: + ValueError: If write_cap or start is None. + + Example: + >>> result = await client.tombstone_range(write_cap, start_index, 10) + >>> for envelope in result["envelopes"]: + ... await client.start_resending_encrypted_message( + ... None, write_cap, None, None, + ... envelope["envelope_descriptor"], + ... envelope["message_ciphertext"], + ... envelope["envelope_hash"]) + """ + if write_cap is None: + raise ValueError("write_cap cannot be None") + if start is None: + raise ValueError("start index cannot be None") + if max_count == 0: + return {"envelopes": [], "next": start} + + cur = start + envelopes = [] + + while len(envelopes) < max_count: + try: + result = await self.tombstone_box(write_cap, cur) + envelopes.append({ + "message_ciphertext": result.message_ciphertext, + "envelope_descriptor": result.envelope_descriptor, + "envelope_hash": result.envelope_hash, + "box_index": cur, + }) + except Exception as e: + self.logger.error(f"Error creating tombstone for box at index {len(envelopes)}: {e}") + return {"envelopes": envelopes, "next": cur, "error": str(e)} + + try: + cur = await self.next_message_box_index(cur) + except Exception as e: + self.logger.error(f"Error getting next index after creating tombstone: {e}") + return {"envelopes": envelopes, "next": cur, "error": str(e)} + + return {"envelopes": envelopes, "next": cur} + diff --git a/pyproject.toml b/pyproject.toml index ef3998f..da8f132 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,5 +39,6 @@ test = [ "pytest", "pytest-cov", "pytest-asyncio", + "pytest-timeout", ] diff --git a/pytest.ini b/pytest.ini index c372928..f0a59b3 100644 --- a/pytest.ini +++ b/pytest.ini @@ -24,6 +24,8 @@ addopts = --durations=10 # Timeout configuration +# Default timeout per test: 5 minutes (300 seconds) for unit tests +# Integration tests override this with --timeout flag in CI timeout = 300 timeout_method = thread diff --git a/src/bin/copycat.rs b/src/bin/copycat.rs new file mode 100644 index 0000000..062615e --- /dev/null +++ b/src/bin/copycat.rs @@ -0,0 +1,438 @@ +// SPDX-FileCopyrightText: © 2026 David Stainton +// SPDX-License-Identifier: AGPL-3.0-only + +//! copycat - A CLI tool for reading and writing to Katzenpost pigeonhole channels +//! +//! Similar to cat or netcat, copycat can: +//! - Read from stdin or a file and write to a copy stream (send mode) +//! - Read from a channel and write to stdout (receive mode) + +use std::fs::File; +use std::io::{self, BufReader, Read, Write}; +use std::path::PathBuf; +use std::sync::Arc; +use std::time::Duration; + +use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64}; +use clap::{Parser, Subcommand}; +use tokio::time::sleep; + +use katzenpost_thin_client::{Config, ThinClient, ThinClientError}; +use katzenpost_thin_client::persistent::{ + PigeonholeClient, ReadCapability, PigeonholeDbError, +}; + +/// Chunk size for streaming input data (4KB) +/// Smaller chunks give more frequent progress updates and lower memory usage. +/// Each chunk is processed by add_multi_payload which creates courier envelopes. +const CHUNK_SIZE: usize = 4 * 1024; + +#[derive(Parser)] +#[command(name = "copycat")] +#[command(about = "Katzenpost pigeonhole copy stream tool")] +#[command(long_about = "A CLI tool for reading and writing to Katzenpost pigeonhole channels.\n\n\ +Similar to cat or netcat, copycat can:\n\ +- Read from stdin or a file and write to a copy stream (send mode)\n\ +- Read from a channel and write to stdout (receive mode)\n\n\ +This tool uses the Pigeonhole protocol with Copy Commands to provide\n\ +reliable message delivery through the mixnet.")] +struct Cli { + #[command(subcommand)] + command: Commands, +} + +#[derive(Subcommand)] +enum Commands { + /// Generate a new keypair and print both capabilities + Genkey { + /// Configuration file (required) + #[arg(short, long)] + config: PathBuf, + }, + + /// Read from stdin or file and write to a copy stream + Send { + /// Configuration file (required) + #[arg(short, long)] + config: PathBuf, + + /// Write capability (base64) + #[arg(short, long)] + write_cap: String, + + /// Input file (default: stdin) + #[arg(short, long)] + file: Option, + + /// Start index (base64, optional) + #[arg(short, long)] + index: Option, + }, + + /// Read from a channel and write to stdout + Receive { + /// Configuration file (required) + #[arg(short, long)] + config: PathBuf, + + /// Read capability (base64) + #[arg(short, long)] + read_cap: String, + + /// Start index (base64, optional) + #[arg(short, long)] + index: Option, + }, +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + env_logger::init(); + let cli = Cli::parse(); + + match cli.command { + Commands::Genkey { config } => run_genkey(config).await, + Commands::Send { config, write_cap, file, index } => { + run_send(config, write_cap, file, index).await + } + Commands::Receive { config, read_cap, index } => { + run_receive(config, read_cap, index).await + } + } +} + +/// Initialize the thin client from config file +async fn init_client(config_path: PathBuf) -> Result, Box> { + let cfg = Config::new(config_path.to_str().ok_or("Invalid config path")?)?; + let client = ThinClient::new(cfg).await?; + + // Wait for PKI document with timeout + eprintln!("Waiting for PKI document..."); + let timeout = Duration::from_secs(60); + let start = std::time::Instant::now(); + + loop { + if start.elapsed() > timeout { + return Err("Timeout waiting for PKI document".into()); + } + if client.pki_document().await.is_ok() { + break; + } + sleep(Duration::from_millis(100)).await; + } + + eprintln!("Connected to mixnet"); + Ok(client) +} + +/// Generate a new keypair and print capabilities +async fn run_genkey(config: PathBuf) -> Result<(), Box> { + let client = init_client(config).await?; + let pigeonhole = PigeonholeClient::new_in_memory(client)?; + + // Create a temporary channel to generate the keypair + let channel = pigeonhole.create_channel("genkey-temp").await?; + + // Get the raw capabilities + let write_cap = channel.write_cap().ok_or("Failed to get write capability")?; + let read_cap = channel.read_cap(); + let first_index = channel.write_index().ok_or("Failed to get write index")?; + + println!("Read Capability (share with recipient):"); + println!("{}\n", BASE64.encode(read_cap)); + + println!("Write Capability (keep secret):"); + println!("{}\n", BASE64.encode(write_cap)); + + println!("First Index:"); + println!("{}", BASE64.encode(first_index)); + + Ok(()) +} + +/// Read from stdin or file and send via copy stream +async fn run_send( + config: PathBuf, + write_cap_b64: String, + input_file: Option, + start_index_b64: Option, +) -> Result<(), Box> { + // Decode write capability + let write_cap = BASE64.decode(&write_cap_b64)?; + + // BACAP WriteCap is 168 bytes: 64-byte PrivateKey + 104-byte MessageBoxIndex + // Extract start_index from the write_cap if not provided explicitly + const PRIVATE_KEY_SIZE: usize = 64; + const MESSAGE_BOX_INDEX_SIZE: usize = 104; + const WRITE_CAP_SIZE: usize = PRIVATE_KEY_SIZE + MESSAGE_BOX_INDEX_SIZE; + + let start_index = if let Some(idx_b64) = start_index_b64 { + BASE64.decode(&idx_b64)? + } else if write_cap.len() == WRITE_CAP_SIZE { + // Extract firstMessageBoxIndex from bytes 64-168 of the WriteCap + write_cap[PRIVATE_KEY_SIZE..].to_vec() + } else { + return Err(format!( + "Invalid write capability size: {} bytes (expected {} bytes). \ + Either provide a full WriteCap or use -i flag to specify start index.", + write_cap.len(), WRITE_CAP_SIZE + ).into()); + }; + + // Determine input source and total size + // For files: get size from metadata and stream in chunks (memory efficient) + // For stdin: must buffer to determine total size for length prefix + let (total_len, mut input_reader): (u64, Box) = if let Some(ref path) = input_file { + let metadata = std::fs::metadata(path)?; + let file = File::open(path)?; + (metadata.len(), Box::new(BufReader::new(file))) + } else { + // For stdin, we must read all data to know the length + eprintln!("Reading from stdin (buffering to determine length)..."); + let mut buf = Vec::new(); + io::stdin().read_to_end(&mut buf)?; + let len = buf.len() as u64; + (len, Box::new(std::io::Cursor::new(buf))) + }; + + eprintln!("Sending {} bytes (with 4-byte length prefix)", total_len); + + // Initialize client + let client = init_client(config).await?; + let pigeonhole = PigeonholeClient::new_in_memory(client.clone())?; + + // Create a temporary channel for copy stream operations + eprintln!("Creating temporary copy stream channel..."); + let channel = pigeonhole.create_channel("copycat-send").await?; + + // Create copy stream builder + eprintln!("Initializing copy stream builder..."); + let mut builder = channel.copy_stream_builder().await?; + + // Calculate total size with 4-byte length prefix + let total_with_prefix = 4 + total_len as usize; + let total_chunks = (total_with_prefix + CHUNK_SIZE - 1) / CHUNK_SIZE; + + eprintln!("Uploading {} bytes in {} chunk(s)...", total_with_prefix, total_chunks); + + // Stream data in chunks + let mut bytes_sent: usize = 0; + let mut chunk_num = 0; + let mut chunk_buf = vec![0u8; CHUNK_SIZE]; + let mut first_chunk = true; + + loop { + // Build the current chunk + let mut payload = Vec::with_capacity(CHUNK_SIZE); + + // First chunk includes the 4-byte length prefix + if first_chunk { + payload.extend_from_slice(&(total_len as u32).to_be_bytes()); + first_chunk = false; + } + + // Fill remaining space in chunk from input + let space_remaining = CHUNK_SIZE - payload.len(); + let bytes_to_read = space_remaining.min(total_len as usize - (bytes_sent.saturating_sub(4).min(total_len as usize))); + + if bytes_to_read > 0 { + let n = input_reader.read(&mut chunk_buf[..bytes_to_read])?; + if n > 0 { + payload.extend_from_slice(&chunk_buf[..n]); + } + } + + if payload.is_empty() { + break; + } + + bytes_sent += payload.len(); + let is_last = bytes_sent >= total_with_prefix; + + // Use add_multi_payload for efficient packing + let destinations = vec![(payload.as_slice(), write_cap.as_slice(), start_index.as_slice())]; + let envelopes_written = builder + .add_multi_payload(destinations, is_last) + .await?; + + let progress_pct = (bytes_sent as f64 / total_with_prefix as f64 * 100.0).min(100.0); + eprintln!( + "Chunk {}/{}: {} bytes, {} envelopes ({:.1}%)", + chunk_num + 1, total_chunks, payload.len(), envelopes_written, progress_pct + ); + + chunk_num += 1; + + if is_last { + break; + } + } + + // Execute the copy command + eprintln!("Sending Copy command to courier..."); + let total_boxes = builder.finish().await?; + eprintln!("Copy command completed successfully ({} boxes written)", total_boxes); + + Ok(()) +} + +/// Receive messages from a channel and write to stdout +/// +/// This function reads boxes with retry logic until all data specified +/// by the length prefix has been received. +async fn run_receive( + config: PathBuf, + read_cap_b64: String, + start_index_b64: Option, +) -> Result<(), Box> { + // Decode read capability + let read_cap_bytes = BASE64.decode(&read_cap_b64)?; + + // BACAP ReadCap is 136 bytes: 32-byte PublicKey + 104-byte MessageBoxIndex + // Extract start_index from the read_cap if not provided explicitly + const PUBLIC_KEY_SIZE: usize = 32; + const MESSAGE_BOX_INDEX_SIZE: usize = 104; + const READ_CAP_SIZE: usize = PUBLIC_KEY_SIZE + MESSAGE_BOX_INDEX_SIZE; + + let start_index = if let Some(idx_b64) = start_index_b64 { + BASE64.decode(&idx_b64)? + } else if read_cap_bytes.len() == READ_CAP_SIZE { + // Extract firstMessageBoxIndex from bytes 32-136 of the ReadCap + read_cap_bytes[PUBLIC_KEY_SIZE..].to_vec() + } else { + return Err(format!( + "Invalid read capability size: {} bytes (expected {} bytes). \ + Either provide a full ReadCap or use -i flag to specify start index.", + read_cap_bytes.len(), READ_CAP_SIZE + ).into()); + }; + + // Initialize client + let client = init_client(config).await?; + let pigeonhole = PigeonholeClient::new_in_memory(client.clone())?; + + let read_capability = ReadCapability { + read_cap: read_cap_bytes, + start_index: start_index.clone(), + name: Some("copycat-receive".to_string()), + }; + + // Import the channel + let mut channel = pigeonhole.import_channel("copycat-receive", &read_capability)?; + + eprintln!("Reading with length prefix..."); + + // Buffer to accumulate all received data + let mut received_data = Vec::new(); + let mut expected_len: Option = None; + let mut box_num = 0; + + const MAX_RETRIES: u32 = 6; + const BASE_DELAY_MS: u64 = 500; + + // Keep reading until we have all expected data + loop { + let mut plaintext: Option> = None; + + // Try to read the next box with retries + // Use receive_no_retry() for immediate feedback, then manually retry with backoff + for attempt in 0..MAX_RETRIES { + eprintln!("Attempting to read box {} (attempt {}/{})...", box_num, attempt + 1, MAX_RETRIES); + + match channel.receive_no_retry().await { + Ok(data) if !data.is_empty() => { + plaintext = Some(data); + break; + } + Ok(_) => { + // Empty data = tombstone, treat as end of stream + eprintln!("Box {} is a tombstone (empty), stopping", box_num); + plaintext = Some(Vec::new()); + break; + } + Err(PigeonholeDbError::ThinClient(ThinClientError::BoxNotFound)) => { + // Box doesn't exist yet - retry with backoff + eprintln!("Box {} not found (attempt {}/{})", box_num, attempt + 1, MAX_RETRIES); + if attempt < MAX_RETRIES - 1 { + let delay = BASE_DELAY_MS * (1 << attempt.min(6)); + eprintln!("Retrying in {}ms...", delay); + sleep(Duration::from_millis(delay)).await; + } + } + Err(e) => { + // Other errors - log and retry + eprintln!("Box {} error: {:?} (attempt {}/{})", box_num, e, attempt + 1, MAX_RETRIES); + if attempt < MAX_RETRIES - 1 { + let delay = BASE_DELAY_MS * (1 << attempt.min(6)); + eprintln!("Retrying in {}ms...", delay); + sleep(Duration::from_millis(delay)).await; + } + } + } + } + + let data = plaintext.ok_or_else(|| { + format!("Failed to read box {} after {} retries", box_num, MAX_RETRIES) + })?; + + // Tombstone = end of stream + if data.is_empty() { + eprintln!("Reached tombstone at box {}", box_num); + break; + } + + // Accumulate received data + let data_len = data.len(); + received_data.extend_from_slice(&data); + box_num += 1; + + // Check if we now know the expected length + if expected_len.is_none() && received_data.len() >= 4 { + let len = u32::from_be_bytes([ + received_data[0], + received_data[1], + received_data[2], + received_data[3], + ]); + expected_len = Some(len); + eprintln!("Expected payload length: {} bytes", len); + } + + // Print progress + if let Some(len) = expected_len { + let total_expected = 4 + len as usize; + let percent = (received_data.len() as f64 / total_expected as f64 * 100.0).min(100.0); + eprintln!( + "Box {}: received {} bytes ({}/{} bytes, {:.1}%)", + box_num, data_len, received_data.len(), total_expected, percent + ); + } else { + eprintln!("Box {}: received {} bytes (total so far: {} bytes)", box_num, data_len, received_data.len()); + } + + // Check if we have all the data (4-byte prefix + expected_len bytes) + if let Some(len) = expected_len { + if received_data.len() >= 4 + len as usize { + eprintln!("Received all {} bytes in {} boxes", len, box_num); + break; + } + } + } + + // Strip the 4-byte length prefix and write the actual payload to stdout + let expected = expected_len.ok_or("No data received")? as usize; + if received_data.len() < 4 + expected { + return Err(format!( + "Received data too short: {} bytes, expected {}", + received_data.len(), + 4 + expected + ).into()); + } + + let payload = &received_data[4..4 + expected]; + io::stdout().write_all(payload)?; + + eprintln!("Done"); + Ok(()) +} + diff --git a/src/chat/mod.rs b/src/chat/mod.rs new file mode 100644 index 0000000..3eacd5a --- /dev/null +++ b/src/chat/mod.rs @@ -0,0 +1,155 @@ +// SPDX-FileCopyrightText: Copyright (C) 2026 David Stainton +// SPDX-License-Identifier: AGPL-3.0-only + +//! Group chat built on top of [`GroupChannel`]. +//! +//! [`GroupChat`] is a thin newtype around `GroupChannel` that adds +//! typed convenience methods (`send_text`, `send_introduction`) so callers +//! never have to construct [`ChatEvent`] variants by hand. +//! +//! # Example +//! +//! ```ignore +//! let chat = GroupChat::create(&pigeonhole, "my-room", "Alice").await?; +//! let intro = chat.my_introduction(); +//! +//! // share `intro` out-of-band, then: +//! chat.add_member(&pigeonhole, &bob_intro)?; +//! +//! chat.send_text("hello everyone!").await?; +//! +//! let events = chat.receive_from_all().await?; +//! for e in events { +//! if let ChatEvent::Text(msg) = e.event { +//! println!("{}: {}", e.sender, msg); +//! } +//! } +//! ``` + +use std::time::Duration; + +use serde::{Deserialize, Serialize}; + +use crate::group::channel::{GroupChannel, ReceivedGroupEvent}; +use crate::group::Introduction; +use crate::persistent::error::Result; +use crate::persistent::PigeonholeClient; + +// --------------------------------------------------------------------------- +// Event type +// --------------------------------------------------------------------------- + +/// The event type carried by every [`GroupChat`] channel. +/// +/// `Text` is a plain UTF-8 message. `Introduction` carries the read +/// capability of a new member so that existing members can add them without +/// a separate out-of-band exchange. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum ChatEvent { + Text(String), + Introduction(Introduction), +} + +// --------------------------------------------------------------------------- +// GroupChat +// --------------------------------------------------------------------------- + +/// A group chat room backed by a [`GroupChannel`]. +/// +/// Wraps all membership and receive operations from the inner channel and +/// adds typed send helpers so callers work with chat concepts rather than +/// raw event envelopes. +pub struct GroupChat(GroupChannel); + +impl GroupChat { + /// Create a new chat room and generate the local member's channel. + pub async fn create( + pigeonhole: &PigeonholeClient, + room_name: &str, + my_display_name: &str, + ) -> Result { + GroupChannel::create(pigeonhole, room_name, my_display_name) + .await + .map(Self) + } + + /// Restore a previously created chat room from persisted channels. + pub async fn restore( + pigeonhole: &PigeonholeClient, + room_name: &str, + my_display_name: &str, + member_intros: &[Introduction], + ) -> Result { + GroupChannel::restore(pigeonhole, room_name, my_display_name, member_intros) + .await + .map(Self) + } + + /// Return an [`Introduction`] suitable for sharing with new members. + pub fn my_introduction(&self) -> Introduction { + self.0.my_introduction() + } + + /// Number of remote member channels currently tracked. + pub fn member_count(&self) -> usize { + self.0.member_count() + } + + /// Import a member's read capability and start tracking their channel. + /// + /// Pass [`Introduction::member_id`] to `remove_member` to undo this. + pub fn add_member(&self, pigeonhole: &PigeonholeClient, intro: &Introduction) -> Result<()> { + self.0.add_member(pigeonhole, intro) + } + + /// Remove a member's channel from local tracking. + /// + /// `member_id` is [`Introduction::member_id`] for the member to remove. + /// Returns `true` if the member was present. + pub fn remove_member(&self, member_id: &str) -> bool { + self.0.remove_member(member_id) + } + + // ----------------------------------------------------------------------- + // Typed send helpers + // ----------------------------------------------------------------------- + + /// Send a plain-text message to the group. + pub async fn send_text(&self, text: &str) -> Result<()> { + self.0.send(&ChatEvent::Text(text.to_string())).await + } + + /// Broadcast an [`Introduction`] so existing members can add the newcomer. + pub async fn send_introduction(&self, intro: &Introduction) -> Result<()> { + self.0.send(&ChatEvent::Introduction(intro.clone())).await + } + + // ----------------------------------------------------------------------- + // Receive + // ----------------------------------------------------------------------- + + /// Block until every member delivers one event. + pub async fn receive_from_all(&self) -> Result>> { + self.0.receive_from_all().await + } + + /// Like [`receive_from_all`] but returns after `timeout` with partial results. + pub async fn receive_from_all_timeout( + &self, + timeout: Duration, + ) -> Result>> { + self.0.receive_from_all_timeout(timeout).await + } + + /// Block until any member delivers an event. + pub async fn receive_any(&self) -> Result> { + self.0.receive_any().await + } + + /// Block until the member identified by `member_id` delivers an event. + /// + /// `member_id` is [`Introduction::member_id`] for the target member. + pub async fn receive_from(&self, member_id: &str) -> Result> { + self.0.receive_from(member_id).await + } +} diff --git a/src/core.rs b/src/core.rs new file mode 100644 index 0000000..d8c4d2a --- /dev/null +++ b/src/core.rs @@ -0,0 +1,616 @@ +// SPDX-FileCopyrightText: Copyright (C) 2026 David Stainton +// SPDX-License-Identifier: AGPL-3.0-only + +//! This module provides the main ThinClient struct and core functionality for +//! connecting to the client daemon, managing events, and sending messages. + +use std::collections::{BTreeMap, HashMap}; +use std::sync::{Arc, atomic::{AtomicBool, Ordering}}; + +use serde_cbor::{from_slice, Value}; + +use tokio::sync::{Mutex, RwLock, mpsc, oneshot}; +use tokio::task::JoinHandle; +use tokio::net::{TcpStream, UnixStream}; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::tcp::{OwnedReadHalf as TcpReadHalf, OwnedWriteHalf as TcpWriteHalf}; +use tokio::net::unix::{OwnedReadHalf as UnixReadHalf, OwnedWriteHalf as UnixWriteHalf}; + +use rand::RngCore; +use log::{debug, error}; + +use crate::error::ThinClientError; +use crate::{Config, ServiceDescriptor, PigeonholeGeometry}; +use crate::helpers::find_services; + +/// The size in bytes of a SURB (Single-Use Reply Block) identifier. +const SURB_ID_SIZE: usize = 16; + +/// The size in bytes of a message identifier. +const MESSAGE_ID_SIZE: usize = 16; + +/// The size in bytes of a query identifier. +const QUERY_ID_SIZE: usize = 16; + +/// This represent the read half of our network socket. +pub enum ReadHalf { + Tcp(TcpReadHalf), + Unix(UnixReadHalf), +} + +/// This represent the write half of our network socket. +pub enum WriteHalf { + Tcp(TcpWriteHalf), + Unix(UnixWriteHalf), +} + +/// Wrapper for event sink receiver that automatically removes the drain when dropped +pub struct EventSinkReceiver { + receiver: mpsc::UnboundedReceiver>, + sender: mpsc::UnboundedSender>, + drain_remove: mpsc::UnboundedSender>>, +} + +impl EventSinkReceiver { + /// Receive the next event from the sink + pub async fn recv(&mut self) -> Option> { + self.receiver.recv().await + } +} + +impl Drop for EventSinkReceiver { + fn drop(&mut self) { + // Remove the drain when the receiver is dropped + if let Err(_) = self.drain_remove.send(self.sender.clone()) { + debug!("Failed to remove drain channel - event sink worker may be stopped"); + } + } +} + +/// This is our ThinClient type which encapsulates our thin client +/// connection management and message processing. +pub struct ThinClient { + read_half: Mutex, + write_half: Mutex, + config: Config, + pki_doc: Arc>>>, + worker_task: Mutex>>, + event_sink_task: Mutex>>, + shutdown: Arc, + is_connected: Arc, + // Event system like Go implementation + event_sink: mpsc::UnboundedSender>, + drain_add: mpsc::UnboundedSender>>, + drain_remove: mpsc::UnboundedSender>>, + // Response routing like Python implementation - keyed by query_id + response_channels: Arc, oneshot::Sender>>>>, +} + + +impl ThinClient { + + /// Create a new thin cilent and connect it to the client daemon. + pub async fn new(config: Config) -> Result, Box> { + // Create event system channels like Go implementation + let (event_sink_tx, event_sink_rx) = mpsc::unbounded_channel(); + let (drain_add_tx, drain_add_rx) = mpsc::unbounded_channel(); + let (drain_remove_tx, drain_remove_rx) = mpsc::unbounded_channel(); + + // Shared response channels map + let response_channels = Arc::new(Mutex::new(HashMap::new())); + + let client = match config.network.to_uppercase().as_str() { + "TCP" => { + let socket = TcpStream::connect(&config.address).await?; + let (read_half, write_half) = socket.into_split(); + Arc::new(Self { + read_half: Mutex::new(ReadHalf::Tcp(read_half)), + write_half: Mutex::new(WriteHalf::Tcp(write_half)), + config, + pki_doc: Arc::new(RwLock::new(None)), + worker_task: Mutex::new(None), + event_sink_task: Mutex::new(None), + shutdown: Arc::new(AtomicBool::new(false)), + is_connected: Arc::new(AtomicBool::new(false)), + event_sink: event_sink_tx.clone(), + drain_add: drain_add_tx.clone(), + drain_remove: drain_remove_tx.clone(), + response_channels: response_channels.clone(), + }) + } + "UNIX" => { + let path = if config.address.starts_with('@') { + let mut p = String::from("\0"); + p.push_str(&config.address[1..]); + p + } else { + config.address.clone() + }; + let socket = UnixStream::connect(path).await?; + let (read_half, write_half) = socket.into_split(); + Arc::new(Self { + read_half: Mutex::new(ReadHalf::Unix(read_half)), + write_half: Mutex::new(WriteHalf::Unix(write_half)), + config, + pki_doc: Arc::new(RwLock::new(None)), + worker_task: Mutex::new(None), + event_sink_task: Mutex::new(None), + shutdown: Arc::new(AtomicBool::new(false)), + is_connected: Arc::new(AtomicBool::new(false)), + event_sink: event_sink_tx, + drain_add: drain_add_tx, + drain_remove: drain_remove_tx, + response_channels, + }) + } + _ => { + return Err(format!("Unknown network type: {}", config.network).into()); + } + }; + + // Start worker loop + let client_clone = Arc::clone(&client); + let task = tokio::spawn(async move { client_clone.worker_loop().await }); + *client.worker_task.lock().await = Some(task); + + // Start event sink worker + let client_clone2 = Arc::clone(&client); + let event_sink_task = tokio::spawn(async move { + client_clone2.event_sink_worker(event_sink_rx, drain_add_rx, drain_remove_rx).await + }); + *client.event_sink_task.lock().await = Some(event_sink_task); + + debug!("✅ ThinClient initialized with worker loop and event sink started."); + Ok(client) + } + + /// Stop our async worker task and disconnect the thin client. + pub async fn stop(&self) { + debug!("Stopping ThinClient..."); + + self.shutdown.store(true, Ordering::Relaxed); + + let mut write_half = self.write_half.lock().await; + + let _ = match &mut *write_half { + WriteHalf::Tcp(wh) => wh.shutdown().await, + WriteHalf::Unix(wh) => wh.shutdown().await, + }; + + if let Some(worker) = self.worker_task.lock().await.take() { + worker.abort(); + } + + debug!("✅ ThinClient stopped."); + } + + /// Returns true if the daemon is connected to the mixnet. + pub fn is_connected(&self) -> bool { + self.is_connected.load(Ordering::Relaxed) + } + + /// Creates a new event channel that receives all events from the thin client + /// This mirrors the Go implementation's EventSink method + pub fn event_sink(&self) -> EventSinkReceiver { + let (tx, rx) = mpsc::unbounded_channel(); + if let Err(_) = self.drain_add.send(tx.clone()) { + debug!("Failed to add drain channel - event sink worker may be stopped"); + } + EventSinkReceiver { + receiver: rx, + sender: tx, + drain_remove: self.drain_remove.clone(), + } + } + + /// Generates a new message ID. + pub fn new_message_id() -> Vec { + let mut id = vec![0; MESSAGE_ID_SIZE]; + rand::thread_rng().fill_bytes(&mut id); + id + } + + /// Generates a new SURB ID. + pub fn new_surb_id() -> Vec { + let mut id = vec![0; SURB_ID_SIZE]; + rand::thread_rng().fill_bytes(&mut id); + id + } + + /// Generates a new query ID. + pub fn new_query_id() -> Vec { + let mut id = vec![0; QUERY_ID_SIZE]; + rand::thread_rng().fill_bytes(&mut id); + id + } + + async fn update_pki_document(&self, new_pki_doc: BTreeMap) { + let mut pki_doc_lock = self.pki_doc.write().await; + *pki_doc_lock = Some(new_pki_doc); + debug!("PKI document updated."); + } + + /// Returns our latest retrieved PKI document. + pub async fn pki_document(&self) -> Result, ThinClientError> { + self.pki_doc.read().await.clone().ok_or(ThinClientError::MissingPkiDocument) + } + + /// Returns the pigeonhole geometry from the config. + /// This geometry defines the payload sizes and envelope formats for the pigeonhole protocol. + pub fn pigeonhole_geometry(&self) -> &PigeonholeGeometry { + &self.config.pigeonhole_geometry + } + + /// Given a service name this returns a ServiceDescriptor if the service exists + /// in the current PKI document. + pub async fn get_service(&self, service_name: &str) -> Result { + let doc = self.pki_doc.read().await.clone().ok_or(ThinClientError::MissingPkiDocument)?; + let services = find_services(service_name, &doc); + services.into_iter().next().ok_or(ThinClientError::ServiceNotFound) + } + + /// Returns a courier service destination for the current epoch. + /// This method finds and randomly selects a courier service from the current + /// PKI document. The returned destination information is used with SendChannelQuery + /// and SendChannelQueryAwaitReply to transmit prepared channel operations. + /// Returns (dest_node, dest_queue) on success. + pub async fn get_courier_destination(&self) -> Result<(Vec, Vec), ThinClientError> { + let courier_service = self.get_service("courier").await?; + let (dest_node, dest_queue) = courier_service.to_destination(); + Ok((dest_node, dest_queue)) + } + + + pub(crate) async fn recv(&self) -> Result, ThinClientError> { + let mut length_prefix = [0; 4]; + { + let mut read_half = self.read_half.lock().await; + match &mut *read_half { + ReadHalf::Tcp(rh) => rh.read_exact(&mut length_prefix).await.map_err(ThinClientError::IoError)?, + ReadHalf::Unix(rh) => rh.read_exact(&mut length_prefix).await.map_err(ThinClientError::IoError)?, + }; + } + let message_length = u32::from_be_bytes(length_prefix) as usize; + let mut buffer = vec![0; message_length]; + { + let mut read_half = self.read_half.lock().await; + match &mut *read_half { + ReadHalf::Tcp(rh) => rh.read_exact(&mut buffer).await.map_err(ThinClientError::IoError)?, + ReadHalf::Unix(rh) => rh.read_exact(&mut buffer).await.map_err(ThinClientError::IoError)?, + }; + } + let response: BTreeMap = match from_slice(&buffer) { + Ok(parsed) => { + parsed + } + Err(err) => { + error!("❌ Failed to parse CBOR: {:?}", err); + return Err(ThinClientError::CborError(err)); + } + }; + Ok(response) + } + + fn parse_status(&self, event: &BTreeMap) { + let is_connected = event.get(&Value::Text("is_connected".to_string())) + .and_then(|v| match v { + Value::Bool(b) => Some(*b), + _ => None, + }) + .unwrap_or(false); + + // Update connection state + self.is_connected.store(is_connected, Ordering::Relaxed); + + if is_connected { + debug!("✅ Daemon is connected to mixnet - full functionality available."); + } else { + debug!("📴 Daemon is not connected to mixnet - entering offline mode (channel operations will work)."); + } + } + + async fn parse_pki_doc(&self, event: &BTreeMap) { + if let Some(Value::Bytes(payload)) = event.get(&Value::Text("payload".to_string())) { + match serde_cbor::from_slice::>(payload) { + Ok(raw_pki_doc) => { + self.update_pki_document(raw_pki_doc).await; + debug!("✅ PKI document successfully parsed."); + } + Err(err) => { + error!("❌ Failed to parse PKI document: {:?}", err); + } + } + } else { + error!("❌ Missing 'payload' field in PKI document event."); + } + } + + async fn handle_response(&self, response: BTreeMap) { + if response.is_empty() { + error!("❌ Received an empty response, ignoring"); + return; + } + + if let Some(Value::Map(event)) = response.get(&Value::Text("connection_status_event".to_string())) { + debug!("🔄 Connection status event received."); + self.parse_status(event); + if let Some(cb) = self.config.on_connection_status.as_ref() { + cb(event); + } + return; + } + + if let Some(Value::Map(event)) = response.get(&Value::Text("new_pki_document_event".to_string())) { + debug!("📜 New PKI document event received."); + self.parse_pki_doc(event).await; + if let Some(cb) = self.config.on_new_pki_document.as_ref() { + cb(event); + } + return; + } + + if let Some(Value::Map(event)) = response.get(&Value::Text("message_sent_event".to_string())) { + debug!("📨 Message sent event received."); + if let Some(cb) = self.config.on_message_sent.as_ref() { + cb(event); + } + return; + } + + if let Some(Value::Map(event)) = response.get(&Value::Text("message_reply_event".to_string())) { + debug!("📩 Message reply event received."); + if let Some(cb) = self.config.on_message_reply.as_ref() { + cb(event); + } + return; + } + + // Route replies to response_channels based on query_id (like Python implementation) + // This handles *_reply messages with query_id fields + for (key, value) in response.iter() { + if let Value::Text(reply_type) = key { + if reply_type.ends_with("_reply") { + if let Value::Map(reply_map) = value { + if let Some(Value::Bytes(query_id)) = reply_map.get(&Value::Text("query_id".to_string())) { + let mut channels = self.response_channels.lock().await; + if let Some(sender) = channels.remove(query_id) { + debug!("Routing {} to waiting caller", reply_type); + let _ = sender.send(reply_map.clone()); + return; + } + } + } + } + } + } + + debug!("Unhandled response (no matching query_id listener): {:?}", response.keys().collect::>()); + } + + async fn worker_loop(&self) { + debug!("Worker loop started"); + while !self.shutdown.load(Ordering::Relaxed) { + match self.recv().await { + Ok(response) => { + // Send all responses to event sink for distribution + if let Err(_) = self.event_sink.send(response.clone()) { + debug!("Event sink channel closed, stopping worker loop"); + break; + } + self.handle_response(response).await; + }, + Err(_) if self.shutdown.load(Ordering::Relaxed) => break, + Err(err) => error!("Error in recv: {}", err), + } + } + debug!("Worker loop exited."); + } + + /// Event sink worker that distributes events to multiple drain channels + /// This mirrors the Go implementation's eventSinkWorker + async fn event_sink_worker( + &self, + mut event_sink_rx: mpsc::UnboundedReceiver>, + mut drain_add_rx: mpsc::UnboundedReceiver>>, + mut drain_remove_rx: mpsc::UnboundedReceiver>>, + ) { + debug!("Event sink worker started"); + let mut drains: HashMap>> = HashMap::new(); + let mut next_id = 0usize; + + loop { + tokio::select! { + // Handle shutdown + _ = async { while !self.shutdown.load(Ordering::Relaxed) { tokio::time::sleep(std::time::Duration::from_millis(100)).await; } } => { + debug!("Event sink worker shutting down"); + break; + } + + // Add new drain channel + Some(drain) = drain_add_rx.recv() => { + drains.insert(next_id, drain); + next_id += 1; + debug!("Added new drain channel, total drains: {}", drains.len()); + } + + // Remove drain channel when EventSinkReceiver is dropped + Some(drain_to_remove) = drain_remove_rx.recv() => { + drains.retain(|_, drain| !std::ptr::addr_eq(drain, &drain_to_remove)); + debug!("Removed drain channel, total drains: {}", drains.len()); + } + + // Distribute events to all drain channels + Some(event) = event_sink_rx.recv() => { + let mut bad_drains = Vec::new(); + + for (id, drain) in &drains { + if let Err(_) = drain.send(event.clone()) { + // Channel is closed, mark for removal + bad_drains.push(*id); + } + } + + // Remove closed channels + for id in bad_drains { + drains.remove(&id); + } + } + } + } + debug!("Event sink worker exited."); + } + + pub(crate) async fn send_cbor_request(&self, request: BTreeMap) -> Result<(), ThinClientError> { + let encoded_request = serde_cbor::to_vec(&serde_cbor::Value::Map(request))?; + let length_prefix = (encoded_request.len() as u32).to_be_bytes(); + + let mut write_half = self.write_half.lock().await; + + match &mut *write_half { + WriteHalf::Tcp(wh) => { + wh.write_all(&length_prefix).await?; + wh.write_all(&encoded_request).await?; + } + WriteHalf::Unix(wh) => { + wh.write_all(&length_prefix).await?; + wh.write_all(&encoded_request).await?; + } + } + + debug!("✅ Request sent successfully."); + Ok(()) + } + + /// Send a CBOR request and wait for a reply with the matching query_id. + /// This uses direct response routing via query_id (like Python's _send_and_wait). + pub(crate) async fn send_and_wait_direct(&self, query_id: Vec, request: BTreeMap) -> Result, ThinClientError> { + // Create oneshot channel for receiving the reply + let (tx, rx) = oneshot::channel(); + + // Register the channel BEFORE sending the request (like Python) + { + let mut channels = self.response_channels.lock().await; + channels.insert(query_id.clone(), tx); + } + + // Send the request + if let Err(e) = self.send_cbor_request(request).await { + // Clean up on failure + let mut channels = self.response_channels.lock().await; + channels.remove(&query_id); + return Err(e); + } + + debug!("send_and_wait_direct: request sent, waiting for reply with query_id {:?}", &query_id[..std::cmp::min(8, query_id.len())]); + + // Wait for the reply (no timeout - block forever like Go/Python) + match rx.await { + Ok(reply) => { + debug!("send_and_wait_direct: received reply"); + Ok(reply) + } + Err(_) => { + // Channel was dropped without sending - clean up + let mut channels = self.response_channels.lock().await; + channels.remove(&query_id); + Err(ThinClientError::Other("Response channel closed without reply".to_string())) + } + } + } + + /// Sends a message encapsulated in a Sphinx packet without any SURB. + /// No reply will be possible. This method requires mixnet connectivity. + pub async fn send_message_without_reply( + &self, + payload: &[u8], + dest_node: Vec, + dest_queue: Vec + ) -> Result<(), ThinClientError> { + // Check if we're in offline mode + if !self.is_connected() { + return Err(ThinClientError::OfflineMode("cannot send message in offline mode - daemon not connected to mixnet".to_string())); + } + // Create the SendMessage structure + let mut send_message = BTreeMap::new(); + send_message.insert(Value::Text("id".to_string()), Value::Null); // No ID for fire-and-forget messages + send_message.insert(Value::Text("with_surb".to_string()), Value::Bool(false)); + send_message.insert(Value::Text("surbid".to_string()), Value::Null); // No SURB ID for fire-and-forget messages + send_message.insert(Value::Text("destination_id_hash".to_string()), Value::Bytes(dest_node)); + send_message.insert(Value::Text("recipient_queue_id".to_string()), Value::Bytes(dest_queue)); + send_message.insert(Value::Text("payload".to_string()), Value::Bytes(payload.to_vec())); + + // Wrap in the new Request structure + let mut request = BTreeMap::new(); + request.insert(Value::Text("send_message".to_string()), Value::Map(send_message)); + + self.send_cbor_request(request).await + } + + /// This method takes a message payload, a destination node, + /// destination queue ID and a SURB ID and sends a message along + /// with a SURB so that you can later receive the reply along with + /// the SURBID you choose. This method of sending messages should + /// be considered to be asynchronous because it does NOT actually + /// wait until the client daemon sends the message. Nor does it + /// wait for a reply. The only blocking aspect to it's behavior is + /// merely blocking until the client daemon receives our request + /// to send a message. This method requires mixnet connectivity. + pub async fn send_message( + &self, + surb_id: Vec, + payload: &[u8], + dest_node: Vec, + dest_queue: Vec + ) -> Result<(), ThinClientError> { + // Check if we're in offline mode + if !self.is_connected() { + return Err(ThinClientError::OfflineMode("cannot send message in offline mode - daemon not connected to mixnet".to_string())); + } + // Create the SendMessage structure + let mut send_message = BTreeMap::new(); + send_message.insert(Value::Text("id".to_string()), Value::Null); // No ID for regular messages + send_message.insert(Value::Text("with_surb".to_string()), Value::Bool(true)); + send_message.insert(Value::Text("surbid".to_string()), Value::Bytes(surb_id)); + send_message.insert(Value::Text("destination_id_hash".to_string()), Value::Bytes(dest_node)); + send_message.insert(Value::Text("recipient_queue_id".to_string()), Value::Bytes(dest_queue)); + send_message.insert(Value::Text("payload".to_string()), Value::Bytes(payload.to_vec())); + + // Wrap in the new Request structure + let mut request = BTreeMap::new(); + request.insert(Value::Text("send_message".to_string()), Value::Map(send_message)); + + self.send_cbor_request(request).await + } + + /// This method takes a message payload, a destination node, + /// destination queue ID and a message ID and reliably sends a message. + /// This uses a simple ARQ to resend the message if a reply wasn't received. + /// The given message ID will be used to identify the reply since a SURB ID + /// can only be used once. This method requires mixnet connectivity. + pub async fn send_reliable_message( + &self, + message_id: Vec, + payload: &[u8], + dest_node: Vec, + dest_queue: Vec + ) -> Result<(), ThinClientError> { + // Check if we're in offline mode + if !self.is_connected() { + return Err(ThinClientError::OfflineMode("cannot send reliable message in offline mode - daemon not connected to mixnet".to_string())); + } + // Create the SendARQMessage structure + let mut send_arq_message = BTreeMap::new(); + send_arq_message.insert(Value::Text("id".to_string()), Value::Bytes(message_id)); + send_arq_message.insert(Value::Text("with_surb".to_string()), Value::Bool(true)); + send_arq_message.insert(Value::Text("surbid".to_string()), Value::Null); // ARQ messages don't use SURB IDs directly + send_arq_message.insert(Value::Text("destination_id_hash".to_string()), Value::Bytes(dest_node)); + send_arq_message.insert(Value::Text("recipient_queue_id".to_string()), Value::Bytes(dest_queue)); + send_arq_message.insert(Value::Text("payload".to_string()), Value::Bytes(payload.to_vec())); + + // Wrap in the new Request structure + let mut request = BTreeMap::new(); + request.insert(Value::Text("send_arq_message".to_string()), Value::Map(send_arq_message)); + + self.send_cbor_request(request).await + } +} diff --git a/src/doodle/mod.rs b/src/doodle/mod.rs new file mode 100644 index 0000000..93b1e26 --- /dev/null +++ b/src/doodle/mod.rs @@ -0,0 +1,364 @@ +// SPDX-FileCopyrightText: Copyright (C) 2026 David Stainton +// SPDX-License-Identifier: AGPL-3.0-only + +//! Distributed Doodle-style meeting poll built on [`GroupChannel`]. +//! +//! A poll creator proposes a fixed set of [`TimeSlot`]s. Every group +//! member — including the creator — publishes a [`DoodleEvent`] on their +//! own channel. Any participant can derive the current poll state by +//! folding all events observed across the group's streams: +//! +//! ```text +//! state = fold(events) +//! ``` +//! +//! # Protocol +//! +//! 1. The creator calls [`DoodlePoll::new_poll`], which broadcasts a +//! `CreatePoll` event containing the poll title and the list of slots. +//! 2. Other participants call [`DoodlePoll::join_poll`] (or +//! [`DoodlePoll::restore`] for a previously persisted poll) and +//! exchange [`Introduction`]s with one another via [`GroupChannel`]. +//! 3. Each participant calls [`DoodlePoll::cast_ballot`] to publish a +//! [`CastBallot`] event mapping each slot ID to their [`Availability`]. +//! Ballots may be updated by publishing a new one; later ballots +//! supersede earlier ones (last-write-wins, guaranteed by channel +//! ordering). +//! 4. All participants call [`DoodlePoll::receive_and_apply`] (or the +//! partial-results variant [`DoodlePoll::receive_and_apply_timeout`]) +//! to pull in remote events and update the local [`PollState`]. +//! 5. Whoever wants a summary calls [`DoodlePoll::tally`] for per-slot +//! counts or [`DoodlePoll::best_slot`] for the slot with the most +//! "Yes" votes. +//! +//! # Example +//! +//! ```ignore +//! // Creator +//! let mut poll = DoodlePoll::new_poll(&ph, "standup", "Alice", +//! "Weekly standup", +//! vec![ +//! TimeSlot::new("mon-9", "Monday 09:00"), +//! TimeSlot::new("tue-9", "Tuesday 09:00"), +//! ], +//! ).await?; +//! +//! // Share poll.my_introduction() out-of-band, then: +//! poll.add_member(&ph, &bob_intro)?; +//! +//! // Bob +//! let mut bob_poll = DoodlePoll::join_poll(&ph, "standup", "Bob", &alice_intro).await?; +//! bob_poll.cast_ballot(hashmap! { +//! "mon-9".to_string() => Availability::Yes, +//! "tue-9".to_string() => Availability::No, +//! }).await?; +//! +//! // Alice receives and tallies +//! poll.receive_and_apply().await?; +//! println!("{:#?}", poll.tally()); +//! ``` + +use std::collections::HashMap; +use std::time::Duration; + +use serde::{Deserialize, Serialize}; + +use crate::group::channel::{GroupChannel, ReceivedGroupEvent}; +use crate::group::Introduction; +use crate::persistent::error::Result; +use crate::persistent::PigeonholeClient; + +// --------------------------------------------------------------------------- +// Domain types +// --------------------------------------------------------------------------- + +/// A candidate meeting time proposed by the poll creator. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct TimeSlot { + /// Stable identifier used as the key in ballot maps. + pub id: String, + /// Human-readable label shown to participants. + pub label: String, +} + +impl TimeSlot { + pub fn new(id: impl Into, label: impl Into) -> Self { + Self { id: id.into(), label: label.into() } + } +} + +/// A participant's answer for a single time slot. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub enum Availability { + Yes, + No, + Maybe, +} + +// --------------------------------------------------------------------------- +// Event type +// --------------------------------------------------------------------------- + +/// Events published on a [`DoodlePoll`] group channel. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum DoodleEvent { + /// Emitted by the poll creator to establish the poll title and slot list. + /// Every participant's local state is initialized from the first + /// `CreatePoll` event they observe (subsequent ones are ignored). + CreatePoll { + title: String, + slots: Vec, + }, + /// A participant's complete ballot. Each entry maps a slot ID to their + /// availability. A later ballot from the same sender completely replaces + /// their earlier one (last-write-wins, guaranteed by channel ordering). + CastBallot { + /// Maps slot id → availability for every slot in the poll. + votes: HashMap, + }, +} + +// --------------------------------------------------------------------------- +// State: fold(events) +// --------------------------------------------------------------------------- + +/// The current poll state, derived by folding all observed events. +/// +/// `title` and `slots` are set by the first `CreatePoll` event; subsequent +/// `CreatePoll` events are ignored (the creator's initial broadcast is +/// authoritative). `ballots` is updated by every `CastBallot` event. +#[derive(Debug, Clone, Default)] +pub struct PollState { + /// Poll title, empty until the first `CreatePoll` is applied. + pub title: String, + /// Ordered list of candidate slots, empty until `CreatePoll` is applied. + pub slots: Vec, + /// Latest ballot per sender. Key: display name; Value: slot-id → availability. + pub ballots: HashMap>, +} + +impl PollState { + /// Fold one `(sender, event)` pair into the state. + pub fn apply(&mut self, sender: &str, event: DoodleEvent) { + match event { + DoodleEvent::CreatePoll { title, slots } => { + // Only the first CreatePoll initializes the poll. + if self.title.is_empty() { + self.title = title; + self.slots = slots; + } + } + DoodleEvent::CastBallot { votes } => { + // Last ballot from this sender wins. + self.ballots.insert(sender.to_string(), votes); + } + } + } + + /// Per-slot tally of Yes / No / Maybe counts across all known ballots. + /// + /// Returns one [`SlotTally`] per slot in creation order. Participants + /// who have not yet cast a ballot are excluded from the counts. + pub fn tally(&self) -> Vec { + self.slots.iter().map(|slot| { + let mut yes = 0u32; + let mut no = 0u32; + let mut maybe = 0u32; + for ballot in self.ballots.values() { + match ballot.get(&slot.id) { + Some(Availability::Yes) => yes += 1, + Some(Availability::No) => no += 1, + Some(Availability::Maybe) => maybe += 1, + None => {} + } + } + SlotTally { slot: slot.clone(), yes, no, maybe } + }).collect() + } + + /// The slot with the highest `Yes` count, breaking ties by slot order. + /// + /// Returns `None` if the poll has no slots or no `Yes` votes at all. + pub fn best_slot(&self) -> Option<&TimeSlot> { + self.tally() + .into_iter() + .enumerate() + .filter(|(_, t)| t.yes > 0) + .max_by_key(|(idx, t)| (t.yes, -((*idx) as i64))) + .map(|(idx, _)| &self.slots[idx]) + } +} + +/// Vote counts for a single time slot. +#[derive(Debug, Clone)] +pub struct SlotTally { + pub slot: TimeSlot, + pub yes: u32, + pub no: u32, + pub maybe: u32, +} + +// --------------------------------------------------------------------------- +// DoodlePoll +// --------------------------------------------------------------------------- + +/// A distributed Doodle-style meeting poll backed by a +/// [`GroupChannel`]. +/// +/// The local poll state is updated by calling [`receive_one_and_apply`], +/// [`receive_and_apply`], or [`receive_and_apply_timeout`]. State is +/// inspected with [`poll_state`], [`tally`], and [`best_slot`]. +pub struct DoodlePoll { + channel: GroupChannel, + state: PollState, +} + +impl DoodlePoll { + // ----------------------------------------------------------------------- + // Constructors + // ----------------------------------------------------------------------- + + /// Create a new poll, broadcasting the slot list to the group channel. + /// + /// The creator is automatically recorded as the poll initiator. Call + /// [`add_member`] for each additional participant after sharing your + /// [`my_introduction`] with them out-of-band. + pub async fn new_poll( + pigeonhole: &PigeonholeClient, + poll_name: &str, + my_display_name: &str, + title: impl Into, + slots: Vec, + ) -> Result { + let channel = GroupChannel::create(pigeonhole, poll_name, my_display_name).await?; + let title = title.into(); + let event = DoodleEvent::CreatePoll { title: title.clone(), slots: slots.clone() }; + channel.send(&event).await?; + let mut state = PollState::default(); + state.apply(my_display_name, event); + Ok(Self { channel, state }) + } + + /// Join an existing poll, given the creator's introduction. + /// + /// Call this after receiving the creator's [`Introduction`] out-of-band. + /// To receive the `CreatePoll` event and initialize the local state, call + /// [`receive_one_and_apply`] (or [`receive_and_apply`]). + pub async fn join_poll( + pigeonhole: &PigeonholeClient, + poll_name: &str, + my_display_name: &str, + creator_intro: &Introduction, + ) -> Result { + let channel = GroupChannel::create(pigeonhole, poll_name, my_display_name).await?; + channel.add_member(pigeonhole, creator_intro)?; + Ok(Self { channel, state: PollState::default() }) + } + + /// Restore a previously persisted poll from the database. + pub async fn restore( + pigeonhole: &PigeonholeClient, + poll_name: &str, + my_display_name: &str, + member_intros: &[Introduction], + ) -> Result { + let channel = GroupChannel::restore(pigeonhole, poll_name, my_display_name, member_intros).await?; + Ok(Self { channel, state: PollState::default() }) + } + + // ----------------------------------------------------------------------- + // Membership + // ----------------------------------------------------------------------- + + /// Return an [`Introduction`] for sharing with other participants. + pub fn my_introduction(&self) -> Introduction { + self.channel.my_introduction() + } + + /// Number of remote member channels currently tracked. + pub fn member_count(&self) -> usize { + self.channel.member_count() + } + + /// Import a member's read capability and start tracking their channel. + pub fn add_member(&self, pigeonhole: &PigeonholeClient, intro: &Introduction) -> Result<()> { + self.channel.add_member(pigeonhole, intro) + } + + /// Remove a member's channel from local tracking. + pub fn remove_member(&self, member_id: &str) -> bool { + self.channel.remove_member(member_id) + } + + // ----------------------------------------------------------------------- + // Voting + // ----------------------------------------------------------------------- + + /// Publish a ballot mapping each slot ID to an [`Availability`] value. + /// + /// The ballot is applied to the local state immediately so that + /// `tally()` and `best_slot()` include the caller's own votes without + /// requiring a round-trip through the mixnet. + /// + /// A later `cast_ballot` call replaces the previous one; last-write-wins + /// is guaranteed by channel ordering. + pub async fn cast_ballot(&mut self, votes: HashMap) -> Result<()> { + let event = DoodleEvent::CastBallot { votes }; + self.channel.send(&event).await?; + let my_name = self.channel.my_display_name.clone(); + self.state.apply(&my_name, event); + Ok(()) + } + + // ----------------------------------------------------------------------- + // Receive + fold + // ----------------------------------------------------------------------- + + /// Block until any member delivers one event, then fold it into the local + /// state. Returns a reference to the updated [`PollState`]. + pub async fn receive_one_and_apply(&mut self) -> Result<&PollState> { + let ReceivedGroupEvent { sender, event } = self.channel.receive_any().await?; + self.state.apply(&sender, event); + Ok(&self.state) + } + + /// Block until every member has delivered one event, then fold all of + /// them into the local state. Returns a reference to the updated + /// [`PollState`]. + pub async fn receive_and_apply(&mut self) -> Result<&PollState> { + let events = self.channel.receive_from_all().await?; + for ReceivedGroupEvent { sender, event } in events { + self.state.apply(&sender, event); + } + Ok(&self.state) + } + + /// Like [`receive_and_apply`] but returns partial results after `timeout`. + pub async fn receive_and_apply_timeout(&mut self, timeout: Duration) -> Result<&PollState> { + let events = self.channel.receive_from_all_timeout(timeout).await?; + for ReceivedGroupEvent { sender, event } in events { + self.state.apply(&sender, event); + } + Ok(&self.state) + } + + // ----------------------------------------------------------------------- + // State inspection + // ----------------------------------------------------------------------- + + /// Read-only view of the current poll state. + pub fn poll_state(&self) -> &PollState { + &self.state + } + + /// Per-slot vote counts derived from the current state. + pub fn tally(&self) -> Vec { + self.state.tally() + } + + /// The slot with the most "Yes" votes. Returns `None` if no votes have + /// been cast yet or no slot has a "Yes". + pub fn best_slot(&self) -> Option<&TimeSlot> { + self.state.best_slot() + } +} diff --git a/src/error.rs b/src/error.rs index d4f8342..d362fe6 100644 --- a/src/error.rs +++ b/src/error.rs @@ -12,9 +12,60 @@ pub enum ThinClientError { MissingPkiDocument, ServiceNotFound, OfflineMode(String), + + // Pigeonhole replica error codes (from pigeonhole/errors.go) + /// Box ID not found on the replica (error code 1) + BoxNotFound, + /// Invalid box ID format (error code 2) + InvalidBoxId, + /// Invalid or missing signature (error code 3) + InvalidSignature, + /// Database operation failed (error code 4) + DatabaseFailure, + /// Invalid payload data (error code 5) + InvalidPayload, + /// Storage capacity exceeded (error code 6) + StorageFull, + /// Internal replica error (error code 7) + ReplicaInternalError, + /// Invalid epoch (error code 8) + InvalidEpoch, + /// Replication to other replicas failed (error code 9) + ReplicationFailed, + /// Box already exists / already written (error code 10) + BoxAlreadyExists, + /// MKEM decryption failed (error code 22) + MkemDecryptionFailed, + /// BACAP decryption failed (error code 23) + BacapDecryptionFailed, + /// Operation was cancelled (error code 24) + StartResendingCancelled, + Other(String), } +/// Maps daemon error codes to ThinClientError variants. +/// This matches the Go `errorCodeToSentinel` function. +pub fn error_code_to_error(error_code: u8) -> ThinClientError { + match error_code { + 0 => ThinClientError::Other("unexpected success code in error path".to_string()), + 1 => ThinClientError::BoxNotFound, + 2 => ThinClientError::InvalidBoxId, + 3 => ThinClientError::InvalidSignature, + 4 => ThinClientError::DatabaseFailure, + 5 => ThinClientError::InvalidPayload, + 6 => ThinClientError::StorageFull, + 7 => ThinClientError::ReplicaInternalError, + 8 => ThinClientError::InvalidEpoch, + 9 => ThinClientError::ReplicationFailed, + 10 => ThinClientError::BoxAlreadyExists, + 22 => ThinClientError::MkemDecryptionFailed, + 23 => ThinClientError::BacapDecryptionFailed, + 24 => ThinClientError::StartResendingCancelled, + code => ThinClientError::Other(format!("unknown error code: {}", code)), + } +} + impl fmt::Display for ThinClientError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { @@ -24,6 +75,19 @@ impl fmt::Display for ThinClientError { ThinClientError::MissingPkiDocument => write!(f, "Missing PKI document."), ThinClientError::ServiceNotFound => write!(f, "Service not found."), ThinClientError::OfflineMode(msg) => write!(f, "Offline mode error: {}", msg), + ThinClientError::BoxNotFound => write!(f, "Box ID not found"), + ThinClientError::InvalidBoxId => write!(f, "Invalid box ID"), + ThinClientError::InvalidSignature => write!(f, "Invalid signature"), + ThinClientError::DatabaseFailure => write!(f, "Database failure"), + ThinClientError::InvalidPayload => write!(f, "Invalid payload"), + ThinClientError::StorageFull => write!(f, "Storage full"), + ThinClientError::ReplicaInternalError => write!(f, "Replica internal error"), + ThinClientError::InvalidEpoch => write!(f, "Invalid epoch"), + ThinClientError::ReplicationFailed => write!(f, "Replication failed"), + ThinClientError::BoxAlreadyExists => write!(f, "Box already exists"), + ThinClientError::MkemDecryptionFailed => write!(f, "MKEM decryption failed"), + ThinClientError::BacapDecryptionFailed => write!(f, "BACAP decryption failed"), + ThinClientError::StartResendingCancelled => write!(f, "Start resending cancelled"), ThinClientError::Other(msg) => write!(f, "Error: {}", msg), } } diff --git a/src/group/channel.rs b/src/group/channel.rs new file mode 100644 index 0000000..76f9ef0 --- /dev/null +++ b/src/group/channel.rs @@ -0,0 +1,467 @@ +// SPDX-FileCopyrightText: Copyright (C) 2026 David Stainton +// SPDX-License-Identifier: AGPL-3.0-only + +//! Generic group channel: each member owns one `EventChannel` for writing; +//! the others hold imported read-only `EventChannel`s for every peer. + +use std::collections::{HashMap, VecDeque}; +use std::sync::{Arc, RwLock}; +use std::time::Duration; + +use blake2::{Blake2s256, Digest}; +use serde::{Deserialize, Serialize, de::DeserializeOwned}; +use tokio::sync::Mutex; + +use crate::persistent::error::{PigeonholeDbError, Result}; +use crate::persistent::{PigeonholeClient, ReadCapability}; + +use super::event_channel::EventChannel; + +/// Out-of-band introduction: a member's display name, read capability, and +/// starting index. Exchanged directly (e.g. QR code, secure side-channel) +/// to bootstrap group membership, or sent in-band as an application event +/// when one member wants to introduce another. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Introduction { + pub display_name: String, + #[serde(with = "serde_bytes")] + pub read_cap: Vec, + #[serde(with = "serde_bytes")] + pub start_index: Vec, +} + +impl Introduction { + pub fn new(display_name: &str, read_cap: Vec, start_index: Vec) -> Self { + Self { + display_name: display_name.to_string(), + read_cap, + start_index, + } + } + + /// A stable, collision-resistant identifier derived from the read + /// capability bytes. Two `Introduction`s with the same `display_name` + /// but different read caps produce different IDs, preventing silent map + /// collisions. Used as the `HashMap` key and as the channel name + /// component in the database. + pub fn member_id(&self) -> String { + hex::encode(Blake2s256::digest(&self.read_cap)) + } +} + +/// An event received from a specific group member. +#[derive(Debug, Clone)] +pub struct ReceivedGroupEvent { + pub sender: String, + pub event: E, +} + +/// A group where every member publishes to their own `EventChannel` and +/// reads from every other member's channel. +/// +/// # Type parameter +/// +/// `E` is the application event type. It must be serializable with +/// `serde` (CBOR encoding is used on the wire). Examples: +/// +/// - A simple `enum ChatEvent { Text(String), Introduction(Introduction) }` +/// for plain group chat. +/// - A CRDT operation type such as `Dot` (a `GCounter` op) +/// for replicated-state applications. +/// +/// # Receiving +/// +/// Use [`receive_from`] to block until a specific member's next message +/// arrives, or [`receive_any`] to race all member channels and return +/// whichever delivers first. Both rely on the daemon's ARQ mechanism +/// rather than an application-level sleep/poll loop. +/// +/// # Channel rotation (future work) +/// +/// For post-compromise security, each member should periodically rotate to a +/// freshly generated channel. The rotation handshake requires the writer to +/// receive an explicit ACK from every reader confirming they have imported the +/// new read cap before the old channel is retired. This is not yet +/// implemented; the current design keeps a single `EventChannel` per +/// member for simplicity. +pub struct GroupChannel { + pub name: String, + pub my_display_name: String, + /// Cached at creation; immutable, so no lock needed to share it. + my_introduction: Introduction, + /// Wrapped in `Arc>` so `send` takes `&self` and can be + /// called concurrently with `receive_from_all` in `tokio::join!`. + my_channel: Arc>>, + /// The member map is wrapped in `Arc>` so that `add_member` + /// and `remove_member` take `&self` and can run concurrently with `send`. + /// Receive methods take a snapshot (clone of `Arc`s) under a brief read + /// lock, then do all async work outside it. A `std::sync::RwLock` is + /// used rather than `tokio::sync::RwLock` because every map operation is + /// synchronous; the async work lives inside the per-channel `Mutex`. + /// + /// Key: `Introduction::member_id()` (Blake2s-256 hex of the read cap). + /// Value: `(display_name, channel)` — display_name is stored alongside + /// the channel so that `ReceivedGroupEvent::sender` stays human-readable. + member_channels: Arc>>)>>>, + /// Events that were received by a `receive_any` call alongside the winner + /// but not yet returned to the caller. Because `ChannelHandle::receive` + /// advances the persistent read cursor before it returns, any result that + /// is produced must eventually be delivered — it cannot be silently + /// discarded. Buffered events are drained in FIFO order by the next + /// `receive_any` call before new network I/O is issued. + receive_any_buffer: Mutex>>, +} + +impl GroupChannel { + /// Create a new group and generate the local member's channel. + pub async fn create( + pigeonhole: &PigeonholeClient, + group_name: &str, + my_display_name: &str, + ) -> Result { + let my_channel_name = format!("group:{}:self", group_name); + let handle = pigeonhole.create_channel(&my_channel_name).await?; + let read_cap = handle.share_read_capability(); + let my_introduction = Introduction::new(my_display_name, read_cap.read_cap, read_cap.start_index); + let my_channel = Arc::new(Mutex::new(EventChannel::new(handle))); + + Ok(Self { + name: group_name.to_string(), + my_display_name: my_display_name.to_string(), + my_introduction, + my_channel, + member_channels: Arc::new(RwLock::new(HashMap::new())), + receive_any_buffer: Mutex::new(VecDeque::new()), + }) + } + + /// Restore a previously created group from persisted channels in the + /// database. + /// + /// Unlike [`create`], this does not need network I/O — all data is already + /// in the local DB. The signature is `async` purely for API symmetry so + /// callers can treat both constructors the same way. + pub async fn restore( + pigeonhole: &PigeonholeClient, + group_name: &str, + my_display_name: &str, + member_intros: &[Introduction], + ) -> Result { + let my_channel_name = format!("group:{}:self", group_name); + let handle = pigeonhole.get_channel(&my_channel_name)?; + let read_cap = handle.share_read_capability(); + let my_introduction = Introduction::new(my_display_name, read_cap.read_cap, read_cap.start_index); + let my_channel = Arc::new(Mutex::new(EventChannel::new(handle))); + + let mut map = HashMap::new(); + for intro in member_intros { + let id = intro.member_id(); + let member_channel_name = format!("group:{}:member:{}", group_name, id); + let handle = pigeonhole.get_channel(&member_channel_name)?; + map.insert( + id, + (intro.display_name.clone(), Arc::new(Mutex::new(EventChannel::new(handle)))), + ); + } + + Ok(Self { + name: group_name.to_string(), + my_display_name: my_display_name.to_string(), + my_introduction, + my_channel, + member_channels: Arc::new(RwLock::new(map)), + receive_any_buffer: Mutex::new(VecDeque::new()), + }) + } + + /// Return an `Introduction` suitable for sharing with new members so they + /// can import this member's channel. + pub fn my_introduction(&self) -> Introduction { + self.my_introduction.clone() + } + + /// Number of remote member channels currently tracked. + pub fn member_count(&self) -> usize { + self.member_channels.read().expect("member_channels lock poisoned").len() + } + + /// Import a member's read capability and start tracking their channel. + /// + /// The member is keyed internally by [`Introduction::member_id`] (a hash + /// of the read cap), not by `display_name`. Two members may share a + /// display name without colliding. + pub fn add_member(&self, pigeonhole: &PigeonholeClient, intro: &Introduction) -> Result<()> { + let id = intro.member_id(); + let channel_name = format!("group:{}:member:{}", self.name, id); + let read_cap = ReadCapability { + read_cap: intro.read_cap.clone(), + start_index: intro.start_index.clone(), + name: Some(intro.display_name.clone()), + }; + let handle = pigeonhole.import_channel(&channel_name, &read_cap)?; + self.member_channels + .write() + .expect("member_channels lock poisoned") + .insert(id, (intro.display_name.clone(), Arc::new(Mutex::new(EventChannel::new(handle))))); + Ok(()) + } + + /// Remove a member's channel from local tracking. + /// + /// Pass [`Introduction::member_id`] as `member_id`. + /// Returns `true` if the member was present. + pub fn remove_member(&self, member_id: &str) -> bool { + self.member_channels + .write() + .expect("member_channels lock poisoned") + .remove(member_id) + .is_some() + } + + /// Send an event on the local member's channel. + pub async fn send(&self, event: &E) -> Result<()> { + self.my_channel.lock().await.send(event).await + } + + /// Block until the next event from `member_id` arrives, using the daemon's + /// ARQ mechanism. Returns immediately if a message is already waiting. + /// + /// Pass [`Introduction::member_id`] as `member_id`. + pub async fn receive_from(&self, member_id: &str) -> Result> { + let (display_name, channel) = self.member_channels + .read() + .expect("member_channels lock poisoned") + .get(member_id) + .ok_or_else(|| PigeonholeDbError::Other(format!("No member '{}' in group", member_id)))? + .clone(); + let mut ch = channel.lock().await; + let event = ch.receive().await?; + Ok(ReceivedGroupEvent { sender: display_name, event }) + } + + /// Block until every member has delivered one event, receiving from all + /// member channels concurrently. + /// + /// All ARQ requests are started simultaneously. Results are collected in + /// completion order (fastest channel first) so no channel waits on + /// another. Returns one event per member in the order they arrived. + pub async fn receive_from_all(&self) -> Result>> { + // snapshot: (display_name, channel) + let snapshot: Vec<(String, Arc>>)> = self.member_channels + .read() + .expect("member_channels lock poisoned") + .values() + .map(|(display_name, ch)| (display_name.clone(), ch.clone())) + .collect(); + + if snapshot.is_empty() { + return Ok(vec![]); + } + + let mut set = tokio::task::JoinSet::new(); + + let results_cap = snapshot.len(); + for (name, channel) in snapshot { + set.spawn(async move { + let mut ch = channel.lock().await; + ch.receive().await + .map(|event| ReceivedGroupEvent { sender: name, event }) + }); + } + + let mut results = Vec::with_capacity(results_cap); + while let Some(res) = set.join_next().await { + match res { + Ok(Ok(event)) => results.push(event), + Ok(Err(e)) => return Err(e), + Err(join_err) => return Err(PigeonholeDbError::Other( + format!("receive_from_all task panicked: {}", join_err) + )), + } + } + Ok(results) + } + + /// Like [`receive_from_all`] but returns after `timeout` with whatever + /// results have arrived, rather than blocking until every member delivers. + /// + /// Returns `Ok(events)` where `events` contains one entry per member that + /// delivered within the deadline; members that did not deliver are simply + /// absent from the result (their tasks are aborted before their + /// `ChannelHandle::receive` completes, so no cursor is advanced and no + /// message is lost on those channels). + /// + /// **No-loss guarantee**: any member whose `receive()` completed + /// *concurrently* with the timeout — meaning its read cursor has already + /// advanced — has its event placed in the internal pending buffer and will + /// be returned by the next [`receive_any`] call. + /// + /// If any member's `receive()` returns an error, all already-collected + /// events are moved to the pending buffer and the error is propagated. + pub async fn receive_from_all_timeout( + &self, + timeout: Duration, + ) -> Result>> { + // snapshot: (display_name, channel) + let snapshot: Vec<(String, Arc>>)> = self.member_channels + .read() + .expect("member_channels lock poisoned") + .values() + .map(|(dn, ch)| (dn.clone(), ch.clone())) + .collect(); + + if snapshot.is_empty() { + return Ok(vec![]); + } + + let n = snapshot.len(); + // Capacity = n: completing tasks send without blocking, preserving + // results even if we time out before draining them. + let (tx, mut rx) = tokio::sync::mpsc::channel(n); + let mut handles = Vec::with_capacity(n); + + for (display_name, channel) in snapshot { + let tx = tx.clone(); + let handle = tokio::spawn(async move { + let mut ch = channel.lock().await; + let result = ch.receive().await + .map(|event| ReceivedGroupEvent { sender: display_name, event }); + let _ = tx.send(result).await; + }); + handles.push(handle); + } + drop(tx); + + let deadline = tokio::time::Instant::now() + timeout; + let mut results = Vec::with_capacity(n); + let mut first_err: Option = None; + + loop { + if results.len() == n { + break; // all members delivered + } + match tokio::time::timeout_at(deadline, rx.recv()).await { + Ok(Some(Ok(event))) => results.push(event), + Ok(Some(Err(e))) => { first_err = Some(e); break; } + Ok(None) | Err(_) => break, // channel closed or deadline + } + } + + // Abort tasks that haven't finished receive() yet. + for handle in &handles { + handle.abort(); + } + + // Drain any events that landed just before/during abort (their cursors + // are already advanced and must not be discarded). + let mut late: Vec> = Vec::new(); + while let Ok(extra) = rx.try_recv() { + if let Ok(event) = extra { + late.push(event); + } + } + + if let Some(e) = first_err { + // Move all successfully received events to the pending buffer so + // they are not lost despite the error return. + let mut buf = self.receive_any_buffer.lock().await; + for event in results.into_iter().chain(late) { + buf.push_back(event); + } + Err(e) + } else { + if !late.is_empty() { + let mut buf = self.receive_any_buffer.lock().await; + for event in late { + buf.push_back(event); + } + } + Ok(results) + } + } + + /// Block until any member sends an event, racing all member channels + /// concurrently. The first to deliver wins; remaining tasks are aborted. + /// + /// **No-loss guarantee**: `ChannelHandle::receive` advances the persistent + /// read cursor before it returns. Any task that completes a receive must + /// therefore have its result delivered to the caller — discarding it would + /// permanently skip that message. To handle races where multiple channels + /// deliver simultaneously, the channel passed to spawned tasks has capacity + /// equal to the member count, so every completing task can send its result + /// without blocking. After returning the first result, any extras already + /// in the channel are drained into an internal buffer and returned by + /// subsequent `receive_any` calls before new network I/O is issued. + /// + /// Each member's channel uses the daemon's ARQ mechanism, so there is no + /// application-level sleep or timeout — the daemon retries automatically + /// until the box is available. + pub async fn receive_any(&self) -> Result> { + // Drain the buffer before doing any network I/O. + { + let mut buf = self.receive_any_buffer.lock().await; + if let Some(event) = buf.pop_front() { + return Ok(event); + } + } + + // snapshot: (display_name, channel) + let snapshot: Vec<(String, Arc>>)> = self.member_channels + .read() + .expect("member_channels lock poisoned") + .values() + .map(|(display_name, ch)| (display_name.clone(), ch.clone())) + .collect(); + + if snapshot.is_empty() { + return Err(PigeonholeDbError::Other("Group has no members".to_string())); + } + + // One slot per member so that a completing task's tx.send() resolves + // without yielding. A non-yielding send completes before the tokio + // scheduler can switch to our task, so by the time rx.recv() wakes us + // up every task that finished receive() has already placed its result + // in the channel. + let n = snapshot.len(); + let (tx, mut rx) = tokio::sync::mpsc::channel(n); + let mut handles = Vec::with_capacity(n); + + for (name, channel) in snapshot { + let tx = tx.clone(); + let handle = tokio::spawn(async move { + let mut ch = channel.lock().await; + let result = ch.receive().await + .map(|event| ReceivedGroupEvent { sender: name, event }); + let _ = tx.send(result).await; + }); + handles.push(handle); + } + drop(tx); + + let first = rx.recv().await + .ok_or_else(|| PigeonholeDbError::Other("All member channels failed".to_string()))?; + + // Abort tasks that haven't completed receive() yet. Tasks that already + // completed sent their result non-blocking (capacity = n), so their + // results are already in `rx` and will be captured by try_recv below. + for handle in &handles { + handle.abort(); + } + + // Drain results that arrived alongside the winner. These come from + // tasks whose receive() completed before the abort fired; their read + // cursors have already advanced and the messages must not be discarded. + { + let mut buf = self.receive_any_buffer.lock().await; + while let Ok(extra) = rx.try_recv() { + if let Ok(event) = extra { + buf.push_back(event); + } + // An Err result means that receive() failed before advancing + // the cursor, so nothing was consumed and we can drop it. + } + } + + first + } +} diff --git a/src/group/event_channel.rs b/src/group/event_channel.rs new file mode 100644 index 0000000..a51891e --- /dev/null +++ b/src/group/event_channel.rs @@ -0,0 +1,211 @@ +// SPDX-FileCopyrightText: Copyright (C) 2026 David Stainton +// SPDX-License-Identifier: AGPL-3.0-only + +//! Typed event channel wrapping a persistent ChannelHandle. +//! +//! `EventChannel` serializes events of type `E` to CBOR before writing +//! to the underlying pigeonhole box, and deserializes on read. This is the +//! primitive used by `GroupChannel` and any higher-level protocol that +//! wants typed, ordered streams over Pigeonhole. +//! +//! # Channel rotation (future work) +//! +//! For post-compromise security each member should periodically rotate to a +//! fresh channel. The rotation handshake requires the writer to receive an +//! ACK from every reader confirming they have imported the new read cap before +//! the old channel can be retired. `EventChannel` is intentionally kept +//! simple so that rotation can be layered on top without changing its core +//! encode/decode contract. + +use std::marker::PhantomData; + +use serde::{Serialize, de::DeserializeOwned}; + +use crate::persistent::{ChannelHandle, ReadCapability}; +use crate::persistent::error::{PigeonholeDbError, Result}; + +/// A typed, CBOR-encoded wrapper around a [`ChannelHandle`]. +/// +/// Every event sent through an `EventChannel` is serialized with +/// `serde_cbor` before being handed to the underlying channel, and +/// deserialized on receipt. The wire format is therefore opaque bytes from +/// the pigeonhole layer's perspective. +pub struct EventChannel { + handle: ChannelHandle, + _phantom: PhantomData, +} + +impl EventChannel { + /// Wrap an existing `ChannelHandle`. + pub fn new(handle: ChannelHandle) -> Self { + Self { handle, _phantom: PhantomData } + } + + /// Consume the wrapper and return the underlying handle. + pub fn into_inner(self) -> ChannelHandle { + self.handle + } + + /// Borrow the underlying handle. + pub fn inner(&self) -> &ChannelHandle { + &self.handle + } + + /// Mutably borrow the underlying handle. + pub fn inner_mut(&mut self) -> &mut ChannelHandle { + &mut self.handle + } + + pub fn name(&self) -> &str { + self.handle.name() + } + + pub fn is_owned(&self) -> bool { + self.handle.is_owned() + } + + pub fn read_cap(&self) -> &[u8] { + self.handle.read_cap() + } + + pub fn share_read_capability(&self) -> ReadCapability { + self.handle.share_read_capability() + } + + /// Serialize `event` to CBOR and write it to the next box in the channel. + pub async fn send(&mut self, event: &E) -> Result<()> { + let bytes = serde_cbor::to_vec(event) + .map_err(|e| PigeonholeDbError::Other(format!("CBOR encode error: {}", e)))?; + self.handle.send(&bytes).await + } + + /// Read the next box and deserialize it as `E`. + /// + /// Blocks (with ARQ retries) until a message arrives. + pub async fn receive(&mut self) -> Result { + let bytes = self.handle.receive().await?; + serde_cbor::from_slice(&bytes) + .map_err(|e| PigeonholeDbError::Other(format!("CBOR decode error: {}", e))) + } + + /// Like [`receive`] but returns [`ThinClientError::BoxNotFound`] immediately + /// instead of retrying when the box does not exist yet. + pub async fn receive_no_retry(&mut self) -> Result { + let bytes = self.handle.receive_no_retry().await?; + serde_cbor::from_slice(&bytes) + .map_err(|e| PigeonholeDbError::Other(format!("CBOR decode error: {}", e))) + } +} + +// ============================================================================ +// Unit tests +// ============================================================================ + +#[cfg(test)] +mod tests { + use serde::{Deserialize, Serialize}; + + + // ----------------------------------------------------------------------- + // Serialization round-trip + // ----------------------------------------------------------------------- + + #[test] + fn test_cbor_roundtrip_string() { + let original = "hello pigeonhole".to_string(); + let bytes = serde_cbor::to_vec(&original).unwrap(); + let decoded: String = serde_cbor::from_slice(&bytes).unwrap(); + assert_eq!(original, decoded); + } + + #[test] + fn test_cbor_roundtrip_enum() { + #[derive(Debug, PartialEq, Serialize, Deserialize)] + enum TestEvent { + Increment { actor: String, amount: u64 }, + Reset, + } + + let cases = vec![ + TestEvent::Increment { actor: "Alice".to_string(), amount: 42 }, + TestEvent::Reset, + ]; + + for event in cases { + let bytes = serde_cbor::to_vec(&event).unwrap(); + let decoded: TestEvent = serde_cbor::from_slice(&bytes).unwrap(); + assert_eq!(event, decoded); + } + } + + #[test] + fn test_cbor_roundtrip_gcounter_dot() { + // GCounter::Op = Dot — verify it is round-trippable. + use crdts::Dot; + let op: Dot = Dot::new("Alice".to_string(), 7); + let bytes = serde_cbor::to_vec(&op).unwrap(); + let decoded: Dot = serde_cbor::from_slice(&bytes).unwrap(); + assert_eq!(op, decoded); + } + + // ----------------------------------------------------------------------- + // CRDT fold logic (no network required) + // ----------------------------------------------------------------------- + + #[test] + fn test_gcounter_fold_over_ops() { + use crdts::{CmRDT, Dot, GCounter}; + + // Simulate three members each broadcasting one increment op through + // their own stream. A reader collects all ops and folds them. + let ops = vec![ + Dot::new("Alice".to_string(), 1u64), + Dot::new("Bob".to_string(), 1u64), + Dot::new("Carol".to_string(), 1u64), + ]; + + let mut counter: GCounter = GCounter::new(); + for op in ops { + counter.apply(op); + } + + assert_eq!(counter.read().to_string(), "3"); + } + + #[test] + fn test_gcounter_fold_respects_per_actor_max() { + use crdts::{CmRDT, Dot, GCounter}; + + // If we receive two ops from the same actor, only the larger one wins. + // This mirrors what happens when a member resends an op and the + // GCounter's max-per-actor semantics deduplicate it. + let mut counter: GCounter = GCounter::new(); + counter.apply(Dot::new("Alice".to_string(), 3u64)); + counter.apply(Dot::new("Alice".to_string(), 1u64)); // lower — ignored + counter.apply(Dot::new("Bob".to_string(), 2u64)); + + // GCounter keeps the max per actor: Alice=3, Bob=2 → total=5 + assert_eq!(counter.read().to_string(), "5"); + } + + #[test] + fn test_state_fold_pattern() { + // Demonstrate `state = fold(events)` concretely without any network. + // Each event is a Dot serialized to CBOR, as it would arrive + // from an EventChannel. + use crdts::{CmRDT, Dot, GCounter}; + + let events_on_wire: Vec> = vec![ + serde_cbor::to_vec(&Dot::new("Alice".to_string(), 2u64)).unwrap(), + serde_cbor::to_vec(&Dot::new("Bob".to_string(), 5u64)).unwrap(), + ]; + + let mut counter: GCounter = GCounter::new(); + for bytes in &events_on_wire { + let op: Dot = serde_cbor::from_slice(bytes).unwrap(); + counter.apply(op); + } + + assert_eq!(counter.read().to_string(), "7"); + } +} diff --git a/src/group/mod.rs b/src/group/mod.rs new file mode 100644 index 0000000..8c02b50 --- /dev/null +++ b/src/group/mod.rs @@ -0,0 +1,11 @@ +// SPDX-FileCopyrightText: Copyright (C) 2026 David Stainton +// SPDX-License-Identifier: AGPL-3.0-only + +//! Group channel: each member has their own typed BACAP stream. + +pub mod channel; +pub mod event_channel; + +pub use channel::{GroupChannel, Introduction, ReceivedGroupEvent}; +pub use event_channel::EventChannel; + diff --git a/src/helpers.rs b/src/helpers.rs new file mode 100644 index 0000000..0e34982 --- /dev/null +++ b/src/helpers.rs @@ -0,0 +1,127 @@ +// SPDX-FileCopyrightText: Copyright (C) 2026 David Stainton +// SPDX-License-Identifier: AGPL-3.0-only + +//! Helper functions for working with PKI documents and service discovery. + +use std::collections::BTreeMap; +use serde_cbor::{from_slice, Value}; +use serde_json::json; + +use crate::ServiceDescriptor; + +/// Find a specific mixnet service if it exists. +pub fn find_services(capability: &str, doc: &BTreeMap) -> Vec { + let mut services = Vec::new(); + + let Some(Value::Array(nodes)) = doc.get(&Value::Text("ServiceNodes".to_string())) else { + println!("❌ No 'ServiceNodes' found in PKI document."); + return services; + }; + + for node in nodes { + let Value::Bytes(node_bytes) = node else { continue }; + let Ok(mynode) = from_slice::>(node_bytes) else { continue }; + + // 🔍 Print available capabilities in each node + if let Some(Value::Map(details)) = mynode.get(&Value::Text("Kaetzchen".to_string())) { + println!("🔍 Available Capabilities: {:?}", details.keys()); + } + + let Some(Value::Map(details)) = mynode.get(&Value::Text("Kaetzchen".to_string())) else { continue }; + let Some(Value::Map(service)) = details.get(&Value::Text(capability.to_string())) else { continue }; + let Some(Value::Text(endpoint)) = service.get(&Value::Text("endpoint".to_string())) else { continue }; + + println!("returning a service descriptor!"); + + services.push(ServiceDescriptor { + recipient_queue_id: endpoint.as_bytes().to_vec(), + mix_descriptor: mynode, + }); + } + + services +} + +fn convert_to_pretty_json(value: &Value) -> serde_json::Value { + match value { + Value::Text(s) => serde_json::Value::String(s.clone()), + Value::Integer(i) => json!(*i), + Value::Bytes(b) => json!(hex::encode(b)), // Encode byte arrays as hex strings + Value::Array(arr) => serde_json::Value::Array(arr.iter().map(convert_to_pretty_json).collect()), + Value::Map(map) => { + let converted_map: serde_json::Map = map + .iter() + .map(|(key, value)| { + let key_str = match key { + Value::Text(s) => s.clone(), + _ => format!("{:?}", key), + }; + (key_str, convert_to_pretty_json(value)) + }) + .collect(); + serde_json::Value::Object(converted_map) + } + _ => serde_json::Value::Null, // Handle unexpected CBOR types + } +} + +fn decode_cbor_nodes(nodes: &[Value]) -> Vec { + nodes + .iter() + .filter_map(|node| match node { + Value::Bytes(blob) => serde_cbor::from_slice::>(blob) + .ok() + .map(Value::Map), + _ => Some(node.clone()), // Preserve non-CBOR values as they are + }) + .collect() +} + +/// Pretty prints a PKI document which you can gather from the client +/// with it's `pki_document` method, documented above. +pub fn pretty_print_pki_doc(doc: &BTreeMap) { + let mut new_doc = BTreeMap::new(); + + // Decode "GatewayNodes" + if let Some(Value::Array(gateway_nodes)) = doc.get(&Value::Text("GatewayNodes".to_string())) { + new_doc.insert(Value::Text("GatewayNodes".to_string()), Value::Array(decode_cbor_nodes(gateway_nodes))); + } + + // Decode "ServiceNodes" + if let Some(Value::Array(service_nodes)) = doc.get(&Value::Text("ServiceNodes".to_string())) { + new_doc.insert(Value::Text("ServiceNodes".to_string()), Value::Array(decode_cbor_nodes(service_nodes))); + } + + // Decode "Topology" (flatten nested arrays of CBOR blobs) + if let Some(Value::Array(topology_layers)) = doc.get(&Value::Text("Topology".to_string())) { + let decoded_topology: Vec = topology_layers + .iter() + .flat_map(|layer| match layer { + Value::Array(layer_nodes) => decode_cbor_nodes(layer_nodes), + _ => vec![], + }) + .collect(); + + new_doc.insert(Value::Text("Topology".to_string()), Value::Array(decoded_topology)); + } + + // Copy and decode all other fields that might contain CBOR blobs + for (key, value) in doc.iter() { + if !matches!(key, Value::Text(s) if ["GatewayNodes", "ServiceNodes", "Topology"].contains(&s.as_str())) { + let key_str = key.clone(); + let decoded_value = match value { + Value::Bytes(blob) => serde_cbor::from_slice::>(blob) + .ok() + .map(Value::Map) + .unwrap_or(value.clone()), // Fallback to original if not CBOR + _ => value.clone(), + }; + + new_doc.insert(key_str, decoded_value); + } + } + + // Convert to pretty JSON format right before printing + let pretty_json = convert_to_pretty_json(&Value::Map(new_doc)); + println!("{}", serde_json::to_string_pretty(&pretty_json).unwrap()); +} diff --git a/src/lib.rs b/src/lib.rs index e35c8c2..937daa4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: Copyright (C) 2025 David Stainton +// SPDX-FileCopyrightText: Copyright (C) 2025, 2026 David Stainton // SPDX-License-Identifier: AGPL-3.0-only //! A thin client for sending and receiving messages via a Katzenpost @@ -16,140 +16,6 @@ //! with mixnet services concurrently. //! //! This example can be found here: https://github.com/katzenpost/thin_client/blob/main/examples/echo_ping.rs -//! Thin client example usage:: -//! -//! -//! ```rust,no_run -//! use std::env; -//! use std::collections::BTreeMap; -//! use std::sync::{Arc, Mutex}; -//! use std::process; -//! -//! use tokio::time::{timeout, Duration}; -//! use tokio::runtime::Runtime; -//! -//! use serde_cbor::Value; -//! -//! use katzenpost_thin_client::{ThinClient, Config, pretty_print_pki_doc}; -//! -//! struct ClientState { -//! reply_message: Arc>>>, -//! pki_received: Arc>, -//! } -//! -//! impl ClientState { -//! fn new() -> Self { -//! Self { -//! reply_message: Arc::new(Mutex::new(None)), -//! pki_received: Arc::new(Mutex::new(false)), -//! } -//! } -//! -//! fn save_reply(&self, reply: &BTreeMap) { -//! let mut stored_reply = self.reply_message.lock().unwrap(); -//! *stored_reply = Some(reply.clone()); -//! } -//! -//! fn set_pki_received(&self) { -//! let mut pki_flag = self.pki_received.lock().unwrap(); -//! *pki_flag = true; -//! } -//! -//! fn is_pki_received(&self) -> bool { -//! *self.pki_received.lock().unwrap() -//! } -//! -//! fn await_message_reply(&self) -> Option> { -//! let stored_reply = self.reply_message.lock().unwrap(); -//! stored_reply.clone() -//! } -//! } -//! -//! fn main() { -//! let args: Vec = env::args().collect(); -//! if args.len() != 2 { -//! eprintln!("Usage: {} ", args[0]); -//! process::exit(1); -//! } -//! let config_path = &args[1]; -//! -//! let rt = Runtime::new().unwrap(); -//! rt.block_on(run_client(config_path)).unwrap(); -//! } -//! -//! async fn run_client(config_path: &str) -> Result<(), Box> { -//! let state = Arc::new(ClientState::new()); -//! let state_for_reply = Arc::clone(&state); -//! let state_for_pki = Arc::clone(&state); -//! -//! let mut cfg = Config::new(config_path)?; -//! cfg.on_new_pki_document = Some(Arc::new(move |_pki_doc| { -//! println!("✅ PKI document received."); -//! state_for_pki.set_pki_received(); -//! })); -//! cfg.on_message_reply = Some(Arc::new(move |reply| { -//! println!("📩 Received a reply!"); -//! state_for_reply.save_reply(reply); -//! })); -//! -//! println!("🚀 Initializing ThinClient..."); -//! let client = ThinClient::new(cfg).await?; -//! -//! println!("⏳ Waiting for PKI document..."); -//! let result = timeout(Duration::from_secs(5), async { -//! loop { -//! if state.is_pki_received() { -//! break; -//! } -//! tokio::task::yield_now().await; -//! } -//! }) -//! .await; -//! -//! if result.is_err() { -//! return Err("❌ PKI document not received in time.".into()); -//! } -//! -//! println!("✅ Pretty printing PKI document:"); -//! let doc = client.pki_document().await; -//! pretty_print_pki_doc(&doc); -//! println!("AFTER Pretty printing PKI document"); -//! -//! let service_desc = client.get_service("echo").await?; -//! println!("got service descriptor for echo service"); -//! -//! let surb_id = ThinClient::new_surb_id(); -//! let payload = b"hello".to_vec(); -//! let (dest_node, dest_queue) = service_desc.to_destination(); -//! -//! println!("before calling send_message"); -//! client.send_message(surb_id, &payload, dest_node, dest_queue).await?; -//! println!("after calling send_message"); -//! -//! println!("⏳ Waiting for message reply..."); -//! let state_for_reply_wait = Arc::clone(&state); -//! -//! let result = timeout(Duration::from_secs(5), async move { -//! loop { -//! if let Some(reply) = state_for_reply_wait.await_message_reply() { -//! if let Some(Value::Bytes(payload2)) = reply.get(&Value::Text("payload".to_string())) { -//! let payload2 = &payload2[..payload.len()]; -//! assert_eq!(payload, payload2, "Reply does not match payload!"); -//! println!("✅ Received valid reply, stopping client."); -//! return Ok::<(), Box>(()); -//! } -//! } -//! tokio::task::yield_now().await; -//! } -//! }).await; -//! -//! result.map_err(|e| Box::new(e))??; -//! client.stop().await; -//! println!("✅ Client stopped successfully."); -//! Ok(()) -//! } -//! ``` -//! //! //! # See Also //! @@ -158,7 +24,45 @@ //! - [katzepost client integration guide](https://katzenpost.network/docs/client_integration/) //! - [katzenpost thin client protocol specification](https://katzenpost.network/docs/specs/connector.html) +// ======================================================================== +// Module declarations +// ======================================================================== + pub mod error; +pub mod core; +pub mod pigeonhole; +pub mod persistent; +pub mod helpers; +pub mod group; +pub mod chat; +pub mod doodle; + +// ======================================================================== +// Re-exports for public API +// ======================================================================== + +pub use crate::core::{ThinClient, EventSinkReceiver}; +pub use crate::error::ThinClientError; +pub use crate::helpers::{find_services, pretty_print_pki_doc}; +pub use crate::pigeonhole::TombstoneRangeResult; + +// ======================================================================== +// Imports for types defined in this file +// ======================================================================== + +use std::collections::BTreeMap; +use std::sync::Arc; +use std::fs; + +use serde::Deserialize; +use serde_cbor::Value; + +use blake2::{Blake2b, Digest}; +use generic_array::typenum::U32; + +// ======================================================================== +// Error codes +// ======================================================================== // Thin client error codes provide standardized error reporting across the protocol. // These codes are used in response messages to indicate the success or failure @@ -221,6 +125,51 @@ pub const THIN_CLIENT_ERROR_COURIER_CACHE_CORRUPTION: u8 = 12; /// propagated to replicas. pub const THIN_CLIENT_PROPAGATION_ERROR: u8 = 13; +/// ThinClientErrorInvalidWriteCapability indicates that the provided write +/// capability is invalid. +pub const THIN_CLIENT_ERROR_INVALID_WRITE_CAPABILITY: u8 = 14; + +/// ThinClientErrorInvalidReadCapability indicates that the provided read +/// capability is invalid. +pub const THIN_CLIENT_ERROR_INVALID_READ_CAPABILITY: u8 = 15; + +/// ThinClientErrorInvalidResumeWriteChannelRequest indicates that the provided +/// ResumeWriteChannel request is invalid. +pub const THIN_CLIENT_ERROR_INVALID_RESUME_WRITE_CHANNEL_REQUEST: u8 = 16; + +/// ThinClientErrorInvalidResumeReadChannelRequest indicates that the provided +/// ResumeReadChannel request is invalid. +pub const THIN_CLIENT_ERROR_INVALID_RESUME_READ_CHANNEL_REQUEST: u8 = 17; + +/// ThinClientImpossibleHashError indicates that the provided hash is impossible +/// to compute, such as when the hash of a write capability is provided but +/// the write capability itself is not provided. +pub const THIN_CLIENT_IMPOSSIBLE_HASH_ERROR: u8 = 18; + +/// ThinClientImpossibleNewWriteCapError indicates that the daemon was unable +/// to create a new write capability. +pub const THIN_CLIENT_IMPOSSIBLE_NEW_WRITE_CAP_ERROR: u8 = 19; + +/// ThinClientImpossibleNewStatefulWriterError indicates that the daemon was unable +/// to create a new stateful writer. +pub const THIN_CLIENT_IMPOSSIBLE_NEW_STATEFUL_WRITER_ERROR: u8 = 20; + +/// ThinClientCapabilityAlreadyInUse indicates that the provided capability +/// is already in use. +pub const THIN_CLIENT_CAPABILITY_ALREADY_IN_USE: u8 = 21; + +/// ThinClientErrorMKEMDecryptionFailed indicates that MKEM decryption failed. +/// This occurs when the MKEM envelope cannot be decrypted with any of the replica keys. +pub const THIN_CLIENT_ERROR_MKEM_DECRYPTION_FAILED: u8 = 22; + +/// ThinClientErrorBACAPDecryptionFailed indicates that BACAP decryption failed. +/// This occurs when the BACAP payload cannot be decrypted or signature verification fails. +pub const THIN_CLIENT_ERROR_BACAP_DECRYPTION_FAILED: u8 = 23; + +/// ThinClientErrorStartResendingCancelled indicates that a StartResendingEncryptedMessage +/// or StartResendingCopyCommand operation was cancelled before completion. +pub const THIN_CLIENT_ERROR_START_RESENDING_CANCELLED: u8 = 24; + /// Converts a thin client error code to a human-readable string. /// This function provides consistent error message formatting across the thin client /// protocol and is used for logging and error reporting. @@ -240,70 +189,24 @@ pub fn thin_client_error_to_string(error_code: u8) -> &'static str { THIN_CLIENT_ERROR_DUPLICATE_CAPABILITY => "Duplicate capability", THIN_CLIENT_ERROR_COURIER_CACHE_CORRUPTION => "Courier cache corruption", THIN_CLIENT_PROPAGATION_ERROR => "Propagation error", + THIN_CLIENT_ERROR_INVALID_WRITE_CAPABILITY => "Invalid write capability", + THIN_CLIENT_ERROR_INVALID_READ_CAPABILITY => "Invalid read capability", + THIN_CLIENT_ERROR_INVALID_RESUME_WRITE_CHANNEL_REQUEST => "Invalid resume write channel request", + THIN_CLIENT_ERROR_INVALID_RESUME_READ_CHANNEL_REQUEST => "Invalid resume read channel request", + THIN_CLIENT_IMPOSSIBLE_HASH_ERROR => "Impossible hash error", + THIN_CLIENT_IMPOSSIBLE_NEW_WRITE_CAP_ERROR => "Failed to create new write capability", + THIN_CLIENT_IMPOSSIBLE_NEW_STATEFUL_WRITER_ERROR => "Failed to create new stateful writer", + THIN_CLIENT_CAPABILITY_ALREADY_IN_USE => "Capability already in use", + THIN_CLIENT_ERROR_MKEM_DECRYPTION_FAILED => "MKEM decryption failed", + THIN_CLIENT_ERROR_BACAP_DECRYPTION_FAILED => "BACAP decryption failed", + THIN_CLIENT_ERROR_START_RESENDING_CANCELLED => "Start resending cancelled", _ => "Unknown thin client error code", } } -use std::collections::{BTreeMap, HashMap}; -use std::sync::{Arc, atomic::{AtomicBool, Ordering}}; -use std::fs; - -use serde::Deserialize; -use serde_json::json; -use serde_cbor::{from_slice, Value}; - -use tokio::sync::{Mutex, RwLock, mpsc}; -use tokio::task::JoinHandle; -use tokio::net::{TcpStream, UnixStream}; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tokio::net::tcp::{OwnedReadHalf as TcpReadHalf, OwnedWriteHalf as TcpWriteHalf}; -use tokio::net::unix::{OwnedReadHalf as UnixReadHalf, OwnedWriteHalf as UnixWriteHalf}; - -use blake2::{Blake2b, Digest}; -use generic_array::typenum::U32; -use rand::RngCore; -use log::{debug, error}; - -use crate::error::ThinClientError; - -/// Reply from WriteChannel operation, matching Go WriteChannelReply -#[derive(Debug, Clone)] -pub struct WriteChannelReply { - pub send_message_payload: Vec, - pub current_message_index: Vec, - pub next_message_index: Vec, - pub envelope_descriptor: Vec, - pub envelope_hash: Vec, -} - -/// Reply from ReadChannel operation, matching Go ReadChannelReply -#[derive(Debug, Clone)] -pub struct ReadChannelReply { - pub send_message_payload: Vec, - pub current_message_index: Vec, - pub next_message_index: Vec, - pub reply_index: Option, - pub envelope_descriptor: Vec, - pub envelope_hash: Vec, -} - -/// The size in bytes of a SURB (Single-Use Reply Block) identifier. -/// -/// SURB IDs are used to correlate replies with the original message sender. -/// Each SURB ID must be unique and is typically randomly generated. -const SURB_ID_SIZE: usize = 16; - -/// The size in bytes of a message identifier. -/// -/// Message IDs are used to track outbound messages and correlate them with replies. -/// Like SURB IDs, these are expected to be randomly generated and unique. -const MESSAGE_ID_SIZE: usize = 16; - -/// The size in bytes of a query identifier. -/// -/// Query IDs are used to correlate channel operation requests with their responses. -/// Each query should have a unique ID. -const QUERY_ID_SIZE: usize = 16; +// ======================================================================== +// Public types +// ======================================================================== /// ServiceDescriptor is used when we are searching the PKI /// document for a specific service. @@ -379,11 +282,90 @@ pub struct Geometry { pub kem_name: String, } +/// PigeonholeGeometry describes the geometry of a Pigeonhole envelope. +/// +/// This provides mathematically precise geometry calculations using trunnel's +/// fixed binary format. +/// +/// It supports 3 distinct use cases: +/// 1. Given MaxPlaintextPayloadLength → compute all envelope sizes +/// 2. Given precomputed Pigeonhole Geometry → derive accommodating Sphinx Geometry +/// 3. Given Sphinx Geometry constraint → derive optimal Pigeonhole Geometry +#[derive(Debug, Clone, Deserialize)] +pub struct PigeonholeGeometry { + /// The maximum usable plaintext payload size within a Box. + #[serde(rename = "MaxPlaintextPayloadLength")] + pub max_plaintext_payload_length: usize, + + /// The size of a CourierQuery containing a ReplicaRead. + #[serde(rename = "CourierQueryReadLength")] + pub courier_query_read_length: usize, + + /// The size of a CourierQuery containing a ReplicaWrite. + #[serde(rename = "CourierQueryWriteLength")] + pub courier_query_write_length: usize, + + /// The size of a CourierQueryReply containing a ReplicaReadReply. + #[serde(rename = "CourierQueryReplyReadLength")] + pub courier_query_reply_read_length: usize, + + /// The size of a CourierQueryReply containing a ReplicaWriteReply. + #[serde(rename = "CourierQueryReplyWriteLength")] + pub courier_query_reply_write_length: usize, + + /// The NIKE scheme name used in MKEM for encrypting to multiple storage replicas. + #[serde(rename = "NIKEName")] + pub nike_name: String, + + /// The signature scheme used for BACAP (always "Ed25519"). + #[serde(rename = "SignatureSchemeName")] + pub signature_scheme_name: String, +} + +impl PigeonholeGeometry { + /// Creates a new PigeonholeGeometry with the given parameters. + /// + /// Note: In a real application, the courier query lengths would be computed + /// from the max_plaintext_payload_length using the geometry calculations. + /// This constructor is primarily for testing where those values may be + /// provided directly or defaulted to 0. + pub fn new(max_plaintext_payload_length: usize, nike_name: &str) -> Self { + Self { + max_plaintext_payload_length, + courier_query_read_length: 0, + courier_query_write_length: 0, + courier_query_reply_read_length: 0, + courier_query_reply_write_length: 0, + nike_name: nike_name.to_string(), + signature_scheme_name: "Ed25519".to_string(), + } + } + + /// Validates that the geometry has valid parameters. + pub fn validate(&self) -> Result<(), &'static str> { + if self.max_plaintext_payload_length == 0 { + return Err("MaxPlaintextPayloadLength must be positive"); + } + if self.nike_name.is_empty() { + return Err("NIKEName must be set"); + } + if self.signature_scheme_name != "Ed25519" { + return Err("SignatureSchemeName must be Ed25519"); + } + Ok(()) + } +} + + + #[derive(Debug, Deserialize)] pub struct ConfigFile { #[serde(rename = "SphinxGeometry")] pub sphinx_geometry: Geometry, + #[serde(rename = "PigeonholeGeometry")] + pub pigeonhole_geometry: PigeonholeGeometry, + #[serde(rename = "Network")] pub network: String, @@ -407,6 +389,7 @@ pub struct Config { pub network: String, pub address: String, pub sphinx_geometry: Geometry, + pub pigeonhole_geometry: PigeonholeGeometry, pub on_connection_status: Option) + Send + Sync>>, pub on_new_pki_document: Option) + Send + Sync>>, @@ -423,6 +406,7 @@ impl Config { network: parsed.network, address: parsed.address, sphinx_geometry: parsed.sphinx_geometry, + pigeonhole_geometry: parsed.pigeonhole_geometry, on_connection_status: None, on_new_pki_document: None, on_message_sent: None, @@ -431,1188 +415,7 @@ impl Config { } } -/// This represent the read half of our network socket. -pub enum ReadHalf { - Tcp(TcpReadHalf), - Unix(UnixReadHalf), -} - -/// This represent the write half of our network socket. -pub enum WriteHalf { - Tcp(TcpWriteHalf), - Unix(UnixWriteHalf), -} - -/// Wrapper for event sink receiver that automatically removes the drain when dropped -pub struct EventSinkReceiver { - receiver: mpsc::UnboundedReceiver>, - sender: mpsc::UnboundedSender>, - drain_remove: mpsc::UnboundedSender>>, -} - -impl EventSinkReceiver { - /// Receive the next event from the sink - pub async fn recv(&mut self) -> Option> { - self.receiver.recv().await - } -} - -impl Drop for EventSinkReceiver { - fn drop(&mut self) { - // Remove the drain when the receiver is dropped - if let Err(_) = self.drain_remove.send(self.sender.clone()) { - debug!("Failed to remove drain channel - event sink worker may be stopped"); - } - } -} - -/// This is our ThinClient type which encapsulates our thin client -/// connection management and message processing. -pub struct ThinClient { - read_half: Mutex, - write_half: Mutex, - config: Config, - pki_doc: Arc>>>, - worker_task: Mutex>>, - event_sink_task: Mutex>>, - shutdown: Arc, - is_connected: Arc, - // Event system like Go implementation - event_sink: mpsc::UnboundedSender>, - drain_add: mpsc::UnboundedSender>>, - drain_remove: mpsc::UnboundedSender>>, -} - -impl ThinClient { - - /// Create a new thin cilent and connect it to the client daemon. - pub async fn new(config: Config) -> Result, Box> { - // Create event system channels like Go implementation - let (event_sink_tx, event_sink_rx) = mpsc::unbounded_channel(); - let (drain_add_tx, drain_add_rx) = mpsc::unbounded_channel(); - let (drain_remove_tx, drain_remove_rx) = mpsc::unbounded_channel(); - - let client = match config.network.to_uppercase().as_str() { - "TCP" => { - let socket = TcpStream::connect(&config.address).await?; - let (read_half, write_half) = socket.into_split(); - Arc::new(Self { - read_half: Mutex::new(ReadHalf::Tcp(read_half)), - write_half: Mutex::new(WriteHalf::Tcp(write_half)), - config, - pki_doc: Arc::new(RwLock::new(None)), - worker_task: Mutex::new(None), - event_sink_task: Mutex::new(None), - shutdown: Arc::new(AtomicBool::new(false)), - is_connected: Arc::new(AtomicBool::new(false)), - event_sink: event_sink_tx.clone(), - drain_add: drain_add_tx.clone(), - drain_remove: drain_remove_tx.clone(), - }) - } - "UNIX" => { - let path = if config.address.starts_with('@') { - let mut p = String::from("\0"); - p.push_str(&config.address[1..]); - p - } else { - config.address.clone() - }; - let socket = UnixStream::connect(path).await?; - let (read_half, write_half) = socket.into_split(); - Arc::new(Self { - read_half: Mutex::new(ReadHalf::Unix(read_half)), - write_half: Mutex::new(WriteHalf::Unix(write_half)), - config, - pki_doc: Arc::new(RwLock::new(None)), - worker_task: Mutex::new(None), - event_sink_task: Mutex::new(None), - shutdown: Arc::new(AtomicBool::new(false)), - is_connected: Arc::new(AtomicBool::new(false)), - event_sink: event_sink_tx, - drain_add: drain_add_tx, - drain_remove: drain_remove_tx, - }) - } - _ => { - return Err(format!("Unknown network type: {}", config.network).into()); - } - }; - - // Start worker loop - let client_clone = Arc::clone(&client); - let task = tokio::spawn(async move { client_clone.worker_loop().await }); - *client.worker_task.lock().await = Some(task); - - // Start event sink worker - let client_clone2 = Arc::clone(&client); - let event_sink_task = tokio::spawn(async move { - client_clone2.event_sink_worker(event_sink_rx, drain_add_rx, drain_remove_rx).await - }); - *client.event_sink_task.lock().await = Some(event_sink_task); - - debug!("✅ ThinClient initialized with worker loop and event sink started."); - Ok(client) - } - - /// Stop our async worker task and disconnect the thin client. - pub async fn stop(&self) { - debug!("Stopping ThinClient..."); - - self.shutdown.store(true, Ordering::Relaxed); - let mut write_half = self.write_half.lock().await; - let _ = match &mut *write_half { - WriteHalf::Tcp(wh) => wh.shutdown().await, - WriteHalf::Unix(wh) => wh.shutdown().await, - }; - - if let Some(worker) = self.worker_task.lock().await.take() { - worker.abort(); - } - debug!("✅ ThinClient stopped."); - } - - /// Returns true if the daemon is connected to the mixnet. - pub fn is_connected(&self) -> bool { - self.is_connected.load(Ordering::Relaxed) - } - - /// Creates a new event channel that receives all events from the thin client - /// This mirrors the Go implementation's EventSink method - pub fn event_sink(&self) -> EventSinkReceiver { - let (tx, rx) = mpsc::unbounded_channel(); - if let Err(_) = self.drain_add.send(tx.clone()) { - debug!("Failed to add drain channel - event sink worker may be stopped"); - } - EventSinkReceiver { - receiver: rx, - sender: tx, - drain_remove: self.drain_remove.clone(), - } - } - - /// Generates a new message ID. - pub fn new_message_id() -> Vec { - let mut id = vec![0; MESSAGE_ID_SIZE]; - rand::thread_rng().fill_bytes(&mut id); - id - } - - /// Generates a new SURB ID. - pub fn new_surb_id() -> Vec { - let mut id = vec![0; SURB_ID_SIZE]; - rand::thread_rng().fill_bytes(&mut id); - id - } - - /// Generates a new query ID. - pub fn new_query_id() -> Vec { - let mut id = vec![0; QUERY_ID_SIZE]; - rand::thread_rng().fill_bytes(&mut id); - id - } - - async fn update_pki_document(&self, new_pki_doc: BTreeMap) { - let mut pki_doc_lock = self.pki_doc.write().await; - *pki_doc_lock = Some(new_pki_doc); - debug!("PKI document updated."); - } - - /// Returns our latest retrieved PKI document. - pub async fn pki_document(&self) -> BTreeMap { - self.pki_doc.read().await.clone().expect("❌ PKI document is missing!") - } - - /// Given a service name this returns a ServiceDescriptor if the service exists - /// in the current PKI document. - pub async fn get_service(&self, service_name: &str) -> Result { - let doc = self.pki_doc.read().await.clone().ok_or(ThinClientError::MissingPkiDocument)?; - let services = find_services(service_name, &doc); - services.into_iter().next().ok_or(ThinClientError::ServiceNotFound) - } - - /// Returns a courier service destination for the current epoch. - /// This method finds and randomly selects a courier service from the current - /// PKI document. The returned destination information is used with SendChannelQuery - /// and SendChannelQueryAwaitReply to transmit prepared channel operations. - /// Returns (dest_node, dest_queue) on success. - pub async fn get_courier_destination(&self) -> Result<(Vec, Vec), ThinClientError> { - let courier_service = self.get_service("courier").await?; - let (dest_node, dest_queue) = courier_service.to_destination(); - Ok((dest_node, dest_queue)) - } - - async fn recv(&self) -> Result, ThinClientError> { - let mut length_prefix = [0; 4]; - { - let mut read_half = self.read_half.lock().await; - match &mut *read_half { - ReadHalf::Tcp(rh) => rh.read_exact(&mut length_prefix).await.map_err(ThinClientError::IoError)?, - ReadHalf::Unix(rh) => rh.read_exact(&mut length_prefix).await.map_err(ThinClientError::IoError)?, - }; - } - let message_length = u32::from_be_bytes(length_prefix) as usize; - let mut buffer = vec![0; message_length]; - { - let mut read_half = self.read_half.lock().await; - match &mut *read_half { - ReadHalf::Tcp(rh) => rh.read_exact(&mut buffer).await.map_err(ThinClientError::IoError)?, - ReadHalf::Unix(rh) => rh.read_exact(&mut buffer).await.map_err(ThinClientError::IoError)?, - }; - } - let response: BTreeMap = match from_slice(&buffer) { - Ok(parsed) => { - parsed - } - Err(err) => { - error!("❌ Failed to parse CBOR: {:?}", err); - return Err(ThinClientError::CborError(err)); - } - }; - Ok(response) - } - - fn parse_status(&self, event: &BTreeMap) { - let is_connected = event.get(&Value::Text("is_connected".to_string())) - .and_then(|v| match v { - Value::Bool(b) => Some(*b), - _ => None, - }) - .unwrap_or(false); - - // Update connection state - self.is_connected.store(is_connected, Ordering::Relaxed); - - if is_connected { - debug!("✅ Daemon is connected to mixnet - full functionality available."); - } else { - debug!("📴 Daemon is not connected to mixnet - entering offline mode (channel operations will work)."); - } - } - - async fn parse_pki_doc(&self, event: &BTreeMap) { - if let Some(Value::Bytes(payload)) = event.get(&Value::Text("payload".to_string())) { - match serde_cbor::from_slice::>(payload) { - Ok(raw_pki_doc) => { - self.update_pki_document(raw_pki_doc).await; - debug!("✅ PKI document successfully parsed."); - } - Err(err) => { - error!("❌ Failed to parse PKI document: {:?}", err); - } - } - } else { - error!("❌ Missing 'payload' field in PKI document event."); - } - } - - async fn handle_response(&self, response: BTreeMap) { - assert!(!response.is_empty(), "❌ Received an empty response!"); - - if let Some(Value::Map(event)) = response.get(&Value::Text("connection_status_event".to_string())) { - debug!("🔄 Connection status event received."); - self.parse_status(event); - if let Some(cb) = self.config.on_connection_status.as_ref() { - cb(event); - } - return; - } - - if let Some(Value::Map(event)) = response.get(&Value::Text("new_pki_document_event".to_string())) { - debug!("📜 New PKI document event received."); - self.parse_pki_doc(event).await; - if let Some(cb) = self.config.on_new_pki_document.as_ref() { - cb(event); - } - return; - } - - if let Some(Value::Map(event)) = response.get(&Value::Text("message_sent_event".to_string())) { - debug!("📨 Message sent event received."); - if let Some(cb) = self.config.on_message_sent.as_ref() { - cb(event); - } - return; - } - - if let Some(Value::Map(event)) = response.get(&Value::Text("message_reply_event".to_string())) { - debug!("📩 Message reply event received."); - if let Some(cb) = self.config.on_message_reply.as_ref() { - cb(event); - } - return; - } - - error!("❌ Unknown event type received: {:?}", response); - } - async fn worker_loop(&self) { - debug!("Worker loop started"); - while !self.shutdown.load(Ordering::Relaxed) { - match self.recv().await { - Ok(response) => { - // Send all responses to event sink for distribution - if let Err(_) = self.event_sink.send(response.clone()) { - debug!("Event sink channel closed, stopping worker loop"); - break; - } - self.handle_response(response).await; - }, - Err(_) if self.shutdown.load(Ordering::Relaxed) => break, - Err(err) => error!("Error in recv: {}", err), - } - } - debug!("Worker loop exited."); - } - - /// Event sink worker that distributes events to multiple drain channels - /// This mirrors the Go implementation's eventSinkWorker - async fn event_sink_worker( - &self, - mut event_sink_rx: mpsc::UnboundedReceiver>, - mut drain_add_rx: mpsc::UnboundedReceiver>>, - mut drain_remove_rx: mpsc::UnboundedReceiver>>, - ) { - debug!("Event sink worker started"); - let mut drains: HashMap>> = HashMap::new(); - let mut next_id = 0usize; - - loop { - tokio::select! { - // Handle shutdown - _ = async { while !self.shutdown.load(Ordering::Relaxed) { tokio::time::sleep(std::time::Duration::from_millis(100)).await; } } => { - debug!("Event sink worker shutting down"); - break; - } - - // Add new drain channel - Some(drain) = drain_add_rx.recv() => { - drains.insert(next_id, drain); - next_id += 1; - debug!("Added new drain channel, total drains: {}", drains.len()); - } - - // Remove drain channel when EventSinkReceiver is dropped - Some(drain_to_remove) = drain_remove_rx.recv() => { - drains.retain(|_, drain| !std::ptr::addr_eq(drain, &drain_to_remove)); - debug!("Removed drain channel, total drains: {}", drains.len()); - } - - // Distribute events to all drain channels - Some(event) = event_sink_rx.recv() => { - let mut bad_drains = Vec::new(); - - for (id, drain) in &drains { - if let Err(_) = drain.send(event.clone()) { - // Channel is closed, mark for removal - bad_drains.push(*id); - } - } - - // Remove closed channels - for id in bad_drains { - drains.remove(&id); - } - } - } - } - debug!("Event sink worker exited."); - } - - async fn send_cbor_request(&self, request: BTreeMap) -> Result<(), ThinClientError> { - let encoded_request = serde_cbor::to_vec(&serde_cbor::Value::Map(request))?; - let length_prefix = (encoded_request.len() as u32).to_be_bytes(); - - let mut write_half = self.write_half.lock().await; - - match &mut *write_half { - WriteHalf::Tcp(wh) => { - wh.write_all(&length_prefix).await?; - wh.write_all(&encoded_request).await?; - } - WriteHalf::Unix(wh) => { - wh.write_all(&length_prefix).await?; - wh.write_all(&encoded_request).await?; - } - } - - debug!("✅ Request sent successfully."); - Ok(()) - } - - /// Sends a message encapsulated in a Sphinx packet without any SURB. - /// No reply will be possible. This method requires mixnet connectivity. - pub async fn send_message_without_reply( - &self, - payload: &[u8], - dest_node: Vec, - dest_queue: Vec - ) -> Result<(), ThinClientError> { - // Check if we're in offline mode - if !self.is_connected() { - return Err(ThinClientError::OfflineMode("cannot send message in offline mode - daemon not connected to mixnet".to_string())); - } - // Create the SendMessage structure - let mut send_message = BTreeMap::new(); - send_message.insert(Value::Text("id".to_string()), Value::Null); // No ID for fire-and-forget messages - send_message.insert(Value::Text("with_surb".to_string()), Value::Bool(false)); - send_message.insert(Value::Text("surbid".to_string()), Value::Null); // No SURB ID for fire-and-forget messages - send_message.insert(Value::Text("destination_id_hash".to_string()), Value::Bytes(dest_node)); - send_message.insert(Value::Text("recipient_queue_id".to_string()), Value::Bytes(dest_queue)); - send_message.insert(Value::Text("payload".to_string()), Value::Bytes(payload.to_vec())); - - // Wrap in the new Request structure - let mut request = BTreeMap::new(); - request.insert(Value::Text("send_message".to_string()), Value::Map(send_message)); - - self.send_cbor_request(request).await - } - - /// This method takes a message payload, a destination node, - /// destination queue ID and a SURB ID and sends a message along - /// with a SURB so that you can later receive the reply along with - /// the SURBID you choose. This method of sending messages should - /// be considered to be asynchronous because it does NOT actually - /// wait until the client daemon sends the message. Nor does it - /// wait for a reply. The only blocking aspect to it's behavior is - /// merely blocking until the client daemon receives our request - /// to send a message. This method requires mixnet connectivity. - pub async fn send_message( - &self, - surb_id: Vec, - payload: &[u8], - dest_node: Vec, - dest_queue: Vec - ) -> Result<(), ThinClientError> { - // Check if we're in offline mode - if !self.is_connected() { - return Err(ThinClientError::OfflineMode("cannot send message in offline mode - daemon not connected to mixnet".to_string())); - } - // Create the SendMessage structure - let mut send_message = BTreeMap::new(); - send_message.insert(Value::Text("id".to_string()), Value::Null); // No ID for regular messages - send_message.insert(Value::Text("with_surb".to_string()), Value::Bool(true)); - send_message.insert(Value::Text("surbid".to_string()), Value::Bytes(surb_id)); - send_message.insert(Value::Text("destination_id_hash".to_string()), Value::Bytes(dest_node)); - send_message.insert(Value::Text("recipient_queue_id".to_string()), Value::Bytes(dest_queue)); - send_message.insert(Value::Text("payload".to_string()), Value::Bytes(payload.to_vec())); - - // Wrap in the new Request structure - let mut request = BTreeMap::new(); - request.insert(Value::Text("send_message".to_string()), Value::Map(send_message)); - - self.send_cbor_request(request).await - } - - /// This method takes a message payload, a destination node, - /// destination queue ID and a message ID and reliably sends a message. - /// This uses a simple ARQ to resend the message if a reply wasn't received. - /// The given message ID will be used to identify the reply since a SURB ID - /// can only be used once. This method requires mixnet connectivity. - pub async fn send_reliable_message( - &self, - message_id: Vec, - payload: &[u8], - dest_node: Vec, - dest_queue: Vec - ) -> Result<(), ThinClientError> { - // Check if we're in offline mode - if !self.is_connected() { - return Err(ThinClientError::OfflineMode("cannot send reliable message in offline mode - daemon not connected to mixnet".to_string())); - } - // Create the SendARQMessage structure - let mut send_arq_message = BTreeMap::new(); - send_arq_message.insert(Value::Text("id".to_string()), Value::Bytes(message_id)); - send_arq_message.insert(Value::Text("with_surb".to_string()), Value::Bool(true)); - send_arq_message.insert(Value::Text("surbid".to_string()), Value::Null); // ARQ messages don't use SURB IDs directly - send_arq_message.insert(Value::Text("destination_id_hash".to_string()), Value::Bytes(dest_node)); - send_arq_message.insert(Value::Text("recipient_queue_id".to_string()), Value::Bytes(dest_queue)); - send_arq_message.insert(Value::Text("payload".to_string()), Value::Bytes(payload.to_vec())); - - // Wrap in the new Request structure - let mut request = BTreeMap::new(); - request.insert(Value::Text("send_arq_message".to_string()), Value::Map(send_arq_message)); - - self.send_cbor_request(request).await - } - - /*** Channel API ***/ - - /// Creates a new Pigeonhole write channel for sending messages. - /// Returns (channel_id, read_cap, write_cap) on success. - pub async fn create_write_channel(&self) -> Result<(u16, Vec, Vec), ThinClientError> { - let query_id = Self::new_query_id(); - - let mut create_write_channel = BTreeMap::new(); - create_write_channel.insert(Value::Text("query_id".to_string()), Value::Bytes(query_id.clone())); - - let mut request = BTreeMap::new(); - request.insert(Value::Text("create_write_channel".to_string()), Value::Map(create_write_channel)); - - self.send_cbor_request(request).await?; - - // Wait for CreateWriteChannelReply using event sink - let mut event_sink = self.event_sink(); - - loop { - let response = event_sink.recv().await - .ok_or_else(|| ThinClientError::Other("Event sink closed".to_string()))?; - - if let Some(Value::Map(reply)) = response.get(&Value::Text("create_write_channel_reply".to_string())) { - // Check for error first - if let Some(Value::Integer(error_code)) = reply.get(&Value::Text("error_code".to_string())) { - if *error_code != 0 { - return Err(ThinClientError::Other(format!("CreateWriteChannel failed with error code: {}", error_code))); - } - } - - if let Some(Value::Text(err)) = reply.get(&Value::Text("err".to_string())) { - return Err(ThinClientError::Other(format!("CreateWriteChannel failed: {}", err))); - } - - let channel_id = reply.get(&Value::Text("channel_id".to_string())) - .and_then(|v| match v { Value::Integer(i) => Some(*i as u16), _ => None }) - .ok_or_else(|| ThinClientError::Other("Missing channel_id in response".to_string()))?; - - let read_cap = match reply.get(&Value::Text("read_cap".to_string())) { - Some(Value::Bytes(bytes)) => bytes.clone(), - Some(_) => return Err(ThinClientError::Other("read_cap is unexpected type".to_string())), - None => return Err(ThinClientError::Other("Missing read_cap in response".to_string())), - }; - - let write_cap = match reply.get(&Value::Text("write_cap".to_string())) { - Some(Value::Bytes(bytes)) => bytes.clone(), - Some(_) => return Err(ThinClientError::Other("write_cap is unexpected type".to_string())), - None => return Err(ThinClientError::Other("Missing write_cap in response".to_string())), - }; - - return Ok((channel_id, read_cap, write_cap)); - } - - // If we get here, it wasn't the reply we were looking for - } - } - - /// Creates a read channel from a read capability. - /// Returns channel_id on success. - pub async fn create_read_channel(&self, read_cap: Vec) -> Result { - let query_id = Self::new_query_id(); - - let mut create_read_channel = BTreeMap::new(); - create_read_channel.insert(Value::Text("query_id".to_string()), Value::Bytes(query_id.clone())); - create_read_channel.insert(Value::Text("read_cap".to_string()), Value::Bytes(read_cap)); - - let mut request = BTreeMap::new(); - request.insert(Value::Text("create_read_channel".to_string()), Value::Map(create_read_channel)); - - self.send_cbor_request(request).await?; - - // Wait for CreateReadChannelReply using event sink - let mut event_sink = self.event_sink(); - - loop { - let response = event_sink.recv().await - .ok_or_else(|| ThinClientError::Other("Event sink closed".to_string()))?; - - if let Some(Value::Map(reply)) = response.get(&Value::Text("create_read_channel_reply".to_string())) { - // Check for error first - if let Some(Value::Integer(error_code)) = reply.get(&Value::Text("error_code".to_string())) { - if *error_code != 0 { - return Err(ThinClientError::Other(format!("CreateReadChannel failed with error code: {}", error_code))); - } - } - - if let Some(Value::Text(err)) = reply.get(&Value::Text("err".to_string())) { - return Err(ThinClientError::Other(format!("CreateReadChannel failed: {}", err))); - } - - let channel_id = reply.get(&Value::Text("channel_id".to_string())) - .and_then(|v| match v { Value::Integer(i) => Some(*i as u16), _ => None }) - .ok_or_else(|| ThinClientError::Other("Missing channel_id in response".to_string()))?; - - return Ok(channel_id); - } - - // If we get here, it wasn't the reply we were looking for - } - } - - /// Prepares a message for writing to a Pigeonhole channel. - /// Returns WriteChannelReply matching the Go API. - pub async fn write_channel(&self, channel_id: u16, payload: &[u8]) -> Result { - let query_id = Self::new_query_id(); - - let mut write_channel = BTreeMap::new(); - write_channel.insert(Value::Text("channel_id".to_string()), Value::Integer(channel_id.into())); - write_channel.insert(Value::Text("query_id".to_string()), Value::Bytes(query_id.clone())); - write_channel.insert(Value::Text("payload".to_string()), Value::Bytes(payload.to_vec())); - - let mut request = BTreeMap::new(); - request.insert(Value::Text("write_channel".to_string()), Value::Map(write_channel)); - - self.send_cbor_request(request).await?; - - // Wait for WriteChannelReply using event sink - let mut event_sink = self.event_sink(); - - loop { - let response = event_sink.recv().await - .ok_or_else(|| ThinClientError::Other("Event sink closed".to_string()))?; - - if let Some(Value::Map(reply)) = response.get(&Value::Text("write_channel_reply".to_string())) { - if let Some(Value::Text(err)) = reply.get(&Value::Text("err".to_string())) { - return Err(ThinClientError::Other(format!("WriteChannel failed: {}", err))); - } - - let send_message_payload = reply.get(&Value::Text("send_message_payload".to_string())) - .and_then(|v| match v { Value::Bytes(b) => Some(b.clone()), _ => None }) - .ok_or_else(|| ThinClientError::Other("Missing send_message_payload in response".to_string()))?; - - let current_message_index = match reply.get(&Value::Text("current_message_index".to_string())) { - Some(Value::Bytes(bytes)) => bytes.clone(), - Some(_) => return Err(ThinClientError::Other("current_message_index is unexpected type".to_string())), - None => return Err(ThinClientError::Other("Missing current_message_index in response".to_string())), - }; - - let next_message_index = match reply.get(&Value::Text("next_message_index".to_string())) { - Some(Value::Bytes(bytes)) => bytes.clone(), - Some(_) => return Err(ThinClientError::Other("next_message_index is unexpected type".to_string())), - None => return Err(ThinClientError::Other("Missing next_message_index in response".to_string())), - }; - - let envelope_descriptor = reply.get(&Value::Text("envelope_descriptor".to_string())) - .and_then(|v| match v { Value::Bytes(b) => Some(b.clone()), _ => None }) - .ok_or_else(|| ThinClientError::Other("Missing envelope_descriptor in response".to_string()))?; - - let envelope_hash = reply.get(&Value::Text("envelope_hash".to_string())) - .and_then(|v| match v { Value::Bytes(b) => Some(b.clone()), _ => None }) - .ok_or_else(|| ThinClientError::Other("Missing envelope_hash in response".to_string()))?; - - return Ok(WriteChannelReply { - send_message_payload, - current_message_index, - next_message_index, - envelope_descriptor, - envelope_hash, - }); - } - - // If we get here, it wasn't the reply we were looking for - } - } - - /// Prepares a read query for a Pigeonhole channel. - /// Returns ReadChannelReply matching the Go API. - pub async fn read_channel(&self, channel_id: u16, message_box_index: Option<&[u8]>, reply_index: Option) -> Result { - let query_id = Self::new_query_id(); - - let mut read_channel = BTreeMap::new(); - read_channel.insert(Value::Text("channel_id".to_string()), Value::Integer(channel_id.into())); - read_channel.insert(Value::Text("query_id".to_string()), Value::Bytes(query_id.clone())); - - if let Some(index) = message_box_index { - read_channel.insert(Value::Text("message_box_index".to_string()), Value::Bytes(index.to_vec())); - } - - if let Some(idx) = reply_index { - read_channel.insert(Value::Text("reply_index".to_string()), Value::Integer(idx.into())); - } - - let mut request = BTreeMap::new(); - request.insert(Value::Text("read_channel".to_string()), Value::Map(read_channel)); - - self.send_cbor_request(request).await?; - - // Wait for ReadChannelReply using event sink - let mut event_sink = self.event_sink(); - - loop { - let response = event_sink.recv().await - .ok_or_else(|| ThinClientError::Other("Event sink closed".to_string()))?; - - if let Some(Value::Map(reply)) = response.get(&Value::Text("read_channel_reply".to_string())) { - if let Some(Value::Text(err)) = reply.get(&Value::Text("err".to_string())) { - return Err(ThinClientError::Other(format!("ReadChannel failed: {}", err))); - } - - let send_message_payload = reply.get(&Value::Text("send_message_payload".to_string())) - .and_then(|v| match v { Value::Bytes(b) => Some(b.clone()), _ => None }) - .ok_or_else(|| ThinClientError::Other("Missing send_message_payload in response".to_string()))?; - - let current_message_index = match reply.get(&Value::Text("current_message_index".to_string())) { - Some(Value::Bytes(bytes)) => bytes.clone(), - Some(_) => return Err(ThinClientError::Other("current_message_index is unexpected type".to_string())), - None => return Err(ThinClientError::Other("Missing current_message_index in response".to_string())), - }; - - let next_message_index = match reply.get(&Value::Text("next_message_index".to_string())) { - Some(Value::Bytes(bytes)) => bytes.clone(), - Some(_) => return Err(ThinClientError::Other("next_message_index is unexpected type".to_string())), - None => return Err(ThinClientError::Other("Missing next_message_index in response".to_string())), - }; - - let used_reply_index = reply.get(&Value::Text("reply_index".to_string())) - .and_then(|v| match v { Value::Integer(i) => Some(*i as u8), _ => None }); - - let envelope_descriptor = reply.get(&Value::Text("envelope_descriptor".to_string())) - .and_then(|v| match v { Value::Bytes(b) => Some(b.clone()), _ => None }) - .ok_or_else(|| ThinClientError::Other("Missing envelope_descriptor in response".to_string()))?; - - let envelope_hash = reply.get(&Value::Text("envelope_hash".to_string())) - .and_then(|v| match v { Value::Bytes(b) => Some(b.clone()), _ => None }) - .ok_or_else(|| ThinClientError::Other("Missing envelope_hash in response".to_string()))?; - - return Ok(ReadChannelReply { - send_message_payload, - current_message_index, - next_message_index, - reply_index: used_reply_index, - envelope_descriptor, - envelope_hash, - }); - } - - // If we get here, it wasn't the reply we were looking for - } - } - - /// Resumes a write channel from a previous session. - /// Returns channel_id on success. - pub async fn resume_write_channel(&self, write_cap: Vec, message_box_index: Option>) -> Result { - let query_id = Self::new_query_id(); - - let mut resume_write_channel = BTreeMap::new(); - resume_write_channel.insert(Value::Text("query_id".to_string()), Value::Bytes(query_id.clone())); - resume_write_channel.insert(Value::Text("write_cap".to_string()), Value::Bytes(write_cap)); - if let Some(index) = message_box_index { - resume_write_channel.insert(Value::Text("message_box_index".to_string()), Value::Bytes(index)); - } - - let mut request = BTreeMap::new(); - request.insert(Value::Text("resume_write_channel".to_string()), Value::Map(resume_write_channel)); - - self.send_cbor_request(request).await?; - - // Wait for ResumeWriteChannelReply using event sink - let mut event_sink = self.event_sink(); - - loop { - let response = event_sink.recv().await - .ok_or_else(|| ThinClientError::Other("Event sink closed".to_string()))?; - - if let Some(Value::Map(reply)) = response.get(&Value::Text("resume_write_channel_reply".to_string())) { - // Check for error first - if let Some(Value::Integer(error_code)) = reply.get(&Value::Text("error_code".to_string())) { - if *error_code != 0 { - return Err(ThinClientError::Other(format!("ResumeWriteChannel failed with error code: {}", error_code))); - } - } - - if let Some(Value::Text(err)) = reply.get(&Value::Text("err".to_string())) { - return Err(ThinClientError::Other(format!("ResumeWriteChannel failed: {}", err))); - } - - let channel_id = reply.get(&Value::Text("channel_id".to_string())) - .and_then(|v| match v { Value::Integer(i) => Some(*i as u16), _ => None }) - .ok_or_else(|| ThinClientError::Other("Missing channel_id in response".to_string()))?; - - return Ok(channel_id); - } - - // If we get here, it wasn't the reply we were looking for - } - } - - /// Resumes a read channel from a previous session. - /// Returns channel_id on success. - pub async fn resume_read_channel(&self, read_cap: Vec, next_message_index: Option>, reply_index: Option) -> Result { - let query_id = Self::new_query_id(); - - let mut resume_read_channel = BTreeMap::new(); - resume_read_channel.insert(Value::Text("query_id".to_string()), Value::Bytes(query_id.clone())); - resume_read_channel.insert(Value::Text("read_cap".to_string()), Value::Bytes(read_cap)); - if let Some(index) = next_message_index { - resume_read_channel.insert(Value::Text("next_message_index".to_string()), Value::Bytes(index)); - } - if let Some(index) = reply_index { - resume_read_channel.insert(Value::Text("reply_index".to_string()), Value::Integer(index.into())); - } - - let mut request = BTreeMap::new(); - request.insert(Value::Text("resume_read_channel".to_string()), Value::Map(resume_read_channel)); - - self.send_cbor_request(request).await?; - - // Wait for ResumeReadChannelReply using event sink - let mut event_sink = self.event_sink(); - - loop { - let response = event_sink.recv().await - .ok_or_else(|| ThinClientError::Other("Event sink closed".to_string()))?; - - if let Some(Value::Map(reply)) = response.get(&Value::Text("resume_read_channel_reply".to_string())) { - // Check for error first - if let Some(Value::Integer(error_code)) = reply.get(&Value::Text("error_code".to_string())) { - if *error_code != 0 { - return Err(ThinClientError::Other(format!("ResumeReadChannel failed with error code: {}", error_code))); - } - } - - if let Some(Value::Text(err)) = reply.get(&Value::Text("err".to_string())) { - return Err(ThinClientError::Other(format!("ResumeReadChannel failed: {}", err))); - } - - let channel_id = reply.get(&Value::Text("channel_id".to_string())) - .and_then(|v| match v { Value::Integer(i) => Some(*i as u16), _ => None }) - .ok_or_else(|| ThinClientError::Other("Missing channel_id in response".to_string()))?; - - return Ok(channel_id); - } - - // If we get here, it wasn't the reply we were looking for - } - } - - /// Resumes a write channel with a specific query state. - /// This method provides more granular resumption control than ResumeWriteChannel - /// by allowing the application to resume from a specific query state, including - /// the envelope descriptor and hash. This is useful when resuming from a partially - /// completed write operation that was interrupted during transmission. - /// Returns channel_id on success. - pub async fn resume_write_channel_query( - &self, - write_cap: Vec, - message_box_index: Vec, - envelope_descriptor: Vec, - envelope_hash: Vec, - ) -> Result { - let query_id = Self::new_query_id(); - - let mut resume_write_channel_query = BTreeMap::new(); - resume_write_channel_query.insert(Value::Text("query_id".to_string()), Value::Bytes(query_id.clone())); - resume_write_channel_query.insert(Value::Text("write_cap".to_string()), Value::Bytes(write_cap)); - resume_write_channel_query.insert(Value::Text("message_box_index".to_string()), Value::Bytes(message_box_index)); - resume_write_channel_query.insert(Value::Text("envelope_descriptor".to_string()), Value::Bytes(envelope_descriptor)); - resume_write_channel_query.insert(Value::Text("envelope_hash".to_string()), Value::Bytes(envelope_hash)); - - let mut request = BTreeMap::new(); - request.insert(Value::Text("resume_write_channel_query".to_string()), Value::Map(resume_write_channel_query)); - - self.send_cbor_request(request).await?; - - // Wait for ResumeWriteChannelQueryReply using event sink - let mut event_sink = self.event_sink(); - - loop { - let response = event_sink.recv().await - .ok_or_else(|| ThinClientError::Other("Event sink closed".to_string()))?; - - if let Some(Value::Map(reply)) = response.get(&Value::Text("resume_write_channel_query_reply".to_string())) { - // Check for error first - if let Some(Value::Integer(error_code)) = reply.get(&Value::Text("error_code".to_string())) { - if *error_code != 0 { - return Err(ThinClientError::Other(format!("ResumeWriteChannelQuery failed with error code: {}", error_code))); - } - } - - if let Some(Value::Text(err)) = reply.get(&Value::Text("err".to_string())) { - return Err(ThinClientError::Other(format!("ResumeWriteChannelQuery failed: {}", err))); - } - - let channel_id = reply.get(&Value::Text("channel_id".to_string())) - .and_then(|v| match v { Value::Integer(i) => Some(*i as u16), _ => None }) - .ok_or_else(|| ThinClientError::Other("Missing channel_id in response".to_string()))?; - - return Ok(channel_id); - } - - // If we get here, it wasn't the reply we were looking for - } - } - - /// Resumes a read channel with a specific query state. - /// This method provides more granular resumption control than ResumeReadChannel - /// by allowing the application to resume from a specific query state, including - /// the envelope descriptor and hash. This is useful when resuming from a partially - /// completed read operation that was interrupted during transmission. - /// Returns channel_id on success. - pub async fn resume_read_channel_query( - &self, - read_cap: Vec, - next_message_index: Vec, - reply_index: Option, - envelope_descriptor: Vec, - envelope_hash: Vec, - ) -> Result { - let query_id = Self::new_query_id(); - - let mut resume_read_channel_query = BTreeMap::new(); - resume_read_channel_query.insert(Value::Text("query_id".to_string()), Value::Bytes(query_id.clone())); - resume_read_channel_query.insert(Value::Text("read_cap".to_string()), Value::Bytes(read_cap)); - resume_read_channel_query.insert(Value::Text("next_message_index".to_string()), Value::Bytes(next_message_index)); - if let Some(index) = reply_index { - resume_read_channel_query.insert(Value::Text("reply_index".to_string()), Value::Integer(index.into())); - } - resume_read_channel_query.insert(Value::Text("envelope_descriptor".to_string()), Value::Bytes(envelope_descriptor)); - resume_read_channel_query.insert(Value::Text("envelope_hash".to_string()), Value::Bytes(envelope_hash)); - - let mut request = BTreeMap::new(); - request.insert(Value::Text("resume_read_channel_query".to_string()), Value::Map(resume_read_channel_query)); - - self.send_cbor_request(request).await?; - - // Wait for ResumeReadChannelQueryReply using event sink - let mut event_sink = self.event_sink(); - - loop { - let response = event_sink.recv().await - .ok_or_else(|| ThinClientError::Other("Event sink closed".to_string()))?; - - if let Some(Value::Map(reply)) = response.get(&Value::Text("resume_read_channel_query_reply".to_string())) { - // Check for error first - if let Some(Value::Integer(error_code)) = reply.get(&Value::Text("error_code".to_string())) { - if *error_code != 0 { - return Err(ThinClientError::Other(format!("ResumeReadChannelQuery failed with error code: {}", error_code))); - } - } - - if let Some(Value::Text(err)) = reply.get(&Value::Text("err".to_string())) { - return Err(ThinClientError::Other(format!("ResumeReadChannelQuery failed: {}", err))); - } - - let channel_id = reply.get(&Value::Text("channel_id".to_string())) - .and_then(|v| match v { Value::Integer(i) => Some(*i as u16), _ => None }) - .ok_or_else(|| ThinClientError::Other("Missing channel_id in response".to_string()))?; - - return Ok(channel_id); - } - - // If we get here, it wasn't the reply we were looking for - } - } - - /// Sends a prepared channel query to the mixnet without waiting for a reply. - pub async fn send_channel_query( - &self, - channel_id: u16, - payload: &[u8], - dest_node: Vec, - dest_queue: Vec, - message_id: Vec, - ) -> Result<(), ThinClientError> { - // Check if we're in offline mode - if !self.is_connected() { - return Err(ThinClientError::OfflineMode("cannot send channel query in offline mode - daemon not connected to mixnet".to_string())); - } - - let mut send_channel_query = BTreeMap::new(); - send_channel_query.insert(Value::Text("message_id".to_string()), Value::Bytes(message_id)); - send_channel_query.insert(Value::Text("channel_id".to_string()), Value::Integer(channel_id.into())); - send_channel_query.insert(Value::Text("destination_id_hash".to_string()), Value::Bytes(dest_node)); - send_channel_query.insert(Value::Text("recipient_queue_id".to_string()), Value::Bytes(dest_queue)); - send_channel_query.insert(Value::Text("payload".to_string()), Value::Bytes(payload.to_vec())); - - let mut request = BTreeMap::new(); - request.insert(Value::Text("send_channel_query".to_string()), Value::Map(send_channel_query)); - - self.send_cbor_request(request).await - } - - /// Sends a channel query and waits for the reply. - /// This combines send_channel_query with event handling to wait for the response. - pub async fn send_channel_query_await_reply( - &self, - channel_id: u16, - payload: &[u8], - dest_node: Vec, - dest_queue: Vec, - message_id: Vec, - ) -> Result, ThinClientError> { - // Create an event sink to listen for the reply - let mut event_sink = self.event_sink(); - - // Send the channel query - self.send_channel_query(channel_id, payload, dest_node, dest_queue, message_id.clone()).await?; - - // Wait for the reply - loop { - match event_sink.recv().await { - Some(response) => { - // Check for ChannelQuerySentEvent first - if let Some(Value::Map(event)) = response.get(&Value::Text("channel_query_sent_event".to_string())) { - if let Some(Value::Bytes(reply_message_id)) = event.get(&Value::Text("message_id".to_string())) { - if reply_message_id == &message_id { - // Check for error in sent event - if let Some(Value::Integer(error_code)) = event.get(&Value::Text("error_code".to_string())) { - if *error_code != 0 { - return Err(ThinClientError::Other(format!("Channel query send failed with error code: {}", error_code))); - } - } - // Continue waiting for the reply - continue; - } - } - } - - // Check for ChannelQueryReplyEvent - if let Some(Value::Map(event)) = response.get(&Value::Text("channel_query_reply_event".to_string())) { - if let Some(Value::Bytes(reply_message_id)) = event.get(&Value::Text("message_id".to_string())) { - if reply_message_id == &message_id { - // Check for error code - if let Some(Value::Integer(error_code)) = event.get(&Value::Text("error_code".to_string())) { - if *error_code != 0 { - return Err(ThinClientError::Other(format!("Channel query failed with error code: {}", error_code))); - } - } - - // Extract the payload - if let Some(Value::Bytes(reply_payload)) = event.get(&Value::Text("payload".to_string())) { - return Ok(reply_payload.clone()); - } else { - return Err(ThinClientError::Other("Missing payload in channel query reply".to_string())); - } - } - } - } - - // Ignore other events and continue waiting - } - None => { - return Err(ThinClientError::Other("Event sink closed while waiting for reply".to_string())); - } - } - } - } - - /// Closes a pigeonhole channel and cleans up its resources. - /// This helps avoid running out of channel IDs by properly releasing them. - pub async fn close_channel(&self, channel_id: u16) -> Result<(), ThinClientError> { - let mut close_channel = BTreeMap::new(); - close_channel.insert(Value::Text("channel_id".to_string()), Value::Integer(channel_id.into())); - - let mut request = BTreeMap::new(); - request.insert(Value::Text("close_channel".to_string()), Value::Map(close_channel)); - - self.send_cbor_request(request).await - } -} - -/// Find a specific mixnet service if it exists. -pub fn find_services(capability: &str, doc: &BTreeMap) -> Vec { - let mut services = Vec::new(); - - let Some(Value::Array(nodes)) = doc.get(&Value::Text("ServiceNodes".to_string())) else { - println!("❌ No 'ServiceNodes' found in PKI document."); - return services; - }; - - for node in nodes { - let Value::Bytes(node_bytes) = node else { continue }; - let Ok(mynode) = from_slice::>(node_bytes) else { continue }; - - // 🔍 Print available capabilities in each node - if let Some(Value::Map(details)) = mynode.get(&Value::Text("Kaetzchen".to_string())) { - println!("🔍 Available Capabilities: {:?}", details.keys()); - } - - let Some(Value::Map(details)) = mynode.get(&Value::Text("Kaetzchen".to_string())) else { continue }; - let Some(Value::Map(service)) = details.get(&Value::Text(capability.to_string())) else { continue }; - let Some(Value::Text(endpoint)) = service.get(&Value::Text("endpoint".to_string())) else { continue }; - - println!("returning a service descriptor!"); - - services.push(ServiceDescriptor { - recipient_queue_id: endpoint.as_bytes().to_vec(), - mix_descriptor: mynode, - }); - } - - services -} - -fn convert_to_pretty_json(value: &Value) -> serde_json::Value { - match value { - Value::Text(s) => serde_json::Value::String(s.clone()), - Value::Integer(i) => json!(*i), - Value::Bytes(b) => json!(hex::encode(b)), // Encode byte arrays as hex strings - Value::Array(arr) => serde_json::Value::Array(arr.iter().map(convert_to_pretty_json).collect()), - Value::Map(map) => { - let converted_map: serde_json::Map = map - .iter() - .map(|(key, value)| { - let key_str = match key { - Value::Text(s) => s.clone(), - _ => format!("{:?}", key), - }; - (key_str, convert_to_pretty_json(value)) - }) - .collect(); - serde_json::Value::Object(converted_map) - } - _ => serde_json::Value::Null, // Handle unexpected CBOR types - } -} - -fn decode_cbor_nodes(nodes: &[Value]) -> Vec { - nodes - .iter() - .filter_map(|node| match node { - Value::Bytes(blob) => serde_cbor::from_slice::>(blob) - .ok() - .map(Value::Map), - _ => Some(node.clone()), // Preserve non-CBOR values as they are - }) - .collect() -} - -/// Pretty prints a PKI document which you can gather from the client -/// with it's `pki_document` method, documented above. -pub fn pretty_print_pki_doc(doc: &BTreeMap) { - let mut new_doc = BTreeMap::new(); - - // Decode "GatewayNodes" - if let Some(Value::Array(gateway_nodes)) = doc.get(&Value::Text("GatewayNodes".to_string())) { - new_doc.insert(Value::Text("GatewayNodes".to_string()), Value::Array(decode_cbor_nodes(gateway_nodes))); - } - - // Decode "ServiceNodes" - if let Some(Value::Array(service_nodes)) = doc.get(&Value::Text("ServiceNodes".to_string())) { - new_doc.insert(Value::Text("ServiceNodes".to_string()), Value::Array(decode_cbor_nodes(service_nodes))); - } - - // Decode "Topology" (flatten nested arrays of CBOR blobs) - if let Some(Value::Array(topology_layers)) = doc.get(&Value::Text("Topology".to_string())) { - let decoded_topology: Vec = topology_layers - .iter() - .flat_map(|layer| match layer { - Value::Array(layer_nodes) => decode_cbor_nodes(layer_nodes), - _ => vec![], - }) - .collect(); - - new_doc.insert(Value::Text("Topology".to_string()), Value::Array(decoded_topology)); - } - - // Copy and decode all other fields that might contain CBOR blobs - for (key, value) in doc.iter() { - if !matches!(key, Value::Text(s) if ["GatewayNodes", "ServiceNodes", "Topology"].contains(&s.as_str())) { - let key_str = key.clone(); - let decoded_value = match value { - Value::Bytes(blob) => serde_cbor::from_slice::>(blob) - .ok() - .map(Value::Map) - .unwrap_or(value.clone()), // Fallback to original if not CBOR - _ => value.clone(), - }; - - new_doc.insert(key_str, decoded_value); - } - } - - // Convert to pretty JSON format right before printing - let pretty_json = convert_to_pretty_json(&Value::Map(new_doc)); - println!("{}", serde_json::to_string_pretty(&pretty_json).unwrap()); -} diff --git a/src/persistent/README.md b/src/persistent/README.md new file mode 100644 index 0000000..0d91e25 --- /dev/null +++ b/src/persistent/README.md @@ -0,0 +1,236 @@ +# Persistent Pigeonhole API + +This module provides a high-level API for pigeonhole messaging with automatic state persistence via SQLite. + +## API Summary + +```rust +// PigeonholeClient +PigeonholeClient::new(client, db) -> Self +PigeonholeClient::new_in_memory(client) -> Result +client.create_channel(name) -> Result +client.import_channel(name, &read_cap) -> Result +client.get_channel(name) -> Result +client.list_channels() -> Result> +client.delete_channel(name) -> Result<()> + +// ChannelHandle - State +channel.name() -> &str +channel.is_owned() -> bool +channel.refresh() -> Result<()> +channel.share_read_capability() -> ReadCapability +channel.write_cap() -> Option<&[u8]> +channel.read_cap() -> &[u8] +channel.write_index() -> Option<&[u8]> +channel.read_index() -> &[u8] + +// ChannelHandle - Messaging +channel.send(&plaintext) -> Result<()> +channel.receive() -> Result> +channel.write_box(&plaintext, &index) -> Result> +channel.read_box(&index) -> Result<(Vec, Vec)> +channel.get_unread_messages() -> Result> +channel.get_all_messages() -> Result> +channel.mark_message_read(id) -> Result<()> + +// ChannelHandle - Tombstones +channel.tombstone_current() -> Result<()> +channel.tombstone_range(count) -> Result + +// ChannelHandle - Copy +channel.copy_stream_builder() -> Result +channel.execute_copy(courier_hash, queue_id) -> Result<()> +channel.cancel_copy(&write_cap_hash) -> Result<()> + +// CopyStreamBuilder +builder.add_payload(&data, &dest_cap, &dest_idx, is_last) -> Result +builder.add_multi_payload(destinations, is_last) -> Result +builder.finish() -> Result +builder.finish_with_courier(&hash, &queue) -> Result +builder.buffer() -> &[u8] +builder.stream_id() -> &[u8; 16] +builder.temp_write_cap() -> &[u8] +``` + +## Overview + +The persistent API simplifies pigeonhole operations by: + +- **Automatic index tracking**: Write and read indices are managed automatically +- **Database persistence**: All state survives restarts +- **Pending message recovery**: Unsent messages can be retried after crashes +- **Message history**: Received messages are stored and can be queried + +## Quick Start + +```rust +use katzenpost_thin_client::persistent::{PigeonholeClient, Database}; + +// Open database and create client +let db = Database::open("my_app.db")?; +let client = PigeonholeClient::new(thin_client, db); + +// Create a channel (you own this - can send and receive) +let mut alice_channel = client.create_channel("alice-inbox").await?; + +// Send a message +alice_channel.send(b"Hello, world!").await?; + +// Share read capability with someone else +let read_cap = alice_channel.share_read_capability(); +println!("Share this: {:?}", read_cap.to_bytes()); +``` + +## Channel Types + +### Owned Channels + +Created with `create_channel()`. You have full read/write access. + +```rust +let mut channel = client.create_channel("my-channel").await?; +channel.send(b"message").await?; // ✓ Can send +let msg = channel.receive().await?; // ✓ Can receive +``` + +### Imported Channels (Read-Only) + +Created by importing someone else's `ReadCapability`. You can only receive. + +```rust +let read_cap = ReadCapability::from_bytes(&shared_bytes)?; +let channel = client.import_channel("friend-channel", &read_cap)?; +let msg = channel.receive().await?; // ✓ Can receive +// channel.send(b"x").await?; // ✗ Error: read-only +``` + +## Core Operations + +### High-Level Send/Receive + +The simplest way to use channels: + +```rust +// Send (owned channels only) +channel.send(b"Hello!").await?; + +// Receive (advances read index automatically) +let plaintext = channel.receive().await?; +``` + +### Low-Level Box Operations + +For precise control over message indices: + +```rust +// Write to a specific box (does NOT advance write index) +let next_idx = channel.write_box(b"payload", &box_index).await?; + +// Read from a specific box (does NOT advance read index) +let (plaintext, next_idx) = channel.read_box(&box_index).await?; +``` + +### Message History + +Query received messages from the database: + +```rust +// Get unread messages +let unread = channel.get_unread_messages()?; + +// Get all messages +let all = channel.get_all_messages()?; + +// Mark as read +channel.mark_message_read(message.id)?; +``` + +## Tombstones (Deletion) + +Tombstones delete messages by writing empty payloads with valid signatures. + +```rust +// Delete the current write position +channel.tombstone_current().await?; + +// Delete a range of boxes (returns count of successful tombstones) +let deleted = channel.tombstone_range(10).await?; +``` + +Reading a tombstoned box returns an empty `Vec`. + +## Copy Streams (Large Payloads) + +For payloads larger than a single box, use `CopyStreamBuilder`: + +```rust +let mut builder = channel.copy_stream_builder().await?; + +// Stream data in chunks (e.g., reading from a file) +while let Some(chunk) = file.read_chunk()? { + let is_last = file.is_eof(); + builder.add_payload(&chunk, &dest_write_cap, &dest_index, is_last).await?; +} + +// Execute the copy command +let boxes_written = builder.finish().await?; +``` + +### Multi-Destination Copy + +Send to multiple destinations efficiently: + +```rust +let destinations = vec![ + (payload1.as_slice(), dest1_write_cap.as_slice(), dest1_index.as_slice()), + (payload2.as_slice(), dest2_write_cap.as_slice(), dest2_index.as_slice()), +]; +builder.add_multi_payload(destinations, true).await?; +``` + +### Crash Recovery + +The `CopyStreamBuilder` exposes its internal buffer for persistence: + +```rust +// After each add_payload call, save the buffer +let buffer = builder.buffer().to_vec(); +db.save_stream_state(&stream_id, &buffer)?; + +// On restart, restore the buffer before continuing +thin_client.set_stream_buffer(&stream_id, &saved_buffer).await?; +``` + +## Database Schema + +Three tables are used: + +| Table | Purpose | +|-------|---------| +| `channels` | Channel state (capabilities, indices, ownership) | +| `pending_messages` | Outgoing messages awaiting confirmation | +| `received_messages` | Incoming messages with read/unread status | + +## Error Handling + +All operations return `Result`: + +```rust +match client.get_channel("nonexistent") { + Ok(ch) => { /* use channel */ } + Err(PigeonholeDbError::ChannelNotFound(name)) => { + println!("Channel {} not found", name); + } + Err(e) => return Err(e.into()), +} +``` + +## Testing + +Use `new_in_memory()` for tests: + +```rust +let client = PigeonholeClient::new_in_memory(thin_client)?; +// All data is lost when client is dropped +``` + diff --git a/src/persistent/channel.rs b/src/persistent/channel.rs new file mode 100644 index 0000000..3aebc2b --- /dev/null +++ b/src/persistent/channel.rs @@ -0,0 +1,1106 @@ +// SPDX-FileCopyrightText: Copyright (C) 2026 David Stainton +// SPDX-License-Identifier: AGPL-3.0-only + +//! High-level Channel API for simplified pigeonhole operations. + +use std::sync::Arc; + +use rand::RngCore; + +use crate::core::ThinClient; +use crate::pigeonhole::TombstoneRangeResult; +use super::db::Database; +use super::error::{PigeonholeDbError, Result}; +use super::models::{Channel as ChannelModel, ReadCapability, ReceivedMessage}; + +/// High-level pigeonhole client with database persistence. +/// +/// This struct provides a simplified API for pigeonhole operations, +/// automatically managing state (indices, capabilities) via SQLite. +pub struct PigeonholeClient { + /// The underlying thin client for network operations. + client: Arc, + /// Database for state persistence. + db: Database, +} + +impl PigeonholeClient { + /// Create a new PigeonholeClient. + /// + /// # Arguments + /// * `client` - The underlying ThinClient for network operations. + /// * `db` - Database handle for state persistence. + pub fn new(client: Arc, db: Database) -> Self { + Self { client, db } + } + + /// Create a new PigeonholeClient with an in-memory database (for testing). + pub fn new_in_memory(client: Arc) -> Result { + let db = Database::open_in_memory()?; + Ok(Self { client, db }) + } + + /// Get a reference to the database. + pub fn db(&self) -> &Database { + &self.db + } + + /// Get a reference to the underlying thin client. + pub fn thin_client(&self) -> &Arc { + &self.client + } + + /// Create a new owned channel. + /// + /// This generates a new keypair and creates a channel that you own + /// (can both send and receive messages). + /// + /// # Arguments + /// * `name` - Human-readable name for the channel. + /// + /// # Returns + /// A `ChannelHandle` for interacting with the channel. + pub async fn create_channel(&self, name: &str) -> Result { + // Generate random seed + let mut seed = [0u8; 32]; + rand::thread_rng().fill_bytes(&mut seed); + + // Create keypair via thin client + let (write_cap, read_cap, first_index) = self.client.new_keypair(&seed).await?; + + // Store in database + let channel = self.db.create_channel(name, &write_cap, &read_cap, &first_index)?; + + Ok(ChannelHandle { + channel, + client: self.client.clone(), + db: self.db.clone(), + }) + } + + /// Import a channel from a shared read capability. + /// + /// This creates a read-only channel that you can receive messages from + /// but cannot send to. + /// + /// # Arguments + /// * `name` - Human-readable name for the channel. + /// * `read_capability` - The shared read capability. + /// + /// # Returns + /// A `ChannelHandle` for interacting with the channel. + pub fn import_channel(&self, name: &str, read_capability: &ReadCapability) -> Result { + let channel = self.db.import_channel(name, &read_capability.read_cap, &read_capability.start_index)?; + + Ok(ChannelHandle { + channel, + client: self.client.clone(), + db: self.db.clone(), + }) + } + + /// Get an existing channel by name. + pub fn get_channel(&self, name: &str) -> Result { + let channel = self.db.get_channel(name)?; + + Ok(ChannelHandle { + channel, + client: self.client.clone(), + db: self.db.clone(), + }) + } + + /// List all channels. + pub fn list_channels(&self) -> Result> { + self.db.list_channels() + } + + /// Delete a channel and all its messages. + pub fn delete_channel(&self, name: &str) -> Result<()> { + self.db.delete_channel(name) + } +} + +/// Handle for interacting with a specific channel. +/// +/// This provides the main send/receive API with automatic state management. +pub struct ChannelHandle { + channel: ChannelModel, + client: Arc, + db: Database, +} + +impl ChannelHandle { + /// Get the channel model. + pub fn channel(&self) -> &ChannelModel { + &self.channel + } + + /// Get the channel name. + pub fn name(&self) -> &str { + &self.channel.name + } + + /// Check if this is an owned channel (can send messages). + pub fn is_owned(&self) -> bool { + self.channel.is_owned + } + + /// Refresh the channel data from the database. + pub fn refresh(&mut self) -> Result<()> { + self.channel = self.db.get_channel_by_id(self.channel.id)?; + Ok(()) + } + + /// Get the read capability for sharing with others. + /// + /// Share this with someone to allow them to read messages from this channel. + pub fn share_read_capability(&self) -> ReadCapability { + ReadCapability { + read_cap: self.channel.read_cap.clone(), + start_index: self.channel.read_index.clone(), + name: Some(self.channel.name.clone()), + } + } + + /// Get the write capability for this channel. + /// + /// Returns the write capability if this is an owned channel, or `None` if + /// this is an imported read-only channel. + /// + /// The write capability is needed for operations like: + /// - The Copy command, which copies data from a temporary channel to a destination + /// - Resuming write operations after a restart + /// - Advanced ARQ scenarios + /// + /// # Security Note + /// The write capability grants full write access to the channel. Only share + /// it with trusted parties or use it in secure contexts like the Copy command. + pub fn write_cap(&self) -> Option<&[u8]> { + self.channel.write_cap.as_deref() + } + + /// Get the read capability bytes for this channel. + /// + /// This returns the raw read capability bytes, which can be used for + /// low-level operations or when you need the capability without the + /// additional metadata included in [`share_read_capability`]. + pub fn read_cap(&self) -> &[u8] { + &self.channel.read_cap + } + + /// Get the current write index for this channel. + /// + /// This is the next message box index that will be used when sending. + /// Returns `None` if this is a read-only channel. + pub fn write_index(&self) -> Option<&[u8]> { + if self.channel.is_owned { + Some(&self.channel.write_index) + } else { + None + } + } + + /// Get the current read index for this channel. + /// + /// This is the next message box index that will be read from. + pub fn read_index(&self) -> &[u8] { + &self.channel.read_index + } + + // ======================================================================== + // Low-Level Box Operations (single box, no state management) + // ======================================================================== + + /// Write a single box payload at a specific index (low-level). + /// + /// This is the low-level primitive for writing to a pigeonhole box. + /// It does NOT update the channel's write index - use this when you need + /// precise control over box indices. + /// + /// # Arguments + /// * `plaintext` - The payload to write. Must be at most + /// `PigeonholeGeometry.max_plaintext_payload_length` bytes. + /// * `box_index` - The specific box index to write to. + /// + /// # Returns + /// The next box index after this write. + /// + /// # Errors + /// Returns an error if this is a read-only channel or the operation fails. + pub async fn write_box(&self, plaintext: &[u8], box_index: &[u8]) -> Result> { + let write_cap = self.channel.write_cap.as_ref().ok_or_else(|| { + PigeonholeDbError::Other("Cannot write on a read-only channel".to_string()) + })?; + + let (message_ciphertext, envelope_descriptor, envelope_hash) = self + .client + .encrypt_write(plaintext, write_cap, box_index) + .await?; + + self.client + .start_resending_encrypted_message( + None, + Some(write_cap), + None, + Some(0), + &envelope_descriptor, + &message_ciphertext, + &envelope_hash, + ) + .await?; + + let next_index = self.client.next_message_box_index(box_index).await?; + Ok(next_index) + } + + /// Write a single box payload at a specific index, returning BoxAlreadyExists as error. + /// + /// Like `write_box`, but returns `BoxAlreadyExistsError` if the box already + /// contains data, instead of treating it as an idempotent success. + /// + /// # Arguments + /// * `plaintext` - The payload to write. + /// * `box_index` - The specific box index to write to. + /// + /// # Returns + /// The next box index after this write. + /// + /// # Errors + /// Returns `BoxAlreadyExistsError` if the box is already written. + pub async fn write_box_return_box_exists(&self, plaintext: &[u8], box_index: &[u8]) -> Result> { + let write_cap = self.channel.write_cap.as_ref().ok_or_else(|| { + PigeonholeDbError::Other("Cannot write on a read-only channel".to_string()) + })?; + + let (message_ciphertext, envelope_descriptor, envelope_hash) = self + .client + .encrypt_write(plaintext, write_cap, box_index) + .await?; + + self.client + .start_resending_encrypted_message_return_box_exists( + None, + Some(write_cap), + None, + Some(0), + &envelope_descriptor, + &message_ciphertext, + &envelope_hash, + ) + .await?; + + let next_index = self.client.next_message_box_index(box_index).await?; + Ok(next_index) + } + + /// Read a single box payload at a specific index (low-level). + /// + /// This is the low-level primitive for reading from a pigeonhole box. + /// It does NOT update the channel's read index - use this when you need + /// precise control over box indices. + /// + /// # Arguments + /// * `box_index` - The specific box index to read from. + /// + /// # Returns + /// A tuple of (plaintext, next_box_index). + /// + /// # Errors + /// Returns an error if the read operation fails. + pub async fn read_box(&self, box_index: &[u8]) -> Result<(Vec, Vec)> { + let (message_ciphertext, next_message_index, envelope_descriptor, envelope_hash) = self + .client + .encrypt_read(&self.channel.read_cap, box_index) + .await?; + + let plaintext = self + .client + .start_resending_encrypted_message( + Some(&self.channel.read_cap), + None, + Some(&next_message_index), + Some(0), + &envelope_descriptor, + &message_ciphertext, + &envelope_hash, + ) + .await?; + + let next_index = self.client.next_message_box_index(box_index).await?; + Ok((plaintext, next_index)) + } + + /// Read a single box without automatic retries on BoxIDNotFound. + /// + /// Like `read_box`, but returns `BoxIDNotFoundError` immediately instead + /// of retrying (which normally accounts for mixnet replication lag). + /// + /// Use this when you need to quickly check if a box exists without waiting + /// for potential retries. + /// + /// # Arguments + /// * `box_index` - The specific box index to read from. + /// + /// # Returns + /// A tuple of (plaintext, next_box_index). + /// + /// # Errors + /// Returns `BoxIDNotFoundError` immediately if box doesn't exist. + pub async fn read_box_no_retry(&self, box_index: &[u8]) -> Result<(Vec, Vec)> { + let (message_ciphertext, next_message_index, envelope_descriptor, envelope_hash) = self + .client + .encrypt_read(&self.channel.read_cap, box_index) + .await?; + + let plaintext = self + .client + .start_resending_encrypted_message_no_retry( + Some(&self.channel.read_cap), + None, + Some(&next_message_index), + Some(0), + &envelope_descriptor, + &message_ciphertext, + &envelope_hash, + ) + .await?; + + let next_index = self.client.next_message_box_index(box_index).await?; + Ok((plaintext, next_index)) + } + + // ======================================================================== + // High-Level Send/Receive (with state management) + // ======================================================================== + + /// Send a message on this channel (high-level). + /// + /// This method: + /// 1. Encrypts the message using the current write index + /// 2. Stores it as a pending message in the database + /// 3. Sends it via ARQ (automatic repeat request) + /// 4. Updates the write index on success + /// 5. Removes the pending message on success + /// + /// # Plaintext Size Constraint + /// + /// The `plaintext` must not exceed `PigeonholeGeometry.max_plaintext_payload_length` bytes. + /// For larger payloads, use the copy stream API via `CopyStreamBuilder`. + /// + /// # Arguments + /// * `plaintext` - The message to send. + /// + /// # Errors + /// Returns an error if this is a read-only channel or the operation fails. + pub async fn send(&mut self, plaintext: &[u8]) -> Result<()> { + let write_cap = self.channel.write_cap.as_ref().ok_or_else(|| { + PigeonholeDbError::Other("Cannot send on a read-only channel".to_string()) + })?; + + let (message_ciphertext, envelope_descriptor, envelope_hash) = self + .client + .encrypt_write(plaintext, write_cap, &self.channel.write_index) + .await?; + + let pending = self.db.create_pending_message( + self.channel.id, + plaintext, + &message_ciphertext, + &envelope_descriptor, + &envelope_hash, + &self.channel.write_index, + )?; + + self.db.update_pending_message_status(pending.id, "sending")?; + + let result = self + .client + .start_resending_encrypted_message( + None, + Some(write_cap), + None, + Some(0), + &envelope_descriptor, + &message_ciphertext, + &envelope_hash, + ) + .await; + + match result { + Ok(_) => { + let next_index = self.client.next_message_box_index(&self.channel.write_index).await?; + self.db.update_write_index(self.channel.id, &next_index)?; + self.db.delete_pending_message(pending.id)?; + self.channel.write_index = next_index; + Ok(()) + } + Err(e) => { + self.db.update_pending_message_status(pending.id, "failed")?; + Err(e.into()) + } + } + } + + /// Send a message, returning BoxAlreadyExists as error if box is occupied. + /// + /// Like `send`, but returns `BoxAlreadyExistsError` if the box already + /// contains data, instead of treating it as an idempotent success. + /// + /// # Arguments + /// * `plaintext` - The message to send. + /// + /// # Errors + /// Returns `BoxAlreadyExistsError` if the box is already written. + pub async fn send_return_box_exists(&mut self, plaintext: &[u8]) -> Result<()> { + let write_cap = self.channel.write_cap.as_ref().ok_or_else(|| { + PigeonholeDbError::Other("Cannot send on a read-only channel".to_string()) + })?; + + let (message_ciphertext, envelope_descriptor, envelope_hash) = self + .client + .encrypt_write(plaintext, write_cap, &self.channel.write_index) + .await?; + + let pending = self.db.create_pending_message( + self.channel.id, + plaintext, + &message_ciphertext, + &envelope_descriptor, + &envelope_hash, + &self.channel.write_index, + )?; + + self.db.update_pending_message_status(pending.id, "sending")?; + + let result = self + .client + .start_resending_encrypted_message_return_box_exists( + None, + Some(write_cap), + None, + Some(0), + &envelope_descriptor, + &message_ciphertext, + &envelope_hash, + ) + .await; + + match result { + Ok(_) => { + let next_index = self.client.next_message_box_index(&self.channel.write_index).await?; + self.db.update_write_index(self.channel.id, &next_index)?; + self.db.delete_pending_message(pending.id)?; + self.channel.write_index = next_index; + Ok(()) + } + Err(e) => { + self.db.update_pending_message_status(pending.id, "failed")?; + Err(e.into()) + } + } + } + + /// Receive the next message from this channel (high-level). + /// + /// This method reads from the current read index, stores the message, + /// and advances the read index. + /// + /// # Returns + /// The decrypted message plaintext. + /// + /// # Errors + /// Returns an error if the read operation fails. + pub async fn receive(&mut self) -> Result> { + let (message_ciphertext, next_message_index, envelope_descriptor, envelope_hash) = self + .client + .encrypt_read(&self.channel.read_cap, &self.channel.read_index) + .await?; + + let plaintext = self + .client + .start_resending_encrypted_message( + Some(&self.channel.read_cap), + None, + Some(&next_message_index), + Some(0), + &envelope_descriptor, + &message_ciphertext, + &envelope_hash, + ) + .await?; + + self.db.create_received_message( + self.channel.id, + &plaintext, + &self.channel.read_index, + )?; + + let next_index = self.client.next_message_box_index(&self.channel.read_index).await?; + self.db.update_read_index(self.channel.id, &next_index)?; + self.channel.read_index = next_index; + + Ok(plaintext) + } + + /// Receive the next message without automatic retries on BoxIDNotFound. + /// + /// Like `receive`, but returns `BoxIDNotFoundError` immediately instead + /// of retrying (which normally accounts for mixnet replication lag). + /// + /// Use this when you need to quickly check if a message exists without + /// waiting for potential retries. + /// + /// # Returns + /// The decrypted message plaintext. + /// + /// # Errors + /// Returns `BoxIDNotFoundError` immediately if no message exists. + pub async fn receive_no_retry(&mut self) -> Result> { + let (message_ciphertext, next_message_index, envelope_descriptor, envelope_hash) = self + .client + .encrypt_read(&self.channel.read_cap, &self.channel.read_index) + .await?; + + let plaintext = self + .client + .start_resending_encrypted_message_no_retry( + Some(&self.channel.read_cap), + None, + Some(&next_message_index), + Some(0), + &envelope_descriptor, + &message_ciphertext, + &envelope_hash, + ) + .await?; + + self.db.create_received_message( + self.channel.id, + &plaintext, + &self.channel.read_index, + )?; + + let next_index = self.client.next_message_box_index(&self.channel.read_index).await?; + self.db.update_read_index(self.channel.id, &next_index)?; + self.channel.read_index = next_index; + + Ok(plaintext) + } + + /// Get unread messages from the database (already received). + pub fn get_unread_messages(&self) -> Result> { + self.db.get_unread_messages(self.channel.id) + } + + /// Get all received messages from the database. + pub fn get_all_messages(&self) -> Result> { + self.db.get_all_messages(self.channel.id) + } + + /// Mark a message as read. + pub fn mark_message_read(&self, message_id: i64) -> Result<()> { + self.db.mark_message_read(message_id) + } + + // ======================================================================== + // Tombstone Operations + // ======================================================================== + + /// Tombstone (delete) the current write position. + /// + /// This writes an empty payload to the current write index, effectively + /// deleting the message at that position. The write index is then advanced. + /// + /// # Errors + /// Returns an error if this is a read-only channel or the operation fails. + pub async fn tombstone_current(&mut self) -> Result<()> { + let write_cap = self.channel.write_cap.as_ref().ok_or_else(|| { + PigeonholeDbError::Other("Cannot tombstone on a read-only channel".to_string()) + })?; + + // Create and send the tombstone + let (ciphertext, env_desc, env_hash) = self + .client + .tombstone_box(write_cap, &self.channel.write_index) + .await?; + + let mut hash_arr = [0u8; 32]; + hash_arr.copy_from_slice(&env_hash); + + self.client + .start_resending_encrypted_message( + None, + Some(write_cap), + None, + None, // No reply expected for tombstone + &env_desc, + &ciphertext, + &hash_arr, + ) + .await?; + + // Update write index + let next_index = self.client.next_message_box_index(&self.channel.write_index).await?; + self.db.update_write_index(self.channel.id, &next_index)?; + self.channel.write_index = next_index; + + Ok(()) + } + + /// Tombstone a range of boxes starting from the current write position. + /// + /// This creates tombstones for up to `count` boxes and sends them all. + /// The write index is advanced past all tombstoned boxes. + /// + /// # Arguments + /// * `count` - Maximum number of boxes to tombstone. + /// + /// # Returns + /// The number of boxes successfully tombstoned. + /// + /// # Errors + /// Returns an error if this is a read-only channel. + pub async fn tombstone_range(&mut self, count: u32) -> Result { + let write_cap = self.channel.write_cap.as_ref().ok_or_else(|| { + PigeonholeDbError::Other("Cannot tombstone on a read-only channel".to_string()) + })?; + + let result: TombstoneRangeResult = self + .client + .tombstone_range(write_cap, &self.channel.write_index, count) + .await; + + let mut sent_count = 0u32; + + // Send all the tombstone envelopes + for envelope in &result.envelopes { + let mut hash_arr = [0u8; 32]; + hash_arr.copy_from_slice(&envelope.envelope_hash); + + match self.client.start_resending_encrypted_message( + None, + Some(write_cap), + None, + None, + &envelope.envelope_descriptor, + &envelope.message_ciphertext, + &hash_arr, + ).await { + Ok(_) => sent_count += 1, + Err(e) => { + // Update write index to where we got to + if sent_count > 0 { + self.db.update_write_index(self.channel.id, &envelope.box_index)?; + self.channel.write_index = envelope.box_index.clone(); + } + return Err(e.into()); + } + } + } + + // Update write index to the final position + if sent_count > 0 { + self.db.update_write_index(self.channel.id, &result.next)?; + self.channel.write_index = result.next; + } + + Ok(sent_count) + } + + /// Tombstone (delete) a specific box by its index. + /// + /// This writes an empty payload to the specified box index, effectively + /// deleting the message at that position. This does NOT update the channel's + /// write index - use this when you need to delete a specific previously-written box. + /// + /// # Arguments + /// * `box_index` - The specific box index to tombstone. + /// + /// # Errors + /// Returns an error if this is a read-only channel or the operation fails. + pub async fn tombstone_at(&self, box_index: &[u8]) -> Result<()> { + let write_cap = self.channel.write_cap.as_ref().ok_or_else(|| { + PigeonholeDbError::Other("Cannot tombstone on a read-only channel".to_string()) + })?; + + // Create and send the tombstone + let (ciphertext, env_desc, env_hash) = self + .client + .tombstone_box(write_cap, box_index) + .await?; + + let mut hash_arr = [0u8; 32]; + hash_arr.copy_from_slice(&env_hash); + + self.client + .start_resending_encrypted_message( + None, + Some(write_cap), + None, + None, // No reply expected for tombstone + &env_desc, + &ciphertext, + &hash_arr, + ) + .await?; + + Ok(()) + } + + /// Tombstone a range of boxes starting from a specific index. + /// + /// This creates tombstones for up to `count` boxes starting from `start_index` + /// and sends them all. This does NOT update the channel's write index - use this + /// when you need to delete specific previously-written boxes. + /// + /// # Arguments + /// * `start_index` - The box index to start tombstoning from. + /// * `count` - Maximum number of boxes to tombstone. + /// + /// # Returns + /// The number of boxes successfully tombstoned. + /// + /// # Errors + /// Returns an error if this is a read-only channel. + pub async fn tombstone_from(&self, start_index: &[u8], count: u32) -> Result { + let write_cap = self.channel.write_cap.as_ref().ok_or_else(|| { + PigeonholeDbError::Other("Cannot tombstone on a read-only channel".to_string()) + })?; + + let result: TombstoneRangeResult = self + .client + .tombstone_range(write_cap, start_index, count) + .await; + + let mut sent_count = 0u32; + + // Send all the tombstone envelopes + for envelope in &result.envelopes { + let mut hash_arr = [0u8; 32]; + hash_arr.copy_from_slice(&envelope.envelope_hash); + + match self.client.start_resending_encrypted_message( + None, + Some(write_cap), + None, + None, + &envelope.envelope_descriptor, + &envelope.message_ciphertext, + &hash_arr, + ).await { + Ok(_) => sent_count += 1, + Err(e) => { + return Err(e.into()); + } + } + } + + Ok(sent_count) + } + + // ======================================================================== + // Copy Stream Operations + // ======================================================================== + + /// Create a new CopyStreamBuilder for streaming large payloads. + /// + /// Use this for payloads of any size. The builder allows you to add + /// payloads incrementally (streaming from disk, network, etc.) without + /// loading everything into memory at once. + /// + /// # Example + /// ```ignore + /// let mut builder = channel.copy_stream_builder().await?; + /// + /// // Stream data in chunks (e.g., reading from a file) + /// while let Some(chunk) = file.read_chunk() { + /// builder.add_payload(&chunk, dest_write_cap, dest_start_index, false).await?; + /// } + /// + /// // Finalize and execute the copy + /// builder.finish().await?; + /// ``` + pub async fn copy_stream_builder(&self) -> Result { + CopyStreamBuilder::new(self.client.clone()).await + } + + /// Execute a Copy command using this channel's write capability as the source. + /// + /// This is useful when this channel has been used as a temporary copy stream + /// and you want to trigger the courier to copy from it to the destination(s) + /// encoded in the stream. + /// + /// # Arguments + /// * `courier_identity_hash` - Optional specific courier to use. + /// * `courier_queue_id` - Optional queue ID for the specific courier. + /// + /// # Errors + /// Returns an error if this is a read-only channel or the operation fails. + pub async fn execute_copy( + &self, + courier_identity_hash: Option<&[u8]>, + courier_queue_id: Option<&[u8]>, + ) -> Result<()> { + let write_cap = self.channel.write_cap.as_ref().ok_or_else(|| { + PigeonholeDbError::Other("Cannot execute copy on a read-only channel".to_string()) + })?; + + self.client + .start_resending_copy_command(write_cap, courier_identity_hash, courier_queue_id) + .await?; + + Ok(()) + } + + /// Cancel a Copy command in progress. + /// + /// This stops the automatic repeat request (ARQ) for a previously started + /// copy command. + /// + /// # Arguments + /// * `write_cap_hash` - 32-byte hash of the WriteCap used in execute_copy. + /// + /// # Errors + /// Returns an error if the operation fails. + pub async fn cancel_copy(&self, write_cap_hash: &[u8; 32]) -> Result<()> { + self.client + .cancel_resending_copy_command(write_cap_hash) + .await?; + + Ok(()) + } +} + +/// Builder for creating copy streams that can handle arbitrarily large payloads. +/// +/// This builder uses the daemon's internal buffer (correlated by stream ID) to +/// efficiently pack data into the temporary channel. You can call `add_payload` +/// multiple times with chunks of data, and the daemon handles the packing. +/// +/// # Memory Efficiency +/// Unlike passing large buffers, this approach: +/// - Uses stream ID correlation to maintain state in the daemon +/// - Allows streaming data from disk/network without loading everything into memory +/// - Packs multiple payloads efficiently into copy stream boxes +/// +/// # Crash Recovery +/// When `is_last=false` is passed to `add_payload` or `add_multi_payload`, partial +/// data may be buffered by the daemon. The buffer is saved after each call and can +/// be accessed via `buffer()`. To recover after a crash, persist the buffer and +/// restore it via `ThinClient::set_stream_buffer` before continuing the stream. +/// +/// # Example +/// ```ignore +/// let mut builder = channel.copy_stream_builder().await?; +/// +/// // Add multiple payloads to different destinations +/// builder.add_payload(payload1, dest1_write_cap, dest1_index, false).await?; +/// builder.add_payload(payload2, dest2_write_cap, dest2_index, false).await?; +/// +/// // Finalize and send the copy command +/// let boxes_written = builder.finish().await?; +/// ``` +pub struct CopyStreamBuilder { + client: Arc, + stream_id: [u8; 16], + temp_write_cap: Vec, + temp_index: Vec, + total_boxes: usize, + /// Buffer containing data that hasn't been output yet. + /// This can be persisted for crash recovery. + buffer: Vec, +} + +impl CopyStreamBuilder { + /// Create a new CopyStreamBuilder. + async fn new(client: Arc) -> Result { + let mut seed = [0u8; 32]; + rand::thread_rng().fill_bytes(&mut seed); + let (temp_write_cap, _temp_read_cap, temp_first_index) = + client.new_keypair(&seed).await?; + + Ok(Self { + client, + stream_id: ThinClient::new_stream_id(), + temp_write_cap, + temp_index: temp_first_index, + total_boxes: 0, + buffer: Vec::new(), + }) + } + + /// Add a payload to the copy stream. + /// + /// This can be called multiple times to stream data incrementally. + /// Each call creates courier envelopes and writes them to the temporary + /// channel immediately. + /// + /// # Arguments + /// * `payload` - The payload chunk to add (max 10MB per call). + /// * `dest_write_cap` - Write capability for the destination. + /// * `dest_start_index` - Starting index in the destination. + /// * `is_last` - True if this is the final payload for this destination. + /// + /// # Returns + /// The number of boxes written for this payload. + pub async fn add_payload( + &mut self, + payload: &[u8], + dest_write_cap: &[u8], + dest_start_index: &[u8], + is_last: bool, + ) -> Result { + let result = self.client.create_courier_envelopes_from_payload( + &self.stream_id, + payload, + dest_write_cap, + dest_start_index, + is_last, + ).await?; + + let chunk_count = result.envelopes.len(); + + // Save the buffer for crash recovery + self.buffer = result.buffer; + + for chunk in result.envelopes { + let (ciphertext, env_desc, env_hash) = self + .client + .encrypt_write(&chunk, &self.temp_write_cap, &self.temp_index) + .await?; + + self.client + .start_resending_encrypted_message( + None, + Some(&self.temp_write_cap), + None, + Some(0), + &env_desc, + &ciphertext, + &env_hash, + ) + .await?; + + self.temp_index = self.client.next_message_box_index(&self.temp_index).await?; + } + + self.total_boxes += chunk_count; + Ok(chunk_count) + } + + /// Add multiple payloads to different destinations efficiently. + /// + /// This packs all payloads together, which is more space-efficient than + /// calling `add_payload` multiple times because envelopes from different + /// destinations are packed together without wasting space. + /// + /// # Arguments + /// * `destinations` - List of (payload, dest_write_cap, dest_start_index) tuples. + /// * `is_last` - True if this is the final set of payloads. + /// + /// # Returns + /// The number of boxes written. + pub async fn add_multi_payload( + &mut self, + destinations: Vec<(&[u8], &[u8], &[u8])>, + is_last: bool, + ) -> Result { + if destinations.is_empty() { + return Ok(0); + } + + let result = self.client.create_courier_envelopes_from_multi_payload( + &self.stream_id, + destinations, + is_last, + ).await?; + + let chunk_count = result.envelopes.len(); + + // Save the buffer for crash recovery + self.buffer = result.buffer; + + for chunk in result.envelopes { + let (ciphertext, env_desc, env_hash) = self + .client + .encrypt_write(&chunk, &self.temp_write_cap, &self.temp_index) + .await?; + + self.client + .start_resending_encrypted_message( + None, + Some(&self.temp_write_cap), + None, + Some(0), + &env_desc, + &ciphertext, + &env_hash, + ) + .await?; + + self.temp_index = self.client.next_message_box_index(&self.temp_index).await?; + } + + self.total_boxes += chunk_count; + Ok(chunk_count) + } + + /// Finalize the copy stream and execute the Copy command. + /// + /// This sends the Copy command to the courier, which will read the + /// temporary channel and execute all the write operations atomically. + /// + /// # Returns + /// The total number of boxes written to the temporary channel. + pub async fn finish(self) -> Result { + self.client + .start_resending_copy_command(&self.temp_write_cap, None, None) + .await?; + + Ok(self.total_boxes) + } + + /// Finalize with a specific courier. + /// + /// # Arguments + /// * `courier_identity_hash` - Identity hash of the courier to use. + /// * `courier_queue_id` - Queue ID for the courier. + pub async fn finish_with_courier( + self, + courier_identity_hash: &[u8], + courier_queue_id: &[u8], + ) -> Result { + self.client + .start_resending_copy_command( + &self.temp_write_cap, + Some(courier_identity_hash), + Some(courier_queue_id), + ) + .await?; + + Ok(self.total_boxes) + } + + /// Get the temporary channel's write capability. + /// + /// This can be used to cancel the copy operation if needed. + pub fn temp_write_cap(&self) -> &[u8] { + &self.temp_write_cap + } + + /// Get the stream ID for this copy stream. + pub fn stream_id(&self) -> &[u8; 16] { + &self.stream_id + } + + /// Get the current buffer contents for crash recovery. + /// + /// When `is_last=false` is passed to `add_payload` or `add_multi_payload`, + /// partial data may be buffered. This buffer can be persisted and restored + /// via `ThinClient::set_stream_buffer` on restart to continue the stream. + pub fn buffer(&self) -> &[u8] { + &self.buffer + } +} + diff --git a/src/persistent/db.rs b/src/persistent/db.rs new file mode 100644 index 0000000..180ec48 --- /dev/null +++ b/src/persistent/db.rs @@ -0,0 +1,464 @@ +// SPDX-FileCopyrightText: Copyright (C) 2026 David Stainton +// SPDX-License-Identifier: AGPL-3.0-only + +//! Database layer for pigeonhole state persistence. + +use std::path::Path; +use std::sync::{Arc, Mutex}; +use std::time::{SystemTime, UNIX_EPOCH}; + +use rusqlite::{Connection, params}; + +use super::error::{PigeonholeDbError, Result}; +use super::models::{Channel, PendingMessage, ReceivedMessage}; + +/// Database handle for pigeonhole state. +/// +/// This struct manages SQLite database operations for storing +/// channels, pending messages, and received messages. +#[derive(Clone)] +pub struct Database { + conn: Arc>, +} + +impl Database { + /// Open or create a database at the given path. + pub fn open>(path: P) -> Result { + let conn = Connection::open(path)?; + let db = Self { + conn: Arc::new(Mutex::new(conn)), + }; + db.init_schema()?; + Ok(db) + } + + /// Open an in-memory database (useful for testing). + pub fn open_in_memory() -> Result { + let conn = Connection::open_in_memory()?; + let db = Self { + conn: Arc::new(Mutex::new(conn)), + }; + db.init_schema()?; + Ok(db) + } + + /// Initialize the database schema. + fn init_schema(&self) -> Result<()> { + let conn = self.conn.lock().unwrap(); + conn.execute_batch( + r#" + CREATE TABLE IF NOT EXISTS channels ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL UNIQUE, + write_cap BLOB, + read_cap BLOB NOT NULL, + write_index BLOB NOT NULL, + read_index BLOB NOT NULL, + is_owned INTEGER NOT NULL DEFAULT 1, + created_at INTEGER NOT NULL, + updated_at INTEGER NOT NULL + ); + + CREATE TABLE IF NOT EXISTS pending_messages ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + channel_id INTEGER NOT NULL, + plaintext BLOB NOT NULL, + message_ciphertext BLOB NOT NULL, + envelope_descriptor BLOB NOT NULL, + envelope_hash BLOB NOT NULL UNIQUE, + box_index BLOB NOT NULL, + attempts INTEGER NOT NULL DEFAULT 0, + status TEXT NOT NULL DEFAULT 'pending', + created_at INTEGER NOT NULL, + last_attempt_at INTEGER, + FOREIGN KEY (channel_id) REFERENCES channels(id) ON DELETE CASCADE + ); + + CREATE TABLE IF NOT EXISTS received_messages ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + channel_id INTEGER NOT NULL, + plaintext BLOB NOT NULL, + box_index BLOB NOT NULL, + received_at INTEGER NOT NULL, + is_read INTEGER NOT NULL DEFAULT 0, + FOREIGN KEY (channel_id) REFERENCES channels(id) ON DELETE CASCADE + ); + + CREATE INDEX IF NOT EXISTS idx_pending_status ON pending_messages(status); + CREATE INDEX IF NOT EXISTS idx_pending_channel ON pending_messages(channel_id); + CREATE INDEX IF NOT EXISTS idx_received_channel ON received_messages(channel_id); + CREATE INDEX IF NOT EXISTS idx_received_unread ON received_messages(is_read); + "#, + )?; + Ok(()) + } + + /// Get current Unix timestamp. + fn now() -> i64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() as i64 + } + + // ======================================================================== + // Channel Operations + // ======================================================================== + + /// Create a new owned channel. + pub fn create_channel( + &self, + name: &str, + write_cap: &[u8], + read_cap: &[u8], + first_index: &[u8], + ) -> Result { + let conn = self.conn.lock().unwrap(); + let now = Self::now(); + + conn.execute( + r#"INSERT INTO channels (name, write_cap, read_cap, write_index, read_index, is_owned, created_at, updated_at) + VALUES (?1, ?2, ?3, ?4, ?5, 1, ?6, ?7)"#, + params![name, write_cap, read_cap, first_index, first_index, now, now], + ).map_err(|e| { + if let rusqlite::Error::SqliteFailure(ref err, _) = e { + if err.code == rusqlite::ErrorCode::ConstraintViolation { + return PigeonholeDbError::ChannelAlreadyExists(name.to_string()); + } + } + PigeonholeDbError::Database(e) + })?; + + let id = conn.last_insert_rowid(); + Ok(Channel { + id, + name: name.to_string(), + write_cap: Some(write_cap.to_vec()), + read_cap: read_cap.to_vec(), + write_index: first_index.to_vec(), + read_index: first_index.to_vec(), + is_owned: true, + created_at: now, + updated_at: now, + }) + } + + /// Import a read-only channel from a shared read capability. + pub fn import_channel( + &self, + name: &str, + read_cap: &[u8], + start_index: &[u8], + ) -> Result { + let conn = self.conn.lock().unwrap(); + let now = Self::now(); + + conn.execute( + r#"INSERT INTO channels (name, write_cap, read_cap, write_index, read_index, is_owned, created_at, updated_at) + VALUES (?1, NULL, ?2, ?3, ?4, 0, ?5, ?6)"#, + params![name, read_cap, start_index, start_index, now, now], + ).map_err(|e| { + if let rusqlite::Error::SqliteFailure(ref err, _) = e { + if err.code == rusqlite::ErrorCode::ConstraintViolation { + return PigeonholeDbError::ChannelAlreadyExists(name.to_string()); + } + } + PigeonholeDbError::Database(e) + })?; + + let id = conn.last_insert_rowid(); + Ok(Channel { + id, + name: name.to_string(), + write_cap: None, + read_cap: read_cap.to_vec(), + write_index: start_index.to_vec(), + read_index: start_index.to_vec(), + is_owned: false, + created_at: now, + updated_at: now, + }) + } + + /// Get a channel by name. + pub fn get_channel(&self, name: &str) -> Result { + let conn = self.conn.lock().unwrap(); + let mut stmt = conn.prepare( + "SELECT id, name, write_cap, read_cap, write_index, read_index, is_owned, created_at, updated_at FROM channels WHERE name = ?1" + )?; + + stmt.query_row(params![name], |row| { + Ok(Channel { + id: row.get(0)?, + name: row.get(1)?, + write_cap: row.get(2)?, + read_cap: row.get(3)?, + write_index: row.get(4)?, + read_index: row.get(5)?, + is_owned: row.get::<_, i64>(6)? != 0, + created_at: row.get(7)?, + updated_at: row.get(8)?, + }) + }).map_err(|e| match e { + rusqlite::Error::QueryReturnedNoRows => PigeonholeDbError::ChannelNotFound(name.to_string()), + _ => PigeonholeDbError::Database(e), + }) + } + + /// Get a channel by ID. + pub fn get_channel_by_id(&self, id: i64) -> Result { + let conn = self.conn.lock().unwrap(); + let mut stmt = conn.prepare( + "SELECT id, name, write_cap, read_cap, write_index, read_index, is_owned, created_at, updated_at FROM channels WHERE id = ?1" + )?; + + stmt.query_row(params![id], |row| { + Ok(Channel { + id: row.get(0)?, + name: row.get(1)?, + write_cap: row.get(2)?, + read_cap: row.get(3)?, + write_index: row.get(4)?, + read_index: row.get(5)?, + is_owned: row.get::<_, i64>(6)? != 0, + created_at: row.get(7)?, + updated_at: row.get(8)?, + }) + }).map_err(|e| match e { + rusqlite::Error::QueryReturnedNoRows => PigeonholeDbError::ChannelNotFound(format!("id={}", id)), + _ => PigeonholeDbError::Database(e), + }) + } + + /// List all channels. + pub fn list_channels(&self) -> Result> { + let conn = self.conn.lock().unwrap(); + let mut stmt = conn.prepare( + "SELECT id, name, write_cap, read_cap, write_index, read_index, is_owned, created_at, updated_at FROM channels ORDER BY name" + )?; + + let channels = stmt.query_map([], |row| { + Ok(Channel { + id: row.get(0)?, + name: row.get(1)?, + write_cap: row.get(2)?, + read_cap: row.get(3)?, + write_index: row.get(4)?, + read_index: row.get(5)?, + is_owned: row.get::<_, i64>(6)? != 0, + created_at: row.get(7)?, + updated_at: row.get(8)?, + }) + })?.collect::, _>>()?; + + Ok(channels) + } + + /// Update the write index for a channel. + pub fn update_write_index(&self, channel_id: i64, new_index: &[u8]) -> Result<()> { + let conn = self.conn.lock().unwrap(); + let now = Self::now(); + conn.execute( + "UPDATE channels SET write_index = ?1, updated_at = ?2 WHERE id = ?3", + params![new_index, now, channel_id], + )?; + Ok(()) + } + + /// Update the read index for a channel. + pub fn update_read_index(&self, channel_id: i64, new_index: &[u8]) -> Result<()> { + let conn = self.conn.lock().unwrap(); + let now = Self::now(); + conn.execute( + "UPDATE channels SET read_index = ?1, updated_at = ?2 WHERE id = ?3", + params![new_index, now, channel_id], + )?; + Ok(()) + } + + /// Delete a channel and all its messages. + pub fn delete_channel(&self, name: &str) -> Result<()> { + let conn = self.conn.lock().unwrap(); + let rows = conn.execute("DELETE FROM channels WHERE name = ?1", params![name])?; + if rows == 0 { + return Err(PigeonholeDbError::ChannelNotFound(name.to_string())); + } + Ok(()) + } + + // ======================================================================== + // Pending Message Operations + // ======================================================================== + + /// Create a pending message. + pub fn create_pending_message( + &self, + channel_id: i64, + plaintext: &[u8], + message_ciphertext: &[u8], + envelope_descriptor: &[u8], + envelope_hash: &[u8], + box_index: &[u8], + ) -> Result { + let conn = self.conn.lock().unwrap(); + let now = Self::now(); + + conn.execute( + r#"INSERT INTO pending_messages + (channel_id, plaintext, message_ciphertext, envelope_descriptor, envelope_hash, box_index, attempts, status, created_at) + VALUES (?1, ?2, ?3, ?4, ?5, ?6, 0, 'pending', ?7)"#, + params![channel_id, plaintext, message_ciphertext, envelope_descriptor, envelope_hash, box_index, now], + )?; + + let id = conn.last_insert_rowid(); + Ok(PendingMessage { + id, + channel_id, + plaintext: plaintext.to_vec(), + message_ciphertext: message_ciphertext.to_vec(), + envelope_descriptor: envelope_descriptor.to_vec(), + envelope_hash: envelope_hash.to_vec(), + box_index: box_index.to_vec(), + attempts: 0, + status: "pending".to_string(), + created_at: now, + last_attempt_at: None, + }) + } + + /// Get all pending messages for a channel. + pub fn get_pending_messages(&self, channel_id: i64) -> Result> { + let conn = self.conn.lock().unwrap(); + let mut stmt = conn.prepare( + r#"SELECT id, channel_id, plaintext, message_ciphertext, envelope_descriptor, + envelope_hash, box_index, attempts, status, created_at, last_attempt_at + FROM pending_messages WHERE channel_id = ?1 ORDER BY created_at"# + )?; + + let messages = stmt.query_map(params![channel_id], |row| { + Ok(PendingMessage { + id: row.get(0)?, + channel_id: row.get(1)?, + plaintext: row.get(2)?, + message_ciphertext: row.get(3)?, + envelope_descriptor: row.get(4)?, + envelope_hash: row.get(5)?, + box_index: row.get(6)?, + attempts: row.get(7)?, + status: row.get(8)?, + created_at: row.get(9)?, + last_attempt_at: row.get(10)?, + }) + })?.collect::, _>>()?; + + Ok(messages) + } + + /// Update pending message status. + pub fn update_pending_message_status(&self, id: i64, status: &str) -> Result<()> { + let conn = self.conn.lock().unwrap(); + let now = Self::now(); + conn.execute( + "UPDATE pending_messages SET status = ?1, attempts = attempts + 1, last_attempt_at = ?2 WHERE id = ?3", + params![status, now, id], + )?; + Ok(()) + } + + /// Delete a pending message (after successful send). + pub fn delete_pending_message(&self, id: i64) -> Result<()> { + let conn = self.conn.lock().unwrap(); + conn.execute("DELETE FROM pending_messages WHERE id = ?1", params![id])?; + Ok(()) + } + + /// Delete a pending message by envelope hash. + pub fn delete_pending_message_by_hash(&self, envelope_hash: &[u8]) -> Result<()> { + let conn = self.conn.lock().unwrap(); + conn.execute("DELETE FROM pending_messages WHERE envelope_hash = ?1", params![envelope_hash])?; + Ok(()) + } + + // ======================================================================== + // Received Message Operations + // ======================================================================== + + /// Store a received message. + pub fn create_received_message( + &self, + channel_id: i64, + plaintext: &[u8], + box_index: &[u8], + ) -> Result { + let conn = self.conn.lock().unwrap(); + let now = Self::now(); + + conn.execute( + r#"INSERT INTO received_messages (channel_id, plaintext, box_index, received_at, is_read) + VALUES (?1, ?2, ?3, ?4, 0)"#, + params![channel_id, plaintext, box_index, now], + )?; + + let id = conn.last_insert_rowid(); + Ok(ReceivedMessage { + id, + channel_id, + plaintext: plaintext.to_vec(), + box_index: box_index.to_vec(), + received_at: now, + is_read: false, + }) + } + + /// Get unread messages for a channel. + pub fn get_unread_messages(&self, channel_id: i64) -> Result> { + let conn = self.conn.lock().unwrap(); + let mut stmt = conn.prepare( + r#"SELECT id, channel_id, plaintext, box_index, received_at, is_read + FROM received_messages WHERE channel_id = ?1 AND is_read = 0 ORDER BY received_at"# + )?; + + let messages = stmt.query_map(params![channel_id], |row| { + Ok(ReceivedMessage { + id: row.get(0)?, + channel_id: row.get(1)?, + plaintext: row.get(2)?, + box_index: row.get(3)?, + received_at: row.get(4)?, + is_read: row.get::<_, i64>(5)? != 0, + }) + })?.collect::, _>>()?; + + Ok(messages) + } + + /// Mark a message as read. + pub fn mark_message_read(&self, id: i64) -> Result<()> { + let conn = self.conn.lock().unwrap(); + conn.execute("UPDATE received_messages SET is_read = 1 WHERE id = ?1", params![id])?; + Ok(()) + } + + /// Get all messages for a channel (including read ones). + pub fn get_all_messages(&self, channel_id: i64) -> Result> { + let conn = self.conn.lock().unwrap(); + let mut stmt = conn.prepare( + r#"SELECT id, channel_id, plaintext, box_index, received_at, is_read + FROM received_messages WHERE channel_id = ?1 ORDER BY received_at"# + )?; + + let messages = stmt.query_map(params![channel_id], |row| { + Ok(ReceivedMessage { + id: row.get(0)?, + channel_id: row.get(1)?, + plaintext: row.get(2)?, + box_index: row.get(3)?, + received_at: row.get(4)?, + is_read: row.get::<_, i64>(5)? != 0, + }) + })?.collect::, _>>()?; + + Ok(messages) + } +} + diff --git a/src/persistent/error.rs b/src/persistent/error.rs new file mode 100644 index 0000000..2b362f2 --- /dev/null +++ b/src/persistent/error.rs @@ -0,0 +1,77 @@ +// SPDX-FileCopyrightText: Copyright (C) 2026 David Stainton +// SPDX-License-Identifier: AGPL-3.0-only + +//! Error types for the persistent pigeonhole module. + +use std::fmt; + +/// Errors that can occur in the persistent pigeonhole module. +#[derive(Debug)] +pub enum PigeonholeDbError { + /// Database error from rusqlite. + Database(rusqlite::Error), + /// Channel not found in database. + ChannelNotFound(String), + /// Channel already exists with the given name. + ChannelAlreadyExists(String), + /// Message not found. + MessageNotFound(i64), + /// Invalid capability data. + InvalidCapability(String), + /// Thin client error (from underlying pigeonhole operations). + ThinClient(crate::error::ThinClientError), + /// I/O error. + Io(std::io::Error), + /// Other error with message. + Other(String), +} + +impl fmt::Display for PigeonholeDbError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + PigeonholeDbError::Database(e) => write!(f, "Database error: {}", e), + PigeonholeDbError::ChannelNotFound(name) => write!(f, "Channel not found: {}", name), + PigeonholeDbError::ChannelAlreadyExists(name) => { + write!(f, "Channel already exists: {}", name) + } + PigeonholeDbError::MessageNotFound(id) => write!(f, "Message not found: {}", id), + PigeonholeDbError::InvalidCapability(msg) => write!(f, "Invalid capability: {}", msg), + PigeonholeDbError::ThinClient(e) => write!(f, "Thin client error: {}", e), + PigeonholeDbError::Io(e) => write!(f, "I/O error: {}", e), + PigeonholeDbError::Other(msg) => write!(f, "{}", msg), + } + } +} + +impl std::error::Error for PigeonholeDbError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + PigeonholeDbError::Database(e) => Some(e), + PigeonholeDbError::ThinClient(e) => Some(e), + PigeonholeDbError::Io(e) => Some(e), + _ => None, + } + } +} + +impl From for PigeonholeDbError { + fn from(err: rusqlite::Error) -> Self { + PigeonholeDbError::Database(err) + } +} + +impl From for PigeonholeDbError { + fn from(err: crate::error::ThinClientError) -> Self { + PigeonholeDbError::ThinClient(err) + } +} + +impl From for PigeonholeDbError { + fn from(err: std::io::Error) -> Self { + PigeonholeDbError::Io(err) + } +} + +/// Result type for persistent pigeonhole operations. +pub type Result = std::result::Result; + diff --git a/src/persistent/mod.rs b/src/persistent/mod.rs new file mode 100644 index 0000000..96d7cdf --- /dev/null +++ b/src/persistent/mod.rs @@ -0,0 +1,76 @@ +// SPDX-FileCopyrightText: Copyright (C) 2026 David Stainton +// SPDX-License-Identifier: AGPL-3.0-only + +//! High-level Pigeonhole API with database persistence. +//! +//! This module provides a simplified API for the Pigeonhole protocol, +//! automatically managing state (capabilities, indices) via SQLite. +//! +//! # Overview +//! +//! The low-level pigeonhole API in [`crate::pigeonhole`] requires manual +//! management of write/read capabilities and message box indices. This +//! module wraps that API with automatic state persistence, making it +//! much easier to build applications. +//! +//! # Features +//! +//! - **Automatic index management**: Write and read indices are automatically +//! persisted and incremented after each operation. +//! - **Channel persistence**: Channels (with their capabilities) are stored in +//! SQLite and can be recovered after application restart. +//! - **Tombstone support**: Delete messages by overwriting them with zeros. +//! - **Copy command support**: Send large payloads that span multiple boxes +//! using the Copy command with automatic chunking. +//! - **Message history**: Received messages are stored for later retrieval. +//! +//! # Example +//! +//! ```rust,ignore +//! use katzenpost_thin_client::persistent::{PigeonholeClient, Database}; +//! +//! // Open database and create client +//! let db = Database::open("pigeonhole.db")?; +//! let pigeonhole = PigeonholeClient::new(thin_client, db); +//! +//! // Create a channel (generates keypair automatically) +//! let mut channel = pigeonhole.create_channel("my-channel").await?; +//! +//! // Send a message (indices managed automatically) +//! channel.send(b"Hello, world!").await?; +//! +//! // Share read capability with someone else +//! let read_cap = channel.share_read_capability(); +//! let read_cap_bytes = read_cap.to_bytes(); +//! +//! // On the receiver side: +//! let read_cap = ReadCapability::from_bytes(&read_cap_bytes)?; +//! let mut their_channel = pigeonhole.import_channel("from-alice", &read_cap)?; +//! let message = their_channel.receive().await?; +//! +//! // Tombstone (delete) the last written message +//! channel.tombstone_current(&geometry).await?; +//! ``` +//! +//! # Plaintext Size Constraints +//! +//! Single messages sent via [`ChannelHandle::send`] must not exceed +//! `PigeonholeGeometry.max_plaintext_payload_length` bytes. +//! +//! # Database Schema +//! +//! The module creates three tables: +//! - `channels`: Stores channel metadata (name, capabilities, indices) +//! - `pending_messages`: Messages waiting to be sent or acknowledged +//! - `received_messages`: Messages received from channels + +pub mod channel; +pub mod db; +pub mod error; +pub mod models; + +pub use channel::{ChannelHandle, CopyStreamBuilder, PigeonholeClient}; +pub use db::Database; +pub use error::{PigeonholeDbError, Result}; +pub use models::{Channel, PendingMessage, ReadCapability, ReceivedMessage}; + diff --git a/src/persistent/models.rs b/src/persistent/models.rs new file mode 100644 index 0000000..405414b --- /dev/null +++ b/src/persistent/models.rs @@ -0,0 +1,104 @@ +// SPDX-FileCopyrightText: Copyright (C) 2026 David Stainton +// SPDX-License-Identifier: AGPL-3.0-only + +//! Database models for the persistent pigeonhole module. + +use serde::{Deserialize, Serialize}; + +/// A pigeonhole channel stored in the database. +/// +/// Channels represent a communication endpoint with write and/or read capabilities. +/// The owner of a channel has the write_cap and can send messages. +/// They can share the read_cap with others to allow reading. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Channel { + /// Unique database ID. + pub id: i64, + /// Human-readable name for the channel. + pub name: String, + /// Write capability (only present if we own the channel). + pub write_cap: Option>, + /// Read capability (always present). + pub read_cap: Vec, + /// Current write index (for sending messages). + pub write_index: Vec, + /// Current read index (for receiving messages). + pub read_index: Vec, + /// Whether this is an owned channel (we have write_cap) or imported (read-only). + pub is_owned: bool, + /// Creation timestamp (Unix epoch seconds). + pub created_at: i64, + /// Last activity timestamp (Unix epoch seconds). + pub updated_at: i64, +} + +/// A pending outgoing message waiting to be sent or acknowledged. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PendingMessage { + /// Unique database ID. + pub id: i64, + /// Channel ID this message belongs to. + pub channel_id: i64, + /// The plaintext message content. + pub plaintext: Vec, + /// The encrypted message ciphertext. + pub message_ciphertext: Vec, + /// Envelope descriptor for decryption. + pub envelope_descriptor: Vec, + /// Envelope hash for cancellation/tracking. + pub envelope_hash: Vec, + /// The message box index this was sent to. + pub box_index: Vec, + /// Number of send attempts. + pub attempts: i32, + /// Current status: "pending", "sending", "sent", "failed". + pub status: String, + /// Creation timestamp (Unix epoch seconds). + pub created_at: i64, + /// Last attempt timestamp (Unix epoch seconds). + pub last_attempt_at: Option, +} + +/// A received message from a channel. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ReceivedMessage { + /// Unique database ID. + pub id: i64, + /// Channel ID this message was received from. + pub channel_id: i64, + /// The decrypted plaintext message content. + pub plaintext: Vec, + /// The message box index this was read from. + pub box_index: Vec, + /// Reception timestamp (Unix epoch seconds). + pub received_at: i64, + /// Whether the message has been read/processed by the application. + pub is_read: bool, +} + +/// Read capability that can be shared with others. +/// +/// This is a serializable structure containing all information +/// needed to import and read from a channel. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ReadCapability { + /// The read capability bytes. + pub read_cap: Vec, + /// The starting message index for reading. + pub start_index: Vec, + /// Optional human-readable name/description. + pub name: Option, +} + +impl ReadCapability { + /// Serialize to bytes for sharing (e.g., as a QR code or file). + pub fn to_bytes(&self) -> Vec { + serde_cbor::to_vec(self).expect("Failed to serialize ReadCapability") + } + + /// Deserialize from bytes. + pub fn from_bytes(bytes: &[u8]) -> Result { + serde_cbor::from_slice(bytes) + } +} + diff --git a/src/pigeonhole.rs b/src/pigeonhole.rs new file mode 100644 index 0000000..6684bda --- /dev/null +++ b/src/pigeonhole.rs @@ -0,0 +1,1188 @@ +// SPDX-FileCopyrightText: Copyright (C) 2026 David Stainton +// SPDX-License-Identifier: AGPL-3.0-only + +//! Pigeonhole protocol API for the thin client. +//! +//! This module provides methods for interacting with the Pigeonhole protocol, +//! including key generation, encryption, and ARQ (Automatic Repeat Request) +//! for reliable message delivery to the courier. + +use std::collections::BTreeMap; +use serde_cbor::Value; +use rand::RngCore; +use log::debug; + +use crate::error::{ThinClientError, error_code_to_error}; +use crate::core::ThinClient; + +// ======================================================================== +// Helper module for serializing Option> as CBOR byte strings +// ======================================================================== + +mod optional_bytes { + use serde::{Deserialize, Deserializer, Serialize, Serializer}; + + pub fn serialize(value: &Option>, serializer: S) -> Result + where + S: Serializer, + { + match value { + Some(bytes) => serde_bytes::serialize(bytes, serializer), + None => Option::<&[u8]>::None.serialize(serializer), + } + } + + pub fn deserialize<'de, D>(deserializer: D) -> Result>, D::Error> + where + D: Deserializer<'de>, + { + let opt: Option = Option::deserialize(deserializer)?; + Ok(opt.map(|b| b.into_vec())) + } +} + +// ======================================================================== +// NEW Pigeonhole API Protocol Message Structs +// ======================================================================== + +/// Request to create a new keypair for the Pigeonhole protocol. +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +struct NewKeypairRequest { + #[serde(with = "serde_bytes")] + query_id: Vec, + #[serde(with = "serde_bytes")] + seed: Vec, +} + +/// Reply containing the generated keypair and first message index. +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +struct NewKeypairReply { + #[serde(with = "serde_bytes")] + query_id: Vec, + #[serde(default, with = "optional_bytes")] + write_cap: Option>, + #[serde(default, with = "optional_bytes")] + read_cap: Option>, + #[serde(default, with = "optional_bytes")] + first_message_index: Option>, + #[serde(default)] + error_code: u8, +} + +/// Request to encrypt a read operation. +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +struct EncryptReadRequest { + #[serde(with = "serde_bytes")] + query_id: Vec, + #[serde(with = "serde_bytes")] + read_cap: Vec, + #[serde(with = "serde_bytes")] + message_box_index: Vec, +} + +/// Reply containing the encrypted read operation. +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +struct EncryptReadReply { + #[serde(with = "serde_bytes")] + query_id: Vec, + #[serde(default, with = "optional_bytes")] + message_ciphertext: Option>, + #[serde(default, with = "optional_bytes")] + next_message_index: Option>, + #[serde(default, with = "optional_bytes")] + envelope_descriptor: Option>, + #[serde(default, with = "optional_bytes")] + envelope_hash: Option>, + #[serde(default)] + error_code: u8, +} + +/// Request to encrypt a write operation. +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +struct EncryptWriteRequest { + #[serde(with = "serde_bytes")] + query_id: Vec, + #[serde(with = "serde_bytes")] + plaintext: Vec, + #[serde(with = "serde_bytes")] + write_cap: Vec, + #[serde(with = "serde_bytes")] + message_box_index: Vec, +} + +/// Reply containing the encrypted write operation. +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +struct EncryptWriteReply { + #[serde(with = "serde_bytes")] + query_id: Vec, + #[serde(default, with = "optional_bytes")] + message_ciphertext: Option>, + #[serde(default, with = "optional_bytes")] + envelope_descriptor: Option>, + #[serde(default, with = "optional_bytes")] + envelope_hash: Option>, + #[serde(default)] + error_code: u8, +} + +/// Request to start resending an encrypted message via ARQ. +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +struct StartResendingEncryptedMessageRequest { + #[serde(with = "serde_bytes")] + query_id: Vec, + #[serde(skip_serializing_if = "Option::is_none", with = "optional_bytes")] + read_cap: Option>, + #[serde(skip_serializing_if = "Option::is_none", with = "optional_bytes")] + write_cap: Option>, + #[serde(skip_serializing_if = "Option::is_none", with = "optional_bytes")] + next_message_index: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + reply_index: Option, + #[serde(with = "serde_bytes")] + envelope_descriptor: Vec, + #[serde(with = "serde_bytes")] + message_ciphertext: Vec, + #[serde(with = "serde_bytes")] + envelope_hash: Vec, + /// If true, BoxIDNotFound errors on reads trigger immediate error instead of automatic retries. + #[serde(skip_serializing_if = "std::ops::Not::not")] + no_retry_on_box_id_not_found: bool, + /// If true, BoxAlreadyExists errors on writes are returned as errors instead of idempotent success. + #[serde(skip_serializing_if = "std::ops::Not::not")] + no_idempotent_box_already_exists: bool, +} + +/// Reply containing the plaintext from a resent encrypted message. +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +struct StartResendingEncryptedMessageReply { + #[serde(with = "serde_bytes")] + query_id: Vec, + #[serde(default, with = "optional_bytes")] + plaintext: Option>, + error_code: u8, +} + +/// Request to cancel resending an encrypted message. +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +struct CancelResendingEncryptedMessageRequest { + #[serde(with = "serde_bytes")] + query_id: Vec, + #[serde(with = "serde_bytes")] + envelope_hash: Vec, +} + +/// Reply confirming cancellation of resending. +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +struct CancelResendingEncryptedMessageReply { + #[serde(with = "serde_bytes")] + query_id: Vec, + error_code: u8, +} + +/// Request to increment a MessageBoxIndex. +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +struct NextMessageBoxIndexRequest { + #[serde(with = "serde_bytes")] + query_id: Vec, + #[serde(with = "serde_bytes")] + message_box_index: Vec, +} + +/// Reply containing the incremented MessageBoxIndex. +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +struct NextMessageBoxIndexReply { + #[serde(with = "serde_bytes")] + query_id: Vec, + #[serde(default, with = "optional_bytes")] + next_message_box_index: Option>, + #[serde(default)] + error_code: u8, +} + +/// Request to start resending a copy command. +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +struct StartResendingCopyCommandRequest { + #[serde(with = "serde_bytes")] + query_id: Vec, + #[serde(with = "serde_bytes")] + write_cap: Vec, + #[serde(skip_serializing_if = "Option::is_none", default, with = "optional_bytes")] + courier_identity_hash: Option>, + #[serde(skip_serializing_if = "Option::is_none", default, with = "optional_bytes")] + courier_queue_id: Option>, +} + +/// Reply confirming start of copy command resending. +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +struct StartResendingCopyCommandReply { + #[serde(with = "serde_bytes")] + query_id: Vec, + error_code: u8, +} + +/// Request to cancel resending a copy command. +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +struct CancelResendingCopyCommandRequest { + #[serde(with = "serde_bytes")] + query_id: Vec, + #[serde(with = "serde_bytes")] + write_cap_hash: Vec, +} + +/// Reply confirming cancellation of copy command resending. +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +struct CancelResendingCopyCommandReply { + #[serde(with = "serde_bytes")] + query_id: Vec, + error_code: u8, +} + +/// Request to create courier envelopes from a payload. +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +struct CreateCourierEnvelopesFromPayloadRequest { + #[serde(with = "serde_bytes")] + query_id: Vec, + #[serde(with = "serde_bytes")] + stream_id: Vec, + #[serde(with = "serde_bytes")] + payload: Vec, + #[serde(with = "serde_bytes")] + dest_write_cap: Vec, + #[serde(with = "serde_bytes")] + dest_start_index: Vec, + is_last: bool, +} + +/// Reply containing the created courier envelopes and buffer state. +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +struct CreateCourierEnvelopesFromPayloadReply { + #[serde(with = "serde_bytes")] + query_id: Vec, + /// Envelopes is None when the daemon returns an error. + envelopes: Option>, + /// Buffer contains any data buffered by the encoder that hasn't been output yet. + /// None when the daemon returns an error. + #[serde(default, with = "optional_bytes")] + buffer: Option>, + #[serde(default)] + error_code: u8, +} + +/// A destination for creating courier envelopes. +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +struct EnvelopeDestination { + #[serde(with = "serde_bytes")] + payload: Vec, + #[serde(with = "serde_bytes")] + write_cap: Vec, + #[serde(with = "serde_bytes")] + start_index: Vec, +} + +/// Request to create courier envelopes from multiple payloads. +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +struct CreateCourierEnvelopesFromPayloadsRequest { + #[serde(with = "serde_bytes")] + query_id: Vec, + #[serde(with = "serde_bytes")] + stream_id: Vec, + destinations: Vec, + is_last: bool, +} + +/// Reply containing the created courier envelopes from multiple payloads and buffer state. +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +struct CreateCourierEnvelopesFromPayloadsReply { + #[serde(with = "serde_bytes")] + query_id: Vec, + /// Envelopes is None when the daemon returns an error. + envelopes: Option>, + /// Buffer contains any data buffered by the encoder that hasn't been output yet. + /// None when the daemon returns an error. + #[serde(default, with = "optional_bytes")] + buffer: Option>, + #[serde(default)] + error_code: u8, +} + +/// Request to set/restore the buffered state for a stream (for crash recovery). +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +struct SetStreamBufferRequest { + #[serde(with = "serde_bytes")] + query_id: Vec, + #[serde(with = "serde_bytes")] + stream_id: Vec, + #[serde(with = "serde_bytes")] + buffer: Vec, +} + +/// Reply confirming the buffer state has been restored. +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +struct SetStreamBufferReply { + #[serde(with = "serde_bytes")] + query_id: Vec, + error_code: u8, +} + +/// Result of creating courier envelopes, including the envelopes and buffer for crash recovery. +#[derive(Debug, Clone)] +pub struct CreateEnvelopesResult { + /// The serialized CopyStreamElements to send to the network. + pub envelopes: Vec>, + /// The buffered data that hasn't been output yet. Persist this for crash recovery. + pub buffer: Vec, +} + +// ======================================================================== +// NEW Pigeonhole API Methods +// ======================================================================== + +impl ThinClient { + /// Creates a new keypair for use with the Pigeonhole protocol. + /// + /// This method generates a WriteCap and ReadCap from the provided seed using + /// the BACAP (Blinding-and-Capability) protocol. The WriteCap should be stored + /// securely for writing messages, while the ReadCap can be shared with others + /// to allow them to read messages. + /// + /// # Arguments + /// * `seed` - 32-byte seed used to derive the keypair + /// + /// # Returns + /// * `Ok((write_cap, read_cap, first_message_index))` on success + /// * `Err(ThinClientError)` on failure + pub async fn new_keypair(&self, seed: &[u8; 32]) -> Result<(Vec, Vec, Vec), ThinClientError> { + let query_id = Self::new_query_id(); + + let request_inner = NewKeypairRequest { + query_id: query_id.clone(), + seed: seed.to_vec(), + }; + + let request_value = serde_cbor::value::to_value(&request_inner) + .map_err(|e| ThinClientError::CborError(e))?; + + let mut request = BTreeMap::new(); + request.insert(Value::Text("new_keypair".to_string()), request_value); + + let reply_map = self.send_and_wait_direct(query_id, request).await?; + + let reply: NewKeypairReply = serde_cbor::value::from_value(Value::Map(reply_map)) + .map_err(|e| ThinClientError::CborError(e))?; + + if reply.error_code != 0 { + return Err(ThinClientError::Other(format!("new_keypair failed with error code: {}", reply.error_code))); + } + + let write_cap = reply.write_cap.ok_or_else(|| ThinClientError::Other("new_keypair: write_cap is None".to_string()))?; + let read_cap = reply.read_cap.ok_or_else(|| ThinClientError::Other("new_keypair: read_cap is None".to_string()))?; + let first_message_index = reply.first_message_index.ok_or_else(|| ThinClientError::Other("new_keypair: first_message_index is None".to_string()))?; + + Ok((write_cap, read_cap, first_message_index)) + } + + /// Encrypts a read operation for a given read capability. + /// + /// This method prepares an encrypted read request that can be sent to the + /// courier service to retrieve a message from a pigeonhole box. + /// + /// # Arguments + /// * `read_cap` - Read capability that grants access to the channel + /// * `message_box_index` - Starting read position for the channel + /// + /// # Returns + /// * `Ok((message_ciphertext, next_message_index, envelope_descriptor, envelope_hash))` on success + /// * `Err(ThinClientError)` on failure + pub async fn encrypt_read( + &self, + read_cap: &[u8], + message_box_index: &[u8] + ) -> Result<(Vec, Vec, Vec, [u8; 32]), ThinClientError> { + let query_id = Self::new_query_id(); + + let request_inner = EncryptReadRequest { + query_id: query_id.clone(), + read_cap: read_cap.to_vec(), + message_box_index: message_box_index.to_vec(), + }; + + let request_value = serde_cbor::value::to_value(&request_inner) + .map_err(|e| ThinClientError::CborError(e))?; + + let mut request = BTreeMap::new(); + request.insert(Value::Text("encrypt_read".to_string()), request_value); + + let reply_map = self.send_and_wait_direct(query_id, request).await?; + + let reply: EncryptReadReply = serde_cbor::value::from_value(Value::Map(reply_map)) + .map_err(|e| ThinClientError::CborError(e))?; + + if reply.error_code != 0 { + return Err(ThinClientError::Other(format!("encrypt_read failed with error code: {}", reply.error_code))); + } + + let message_ciphertext = reply.message_ciphertext.ok_or_else(|| ThinClientError::Other("encrypt_read: message_ciphertext is None".to_string()))?; + let next_message_index = reply.next_message_index.ok_or_else(|| ThinClientError::Other("encrypt_read: next_message_index is None".to_string()))?; + let envelope_descriptor = reply.envelope_descriptor.ok_or_else(|| ThinClientError::Other("encrypt_read: envelope_descriptor is None".to_string()))?; + let envelope_hash_vec = reply.envelope_hash.ok_or_else(|| ThinClientError::Other("encrypt_read: envelope_hash is None".to_string()))?; + + let mut envelope_hash = [0u8; 32]; + envelope_hash.copy_from_slice(&envelope_hash_vec[..32]); + + Ok(( + message_ciphertext, + next_message_index, + envelope_descriptor, + envelope_hash + )) + } + + /// Encrypts a write operation for a given write capability. + /// + /// This method prepares an encrypted write request that can be sent to the + /// courier service to store a message in a pigeonhole box. + /// + /// # Plaintext Size Constraint + /// + /// The `plaintext` must not exceed `PigeonholeGeometry.max_plaintext_payload_length` bytes. + /// The daemon internally adds a 4-byte big-endian length prefix before padding and + /// encryption, so the actual wire format is `[4-byte length][plaintext][zero padding]`. + /// + /// If the plaintext exceeds the maximum size, the daemon will return + /// `ThinClientErrorInvalidRequest`. + /// + /// # Arguments + /// * `plaintext` - The plaintext message to encrypt. Must be at most + /// `PigeonholeGeometry.max_plaintext_payload_length` bytes. + /// * `write_cap` - Write capability that grants access to the channel. + /// * `message_box_index` - The message box index for this write operation. + /// + /// # Returns + /// * `Ok((message_ciphertext, envelope_descriptor, envelope_hash))` on success + /// * `Err(ThinClientError)` on failure (including if plaintext is too large) + pub async fn encrypt_write( + &self, + plaintext: &[u8], + write_cap: &[u8], + message_box_index: &[u8] + ) -> Result<(Vec, Vec, [u8; 32]), ThinClientError> { + let query_id = Self::new_query_id(); + + let request_inner = EncryptWriteRequest { + query_id: query_id.clone(), + plaintext: plaintext.to_vec(), + write_cap: write_cap.to_vec(), + message_box_index: message_box_index.to_vec(), + }; + + let request_value = serde_cbor::value::to_value(&request_inner) + .map_err(|e| ThinClientError::CborError(e))?; + + let mut request = BTreeMap::new(); + request.insert(Value::Text("encrypt_write".to_string()), request_value); + + let reply_map = self.send_and_wait_direct(query_id, request).await?; + + let reply: EncryptWriteReply = serde_cbor::value::from_value(Value::Map(reply_map)) + .map_err(|e| ThinClientError::CborError(e))?; + + if reply.error_code != 0 { + return Err(ThinClientError::Other(format!("encrypt_write failed with error code: {}", reply.error_code))); + } + + let message_ciphertext = reply.message_ciphertext.ok_or_else(|| ThinClientError::Other("encrypt_write: message_ciphertext is None".to_string()))?; + let envelope_descriptor = reply.envelope_descriptor.ok_or_else(|| ThinClientError::Other("encrypt_write: envelope_descriptor is None".to_string()))?; + let envelope_hash_vec = reply.envelope_hash.ok_or_else(|| ThinClientError::Other("encrypt_write: envelope_hash is None".to_string()))?; + + let mut envelope_hash = [0u8; 32]; + envelope_hash.copy_from_slice(&envelope_hash_vec[..32]); + + Ok(( + message_ciphertext, + envelope_descriptor, + envelope_hash + )) + } + + /// Starts resending an encrypted message via ARQ (Automatic Repeat Request). + /// + /// This method initiates automatic repeat request for an encrypted message, + /// which will be resent periodically until either a reply is received or + /// the operation is cancelled. + /// + /// # Arguments + /// * `read_cap` - Optional read capability (for read operations) + /// * `write_cap` - Optional write capability (for write operations) + /// * `next_message_index` - Optional next message index (for read operations) + /// * `reply_index` - Reply index for the operation (None for tombstone writes) + /// * `envelope_descriptor` - Envelope descriptor from encrypt_read/encrypt_write + /// * `message_ciphertext` - Encrypted message from encrypt_read/encrypt_write + /// * `envelope_hash` - Envelope hash from encrypt_read/encrypt_write + /// + /// # Returns + /// * `Ok(plaintext)` - For read operations, the decrypted plaintext message + /// (at most `PigeonholeGeometry.max_plaintext_payload_length` bytes). + /// For write operations, returns an empty vector on success. + /// * `Err(ThinClientError)` on failure + /// Sends an encrypted message via ARQ and blocks until completion. + /// + /// This method BLOCKS until a reply is received from the daemon. + /// The message will be resent periodically until either: + /// - A successful response is received (plaintext for reads, ACK for writes) + /// - An error response is received from the daemon + /// - The operation is cancelled via cancel_resending_encrypted_message + pub async fn start_resending_encrypted_message( + &self, + read_cap: Option<&[u8]>, + write_cap: Option<&[u8]>, + next_message_index: Option<&[u8]>, + reply_index: Option, + envelope_descriptor: &[u8], + message_ciphertext: &[u8], + envelope_hash: &[u8; 32] + ) -> Result, ThinClientError> { + self.start_resending_encrypted_message_with_options( + read_cap, + write_cap, + next_message_index, + reply_index, + envelope_descriptor, + message_ciphertext, + envelope_hash, + false, + false, + ).await + } + + /// Like `start_resending_encrypted_message` but returns BoxAlreadyExists errors. + /// + /// Use this when you want to detect whether a write was actually performed + /// or if the box already existed. + /// + /// # Arguments + /// Same as `start_resending_encrypted_message` + /// + /// # Returns + /// * `Ok(plaintext)` on success + /// * `Err(ThinClientError::BoxAlreadyExists)` if the box already contains data + /// * `Err(ThinClientError)` on other failures + pub async fn start_resending_encrypted_message_return_box_exists( + &self, + read_cap: Option<&[u8]>, + write_cap: Option<&[u8]>, + next_message_index: Option<&[u8]>, + reply_index: Option, + envelope_descriptor: &[u8], + message_ciphertext: &[u8], + envelope_hash: &[u8; 32] + ) -> Result, ThinClientError> { + self.start_resending_encrypted_message_with_options( + read_cap, + write_cap, + next_message_index, + reply_index, + envelope_descriptor, + message_ciphertext, + envelope_hash, + false, + true, // no_idempotent_box_already_exists + ).await + } + + /// Like `start_resending_encrypted_message` but disables automatic retries on BoxIDNotFound. + /// + /// Use this when you want immediate error feedback rather than waiting for + /// potential replication lag to resolve. + /// + /// # Arguments + /// Same as `start_resending_encrypted_message` + /// + /// # Returns + /// * `Ok(plaintext)` on success + /// * `Err(ThinClientError::BoxIdNotFound)` if the box does not exist (no automatic retries) + /// * `Err(ThinClientError)` on other failures + pub async fn start_resending_encrypted_message_no_retry( + &self, + read_cap: Option<&[u8]>, + write_cap: Option<&[u8]>, + next_message_index: Option<&[u8]>, + reply_index: Option, + envelope_descriptor: &[u8], + message_ciphertext: &[u8], + envelope_hash: &[u8; 32] + ) -> Result, ThinClientError> { + self.start_resending_encrypted_message_with_options( + read_cap, + write_cap, + next_message_index, + reply_index, + envelope_descriptor, + message_ciphertext, + envelope_hash, + true, // no_retry_on_box_id_not_found + false, + ).await + } + + /// Internal method with all options for start_resending_encrypted_message. + async fn start_resending_encrypted_message_with_options( + &self, + read_cap: Option<&[u8]>, + write_cap: Option<&[u8]>, + next_message_index: Option<&[u8]>, + reply_index: Option, + envelope_descriptor: &[u8], + message_ciphertext: &[u8], + envelope_hash: &[u8; 32], + no_retry_on_box_id_not_found: bool, + no_idempotent_box_already_exists: bool, + ) -> Result, ThinClientError> { + let query_id = Self::new_query_id(); + + let request_inner = StartResendingEncryptedMessageRequest { + query_id: query_id.clone(), + read_cap: read_cap.map(|rc| rc.to_vec()), + write_cap: write_cap.map(|wc| wc.to_vec()), + next_message_index: next_message_index.map(|nmi| nmi.to_vec()), + reply_index, + envelope_descriptor: envelope_descriptor.to_vec(), + message_ciphertext: message_ciphertext.to_vec(), + envelope_hash: envelope_hash.to_vec(), + no_retry_on_box_id_not_found, + no_idempotent_box_already_exists, + }; + + let request_value = serde_cbor::value::to_value(&request_inner) + .map_err(|e| ThinClientError::CborError(e))?; + + let mut request = BTreeMap::new(); + request.insert(Value::Text("start_resending_encrypted_message".to_string()), request_value); + + // Use direct response routing (like Python's _send_and_wait) + // This blocks until the daemon sends a reply with matching query_id + let reply_map = self.send_and_wait_direct(query_id, request).await?; + + // Parse the reply + let reply: StartResendingEncryptedMessageReply = serde_cbor::value::from_value(Value::Map(reply_map)) + .map_err(|e| ThinClientError::CborError(e))?; + + debug!("start_resending_encrypted_message: received reply, error_code={}, plaintext_len={}", + reply.error_code, reply.plaintext.as_ref().map(|p| p.len()).unwrap_or(0)); + + if reply.error_code != 0 { + return Err(error_code_to_error(reply.error_code)); + } + + Ok(reply.plaintext.unwrap_or_default()) + } + + /// Cancels ARQ resending for an encrypted message. + /// + /// This method stops the automatic repeat request for a previously started + /// encrypted message transmission. + /// + /// # Arguments + /// * `envelope_hash` - Hash of the courier envelope to cancel + /// + /// # Returns + /// * `Ok(())` on success + /// * `Err(ThinClientError)` on failure + pub async fn cancel_resending_encrypted_message(&self, envelope_hash: &[u8; 32]) -> Result<(), ThinClientError> { + let query_id = Self::new_query_id(); + + let request_inner = CancelResendingEncryptedMessageRequest { + query_id: query_id.clone(), + envelope_hash: envelope_hash.to_vec(), + }; + + let request_value = serde_cbor::value::to_value(&request_inner) + .map_err(|e| ThinClientError::CborError(e))?; + + let mut request = BTreeMap::new(); + request.insert(Value::Text("cancel_resending_encrypted_message".to_string()), request_value); + + let reply_map = self.send_and_wait_direct(query_id, request).await?; + + let reply: CancelResendingEncryptedMessageReply = serde_cbor::value::from_value(Value::Map(reply_map)) + .map_err(|e| ThinClientError::CborError(e))?; + + if reply.error_code != 0 { + return Err(ThinClientError::Other(format!("cancel_resending_encrypted_message failed with error code: {}", reply.error_code))); + } + + Ok(()) + } + + /// Increments a MessageBoxIndex using the BACAP NextIndex method. + /// + /// This method is used when sending multiple messages to different mailboxes using + /// the same WriteCap or ReadCap. It properly advances the cryptographic state by: + /// - Incrementing the Idx64 counter + /// - Deriving new encryption and blinding keys using HKDF + /// - Updating the HKDF state for the next iteration + /// + /// # Arguments + /// * `message_box_index` - Current message box index to increment + /// + /// # Returns + /// * `Ok(next_message_box_index)` - The incremented message box index + /// * `Err(ThinClientError)` on failure + pub async fn next_message_box_index(&self, message_box_index: &[u8]) -> Result, ThinClientError> { + let query_id = Self::new_query_id(); + + let request_inner = NextMessageBoxIndexRequest { + query_id: query_id.clone(), + message_box_index: message_box_index.to_vec(), + }; + + let request_value = serde_cbor::value::to_value(&request_inner) + .map_err(|e| ThinClientError::CborError(e))?; + + let mut request = BTreeMap::new(); + request.insert(Value::Text("next_message_box_index".to_string()), request_value); + + let reply_map = self.send_and_wait_direct(query_id, request).await?; + + let reply: NextMessageBoxIndexReply = serde_cbor::value::from_value(Value::Map(reply_map)) + .map_err(|e| ThinClientError::CborError(e))?; + + if reply.error_code != 0 { + return Err(ThinClientError::Other(format!("next_message_box_index failed with error code: {}", reply.error_code))); + } + + let next_index = reply.next_message_box_index.ok_or_else(|| ThinClientError::Other("next_message_box_index: next_message_box_index is None".to_string()))?; + Ok(next_index) + } + + /// Starts resending a copy command to a courier via ARQ. + /// + /// This method instructs a courier to read data from a temporary channel + /// (identified by the write_cap) and write it to the destination channel. + /// The command is automatically retransmitted until acknowledged. + /// + /// If courier_identity_hash and courier_queue_id are both provided, + /// the copy command is sent to that specific courier. Otherwise, a + /// random courier is selected. + /// + /// # Arguments + /// * `write_cap` - Write capability for the temporary channel containing the data + /// * `courier_identity_hash` - Optional identity hash of a specific courier to use + /// * `courier_queue_id` - Optional queue ID for the specified courier + /// + /// # Returns + /// * `Ok(())` on success + /// * `Err(ThinClientError)` on failure + pub async fn start_resending_copy_command( + &self, + write_cap: &[u8], + courier_identity_hash: Option<&[u8]>, + courier_queue_id: Option<&[u8]> + ) -> Result<(), ThinClientError> { + let query_id = Self::new_query_id(); + + let request_inner = StartResendingCopyCommandRequest { + query_id: query_id.clone(), + write_cap: write_cap.to_vec(), + courier_identity_hash: courier_identity_hash.map(|h| h.to_vec()), + courier_queue_id: courier_queue_id.map(|q| q.to_vec()), + }; + + let request_value = serde_cbor::value::to_value(&request_inner) + .map_err(|e| ThinClientError::CborError(e))?; + + let mut request = BTreeMap::new(); + request.insert(Value::Text("start_resending_copy_command".to_string()), request_value); + + let reply_map = self.send_and_wait_direct(query_id, request).await?; + + let reply: StartResendingCopyCommandReply = serde_cbor::value::from_value(Value::Map(reply_map)) + .map_err(|e| ThinClientError::CborError(e))?; + + if reply.error_code != 0 { + return Err(ThinClientError::Other(format!("start_resending_copy_command failed with error code: {}", reply.error_code))); + } + + Ok(()) + } + + /// Cancels ARQ resending for a copy command. + /// + /// This method stops the automatic repeat request (ARQ) for a previously started + /// copy command. + /// + /// # Arguments + /// * `write_cap_hash` - Hash of the WriteCap used in start_resending_copy_command + /// + /// # Returns + /// * `Ok(())` on success + /// * `Err(ThinClientError)` on failure + pub async fn cancel_resending_copy_command(&self, write_cap_hash: &[u8; 32]) -> Result<(), ThinClientError> { + let query_id = Self::new_query_id(); + + let request_inner = CancelResendingCopyCommandRequest { + query_id: query_id.clone(), + write_cap_hash: write_cap_hash.to_vec(), + }; + + let request_value = serde_cbor::value::to_value(&request_inner) + .map_err(|e| ThinClientError::CborError(e))?; + + let mut request = BTreeMap::new(); + request.insert(Value::Text("cancel_resending_copy_command".to_string()), request_value); + + let reply_map = self.send_and_wait_direct(query_id, request).await?; + + let reply: CancelResendingCopyCommandReply = serde_cbor::value::from_value(Value::Map(reply_map)) + .map_err(|e| ThinClientError::CborError(e))?; + + if reply.error_code != 0 { + return Err(ThinClientError::Other(format!("cancel_resending_copy_command failed with error code: {}", reply.error_code))); + } + + Ok(()) + } + + /// Creates multiple CourierEnvelopes from a payload of any size. + /// + /// The payload is automatically chunked and each chunk is wrapped in a + /// CourierEnvelope. Each returned chunk is a serialized CopyStreamElement + /// ready to be written to a box. + /// + /// Multiple calls can be made with the same stream_id to build up a stream + /// incrementally. The first call creates a new encoder (first element gets + /// IsStart=true). The final call should have is_last=true (last element + /// gets IsFinal=true). + /// + /// # Crash Recovery + /// + /// When `is_last=false`, the daemon buffers the last partial box's payload + /// internally so that subsequent writes can be packed efficiently. The + /// `buffer` in the result contains this buffered data which you should + /// persist for crash recovery. On restart, use `set_stream_buffer` to restore + /// the state before continuing the stream. + /// + /// # Arguments + /// * `stream_id` - 16-byte identifier for the encoder instance + /// * `payload` - The data to be encoded into courier envelopes + /// * `dest_write_cap` - Write capability for the destination channel + /// * `dest_start_index` - Starting index in the destination channel + /// * `is_last` - Whether this is the last payload in the sequence + /// + /// # Returns + /// * `Ok(CreateEnvelopesResult)` - Contains envelopes and buffer state for crash recovery + /// * `Err(ThinClientError)` on failure + pub async fn create_courier_envelopes_from_payload( + &self, + stream_id: &[u8; 16], + payload: &[u8], + dest_write_cap: &[u8], + dest_start_index: &[u8], + is_last: bool + ) -> Result { + let query_id = Self::new_query_id(); + + let request_inner = CreateCourierEnvelopesFromPayloadRequest { + query_id: query_id.clone(), + stream_id: stream_id.to_vec(), + payload: payload.to_vec(), + dest_write_cap: dest_write_cap.to_vec(), + dest_start_index: dest_start_index.to_vec(), + is_last, + }; + + let request_value = serde_cbor::value::to_value(&request_inner) + .map_err(|e| ThinClientError::CborError(e))?; + + let mut request = BTreeMap::new(); + request.insert(Value::Text("create_courier_envelopes_from_payload".to_string()), request_value); + + let reply_map = self.send_and_wait_direct(query_id, request).await?; + + let reply: CreateCourierEnvelopesFromPayloadReply = serde_cbor::value::from_value(Value::Map(reply_map)) + .map_err(|e| ThinClientError::CborError(e))?; + + if reply.error_code != 0 { + return Err(ThinClientError::Other(format!("create_courier_envelopes_from_payload failed with error code: {}", reply.error_code))); + } + + Ok(CreateEnvelopesResult { + envelopes: reply.envelopes.unwrap_or_default().into_iter().map(|b| b.into_vec()).collect(), + buffer: reply.buffer.unwrap_or_default(), + }) + } + + /// Creates CourierEnvelopes from multiple payloads going to different destinations. + /// + /// This is more space-efficient than calling create_courier_envelopes_from_payload + /// multiple times because envelopes from different destinations are packed + /// together in the copy stream without wasting space. + /// + /// # Crash Recovery + /// + /// When `is_last=false`, the daemon buffers the last partial box's payload + /// internally so that subsequent writes can be packed efficiently. The + /// `buffer` in the result contains this buffered data which you should + /// persist for crash recovery. On restart, use `set_stream_buffer` to restore + /// the state before continuing the stream. + /// + /// # Arguments + /// * `stream_id` - 16-byte identifier for the encoder instance + /// * `destinations` - List of (payload, write_cap, start_index) tuples + /// * `is_last` - Whether this is the last set of payloads in the sequence + /// + /// # Returns + /// * `Ok(CreateEnvelopesResult)` - Contains envelopes and buffer state for crash recovery + /// * `Err(ThinClientError)` on failure + pub async fn create_courier_envelopes_from_multi_payload( + &self, + stream_id: &[u8; 16], + destinations: Vec<(&[u8], &[u8], &[u8])>, + is_last: bool + ) -> Result { + let query_id = Self::new_query_id(); + + let destinations_inner: Vec = destinations + .into_iter() + .map(|(payload, write_cap, start_index)| EnvelopeDestination { + payload: payload.to_vec(), + write_cap: write_cap.to_vec(), + start_index: start_index.to_vec(), + }) + .collect(); + + let request_inner = CreateCourierEnvelopesFromPayloadsRequest { + query_id: query_id.clone(), + stream_id: stream_id.to_vec(), + destinations: destinations_inner, + is_last, + }; + + let request_value = serde_cbor::value::to_value(&request_inner) + .map_err(|e| ThinClientError::CborError(e))?; + + let mut request = BTreeMap::new(); + request.insert(Value::Text("create_courier_envelopes_from_multi_payload".to_string()), request_value); + + let reply_map = self.send_and_wait_direct(query_id, request).await?; + + let reply: CreateCourierEnvelopesFromPayloadsReply = serde_cbor::value::from_value(Value::Map(reply_map)) + .map_err(|e| ThinClientError::CborError(e))?; + + if reply.error_code != 0 { + return Err(ThinClientError::Other(format!("create_courier_envelopes_from_multi_payload failed with error code: {}", reply.error_code))); + } + + Ok(CreateEnvelopesResult { + envelopes: reply.envelopes.unwrap_or_default().into_iter().map(|b| b.into_vec()).collect(), + buffer: reply.buffer.unwrap_or_default(), + }) + } + + /// Generates a new random 16-byte stream ID. + pub fn new_stream_id() -> [u8; 16] { + let mut stream_id = [0u8; 16]; + rand::thread_rng().fill_bytes(&mut stream_id); + stream_id + } + + /// Restores the buffered state for a given stream ID. + /// + /// This is useful for crash recovery: after restart, call this method with the + /// buffer that was returned by `create_courier_envelopes_from_payload` or + /// `create_courier_envelopes_from_multi_payload` before the crash/shutdown. + /// + /// Note: This will create a new encoder if one doesn't exist for this stream_id, + /// or replace the buffer contents if one already exists. + /// + /// # Arguments + /// * `stream_id` - 16-byte identifier for the encoder instance + /// * `buffer` - The buffered data to restore (from `CreateEnvelopesResult.buffer`) + /// + /// # Returns + /// * `Ok(())` on success + /// * `Err(ThinClientError)` on failure + /// + /// # Example + /// ```ignore + /// // During streaming, save the buffer from each call + /// let result = client.create_courier_envelopes_from_payload(&stream_id, data, ..., false).await?; + /// save_to_disk(&stream_id, &result.buffer)?; + /// + /// // On restart, restore the stream state + /// let buffer = load_from_disk(&stream_id)?; + /// client.set_stream_buffer(&stream_id, buffer).await?; + /// // Now continue streaming from where we left off + /// client.create_courier_envelopes_from_payload(&stream_id, more_data, ..., true).await?; + /// ``` + pub async fn set_stream_buffer( + &self, + stream_id: &[u8; 16], + buffer: Vec, + ) -> Result<(), ThinClientError> { + let query_id = Self::new_query_id(); + + let request_inner = SetStreamBufferRequest { + query_id: query_id.clone(), + stream_id: stream_id.to_vec(), + buffer, + }; + + let request_value = serde_cbor::value::to_value(&request_inner) + .map_err(|e| ThinClientError::CborError(e))?; + + let mut request = BTreeMap::new(); + request.insert(Value::Text("set_stream_buffer".to_string()), request_value); + + let reply_map = self.send_and_wait_direct(query_id, request).await?; + + let reply: SetStreamBufferReply = serde_cbor::value::from_value(Value::Map(reply_map)) + .map_err(|e| ThinClientError::CborError(e))?; + + if reply.error_code != 0 { + return Err(ThinClientError::Other(format!( + "set_stream_buffer failed with error code: {}", reply.error_code + ))); + } + + Ok(()) + } + + /// Create an encrypted tombstone for a single pigeonhole box. + /// + /// This method creates an encrypted zero-filled payload for overwriting + /// the specified box. The caller must send the returned values via + /// start_resending_encrypted_message to complete the tombstone operation. + /// + /// # Arguments + /// * `geometry` - Pigeonhole geometry defining payload size + /// * `write_cap` - Write capability for the box + /// * `box_index` - Index of the box to tombstone + /// + /// # Returns + /// * `Ok((ciphertext, envelope_descriptor, envelope_hash))` on success + /// * `Err(ThinClientError)` on failure + /// Create a tombstone for a single pigeonhole box. + /// + /// This method creates a tombstone (empty payload with signature) for deleting + /// the specified box. The caller must send the returned values via + /// `start_resending_encrypted_message` to complete the tombstone operation. + /// + /// # Arguments + /// * `write_cap` - Write capability for the box + /// * `box_index` - Index of the box to tombstone + /// + /// # Returns + /// * `Ok((ciphertext, envelope_descriptor, envelope_hash))` on success + /// * `Err(ThinClientError)` on failure + pub async fn tombstone_box( + &self, + write_cap: &[u8], + box_index: &[u8] + ) -> Result<(Vec, Vec, Vec), ThinClientError> { + // Tombstones are created by sending an empty plaintext to encrypt_write + // The daemon will detect this and sign an empty payload instead of encrypting + let (ciphertext, env_desc, env_hash) = self + .encrypt_write(&[], write_cap, box_index).await?; + + Ok((ciphertext, env_desc, env_hash.to_vec())) + } +} + +/// A single tombstone envelope ready to be sent. +#[derive(Debug, Clone)] +pub struct TombstoneEnvelope { + /// The encrypted tombstone payload. + pub message_ciphertext: Vec, + /// The envelope descriptor. + pub envelope_descriptor: Vec, + /// The envelope hash for cancellation. + pub envelope_hash: Vec, + /// The box index this envelope is for. + pub box_index: Vec, +} + +/// Result of a tombstone_range operation. +#[derive(Debug)] +pub struct TombstoneRangeResult { + /// List of tombstone envelopes ready to be sent. + pub envelopes: Vec, + /// The next MessageBoxIndex after the last processed. + pub next: Vec, + /// Error message if the operation failed partway through. + pub error: Option, +} + +impl ThinClient { + /// Create tombstones for a range of pigeonhole boxes. + /// + /// This method creates tombstones for up to max_count boxes, + /// starting from the specified box index and advancing through consecutive + /// indices. The caller must send each envelope via start_resending_encrypted_message + /// to complete the tombstone operations. + /// + /// If an error occurs during the operation, a partial result is returned + /// containing the envelopes created so far and the next index. + /// + /// # Arguments + /// * `write_cap` - Write capability for the boxes + /// * `start` - Starting MessageBoxIndex + /// * `max_count` - Maximum number of boxes to tombstone + /// + /// # Returns + /// * `TombstoneRangeResult` containing the envelopes and next index + pub async fn tombstone_range( + &self, + write_cap: &[u8], + start: &[u8], + max_count: u32 + ) -> TombstoneRangeResult { + if max_count == 0 { + return TombstoneRangeResult { + envelopes: Vec::new(), + next: start.to_vec(), + error: None, + }; + } + + let mut cur = start.to_vec(); + let mut envelopes: Vec = Vec::with_capacity(max_count as usize); + + while (envelopes.len() as u32) < max_count { + match self.tombstone_box(write_cap, &cur).await { + Ok((ciphertext, env_desc, env_hash)) => { + envelopes.push(TombstoneEnvelope { + message_ciphertext: ciphertext, + envelope_descriptor: env_desc, + envelope_hash: env_hash, + box_index: cur.clone(), + }); + } + Err(e) => { + let count = envelopes.len(); + return TombstoneRangeResult { + envelopes, + next: cur, + error: Some(format!("Error creating tombstone at index {}: {:?}", count, e)), + }; + } + } + + match self.next_message_box_index(&cur).await { + Ok(next) => cur = next, + Err(e) => { + return TombstoneRangeResult { + envelopes, + next: cur, + error: Some(format!("Error getting next index after creating tombstone: {:?}", e)), + }; + } + } + } + + TombstoneRangeResult { + envelopes, + next: cur, + error: None, + } + } +} \ No newline at end of file diff --git a/testdata/thinclient.toml b/testdata/thinclient.toml index 06bab92..513a2a5 100644 --- a/testdata/thinclient.toml +++ b/testdata/thinclient.toml @@ -18,10 +18,10 @@ Address = "localhost:64331" KEMName = "" [PigeonholeGeometry] - BoxPayloadLength = 1556 - CourierQueryReadLength = 360 + MaxPlaintextPayloadLength = 1553 + CourierQueryReadLength = 359 CourierQueryWriteLength = 2000 - CourierQueryReplyReadLength = 1701 + CourierQueryReplyReadLength = 1698 CourierQueryReplyWriteLength = 50 NIKEName = "CTIDH1024-X25519" SignatureSchemeName = "Ed25519" diff --git a/tests/channel_api_test.rs b/tests/channel_api_test.rs index dfb56fc..3984e72 100644 --- a/tests/channel_api_test.rs +++ b/tests/channel_api_test.rs @@ -1,16 +1,29 @@ // SPDX-FileCopyrightText: Copyright (C) 2025 David Stainton // SPDX-License-Identifier: AGPL-3.0-only -//! Channel API integration tests for the Rust thin client -//! -//! These tests mirror the Go tests in courier_docker_test.go and require -//! a running mixnet with client daemon for integration testing. +//! NEW Pigeonhole API integration tests for the Rust thin client +//! +//! These tests verify the NEW Pigeonhole API: +//! 1. new_keypair - Generate WriteCap and ReadCap from seed +//! 2. encrypt_read - Encrypt a read operation +//! 3. encrypt_write - Encrypt a write operation +//! 4. start_resending_encrypted_message - Send encrypted message with ARQ +//! 5. cancel_resending_encrypted_message - Cancel ARQ for a message +//! 6. next_message_box_index - Increment MessageBoxIndex for multiple messages +//! 7. start_resending_copy_command - Send copy command via ARQ +//! 8. cancel_resending_copy_command - Cancel copy command ARQ +//! 9. create_courier_envelopes_from_payload - Chunk payload into courier envelopes +//! 10. create_courier_envelopes_from_multi_payload - Chunk multiple payloads efficiently +//! +//! Helper functions: +//! - tombstone_box - Create a tombstone (empty payload with valid signature) +//! - tombstone_range - Create tombstones for a range of boxes +//! +//! These tests require a running mixnet with client daemon for integration testing. use std::time::Duration; use katzenpost_thin_client::{ThinClient, Config}; - - /// Test helper to setup a thin client for integration tests async fn setup_thin_client() -> Result, Box> { let config = Config::new("testdata/thinclient.toml")?; @@ -22,798 +35,596 @@ async fn setup_thin_client() -> Result, Box Result<(), Box> { - let alice_thin_client = setup_thin_client().await?; - let bob_thin_client = setup_thin_client().await?; - - // Wait for PKI documents to be available and connection to mixnet - println!("Waiting for daemon to connect to mixnet..."); - let mut attempts = 0; - while !alice_thin_client.is_connected() && attempts < 30 { - tokio::time::sleep(Duration::from_secs(1)).await; - attempts += 1; - } +async fn test_new_keypair_basic() { + println!("\n=== Test: new_keypair basic functionality ==="); - if !alice_thin_client.is_connected() { - return Err("Daemon failed to connect to mixnet within 30 seconds".into()); - } + let client = setup_thin_client().await.expect("Failed to setup client"); - println!("✅ Daemon connected to mixnet, using current PKI document"); - - // Alice creates write channel - println!("Alice: Creating write channel"); - let (alice_channel_id, read_cap, _write_cap) = alice_thin_client.create_write_channel().await?; - println!("Alice: Created write channel {}", alice_channel_id); - - // Bob creates read channel using the read capability from Alice's write channel - println!("Bob: Creating read channel"); - let bob_channel_id = bob_thin_client.create_read_channel(read_cap).await?; - println!("Bob: Created read channel {}", bob_channel_id); - - // Alice writes first message - let original_message = b"hello1"; - println!("Alice: Writing first message and waiting for completion"); - - let write_reply1 = alice_thin_client.write_channel(alice_channel_id, original_message).await?; - println!("Alice: Write operation completed successfully"); - - // Get the courier service from PKI - let courier_service = alice_thin_client.get_service("courier").await?; - let (dest_node, dest_queue) = courier_service.to_destination(); - - let alice_message_id1 = ThinClient::new_message_id(); - - let _reply1 = alice_thin_client.send_channel_query_await_reply( - alice_channel_id, - &write_reply1.send_message_payload, - dest_node.clone(), - dest_queue.clone(), - alice_message_id1 - ).await?; - - // Alice writes a second message - let second_message = b"hello2"; - println!("Alice: Writing second message and waiting for completion"); - - let write_reply2 = alice_thin_client.write_channel(alice_channel_id, second_message).await?; - println!("Alice: Second write operation completed successfully"); - - let alice_message_id2 = ThinClient::new_message_id(); - - let _reply2 = alice_thin_client.send_channel_query_await_reply( - alice_channel_id, - &write_reply2.send_message_payload, - dest_node.clone(), - dest_queue.clone(), - alice_message_id2 - ).await?; - - // Wait for message propagation to storage replicas - println!("Waiting for message propagation to storage replicas"); - tokio::time::sleep(Duration::from_secs(10)).await; - - // Bob reads first message - println!("Bob: Reading first message"); - let read_reply1 = bob_thin_client.read_channel(bob_channel_id, None, None).await?; - - let bob_message_id1 = ThinClient::new_message_id(); - - // In a real implementation, you'd retry the SendChannelQueryAwaitReply until you get a response - let mut bob_reply_payload1 = vec![]; - for i in 0..10 { - match alice_thin_client.send_channel_query_await_reply( - bob_channel_id, - &read_reply1.send_message_payload, - dest_node.clone(), - dest_queue.clone(), - bob_message_id1.clone() - ).await { - Ok(payload) if !payload.is_empty() => { - bob_reply_payload1 = payload; - break; - } - Ok(_) => { - println!("Bob: Read attempt {} returned empty payload, retrying...", i + 1); - tokio::time::sleep(Duration::from_millis(500)).await; - } - Err(e) => return Err(e.into()), - } - } + // Generate a random 32-byte seed + let seed: [u8; 32] = rand::random(); - assert_eq!(original_message, bob_reply_payload1.as_slice(), "Bob: Reply payload mismatch"); - - // Bob reads second message - println!("Bob: Reading second message"); - let read_reply2 = bob_thin_client.read_channel(bob_channel_id, None, None).await?; - - let bob_message_id2 = ThinClient::new_message_id(); - let mut bob_reply_payload2 = vec![]; - - for i in 0..10 { - println!("Bob: second read attempt {}", i + 1); - match alice_thin_client.send_channel_query_await_reply( - bob_channel_id, - &read_reply2.send_message_payload, - dest_node.clone(), - dest_queue.clone(), - bob_message_id2.clone() - ).await { - Ok(payload) if !payload.is_empty() => { - bob_reply_payload2 = payload; - break; - } - Ok(_) => { - tokio::time::sleep(Duration::from_millis(500)).await; - } - Err(e) => return Err(e.into()), - } + // Create a new keypair + let result = client.new_keypair(&seed).await; + if let Err(ref e) = result { + println!("new_keypair error: {:?}", e); } + assert!(result.is_ok(), "new_keypair should succeed: {:?}", result.err()); - assert_eq!(second_message, bob_reply_payload2.as_slice(), "Bob: Second reply payload mismatch"); + let (write_cap, read_cap, first_index) = result.unwrap(); - // Clean up channels - alice_thin_client.close_channel(alice_channel_id).await?; - bob_thin_client.close_channel(bob_channel_id).await?; + // Verify we got non-empty capabilities + assert!(!write_cap.is_empty(), "WriteCap should not be empty"); + assert!(!read_cap.is_empty(), "ReadCap should not be empty"); + assert!(!first_index.is_empty(), "First message index should not be empty"); - alice_thin_client.stop().await; - bob_thin_client.stop().await; + println!("✓ Created keypair successfully"); + println!(" WriteCap length: {}", write_cap.len()); + println!(" ReadCap length: {}", read_cap.len()); + println!(" First index length: {}", first_index.len()); +} - println!("✅ Channel API basics test completed successfully"); - Ok(()) +#[tokio::test] +async fn test_alice_sends_bob_complete_workflow() { + println!("\n=== Test: Complete Alice sends to Bob workflow ==="); + + let alice_client = setup_thin_client().await.expect("Failed to setup Alice client"); + let bob_client = setup_thin_client().await.expect("Failed to setup Bob client"); + + // Alice creates a keypair + let alice_seed: [u8; 32] = rand::random(); + let (alice_write_cap, bob_read_cap, first_index) = alice_client.new_keypair(&alice_seed).await + .expect("Failed to create Alice's keypair"); + println!("✓ Alice created keypair"); + + // Alice encrypts and sends a message + let message = b"Hello Bob, this is Alice!"; + let (ciphertext, env_desc, env_hash) = alice_client + .encrypt_write(message, &alice_write_cap, &first_index).await + .expect("Failed to encrypt write"); + println!("✓ Alice encrypted message"); + + // Alice starts resending the encrypted message + let _alice_plaintext = alice_client.start_resending_encrypted_message( + None, + Some(&alice_write_cap), + None, + Some(0), + &env_desc, + &ciphertext, + &env_hash + ).await.expect("Failed to start resending"); + + println!("✓ Alice sent message via ARQ"); + + // Wait for message propagation + println!("Waiting for message propagation..."); + tokio::time::sleep(Duration::from_secs(5)).await; + + // Bob encrypts a read operation + let (bob_ciphertext, bob_next_index, bob_env_desc, bob_env_hash) = bob_client + .encrypt_read(&bob_read_cap, &first_index).await + .expect("Failed to encrypt read"); + println!("✓ Bob encrypted read operation"); + + // Bob starts resending to retrieve the message + let bob_plaintext = bob_client.start_resending_encrypted_message( + Some(&bob_read_cap), + None, + Some(&bob_next_index), + Some(0), + &bob_env_desc, + &bob_ciphertext, + &bob_env_hash + ).await.expect("Failed to retrieve message"); + + println!("✓ Bob received message"); + + // Verify the message matches + assert_eq!(bob_plaintext, message, "Bob should receive Alice's message"); + + println!("✅ Complete workflow test passed!"); + println!(" Message sent: {:?}", String::from_utf8_lossy(message)); + println!(" Message received: {:?}", String::from_utf8_lossy(&bob_plaintext)); } -/// Test resuming a write channel - equivalent to TestResumeWriteChannel from Go -/// This test demonstrates the write channel resumption workflow: -/// 1. Create a write channel -/// 2. Write the first message onto the channel -/// 3. Close the channel -/// 4. Resume the channel -/// 5. Write the second message onto the channel -/// 6. Create a read channel -/// 7. Read first and second message from the channel -/// 8. Verify payloads match #[tokio::test] -async fn test_resume_write_channel() -> Result<(), Box> { - let alice_thin_client = setup_thin_client().await?; - let bob_thin_client = setup_thin_client().await?; - - // Wait for PKI documents to be available and connection to mixnet - println!("Waiting for daemon to connect to mixnet..."); - let mut attempts = 0; - while !alice_thin_client.is_connected() && attempts < 30 { - tokio::time::sleep(Duration::from_secs(1)).await; - attempts += 1; - } +async fn test_next_message_box_index() { + println!("\n=== Test: next_message_box_index ==="); - if !alice_thin_client.is_connected() { - return Err("Daemon failed to connect to mixnet within 30 seconds".into()); - } + let client = setup_thin_client().await.expect("Failed to setup client"); - println!("✅ Daemon connected to mixnet, using current PKI document"); - - // Alice creates write channel - println!("Alice: Creating write channel"); - let (alice_channel_id, read_cap, write_cap) = alice_thin_client.create_write_channel().await?; - println!("Alice: Created write channel {}", alice_channel_id); - - // Alice writes first message - let alice_payload1 = b"Hello, Bob!"; - println!("Alice: Writing first message"); - let write_reply1 = alice_thin_client.write_channel(alice_channel_id, alice_payload1).await?; - - // Get courier destination - let (dest_node, dest_queue) = alice_thin_client.get_courier_destination().await?; - let alice_message_id1 = ThinClient::new_message_id(); - - // Send first message - let _reply1 = alice_thin_client.send_channel_query_await_reply( - alice_channel_id, - &write_reply1.send_message_payload, - dest_node.clone(), - dest_queue.clone(), - alice_message_id1 - ).await?; - - println!("Waiting for first message propagation to storage replicas"); - tokio::time::sleep(Duration::from_secs(3)).await; - - // Close the channel - alice_thin_client.close_channel(alice_channel_id).await?; - - // Resume the write channel - println!("Alice: Resuming write channel"); - let alice_channel_id = alice_thin_client.resume_write_channel( - write_cap, - Some(write_reply1.next_message_index) - ).await?; - println!("Alice: Resumed write channel with ID {}", alice_channel_id); - - // Write second message after resume - println!("Alice: Writing second message after resume"); - let alice_payload2 = b"Second message from Alice!"; - let write_reply2 = - alice_thin_client.write_channel(alice_channel_id, alice_payload2).await?; - - let alice_message_id2 = ThinClient::new_message_id(); - let _reply2 = alice_thin_client.send_channel_query_await_reply( - alice_channel_id, - &write_reply2.send_message_payload, - dest_node.clone(), - dest_queue.clone(), - alice_message_id2 - ).await?; - println!("Alice: Second write operation completed successfully"); - - println!("Waiting for second message propagation to storage replicas"); - tokio::time::sleep(Duration::from_secs(3)).await; - - // Bob creates read channel - println!("Bob: Creating read channel"); - let bob_channel_id = bob_thin_client.create_read_channel(read_cap).await?; - println!("Bob: Created read channel {}", bob_channel_id); - - // Bob reads first message - println!("Bob: Reading first message"); - let read_reply1 = - bob_thin_client.read_channel(bob_channel_id, None, None).await?; - - let bob_message_id1 = ThinClient::new_message_id(); - let mut bob_reply_payload1 = vec![]; - - for i in 0..10 { - match alice_thin_client.send_channel_query_await_reply( - bob_channel_id, - &read_reply1.send_message_payload, - dest_node.clone(), - dest_queue.clone(), - bob_message_id1.clone() - ).await { - Ok(payload) if !payload.is_empty() => { - bob_reply_payload1 = payload; - break; - } - Ok(_) => { - println!("Bob: First read attempt {} returned empty payload, retrying...", i + 1); - tokio::time::sleep(Duration::from_millis(500)).await; - } - Err(e) => return Err(e.into()), - } - } + // Generate keypair to get a first_index + let seed: [u8; 32] = rand::random(); + let (_write_cap, _read_cap, first_index) = client.new_keypair(&seed).await + .expect("Failed to create keypair"); - assert_eq!(alice_payload1, bob_reply_payload1.as_slice(), "Bob: First message payload mismatch"); - - // Bob reads second message - println!("Bob: Reading second message"); - let read_reply2 = - bob_thin_client.read_channel(bob_channel_id, None, None).await?; - - let bob_message_id2 = ThinClient::new_message_id(); - let mut bob_reply_payload2 = vec![]; - - for i in 0..10 { - println!("Bob: second message read attempt {}", i + 1); - match alice_thin_client.send_channel_query_await_reply( - bob_channel_id, - &read_reply2.send_message_payload, - dest_node.clone(), - dest_queue.clone(), - bob_message_id2.clone() - ).await { - Ok(payload) if !payload.is_empty() => { - bob_reply_payload2 = payload; - break; - } - Ok(_) => { - tokio::time::sleep(Duration::from_millis(500)).await; - } - Err(e) => return Err(e.into()), - } - } + println!("✓ Created keypair"); + println!(" First index length: {}", first_index.len()); - // Verify the second message content matches - assert_eq!(alice_payload2, bob_reply_payload2.as_slice(), "Bob: Second message payload mismatch"); - println!("Bob: Successfully received and verified second message"); + // Increment the index + let second_index = client.next_message_box_index(&first_index).await + .expect("Failed to get next message box index"); - // Clean up channels - alice_thin_client.close_channel(alice_channel_id).await?; - bob_thin_client.close_channel(bob_channel_id).await?; + assert!(!second_index.is_empty(), "Second index should not be empty"); + assert_ne!(first_index, second_index, "Second index should differ from first"); + println!("✓ Got second index (length: {})", second_index.len()); - alice_thin_client.stop().await; - bob_thin_client.stop().await; + // Increment again + let third_index = client.next_message_box_index(&second_index).await + .expect("Failed to get third message box index"); - println!("✅ Resume write channel test completed successfully"); - Ok(()) + assert!(!third_index.is_empty(), "Third index should not be empty"); + assert_ne!(second_index, third_index, "Third index should differ from second"); + println!("✓ Got third index (length: {})", third_index.len()); + + println!("✅ next_message_box_index test passed!"); } -/// Test resuming a write channel with query state - equivalent to TestResumeWriteChannelQuery from Go -/// This test demonstrates the write channel query resumption workflow: -/// 1. Create write channel -/// 2. Create first write query message but do not send to channel yet -/// 3. Close channel -/// 4. Resume write channel with query via ResumeWriteChannelQuery -/// 5. Send resumed write query to channel -/// 6. Send second message to channel -/// 7. Create read channel -/// 8. Read both messages from channel -/// 9. Verify payloads match #[tokio::test] -async fn test_resume_write_channel_query() -> Result<(), Box> { - let alice_thin_client = setup_thin_client().await?; - let bob_thin_client = setup_thin_client().await?; - - // Wait for PKI documents to be available and connection to mixnet - println!("Waiting for daemon to connect to mixnet..."); - let mut attempts = 0; - while !alice_thin_client.is_connected() && attempts < 30 { - tokio::time::sleep(Duration::from_secs(1)).await; - attempts += 1; - } - - if !alice_thin_client.is_connected() { - return Err("Daemon failed to connect to mixnet within 30 seconds".into()); +async fn test_create_courier_envelopes_from_payload() { + println!("\n=== Test: create_courier_envelopes_from_payload with Copy Command ==="); + + let alice_client = setup_thin_client().await.expect("Failed to setup Alice client"); + let bob_client = setup_thin_client().await.expect("Failed to setup Bob client"); + + // Step 1: Alice creates destination channel + println!("\n--- Step 1: Creating destination channel ---"); + let dest_seed: [u8; 32] = rand::random(); + let (dest_write_cap, dest_read_cap, dest_first_index) = alice_client.new_keypair(&dest_seed).await + .expect("Failed to create destination keypair"); + println!("✓ Alice created destination channel"); + + // Step 2: Alice creates temporary copy stream channel + println!("\n--- Step 2: Creating temporary copy stream channel ---"); + let temp_seed: [u8; 32] = rand::random(); + let (temp_write_cap, _temp_read_cap, temp_first_index) = alice_client.new_keypair(&temp_seed).await + .expect("Failed to create temp keypair"); + println!("✓ Alice created temporary copy stream channel"); + + // Step 3: Create a payload with length prefix (like Go/Python tests) + println!("\n--- Step 3: Creating payload ---"); + let random_data: Vec = (0..100).map(|_| rand::random::()).collect(); + let mut large_payload = Vec::new(); + large_payload.extend_from_slice(&(random_data.len() as u32).to_be_bytes()); + large_payload.extend_from_slice(&random_data); + println!("✓ Alice created payload ({} bytes)", large_payload.len()); + + // Step 4: Create copy stream chunks from the payload + println!("\n--- Step 4: Creating copy stream chunks ---"); + let stream_id = ThinClient::new_stream_id(); + let copy_stream_result = alice_client.create_courier_envelopes_from_payload( + &stream_id, + &large_payload, + &dest_write_cap, + &dest_first_index, + true // is_last + ).await.expect("Failed to create courier envelopes from payload"); + + assert!(!copy_stream_result.envelopes.is_empty(), "Should have at least one chunk"); + println!("✓ Alice created {} copy stream chunks", copy_stream_result.envelopes.len()); + + // Step 5: Write all copy stream chunks to the temporary channel + println!("\n--- Step 5: Writing copy stream chunks to temp channel ---"); + let mut temp_index = temp_first_index.clone(); + for (i, chunk) in copy_stream_result.envelopes.iter().enumerate() { + let (ciphertext, env_desc, env_hash) = alice_client + .encrypt_write(chunk, &temp_write_cap, &temp_index).await + .expect("Failed to encrypt chunk"); + + let _ = alice_client.start_resending_encrypted_message( + None, + Some(&temp_write_cap), + None, + Some(0), + &env_desc, + &ciphertext, + &env_hash + ).await.expect("Failed to send chunk via ARQ"); + + println!(" ✓ Wrote chunk {} ({} bytes)", i + 1, chunk.len()); + + // Advance to next index for next chunk + temp_index = alice_client.next_message_box_index(&temp_index).await + .expect("Failed to get next index"); } - println!("✅ Daemon connected to mixnet, using current PKI document"); - - // Alice creates write channel - println!("Alice: Creating write channel"); - let (alice_channel_id, read_cap, write_cap) = alice_thin_client.create_write_channel().await?; - println!("Alice: Created write channel {}", alice_channel_id); - - // Alice prepares first message but doesn't send it yet - let alice_payload1 = b"Hello, Bob!"; - let write_reply = alice_thin_client.write_channel(alice_channel_id, alice_payload1).await?; - - // Get courier destination - let (courier_node, courier_queue_id) = alice_thin_client.get_courier_destination().await?; - let alice_message_id1 = ThinClient::new_message_id(); - - // Close the channel immediately (like in Go test - no waiting for propagation) - alice_thin_client.close_channel(alice_channel_id).await?; - - // Resume the write channel with query state using current_message_index like Go test - println!("Alice: Resuming write channel"); - let alice_channel_id = alice_thin_client.resume_write_channel_query( - write_cap, - write_reply.current_message_index, // Use current_message_index like in Go test - write_reply.envelope_descriptor, - write_reply.envelope_hash - ).await?; - println!("Alice: Resumed write channel with ID {}", alice_channel_id); - - // Send the first message after resume - println!("Alice: Writing first message after resume"); - let _reply1 = alice_thin_client.send_channel_query_await_reply( - alice_channel_id, - &write_reply.send_message_payload, - courier_node.clone(), - courier_queue_id.clone(), - alice_message_id1 - ).await?; - - // Write second message - println!("Alice: Writing second message"); - let alice_payload2 = b"Second message from Alice!"; - let write_reply2 = - alice_thin_client.write_channel(alice_channel_id, alice_payload2).await?; - - let alice_message_id2 = ThinClient::new_message_id(); - let _reply2 = alice_thin_client.send_channel_query_await_reply( - alice_channel_id, - &write_reply2.send_message_payload, - courier_node.clone(), - courier_queue_id.clone(), - alice_message_id2 - ).await?; - println!("Alice: Second write operation completed successfully"); - - println!("Waiting for second message propagation to storage replicas"); - tokio::time::sleep(Duration::from_secs(3)).await; - - // Bob creates read channel - println!("Bob: Creating read channel"); - let bob_channel_id = bob_thin_client.create_read_channel(read_cap).await?; - println!("Bob: Created read channel {}", bob_channel_id); - - // Bob reads first message - println!("Bob: Reading first message"); - let read_reply1 = - bob_thin_client.read_channel(bob_channel_id, None, None).await?; - - let bob_message_id1 = ThinClient::new_message_id(); - let mut bob_reply_payload1 = vec![]; - - for i in 0..10 { - match alice_thin_client.send_channel_query_await_reply( - bob_channel_id, - &read_reply1.send_message_payload, - courier_node.clone(), - courier_queue_id.clone(), - bob_message_id1.clone() - ).await { - Ok(payload) if !payload.is_empty() => { - bob_reply_payload1 = payload; - break; - } - Ok(_) => { - println!("Bob: First read attempt {} returned empty payload, retrying...", i + 1); - tokio::time::sleep(Duration::from_millis(500)).await; - } - Err(e) => return Err(e.into()), - } - } + // Wait for chunks to propagate + println!("\n--- Waiting for copy stream chunks to propagate (30 seconds) ---"); + tokio::time::sleep(Duration::from_secs(30)).await; + + // Step 6: Send Copy command to courier + println!("\n--- Step 6: Sending Copy command to courier via ARQ ---"); + alice_client.start_resending_copy_command(&temp_write_cap, None, None).await + .expect("Failed to send copy command"); + println!("✓ Alice copy command completed"); + + // Wait for copy command to execute + println!("\n--- Waiting for copy command to execute (30 seconds) ---"); + tokio::time::sleep(Duration::from_secs(30)).await; + + // Step 7: Bob reads from destination channel + println!("\n--- Step 7: Bob reads from destination channel ---"); + let (bob_ciphertext, bob_next_index, bob_env_desc, bob_env_hash) = bob_client + .encrypt_read(&dest_read_cap, &dest_first_index).await + .expect("Failed to encrypt read"); + + let bob_plaintext = bob_client.start_resending_encrypted_message( + Some(&dest_read_cap), + None, + Some(&bob_next_index), + Some(0), + &bob_env_desc, + &bob_ciphertext, + &bob_env_hash + ).await.expect("Failed to retrieve message"); + + println!("✓ Bob received {} bytes", bob_plaintext.len()); + + // Verify the payload matches + assert_eq!(bob_plaintext, large_payload, "Received payload should match original"); + + println!("✅ create_courier_envelopes_from_payload test passed!"); +} - assert_eq!(alice_payload1, bob_reply_payload1.as_slice(), "Bob: First message payload mismatch"); - - // Bob reads second message - println!("Bob: Reading second message"); - let read_reply2 = - bob_thin_client.read_channel(bob_channel_id, None, None).await?; - - let bob_message_id2 = ThinClient::new_message_id(); - let mut bob_reply_payload2 = vec![]; - - for i in 0..10 { - println!("Bob: second message read attempt {}", i + 1); - match alice_thin_client.send_channel_query_await_reply( - bob_channel_id, - &read_reply2.send_message_payload, - courier_node.clone(), - courier_queue_id.clone(), - bob_message_id2.clone() - ).await { - Ok(payload) if !payload.is_empty() => { - bob_reply_payload2 = payload; - break; - } - Ok(_) => { - tokio::time::sleep(Duration::from_millis(500)).await; - } - Err(e) => return Err(e.into()), - } +#[tokio::test] +async fn test_create_courier_envelopes_from_multi_payload_multi_channel() { + println!("\n=== Test: create_courier_envelopes_from_multi_payload (efficient multi-channel) ==="); + + let alice_client = setup_thin_client().await.expect("Failed to setup Alice client"); + let bob_client = setup_thin_client().await.expect("Failed to setup Bob client"); + + // Step 1: Create two destination channels + println!("\n--- Step 1: Creating two destination channels ---"); + let chan1_seed: [u8; 32] = rand::random(); + let (chan1_write_cap, chan1_read_cap, chan1_first_index) = alice_client.new_keypair(&chan1_seed).await + .expect("Failed to create channel 1 keypair"); + println!("✓ Created Channel 1"); + + let chan2_seed: [u8; 32] = rand::random(); + let (chan2_write_cap, chan2_read_cap, chan2_first_index) = alice_client.new_keypair(&chan2_seed).await + .expect("Failed to create channel 2 keypair"); + println!("✓ Created Channel 2"); + + // Step 2: Create temporary copy stream channel + println!("\n--- Step 2: Creating temporary copy stream channel ---"); + let temp_seed: [u8; 32] = rand::random(); + let (temp_write_cap, _temp_read_cap, temp_first_index) = alice_client.new_keypair(&temp_seed).await + .expect("Failed to create temp keypair"); + println!("✓ Created temporary copy stream channel"); + + // Step 3: Create payloads for each channel + println!("\n--- Step 3: Creating payloads ---"); + let payload1 = b"Hello from Channel 1! This is payload one.".to_vec(); + let payload2 = b"Hello from Channel 2! This is payload two.".to_vec(); + println!("✓ Created payload1 ({} bytes) and payload2 ({} bytes)", payload1.len(), payload2.len()); + + // Step 4: Create copy stream chunks using efficient multi-destination API + println!("\n--- Step 4: Creating copy stream chunks using efficient API ---"); + let stream_id = ThinClient::new_stream_id(); + + let destinations = vec![ + (payload1.as_slice(), chan1_write_cap.as_slice(), chan1_first_index.as_slice()), + (payload2.as_slice(), chan2_write_cap.as_slice(), chan2_first_index.as_slice()), + ]; + + let result = alice_client.create_courier_envelopes_from_multi_payload( + &stream_id, + destinations, + true // is_last + ).await.expect("Failed to create courier envelopes from multi payload"); + + assert!(!result.envelopes.is_empty(), "Should have at least one chunk"); + println!("✓ Created {} copy stream chunks for both destinations", result.envelopes.len()); + + // Step 5: Write all chunks to temporary channel + println!("\n--- Step 5: Writing copy stream chunks to temp channel ---"); + let mut temp_index = temp_first_index.clone(); + for (i, chunk) in result.envelopes.iter().enumerate() { + let (ciphertext, env_desc, env_hash) = alice_client + .encrypt_write(chunk, &temp_write_cap, &temp_index).await + .expect("Failed to encrypt chunk"); + + let _ = alice_client.start_resending_encrypted_message( + None, + Some(&temp_write_cap), + None, + Some(0), + &env_desc, + &ciphertext, + &env_hash + ).await.expect("Failed to send chunk via ARQ"); + + println!(" ✓ Wrote chunk {} ({} bytes)", i + 1, chunk.len()); + + temp_index = alice_client.next_message_box_index(&temp_index).await + .expect("Failed to get next index"); } - // Verify the second message content matches - assert_eq!(alice_payload2, bob_reply_payload2.as_slice(), "Bob: Second message payload mismatch"); - println!("Bob: Successfully received and verified second message"); - - // Clean up channels - alice_thin_client.close_channel(alice_channel_id).await?; - bob_thin_client.close_channel(bob_channel_id).await?; - - alice_thin_client.stop().await; - bob_thin_client.stop().await; - - println!("✅ Resume write channel query test completed successfully"); - Ok(()) + // Wait for chunks to propagate + println!("\n--- Waiting for copy stream chunks to propagate (30 seconds) ---"); + tokio::time::sleep(Duration::from_secs(30)).await; + + // Step 6: Send Copy command + println!("\n--- Step 6: Sending Copy command via ARQ ---"); + alice_client.start_resending_copy_command(&temp_write_cap, None, None).await + .expect("Failed to send copy command"); + println!("✓ Copy command completed"); + + // Wait for copy command to execute + println!("\n--- Waiting for copy command to execute (30 seconds) ---"); + tokio::time::sleep(Duration::from_secs(30)).await; + + // Step 7: Bob reads from Channel 1 + println!("\n--- Step 7: Bob reads from Channel 1 ---"); + let (bob1_ciphertext, bob1_next_index, bob1_env_desc, bob1_env_hash) = bob_client + .encrypt_read(&chan1_read_cap, &chan1_first_index).await + .expect("Failed to encrypt read for channel 1"); + + let bob1_plaintext = bob_client.start_resending_encrypted_message( + Some(&chan1_read_cap), + None, + Some(&bob1_next_index), + Some(0), + &bob1_env_desc, + &bob1_ciphertext, + &bob1_env_hash + ).await.expect("Failed to retrieve from channel 1"); + + println!("✓ Bob received from Channel 1: {:?}", String::from_utf8_lossy(&bob1_plaintext)); + assert_eq!(bob1_plaintext, payload1, "Channel 1 payload mismatch"); + + // Step 8: Bob reads from Channel 2 + println!("\n--- Step 8: Bob reads from Channel 2 ---"); + let (bob2_ciphertext, bob2_next_index, bob2_env_desc, bob2_env_hash) = bob_client + .encrypt_read(&chan2_read_cap, &chan2_first_index).await + .expect("Failed to encrypt read for channel 2"); + + let bob2_plaintext = bob_client.start_resending_encrypted_message( + Some(&chan2_read_cap), + None, + Some(&bob2_next_index), + Some(0), + &bob2_env_desc, + &bob2_ciphertext, + &bob2_env_hash + ).await.expect("Failed to retrieve from channel 2"); + + println!("✓ Bob received from Channel 2: {:?}", String::from_utf8_lossy(&bob2_plaintext)); + assert_eq!(bob2_plaintext, payload2, "Channel 2 payload mismatch"); + + println!("✅ create_courier_envelopes_from_multi_payload multi-channel test passed!"); } -/// Test resuming a read channel - equivalent to TestResumeReadChannel from Go -/// This test demonstrates the read channel resumption workflow: -/// 1. Create a write channel -/// 2. Write two messages to the channel -/// 3. Create a read channel -/// 4. Read the first message from the channel -/// 5. Verify payload matches -/// 6. Close the read channel -/// 7. Resume the read channel -/// 8. Read the second message from the channel -/// 9. Verify payload matches +// TestTombstoning tests the tombstoning API: +// 1. Alice writes a message to a box +// 2. Bob reads and verifies the message +// 3. Alice tombstones the box (deletes it with an empty payload) +// 4. Bob reads again and verifies the tombstone #[tokio::test] -async fn test_resume_read_channel() -> Result<(), Box> { - let alice_thin_client = setup_thin_client().await?; - let bob_thin_client = setup_thin_client().await?; - - // Wait for PKI documents to be available and connection to mixnet - println!("Waiting for daemon to connect to mixnet..."); - let mut attempts = 0; - while !alice_thin_client.is_connected() && attempts < 30 { - tokio::time::sleep(Duration::from_secs(1)).await; - attempts += 1; +async fn test_tombstone_box() { + let alice = setup_thin_client().await.expect("Failed to setup Alice client"); + let bob = setup_thin_client().await.expect("Failed to setup Bob client"); + + // Create keypair + let seed: [u8; 32] = rand::random(); + let (write_cap, read_cap, first_index) = alice.new_keypair(&seed).await + .expect("Failed to create keypair"); + println!("✓ Created keypair"); + + // Step 1: Alice writes a message + let message = b"Secret message that will be tombstoned"; + let (ciphertext, env_desc, env_hash) = alice + .encrypt_write(message, &write_cap, &first_index).await + .expect("Failed to encrypt write"); + + let reply_index: u8 = 0; + alice.start_resending_encrypted_message( + None, + Some(&write_cap), + None, + Some(reply_index), + &env_desc, + &ciphertext, + &env_hash + ).await.expect("Failed to send message"); + println!("✓ Alice wrote message"); + + println!("Waiting for 30 seconds for message propagation..."); + tokio::time::sleep(Duration::from_secs(30)).await; + + // Step 2: Bob reads and verifies + let (bob_ciphertext, bob_next_index, bob_env_desc, bob_env_hash) = bob + .encrypt_read(&read_cap, &first_index).await + .expect("Failed to encrypt read"); + + let plaintext = bob.start_resending_encrypted_message( + Some(&read_cap), + None, + Some(&bob_next_index), + Some(reply_index), + &bob_env_desc, + &bob_ciphertext, + &bob_env_hash + ).await.expect("Failed to read message"); + + assert_eq!(plaintext, message, "Message mismatch"); + println!("✓ Bob read message: {:?}", String::from_utf8_lossy(&plaintext)); + + // Step 3: Alice tombstones the box + let (tomb_ciphertext, tomb_env_desc, tomb_env_hash) = alice + .tombstone_box(&write_cap, &first_index).await + .expect("Failed to create tombstone"); + + let tomb_env_hash_arr: [u8; 32] = tomb_env_hash.try_into() + .expect("envelope_hash should be 32 bytes"); + + alice.start_resending_encrypted_message( + None, + Some(&write_cap), + None, + None, // reply_index is nil for tombstone writes + &tomb_env_desc, + &tomb_ciphertext, + &tomb_env_hash_arr + ).await.expect("Failed to send tombstone"); + println!("✓ Alice tombstoned the box"); + + // Step 4: Bob polls for tombstone with retries (matching Go test) + const MAX_ATTEMPTS: u32 = 6; + const POLL_INTERVAL_SECS: u64 = 10; + let mut tombstone_verified = false; + + for attempt in 1..=MAX_ATTEMPTS { + println!("Polling for tombstone (attempt {}/{})...", attempt, MAX_ATTEMPTS); + tokio::time::sleep(Duration::from_secs(POLL_INTERVAL_SECS)).await; + + let (ciphertext2, next_idx2, env_desc2, env_hash2) = bob + .encrypt_read(&read_cap, &first_index).await + .expect("Failed to encrypt read for tombstone check"); + + let bob_plaintext2 = bob.start_resending_encrypted_message( + Some(&read_cap), + None, + Some(&next_idx2), + Some(reply_index), + &env_desc2, + &ciphertext2, + &env_hash2 + ).await.expect("Failed to read tombstone"); + + if bob_plaintext2.is_empty() { + tombstone_verified = true; + println!("✓ Bob verified tombstone on attempt {}", attempt); + break; + } + println!(" Still seeing original message ({} bytes), retrying...", bob_plaintext2.len()); } - if !alice_thin_client.is_connected() { - return Err("Daemon failed to connect to mixnet within 30 seconds".into()); - } + assert!(tombstone_verified, "Tombstone not propagated after {} attempts", MAX_ATTEMPTS); + println!("\n✅ Tombstoning test passed!"); +} - println!("✅ Daemon connected to mixnet, using current PKI document"); - - // Alice creates write channel - println!("Alice: Creating write channel"); - let (alice_channel_id, read_cap, _write_cap) = alice_thin_client.create_write_channel().await?; - println!("Alice: Created write channel {}", alice_channel_id); - - // Alice writes first message - let alice_payload1 = b"Hello, Bob!"; - let write_reply1 = - alice_thin_client.write_channel(alice_channel_id, alice_payload1).await?; - - let (dest_node, dest_queue) = alice_thin_client.get_courier_destination().await?; - let alice_message_id1 = ThinClient::new_message_id(); - - let _reply1 = alice_thin_client.send_channel_query_await_reply( - alice_channel_id, - &write_reply1.send_message_payload, - dest_node.clone(), - dest_queue.clone(), - alice_message_id1 - ).await?; - - println!("Waiting for first message propagation to storage replicas"); - tokio::time::sleep(Duration::from_secs(3)).await; - - // Alice writes second message - println!("Alice: Writing second message"); - let alice_payload2 = b"Second message from Alice!"; - let write_reply2 = - alice_thin_client.write_channel(alice_channel_id, alice_payload2).await?; - - let alice_message_id2 = ThinClient::new_message_id(); - let _reply2 = alice_thin_client.send_channel_query_await_reply( - alice_channel_id, - &write_reply2.send_message_payload, - dest_node.clone(), - dest_queue.clone(), - alice_message_id2 - ).await?; - println!("Alice: Second write operation completed successfully"); - - println!("Waiting for second message propagation to storage replicas"); - tokio::time::sleep(Duration::from_secs(3)).await; - - // Bob creates read channel - println!("Bob: Creating read channel"); - let bob_channel_id = bob_thin_client.create_read_channel(read_cap.clone()).await?; - println!("Bob: Created read channel {}", bob_channel_id); - - // Bob reads first message - println!("Bob: Reading first message"); - let read_reply1 = bob_thin_client.read_channel(bob_channel_id, None, None).await?; - - let bob_message_id1 = ThinClient::new_message_id(); - let mut bob_reply_payload1 = vec![]; - - for i in 0..10 { - match alice_thin_client.send_channel_query_await_reply( - bob_channel_id, - &read_reply1.send_message_payload, - dest_node.clone(), - dest_queue.clone(), - bob_message_id1.clone() - ).await { - Ok(payload) if !payload.is_empty() => { - bob_reply_payload1 = payload; - break; - } - Ok(_) => { - println!("Bob: First read attempt {} returned empty payload, retrying...", i + 1); - tokio::time::sleep(Duration::from_millis(500)).await; - } - Err(e) => return Err(e.into()), +#[tokio::test] +async fn test_tombstone_range() { + println!("\n=== Test: tombstone_range ==="); + + let alice_client = setup_thin_client().await.expect("Failed to setup Alice client"); + + // Get the geometry from the config + let _geometry = alice_client.pigeonhole_geometry().clone(); + + // Create keypair + let seed: [u8; 32] = rand::random(); + let (write_cap, _read_cap, first_index) = alice_client.new_keypair(&seed).await + .expect("Failed to create keypair"); + println!("✓ Created keypair"); + + // Write 3 messages to sequential boxes + let num_messages: u32 = 3; + let mut current_index = first_index.clone(); + + println!("\n--- Writing {} messages ---", num_messages); + for i in 0..num_messages { + let message = format!("Message {} to be tombstoned", i + 1); + let (ciphertext, env_desc, env_hash) = alice_client + .encrypt_write(message.as_bytes(), &write_cap, ¤t_index).await + .expect("Failed to encrypt write"); + + let _ = alice_client.start_resending_encrypted_message( + None, + Some(&write_cap), + None, + Some(0), + &env_desc, + &ciphertext, + &env_hash + ).await.expect("Failed to send message"); + println!("✓ Wrote message {}", i + 1); + + if i < num_messages - 1 { + current_index = alice_client.next_message_box_index(¤t_index).await + .expect("Failed to get next index"); } } - assert_eq!(alice_payload1, bob_reply_payload1.as_slice(), "Bob: First message payload mismatch"); - - // Close the read channel - bob_thin_client.close_channel(bob_channel_id).await?; - - // Resume the read channel - println!("Bob: Resuming read channel"); - let bob_channel_id = bob_thin_client.resume_read_channel( - read_cap, - Some(read_reply1.next_message_index), - read_reply1.reply_index - ).await?; - println!("Bob: Resumed read channel with ID {}", bob_channel_id); - - // Bob reads second message - println!("Bob: Reading second message"); - let read_reply2 = bob_thin_client.read_channel(bob_channel_id, None, None).await?; - - let bob_message_id2 = ThinClient::new_message_id(); - let mut bob_reply_payload2 = vec![]; - - for i in 0..10 { - println!("Bob: second message read attempt {}", i + 1); - match alice_thin_client.send_channel_query_await_reply( - bob_channel_id, - &read_reply2.send_message_payload, - dest_node.clone(), - dest_queue.clone(), - bob_message_id2.clone() - ).await { - Ok(payload) if !payload.is_empty() => { - bob_reply_payload2 = payload; - break; - } - Ok(_) => { - tokio::time::sleep(Duration::from_millis(500)).await; - } - Err(e) => return Err(e.into()), - } + // Wait for messages to propagate + println!("--- Waiting for message propagation (30 seconds) ---"); + tokio::time::sleep(Duration::from_secs(30)).await; + + // Tombstone the range - creates envelopes without sending + println!("\n--- Creating tombstones for {} boxes ---", num_messages); + let result = alice_client.tombstone_range(&write_cap, &first_index, num_messages).await; + + assert!(result.error.is_none(), "Unexpected error: {:?}", result.error); + assert_eq!(result.envelopes.len(), num_messages as usize, "Expected {} envelopes, got {}", num_messages, result.envelopes.len()); + assert!(!result.next.is_empty(), "Next index should not be empty"); + println!("✓ Created {} tombstone envelopes", result.envelopes.len()); + + // Send all tombstone envelopes + println!("\n--- Sending {} tombstone envelopes ---", num_messages); + for (i, envelope) in result.envelopes.iter().enumerate() { + // Convert envelope_hash Vec to [u8; 32] + let env_hash: [u8; 32] = envelope.envelope_hash.clone().try_into() + .expect("envelope_hash should be 32 bytes"); + alice_client.start_resending_encrypted_message( + None, + Some(&write_cap), + None, + None, // reply_index must be None for tombstone writes + &envelope.envelope_descriptor, + &envelope.message_ciphertext, + &env_hash + ).await.expect("Failed to send tombstone envelope"); + println!("✓ Sent tombstone envelope {}", i + 1); } - // Verify the second message content matches - assert_eq!(alice_payload2, bob_reply_payload2.as_slice(), "Bob: Second message payload mismatch"); - println!("Bob: Successfully received and verified second message"); - - // Clean up channels - alice_thin_client.close_channel(alice_channel_id).await?; - bob_thin_client.close_channel(bob_channel_id).await?; - - alice_thin_client.stop().await; - bob_thin_client.stop().await; - - println!("✅ Resume read channel test completed successfully"); - Ok(()) + println!("✅ tombstone_range test passed! Created and sent {} tombstones successfully!", num_messages); } -/// Test resuming a read channel with query state - equivalent to TestResumeReadChannelQuery from Go -/// This test demonstrates the read channel query resumption workflow: -/// 1. Create a write channel -/// 2. Write two messages to the channel -/// 3. Create read channel -/// 4. Make read query but do not send it -/// 5. Close read channel -/// 6. Resume read channel query with ResumeReadChannelQuery method -/// 7. Send previously made read query to channel -/// 8. Verify received payload matches -/// 9. Read second message from channel -/// 10. Verify received payload matches #[tokio::test] -async fn test_resume_read_channel_query() -> Result<(), Box> { - let alice_thin_client = setup_thin_client().await?; - let bob_thin_client = setup_thin_client().await?; - - // Wait for PKI documents to be available and connection to mixnet - println!("Waiting for daemon to connect to mixnet..."); - let mut attempts = 0; - while !alice_thin_client.is_connected() && attempts < 30 { - tokio::time::sleep(Duration::from_secs(1)).await; - attempts += 1; - } - - if !alice_thin_client.is_connected() { - return Err("Daemon failed to connect to mixnet within 30 seconds".into()); - } - - println!("✅ Daemon connected to mixnet, using current PKI document"); - - // Alice creates write channel - println!("Alice: Creating write channel"); - let (alice_channel_id, read_cap, _write_cap) = alice_thin_client.create_write_channel().await?; - println!("Alice: Created write channel {}", alice_channel_id); - - // Alice writes first message - let alice_payload1 = b"Hello, Bob!"; - let write_reply1 = - alice_thin_client.write_channel(alice_channel_id, alice_payload1).await?; - - let (dest_node, dest_queue) = alice_thin_client.get_courier_destination().await?; - let alice_message_id1 = ThinClient::new_message_id(); - - let _reply1 = alice_thin_client.send_channel_query_await_reply( - alice_channel_id, - &write_reply1.send_message_payload, - dest_node.clone(), - dest_queue.clone(), - alice_message_id1 - ).await?; - - println!("Waiting for first message propagation to storage replicas"); - tokio::time::sleep(Duration::from_secs(3)).await; - - // Alice writes second message - println!("Alice: Writing second message"); - let alice_payload2 = b"Second message from Alice!"; - let write_reply2 = - alice_thin_client.write_channel(alice_channel_id, alice_payload2).await?; - - let alice_message_id2 = ThinClient::new_message_id(); - let _reply2 = alice_thin_client.send_channel_query_await_reply( - alice_channel_id, - &write_reply2.send_message_payload, - dest_node.clone(), - dest_queue.clone(), - alice_message_id2 - ).await?; - println!("Alice: Second write operation completed successfully"); - - println!("Waiting for second message propagation to storage replicas"); - tokio::time::sleep(Duration::from_secs(3)).await; - - // Bob creates read channel - println!("Bob: Creating read channel"); - let bob_channel_id = bob_thin_client.create_read_channel(read_cap.clone()).await?; - println!("Bob: Created read channel {}", bob_channel_id); - - // Bob prepares first read query but doesn't send it yet - println!("Bob: Reading first message"); - let read_reply1 = bob_thin_client.read_channel(bob_channel_id, None, None).await?; - - // Close the read channel - bob_thin_client.close_channel(bob_channel_id).await?; - - // Resume the read channel with query state - println!("Bob: Resuming read channel"); - let bob_channel_id = bob_thin_client.resume_read_channel_query( - read_cap, - read_reply1.current_message_index, - read_reply1.reply_index, - read_reply1.envelope_descriptor, - read_reply1.envelope_hash - ).await?; - println!("Bob: Resumed read channel with ID {}", bob_channel_id); - - // Send the first read query and get the message payload - let bob_message_id1 = ThinClient::new_message_id(); - let mut bob_reply_payload1 = vec![]; - - for i in 0..10 { - println!("Bob: first message read attempt {}", i + 1); - match alice_thin_client.send_channel_query_await_reply( - bob_channel_id, - &read_reply1.send_message_payload, - dest_node.clone(), - dest_queue.clone(), - bob_message_id1.clone() - ).await { - Ok(payload) if !payload.is_empty() => { - bob_reply_payload1 = payload; - break; - } - Ok(_) => { - tokio::time::sleep(Duration::from_millis(500)).await; - } - Err(e) => return Err(e.into()), +async fn test_box_id_not_found_error() { + println!("\n=== Test: BoxIDNotFoundError ==="); + println!("This test verifies that reading from a non-existent box returns BoxNotFound error"); + + let client = setup_thin_client().await.expect("Failed to setup client"); + + // Create a fresh keypair - but do NOT write anything to it + let seed: [u8; 32] = rand::random(); + let (_write_cap, read_cap, first_index) = client.new_keypair(&seed).await + .expect("Failed to create keypair"); + println!("✓ Created fresh keypair (no messages written)"); + + // Encrypt a read request for the non-existent box + let (ciphertext, next_index, env_desc, env_hash) = client + .encrypt_read(&read_cap, &first_index).await + .expect("Failed to encrypt read"); + println!("✓ Encrypted read request for non-existent box"); + + // Attempt to read - this should return BoxNotFound error + // Use start_resending_encrypted_message_no_retry to get immediate error without retries + println!("--- Attempting to read from non-existent box ---"); + let result = client.start_resending_encrypted_message_no_retry( + Some(&read_cap), + None, + Some(&next_index), + Some(0), + &env_desc, + &ciphertext, + &env_hash + ).await; + + // Verify we got the expected error + match result { + Err(katzenpost_thin_client::ThinClientError::BoxNotFound) => { + println!("✓ Received expected BoxNotFound error"); + println!("✅ BoxIDNotFoundError test passed!"); } - } - - assert_eq!(alice_payload1, bob_reply_payload1.as_slice(), "Bob: First message payload mismatch"); - - // Bob reads second message - println!("Bob: Reading second message"); - let read_reply2 = bob_thin_client.read_channel(bob_channel_id, None, None).await?; - - let bob_message_id2 = ThinClient::new_message_id(); - let mut bob_reply_payload2 = vec![]; - - for i in 0..10 { - println!("Bob: second message read attempt {}", i + 1); - match alice_thin_client.send_channel_query_await_reply( - bob_channel_id, - &read_reply2.send_message_payload, - dest_node.clone(), - dest_queue.clone(), - bob_message_id2.clone() - ).await { - Ok(payload) if !payload.is_empty() => { - bob_reply_payload2 = payload; - break; - } - Ok(_) => { - tokio::time::sleep(Duration::from_millis(500)).await; - } - Err(e) => return Err(e.into()), + Err(e) => { + panic!("Expected BoxNotFound error but got: {:?}", e); + } + Ok(plaintext) => { + panic!("Expected BoxNotFound error but got success with plaintext len: {}", plaintext.len()); } } - - // Verify the second message content matches - assert_eq!(alice_payload2, bob_reply_payload2.as_slice(), "Bob: Second message payload mismatch"); - println!("Bob: Successfully received and verified second message"); - - // Clean up channels - alice_thin_client.close_channel(alice_channel_id).await?; - bob_thin_client.close_channel(bob_channel_id).await?; - - alice_thin_client.stop().await; - bob_thin_client.stop().await; - - println!("✅ Resume read channel query test completed successfully"); - Ok(()) } diff --git a/tests/doodle_test.rs b/tests/doodle_test.rs new file mode 100644 index 0000000..7f03d53 --- /dev/null +++ b/tests/doodle_test.rs @@ -0,0 +1,386 @@ +// SPDX-FileCopyrightText: Copyright (C) 2026 David Stainton +// SPDX-License-Identifier: AGPL-3.0-only + +//! Doodle protocol integration tests. +//! +//! Tests the distributed meeting-poll protocol over a live mixnet. +//! +//! # Running tests one at a time with output +//! +//! cargo test --test doodle_test test_doodle_two_voters -- --nocapture +//! cargo test --test doodle_test test_doodle_three_voters -- --nocapture +//! cargo test --test doodle_test test_doodle_ballot_update -- --nocapture +//! cargo test --test doodle_test test_doodle_best_slot -- --nocapture + +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Duration; + +use katzenpost_thin_client::doodle::{Availability, DoodlePoll, TimeSlot}; +use katzenpost_thin_client::persistent::PigeonholeClient; +use katzenpost_thin_client::{Config, ThinClient}; + +// --------------------------------------------------------------------------- +// Helper +// --------------------------------------------------------------------------- + +async fn setup_client(label: &str) -> Result, Box> { + println!("[{}] ThinClient::new ...", label); + let config = Config::new("testdata/thinclient.toml")?; + let client = ThinClient::new(config).await?; + tokio::time::sleep(Duration::from_secs(2)).await; + println!("[{}] setup_client done", label); + Ok(client) +} + +fn two_slots() -> Vec { + vec![ + TimeSlot::new("mon-9", "Monday 09:00"), + TimeSlot::new("tue-9", "Tuesday 09:00"), + ] +} + +fn three_slots() -> Vec { + vec![ + TimeSlot::new("mon-9", "Monday 09:00"), + TimeSlot::new("tue-9", "Tuesday 09:00"), + TimeSlot::new("wed-9", "Wednesday 09:00"), + ] +} + +// --------------------------------------------------------------------------- +// Test 1: creator + one voter, verify tally +// --------------------------------------------------------------------------- + +/// Alice creates a poll with two slots and votes Yes on both. +/// Bob joins, receives Alice's CreatePoll, then casts his ballot. +/// Alice receives Bob's ballot and verifies the tally is correct. +#[tokio::test] +async fn test_doodle_two_voters() { + println!("\n=== Test: Doodle poll with Alice (creator) and Bob ==="); + + let (alice_thin, bob_thin) = tokio::join!( + setup_client("alice"), + setup_client("bob"), + ); + let alice_thin = alice_thin.expect("Alice ThinClient"); + let bob_thin = bob_thin.expect("Bob ThinClient"); + + let alice_ph = PigeonholeClient::new_in_memory(alice_thin.clone()).expect("Alice PigeonholeClient"); + let bob_ph = PigeonholeClient::new_in_memory(bob_thin.clone()).expect("Bob PigeonholeClient"); + + println!("\n--- Alice creates the poll ---"); + let mut alice_poll = DoodlePoll::new_poll( + &alice_ph, + "standup-2v", + "Alice", + "Weekly standup", + two_slots(), + ).await.expect("Alice new_poll"); + + let alice_intro = alice_poll.my_introduction(); + + println!("\n--- Bob joins and receives Alice's CreatePoll ---"); + let mut bob_poll = DoodlePoll::join_poll( + &bob_ph, + "standup-2v", + "Bob", + &alice_intro, + ).await.expect("Bob join_poll"); + + // Bob receives Alice's CreatePoll and folds it into his state. + bob_poll.receive_one_and_apply().await.expect("Bob receive CreatePoll"); + assert_eq!(bob_poll.poll_state().title, "Weekly standup"); + assert_eq!(bob_poll.poll_state().slots.len(), 2); + println!("✓ Bob sees poll: '{}' with {} slots", + bob_poll.poll_state().title, + bob_poll.poll_state().slots.len()); + + println!("\n--- Exchange introductions ---"); + let bob_intro = bob_poll.my_introduction(); + alice_poll.add_member(&alice_ph, &bob_intro).expect("Alice add Bob"); + // Bob already has Alice from join_poll; only Alice needs to add Bob. + + println!("\n--- Alice and Bob cast ballots in parallel ---"); + let alice_votes = HashMap::from([ + ("mon-9".to_string(), Availability::Yes), + ("tue-9".to_string(), Availability::Maybe), + ]); + let bob_votes = HashMap::from([ + ("mon-9".to_string(), Availability::No), + ("tue-9".to_string(), Availability::Yes), + ]); + let (r1, r2) = tokio::join!( + alice_poll.cast_ballot(alice_votes.clone()), + bob_poll.cast_ballot(bob_votes.clone()), + ); + r1.expect("Alice cast_ballot"); + r2.expect("Bob cast_ballot"); + println!("✓ Alice and Bob sent their ballots"); + + println!("\n--- Alice and Bob receive from all members ---"); + let (r1, r2) = tokio::join!( + alice_poll.receive_and_apply(), + bob_poll.receive_and_apply(), + ); + r1.expect("Alice receive_and_apply"); + r2.expect("Bob receive_and_apply"); + + // cast_ballot applies to local state immediately, so Alice's own votes + // are already present. receive_and_apply added Bob's ballot. + alice_poll.poll_state().ballots.get("Alice").expect("Alice ballot in state"); + + // Verify Alice's tally + let alice_tally = alice_poll.tally(); + let mon = alice_tally.iter().find(|t| t.slot.id == "mon-9").expect("mon-9"); + let tue = alice_tally.iter().find(|t| t.slot.id == "tue-9").expect("tue-9"); + // Alice=Yes, Bob=No + assert_eq!(mon.yes, 1, "mon-9 yes"); + assert_eq!(mon.no, 1, "mon-9 no"); + assert_eq!(mon.maybe, 0, "mon-9 maybe"); + // Alice=Maybe, Bob=Yes + assert_eq!(tue.yes, 1, "tue-9 yes"); + assert_eq!(tue.no, 0, "tue-9 no"); + assert_eq!(tue.maybe, 1, "tue-9 maybe"); + println!("✓ Alice's tally: mon-9 Yes={} No={} Maybe={}", mon.yes, mon.no, mon.maybe); + println!("✓ Alice's tally: tue-9 Yes={} No={} Maybe={}", tue.yes, tue.no, tue.maybe); + + println!("\n✅ Two-voter Doodle test passed!"); +} + +// --------------------------------------------------------------------------- +// Test 2: three voters, unanimous best slot +// --------------------------------------------------------------------------- + +/// Alice, Bob, and Carol each vote. All three agree on Wednesday, so +/// `best_slot()` should return Wednesday for every participant. +#[tokio::test] +async fn test_doodle_three_voters() { + println!("\n=== Test: Doodle poll with three voters ==="); + + let (alice_thin, bob_thin, carol_thin) = tokio::join!( + setup_client("alice"), + setup_client("bob"), + setup_client("carol"), + ); + let alice_thin = alice_thin.expect("Alice ThinClient"); + let bob_thin = bob_thin.expect("Bob ThinClient"); + let carol_thin = carol_thin.expect("Carol ThinClient"); + + let alice_ph = PigeonholeClient::new_in_memory(alice_thin.clone()).expect("Alice ph"); + let bob_ph = PigeonholeClient::new_in_memory(bob_thin.clone()).expect("Bob ph"); + let carol_ph = PigeonholeClient::new_in_memory(carol_thin.clone()).expect("Carol ph"); + + println!("\n--- Alice creates the poll ---"); + let mut alice_poll = DoodlePoll::new_poll( + &alice_ph, "standup-3v", "Alice", + "Team standup", three_slots(), + ).await.expect("Alice new_poll"); + let alice_intro = alice_poll.my_introduction(); + + println!("\n--- Bob and Carol join in parallel ---"); + let (bob_poll, carol_poll) = tokio::join!( + DoodlePoll::join_poll(&bob_ph, "standup-3v", "Bob", &alice_intro), + DoodlePoll::join_poll(&carol_ph, "standup-3v", "Carol", &alice_intro), + ); + let mut bob_poll = bob_poll.expect("Bob join_poll"); + let mut carol_poll = carol_poll.expect("Carol join_poll"); + + let bob_intro = bob_poll.my_introduction(); + let carol_intro = carol_poll.my_introduction(); + + println!("\n--- Exchange introductions ---"); + alice_poll.add_member(&alice_ph, &bob_intro).expect("Alice add Bob"); + alice_poll.add_member(&alice_ph, &carol_intro).expect("Alice add Carol"); + // Bob and Carol already have Alice from join_poll. + bob_poll.add_member(&bob_ph, &carol_intro).expect("Bob add Carol"); + carol_poll.add_member(&carol_ph, &bob_intro).expect("Carol add Bob"); + + println!("\n--- Bob and Carol receive Alice's CreatePoll ---"); + // Each of Bob and Carol needs to receive Alice's CreatePoll to learn the slots. + let (r1, r2) = tokio::join!( + bob_poll.receive_one_and_apply(), + carol_poll.receive_one_and_apply(), + ); + r1.expect("Bob receive CreatePoll"); + r2.expect("Carol receive CreatePoll"); + assert_eq!(bob_poll.poll_state().slots.len(), 3); + assert_eq!(carol_poll.poll_state().slots.len(), 3); + println!("✓ Bob and Carol know the 3 slots"); + + println!("\n--- All three cast ballots in parallel ---"); + // All three vote Yes on wed-9 (Wednesday). + let (r1, r2, r3) = tokio::join!( + alice_poll.cast_ballot(HashMap::from([ + ("mon-9".to_string(), Availability::No), + ("tue-9".to_string(), Availability::No), + ("wed-9".to_string(), Availability::Yes), + ])), + bob_poll.cast_ballot(HashMap::from([ + ("mon-9".to_string(), Availability::Maybe), + ("tue-9".to_string(), Availability::No), + ("wed-9".to_string(), Availability::Yes), + ])), + carol_poll.cast_ballot(HashMap::from([ + ("mon-9".to_string(), Availability::No), + ("tue-9".to_string(), Availability::Maybe), + ("wed-9".to_string(), Availability::Yes), + ])), + ); + r1.expect("Alice cast"); r2.expect("Bob cast"); r3.expect("Carol cast"); + println!("✓ All three cast Yes on Wednesday"); + + println!("\n--- All three receive from all members in parallel ---"); + let (r1, r2, r3) = tokio::join!( + alice_poll.receive_and_apply(), + bob_poll.receive_and_apply(), + carol_poll.receive_and_apply(), + ); + r1.expect("Alice receive"); r2.expect("Bob receive"); r3.expect("Carol receive"); + + println!("\n--- Verify best_slot is Wednesday for all three ---"); + for (name, poll) in [("Alice", &alice_poll), ("Bob", &bob_poll), ("Carol", &carol_poll)] { + let best = poll.best_slot().expect(&format!("{} best_slot is None", name)); + assert_eq!(best.id, "wed-9", "{} best_slot wrong", name); + println!("✓ {} best_slot: '{}'", name, best.label); + } + + println!("\n✅ Three-voter Doodle test passed!"); +} + +// --------------------------------------------------------------------------- +// Test 3: ballot update (last-write-wins) +// --------------------------------------------------------------------------- + +/// Bob first votes No on Monday, then changes his mind and votes Yes. +/// Alice should see only the latest ballot from Bob. +#[tokio::test] +async fn test_doodle_ballot_update() { + println!("\n=== Test: Ballot update (last-write-wins) ==="); + + let (alice_thin, bob_thin) = tokio::join!( + setup_client("alice"), + setup_client("bob"), + ); + let alice_thin = alice_thin.expect("Alice ThinClient"); + let bob_thin = bob_thin.expect("Bob ThinClient"); + + let alice_ph = PigeonholeClient::new_in_memory(alice_thin.clone()).expect("Alice ph"); + let bob_ph = PigeonholeClient::new_in_memory(bob_thin.clone()).expect("Bob ph"); + + let mut alice_poll = DoodlePoll::new_poll( + &alice_ph, "standup-upd", "Alice", + "Update test poll", two_slots(), + ).await.expect("Alice new_poll"); + + let alice_intro = alice_poll.my_introduction(); + let mut bob_poll = DoodlePoll::join_poll( + &bob_ph, "standup-upd", "Bob", &alice_intro, + ).await.expect("Bob join_poll"); + + let bob_intro = bob_poll.my_introduction(); + alice_poll.add_member(&alice_ph, &bob_intro).expect("Alice add Bob"); + // Bob already has Alice from join_poll. + + // Bob receives CreatePoll. + bob_poll.receive_one_and_apply().await.expect("Bob receive CreatePoll"); + + println!("\n--- Bob casts first ballot (No on Monday) ---"); + bob_poll.cast_ballot(HashMap::from([ + ("mon-9".to_string(), Availability::No), + ("tue-9".to_string(), Availability::Yes), + ])).await.expect("Bob first ballot"); + + // Alice receives Bob's first ballot. + alice_poll.receive_and_apply().await.expect("Alice receive first ballot"); + let tally = alice_poll.tally(); + let mon = tally.iter().find(|t| t.slot.id == "mon-9").unwrap(); + assert_eq!(mon.no, 1, "Bob's first ballot: No on Monday"); + assert_eq!(mon.yes, 0); + println!("✓ Alice sees Bob's first ballot: mon-9 No={}", mon.no); + + println!("\n--- Bob changes his mind (Yes on Monday) ---"); + bob_poll.cast_ballot(HashMap::from([ + ("mon-9".to_string(), Availability::Yes), + ("tue-9".to_string(), Availability::Yes), + ])).await.expect("Bob second ballot"); + + // Alice receives Bob's updated ballot. + alice_poll.receive_and_apply().await.expect("Alice receive second ballot"); + let tally = alice_poll.tally(); + let mon = tally.iter().find(|t| t.slot.id == "mon-9").unwrap(); + // The updated ballot replaces the old one: only Yes should remain. + assert_eq!(mon.yes, 1, "Bob's updated ballot: Yes on Monday"); + assert_eq!(mon.no, 0, "Old No vote should be gone"); + println!("✓ Alice sees Bob's updated ballot: mon-9 Yes={} No={}", mon.yes, mon.no); + + println!("\n✅ Ballot update (LWW) test passed!"); +} + +// --------------------------------------------------------------------------- +// Test 4: best_slot tie-breaking +// --------------------------------------------------------------------------- + +/// Two slots each receive one Yes vote. `best_slot` should return the +/// first one in creation order (tie-break by index). +#[tokio::test] +async fn test_doodle_best_slot() { + println!("\n=== Test: best_slot tie-breaking ==="); + + let (alice_thin, bob_thin) = tokio::join!( + setup_client("alice"), + setup_client("bob"), + ); + let alice_thin = alice_thin.expect("Alice ThinClient"); + let bob_thin = bob_thin.expect("Bob ThinClient"); + + let alice_ph = PigeonholeClient::new_in_memory(alice_thin.clone()).expect("Alice ph"); + let bob_ph = PigeonholeClient::new_in_memory(bob_thin.clone()).expect("Bob ph"); + + // Alice votes Yes on Monday; Bob votes Yes on Tuesday. + // Both slots have exactly 1 Yes; Monday (index 0) should win the tie. + let mut alice_poll = DoodlePoll::new_poll( + &alice_ph, "standup-tie", "Alice", + "Tie-break test", two_slots(), + ).await.expect("Alice new_poll"); + + let alice_intro = alice_poll.my_introduction(); + let mut bob_poll = DoodlePoll::join_poll( + &bob_ph, "standup-tie", "Bob", &alice_intro, + ).await.expect("Bob join_poll"); + + let bob_intro = bob_poll.my_introduction(); + alice_poll.add_member(&alice_ph, &bob_intro).expect("Alice add Bob"); + // Bob already has Alice from join_poll. + + bob_poll.receive_one_and_apply().await.expect("Bob receive CreatePoll"); + + let (r1, r2) = tokio::join!( + alice_poll.cast_ballot(HashMap::from([ + ("mon-9".to_string(), Availability::Yes), + ("tue-9".to_string(), Availability::No), + ])), + bob_poll.cast_ballot(HashMap::from([ + ("mon-9".to_string(), Availability::No), + ("tue-9".to_string(), Availability::Yes), + ])), + ); + r1.expect("Alice cast"); r2.expect("Bob cast"); + + let (r1, r2) = tokio::join!( + alice_poll.receive_and_apply(), + bob_poll.receive_and_apply(), + ); + r1.expect("Alice receive"); r2.expect("Bob receive"); + + // Both slots have 1 Yes; Monday comes first → it wins the tie. + let best = alice_poll.best_slot().expect("best_slot should exist"); + assert_eq!(best.id, "mon-9", "tie broken by creation order"); + println!("✓ Alice's best_slot: '{}' (tie broken by index)", best.label); + + let best = bob_poll.best_slot().expect("bob best_slot"); + assert_eq!(best.id, "mon-9", "Bob also sees Monday win the tie"); + println!("✓ Bob's best_slot: '{}'", best.label); + + println!("\n✅ best_slot tie-breaking test passed!"); +} diff --git a/tests/group_channel_test.rs b/tests/group_channel_test.rs new file mode 100644 index 0000000..12162f3 --- /dev/null +++ b/tests/group_channel_test.rs @@ -0,0 +1,327 @@ +// SPDX-FileCopyrightText: Copyright (C) 2026 David Stainton +// SPDX-License-Identifier: AGPL-3.0-only + +//! Group channel integration tests. +//! +//! Tests the generic `GroupChannel` abstraction over a live mixnet. +//! +//! # Event types used +//! +//! `ChatEvent` — a simple text/introduction enum used for the basic group chat +//! tests that mirror the original test suite. +//! +//! `Dot` — a `GCounter` operation used for the CRDT test +//! that demonstrates `state = fold(events)`. +//! +//! # Running tests one at a time with output +//! +//! cargo test --test group_channel_test test_group_channel_three_members -- --nocapture +//! cargo test --test group_channel_test test_group_channel_introduction -- --nocapture +//! cargo test --test group_channel_test test_group_crdt_gcounter -- --nocapture + +use std::sync::Arc; +use std::time::Duration; + +use katzenpost_thin_client::chat::{ChatEvent, GroupChat}; +use katzenpost_thin_client::group::GroupChannel; +use katzenpost_thin_client::persistent::PigeonholeClient; +use katzenpost_thin_client::{Config, ThinClient}; + +// --------------------------------------------------------------------------- +// Test helpers +// --------------------------------------------------------------------------- + +async fn setup_client(label: &str) -> Result, Box> { + println!("[{}] ThinClient::new ...", label); + let config = Config::new("testdata/thinclient.toml")?; + let client = ThinClient::new(config).await?; + tokio::time::sleep(Duration::from_secs(2)).await; + println!("[{}] setup_client done", label); + Ok(client) +} + +// --------------------------------------------------------------------------- +// Basic chat tests (ChatEvent) +// --------------------------------------------------------------------------- + +#[tokio::test] +async fn test_group_channel_three_members() { + println!("\n=== Test: Group channel with Alice, Bob, and Carol ==="); + + println!("\n--- Setup: create three clients in parallel ---"); + let (alice_thin, bob_thin, carol_thin) = tokio::join!( + setup_client("alice"), + setup_client("bob"), + setup_client("carol"), + ); + let alice_thin = alice_thin.expect("Failed to setup Alice"); + let bob_thin = bob_thin.expect("Failed to setup Bob"); + let carol_thin = carol_thin.expect("Failed to setup Carol"); + + let alice_ph = PigeonholeClient::new_in_memory(alice_thin.clone()).expect("Alice PigeonholeClient"); + let bob_ph = PigeonholeClient::new_in_memory(bob_thin.clone()).expect("Bob PigeonholeClient"); + let carol_ph = PigeonholeClient::new_in_memory(carol_thin.clone()).expect("Carol PigeonholeClient"); + + println!("\n--- Step 1: All members create group channels in parallel ---"); + let (alice_group, bob_group, carol_group) = tokio::join!( + GroupChat::create(&alice_ph, "test-room", "Alice"), + GroupChat::create(&bob_ph, "test-room", "Bob"), + GroupChat::create(&carol_ph, "test-room", "Carol"), + ); + let alice_group: GroupChat = alice_group.expect("Alice group"); + let bob_group: GroupChat = bob_group.expect("Bob group"); + let carol_group: GroupChat = carol_group.expect("Carol group"); + println!("✓ Alice, Bob, Carol created group 'test-room'"); + + println!("\n--- Step 2: Exchange read capabilities ---"); + let alice_intro = alice_group.my_introduction(); + let bob_intro = bob_group.my_introduction(); + let carol_intro = carol_group.my_introduction(); + + alice_group.add_member(&alice_ph, &bob_intro).expect("Alice add Bob"); + alice_group.add_member(&alice_ph, &carol_intro).expect("Alice add Carol"); + bob_group.add_member(&bob_ph, &alice_intro).expect("Bob add Alice"); + bob_group.add_member(&bob_ph, &carol_intro).expect("Bob add Carol"); + carol_group.add_member(&carol_ph, &alice_intro).expect("Carol add Alice"); + carol_group.add_member(&carol_ph, &bob_intro).expect("Carol add Bob"); + + println!("✓ Alice has {} members", alice_group.member_count()); + println!("✓ Bob has {} members", bob_group.member_count()); + println!("✓ Carol has {} members", carol_group.member_count()); + + println!("\n--- Step 3: Alice sends a message ---"); + let alice_msg = "Hello everyone!"; + alice_group.send_text(alice_msg).await.expect("Alice send"); + println!("✓ Alice sent: '{}'", alice_msg); + + // Bob and Carol each have Alice as a member — receive_from_all races their + // respective member channels concurrently. + println!("\n--- Step 4: Bob and Carol receive Alice's message in parallel ---"); + let (bob_events, carol_events) = tokio::join!( + bob_group.receive_from_all(), + carol_group.receive_from_all(), + ); + let bob_events = bob_events.expect("Bob receive_from_all"); + let carol_events = carol_events.expect("Carol receive_from_all"); + + // Bob and Carol each only have one member (Alice) at this point, so each + // Vec has exactly one event. + let bob_event = bob_events.into_iter().find(|e| e.sender == "Alice").expect("Bob: no event from Alice"); + let carol_event = carol_events.into_iter().find(|e| e.sender == "Alice").expect("Carol: no event from Alice"); + + let ChatEvent::Text(ref bob_text) = bob_event.event else { panic!("Expected Text") }; + let ChatEvent::Text(ref carol_text) = carol_event.event else { panic!("Expected Text") }; + assert_eq!(bob_text, alice_msg); + assert_eq!(carol_text, alice_msg); + println!("✓ Bob received from Alice: '{}'", bob_text); + println!("✓ Carol received from Alice: '{}'", carol_text); + + println!("\n--- Step 5+6: Bob and Carol reply in parallel ---"); + let bob_msg = "Hi from Bob!"; + let carol_msg = "Hey, Carol here!"; + let (r1, r2) = tokio::join!( + bob_group.send_text(bob_msg), + carol_group.send_text(carol_msg), + ); + r1.expect("Bob send"); r2.expect("Carol send"); + println!("✓ Bob sent: '{}'", bob_msg); + println!("✓ Carol sent: '{}'", carol_msg); + + // Alice receives from all members (Bob and Carol) concurrently. + println!("\n--- Step 7: Alice receives from all members ---"); + let alice_events = alice_group.receive_from_all().await.expect("Alice receive_from_all"); + + let from_bob = alice_events.iter().find(|e| e.sender == "Bob").expect("Alice: no event from Bob"); + let from_carol = alice_events.iter().find(|e| e.sender == "Carol").expect("Alice: no event from Carol"); + + let ChatEvent::Text(ref bob_reply) = from_bob.event else { panic!("Expected Text from Bob") }; + let ChatEvent::Text(ref carol_reply) = from_carol.event else { panic!("Expected Text from Carol") }; + assert_eq!(bob_reply, bob_msg); + assert_eq!(carol_reply, carol_msg); + println!("✓ Alice received from Bob: '{}'", bob_reply); + println!("✓ Alice received from Carol: '{}'", carol_reply); + + println!("\n✅ Group channel three-member test passed!"); +} + +#[tokio::test] +async fn test_group_channel_introduction() { + println!("\n=== Test: Alice introduces Carol to Bob ==="); + + println!("\n--- Setup: create three clients in parallel ---"); + let (alice_thin, bob_thin, carol_thin) = tokio::join!( + setup_client("alice"), + setup_client("bob"), + setup_client("carol"), + ); + let alice_thin = alice_thin.expect("Failed to setup Alice"); + let bob_thin = bob_thin.expect("Failed to setup Bob"); + let carol_thin = carol_thin.expect("Failed to setup Carol"); + + let alice_ph = PigeonholeClient::new_in_memory(alice_thin.clone()).unwrap(); + let bob_ph = PigeonholeClient::new_in_memory(bob_thin.clone()).unwrap(); + let carol_ph = PigeonholeClient::new_in_memory(carol_thin.clone()).unwrap(); + + println!("\n--- Create group channels in parallel ---"); + let (alice_group, bob_group, carol_group) = tokio::join!( + GroupChat::create(&alice_ph, "intro-test", "Alice"), + GroupChat::create(&bob_ph, "intro-test", "Bob"), + GroupChat::create(&carol_ph, "intro-test", "Carol"), + ); + let alice_group: GroupChat = alice_group.unwrap(); + let bob_group: GroupChat = bob_group.unwrap(); + let carol_group: GroupChat = carol_group.unwrap(); + + let alice_intro = alice_group.my_introduction(); + let bob_intro = bob_group.my_introduction(); + let carol_intro = carol_group.my_introduction(); + + // Alice knows everyone; Bob only knows Alice initially. + alice_group.add_member(&alice_ph, &bob_intro).unwrap(); + alice_group.add_member(&alice_ph, &carol_intro).unwrap(); + bob_group.add_member(&bob_ph, &alice_intro).unwrap(); + carol_group.add_member(&carol_ph, &alice_intro).unwrap(); + carol_group.add_member(&carol_ph, &bob_intro).unwrap(); + println!("✓ Initial setup: Alice knows everyone, Bob only knows Alice"); + + println!("\n--- Alice sends Carol's introduction ---"); + alice_group.send_introduction(&carol_intro).await.unwrap(); + println!("✓ Alice sent Carol's introduction"); + + println!("\n--- Bob receives from all members (just Alice) ---"); + let msgs = bob_group.receive_from_all().await.unwrap(); + let msg = msgs.into_iter().find(|e| e.sender == "Alice").expect("Bob: no event from Alice"); + + let ChatEvent::Introduction(ref intro) = msg.event else { + panic!("Expected ChatEvent::Introduction"); + }; + println!("✓ Bob received introduction for: '{}'", intro.display_name); + assert_eq!(intro.display_name, "Carol"); + + bob_group.add_member(&bob_ph, intro).unwrap(); + println!("✓ Bob added Carol (member count: {})", bob_group.member_count()); + assert_eq!(bob_group.member_count(), 2); // Alice + Carol + + println!("\n--- Bob sends message to group (now including Carol) ---"); + let bob_text = "Hi Carol, nice to meet you!"; + bob_group.send_text(bob_text).await.unwrap(); + println!("✓ Bob sent: '{}'", bob_text); + + println!("\n--- Carol receives from all members ---"); + let carol_msgs = carol_group.receive_from_all().await.unwrap(); + let from_bob = carol_msgs.into_iter().find(|e| e.sender == "Bob").expect("Carol: no event from Bob"); + + let ChatEvent::Text(ref carol_got) = from_bob.event else { + panic!("Expected ChatEvent::Text from Bob"); + }; + assert_eq!(carol_got, bob_text); + println!("✓ Carol got Bob's message: '{}'", carol_got); + + println!("\n✅ Introduction test passed!"); +} + +// --------------------------------------------------------------------------- +// CRDT integration test +// --------------------------------------------------------------------------- + +/// Each of three participants broadcasts a GCounter increment operation on +/// their own channel. After all messages propagate through the mixnet, every +/// participant folds the received operations into a local GCounter and verifies +/// the total matches the expected sum. +/// +/// This directly demonstrates the `state = fold(events)` pattern described in +/// the Pigeonhole blog post. +#[tokio::test] +async fn test_group_crdt_gcounter() { + use crdts::{CmRDT, Dot, GCounter}; + + type CounterOp = Dot; + + println!("\n=== Test: CRDT GCounter over group channel ==="); + + println!("\n--- Setup: create three clients in parallel ---"); + let (alice_thin, bob_thin, carol_thin) = tokio::join!( + setup_client("alice"), + setup_client("bob"), + setup_client("carol"), + ); + let alice_thin = alice_thin.expect("Failed to setup Alice"); + let bob_thin = bob_thin.expect("Failed to setup Bob"); + let carol_thin = carol_thin.expect("Failed to setup Carol"); + + let alice_ph = PigeonholeClient::new_in_memory(alice_thin.clone()).unwrap(); + let bob_ph = PigeonholeClient::new_in_memory(bob_thin.clone()).unwrap(); + let carol_ph = PigeonholeClient::new_in_memory(carol_thin.clone()).unwrap(); + + println!("\n--- Create group channels in parallel ---"); + let (alice_group, bob_group, carol_group) = tokio::join!( + GroupChannel::create(&alice_ph, "crdt-test", "Alice"), + GroupChannel::create(&bob_ph, "crdt-test", "Bob"), + GroupChannel::create(&carol_ph, "crdt-test", "Carol"), + ); + let alice_group: GroupChannel = alice_group.unwrap(); + let bob_group: GroupChannel = bob_group.unwrap(); + let carol_group: GroupChannel = carol_group.unwrap(); + + let alice_intro = alice_group.my_introduction(); + let bob_intro = bob_group.my_introduction(); + let carol_intro = carol_group.my_introduction(); + + alice_group.add_member(&alice_ph, &bob_intro).unwrap(); + alice_group.add_member(&alice_ph, &carol_intro).unwrap(); + bob_group.add_member(&bob_ph, &alice_intro).unwrap(); + bob_group.add_member(&bob_ph, &carol_intro).unwrap(); + carol_group.add_member(&carol_ph, &alice_intro).unwrap(); + carol_group.add_member(&carol_ph, &bob_intro).unwrap(); + println!("✓ All members joined group 'crdt-test'"); + + let alice_gen: GCounter = GCounter::new(); + let bob_gen: GCounter = GCounter::new(); + let carol_gen: GCounter = GCounter::new(); + + let alice_op = alice_gen.inc("Alice".to_string()); + let bob_op = bob_gen.inc("Bob".to_string()); + let carol_op = carol_gen.inc("Carol".to_string()); + + println!("\n--- All three send in parallel ---"); + let (r1, r2, r3) = tokio::join!( + alice_group.send(&alice_op), + bob_group.send(&bob_op), + carol_group.send(&carol_op), + ); + r1.expect("Alice send"); r2.expect("Bob send"); r3.expect("Carol send"); + println!("✓ Alice, Bob, Carol each sent their counter op"); + + println!("\n--- All three receive in parallel ---"); + let (alice_received, bob_received, carol_received) = tokio::join!( + alice_group.receive_from_all(), + bob_group.receive_from_all(), + carol_group.receive_from_all(), + ); + let alice_received = alice_received.expect("Alice receive_from_all"); + let bob_received = bob_received.expect("Bob receive_from_all"); + let carol_received = carol_received.expect("Carol receive_from_all"); + println!("✓ All ops received"); + + // Fold: each participant applies their own op plus the received ops. + let mut alice_counter: GCounter = GCounter::new(); + alice_counter.apply(alice_op.clone()); + for e in &alice_received { alice_counter.apply(e.event.clone()); } + assert_eq!(alice_counter.read().to_string(), "3"); + println!("✓ Alice's counter = {} (expected 3)", alice_counter.read()); + + let mut bob_counter: GCounter = GCounter::new(); + bob_counter.apply(bob_op.clone()); + for e in &bob_received { bob_counter.apply(e.event.clone()); } + assert_eq!(bob_counter.read().to_string(), "3"); + println!("✓ Bob's counter = {} (expected 3)", bob_counter.read()); + + let mut carol_counter: GCounter = GCounter::new(); + carol_counter.apply(carol_op.clone()); + for e in &carol_received { carol_counter.apply(e.event.clone()); } + assert_eq!(carol_counter.read().to_string(), "3"); + println!("✓ Carol's counter = {} (expected 3)", carol_counter.read()); + + println!("\n✅ CRDT GCounter group test passed!"); + println!(" All three participants independently converged to state = 3"); +} diff --git a/tests/high_level_api_test.rs b/tests/high_level_api_test.rs new file mode 100644 index 0000000..19dc35a --- /dev/null +++ b/tests/high_level_api_test.rs @@ -0,0 +1,676 @@ +// SPDX-FileCopyrightText: Copyright (C) 2026 David Stainton +// SPDX-License-Identifier: AGPL-3.0-only + +//! High-level PigeonholeClient API integration tests +//! These tests require a running mixnet with client daemon for integration testing. + +use std::sync::Arc; +use std::time::Duration; +use katzenpost_thin_client::{ThinClient, Config}; +use katzenpost_thin_client::persistent::PigeonholeClient; + +/// Test helper to setup thin clients for integration tests +async fn setup_clients() -> Result<(Arc, Arc), Box> { + let alice_config = Config::new("testdata/thinclient.toml")?; + let alice_client = ThinClient::new(alice_config).await?; + + let bob_config = Config::new("testdata/thinclient.toml")?; + let bob_client = ThinClient::new(bob_config).await?; + + // Wait for initial connection and PKI document + tokio::time::sleep(Duration::from_secs(2)).await; + + Ok((alice_client, bob_client)) +} + +#[tokio::test] +async fn test_high_level_send_receive() { + println!("\n=== Test: High-level API - Alice sends message to Bob ==="); + + let (alice_thin, bob_thin) = setup_clients().await.expect("Failed to setup clients"); + + // Create high-level clients with in-memory databases + let alice = PigeonholeClient::new_in_memory(alice_thin.clone()) + .expect("Failed to create Alice's PigeonholeClient"); + let bob = PigeonholeClient::new_in_memory(bob_thin.clone()) + .expect("Failed to create Bob's PigeonholeClient"); + + // Step 1: Alice creates a channel + println!("\n--- Step 1: Alice creates a channel ---"); + let mut alice_channel = alice.create_channel("alice-to-bob").await + .expect("Failed to create channel"); + println!("✓ Alice created channel: {}", alice_channel.name()); + + // Step 2: Alice shares the read capability with Bob + println!("\n--- Step 2: Alice shares read capability with Bob ---"); + let read_cap = alice_channel.share_read_capability(); + println!("✓ Alice shared read capability"); + + // Step 3: Bob imports the channel + println!("\n--- Step 3: Bob imports the channel ---"); + let mut bob_channel = bob.import_channel("messages-from-alice", &read_cap) + .expect("Failed to import channel"); + println!("✓ Bob imported channel: {}", bob_channel.name()); + assert!(!bob_channel.is_owned(), "Bob's channel should be read-only"); + + // Step 4: Alice sends a message using high-level API + println!("\n--- Step 4: Alice sends a message ---"); + let message = b"Hello Bob! This is a secret message from Alice."; + alice_channel.send(message).await.expect("Failed to send message"); + println!("✓ Alice sent message: {:?}", String::from_utf8_lossy(message)); + + // Wait for message propagation through the mixnet + println!("\n--- Waiting 30 seconds for message propagation ---"); + tokio::time::sleep(Duration::from_secs(30)).await; + + // Step 5: Bob receives the message using high-level API + println!("\n--- Step 5: Bob receives the message ---"); + let received = bob_channel.receive().await.expect("Failed to receive message"); + println!("✓ Bob received message: {:?}", String::from_utf8_lossy(&received)); + + // Verify the message content + assert_eq!(received, message, "Received message should match sent message"); + println!("\n✅ High-level send/receive test passed!"); +} + +#[tokio::test] +async fn test_high_level_multiple_messages() { + println!("\n=== Test: High-level API - Multiple sequential messages ==="); + + let (alice_thin, bob_thin) = setup_clients().await.expect("Failed to setup clients"); + + let alice = PigeonholeClient::new_in_memory(alice_thin.clone()) + .expect("Failed to create Alice's PigeonholeClient"); + let bob = PigeonholeClient::new_in_memory(bob_thin.clone()) + .expect("Failed to create Bob's PigeonholeClient"); + + // Alice creates a channel and Bob imports it + let mut alice_channel = alice.create_channel("multi-msg-channel").await + .expect("Failed to create channel"); + let read_cap = alice_channel.share_read_capability(); + let mut bob_channel = bob.import_channel("multi-msg-channel", &read_cap) + .expect("Failed to import channel"); + + // Alice sends multiple messages + let messages = vec![ + b"Message 1: Hello!".to_vec(), + b"Message 2: How are you?".to_vec(), + b"Message 3: Goodbye!".to_vec(), + ]; + + println!("\n--- Alice sends {} messages ---", messages.len()); + for (i, msg) in messages.iter().enumerate() { + alice_channel.send(msg).await.expect("Failed to send message"); + println!("✓ Sent message {}: {:?}", i + 1, String::from_utf8_lossy(msg)); + } + + // Wait for propagation + println!("\n--- Waiting 45 seconds for message propagation ---"); + tokio::time::sleep(Duration::from_secs(45)).await; + + // Bob receives all messages in order + println!("\n--- Bob receives messages ---"); + for (i, expected_msg) in messages.iter().enumerate() { + let received = bob_channel.receive().await.expect("Failed to receive message"); + println!("✓ Received message {}: {:?}", i + 1, String::from_utf8_lossy(&received)); + assert_eq!(&received, expected_msg, "Message {} mismatch", i + 1); + } + + println!("\n✅ Multiple messages test passed!"); +} + +#[tokio::test] +async fn test_low_level_box_operations() { + println!("\n=== Test: Low-level box operations (write_box / read_box) ==="); + + let (alice_thin, bob_thin) = setup_clients().await.expect("Failed to setup clients"); + + let alice = PigeonholeClient::new_in_memory(alice_thin.clone()) + .expect("Failed to create Alice's PigeonholeClient"); + let bob = PigeonholeClient::new_in_memory(bob_thin.clone()) + .expect("Failed to create Bob's PigeonholeClient"); + + // Alice creates a channel + let alice_channel = alice.create_channel("low-level-test").await + .expect("Failed to create channel"); + + // Get the initial indices + let write_index = alice_channel.write_index().unwrap().to_vec(); + let read_cap = alice_channel.share_read_capability(); + let bob_channel = bob.import_channel("low-level-test", &read_cap) + .expect("Failed to import channel"); + + // Alice writes directly to a specific box using low-level API + println!("\n--- Alice writes to box using write_box ---"); + let message = b"Direct box write test"; + alice_channel.write_box(message, &write_index).await + .expect("Failed to write box"); + println!("✓ Alice wrote to box at index"); + + // Wait for propagation + println!("\n--- Waiting 30 seconds for propagation ---"); + tokio::time::sleep(Duration::from_secs(30)).await; + + // Bob reads from the specific box using low-level API + println!("\n--- Bob reads from box using read_box ---"); + let read_index = bob_channel.read_index().to_vec(); + let (received, _next_index) = bob_channel.read_box(&read_index).await + .expect("Failed to read box"); + println!("✓ Bob read from box: {:?}", String::from_utf8_lossy(&received)); + + assert_eq!(received, message, "Box content mismatch"); + println!("\n✅ Low-level box operations test passed!"); +} + +#[tokio::test] +async fn test_copy_stream_multi_payload() { + println!("\n=== Test: Copy stream with multiple payloads (add_multi_payload) ==="); + + let (alice_thin, bob_thin) = setup_clients().await.expect("Failed to setup clients"); + + let alice = PigeonholeClient::new_in_memory(alice_thin.clone()) + .expect("Failed to create Alice's PigeonholeClient"); + let bob = PigeonholeClient::new_in_memory(bob_thin.clone()) + .expect("Failed to create Bob's PigeonholeClient"); + + // Create two destination channels + let channel1 = alice.create_channel("multi-dest-1").await + .expect("Failed to create channel 1"); + let channel2 = alice.create_channel("multi-dest-2").await + .expect("Failed to create channel 2"); + + let dest1_write_cap = channel1.write_cap().unwrap().to_vec(); + let dest1_index = channel1.write_index().unwrap().to_vec(); + let dest2_write_cap = channel2.write_cap().unwrap().to_vec(); + let dest2_index = channel2.write_index().unwrap().to_vec(); + + // Bob imports both channels + let read_cap1 = channel1.share_read_capability(); + let read_cap2 = channel2.share_read_capability(); + let bob_channel1 = bob.import_channel("multi-dest-1", &read_cap1) + .expect("Failed to import channel 1"); + let bob_channel2 = bob.import_channel("multi-dest-2", &read_cap2) + .expect("Failed to import channel 2"); + + // Create payloads for each destination + let payload1 = b"Secret message for Channel 1"; + let payload2 = b"Secret message for Channel 2"; + + println!("\n--- Creating copy stream with multiple destinations ---"); + + // Use add_multi_payload for efficient packing + let mut builder = channel1.copy_stream_builder().await + .expect("Failed to create copy stream builder"); + + let destinations: Vec<(&[u8], &[u8], &[u8])> = vec![ + (payload1.as_slice(), &dest1_write_cap, &dest1_index), + (payload2.as_slice(), &dest2_write_cap, &dest2_index), + ]; + + builder.add_multi_payload(destinations, true).await + .expect("Failed to add multi payload"); + println!("✓ Added payloads for both destinations in single call"); + + let boxes_written = builder.finish().await + .expect("Failed to finish copy stream"); + println!("✓ Copy stream finished, {} boxes written", boxes_written); + + // Wait for courier to process + println!("\n--- Waiting 60 seconds for copy command execution ---"); + tokio::time::sleep(Duration::from_secs(60)).await; + + // Bob reads from both channels + println!("\n--- Bob reads from Channel 1 ---"); + let (received1, _) = bob_channel1.read_box(bob_channel1.read_index()).await + .expect("Failed to read from channel 1"); + println!("✓ Channel 1: {:?}", String::from_utf8_lossy(&received1)); + + println!("\n--- Bob reads from Channel 2 ---"); + let (received2, _) = bob_channel2.read_box(bob_channel2.read_index()).await + .expect("Failed to read from channel 2"); + println!("✓ Channel 2: {:?}", String::from_utf8_lossy(&received2)); + + // Verify + assert_eq!(received1, payload1.to_vec(), "Channel 1 payload mismatch"); + assert_eq!(received2, payload2.to_vec(), "Channel 2 payload mismatch"); + + println!("\n✅ Multi-payload copy stream test passed!"); +} + +#[tokio::test] +async fn test_tombstone_single_box() { + println!("\n=== Test: Tombstoning a single box ==="); + + let (alice_thin, bob_thin) = setup_clients().await.expect("Failed to setup clients"); + + let alice = PigeonholeClient::new_in_memory(alice_thin.clone()) + .expect("Failed to create Alice's PigeonholeClient"); + let bob = PigeonholeClient::new_in_memory(bob_thin.clone()) + .expect("Failed to create Bob's PigeonholeClient"); + + // Alice creates a channel + let mut alice_channel = alice.create_channel("tombstone-test").await + .expect("Failed to create channel"); + let read_cap = alice_channel.share_read_capability(); + let mut bob_channel = bob.import_channel("tombstone-test", &read_cap) + .expect("Failed to import channel"); + + // Step 1: Alice sends a message + println!("\n--- Step 1: Alice sends a message ---"); + let message = b"This message will be tombstoned"; + alice_channel.send(message).await.expect("Failed to send message"); + println!("✓ Alice sent message"); + + // Wait for propagation + println!("\n--- Waiting 30 seconds for propagation ---"); + tokio::time::sleep(Duration::from_secs(30)).await; + + // Step 2: Bob reads the message + println!("\n--- Step 2: Bob reads the message ---"); + let received = bob_channel.receive().await.expect("Failed to receive"); + println!("✓ Bob received: {:?}", String::from_utf8_lossy(&received)); + assert_eq!(received, message); + + // Step 3: Alice tombstones the box at the first index + println!("\n--- Step 3: Alice tombstones the box ---"); + let first_index = read_cap.start_index.clone(); + alice_channel.tombstone_at(&first_index).await + .expect("Failed to tombstone"); + println!("✓ Alice tombstoned the box"); + + // Wait for tombstone propagation + println!("\n--- Waiting 60 seconds for tombstone propagation ---"); + tokio::time::sleep(Duration::from_secs(60)).await; + + // Step 4: Bob reads again and sees tombstone + println!("\n--- Step 4: Bob reads the tombstoned box ---"); + bob_channel.refresh().expect("Failed to refresh"); + // Reset read index to re-read the same box + let first_index = read_cap.start_index.clone(); + let (tombstone_content, _) = bob_channel.read_box(&first_index).await + .expect("Failed to read tombstoned box"); + + // A tombstone is an empty payload with a valid signature + assert!( + tombstone_content.is_empty(), + "Expected tombstone (empty payload), got {} bytes", + tombstone_content.len() + ); + println!("✓ Bob verified tombstone (content is empty)"); + + println!("\n✅ Tombstone single box test passed!"); +} + +#[tokio::test] +async fn test_tombstone_range() { + println!("\n=== Test: Tombstoning a range of boxes ==="); + + let (alice_thin, bob_thin) = setup_clients().await.expect("Failed to setup clients"); + + let alice = PigeonholeClient::new_in_memory(alice_thin.clone()) + .expect("Failed to create Alice's PigeonholeClient"); + let bob = PigeonholeClient::new_in_memory(bob_thin.clone()) + .expect("Failed to create Bob's PigeonholeClient"); + + // Alice creates a channel + let mut alice_channel = alice.create_channel("tombstone-range-test").await + .expect("Failed to create channel"); + let read_cap = alice_channel.share_read_capability(); + let first_index = read_cap.start_index.clone(); + let bob_channel = bob.import_channel("tombstone-range-test", &read_cap) + .expect("Failed to import channel"); + + // Step 1: Alice sends multiple messages + let num_messages = 3; + println!("\n--- Step 1: Alice sends {} messages ---", num_messages); + for i in 0..num_messages { + let msg = format!("Message {} to be tombstoned", i + 1); + alice_channel.send(msg.as_bytes()).await.expect("Failed to send"); + println!("✓ Sent message {}", i + 1); + } + + // Wait for propagation + println!("\n--- Waiting 45 seconds for propagation ---"); + tokio::time::sleep(Duration::from_secs(45)).await; + + // Step 2: Verify Bob can read all messages + println!("\n--- Step 2: Verify Bob can read messages ---"); + let mut bob_channel = bob_channel; // Make mutable for receive + for i in 0..num_messages { + let received = bob_channel.receive().await.expect("Failed to receive"); + println!("✓ Read message {}: {:?}", i + 1, String::from_utf8_lossy(&received)); + } + + // Step 3: Alice tombstones the range starting from the first index + println!("\n--- Step 3: Alice tombstones {} boxes ---", num_messages); + alice_channel.tombstone_from(&first_index, num_messages).await + .expect("Failed to tombstone range"); + println!("✓ Alice sent tombstone range"); + + // Wait for tombstone propagation + println!("\n--- Waiting 60 seconds for tombstone propagation ---"); + tokio::time::sleep(Duration::from_secs(60)).await; + + // Step 4: Verify all boxes are tombstoned + println!("\n--- Step 4: Verify all boxes are tombstoned ---"); + let mut current_index = first_index; + for i in 0..num_messages { + let (content, next_idx) = bob_channel.read_box(¤t_index).await + .expect("Failed to read box"); + // A tombstone is an empty payload with a valid signature + assert!( + content.is_empty(), + "Box {} should be tombstoned (empty), got {} bytes", i + 1, content.len() + ); + println!("✓ Box {} is tombstoned", i + 1); + + if i < num_messages - 1 { + current_index = next_idx; + } + } + + println!("\n✅ Tombstone range test passed!"); +} + +#[tokio::test] +async fn test_stream_buffer_set_and_restore() { + println!("\n=== Test: Set stream buffer for recovery ==="); + + let (alice_thin, _bob_thin) = setup_clients().await.expect("Failed to setup clients"); + + // Generate a stream ID + let stream_id = ThinClient::new_stream_id(); + println!("Using stream_id: {:?}", &stream_id[..4]); + + // Set a buffer state (simulating restoration from persisted state) + let test_buffer = b"test buffer data for crash recovery".to_vec(); + + println!("Setting buffer: {} bytes", test_buffer.len()); + alice_thin.set_stream_buffer(&stream_id, test_buffer.clone()).await + .expect("Failed to set stream buffer"); + println!("✓ Buffer set successfully - encoder created/updated in daemon"); + + println!("\n✅ Set stream buffer test passed!"); +} + +#[tokio::test] +async fn test_stream_buffer_returned_from_payload() { + println!("\n=== Test: Buffer state returned from create_courier_envelopes_from_payload ==="); + + let (alice_thin, _bob_thin) = setup_clients().await.expect("Failed to setup clients"); + + // Create a channel for the write capability + let alice = katzenpost_thin_client::persistent::PigeonholeClient::new_in_memory(alice_thin.clone()) + .expect("Failed to create Alice's PigeonholeClient"); + + let alice_channel = alice.create_channel("buffer-test").await + .expect("Failed to create channel"); + let write_cap = alice_channel.write_cap().expect("Channel should have write cap").to_vec(); + let start_index = alice_channel.write_index().expect("Channel should have write index").to_vec(); + + // Create envelopes with is_last=false to trigger buffering + let stream_id = ThinClient::new_stream_id(); + let payload = b"Test payload data for buffering".to_vec(); + + let result = alice_thin.create_courier_envelopes_from_payload( + &stream_id, + &payload, + &write_cap, + &start_index, + false, // is_last=false triggers buffering + ).await.expect("Failed to create envelopes"); + + println!("✓ Got {} envelopes", result.envelopes.len()); + println!("✓ Buffer: {} bytes", result.buffer.len()); + + // The buffer should be available for persistence + // (actual buffer contents depend on payload size vs geometry) + println!("\n✅ Buffer returned from payload test passed!"); +} + +#[tokio::test] +async fn test_stream_buffer_recovery_workflow() { + println!("\n=== Test: Stream buffer crash recovery workflow ==="); + + let (alice_thin, _bob_thin) = setup_clients().await.expect("Failed to setup clients"); + + // Step 1: Alice creates a channel and gets write capability + println!("\n--- Step 1: Setup channel ---"); + let alice = katzenpost_thin_client::persistent::PigeonholeClient::new_in_memory(alice_thin.clone()) + .expect("Failed to create Alice's PigeonholeClient"); + + let alice_channel = alice.create_channel("recovery-test").await + .expect("Failed to create channel"); + let write_cap = alice_channel.write_cap().expect("Channel should have write cap").to_vec(); + let start_index = alice_channel.write_index().expect("Channel should have write index").to_vec(); + println!("✓ Channel created"); + + // Step 2: Start a stream with is_last=false (simulating partial write) + println!("\n--- Step 2: Start streaming with is_last=false ---"); + let stream_id = ThinClient::new_stream_id(); + let first_payload = b"First chunk of data for crash recovery test".to_vec(); + + let result = alice_thin.create_courier_envelopes_from_payload( + &stream_id, + &first_payload, + &write_cap, + &start_index, + false, // is_last=false, so buffer will be retained + ).await.expect("Failed to create envelopes"); + println!("✓ First chunk written with is_last=false"); + println!(" Envelopes: {}, Buffer: {} bytes", + result.envelopes.len(), result.buffer.len()); + + // Step 3: Save the buffer (simulating checkpoint before crash) + println!("\n--- Step 3: Checkpoint - save buffer ---"); + let saved_buffer = result.buffer.clone(); + println!("✓ Saved buffer: {} bytes", saved_buffer.len()); + + // Step 4: Simulate restart by setting buffer on a "new" stream + // In real crash recovery, this would be a new client instance + println!("\n--- Step 4: Restore buffer (simulating restart) ---"); + let new_stream_id = ThinClient::new_stream_id(); + alice_thin.set_stream_buffer( + &new_stream_id, + saved_buffer.clone(), + ).await.expect("Failed to restore stream buffer"); + println!("✓ Buffer restored to new stream"); + + // Step 5: Continue the stream with more data and finish + println!("\n--- Step 5: Continue stream and finalize ---"); + let second_payload = b"Second chunk completing the stream".to_vec(); + let final_result = alice_thin.create_courier_envelopes_from_payload( + &new_stream_id, + &second_payload, + &write_cap, + &start_index, + true, // is_last=true to finalize + ).await.expect("Failed to finalize stream"); + + println!("✓ Stream finalized with {} envelopes", final_result.envelopes.len()); + println!("✓ Final buffer: {} bytes (should be 0 after flush)", + final_result.buffer.len()); + + println!("\n✅ Stream buffer crash recovery workflow test passed!"); +} + +#[tokio::test] +async fn test_read_box_no_retry() { + println!("\n=== Test: read_box_no_retry returns immediate error for non-existent box ==="); + + let (alice_thin, _bob_thin) = setup_clients().await.expect("Failed to setup clients"); + + let alice = PigeonholeClient::new_in_memory(alice_thin.clone()) + .expect("Failed to create Alice's PigeonholeClient"); + + // Create a channel + let alice_channel = alice.create_channel("read-no-retry-test").await + .expect("Failed to create channel"); + let read_cap = alice_channel.share_read_capability(); + + // Try to read from a box that doesn't exist yet (nothing was written) + // With no_retry, this should fail immediately with BoxNotFound + println!("\n--- Attempting read_box_no_retry on empty box ---"); + let result = alice_channel.read_box_no_retry(&read_cap.start_index).await; + + match result { + Err(e) => { + let err_str = format!("{:?}", e); + println!("✓ Got expected error: {}", err_str); + assert!( + err_str.contains("BoxNotFound") || err_str.contains("box id not found"), + "Expected BoxNotFound error, got: {}", err_str + ); + } + Ok(_) => { + panic!("Expected BoxNotFound error, but read succeeded"); + } + } + + println!("\n✅ read_box_no_retry test passed!"); +} + +#[tokio::test] +async fn test_receive_no_retry() { + println!("\n=== Test: receive_no_retry returns immediate error for non-existent message ==="); + + let (alice_thin, _bob_thin) = setup_clients().await.expect("Failed to setup clients"); + + let alice = PigeonholeClient::new_in_memory(alice_thin.clone()) + .expect("Failed to create Alice's PigeonholeClient"); + + // Create a channel + let mut alice_channel = alice.create_channel("receive-no-retry-test").await + .expect("Failed to create channel"); + + // Try to receive when nothing was sent + // With no_retry, this should fail immediately with BoxNotFound + println!("\n--- Attempting receive_no_retry on empty channel ---"); + let result = alice_channel.receive_no_retry().await; + + match result { + Err(e) => { + let err_str = format!("{:?}", e); + println!("✓ Got expected error: {}", err_str); + assert!( + err_str.contains("BoxNotFound") || err_str.contains("box id not found"), + "Expected BoxNotFound error, got: {}", err_str + ); + } + Ok(_) => { + panic!("Expected BoxNotFound error, but receive succeeded"); + } + } + + println!("\n✅ receive_no_retry test passed!"); +} + +#[tokio::test] +async fn test_write_box_return_box_exists() { + println!("\n=== Test: write_box_return_box_exists returns error on duplicate write ==="); + + let (alice_thin, _bob_thin) = setup_clients().await.expect("Failed to setup clients"); + + let alice = PigeonholeClient::new_in_memory(alice_thin.clone()) + .expect("Failed to create Alice's PigeonholeClient"); + + // Create a channel + let alice_channel = alice.create_channel("write-box-exists-test").await + .expect("Failed to create channel"); + let start_index = alice_channel.read_index().to_vec(); + + // First write should succeed + println!("\n--- First write_box ---"); + let message1 = b"First message"; + alice_channel.write_box(message1, &start_index).await + .expect("First write should succeed"); + println!("✓ First write succeeded"); + + // Wait for propagation + println!("\n--- Waiting 30 seconds for propagation ---"); + tokio::time::sleep(Duration::from_secs(30)).await; + + // Second write to same index with return_box_exists should fail + println!("\n--- Second write_box_return_box_exists to same index ---"); + let message2 = b"Second message"; + let result = alice_channel.write_box_return_box_exists(message2, &start_index).await; + + match result { + Err(e) => { + let err_str = format!("{:?}", e); + println!("✓ Got expected error: {}", err_str); + assert!( + err_str.contains("BoxAlreadyExists") || err_str.contains("box already exists"), + "Expected BoxAlreadyExists error, got: {}", err_str + ); + } + Ok(_) => { + panic!("Expected BoxAlreadyExists error, but write succeeded"); + } + } + + println!("\n✅ write_box_return_box_exists test passed!"); +} + +#[tokio::test] +async fn test_send_return_box_exists() { + println!("\n=== Test: send_return_box_exists returns error on duplicate send ==="); + + let (alice_thin, bob_thin) = setup_clients().await.expect("Failed to setup clients"); + + let alice = PigeonholeClient::new_in_memory(alice_thin.clone()) + .expect("Failed to create Alice's PigeonholeClient"); + let bob = PigeonholeClient::new_in_memory(bob_thin.clone()) + .expect("Failed to create Bob's PigeonholeClient"); + + // Create a channel + let mut alice_channel = alice.create_channel("send-box-exists-test").await + .expect("Failed to create channel"); + let read_cap = alice_channel.share_read_capability(); + let _bob_channel = bob.import_channel("send-box-exists-test", &read_cap) + .expect("Failed to import channel"); + + // First send should succeed + println!("\n--- First send ---"); + let message1 = b"First message"; + alice_channel.send(message1).await + .expect("First send should succeed"); + println!("✓ First send succeeded"); + + // Wait for propagation + println!("\n--- Waiting 30 seconds for propagation ---"); + tokio::time::sleep(Duration::from_secs(30)).await; + + // Now manually write to the CURRENT write index (which was just advanced) + // to set up a conflict scenario. We need to use write_box to write at the + // new write_index, then try send_return_box_exists which will try to write there + let current_write_index = alice_channel.write_index().unwrap().to_vec(); + println!("\n--- Writing directly to current write index to create conflict ---"); + alice_channel.write_box(b"Conflict message", ¤t_write_index).await + .expect("Direct write should succeed"); + println!("✓ Conflict message written"); + + // Wait for propagation + println!("\n--- Waiting 30 seconds for propagation ---"); + tokio::time::sleep(Duration::from_secs(30)).await; + + // Now send_return_box_exists should fail because the box is occupied + println!("\n--- Attempting send_return_box_exists ---"); + let result = alice_channel.send_return_box_exists(b"This should fail").await; + + match result { + Err(e) => { + let err_str = format!("{:?}", e); + println!("✓ Got expected error: {}", err_str); + assert!( + err_str.contains("BoxAlreadyExists") || err_str.contains("box already exists"), + "Expected BoxAlreadyExists error, got: {}", err_str + ); + } + Ok(_) => { + panic!("Expected BoxAlreadyExists error, but send succeeded"); + } + } + + println!("\n✅ send_return_box_exists test passed!"); +} diff --git a/tests/test_channel_api.py b/tests/test_channel_api.py deleted file mode 100644 index 0e05ba5..0000000 --- a/tests/test_channel_api.py +++ /dev/null @@ -1,345 +0,0 @@ -#!/usr/bin/env python3 -# SPDX-FileCopyrightText: Copyright (C) 2025 David Stainton -# SPDX-License-Identifier: AGPL-3.0-only - -""" -Channel API integration tests for the Python thin client. - -These tests mirror the Rust tests in channel_api_test.rs and require -a running mixnet with client daemon for integration testing. -""" - -import asyncio -import pytest -from katzenpost_thinclient import ThinClient, Config - - -async def setup_thin_client(): - """Test helper to setup a thin client for integration tests.""" - config = Config("testdata/thinclient.toml") - client = ThinClient(config) - - # Start the client and wait a bit for initial connection and PKI document - loop = asyncio.get_running_loop() - await client.start(loop) - await asyncio.sleep(2) - - return client - - -@pytest.mark.asyncio -async def test_channel_api_basics(): - """ - Test basic channel API operations - equivalent to TestChannelAPIBasics from Rust. - This test demonstrates the full channel workflow: Alice creates a write channel, - Bob creates a read channel, Alice writes messages, Bob reads them back. - """ - alice_thin_client = await setup_thin_client() - bob_thin_client = await setup_thin_client() - - # Wait for PKI documents to be available and connection to mixnet - print("Waiting for daemon to connect to mixnet...") - attempts = 0 - while not alice_thin_client.is_connected() and attempts < 30: - await asyncio.sleep(1) - attempts += 1 - - if not alice_thin_client.is_connected(): - raise Exception("Daemon failed to connect to mixnet within 30 seconds") - - print("✅ Daemon connected to mixnet, using current PKI document") - - # Alice creates write channel - print("Alice: Creating write channel") - alice_channel_id, read_cap, _write_cap = await alice_thin_client.create_write_channel() - print(f"Alice: Created write channel {alice_channel_id}") - - # Bob creates read channel using the read capability from Alice's write channel - print("Bob: Creating read channel") - bob_channel_id = await bob_thin_client.create_read_channel(read_cap) - print(f"Bob: Created read channel {bob_channel_id}") - - # Alice writes first message - original_message = b"hello1" - print("Alice: Writing first message and waiting for completion") - - write_reply1 = await alice_thin_client.write_channel(alice_channel_id, original_message) - print("Alice: Write operation completed successfully") - - # Get the courier service from PKI - courier_service = alice_thin_client.get_service("courier") - dest_node, dest_queue = courier_service.to_destination() - - alice_message_id1 = ThinClient.new_message_id() - - _reply1 = await alice_thin_client.send_channel_query_await_reply( - alice_channel_id, - write_reply1.send_message_payload, - dest_node, - dest_queue, - alice_message_id1 - ) - - # Alice writes a second message - second_message = b"hello2" - print("Alice: Writing second message and waiting for completion") - - write_reply2 = await alice_thin_client.write_channel(alice_channel_id, second_message) - print("Alice: Second write operation completed successfully") - - alice_message_id2 = ThinClient.new_message_id() - - _reply2 = await alice_thin_client.send_channel_query_await_reply( - alice_channel_id, - write_reply2.send_message_payload, - dest_node, - dest_queue, - alice_message_id2 - ) - - # Wait for message propagation to storage replicas - print("Waiting for message propagation to storage replicas") - await asyncio.sleep(10) - - # Bob reads first message - print("Bob: Reading first message") - read_reply1 = await bob_thin_client.read_channel(bob_channel_id, None, None) - - bob_message_id1 = ThinClient.new_message_id() - - # In a real implementation, you'd retry the send_channel_query_await_reply until you get a response - bob_reply_payload1 = b"" - for i in range(10): - try: - payload = await alice_thin_client.send_channel_query_await_reply( - bob_channel_id, - read_reply1.send_message_payload, - dest_node, - dest_queue, - bob_message_id1 - ) - if payload: - bob_reply_payload1 = payload - break - else: - print(f"Bob: Read attempt {i + 1} returned empty payload, retrying...") - await asyncio.sleep(0.5) - except Exception as e: - raise e - - assert original_message == bob_reply_payload1, "Bob: Reply payload mismatch" - - # Bob closes and resumes read channel to advance to second message - await bob_thin_client.close_channel(bob_channel_id) - - print("Bob: Resuming read channel to read second message") - bob_channel_id = await bob_thin_client.resume_read_channel( - read_cap, - read_reply1.next_message_index, - read_reply1.reply_index - ) - - # Bob reads second message - print("Bob: Reading second message") - read_reply2 = await bob_thin_client.read_channel(bob_channel_id, None, None) - - bob_message_id2 = ThinClient.new_message_id() - bob_reply_payload2 = b"" - - for i in range(10): - print(f"Bob: second read attempt {i + 1}") - try: - payload = await alice_thin_client.send_channel_query_await_reply( - bob_channel_id, - read_reply2.send_message_payload, - dest_node, - dest_queue, - bob_message_id2 - ) - if payload: - bob_reply_payload2 = payload - break - else: - await asyncio.sleep(0.5) - except Exception as e: - raise e - - assert second_message == bob_reply_payload2, "Bob: Second reply payload mismatch" - - # Clean up channels - await alice_thin_client.close_channel(alice_channel_id) - await bob_thin_client.close_channel(bob_channel_id) - - alice_thin_client.stop() - bob_thin_client.stop() - - print("✅ Channel API basics test completed successfully") - - -@pytest.mark.asyncio -async def test_resume_write_channel(): - """ - Test resuming a write channel - equivalent to TestResumeWriteChannel from Rust. - This test demonstrates the write channel resumption workflow: - 1. Create a write channel - 2. Write the first message onto the channel - 3. Close the channel - 4. Resume the channel - 5. Write the second message onto the channel - 6. Create a read channel - 7. Read first and second message from the channel - 8. Verify payloads match - """ - alice_thin_client = await setup_thin_client() - bob_thin_client = await setup_thin_client() - - # Wait for PKI documents to be available and connection to mixnet - print("Waiting for daemon to connect to mixnet...") - attempts = 0 - while not alice_thin_client.is_connected() and attempts < 30: - await asyncio.sleep(1) - attempts += 1 - - if not alice_thin_client.is_connected(): - raise Exception("Daemon failed to connect to mixnet within 30 seconds") - - print("✅ Daemon connected to mixnet, using current PKI document") - - # Alice creates write channel - print("Alice: Creating write channel") - alice_channel_id, read_cap, write_cap = await alice_thin_client.create_write_channel() - print(f"Alice: Created write channel {alice_channel_id}") - - # Alice writes first message - alice_payload1 = b"Hello, Bob!" - print("Alice: Writing first message") - write_reply1 = await alice_thin_client.write_channel(alice_channel_id, alice_payload1) - - # Get courier destination - dest_node, dest_queue = await alice_thin_client.get_courier_destination() - alice_message_id1 = ThinClient.new_message_id() - - # Send first message - _reply1 = await alice_thin_client.send_channel_query_await_reply( - alice_channel_id, - write_reply1.send_message_payload, - dest_node, - dest_queue, - alice_message_id1 - ) - - print("Waiting for first message propagation to storage replicas") - await asyncio.sleep(3) - - # Close the channel - await alice_thin_client.close_channel(alice_channel_id) - - # Resume the write channel - print("Alice: Resuming write channel") - alice_channel_id = await alice_thin_client.resume_write_channel( - write_cap, - write_reply1.next_message_index - ) - print(f"Alice: Resumed write channel with ID {alice_channel_id}") - - # Write second message after resume - print("Alice: Writing second message after resume") - alice_payload2 = b"Second message from Alice!" - write_reply2 = await alice_thin_client.write_channel(alice_channel_id, alice_payload2) - - alice_message_id2 = ThinClient.new_message_id() - _reply2 = await alice_thin_client.send_channel_query_await_reply( - alice_channel_id, - write_reply2.send_message_payload, - dest_node, - dest_queue, - alice_message_id2 - ) - print("Alice: Second write operation completed successfully") - - print("Waiting for second message propagation to storage replicas") - await asyncio.sleep(3) - - # Bob creates read channel - print("Bob: Creating read channel") - bob_channel_id = await bob_thin_client.create_read_channel(read_cap) - print(f"Bob: Created read channel {bob_channel_id}") - - # Bob reads first message - print("Bob: Reading first message") - read_reply1 = await bob_thin_client.read_channel(bob_channel_id, None, None) - - bob_message_id1 = ThinClient.new_message_id() - bob_reply_payload1 = b"" - - for i in range(10): - try: - payload = await alice_thin_client.send_channel_query_await_reply( - bob_channel_id, - read_reply1.send_message_payload, - dest_node, - dest_queue, - bob_message_id1 - ) - if payload: - bob_reply_payload1 = payload - break - else: - print(f"Bob: First read attempt {i + 1} returned empty payload, retrying...") - await asyncio.sleep(0.5) - except Exception as e: - raise e - - assert alice_payload1 == bob_reply_payload1, "Bob: First message payload mismatch" - - # Bob closes and resumes read channel to advance to second message - await bob_thin_client.close_channel(bob_channel_id) - - print("Bob: Resuming read channel to read second message") - bob_channel_id = await bob_thin_client.resume_read_channel( - read_cap, - read_reply1.next_message_index, - read_reply1.reply_index - ) - - # Bob reads second message - print("Bob: Reading second message") - read_reply2 = await bob_thin_client.read_channel(bob_channel_id, None, None) - - bob_message_id2 = ThinClient.new_message_id() - bob_reply_payload2 = b"" - - for i in range(10): - print(f"Bob: second message read attempt {i + 1}") - try: - payload = await alice_thin_client.send_channel_query_await_reply( - bob_channel_id, - read_reply2.send_message_payload, - dest_node, - dest_queue, - bob_message_id2 - ) - if payload: - bob_reply_payload2 = payload - break - else: - await asyncio.sleep(0.5) - except Exception as e: - raise e - - # Verify the second message content matches - assert alice_payload2 == bob_reply_payload2, "Bob: Second message payload mismatch" - print("Bob: Successfully received and verified second message") - - # Clean up channels - await alice_thin_client.close_channel(alice_channel_id) - await bob_thin_client.close_channel(bob_channel_id) - - alice_thin_client.stop() - bob_thin_client.stop() - - print("✅ Resume write channel test completed successfully") - - -if __name__ == "__main__": - pytest.main([__file__]) diff --git a/tests/test_channel_api_extended.py b/tests/test_channel_api_extended.py deleted file mode 100644 index 14b6304..0000000 --- a/tests/test_channel_api_extended.py +++ /dev/null @@ -1,492 +0,0 @@ -#!/usr/bin/env python3 -# SPDX-FileCopyrightText: Copyright (C) 2025 David Stainton -# SPDX-License-Identifier: AGPL-3.0-only - -""" -Extended channel API integration tests for the Python thin client. -These tests cover the more advanced channel resumption scenarios. -""" - -import asyncio -import pytest -from katzenpost_thinclient import ThinClient, Config - - -async def setup_thin_client(): - """Test helper to setup a thin client for integration tests.""" - config = Config("testdata/thinclient.toml") - client = ThinClient(config) - - # Start the client and wait a bit for initial connection and PKI document - loop = asyncio.get_running_loop() - await client.start(loop) - await asyncio.sleep(2) - - return client - - -@pytest.mark.asyncio -async def test_resume_write_channel_query(): - """ - Test resuming a write channel with query state - equivalent to TestResumeWriteChannelQuery from Rust. - This test demonstrates the write channel query resumption workflow: - 1. Create write channel - 2. Create first write query message but do not send to channel yet - 3. Close channel - 4. Resume write channel with query via resume_write_channel_query - 5. Send resumed write query to channel - 6. Send second message to channel - 7. Create read channel - 8. Read both messages from channel - 9. Verify payloads match - """ - alice_thin_client = await setup_thin_client() - bob_thin_client = await setup_thin_client() - - # Wait for PKI documents to be available and connection to mixnet - print("Waiting for daemon to connect to mixnet...") - attempts = 0 - while not alice_thin_client.is_connected() and attempts < 30: - await asyncio.sleep(1) - attempts += 1 - - if not alice_thin_client.is_connected(): - raise Exception("Daemon failed to connect to mixnet within 30 seconds") - - print("✅ Daemon connected to mixnet, using current PKI document") - - # Alice creates write channel - print("Alice: Creating write channel") - alice_channel_id, read_cap, write_cap = await alice_thin_client.create_write_channel() - print(f"Alice: Created write channel {alice_channel_id}") - - # Alice prepares first message but doesn't send it yet - alice_payload1 = b"Hello, Bob!" - write_reply = await alice_thin_client.write_channel(alice_channel_id, alice_payload1) - - # Get courier destination - courier_node, courier_queue_id = await alice_thin_client.get_courier_destination() - alice_message_id1 = ThinClient.new_message_id() - - # Close the channel immediately (like in Rust test - no waiting for propagation) - await alice_thin_client.close_channel(alice_channel_id) - - # Resume the write channel with query state using current_message_index like Rust test - print("Alice: Resuming write channel") - alice_channel_id = await alice_thin_client.resume_write_channel_query( - write_cap, - write_reply.current_message_index, # Use current_message_index like in Rust test - write_reply.envelope_descriptor, - write_reply.envelope_hash - ) - print(f"Alice: Resumed write channel with ID {alice_channel_id}") - - # Send the first message after resume - print("Alice: Writing first message after resume") - _reply1 = await alice_thin_client.send_channel_query_await_reply( - alice_channel_id, - write_reply.send_message_payload, - courier_node, - courier_queue_id, - alice_message_id1 - ) - - # Write second message - print("Alice: Writing second message") - alice_payload2 = b"Second message from Alice!" - write_reply2 = await alice_thin_client.write_channel(alice_channel_id, alice_payload2) - - alice_message_id2 = ThinClient.new_message_id() - _reply2 = await alice_thin_client.send_channel_query_await_reply( - alice_channel_id, - write_reply2.send_message_payload, - courier_node, - courier_queue_id, - alice_message_id2 - ) - print("Alice: Second write operation completed successfully") - - print("Waiting for second message propagation to storage replicas") - await asyncio.sleep(3) - - # Bob creates read channel - print("Bob: Creating read channel") - bob_channel_id = await bob_thin_client.create_read_channel(read_cap) - print(f"Bob: Created read channel {bob_channel_id}") - - # Bob reads first message - print("Bob: Reading first message") - read_reply1 = await bob_thin_client.read_channel(bob_channel_id, None, None) - - bob_message_id1 = ThinClient.new_message_id() - bob_reply_payload1 = b"" - - for i in range(10): - try: - payload = await alice_thin_client.send_channel_query_await_reply( - bob_channel_id, - read_reply1.send_message_payload, - courier_node, - courier_queue_id, - bob_message_id1 - ) - if payload: - bob_reply_payload1 = payload - break - else: - print(f"Bob: First read attempt {i + 1} returned empty payload, retrying...") - await asyncio.sleep(0.5) - except Exception as e: - raise e - - assert alice_payload1 == bob_reply_payload1, "Bob: First message payload mismatch" - - # Bob reads second message - print("Bob: Reading second message") - read_reply2 = await bob_thin_client.read_channel(bob_channel_id, None, None) - - bob_message_id2 = ThinClient.new_message_id() - bob_reply_payload2 = b"" - - for i in range(10): - print(f"Bob: second message read attempt {i + 1}") - try: - payload = await alice_thin_client.send_channel_query_await_reply( - bob_channel_id, - read_reply2.send_message_payload, - courier_node, - courier_queue_id, - bob_message_id2 - ) - if payload: - bob_reply_payload2 = payload - break - else: - await asyncio.sleep(0.5) - except Exception as e: - raise e - - # Verify the second message content matches - assert alice_payload2 == bob_reply_payload2, "Bob: Second message payload mismatch" - print("Bob: Successfully received and verified second message") - - # Clean up channels - await alice_thin_client.close_channel(alice_channel_id) - await bob_thin_client.close_channel(bob_channel_id) - - alice_thin_client.stop() - bob_thin_client.stop() - - print("✅ Resume write channel query test completed successfully") - - -@pytest.mark.asyncio -async def test_resume_read_channel(): - """ - Test resuming a read channel - equivalent to TestResumeReadChannel from Rust. - This test demonstrates the read channel resumption workflow: - 1. Create a write channel - 2. Write two messages to the channel - 3. Create a read channel - 4. Read the first message from the channel - 5. Verify payload matches - 6. Close the read channel - 7. Resume the read channel - 8. Read the second message from the channel - 9. Verify payload matches - """ - alice_thin_client = await setup_thin_client() - bob_thin_client = await setup_thin_client() - - # Wait for PKI documents to be available and connection to mixnet - print("Waiting for daemon to connect to mixnet...") - attempts = 0 - while not alice_thin_client.is_connected() and attempts < 30: - await asyncio.sleep(1) - attempts += 1 - - if not alice_thin_client.is_connected(): - raise Exception("Daemon failed to connect to mixnet within 30 seconds") - - print("✅ Daemon connected to mixnet, using current PKI document") - - # Alice creates write channel - print("Alice: Creating write channel") - alice_channel_id, read_cap, _write_cap = await alice_thin_client.create_write_channel() - print(f"Alice: Created write channel {alice_channel_id}") - - # Alice writes first message - alice_payload1 = b"Hello, Bob!" - write_reply1 = await alice_thin_client.write_channel(alice_channel_id, alice_payload1) - - dest_node, dest_queue = await alice_thin_client.get_courier_destination() - alice_message_id1 = ThinClient.new_message_id() - - _reply1 = await alice_thin_client.send_channel_query_await_reply( - alice_channel_id, - write_reply1.send_message_payload, - dest_node, - dest_queue, - alice_message_id1 - ) - - print("Waiting for first message propagation to storage replicas") - await asyncio.sleep(3) - - # Alice writes second message - print("Alice: Writing second message") - alice_payload2 = b"Second message from Alice!" - write_reply2 = await alice_thin_client.write_channel(alice_channel_id, alice_payload2) - - alice_message_id2 = ThinClient.new_message_id() - _reply2 = await alice_thin_client.send_channel_query_await_reply( - alice_channel_id, - write_reply2.send_message_payload, - dest_node, - dest_queue, - alice_message_id2 - ) - print("Alice: Second write operation completed successfully") - - print("Waiting for second message propagation to storage replicas") - await asyncio.sleep(3) - - # Bob creates read channel - print("Bob: Creating read channel") - bob_channel_id = await bob_thin_client.create_read_channel(read_cap) - print(f"Bob: Created read channel {bob_channel_id}") - - # Bob reads first message - print("Bob: Reading first message") - read_reply1 = await bob_thin_client.read_channel(bob_channel_id, None, None) - - bob_message_id1 = ThinClient.new_message_id() - bob_reply_payload1 = b"" - - for i in range(10): - try: - payload = await alice_thin_client.send_channel_query_await_reply( - bob_channel_id, - read_reply1.send_message_payload, - dest_node, - dest_queue, - bob_message_id1 - ) - if payload: - bob_reply_payload1 = payload - break - else: - print(f"Bob: First read attempt {i + 1} returned empty payload, retrying...") - await asyncio.sleep(0.5) - except Exception as e: - raise e - - assert alice_payload1 == bob_reply_payload1, "Bob: First message payload mismatch" - - # Close the read channel - await bob_thin_client.close_channel(bob_channel_id) - - # Resume the read channel - print("Bob: Resuming read channel") - bob_channel_id = await bob_thin_client.resume_read_channel( - read_cap, - read_reply1.next_message_index, - read_reply1.reply_index - ) - print(f"Bob: Resumed read channel with ID {bob_channel_id}") - - # Bob reads second message - print("Bob: Reading second message") - read_reply2 = await bob_thin_client.read_channel(bob_channel_id, None, None) - - bob_message_id2 = ThinClient.new_message_id() - bob_reply_payload2 = b"" - - for i in range(10): - print(f"Bob: second message read attempt {i + 1}") - try: - payload = await alice_thin_client.send_channel_query_await_reply( - bob_channel_id, - read_reply2.send_message_payload, - dest_node, - dest_queue, - bob_message_id2 - ) - if payload: - bob_reply_payload2 = payload - break - else: - await asyncio.sleep(0.5) - except Exception as e: - raise e - - # Verify the second message content matches - assert alice_payload2 == bob_reply_payload2, "Bob: Second message payload mismatch" - print("Bob: Successfully received and verified second message") - - # Clean up channels - await alice_thin_client.close_channel(alice_channel_id) - await bob_thin_client.close_channel(bob_channel_id) - - alice_thin_client.stop() - bob_thin_client.stop() - - print("✅ Resume read channel test completed successfully") - - -@pytest.mark.asyncio -async def test_resume_read_channel_query(): - """ - Test resuming a read channel with query state - equivalent to TestResumeReadChannelQuery from Rust. - This test demonstrates the read channel query resumption workflow: - 1. Create a write channel - 2. Write two messages to the channel - 3. Create read channel - 4. Make read query but do not send it - 5. Close read channel - 6. Resume read channel query with resume_read_channel_query method - 7. Send previously made read query to channel - 8. Verify received payload matches - 9. Read second message from channel - 10. Verify received payload matches - """ - alice_thin_client = await setup_thin_client() - bob_thin_client = await setup_thin_client() - - # Wait for PKI documents to be available and connection to mixnet - print("Waiting for daemon to connect to mixnet...") - attempts = 0 - while not alice_thin_client.is_connected() and attempts < 30: - await asyncio.sleep(1) - attempts += 1 - - if not alice_thin_client.is_connected(): - raise Exception("Daemon failed to connect to mixnet within 30 seconds") - - print("✅ Daemon connected to mixnet, using current PKI document") - - # Alice creates write channel - print("Alice: Creating write channel") - alice_channel_id, read_cap, _write_cap = await alice_thin_client.create_write_channel() - print(f"Alice: Created write channel {alice_channel_id}") - - # Alice writes first message - alice_payload1 = b"Hello, Bob!" - write_reply1 = await alice_thin_client.write_channel(alice_channel_id, alice_payload1) - - dest_node, dest_queue = await alice_thin_client.get_courier_destination() - alice_message_id1 = ThinClient.new_message_id() - - _reply1 = await alice_thin_client.send_channel_query_await_reply( - alice_channel_id, - write_reply1.send_message_payload, - dest_node, - dest_queue, - alice_message_id1 - ) - - print("Waiting for first message propagation to storage replicas") - await asyncio.sleep(3) - - # Alice writes second message - print("Alice: Writing second message") - alice_payload2 = b"Second message from Alice!" - write_reply2 = await alice_thin_client.write_channel(alice_channel_id, alice_payload2) - - alice_message_id2 = ThinClient.new_message_id() - _reply2 = await alice_thin_client.send_channel_query_await_reply( - alice_channel_id, - write_reply2.send_message_payload, - dest_node, - dest_queue, - alice_message_id2 - ) - print("Alice: Second write operation completed successfully") - - print("Waiting for second message propagation to storage replicas") - await asyncio.sleep(3) - - # Bob creates read channel - print("Bob: Creating read channel") - bob_channel_id = await bob_thin_client.create_read_channel(read_cap) - print(f"Bob: Created read channel {bob_channel_id}") - - # Bob prepares first read query but doesn't send it yet - print("Bob: Reading first message") - read_reply1 = await bob_thin_client.read_channel(bob_channel_id, None, None) - - # Close the read channel - await bob_thin_client.close_channel(bob_channel_id) - - # Resume the read channel with query state - print("Bob: Resuming read channel") - bob_channel_id = await bob_thin_client.resume_read_channel_query( - read_cap, - read_reply1.current_message_index, - read_reply1.reply_index, - read_reply1.envelope_descriptor, - read_reply1.envelope_hash - ) - print(f"Bob: Resumed read channel with ID {bob_channel_id}") - - # Send the first read query and get the message payload - bob_message_id1 = ThinClient.new_message_id() - bob_reply_payload1 = b"" - - for i in range(10): - print(f"Bob: first message read attempt {i + 1}") - try: - payload = await alice_thin_client.send_channel_query_await_reply( - bob_channel_id, - read_reply1.send_message_payload, - dest_node, - dest_queue, - bob_message_id1 - ) - if payload: - bob_reply_payload1 = payload - break - else: - await asyncio.sleep(0.5) - except Exception as e: - raise e - - assert alice_payload1 == bob_reply_payload1, "Bob: First message payload mismatch" - - # Bob reads second message - print("Bob: Reading second message") - read_reply2 = await bob_thin_client.read_channel(bob_channel_id, None, None) - - bob_message_id2 = ThinClient.new_message_id() - bob_reply_payload2 = b"" - - for i in range(10): - print(f"Bob: second message read attempt {i + 1}") - try: - payload = await alice_thin_client.send_channel_query_await_reply( - bob_channel_id, - read_reply2.send_message_payload, - dest_node, - dest_queue, - bob_message_id2 - ) - if payload: - bob_reply_payload2 = payload - break - else: - await asyncio.sleep(0.5) - except Exception as e: - raise e - - # Verify the second message content matches - assert alice_payload2 == bob_reply_payload2, "Bob: Second message payload mismatch" - print("Bob: Successfully received and verified second message") - - # Clean up channels - await alice_thin_client.close_channel(alice_channel_id) - await bob_thin_client.close_channel(bob_channel_id) - - alice_thin_client.stop() - bob_thin_client.stop() - - print("✅ Resume read channel query test completed successfully") diff --git a/tests/test_core.py b/tests/test_core.py index 070ccaa..efa4c4b 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -37,6 +37,21 @@ async def test_thin_client_send_receive_integration_test(): try: await client.start(loop) + # Wait for daemon to connect to mixnet and receive PKI document + print("Waiting for daemon to connect to mixnet...") + attempts = 0 + while (not client.is_connected() or client.pki_document() is None) and attempts < 30: + await asyncio.sleep(1) + attempts += 1 + + if not client.is_connected(): + raise Exception("Daemon failed to connect to mixnet within 30 seconds") + + if client.pki_document() is None: + raise Exception("PKI document not received within 30 seconds") + + print("✅ Daemon connected to mixnet, using current PKI document") + service_desc = client.get_service("echo") surb_id = client.new_surb_id() payload = "hello" @@ -78,3 +93,236 @@ async def dummy_callback(event): assert cfg_with_callbacks is not None, "Config with callbacks should work" # Configuration validation passed + + +def test_error_codes_completeness(): + """ + Test that all error codes 0-24 are defined and have corresponding error strings. + + This is a unit test that doesn't require a daemon connection. + It verifies error code consistency between constants and the error string function. + """ + from katzenpost_thinclient import ( + THIN_CLIENT_SUCCESS, + THIN_CLIENT_ERROR_CONNECTION_LOST, + THIN_CLIENT_ERROR_TIMEOUT, + THIN_CLIENT_ERROR_INVALID_REQUEST, + THIN_CLIENT_ERROR_INTERNAL_ERROR, + THIN_CLIENT_ERROR_MAX_RETRIES, + THIN_CLIENT_ERROR_INVALID_CHANNEL, + THIN_CLIENT_ERROR_CHANNEL_NOT_FOUND, + THIN_CLIENT_ERROR_PERMISSION_DENIED, + THIN_CLIENT_ERROR_INVALID_PAYLOAD, + THIN_CLIENT_ERROR_SERVICE_UNAVAILABLE, + THIN_CLIENT_ERROR_DUPLICATE_CAPABILITY, + THIN_CLIENT_ERROR_COURIER_CACHE_CORRUPTION, + THIN_CLIENT_PROPAGATION_ERROR, + THIN_CLIENT_ERROR_INVALID_WRITE_CAPABILITY, + THIN_CLIENT_ERROR_INVALID_READ_CAPABILITY, + THIN_CLIENT_ERROR_INVALID_RESUME_WRITE_CHANNEL_REQUEST, + THIN_CLIENT_ERROR_INVALID_RESUME_READ_CHANNEL_REQUEST, + THIN_CLIENT_IMPOSSIBLE_HASH_ERROR, + THIN_CLIENT_IMPOSSIBLE_NEW_WRITE_CAP_ERROR, + THIN_CLIENT_IMPOSSIBLE_NEW_STATEFUL_WRITER_ERROR, + THIN_CLIENT_CAPABILITY_ALREADY_IN_USE, + THIN_CLIENT_ERROR_MKEM_DECRYPTION_FAILED, + THIN_CLIENT_ERROR_BACAP_DECRYPTION_FAILED, + THIN_CLIENT_ERROR_START_RESENDING_CANCELLED, + thin_client_error_to_string + ) + + # Verify all error codes have sequential values 0-24 + expected_codes = { + THIN_CLIENT_SUCCESS: 0, + THIN_CLIENT_ERROR_CONNECTION_LOST: 1, + THIN_CLIENT_ERROR_TIMEOUT: 2, + THIN_CLIENT_ERROR_INVALID_REQUEST: 3, + THIN_CLIENT_ERROR_INTERNAL_ERROR: 4, + THIN_CLIENT_ERROR_MAX_RETRIES: 5, + THIN_CLIENT_ERROR_INVALID_CHANNEL: 6, + THIN_CLIENT_ERROR_CHANNEL_NOT_FOUND: 7, + THIN_CLIENT_ERROR_PERMISSION_DENIED: 8, + THIN_CLIENT_ERROR_INVALID_PAYLOAD: 9, + THIN_CLIENT_ERROR_SERVICE_UNAVAILABLE: 10, + THIN_CLIENT_ERROR_DUPLICATE_CAPABILITY: 11, + THIN_CLIENT_ERROR_COURIER_CACHE_CORRUPTION: 12, + THIN_CLIENT_PROPAGATION_ERROR: 13, + THIN_CLIENT_ERROR_INVALID_WRITE_CAPABILITY: 14, + THIN_CLIENT_ERROR_INVALID_READ_CAPABILITY: 15, + THIN_CLIENT_ERROR_INVALID_RESUME_WRITE_CHANNEL_REQUEST: 16, + THIN_CLIENT_ERROR_INVALID_RESUME_READ_CHANNEL_REQUEST: 17, + THIN_CLIENT_IMPOSSIBLE_HASH_ERROR: 18, + THIN_CLIENT_IMPOSSIBLE_NEW_WRITE_CAP_ERROR: 19, + THIN_CLIENT_IMPOSSIBLE_NEW_STATEFUL_WRITER_ERROR: 20, + THIN_CLIENT_CAPABILITY_ALREADY_IN_USE: 21, + THIN_CLIENT_ERROR_MKEM_DECRYPTION_FAILED: 22, + THIN_CLIENT_ERROR_BACAP_DECRYPTION_FAILED: 23, + THIN_CLIENT_ERROR_START_RESENDING_CANCELLED: 24, + } + + for const, expected_value in expected_codes.items(): + assert const == expected_value, f"Error code constant has wrong value: expected {expected_value}, got {const}" + + # Verify all error codes have non-empty, non-"Unknown" error strings + for code in range(25): + error_str = thin_client_error_to_string(code) + assert error_str, f"Error code {code} has empty error string" + assert "Unknown" not in error_str, f"Error code {code} has 'Unknown' in error string: {error_str}" + + # Verify specific error strings for cancel behavior + assert thin_client_error_to_string(THIN_CLIENT_ERROR_START_RESENDING_CANCELLED) == "Start resending cancelled" + + print("✅ All error codes 0-24 are defined with proper error strings") + + +class TestGracefulShutdown: + """ + Unit tests for graceful shutdown behavior. + + These tests verify that BrokenPipeError and other connection errors + are handled gracefully during shutdown without printing tracebacks. + """ + + def test_stopping_flag_initially_false(self): + """Test that _stopping flag is False after initialization.""" + from .conftest import get_config_path + + config_path = get_config_path() + cfg = Config(config_path) + client = ThinClient(cfg) + + assert client._stopping is False, "_stopping should be False initially" + + # Cleanup - close socket without starting + client.socket.close() + print("✅ _stopping flag is False on initialization") + + def test_stop_sets_stopping_flag(self): + """Test that stop() sets _stopping flag to True before closing.""" + from .conftest import get_config_path + import socket as sock_module + + config_path = get_config_path() + cfg = Config(config_path) + client = ThinClient(cfg) + + # Create a mock task to avoid AttributeError + class MockTask: + def cancel(self): + pass + client.task = MockTask() + + assert client._stopping is False, "_stopping should be False before stop()" + + client.stop() + + assert client._stopping is True, "_stopping should be True after stop()" + print("✅ stop() sets _stopping flag correctly") + + @pytest.mark.asyncio + async def test_worker_loop_handles_broken_pipe_during_shutdown(self): + """Test that worker_loop handles BrokenPipeError gracefully when stopping.""" + from .conftest import get_config_path + from unittest.mock import AsyncMock, patch + + config_path = get_config_path() + cfg = Config(config_path) + client = ThinClient(cfg) + + # Set stopping flag to True (simulating shutdown in progress) + client._stopping = True + + # Mock recv to raise BrokenPipeError + async def mock_recv_broken_pipe(loop): + raise BrokenPipeError("Connection closed") + + client.recv = mock_recv_broken_pipe + + loop = asyncio.get_running_loop() + + # worker_loop should exit gracefully without raising + # when _stopping is True and BrokenPipeError occurs + await client.worker_loop(loop) + + # If we get here, the test passed - worker_loop handled the error gracefully + client.socket.close() + print("✅ worker_loop handles BrokenPipeError gracefully during shutdown") + + @pytest.mark.asyncio + async def test_worker_loop_raises_broken_pipe_when_not_stopping(self): + """Test that worker_loop raises BrokenPipeError when not in shutdown.""" + from .conftest import get_config_path + + config_path = get_config_path() + cfg = Config(config_path) + client = ThinClient(cfg) + + # Ensure stopping flag is False (not in shutdown) + client._stopping = False + + # Mock recv to raise BrokenPipeError + async def mock_recv_broken_pipe(loop): + raise BrokenPipeError("Connection closed") + + client.recv = mock_recv_broken_pipe + + loop = asyncio.get_running_loop() + + # worker_loop should raise BrokenPipeError when _stopping is False + with pytest.raises(BrokenPipeError): + await client.worker_loop(loop) + + client.socket.close() + print("✅ worker_loop raises BrokenPipeError when not stopping") + + @pytest.mark.asyncio + async def test_worker_loop_handles_connection_reset_during_shutdown(self): + """Test that worker_loop handles ConnectionResetError gracefully when stopping.""" + from .conftest import get_config_path + + config_path = get_config_path() + cfg = Config(config_path) + client = ThinClient(cfg) + + # Set stopping flag to True + client._stopping = True + + # Mock recv to raise ConnectionResetError + async def mock_recv_conn_reset(loop): + raise ConnectionResetError("Connection reset by peer") + + client.recv = mock_recv_conn_reset + + loop = asyncio.get_running_loop() + + # Should exit gracefully + await client.worker_loop(loop) + + client.socket.close() + print("✅ worker_loop handles ConnectionResetError gracefully during shutdown") + + @pytest.mark.asyncio + async def test_worker_loop_handles_os_error_during_shutdown(self): + """Test that worker_loop handles OSError gracefully when stopping.""" + from .conftest import get_config_path + + config_path = get_config_path() + cfg = Config(config_path) + client = ThinClient(cfg) + + # Set stopping flag to True + client._stopping = True + + # Mock recv to raise OSError (e.g., bad file descriptor) + async def mock_recv_os_error(loop): + raise OSError("Bad file descriptor") + + client.recv = mock_recv_os_error + + loop = asyncio.get_running_loop() + + # Should exit gracefully + await client.worker_loop(loop) + + client.socket.close() + print("✅ worker_loop handles OSError gracefully during shutdown") diff --git a/tests/test_new_pigeonhole_api.py b/tests/test_new_pigeonhole_api.py new file mode 100644 index 0000000..7f24c9d --- /dev/null +++ b/tests/test_new_pigeonhole_api.py @@ -0,0 +1,1359 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (C) 2025 David Stainton +# SPDX-License-Identifier: AGPL-3.0-only + +""" +NEW Pigeonhole API integration tests for the Python thin client. + +These tests verify the 5-function NEW Pigeonhole API: +1. new_keypair - Generate WriteCap and ReadCap from seed +2. encrypt_read - Encrypt a read operation +3. encrypt_write - Encrypt a write operation +4. start_resending_encrypted_message - Send encrypted message with ARQ +5. cancel_resending_encrypted_message - Cancel ARQ for a message + +These tests require a running mixnet with client daemon for integration testing. +""" + +import asyncio +import pytest +import os +from katzenpost_thinclient import ThinClient, Config + + +async def setup_thin_client(): + """Test helper to setup a thin client for integration tests.""" + from .conftest import get_config_path + + config_path = get_config_path() + config = Config(config_path) + client = ThinClient(config) + + # Start the client and wait for connection and PKI document + loop = asyncio.get_running_loop() + await client.start(loop) + + # Wait for daemon to connect to mixnet and receive PKI document + print("Waiting for daemon to connect to mixnet...") + attempts = 0 + while (not client.is_connected() or client.pki_document() is None) and attempts < 30: + await asyncio.sleep(1) + attempts += 1 + + if not client.is_connected(): + raise Exception("Daemon failed to connect to mixnet within 30 seconds") + + if client.pki_document() is None: + raise Exception("PKI document not received within 30 seconds") + + print("✅ Daemon connected to mixnet, using current PKI document") + + return client + + +@pytest.mark.asyncio +async def test_new_keypair_basic(): + """ + Test basic keypair generation using new_keypair. + + This test verifies: + 1. Keypair can be generated from a 32-byte seed + 2. WriteCap, ReadCap, and FirstMessageIndex are returned + 3. The returned values have the expected sizes + """ + client = await setup_thin_client() + + try: + print("\n=== Test: new_keypair basic functionality ===") + + # Generate a 32-byte seed + seed = os.urandom(32) + print(f"Generated seed: {len(seed)} bytes") + + # Create keypair + keypair = await client.new_keypair(seed) + + print(f"✓ WriteCap size: {len(keypair.write_cap)} bytes") + print(f"✓ ReadCap size: {len(keypair.read_cap)} bytes") + print(f"✓ FirstMessageIndex size: {len(keypair.first_message_index)} bytes") + + # Verify the returned values are not empty + assert len(keypair.write_cap) > 0, "WriteCap should not be empty" + assert len(keypair.read_cap) > 0, "ReadCap should not be empty" + assert len(keypair.first_message_index) > 0, "FirstMessageIndex should not be empty" + + print("✅ new_keypair test completed successfully") + + finally: + client.stop() + + +@pytest.mark.asyncio +async def test_alice_sends_bob_complete_workflow(): + """ + Test complete end-to-end workflow: Alice sends a message to Bob. + + This test demonstrates the full NEW Pigeonhole API workflow: + 1. Alice creates a WriteCap and derives a ReadCap for Bob + 2. Alice encrypts a message using encrypt_write + 3. Alice sends the encrypted message via start_resending_encrypted_message + 4. Bob encrypts a read request using encrypt_read + 5. Bob sends the read request and receives Alice's encrypted message + 6. Bob verifies the received message + + This mirrors the Go test: TestNewPigeonholeAPIAliceSendsBob + """ + alice_client = await setup_thin_client() + bob_client = await setup_thin_client() + + try: + print("\n=== Test: Alice sends message to Bob (complete workflow) ===") + + # Step 1: Alice creates WriteCap and derives ReadCap for Bob + print("\n--- Step 1: Alice creates keypair ---") + alice_seed = os.urandom(32) + alice_keypair = await alice_client.new_keypair(alice_seed) + print(f"✓ Alice created WriteCap and derived ReadCap for Bob") + + # Step 2: Alice encrypts a message for Bob + print("\n--- Step 2: Alice encrypts message ---") + alice_message = b"Bob, Beware they are jamming GPS." + print(f"Alice's message: {alice_message.decode()}") + + alice_result = await alice_client.encrypt_write( + alice_message, alice_keypair.write_cap, alice_keypair.first_message_index + ) + print(f"✓ Alice encrypted message (ciphertext: {len(alice_result.message_ciphertext)} bytes)") + + # Step 3: Alice sends the encrypted message via start_resending_encrypted_message + print("\n--- Step 3: Alice sends encrypted message to courier/replicas ---") + reply_index = 0 + + alice_plaintext = await alice_client.start_resending_encrypted_message( + read_cap=None, # None for write operations + write_cap=alice_keypair.write_cap, + next_message_index=None, # Not needed for writes + reply_index=reply_index, + envelope_descriptor=alice_result.envelope_descriptor, + message_ciphertext=alice_result.message_ciphertext, + envelope_hash=alice_result.envelope_hash + ) + + # For write operations, plaintext should be empty (ACK only) + print(f"✓ Alice received ACK (plaintext length: {len(alice_plaintext) if alice_plaintext else 0})") + + # Wait for message propagation to storage replicas + print("\n--- Waiting for message propagation to storage replicas ---") + await asyncio.sleep(5) + + # Step 4: Bob encrypts a read request + print("\n--- Step 4: Bob encrypts read request ---") + bob_result = await bob_client.encrypt_read( + alice_keypair.read_cap, alice_keypair.first_message_index + ) + print(f"✓ Bob encrypted read request (ciphertext: {len(bob_result.message_ciphertext)} bytes)") + + # Step 5: Bob sends the read request and receives Alice's encrypted message + print("\n--- Step 5: Bob sends read request and receives encrypted message ---") + bob_plaintext = await bob_client.start_resending_encrypted_message( + read_cap=alice_keypair.read_cap, + write_cap=None, # None for read operations + next_message_index=bob_result.next_message_index, + reply_index=reply_index, + envelope_descriptor=bob_result.envelope_descriptor, + message_ciphertext=bob_result.message_ciphertext, + envelope_hash=bob_result.envelope_hash + ) + + # Step 6: Verify Bob received Alice's message + print(f"\n--- Step 6: Verify received message ---") + print(f"Bob received: {bob_plaintext.decode() if bob_plaintext else '(empty)'}") + + assert bob_plaintext == alice_message, f"Message mismatch! Expected: {alice_message}, Got: {bob_plaintext}" + + print("✅ Complete workflow test passed - Bob successfully received Alice's message!") + + finally: + alice_client.stop() + bob_client.stop() + + +@pytest.mark.asyncio +async def test_cancel_resending_encrypted_message(): + """ + Test cancelling ARQ for an encrypted message. + + This test verifies: + 1. An encrypted message can be prepared + 2. The ARQ can be cancelled using cancel_resending_encrypted_message + 3. The cancellation completes without error + """ + client = await setup_thin_client() + + try: + print("\n=== Test: cancel_resending_encrypted_message ===") + + # Generate keypair and encrypt a message + seed = os.urandom(32) + keypair = await client.new_keypair(seed) + + plaintext = b"This message will be cancelled" + result = await client.encrypt_write( + plaintext, keypair.write_cap, keypair.first_message_index + ) + + print(f"✓ Encrypted message for cancellation test") + print(f"EnvelopeHash: {result.envelope_hash.hex()}") + + # Cancel the message (before sending it) + # Note: In practice, you would start_resending first, then cancel + # But for this test, we just verify the cancel API works + await client.cancel_resending_encrypted_message(result.envelope_hash) + + print("✅ cancel_resending_encrypted_message completed successfully") + + finally: + client.stop() + + +@pytest.mark.asyncio +@pytest.mark.timeout(60) # Prevent test from hanging in CI +async def test_cancel_causes_start_resending_to_return_error(): + """ + Test that calling cancel causes start_resending to return with error code 24. + + This test verifies the core cancel behavior: + 1. Start a start_resending_encrypted_message call (which blocks waiting for reply) + 2. Call cancel_resending_encrypted_message from another task + 3. Verify that the original start_resending call returns with StartResendingCancelledError + + This requires a running daemon but does NOT require a full mixnet since we're + testing the cancel behavior before any reply is received from the mixnet. + """ + from katzenpost_thinclient import StartResendingCancelledError + + client = await setup_thin_client() + + try: + print("\n=== Test: cancel causes start_resending to return error ===") + + # Generate keypair and encrypt a message + seed = os.urandom(32) + keypair = await client.new_keypair(seed) + + plaintext = b"This message will be cancelled while sending" + result = await client.encrypt_write( + plaintext, keypair.write_cap, keypair.first_message_index + ) + + print(f"✓ Encrypted message") + print(f"EnvelopeHash: {result.envelope_hash.hex()}") + + # Track the result of the start_resending call + start_resending_result = None # Will be "success", "cancelled", or an exception + start_resending_completed = asyncio.Event() + + async def start_resending_task(): + """Task that calls start_resending and captures the result.""" + nonlocal start_resending_result + try: + await client.start_resending_encrypted_message( + read_cap=None, + write_cap=keypair.write_cap, + next_message_index=None, + reply_index=0, + envelope_descriptor=result.envelope_descriptor, + message_ciphertext=result.message_ciphertext, + envelope_hash=result.envelope_hash + ) + # If we get here without error, the message completed before cancel + start_resending_result = "success" + except StartResendingCancelledError: + # This is the expected case when cancel works + start_resending_result = "cancelled" + except Exception as e: + # Unexpected error + start_resending_result = e + finally: + start_resending_completed.set() + + # Start the start_resending task + print("--- Starting start_resending_encrypted_message task ---") + resend_task = asyncio.create_task(start_resending_task()) + + # Retry cancel until start_resending returns. + # The daemon only sends StartResendingEncryptedMessageReply with error code 24 + # if it finds the envelope in arqEnvelopeHashMap. If cancel arrives before the + # daemon has fully registered the request, it won't find it and won't wake up + # the waiting caller. By retrying, we ensure we eventually hit after registration. + print("--- Calling cancel_resending_encrypted_message (with retry) ---") + max_attempts = 20 + for attempt in range(max_attempts): + # Small delay between attempts to avoid spamming the daemon + await asyncio.sleep(0.1) + + try: + await asyncio.wait_for( + client.cancel_resending_encrypted_message(result.envelope_hash), + timeout=5.0 + ) + except asyncio.TimeoutError: + resend_task.cancel() + raise Exception("cancel_resending_encrypted_message timed out") + + # Check if start_resending has completed + try: + await asyncio.wait_for(start_resending_completed.wait(), timeout=0.5) + print(f"✓ Cancel succeeded on attempt {attempt + 1}") + break # start_resending returned! + except asyncio.TimeoutError: + # Cancel didn't find the envelope yet (not registered), retry + continue + else: + resend_task.cancel() + raise Exception(f"start_resending did not return after {max_attempts} cancel attempts") + + # Verify the result + print(f"--- Verifying result ---") + print(f"Result received: {start_resending_result}") + + assert start_resending_result is not None, "Expected a result but got None" + + # The test can have two valid outcomes: + # 1. Cancel happened before ACK: start_resending raises StartResendingCancelledError + # 2. ACK arrived before cancel: start_resending completes successfully + # + # Both are valid behaviors - the cancel feature works correctly in case 1, + # and in case 2, the message simply completed before we could cancel it. + # This can happen in fast environments (like CI with local mixnet). + if start_resending_result == "success": + print("⚠️ Message completed before cancel took effect (ACK arrived quickly)") + print("✅ Test passed - cancel was called but message completed first (valid race condition)") + elif start_resending_result == "cancelled": + print("✅ start_resending returned with StartResendingCancelledError (error code 24)") + else: + # Unexpected error + raise AssertionError(f"Unexpected error: {start_resending_result}") + + finally: + client.stop() + + +@pytest.mark.asyncio +@pytest.mark.timeout(60) # Prevent test from hanging in CI +async def test_cancel_causes_start_resending_copy_command_to_return_error(): + """ + Test that calling cancel causes start_resending_copy_command to return with error. + + This test verifies the cancel behavior for copy commands: + 1. Create a temporary channel and write some data to it + 2. Start a start_resending_copy_command call (which blocks) + 3. Call cancel_resending_copy_command from another task + 4. Verify that the original start_resending call returns with error code 24 + """ + from hashlib import blake2b + + client = await setup_thin_client() + + try: + print("\n=== Test: cancel causes start_resending_copy_command to return error ===") + + # Create temporary channel + temp_seed = os.urandom(32) + temp_keypair = await client.new_keypair(temp_seed) + print("✓ Created temporary copy stream WriteCap") + + # Compute write_cap_hash for cancel + write_cap_hash = blake2b(temp_keypair.write_cap, digest_size=32).digest() + print(f"WriteCapHash: {write_cap_hash.hex()}") + + # Track whether the start_resending returned with the expected error + start_resending_error = None + start_resending_completed = asyncio.Event() + + async def start_resending_copy_task(): + """Task that calls start_resending_copy_command and captures any error.""" + nonlocal start_resending_error + try: + await client.start_resending_copy_command(temp_keypair.write_cap) + # If we get here without error, that's unexpected + start_resending_error = "No error raised" + except Exception as e: + start_resending_error = str(e) + finally: + start_resending_completed.set() + + # Start the start_resending_copy_command task + print("--- Starting start_resending_copy_command task ---") + resend_task = asyncio.create_task(start_resending_copy_task()) + + # Retry cancel until start_resending returns. + # The daemon only sends StartResendingCopyCommandReply with error code 24 + # if it finds the write_cap_hash in arqWriteCapHashMap. If cancel arrives before + # the daemon has fully registered the request, it won't find it and won't wake up + # the waiting caller. By retrying, we ensure we eventually hit after registration. + print("--- Calling cancel_resending_copy_command (with retry) ---") + max_attempts = 20 + for attempt in range(max_attempts): + # Small delay between attempts to avoid spamming the daemon + await asyncio.sleep(0.1) + + try: + await asyncio.wait_for( + client.cancel_resending_copy_command(write_cap_hash), + timeout=5.0 + ) + except asyncio.TimeoutError: + resend_task.cancel() + raise Exception("cancel_resending_copy_command timed out") + + # Check if start_resending has completed + try: + await asyncio.wait_for(start_resending_completed.wait(), timeout=0.5) + print(f"✓ Cancel succeeded on attempt {attempt + 1}") + break # start_resending returned! + except asyncio.TimeoutError: + # Cancel didn't find the write_cap_hash yet (not registered), retry + continue + else: + resend_task.cancel() + raise Exception(f"start_resending_copy_command did not return after {max_attempts} cancel attempts") + + # Verify the result + print(f"--- Verifying result ---") + print(f"Result received: {start_resending_error}") + + assert start_resending_error is not None, "Expected a result but got None" + + # The test can have two valid outcomes: + # 1. Cancel happened before ACK: start_resending returns error code 24 + # 2. ACK arrived before cancel: start_resending completes successfully (no error) + # + # Both are valid behaviors - the cancel feature works correctly in case 1, + # and in case 2, the message simply completed before we could cancel it. + # This can happen in fast environments (like CI with local mixnet). + if start_resending_error == "No error raised": + print("⚠️ Copy command completed before cancel took effect (ACK arrived quickly)") + print("✅ Test passed - cancel was called but copy command completed first (valid race condition)") + elif "Start resending cancelled" in start_resending_error: + print("✅ start_resending_copy_command returned with expected error code 24 (Start resending cancelled)") + else: + # Unexpected error + raise AssertionError(f"Unexpected error: {start_resending_error}") + + finally: + client.stop() + + +@pytest.mark.asyncio +async def test_multiple_messages_sequence(): + """ + Test sending multiple messages with incrementing indices. + + This test verifies: + 1. Multiple messages can be sent using the same WriteCap + 2. Each message is written to a different MessageBoxIndex + 3. All messages can be read back in sequence + 4. The messages are reassembled correctly + + Note: Each MessageBoxIndex holds one message. To send multiple messages, + you must increment the index for each new message. + """ + alice_client = await setup_thin_client() + bob_client = await setup_thin_client() + + try: + print("\n=== Test: Multiple messages with incrementing indices ===") + + # Alice creates keypair + alice_seed = os.urandom(32) + alice_keypair = await alice_client.new_keypair(alice_seed) + print(f"✓ Alice created keypair") + + num_messages = 3 + messages = [ + b"Message 1 from Alice to Bob", + b"Message 2 from Alice to Bob", + b"Message 3 from Alice to Bob" + ] + + # Alice sends multiple messages, each to a different index + # We increment the index for each message using the BACAP HKDF logic + current_index = alice_keypair.first_message_index + indices_used = [current_index] # Track all indices for reading later + + for i, message in enumerate(messages): + print(f"\n--- Sending message {i+1}/{num_messages} ---") + print(f"Message: {message.decode()}") + + # Encrypt and send to current index + write_result = await alice_client.encrypt_write( + message, alice_keypair.write_cap, current_index + ) + + alice_plaintext = await alice_client.start_resending_encrypted_message( + read_cap=None, + write_cap=alice_keypair.write_cap, + next_message_index=None, + reply_index=0, + envelope_descriptor=write_result.envelope_descriptor, + message_ciphertext=write_result.message_ciphertext, + envelope_hash=write_result.envelope_hash + ) + + print(f"✓ Message {i+1} sent to index successfully") + + # Increment index for next message + if i < num_messages - 1: # Don't increment after last message + current_index = await alice_client.next_message_box_index(current_index) + indices_used.append(current_index) + + print("\n--- Waiting for message propagation ---") + await asyncio.sleep(5) + + # Bob reads all messages from their respective indices + print("\n--- Bob reads all messages ---") + received_messages = [] + bob_current_index = alice_keypair.first_message_index + + for i in range(num_messages): + print(f"\nReading message {i+1}/{num_messages}...") + read_result = await bob_client.encrypt_read( + alice_keypair.read_cap, bob_current_index + ) + + bob_plaintext = await bob_client.start_resending_encrypted_message( + read_cap=alice_keypair.read_cap, + write_cap=None, + next_message_index=read_result.next_message_index, + reply_index=0, + envelope_descriptor=read_result.envelope_descriptor, + message_ciphertext=read_result.message_ciphertext, + envelope_hash=read_result.envelope_hash + ) + + print(f"Bob received: {bob_plaintext.decode() if bob_plaintext else '(empty)'}") + received_messages.append(bob_plaintext) + + # Increment index for next read + if i < num_messages - 1: + bob_current_index = await bob_client.next_message_box_index(bob_current_index) + + # Verify all messages were received correctly + for i, (sent, received) in enumerate(zip(messages, received_messages)): + assert received == sent, f"Message {i+1} mismatch: expected {sent}, got {received}" + + print("\n✅ Multiple messages test completed successfully!") + print(f"✅ All {num_messages} messages sent and received correctly with proper index incrementing!") + + finally: + alice_client.stop() + bob_client.stop() + + +@pytest.mark.asyncio +async def test_create_courier_envelopes_from_payload(): + """ + Test the CreateCourierEnvelopesFromPayload API. + + This test verifies: + 1. Alice creates a large payload that will be automatically chunked + 2. Alice calls create_courier_envelopes_from_payload to get copy stream chunks + 3. Alice writes all copy stream chunks to a temporary copy stream channel + 4. Alice sends the Copy command to the courier + 5. Bob reads all chunks from the destination channel and reconstructs the payload + + This mirrors the Go test: TestCreateCourierEnvelopesFromPayload + """ + import struct + + alice_client = await setup_thin_client() + bob_client = await setup_thin_client() + + try: + print("\n=== Test: CreateCourierEnvelopesFromPayload ===") + + # Step 1: Alice creates destination WriteCap for the final payload + print("\n--- Step 1: Alice creates destination WriteCap ---") + dest_seed = os.urandom(32) + dest_keypair = await alice_client.new_keypair(dest_seed) + print("✓ Alice created destination WriteCap and derived ReadCap for Bob") + + # Step 2: Alice creates temporary copy stream + print("\n--- Step 2: Alice creates temporary copy stream ---") + temp_seed = os.urandom(32) + temp_keypair = await alice_client.new_keypair(temp_seed) + print("✓ Alice created temporary copy stream WriteCap") + + # Step 3: Create a large payload that will be chunked + print("\n--- Step 3: Creating large payload ---") + # Create a payload large enough to require multiple chunks + # Use a 4-byte length prefix so Bob knows when to stop reading + random_data = os.urandom(5 * 1024) # 5KB of random data + # Length-prefix the payload: [4 bytes length][random data] + large_payload = struct.pack(">I", len(random_data)) + random_data + print(f"✓ Alice created large payload ({len(large_payload)} bytes = 4 byte length prefix + {len(random_data)} bytes data)") + + # Step 4: Create copy stream chunks from the large payload + print("\n--- Step 4: Creating copy stream chunks from large payload ---") + query_id = alice_client.new_query_id() + stream_id = alice_client.new_stream_id() + result = await alice_client.create_courier_envelopes_from_payload( + query_id, stream_id, large_payload, dest_keypair.write_cap, dest_keypair.first_message_index, True # is_last + ) + assert result.envelopes, "create_courier_envelopes_from_payload returned empty chunks" + copy_stream_chunks = result.envelopes + num_chunks = len(copy_stream_chunks) + print(f"✓ Alice created {num_chunks} copy stream chunks from {len(large_payload)} byte payload") + + # Step 5: Write all copy stream chunks to the temporary copy stream + print("\n--- Step 5: Writing copy stream chunks to temporary channel ---") + temp_index = temp_keypair.first_message_index + + for i, chunk in enumerate(copy_stream_chunks): + print(f"--- Writing copy stream chunk {i+1}/{num_chunks} to temporary channel ---") + + # Encrypt the chunk for the copy stream + write_result = await alice_client.encrypt_write( + chunk, temp_keypair.write_cap, temp_index + ) + print(f"✓ Alice encrypted copy stream chunk {i+1} ({len(chunk)} bytes plaintext -> {len(write_result.message_ciphertext)} bytes ciphertext)") + + # Send the encrypted chunk to the copy stream + await alice_client.start_resending_encrypted_message( + read_cap=None, + write_cap=temp_keypair.write_cap, + next_message_index=None, + reply_index=0, + envelope_descriptor=write_result.envelope_descriptor, + message_ciphertext=write_result.message_ciphertext, + envelope_hash=write_result.envelope_hash + ) + print(f"✓ Alice sent copy stream chunk {i+1} to temporary channel") + + # Increment temp index for next chunk + temp_index = await alice_client.next_message_box_index(temp_index) + + # Wait for all chunks to propagate to the copy stream + print("\n--- Waiting for copy stream chunks to propagate (30 seconds) ---") + await asyncio.sleep(30) + + # Step 6: Send Copy command to courier using ARQ + print("\n--- Step 6: Sending Copy command to courier via ARQ ---") + await alice_client.start_resending_copy_command(temp_keypair.write_cap) + print("✓ Alice copy command completed successfully via ARQ") + + # Step 7: Bob reads chunks until we have the full payload (based on length prefix) + print("\n--- Step 7: Bob reads all chunks and reconstructs payload ---") + bob_index = dest_keypair.first_message_index + reconstructed_payload = b"" + expected_length = 0 + chunk_num = 0 + + while True: + chunk_num += 1 + print(f"--- Bob reading chunk {chunk_num} ---") + + # Bob encrypts read request + read_result = await bob_client.encrypt_read( + dest_keypair.read_cap, bob_index + ) + print(f"✓ Bob encrypted read request {chunk_num}") + + # Bob sends read request and receives chunk + bob_plaintext = await bob_client.start_resending_encrypted_message( + read_cap=dest_keypair.read_cap, + write_cap=None, + next_message_index=read_result.next_message_index, + reply_index=0, + envelope_descriptor=read_result.envelope_descriptor, + message_ciphertext=read_result.message_ciphertext, + envelope_hash=read_result.envelope_hash + ) + assert bob_plaintext, f"Bob: Failed to receive chunk {chunk_num}" + print(f"✓ Bob received and decrypted chunk {chunk_num} ({len(bob_plaintext)} bytes)") + + # Append chunk to reconstructed payload + reconstructed_payload += bob_plaintext + + # Extract expected length from the first 4 bytes once we have them + if expected_length == 0 and len(reconstructed_payload) >= 4: + expected_length = struct.unpack(">I", reconstructed_payload[:4])[0] + print(f"✓ Bob: Expected payload length is {expected_length} bytes (+ 4 byte prefix = {expected_length + 4} total)") + + # Check if we have the full payload (4 byte prefix + expected_length bytes) + if expected_length > 0 and len(reconstructed_payload) >= expected_length + 4: + print(f"✓ Bob: Received full payload after {chunk_num} chunks") + break + + # Advance to next chunk + bob_index = await bob_client.next_message_box_index(bob_index) + + # Verify the reconstructed payload matches the original + print(f"\n--- Verifying reconstructed payload ({len(reconstructed_payload)} bytes) ---") + assert reconstructed_payload == large_payload, "Reconstructed payload doesn't match original" + print(f"✅ CreateCourierEnvelopesFromPayload test passed! Large payload ({len(random_data)} bytes data) encoded into {num_chunks} copy stream chunks and reconstructed successfully!") + + finally: + alice_client.stop() + bob_client.stop() + + +@pytest.mark.asyncio +async def test_copy_command_multi_channel(): + """ + Test the Copy Command API with multiple destination channels. + + This test verifies: + 1. Alice creates two destination channels (chan1 and chan2) + 2. Alice creates a temporary copy stream channel + 3. Alice creates two payloads - one for each destination channel + 4. Alice calls create_courier_envelopes_from_payload twice with the same streamID but different WriteCaps + 5. Alice writes all copy stream chunks to the temporary channel + 6. Alice sends the Copy command to the courier + 7. Bob reads from both destination channels and verifies the payloads + + This mirrors the Go test: TestCopyCommandMultiChannel + """ + alice_client = await setup_thin_client() + bob_client = await setup_thin_client() + + try: + print("\n=== Test: Copy Command Multi-Channel ===") + + # Step 1: Alice creates two destination channels + print("\n--- Step 1: Alice creates two destination channels ---") + + # Channel 1 + chan1_seed = os.urandom(32) + chan1_keypair = await alice_client.new_keypair(chan1_seed) + print("✓ Alice created Channel 1 (WriteCap and ReadCap)") + + # Channel 2 + chan2_seed = os.urandom(32) + chan2_keypair = await alice_client.new_keypair(chan2_seed) + print("✓ Alice created Channel 2 (WriteCap and ReadCap)") + + # Step 2: Alice creates temporary copy stream + print("\n--- Step 2: Alice creates temporary copy stream ---") + temp_seed = os.urandom(32) + temp_keypair = await alice_client.new_keypair(temp_seed) + print("✓ Alice created temporary copy stream WriteCap") + + # Step 3: Create two payloads - one for each destination channel + print("\n--- Step 3: Creating payloads for each channel ---") + payload1 = b"This is the secret message for Channel 1. It contains important information." + print(f"✓ Alice created payload1 for Channel 1 ({len(payload1)} bytes)") + payload2 = b"This is the confidential data for Channel 2. Handle with care and discretion." + print(f"✓ Alice created payload2 for Channel 2 ({len(payload2)} bytes)") + + # Step 4: Create copy stream chunks using same streamID but different WriteCaps + print("\n--- Step 4: Creating copy stream chunks for both channels ---") + query_id = alice_client.new_query_id() + stream_id = alice_client.new_stream_id() + + # First call: payload1 -> channel 1 (is_last=False) + result1 = await alice_client.create_courier_envelopes_from_payload( + query_id, stream_id, payload1, chan1_keypair.write_cap, chan1_keypair.first_message_index, False + ) + assert result1.envelopes, "create_courier_envelopes_from_payload returned empty chunks for channel 1" + print(f"✓ Alice created {len(result1.envelopes)} chunks for Channel 1") + + # Second call: payload2 -> channel 2 (is_last=True) + result2 = await alice_client.create_courier_envelopes_from_payload( + query_id, stream_id, payload2, chan2_keypair.write_cap, chan2_keypair.first_message_index, True + ) + assert result2.envelopes, "create_courier_envelopes_from_payload returned empty chunks for channel 2" + print(f"✓ Alice created {len(result2.envelopes)} chunks for Channel 2") + + # Combine all chunks + all_chunks = result1.envelopes + result2.envelopes + print(f"✓ Alice total chunks to write to temp channel: {len(all_chunks)}") + + # Step 5: Write all copy stream chunks to the temporary channel + print("\n--- Step 5: Writing all chunks to temporary channel ---") + temp_index = temp_keypair.first_message_index + + for i, chunk in enumerate(all_chunks): + print(f"--- Writing chunk {i+1}/{len(all_chunks)} to temporary channel ---") + + # Encrypt the chunk for the copy stream + write_result = await alice_client.encrypt_write( + chunk, temp_keypair.write_cap, temp_index + ) + print(f"✓ Alice encrypted chunk {i+1} ({len(chunk)} bytes plaintext -> {len(write_result.message_ciphertext)} bytes ciphertext)") + + # Send the encrypted chunk to the copy stream + await alice_client.start_resending_encrypted_message( + read_cap=None, + write_cap=temp_keypair.write_cap, + next_message_index=None, + reply_index=0, + envelope_descriptor=write_result.envelope_descriptor, + message_ciphertext=write_result.message_ciphertext, + envelope_hash=write_result.envelope_hash + ) + print(f"✓ Alice sent chunk {i+1} to temporary channel") + + # Increment temp index for next chunk + temp_index = await alice_client.next_message_box_index(temp_index) + + # Wait for chunks to propagate + print("\n--- Waiting for copy stream chunks to propagate (30 seconds) ---") + await asyncio.sleep(30) + + # Step 6: Send Copy command to courier using ARQ + print("\n--- Step 6: Sending Copy command to courier via ARQ ---") + await alice_client.start_resending_copy_command(temp_keypair.write_cap) + print("✓ Alice copy command completed successfully via ARQ") + + # Step 7: Bob reads from both channels and verifies payloads + print("\n--- Step 7: Bob reads from both channels ---") + + # Read from Channel 1 + print("--- Bob reading from Channel 1 ---") + bob1_read_result = await bob_client.encrypt_read( + chan1_keypair.read_cap, chan1_keypair.first_message_index + ) + assert bob1_read_result.message_ciphertext, "Bob: EncryptRead returned empty ciphertext for Channel 1" + + bob1_plaintext = await bob_client.start_resending_encrypted_message( + read_cap=chan1_keypair.read_cap, + write_cap=None, + next_message_index=bob1_read_result.next_message_index, + reply_index=0, + envelope_descriptor=bob1_read_result.envelope_descriptor, + message_ciphertext=bob1_read_result.message_ciphertext, + envelope_hash=bob1_read_result.envelope_hash + ) + assert bob1_plaintext, "Bob: Failed to receive data from Channel 1" + print(f"✓ Bob received from Channel 1: {bob1_plaintext.decode()} ({len(bob1_plaintext)} bytes)") + + # Verify Channel 1 payload + assert bob1_plaintext == payload1, "Channel 1 payload doesn't match" + print("✓ Channel 1 payload verified!") + + # Read from Channel 2 + print("--- Bob reading from Channel 2 ---") + bob2_read_result = await bob_client.encrypt_read( + chan2_keypair.read_cap, chan2_keypair.first_message_index + ) + assert bob2_read_result.message_ciphertext, "Bob: EncryptRead returned empty ciphertext for Channel 2" + + bob2_plaintext = await bob_client.start_resending_encrypted_message( + read_cap=chan2_keypair.read_cap, + write_cap=None, + next_message_index=bob2_read_result.next_message_index, + reply_index=0, + envelope_descriptor=bob2_read_result.envelope_descriptor, + message_ciphertext=bob2_read_result.message_ciphertext, + envelope_hash=bob2_read_result.envelope_hash + ) + assert bob2_plaintext, "Bob: Failed to receive data from Channel 2" + print(f"✓ Bob received from Channel 2: {bob2_plaintext.decode()} ({len(bob2_plaintext)} bytes)") + + # Verify Channel 2 payload + assert bob2_plaintext == payload2, "Channel 2 payload doesn't match" + print("✓ Channel 2 payload verified!") + + print("\n✅ Multi-channel Copy Command test passed! Payload1 written to Channel 1 and Payload2 written to Channel 2 atomically!") + + finally: + alice_client.stop() + bob_client.stop() + + +@pytest.mark.asyncio +async def test_copy_command_multi_channel_efficient(): + """ + Test the space-efficient multi-channel copy command using + create_courier_envelopes_from_multi_payload which packs envelopes from different + destinations together without wasting space in the copy stream. + + This test verifies: + - The create_courier_envelopes_from_multi_payload API works correctly + - Multiple destination payloads are packed efficiently into the copy stream + - The courier processes all envelopes and writes to the correct destinations + + This mirrors the Go test: TestCopyCommandMultiChannelEfficient + """ + alice_client = await setup_thin_client() + bob_client = await setup_thin_client() + + try: + print("\n=== Test: Efficient Multi-Channel Copy Command ===") + + # Step 1: Alice creates two destination channels + print("\n--- Step 1: Alice creates two destination channels ---") + + # Channel 1 + chan1_seed = os.urandom(32) + chan1_keypair = await alice_client.new_keypair(chan1_seed) + print("✓ Alice created Channel 1 (WriteCap and ReadCap)") + + # Channel 2 + chan2_seed = os.urandom(32) + chan2_keypair = await alice_client.new_keypair(chan2_seed) + print("✓ Alice created Channel 2 (WriteCap and ReadCap)") + + # Step 2: Alice creates temporary copy stream + print("\n--- Step 2: Alice creates temporary copy stream ---") + temp_seed = os.urandom(32) + temp_keypair = await alice_client.new_keypair(temp_seed) + print("✓ Alice created temporary copy stream WriteCap") + + # Step 3: Create two payloads - one for each destination channel + print("\n--- Step 3: Creating payloads for each channel ---") + payload1 = b"This is the secret message for Channel 1 using the efficient multi-channel API." + print(f"✓ Alice created payload1 for Channel 1 ({len(payload1)} bytes)") + payload2 = b"This is the confidential data for Channel 2 packed efficiently with payload1." + print(f"✓ Alice created payload2 for Channel 2 ({len(payload2)} bytes)") + + # Step 4: Create copy stream chunks using efficient multi-destination API + print("\n--- Step 4: Creating copy stream chunks using efficient multi-destination API ---") + stream_id = alice_client.new_stream_id() + + # Create destinations list with both payloads + destinations = [ + { + "payload": payload1, + "write_cap": chan1_keypair.write_cap, + "start_index": chan1_keypair.first_message_index, + }, + { + "payload": payload2, + "write_cap": chan2_keypair.write_cap, + "start_index": chan2_keypair.first_message_index, + }, + ] + + # Single call packs all envelopes efficiently + result = await alice_client.create_courier_envelopes_from_multi_payload( + stream_id, destinations, True # is_last + ) + assert result.envelopes, "create_courier_envelopes_from_multi_payload returned empty chunks" + all_chunks = result.envelopes + print(f"✓ Alice created {len(all_chunks)} chunks for both channels (packed efficiently)") + + # Step 5: Write all copy stream chunks to the temporary channel + print("\n--- Step 5: Writing all chunks to temporary channel ---") + temp_index = temp_keypair.first_message_index + + for i, chunk in enumerate(all_chunks): + print(f"--- Writing chunk {i+1}/{len(all_chunks)} to temporary channel ---") + + # Encrypt the chunk for the copy stream + write_result = await alice_client.encrypt_write( + chunk, temp_keypair.write_cap, temp_index + ) + print(f"✓ Alice encrypted chunk {i+1} ({len(chunk)} bytes plaintext -> {len(write_result.message_ciphertext)} bytes ciphertext)") + + # Send the encrypted chunk to the copy stream + await alice_client.start_resending_encrypted_message( + read_cap=None, + write_cap=temp_keypair.write_cap, + next_message_index=None, + reply_index=0, + envelope_descriptor=write_result.envelope_descriptor, + message_ciphertext=write_result.message_ciphertext, + envelope_hash=write_result.envelope_hash + ) + print(f"✓ Alice sent chunk {i+1} to temporary channel") + + # Increment temp index for next chunk + temp_index = await alice_client.next_message_box_index(temp_index) + + # Wait for chunks to propagate + print("\n--- Waiting for copy stream chunks to propagate (30 seconds) ---") + await asyncio.sleep(30) + + # Step 6: Send Copy command to courier using ARQ + print("\n--- Step 6: Sending Copy command to courier via ARQ ---") + await alice_client.start_resending_copy_command(temp_keypair.write_cap) + print("✓ Alice copy command completed successfully via ARQ") + + # Step 7: Bob reads from both channels and verifies payloads + print("\n--- Step 7: Bob reads from both channels ---") + + # Read from Channel 1 + print("--- Bob reading from Channel 1 ---") + bob1_read_result = await bob_client.encrypt_read( + chan1_keypair.read_cap, chan1_keypair.first_message_index + ) + + bob1_plaintext = await bob_client.start_resending_encrypted_message( + read_cap=chan1_keypair.read_cap, + write_cap=None, + next_message_index=bob1_read_result.next_message_index, + reply_index=0, + envelope_descriptor=bob1_read_result.envelope_descriptor, + message_ciphertext=bob1_read_result.message_ciphertext, + envelope_hash=bob1_read_result.envelope_hash + ) + assert bob1_plaintext, "Bob: Failed to receive data from Channel 1" + print(f"✓ Bob received from Channel 1: {bob1_plaintext.decode()} ({len(bob1_plaintext)} bytes)") + assert bob1_plaintext == payload1, "Channel 1 payload doesn't match" + print("✓ Channel 1 payload verified!") + + # Read from Channel 2 + print("--- Bob reading from Channel 2 ---") + bob2_read_result = await bob_client.encrypt_read( + chan2_keypair.read_cap, chan2_keypair.first_message_index + ) + + bob2_plaintext = await bob_client.start_resending_encrypted_message( + read_cap=chan2_keypair.read_cap, + write_cap=None, + next_message_index=bob2_read_result.next_message_index, + reply_index=0, + envelope_descriptor=bob2_read_result.envelope_descriptor, + message_ciphertext=bob2_read_result.message_ciphertext, + envelope_hash=bob2_read_result.envelope_hash + ) + assert bob2_plaintext, "Bob: Failed to receive data from Channel 2" + print(f"✓ Bob received from Channel 2: {bob2_plaintext.decode()} ({len(bob2_plaintext)} bytes)") + assert bob2_plaintext == payload2, "Channel 2 payload doesn't match" + print("✓ Channel 2 payload verified!") + + print("\n✅ Efficient multi-channel Copy Command test passed! Both payloads packed efficiently and delivered to correct channels!") + + finally: + alice_client.stop() + bob_client.stop() + + +@pytest.mark.asyncio +async def test_tombstoning(): + """ + Test the tombstoning API. + + This test verifies: + 1. Alice writes a message to a box + 2. Bob reads and verifies the message + 3. Alice tombstones the box (overwrites with zeros) + 4. Bob reads again and verifies the tombstone + + This mirrors the Go test: TestTombstoning + """ + from katzenpost_thinclient import PigeonholeGeometry + + alice_client = await setup_thin_client() + bob_client = await setup_thin_client() + + try: + print("\n=== Test: Tombstoning ===") + + # Create a geometry with a reasonable payload size + # In a real scenario, this would come from the PKI document + geometry = PigeonholeGeometry( + max_plaintext_payload_length=1024, + nike_name="x25519" + ) + + # Create keypair + seed = os.urandom(32) + keypair = await alice_client.new_keypair(seed) + print("✓ Created keypair") + + # Step 1: Alice writes a message + print("\n--- Step 1: Alice writes a message ---") + message = b"Secret message that will be tombstoned" + write_result = await alice_client.encrypt_write( + message, keypair.write_cap, keypair.first_message_index + ) + + await alice_client.start_resending_encrypted_message( + read_cap=None, + write_cap=keypair.write_cap, + next_message_index=None, + reply_index=0, + envelope_descriptor=write_result.envelope_descriptor, + message_ciphertext=write_result.message_ciphertext, + envelope_hash=write_result.envelope_hash + ) + print("✓ Alice wrote message") + + # Wait for message propagation + print("--- Waiting for message propagation (5 seconds) ---") + await asyncio.sleep(5) + + # Step 2: Bob reads and verifies + print("\n--- Step 2: Bob reads and verifies ---") + read_result = await bob_client.encrypt_read( + keypair.read_cap, keypair.first_message_index + ) + bob_plaintext = await bob_client.start_resending_encrypted_message( + read_cap=keypair.read_cap, + write_cap=None, + next_message_index=read_result.next_message_index, + reply_index=0, + envelope_descriptor=read_result.envelope_descriptor, + message_ciphertext=read_result.message_ciphertext, + envelope_hash=read_result.envelope_hash + ) + assert bob_plaintext == message, f"Message mismatch: expected {message}, got {bob_plaintext}" + print(f"✓ Bob read message: {bob_plaintext.decode()}") + + # Step 3: Alice tombstones the box + print("\n--- Step 3: Alice tombstones the box ---") + tomb_result = await alice_client.tombstone_box( + keypair.write_cap, keypair.first_message_index + ) + await alice_client.start_resending_encrypted_message( + read_cap=None, + write_cap=keypair.write_cap, + next_message_index=None, + reply_index=None, + envelope_descriptor=tomb_result.envelope_descriptor, + message_ciphertext=tomb_result.message_ciphertext, + envelope_hash=tomb_result.envelope_hash + ) + print("✓ Alice tombstoned the box") + + # Wait for tombstone propagation + print("--- Waiting for tombstone propagation (30 seconds) ---") + await asyncio.sleep(30) + + # Step 4: Bob reads again and verifies tombstone + print("\n--- Step 4: Bob reads again and verifies tombstone ---") + read_result2 = await bob_client.encrypt_read( + keypair.read_cap, keypair.first_message_index + ) + bob_plaintext2 = await bob_client.start_resending_encrypted_message( + read_cap=keypair.read_cap, + write_cap=None, + next_message_index=read_result2.next_message_index, + reply_index=0, + envelope_descriptor=read_result2.envelope_descriptor, + message_ciphertext=read_result2.message_ciphertext, + envelope_hash=read_result2.envelope_hash + ) + + assert len(bob_plaintext2) == 0, "Expected tombstone plaintext (empty)" + print("✓ Bob verified tombstone (empty payload)") + + print("\n✅ Tombstoning test passed!") + + finally: + alice_client.stop() + bob_client.stop() + + +@pytest.mark.asyncio +async def test_tombstone_range(): + """ + Test the tombstone_range API. + + This test verifies: + 1. Alice writes multiple messages to sequential boxes + 2. Alice tombstones a range of boxes + 3. The result shows the correct number of tombstoned boxes + + This mirrors the Go TombstoneRange functionality. + """ + from katzenpost_thinclient import PigeonholeGeometry + + alice_client = await setup_thin_client() + + try: + print("\n=== Test: Tombstone Range ===") + + # Create a geometry with a reasonable payload size + geometry = PigeonholeGeometry( + max_plaintext_payload_length=1024, + nike_name="x25519" + ) + + # Create keypair + seed = os.urandom(32) + keypair = await alice_client.new_keypair(seed) + print("✓ Created keypair") + + # Write 3 messages to sequential boxes + num_messages = 3 + current_index = keypair.first_message_index + + print(f"\n--- Writing {num_messages} messages ---") + for i in range(num_messages): + message = f"Message {i+1} to be tombstoned".encode() + write_result = await alice_client.encrypt_write( + message, keypair.write_cap, current_index + ) + await alice_client.start_resending_encrypted_message( + read_cap=None, + write_cap=keypair.write_cap, + next_message_index=None, + reply_index=0, + envelope_descriptor=write_result.envelope_descriptor, + message_ciphertext=write_result.message_ciphertext, + envelope_hash=write_result.envelope_hash + ) + print(f"✓ Wrote message {i+1}") + + if i < num_messages - 1: + current_index = await alice_client.next_message_box_index(current_index) + + # Wait for messages to propagate + print("--- Waiting for message propagation (30 seconds) ---") + await asyncio.sleep(30) + + # Tombstone the range - creates envelopes without sending + print(f"\n--- Creating tombstones for {num_messages} boxes ---") + result = await alice_client.tombstone_range(keypair.write_cap, keypair.first_message_index, num_messages) + + assert 'envelopes' in result, "Result should contain 'envelopes' list" + assert len(result['envelopes']) == num_messages, f"Expected {num_messages} envelopes, got {len(result['envelopes'])}" + assert 'next' in result, "Result should contain 'next' index" + print(f"✓ Created {len(result['envelopes'])} tombstone envelopes") + + # Send all tombstone envelopes + print(f"\n--- Sending {num_messages} tombstone envelopes ---") + for i, envelope in enumerate(result['envelopes']): + await alice_client.start_resending_encrypted_message( + read_cap=None, + write_cap=keypair.write_cap, + next_message_index=None, + reply_index=None, + envelope_descriptor=envelope['envelope_descriptor'], + message_ciphertext=envelope['message_ciphertext'], + envelope_hash=envelope['envelope_hash'] + ) + print(f"✓ Sent tombstone envelope {i+1}") + + print(f"\n✅ Tombstone range test passed! Created and sent {num_messages} tombstones successfully!") + + finally: + alice_client.stop() + + +@pytest.mark.asyncio +async def test_box_id_not_found_error(): + """ + Test that we receive a BoxIDNotFoundError when reading from a box that doesn't exist. + + This test verifies: + 1. A new keypair is created (but no message is written) + 2. Attempting to read from the non-existent box raises BoxIDNotFoundError + 3. The error can be caught using isinstance() similar to Go's errors.Is() + + This mirrors the Go test: TestBoxIDNotFoundError + """ + from katzenpost_thinclient import BoxIDNotFoundError + + client = await setup_thin_client() + + try: + print("\n=== Test: BoxIDNotFoundError ===") + + # Create a fresh keypair - but do NOT write anything to it + seed = os.urandom(32) + keypair = await client.new_keypair(seed) + print("✓ Created fresh keypair (no messages written)") + + # Encrypt a read request for the non-existent box + read_result = await client.encrypt_read( + keypair.read_cap, keypair.first_message_index + ) + print("✓ Encrypted read request for non-existent box") + + # Attempt to read - this should raise BoxIDNotFoundError + # Use start_resending_encrypted_message_no_retry to get immediate error without retries + print("--- Attempting to read from non-existent box ---") + try: + await client.start_resending_encrypted_message_no_retry( + read_cap=keypair.read_cap, + write_cap=None, + next_message_index=read_result.next_message_index, + reply_index=0, + envelope_descriptor=read_result.envelope_descriptor, + message_ciphertext=read_result.message_ciphertext, + envelope_hash=read_result.envelope_hash + ) + # If we get here, the test failed - we expected an error + raise AssertionError("Expected BoxIDNotFoundError but no exception was raised") + except BoxIDNotFoundError as e: + # This is the expected case + print(f"✓ Received expected BoxIDNotFoundError: {e}") + print("✅ BoxIDNotFoundError test passed!") + + finally: + client.stop() + + +@pytest.mark.asyncio +async def test_box_already_exists_error(): + """ + Test that we receive a BoxAlreadyExistsError when writing to a box that already has data. + + This test verifies: + 1. A new keypair is created and a message is successfully written + 2. Attempting to write to the same box again raises BoxAlreadyExistsError + 3. The error can be caught using isinstance() similar to Go's errors.Is() + + This mirrors the Go test: TestBoxAlreadyExistsError + """ + from katzenpost_thinclient import BoxAlreadyExistsError + + client = await setup_thin_client() + + try: + print("\n=== Test: BoxAlreadyExistsError ===") + + # Create a fresh keypair + seed = os.urandom(32) + keypair = await client.new_keypair(seed) + print("✓ Created keypair") + + # First write - should succeed + print("--- First write (should succeed) ---") + message1 = b"First message - this should work" + write_result1 = await client.encrypt_write( + message1, keypair.write_cap, keypair.first_message_index + ) + print("✓ Encrypted first message") + + await client.start_resending_encrypted_message( + read_cap=None, + write_cap=keypair.write_cap, + next_message_index=None, + reply_index=None, + envelope_descriptor=write_result1.envelope_descriptor, + message_ciphertext=write_result1.message_ciphertext, + envelope_hash=write_result1.envelope_hash + ) + print("✓ First write succeeded") + + # Wait for propagation + print("Waiting for message propagation...") + await asyncio.sleep(5) + + # Second write to the SAME box - should fail with BoxAlreadyExists + print("--- Second write to same box (should fail) ---") + message2 = b"Second message - this should fail" + write_result2 = await client.encrypt_write( + message2, keypair.write_cap, keypair.first_message_index + ) + print("✓ Encrypted second message") + + # Send the second write - should fail with BoxAlreadyExists + # Use start_resending_encrypted_message_return_box_exists to get the error instead of + # treating it as idempotent success + try: + await client.start_resending_encrypted_message_return_box_exists( + read_cap=None, + write_cap=keypair.write_cap, + next_message_index=None, + reply_index=None, + envelope_descriptor=write_result2.envelope_descriptor, + message_ciphertext=write_result2.message_ciphertext, + envelope_hash=write_result2.envelope_hash + ) + # If we get here, the test failed - we expected an error + raise AssertionError("Expected BoxAlreadyExistsError but no exception was raised") + except BoxAlreadyExistsError as e: + # This is the expected case + print(f"✓ Received expected BoxAlreadyExistsError: {e}") + print("✅ BoxAlreadyExistsError test passed!") + + finally: + client.stop()