diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0d4c856..0ccd1f4 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -96,6 +96,6 @@ jobs: # new experimental lints that we don't want to enforce yet. RUSTDOCFLAGS: "" - # 8. Check Examples (Compilation Only) - - name: Check Examples Compile + # 8. Check Example (Compilation Only) + - name: Check Example Compiles run: cargo build --examples --verbose diff --git a/.gitignore b/.gitignore index b9921a1..28c8683 100644 --- a/.gitignore +++ b/.gitignore @@ -26,4 +26,8 @@ *.bak # Cargo local config with secrets -/.cargo/config.toml \ No newline at end of file +/.cargo/config.toml + +# Generated example reports (keep the dir, ignore contents) +/examples/reports/* +!/examples/reports/.gitkeep diff --git a/Cargo.lock b/Cargo.lock index 19fc3bd..2d1e67b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -184,9 +184,9 @@ checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" [[package]] name = "autocfg" -version = "1.5.0" +version = "1.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" +checksum = "f2032f911046de80f0a198e0901378627c33f59ea0ac00e363d481118bd70a53" [[package]] name = "axum" @@ -339,9 +339,9 @@ checksum = "36f64beae40a84da1b4b26ff2761a5b895c12adc41dc25aaee1c4f2bbfe97a6e" [[package]] name = "brotli" -version = "8.0.2" +version = "8.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4bd8b9603c7aa97359dbd97ecf258968c95f3adddd6db2f7e7a5bef101c84560" +checksum = "8119e4516436f5708bbc474a9d395bf12f1b5395e93a92a56e647ac3388c8610" dependencies = [ "alloc-no-stdlib", "alloc-stdlib", @@ -350,9 +350,9 @@ dependencies = [ [[package]] name = "brotli-decompressor" -version = "5.0.0" +version = "5.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "874bb8112abecc98cbd6d81ea4fa7e94fb9449648c93cc89aa40c81c24d7de03" +checksum = "5962523e1b92ce1b5e793d9169b9943eece10d39f62550bc04bb605d75b94924" dependencies = [ "alloc-no-stdlib", "alloc-stdlib", @@ -369,9 +369,9 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.20.2" +version = "3.20.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d20789868f4b01b2f2caec9f5c4e0213b41e3e5702a50157d699ae31ced2fcb" +checksum = "72f5acc6cb2ba439de613abc23857ec3d78374d8ed5ac84e9d11336e87da8649" [[package]] name = "bytecheck" @@ -441,9 +441,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.62" +version = "1.2.63" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1dce859f0832a7d088c4f1119888ab94ef4b5d6795d1ce05afb7fe159d79f98" +checksum = "556e016178bb5662a08681bbe0f00f8e17631781a4dfc8c45e466e4b185ec27f" dependencies = [ "find-msvc-tools", "jobserver", @@ -476,7 +476,7 @@ dependencies = [ [[package]] name = "chapaty" -version = "1.1.4" +version = "1.2.0" dependencies = [ "anyhow", "async-channel", @@ -509,7 +509,6 @@ dependencies = [ "strum", "strum_macros 0.28.0", "thiserror", - "tikv-jemallocator", "time", "tokio", "tokio-util", @@ -567,9 +566,9 @@ dependencies = [ [[package]] name = "compact_str" -version = "0.9.0" +version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fdb1325a1cece981e8a296ab8f0f9b63ae357bd0784a9faaf548cc7b480707a" +checksum = "9dfdd1c2274d9aa354115b09dc9a901d6c5576818cdf70d14cae2bdb47df00ab" dependencies = [ "castaway", "cfg-if", @@ -877,9 +876,9 @@ dependencies = [ [[package]] name = "displaydoc" -version = "0.2.5" +version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" +checksum = "1ac70aa55017e108007fbaf5aa0f54b021c98f92ff8af59d42eda9da96e3dd4f" dependencies = [ "proc-macro2", "quote", @@ -903,9 +902,9 @@ checksum = "d0881ea181b1df73ff77ffaaf9c7544ecc11e82fba9b5f27b262a3c73a332555" [[package]] name = "either" -version = "1.15.0" +version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" +checksum = "91622ff5e7162018101f2fea40d6ebf4a78bbe5a49736a2020649edf9693679e" [[package]] name = "embedded-io" @@ -1375,9 +1374,9 @@ dependencies = [ [[package]] name = "http" -version = "1.4.0" +version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3ba2a386d7f85a81f119ad7498ebe444d2e22c2af0b86b069416ace48b3311a" +checksum = "8be7462df143984c4598a256ef469b251d7d7f9e271135073e78fc535414f3d0" dependencies = [ "bytes", "itoa", @@ -1426,9 +1425,9 @@ checksum = "135b12329e5e3ce057a9f972339ea52bc954fe1e9358ef27f95e89716fbc5424" [[package]] name = "hyper" -version = "1.9.0" +version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6299f016b246a94207e63da54dbe807655bf9e00044f73ded42c3ac5305fbcca" +checksum = "55281c53a1894c864990125767da440a4e630446785086f52523b20033b74498" dependencies = [ "atomic-waker", "bytes", @@ -1725,9 +1724,9 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.98" +version = "0.3.99" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67df7112613f8bfd9150013a0314e196f4800d3201ae742489d999db2f979f08" +checksum = "142bc4740e452c1e57ade0cbc129f139c9093e354346f0872ef985f4f5cf5f11" dependencies = [ "cfg-if", "futures-util", @@ -1772,9 +1771,9 @@ checksum = "b6d2cec3eae94f9f509c767b45932f1ada8350c4bdb85af2fcab4a3c14807981" [[package]] name = "libredox" -version = "0.1.16" +version = "0.1.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e02f3bb43d335493c96bf3fd3a321600bf6bd07ed34bc64118e9293bdffea46c" +checksum = "f02ab6bace2054fb888a3c16f990117b579d14a3088e472d63c6011fa185c9d3" dependencies = [ "libc", ] @@ -1808,9 +1807,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.29" +version = "0.4.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" +checksum = "616ec5685824bcc94416c6d4a7a446eea774a31efd7062c8480ba6fd06d7a6e5" [[package]] name = "lru-slab" @@ -1874,9 +1873,9 @@ dependencies = [ [[package]] name = "memchr" -version = "2.8.0" +version = "2.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" +checksum = "6b947ae49db0d222b1dbc6b113ce7248a3fc3a6ca21b696717bfc000ba4484d8" [[package]] name = "memmap2" @@ -1905,9 +1904,9 @@ dependencies = [ [[package]] name = "mio" -version = "1.2.0" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50b7e5b27aa02a74bac8c3f23f448f8d87ff11f92d3aac1a6ed369ee08cc56c1" +checksum = "02bd0af71c67b473010cbbc60715ee815645a4dc942899111f494b4b737d6fda" dependencies = [ "libc", "wasi", @@ -1981,9 +1980,9 @@ dependencies = [ [[package]] name = "num-conv" -version = "0.2.1" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c6673768db2d862beb9b39a78fdcb1a69439615d5794a1be50caa9bc92c81967" +checksum = "521739c6d2bac4aa25192232afe6841231376b2b26d4d9fae5ecf8ca5772e441" [[package]] name = "num-integer" @@ -2069,9 +2068,9 @@ checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50" [[package]] name = "openssl" -version = "0.10.79" +version = "0.10.80" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf0b434746ee2832f4f0baf10137e1cabb18cbe6912c69e2e33263c45250f542" +checksum = "a45fa2aa886c42762255da344f0a0d313e254066c46aad76f300c3d3da62d967" dependencies = [ "bitflags", "cfg-if", @@ -2100,9 +2099,9 @@ checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe" [[package]] name = "openssl-sys" -version = "0.9.115" +version = "0.9.116" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "158fe5b292746440aa6e7a7e690e55aeb72d41505e2804c23c6973ad0e9c9781" +checksum = "f28a22dc7140cda5f096e5e7724a6962ca81a7f8bfd2979f9b18c11af56318c4" dependencies = [ "cc", "libc", @@ -2202,18 +2201,18 @@ dependencies = [ [[package]] name = "pin-project" -version = "1.1.12" +version = "1.1.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cbf0d9e68100b3a7989b4901972f265cd542e560a3a8a724e1e20322f4d06ce9" +checksum = "2466b2336ed02bcdca6b294417127b90ec92038d1d5c4fbeac971a922e0e0924" dependencies = [ "pin-project-internal", ] [[package]] name = "pin-project-internal" -version = "1.1.12" +version = "1.1.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a990e22f43e84855daf260dded30524ef4a9021cc7541c26540500a50b624389" +checksum = "c96395f0a926bc13b1c17622aaddda1ecb55d49c8f1bf9777e4d877800a43f8b" dependencies = [ "proc-macro2", "quote", @@ -2959,9 +2958,9 @@ dependencies = [ [[package]] name = "pulldown-cmark" -version = "0.13.3" +version = "0.13.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c3a14896dfa883796f1cb410461aef38810ea05f2b2c33c5aded3649095fdad" +checksum = "e9f068eba8e7071c5f9511831b44f32c740d5adf574e990f946ddb53db2f314e" dependencies = [ "bitflags", "memchr", @@ -3628,9 +3627,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.149" +version = "1.0.150" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" +checksum = "e8014e44b4736ed0538adeecded0fce2a272f22dc9578a7eb6b2d9993c74cfb9" dependencies = [ "indexmap 2.14.0", "itoa", @@ -3717,9 +3716,9 @@ dependencies = [ [[package]] name = "shlex" -version = "1.3.0" +version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" +checksum = "f8fadd59c855ef2080decdef8ff161eb6661b86933c9d82e5ba29dc602a55aba" [[package]] name = "signal-hook" @@ -3807,9 +3806,9 @@ checksum = "1b6b67fb9a61334225b5b790716f609cd58395f895b3fe8b328786812a40bc3b" [[package]] name = "socket2" -version = "0.6.3" +version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a766e1110788c36f4fa1c2b71b387a7815aa65f88ce0229841826633d93723e" +checksum = "52d1cfed4120b4d927bf7c0f86d2087a4a7d6027c906d9f9d525a80573b9be51" dependencies = [ "libc", "windows-sys 0.61.2", @@ -4052,26 +4051,6 @@ dependencies = [ "cfg-if", ] -[[package]] -name = "tikv-jemalloc-sys" -version = "0.6.1+5.3.0-1-ge13ca993e8ccb9ba9847cc330696e02839f328f7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd8aa5b2ab86a2cefa406d889139c162cbb230092f7d1d7cbc1716405d852a3b" -dependencies = [ - "cc", - "libc", -] - -[[package]] -name = "tikv-jemallocator" -version = "0.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0359b4327f954e0567e69fb191cf1436617748813819c94b8cd4a431422d053a" -dependencies = [ - "libc", - "tikv-jemalloc-sys", -] - [[package]] name = "time" version = "0.3.47" @@ -4211,9 +4190,9 @@ dependencies = [ [[package]] name = "toml_edit" -version = "0.25.11+spec-1.1.0" +version = "0.25.12+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b59c4d22ed448339746c59b905d24568fcbb3ab65a500494f7b8c3e97739f2b" +checksum = "d2153edc6955a6c354fad8f5efd38b6a8769bdccf9fe50f8e1329f81b0baa5d7" dependencies = [ "indexmap 2.14.0", "toml_datetime", @@ -4319,9 +4298,9 @@ dependencies = [ [[package]] name = "tower-http" -version = "0.6.10" +version = "0.6.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68d6fdd9f81c2819c9a8b0e0cd91660e7746a8e6ea2ba7c6b2b057985f6bcb51" +checksum = "4cfcf7e2740e6fc6d4d688b4ef00650406bb94adf4731e43c096c3a19fe40840" dependencies = [ "bitflags", "bytes", @@ -4443,9 +4422,9 @@ checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" [[package]] name = "typenum" -version = "1.20.0" +version = "1.20.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40ce102ab67701b8526c123c1bab5cbe42d7040ccfd0f64af1a385808d2f43de" +checksum = "b6f5e870be6c3b371b77fe0ee0bafb859fa4964b4404c27de1d380043c4dda20" [[package]] name = "unicase" @@ -4575,9 +4554,9 @@ checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" [[package]] name = "uuid" -version = "1.23.1" +version = "1.23.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ddd74a9687298c6858e9b88ec8935ec45d22e8fd5e6394fa1bd4e99a87789c76" +checksum = "d258b83ceec21034727ecee8c382cfa6c3e133699b0742c64571814fb420c9f7" dependencies = [ "getrandom 0.4.2", "js-sys", @@ -4666,9 +4645,9 @@ dependencies = [ [[package]] name = "wasm-bindgen" -version = "0.2.121" +version = "0.2.122" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49ace1d07c165b0864824eee619580c4689389afa9dc9ed3a4c75040d82e6790" +checksum = "3ed04576f974d2b2fba0f38c51dbc5518011e38c36bf1143164be765528fd409" dependencies = [ "cfg-if", "once_cell", @@ -4680,9 +4659,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.71" +version = "0.4.72" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96492d0d3ffba25305a7dc88720d250b1401d7edca02cc3bcd50633b424673b8" +checksum = "9473dbd2991ae90b6291c3c32c30c6187ac49aa32f9905d1cce280ec1e110b0f" dependencies = [ "js-sys", "wasm-bindgen", @@ -4690,9 +4669,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.121" +version = "0.2.122" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e68e6f4afd367a562002c05637acb8578ff2dea1943df76afb9e83d177c8578" +checksum = "916151b09da36bd82f6615cbf3a419e2f0ba23a03c6160e8e92eb6bd4aa1dec6" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -4700,9 +4679,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.121" +version = "0.2.122" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d95a9ec35c64b2a7cb35d3fead40c4238d0940c86d107136999567a4703259f2" +checksum = "299047362ccbfce148b67ab7e73349f77748e00c8296f9542adfad2ad82c5c5e" dependencies = [ "bumpalo", "proc-macro2", @@ -4713,9 +4692,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.121" +version = "0.2.122" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4e0100b01e9f0d03189a92b96772a1fb998639d981193d7dbab487302513441" +checksum = "9a929b2c61f11ba3e9bc35b50c1f25cb38e0e892c0c231ae2b8cf78d5dad4437" dependencies = [ "unicode-ident", ] @@ -4769,9 +4748,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.98" +version = "0.3.99" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b572dff8bcf38bad0fa19729c89bb5748b2b9b1d8be70cf90df697e3a8f32aa" +checksum = "6d621441cfc37b84979402712047321980c178f299193a3589d05b99e8763436" dependencies = [ "js-sys", "wasm-bindgen", @@ -5073,9 +5052,9 @@ checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650" [[package]] name = "winnow" -version = "1.0.2" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2ee1708bef14716a11bae175f579062d4554d95be2c6829f518df847b7b3fdd0" +checksum = "0592e1c9d151f854e6fd382574c3a0855250e1d9b2f99d9281c6e6391af352f1" dependencies = [ "memchr", ] @@ -5220,18 +5199,18 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.8.48" +version = "0.8.50" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eed437bf9d6692032087e337407a86f04cd8d6a16a37199ed57949d415bd68e9" +checksum = "3b065d4f0e55f82fae73202e189638116a87c55ab6b8e6c2721e13dd9d854ad1" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.8.48" +version = "0.8.50" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70e3cd084b1788766f53af483dd21f93881ff30d7320490ec3ef7526d203bad4" +checksum = "0b631b19d36a892ab55420c92dbc83ccd79274f25be714855d3074aa71cab639" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index f6ddc7e..57fb6ea 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "chapaty" -version = "1.1.4" +version = "1.2.0" edition = "2024" authors = ["Len Williamson "] description = "An event-driven Rust engine for building and evaluating quantitative trading agents. Features a Gym-style API for algorithmic backtesting and reinforcement learning." @@ -95,7 +95,3 @@ time = { version = "0.3", features = ["macros", "formatting"] } tracing-appender = "0.2" tracing-subscriber = { version = "0.3", features = ["env-filter", "json", "time"] } anyhow = "1.0" - -# Only pull jemalloc on Linux -[target.'cfg(target_os = "linux")'.dev-dependencies] -tikv-jemallocator = "0.6" diff --git a/README.md b/README.md index a02400e..1ffc576 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ ## Getting Started -> **Fast Track:** Use the [**Chapaty Starter Template**](https://github.com/LenWilliamson/chapaty-template) to instantly bootstrap a new project. It includes pre-configured AI prompts for backtesting with a LLM of your choice, built-in dashboard setups with [Quantstats](https://github.com/ranaroussi/quantstats), and best-practice strategy examples. +> **Fast Track:** Use the [**Chapaty Starter Template**][chapatyTemplateLink] to instantly bootstrap a new project. It includes pre-configured AI prompts for backtesting with a LLM of your choice and built-in dashboard setups with [QuantStats][quantstatsLink]. For a library of ready-to-run strategies, including the top TradingView setups backtested across million-agent grids, see [**chapaty-zoo**][chapatyZooLink]. Chapaty supports two primary workflows: **Parallel Backtesting** for evaluating agent grids, and the **Canonical Gym Loop** for step-by-step control over the environment. @@ -18,11 +18,15 @@ Chapaty supports two primary workflows: **Parallel Backtesting** for evaluating For grid searches, Chapaty leverages `rayon` to evaluate agents in parallel, automatically tracking the top performers. -**Run this example:** [`examples/news_breakout_grid.rs`](examples/news_breakout_grid.rs) +**Run this example:** [`examples/quickstart.rs`](examples/quickstart.rs) -```rust -use std::path::Path; +```bash +cargo run --release --example quickstart +``` + +Under the hood, the parallel path builds the environment, generates a grid of agents, and evaluates them in one call: +```rust use chapaty::prelude::*; #[tokio::main] @@ -31,16 +35,17 @@ async fn main() -> Result<()> { // See the full example file for the 'environment()' helper implementation. let mut env = environment().await?; - // 2. Generate the Agent Grid - // Creates a lazy parallel iterator of 1,000,000+ distinct parameter combinations. - // This allows efficient streaming without loading all agents into RAM. - let (count, agents) = news_breakout_grid(); - - println!("Evaluating {count} agents..."); + // 2. Create the Agent Grid + // Creates a vector of 1M distinct parameter combinations. `NoOpAgent` is a + // placeholder. Swap in your own strategy (see chapaty-zoo for examples). + let num_agents = 1_000_000; + let agents = (0..num_agents) + .map(|uid| (uid, NoOpAgent::default())) + .collect::>(); // 3. Execute Parallel Evaluation - // Chapaty manages the batching and threading, retaining the Top-100 agents - let leaderboard = env.evaluate_agents(agents, 100, count as u64)?; + // Chapaty manages the batching and threading, retaining the Top-100 agents. + let leaderboard = env.evaluate_agents(agents, 100)?; // 4. Export the Leaderboard // Results are saved as a structured CSV dataset. @@ -55,8 +60,6 @@ async fn main() -> Result<()> { For custom integrations or those who prefer full control over the observation-action transition loop, Chapaty implements a standard API inspired by OpenAI Gym. ```rust -use std::path::Path; - use chapaty::prelude::*; #[tokio::main] @@ -99,7 +102,7 @@ async fn main() -> ChapatyResult<()> { > **Note:** Environments are **async** because they stream large financial datasets directly from cloud storage (e.g. GCS, BigQuery, HuggingFace). -For practical, _ready-to-run agents_, check out the `examples/` directory to get started quickly. +The [`examples/quickstart.rs`](examples/quickstart.rs) file demonstrates both workflows end to end for a single-agent baseline, a parallel grid, report export, and logging setup. For real, ready-to-run strategies, see [**chapaty-zoo**][chapatyZooLink]. ## Related Projects @@ -136,3 +139,6 @@ By using Chapaty, you acknowledge that **you are solely responsible for any trad [discord]: https://discord.gg/MmMAB6NCuK [gymnasiumLink]: https://github.com/Farama-Foundation/Gymnasium [deepmindLink]: https://github.com/deepmind/dm_control +[quantstatsLink]: https://github.com/ranaroussi/quantstats +[chapatyTemplateLink]: https://github.com/LenWilliamson/chapaty-template +[chapatyZooLink]: https://github.com/LenWilliamson/chapaty-zoo diff --git a/bin/pre-push.sh b/bin/pre-push.sh index 41015a9..3fe5b70 100755 --- a/bin/pre-push.sh +++ b/bin/pre-push.sh @@ -15,7 +15,7 @@ echo -e "${BLUE}>>> Starting Local CI Pipeline for Chapaty...${NC}" # JOB 1: Compliance & Security # ============================================================================== -echo -e "\n${YELLOW}[1/11] Checking Security (Secrets)...${NC}" +echo -e "\n${YELLOW}[1/10] Checking Security (Secrets)...${NC}" # Check if .cargo/config.toml is tracked by git if git ls-files --error-unmatch .cargo/config.toml > /dev/null 2>&1; then echo -e "${RED}[FAIL] CRITICAL: .cargo/config.toml is being tracked by git! Remove it immediately.${NC}" @@ -23,12 +23,12 @@ if git ls-files --error-unmatch .cargo/config.toml > /dev/null 2>&1; then fi echo -e "${GREEN}[OK] No leaked secrets in git index.${NC}" -echo -e "\n${YELLOW}[2/11] Checking Formatting...${NC}" +echo -e "\n${YELLOW}[2/10] Checking Formatting...${NC}" # Fails if code is not formatted. Remove '--check' to auto-format instead. cargo fmt -- --check || { echo -e "${RED}[FAIL] Formatting invalid. Run 'cargo fmt' to fix.${NC}"; exit 1; } echo -e "${GREEN}[OK] Formatting is correct.${NC}" -echo -e "\n${YELLOW}[3/11] Checking Architecture Guardrails...${NC}" +echo -e "\n${YELLOW}[3/10] Checking Architecture Guardrails...${NC}" # Prevent circular dependencies via prelude imports within the library if grep -r "use crate::prelude::" src/; then echo -e "${RED}[FAIL] Architecture violation: Internal imports from 'crate::prelude' found.${NC}" @@ -40,7 +40,7 @@ echo -e "${GREEN}[OK] Architecture compliant.${NC}" # JOB 2: Build, Test & Verify # ============================================================================== -echo -e "\n${YELLOW}[4/11] Security Audit (Dependencies)...${NC}" +echo -e "\n${YELLOW}[4/10] Security Audit (Dependencies)...${NC}" # Check if cargo-audit is installed if ! command -v cargo-audit &> /dev/null; then echo -e "${RED}[FAIL] 'cargo-audit' is not installed.${NC}" @@ -50,26 +50,26 @@ fi cargo audit echo -e "${GREEN}[OK] Dependencies audited.${NC}" -echo -e "\n${YELLOW}[5/11] Linting (Clippy)...${NC}" +echo -e "\n${YELLOW}[5/10] Linting (Clippy)...${NC}" # Deny warnings to match CI strictness cargo clippy --all-targets --all-features -- -D warnings echo -e "${GREEN}[OK] Code is clean.${NC}" -echo -e "\n${YELLOW}[6/11] Building Workspace...${NC}" +echo -e "\n${YELLOW}[6/10] Building Workspace...${NC}" cargo build --all-features echo -e "${GREEN}[OK] Workspace compiled successfully.${NC}" -echo -e "\n${YELLOW}[7/11] Running Unit Tests...${NC}" +echo -e "\n${YELLOW}[7/10] Running Unit Tests...${NC}" cargo test --all-features echo -e "${GREEN}[OK] All tests passed.${NC}" -echo -e "\n${YELLOW}[8/11] Verifying Documentation...${NC}" +echo -e "\n${YELLOW}[8/10] Verifying Documentation...${NC}" # Ensure documentation builds without warnings (broken links, etc.) export RUSTDOCFLAGS="-D warnings" cargo doc --no-deps --document-private-items echo -e "${GREEN}[OK] Documentation builds successfully.${NC}" -echo -e "\n${YELLOW}[9/11] Verifying Docs.rs Compatibility (Nightly)...${NC}" +echo -e "\n${YELLOW}[9/10] Verifying Docs.rs Compatibility (Nightly)...${NC}" # docs.rs strictly uses the nightly compiler. We run a soft-fail check here. if rustup toolchain list | grep -q nightly; then # We suppress stdout to keep it clean, but let stderr show if it fails. @@ -88,41 +88,14 @@ else fi # ============================================================================== -# NEW STEP: Build All Examples, Then Run (excluding grids) +# Build & Dry-Run the Quickstart Example # ============================================================================== -echo -e "\n${YELLOW}[10/11] Compiling All Examples...${NC}" -# This mirrors CI Step 7: Ensures even grid.rs examples compile properly -cargo build --examples -echo -e "${GREEN}[OK] All examples compiled.${NC}" - -echo -e "\n${YELLOW}[11/11] Running Examples (skipping *grid.rs except noop_grid)...${NC}" - -# Iterate over all .rs files in the examples directory -for file in examples/*.rs; do - # 1. Extract filename (e.g., "news_breakout_grid.rs") - filename=$(basename "$file") - - # 2. Extract example name (remove .rs extension) - example_name="${filename%.*}" - - # 3. Filter: Check if filename contains "grid.rs", but explicitly ALLOW noop_grid - if [[ "$filename" == *"grid.rs"* && "$filename" != "noop_grid.rs" ]]; then - echo -e "${BLUE}[SKIP] Long-running example: $example_name (Compiled, but not run)${NC}" - continue - fi - - echo -ne " Running example: $example_name ... " - - # 4. Run the example - # Redirect stdout to /dev/null to keep terminal clean, but keep stderr for errors. - if cargo run --example "$example_name" > /dev/null; then - echo -e "${GREEN}[PASS]${NC}" - else - echo -e "${RED}[FAIL]${NC}" - # Exit immediately if an example fails - exit 1 - fi -done +echo -e "\n${YELLOW}[10/10] Building & Dry-Running Quickstart Example...${NC}" +# Compile first so a build error is distinct from a runtime error. +cargo build --example quickstart +# Then run it to verify the full logic path (environment load, eval, export) works. +cargo run --release --example quickstart > /dev/null +echo -e "${GREEN}[OK] Quickstart example ran successfully.${NC}" echo -e "\n${GREEN}>>> SUCCESS! All checks passed. Ready to push.${NC}" diff --git a/examples/crossover.rs b/examples/crossover.rs deleted file mode 100644 index ca9098e..0000000 --- a/examples/crossover.rs +++ /dev/null @@ -1,63 +0,0 @@ -use anyhow::{Context, Result}; -use chapaty::{ - gym::trading::agent::crossover::{PrecomputedCrossover, StreamingCrossover}, - prelude::*, -}; -use std::path::Path; - -#[tokio::main] -async fn main() -> Result<()> { - let ohlcv_id = ohlcv_id(); - let fast_sma = SmaWindow(20); - let slow_sma = SmaWindow(50); - let fast_sma_id = SmaId { - parent: ohlcv_id, - length: fast_sma, - }; - let slow_sma_id = SmaId { - parent: ohlcv_id, - length: slow_sma, - }; - - let mut env = environment().await?; - - println!("Running Streaming Crossover Agent..."); - let mut streaming_agent = StreamingCrossover::new(ohlcv_id, fast_sma.0, slow_sma.0); - let journal_stream = env.evaluate_agent(&mut streaming_agent)?; - let file_cfg = FileConfig::default().with_dir(Path::new("examples/reports/streaming_cross")); - journal_stream.to_file_sync(&file_cfg)?; - env.equity_curve_report()? - .into_eod()? - .to_file_sync(&file_cfg)?; - - println!("Running Precomputed Crossover Agent..."); - let mut env_agent = PrecomputedCrossover::new(ohlcv_id, fast_sma_id, slow_sma_id); - let journal_env = env.evaluate_agent(&mut env_agent)?; - let file_cfg = FileConfig::default().with_dir(Path::new("examples/reports/precomputed_cross")); - journal_env.to_file_sync(&file_cfg)?; - env.equity_curve_report()? - .into_eod()? - .to_file_sync(&file_cfg)?; - - Ok(()) -} - -async fn environment() -> Result { - let preset = EnvPreset::BinanceBtcUsdt1dSma20Sma50; - let file_stem = preset.to_string(); - let loc = StorageLocation::HuggingFace { version: None }; - let cfg = IoConfig::new(loc).with_file_stem(&file_stem); - - chapaty::load(preset, &cfg) - .await - .context("Failed to load trading environment") -} - -fn ohlcv_id() -> OhlcvId { - OhlcvId { - broker: DataBroker::Binance, - exchange: Exchange::Binance, - symbol: Symbol::Spot(SpotPair::BtcUsdt), - period: Period::Day(1), - } -} diff --git a/examples/logging.rs b/examples/logging.rs deleted file mode 100644 index 4de5104..0000000 --- a/examples/logging.rs +++ /dev/null @@ -1,154 +0,0 @@ -use std::{env, fs, path::Path, time::Instant}; - -use anyhow::{Context, Result}; -use chapaty::{gym::trading::agent::news::fade::NewsFade, prelude::*}; -use chrono::Duration; -use time::macros::format_description; -use tracing::info; -use tracing_appender::non_blocking::WorkerGuard; -use tracing_subscriber::EnvFilter; - -#[tokio::main] -async fn main() -> Result<()> { - // Create simple logging subscriber - let _guard = init_tracing()?; - - println!("Starting evaluation process..."); - - let build_start = Instant::now(); - let mut env = environment().await?; - let build_time = build_start.elapsed(); - - let mut agent = news_fade()?; - let fade_start = Instant::now(); - let journal = env.evaluate_agent(&mut agent)?; - let fade_time = fade_start.elapsed(); - - let path = Path::new("examples/reports/news_fade"); - journal.to_file_sync(&FileConfig::default().with_dir(path))?; - - println!("\n--- Evaluation Timings ---"); - println!("1. Environment build time: {build_time:?}"); - println!("2. Fade agent run time: {fade_time:?}"); - - // The WorkerGuard ensures all buffered logs are flushed when dropped. - drop(_guard); - - Ok(()) -} - -// ================================================================================================ -// Tracing Configuration -// ================================================================================================ - -fn init_tracing() -> Result> { - let app_name = "chapaty"; - - // Detect if running in container - let in_container = - env::var("CONTAINER").is_ok() || std::path::Path::new("/.dockerenv").exists(); - - let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")); - - if in_container { - // Container mode: log to stdout - tracing_subscriber::fmt() - .json() - .with_env_filter(env_filter) - .with_span_events(tracing_subscriber::fmt::format::FmtSpan::NONE) - .with_current_span(true) - .with_thread_ids(true) - .with_thread_names(true) - .with_timer(tracing_subscriber::fmt::time::UtcTime::rfc_3339()) - .init(); - - info!("Logging to stdout (container mode)"); - Ok(None) - } else { - // Local mode: log to file - let log_dir = dirs::state_dir() - .map(|mut p| { - p.push(app_name); - p.push("logs"); - p - }) - .unwrap_or_else(|| { - let mut home = dirs::home_dir().expect("Failed to find home directory"); - home.push(format!(".local/state/{app_name}/logs")); - home - }); - fs::create_dir_all(&log_dir)?; - - let timestamp = time::OffsetDateTime::now_utc() - .format(&format_description!( - "[year][month][day]-[hour][minute][second]" - )) - .context("Failed to format timestamp")?; - let file_name = format!("{app_name}-{timestamp}.log"); - let file_path = log_dir.join(file_name); - - let file_appender = - tracing_appender::rolling::never(log_dir.clone(), file_path.file_name().unwrap()); - let (non_blocking, guard) = tracing_appender::non_blocking(file_appender); - - tracing_subscriber::fmt() - .json() - .with_env_filter(env_filter) - .with_writer(non_blocking) - .with_span_events(tracing_subscriber::fmt::format::FmtSpan::NONE) - .with_current_span(true) - .with_thread_ids(true) - .with_thread_names(true) - .with_timer(tracing_subscriber::fmt::time::UtcTime::rfc_3339()) - .init(); - - info!(log_file = %file_path.display(), "Logging to file (local mode)"); - Ok(Some(guard)) - } -} - -// ================================================================================================ -// Helper Functions -// ================================================================================================ - -fn news_fade() -> Result { - let agent = NewsFade::baseline(economic_calendar_id(), ohlcv_id()) - .with_candles_after_news(Duration::minutes(14)) - .with_take_profit_risk_factor(0.0) - .with_risk_reward_ratio(0.1)?; - Ok(agent) -} - -async fn environment() -> Result { - let preset = EnvPreset::NinjaTraderCme6eh61mUsEmpHighEventsOnly; - let file_stem = preset.to_string(); - let loc = StorageLocation::HuggingFace { version: None }; - let cfg = IoConfig::new(loc).with_file_stem(&file_stem); - - chapaty::load(preset, &cfg) - .await - .context("Failed to load trading environment") -} - -fn economic_calendar_id() -> EconomicCalendarId { - EconomicCalendarId { - broker: DataBroker::InvestingCom, - data_source: None, - country_code: Some(CountryCode::Us), - category: Some(EconomicCategory::Employment), - importance: Some(EconomicEventImpact::High), - } -} - -fn ohlcv_id() -> OhlcvId { - OhlcvId { - broker: DataBroker::NinjaTrader, - exchange: Exchange::Cme, - symbol: Symbol::Future(FutureContract { - root: FutureRoot::EurUsd, - month: ContractMonth::June, - year: ContractYear::Y6, - }), - period: Period::Minute(1), - } -} diff --git a/examples/news_breakout.rs b/examples/news_breakout.rs deleted file mode 100644 index 79fd348..0000000 --- a/examples/news_breakout.rs +++ /dev/null @@ -1,82 +0,0 @@ -use std::{path::Path, time::Instant}; - -use anyhow::{Context, Result}; -use chapaty::{gym::trading::agent::news::breakout::NewsBreakout, prelude::*}; -use chrono::Duration; - -#[tokio::main] -async fn main() -> Result<()> { - println!("Starting evaluation process..."); - - let build_start = Instant::now(); - let mut env = environment().await?; - let build_time = build_start.elapsed(); - - let mut agent = news_breakout()?; - let breakout_start = Instant::now(); - let journal = env.evaluate_agent(&mut agent)?; - let breakout_time = breakout_start.elapsed(); - - let path = Path::new("examples/reports/news_breakout"); - let file_cfg = FileConfig::default().with_dir(path); - journal.to_file_sync(&file_cfg)?; - journal.cumulative_returns()?.to_file_sync(&file_cfg)?; - journal.portfolio_performance()?.to_file_sync(&file_cfg)?; - journal.trade_stats()?.to_file_sync(&file_cfg)?; - env.equity_curve_report()? - .into_eod()? - .to_file_sync(&file_cfg)?; - - println!("\n--- Evaluation Timings ---"); - println!("1. Environment build time: {build_time:?}"); - println!("2. Breakout agent run time: {breakout_time:?}"); - - Ok(()) -} - -// ================================================================================================ -// Helper Functions -// ================================================================================================ - -fn news_breakout() -> Result { - let agent = NewsBreakout::baseline(economic_calendar_id(), ohlcv_id()) - .with_earliest_entry_candle(Duration::minutes(10)) - .with_latest_entry_candle(Duration::minutes(50)) - .with_stop_loss_risk_factor(1.15) - .with_risk_reward_ratio(1.0 / 0.7)?; - Ok(agent) -} - -async fn environment() -> Result { - let preset = EnvPreset::NinjaTraderCme6eh61mUsEmpHighEventsOnly; - let file_stem = preset.to_string(); - let loc = StorageLocation::HuggingFace { version: None }; - let cfg = IoConfig::new(loc).with_file_stem(&file_stem); - - chapaty::load(preset, &cfg) - .await - .context("Failed to load trading environment") -} - -fn economic_calendar_id() -> EconomicCalendarId { - EconomicCalendarId { - broker: DataBroker::InvestingCom, - data_source: None, - country_code: Some(CountryCode::Us), - category: Some(EconomicCategory::Employment), - importance: Some(EconomicEventImpact::High), - } -} - -fn ohlcv_id() -> OhlcvId { - OhlcvId { - broker: DataBroker::NinjaTrader, - exchange: Exchange::Cme, - symbol: Symbol::Future(FutureContract { - root: FutureRoot::EurUsd, - month: ContractMonth::June, - year: ContractYear::Y6, - }), - period: Period::Minute(1), - } -} diff --git a/examples/news_breakout_grid.rs b/examples/news_breakout_grid.rs deleted file mode 100644 index 4975507..0000000 --- a/examples/news_breakout_grid.rs +++ /dev/null @@ -1,86 +0,0 @@ -use std::{path::Path, time::Instant}; - -use anyhow::{Context, Result}; -use chapaty::{ - gym::trading::agent::news::breakout::{NewsBreakout, NewsBreakoutGrid}, - prelude::*, -}; - -// === BEGIN JEMALLOC CONFIG === -#[cfg(target_os = "linux")] -use tikv_jemallocator::Jemalloc; - -#[cfg(target_os = "linux")] -#[global_allocator] -static GLOBAL: Jemalloc = Jemalloc; -// === END JEMALLOC CONFIG === - -#[tokio::main] -async fn main() -> Result<()> { - println!("Starting evaluation process..."); - - let build_start = Instant::now(); - let mut env = environment().await?; - let build_time = build_start.elapsed(); - - let agents = news_breakout_grid(); - let grid_backtest_start = Instant::now(); - let leaderboard = env.evaluate_agents(agents, 100)?; - let grid_backtest_time = grid_backtest_start.elapsed(); - - let path = Path::new("examples/reports/news_breakout"); - leaderboard.to_file_sync(&FileConfig::default().with_dir(path))?; - - println!("\n--- Evaluation Timings ---"); - println!("1. Environment build time: {build_time:?}"); - println!("2. Breakout agents run time: {grid_backtest_time:?}"); - - Ok(()) -} - -// ================================================================================================ -// Helper Functions -// ================================================================================================ - -fn news_breakout_grid() -> Vec<(usize, NewsBreakout)> { - NewsBreakoutGrid::baseline(economic_calendar_id(), ohlcv_id()) - .expect("Failed to create baseline grid") - // Optional: Constrain the grid for a quick demo run - // .with_stop_loss_risk_factor(GridAxis::new("0.5", "1.5", "0.01").expect("Invalid stop loss axis")) - // .with_risk_reward_ratio(GridAxis::new("0.5", "1.0", "0.1").expect("Invalid RRR axis")) - .build() -} - -async fn environment() -> Result { - let preset = EnvPreset::NinjaTraderCme6eh61mUsEmpHighEventsOnly; - let file_stem = preset.to_string(); - let loc = StorageLocation::HuggingFace { version: None }; - let cfg = IoConfig::new(loc).with_file_stem(&file_stem); - - chapaty::load(preset, &cfg) - .await - .context("Failed to load trading environment") -} - -fn economic_calendar_id() -> EconomicCalendarId { - EconomicCalendarId { - broker: DataBroker::InvestingCom, - data_source: None, - country_code: Some(CountryCode::Us), - category: Some(EconomicCategory::Employment), - importance: Some(EconomicEventImpact::High), - } -} - -fn ohlcv_id() -> OhlcvId { - OhlcvId { - broker: DataBroker::NinjaTrader, - exchange: Exchange::Cme, - symbol: Symbol::Future(FutureContract { - root: FutureRoot::EurUsd, - month: ContractMonth::June, - year: ContractYear::Y6, - }), - period: Period::Minute(1), - } -} diff --git a/examples/news_fade.rs b/examples/news_fade.rs deleted file mode 100644 index 3fb27c1..0000000 --- a/examples/news_fade.rs +++ /dev/null @@ -1,81 +0,0 @@ -use std::{path::Path, time::Instant}; - -use anyhow::{Context, Result}; -use chapaty::{gym::trading::agent::news::fade::NewsFade, prelude::*}; -use chrono::Duration; - -#[tokio::main] -async fn main() -> Result<()> { - println!("Starting evaluation process..."); - - let build_start = Instant::now(); - let mut env = environment().await?; - let build_time = build_start.elapsed(); - - let mut agent = news_fade()?; - let fade_start = Instant::now(); - let journal = env.evaluate_agent(&mut agent)?; - let fade_time = fade_start.elapsed(); - - let path = Path::new("examples/reports/news_fade"); - let file_cfg = FileConfig::default().with_dir(path); - journal.to_file_sync(&file_cfg)?; - journal.cumulative_returns()?.to_file_sync(&file_cfg)?; - journal.portfolio_performance()?.to_file_sync(&file_cfg)?; - journal.trade_stats()?.to_file_sync(&file_cfg)?; - env.equity_curve_report()? - .into_eod()? - .to_file_sync(&file_cfg)?; - - println!("\n--- Evaluation Timings ---"); - println!("1. Environment build time: {build_time:?}"); - println!("2. Fade agent run time: {fade_time:?}"); - - Ok(()) -} - -// ================================================================================================ -// Helper Functions -// ================================================================================================ - -fn news_fade() -> Result { - let agent = NewsFade::baseline(economic_calendar_id(), ohlcv_id()) - .with_candles_after_news(Duration::minutes(8)) - .with_take_profit_risk_factor(1.25) - .with_risk_reward_ratio(1. / 2.8)?; - Ok(agent) -} - -async fn environment() -> Result { - let preset = EnvPreset::NinjaTraderCme6eh61mUsEmpHighEventsOnly; - let file_stem = preset.to_string(); - let loc = StorageLocation::HuggingFace { version: None }; - let cfg = IoConfig::new(loc).with_file_stem(&file_stem); - - chapaty::load(preset, &cfg) - .await - .context("Failed to load trading environment") -} - -fn economic_calendar_id() -> EconomicCalendarId { - EconomicCalendarId { - broker: DataBroker::InvestingCom, - data_source: None, - country_code: Some(CountryCode::Us), - category: Some(EconomicCategory::Employment), - importance: Some(EconomicEventImpact::High), - } -} - -fn ohlcv_id() -> OhlcvId { - OhlcvId { - broker: DataBroker::NinjaTrader, - exchange: Exchange::Cme, - symbol: Symbol::Future(FutureContract { - root: FutureRoot::EurUsd, - month: ContractMonth::June, - year: ContractYear::Y6, - }), - period: Period::Minute(1), - } -} diff --git a/examples/news_fade_grid.rs b/examples/news_fade_grid.rs deleted file mode 100644 index c0c3aae..0000000 --- a/examples/news_fade_grid.rs +++ /dev/null @@ -1,86 +0,0 @@ -use std::{path::Path, time::Instant}; - -use anyhow::{Context, Result}; -use chapaty::{ - gym::trading::agent::news::fade::{NewsFade, NewsFadeGrid}, - prelude::*, -}; - -// === BEGIN JEMALLOC CONFIG === -#[cfg(target_os = "linux")] -use tikv_jemallocator::Jemalloc; - -#[cfg(target_os = "linux")] -#[global_allocator] -static GLOBAL: Jemalloc = Jemalloc; -// === END JEMALLOC CONFIG === - -#[tokio::main] -async fn main() -> Result<()> { - println!("Starting evaluation process..."); - - let build_start = Instant::now(); - let mut env = environment().await?; - let build_time = build_start.elapsed(); - - let agents = news_fade_grid(); - let grid_backtest_start = Instant::now(); - let leaderboard = env.evaluate_agents(agents, 100)?; - let grid_backtest_time = grid_backtest_start.elapsed(); - - let path = Path::new("examples/reports/news_fade"); - leaderboard.to_file_sync(&FileConfig::default().with_dir(path))?; - - println!("\n--- Evaluation Timings ---"); - println!("1. Environment build time: {build_time:?}"); - println!("2. Fade agents run time: {grid_backtest_time:?}"); - - Ok(()) -} - -// ================================================================================================ -// Helper Functions -// ================================================================================================ - -fn news_fade_grid() -> Vec<(usize, NewsFade)> { - NewsFadeGrid::baseline(economic_calendar_id(), ohlcv_id()) - .expect("Failed to create baseline grid") - // Optional: Constrain the grid for a quick demo run - // .with_take_profit_risk_factor(GridAxis::new("0.5", "3.0", "0.1").expect("Invalid TP axis")) - // .with_risk_reward_ratio(GridAxis::new("0.1", "1.0", "0.1").expect("Invalid RRR axis")) - .build() -} - -async fn environment() -> Result { - let preset = EnvPreset::NinjaTraderCme6eh61mUsEmpHighEventsOnly; - let file_stem = preset.to_string(); - let loc = StorageLocation::HuggingFace { version: None }; - let cfg = IoConfig::new(loc).with_file_stem(&file_stem); - - chapaty::load(preset, &cfg) - .await - .context("Failed to load trading environment") -} - -fn economic_calendar_id() -> EconomicCalendarId { - EconomicCalendarId { - broker: DataBroker::InvestingCom, - data_source: None, - country_code: Some(CountryCode::Us), - category: Some(EconomicCategory::Employment), - importance: Some(EconomicEventImpact::High), - } -} - -fn ohlcv_id() -> OhlcvId { - OhlcvId { - broker: DataBroker::NinjaTrader, - exchange: Exchange::Cme, - symbol: Symbol::Future(FutureContract { - root: FutureRoot::EurUsd, - month: ContractMonth::June, - year: ContractYear::Y6, - }), - period: Period::Minute(1), - } -} diff --git a/examples/news_hybrid.rs b/examples/news_hybrid.rs deleted file mode 100644 index d5bae30..0000000 --- a/examples/news_hybrid.rs +++ /dev/null @@ -1,93 +0,0 @@ -use std::{path::Path, time::Instant}; - -use anyhow::{Context, Result}; -use chapaty::{ - gym::trading::agent::news::{breakout::NewsBreakout, fade::NewsFade, hybrid::NewsHybrid}, - prelude::*, -}; -use chrono::Duration; - -#[tokio::main] -async fn main() -> Result<()> { - println!("Starting evaluation process..."); - - let build_start = Instant::now(); - let mut env = environment().await?; - let build_time = build_start.elapsed(); - - let mut agent = news_hybrid()?; - let decision_start = Instant::now(); - let journal = env.evaluate_agent(&mut agent)?; - let decision_time = decision_start.elapsed(); - - let path = Path::new("examples/reports/news_hybrid"); - let file_cfg = FileConfig::default().with_dir(path); - journal.to_file_sync(&file_cfg)?; - journal.cumulative_returns()?.to_file_sync(&file_cfg)?; - journal.portfolio_performance()?.to_file_sync(&file_cfg)?; - journal.trade_stats()?.to_file_sync(&file_cfg)?; - env.equity_curve_report()? - .into_eod()? - .to_file_sync(&file_cfg)?; - - println!("\n--- Evaluation Timings ---"); - println!("1. Environment build time: {build_time:?}"); - println!("2. Hybrid agent run time: {decision_time:?}"); - - Ok(()) -} - -// ================================================================================================ -// Helper Functions -// ================================================================================================ - -fn news_hybrid() -> Result { - let cal_id = economic_calendar_id(); - - let fade = NewsFade::baseline(cal_id, ohlcv_id(Period::Minute(1))) - .with_candles_after_news(Duration::minutes(7)) - .with_take_profit_risk_factor(1.27) - .with_risk_reward_ratio(0.276)?; - - let breakout = NewsBreakout::baseline(cal_id, ohlcv_id(Period::Minute(5))) - .with_earliest_entry_candle(Duration::minutes(8)) - .with_latest_entry_candle(Duration::minutes(50)) - .with_stop_loss_risk_factor(0.89) - .with_risk_reward_ratio(0.726)?; - - Ok(NewsHybrid { breakout, fade }) -} - -async fn environment() -> Result { - let preset = EnvPreset::NinjaTraderCme6eh61m5mUsEmpHighEventsOnly; - let file_stem = preset.to_string(); - let loc = StorageLocation::HuggingFace { version: None }; - let cfg = IoConfig::new(loc).with_file_stem(&file_stem); - - chapaty::load(preset, &cfg) - .await - .context("Failed to load trading environment") -} - -fn economic_calendar_id() -> EconomicCalendarId { - EconomicCalendarId { - broker: DataBroker::InvestingCom, - data_source: None, - country_code: Some(CountryCode::Us), - category: Some(EconomicCategory::Employment), - importance: Some(EconomicEventImpact::High), - } -} - -fn ohlcv_id(period: Period) -> OhlcvId { - OhlcvId { - broker: DataBroker::NinjaTrader, - exchange: Exchange::Cme, - symbol: Symbol::Future(FutureContract { - root: FutureRoot::EurUsd, - month: ContractMonth::June, - year: ContractYear::Y6, - }), - period, - } -} diff --git a/examples/noop_grid.rs b/examples/noop_grid.rs deleted file mode 100644 index e61dc33..0000000 --- a/examples/noop_grid.rs +++ /dev/null @@ -1,64 +0,0 @@ -use anyhow::{Context, Result}; -use chapaty::prelude::*; -use serde::Serialize; -use std::path::Path; -use std::sync::Arc; - -#[derive(Clone, Serialize)] -struct NoOpAgent; - -impl Agent for NoOpAgent { - fn identifier(&self) -> AgentIdentifier { - AgentIdentifier::Named(Arc::new("NoOpAgent".to_string())) - } - - fn reset(&mut self) {} - - fn act(&mut self, _obs: Observation) -> ChapatyResult { - // Return no actions, guaranteeing 0 trades - Ok(Actions::no_op()) - } -} - -#[tokio::main] -async fn main() -> Result<()> { - let mut env = environment().await?; - let num_agents = 5; - let agents: Vec<(usize, NoOpAgent)> = (0..num_agents).map(|uid| (uid, NoOpAgent)).collect(); - - let leaderboard = env.evaluate_agents(agents, 10)?; - - println!( - "Evaluation complete. Leaderboard size: {}", - leaderboard.as_df().height() - ); - - let export_dir = Path::new("examples/reports/noop_grid"); - let file_cfg = FileConfig::default().with_dir(export_dir); - leaderboard.to_file_sync(&file_cfg)?; - - println!("Saved leaderboard to {}", export_dir.display()); - - Ok(()) -} - -async fn environment() -> Result { - let preset = EnvPreset::BinanceBtcUsdt1d; - let file_stem = preset.to_string(); - let loc = StorageLocation::HuggingFace { version: None }; - let cfg = IoConfig::new(loc).with_file_stem(&file_stem); - - chapaty::load(preset, &cfg) - .await - .context("Failed to load trading environment") -} - -#[allow(dead_code)] // Provided for completeness -fn ohlcv_id() -> OhlcvId { - OhlcvId { - broker: DataBroker::Binance, - exchange: Exchange::Binance, - symbol: Symbol::Spot(SpotPair::BtcUsdt), - period: Period::Day(1), - } -} diff --git a/examples/quickstart.rs b/examples/quickstart.rs new file mode 100644 index 0000000..42faeb2 --- /dev/null +++ b/examples/quickstart.rs @@ -0,0 +1,189 @@ +use anyhow::{Context, Result}; +use chapaty::prelude::*; +use serde::Serialize; +use std::{env, fs, path::Path, sync::Arc, time::Instant}; +use time::macros::format_description; +use tracing::{debug, info}; +use tracing_appender::non_blocking::WorkerGuard; +use tracing_subscriber::EnvFilter; + +const LEADERBOARD_TOP_K: usize = 10; +const GRID_SIZE: usize = 400; +const REPORTS_SUBDIR: &str = "examples/reports/quickstart"; + +// ================================================================================================ +// No-Op Agent +// +// A placeholder agent that never trades. It exists only to demonstrate the evaluation API +// (single-agent journals + parallel leaderboards) and the logging setup, without bundling any +// real strategy logic into the core crate. +// +// For real, ready-to-run strategies, see chapaty-zoo: +// https://github.com/LenWilliamson/chapaty-zoo +// ================================================================================================ + +#[derive(Clone, Serialize)] +struct NoOpAgent { + #[serde(skip)] + agent_id: AgentIdentifier, +} + +impl Default for NoOpAgent { + fn default() -> Self { + Self { + agent_id: AgentIdentifier::Named(Arc::new("NoOpAgent".to_string())), + } + } +} + +impl Agent for NoOpAgent { + fn identifier(&self) -> AgentIdentifier { + self.agent_id.clone() + } + + fn reset(&mut self) {} + + // `act` is called millions of times. + // Keep logging here at `debug` so it stays silent under the default `info` filter. + #[tracing::instrument(skip_all)] + fn act(&mut self, _obs: Observation) -> ChapatyResult { + debug!("Returning no actions, guaranteeing 0 trades"); + Ok(Actions::no_op()) + } +} + +// ================================================================================================ +// Main +// ================================================================================================ + +#[tokio::main] +async fn main() -> Result<()> { + let _guard = init_tracing()?; + info!("Starting evaluation example..."); + + let build_start = Instant::now(); + let mut env = environment().await?; + info!(build_time = ?build_start.elapsed(), "Environment ready"); + + let reports_dir = Path::new(env!("CARGO_MANIFEST_DIR")).join(REPORTS_SUBDIR); + let file_cfg = FileConfig::default().with_dir(&reports_dir); + + // === 1. Single-agent baseline: full journal + reports === + let mut baseline = NoOpAgent::default(); + let label = baseline.identifier(); + + let baseline_start = Instant::now(); + info!(%label, "Running baseline backtest..."); + + let journal = env.evaluate_agent(&mut baseline)?; + journal.to_file_sync(&file_cfg)?; + journal.cumulative_returns()?.to_file_sync(&file_cfg)?; + journal.portfolio_performance()?.to_file_sync(&file_cfg)?; + journal.trade_stats()?.to_file_sync(&file_cfg)?; + env.equity_curve_report()? + .into_eod()? + .to_file_sync(&file_cfg)?; + + info!(%label, elapsed = ?baseline_start.elapsed(), "Baseline backtest complete"); + + // === 2. Parallel grid: ranked leaderboard === + let agents = (0..GRID_SIZE) + .map(|uid| (uid, NoOpAgent::default())) + .collect::>(); + + let grid_start = Instant::now(); + info!(grid_size = GRID_SIZE, "Evaluating agents in parallel..."); + + let leaderboard = env.evaluate_agents(agents, LEADERBOARD_TOP_K)?; + leaderboard.to_file_sync(&file_cfg)?; + + info!( + elapsed = ?grid_start.elapsed(), + rows = leaderboard.as_df().height(), + dir = %file_cfg.dir.display(), + "Grid evaluation complete; leaderboard saved" + ); + + // The WorkerGuard ensures all buffered logs are flushed when dropped. + drop(_guard); + Ok(()) +} + +async fn environment() -> Result { + let preset = EnvPreset::BinanceBtcUsdt1d; + let file_stem = preset.to_string(); + let loc = StorageLocation::HuggingFace { version: None }; + let cfg = IoConfig::new(loc).with_file_stem(&file_stem); + + chapaty::load(preset, &cfg) + .await + .context("Failed to load trading environment") +} + +// ================================================================================================ +// Tracing Configuration +// +// JSON to stdout in containers, or to a timestamped file under the OS state dir locally. +// ================================================================================================ + +fn init_tracing() -> Result> { + let app_name = "chapaty"; + + let in_container = env::var("CONTAINER").is_ok() || Path::new("/.dockerenv").exists(); + + let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")); + + if in_container { + // Container mode: log to stdout + tracing_subscriber::fmt() + .json() + .with_env_filter(env_filter) + .with_span_events(tracing_subscriber::fmt::format::FmtSpan::NONE) + .with_current_span(true) + .with_thread_ids(true) + .with_thread_names(true) + .with_timer(tracing_subscriber::fmt::time::UtcTime::rfc_3339()) + .init(); + + info!("Logging to stdout (container mode)"); + Ok(None) + } else { + // Local mode: log to file + let log_dir = dirs::state_dir() + .map(|mut p| { + p.push(app_name); + p.push("logs"); + p + }) + .unwrap_or_else(|| { + let mut home = dirs::home_dir().expect("Failed to find home directory"); + home.push(format!(".local/state/{app_name}/logs")); + home + }); + fs::create_dir_all(&log_dir)?; + + let timestamp = time::OffsetDateTime::now_utc() + .format(&format_description!( + "[year][month][day]-[hour][minute][second]" + )) + .context("Failed to format timestamp")?; + let file_name = format!("{app_name}-{timestamp}.log"); + + let file_appender = tracing_appender::rolling::never(&log_dir, &file_name); + let (non_blocking, guard) = tracing_appender::non_blocking(file_appender); + + tracing_subscriber::fmt() + .json() + .with_env_filter(env_filter) + .with_writer(non_blocking) + .with_span_events(tracing_subscriber::fmt::format::FmtSpan::NONE) + .with_current_span(true) + .with_thread_ids(true) + .with_thread_names(true) + .with_timer(tracing_subscriber::fmt::time::UtcTime::rfc_3339()) + .init(); + + info!(log_file = %log_dir.join(&file_name).display(), "Logging to file (local mode)"); + Ok(Some(guard)) + } +} diff --git a/src/data.rs b/src/data.rs index 93cee4f..3d36987 100644 --- a/src/data.rs +++ b/src/data.rs @@ -1,8 +1,8 @@ +pub mod batch_indicator; pub mod common; -pub mod config; pub mod domain; pub mod episode; pub mod event; pub mod filter; -pub mod indicator; +pub mod query; pub mod view; diff --git a/src/data/indicator.rs b/src/data/batch_indicator.rs similarity index 82% rename from src/data/indicator.rs rename to src/data/batch_indicator.rs index 7bba875..41717f6 100644 --- a/src/data/indicator.rs +++ b/src/data/batch_indicator.rs @@ -16,14 +16,36 @@ pub struct SmaWindow(pub u16); pub struct RsiWindow(pub u16); #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] -pub enum TechnicalIndicator { +pub enum BatchOhlcvIndicator { Ema(EmaWindow), Sma(SmaWindow), Rsi(RsiWindow), } +/// A trait for data configurations that support derived technical analysis. +impl BatchOhlcvIndicator { + pub fn pre_compute(&self, lf: LazyFrame) -> ChapatyResult { + match self { + BatchOhlcvIndicator::Ema(ema) => ema.pre_compute_ema(lf), + BatchOhlcvIndicator::Sma(sma) => sma.pre_compute_sma(lf), + BatchOhlcvIndicator::Rsi(rsi) => rsi.pre_compute_rsi(lf), + } + } +} + +pub trait WithBatchIndicators: Sized { + type BatchIndicator: Clone; + + fn with_indicator(self, kind: Self::BatchIndicator) -> Self; + fn with_indicators(self, kinds: &[Self::BatchIndicator]) -> Self { + kinds + .iter() + .fold(self, |acc, kind| acc.with_indicator(kind.clone())) + } +} + impl EmaWindow { - pub fn pre_compute_ema(&self, lf: LazyFrame) -> ChapatyResult { + fn pre_compute_ema(&self, lf: LazyFrame) -> ChapatyResult { let window = self.0; // Standard EMA formula: alpha = 2 / (span + 1) @@ -60,7 +82,7 @@ impl EmaWindow { } } impl SmaWindow { - pub fn pre_compute_sma(&self, lf: LazyFrame) -> ChapatyResult { + fn pre_compute_sma(&self, lf: LazyFrame) -> ChapatyResult { let window = self.0; let options = RollingOptionsFixedWindow { window_size: window as usize, @@ -86,7 +108,7 @@ impl SmaWindow { } impl RsiWindow { - pub fn pre_compute_rsi(&self, lf: LazyFrame) -> ChapatyResult { + fn pre_compute_rsi(&self, lf: LazyFrame) -> ChapatyResult { let window = self.0; // Wilder's Smoothing for RSI: alpha = 1 / N let alpha = 1.0 / (window as f64); diff --git a/src/data/domain.rs b/src/data/domain.rs index faa7e8d..679b921 100644 --- a/src/data/domain.rs +++ b/src/data/domain.rs @@ -1,6 +1,6 @@ use serde::{Deserialize, Serialize}; use std::{fmt, str::FromStr}; -use strum::{Display, EnumIter, IntoStaticStr}; +use strum::{AsRefStr, Display, EnumIter, IntoStaticStr}; use strum_macros::EnumString; use crate::{ @@ -404,6 +404,32 @@ impl TryFrom for EconomicDataSource { } } +#[derive( + Copy, + Clone, + Debug, + Hash, + PartialEq, + Eq, + Deserialize, + Serialize, + PartialOrd, + Ord, + EnumIter, + EnumString, + Display, + AsRefStr, + IntoStaticStr, + Default, +)] +pub enum PriceSource { + /// Evaluate the absolute extremes (High for peaks, Low for valleys) + #[default] + HighLow, + /// Evaluate the candle bodies (Close for peaks, Open/Close for valleys) + OpenClose, +} + #[derive( Copy, Clone, @@ -616,6 +642,10 @@ pub enum FutureRoot { NzdUsd, #[strum(serialize = "btc")] Btc, + #[strum(serialize = "es")] + EminiSp500, + #[strum(serialize = "nq")] + EminiNasdaq100, } #[derive( @@ -879,12 +909,18 @@ impl TryFrom for EconomicEventImpact { pub enum CountryCode { /// Australia Au, + /// Brazil + Br, /// Canada Ca, + /// China + Cn, /// Euro Zone Ez, /// United Kingdom Gb, + /// India + In, /// Japan Jp, /// New Zealand @@ -959,7 +995,6 @@ impl Instrument for SpotPair { } fn tick_value_usd(&self) -> f64 { - // For spot, 1 unit of movement = 1 USD per unit held self.tick_size() } } @@ -973,13 +1008,18 @@ impl Instrument for FutureRoot { FutureRoot::GbpUsd => 0.0001, FutureRoot::JpyUsd => 0.0000005, FutureRoot::Btc => 5.0, + FutureRoot::EminiSp500 | FutureRoot::EminiNasdaq100 => 0.25, } } fn tick_value_usd(&self) -> f64 { match self { FutureRoot::EurUsd | FutureRoot::GbpUsd | FutureRoot::JpyUsd => 6.25, - FutureRoot::AudUsd | FutureRoot::CadUsd | FutureRoot::NzdUsd => 5.0, + FutureRoot::AudUsd + | FutureRoot::CadUsd + | FutureRoot::NzdUsd + | FutureRoot::EminiNasdaq100 => 5.0, + FutureRoot::EminiSp500 => 12.50, FutureRoot::Btc => 25.0, } } diff --git a/src/data/event.rs b/src/data/event.rs index c645c6e..79ef5a1 100644 --- a/src/data/event.rs +++ b/src/data/event.rs @@ -11,13 +11,13 @@ use strum::{Display, EnumString, IntoStaticStr}; use crate::{ data::{ + batch_indicator::{EmaWindow, RsiWindow, SmaWindow}, common::{ProfileAggregation, ProfileBinStats}, domain::{ CandleDirection, Count, CountryCode, DataBroker, EconomicCategory, EconomicDataSource, EconomicEventImpact, EconomicValue, Exchange, ExecutionDepth, LiquiditySide, MarketType, Period, Price, Quantity, Symbol, TradeId, Volume, }, - indicator::{EmaWindow, RsiWindow, SmaWindow}, }, error::{ChapatyError, ChapatyResult, DataError}, gym::trading::types::TradeType, @@ -170,6 +170,23 @@ impl Ohlcv { } } +/// A wrapper around an OHLCV candle that includes its absolute index in the stream. +#[derive(Debug, Clone, Copy)] +pub struct IndexedOhlcv { + pub candle: Ohlcv, + pub index: usize, +} + +impl MarketEvent for IndexedOhlcv { + fn point_in_time(&self) -> DateTime { + self.candle.point_in_time() + } + + fn opened_at(&self) -> DateTime { + self.candle.opened_at() + } +} + // ================================================================================================ // Trade // ================================================================================================ diff --git a/src/data/config.rs b/src/data/query.rs similarity index 69% rename from src/data/config.rs rename to src/data/query.rs index f7e0ee7..fbafc6c 100644 --- a/src/data/config.rs +++ b/src/data/query.rs @@ -4,13 +4,13 @@ use serde::{Deserialize, Serialize}; use crate::{ data::{ + batch_indicator::{BatchOhlcvIndicator, WithBatchIndicators}, common::ProfileAggregation, domain::{ CountryCode, DataBroker, EconomicCategory, EconomicDataSource, EconomicEventImpact, Exchange, Period, Symbol, }, event::{EconomicCalendarId, OhlcvId, TpoId, TradesId, VolumeProfileId}, - indicator::{EmaWindow, RsiWindow, SmaWindow, TechnicalIndicator}, }, error::ChapatyResult, }; @@ -24,7 +24,7 @@ use crate::{ /// OHLCV data represents aggregated price and volume information over specified time periods, /// commonly used for candlestick charts and technical analysis. #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] -pub struct OhlcvSpotConfig { +pub struct OhlcvSpotQuery { /// The data broker to query from. pub broker: DataBroker, @@ -43,7 +43,7 @@ pub struct OhlcvSpotConfig { pub batch_size: i32, // Data configurations that support derived technical analysis. - pub indicators: Vec, + pub indicators: Vec, } /// Configuration for retrieving OHLCV (Open, High, Low, Close, Volume) data from futures markets. @@ -51,7 +51,7 @@ pub struct OhlcvSpotConfig { /// Similar to spot OHLCV data, but specifically for futures contracts which include /// additional fields like open interest and funding rates. #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] -pub struct OhlcvFutureConfig { +pub struct OhlcvFutureQuery { /// The data broker to query from. pub broker: DataBroker, @@ -70,7 +70,7 @@ pub struct OhlcvFutureConfig { pub batch_size: i32, // Data configurations that support derived technical analysis. - pub indicators: Vec, + pub indicators: Vec, } // ================================================================================================ @@ -82,7 +82,7 @@ pub struct OhlcvFutureConfig { /// Trade data represents individual trades or price updates at the finest granularity, /// capturing every market transaction with microsecond precision. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)] -pub struct TradeSpotConfig { +pub struct TradeSpotQuery { /// The data broker to query from. pub broker: DataBroker, @@ -109,7 +109,7 @@ pub struct TradeSpotConfig { /// and time, showing where trading activity has occurred and helping identify /// key support/resistance levels. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)] -pub struct TpoSpotConfig { +pub struct TpoSpotQuery { /// The data broker to query from. pub broker: DataBroker, @@ -134,7 +134,7 @@ pub struct TpoSpotConfig { /// /// TPO data for futures markets, providing Market Profile insights for futures contracts. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)] -pub struct TpoFutureConfig { +pub struct TpoFutureQuery { /// The data broker to query from. pub broker: DataBroker, @@ -165,7 +165,7 @@ pub struct TpoFutureConfig { /// helping identify high-volume nodes (HVN) and low-volume nodes (LVN) that often act /// as support or resistance. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)] -pub struct VolumeProfileSpotConfig { +pub struct VolumeProfileSpotQuery { /// The data broker to query from. pub broker: DataBroker, @@ -195,7 +195,7 @@ pub struct VolumeProfileSpotConfig { /// Economic calendar data provides scheduled releases of economic indicators, /// central bank announcements, and other market-moving events. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)] -pub struct EconomicCalendarConfig { +pub struct EconomicCalendarQuery { /// The data broker to query from. pub broker: DataBroker, @@ -230,181 +230,21 @@ pub struct EconomicCalendarConfig { // Traits // ================================================================================================ -/// A trait for data configurations that support derived technical analysis. -/// -/// This trait enables fluent, compile-time checked configuration of indicators -/// on OHLCV data streams. -pub trait TechnicalAnalysis { - /// Adds a technical indicator to be computed for this data stream. - fn add_indicator(&mut self, kind: TechnicalIndicator); - - /// Fluent builder version of `add_indicator`. - /// - /// # Example - /// ``` - /// # use chapaty::prelude::*; - /// let config = OhlcvSpotConfig { - /// broker: DataBroker::Binance, - /// symbol: Symbol::Spot(SpotPair::BtcUsdt), - /// exchange: None, - /// period: Period::Minute(1), - /// batch_size: 1000, - /// indicators: vec![], - /// } - /// .with_indicator(TechnicalIndicator::Sma(SmaWindow(20))); - /// ``` - fn with_indicator(mut self, kind: TechnicalIndicator) -> Self - where - Self: Sized, - { - self.add_indicator(kind); - self - } - - // === Ergonomic Sugar Helpers === - - /// Adds a Simple Moving Average (SMA) indicator. - /// - /// # Example - /// ``` - /// # use chapaty::prelude::*; - /// let config = OhlcvSpotConfig { - /// broker: DataBroker::Binance, - /// symbol: Symbol::Spot(SpotPair::BtcUsdt), - /// exchange: None, - /// period: Period::Minute(1), - /// batch_size: 1000, - /// indicators: vec![], - /// } - /// .with_sma(20); - /// ``` - fn with_sma(self, window: u16) -> Self - where - Self: Sized, - { - self.with_indicator(TechnicalIndicator::Sma(SmaWindow(window))) - } - - /// Adds an Exponential Moving Average (EMA) indicator. - /// - /// # Example - /// ``` - /// # use chapaty::prelude::*; - /// let config = OhlcvSpotConfig { - /// broker: DataBroker::Binance, - /// symbol: Symbol::Spot(SpotPair::BtcUsdt), - /// exchange: None, - /// period: Period::Minute(1), - /// batch_size: 1000, - /// indicators: vec![], - /// } - /// .with_ema(12); - /// ``` - fn with_ema(self, window: u16) -> Self - where - Self: Sized, - { - self.with_indicator(TechnicalIndicator::Ema(EmaWindow(window))) - } - - /// Adds a Relative Strength Index (RSI) indicator. - /// - /// # Example - /// ``` - /// # use chapaty::prelude::*; - /// let config = OhlcvSpotConfig { - /// broker: DataBroker::Binance, - /// symbol: Symbol::Spot(SpotPair::BtcUsdt), - /// exchange: None, - /// period: Period::Minute(1), - /// batch_size: 1000, - /// indicators: vec![], - /// } - /// .with_rsi(14); - /// ``` - fn with_rsi(self, window: u16) -> Self - where - Self: Sized, - { - self.with_indicator(TechnicalIndicator::Rsi(RsiWindow(window))) - } +impl WithBatchIndicators for OhlcvSpotQuery { + type BatchIndicator = BatchOhlcvIndicator; - // === Multi-indicator Helpers === - - /// Adds multiple indicators at once. - /// - /// # Example - /// ``` - /// # use chapaty::prelude::*; - /// let config = OhlcvSpotConfig { - /// broker: DataBroker::Binance, - /// symbol: Symbol::Spot(SpotPair::BtcUsdt), - /// exchange: None, - /// period: Period::Minute(1), - /// batch_size: 1000, - /// indicators: vec![], - /// } - /// .with_indicators(vec![ - /// TechnicalIndicator::Sma(SmaWindow(20)), - /// TechnicalIndicator::Ema(EmaWindow(12)), - /// ]); - /// ``` - fn with_indicators(mut self, kinds: Vec) -> Self - where - Self: Sized, - { - for kind in kinds { - self.add_indicator(kind); - } - self - } - - /// Chainable helper to add multiple SMAs. - /// - /// # Example - /// ``` - /// # use chapaty::prelude::*; - /// let config = OhlcvSpotConfig { - /// broker: DataBroker::Binance, - /// symbol: Symbol::Spot(SpotPair::BtcUsdt), - /// exchange: None, - /// period: Period::Minute(1), - /// batch_size: 1000, - /// indicators: vec![], - /// } - /// .with_smas(&[20, 50, 200]); - /// ``` - fn with_smas(mut self, windows: &[u16]) -> Self - where - Self: Sized, - { - for &window in windows { - self.add_indicator(TechnicalIndicator::Sma(SmaWindow(window))); - } - self - } - - /// Chainable helper to add multiple EMAs. - fn with_emas(mut self, windows: &[u16]) -> Self - where - Self: Sized, - { - for &window in windows { - self.add_indicator(TechnicalIndicator::Ema(EmaWindow(window))); - } + fn with_indicator(mut self, kind: Self::BatchIndicator) -> Self { + self.indicators.push(kind); self } } -impl TechnicalAnalysis for OhlcvSpotConfig { - fn add_indicator(&mut self, kind: TechnicalIndicator) { - self.indicators.push(kind); - } -} +impl WithBatchIndicators for OhlcvFutureQuery { + type BatchIndicator = BatchOhlcvIndicator; -impl TechnicalAnalysis for OhlcvFutureConfig { - fn add_indicator(&mut self, kind: TechnicalIndicator) { + fn with_indicator(mut self, kind: Self::BatchIndicator) -> Self { self.indicators.push(kind); + self } } @@ -413,7 +253,7 @@ impl TechnicalAnalysis for OhlcvFutureConfig { /// This trait enables type-safe conversion from user-facing configuration /// (which includes wire protocol details like batch_size) to internal /// domain identifiers used for stream management. -pub trait ConfigId { +pub trait QueryId { /// The unique identifier type for this configuration's data stream. type Id: Copy + PartialEq + Eq + Hash + PartialOrd + Ord + Debug + Send + Sync; @@ -428,7 +268,7 @@ pub trait ConfigId { /// combinations or if required conversions fail. fn to_id(&self) -> ChapatyResult; } -impl ConfigId for OhlcvSpotConfig { +impl QueryId for OhlcvSpotQuery { type Id = OhlcvId; fn to_id(&self) -> ChapatyResult { @@ -446,7 +286,7 @@ impl ConfigId for OhlcvSpotConfig { } } -impl ConfigId for OhlcvFutureConfig { +impl QueryId for OhlcvFutureQuery { type Id = OhlcvId; fn to_id(&self) -> ChapatyResult { @@ -464,7 +304,7 @@ impl ConfigId for OhlcvFutureConfig { } } -impl ConfigId for TradeSpotConfig { +impl QueryId for TradeSpotQuery { type Id = TradesId; fn to_id(&self) -> ChapatyResult { @@ -481,7 +321,7 @@ impl ConfigId for TradeSpotConfig { } } -impl ConfigId for TpoSpotConfig { +impl QueryId for TpoSpotQuery { type Id = TpoId; fn to_id(&self) -> ChapatyResult { @@ -499,7 +339,7 @@ impl ConfigId for TpoSpotConfig { } } -impl ConfigId for TpoFutureConfig { +impl QueryId for TpoFutureQuery { type Id = TpoId; fn to_id(&self) -> ChapatyResult { @@ -517,7 +357,7 @@ impl ConfigId for TpoFutureConfig { } } -impl ConfigId for VolumeProfileSpotConfig { +impl QueryId for VolumeProfileSpotQuery { type Id = VolumeProfileId; fn to_id(&self) -> ChapatyResult { @@ -535,7 +375,7 @@ impl ConfigId for VolumeProfileSpotConfig { } } -impl ConfigId for EconomicCalendarConfig { +impl QueryId for EconomicCalendarQuery { type Id = EconomicCalendarId; fn to_id(&self) -> ChapatyResult { diff --git a/src/data/view.rs b/src/data/view.rs index 4b68c3c..df1ffee 100644 --- a/src/data/view.rs +++ b/src/data/view.rs @@ -31,14 +31,16 @@ pub trait StreamView<'env> { /// Returns the slice of events visible at the current time step. fn get_slice(&self, id: &Self::Id) -> Option<&'env [Self::Event]>; - #[inline] + fn len(&self, id: &Self::Id) -> usize { + self.get_slice(id).map_or(0, |s| s.len()) + } + fn last_event(&self, id: &Self::Id) -> Option<&'env Self::Event> { self.get_slice(id).and_then(|s| s.last()) } /// Returns an iterator over events in reverse chronological order (Newest -> Oldest). /// This is the primary access pattern for RL agents reacting to recent news. - #[inline] fn rev_iter( &self, id: &Self::Id, @@ -48,7 +50,6 @@ pub trait StreamView<'env> { /// Returns all *new* events with point in time after `since_ts`. /// Useful for agents to react to the latest news or ticks. - #[inline] fn new_events_since( &self, id: &Self::Id, @@ -98,7 +99,6 @@ impl<'env, S: StreamId + 'env> StreamView<'env> for View<'env, S> { type Id = S; type Event = S::Event; - #[inline] fn get_slice(&self, id: &S) -> Option<&'env [S::Event]> { self.data.get(id).copied() } @@ -262,13 +262,11 @@ impl<'env> MarketView<'env> { impl<'env> MarketView<'env> { /// Returns a stack-allocated array of all views that support price checking. - #[inline] fn all_price_checkable_views(&self) -> [&dyn PriceCheckableView; 5] { [&self.ohlcv, &self.trades, &self.ema, &self.sma, &self.rsi] } /// Returns a stack-allocated array of all views that provide a canonical market "Close" price. - #[inline] fn close_price_views(&self) -> [&dyn ClosePriceView; 2] { [&self.ohlcv, &self.trades] } diff --git a/src/gym/trading.rs b/src/gym/trading.rs index a71e189..e68f469 100644 --- a/src/gym/trading.rs +++ b/src/gym/trading.rs @@ -1,9 +1,6 @@ use crate::{ error::ChapatyResult, - gym::{ - Reward, StepOutcome, - trading::{action::Actions, observation::Observation}, - }, + gym::{Reward, StepOutcome}, }; pub mod action; @@ -18,7 +15,15 @@ pub mod observation; pub mod state; pub mod types; -pub use factory::{load, make}; +pub use action::*; +pub use action_space::*; +pub use agent::*; +pub use config::*; +pub use env::*; +pub use factory::*; +pub use observation::*; +pub use state::*; +pub use types::*; pub trait Env { fn reset(&mut self) -> ChapatyResult<(Observation<'_>, Reward, StepOutcome)>; diff --git a/src/gym/trading/agent.rs b/src/gym/trading/agent.rs index 4ddc7fb..aa77842 100644 --- a/src/gym/trading/agent.rs +++ b/src/gym/trading/agent.rs @@ -1,6 +1,3 @@ -pub mod crossover; -pub mod news; - use std::sync::Arc; use crate::{ diff --git a/src/gym/trading/agent/crossover.rs b/src/gym/trading/agent/crossover.rs deleted file mode 100644 index c40edae..0000000 --- a/src/gym/trading/agent/crossover.rs +++ /dev/null @@ -1,230 +0,0 @@ -use std::sync::Arc; - -use chrono::{DateTime, Utc}; -use serde::Serialize; - -use crate::{ - data::{ - domain::{Quantity, TradeId}, - event::{IndicatorValueProvider, OhlcvId, SmaId}, - view::StreamView, - }, - error::ChapatyResult, - gym::{ - AgentIdentifier, - trading::{ - action::{Action, Actions, MarketCloseCmd, OpenCmd}, - agent::Agent, - observation::Observation, - types::TradeType, - }, - }, - math::indicator::{StreamingIndicator, StreamingSma}, -}; - -// ================================================================================================ -// Streaming SMA Crossover -// ================================================================================================ -#[derive(Debug, Clone, Serialize)] -pub struct StreamingCrossover { - #[serde(skip)] - ohlcv_id: OhlcvId, - fast_period: u16, - slow_period: u16, - #[serde(skip)] - fast_sma: StreamingSma, - #[serde(skip)] - slow_sma: StreamingSma, - #[serde(skip)] - current_fast: Option, - #[serde(skip)] - current_slow: Option, - #[serde(skip)] - trade_counter: i64, - #[serde(skip)] - last_processed_ts: Option>, -} - -impl StreamingCrossover { - pub fn new(ohlcv_id: OhlcvId, fast_period: u16, slow_period: u16) -> Self { - Self { - ohlcv_id, - fast_sma: StreamingSma::new(fast_period), - slow_sma: StreamingSma::new(slow_period), - fast_period, - slow_period, - trade_counter: 0, - current_fast: None, - current_slow: None, - last_processed_ts: None, - } - } -} - -impl Agent for StreamingCrossover { - fn identifier(&self) -> AgentIdentifier { - AgentIdentifier::Named(Arc::new("StreamingCrossover".to_string())) - } - - fn reset(&mut self) { - self.fast_sma.reset(); - self.slow_sma.reset(); - self.trade_counter = 0; - self.current_fast = None; - self.current_slow = None; - self.last_processed_ts = None; - } - - fn act(&mut self, obs: Observation) -> ChapatyResult { - let market_view = &obs.market_view; - - // 1. Fetch the latest candle - let Some(candle) = market_view.ohlcv().last_event(&self.ohlcv_id) else { - return Ok(Actions::no_op()); - }; - - // 2. Update Internal State (Idempotency check) - // We only push to the SMA buffer if we have moved to a new timestamp. - if self.last_processed_ts != Some(candle.close_timestamp) { - self.current_fast = self.fast_sma.update(candle.close.0); - self.current_slow = self.slow_sma.update(candle.close.0); - self.last_processed_ts = Some(candle.close_timestamp); - } - - // 3. Check Signal Validity - let (Some(fast), Some(slow)) = (self.current_fast, self.current_slow) else { - // SMAs are not warm yet - return Ok(Actions::no_op()); - }; - - // 4. Determine Position Status - let agent_id = self.identifier(); - let active_trade = obs.states.find_active_trade_for_agent(&agent_id); - - // 5. Signal Logic - if fast > slow { - // Golden Cross (Bullish): Fast > Slow - // If we are not already Long, we enter. - if active_trade.is_none() { - self.trade_counter += 1; - - let cmd = OpenCmd { - agent_id, - trade_id: TradeId(self.trade_counter), - trade_type: TradeType::Long, - quantity: Quantity(1.0), - entry_price: None, // Market Order - stop_loss: None, - take_profit: None, - }; - - return Ok(Actions::from((self.ohlcv_id.into(), Action::Open(cmd)))); - } - } else if fast < slow { - // Death Cross (Bearish): Fast < Slow - // If we have an open Long position, close it. - if let Some((_, state)) = active_trade { - let cmd = MarketCloseCmd { - agent_id, - trade_id: state.trade_id(), - quantity: None, // Close Full Position - }; - return Ok(Actions::from(( - self.ohlcv_id.into(), - Action::MarketClose(cmd), - ))); - } - } - - Ok(Actions::no_op()) - } -} - -// ================================================================================================ -// Precomputed SMA Crossover -// ================================================================================================ - -#[derive(Debug, Clone, Serialize)] -pub struct PrecomputedCrossover { - #[serde(skip)] - ohlcv_id: OhlcvId, - fast_sma_id: SmaId, - slow_sma_id: SmaId, - - #[serde(skip)] - trade_counter: i64, -} - -impl PrecomputedCrossover { - pub fn new(ohlcv_id: OhlcvId, fast_sma_id: SmaId, slow_sma_id: SmaId) -> Self { - Self { - ohlcv_id, - fast_sma_id, - slow_sma_id, - trade_counter: 0, - } - } -} - -impl Agent for PrecomputedCrossover { - fn identifier(&self) -> AgentIdentifier { - AgentIdentifier::Named(Arc::new("PrecomputedCrossover".to_string())) - } - - fn reset(&mut self) { - self.trade_counter = 0; - } - - fn act(&mut self, obs: Observation) -> ChapatyResult { - let view = &obs.market_view; - - // 1. Get pre-computed values directly from the environment - let fast_event = view.sma().last_event(&self.fast_sma_id); - let slow_event = view.sma().last_event(&self.slow_sma_id); - - let (Some(fast_evt), Some(slow_evt)) = (fast_event, slow_event) else { - return Ok(Actions::no_op()); - }; - - let fast = fast_evt.value(); - let slow = slow_evt.value(); - - // 2. Position Management - let agent_id = self.identifier(); - let active_trade = obs.states.find_active_trade_for_agent(&agent_id); - - // 3. Signal Logic - if fast > slow { - // Buy Signal - if active_trade.is_none() { - self.trade_counter += 1; - - let cmd = OpenCmd { - agent_id, - trade_id: TradeId(self.trade_counter), - trade_type: TradeType::Long, - quantity: Quantity(1.0), - entry_price: None, - stop_loss: None, - take_profit: None, - }; - return Ok(Actions::from((self.ohlcv_id.into(), Action::Open(cmd)))); - } - } else if fast < slow { - // Sell Signal - if let Some((_, state)) = active_trade { - let cmd = MarketCloseCmd { - agent_id, - trade_id: state.trade_id(), - quantity: None, - }; - return Ok(Actions::from(( - self.ohlcv_id.into(), - Action::MarketClose(cmd), - ))); - } - } - - Ok(Actions::no_op()) - } -} diff --git a/src/gym/trading/agent/news.rs b/src/gym/trading/agent/news.rs deleted file mode 100644 index d542edd..0000000 --- a/src/gym/trading/agent/news.rs +++ /dev/null @@ -1,21 +0,0 @@ -use chrono::{DateTime, Utc}; - -use crate::data::event::Ohlcv; - -pub mod breakout; -pub mod fade; -pub mod hybrid; - -#[derive(Debug, Copy, Clone, Default)] -enum NewsPhase { - /// The agent is waiting for a news event to occur. - #[default] - AwaitingNews, - - /// A news event has been observed. The agent is now waiting for the - /// `wait_duration` to elapse before entering a trade. - PostNews { - news_time: DateTime, - news_candle: Option, - }, -} diff --git a/src/gym/trading/agent/news/breakout.rs b/src/gym/trading/agent/news/breakout.rs deleted file mode 100644 index ec22922..0000000 --- a/src/gym/trading/agent/news/breakout.rs +++ /dev/null @@ -1,487 +0,0 @@ -use std::sync::Arc; - -use chrono::Duration; -use itertools::iproduct; -use serde::Serialize; -use serde_with::{DurationSeconds, serde_as}; - -use crate::{ - data::{ - domain::{CandleDirection, Price, Quantity, TradeId}, - event::{EconomicCalendarId, MarketId, Ohlcv, OhlcvId}, - view::StreamView, - }, - error::{AgentError, ChapatyResult}, - gym::{ - AgentIdentifier, GridAxis, - trading::{ - action::{Action, Actions, OpenCmd}, - agent::{Agent, news::NewsPhase}, - observation::Observation, - types::TradeType, - }, - }, -}; - -#[serde_as] -#[derive(Debug, Clone, Copy, Serialize)] -pub struct NewsBreakout { - #[serde(skip)] - economic_cal_id: EconomicCalendarId, - #[serde(skip)] - ohlcv_id: OhlcvId, - /// Earliest allowed entry time after the news event. - #[serde_as(as = "DurationSeconds")] - earliest_entry: Duration, - /// Latest allowed entry time after the news event. - #[serde_as(as = "DurationSeconds")] - latest_entry: Duration, - - /// A factor that defines the portion of the news candle's body to risk before a stop-loss is triggered. - /// - /// The calculation starts from the news candle's **close price** and moves towards - /// (or beyond) its **open price**. A higher value means a wider stop-loss and more risk. - /// - /// - **`-0.5`**: Places the stop-loss **beyond the close price**, creating an extra safety margin equal to **50%** of the candle's body size. - /// - **`0.0`**: Places the stop-loss at the **close price**. This risks **0%** of the candle body. - /// - **`0.5`**: Places the stop-loss at the **midpoint** of the body. This risks **50%** of the body. - /// - **`1.0`**: Places the stop-loss at the **open price**. This risks **100%** of the candle body. - /// - **`1.5`**: Places the stop-loss **beyond the open price**, creating an extra risk margin equal to **50%** of the candle's body size. - /// - /// # Formulas - /// Let `body_size = |news_open - news_close|`. - /// - For **Long** trades: `StopLoss = news_close - body_size * stop_loss_risk_factor` - /// - For **Short** trades: `StopLoss = news_close + body_size * stop_loss_risk_factor` - /// - /// # Long Trade Example (Bullish News Candle: Open=100, Close=110, Body=10) - /// - `stop_loss_risk_factor = 0.0` -> SL is `110 - 10 * 0.0 = 110` - /// - `stop_loss_risk_factor = 1.0` -> SL is `110 - 10 * 1.0 = 100` - /// - `stop_loss_risk_factor = -0.2` -> SL is `110 - 10 * -0.2 = 112` (Extra safety margin) - /// - /// # Short Trade Example (Bearish News Candle: Open=100, Close=90, Body=10) - /// - `stop_loss_risk_factor = 0.0` -> SL is `90 + 10 * 0.0 = 90` - /// - `stop_loss_risk_factor = 1.0` -> SL is `90 + 10 * 1.0 = 100` - /// - `stop_loss_risk_factor = -0.2` -> SL is `90 + 10 * -0.2 = 88` (Extra safety margin) - stop_loss_risk_factor: f64, - - /// Risk-Reward Ratio (RRR) for the strategy. - /// - /// The Risk-Reward Ratio defines the relationship between the potential **loss** (risk) and - /// the potential **gain** (reward) of a trade. It is used to calculate the **take-profit** - /// level given a known entry price and stop-loss price. - /// - /// # Formula - /// ```text - /// RRR = |risk| / |reward| - /// - /// where: - /// - risk = |entry_price - stop_loss_price| - /// - reward = |take_profit_price - entry_price| - /// ``` - /// - /// # Interpretation - /// - `risk_reward_ratio > 1.0` -> risking more than potential reward (caution) - /// - `risk_reward_ratio = 1.0` -> risk equals reward - /// - `risk_reward_ratio < 1.0` -> potential reward exceeds risk (favorable) - /// - /// # Valid Values - /// Must be strictly positive (`risk_reward_ratio > 0.0`), otherwise the trade setup is invalid. - /// - /// # Take-Profit Calculation - /// Given a trade entry and stop-loss (from `stop_loss_risk_factor`): - /// - /// - **Long Trade** (bullish entry): - /// ```text - /// take_profit = entry_price + (entry_price - stop_loss_price) / risk_reward_ratio - /// ``` - /// - **Short Trade** (bearish entry): - /// ```text - /// take_profit = entry_price - (stop_loss_price - entry_price) / risk_reward_ratio - /// ``` - /// - /// # Examples - /// Long trade: entry = 100, stop-loss = 95, risk_reward_ratio = 0.5 - /// ```text - /// take_profit = 100 + (100 - 95) / 0.5 = 110 - /// ``` - /// - /// Short trade: entry = 100, stop-loss = 105, risk_reward_ratio = 2.0 - /// ```text - /// take_profit = 100 - (100 - 105) / 2.0 = 97.5 - /// ``` - risk_reward_ratio: f64, - - // === Internal only === - #[serde(skip)] - phase: NewsPhase, - - #[serde(skip)] - trade_counter: i64, -} - -impl NewsBreakout { - pub fn baseline(economic_cal_id: EconomicCalendarId, ohlcv_id: OhlcvId) -> Self { - Self { - economic_cal_id, - ohlcv_id, - earliest_entry: Duration::seconds(480), - latest_entry: Duration::seconds(3000), - stop_loss_risk_factor: 0.89, - risk_reward_ratio: 0.726, - phase: NewsPhase::default(), - trade_counter: 0, - } - } - - pub fn economic_calendar_id(&self) -> EconomicCalendarId { - self.economic_cal_id - } - - pub fn ohlcv_id(&self) -> OhlcvId { - self.ohlcv_id - } - - pub fn earliest_entry(&self) -> Duration { - self.earliest_entry - } - - pub fn latest_entry(&self) -> Duration { - self.latest_entry - } - - pub fn stop_loss_risk_factor(&self) -> f64 { - self.stop_loss_risk_factor - } - - pub fn risk_reward_ratio(&self) -> f64 { - self.risk_reward_ratio - } - - pub fn with_calendar_id(self, economic_cal_id: EconomicCalendarId) -> Self { - Self { - economic_cal_id, - ..self - } - } - - pub fn with_ohlcv_id(self, ohlcv_id: OhlcvId) -> Self { - Self { ohlcv_id, ..self } - } - - pub fn with_earliest_entry_candle(self, duration: Duration) -> Self { - Self { - earliest_entry: duration, - ..self - } - } - - pub fn with_latest_entry_candle(self, duration: Duration) -> Self { - Self { - latest_entry: duration, - ..self - } - } - - pub fn with_stop_loss_risk_factor(self, factor: f64) -> Self { - Self { - stop_loss_risk_factor: factor, - ..self - } - } - - pub fn with_risk_reward_ratio(self, ratio: f64) -> ChapatyResult { - if ratio <= 0.0 { - return Err( - AgentError::InvalidInput("risk_reward_ratio must be > 0.0".to_string()).into(), - ); - } - Ok(Self { - risk_reward_ratio: ratio, - ..self - }) - } -} - -impl NewsBreakout { - /// Computes the **stop-loss target** for following a news candle. - /// - /// This includes both the stop-loss **price** and the **trade type**, - /// because the trade direction (Long vs Short) is determined by the - /// candle’s direction: - /// - /// - **Bearish candle** -> follow downward -> `Short` - /// - **Bullish candle** -> follow upward -> `Long` - /// - /// Returns `None` if the candle has no clear direction (e.g., a doji). - fn stop_loss_target(&self, news_candle: &Ohlcv) -> Option { - let open = news_candle.open.0; - let close = news_candle.close.0; - let body_size = (open - close).abs(); - - match news_candle.direction() { - CandleDirection::Bearish => { - let price = close + body_size * self.stop_loss_risk_factor; - Some(StopLossTarget { - stop_loss_price: Price(price), - trade_type: TradeType::Short, - }) - } - CandleDirection::Bullish => { - let price = close - body_size * self.stop_loss_risk_factor; - Some(StopLossTarget { - stop_loss_price: Price(price), - trade_type: TradeType::Long, - }) - } - CandleDirection::Doji => None, - } - } -} -impl Agent for NewsBreakout { - fn act(&mut self, obs: Observation) -> ChapatyResult { - let economic_cal_id = self.economic_cal_id; - let ohlcv_id = self.ohlcv_id; - - let current_time = obs.market_view.current_timestamp(); - - // === Early return: skip if already in trade === - if obs.states.any_active_trade_for_agent(&self.identifier()) { - return Ok(Actions::no_op()); - } - - // === 1. Update phase === - if let NewsPhase::AwaitingNews = self.phase - && let Some(news_event) = obs.market_view.economic_news().last_event(&economic_cal_id) - { - let news_candle_candidate = obs - .market_view - .ohlcv() - .last_event(&ohlcv_id) - .filter(|candle| candle.open_timestamp == news_event.timestamp) - .copied(); - self.phase = NewsPhase::PostNews { - news_time: news_event.timestamp, - news_candle: news_candle_candidate, - }; - } - - // === 2. Decide action == - let (news_time, candle) = if let NewsPhase::PostNews { - news_time, - news_candle: Some(candle), - } = self.phase - { - (news_time, candle) - } else { - self.phase = NewsPhase::AwaitingNews; - return Ok(Actions::no_op()); - }; - - let time_since_news = current_time - news_time; - - if time_since_news < self.earliest_entry { - return Ok(Actions::no_op()); - } - if time_since_news > self.latest_entry { - self.phase = NewsPhase::AwaitingNews; - return Ok(Actions::no_op()); - } - - // 1. Get Current Price (for Breakout Check & Math) - let entry_price = obs.market_view.try_resolved_close_price(&ohlcv_id.symbol)?; - - let breakout_up = entry_price.0 > candle.high.0; - let breakout_down = entry_price.0 < candle.low.0; - if !breakout_up && !breakout_down { - return Ok(Actions::no_op()); // no breakout - } - - let sl_target = match self.stop_loss_target(&candle) { - Some(tp) => tp, - None => { - self.phase = NewsPhase::AwaitingNews; - return Ok(Actions::no_op()); - } - }; - - // 2. Generate Unique ID - self.trade_counter += 1; - let trade_id = TradeId(self.trade_counter); - - // 3. Define Quantity - let quantity = Quantity(1.0); - - // 4. Construct Command (Struct Init) - let cmd = OpenCmd { - agent_id: self.identifier(), - trade_id, - trade_type: sl_target.trade_type, - quantity, - - // EXECUTION: Market Order (None) - // A breakout strategy must enter immediately. Waiting for a limit - // at the breakout level might miss the momentum. - entry_price: None, - - // MATH: We use the calculated targets - stop_loss: Some(sl_target.stop_loss_price), - // Note: entry_price is passed here purely for the math calculation - take_profit: Some(sl_target.take_profit_price(entry_price, self.risk_reward_ratio)), - }; - - self.phase = NewsPhase::AwaitingNews; - - let market_id: MarketId = ohlcv_id.into(); - Ok(Actions::from((market_id, Action::Open(cmd)))) - } - - fn identifier(&self) -> AgentIdentifier { - AgentIdentifier::Named(Arc::new("NewsBreakout".to_string())) - } - - fn reset(&mut self) { - self.phase = NewsPhase::AwaitingNews; - self.trade_counter = 0; - } -} - -// ================================================================================================ -// Helper Structs -// ================================================================================================ - -/// Result of a stop-loss calculation. -/// -/// Includes both the target price and the trade direction, -/// since the direction is implied by the news candle. -struct StopLossTarget { - stop_loss_price: Price, - trade_type: TradeType, -} - -impl StopLossTarget { - /// Computes the take-profit price for this stop-loss target, given the - /// trade entry price and a risk-reward ratio (RRR). - /// - /// # Formula - /// - **Long Trade**: - /// ```text - /// take_profit = entry_price + (entry_price - stop_loss_price) / risk_reward_ratio - /// ``` - /// - /// - **Short Trade**: - /// ```text - /// take_profit = entry_price - (stop_loss_price - entry_price) / risk_reward_ratio - /// ``` - fn take_profit_price(&self, entry_price: Price, risk_reward_ratio: f64) -> Price { - let sl = self.stop_loss_price.0; - let entry = entry_price.0; - - let sl = match self.trade_type { - TradeType::Long => entry + (entry - sl) / risk_reward_ratio, - TradeType::Short => entry - (sl - entry) / risk_reward_ratio, - }; - - Price(sl) - } -} - -// ================================================================================================ -// Grid Generator -// ================================================================================================ - -pub struct NewsBreakoutGrid { - cal_id: EconomicCalendarId, - market_id: OhlcvId, - earliest_entry: (Duration, Duration), - latest_entry: (Duration, Duration), - stop_loss_risk_factor: GridAxis, - risk_reward_ratio: GridAxis, -} - -impl NewsBreakoutGrid { - /// Creates a grid generator with a default "Baseline" search space. - pub fn baseline(cal_id: EconomicCalendarId, market_id: OhlcvId) -> ChapatyResult { - Ok(Self { - cal_id, - market_id, - earliest_entry: (Duration::minutes(1), Duration::minutes(6)), - latest_entry: (Duration::minutes(20), Duration::minutes(28)), - stop_loss_risk_factor: GridAxis::new("0.5", "1.5", "0.01")?, - risk_reward_ratio: GridAxis::new("0.1", "2.6", "0.01")?, - }) - } - - /// Overrides the range of earliest entry times. Range is `[start, end)`. - pub fn with_earliest_entry_range(self, start: Duration, end: Duration) -> Self { - Self { - earliest_entry: (start, end), - ..self - } - } - - /// Overrides the range of latest entry times. Range is `[start, end)`. - pub fn with_latest_entry_range(self, start: Duration, end: Duration) -> Self { - Self { - latest_entry: (start, end), - ..self - } - } - - /// Overrides the stop-loss risk factor range. Range is `[start, end)`. - pub fn with_stop_loss_risk_factor(self, axis: GridAxis) -> Self { - Self { - stop_loss_risk_factor: axis, - ..self - } - } - - /// Overrides the risk reward ratio range. Range is `[start, end)`. - pub fn with_risk_reward_ratio(self, axis: GridAxis) -> Self { - Self { - risk_reward_ratio: axis, - ..self - } - } - - pub fn build(self) -> Vec<(usize, NewsBreakout)> { - let (start_earliest, end_earliest) = self.earliest_entry; - let (start_latest, end_latest) = self.latest_entry; - - // === Generate Axes === - let stop_loss_risk_factors = self.stop_loss_risk_factor.generate(); - let risk_reward_ratios = self.risk_reward_ratio.generate(); - - let earliest_entries = (start_earliest.num_minutes()..end_earliest.num_minutes()) - .map(Duration::minutes) - .collect::>(); - - let latest_entries = (start_latest.num_minutes()..end_latest.num_minutes()) - .map(Duration::minutes) - .collect::>(); - - // === Eagerly Collect Valid Args === - let cal_id = self.cal_id; - let market_id = self.market_id; - - iproduct!( - risk_reward_ratios, - stop_loss_risk_factors, - latest_entries, - earliest_entries - ) - .filter(|(_, _, latest, earliest)| earliest < latest) - .enumerate() - .map(|(uid, (rrr, slrf, latest, earliest))| { - ( - uid, - NewsBreakout::baseline(cal_id, market_id) - .with_earliest_entry_candle(earliest) - .with_latest_entry_candle(latest) - .with_stop_loss_risk_factor(slrf) - .with_risk_reward_ratio(rrr) - .expect("Valid grid parameters"), - ) - }) - .collect::>() - } -} diff --git a/src/gym/trading/agent/news/fade.rs b/src/gym/trading/agent/news/fade.rs deleted file mode 100644 index c5b305d..0000000 --- a/src/gym/trading/agent/news/fade.rs +++ /dev/null @@ -1,465 +0,0 @@ -use std::sync::Arc; - -use chrono::{DateTime, Duration, Utc}; - -use itertools::iproduct; -use serde::Serialize; -use serde_with::{DurationSeconds, serde_as}; - -use crate::{ - data::{ - domain::{CandleDirection, Price, Quantity, TradeId}, - event::{EconomicCalendarId, Ohlcv, OhlcvId}, - view::StreamView, - }, - error::{AgentError, ChapatyResult}, - gym::{ - AgentIdentifier, GridAxis, - trading::{ - action::{Action, Actions, OpenCmd}, - agent::{Agent, news::NewsPhase}, - observation::Observation, - types::TradeType, - }, - }, -}; - -#[serde_as] -#[derive(Debug, Clone, Copy, Serialize)] -pub struct NewsFade { - #[serde(skip)] - economic_cal_id: EconomicCalendarId, - #[serde(skip)] - ohlcv_id: OhlcvId, - /// Duration to wait after the news release before entering a trade. - /// - /// The entry price is taken from the **first close observed - /// after this duration has elapsed** since the news timestamp. - /// - /// # Examples - /// - /// - `wait_duration = Duration::zero()`: enter immediately on the news candle. - /// - `wait_duration = Duration::seconds(60)`: enter 1 minute after news. - /// - `wait_duration = Duration::minutes(5)`: enter 5 minutes after news. - #[serde_as(as = "DurationSeconds")] - wait_duration: Duration, - - /// A factor that defines the portion of the news candle's body to capture before a take-profit is triggered. - /// - /// The calculation starts from the news candle's **close price** and moves towards - /// (or beyond) its **open price**. A higher value means aiming for a larger reversal move - /// (wider take-profit target). - /// - /// - **`-0.5`**: Take-profit is set **past the close price**, i.e. on the wrong side of the reversal. - /// - **`0.0`**: Take-profit at the **close price** — reversal ends exactly at the close. - /// - **`0.5`**: Take-profit at the **midpoint** of the candle body (captures 50% of the body). - /// - **`1.0`**: Take-profit at the **open price** — a full reversal of the news candle. - /// - **`1.5`**: Take-profit **beyond the open price**, anticipating an overshoot beyond full reversal. - /// - /// # Formulas - /// Let `body_size = |news_open - news_close|`. - /// - For **Long** trades (fading a bearish candle): `TakeProfit = news_close + body_size * take_profit_risk_factor` - /// - For **Short** trades (fading a bullish candle): `TakeProfit = news_close - body_size * take_profit_risk_factor` - /// - /// # Long Trade Example (Bearish News Candle: Open=100, Close=90, Body=10) - /// - `take_profit_risk_factor = 0.0`: TP = 90 - /// - `take_profit_risk_factor = 1.0`: TP = 100 - /// - `take_profit_risk_factor = 1.5`: TP = 105 - /// - /// # Short Trade Example (Bullish News Candle: Open=100, Close=110, Body=10) - /// - `take_profit_risk_factor = 0.0`: TP = 110 - /// - `take_profit_risk_factor = 1.0`: TP = 100 - /// - `take_profit_risk_factor = 1.5`: TP = 95 - take_profit_risk_factor: f64, - - /// Risk-Reward Ratio (RRR) for the strategy. - /// - /// The Risk-Reward Ratio defines the relationship between the potential **loss** (risk) and - /// the potential **gain** (reward) of a trade. It is used to calculate the **stop-loss** - /// level given a known entry price and take-profit price. - /// - /// # Formula - /// ```text - /// RRR = |risk| / |reward| - /// - /// where: - /// - risk = |entry_price - stop_loss_price| - /// - reward = |take_profit_price - entry_price| - /// ``` - /// - /// # Interpretation - /// - `risk_reward_ratio > 1.0` -> risking more than potential reward (caution) - /// - `risk_reward_ratio = 1.0` -> risk equals reward - /// - `risk_reward_ratio < 1.0` -> potential reward exceeds risk (favorable) - /// - /// # Valid Values - /// Must be strictly positive (`risk_reward_ratio > 0.0`), otherwise the trade setup is invalid. - /// - /// # Stop-Loss Calculation - /// Given a trade entry and take-profit (from `take_profit_risk_factor`): - /// - /// - **Long Trade** (fading a bearish candle): - /// ```text - /// stop_loss = entry_price - (take_profit_price - entry_price) * risk_reward_ratio - /// ``` - /// - /// - **Short Trade** (fading a bullish candle): - /// ```text - /// stop_loss = entry_price + (entry_price - take_profit_price) * risk_reward_ratio - /// ``` - /// - /// # Examples - /// Long trade: entry = 100, take-profit = 110, RRR = 2.0 - /// ```text - /// stop_loss = 100 - (110 - 100) * 2.0 = 80 - /// ``` - /// - /// Short trade: entry = 100, take-profit = 90, RRR = 0.5 - /// ```text - /// stop_loss = 100 + (100 - 90) * 0.5 = 105 - /// ``` - risk_reward_ratio: f64, - - // === Internal only === - #[serde(skip)] - phase: NewsPhase, - - #[serde(skip)] - trade_counter: i64, - - ///Track the last news we already handled to prevent re-entry - #[serde(skip)] - last_processed_news: Option>, -} - -impl NewsFade { - pub fn baseline(economic_cal_id: EconomicCalendarId, ohlcv_id: OhlcvId) -> Self { - Self { - economic_cal_id, - ohlcv_id, - wait_duration: Duration::seconds(420), - take_profit_risk_factor: 1.27, - risk_reward_ratio: 0.276, - phase: NewsPhase::default(), - trade_counter: 0, - last_processed_news: None, - } - } - - pub fn economic_calendar_id(&self) -> EconomicCalendarId { - self.economic_cal_id - } - - pub fn ohlcv_id(&self) -> OhlcvId { - self.ohlcv_id - } - - pub fn wait_duration(&self) -> Duration { - self.wait_duration - } - - pub fn take_profit_risk_factor(&self) -> f64 { - self.take_profit_risk_factor - } - - pub fn risk_reward_ratio(&self) -> f64 { - self.risk_reward_ratio - } - - pub fn with_calendar_id(self, economic_cal_id: EconomicCalendarId) -> Self { - Self { - economic_cal_id, - ..self - } - } - - pub fn with_ohlcv_id(self, ohlcv_id: OhlcvId) -> Self { - Self { ohlcv_id, ..self } - } - - pub fn with_candles_after_news(self, duration: Duration) -> Self { - Self { - wait_duration: duration, - ..self - } - } - - pub fn with_take_profit_risk_factor(self, factor: f64) -> Self { - Self { - take_profit_risk_factor: factor, - ..self - } - } - - pub fn with_risk_reward_ratio(self, ratio: f64) -> ChapatyResult { - if ratio <= 0.0 { - return Err( - AgentError::InvalidInput("risk_reward_ratio must be > 0.0".to_string()).into(), - ); - } - Ok(Self { - risk_reward_ratio: ratio, - ..self - }) - } -} - -impl NewsFade { - /// Computes the **take-profit target** for fading a news candle. - /// - /// This includes both the take-profit **price** and the **trade type**, - /// because the trade direction (Long vs Short) is determined by the - /// candle’s direction: - /// - /// - **Bearish candle** -> fade upward -> `Long` - /// - **Bullish candle** -> fade downward -> `Short` - /// - /// Returns `None` if the candle has no clear direction (e.g., a doji). - fn take_profit_target(&self, news_candle: &Ohlcv) -> Option { - let open = news_candle.open.0; - let close = news_candle.close.0; - let body_size = (open - close).abs(); - - match news_candle.direction() { - CandleDirection::Bearish => { - let price = close + body_size * self.take_profit_risk_factor; - Some(TakeProfitTarget { - take_profit_price: Price(price), - trade_type: TradeType::Long, - }) - } - CandleDirection::Bullish => { - let price = close - body_size * self.take_profit_risk_factor; - Some(TakeProfitTarget { - take_profit_price: Price(price), - trade_type: TradeType::Short, - }) - } - CandleDirection::Doji => None, - } - } -} - -impl Agent for NewsFade { - fn act(&mut self, obs: Observation) -> ChapatyResult { - let current_time = obs.market_view.current_timestamp(); - - // === Early return: skip if already in trade === - if obs.states.any_active_trade_for_agent(&self.identifier()) { - return Ok(Actions::no_op()); - } - - // === 1. Update Phase === - if let NewsPhase::AwaitingNews = self.phase - && let Some(news_event) = obs - .market_view - .economic_news() - .last_event(&self.economic_cal_id) - { - // Check if we already processed this specific event - if Some(news_event.timestamp) == self.last_processed_news { - return Ok(Actions::no_op()); - } - - let news_candle = obs - .market_view - .ohlcv() - .last_event(&self.ohlcv_id) - .filter(|candle| candle.open_timestamp == news_event.timestamp) - .copied(); - - self.phase = NewsPhase::PostNews { - news_time: news_event.timestamp, - news_candle, - }; - } - - // === 2. Decision Phase === - let (news_time, candle) = if let NewsPhase::PostNews { - news_time, - news_candle: Some(candle), - } = self.phase - { - (news_time, candle) - } else { - // Candle not found yet? - // We simply revert to Awaiting to retry the fetch in step 1. - self.phase = NewsPhase::AwaitingNews; - return Ok(Actions::no_op()); - }; - - // Check Wait Duration - if current_time < news_time + self.wait_duration { - return Ok(Actions::no_op()); - } - - // === 3. Execution Phase === - let tp_target = match self.take_profit_target(&candle) { - Some(tp) => tp, - None => { - // Invalid candle (Doji) -> Mark news as processed so we don't retry forever - self.last_processed_news = Some(news_time); - self.phase = NewsPhase::AwaitingNews; - return Ok(Actions::no_op()); - } - }; - - self.trade_counter += 1; - let trade_id = TradeId(self.trade_counter); - let quantity = Quantity(1.0); - let estimated_entry = obs - .market_view - .try_resolved_close_price(&self.ohlcv_id.symbol)?; - - let cmd = OpenCmd { - agent_id: self.identifier(), - trade_id, - trade_type: tp_target.trade_type, - quantity, - entry_price: None, - stop_loss: Some(tp_target.stop_loss_price(estimated_entry, self.risk_reward_ratio)), - take_profit: Some(tp_target.take_profit_price), - }; - - // Mark this news event as processed - self.last_processed_news = Some(news_time); - self.phase = NewsPhase::AwaitingNews; - - Ok(Actions::from((self.ohlcv_id.into(), Action::Open(cmd)))) - } - - fn identifier(&self) -> AgentIdentifier { - AgentIdentifier::Named(Arc::new("NewsFade".to_string())) - } - - fn reset(&mut self) { - self.phase = NewsPhase::AwaitingNews; - self.trade_counter = 0; - self.last_processed_news = None; - } -} - -// ================================================================================================ -// Helper Structs -// ================================================================================================ - -/// Result of a take-profit calculation. -/// -/// Includes both the target price and the trade direction, -/// since the direction is implied by the news candle. -struct TakeProfitTarget { - take_profit_price: Price, - trade_type: TradeType, -} - -impl TakeProfitTarget { - /// Computes the stop-loss price for this take-profit target, given the - /// trade entry price and a risk-reward ratio (RRR). - /// - /// # Formula - /// - **Long Trade**: - /// ```text - /// stop_loss = entry_price - (take_profit_price - entry_price) * risk_reward_ratio - /// ``` - /// - /// - **Short Trade**: - /// ```text - /// stop_loss = entry_price + (entry_price - take_profit_price) * risk_reward_ratio - /// ``` - fn stop_loss_price(&self, entry_price: Price, risk_reward_ratio: f64) -> Price { - let tp = self.take_profit_price.0; - let entry = entry_price.0; - - let sl = match self.trade_type { - TradeType::Long => entry - (tp - entry) * risk_reward_ratio, - TradeType::Short => entry + (entry - tp) * risk_reward_ratio, - }; - - Price(sl) - } -} - -// ================================================================================================ -// Grid Generator -// ================================================================================================ - -pub struct NewsFadeGrid { - cal_id: EconomicCalendarId, - ohlcv_id: OhlcvId, - wait_duration: (Duration, Duration), - tp_risk_factor: GridAxis, - risk_reward: GridAxis, -} - -impl NewsFadeGrid { - /// Creates a grid generator with a default "Baseline" search space. - /// - /// This pre-populates the ranges with standard values, ensuring the grid - /// is valid immediately. - pub fn baseline(cal_id: EconomicCalendarId, ohlcv_id: OhlcvId) -> ChapatyResult { - Ok(Self { - cal_id, - ohlcv_id, - wait_duration: (Duration::minutes(5), Duration::minutes(30)), - tp_risk_factor: GridAxis::new("0.5", "3.0", "0.01")?, - risk_reward: GridAxis::new("0.1", "1.0", "0.01")?, - }) - } - - /// Overrides the range of candles to consider after a news event. - /// Range is `[start, end)`. - pub fn with_candles_after_news(self, start: Duration, end: Duration) -> Self { - Self { - wait_duration: (start, end), - ..self - } - } - - /// Overrides the take-profit risk factor parameter range. - /// Range is `[start, end)`. - pub fn with_take_profit_risk_factor(self, axis: GridAxis) -> Self { - Self { - tp_risk_factor: axis, - ..self - } - } - - /// Overrides the risk reward ratio parameter range. - /// Range is `[start, end)`. - pub fn with_risk_reward_ratio(self, axis: GridAxis) -> Self { - Self { - risk_reward: axis, - ..self - } - } - - pub fn build(self) -> Vec<(usize, NewsFade)> { - let (start_wait, end_wait) = self.wait_duration; - - // === 1. Generate Axes === - let candles_after_news = (start_wait.num_minutes()..end_wait.num_minutes()) - .map(Duration::minutes) - .collect::>(); - - let take_profit_factors = self.tp_risk_factor.generate(); - let risk_rewards = self.risk_reward.generate(); - - // === 2. Eagerly Collect Valid Args === - let cal_id = self.cal_id; - let ohlcv_id = self.ohlcv_id; - - iproduct!(risk_rewards, candles_after_news, take_profit_factors) - .enumerate() - .map(|(uid, (rrr, wait, tprf))| { - ( - uid, - NewsFade::baseline(cal_id, ohlcv_id) - .with_candles_after_news(wait) - .with_take_profit_risk_factor(tprf) - .with_risk_reward_ratio(rrr) - .expect("Valid grid parameters"), - ) - }) - .collect::>() - } -} diff --git a/src/gym/trading/agent/news/hybrid.rs b/src/gym/trading/agent/news/hybrid.rs deleted file mode 100644 index 2898ec0..0000000 --- a/src/gym/trading/agent/news/hybrid.rs +++ /dev/null @@ -1,173 +0,0 @@ -use std::sync::Arc; - -use itertools::iproduct; -use serde::Serialize; - -use crate::{ - error::ChapatyResult, - gym::{ - AgentIdentifier, - trading::{ - action::{Action, Actions, MarketCloseCmd}, - agent::{ - Agent, - news::{ - breakout::{NewsBreakout, NewsBreakoutGrid}, - fade::{NewsFade, NewsFadeGrid}, - }, - }, - observation::Observation, - }, - }, -}; - -/// A decision agent that coordinates between [`NewsFade`] and -/// [`NewsBreakout`] strategies. -/// -/// This agent implements a **priority policy** for handling overlapping signals: -/// -/// # Policy -/// - **Breakout-first (or simultaneous):** -/// If [`NewsBreakout`] produces an entry signal before (or at the same -/// step as) [`NewsFade`], the breakout signal is executed and the fade -/// signal is ignored. -/// - **Fade-first, then Breakout:** -/// If [`NewsFade`] produces a signal first, the fade trade is opened. -/// If a breakout signal occurs afterwards, the fade trade is closed and replaced -/// with the breakout trade (“pivot”). -/// - **Fade-only:** -/// If only [`NewsFade`] signals, its trade is executed and maintained. -/// - **Breakout-only:** -/// If only [`NewsBreakout`] signals, its trade is executed. -/// - **Otherwise:** -/// The agent performs [`Actions::no_op`]. -/// -/// # Motivation -/// The policy reflects the assumption that a breakout move carries stronger -/// informational value than a mean-reversion fade. Breakout signals therefore -/// dominate whenever they appear, even retroactively displacing an open fade trade. -/// -/// # Example Timeline -/// ```text -/// t0: Fade signals -> enter Fade trade -/// t1: Breakout signals -> close Fade, enter Breakout -/// t2: No new signals -> hold Breakout -/// ``` -/// -/// # Design Notes & Limitations -/// -/// This agent's pivot logic is designed for scenarios where news events -/// are distinct and don't result in overlapping, long-lived trades. -/// -/// When using `EpisodeLength::Infinite`, it's possible for a trade -/// from a much earlier news event to remain open. The current implementation -/// would incorrectly close this old trade if a new breakout signal appears, -/// as it doesn't correlate signals to specific trades. -/// -/// For typical backtesting with finite episode lengths (e.g., daily, weekly, -/// or monthly resets), this is not an issue. -/// -/// See also: [`NewsFade`], [`NewsBreakout`]. -#[derive(Debug, Clone, Copy, Serialize)] -pub struct NewsHybrid { - pub breakout: NewsBreakout, - pub fade: NewsFade, -} - -impl Agent for NewsHybrid { - fn act(&mut self, obs: Observation) -> ChapatyResult { - // 1. Get Proposals (Ask both sub-agents) - // We clone 'obs' because the sub-agents need their own view - let fade_actions = self.fade.act(obs.clone())?; - let breakout_actions = self.breakout.act(obs.clone())?; - - let any_breakout_signal = - breakout_actions.any_open_action(&self.breakout.ohlcv_id().into()); - let any_fade_signal = fade_actions.any_open_action(&self.fade.ohlcv_id().into()); - - // === PRIORITY 1: Breakout Signal === - if any_breakout_signal { - // "Pivot" Logic: - // If the FADE agent is currently in a trade, we must close it - // to make room for the Breakout trade. - let fade_agent_id = self.fade.identifier(); - - if let Some((market_id, state)) = obs.states.find_active_trade_for_agent(&fade_agent_id) - { - // Construct Close Command - let close_cmd = MarketCloseCmd { - agent_id: fade_agent_id, - trade_id: state.trade_id(), - // Use the helper we defined earlier (State::quantity) - quantity: Some(state.quantity()), - }; - - return Ok(breakout_actions.with_action(market_id, Action::MarketClose(close_cmd))); - } else { - // No conflict, just execute breakout - return Ok(breakout_actions); - } - } - - // === PRIORITY 2: Fade Signal === - if any_fade_signal { - // Dominance Check: - // If the BREAKOUT agent is already in a trade, ignore the fade signal. - // Breakout trades are "stronger" and shouldn't be interrupted by a fade. - let breakout_id = self.breakout.identifier(); - - if obs - .states - .find_active_trade_for_agent(&breakout_id) - .is_some() - { - // Breakout dominates. Ignore Fade signal. - return Ok(Actions::no_op()); - } else { - // No conflict, execute fade - return Ok(fade_actions); - } - } - - // === Default === - Ok(Actions::no_op()) - } - - fn identifier(&self) -> AgentIdentifier { - AgentIdentifier::Named(Arc::new("NewsHybrid".to_string())) - } - - fn reset(&mut self) { - self.breakout.reset(); - self.fade.reset(); - } -} - -// ================================================================================================ -// Builder for `pub struct AdaptiveNewsAgent` Agent -// ================================================================================================ - -pub struct NewsHybridGrid { - pub fade: NewsFadeGrid, - pub breakout: NewsBreakoutGrid, -} - -impl NewsHybridGrid { - pub fn build(self) -> Vec<(usize, NewsHybrid)> { - let breakout_agents = self.breakout.build(); - let fade_agents = self.fade.build(); - - iproduct!(breakout_agents, fade_agents) - .enumerate() - .map(|(uid, (breakout, fade))| { - ( - uid, - NewsHybrid { - breakout: breakout.1, - fade: fade.1, - }, - ) - }) - .collect() - } -} diff --git a/src/gym/trading/config.rs b/src/gym/trading/config.rs index b5b8453..87f6197 100644 --- a/src/gym/trading/config.rs +++ b/src/gym/trading/config.rs @@ -10,18 +10,18 @@ use strum::{Display, EnumCount, EnumIter, EnumString, IntoStaticStr}; use crate::{ ApiKey, EndpointUrl, SelfHostedApi, data::{ + batch_indicator::{BatchOhlcvIndicator, SmaWindow}, common::{ProfileAggregation, RiskMetricsConfig}, - config::{ - EconomicCalendarConfig, OhlcvFutureConfig, OhlcvSpotConfig, TpoFutureConfig, - TpoSpotConfig, TradeSpotConfig, VolumeProfileSpotConfig, - }, domain::{ ContractMonth, ContractYear, CountryCode, DataBroker, EconomicCategory, EconomicEventImpact, Exchange, FutureContract, FutureRoot, Period, SpotPair, Symbol, }, episode::EpisodeLength, filter::{EconomicCalendarPolicy, FilterConfig}, - indicator::{SmaWindow, TechnicalIndicator}, + query::{ + EconomicCalendarQuery, OhlcvFutureQuery, OhlcvSpotQuery, TpoFutureQuery, TpoSpotQuery, + TradeSpotQuery, VolumeProfileSpotQuery, + }, }, error::{ChapatyResult, EnvError}, gym::InvalidActionPenalty, @@ -527,7 +527,7 @@ impl From for EnvConfig { let source = self_hosted_source(); match preset { EnvPreset::BinanceBtcUsdt1d => { - let market_config = OhlcvSpotConfig { + let market_config = OhlcvSpotQuery { broker: DataBroker::Binance, symbol: Symbol::Spot(SpotPair::BtcUsdt), period: Period::Day(1), @@ -546,7 +546,7 @@ impl From for EnvConfig { .with_filter_config(filter) } EnvPreset::BinanceBtcUsdt1m => { - let market_config = OhlcvSpotConfig { + let market_config = OhlcvSpotQuery { broker: DataBroker::Binance, symbol: Symbol::Spot(SpotPair::BtcUsdt), period: Period::Minute(1), @@ -564,7 +564,7 @@ impl From for EnvConfig { .with_filter_config(filter) } EnvPreset::BinanceBtcUsdt1m15m => { - let ohlcv_1m = OhlcvSpotConfig { + let ohlcv_1m = OhlcvSpotQuery { broker: DataBroker::Binance, symbol: Symbol::Spot(SpotPair::BtcUsdt), exchange: Some(Exchange::Binance), @@ -572,7 +572,7 @@ impl From for EnvConfig { batch_size: 1000, indicators: Vec::new(), }; - let ohlcv_15m = OhlcvSpotConfig { + let ohlcv_15m = OhlcvSpotQuery { broker: DataBroker::Binance, symbol: Symbol::Spot(SpotPair::BtcUsdt), exchange: Some(Exchange::Binance), @@ -591,7 +591,7 @@ impl From for EnvConfig { .with_filter_config(filter) } EnvPreset::NinjaTraderCme6eh61m5mUsEmpHigh => { - let ohlcv_1m = OhlcvFutureConfig { + let ohlcv_1m = OhlcvFutureQuery { broker: DataBroker::NinjaTrader, symbol: Symbol::Future(FutureContract { root: FutureRoot::EurUsd, @@ -603,7 +603,7 @@ impl From for EnvConfig { batch_size: 1000, indicators: vec![], }; - let ohlcv_5m = OhlcvFutureConfig { + let ohlcv_5m = OhlcvFutureQuery { broker: DataBroker::NinjaTrader, symbol: Symbol::Future(FutureContract { root: FutureRoot::EurUsd, @@ -615,7 +615,7 @@ impl From for EnvConfig { batch_size: 1000, indicators: vec![], }; - let calendar = EconomicCalendarConfig { + let calendar = EconomicCalendarQuery { broker: DataBroker::InvestingCom, data_source: None, country_code: Some(CountryCode::Us), @@ -636,7 +636,7 @@ impl From for EnvConfig { .with_trade_hint(4) } EnvPreset::NinjaTraderCme6eh61mUsEmpHighEventsOnly => { - let ohlcv = OhlcvFutureConfig { + let ohlcv = OhlcvFutureQuery { broker: DataBroker::NinjaTrader, symbol: Symbol::Future(FutureContract { root: FutureRoot::EurUsd, @@ -648,7 +648,7 @@ impl From for EnvConfig { batch_size: 1000, indicators: vec![], }; - let calendar = EconomicCalendarConfig { + let calendar = EconomicCalendarQuery { broker: DataBroker::InvestingCom, data_source: None, country_code: Some(CountryCode::Us), @@ -669,7 +669,7 @@ impl From for EnvConfig { .with_trade_hint(2) } EnvPreset::NinjaTraderCme6eh61m5mUsEmpHighEventsOnly => { - let ohlcv_1m = OhlcvFutureConfig { + let ohlcv_1m = OhlcvFutureQuery { broker: DataBroker::NinjaTrader, symbol: Symbol::Future(FutureContract { root: FutureRoot::EurUsd, @@ -681,7 +681,7 @@ impl From for EnvConfig { batch_size: 1000, indicators: vec![], }; - let ohlcv_5m = OhlcvFutureConfig { + let ohlcv_5m = OhlcvFutureQuery { broker: DataBroker::NinjaTrader, symbol: Symbol::Future(FutureContract { root: FutureRoot::EurUsd, @@ -693,7 +693,7 @@ impl From for EnvConfig { batch_size: 1000, indicators: vec![], }; - let calendar = EconomicCalendarConfig { + let calendar = EconomicCalendarQuery { broker: DataBroker::InvestingCom, data_source: None, country_code: Some(CountryCode::Us), @@ -715,15 +715,15 @@ impl From for EnvConfig { .with_trade_hint(4) } EnvPreset::BinanceBtcUsdt1dSma20Sma50 => { - let market_config = OhlcvSpotConfig { + let market_config = OhlcvSpotQuery { broker: DataBroker::Binance, symbol: Symbol::Spot(SpotPair::BtcUsdt), exchange: Some(Exchange::Binance), period: Period::Day(1), batch_size: 1000, indicators: vec![ - TechnicalIndicator::Sma(SmaWindow(20)), - TechnicalIndicator::Sma(SmaWindow(50)), + BatchOhlcvIndicator::Sma(SmaWindow(20)), + BatchOhlcvIndicator::Sma(SmaWindow(50)), ], }; let filter = FilterConfig { @@ -736,7 +736,7 @@ impl From for EnvConfig { .with_filter_config(filter) } EnvPreset::BinanceBtcUsdt1h1mVolumeProfile1d100Usdt => { - let ohlcv_1h = OhlcvSpotConfig { + let ohlcv_1h = OhlcvSpotQuery { broker: DataBroker::Binance, symbol: Symbol::Spot(SpotPair::BtcUsdt), exchange: Some(Exchange::Binance), @@ -744,7 +744,7 @@ impl From for EnvConfig { batch_size: 1000, indicators: vec![], }; - let ohlcv_1m = OhlcvSpotConfig { + let ohlcv_1m = OhlcvSpotQuery { broker: DataBroker::Binance, symbol: Symbol::Spot(SpotPair::BtcUsdt), exchange: Some(Exchange::Binance), @@ -752,7 +752,7 @@ impl From for EnvConfig { batch_size: 1000, indicators: vec![], }; - let vp = VolumeProfileSpotConfig { + let vp = VolumeProfileSpotQuery { broker: DataBroker::Binance, symbol: Symbol::Spot(SpotPair::BtcUsdt), exchange: Some(Exchange::Binance), @@ -775,7 +775,7 @@ impl From for EnvConfig { .with_filter_config(filter) } EnvPreset::BinanceBtcUsdt1h1mTpo1d1Usdt => { - let ohlcv_1h = OhlcvSpotConfig { + let ohlcv_1h = OhlcvSpotQuery { broker: DataBroker::Binance, symbol: Symbol::Spot(SpotPair::BtcUsdt), exchange: Some(Exchange::Binance), @@ -783,7 +783,7 @@ impl From for EnvConfig { batch_size: 1000, indicators: vec![], }; - let ohlcv_1m = OhlcvSpotConfig { + let ohlcv_1m = OhlcvSpotQuery { broker: DataBroker::Binance, symbol: Symbol::Spot(SpotPair::BtcUsdt), exchange: Some(Exchange::Binance), @@ -791,7 +791,7 @@ impl From for EnvConfig { batch_size: 1000, indicators: vec![], }; - let tpo = TpoSpotConfig { + let tpo = TpoSpotQuery { broker: DataBroker::Binance, symbol: Symbol::Spot(SpotPair::BtcUsdt), exchange: Some(Exchange::Binance), @@ -814,7 +814,7 @@ impl From for EnvConfig { .with_filter_config(filter) } EnvPreset::NinjaTraderCme6eh61mTpo1d => { - let ohlcv = OhlcvFutureConfig { + let ohlcv = OhlcvFutureQuery { broker: DataBroker::NinjaTrader, symbol: Symbol::Future(FutureContract { root: FutureRoot::EurUsd, @@ -826,7 +826,7 @@ impl From for EnvConfig { batch_size: 1000, indicators: vec![], }; - let tpo = TpoFutureConfig { + let tpo = TpoFutureQuery { broker: DataBroker::NinjaTrader, symbol: Symbol::Future(FutureContract { root: FutureRoot::EurUsd, @@ -886,31 +886,31 @@ pub struct EnvConfig { // Market Data (RPC) + Computed Indicators // ======================================================================== /// OHLCV data from spot markets, optionally with technical indicators. - ohlcv_spot: Vec>, + ohlcv_spot: Vec>, /// OHLCV data from futures markets, optionally with technical indicators. - ohlcv_future: Vec>, + ohlcv_future: Vec>, /// Trade-level trade execution data. - trade_spot: Vec>, + trade_spot: Vec>, // ======================================================================== // Profile Data (External RPC) // ======================================================================== /// Time Price Opportunity (Market Profile) data for spot markets. - tpo_spot: Vec>, + tpo_spot: Vec>, /// Time Price Opportunity (Market Profile) data for futures markets. - tpo_future: Vec>, + tpo_future: Vec>, /// Volume Profile data for spot markets. - volume_profile_spot: Vec>, + volume_profile_spot: Vec>, // ======================================================================== // External Event Data // ======================================================================== /// Economic calendar events and news releases. - economic_calendar: Vec>, + economic_calendar: Vec>, // ======================================================================== // Processing Pipelines @@ -961,7 +961,7 @@ impl Default for EnvConfig { impl EnvConfig { /// Adds OHLCV spot market data from a specific source. - pub fn add_ohlcv_spot(self, source: DataSource, config: OhlcvSpotConfig) -> Self { + pub fn add_ohlcv_spot(self, source: DataSource, config: OhlcvSpotQuery) -> Self { Self { ohlcv_spot: update_source_group(self.ohlcv_spot, source, config), ..self @@ -969,7 +969,7 @@ impl EnvConfig { } /// Adds OHLCV futures market data from a specific source. - pub fn add_ohlcv_future(self, source: DataSource, config: OhlcvFutureConfig) -> Self { + pub fn add_ohlcv_future(self, source: DataSource, config: OhlcvFutureQuery) -> Self { Self { ohlcv_future: update_source_group(self.ohlcv_future, source, config), ..self @@ -977,7 +977,7 @@ impl EnvConfig { } /// Adds trade-level spot market data from a specific source. - pub fn add_trade_spot(self, source: DataSource, config: TradeSpotConfig) -> Self { + pub fn add_trade_spot(self, source: DataSource, config: TradeSpotQuery) -> Self { Self { trade_spot: update_source_group(self.trade_spot, source, config), ..self @@ -985,7 +985,7 @@ impl EnvConfig { } /// Adds TPO (Market Profile) spot data from a specific source. - pub fn add_tpo_spot(self, source: DataSource, config: TpoSpotConfig) -> Self { + pub fn add_tpo_spot(self, source: DataSource, config: TpoSpotQuery) -> Self { Self { tpo_spot: update_source_group(self.tpo_spot, source, config), ..self @@ -993,7 +993,7 @@ impl EnvConfig { } /// Adds TPO (Market Profile) futures data from a specific source. - pub fn add_tpo_future(self, source: DataSource, config: TpoFutureConfig) -> Self { + pub fn add_tpo_future(self, source: DataSource, config: TpoFutureQuery) -> Self { Self { tpo_future: update_source_group(self.tpo_future, source, config), ..self @@ -1004,7 +1004,7 @@ impl EnvConfig { pub fn add_volume_profile_spot( self, source: DataSource, - config: VolumeProfileSpotConfig, + config: VolumeProfileSpotQuery, ) -> Self { Self { volume_profile_spot: update_source_group(self.volume_profile_spot, source, config), @@ -1013,7 +1013,7 @@ impl EnvConfig { } /// Adds economic calendar events from a specific source. - pub fn add_economic_calendar(self, source: DataSource, config: EconomicCalendarConfig) -> Self { + pub fn add_economic_calendar(self, source: DataSource, config: EconomicCalendarQuery) -> Self { Self { economic_calendar: update_source_group(self.economic_calendar, source, config), ..self @@ -1084,31 +1084,31 @@ impl EnvConfig { // ================================================================================================ impl EnvConfig { - pub fn ohlcv_spot(&self) -> &[SourceGroup] { + pub fn ohlcv_spot(&self) -> &[SourceGroup] { &self.ohlcv_spot } - pub fn ohlcv_future(&self) -> &[SourceGroup] { + pub fn ohlcv_future(&self) -> &[SourceGroup] { &self.ohlcv_future } - pub fn trade_spot(&self) -> &[SourceGroup] { + pub fn trade_spot(&self) -> &[SourceGroup] { &self.trade_spot } - pub fn tpo_spot(&self) -> &[SourceGroup] { + pub fn tpo_spot(&self) -> &[SourceGroup] { &self.tpo_spot } - pub fn tpo_future(&self) -> &[SourceGroup] { + pub fn tpo_future(&self) -> &[SourceGroup] { &self.tpo_future } - pub fn volume_profile_spot(&self) -> &[SourceGroup] { + pub fn volume_profile_spot(&self) -> &[SourceGroup] { &self.volume_profile_spot } - pub fn economic_calendar(&self) -> &[SourceGroup] { + pub fn economic_calendar(&self) -> &[SourceGroup] { &self.economic_calendar } diff --git a/src/gym/trading/factory.rs b/src/gym/trading/factory.rs index 6036d9c..9d86b23 100644 --- a/src/gym/trading/factory.rs +++ b/src/gym/trading/factory.rs @@ -1,7 +1,7 @@ use crate::{ data::{ + batch_indicator::{BatchOhlcvIndicator, EmaWindow, RsiWindow, SmaWindow}, common::ProfileAggregation, - config::ConfigId, domain::{ Count, CountryCode, EconomicEventImpact, EconomicValue, ExecutionDepth, LiquiditySide, Price, Quantity, TradeId, @@ -13,7 +13,7 @@ use crate::{ VolumeProfileId, }, filter::{EconomicCalendarPolicy, TradingWindow, Weekday}, - indicator::{EmaWindow, RsiWindow, SmaWindow, TechnicalIndicator}, + query::QueryId, }, error::{ChapatyError, ChapatyResult, DataError, EnvError}, gym::trading::{ @@ -218,27 +218,27 @@ impl BuildCtx { // 2. Define a helper to process indicators for a specific parent LazyFrame let mut process_indicators = |parent_id: OhlcvId, source_lf: LazyFrame, - indicators: &[TechnicalIndicator]| + indicators: &[BatchOhlcvIndicator]| -> ChapatyResult<()> { for &ind in indicators { - let lf_result = compute_indicator(source_lf.clone(), ind)?; + let lf_result = ind.pre_compute(source_lf.clone())?; match ind { - TechnicalIndicator::Ema(EmaWindow(w)) => { + BatchOhlcvIndicator::Ema(EmaWindow(w)) => { let id = EmaId { parent: parent_id, length: EmaWindow(w), }; ema_map.insert(id, (schema.clone(), lf_result)); } - TechnicalIndicator::Sma(SmaWindow(w)) => { + BatchOhlcvIndicator::Sma(SmaWindow(w)) => { let id = SmaId { parent: parent_id, length: SmaWindow(w), }; sma_map.insert(id, (schema.clone(), lf_result)); } - TechnicalIndicator::Rsi(RsiWindow(w)) => { + BatchOhlcvIndicator::Rsi(RsiWindow(w)) => { let id = RsiId { parent: parent_id, length: RsiWindow(w), @@ -654,14 +654,6 @@ async fn fetch_groups( Ok(aggregated_map) } -fn compute_indicator(lf: LazyFrame, indicator: TechnicalIndicator) -> ChapatyResult { - match indicator { - TechnicalIndicator::Ema(w) => w.pre_compute_ema(lf), - TechnicalIndicator::Sma(w) => w.pre_compute_sma(lf), - TechnicalIndicator::Rsi(w) => w.pre_compute_rsi(lf), - } -} - fn apply_overlay( map: &mut HashMap, news_lf: &LazyFrame, @@ -1684,7 +1676,7 @@ mod test { struct IndicatorTestCase { name: &'static str, - indicator: TechnicalIndicator, + indicator: BatchOhlcvIndicator, expected_file: &'static str, } @@ -1702,17 +1694,17 @@ mod test { let test_cases = vec![ IndicatorTestCase { name: "EMA-20", - indicator: TechnicalIndicator::Ema(EmaWindow(20)), + indicator: BatchOhlcvIndicator::Ema(EmaWindow(20)), expected_file: "ema_20_daily.csv", }, IndicatorTestCase { name: "SMA-14", - indicator: TechnicalIndicator::Sma(SmaWindow(14)), + indicator: BatchOhlcvIndicator::Sma(SmaWindow(14)), expected_file: "sma_14_daily.csv", }, IndicatorTestCase { name: "RSI-14", - indicator: TechnicalIndicator::Rsi(RsiWindow(14)), + indicator: BatchOhlcvIndicator::Rsi(RsiWindow(14)), expected_file: "rsi_14_daily.csv", }, ]; @@ -1724,7 +1716,9 @@ mod test { let input_lf = load_ohlcv_fixture("binance-btc-usdt-8h.csv"); // 2. Compute (Simulating internal build step) - let result_lf = compute_indicator(input_lf, case.indicator) + let result_lf = case + .indicator + .pre_compute(input_lf) .unwrap_or_else(|_| panic!("Failed to compute {}", case.name)); // 3. Assert diff --git a/src/math.rs b/src/math.rs index e69dc8e..49e6e09 100644 --- a/src/math.rs +++ b/src/math.rs @@ -1,2 +1,8 @@ -pub mod indicator; +pub mod fair_value_gap; pub mod market_profile; +pub mod moving_averages; +pub mod oscillators; +pub mod swing; +pub mod traits; + +pub use traits::*; diff --git a/src/math/fair_value_gap.rs b/src/math/fair_value_gap.rs new file mode 100644 index 0000000..0136630 --- /dev/null +++ b/src/math/fair_value_gap.rs @@ -0,0 +1,1308 @@ +use std::{collections::VecDeque, fmt::Debug}; + +use chrono::{DateTime, Duration, Utc}; + +use crate::{ + data::{ + domain::Price, + event::{IndexedOhlcv, MarketEvent, Ohlcv}, + }, + math::StreamingIndicator, +}; + +const LHS: usize = 0; +const RHS: usize = 2; +const PATTERN_LENGTH: usize = 3; + +/// Defines the time to live (ttl) condition under which a Fair Value Gap expires. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum TtlPolicy { + /// Expires after a specific number of bars have passed since creation. + Bars(usize), + /// Expires after a specific time duration has passed since creation. + Time(Duration), + /// Never expires automatically. Stays open until completely filled. + #[default] + Filled, +} + +pub trait FairValueGapState: Debug + Clone + Send + Sync + 'static {} + +#[derive(Debug, Clone, Copy, PartialEq, Default)] +pub struct OpenState { + max_fill_percentage: f64, + touch_count: u32, +} + +impl OpenState { + pub fn max_fill_percentage(&self) -> f64 { + self.max_fill_percentage + } + + pub fn touch_count(&self) -> u32 { + self.touch_count + } +} + +impl FairValueGapState for OpenState {} + +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct ClosedState { + closed_time: DateTime, + touch_count: u32, +} + +impl ClosedState { + pub fn closed_time(&self) -> DateTime { + self.closed_time + } + + pub const fn max_fill_percentage(&self) -> f64 { + 1.0 + } + + pub fn touch_count(&self) -> u32 { + self.touch_count + } +} + +impl FairValueGapState for ClosedState {} + +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct ExpiredState { + expired_time: DateTime, + touch_count: u32, + final_fill_percentage: f64, +} + +impl ExpiredState { + pub fn expired_time(&self) -> DateTime { + self.expired_time + } + pub fn final_fill_percentage(&self) -> f64 { + self.final_fill_percentage + } + pub fn touch_count(&self) -> u32 { + self.touch_count + } +} + +impl FairValueGapState for ExpiredState {} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum FairValueGapDirection { + Bullish, + Bearish, +} + +/// Represents how a price candle interacted with a Fair Value Gap. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum GapInteraction { + /// The candle's price range completely missed the gap (no overlap). + Miss, + /// The candle's price range entered the gap, but did not pierce the far boundary. + Touch, + /// The candle's price range completely pierced the required boundary to fill the gap. + Fill, +} + +impl GapInteraction { + /// Returns true if the candle touched OR filled the gap. + pub fn is_touch(&self) -> bool { + matches!(self, Self::Touch | Self::Fill) + } + + /// Returns true strictly if the candle filled the gap. + pub fn is_fill(&self) -> bool { + matches!(self, Self::Fill) + } +} + +#[derive(Debug, Clone, Copy)] +pub struct FairValueGap { + direction: FairValueGapDirection, + creation_time: DateTime, + creation_index: usize, + top: Price, + bottom: Price, + state: S, +} + +#[derive(Debug, Clone, Copy)] +pub enum FairValueGapStatus { + Open(FairValueGap), + Closed(FairValueGap), + Expired(FairValueGap), +} + +impl MarketEvent for FairValueGapStatus { + fn point_in_time(&self) -> DateTime { + match self { + FairValueGapStatus::Open(gap) => gap.point_in_time(), + FairValueGapStatus::Closed(gap) => gap.point_in_time(), + FairValueGapStatus::Expired(gap) => gap.point_in_time(), + } + } +} + +impl MarketEvent for FairValueGap { + fn point_in_time(&self) -> DateTime { + self.creation_time + } +} + +impl FairValueGap { + pub fn direction(&self) -> FairValueGapDirection { + self.direction + } + + pub fn creation_time(&self) -> DateTime { + self.creation_time + } + + /// Returns the index of the OHLCV candle that created this gap. + pub fn creation_index(&self) -> usize { + self.creation_index + } + + pub fn top(&self) -> Price { + self.top + } + + pub fn bottom(&self) -> Price { + self.bottom + } + + pub fn state(&self) -> &S { + &self.state + } + + pub fn gap_size(&self) -> f64 { + (self.top.0 - self.bottom.0).abs() + } + + pub fn map(self, f: F) -> FairValueGap + where + F: FnOnce(S) -> NewState, + { + FairValueGap { + direction: self.direction, + creation_time: self.creation_time, + creation_index: self.creation_index, + top: self.top, + bottom: self.bottom, + state: f(self.state), + } + } + + /// Evaluates how a given candle's price action interacts with the gap's price zone. + /// + /// # The Overlap Logic (Filtering Breakaway Gaps) + /// A candle's traded range is a continuous interval defined as `[low, high]`. + /// The gap's price zone is defined as `[bottom, top]`. + /// + /// For a candle to interact with the gap, the market must have physically traded + /// inside that zone. Mathematically, two 1-dimensional intervals `[A, B]` and + /// `[C, D]` intersect if and only if `A < D` AND `B > C`. + /// + /// Applying this to our market data: + /// `candle.low < gap.top` AND `candle.high > gap.bottom` + /// + /// **Why this is critical:** + /// In markets that close (like traditional equities) or over weekends (like Forex), + /// the price can open drastically lower or higher than the previous close. + /// If a Bullish Gap exists at `[10.0, 15.0]`, and the market violently crashes + /// overnight to open at `5.0` and wicks to a high of `8.0`, the price is technically + /// "below" the gap. But because `candle.high (8.0)` is NOT `> gap.bottom (10.0)`, + /// the overlap check correctly identifies that the market teleported _over_ the + /// zone without ever actually trading inside it. It remains an untouched Miss. + pub fn evaluate_interaction(&self, candle: &Ohlcv) -> GapInteraction { + let overlaps = candle.low < self.top && candle.high > self.bottom; + + if !overlaps { + return GapInteraction::Miss; + } + + let is_filled = match self.direction { + FairValueGapDirection::Bullish => candle.low <= self.bottom, + FairValueGapDirection::Bearish => candle.high >= self.top, + }; + + if is_filled { + GapInteraction::Fill + } else { + GapInteraction::Touch + } + } +} + +impl FairValueGap { + /// Evaluates the incoming indexed candle against the open gap, considering TTL. + fn process_candle(self, indexed_candle: &IndexedOhlcv, ttl: TtlPolicy) -> FairValueGapStatus { + let candle = &indexed_candle.candle; + + // 1. Evaluate Price Action First via the interaction helper + let updated_gap = match self.evaluate_interaction(candle) { + GapInteraction::Fill => { + // Early return: If it fully fills, it closes immediately before TTL checks. + return FairValueGapStatus::Closed(self.into_closed(candle.point_in_time())); + } + GapInteraction::Touch => { + let gap_size = self.gap_size(); + let current_fill_pct = match self.direction { + FairValueGapDirection::Bullish => (self.top.0 - candle.low.0) / gap_size, + FairValueGapDirection::Bearish => (candle.high.0 - self.bottom.0) / gap_size, + }; + self.with_partial_fill(current_fill_pct) + } + GapInteraction::Miss => self, // pass-through + }; + + // 2. Evaluate TTL Expiration + let is_expired = match ttl { + TtlPolicy::Bars(limit) => { + indexed_candle + .index + .saturating_sub(updated_gap.creation_index()) + >= limit + } + TtlPolicy::Time(limit) => { + candle + .close_timestamp + .signed_duration_since(updated_gap.creation_time()) + >= limit + } + TtlPolicy::Filled => false, + }; + + if is_expired { + FairValueGapStatus::Expired(updated_gap.into_expired(candle.close_timestamp)) + } else { + FairValueGapStatus::Open(updated_gap) + } + } + + fn with_partial_fill(self, fill_pct: f64) -> Self { + let max_fill_percentage = self.state.max_fill_percentage.max(fill_pct.clamp(0.0, 1.0)); + self.map(|s| OpenState { + max_fill_percentage, + touch_count: s.touch_count + 1, + }) + } + + fn into_closed(self, closed_time: DateTime) -> FairValueGap { + self.map(|s| ClosedState { + closed_time, + touch_count: s.touch_count + 1, + }) + } + + fn into_expired(self, expired_time: DateTime) -> FairValueGap { + self.map(|s| ExpiredState { + expired_time, + touch_count: s.touch_count, + final_fill_percentage: s.max_fill_percentage, + }) + } +} + +#[derive(Debug, Clone)] +pub struct StreamingFairValueGap { + min_gap_size: f64, + ttl_policy: TtlPolicy, + buffer: VecDeque, + active_gaps: Vec>, + closed_gaps: Vec>, + expired_gaps: Vec>, +} + +impl Default for StreamingFairValueGap { + fn default() -> Self { + Self { + min_gap_size: f64::EPSILON, + ttl_policy: TtlPolicy::default(), + buffer: VecDeque::with_capacity(PATTERN_LENGTH), + active_gaps: Vec::new(), + closed_gaps: Vec::new(), + expired_gaps: Vec::new(), + } + } +} + +impl StreamingFairValueGap { + /// Sets the minimum gap size for the indicator. + /// + /// # Arguments + /// * `min_gap_size` - The minimum gap size to set. Must be > 0.0. + /// + /// # Panics + /// Panics if `min_gap_size` <= 0.0. + pub fn with_min_gap_size(self, min_gap_size: f64) -> Self { + assert!( + min_gap_size > 0.0, + "min_gap_size must be strictly positive (got {min_gap_size} which is <= 0.0)" + ); + Self { + min_gap_size, + ..self + } + } + + pub fn with_ttl_policy(self, ttl_policy: TtlPolicy) -> Self { + Self { ttl_policy, ..self } + } + + // Accessors for agent state inspection... + pub fn active_gaps(&self) -> &[FairValueGap] { + &self.active_gaps + } + pub fn closed_gaps(&self) -> &[FairValueGap] { + &self.closed_gaps + } + pub fn expired_gaps(&self) -> &[FairValueGap] { + &self.expired_gaps + } + + fn detect_gap(&self) -> Option> { + if self.buffer.len() < PATTERN_LENGTH { + return None; + } + + let lhs = &self.buffer[LHS].candle; + let rhs = &self.buffer[RHS].candle; + let rhs_index = self.buffer[RHS].index; + + let gap_up = rhs.low.0 - lhs.high.0; + let gap_down = lhs.low.0 - rhs.high.0; + + // A bullish and bearish gap can't coexist for the same candle pair. + debug_assert!( + !(gap_up >= self.min_gap_size && gap_down >= self.min_gap_size), + "detected bullish and bearish gap simultaneously (gap_up={gap_up}, gap_down={gap_down})" + ); + + let (direction, top, bottom) = if gap_up >= self.min_gap_size { + (FairValueGapDirection::Bullish, rhs.low, lhs.high) + } else if gap_down >= self.min_gap_size { + (FairValueGapDirection::Bearish, lhs.low, rhs.high) + } else { + return None; + }; + + Some(FairValueGap { + direction, + creation_time: rhs.close_timestamp, + creation_index: rhs_index, + top, + bottom, + state: OpenState::default(), + }) + } +} + +impl StreamingIndicator for StreamingFairValueGap { + type Input = IndexedOhlcv; + type Output<'a> = &'a [FairValueGap]; + + fn update(&mut self, indexed_candle: Self::Input) -> Self::Output<'_> { + // 1. Process active gaps against the new candle + let ttl = self.ttl_policy; + let closed_gaps = &mut self.closed_gaps; + let expired_gaps = &mut self.expired_gaps; + + self.active_gaps.retain_mut(|gap_ref| { + match gap_ref.process_candle(&indexed_candle, ttl) { + FairValueGapStatus::Open(updated_gap) => { + *gap_ref = updated_gap; + true // Keep in active + } + FairValueGapStatus::Closed(closed_gap) => { + closed_gaps.push(closed_gap); + false // Remove from active + } + FairValueGapStatus::Expired(expired_gap) => { + expired_gaps.push(expired_gap); + false // Remove from active + } + } + }); + + // 2. Update buffer and detect new gaps + if self.buffer.len() >= PATTERN_LENGTH { + self.buffer.pop_front(); + } + self.buffer.push_back(indexed_candle); + + if let Some(new_gap) = self.detect_gap() { + self.active_gaps.push(new_gap); + } + + self.active_gaps.as_slice() + } + + fn reset(&mut self) { + self.buffer.clear(); + self.active_gaps.clear(); + self.closed_gaps.clear(); + self.expired_gaps.clear(); + } +} +#[cfg(test)] +mod tests { + use super::*; + use crate::data::{domain::Quantity, event::Ohlcv}; + + // ========================================== + // === 1. Mocks & Helpers === + // ========================================== + + /// Parse RFC3339 timestamp string to DateTime. + fn ts(s: &str) -> DateTime { + DateTime::parse_from_rfc3339(s).unwrap().with_timezone(&Utc) + } + + /// A rapid builder for Indexed OHLCV candles to keep our test trajectories readable. + fn candle( + index: usize, + time: &str, + open: f64, + high: f64, + low: f64, + close: f64, + ) -> IndexedOhlcv { + assert!(high >= low, "Invalid mock candle: high {high} < low {low}"); + IndexedOhlcv { + index, + candle: Ohlcv { + open_timestamp: ts(time), + close_timestamp: ts(time), + open: Price(open), + high: Price(high), + low: Price(low), + close: Price(close), + volume: Quantity(100.0), // Adjust if your Volume wrapper is different + quote_asset_volume: None, + number_of_trades: None, + taker_buy_base_asset_volume: None, + taker_buy_quote_asset_volume: None, + }, + } + } + + /// Helper to assert floats with epsilon tolerance + fn assert_f64_eq(a: f64, b: f64) { + assert!( + (a - b).abs() < f64::EPSILON, + "Expected {} to equal {}", + a, + b + ); + } + + // ========================================== + // === 2. Core Invariant Proofs === + // ========================================== + + #[test] + fn simultaneous_bullish_and_bearish_gap_is_impossible() { + // PROOF: + // Bullish Gap requires: gap_up > 0 => C3.low > C1.high + // Bearish Gap requires: gap_down > 0 => C1.low > C3.high + // For a valid candle, High >= Low always. + // If both gaps existed: + // - gap_up: C3.low > C1.high > 0 + // - gap_down: C1.low > C3.high > 0 + // + // As C1.low is greater than C3.high (gap_down) and C3.high >= C3.low (valid candle), + // we get C1.low > C3.high >= C3.low > C1.high > 0, by extending the left side of the + // inequality of gap_up. + // + // This transitively means C1.low > C1.high, which is a contradiction. Hence, the gap_up + // and gap_down cannot both exist simultaneously. + + let mut fvg = StreamingFairValueGap::default().with_min_gap_size(0.1); + + // Feed an erratic sequence to ensure the math holds and the debug_assert never fires + let trajectory = vec![ + candle(1, "2026-05-24T10:00:00Z", 50., 100., 10., 50.), // Massive range + candle(2, "2026-05-24T10:01:00Z", 50., 50., 50., 50.), // Inside doji + candle(3, "2026-05-24T10:02:00Z", 10., 10., 10., 10.), // Exact bottom touch + ]; + + for c in trajectory { + let _ = fvg.update(c); + } + + assert_eq!(fvg.active_gaps.len(), 0); + } + + // ========================================== + // === 3. Detection & Noise Filtering === + // ========================================== + + #[test] + fn filters_noise_below_min_gap_size() { + let mut indicator = StreamingFairValueGap::default().with_min_gap_size(2.0); + + // Gap size will be 11.0 - 10.0 = 1.0. + // Since 1.0 < min_gap_size (2.0), it must be rejected as noise. + indicator.update(candle(1, "2026-05-24T10:00:00Z", 10., 10., 5., 8.)); // C1 High = 10 + indicator.update(candle(2, "2026-05-24T10:01:00Z", 10., 12., 8., 11.)); // C2 + indicator.update(candle(3, "2026-05-24T10:02:00Z", 12., 15., 11., 14.)); // C3 Low = 11 + + assert!(indicator.active_gaps.is_empty()); + } + + #[test] + fn detects_bullish_and_bearish_fvgs() { + let mut indicator = StreamingFairValueGap::default().with_min_gap_size(1.0); + + // === Bullish Sequence === + indicator.update(candle(1, "2026-05-24T10:00:00Z", 10., 10., 5., 8.)); // C1 High = 10 + indicator.update(candle(2, "2026-05-24T10:01:00Z", 10., 12., 8., 11.)); // C2 + indicator.update(candle(3, "2026-05-24T10:02:00Z", 15., 20., 15., 18.)); // C3 Low = 15 + + assert_eq!(indicator.active_gaps.len(), 1); + let gap = indicator.active_gaps[0]; + assert_eq!(gap.direction(), FairValueGapDirection::Bullish); + assert_eq!(gap.bottom().0, 10.0); + assert_eq!(gap.top().0, 15.0); + assert_f64_eq(gap.gap_size(), 5.0); + + indicator.reset(); + + // === Bearish Sequence === + indicator.update(candle(4, "2026-05-24T10:00:00Z", 20., 25., 20., 22.)); // C1 Low = 20 + indicator.update(candle(5, "2026-05-24T10:01:00Z", 18., 22., 15., 16.)); // C2 + indicator.update(candle(6, "2026-05-24T10:02:00Z", 12., 15., 10., 11.)); // C3 High = 15 + + assert_eq!(indicator.active_gaps.len(), 1); + let gap = indicator.active_gaps[0]; + assert_eq!(gap.direction(), FairValueGapDirection::Bearish); + assert_eq!(gap.top().0, 20.0); + assert_eq!(gap.bottom().0, 15.0); + assert_f64_eq(gap.gap_size(), 5.0); + } + + // ========================================== + // === 4. State Management (Active/Hist) === + // ========================================== + + #[test] + fn partial_fill_updates_active_state_and_clamps() { + let mut indicator = StreamingFairValueGap::default().with_min_gap_size(1.0); + + // 1. Create Bullish Gap: Top=15.0, Bottom=10.0, Size=5.0 + indicator.update(candle(1, "2026-05-24T10:00:00Z", 10., 10., 5., 8.)); + indicator.update(candle(2, "2026-05-24T10:01:00Z", 10., 12., 8., 11.)); + indicator.update(candle(3, "2026-05-24T10:02:00Z", 15., 20., 15., 18.)); + + // Verify Setup Assumption + assert_eq!( + indicator.active_gaps().len(), + 1, + "Assumption failed: Bullish gap was not created" + ); + let initial_gap = indicator.active_gaps()[0]; + assert_eq!(initial_gap.direction(), FairValueGapDirection::Bullish); + assert_eq!(initial_gap.top().0, 15.0); + assert_eq!(initial_gap.bottom().0, 10.0); + assert_f64_eq(initial_gap.gap_size(), 5.0); + + // 2. Partial Fill: Wick down to 12.5 (50% fill) + indicator.update(candle(4, "2026-05-24T10:03:00Z", 18., 18., 12.5, 17.)); + + assert_eq!(indicator.active_gaps().len(), 1); + assert_eq!(indicator.closed_gaps().len(), 0); // Still active + + let gap = indicator.active_gaps()[0]; + assert_eq!(gap.state().touch_count(), 1); + assert_f64_eq(gap.state().max_fill_percentage(), 0.5); // (15 - 12.5) / 5 + + // 3. Lesser Fill: Wick down to 14.0 (20% fill). Should NOT reduce max_fill. + indicator.update(candle(5, "2026-05-24T10:04:00Z", 18., 18., 14.0, 17.)); + + let gap = indicator.active_gaps()[0]; + assert_eq!(gap.state().touch_count(), 2); + assert_f64_eq(gap.state().max_fill_percentage(), 0.5); // Retains 50% max + } + + #[test] + fn full_fill_migrates_gap_to_closed() { + let mut indicator = StreamingFairValueGap::default().with_min_gap_size(1.0); + + // 1. Create Bearish Gap: Top=20.0, Bottom=15.0, Size=5.0 + indicator.update(candle(1, "2026-05-24T10:00:00Z", 20., 25., 20., 22.)); // C1 Low=20 + indicator.update(candle(2, "2026-05-24T10:01:00Z", 18., 22., 12., 16.)); // C2 Low down to 12 + indicator.update(candle(3, "2026-05-24T10:02:00Z", 12., 15., 10., 11.)); // C3 High=15 + + // Verify Setup Assumption + assert_eq!( + indicator.active_gaps().len(), + 1, + "Assumption failed: Bearish gap was not created" + ); + let initial_gap = indicator.active_gaps()[0]; + assert_eq!(initial_gap.direction(), FairValueGapDirection::Bearish); + assert_eq!(initial_gap.top().0, 20.0); + assert_eq!(initial_gap.bottom().0, 15.0); + assert_eq!(indicator.closed_gaps().len(), 0); + + // 2. Miss (Price drops further away from the gap) + indicator.update(candle(4, "2026-05-24T10:03:00Z", 10., 12., 5., 8.)); + assert_eq!(indicator.active_gaps()[0].state().touch_count(), 0); + + // 3. Full Fill (Price violently rallies through Top of 20.0) + indicator.update(candle(5, "2026-05-24T10:04:00Z", 12., 21., 12., 21.)); // High = 21 >= 20 + + // Assert Migration + assert_eq!( + indicator.active_gaps().len(), + 0, + "Gap should be removed from active pool" + ); + assert_eq!( + indicator.closed_gaps().len(), + 1, + "Gap should be migrated to history" + ); + + let closed = indicator.closed_gaps()[0]; + assert_eq!(closed.direction(), FairValueGapDirection::Bearish); + assert_f64_eq(closed.state().max_fill_percentage(), 1.0); // Full fill is exactly 1.0 + assert_eq!(closed.state().touch_count(), 1); // Only took 1 touch to close + assert_eq!(closed.state().closed_time(), ts("2026-05-24T10:04:00Z")); // Time of the violating candle + } + + #[test] + fn boundary_exact_tick_is_a_miss() { + let mut indicator = StreamingFairValueGap::default().with_min_gap_size(1.0); + + // Create Bullish Gap: Top=15.0, Bottom=10.0 + indicator.update(candle(1, "2026-05-26T10:01:00Z", 10., 10., 5., 8.)); + indicator.update(candle(2, "2026-05-26T10:02:00Z", 10., 12., 8., 11.)); + indicator.update(candle(3, "2026-05-26T10:03:00Z", 15., 20., 15., 18.)); + + // Verify Setup Assumption + let initial_gap = indicator.active_gaps()[0]; + assert_eq!( + indicator.active_gaps().len(), + 1, + "Assumption failed: Bullish gap was not created" + ); + assert_eq!(initial_gap.direction(), FairValueGapDirection::Bullish); + assert_eq!(initial_gap.top().0, 15.0); + assert_eq!(initial_gap.bottom().0, 10.0); + + // Send a candle that wicks to EXACTLY 15.0 + // Because process_candle uses `candle.low < self.top`, this evaluates to false. + // It is mathematically defined as a Miss, NOT a touch/partial fill. + indicator.update(candle(4, "2026-05-26T10:04:00Z", 20., 20., 15.0, 20.)); + + let gap = indicator.active_gaps()[0]; + assert_eq!( + gap.state().touch_count(), + 0, + "Exact tick overlap should not increment touches" + ); + assert_f64_eq(gap.state().max_fill_percentage(), 0.0); + } + + #[test] + fn multiple_gaps_tracked_and_filled_independently() { + let mut indicator = StreamingFairValueGap::default().with_min_gap_size(1.0); + + // 1. Create Bullish Gap A (10 -> 15) + indicator.update(candle(1, "2026-05-26T10:01:00Z", 10., 10., 5., 8.)); + indicator.update(candle(2, "2026-05-26T10:02:00Z", 10., 20., 8., 11.)); // C2 High up to 20 + indicator.update(candle(3, "2026-05-26T10:03:00Z", 15., 22., 15., 18.)); // C3 High up to 22 + + // Verify Setup Assumption A + assert_eq!( + indicator.active_gaps().len(), + 1, + "Assumption failed: Gap A not created" + ); + assert_eq!( + indicator.active_gaps()[0].direction(), + FairValueGapDirection::Bullish + ); + assert_eq!(indicator.active_gaps()[0].top().0, 15.0); + assert_eq!(indicator.active_gaps()[0].bottom().0, 10.0); + + // 2. Create Bullish Gap B (25 -> 30) further up the trend + indicator.update(candle(4, "2026-05-26T10:04:00Z", 25., 25., 20., 22.)); + indicator.update(candle(5, "2026-05-26T10:05:00Z", 25., 28., 22., 26.)); + indicator.update(candle(6, "2026-05-26T10:06:00Z", 30., 35., 30., 32.)); + + // Verify Setup Assumption B + assert_eq!( + indicator.active_gaps().len(), + 2, + "Assumption failed: Gap B not created" + ); + assert_eq!( + indicator.active_gaps()[1].direction(), + FairValueGapDirection::Bullish + ); + assert_eq!(indicator.active_gaps()[1].top().0, 30.0); + assert_eq!(indicator.active_gaps()[1].bottom().0, 25.0); + + // 3. Price drops to 20. This completely fills Gap B (25->30), but only misses Gap A (10->15) + indicator.update(candle(7, "2026-05-26T10:07:00Z", 30., 30., 20., 25.)); + + assert_eq!(indicator.active_gaps().len(), 1, "Gap B should be closed"); + assert_eq!( + indicator.closed_gaps().len(), + 1, + "Gap B should be in history" + ); + + // Verify Gap A is still active and untouched (passed by value since it is Copy) + let active_gap = indicator.active_gaps()[0]; + assert_eq!(active_gap.bottom().0, 10.0); + assert_eq!(active_gap.top().0, 15.0); + assert_eq!(active_gap.state().touch_count(), 0); + + // Verify Gap B is closed (passed by value) + let closed_gap = indicator.closed_gaps()[0]; + assert_eq!(closed_gap.bottom().0, 25.0); + assert_eq!(closed_gap.top().0, 30.0); + assert_eq!(closed_gap.state().touch_count(), 1); + assert_f64_eq(closed_gap.state().max_fill_percentage(), 1.0); + } + + // ========================================== + // === 5. Time-To-Live (TTL) Expiration === + // ========================================== + + #[test] + fn ttl_expires_after_n_bars() { + // Expire if 2 or more bars have closed since creation + let mut indicator = StreamingFairValueGap::default() + .with_min_gap_size(1.0) + .with_ttl_policy(TtlPolicy::Bars(2)); + + // 1. Create Bullish Gap: Top=15.0, Bottom=10.0 + // C3 is the RHS candle, so creation_index = 3 + indicator.update(candle(1, "2026-05-26T10:01:00Z", 10., 10., 5., 8.)); + indicator.update(candle(2, "2026-05-26T10:02:00Z", 10., 20., 8., 11.)); // C2 High up to 20 + indicator.update(candle(3, "2026-05-26T10:03:00Z", 15., 20., 15., 18.)); + + // Verify Setup Assumption + assert_eq!( + indicator.active_gaps().len(), + 1, + "Assumption failed: Gap not created" + ); + assert_eq!( + indicator.active_gaps()[0].direction(), + FairValueGapDirection::Bullish + ); + assert_eq!(indicator.active_gaps()[0].creation_index(), 3); + assert_eq!(indicator.expired_gaps().len(), 0); + assert_eq!(indicator.closed_gaps().len(), 0, "No closed gaps at setup"); + + // 2. Candle 4 (Index 4). Diff = 4 - 3 = 1 bar. + // 1 < 2, so the gap remains active. + indicator.update(candle(4, "2026-05-26T10:04:00Z", 20., 25., 20., 22.)); + + assert_eq!(indicator.active_gaps().len(), 1); + assert_eq!(indicator.expired_gaps().len(), 0); + assert_eq!( + indicator.closed_gaps().len(), + 0, + "No closed gaps mid-flight" + ); + + // 3. Candle 5 (Index 5). Diff = 5 - 3 = 2 bars. + // 2 >= 2, so the gap should immediately expire. + indicator.update(candle(5, "2026-05-26T10:05:00Z", 20., 25., 20., 22.)); + + assert_eq!( + indicator.active_gaps().len(), + 0, + "Gap should be removed from active" + ); + assert_eq!( + indicator.expired_gaps().len(), + 1, + "Gap should be migrated to expired" + ); + assert_eq!( + indicator.closed_gaps().len(), + 0, + "No closed gaps after expiration" + ); + + let expired = indicator.expired_gaps()[0]; + assert_eq!(expired.creation_index(), 3); + assert_eq!(expired.state().expired_time(), ts("2026-05-26T10:05:00Z")); + } + + #[test] + fn ttl_expires_after_time_duration() { + // Expire if 5 minutes have passed since creation + let mut indicator = StreamingFairValueGap::default() + .with_min_gap_size(1.0) + .with_ttl_policy(TtlPolicy::Time(Duration::minutes(5))); + + // 1. Create Bullish Gap: Top=15.0, Bottom=10.0 + // C3 close_timestamp = "10:03:00Z" + indicator.update(candle(1, "2026-05-26T10:01:00Z", 10., 10., 5., 8.)); + indicator.update(candle(2, "2026-05-26T10:02:00Z", 10., 20., 8., 11.)); // C2 High up to 20 + indicator.update(candle(3, "2026-05-26T10:03:00Z", 15., 20., 15., 18.)); + + // Verify Setup Assumption + assert_eq!( + indicator.active_gaps().len(), + 1, + "Assumption failed: Gap not created" + ); + assert_eq!( + indicator.active_gaps()[0].direction(), + FairValueGapDirection::Bullish + ); + assert_eq!( + indicator.active_gaps()[0].creation_time(), + ts("2026-05-26T10:03:00Z") + ); + assert_eq!(indicator.closed_gaps().len(), 0, "No closed gaps at setup"); + + // 2. Candle at 10:07:00Z. Diff = 4 mins. + // 4 mins < 5 mins, so it remains active. + indicator.update(candle(4, "2026-05-26T10:07:00Z", 20., 25., 20., 22.)); + assert_eq!(indicator.active_gaps().len(), 1); + assert_eq!( + indicator.closed_gaps().len(), + 0, + "No closed gaps mid-flight" + ); + + // 3. Candle at 10:08:00Z. Diff = 5 mins. + // 5 mins >= 5 mins, gap expires. + indicator.update(candle(5, "2026-05-26T10:08:00Z", 20., 25., 20., 22.)); + assert_eq!(indicator.active_gaps().len(), 0); + assert_eq!(indicator.expired_gaps().len(), 1); + assert_eq!( + indicator.closed_gaps().len(), + 0, + "No closed gaps after expiration" + ); + } + + #[test] + fn expired_state_preserves_partial_fill_history() { + let mut indicator = StreamingFairValueGap::default() + .with_min_gap_size(1.0) + .with_ttl_policy(TtlPolicy::Bars(2)); + + // 1. Create Bullish Gap: Top=15.0, Bottom=10.0, Size=5.0 + indicator.update(candle(1, "2026-05-26T10:01:00Z", 10., 10., 5., 8.)); + indicator.update(candle(2, "2026-05-26T10:02:00Z", 10., 12., 8., 11.)); + indicator.update(candle(3, "2026-05-26T10:03:00Z", 15., 20., 15., 18.)); + + // Verify Setup Assumption + assert_eq!( + indicator.active_gaps().len(), + 1, + "Assumption failed: Gap not created" + ); + assert_eq!( + indicator.active_gaps()[0].direction(), + FairValueGapDirection::Bullish + ); + assert_eq!(indicator.closed_gaps().len(), 0, "No closed gaps at setup"); + assert_eq!( + indicator.expired_gaps().len(), + 0, + "No expired gaps at setup" + ); + + // 2. Partial Fill: Wick down to 12.5 (50% fill) on the very next bar + // This is 1 bar after creation, so it does NOT expire yet. + indicator.update(candle(4, "2026-05-26T10:04:00Z", 18., 18., 12.5, 17.)); + + // Verify state prior to expiration + assert_eq!(indicator.active_gaps().len(), 1); + assert_eq!(indicator.active_gaps()[0].state().touch_count(), 1); + + // 3. Expiration: Next bar runs away but triggers the 2-bar expiration limit. + indicator.update(candle(5, "2026-05-26T10:05:00Z", 20., 25., 20., 22.)); + + assert_eq!(indicator.active_gaps().len(), 0); + assert_eq!(indicator.closed_gaps().len(), 0); + assert_eq!(indicator.expired_gaps().len(), 1); + + // Verify that the ExpiredState successfully inherited the fill data from OpenState + let expired = indicator.expired_gaps()[0]; + assert_eq!( + expired.state().touch_count(), + 1, + "Should preserve the touch count before expiration" + ); + assert_f64_eq(expired.state().final_fill_percentage(), 0.5); + } + + #[test] + fn ttl_policy_filled_never_expires() { + // TtlPolicy::Filled ist der Standard. Die Lücke darf niemals von allein verfallen. + let mut indicator = StreamingFairValueGap::default() + .with_min_gap_size(1.0) + .with_ttl_policy(TtlPolicy::Filled); + + // 1. Bullish Gap erstellen: Top=15.0, Bottom=10.0 (Creation Index = 3) + indicator.update(candle(1, "2026-05-26T10:01:00Z", 10., 10., 5., 8.)); + indicator.update(candle(2, "2026-05-26T10:02:00Z", 10., 20., 8., 11.)); // C2 High up to 20 + indicator.update(candle(3, "2026-05-26T10:03:00Z", 15., 20., 15., 18.)); + + // Verify Setup Assumption + assert_eq!( + indicator.active_gaps().len(), + 1, + "Assumption failed: Gap not created" + ); + assert_eq!( + indicator.active_gaps()[0].direction(), + FairValueGapDirection::Bullish + ); + assert_eq!(indicator.active_gaps()[0].top().0, 15.0); + + // 2. Einen gewaltigen Sprung in die Zukunft simulieren (Index 1000, 10 Stunden später) + // Der Preis bleibt weit über der Lücke, sodass sie nicht gefüllt wird. + indicator.update(candle(1000, "2026-05-26T20:00:00Z", 20., 25., 20., 22.)); + + assert_eq!( + indicator.active_gaps().len(), + 1, + "Gap with TtlPolicy::Filled must remain active indefinitely" + ); + assert_eq!(indicator.expired_gaps().len(), 0); + } + + // ========================================== + // === 6. Edge Cases & Invariants === + // ========================================== + + #[test] + fn simultaneous_full_fill_and_expiration_results_in_closed_gap() { + // If a gap completely fills on the exact same candle that triggers its expiration, + // the fill wins. The price action happened *during* the candle. + let mut indicator = StreamingFairValueGap::default() + .with_min_gap_size(1.0) + .with_ttl_policy(TtlPolicy::Bars(2)); + + // 1. Create Bullish Gap: Top=15.0, Bottom=10.0 (Creation Index = 3) + indicator.update(candle(1, "2026-05-26T10:01:00Z", 10., 10., 5., 8.)); + indicator.update(candle(2, "2026-05-26T10:02:00Z", 10., 20., 8., 11.)); // C2 High up to 20 + indicator.update(candle(3, "2026-05-26T10:03:00Z", 15., 20., 15., 18.)); + + // Verify Setup Assumption + assert_eq!( + indicator.active_gaps().len(), + 1, + "Assumption failed: Gap not created" + ); + assert_eq!( + indicator.active_gaps()[0].direction(), + FairValueGapDirection::Bullish + ); + assert_eq!(indicator.active_gaps()[0].top().0, 15.0); + assert_eq!(indicator.active_gaps()[0].bottom().0, 10.0); + assert_eq!( + indicator.closed_gaps().len(), + 0, + "Assumption failed: closed_gaps should be empty" + ); + assert_eq!( + indicator.expired_gaps().len(), + 0, + "Assumption failed: expired_gaps should be empty" + ); + + // 2. The very next candle misses the gap completely. + indicator.update(candle(4, "2026-05-26T10:04:00Z", 20., 25., 20., 22.)); + + // 3. The expiry candle! Index 5 triggers the 2-bar expiration. + // AT THE EXACT SAME TIME, it has a violent wick down to 5.0, fully covering the gap. + indicator.update(candle(5, "2026-05-26T10:05:00Z", 20., 20., 5.0, 10.)); + + // Verify the invariants + assert_eq!(indicator.active_gaps().len(), 0); + assert_eq!( + indicator.expired_gaps().len(), + 0, + "Gap must NOT be expired. It was fully filled during the candle lifespan." + ); + assert_eq!( + indicator.closed_gaps().len(), + 1, + "Gap MUST be closed because the fill happened before the candle closed." + ); + + let closed = indicator.closed_gaps()[0]; + assert_f64_eq(closed.state().max_fill_percentage(), 1.0); + } + + #[test] + fn simultaneous_partial_fill_and_expiration_preserves_final_action() { + // If a gap partially fills on the exact same candle that triggers its expiration, + // it must expire, BUT it must successfully capture the partial fill from its final moments. + let mut indicator = StreamingFairValueGap::default() + .with_min_gap_size(1.0) + .with_ttl_policy(TtlPolicy::Bars(2)); + + // 1. Create Bullish Gap: Top=15.0, Bottom=10.0 (Creation Index = 3) + indicator.update(candle(1, "2026-05-26T10:01:00Z", 10., 10., 5., 8.)); + indicator.update(candle(2, "2026-05-26T10:02:00Z", 10., 20., 8., 11.)); // Raise C2 High to 20 to close C2-C4 distance + indicator.update(candle(3, "2026-05-26T10:03:00Z", 15., 20., 15., 18.)); + + // Verify Setup Assumption + assert_eq!( + indicator.active_gaps().len(), + 1, + "Assumption failed: Gap not created" + ); + assert_eq!( + indicator.active_gaps()[0].direction(), + FairValueGapDirection::Bullish + ); + assert_eq!(indicator.active_gaps()[0].top().0, 15.0); + assert_eq!(indicator.active_gaps()[0].bottom().0, 10.0); + assert_eq!( + indicator.closed_gaps().len(), + 0, + "Assumption failed: closed_gaps should be empty" + ); + assert_eq!( + indicator.expired_gaps().len(), + 0, + "Assumption failed: expired_gaps should be empty" + ); + + // 2. The very next candle misses the gap completely. + indicator.update(candle(4, "2026-05-26T10:04:00Z", 20., 25., 20., 22.)); + + // 3. The expiry candle! Index 5 triggers the 2-bar expiration. + // It drops to 12.5, filling exactly 50% of the gap right before time runs out. + indicator.update(candle(5, "2026-05-26T10:05:00Z", 20., 20., 12.5, 18.)); + + // Verify the invariants + assert_eq!(indicator.active_gaps().len(), 0); + assert_eq!(indicator.closed_gaps().len(), 0); + assert_eq!(indicator.expired_gaps().len(), 1); + + let expired = indicator.expired_gaps()[0]; + assert_eq!( + expired.state().touch_count(), + 1, + "Must register the touch from the expiring candle" + ); + assert_f64_eq(expired.state().final_fill_percentage(), 0.5); // Correctly captured the 50% fill right before death + } + + #[test] + fn ttl_policy_filled_migrates_to_closed_on_full_fill() { + // Rule: A gap with TtlPolicy::Filled can NEVER expire. + // When it eventually fills, it must explicitly migrate to Closed. + let mut indicator = StreamingFairValueGap::default() + .with_min_gap_size(1.0) + .with_ttl_policy(TtlPolicy::Filled); + + // 1. Create Bullish Gap: Top=15.0, Bottom=10.0 + indicator.update(candle(1, "2026-05-26T10:01:00Z", 10., 10., 5., 8.)); + indicator.update(candle(2, "2026-05-26T10:02:00Z", 10., 20., 8., 11.)); // C2 High up to 20 + indicator.update(candle(3, "2026-05-26T10:03:00Z", 15., 20., 15., 18.)); + + // Verify Setup Assumption + assert_eq!( + indicator.active_gaps().len(), + 1, + "Assumption failed: Gap not created" + ); + assert_eq!( + indicator.active_gaps()[0].direction(), + FairValueGapDirection::Bullish + ); + assert_eq!(indicator.active_gaps()[0].top().0, 15.0); + assert_eq!(indicator.active_gaps()[0].bottom().0, 10.0); + assert_eq!( + indicator.closed_gaps().len(), + 0, + "Assumption failed: closed_gaps should be empty" + ); + assert_eq!( + indicator.expired_gaps().len(), + 0, + "Assumption failed: expired_gaps should be empty" + ); + + // 2. Advance far into the future (Index 1000) - Gap remains open + indicator.update(candle(1000, "2026-05-26T20:00:00Z", 20., 25., 20., 22.)); + assert_eq!(indicator.active_gaps().len(), 1); + assert_eq!(indicator.closed_gaps().len(), 0); + assert_eq!(indicator.expired_gaps().len(), 0); + + // 3. Price finally crashes down and fills the gap + indicator.update(candle(1001, "2026-05-26T20:01:00Z", 20., 20., 8.0, 10.)); + + assert_eq!( + indicator.active_gaps().len(), + 0, + "Gap should be removed from active pool" + ); + assert_eq!( + indicator.expired_gaps().len(), + 0, + "Gap with TtlPolicy::Filled must NEVER enter Expired state" + ); + assert_eq!( + indicator.closed_gaps().len(), + 1, + "Gap MUST be correctly migrated to Closed state upon fill" + ); + + let closed = indicator.closed_gaps()[0]; + assert_f64_eq(closed.state().max_fill_percentage(), 1.0); + } + + #[test] + fn breakaway_gaps_do_not_touch_or_fill_fvg() { + // This tests the non-continuous pricing invariant. + // If the market completely teleports over the FVG zone without trading inside it, + // the gap must remain open and untouched. + let mut indicator = StreamingFairValueGap::default().with_min_gap_size(1.0); + + // ========================================== + // SCENARIO A: Bullish FVG Bypassed + // ========================================== + + // 1. Create Bullish Gap: Top=15.0, Bottom=10.0 + indicator.update(candle(1, "2026-05-26T10:01:00Z", 10., 10., 5., 8.)); + indicator.update(candle(2, "2026-05-26T10:02:00Z", 10., 12., 8., 11.)); + indicator.update(candle(3, "2026-05-26T10:03:00Z", 15., 20., 15., 18.)); + + assert_eq!(indicator.active_gaps().len(), 1, "Bullish gap created"); + + // 2. A massive gap DOWN completely below the FVG (High=8.0, Low=5.0) + indicator.update(candle(4, "2026-05-26T10:04:00Z", 8., 8., 5., 6.)); + + assert_eq!( + indicator.active_gaps().len(), + 1, + "Bullish gap must remain active because it was leaped over" + ); + let bullish_gap = indicator.active_gaps()[0]; + assert_eq!( + bullish_gap.state().touch_count(), + 0, + "The market never traded inside the Bullish gap" + ); + assert_f64_eq(bullish_gap.state().max_fill_percentage(), 0.0); + + indicator.reset(); + + // ========================================== + // SCENARIO B: Bearish FVG Bypassed + // ========================================== + + // 1. Create Bearish Gap: Top=20.0, Bottom=15.0 + indicator.update(candle(1, "2026-05-26T10:01:00Z", 20., 25., 20., 22.)); + indicator.update(candle(2, "2026-05-26T10:02:00Z", 18., 25., 15., 16.)); // C2 High up to 25 + indicator.update(candle(3, "2026-05-26T10:03:00Z", 12., 15., 10., 11.)); + + assert_eq!(indicator.active_gaps().len(), 1, "Bearish gap created"); + + // 2. A massive gap UP completely above the FVG (High=30.0, Low=25.0) + indicator.update(candle(4, "2026-05-26T10:04:00Z", 25., 30., 25., 28.)); + + assert_eq!( + indicator.active_gaps().len(), + 1, + "Bearish gap must remain active because it was leaped over" + ); + let bearish_gap = indicator.active_gaps()[0]; + assert_eq!( + bearish_gap.state().touch_count(), + 0, + "The market never traded inside the Bearish gap" + ); + assert_f64_eq(bearish_gap.state().max_fill_percentage(), 0.0); + } + + #[test] + fn gap_interaction_evaluates_overlap_and_fills_correctly() { + // --- 1. Bullish Gap Setup (Top=15.0, Bottom=10.0) --- + let bullish_gap = FairValueGap { + direction: FairValueGapDirection::Bullish, + creation_time: ts("2026-05-24T10:00:00Z"), + creation_index: 0, + top: Price(15.0), + bottom: Price(10.0), + state: OpenState::default(), + }; + + // A. Bullish Miss (Price stays entirely above the gap) + let miss_above = candle(1, "2026-05-24T10:01:00Z", 20., 25., 15.0, 22.).candle; + let interaction = bullish_gap.evaluate_interaction(&miss_above); + assert_eq!(interaction, GapInteraction::Miss); + assert!(!interaction.is_touch()); + + // B. Bullish Breakaway Miss (Price teleports completely below the gap) + let breakaway_below = candle(2, "2026-05-24T10:02:00Z", 5., 8., 2., 6.).candle; + let interaction = bullish_gap.evaluate_interaction(&breakaway_below); + assert_eq!(interaction, GapInteraction::Miss); + + // C. Bullish Touch (Wick enters the gap: low is 12.0) + let touch_candle = candle(3, "2026-05-24T10:03:00Z", 18., 18., 12., 15.).candle; + let interaction = bullish_gap.evaluate_interaction(&touch_candle); + assert_eq!(interaction, GapInteraction::Touch); + assert!(interaction.is_touch()); + assert!(!interaction.is_fill()); + + // D. Bullish Fill (Wick drops below the bottom of 10.0) + let fill_candle = candle(4, "2026-05-24T10:04:00Z", 18., 18., 9., 15.).candle; + let interaction = bullish_gap.evaluate_interaction(&fill_candle); + assert_eq!(interaction, GapInteraction::Fill); + assert!(interaction.is_touch()); // A fill MUST register as a touch + assert!(interaction.is_fill()); + + // --- 2. Bearish Gap Setup (Top=20.0, Bottom=15.0) --- + let bearish_gap = FairValueGap { + direction: FairValueGapDirection::Bearish, + creation_time: ts("2026-05-24T10:00:00Z"), + creation_index: 0, + top: Price(20.0), + bottom: Price(15.0), + state: OpenState::default(), + }; + + // A. Bearish Miss (Price stays entirely below the gap) + let miss_below = candle(5, "2026-05-24T10:01:00Z", 10., 15.0, 5., 12.).candle; + assert_eq!( + bearish_gap.evaluate_interaction(&miss_below), + GapInteraction::Miss + ); + + // B. Bearish Breakaway Miss (Price teleports completely above the gap) + let breakaway_above = candle(6, "2026-05-24T10:02:00Z", 25., 30., 22., 28.).candle; + assert_eq!( + bearish_gap.evaluate_interaction(&breakaway_above), + GapInteraction::Miss + ); + + // C. Bearish Touch (Wick enters the gap: high is 18.0) + let touch_bear = candle(7, "2026-05-24T10:03:00Z", 10., 18., 10., 12.).candle; + assert_eq!( + bearish_gap.evaluate_interaction(&touch_bear), + GapInteraction::Touch + ); + + // D. Bearish Fill (Wick spikes above the top of 20.0) + let fill_bear = candle(8, "2026-05-24T10:04:00Z", 10., 21., 10., 12.).candle; + let interaction = bearish_gap.evaluate_interaction(&fill_bear); + assert_eq!(interaction, GapInteraction::Fill); + assert!(interaction.is_touch()); + assert!(interaction.is_fill()); + } +} diff --git a/src/math/indicator.rs b/src/math/indicator.rs deleted file mode 100644 index 2e53c67..0000000 --- a/src/math/indicator.rs +++ /dev/null @@ -1,225 +0,0 @@ -use serde::{Deserialize, Serialize}; -use std::collections::VecDeque; - -/// A trait for incremental indicators. -/// Designed to be object-safe so agents can hold `Box`. -pub trait StreamingIndicator: std::fmt::Debug + Send + Sync { - /// Update the indicator with the latest scalar value (e.g., close price). - /// Returns `Some(value)` if the indicator is warm (enough data seen), otherwise `None`. - fn update(&mut self, value: f64) -> Option; - - /// Reset the internal state to clear history (e.g., for a new trading session). - fn reset(&mut self); -} - -// ================================================================================================ -// SMA: Simple Moving Average -// ================================================================================================ - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct StreamingSma { - window_size: usize, - buffer: VecDeque, - sum: f64, -} - -impl StreamingSma { - pub fn new(window_size: u16) -> Self { - let size = window_size as usize; - Self { - window_size: size, - buffer: VecDeque::with_capacity(size), - sum: 0.0, - } - } -} - -impl StreamingIndicator for StreamingSma { - fn update(&mut self, value: f64) -> Option { - // 1. Add new value to window - self.buffer.push_back(value); - self.sum += value; - - // 2. Remove old value if we exceeded window size - if self.buffer.len() > self.window_size { - // Safety: We just pushed, so unwrap is safe, but idiomatic rust prefers matching - if let Some(removed) = self.buffer.pop_front() { - self.sum -= removed; - } - } - - // 3. Check readiness - if self.buffer.len() >= self.window_size { - Some(self.sum / self.buffer.len() as f64) - } else { - None - } - } - - fn reset(&mut self) { - self.buffer.clear(); - self.sum = 0.0; - } -} - -// ================================================================================================ -// SHARED: Exponential Weighted Mean (Base Logic) -// ================================================================================================ - -/// Internal helper for EMA-like calculations (Standard EMA and Wilder's Smoothing). -/// Implements the recursive formula: $y_t = \alpha * x_t + (1 - \alpha) * y_{t-1}$. -#[derive(Debug, Clone, Serialize, Deserialize)] -struct StreamingEwm { - alpha: f64, - current_mean: f64, - initialized: bool, - window_size: usize, - count: usize, -} - -impl StreamingEwm { - fn new(alpha: f64, window_size: usize) -> Self { - Self { - alpha, - current_mean: 0.0, - initialized: false, - window_size, - count: 0, - } - } - - fn update(&mut self, value: f64) -> Option { - if !self.initialized { - // Per Polars/Pandas `adjust=false`: initialize with the first value - self.current_mean = value; - self.initialized = true; - self.count = 1; - } else { - // Recursive update: Mean = Alpha * Val + (1 - Alpha) * Prev - self.current_mean = self.alpha * value + (1.0 - self.alpha) * self.current_mean; - self.count += 1; - } - - if self.count >= self.window_size { - Some(self.current_mean) - } else { - None - } - } - - fn reset(&mut self) { - self.initialized = false; - self.current_mean = 0.0; - self.count = 0; - } -} - -// ================================================================================================ -// EMA: Exponential Moving Average -// ================================================================================================ - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct StreamingEma { - inner: StreamingEwm, -} - -impl StreamingEma { - pub fn new(window_size: u16) -> Self { - // Standard EMA Alpha = 2 / (Span + 1) - let alpha = 2.0 / (window_size as f64 + 1.0); - Self { - inner: StreamingEwm::new(alpha, window_size as usize), - } - } -} - -impl StreamingIndicator for StreamingEma { - fn update(&mut self, value: f64) -> Option { - self.inner.update(value) - } - - fn reset(&mut self) { - self.inner.reset(); - } -} - -// ================================================================================================ -// RSI: Relative Strength Index -// ================================================================================================ - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct StreamingRsi { - prev_price: Option, - avg_gain: StreamingEwm, - avg_loss: StreamingEwm, -} - -impl StreamingRsi { - pub fn new(window_size: u16) -> Self { - // Wilder's Smoothing Alpha = 1 / N - // This differs from standard EMA! - let alpha = 1.0 / (window_size as f64); - let win = window_size as usize; - - Self { - prev_price: None, - avg_gain: StreamingEwm::new(alpha, win), - avg_loss: StreamingEwm::new(alpha, win), - } - } -} - -impl StreamingIndicator for StreamingRsi { - fn update(&mut self, value: f64) -> Option { - let prev = match self.prev_price { - Some(p) => p, - None => { - // First trade: just store the price, we cannot calculate delta yet - self.prev_price = Some(value); - return None; - } - }; - - // 1. Calculate Delta - let delta = value - prev; - self.prev_price = Some(value); - - // 2. Separate Gain/Loss - let (gain, loss) = if delta > 0.0 { - (delta, 0.0) - } else { - (0.0, delta.abs()) - }; - - // 3. Update Wilder's Smoothers - // We capture the Option from both. If both are Some, we have enough data. - let g_val = self.avg_gain.update(gain); - let l_val = self.avg_loss.update(loss); - - match (g_val, l_val) { - (Some(avg_gain), Some(avg_loss)) => { - // 4. Calculate RSI - // Prevent division by zero if avg_loss is 0 (Monotonic Up-trend) - if avg_loss == 0.0 { - if avg_gain == 0.0 { - // Flat line - Some(50.0) - } else { - // Pure gain - Some(100.0) - } - } else { - let rs = avg_gain / avg_loss; - Some(100.0 - (100.0 / (1.0 + rs))) - } - } - _ => None, - } - } - - fn reset(&mut self) { - self.prev_price = None; - self.avg_gain.reset(); - self.avg_loss.reset(); - } -} diff --git a/src/math/moving_averages.rs b/src/math/moving_averages.rs new file mode 100644 index 0000000..c4cc2a9 --- /dev/null +++ b/src/math/moving_averages.rs @@ -0,0 +1,141 @@ +use std::collections::VecDeque; + +use serde::{Deserialize, Serialize}; + +use crate::math::StreamingIndicator; + +// ================================================================================================ +// SHARED: Exponential Weighted Mean (Base Logic) +// ================================================================================================ + +/// Internal helper for EMA-like calculations (Standard EMA and Wilder's Smoothing). +#[derive(Debug, Clone, Serialize, Deserialize)] +pub(crate) struct StreamingEwm { + alpha: f64, + current_mean: Option, + window_size: usize, + count: usize, +} + +impl StreamingEwm { + pub(crate) fn new(alpha: f64, window_size: usize) -> Self { + Self { + alpha, + current_mean: None, + window_size, + count: 0, + } + } +} + +impl StreamingIndicator for StreamingEwm { + type Input = f64; + type Output<'a> = Option; + + fn update(&mut self, value: Self::Input) -> Self::Output<'_> { + self.count += 1; + + match self.current_mean { + None => { + // First trade: initialize the mean + self.current_mean = Some(value); + } + Some(prev) => { + // Recursive update + self.current_mean = Some(self.alpha * value + (1.0 - self.alpha) * prev); + } + } + + if self.count >= self.window_size { + self.current_mean + } else { + None + } + } + + fn reset(&mut self) { + self.current_mean = None; + self.count = 0; + } +} + +// ================================================================================================ +// SMA: Simple Moving Average +// ================================================================================================ + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct StreamingSma { + window_size: usize, + buffer: VecDeque, + sum: f64, +} + +impl StreamingSma { + pub fn new(window_size: u16) -> Self { + let size = window_size as usize; + Self { + window_size: size, + buffer: VecDeque::with_capacity(size), + sum: 0.0, + } + } +} + +impl StreamingIndicator for StreamingSma { + type Input = f64; + type Output<'a> = Option; + + fn update(&mut self, value: Self::Input) -> Self::Output<'_> { + self.buffer.push_back(value); + self.sum += value; + + if self.buffer.len() > self.window_size + && let Some(removed) = self.buffer.pop_front() + { + self.sum -= removed; + } + + if self.buffer.len() >= self.window_size { + Some(self.sum / self.buffer.len() as f64) + } else { + None + } + } + + fn reset(&mut self) { + self.buffer.clear(); + self.sum = 0.0; + } +} + +// ================================================================================================ +// EMA: Exponential Moving Average +// ================================================================================================ + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct StreamingEma { + inner: StreamingEwm, +} + +impl StreamingEma { + pub fn new(window_size: u16) -> Self { + // Standard EMA Alpha = 2 / (Span + 1) + let alpha = 2.0 / (window_size as f64 + 1.0); + Self { + inner: StreamingEwm::new(alpha, window_size as usize), + } + } +} + +impl StreamingIndicator for StreamingEma { + type Input = f64; + type Output<'a> = Option; + + fn update(&mut self, value: Self::Input) -> Self::Output<'_> { + self.inner.update(value) + } + + fn reset(&mut self) { + self.inner.reset(); + } +} diff --git a/src/math/oscillators.rs b/src/math/oscillators.rs new file mode 100644 index 0000000..ff9073b --- /dev/null +++ b/src/math/oscillators.rs @@ -0,0 +1,77 @@ +use serde::{Deserialize, Serialize}; + +use crate::math::{StreamingIndicator, moving_averages::StreamingEwm}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct StreamingRsi { + prev_price: Option, + avg_gain: StreamingEwm, + avg_loss: StreamingEwm, +} + +impl StreamingRsi { + pub fn new(window_size: u16) -> Self { + // Wilder's Smoothing Alpha = 1 / N + // This differs from standard EMA! + let alpha = 1.0 / (window_size as f64); + let win = window_size as usize; + + Self { + prev_price: None, + avg_gain: StreamingEwm::new(alpha, win), + avg_loss: StreamingEwm::new(alpha, win), + } + } +} + +impl StreamingIndicator for StreamingRsi { + type Input = f64; + type Output<'a> = Option; + + fn update(&mut self, value: Self::Input) -> Self::Output<'_> { + let prev = match self.prev_price { + Some(p) => p, + None => { + self.prev_price = Some(value); + return None; + } + }; + + let delta = value - prev; + self.prev_price = Some(value); + + let (gain, loss) = if delta > 0.0 { + (delta, 0.0) + } else { + (0.0, delta.abs()) + }; + + let g_val = self.avg_gain.update(gain); + let l_val = self.avg_loss.update(loss); + + match (g_val, l_val) { + (Some(avg_gain), Some(avg_loss)) => { + // Prevent division by zero if avg_loss is 0 (Monotonic Up-trend) + if avg_loss == 0.0 { + if avg_gain == 0.0 { + // Flat line + Some(50.0) + } else { + // Pure gain + Some(100.0) + } + } else { + let rs = avg_gain / avg_loss; + Some(100.0 - (100.0 / (1.0 + rs))) + } + } + _ => None, + } + } + + fn reset(&mut self) { + self.prev_price = None; + self.avg_gain.reset(); + self.avg_loss.reset(); + } +} diff --git a/src/math/swing.rs b/src/math/swing.rs new file mode 100644 index 0000000..75de2f2 --- /dev/null +++ b/src/math/swing.rs @@ -0,0 +1,1982 @@ +use std::{cmp::Ordering, collections::VecDeque}; + +use crate::{ + data::{ + domain::{CandleDirection, Price, PriceSource}, + event::{IndexedOhlcv, MarketEvent, Ohlcv}, + }, + math::StreamingIndicator, +}; +use chrono::{DateTime, Utc}; + +/// Represents the geometric type of a single pivot point. +/// +/// [`PivotType`] is required for the [`AlternationMode`] filter. +/// To enforce alternation (High -> Low -> High -> Low), the algorithm +/// needs to know the type of the current pivot to check if it +/// violates the sequence (e.g., detecting two `High`s in a row). +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum PivotType { + High, + Low, +} + +impl From for PivotType { + fn from(sequence: MarketStructureSequence) -> Self { + use MarketStructureSequence::*; + match sequence { + LowerHigh | HigherHigh | EqualHigh | UnclassifiedHigh => PivotType::High, + HigherLow | LowerLow | EqualLow | UnclassifiedLow => PivotType::Low, + } + } +} + +impl PivotType { + /// Extracts the relevant price for peak/trough finding given the + /// configured [`PriceSource`] and the candle's own direction. + fn extract_price(self, candle: Ohlcv, source: PriceSource) -> Price { + match (source, self, candle.direction()) { + (PriceSource::HighLow, PivotType::High, _) => candle.high, + (PriceSource::HighLow, PivotType::Low, _) => candle.low, + + (PriceSource::OpenClose, PivotType::High, CandleDirection::Bullish) => candle.close, + (PriceSource::OpenClose, PivotType::High, CandleDirection::Bearish) => candle.open, + (PriceSource::OpenClose, PivotType::High, CandleDirection::Doji) => candle.close, + + (PriceSource::OpenClose, PivotType::Low, CandleDirection::Bullish) => candle.open, + (PriceSource::OpenClose, PivotType::Low, CandleDirection::Bearish) => candle.close, + (PriceSource::OpenClose, PivotType::Low, CandleDirection::Doji) => candle.close, + } + } +} + +/// Represents the relative sequence that defines the market's overall direction. +/// +/// While [`PivotType`] tells us the basic shape (peak or trough), +/// [`MarketStructureSequence`] provides the **trend context** by comparing it +/// to historical pivots. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum MarketStructureSequence { + HigherHigh, + LowerHigh, + EqualHigh, + HigherLow, + LowerLow, + EqualLow, + UnclassifiedHigh, + UnclassifiedLow, +} + +impl MarketStructureSequence { + pub fn as_pivot_type(&self) -> PivotType { + (*self).into() + } +} + +/// Defines how the indicator handles consecutive pivots of the same type +/// (e.g., detecting two `PivotType::High`s in a row without a `PivotType::Low` in between). +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum AlternationMode { + /// Forces alternating `High -> Low -> High -> Low` sequences (ZigZag behavior). + /// + /// If the algorithm detects a new `PivotType::High`, but the last confirmed pivot + /// was also a `PivotType::High`, it evaluates both and only keeps the one with + /// the higher price. The lesser pivot is discarded as market noise. + #[default] + Alternating, + + /// No alternation filtering. Every detected [`PivotType`] is kept. + /// + /// If the algorithm detects two `PivotType::High`s in a row, the second `PivotType::High` + /// is simply classified against the first `PivotType::High` (resulting in a HH or LH), + /// regardless of the missing `PivotType::Low`. + Consecutive, +} + +/// Represents structural breakthrough events detected in the price series. +/// +/// A `MarketStructureEvent` is emitted alongside every new confirmed pivot. +/// It answers the question: _"Did this new pivot break a meaningful prior level, +/// and if so, does it continue or contradict the prevailing trend?"_ +/// +/// # Background +/// +/// Most price action consists of small zig-zags within a range. Occasionally, +/// a new swing high breaks above the previous swing high (or a new swing low +/// breaks below the previous swing low). Those breakouts are the structural +/// events this enum classifies. +/// +/// The classification depends on two things: +/// 1. Whether the new pivot exceeds the most recent same-side pivot +/// (e.g. a new high above the last confirmed high). +/// 2. What the prevailing trend looked like just before the break, inferred +/// from the most recent opposite-side pivot. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum MarketStructureEvent { + /// The new pivot extends the prevailing trend, or establishes the first directional trend. + /// + /// A _Break of Structure_ (BOS) is the classification for any new pivot + /// that exceeds its confirmed same-side predecessor. It fires when: + /// + /// - **Continuation**: the market was already trending in the same direction + /// (e.g., a new Higher High after a recent Higher Low). + /// - **Initiation**: the pivot establishes the first directional trend by breaking + /// the initial anchor (e.g., a Higher High occurs, but the prior opposite-side + /// pivot is still `UnclassifiedLow` or `EqualLow`). + BreakOfStructure, + + /// The new pivot reverses a previously confirmed trend. + /// + /// Also known as a _Change of Character_ (CHoCH). This is the strict case: + /// it fires only when the prior opposite-side pivot was itself a trend-confirming + /// break, so there is concrete evidence of a trend to reverse. + /// + /// - **Bullish Shift**: a new swing high prints above the most recent swing high, + /// and the most recent low was a `LowerLow` (downtrend -> potential uptrend). + /// - **Bearish Shift**: a new swing low prints below the most recent swing low, + /// and the most recent high was a `HigherHigh` (uptrend -> potential downtrend). + MarketStructureShift, + + /// The new pivot did not break structure or shift the trend. + /// + /// Returned when the candidate forms a Lower High, Equal High, Higher Low, or + /// Equal Low. This is also emitted for the **very first detected pivot** (unclassified), + /// as there is no prior structure on that side to compare against. The pivot is + /// still recorded and added to the history. It just doesn't represent a breakout. + NoChange, +} + +/// Tiebreaker policy when two adjacent bars share the same extreme price. +/// +/// Affects two situations: +/// 1. **Plateaus inside the lookback window**: when a sequence of adjacent bars shares the +/// exact same extreme price, they will all eventually pass through the center candidate +/// position. This rule applies strict/inclusive inequalities to ensure only one of them +/// (the earliest or latest) actually triggers a valid pivot. +/// 2. **Conflicts under [`AlternationMode::Alternating`]**: when a new pivot of the same +/// type as the active one is detected and their prices are equal, this rule decides +/// which one wins. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum ExtremeTiebreaker { + /// On a tie, the later bar wins. Pivots track the most recent occurrence + /// of a repeated extreme. + #[default] + Latest, + /// On a tie, the earlier bar wins. Pivots anchor to the first occurrence + /// of a repeated extreme. + Earliest, +} + +/// A confirmed swing point in the price series. +/// +/// Carries the originating candle, the price that triggered the pivot (which may be +/// `high`/`low` or `open`/`close` depending on the configured [`PriceSource`]), +/// and the trend-relative classification ([`MarketStructureSequence`]). +/// +/// Note: The geometric type ([`PivotType::High`] or [`PivotType::Low`]) is intrinsically +/// implied by the `trend` and can be accessed via [`Self::pivot_type`]. +#[derive(Debug, Clone, Copy)] +pub struct PivotPoint { + pub indexed_candle: IndexedOhlcv, + pub price: Price, + pub price_source: PriceSource, + pub trend: MarketStructureSequence, +} +impl MarketEvent for PivotPoint { + fn point_in_time(&self) -> DateTime { + self.indexed_candle.point_in_time() + } +} + +impl PivotPoint { + /// Returns the geometric type of the pivot, derived directly from its trend sequence. + pub fn pivot_type(&self) -> PivotType { + self.trend.into() + } + + /// Generates a linear interpolation function based on bar indices. + /// + /// Returns a zero-allocation closure that takes a target bar index (usize) + /// and returns the interpolated/extrapolated price at that index. + pub fn price_line_by_index(&self, target: &PivotPoint) -> impl Fn(usize) -> Price { + let p0 = self.price.0; + let p1 = target.price.0; + let x0 = self.indexed_candle.index as f64; + let x1 = target.indexed_candle.index as f64; + + let dx = x1 - x0; + let m = if dx == 0.0 { 0.0 } else { (p1 - p0) / dx }; + + move |x: usize| -> Price { + let current_dx = (x as f64) - x0; + Price(p0 + m * current_dx) + } + } + + /// Generates a linear interpolation function based on exact point-in-time timestamps. + /// + /// Returns a zero-allocation closure that takes a target point in time + /// (`DateTime`) and returns the interpolated/extrapolated price. + /// Uses chrono::Duration to safely compute the time deltas in milliseconds. + pub fn price_line_by_point_in_time( + &self, + target: &PivotPoint, + ) -> impl Fn(DateTime) -> Price { + let p0 = self.price.0; + let p1 = target.price.0; + + let t0 = self.point_in_time(); + let t1 = target.point_in_time(); + + let dx = (t1 - t0).num_milliseconds() as f64; + let m = if dx == 0.0 { 0.0 } else { (p1 - p0) / dx }; + + move |t: DateTime| -> Price { + let current_dx = (t - t0).num_milliseconds() as f64; + Price(p0 + m * current_dx) + } + } +} + +/// Lookback and lookforward window for swing detection. +/// +/// A bar is treated as a candidate pivot only if it is the most extreme bar within +/// a window of `left_bars` preceding bars and `right_bars` following bars. Larger +/// values produce fewer, more significant pivots and smaller values are more responsive +/// but noisier. Default is a symmetric window with `5` bars on each side. +/// +/// Note that the indicator must buffer `left_bars + right_bars + 1` candles before +/// it can emit its first result, since the candidate sits in the middle of the window. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct ZigZagPeriod { + pub left_bars: u16, + pub right_bars: u16, +} + +impl Default for ZigZagPeriod { + fn default() -> Self { + Self::symmetric(5) + } +} + +impl ZigZagPeriod { + /// Creates a `ZigZagPeriod` with equal lookback and lookforward windows. + /// + /// Sets both `left_bars` and `right_bars` to `bars`, producing a symmetric + /// window where the candidate pivot sits exactly in the middle. + /// + /// The indicator will buffer `2 * bars + 1` candles before emitting its first + /// result. + pub fn symmetric(bars: u16) -> Self { + Self { + left_bars: bars, + right_bars: bars, + } + } + + fn buffer_size(&self) -> usize { + (self.left_bars + self.right_bars + 1) as usize + } + + fn mid_index(&self) -> usize { + self.left_bars as usize + } +} + +/// A streaming Higher-High / Lower-Low indicator over OHLCV bars. +/// +/// Consumes a stream of [`IndexedOhlcv`] bars and emits a confirmed [`PivotPoint`] +/// together with a [`MarketStructureEvent`] whenever a new swing point is identified +/// and either continues or breaks the prevailing trend. +/// +/// # How it works +/// +/// 1. Each incoming bar is appended to an internal rolling window of size +/// `left_bars + right_bars + 1` (see [`ZigZagPeriod`]). +/// 2. The bar at the center of that window is the _candidate_. It is considered a +/// pivot if it is the most extreme bar in the window (highest for a swing high, +/// lowest for a swing low). Ties are resolved per [`ExtremeTiebreaker`]. +/// 3. When a candidate qualifies, it is classified relative to the most recent +/// confirmed pivot of the same kind (Higher High / Lower High / Equal High, etc.) +/// and a [`MarketStructureEvent`] is emitted. +/// 4. The [`AlternationMode`] controls what happens when consecutive same-type +/// pivots appear without an intervening opposite-type pivot. +/// +/// Because the candidate sits at the middle of the window, every emitted pivot is +/// confirmed with a lag of `right_bars` bars. +/// +/// # Output stream +/// +/// [`StreamingIndicator::update`] returns `Some((event, pivot))` whenever a pivot +/// is confirmed, and `None` while the window is still filling or no pivot is detected. +/// The full chronological history is available via [`Self::history`]. +#[derive(Debug, Clone)] +pub struct StreamingHhll { + zig_zag_period: ZigZagPeriod, + price_source: PriceSource, + tiebreaker: ExtremeTiebreaker, + alternation_mode: AlternationMode, + + // === Internal State === + buffer: VecDeque, + active_pivot: Option, + anchor_high: Option, + anchor_low: Option, + history: Vec, +} + +impl Default for StreamingHhll { + fn default() -> Self { + let zig_zag_period = ZigZagPeriod::default(); + Self { + zig_zag_period, + price_source: PriceSource::default(), + tiebreaker: ExtremeTiebreaker::default(), + alternation_mode: AlternationMode::default(), + buffer: VecDeque::with_capacity(zig_zag_period.buffer_size()), + active_pivot: None, + anchor_high: None, + anchor_low: None, + history: Vec::new(), + } + } +} + +impl StreamingHhll { + pub fn with_zig_zag_period(self, zig_zag_period: ZigZagPeriod) -> Self { + Self { + zig_zag_period, + buffer: VecDeque::with_capacity(zig_zag_period.buffer_size()), + ..self + } + } + + pub fn with_price_source(self, price_source: PriceSource) -> Self { + Self { + price_source, + ..self + } + } + + pub fn with_tiebreaker(self, tiebreaker: ExtremeTiebreaker) -> Self { + Self { tiebreaker, ..self } + } + + pub fn with_alternation_mode(self, alternation_mode: AlternationMode) -> Self { + Self { + alternation_mode, + ..self + } + } + + /// The active pivot currently tracking the trailing edge of the market structure. + /// + /// If [`AlternationMode::Alternating`] is active, this pivot remains mutable. + /// If a consecutive vertex of the same [`PivotType`] appears, this [`PivotPoint`] + /// may be overwritten or extended based on the [`ExtremeTiebreaker`]. + pub fn active_pivot(&self) -> Option { + self.active_pivot + } + + /// The historical, safely locked-in `PivotType::High` used as a baseline for relative classification. + /// + /// When a new High vertex is detected, it is compared against this anchor to determine if it + /// is a `HigherHigh`, `LowerHigh`, or `EqualHigh`. + pub fn anchor_high(&self) -> Option { + self.anchor_high + } + + /// The historical, safely locked-in `PivotType::Low` used as a baseline for relative classification. + /// + /// When a new Low vertex is detected, it is compared against this anchor to determine if it + /// is a `HigherLow`, `LowerLow`, or `EqualLow`. + pub fn anchor_low(&self) -> Option { + self.anchor_low + } + + /// Chronological history of all safely locked-in pivots. + /// + /// This vector guarantees perfect time-order. To iterate from latest to earliest, + /// you simply call `self.history.iter().rev()`. + pub fn history(&self) -> &[PivotPoint] { + &self.history + } +} + +impl StreamingHhll { + fn candidate(&self) -> IndexedOhlcv { + let mid_idx = self.zig_zag_period.mid_index(); + self.buffer[mid_idx] + } + + /// Yields the left side of the rolling window (before the candidate). + fn left_partition(&self) -> impl Iterator + '_ { + self.buffer + .iter() + .take(self.zig_zag_period.left_bars as usize) + .copied() + } + + /// Yields the right side of the rolling window (after the candidate). + fn right_partition(&self) -> impl Iterator + '_ { + self.buffer + .iter() + .rev() + .take(self.zig_zag_period.right_bars as usize) + .copied() + } + + /// Checks if the candidate price is a valid extremum against its neighbors. + fn check_extremum(&self, pivot_type: PivotType) -> bool { + let candidate = self.candidate(); + let candidate_price = pivot_type.extract_price(candidate.candle, self.price_source); + + // Determine which side of the window requires a STRICT inequality based on the tiebreaker. + let (strict_left, strict_right) = match self.tiebreaker { + ExtremeTiebreaker::Earliest => (true, false), + ExtremeTiebreaker::Latest => (false, true), + }; + + let is_valid = |neighbor: IndexedOhlcv, strict: bool| -> bool { + let neighbor_price = pivot_type.extract_price(neighbor.candle, self.price_source); + match (pivot_type, strict) { + (PivotType::High, true) => candidate_price > neighbor_price, + (PivotType::High, false) => candidate_price >= neighbor_price, + (PivotType::Low, true) => candidate_price < neighbor_price, + (PivotType::Low, false) => candidate_price <= neighbor_price, + } + }; + + self.left_partition().all(|c| is_valid(c, strict_left)) + && self.right_partition().all(|c| is_valid(c, strict_right)) + } + + #[tracing::instrument(skip(self), fields(ts = %self.candidate().candle.close_timestamp))] + fn process_high(&mut self) -> Option<(MarketStructureEvent, PivotPoint)> { + let candidate = self.candidate(); + let current_high_price = PivotType::High.extract_price(candidate.candle, self.price_source); + + // 1. Evaluate the candidate against the active pivot + let resolution = match self.active_pivot { + Some(active) + if self.alternation_mode == AlternationMode::Alternating + && active.pivot_type() == PivotType::High => + { + // Alternation Conflict: We have two Highs in a row. + let overwrite = match self.tiebreaker { + ExtremeTiebreaker::Earliest => current_high_price > active.price, + ExtremeTiebreaker::Latest => current_high_price >= active.price, + }; + + if overwrite { + CandidateResolution::ReplaceActive + } else { + CandidateResolution::Discard + } + } + Some(_) | None => CandidateResolution::ConfirmActive, + }; + + // 2. Execute the state transition + match resolution { + CandidateResolution::Discard => return None, + CandidateResolution::ReplaceActive => { + // Do nothing to the history. We will simply overwrite `self.active_pivot` + // with the new candidate at the end of the method. + } + CandidateResolution::ConfirmActive => { + // The active pivot is safe. Lock it into the anchors and history. + if let Some(active) = self.active_pivot { + match active.pivot_type() { + PivotType::High => self.anchor_high = Some(active), + PivotType::Low => self.anchor_low = Some(active), + } + self.history.push(active); + } + } + } + + // 3. Classify the new pivot + let (trend, event) = match self.anchor_high { + Some(anchor) => match current_high_price.partial_cmp(&anchor.price) { + Some(Ordering::Greater) => { + use MarketStructureSequence::*; + + let market_structure_event = match self.anchor_low.map(|l| l.trend) { + Some(LowerLow) => MarketStructureEvent::MarketStructureShift, + + Some(HigherLow | EqualLow | UnclassifiedLow) | None => { + MarketStructureEvent::BreakOfStructure + } + + Some( + invalid_trend @ (HigherHigh | LowerHigh | EqualHigh | UnclassifiedHigh), + ) => { + tracing::error!( + reason = "corrupted_state", + anchor_low_trend = ?invalid_trend, + "anchor_low contains a High pivot sequence. Defaulting to BreakOfStructure." + ); + MarketStructureEvent::BreakOfStructure + } + }; + + (HigherHigh, market_structure_event) + } + Some(Ordering::Less) => ( + MarketStructureSequence::LowerHigh, + MarketStructureEvent::NoChange, + ), + Some(Ordering::Equal) => ( + MarketStructureSequence::EqualHigh, + MarketStructureEvent::NoChange, + ), + None => { + tracing::warn!( + reason = "nan_detected", + candidate_price = ?current_high_price, + anchor_price = ?anchor.price, + "Invalid float (NaN) detected. Discarding pivot to prevent state poisoning." + ); + return None; + } + }, + None => ( + MarketStructureSequence::UnclassifiedHigh, + MarketStructureEvent::NoChange, + ), + }; + + // 4. Construct and emit the new pivot + let new_pivot = PivotPoint { + indexed_candle: candidate, + price: current_high_price, + price_source: self.price_source, + trend, + }; + + self.active_pivot = Some(new_pivot); + + if event != MarketStructureEvent::NoChange { + tracing::debug!( + event = ?event, + trend = ?trend, + price = ?current_high_price, + "Market Structure Extracted" + ); + } + + Some((event, new_pivot)) + } + + #[tracing::instrument(skip(self), fields(ts = %self.candidate().candle.close_timestamp))] + fn process_low(&mut self) -> Option<(MarketStructureEvent, PivotPoint)> { + let candidate = self.candidate(); + let current_low_price = PivotType::Low.extract_price(candidate.candle, self.price_source); + + // 1. Evaluate the candidate against the active pivot + let resolution = match self.active_pivot { + Some(active) + if self.alternation_mode == AlternationMode::Alternating + && active.pivot_type() == PivotType::Low => + { + // Alternation Conflict: We have two Lows in a row. + let overwrite = match self.tiebreaker { + ExtremeTiebreaker::Earliest => current_low_price < active.price, + ExtremeTiebreaker::Latest => current_low_price <= active.price, + }; + + if overwrite { + CandidateResolution::ReplaceActive + } else { + CandidateResolution::Discard + } + } + Some(_) | None => CandidateResolution::ConfirmActive, + }; + + // 2. Execute the state transition + match resolution { + CandidateResolution::Discard => return None, + CandidateResolution::ReplaceActive => { + // Do nothing to the history. We will simply overwrite `self.active_pivot` + // with the new candidate at the end of the method. + } + CandidateResolution::ConfirmActive => { + // The active pivot is safe. Lock it into the anchors and history. + if let Some(active) = self.active_pivot { + match active.pivot_type() { + PivotType::High => self.anchor_high = Some(active), + PivotType::Low => self.anchor_low = Some(active), + } + self.history.push(active); + } + } + } + + // 3. Classify the new pivot + let (trend, event) = match self.anchor_low { + Some(anchor) => match current_low_price.partial_cmp(&anchor.price) { + Some(Ordering::Less) => { + use MarketStructureSequence::*; + + // Explicitly match all variants to prevent black holes + let market_structure_event = match self.anchor_high.map(|h| h.trend) { + Some(HigherHigh) => MarketStructureEvent::MarketStructureShift, + + Some(LowerHigh | EqualHigh | UnclassifiedHigh) | None => { + MarketStructureEvent::BreakOfStructure + } + + // Explicitly trap invalid states that the old `_` was swallowing + Some( + invalid_trend @ (HigherLow | LowerLow | EqualLow | UnclassifiedLow), + ) => { + tracing::error!( + reason = "corrupted_state", + anchor_high_trend = ?invalid_trend, + "anchor_high contains a Low pivot sequence. Defaulting to BreakOfStructure." + ); + MarketStructureEvent::BreakOfStructure + } + }; + + (LowerLow, market_structure_event) + } + Some(Ordering::Greater) => ( + MarketStructureSequence::HigherLow, + MarketStructureEvent::NoChange, + ), + Some(Ordering::Equal) => ( + MarketStructureSequence::EqualLow, + MarketStructureEvent::NoChange, + ), + None => { + tracing::warn!( + reason = "nan_detected", + candidate_price = ?current_low_price, + anchor_price = ?anchor.price, + "Invalid float (NaN) detected. Discarding pivot to prevent state poisoning." + ); + return None; + } + }, + None => ( + MarketStructureSequence::UnclassifiedLow, + MarketStructureEvent::NoChange, + ), + }; + + // 4. Construct and emit the new pivot + let new_pivot = PivotPoint { + indexed_candle: candidate, + price: current_low_price, + price_source: self.price_source, + trend, + }; + + self.active_pivot = Some(new_pivot); + + if event != MarketStructureEvent::NoChange { + tracing::debug!( + event = ?event, + trend = ?trend, + price = ?current_low_price, + "Market Structure Extracted" + ); + } + + Some((event, new_pivot)) + } +} + +impl StreamingIndicator for StreamingHhll { + type Input = IndexedOhlcv; + type Output<'a> = Option<(MarketStructureEvent, PivotPoint)>; + + fn update(&mut self, indexed_candle: Self::Input) -> Self::Output<'_> { + let window_size = self.zig_zag_period.buffer_size(); + self.buffer.push_back(indexed_candle); + + if self.buffer.len() < window_size { + return None; + } + if self.buffer.len() > window_size { + self.buffer.pop_front(); + } + + let is_swing_high = self.check_extremum(PivotType::High); + let is_swing_low = self.check_extremum(PivotType::Low); + + match (is_swing_high, is_swing_low) { + (true, true) => { + // The candidate is BOTH a Swing High and a Swing Low (Mega Bar). + let candidate = self.candidate(); + match candidate.candle.direction() { + CandleDirection::Bullish => self.process_high(), + CandleDirection::Bearish => self.process_low(), + CandleDirection::Doji => { + // Assumption: Extend the current market structure + match self.active_pivot.map(|p| p.pivot_type()) { + Some(PivotType::Low) => self.process_low(), + Some(PivotType::High) => self.process_high(), + None => None, // A doji candle and no history. + } + } + } + } + (true, false) => self.process_high(), + (false, true) => self.process_low(), + (false, false) => None, + } + } + + fn reset(&mut self) { + self.buffer.clear(); + self.active_pivot = None; + self.anchor_high = None; + self.anchor_low = None; + self.history.clear(); + } +} + +// ================================================================================================ +// Helper Enum +// ================================================================================================ + +/// Represents how a new candidate resolves against the currently active pivot. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum CandidateResolution { + /// The candidate violates alternation and is weaker. Discard it. + Discard, + /// The candidate violates alternation but is stronger. Overwrite the active pivot. + ReplaceActive, + /// The candidate respects alternation. Lock the active pivot into history. + ConfirmActive, +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::data::domain::Quantity; + + // ========================================== + // === 1. Mocks & Helpers === + // ========================================== + + /// Parse RFC3339 timestamp string to DateTime. + fn ts(s: &str) -> DateTime { + DateTime::parse_from_rfc3339(s).unwrap().with_timezone(&Utc) + } + + /// A rapid builder for Indexed OHLCV candles to keep our test trajectories readable. + fn candle( + index: usize, + time: &str, + open: f64, + high: f64, + low: f64, + close: f64, + ) -> IndexedOhlcv { + IndexedOhlcv { + index, + candle: Ohlcv { + open_timestamp: ts(time), + close_timestamp: ts(time), + open: Price(open), + high: Price(high), + low: Price(low), + close: Price(close), + volume: Quantity(100.0), + quote_asset_volume: None, + number_of_trades: None, + taker_buy_base_asset_volume: None, + taker_buy_quote_asset_volume: None, + }, + } + } + + /// Helper to assert floats with epsilon tolerance + fn assert_f64_eq(a: f64, b: f64) { + assert!( + (a - b).abs() < f64::EPSILON, + "Expected {} to equal {}", + a, + b + ); + } + + fn create_indicator(left: u16, right: u16, tiebreaker: ExtremeTiebreaker) -> StreamingHhll { + let indicator = StreamingHhll::default() + .with_zig_zag_period(ZigZagPeriod { + left_bars: left, + right_bars: right, + }) + .with_tiebreaker(tiebreaker) + .with_alternation_mode(AlternationMode::Alternating) + .with_price_source(PriceSource::HighLow); + + // ALWAYS Verify strictly correct initial state + assert!(indicator.active_pivot().is_none()); + assert!(indicator.anchor_high().is_none()); + assert!(indicator.anchor_low().is_none()); + assert!(indicator.history().is_empty()); + assert!(indicator.buffer.is_empty()); + + indicator + } + + // ========================================== + // === 2. Enum/Math/Classification Tests === + // ========================================== + + #[test] + fn test_pivot_type_extract_price() { + // Create mock candles with explicit directions + let bullish_candle = candle(0, "2026-05-24T10:00:00Z", 10., 20., 5., 15.).candle; + let bearish_candle = candle(1, "2026-05-24T10:01:00Z", 15., 20., 5., 10.).candle; + let doji_candle = candle(2, "2026-05-24T10:02:00Z", 15., 20., 5., 15.).candle; + + // === HighLow PriceSource (Always extracts High/Low regardless of direction) === + assert_f64_eq( + PivotType::High + .extract_price(bullish_candle, PriceSource::HighLow) + .0, + 20., + ); + assert_f64_eq( + PivotType::Low + .extract_price(bullish_candle, PriceSource::HighLow) + .0, + 5., + ); + + // === OpenClose PriceSource === + // Bullish: High -> Close(15), Low -> Open(10) + assert_f64_eq( + PivotType::High + .extract_price(bullish_candle, PriceSource::OpenClose) + .0, + 15., + ); + assert_f64_eq( + PivotType::Low + .extract_price(bullish_candle, PriceSource::OpenClose) + .0, + 10., + ); + + // Bearish: High -> Open(15), Low -> Close(10) + assert_f64_eq( + PivotType::High + .extract_price(bearish_candle, PriceSource::OpenClose) + .0, + 15., + ); + assert_f64_eq( + PivotType::Low + .extract_price(bearish_candle, PriceSource::OpenClose) + .0, + 10., + ); + + // Doji: High -> Close(15), Low -> Close(15) + assert_f64_eq( + PivotType::High + .extract_price(doji_candle, PriceSource::OpenClose) + .0, + 15., + ); + assert_f64_eq( + PivotType::Low + .extract_price(doji_candle, PriceSource::OpenClose) + .0, + 15., + ); + } + + /// Verifies that all `MarketStructureSequence` variants correctly map to their implied geometric `PivotType`. + #[test] + fn test_market_structure_sequence_to_pivot_type() { + use MarketStructureSequence::*; + + let highs = vec![HigherHigh, LowerHigh, EqualHigh, UnclassifiedHigh]; + for h in highs { + assert_eq!(h.as_pivot_type(), PivotType::High); + assert_eq!(PivotType::from(h), PivotType::High); + } + + let lows = vec![HigherLow, LowerLow, EqualLow, UnclassifiedLow]; + for l in lows { + assert_eq!(l.as_pivot_type(), PivotType::Low); + assert_eq!(PivotType::from(l), PivotType::Low); + } + } + + /// Verifies the initial struct classification accurately returns NoChange for the first unclassified points. + #[test] + fn test_initial_classification_bos_vs_nochange() { + let mut hhll = create_indicator(1, 1, ExtremeTiebreaker::Latest); + + // 1. Fire the first High. It has no prior structure to compare against. + assert!( + hhll.update(candle(0, "2026-05-24T10:00:00Z", 10., 10., 10., 10.)) + .is_none() + ); + + let event_1 = hhll.update(candle(1, "2026-05-24T10:01:00Z", 20., 20., 20., 20.)); + assert!(event_1.is_none(), "Window is not full, must not emit"); + + let (e1, p1) = hhll + .update(candle(2, "2026-05-24T10:02:00Z", 10., 10., 10., 10.)) + .unwrap(); + + // Assert INITIAL classification behavior + assert_eq!(e1, MarketStructureEvent::NoChange); + assert_eq!(p1.trend, MarketStructureSequence::UnclassifiedHigh); + + // 2. Fire the first Low to establish opposite-side anchor. + hhll.update(candle(3, "2026-05-24T10:03:00Z", 5., 5., 5., 5.)); + let (e2, p2) = hhll + .update(candle(4, "2026-05-24T10:04:00Z", 10., 10., 10., 10.)) + .unwrap(); + + assert_eq!(e2, MarketStructureEvent::NoChange); + assert_eq!(p2.trend, MarketStructureSequence::UnclassifiedLow); + + // 3. Fire a SECOND High that exceeds the FIRST High. + // This initiates the trend and must trigger a BreakOfStructure. + hhll.update(candle(5, "2026-05-24T10:05:00Z", 30., 30., 30., 30.)); + let (e3, p3) = hhll + .update(candle(6, "2026-05-24T10:06:00Z", 10., 10., 10., 10.)) + .unwrap(); + + assert_eq!(e3, MarketStructureEvent::BreakOfStructure); + assert_eq!(p3.trend, MarketStructureSequence::HigherHigh); + + // 4. Fire a SECOND Low that exceeds (is lower than) the FIRST Low. + hhll.update(candle(7, "2026-05-24T10:07:00Z", 2., 2., 2., 2.)); + let (e4, p4) = hhll + .update(candle(8, "2026-05-24T10:08:00Z", 10., 10., 10., 10.)) + .unwrap(); + + assert_eq!(e4, MarketStructureEvent::MarketStructureShift); + assert_eq!(p4.trend, MarketStructureSequence::LowerLow); + } + + #[test] + fn test_pivot_point_interpolation() { + let p1 = PivotPoint { + indexed_candle: candle(10, "2026-05-24T15:00:00Z", 100., 100., 100., 100.), + price: Price(100.0), + price_source: PriceSource::HighLow, + trend: MarketStructureSequence::LowerLow, + }; + + // Target pivot is exactly 10 bars (and 10 minutes) later, price has risen by 50. + let p2 = PivotPoint { + indexed_candle: candle(20, "2026-05-24T15:10:00Z", 150., 150., 150., 150.), + price: Price(150.0), + price_source: PriceSource::HighLow, + trend: MarketStructureSequence::HigherLow, + }; + + // === 1. Test Index Based Interpolation === + // Slope = (150 - 100) / (20 - 10) = 5.0 per bar + let line_by_idx = p1.price_line_by_index(&p2); + + assert_f64_eq(line_by_idx(10).0, 100.0); // Start point + assert_f64_eq(line_by_idx(15).0, 125.0); // Exact midpoint + assert_f64_eq(line_by_idx(20).0, 150.0); // Target point + assert_f64_eq(line_by_idx(25).0, 175.0); // Extrapolation into the future! + + // === 2. Test Time Based Interpolation === + let line_by_time = p1.price_line_by_point_in_time(&p2); + + assert_f64_eq(line_by_time(ts("2026-05-24T15:00:00Z")).0, 100.0); // Start + assert_f64_eq(line_by_time(ts("2026-05-24T15:05:00Z")).0, 125.0); // Midpoint (5 mins) + assert_f64_eq(line_by_time(ts("2026-05-24T15:10:00Z")).0, 150.0); // Target + assert_f64_eq(line_by_time(ts("2026-05-24T15:20:00Z")).0, 200.0); // Extrapolation into future + } + + #[test] + fn test_pivot_point_flat_line_and_zero_division() { + let p1 = PivotPoint { + indexed_candle: candle(5, "2026-05-24T15:00:00Z", 100., 100., 100., 100.), + price: Price(100.0), + price_source: PriceSource::HighLow, + trend: MarketStructureSequence::LowerLow, + }; + + let p2 = PivotPoint { + indexed_candle: candle(5, "2026-05-24T15:00:00Z", 100., 100., 100., 100.), + price: Price(100.0), + price_source: PriceSource::HighLow, + trend: MarketStructureSequence::LowerLow, + }; + + // Same index/time should result in flat line, NOT a NaN/Inf panic + let line = p1.price_line_by_index(&p2); + assert_f64_eq(line(10).0, 100.0); + } + + // ========================================== + // === 3. Partitions & Tiebreaker Microstructure Tests === + // ========================================== + + /// Validates the internal left/right partitioning logic and `candidate()` pointer. + #[test] + fn test_streaming_hhll_partitions_and_candidate() { + let mut hhll = create_indicator(2, 2, ExtremeTiebreaker::Latest); + + // Fill buffer exactly to window size (5 bars) + for i in 0..5 { + hhll.buffer + .push_back(candle(i, "2026-05-24T10:00:00Z", 10., 10., 10., 10.)); + } + + // With left_bars=2, mid index must be 2. + assert_eq!(hhll.candidate().index, 2); + + // Left partition should take `left_bars` from the front + let left = hhll.left_partition().map(|c| c.index).collect::>(); + assert_eq!(left, vec![0, 1]); + + // Right partition takes `right_bars` from the end (in reverse iteration) + let right = hhll.right_partition().map(|c| c.index).collect::>(); + assert_eq!(right, vec![4, 3]); + } + + /// Evaluates the `check_extremum` logic in strict isolation to prove micro-swing detection. + /// + /// Note: Because this tests a private internal method, we manually populate the buffer + /// to bypass the `update()` state machine and avoid triggering `process_high/low`. + #[test] + fn test_check_extremum_isolated() { + // === Case 1: Clear Swing High === + { + let mut hhll = create_indicator(1, 1, ExtremeTiebreaker::Latest); + hhll.buffer + .push_back(candle(0, "2026-05-24T10:00:00Z", 10., 10., 10., 10.)); + hhll.buffer + .push_back(candle(1, "2026-05-24T10:01:00Z", 20., 20., 20., 20.)); + hhll.buffer + .push_back(candle(2, "2026-05-24T10:02:00Z", 10., 10., 10., 10.)); + + assert!( + hhll.check_extremum(PivotType::High), + "Clear peak should be High" + ); + assert!( + !hhll.check_extremum(PivotType::Low), + "Clear peak is not a Low" + ); + } + + // === Case 2: Clear Swing Low === + { + let mut hhll = create_indicator(1, 1, ExtremeTiebreaker::Latest); + hhll.buffer + .push_back(candle(0, "2026-05-24T10:00:00Z", 20., 20., 20., 20.)); + hhll.buffer + .push_back(candle(1, "2026-05-24T10:01:00Z", 10., 10., 10., 10.)); + hhll.buffer + .push_back(candle(2, "2026-05-24T10:02:00Z", 20., 20., 20., 20.)); + + assert!( + hhll.check_extremum(PivotType::Low), + "Clear trough should be Low" + ); + assert!( + !hhll.check_extremum(PivotType::High), + "Clear trough is not a High" + ); + } + + // === Case 3: Tiebreaker on a Flat Top Plateau [20, 20(Candidate), 10] === + { + // Sub-case: Latest Tiebreaker + let mut hhll_latest = create_indicator(1, 1, ExtremeTiebreaker::Latest); + hhll_latest + .buffer + .push_back(candle(0, "2026-05-24T10:00:00Z", 20., 20., 20., 20.)); + hhll_latest + .buffer + .push_back(candle(1, "2026-05-24T10:01:00Z", 20., 20., 20., 20.)); + hhll_latest + .buffer + .push_back(candle(2, "2026-05-24T10:02:00Z", 10., 10., 10., 10.)); + + // Latest: Left is inclusive (20 >= 20 = Pass). Right is strict (20 > 10 = Pass). + assert!( + hhll_latest.check_extremum(PivotType::High), + "Latest should pass on flat left side" + ); + + // Sub-case: Earliest Tiebreaker + let mut hhll_earliest = create_indicator(1, 1, ExtremeTiebreaker::Earliest); + hhll_earliest + .buffer + .push_back(candle(0, "2026-05-24T10:00:00Z", 20., 20., 20., 20.)); + hhll_earliest + .buffer + .push_back(candle(1, "2026-05-24T10:01:00Z", 20., 20., 20., 20.)); + hhll_earliest + .buffer + .push_back(candle(2, "2026-05-24T10:02:00Z", 10., 10., 10., 10.)); + + // Earliest: Left is strict (20 > 20 = FAIL). Right is inclusive (20 >= 10 = Pass). + assert!( + !hhll_earliest.check_extremum(PivotType::High), + "Earliest should fail on flat left side" + ); + } + } + + /// Comprehensive edge case: Plateau inside lookback window under Earliest policy. + /// Scenario: Prices `[10, 15(A), 15(B), 15(C), 10]`. + /// `Earliest` dictates that ONLY 15(A) should be emitted. + #[test] + fn test_sliding_window_plateaus_earliest() { + let mut hhll = create_indicator(1, 1, ExtremeTiebreaker::Earliest); + + // Pre-fill to start the sliding window + hhll.update(candle(0, "2026-05-24T10:00:00Z", 10., 10., 10., 10.)); // Buffer: [10] + hhll.update(candle(1, "2026-05-24T10:01:00Z", 15., 15., 15., 15.)); // Buffer: [10, 15(A)] + + // Tick 1: Buffer is [10, 15(A), 15(B)]. Candidate is 15(A). + // Left is 10. 15 > 10 (Pass). Right is 15. 15 >= 15 (Pass). + let event_a = hhll.update(candle(2, "2026-05-24T10:02:00Z", 15., 15., 15., 15.)); + assert!( + event_a.is_some(), + "Earliest policy should emit the first peak (A)." + ); + assert_eq!(event_a.unwrap().1.indexed_candle.index, 1); + + // Tick 2: Buffer is [15(A), 15(B), 15(C)]. Candidate is 15(B). + // Left is 15(A). 15 > 15 (FAIL). + let event_b = hhll.update(candle(3, "2026-05-24T10:03:00Z", 15., 15., 15., 15.)); + assert!( + event_b.is_none(), + "Earliest policy must discard middle plateau bars (B)." + ); + + // Tick 3: Buffer is [15(B), 15(C), 10]. Candidate is 15(C). + // Left is 15(B). 15 > 15 (FAIL). + let event_c = hhll.update(candle(4, "2026-05-24T10:04:00Z", 10., 10., 10., 10.)); + assert!( + event_c.is_none(), + "Earliest policy must discard final plateau bars (C)." + ); + } + + #[test] + fn test_sliding_window_plateaus_earliest_low() { + let mut hhll = create_indicator(1, 1, ExtremeTiebreaker::Earliest); + + hhll.update(candle(0, "2026-05-24T10:00:00Z", 20., 20., 20., 20.)); // Buffer: [20] + hhll.update(candle(1, "2026-05-24T10:01:00Z", 10., 10., 10., 10.)); // Buffer: [20, 10(A)] + + // Tick 1: Buffer is [20, 10(A), 10(B)]. Candidate is 10(A). + // Left is 20. 10 < 20 (Pass). Right is 10. 10 <= 10 (Pass). + let event_a = hhll.update(candle(2, "2026-05-24T10:02:00Z", 10., 10., 10., 10.)); + assert!( + event_a.is_some(), + "Earliest policy should emit the first trough (A)." + ); + assert_eq!(event_a.unwrap().1.indexed_candle.index, 1); + + // Tick 2: Buffer is [10(A), 10(B), 10(C)]. Candidate is 10(B). + // Left is 10(A). 10 < 10 (FAIL). + let event_b = hhll.update(candle(3, "2026-05-24T10:03:00Z", 10., 10., 10., 10.)); + assert!(event_b.is_none()); + + // Tick 3: Buffer is [10(B), 10(C), 20]. Candidate is 10(C). + let event_c = hhll.update(candle(4, "2026-05-24T10:04:00Z", 20., 20., 20., 20.)); + assert!(event_c.is_none()); + } + + /// Comprehensive edge case: Plateau inside lookback window under Latest policy. + /// Scenario: Prices `[10, 15(A), 15(B), 15(C), 10]`. + /// `Latest` dictates that ONLY 15(C) should be emitted. + #[test] + fn test_sliding_window_plateaus_latest() { + let mut hhll = create_indicator(1, 1, ExtremeTiebreaker::Latest); + + // Pre-fill to start the sliding window + hhll.update(candle(0, "2026-05-24T10:00:00Z", 10., 10., 10., 10.)); // Buffer: [10] + hhll.update(candle(1, "2026-05-24T10:01:00Z", 15., 15., 15., 15.)); // Buffer: [10, 15(A)] + + // Tick 1: Buffer is [10, 15(A), 15(B)]. Candidate is 15(A). + // Right is 15. 15 > 15 (FAIL under Latest strict right). + let event_a = hhll.update(candle(2, "2026-05-24T10:02:00Z", 15., 15., 15., 15.)); + assert!( + event_a.is_none(), + "Latest policy must discard early plateau bars (A)." + ); + + // Tick 2: Buffer is [15(A), 15(B), 15(C)]. Candidate is 15(B). + // Right is 15. 15 > 15 (FAIL). + let event_b = hhll.update(candle(3, "2026-05-24T10:03:00Z", 15., 15., 15., 15.)); + assert!( + event_b.is_none(), + "Latest policy must discard middle plateau bars (B)." + ); + + // Tick 3: Buffer is [15(B), 15(C), 10]. Candidate is 15(C). + // Left is 15. 15 >= 15 (Pass). Right is 10. 15 > 10 (Pass). + let event_c = hhll.update(candle(4, "2026-05-24T10:04:00Z", 10., 10., 10., 10.)); + assert!( + event_c.is_some(), + "Latest policy should emit the final peak (C)." + ); + assert_eq!(event_c.unwrap().1.indexed_candle.index, 3); + } + + #[test] + fn test_sliding_window_plateaus_latest_low() { + let mut hhll = create_indicator(1, 1, ExtremeTiebreaker::Latest); + + hhll.update(candle(0, "2026-05-24T10:00:00Z", 20., 20., 20., 20.)); // Buffer: [20] + hhll.update(candle(1, "2026-05-24T10:01:00Z", 10., 10., 10., 10.)); // Buffer: [20, 10(A)] + + // Tick 1: Buffer is [20, 10(A), 10(B)]. Candidate is 10(A). + // Right is 10. 10 < 10 (FAIL under Latest strict right). + let event_a = hhll.update(candle(2, "2026-05-24T10:02:00Z", 10., 10., 10., 10.)); + assert!( + event_a.is_none(), + "Latest policy must discard early plateau bars (A)." + ); + + // Tick 2: Buffer is [10(A), 10(B), 10(C)]. Candidate is 10(B). + // Right is 10. 10 < 10 (FAIL). + let event_b = hhll.update(candle(3, "2026-05-24T10:03:00Z", 10., 10., 10., 10.)); + assert!(event_b.is_none()); + + // Tick 3: Buffer is [10(B), 10(C), 20]. Candidate is 10(C). + // Left is 10. 10 <= 10 (Pass). Right is 20. 10 < 20 (Pass). + let event_c = hhll.update(candle(4, "2026-05-24T10:04:00Z", 20., 20., 20., 20.)); + assert!( + event_c.is_some(), + "Latest policy should emit the final trough (C)." + ); + assert_eq!(event_c.unwrap().1.indexed_candle.index, 3); + } + + /// Verifies the update block bootstrapping logic explicitly routes Mega Bars + /// to the correct processor based on Bullish / Bearish candle closes. + #[test] + fn test_mega_bar_bullish_and_bearish_routing() { + let mut hhll = create_indicator(1, 1, ExtremeTiebreaker::Latest); + + // === Subtest: Bullish Mega Bar (Close > Open) === + hhll.update(candle(0, "2026-05-24T10:00:00Z", 20., 20., 20., 20.)); + // Outside Bar: Higher High (30>20) AND Lower Low (5<20). Bullish Close (25 > 10). + hhll.update(candle(1, "2026-05-24T10:01:00Z", 10., 30., 5., 25.)); + let event_bullish = hhll.update(candle(2, "2026-05-24T10:02:00Z", 20., 20., 20., 20.)); + + // Assert: Bullish Mega Bar routes to `process_high` + assert!(event_bullish.is_some()); + assert_eq!(event_bullish.unwrap().1.pivot_type(), PivotType::High); + assert_eq!(event_bullish.unwrap().1.price.0, 30.0); + + hhll.reset(); + + // === Subtest: Bearish Mega Bar (Close < Open) === + hhll.update(candle(0, "2026-05-24T10:00:00Z", 20., 20., 20., 20.)); + // Outside Bar: Higher High (30>20) AND Lower Low (5<20). Bearish Close (10 < 25). + hhll.update(candle(1, "2026-05-24T10:01:00Z", 25., 30., 5., 10.)); + let event_bearish = hhll.update(candle(2, "2026-05-24T10:02:00Z", 20., 20., 20., 20.)); + + // Assert: Bearish Mega Bar routes to `process_low` + assert!(event_bearish.is_some()); + assert_eq!(event_bearish.unwrap().1.pivot_type(), PivotType::Low); + assert_eq!(event_bearish.unwrap().1.price.0, 5.0); + } + + // ========================================== + // === 4. Macro Tiebreaker Tests === + // ========================================== + + #[test] + fn test_basic_swing_high_detection() { + // Window size 5: 2 left, 1 candidate, 2 right. + let mut hhll = create_indicator(2, 2, ExtremeTiebreaker::Latest); + + let trajectory = vec![ + candle(0, "2026-05-24T15:01:00Z", 10., 10., 10., 10.), // L1 + candle(1, "2026-05-24T15:02:00Z", 15., 15., 15., 15.), // L2 + candle(2, "2026-05-24T15:03:00Z", 20., 20., 20., 20.), // Peak (Candidate) + candle(3, "2026-05-24T15:04:00Z", 15., 15., 15., 15.), // R1 + ]; + + for c in trajectory { + assert!( + hhll.update(c).is_none(), + "Should not emit before right window is full" + ); + } + + // Pushing R2 completes the right window for the Candidate (20.0). + // It should immediately trigger the Swing High evaluation. + let event = hhll + .update(candle(4, "2026-05-24T15:05:00Z", 10., 10., 10., 10.)) + .unwrap(); + + assert_eq!(event.1.pivot_type(), PivotType::High); + assert_eq!(event.1.price.0, 20.0); + assert_eq!(event.1.point_in_time(), ts("2026-05-24T15:03:00Z")); + } + + #[test] + fn test_tiebreaker_double_top() { + // Window is 2-2. Requires 5 bars. To correctly center C2 as the first candidate, + // we must pre-pad with two left bars. + let trajectory = vec![ + candle(0, "2026-05-24T15:00:00Z", 10., 10., 10., 10.), + candle(1, "2026-05-24T15:01:00Z", 10., 10., 10., 10.), + candle(2, "2026-05-24T15:02:00Z", 20., 20., 20., 20.), // Peak 1 + candle(3, "2026-05-24T15:03:00Z", 20., 20., 20., 20.), // Peak 2 (Double Top) + candle(4, "2026-05-24T15:04:00Z", 10., 10., 10., 10.), + candle(5, "2026-05-24T15:05:00Z", 10., 10., 10., 10.), + ]; + + // Test Earliest: Should capture Peak 1 (Index 2) + let mut hhll_early = create_indicator(2, 2, ExtremeTiebreaker::Earliest); + let mut early_result = None; + for &c in &trajectory { + if let Some(res) = hhll_early.update(c) { + early_result = Some(res); + } + } + assert_eq!( + early_result.unwrap().1.point_in_time(), + ts("2026-05-24T15:02:00Z") + ); + + // Test Latest: Should capture Peak 2 (Index 3) + let mut hhll_late = create_indicator(2, 2, ExtremeTiebreaker::Latest); + let mut late_result = None; + for &c in &trajectory { + if let Some(res) = hhll_late.update(c) { + late_result = Some(res); + } + } + assert_eq!( + late_result.unwrap().1.point_in_time(), + ts("2026-05-24T15:03:00Z") + ); + } + + #[test] + fn test_tiebreaker_double_bottom() { + // Symmetrical test to double_top. We ensure decreasing Highs to avoid + // generating micro-swing highs during the test window. + let trajectory = vec![ + candle(0, "2026-05-24T15:00:00Z", 20., 50., 20., 20.), + candle(1, "2026-05-24T15:01:00Z", 20., 45., 20., 20.), + candle(2, "2026-05-24T15:02:00Z", 10., 40., 10., 10.), // Trough 1 + candle(3, "2026-05-24T15:03:00Z", 10., 35., 10., 10.), // Trough 2 (Double Bottom) + candle(4, "2026-05-24T15:04:00Z", 20., 30., 20., 20.), + candle(5, "2026-05-24T15:05:00Z", 20., 25., 20., 20.), + ]; + + // Test Earliest: Should capture Trough 1 (Index 2) + let mut hhll_early = create_indicator(2, 2, ExtremeTiebreaker::Earliest); + let mut early_result = None; + for &c in &trajectory { + if let Some(res) = hhll_early.update(c) { + early_result = Some(res); + } + } + assert_eq!( + early_result.unwrap().1.point_in_time(), + ts("2026-05-24T15:02:00Z") + ); + + // Test Latest: Should capture Trough 2 (Index 3) + let mut hhll_late = create_indicator(2, 2, ExtremeTiebreaker::Latest); + let mut late_result = None; + for &c in &trajectory { + if let Some(res) = hhll_late.update(c) { + late_result = Some(res); + } + } + assert_eq!( + late_result.unwrap().1.point_in_time(), + ts("2026-05-24T15:03:00Z") + ); + } + + #[test] + fn test_alternation_filter_overwrites_noise() { + let mut hhll = create_indicator(1, 1, ExtremeTiebreaker::Latest); + + // To test alternation, we must prevent intermediate Swing Lows from triggering. + // We do this by ensuring the `low` values strictly increase. + hhll.update(candle(0, "2026-05-24T10:00:00Z", 10., 10., 1., 10.)); + + // Peak 1: High=20 + let p1 = hhll.update(candle(1, "2026-05-24T10:01:00Z", 20., 20., 2., 20.)); + assert!(p1.is_none()); // Wait for right buffer + + let p1_confirmed = hhll + .update(candle(2, "2026-05-24T10:02:00Z", 15., 15., 3., 15.)) + .unwrap() + .1; + + // Assert strict initial state before the conflict + assert_eq!(p1_confirmed.price.0, 20.0); + assert_eq!(p1_confirmed.pivot_type(), PivotType::High); + assert_eq!(hhll.active_pivot().unwrap().price.0, 20.0); + assert!( + hhll.history().is_empty(), + "History must be empty before confirmation" + ); + + // Dip in Highs, but NOT Lows. This prevents a Trough from confirming! + hhll.update(candle(3, "2026-05-24T10:03:00Z", 15., 15., 4., 15.)); + + // Peak 2: Higher High (30) + let p2 = hhll.update(candle(4, "2026-05-24T10:04:00Z", 30., 30., 5., 30.)); + assert!( + p2.is_none(), + "Candidate is the intermediate dip, must not emit" + ); + + let p2_confirmed = hhll + .update(candle(5, "2026-05-24T10:05:00Z", 10., 10., 6., 10.)) + .unwrap() + .1; + + // Because Alternation is active and no Trough fired, P2 replaces P1. + assert_eq!(p2_confirmed.price.0, 30.0); + assert_eq!(p2_confirmed.pivot_type(), PivotType::High); + + // Ensure the lesser peak was successfully overwritten and NOT pushed to history + assert_eq!(hhll.history().len(), 0); + assert_eq!(hhll.active_pivot().unwrap().price.0, 30.0); + } + + /// Tests Macro Tiebreaker Resolution: Earliest. + #[test] + fn test_macro_tiebreaker_equal_peaks_earliest() { + let mut hhll = create_indicator(1, 1, ExtremeTiebreaker::Earliest); + + let trajectory = vec![ + candle(0, "2026-05-24T10:00:00Z", 10., 10., 1., 10.), + candle(1, "2026-05-24T10:01:00Z", 20., 20., 2., 20.), // Peak 1 + candle(2, "2026-05-24T10:02:00Z", 15., 15., 3., 15.), + candle(3, "2026-05-24T10:03:00Z", 15., 15., 4., 15.), // Flat dip in Highs (no trough) + candle(4, "2026-05-24T10:04:00Z", 20., 20., 5., 20.), // Peak 2 (Double Top) + candle(5, "2026-05-24T10:05:00Z", 10., 10., 6., 10.), + ]; + + let mut events = Vec::new(); + for c in trajectory { + if let Some(event) = hhll.update(c) { + events.push(event); + } + } + + // Only Peak 1 should have been emitted. Peak 2 evaluates in the state machine + // but `20.0 > 20.0` is false, so it is discarded. + assert_eq!(events.len(), 1); + assert_eq!(events[0].1.point_in_time(), ts("2026-05-24T10:01:00Z")); + assert_eq!( + hhll.active_pivot.unwrap().point_in_time(), + ts("2026-05-24T10:01:00Z") + ); + } + + /// Tests Macro Tiebreaker Resolution: Latest. + #[test] + fn test_macro_tiebreaker_equal_peaks_latest() { + let mut hhll = create_indicator(1, 1, ExtremeTiebreaker::Latest); + + let trajectory = vec![ + candle(0, "2026-05-24T10:00:00Z", 10., 10., 1., 10.), + candle(1, "2026-05-24T10:01:00Z", 20., 20., 2., 20.), // Peak 1 + candle(2, "2026-05-24T10:02:00Z", 15., 15., 3., 15.), + candle(3, "2026-05-24T10:03:00Z", 15., 15., 4., 15.), // Flat dip in Highs (no trough) + candle(4, "2026-05-24T10:04:00Z", 20., 20., 5., 20.), // Peak 2 (Double Top) + candle(5, "2026-05-24T10:05:00Z", 10., 10., 6., 10.), + ]; + + let mut events = Vec::new(); + for c in trajectory { + if let Some(event) = hhll.update(c) { + events.push(event); + } + } + + // Peak 1 is emitted first. Then Peak 2 arrives and is ALSO emitted because + // it overwrites Peak 1 as the new active_pivot. + assert_eq!(events.len(), 2); + assert_eq!(events[0].1.point_in_time(), ts("2026-05-24T10:01:00Z")); // First emission + assert_eq!(events[1].1.point_in_time(), ts("2026-05-24T10:04:00Z")); // Overwrite emission + + // The active pivot currently tracking the market should be Peak 2. + assert_eq!( + hhll.active_pivot.unwrap().point_in_time(), + ts("2026-05-24T10:04:00Z") + ); + } + + #[test] + fn test_history_invariant_alternating_mode() { + let mut hhll = create_indicator(1, 1, ExtremeTiebreaker::Latest); + + let trajectory = vec![ + candle(0, "2026-05-24T10:00:00Z", 10., 10., 1., 10.), + candle(1, "2026-05-24T10:01:00Z", 20., 20., 2., 20.), // Peak 1 + candle(2, "2026-05-24T10:02:00Z", 15., 15., 3., 15.), // Dip (No confirmed Low yet) + candle(3, "2026-05-24T10:03:00Z", 30., 30., 4., 30.), // Peak 2 (Higher High) + candle(4, "2026-05-24T10:04:00Z", 10., 10., 5., 10.), + ]; + + for c in trajectory { + let _ = hhll.update(c); + } + + // Active pivot is Peak 2. + assert_eq!(hhll.active_pivot.unwrap().price.0, 30.0); + + // History should be EMPTY. Peak 1 was discarded and never locked in. + assert_eq!( + hhll.history().len(), + 0, + "History should not contain overwritten pivots" + ); + + // Now, push a confirmed Trough to lock in Peak 2. + let _ = hhll.update(candle(5, "2026-05-24T10:05:00Z", 10., 10., -10., 10.)); // Swing Low + let _ = hhll.update(candle(6, "2026-05-24T10:06:00Z", 15., 15., -5., 15.)); // Right window to trigger Low + + // NOW Peak 2 should be safely locked in history. + assert_eq!(hhll.history().len(), 1); + assert_eq!(hhll.history()[0].price.0, 30.0); + assert_eq!(hhll.history()[0].pivot_type(), PivotType::High); + } + + #[test] + fn test_history_invariant_consecutive_mode() { + let mut hhll = StreamingHhll::default() + .with_zig_zag_period(ZigZagPeriod { + left_bars: 1, + right_bars: 1, + }) + .with_alternation_mode(AlternationMode::Consecutive) // Unfiltered + .with_tiebreaker(ExtremeTiebreaker::Latest) + .with_price_source(PriceSource::HighLow); + + let trajectory = vec![ + candle(0, "2026-05-24T10:00:00Z", 10., 10., 1., 10.), + candle(1, "2026-05-24T10:01:00Z", 20., 20., 2., 20.), // Peak 1 + candle(2, "2026-05-24T10:02:00Z", 15., 15., 3., 15.), + candle(3, "2026-05-24T10:03:00Z", 30., 30., 4., 30.), // Peak 2 + candle(4, "2026-05-24T10:04:00Z", 10., 10., 5., 10.), + candle(5, "2026-05-24T10:05:00Z", 40., 40., 6., 40.), // Peak 3 + candle(6, "2026-05-24T10:06:00Z", 10., 10., 7., 10.), + ]; + + for c in trajectory { + let _ = hhll.update(c); + } + + // Active pivot is tracking Peak 3. + assert_eq!(hhll.active_pivot.unwrap().price.0, 40.0); + + // History should contain Peak 1 and Peak 2, despite them all being Highs. + assert_eq!( + hhll.history().len(), + 2, + "Consecutive mode should lock all previous peaks" + ); + assert_eq!(hhll.history()[0].price.0, 20.0); + assert_eq!(hhll.history()[1].price.0, 30.0); + } + + // ========================================== + // === 5. Deep Edge Cases & Mega Bars === + // ========================================== + + /// Tests the Mega Bar (Outside Bar) anomaly when the candidate closes as a Doji. + /// + /// # The Microstructure Logic + /// When a candidate is mathematically BOTH a Swing High and a Swing Low (an outside bar), + /// the algorithm must choose which extremum logically extends the current market structure. + /// If the candle is a Doji (Open == Close, indicating no directional conviction), + /// the algorithm evaluates it against the active trend to see if it extends it. + /// + /// # Scenario A: Mega Doji Extends a High + #[test] + fn test_mega_doji_extends_high() { + let mut hhll = create_indicator(1, 1, ExtremeTiebreaker::Latest); + + // Naturally setup the active pivot to be a High at 50.0 + hhll.update(candle(0, "2026-05-24T13:59:00Z", 10., 10., 10., 10.)); + hhll.update(candle(1, "2026-05-24T14:00:00Z", 50., 50., 10., 50.)); + let p1 = hhll.update(candle(2, "2026-05-24T14:01:00Z", 20., 20., 10., 20.)); + + // Assert natural setup is valid + assert!(p1.is_some()); + assert_eq!(hhll.active_pivot().unwrap().pivot_type(), PivotType::High); + assert_eq!(hhll.active_pivot().unwrap().price.0, 50.0); + + // 1. Push Left Window (prevents Low from firing) + assert!( + hhll.update(candle(3, "2026-05-24T15:01:00Z", 20., 20., 10., 20.)) + .is_none() + ); + + // 2. Push Mega Doji Candidate (High > neighbors, Low < neighbors, Open == Close) + // High is 60 (Extends the 50 High). + assert!( + hhll.update(candle(4, "2026-05-24T15:02:00Z", 25., 60., 5., 25.)) + .is_none() + ); + + // 3. Push Right Window -> Triggers evaluation of the Mega Doji + let event = hhll.update(candle(5, "2026-05-24T15:03:00Z", 20., 20., 10., 20.)); + + // The algorithm routes to `process_high` and evaluates 60 >= 50 (True). + assert!( + event.is_some(), + "Expected Doji to successfully extend the High, but it was discarded" + ); + let (_, pivot) = event.unwrap(); + assert_eq!(pivot.pivot_type(), PivotType::High); + assert_eq!(pivot.price.0, 60.0); + } + + /// # Scenario B: Mega Doji Extends a Low (Sweeps Liquidity) + #[test] + fn test_mega_doji_extends_low() { + let mut hhll = create_indicator(1, 1, ExtremeTiebreaker::Latest); + + // Naturally setup the active pivot to be a Low at 10.0 + hhll.update(candle(0, "2026-05-24T13:59:00Z", 50., 50., 50., 50.)); + hhll.update(candle(1, "2026-05-24T14:00:00Z", 50., 50., 10., 50.)); + let p1 = hhll.update(candle(2, "2026-05-24T14:01:00Z", 50., 50., 20., 50.)); + + // Assert natural setup is valid + assert!(p1.is_some()); + assert_eq!(hhll.active_pivot().unwrap().pivot_type(), PivotType::Low); + assert_eq!(hhll.active_pivot().unwrap().price.0, 10.0); + + assert!( + hhll.update(candle(3, "2026-05-24T15:01:00Z", 50., 50., 20., 50.)) + .is_none() + ); + + // Candidate Low (5.0) < Active Pivot Low (10.0). + assert!( + hhll.update(candle(4, "2026-05-24T15:02:00Z", 25., 60., 5., 25.)) + .is_none() + ); + + let event = hhll.update(candle(5, "2026-05-24T15:03:00Z", 50., 50., 20., 50.)); + + assert!(event.is_some(), "Expected Doji to extend the Low"); + let (_, pivot) = event.unwrap(); + assert_eq!(pivot.pivot_type(), PivotType::Low); + assert_eq!(pivot.price.0, 5.0); + } + + /// # Scenario C: Mega Doji Discarded as Internal Noise + #[test] + fn test_mega_doji_discarded_as_noise() { + let mut hhll = create_indicator(1, 1, ExtremeTiebreaker::Latest); + + // Naturally setup a very deep, established macro Low at price 1.0 + hhll.update(candle(0, "2026-05-24T13:59:00Z", 50., 50., 50., 50.)); + hhll.update(candle(1, "2026-05-24T14:00:00Z", 50., 50., 1., 50.)); + let p1 = hhll.update(candle(2, "2026-05-24T14:01:00Z", 50., 50., 20., 50.)); + + assert!(p1.is_some()); + assert_eq!(hhll.active_pivot().unwrap().price.0, 1.0); + + assert!( + hhll.update(candle(3, "2026-05-24T17:01:00Z", 50., 50., 20., 50.)) + .is_none() + ); + + // Candidate Low (5.0) is NOT < Active Pivot Low (1.0). + assert!( + hhll.update(candle(4, "2026-05-24T17:02:00Z", 25., 60., 5., 25.)) + .is_none() + ); + + let event_noise = hhll.update(candle(5, "2026-05-24T17:03:00Z", 50., 50., 20., 50.)); + + assert!( + event_noise.is_none(), + "Expected Doji to be discarded as internal noise" + ); + } + + /// # Scenario D: Mega Doji Extends a High under Earliest Tiebreaker + #[test] + fn test_mega_doji_extends_high_earliest() { + let mut hhll = create_indicator(1, 1, ExtremeTiebreaker::Earliest); + + // Naturally setup the active pivot to be a High at 50.0 + hhll.update(candle(0, "2026-05-24T13:59:00Z", 10., 10., 10., 10.)); + hhll.update(candle(1, "2026-05-24T14:00:00Z", 50., 50., 10., 50.)); + let p1 = hhll.update(candle(2, "2026-05-24T14:01:00Z", 20., 20., 10., 20.)); + + assert!(p1.is_some()); + + // 1. Push Left Window (prevents Low from firing) + assert!( + hhll.update(candle(3, "2026-05-24T15:01:00Z", 20., 20., 10., 20.)) + .is_none() + ); + + // 2. Push Mega Doji Candidate + // High is 60. Under Earliest, 60 > 50 is TRUE. + assert!( + hhll.update(candle(4, "2026-05-24T15:02:00Z", 25., 60., 5., 25.)) + .is_none() + ); + + // 3. Push Right Window + let event = hhll.update(candle(5, "2026-05-24T15:03:00Z", 20., 20., 10., 20.)); + + assert!( + event.is_some(), + "Expected Doji to successfully extend the High under Earliest" + ); + assert_eq!(event.unwrap().1.price.0, 60.0); + } + + /// # Scenario E: Mega Doji Exact Tie under Earliest (Discarded) + #[test] + fn test_mega_doji_exact_tie_earliest_discarded() { + let mut hhll = create_indicator(1, 1, ExtremeTiebreaker::Earliest); + + // Naturally setup the active pivot to be a High at 50.0 + hhll.update(candle(0, "2026-05-24T13:59:00Z", 10., 10., 10., 10.)); + hhll.update(candle(1, "2026-05-24T14:00:00Z", 50., 50., 10., 50.)); + hhll.update(candle(2, "2026-05-24T14:01:00Z", 20., 20., 10., 20.)); + + hhll.update(candle(3, "2026-05-24T15:01:00Z", 20., 20., 10., 20.)); + + // 2. Push Mega Doji Candidate + // High is EXACTLY 50. Under Earliest, 50 > 50 is FALSE. + hhll.update(candle(4, "2026-05-24T15:02:00Z", 25., 50., 5., 25.)); + + // 3. Push Right Window + let event = hhll.update(candle(5, "2026-05-24T15:03:00Z", 20., 20., 10., 20.)); + + assert!( + event.is_none(), + "Expected exact tie Doji to be discarded under Earliest tiebreaker" + ); + } + + /// # Scenario F: Mega Doji Exact Tie under Latest (Overwrites) + #[test] + fn test_mega_doji_exact_tie_latest_overwrites() { + let mut hhll = create_indicator(1, 1, ExtremeTiebreaker::Latest); + + // Naturally setup the active pivot to be a High at 50.0 + hhll.update(candle(0, "2026-05-24T13:59:00Z", 10., 10., 10., 10.)); + hhll.update(candle(1, "2026-05-24T14:00:00Z", 50., 50., 10., 50.)); + hhll.update(candle(2, "2026-05-24T14:01:00Z", 20., 20., 10., 20.)); + + hhll.update(candle(3, "2026-05-24T15:01:00Z", 20., 20., 10., 20.)); + + // 2. Push Mega Doji Candidate + // High is EXACTLY 50. Under Latest, 50 >= 50 is TRUE. + hhll.update(candle(4, "2026-05-24T15:02:00Z", 25., 50., 5., 25.)); + + // 3. Push Right Window + let event = hhll.update(candle(5, "2026-05-24T15:03:00Z", 20., 20., 10., 20.)); + + assert!( + event.is_some(), + "Expected exact tie Doji to overwrite under Latest tiebreaker" + ); + let (_, pivot) = event.unwrap(); + assert_eq!(pivot.price.0, 50.0); + // Ensure it is actually the NEW candle by checking the timestamp + assert_eq!(pivot.point_in_time(), ts("2026-05-24T15:02:00Z")); + } + + /// # Scenario G: Mega Doji Extends a Low under Earliest Tiebreaker + #[test] + fn test_mega_doji_extends_low_earliest() { + let mut hhll = create_indicator(1, 1, ExtremeTiebreaker::Earliest); + + // Naturally setup the active pivot to be a Low at 10.0 + hhll.update(candle(0, "2026-05-24T13:59:00Z", 50., 50., 50., 50.)); + hhll.update(candle(1, "2026-05-24T14:00:00Z", 50., 50., 10., 50.)); + let p1 = hhll.update(candle(2, "2026-05-24T14:01:00Z", 50., 50., 20., 50.)); + + assert!(p1.is_some()); + assert_eq!(hhll.active_pivot().unwrap().pivot_type(), PivotType::Low); + assert_eq!(hhll.active_pivot().unwrap().price.0, 10.0); + + // 1. Push Left Window (prevents High from firing) + assert!( + hhll.update(candle(3, "2026-05-24T15:01:00Z", 50., 50., 20., 50.)) + .is_none() + ); + + // 2. Push Mega Doji Candidate + // Low is 5. Under Earliest, 5 < 10 is TRUE. + assert!( + hhll.update(candle(4, "2026-05-24T15:02:00Z", 25., 60., 5., 25.)) + .is_none() + ); + + // 3. Push Right Window + let event = hhll.update(candle(5, "2026-05-24T15:03:00Z", 50., 50., 20., 50.)); + + assert!( + event.is_some(), + "Expected Doji to successfully extend the Low under Earliest" + ); + assert_eq!(event.unwrap().1.price.0, 5.0); + } + + /// # Scenario H: Mega Doji Exact Tie under Earliest (Discarded for Low) + #[test] + fn test_mega_doji_exact_tie_earliest_discarded_low() { + let mut hhll = create_indicator(1, 1, ExtremeTiebreaker::Earliest); + + // Naturally setup the active pivot to be a Low at 10.0 + hhll.update(candle(0, "2026-05-24T13:59:00Z", 50., 50., 50., 50.)); + hhll.update(candle(1, "2026-05-24T14:00:00Z", 50., 50., 10., 50.)); + hhll.update(candle(2, "2026-05-24T14:01:00Z", 50., 50., 20., 50.)); + + hhll.update(candle(3, "2026-05-24T15:01:00Z", 50., 50., 20., 50.)); + + // 2. Push Mega Doji Candidate + // Low is EXACTLY 10. Under Earliest, 10 < 10 is FALSE. + hhll.update(candle(4, "2026-05-24T15:02:00Z", 25., 60., 10., 25.)); + + // 3. Push Right Window + let event = hhll.update(candle(5, "2026-05-24T15:03:00Z", 50., 50., 20., 50.)); + + assert!( + event.is_none(), + "Expected exact tie Doji to be discarded under Earliest tiebreaker" + ); + } + + /// # Scenario I: Mega Doji Exact Tie under Latest (Overwrites for Low) + #[test] + fn test_mega_doji_exact_tie_latest_overwrites_low() { + let mut hhll = create_indicator(1, 1, ExtremeTiebreaker::Latest); + + // Naturally setup the active pivot to be a Low at 10.0 + hhll.update(candle(0, "2026-05-24T13:59:00Z", 50., 50., 50., 50.)); + hhll.update(candle(1, "2026-05-24T14:00:00Z", 50., 50., 10., 50.)); + hhll.update(candle(2, "2026-05-24T14:01:00Z", 50., 50., 20., 50.)); + + hhll.update(candle(3, "2026-05-24T15:01:00Z", 50., 50., 20., 50.)); + + // 2. Push Mega Doji Candidate + // Low is EXACTLY 10. Under Latest, 10 <= 10 is TRUE. + hhll.update(candle(4, "2026-05-24T15:02:00Z", 25., 60., 10., 25.)); + + // 3. Push Right Window + let event = hhll.update(candle(5, "2026-05-24T15:03:00Z", 50., 50., 20., 50.)); + + assert!( + event.is_some(), + "Expected exact tie Doji to overwrite under Latest tiebreaker" + ); + let (_, pivot) = event.unwrap(); + assert_eq!(pivot.price.0, 10.0); + // Ensure it is actually the NEW candle by checking the timestamp + assert_eq!(pivot.point_in_time(), ts("2026-05-24T15:02:00Z")); + } + + /// Verifies that if the VERY FIRST extremum detected is a Mega Bar Doji, + /// the algorithm safely ignores it because there is no prior trend to derive inertia from. + #[test] + fn test_initial_orphaned_mega_doji_safely_ignored() { + let mut hhll = create_indicator(1, 1, ExtremeTiebreaker::Latest); + + // Ensure strictly clean state + assert!(hhll.active_pivot().is_none()); + + hhll.update(candle(0, "2026-05-24T10:00:00Z", 20., 20., 20., 20.)); + + // Candidate: Mega Doji (High = 30, Low = 10, Open = 20, Close = 20) + hhll.update(candle(1, "2026-05-24T10:01:00Z", 20., 30., 10., 20.)); + + // Right boundary closes the window + let event = hhll.update(candle(2, "2026-05-24T10:02:00Z", 20., 20., 20., 20.)); + + // Because active_pivot is None, `CandleDirection::Doji` logic explicitly returns `None`. + assert!( + event.is_none(), + "Expected initial Mega Doji to be safely discarded, but it emitted an event." + ); + } + + /// Geometrically proves that an Inside Bar can never be evaluated as an extremum. + #[test] + fn test_inside_bar_never_triggers() { + let mut hhll = create_indicator(1, 1, ExtremeTiebreaker::Latest); + + // Huge outer bar bounds the window + hhll.update(candle(0, "2026-05-24T10:00:00Z", 50., 100., 10., 50.)); + + // Candidate is strictly inside the previous bar's range + hhll.update(candle(1, "2026-05-24T10:01:00Z", 50., 60., 40., 50.)); + + // Right boundary closes the window + let event = hhll.update(candle(2, "2026-05-24T10:02:00Z", 50., 50., 50., 50.)); + + assert!( + event.is_none(), + "An inside bar should never emit a pivot event" + ); + } + + /// Verifies that state-poisoning (NaN prices) are trapped by the partial_cmp logic + /// and gracefully return None without panicking the application. + #[test] + fn test_nan_price_corruption_resistance() { + let mut hhll = create_indicator(1, 1, ExtremeTiebreaker::Latest); + + // Naturally setup a baseline High and establish the anchor + hhll.update(candle(0, "2026-05-24T13:59:00Z", 10., 10., 10., 10.)); + hhll.update(candle(1, "2026-05-24T14:00:00Z", 50., 50., 50., 50.)); + hhll.update(candle(2, "2026-05-24T14:01:00Z", 10., 10., 10., 10.)); + + // Push a confirmed Low to lock the High into `anchor_high` + hhll.update(candle(3, "2026-05-24T14:02:00Z", 5., 5., 5., 5.)); + hhll.update(candle(4, "2026-05-24T14:03:00Z", 10., 10., 10., 10.)); + + assert!( + hhll.anchor_high().is_some(), + "Anchor High must be securely established" + ); + + // Introduce a corrupt candidate with f64::NAN + hhll.update(candle(5, "2026-05-24T15:01:00Z", 10., 10., 10., 10.)); + hhll.update(candle( + 6, + "2026-05-24T15:02:00Z", + f64::NAN, + f64::NAN, + f64::NAN, + f64::NAN, + )); + let event = hhll.update(candle(7, "2026-05-24T15:03:00Z", 10., 10., 10., 10.)); + + // The float partial_cmp(50.0, f64::NAN) yields None. + // The code logs a warning and returns None. + assert!( + event.is_none(), + "Indicator failed to trap NaN price and allowed state corruption." + ); + } +} diff --git a/src/math/traits.rs b/src/math/traits.rs new file mode 100644 index 0000000..bfb19be --- /dev/null +++ b/src/math/traits.rs @@ -0,0 +1,14 @@ +/// A generic trait for incremental indicators. +/// Designed to be object-safe so agents can hold `Box>`. +pub trait StreamingIndicator: std::fmt::Debug + Send + Sync { + type Input; + type Output<'a> + where + Self: 'a; + + /// Update the indicator with the latest data point. + fn update(&mut self, input: Self::Input) -> Self::Output<'_>; + + /// Reset the internal state to clear history (e.g., for a new trading session). + fn reset(&mut self); +} diff --git a/src/prelude.rs b/src/prelude.rs index 1a39391..cd6ee34 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -5,35 +5,31 @@ pub use crate::transport::source::*; // 2. The Core "Loop", Agents & States pub use crate::data::episode::*; -pub use crate::gym::trading::{ - Env, action::*, agent::*, config::*, env::*, observation::*, state::*, types::*, -}; -pub use crate::gym::{ - AgentIdentifier, EnvStatus, GridAxis, InvalidActionPenalty, Reward, StepOutcome, -}; +// Pulls in Reward, EnvStatus, StepOutcome, GridAxis, AgentIdentifier, etc. +pub use crate::gym::*; +// Pulls in Env, Actions, Observation, Agent, load, make, etc. +pub use crate::gym::trading::*; // 3. Financial Domain Types (Primitives & Classifications) -// Safely pulls in Price, Quantity, Tick, Volume, TradeId, SpotPair, etc. pub use crate::data::domain::*; // 4. Events & Views -// Pulls in Ohlcv, Trade, Tpo, MarketView, StreamView, ClosePriceProvider, etc. pub use crate::data::event::*; pub use crate::data::view::*; // 5. Data Configurations & Filters -// Pulls in TechnicalAnalysis, FilterConfig, TradingWindow, OhlcvSpotConfig, etc. pub use crate::data::common::*; -pub use crate::data::config::*; pub use crate::data::filter::*; +pub use crate::data::query::*; // 6. Technical Indicators -// Automatically exposes StreamingSma, StreamingEma, StreamingRsi, StreamingIndicator, etc. -pub use crate::data::indicator::*; -pub use crate::math::indicator::*; +pub use crate::data::batch_indicator::*; +pub use crate::math::fair_value_gap::*; +pub use crate::math::market_profile::*; +pub use crate::math::moving_averages::*; +pub use crate::math::oscillators::*; +pub use crate::math::swing::*; +pub use crate::math::traits::*; // 7. Errors pub use crate::error::*; - -// 8. Factories -pub use crate::gym::trading::{load, make}; diff --git a/src/sim/data.rs b/src/sim/data.rs index e8fda35..b8b2a4b 100644 --- a/src/sim/data.rs +++ b/src/sim/data.rs @@ -444,12 +444,12 @@ mod tests { use crate::{ DataSource, SelfHostedApi, StorageLocation, data::{ - config::{EconomicCalendarConfig, OhlcvSpotConfig}, domain::{ CountryCode, DataBroker, EconomicCategory, EconomicEventImpact, Exchange, Period, Price, Quantity, SpotPair, Symbol, }, event::{EconomicCalendarId, EconomicEvent, Ohlcv, OhlcvId, TradeEvent, TradesId}, + query::{EconomicCalendarQuery, OhlcvSpotQuery}, }, transport::source::EndpointUrl, }; @@ -712,7 +712,7 @@ mod tests { endpoint: EndpointUrl::from("http://test:50051"), api_key: None, }), - OhlcvSpotConfig { + OhlcvSpotQuery { broker: DataBroker::Binance, symbol: Symbol::Spot(SpotPair::BtcUsdt), exchange: Some(Exchange::Binance), @@ -726,7 +726,7 @@ mod tests { endpoint: EndpointUrl::from("http://test:50051"), api_key: None, }), - EconomicCalendarConfig { + EconomicCalendarQuery { broker: DataBroker::InvestingCom, data_source: None, country_code: Some(CountryCode::Us), diff --git a/src/sorted_vec_map.rs b/src/sorted_vec_map.rs index 81c918c..009aee1 100644 --- a/src/sorted_vec_map.rs +++ b/src/sorted_vec_map.rs @@ -67,7 +67,6 @@ impl SortedVecMap { /// let map: SortedVecMap = SortedVecMap::new(); /// assert!(map.is_empty()); /// ``` - #[inline] pub const fn new() -> Self { Self { inner: SmallVec::new_const(), @@ -102,7 +101,6 @@ impl SortedVecMap { /// map.insert(1, "a"); /// assert_eq!(map.len(), 1); /// ``` - #[inline] pub fn len(&self) -> usize { self.inner.len() } @@ -118,13 +116,11 @@ impl SortedVecMap { /// map.insert(1, "a"); /// assert!(!map.is_empty()); /// ``` - #[inline] pub fn is_empty(&self) -> bool { self.inner.is_empty() } /// Returns the number of elements the map can hold without reallocating. - #[inline] pub fn capacity(&self) -> usize { self.inner.capacity() } @@ -140,7 +136,6 @@ impl SortedVecMap { /// map.clear(); /// assert!(map.is_empty()); /// ``` - #[inline] pub fn clear(&mut self) { self.inner.clear(); } @@ -162,7 +157,6 @@ impl SortedVecMap { /// assert!(map.contains_key(&1)); /// assert!(!map.contains_key(&2)); /// ``` - #[inline] pub fn contains_key(&self, key: &K) -> bool { self.inner.iter().any(|(k, _)| k == key) } @@ -178,7 +172,6 @@ impl SortedVecMap { /// assert_eq!(map.get(&1), Some(&"a")); /// assert_eq!(map.get(&2), None); /// ``` - #[inline] pub fn get(&self, key: &K) -> Option<&V> { self.inner.iter().find(|(k, _)| k == key).map(|(_, v)| v) } @@ -196,7 +189,6 @@ impl SortedVecMap { /// } /// assert_eq!(map.get(&1), Some(&"b")); /// ``` - #[inline] pub fn get_mut(&mut self, key: &K) -> Option<&mut V> { self.inner .iter_mut() @@ -214,7 +206,6 @@ impl SortedVecMap { /// map.insert(1, "a"); /// assert_eq!(map.get_key_value(&1), Some((&1, &"a"))); /// ``` - #[inline] pub fn get_key_value(&self, key: &K) -> Option<(&K, &V)> { self.inner .iter() @@ -405,7 +396,6 @@ impl SortedVecMap { /// let keys: Vec<_> = map.keys().copied().collect(); /// assert_eq!(keys, vec![1, 2]); /// ``` - #[inline] pub fn keys(&self) -> impl Iterator { self.inner.iter().map(|(k, _)| k) } @@ -423,7 +413,6 @@ impl SortedVecMap { /// let values: Vec<_> = map.values().copied().collect(); /// assert_eq!(values, vec!["a", "b"]); /// ``` - #[inline] pub fn values(&self) -> impl Iterator { self.inner.iter().map(|(_, v)| v) } @@ -444,7 +433,6 @@ impl SortedVecMap { /// /// assert_eq!(map.get(&1), Some(&2)); /// ``` - #[inline] pub fn values_mut(&mut self) -> impl Iterator { self.inner.iter_mut().map(|(_, v)| v) } @@ -463,7 +451,6 @@ impl SortedVecMap { /// println!("{}: {}", key, value); /// } /// ``` - #[inline] pub fn iter(&self) -> impl Iterator { self.inner.iter().map(|(k, v)| (k, v)) } @@ -481,7 +468,6 @@ impl SortedVecMap { /// *value = "b"; /// } /// ``` - #[inline] pub fn iter_mut(&mut self) -> impl Iterator { self.inner.iter_mut().map(|(k, v)| (&*k, v)) } @@ -504,7 +490,6 @@ impl SortedVecMap { /// Returns a parallel iterator over the key-value pairs. /// /// Requires the `rayon` feature to be enabled. - #[inline] pub fn par_iter(&self) -> impl ParallelIterator { self.inner.par_iter().map(|(k, v)| (k, v)) } diff --git a/src/transport/fetcher.rs b/src/transport/fetcher.rs index 6fbb51c..a86d2b5 100644 --- a/src/transport/fetcher.rs +++ b/src/transport/fetcher.rs @@ -1,7 +1,7 @@ use crate::{ - data::config::{ - ConfigId, EconomicCalendarConfig, OhlcvFutureConfig, OhlcvSpotConfig, TpoFutureConfig, - TpoSpotConfig, TradeSpotConfig, VolumeProfileSpotConfig, + data::query::{ + EconomicCalendarQuery, OhlcvFutureQuery, OhlcvSpotQuery, QueryId, TpoFutureQuery, + TpoSpotQuery, TradeSpotQuery, VolumeProfileSpotQuery, }, error::ChapatyResult, generated::chapaty::{ @@ -29,7 +29,7 @@ use tonic::async_trait; /// Defines how a specific Config/Spec fetches its data. #[async_trait] -pub trait Fetchable: ConfigId + Clone + Send + Sync + Debug + 'static { +pub trait Fetchable: QueryId + Clone + Send + Sync + Debug + 'static { /// The Protobuf Response type (e.g., OhlcvSpotResponse) type Response: ProtoBatch + Send; /// The Protobuf Request type (e.g., OhlcvSpotRequest) @@ -52,7 +52,7 @@ pub trait Fetchable: ConfigId + Clone + Send + Sync + Debug + 'static { // ================================================================================================ #[async_trait] -impl Fetchable for OhlcvSpotConfig { +impl Fetchable for OhlcvSpotQuery { type Response = OhlcvSpotResponse; type Request = OhlcvSpotRequest; @@ -90,7 +90,7 @@ impl Fetchable for OhlcvSpotConfig { // ================================================================================================ #[async_trait] -impl Fetchable for OhlcvFutureConfig { +impl Fetchable for OhlcvFutureQuery { type Response = OhlcvFutureResponse; type Request = OhlcvFutureRequest; @@ -128,7 +128,7 @@ impl Fetchable for OhlcvFutureConfig { // ================================================================================================ #[async_trait] -impl Fetchable for TradeSpotConfig { +impl Fetchable for TradeSpotQuery { type Response = TradesSpotResponse; type Request = TradesSpotRequest; @@ -165,7 +165,7 @@ impl Fetchable for TradeSpotConfig { // ================================================================================================ #[async_trait] -impl Fetchable for TpoSpotConfig { +impl Fetchable for TpoSpotQuery { type Response = TpoSpotResponse; type Request = TpoSpotRequest; @@ -218,7 +218,7 @@ impl Fetchable for TpoSpotConfig { // ================================================================================================ #[async_trait] -impl Fetchable for TpoFutureConfig { +impl Fetchable for TpoFutureQuery { type Response = TpoFutureResponse; type Request = TpoFutureRequest; @@ -271,7 +271,7 @@ impl Fetchable for TpoFutureConfig { // ================================================================================================ #[async_trait] -impl Fetchable for VolumeProfileSpotConfig { +impl Fetchable for VolumeProfileSpotQuery { type Response = VolumeProfileSpotResponse; type Request = VolumeProfileSpotRequest; @@ -324,7 +324,7 @@ impl Fetchable for VolumeProfileSpotConfig { // ================================================================================================ #[async_trait] -impl Fetchable for EconomicCalendarConfig { +impl Fetchable for EconomicCalendarQuery { type Response = EconomicCalendarResponse; type Request = EconomicCalendarRequest;