diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index bc0ad3a..0d4c856 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -77,12 +77,25 @@ jobs: - name: Run Tests run: cargo test --all-features --verbose - # 6. Build Documentation + # 6. Build Documentation (Stable) - name: Verify Documentation Build run: cargo doc --no-deps --document-private-items env: - RUSTDOCFLAGS: "-D warnings" # Fail if docs have broken links + RUSTDOCFLAGS: "-D warnings" - # 7. Check Examples (Compilation Only) + # 7 Verify docs.rs Compatibility (Nightly - SOFT FAIL) + # docs.rs uses the nightly compiler. Upstream crates frequently break here. + # We use continue-on-error so a nightly upstream break doesn't block our PRs. + - name: Check Nightly Docs Compatibility + continue-on-error: true + run: | + rustup toolchain install nightly --profile minimal + cargo +nightly doc --no-deps + env: + # We don't use -D warnings here because nightly often introduces + # new experimental lints that we don't want to enforce yet. + RUSTDOCFLAGS: "" + + # 8. Check Examples (Compilation Only) - name: Check Examples Compile run: cargo build --examples --verbose diff --git a/Cargo.lock b/Cargo.lock index 83b237c..b85e11c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -467,7 +467,7 @@ dependencies = [ [[package]] name = "chapaty" -version = "1.1.0" +version = "1.1.1" dependencies = [ "anyhow", "async-channel", @@ -1725,9 +1725,9 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.95" +version = "0.3.97" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2964e92d1d9dc3364cae4d718d93f227e3abb088e747d92e0395bfdedf1c12ca" +checksum = "a1840c94c045fbcf8ba2812c95db44499f7c64910a912551aaaa541decebcacf" dependencies = [ "cfg-if", "futures-util", @@ -3447,9 +3447,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.39" +version = "0.23.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c2c118cb077cca2822033836dfb1b975355dfb784b5e8da48f7b6c5db74e60e" +checksum = "ef86cd5876211988985292b91c96a8f2d298df24e75989a43a3c73f2d4d8168b" dependencies = [ "log", "once_cell", @@ -3666,9 +3666,9 @@ dependencies = [ [[package]] name = "serde_with" -version = "3.18.0" +version = "3.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd5414fad8e6907dbdd5bc441a50ae8d6e26151a03b1de04d89a5576de61d01f" +checksum = "f05839ce67618e14a09b286535c0d9c94e85ef25469b0e13cb4f844e5593eb19" dependencies = [ "base64", "chrono", @@ -3685,9 +3685,9 @@ dependencies = [ [[package]] name = "serde_with_macros" -version = "3.18.0" +version = "3.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d3db8978e608f1fe7357e211969fd9abdcae80bac1ba7a3369bb7eb6b404eb65" +checksum = "cf2ebbe86054f9b45bc3881e865683ccfaccce97b9b4cb53f3039d67f355a334" dependencies = [ "darling", "proc-macro2", @@ -4666,9 +4666,9 @@ dependencies = [ [[package]] name = "wasm-bindgen" -version = "0.2.118" +version = "0.2.120" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0bf938a0bacb0469e83c1e148908bd7d5a6010354cf4fb73279b7447422e3a89" +checksum = "df52b6d9b87e0c74c9edfa1eb2d9bf85e5d63515474513aa50fa181b3c4f5db1" dependencies = [ "cfg-if", "once_cell", @@ -4680,9 +4680,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.68" +version = "0.4.70" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f371d383f2fb139252e0bfac3b81b265689bf45b6874af544ffa4c975ac1ebf8" +checksum = "af934872acec734c2d80e6617bbb5ff4f12b052dd8e6332b0817bce889516084" dependencies = [ "js-sys", "wasm-bindgen", @@ -4690,9 +4690,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.118" +version = "0.2.120" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eeff24f84126c0ec2db7a449f0c2ec963c6a49efe0698c4242929da037ca28ed" +checksum = "78b1041f495fb322e64aca85f5756b2172e35cd459376e67f2a6c9dffcedb103" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -4700,9 +4700,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.118" +version = "0.2.120" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d08065faf983b2b80a79fd87d8254c409281cf7de75fc4b773019824196c904" +checksum = "9dcd0ff20416988a18ac686d4d4d0f6aae9ebf08a389ff5d29012b05af2a1b41" dependencies = [ "bumpalo", "proc-macro2", @@ -4713,9 +4713,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.118" +version = "0.2.120" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5fd04d9e306f1907bd13c6361b5c6bfc7b3b3c095ed3f8a9246390f8dbdee129" +checksum = "49757b3c82ebf16c57d69365a142940b384176c24df52a087fb748e2085359ea" dependencies = [ "unicode-ident", ] @@ -4769,9 +4769,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.95" +version = "0.3.97" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f2dfbb17949fa2088e5d39408c48368947b86f7834484e87b73de55bc14d97d" +checksum = "2eadbac71025cd7b0834f20d1fe8472e8495821b4e9801eb0a60bd1f19827602" dependencies = [ "js-sys", "wasm-bindgen", diff --git a/Cargo.toml b/Cargo.toml index bed08cb..c8a1d69 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,9 +1,9 @@ [package] name = "chapaty" -version = "1.1.0" +version = "1.1.1" edition = "2024" authors = ["Len Williamson "] -description = "High-performance backtesting and financial simulation framework for trading strategies and reinforcement learning agents. Async-first, Gym-like API in Rust." +description = "An event-driven Rust engine for building and evaluating quantitative trading agents. Features a Gym-style API for algorithmic backtesting and reinforcement learning." license = "Apache-2.0" readme = "README.md" # homepage = "https://www.chapaty.com" diff --git a/README.md b/README.md index cdf9cf7..a02400e 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,12 @@ # Chapaty -[![Discord](https://img.shields.io/discord/1495690333911257108.svg?label=Discord&logo=discord&color=7289da&logoColor=white)](https://discord.gg/k7GWpDQC) +[![Discord](https://img.shields.io/discord/1495690333911257108.svg?label=Discord&logo=discord&color=7289da&logoColor=white)][discord] [![Crates.io](https://img.shields.io/crates/v/chapaty.svg)](https://crates.io/crates/chapaty) [![Docs.rs](https://img.shields.io/docsrs/chapaty)](https://docs.rs/chapaty) [![CI (Main)](https://github.com/LenWilliamson/chapaty/actions/workflows/ci.yml/badge.svg?branch=main)](https://github.com/LenWilliamson/chapaty/actions/workflows/ci.yml) [![CI (Develop)](https://github.com/LenWilliamson/chapaty/actions/workflows/ci.yml/badge.svg?branch=develop)](https://github.com/LenWilliamson/chapaty/actions/workflows/ci.yml) -**Chapaty** is an Rust engine for **building and evaluating quantitative trading agents**. Designed with a familiar, [**Gym-style API**][gymnasiumLink], Chapaty brings the rigor of standardized simulation interfaces to **event-driven financial backtesting**. +**Chapaty** is a Rust engine for **building and evaluating quantitative trading agents**. Designed with a familiar [**Gym-style API**][gymnasiumLink], Chapaty brings the rigor of standardized simulation interfaces to **event-driven financial backtesting**. ## Getting Started @@ -111,7 +111,7 @@ For practical, _ready-to-run agents_, check out the `examples/` directory to get ## Community -If you are excited about the project, don't hesitate to join our [Discord](https://discord.gg/k7GWpDQC)! It is the perfect place to ask questions, file data requests, discuss new features, and share what you have built with the community. +If you are excited about the project, don't hesitate to join our [Discord][discord]! It is the perfect place to ask questions, file data requests, discuss new features, and share what you have built with the community. ## Contributing @@ -133,5 +133,6 @@ This software is provided **“AS IS”**, without warranties or conditions of a By using Chapaty, you acknowledge that **you are solely responsible for any trading decisions, strategies, or outcomes**. +[discord]: https://discord.gg/MmMAB6NCuK [gymnasiumLink]: https://github.com/Farama-Foundation/Gymnasium [deepmindLink]: https://github.com/deepmind/dm_control diff --git a/bin/pre-push.sh b/bin/pre-push.sh index a2c3a5e..41015a9 100755 --- a/bin/pre-push.sh +++ b/bin/pre-push.sh @@ -15,7 +15,7 @@ echo -e "${BLUE}>>> Starting Local CI Pipeline for Chapaty...${NC}" # JOB 1: Compliance & Security # ============================================================================== -echo -e "\n${YELLOW}[1/10] Checking Security (Secrets)...${NC}" +echo -e "\n${YELLOW}[1/11] Checking Security (Secrets)...${NC}" # Check if .cargo/config.toml is tracked by git if git ls-files --error-unmatch .cargo/config.toml > /dev/null 2>&1; then echo -e "${RED}[FAIL] CRITICAL: .cargo/config.toml is being tracked by git! Remove it immediately.${NC}" @@ -23,12 +23,12 @@ if git ls-files --error-unmatch .cargo/config.toml > /dev/null 2>&1; then fi echo -e "${GREEN}[OK] No leaked secrets in git index.${NC}" -echo -e "\n${YELLOW}[2/10] Checking Formatting...${NC}" +echo -e "\n${YELLOW}[2/11] Checking Formatting...${NC}" # Fails if code is not formatted. Remove '--check' to auto-format instead. cargo fmt -- --check || { echo -e "${RED}[FAIL] Formatting invalid. Run 'cargo fmt' to fix.${NC}"; exit 1; } echo -e "${GREEN}[OK] Formatting is correct.${NC}" -echo -e "\n${YELLOW}[3/10] Checking Architecture Guardrails...${NC}" +echo -e "\n${YELLOW}[3/11] Checking Architecture Guardrails...${NC}" # Prevent circular dependencies via prelude imports within the library if grep -r "use crate::prelude::" src/; then echo -e "${RED}[FAIL] Architecture violation: Internal imports from 'crate::prelude' found.${NC}" @@ -40,7 +40,7 @@ echo -e "${GREEN}[OK] Architecture compliant.${NC}" # JOB 2: Build, Test & Verify # ============================================================================== -echo -e "\n${YELLOW}[4/10] Security Audit (Dependencies)...${NC}" +echo -e "\n${YELLOW}[4/11] Security Audit (Dependencies)...${NC}" # Check if cargo-audit is installed if ! command -v cargo-audit &> /dev/null; then echo -e "${RED}[FAIL] 'cargo-audit' is not installed.${NC}" @@ -50,52 +50,70 @@ fi cargo audit echo -e "${GREEN}[OK] Dependencies audited.${NC}" -echo -e "\n${YELLOW}[5/10] Linting (Clippy)...${NC}" +echo -e "\n${YELLOW}[5/11] Linting (Clippy)...${NC}" # Deny warnings to match CI strictness cargo clippy --all-targets --all-features -- -D warnings echo -e "${GREEN}[OK] Code is clean.${NC}" -echo -e "\n${YELLOW}[6/10] Building Workspace...${NC}" +echo -e "\n${YELLOW}[6/11] Building Workspace...${NC}" cargo build --all-features echo -e "${GREEN}[OK] Workspace compiled successfully.${NC}" -echo -e "\n${YELLOW}[7/10] Running Unit Tests...${NC}" +echo -e "\n${YELLOW}[7/11] Running Unit Tests...${NC}" cargo test --all-features echo -e "${GREEN}[OK] All tests passed.${NC}" -echo -e "\n${YELLOW}[8/10] Verifying Documentation...${NC}" +echo -e "\n${YELLOW}[8/11] Verifying Documentation...${NC}" # Ensure documentation builds without warnings (broken links, etc.) export RUSTDOCFLAGS="-D warnings" cargo doc --no-deps --document-private-items echo -e "${GREEN}[OK] Documentation builds successfully.${NC}" +echo -e "\n${YELLOW}[9/11] Verifying Docs.rs Compatibility (Nightly)...${NC}" +# docs.rs strictly uses the nightly compiler. We run a soft-fail check here. +if rustup toolchain list | grep -q nightly; then + # We suppress stdout to keep it clean, but let stderr show if it fails. + # We inline RUSTDOCFLAGS="" to override the strict warnings exported in Step 8. + if env RUSTDOCFLAGS="" cargo +nightly doc --no-deps > /dev/null 2>&1; then + echo -e "${GREEN}[OK] Nightly docs build successfully.${NC}" + else + echo -e "${YELLOW}[WARN] Nightly docs build failed!${NC}" + echo -e "${YELLOW} Your code compiles on Stable, but your docs.rs page will likely fail.${NC}" + echo -e "${YELLOW} This is usually an upstream dependency breaking on Nightly.${NC}" + echo -e "${YELLOW} -> Pipeline continuing because Stable is intact.${NC}" + fi +else + echo -e "${BLUE}[SKIP] Nightly toolchain not installed.${NC}" + echo -e "${BLUE} Run 'rustup toolchain install nightly' to enable docs.rs dry-runs.${NC}" +fi + # ============================================================================== # NEW STEP: Build All Examples, Then Run (excluding grids) # ============================================================================== -echo -e "\n${YELLOW}[9/10] Compiling All Examples...${NC}" +echo -e "\n${YELLOW}[10/11] Compiling All Examples...${NC}" # This mirrors CI Step 7: Ensures even grid.rs examples compile properly cargo build --examples echo -e "${GREEN}[OK] All examples compiled.${NC}" -echo -e "\n${YELLOW}[10/10] Running Examples (skipping *grid.rs)...${NC}" +echo -e "\n${YELLOW}[11/11] Running Examples (skipping *grid.rs except noop_grid)...${NC}" # Iterate over all .rs files in the examples directory for file in examples/*.rs; do # 1. Extract filename (e.g., "news_breakout_grid.rs") filename=$(basename "$file") - + # 2. Extract example name (remove .rs extension) example_name="${filename%.*}" - # 3. Filter: Check if filename contains "grid.rs" - if [[ "$filename" == *"grid.rs"* ]]; then + # 3. Filter: Check if filename contains "grid.rs", but explicitly ALLOW noop_grid + if [[ "$filename" == *"grid.rs"* && "$filename" != "noop_grid.rs" ]]; then echo -e "${BLUE}[SKIP] Long-running example: $example_name (Compiled, but not run)${NC}" continue fi echo -ne " Running example: $example_name ... " - + # 4. Run the example # Redirect stdout to /dev/null to keep terminal clean, but keep stderr for errors. if cargo run --example "$example_name" > /dev/null; then @@ -107,4 +125,4 @@ for file in examples/*.rs; do fi done -echo -e "\n${GREEN}>>> SUCCESS! All checks passed. Ready to push.${NC}" \ No newline at end of file +echo -e "\n${GREEN}>>> SUCCESS! All checks passed. Ready to push.${NC}" diff --git a/examples/noop_grid.rs b/examples/noop_grid.rs new file mode 100644 index 0000000..214e110 --- /dev/null +++ b/examples/noop_grid.rs @@ -0,0 +1,69 @@ +use anyhow::{Context, Result}; +use chapaty::prelude::*; +use rayon::iter::ParallelBridge; +use serde::Serialize; +use std::path::Path; +use std::sync::Arc; + +#[derive(Clone, Serialize)] +struct NoOpAgent; + +impl Agent for NoOpAgent { + fn identifier(&self) -> AgentIdentifier { + AgentIdentifier::Named(Arc::new("NoOpAgent".to_string())) + } + + fn reset(&mut self) {} + + fn act(&mut self, _obs: Observation) -> ChapatyResult { + // Return no actions, guaranteeing 0 trades + Ok(Actions::no_op()) + } +} + +#[tokio::main] +async fn main() -> Result<()> { + let mut env = environment().await?; + let num_agents = 5; + let agents: Vec<(usize, NoOpAgent)> = (0..num_agents).map(|uid| (uid, NoOpAgent)).collect(); + + let leaderboard = env.evaluate_agents( + agents.into_iter().par_bridge(), + 10, // top_k + num_agents as u64, + )?; + + println!( + "Evaluation complete. Leaderboard size: {}", + leaderboard.as_df().height() + ); + + let export_dir = Path::new("examples/reports/noop_grid"); + let file_cfg = FileConfig::default().with_dir(export_dir); + leaderboard.to_file_sync(&file_cfg)?; + + println!("Saved leaderboard to {}", export_dir.display()); + + Ok(()) +} + +async fn environment() -> Result { + let preset = EnvPreset::BinanceBtcUsdt1d; + let file_stem = preset.to_string(); + let loc = StorageLocation::HuggingFace { version: None }; + let cfg = IoConfig::new(loc).with_file_stem(&file_stem); + + chapaty::load(preset, &cfg) + .await + .context("Failed to load trading environment") +} + +#[allow(dead_code)] // Provided for completeness +fn ohlcv_id() -> OhlcvId { + OhlcvId { + broker: DataBroker::Binance, + exchange: Exchange::Binance, + symbol: Symbol::Spot(SpotPair::BtcUsdt), + period: Period::Day(1), + } +} diff --git a/src/data/event.rs b/src/data/event.rs index 580060c..c645c6e 100644 --- a/src/data/event.rs +++ b/src/data/event.rs @@ -20,6 +20,7 @@ use crate::{ indicator::{EmaWindow, RsiWindow, SmaWindow}, }, error::{ChapatyError, ChapatyResult, DataError}, + gym::trading::types::TradeType, }; // ================================================================================================ @@ -28,8 +29,8 @@ use crate::{ /// Capability to check if a specific price was traded within an event's range. pub trait PriceReachable { - /// Returns true if the given `price` was observed by this market event. - fn price_reached(&self, price: Price) -> bool; + /// Returns true if the given `price` was reached or breached based on the intended trade direction. + fn price_reached(&self, price: Price, direction: TradeType) -> bool; } /// Capability to provide a "Close" price for resolving market state. @@ -38,6 +39,15 @@ pub trait ClosePriceProvider { fn close_timestamp(&self) -> DateTime; } +/// Capability to provide a computed technical indicator value at a specific point in time. +pub trait IndicatorValueProvider { + /// The computed value of the indicator (e.g., the EMA line level). + fn value(&self) -> Price; + + /// The timestamp at which this indicator value was recorded/calculated. + fn timestamp(&self) -> DateTime; +} + /// Defines the temporal properties of any financial event. pub trait MarketEvent { /// The canonical timestamp when the event is finished and the data @@ -112,7 +122,7 @@ pub struct Ohlcv { } impl PriceReachable for Ohlcv { - fn price_reached(&self, price: Price) -> bool { + fn price_reached(&self, price: Price, _direction: TradeType) -> bool { self.low.0 <= price.0 && price.0 <= self.high.0 } } @@ -201,8 +211,11 @@ pub struct TradeEvent { } impl PriceReachable for TradeEvent { - fn price_reached(&self, price: Price) -> bool { - self.price.0 == price.0 + fn price_reached(&self, target_price: Price, direction: TradeType) -> bool { + match direction { + TradeType::Long => self.price.0 <= target_price.0, + TradeType::Short => self.price.0 >= target_price.0, + } } } @@ -691,16 +704,19 @@ pub struct Ema { } impl PriceReachable for Ema { - fn price_reached(&self, price: Price) -> bool { - self.price.0 == price.0 + fn price_reached(&self, target_price: Price, direction: TradeType) -> bool { + match direction { + TradeType::Long => self.price.0 <= target_price.0, + TradeType::Short => self.price.0 >= target_price.0, + } } } -impl ClosePriceProvider for Ema { - fn close_price(&self) -> Price { +impl IndicatorValueProvider for Ema { + fn value(&self) -> Price { self.price } - fn close_timestamp(&self) -> DateTime { + fn timestamp(&self) -> DateTime { self.timestamp } } @@ -735,16 +751,19 @@ pub struct Rsi { } impl PriceReachable for Rsi { - fn price_reached(&self, price: Price) -> bool { - self.price.0 == price.0 + fn price_reached(&self, target_price: Price, direction: TradeType) -> bool { + match direction { + TradeType::Long => self.price.0 <= target_price.0, + TradeType::Short => self.price.0 >= target_price.0, + } } } -impl ClosePriceProvider for Rsi { - fn close_price(&self) -> Price { +impl IndicatorValueProvider for Rsi { + fn value(&self) -> Price { self.price } - fn close_timestamp(&self) -> DateTime { + fn timestamp(&self) -> DateTime { self.timestamp } } @@ -780,16 +799,19 @@ pub struct Sma { } impl PriceReachable for Sma { - fn price_reached(&self, price: Price) -> bool { - self.price.0 == price.0 + fn price_reached(&self, target_price: Price, direction: TradeType) -> bool { + match direction { + TradeType::Long => self.price.0 <= target_price.0, + TradeType::Short => self.price.0 >= target_price.0, + } } } -impl ClosePriceProvider for Sma { - fn close_price(&self) -> Price { +impl IndicatorValueProvider for Sma { + fn value(&self) -> Price { self.price } - fn close_timestamp(&self) -> DateTime { + fn timestamp(&self) -> DateTime { self.timestamp } } @@ -955,14 +977,15 @@ impl From for MarketId { mod test { use super::*; + /// Parse RFC3339 timestamp string to DateTime. + fn ts(s: &str) -> DateTime { + DateTime::parse_from_rfc3339(s).unwrap().with_timezone(&Utc) + } + #[test] fn tpo_as_df() { - let open_ts = DateTime::parse_from_rfc3339("2023-01-01T00:00:00Z") - .unwrap() - .with_timezone(&Utc); - let close_ts = DateTime::parse_from_rfc3339("2023-01-01T01:00:00Z") - .unwrap() - .with_timezone(&Utc); + let open_ts = ts("2023-01-01T00:00:00Z"); + let close_ts = ts("2023-01-01T01:00:00Z"); let tpo = Tpo { open_timestamp: open_ts, @@ -1017,12 +1040,8 @@ mod test { #[test] fn vp_as_df() { - let open_ts = DateTime::parse_from_rfc3339("2023-01-02T12:00:00Z") - .unwrap() - .with_timezone(&Utc); - let close_ts = DateTime::parse_from_rfc3339("2023-01-02T16:00:00Z") - .unwrap() - .with_timezone(&Utc); + let open_ts = ts("2023-01-02T12:00:00Z"); + let close_ts = ts("2023-01-02T16:00:00Z"); let vp = VolumeProfile { open_timestamp: open_ts, @@ -1082,4 +1101,204 @@ mod test { assert!(tb_quote.f64().expect("to be f64").get(0).is_none()); // First bin was None assert_eq!(tb_quote.f64().expect("to be f64").get(1), Some(60000.0)); // Second bin was Some } + + // ============================================================================ + // Price Reachability Tests + // ============================================================================ + + fn mock_trade(price: f64) -> TradeEvent { + TradeEvent { + timestamp: ts("2026-05-01T00:00:00Z"), + price: Price(price), + quantity: Quantity(1.0), + trade_id: None, + quote_asset_volume: None, + is_buyer_maker: None, + is_best_match: None, + } + } + + fn mock_ohlcv(low: f64, high: f64) -> Ohlcv { + Ohlcv { + open_timestamp: ts("2026-05-01T00:00:00Z"), + close_timestamp: ts("2026-05-01T00:00:00Z"), + open: Price(0.0), // Irrelevant for reachability + high: Price(high), + low: Price(low), + close: Price(0.0), // Irrelevant for reachability + volume: Quantity(0.0), + quote_asset_volume: None, + number_of_trades: None, + taker_buy_base_asset_volume: None, + taker_buy_quote_asset_volume: None, + } + } + + fn mock_sma(price: f64) -> Sma { + Sma { + timestamp: ts("2026-05-01T00:00:00Z"), + price: Price(price), + } + } + + fn mock_ema(price: f64) -> Ema { + Ema { + timestamp: ts("2026-05-01T00:00:00Z"), + price: Price(price), + } + } + + fn mock_rsi(value: f64) -> Rsi { + Rsi { + timestamp: ts("2026-05-01T00:00:00Z"), + price: Price(value), + } + } + + #[test] + fn test_ohlcv_reachability() { + let target = Price(50000.0); + + // 1. Exact Wick Touches (Edge Cases) + assert!( + mock_ohlcv(49000.0, 50000.0).price_reached(target, TradeType::Long), + "High wick exactly touches target" + ); + assert!( + mock_ohlcv(50000.0, 51000.0).price_reached(target, TradeType::Short), + "Low wick exactly touches target" + ); + + // 2. Complete Engulfing (Target is inside the candle body/wicks) + assert!(mock_ohlcv(49000.0, 51000.0).price_reached(target, TradeType::Long)); + + // 3. Flat Candle / Zero Variance (Doji tick) + assert!(mock_ohlcv(50000.0, 50000.0).price_reached(target, TradeType::Long)); + + // 4. Undershoots / Misses + assert!( + !mock_ohlcv(49000.0, 49999.999999).price_reached(target, TradeType::Long), + "Wick high barely misses" + ); + assert!( + !mock_ohlcv(50000.000001, 51000.0).price_reached(target, TradeType::Short), + "Wick low barely misses" + ); + } + + #[test] + fn test_trade_long_reachability() { + let target = Price(50000.0); + + // 1. Miss: Market price hasn't dropped enough. + assert!(!mock_trade(50000.000001).price_reached(target, TradeType::Long)); + + // 2. Exact Touch: Market prints exactly at our limit. + assert!(mock_trade(50000.0).price_reached(target, TradeType::Long)); + + // 3. Overshoot (Slippage/Gap in our favor): Market blew past our entry, offering a better price. + assert!(mock_trade(49990.0).price_reached(target, TradeType::Long)); + } + + #[test] + fn test_trade_short_reachability() { + let target = Price(50000.0); + + // 1. Miss: Market price hasn't risen enough. + assert!(!mock_trade(49999.999999).price_reached(target, TradeType::Short)); + + // 2. Exact Touch: Market prints exactly at our limit. + assert!(mock_trade(50000.0).price_reached(target, TradeType::Short)); + + // 3. Overshoot (Slippage/Gap in our favor): Market blew past our entry, offering a better price. + assert!(mock_trade(50010.0).price_reached(target, TradeType::Short)); + } + + #[test] + fn test_sma_long_reachability() { + // We want to trigger a Long when SMA drops to 50000.0 or below + let target = Price(50000.0); + + // 1. Undershoot (Miss): SMA is at 50000.1, hasn't dropped enough. + assert!(!mock_sma(50000.1).price_reached(target, TradeType::Long)); + + // 2. Exact Touch: SMA hits exactly 50000.0. + assert!(mock_sma(50000.0).price_reached(target, TradeType::Long)); + + // 3. Overshoot (Gap down): SMA gaps down to 49000.0, completely skipping 50000.0. + assert!(mock_sma(49000.0).price_reached(target, TradeType::Long)); + } + + #[test] + fn test_sma_short_reachability() { + // We want to trigger a Short when SMA rises to 50000.0 or above + let target = Price(50000.0); + + // 1. Undershoot (Miss): SMA is at 49999.9, hasn't risen enough. + assert!(!mock_sma(49999.9).price_reached(target, TradeType::Short)); + + // 2. Exact Touch: SMA hits exactly 50000.0. + assert!(mock_sma(50000.0).price_reached(target, TradeType::Short)); + + // 3. Overshoot (Gap up): SMA gaps up to 51000.0, completely skipping 50000.0. + assert!(mock_sma(51000.0).price_reached(target, TradeType::Short)); + } + + #[test] + fn test_ema_long_reachability() { + let target = Price(100.5); + + // Test precision boundaries often encountered in floating-point math + assert!(!mock_ema(100.50000001).price_reached(target, TradeType::Long)); + assert!(mock_ema(100.5).price_reached(target, TradeType::Long)); + assert!(mock_ema(100.49999999).price_reached(target, TradeType::Long)); + } + + #[test] + fn test_ema_short_reachability() { + let target = Price(100.5); + + assert!( + !mock_ema(100.49999999).price_reached(target, TradeType::Short), + "EMA is just below target" + ); + assert!( + mock_ema(100.5).price_reached(target, TradeType::Short), + "EMA exactly hits target" + ); + assert!( + mock_ema(100.50000001).price_reached(target, TradeType::Short), + "EMA spikes just above target" + ); + } + + #[test] + fn test_rsi_oversold_long() { + // Classic strategy: Buy when RSI drops below 30 + let target = Price(30.0); + + // RSI is 31 (Not oversold enough) + assert!(!mock_rsi(31.0).price_reached(target, TradeType::Long)); + + // RSI is exactly 30 (Trigger) + assert!(mock_rsi(30.0).price_reached(target, TradeType::Long)); + + // RSI plummets to 15 (Trigger) + assert!(mock_rsi(15.0).price_reached(target, TradeType::Long)); + } + + #[test] + fn test_rsi_overbought_short() { + // Classic strategy: Sell when RSI spikes above 70 + let target = Price(70.0); + + // RSI is 69.9 (Not overbought enough) + assert!(!mock_rsi(69.9).price_reached(target, TradeType::Short)); + + // RSI is exactly 70.0 (Trigger) + assert!(mock_rsi(70.0).price_reached(target, TradeType::Short)); + + // RSI rockets to 85.5 (Trigger) + assert!(mock_rsi(85.5).price_reached(target, TradeType::Short)); + } } diff --git a/src/data/view.rs b/src/data/view.rs index d9f45c2..4b68c3c 100644 --- a/src/data/view.rs +++ b/src/data/view.rs @@ -11,6 +11,7 @@ use crate::{ }, }, error::{ChapatyError, ChapatyResult, DataError, SystemError}, + gym::trading::types::TradeType, sim::{ cursor::{Cursor, StreamEntity}, cursor_group::CursorGroup, @@ -67,6 +68,7 @@ pub trait PriceCheckableView { &self, target_symbol: &Symbol, price: Price, + direction: TradeType, since_ts: DateTime, ) -> bool; } @@ -111,6 +113,7 @@ where &self, target_symbol: &Symbol, price: Price, + direction: TradeType, since_ts: DateTime, ) -> bool { // Linear scan of all streams in this view is cheap (M < 100). @@ -126,7 +129,7 @@ where .iter() .rev() .take_while(|e| e.point_in_time() > since_ts) - .any(|e| e.price_reached(price)) + .any(|e| e.price_reached(price, direction)) }) } } @@ -209,17 +212,22 @@ impl<'env> MarketView<'env> { } /// Returns `true` if `price` was reached by any *new* event since the last step. - pub fn reached_price(&self, price: Price, target_symbol: &Symbol) -> bool { + pub fn reached_price( + &self, + price: Price, + target_symbol: &Symbol, + direction: TradeType, + ) -> bool { let prev = self.previous_timestamp(); self.all_price_checkable_views() .into_iter() - .any(|view| view.reached_price_since(target_symbol, price, prev)) + .any(|view| view.reached_price_since(target_symbol, price, direction, prev)) } /// Resolves the most recent, non-leaky close price. pub fn try_resolved_close_price(&self, target_symbol: &Symbol) -> ChapatyResult { let best_price = self - .all_close_price_views() + .close_price_views() .into_iter() .filter_map(|view| view.latest_price_for_symbol(target_symbol)) .max_by_key(|(ts, _)| *ts) @@ -261,8 +269,8 @@ impl<'env> MarketView<'env> { /// Returns a stack-allocated array of all views that provide a canonical market "Close" price. #[inline] - fn all_close_price_views(&self) -> [&dyn ClosePriceView; 5] { - [&self.ohlcv, &self.trades, &self.ema, &self.sma, &self.rsi] + fn close_price_views(&self) -> [&dyn ClosePriceView; 2] { + [&self.ohlcv, &self.trades] } } @@ -398,12 +406,12 @@ mod test { // Focus: PriceCheckableView and ClosePriceView implementations // ============================================================================ - // ========================================================================== + // ============================================================================ // Test: reached_price (Reverse Iteration) // Constraint: Must iterate in reverse // The view checks if any NEW event (prev_ts < event_ts <= current_ts) // hit the price. - // ========================================================================== + // ============================================================================ #[test] fn test_reached_price_basic_hit() { @@ -434,23 +442,23 @@ mod test { ); assert!( - market_view.reached_price(Price(110.0), &symbol), + market_view.reached_price(Price(110.0), &symbol, TradeType::Long), "High (110.0) should be reached" ); assert!( - market_view.reached_price(Price(90.0), &symbol), + market_view.reached_price(Price(90.0), &symbol, TradeType::Long), "Low (90.0) should be reached" ); assert!( - market_view.reached_price(Price(100.0), &symbol), + market_view.reached_price(Price(100.0), &symbol, TradeType::Long), "Price in range (100.0) should be reached" ); assert!( - !market_view.reached_price(Price(120.0), &symbol), + !market_view.reached_price(Price(120.0), &symbol, TradeType::Long), "Price above high (120.0) should NOT be reached" ); assert!( - !market_view.reached_price(Price(80.0), &symbol), + !market_view.reached_price(Price(80.0), &symbol, TradeType::Long), "Price below low (80.0) should NOT be reached" ); } @@ -496,11 +504,11 @@ mod test { ); assert!( - !market_view.reached_price(Price(150.0), &symbol), + !market_view.reached_price(Price(150.0), &symbol, TradeType::Long), "Price 150 is in OLD candle (<=previous_ts), should be IGNORED" ); assert!( - market_view.reached_price(Price(115.0), &symbol), + market_view.reached_price(Price(115.0), &symbol, TradeType::Long), "Price 115 is in NEW candle, should be reached" ); } @@ -566,11 +574,11 @@ mod test { // Both candles are "new" (point_in_time > previous_ts), so both should be checked assert!( - market_view.reached_price(Price(500.0), &symbol), + market_view.reached_price(Price(500.0), &symbol, TradeType::Long), "3m candle (new) contains 500, should be reached" ); assert!( - market_view.reached_price(Price(200.0), &symbol), + market_view.reached_price(Price(200.0), &symbol, TradeType::Long), "5m candle (new) contains 200, should be reached" ); } @@ -618,11 +626,11 @@ mod test { ); assert!( - !market_view.reached_price(Price(111.0), &symbol), + !market_view.reached_price(Price(111.0), &symbol, TradeType::Long), "Price 111 is in candle at EXACTLY previous_ts, should be EXCLUDED (not > since_ts)" ); assert!( - market_view.reached_price(Price(222.0), &symbol), + market_view.reached_price(Price(222.0), &symbol, TradeType::Long), "Price 222 is in candle 1 second after previous_ts, should be INCLUDED" ); } diff --git a/src/gym/trading/agent/crossover.rs b/src/gym/trading/agent/crossover.rs index 729c496..c40edae 100644 --- a/src/gym/trading/agent/crossover.rs +++ b/src/gym/trading/agent/crossover.rs @@ -6,7 +6,7 @@ use serde::Serialize; use crate::{ data::{ domain::{Quantity, TradeId}, - event::{ClosePriceProvider, OhlcvId, SmaId}, + event::{IndicatorValueProvider, OhlcvId, SmaId}, view::StreamView, }, error::ChapatyResult, @@ -186,8 +186,8 @@ impl Agent for PrecomputedCrossover { return Ok(Actions::no_op()); }; - let fast = fast_evt.close_price(); - let slow = slow_evt.close_price(); + let fast = fast_evt.value(); + let slow = slow_evt.value(); // 2. Position Management let agent_id = self.identifier(); diff --git a/src/gym/trading/env.rs b/src/gym/trading/env.rs index 9ca3c42..984671a 100644 --- a/src/gym/trading/env.rs +++ b/src/gym/trading/env.rs @@ -335,12 +335,11 @@ impl Environment { thread_env.eval(agent)?; let pp = thread_env.portfolio_performance()?; - let accessor = pp.accessor()?; let mut entries = Vec::with_capacity(PortfolioPerformanceCol::COUNT); for metric in PortfolioPerformanceCol::iter() { - if let Some(value) = accessor.get(metric) { + if let Some(value) = pp.first(metric) { entries.push(LeaderboardEntry { agent_uid, metric, diff --git a/src/gym/trading/state/active.rs b/src/gym/trading/state/active.rs index a1ab891..a8e73c2 100644 --- a/src/gym/trading/state/active.rs +++ b/src/gym/trading/state/active.rs @@ -210,12 +210,12 @@ pub(super) fn update( // A. Detect Triggers (Independent Checks) let tp_exit = trade .take_profit - .filter(|&tp| ctx.market.reached_price(tp, symbol)) + .filter(|&tp| ctx.market.reached_price(tp, symbol, trade.trade_type)) .map(|tp| (TerminationReason::TakeProfit, tp.0)); let sl_exit = trade .stop_loss - .filter(|&sl| ctx.market.reached_price(sl, symbol)) + .filter(|&sl| ctx.market.reached_price(sl, symbol, trade.trade_type)) .map(|sl| (TerminationReason::StopLoss, sl.0)); // B. Resolve Conflict (Priority Logic) diff --git a/src/gym/trading/state/pending.rs b/src/gym/trading/state/pending.rs index cd371fc..6d0fbd7 100644 --- a/src/gym/trading/state/pending.rs +++ b/src/gym/trading/state/pending.rs @@ -121,7 +121,9 @@ pub(super) fn update( ctx: &UpdateCtx, ) -> ChapatyResult<(State, f64)> { let limit_price = trade.state.limit_price; - let hit_entry = ctx.market.reached_price(limit_price, &m_id.symbol); + let hit_entry = ctx + .market + .reached_price(limit_price, &m_id.symbol, trade.trade_type); if !hit_entry { return Ok((State::Pending(trade), 0.0)); diff --git a/src/io.rs b/src/io.rs index ef97688..c565cfd 100644 --- a/src/io.rs +++ b/src/io.rs @@ -133,7 +133,7 @@ pub enum StorageLocation<'a> { options: CloudOptions, }, /// Local storage location (directory only, not a file path). - Local(&'a Path), + Local { path: &'a Path }, /// Hugging Face Hosted Dataset. /// @@ -161,7 +161,7 @@ impl<'a> StorageLocation<'a> { }) .map_err(|e| ChapatyError::Io(IoError::WriterCreation(e.to_string()))) } - Self::Local(path) => { + Self::Local { path } => { if !path.exists() { std::fs::create_dir_all(path).map_err(|e| { ChapatyError::Io(IoError::WriterCreation(format!( @@ -202,7 +202,7 @@ impl<'a> StorageLocation<'a> { None, )) } - Self::Local(path) => { + Self::Local { path } => { let full_path = path.join(filename); open_local_file(&full_path, buffer_size) } diff --git a/src/prelude.rs b/src/prelude.rs index c8512e0..1a39391 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -4,11 +4,13 @@ pub use crate::report::io::*; pub use crate::transport::source::*; // 2. The Core "Loop", Agents & States -pub use crate::data::episode::EpisodeLength; -pub use crate::gym::AgentIdentifier; +pub use crate::data::episode::*; pub use crate::gym::trading::{ Env, action::*, agent::*, config::*, env::*, observation::*, state::*, types::*, }; +pub use crate::gym::{ + AgentIdentifier, EnvStatus, GridAxis, InvalidActionPenalty, Reward, StepOutcome, +}; // 3. Financial Domain Types (Primitives & Classifications) // Safely pulls in Price, Quantity, Tick, Volume, TradeId, SpotPair, etc. diff --git a/src/report/leaderboard.rs b/src/report/leaderboard.rs index 2eb6fa4..55419b5 100644 --- a/src/report/leaderboard.rs +++ b/src/report/leaderboard.rs @@ -139,13 +139,13 @@ impl Report for Leaderboard { /// A report tracking the top-k performing agents for each performance metric. /// -/// This report maintains a **min-heap** (`BinaryHeap>`) -/// to efficiently track the **top-k best-performing agents** per metric. +/// This report maintains a **min-heap** (`BinaryHeap>`) +/// to efficiently track the **top-k best-performing agents** per metric. #[derive(Clone, Debug)] pub(crate) struct AgentLeaderboard { /// A mapping from performance metrics to **min-heaps** tracking the top-k performing agents. /// - /// Each entry in the heap is wrapped in `Reverse` to ensure that the smallest (i.e. + /// Each entry in the heap is wrapped in `Reverse` to ensure that the smallest (i.e. /// the worst-performing among the top-k) entry is always at the top. pub top_per_metric: SortedVecMap>>, @@ -191,6 +191,11 @@ impl AgentLeaderboard { } pub(crate) fn update(&mut self, new_entries: &[LeaderboardEntry], agent: T) { + // GUARD: If the agent produced no trades/metrics, ignore it entirely. + if new_entries.is_empty() { + return; + } + let mut is_global_winner = false; let mut potentially_evicted = SmallVec::<[u64; METRIC_COUNT]>::new(); @@ -488,6 +493,39 @@ mod tests { // 2. `update` Logic (The Hot Loop) // ============================================================================================ + #[test] + fn test_update_with_empty_entries_does_not_panic() { + // Arrange + let k = 3; + let mut board = AgentLeaderboard::::new(k); + let metric = PortfolioPerformanceCol::TradeSharpeRatio; + + // Fill with one valid agent + let valid_entry = make_entry(1, metric, 10.0); + board.update(&[valid_entry], TestAgent::new(1)); + + // Act: Attempt to update with an empty array (the 0-trade edge case) + board.update(&[], TestAgent::new(2)); + + // Assert: + // 1. We did not panic. + // 2. Agent 2 is not in the heap. + // 3. Agent 2 is not in the cache. + let heap = board.top_per_metric.get(&metric).unwrap(); + assert_eq!(heap.len(), 1, "Heap should only contain the valid agent"); + + let uids = heap.iter().map(|r| r.0.agent_uid).collect::>(); + assert!( + !uids.contains(&2), + "Zero-entry agent should not be in the heap" + ); + + assert!( + !board.agent_data.contains_key(&2), + "Zero-entry agent should not be cached" + ); + } + #[test] fn test_update_fills_capacity() { // Arrange diff --git a/src/report/portfolio_performance.rs b/src/report/portfolio_performance.rs index c7ef490..08e0744 100644 --- a/src/report/portfolio_performance.rs +++ b/src/report/portfolio_performance.rs @@ -31,6 +31,23 @@ pub struct PortfolioPerformance { pub df: DataFrame, } +impl PortfolioPerformance { + /// Efficiently extracts a metric value from a specific row. + pub fn get(&self, metric: PortfolioPerformanceCol, row: usize) -> Option { + // GUARD: Polars ChunkedArray::get() will panic if called on an empty series. + // We must explicitly verify the row exists within the dataframe bounds first. + if row >= self.df.height() { + return None; + } + + self.df.column(metric.as_str()).ok()?.f64().ok()?.get(row) + } + + /// Safely extracts the metric value from the first row. + pub fn first(&self, metric: PortfolioPerformanceCol) -> Option { + self.get(metric, 0) + } +} impl Default for PortfolioPerformance { fn default() -> Self { let df = DataFrame::empty_with_schema(&Self::to_schema()); @@ -172,38 +189,6 @@ impl TryFrom<&GroupedJournal<'_>> for PortfolioPerformance { } } -pub struct PortfolioPerformanceAccessor<'a> { - df: &'a DataFrame, -} - -impl PortfolioPerformance { - /// Creates a safe accessor for scalar value extraction. - /// - /// # Errors - /// Returns an error if the report is **Grouped** (rows > 1) or **Empty**. - /// This prevents logical errors where users might mistakenly read the first group's - /// result as a global metric. - pub fn accessor(&self) -> ChapatyResult> { - match self.df.height() { - 1 => Ok(PortfolioPerformanceAccessor { df: &self.df }), - 0 => Err(DataError::DataFrame("Report is empty".to_string()).into()), - n => Err(DataError::DataFrame(format!( - "Cannot extract scalar from grouped report (rows={n})." - )) - .into()), - } - } -} - -impl<'a> PortfolioPerformanceAccessor<'a> { - /// Efficiently extracts a metric value from the single-row report. - /// - /// Returns `None` if the value is null (e.g., Sharpe Ratio with 0 volatility). - pub fn get(&self, metric: PortfolioPerformanceCol) -> Option { - self.df.column(metric.as_str()).ok()?.f64().ok()?.get(0) - } -} - fn exprs(cfg: RiskMetricsConfig) -> Vec { let return_col = JournalCol::RealizedReturnDollars; let exit_reason_col = JournalCol::ExitReason; @@ -1471,10 +1456,8 @@ mod tests { let journal = load_journal_fixture(); let perf = PortfolioPerformance::try_from(&journal).expect("Conversion failed"); - let accessor = perf.accessor().expect("Should create accessor"); - - let net_profit = accessor - .get(PortfolioPerformanceCol::NetProfit) + let net_profit = perf + .first(PortfolioPerformanceCol::NetProfit) .expect("Net profit should be available"); assert_eq!(net_profit, 2000.0, "Net profit via accessor should be 2000"); @@ -1498,6 +1481,37 @@ mod tests { assert_eq!(df.height(), 0, "Empty journal should produce 0 rows"); } + // ======================================================================== + // Test: Empty PortfolioPerformance + // ======================================================================== + + #[test] + fn test_get_and_first_safely_handle_empty_dataframe() { + let perf = PortfolioPerformance::default(); + let metric = PortfolioPerformanceCol::TradeSharpeRatio; + + // Act & Assert 1: .first() must return None without panicking + let first_val = perf.first(metric); + assert_eq!( + first_val, None, + "first() on an empty dataframe must return None" + ); + + // Act & Assert 2: .get(0) must return None without panicking + let get_zero = perf.get(metric, 0); + assert_eq!( + get_zero, None, + "get(0) on an empty dataframe must return None" + ); + + // Act & Assert 3: .get() out of bounds on empty dataframe must return None + let get_out_of_bounds = perf.get(metric, 99); + assert_eq!( + get_out_of_bounds, None, + "get() out of bounds must return None" + ); + } + // ======================================================================== // Test: Trade Return Variance // ======================================================================== diff --git a/src/sim/data.rs b/src/sim/data.rs index e723fd8..e8fda35 100644 --- a/src/sim/data.rs +++ b/src/sim/data.rs @@ -774,7 +774,7 @@ mod tests { // 2. Set up temp directory for cache let temp_dir = std::env::temp_dir().join("chapaty_test_cache"); - let storage = StorageLocation::Local(&temp_dir); + let storage = StorageLocation::Local { path: &temp_dir }; let io_cfg = IoConfig::new(storage); // 3. Write to file using SimulationData::write() @@ -823,7 +823,7 @@ mod tests { // 2. Set up temp directory with a custom file stem let temp_dir = std::env::temp_dir().join("chapaty_test_cache_custom"); - let storage = StorageLocation::Local(&temp_dir); + let storage = StorageLocation::Local { path: &temp_dir }; let io_cfg = IoConfig::new(storage).with_file_stem(CUSTOM_NAME); // 3. Write using the custom filename