From f622e4dde27185f012cbfdac79b9fdbb7479a56f Mon Sep 17 00:00:00 2001 From: LenWilliamson Date: Sun, 3 May 2026 19:14:09 +0200 Subject: [PATCH 1/2] bump version to 1.1.3 --- Cargo.lock | 2 +- Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 935efcb..4b78e47 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -467,7 +467,7 @@ dependencies = [ [[package]] name = "chapaty" -version = "1.1.2" +version = "1.1.3" dependencies = [ "anyhow", "async-channel", diff --git a/Cargo.toml b/Cargo.toml index 9a11c31..edf3186 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "chapaty" -version = "1.1.2" +version = "1.1.3" edition = "2024" authors = ["Len Williamson "] description = "An event-driven Rust engine for building and evaluating quantitative trading agents. Features a Gym-style API for algorithmic backtesting and reinforcement learning." From dbc15df86f741dc6f4d0431cf809189ef27a570f Mon Sep 17 00:00:00 2001 From: LenWilliamson Date: Wed, 13 May 2026 21:19:27 +0200 Subject: [PATCH 2/2] refactor: update evaluate_agents API and fix progress bar starvation - Changed `evaluate_agents` to accept a `Vec<(usize, T)>` instead of an `impl ParallelIterator`. - Removed the `stream_len` argument, as the exact bounds are now natively known from the vector's length. - Fixed the UI freezing issue by replacing the manual `pb.inc(1)` inside the CPU-bound `try_fold` with `indicatif`'s `.progress_with()`. This ensures smooth real-time progress bar updates during parallel execution. - Added new environment presets (`envpreset`) for backtesting. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- Cargo.lock | 126 ++++++++++++------------- Cargo.toml | 6 +- examples/news_breakout_grid.rs | 11 +-- examples/news_fade_grid.rs | 11 +-- examples/noop_grid.rs | 7 +- src/gym/trading/agent/news/breakout.rs | 56 +++-------- src/gym/trading/agent/news/fade.rs | 54 +++-------- src/gym/trading/agent/news/hybrid.rs | 39 ++++---- src/gym/trading/config.rs | 126 ++++++++++++++++++++++--- src/gym/trading/env.rs | 21 ++--- 10 files changed, 240 insertions(+), 217 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4b78e47..5866400 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -358,6 +358,15 @@ dependencies = [ "alloc-stdlib", ] +[[package]] +name = "bs58" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf88ba1141d185c399bee5288d850d63b8369520c1eafc32a0430b5b6c287bf4" +dependencies = [ + "tinyvec", +] + [[package]] name = "bumpalo" version = "3.20.2" @@ -432,9 +441,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.61" +version = "1.2.62" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d16d90359e986641506914ba71350897565610e87ce0ad9e6f28569db3dd5c6d" +checksum = "a1dce859f0832a7d088c4f1119888ab94ef4b5d6795d1ce05afb7fe159d79f98" dependencies = [ "find-msvc-tools", "jobserver", @@ -1223,9 +1232,9 @@ checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280" [[package]] name = "h2" -version = "0.4.13" +version = "0.4.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f44da3a8150a6703ed5d34e164b875fd14c2cdab9af1252a9a1020bde2bdc54" +checksum = "171fefbc92fe4a4de27e0698d6a5b392d6a0e333506bc49133760b3bcf948733" dependencies = [ "atomic-waker", "bytes", @@ -1295,9 +1304,9 @@ dependencies = [ [[package]] name = "hashbrown" -version = "0.17.0" +version = "0.17.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f467dd6dccf739c208452f8014c75c18bb8301b050ad1cfb27153803edb0f51" +checksum = "ed5909b6e89a2db4456e54cd5f673791d7eca6732202bbf2a9cc504fe2f9b84a" [[package]] name = "heapless" @@ -1664,7 +1673,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d466e9454f08e4a911e14806c24e16fba1b4c121d1ea474396f396069cf949d9" dependencies = [ "equivalent", - "hashbrown 0.17.0", + "hashbrown 0.17.1", "serde", "serde_core", ] @@ -1677,6 +1686,7 @@ checksum = "25470f23803092da7d239834776d653104d551bc4d7eacaf31e6837854b8e9eb" dependencies = [ "console", "portable-atomic", + "rayon", "unicode-width", "unit-prefix", "web-time", @@ -1688,16 +1698,6 @@ version = "2.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d98f6fed1fde3f8c21bc40a1abb88dd75e67924f9cffc3ef95607bad8017f8e2" -[[package]] -name = "iri-string" -version = "0.7.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25e659a4bb38e810ebc252e53b5814ff908a8c58c2a9ce2fae1bbec24cbf4e20" -dependencies = [ - "memchr", - "serde", -] - [[package]] name = "itertools" version = "0.14.0" @@ -1725,9 +1725,9 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.97" +version = "0.3.98" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1840c94c045fbcf8ba2812c95db44499f7c64910a912551aaaa541decebcacf" +checksum = "67df7112613f8bfd9150013a0314e196f4800d3201ae742489d999db2f979f08" dependencies = [ "cfg-if", "futures-util", @@ -2069,15 +2069,14 @@ checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50" [[package]] name = "openssl" -version = "0.10.78" +version = "0.10.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f38c4372413cdaaf3cc79dd92d29d7d9f5ab09b51b10dded508fb90bb70b9222" +checksum = "bf0b434746ee2832f4f0baf10137e1cabb18cbe6912c69e2e33263c45250f542" dependencies = [ "bitflags", "cfg-if", "foreign-types", "libc", - "once_cell", "openssl-macros", "openssl-sys", ] @@ -2101,9 +2100,9 @@ checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe" [[package]] name = "openssl-sys" -version = "0.9.114" +version = "0.9.115" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13ce1245cd07fcc4cfdb438f7507b0c7e4f3849a69fd84d52374c66d83741bb6" +checksum = "158fe5b292746440aa6e7a7e690e55aeb72d41505e2804c23c6973ad0e9c9781" dependencies = [ "cc", "libc", @@ -2203,18 +2202,18 @@ dependencies = [ [[package]] name = "pin-project" -version = "1.1.11" +version = "1.1.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1749c7ed4bcaf4c3d0a3efc28538844fb29bcdd7d2b67b2be7e20ba861ff517" +checksum = "cbf0d9e68100b3a7989b4901972f265cd542e560a3a8a724e1e20322f4d06ce9" dependencies = [ "pin-project-internal", ] [[package]] name = "pin-project-internal" -version = "1.1.11" +version = "1.1.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d9b20ed30f105399776b9c883e68e536ef602a16ae6f596d2c473591d6ad64c6" +checksum = "a990e22f43e84855daf260dded30524ef4a9021cc7541c26540500a50b624389" dependencies = [ "proc-macro2", "quote", @@ -3402,9 +3401,9 @@ dependencies = [ [[package]] name = "rust_decimal" -version = "1.41.0" +version = "1.42.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2ce901f9a19d251159075a4c37af514c3b8ef99c22e02dd8c19161cf397ee94a" +checksum = "0c5108e3d4d903e21aac27f12ba5377b6b34f9f44b325e4894c7924169d06995" dependencies = [ "arrayvec", "borsh", @@ -3666,11 +3665,12 @@ dependencies = [ [[package]] name = "serde_with" -version = "3.19.0" +version = "3.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f05839ce67618e14a09b286535c0d9c94e85ef25469b0e13cb4f844e5593eb19" +checksum = "e72c1c2cb7b223fafb600a619537a871c2818583d619401b785e7c0b746ccde2" dependencies = [ "base64", + "bs58", "chrono", "hex", "indexmap 1.9.3", @@ -3685,9 +3685,9 @@ dependencies = [ [[package]] name = "serde_with_macros" -version = "3.19.0" +version = "3.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf2ebbe86054f9b45bc3881e865683ccfaccce97b9b4cb53f3039d67f355a334" +checksum = "b90c488738ecb4fb0262f41f43bc40efc5868d9fb744319ddf5f5317f417bfac" dependencies = [ "darling", "proc-macro2", @@ -3771,9 +3771,9 @@ checksum = "e3a9fe34e3e7a50316060351f37187a3f546bce95496156754b601a5fa71b76e" [[package]] name = "siphasher" -version = "1.0.2" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2aa850e253778c88a04c3d7323b043aeda9d3e30d5971937c1855769763678e" +checksum = "8ee5873ec9cce0195efcb7a4e9507a04cd49aec9c83d0389df45b1ef7ba2e649" [[package]] name = "slab" @@ -4130,9 +4130,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.52.1" +version = "1.52.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b67dee974fe86fd92cc45b7a95fdd2f99a36a6d7b0d431a231178d3d670bbcc6" +checksum = "8fc7f01b389ac15039e4dc9531aa973a135d7a4135281b12d7c1bc79fd57fffe" dependencies = [ "bytes", "libc", @@ -4232,9 +4232,9 @@ dependencies = [ [[package]] name = "tonic" -version = "0.14.5" +version = "0.14.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fec7c61a0695dc1887c1b53952990f3ad2e3a31453e1f49f10e75424943a93ec" +checksum = "ac2a5518c70fa84342385732db33fb3f44bc4cc748936eb5833d2df34d6445ef" dependencies = [ "async-trait", "axum", @@ -4261,9 +4261,9 @@ dependencies = [ [[package]] name = "tonic-build" -version = "0.14.5" +version = "0.14.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1882ac3bf5ef12877d7ed57aad87e75154c11931c2ba7e6cde5e22d63522c734" +checksum = "c68f61875ac5293cf72e6c8cf0158086428c82c37229e98c840878f1706b0322" dependencies = [ "prettyplease", "proc-macro2", @@ -4273,9 +4273,9 @@ dependencies = [ [[package]] name = "tonic-prost" -version = "0.14.5" +version = "0.14.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a55376a0bbaa4975a3f10d009ad763d8f4108f067c7c2e74f3001fb49778d309" +checksum = "50849f68853be452acf590cde0b146665b8d507b3b8af17261df47e02c209ea0" dependencies = [ "bytes", "prost", @@ -4284,9 +4284,9 @@ dependencies = [ [[package]] name = "tonic-prost-build" -version = "0.14.5" +version = "0.14.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f3144df636917574672e93d0f56d7edec49f90305749c668df5101751bb8f95a" +checksum = "654e5643eff75d7f8c99197ce1440ed19a3474eada74c12bbac488b2cafdae27" dependencies = [ "prettyplease", "proc-macro2", @@ -4319,20 +4319,20 @@ dependencies = [ [[package]] name = "tower-http" -version = "0.6.8" +version = "0.6.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d4e6559d53cc268e5031cd8429d05415bc4cb4aefc4aa5d6cc35fbf5b924a1f8" +checksum = "68d6fdd9f81c2819c9a8b0e0cd91660e7746a8e6ea2ba7c6b2b057985f6bcb51" dependencies = [ "bitflags", "bytes", "futures-util", "http", "http-body", - "iri-string", "pin-project-lite", "tower", "tower-layer", "tower-service", + "url", ] [[package]] @@ -4666,9 +4666,9 @@ dependencies = [ [[package]] name = "wasm-bindgen" -version = "0.2.120" +version = "0.2.121" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df52b6d9b87e0c74c9edfa1eb2d9bf85e5d63515474513aa50fa181b3c4f5db1" +checksum = "49ace1d07c165b0864824eee619580c4689389afa9dc9ed3a4c75040d82e6790" dependencies = [ "cfg-if", "once_cell", @@ -4680,9 +4680,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.70" +version = "0.4.71" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af934872acec734c2d80e6617bbb5ff4f12b052dd8e6332b0817bce889516084" +checksum = "96492d0d3ffba25305a7dc88720d250b1401d7edca02cc3bcd50633b424673b8" dependencies = [ "js-sys", "wasm-bindgen", @@ -4690,9 +4690,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.120" +version = "0.2.121" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78b1041f495fb322e64aca85f5756b2172e35cd459376e67f2a6c9dffcedb103" +checksum = "8e68e6f4afd367a562002c05637acb8578ff2dea1943df76afb9e83d177c8578" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -4700,9 +4700,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.120" +version = "0.2.121" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9dcd0ff20416988a18ac686d4d4d0f6aae9ebf08a389ff5d29012b05af2a1b41" +checksum = "d95a9ec35c64b2a7cb35d3fead40c4238d0940c86d107136999567a4703259f2" dependencies = [ "bumpalo", "proc-macro2", @@ -4713,9 +4713,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.120" +version = "0.2.121" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49757b3c82ebf16c57d69365a142940b384176c24df52a087fb748e2085359ea" +checksum = "c4e0100b01e9f0d03189a92b96772a1fb998639d981193d7dbab487302513441" dependencies = [ "unicode-ident", ] @@ -4769,9 +4769,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.97" +version = "0.3.98" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2eadbac71025cd7b0834f20d1fe8472e8495821b4e9801eb0a60bd1f19827602" +checksum = "4b572dff8bcf38bad0fa19729c89bb5748b2b9b1d8be70cf90df697e3a8f32aa" dependencies = [ "js-sys", "wasm-bindgen", @@ -5240,9 +5240,9 @@ dependencies = [ [[package]] name = "zerofrom" -version = "0.1.7" +version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69faa1f2a1ea75661980b013019ed6687ed0e83d069bc1114e2cc74c6c04c4df" +checksum = "0ec05a11813ea801ff6d75110ad09cd0824ddba17dfe17128ea0d5f68e6c5272" dependencies = [ "zerofrom-derive", ] diff --git a/Cargo.toml b/Cargo.toml index edf3186..675e1a7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -40,7 +40,7 @@ crossbeam = "0.8" futures = "0.3" hf-hub = "0.5.0" humantime = "2.3.0" -indicatif = "0.18" +indicatif = {version ="0.18", features = ["rayon"]} itertools = "0.14" ndarray = "0.17" object_store = "0.12.5" @@ -72,10 +72,10 @@ prost-types = "0.14" rand = "0.10" rayon = "1.12" regex = "1.12" -rust_decimal = "1.41" +rust_decimal = "1.42" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" -serde_with = { version = "3.18", features = ["chrono"] } +serde_with = { version = "3.20", features = ["chrono"] } smallvec = { version = "1.15", features = ["serde"] } strum = { version = "0.28", features = ["derive"] } strum_macros = "0.28" diff --git a/examples/news_breakout_grid.rs b/examples/news_breakout_grid.rs index 098a0d3..4144dee 100644 --- a/examples/news_breakout_grid.rs +++ b/examples/news_breakout_grid.rs @@ -5,7 +5,6 @@ use chapaty::{ gym::trading::agent::news::breakout::{NewsBreakout, NewsBreakoutGrid}, prelude::*, }; -use rayon::iter::{ParallelBridge, ParallelIterator}; // === BEGIN JEMALLOC CONFIG === #[cfg(target_os = "linux")] @@ -24,13 +23,9 @@ async fn main() -> Result<()> { let mut env = environment().await?; let build_time = build_start.elapsed(); - let (stream_len, agents_iter) = news_breakout_grid(); + let agents = news_breakout_grid(); let grid_backtest_start = Instant::now(); - let leaderboard = env.evaluate_agents( - agents_iter.collect::>().into_iter().par_bridge(), - 100, - stream_len as u64, - )?; + let leaderboard = env.evaluate_agents(agents, 100)?; let grid_backtest_time = grid_backtest_start.elapsed(); let path = Path::new("examples/reports/news_breakout"); @@ -47,7 +42,7 @@ async fn main() -> Result<()> { // Helper Functions // ================================================================================================ -fn news_breakout_grid() -> (usize, impl ParallelIterator) { +fn news_breakout_grid() -> Vec<(usize, NewsBreakout)> { NewsBreakoutGrid::baseline(economic_calendar_id(), ohlcv_id()) .expect("Failed to create baseline grid") // Optional: Constrain the grid for a quick demo run diff --git a/examples/news_fade_grid.rs b/examples/news_fade_grid.rs index 9c3a964..6abb794 100644 --- a/examples/news_fade_grid.rs +++ b/examples/news_fade_grid.rs @@ -5,7 +5,6 @@ use chapaty::{ gym::trading::agent::news::fade::{NewsFade, NewsFadeGrid}, prelude::*, }; -use rayon::iter::{ParallelBridge, ParallelIterator}; // === BEGIN JEMALLOC CONFIG === #[cfg(target_os = "linux")] @@ -24,13 +23,9 @@ async fn main() -> Result<()> { let mut env = environment().await?; let build_time = build_start.elapsed(); - let (stream_len, agents_iter) = news_fade_grid(); + let agents = news_fade_grid(); let grid_backtest_start = Instant::now(); - let leaderboard = env.evaluate_agents( - agents_iter.collect::>().into_iter().par_bridge(), - 100, - stream_len as u64, - )?; + let leaderboard = env.evaluate_agents(agents, 100)?; let grid_backtest_time = grid_backtest_start.elapsed(); let path = Path::new("examples/reports/news_fade"); @@ -47,7 +42,7 @@ async fn main() -> Result<()> { // Helper Functions // ================================================================================================ -fn news_fade_grid() -> (usize, impl ParallelIterator) { +fn news_fade_grid() -> Vec<(usize, NewsFade)> { NewsFadeGrid::baseline(economic_calendar_id(), ohlcv_id()) .expect("Failed to create baseline grid") // Optional: Constrain the grid for a quick demo run diff --git a/examples/noop_grid.rs b/examples/noop_grid.rs index 214e110..e61dc33 100644 --- a/examples/noop_grid.rs +++ b/examples/noop_grid.rs @@ -1,6 +1,5 @@ use anyhow::{Context, Result}; use chapaty::prelude::*; -use rayon::iter::ParallelBridge; use serde::Serialize; use std::path::Path; use std::sync::Arc; @@ -27,11 +26,7 @@ async fn main() -> Result<()> { let num_agents = 5; let agents: Vec<(usize, NoOpAgent)> = (0..num_agents).map(|uid| (uid, NoOpAgent)).collect(); - let leaderboard = env.evaluate_agents( - agents.into_iter().par_bridge(), - 10, // top_k - num_agents as u64, - )?; + let leaderboard = env.evaluate_agents(agents, 10)?; println!( "Evaluation complete. Leaderboard size: {}", diff --git a/src/gym/trading/agent/news/breakout.rs b/src/gym/trading/agent/news/breakout.rs index 51df17a..ec22922 100644 --- a/src/gym/trading/agent/news/breakout.rs +++ b/src/gym/trading/agent/news/breakout.rs @@ -2,8 +2,6 @@ use std::sync::Arc; use chrono::Duration; use itertools::iproduct; -use rand::seq::SliceRandom; -use rayon::iter::{IntoParallelIterator, ParallelIterator}; use serde::Serialize; use serde_with::{DurationSeconds, serde_as}; @@ -445,11 +443,11 @@ impl NewsBreakoutGrid { } } - pub fn build(self) -> (usize, impl ParallelIterator) { + pub fn build(self) -> Vec<(usize, NewsBreakout)> { let (start_earliest, end_earliest) = self.earliest_entry; let (start_latest, end_latest) = self.latest_entry; - // === 1. Generate Axes === + // === Generate Axes === let stop_loss_risk_factors = self.stop_loss_risk_factor.generate(); let risk_reward_ratios = self.risk_reward_ratio.generate(); @@ -461,8 +459,11 @@ impl NewsBreakoutGrid { .map(Duration::minutes) .collect::>(); - // === 2. Eagerly Collect Valid Args (The "Fat" Vector) === - let mut args = iproduct!( + // === Eagerly Collect Valid Args === + let cal_id = self.cal_id; + let market_id = self.market_id; + + iproduct!( risk_reward_ratios, stop_loss_risk_factors, latest_entries, @@ -470,44 +471,17 @@ impl NewsBreakoutGrid { ) .filter(|(_, _, latest, earliest)| earliest < latest) .enumerate() - .map(|(uid, (rrr, slrf, latest, earliest))| NewsBreakoutArgs { - uid, - rrr, - slrf, - latest, - earliest, - }) - .collect::>(); - - let mut rng = rand::rng(); - args.shuffle(&mut rng); - - let total_combinations = args.len(); - let cal_id = self.cal_id; - let market_id = self.market_id; - - // === 3. Simple Parallel Iterator === - let iterator = args.into_par_iter().map(move |arg| { + .map(|(uid, (rrr, slrf, latest, earliest))| { ( - arg.uid, + uid, NewsBreakout::baseline(cal_id, market_id) - .with_earliest_entry_candle(arg.earliest) - .with_latest_entry_candle(arg.latest) - .with_stop_loss_risk_factor(arg.slrf) - .with_risk_reward_ratio(arg.rrr) + .with_earliest_entry_candle(earliest) + .with_latest_entry_candle(latest) + .with_stop_loss_risk_factor(slrf) + .with_risk_reward_ratio(rrr) .expect("Valid grid parameters"), ) - }); - - (total_combinations, iterator) + }) + .collect::>() } } - -#[derive(Debug, Clone, Copy)] -struct NewsBreakoutArgs { - uid: usize, - rrr: f64, - slrf: f64, - latest: Duration, - earliest: Duration, -} diff --git a/src/gym/trading/agent/news/fade.rs b/src/gym/trading/agent/news/fade.rs index 3e97f00..c5b305d 100644 --- a/src/gym/trading/agent/news/fade.rs +++ b/src/gym/trading/agent/news/fade.rs @@ -3,8 +3,6 @@ use std::sync::Arc; use chrono::{DateTime, Duration, Utc}; use itertools::iproduct; -use rand::seq::SliceRandom; -use rayon::iter::{IntoParallelIterator, ParallelIterator}; use serde::Serialize; use serde_with::{DurationSeconds, serde_as}; @@ -435,7 +433,7 @@ impl NewsFadeGrid { } } - pub fn build(self) -> (usize, impl ParallelIterator) { + pub fn build(self) -> Vec<(usize, NewsFade)> { let (start_wait, end_wait) = self.wait_duration; // === 1. Generate Axes === @@ -446,44 +444,22 @@ impl NewsFadeGrid { let take_profit_factors = self.tp_risk_factor.generate(); let risk_rewards = self.risk_reward.generate(); - // === 2. Eagerly Collect Valid Args (The "Fat" Vector) === - let mut args = iproduct!(risk_rewards, candles_after_news, take_profit_factors) - .enumerate() - .map(|(uid, (rrr, wait, tprf))| NewsFadeArgs { - uid, - rrr, - wait, - tprf, - }) - .collect::>(); - - let mut rng = rand::rng(); - args.shuffle(&mut rng); - - let total_combinations = args.len(); + // === 2. Eagerly Collect Valid Args === let cal_id = self.cal_id; let ohlcv_id = self.ohlcv_id; - // === 3. Simple Parallel Iterator === - let iterator = args.into_par_iter().map(move |arg| { - ( - arg.uid, - NewsFade::baseline(cal_id, ohlcv_id) - .with_candles_after_news(arg.wait) - .with_take_profit_risk_factor(arg.tprf) - .with_risk_reward_ratio(arg.rrr) - .expect("Valid grid parameters"), - ) - }); - - (total_combinations, iterator) + iproduct!(risk_rewards, candles_after_news, take_profit_factors) + .enumerate() + .map(|(uid, (rrr, wait, tprf))| { + ( + uid, + NewsFade::baseline(cal_id, ohlcv_id) + .with_candles_after_news(wait) + .with_take_profit_risk_factor(tprf) + .with_risk_reward_ratio(rrr) + .expect("Valid grid parameters"), + ) + }) + .collect::>() } } - -#[derive(Debug, Clone, Copy)] -struct NewsFadeArgs { - uid: usize, - rrr: f64, - tprf: f64, - wait: Duration, -} diff --git a/src/gym/trading/agent/news/hybrid.rs b/src/gym/trading/agent/news/hybrid.rs index 73e2dbb..2898ec0 100644 --- a/src/gym/trading/agent/news/hybrid.rs +++ b/src/gym/trading/agent/news/hybrid.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use rayon::iter::{IntoParallelIterator, ParallelIterator}; +use itertools::iproduct; use serde::Serialize; use crate::{ @@ -153,28 +153,21 @@ pub struct NewsHybridGrid { } impl NewsHybridGrid { - pub fn build(self) -> (usize, impl ParallelIterator) { - let (len_breakout, iter_breakout) = self.breakout.build(); - let (len_fade, iter_fade) = self.fade.build(); - let total_combinations = len_breakout * len_fade; - - let fade_agents = iter_fade.map(|(_, agent)| agent).collect::>(); - let fade_arc = Arc::new(fade_agents); - let iterator = iter_breakout.flat_map(move |(b_uid, breakout)| { - let fade_ref = fade_arc.clone(); - let len_fade = fade_ref.len(); - - (0..len_fade).into_par_iter().map(move |f_uid| { - let fade = fade_ref[f_uid]; - - // === Deterministic UID Calculation === - // Since 'b_uid' is 0..M and 'f_uid' is 0..N, - // we can mathematically map them to a unique 0..Total sequence. - let hybrid_uid = (b_uid * len_fade) + f_uid; - (hybrid_uid, NewsHybrid { breakout, fade }) + pub fn build(self) -> Vec<(usize, NewsHybrid)> { + let breakout_agents = self.breakout.build(); + let fade_agents = self.fade.build(); + + iproduct!(breakout_agents, fade_agents) + .enumerate() + .map(|(uid, (breakout, fade))| { + ( + uid, + NewsHybrid { + breakout: breakout.1, + fade: fade.1, + }, + ) }) - }); - - (total_combinations, iterator) + .collect() } } diff --git a/src/gym/trading/config.rs b/src/gym/trading/config.rs index bdb1621..b5b8453 100644 --- a/src/gym/trading/config.rs +++ b/src/gym/trading/config.rs @@ -158,6 +158,59 @@ pub enum EnvPreset { /// ``` BinanceBtcUsdt1d, + /// **BTC/USDT 1-Minute Spot (Binance)** + /// + /// A high-frequency intraday environment for scalping or short-term momentum strategies + /// on Bitcoin spot markets. Each episode covers a single trading day. + /// + /// # Episode Length + /// + /// [`EpisodeLength::Infinite`] + /// + /// # Available IDs + /// + /// ```rust + /// # use chapaty::prelude::*; + /// let ohlcv_id = OhlcvId { + /// broker: DataBroker::Binance, + /// exchange: Exchange::Binance, + /// symbol: Symbol::Spot(SpotPair::BtcUsdt), + /// period: Period::Minute(1), + /// }; + /// ``` + BinanceBtcUsdt1m, + + /// **BTC/USDT 1-Minute + 15-Minute Spot (Binance)** + /// + /// A multi-resolution intraday environment combining 1-minute and 15-minute BTC/USDT + /// OHLCV data. The 15-minute timeframe provides trend context while the 1-minute + /// timeframe is used for precise entry and exit timing. Each episode covers a single + /// trading day. + /// + /// # Episode Length + /// + /// [`EpisodeLength::Infinite`] + /// + /// # Available IDs + /// + /// ```rust + /// # use chapaty::prelude::*; + /// let ohlcv_1m_id = OhlcvId { + /// broker: DataBroker::Binance, + /// exchange: Exchange::Binance, + /// symbol: Symbol::Spot(SpotPair::BtcUsdt), + /// period: Period::Minute(1), + /// }; + /// + /// let ohlcv_15m_id = OhlcvId { + /// broker: DataBroker::Binance, + /// exchange: Exchange::Binance, + /// symbol: Symbol::Spot(SpotPair::BtcUsdt), + /// period: Period::Minute(15), + /// }; + /// ``` + BinanceBtcUsdt1m15m, + /// **EUR/USD 1-Minute + 5-Minute Futures with US Employment News — Unrestricted (NinjaTrader, CME 6eh6)** /// /// A multi-resolution intraday environment with 1-minute and 5-minute EUR/USD futures @@ -179,7 +232,7 @@ pub enum EnvPreset { /// exchange: Exchange::Cme, /// symbol: Symbol::Future(FutureContract { /// root: FutureRoot::EurUsd, - /// month: ContractMonth::March, + /// month: ContractMonth::June, /// year: ContractYear::Y6, /// }), /// period: Period::Minute(1), @@ -190,7 +243,7 @@ pub enum EnvPreset { /// exchange: Exchange::Cme, /// symbol: Symbol::Future(FutureContract { /// root: FutureRoot::EurUsd, - /// month: ContractMonth::March, + /// month: ContractMonth::June, /// year: ContractYear::Y6, /// }), /// period: Period::Minute(5), @@ -232,7 +285,7 @@ pub enum EnvPreset { /// exchange: Exchange::Cme, /// symbol: Symbol::Future(FutureContract { /// root: FutureRoot::EurUsd, - /// month: ContractMonth::March, + /// month: ContractMonth::June, /// year: ContractYear::Y6, /// }), /// period: Period::Minute(1), @@ -273,7 +326,7 @@ pub enum EnvPreset { /// exchange: Exchange::Cme, /// symbol: Symbol::Future(FutureContract { /// root: FutureRoot::EurUsd, - /// month: ContractMonth::March, + /// month: ContractMonth::June, /// year: ContractYear::Y6, /// }), /// period: Period::Minute(1), @@ -284,7 +337,7 @@ pub enum EnvPreset { /// exchange: Exchange::Cme, /// symbol: Symbol::Future(FutureContract { /// root: FutureRoot::EurUsd, - /// month: ContractMonth::March, + /// month: ContractMonth::June, /// year: ContractYear::Y6, /// }), /// period: Period::Minute(5), @@ -438,7 +491,7 @@ pub enum EnvPreset { /// exchange: Exchange::Cme, /// symbol: Symbol::Future(FutureContract { /// root: FutureRoot::EurUsd, - /// month: ContractMonth::March, + /// month: ContractMonth::June, /// year: ContractYear::Y6, /// }), /// period: Period::Minute(1), @@ -449,7 +502,7 @@ pub enum EnvPreset { /// exchange: Exchange::Cme, /// symbol: Symbol::Future(FutureContract { /// root: FutureRoot::EurUsd, - /// month: ContractMonth::March, + /// month: ContractMonth::June, /// year: ContractYear::Y6, /// }), /// aggregation: ProfileAggregation { @@ -492,12 +545,57 @@ impl From for EnvConfig { .with_episode_length(EpisodeLength::Infinite) .with_filter_config(filter) } + EnvPreset::BinanceBtcUsdt1m => { + let market_config = OhlcvSpotConfig { + broker: DataBroker::Binance, + symbol: Symbol::Spot(SpotPair::BtcUsdt), + period: Period::Minute(1), + batch_size: 1000, + exchange: Some(Exchange::Binance), + indicators: Vec::new(), + }; + let filter = FilterConfig { + allowed_years: Some((2017..=2026).collect::>()), + ..FilterConfig::default() + }; + EnvConfig::default() + .add_ohlcv_spot(source.clone(), market_config) + .with_episode_length(EpisodeLength::Infinite) + .with_filter_config(filter) + } + EnvPreset::BinanceBtcUsdt1m15m => { + let ohlcv_1m = OhlcvSpotConfig { + broker: DataBroker::Binance, + symbol: Symbol::Spot(SpotPair::BtcUsdt), + exchange: Some(Exchange::Binance), + period: Period::Minute(1), + batch_size: 1000, + indicators: Vec::new(), + }; + let ohlcv_15m = OhlcvSpotConfig { + broker: DataBroker::Binance, + symbol: Symbol::Spot(SpotPair::BtcUsdt), + exchange: Some(Exchange::Binance), + period: Period::Minute(15), + batch_size: 1000, + indicators: Vec::new(), + }; + let filter = FilterConfig { + allowed_years: Some((2017..=2026).collect::>()), + ..FilterConfig::default() + }; + EnvConfig::default() + .add_ohlcv_spot(source.clone(), ohlcv_1m) + .add_ohlcv_spot(source.clone(), ohlcv_15m) + .with_episode_length(EpisodeLength::Infinite) + .with_filter_config(filter) + } EnvPreset::NinjaTraderCme6eh61m5mUsEmpHigh => { let ohlcv_1m = OhlcvFutureConfig { broker: DataBroker::NinjaTrader, symbol: Symbol::Future(FutureContract { root: FutureRoot::EurUsd, - month: ContractMonth::March, + month: ContractMonth::June, year: ContractYear::Y6, }), exchange: Some(Exchange::Cme), @@ -509,7 +607,7 @@ impl From for EnvConfig { broker: DataBroker::NinjaTrader, symbol: Symbol::Future(FutureContract { root: FutureRoot::EurUsd, - month: ContractMonth::March, + month: ContractMonth::June, year: ContractYear::Y6, }), exchange: Some(Exchange::Cme), @@ -542,7 +640,7 @@ impl From for EnvConfig { broker: DataBroker::NinjaTrader, symbol: Symbol::Future(FutureContract { root: FutureRoot::EurUsd, - month: ContractMonth::March, + month: ContractMonth::June, year: ContractYear::Y6, }), exchange: Some(Exchange::Cme), @@ -575,7 +673,7 @@ impl From for EnvConfig { broker: DataBroker::NinjaTrader, symbol: Symbol::Future(FutureContract { root: FutureRoot::EurUsd, - month: ContractMonth::March, + month: ContractMonth::June, year: ContractYear::Y6, }), exchange: Some(Exchange::Cme), @@ -587,7 +685,7 @@ impl From for EnvConfig { broker: DataBroker::NinjaTrader, symbol: Symbol::Future(FutureContract { root: FutureRoot::EurUsd, - month: ContractMonth::March, + month: ContractMonth::June, year: ContractYear::Y6, }), exchange: Some(Exchange::Cme), @@ -720,7 +818,7 @@ impl From for EnvConfig { broker: DataBroker::NinjaTrader, symbol: Symbol::Future(FutureContract { root: FutureRoot::EurUsd, - month: ContractMonth::March, + month: ContractMonth::June, year: ContractYear::Y6, }), exchange: Some(Exchange::Cme), @@ -732,7 +830,7 @@ impl From for EnvConfig { broker: DataBroker::NinjaTrader, symbol: Symbol::Future(FutureContract { root: FutureRoot::EurUsd, - month: ContractMonth::March, + month: ContractMonth::June, year: ContractYear::Y6, }), exchange: Some(Exchange::Cme), diff --git a/src/gym/trading/env.rs b/src/gym/trading/env.rs index 984671a..d1f2d24 100644 --- a/src/gym/trading/env.rs +++ b/src/gym/trading/env.rs @@ -1,7 +1,7 @@ use std::{fmt::Debug, sync::Arc}; -use indicatif::{ProgressBar, ProgressStyle}; -use rayon::iter::ParallelIterator; +use indicatif::{ParallelProgressIterator, ProgressBar, ProgressStyle}; +use rayon::iter::{IntoParallelIterator, ParallelIterator}; use serde::Serialize; use strum::{EnumCount, IntoEnumIterator}; use tracing::warn; @@ -119,12 +119,10 @@ impl Environment { /// /// # Arguments /// - /// * `agents` - A parallel iterator yielding `(usize, Agent)`. The `usize` is treated as + /// * `agents` - A vector of `(usize, Agent)`. The `usize` is treated as /// the unique **Agent UID**. This is typically created by calling `.enumerate()` on your - /// configuration iterator (e.g., `args.enumerate()`) before converting it to parallel. + /// configuration of the agent grid. /// * `top_k` - The maximum number of agents to retain in the leaderboard. - /// * `stream_len` - The total number of agents expected. This is used solely to initialize - /// the progress bar's length, as parallel iterators may not always know their exact bounds. /// /// # Runtime Estimation /// @@ -137,28 +135,27 @@ impl Environment { /// 2. Measure the time it takes to run `env.evaluate_agent(&mut agent)`. /// 3. Estimate your total wait time: `(Single Time * Total Agents) / CPU Cores`. /// - /// This simple check prevents surprises—like discovering a 1M run will take 2 weeks - /// instead of 2 hours. + /// This simple check prevents surprises—like discovering a 1M run will take 2 weeks instead of 2 hours. pub fn evaluate_agents( &mut self, - agents: impl ParallelIterator, + agents: Vec<(usize, T)>, top_k: usize, - stream_len: u64, ) -> ChapatyResult where T: Agent + Send + Serialize, { self.reset()?; - let pb = progress_bar(stream_len)?; + let pb = progress_bar(agents.len() as u64)?; pb.set_message("Running evaluation..."); let agent_leaderboard = agents + .into_par_iter() + .progress_with(pb.clone()) .try_fold( || AgentLeaderboard::new(top_k), |mut board, (uid, mut agent)| { let entries = self.worker(&mut agent, uid as u64)?; board.update(&entries, agent); - pb.inc(1); Ok(board) }, )