diff --git a/README.md b/README.md index e4c6b89..3077c1d 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,7 @@ This repository implements a small, deterministic C++ limit-order-book engine fo - aggregated bid/ask levels plus order-ID lookup - two price-level backends: `std::map` and flat sorted `std::vector` - rolling analytics and CSV export after every processed message +- optional post-replay prediction summary reporting by message horizon - deterministic C++ and Python integration tests - replay benchmark tooling and a hand-maintained benchmark reproducibility note @@ -61,6 +62,19 @@ Export analytics rows after every processed message: If `--backend both` is selected, the CLI writes one CSV per backend by suffixing the output path. +Emit a separate prediction summary after replay without changing the analytics CSV rows: + +```bash +"$build_dir/lob_engine" \ + data/AAPL_sample_messages.csv \ + --backend map \ + --analytics-out "$build_dir/analytics.csv" \ + --prediction-report-out "$build_dir/prediction_report.csv" \ + --prediction-horizons 100,500 +``` + +`--prediction-report-out` requires `--prediction-horizons`. If both flags are omitted, prediction work stays disabled. + ## Analytics Each processed message produces a row with: @@ -78,6 +92,8 @@ The default rolling windows match the project objective: - trailing `1000` messages for trade-based metrics - trailing `300` seconds for realized volatility +Prediction reporting is a separate CSV keyed by message horizon. For each row `t`, the label is the sign of the first non-zero mid-price move found in `t+1 ... t+H` relative to mid at `t`. Rows with invalid current mid or no non-zero future move inside the horizon are skipped. The report includes labeled sample counts, up/down move counts, hit rate from `sign(order_imbalance_top5)` on non-zero-signal rows, and information coefficient computed as the Pearson correlation between the raw top-5 imbalance value and the future move sign. Zero-signal rows stay in the labeled sample and IC calculation but increment `skipped_zero_signal` so they are excluded from the hit-rate denominator. + ## Backends Two backends are implemented behind the same `OrderBook` interface: diff --git a/include/lob/analytics.hpp b/include/lob/analytics.hpp index 58adda8..c66d9b1 100644 --- a/include/lob/analytics.hpp +++ b/include/lob/analytics.hpp @@ -1,9 +1,12 @@ #pragma once #include +#include +#include #include #include #include +#include #include #include "lob/order_book.hpp" @@ -11,11 +14,208 @@ namespace lob { +struct OptionalStringSetting { + OptionalStringSetting() = default; + OptionalStringSetting(std::nullopt_t) noexcept {} + OptionalStringSetting(const std::optional& text) + : value_(text) {} + OptionalStringSetting(std::optional&& text) noexcept + : value_(std::move(text)) {} + OptionalStringSetting(const std::string& text) + : value_(text) {} + OptionalStringSetting(std::string&& text) noexcept + : value_(std::move(text)) {} + OptionalStringSetting(const char* text) + : value_(text == nullptr ? std::optional{} : std::optional{text}) {} + + OptionalStringSetting& operator=(std::nullopt_t) noexcept { + value_.reset(); + return *this; + } + + OptionalStringSetting& operator=(const std::optional& text) { + value_ = text; + return *this; + } + + OptionalStringSetting& operator=(std::optional&& text) noexcept { + value_ = std::move(text); + return *this; + } + + OptionalStringSetting& operator=(const std::string& text) { + value_ = text; + return *this; + } + + OptionalStringSetting& operator=(std::string&& text) noexcept { + value_ = std::move(text); + return *this; + } + + OptionalStringSetting& operator=(const char* text) { + value_ = text == nullptr ? std::optional{} : std::optional{text}; + return *this; + } + + bool has_value() const noexcept { + return value_.has_value(); + } + + bool empty() const noexcept { + return !value_.has_value() || value_->empty(); + } + + void reset() noexcept { + value_.reset(); + } + + const std::string& value() const { + return value_.value(); + } + + std::string value_or(std::string default_value) const { + return value_.value_or(std::move(default_value)); + } + + const std::string& operator*() const { + return value(); + } + + std::string& operator*() { + return value_.value(); + } + + const std::string* operator->() const { + return &value(); + } + + std::string* operator->() { + return &value_.value(); + } + + explicit operator bool() const noexcept { + return value_.has_value(); + } + + friend bool operator==(const OptionalStringSetting& lhs, std::nullopt_t) noexcept { + return !lhs.value_.has_value(); + } + + friend bool operator==(std::nullopt_t, const OptionalStringSetting& rhs) noexcept { + return rhs == std::nullopt; + } + + friend bool operator!=(const OptionalStringSetting& lhs, std::nullopt_t) noexcept { + return !(lhs == std::nullopt); + } + + friend bool operator!=(std::nullopt_t, const OptionalStringSetting& rhs) noexcept { + return !(rhs == std::nullopt); + } + + friend bool operator==(const OptionalStringSetting& lhs, const std::string& rhs) { + return lhs.value_ == rhs; + } + + friend bool operator==(const std::string& lhs, const OptionalStringSetting& rhs) { + return rhs == lhs; + } + + friend bool operator!=(const OptionalStringSetting& lhs, const std::string& rhs) { + return !(lhs == rhs); + } + + friend bool operator!=(const std::string& lhs, const OptionalStringSetting& rhs) { + return !(rhs == lhs); + } + + friend bool operator==(const OptionalStringSetting& lhs, const char* rhs) { + return lhs == std::string(rhs == nullptr ? "" : rhs); + } + + friend bool operator==(const char* lhs, const OptionalStringSetting& rhs) { + return rhs == lhs; + } + + friend bool operator!=(const OptionalStringSetting& lhs, const char* rhs) { + return !(lhs == rhs); + } + + friend bool operator!=(const char* lhs, const OptionalStringSetting& rhs) { + return !(rhs == lhs); + } + + friend bool operator==(const OptionalStringSetting& lhs, const std::optional& rhs) { + return lhs.value_ == rhs; + } + + friend bool operator==(const std::optional& lhs, const OptionalStringSetting& rhs) { + return rhs == lhs; + } + + friend bool operator!=(const OptionalStringSetting& lhs, const std::optional& rhs) { + return !(lhs == rhs); + } + + friend bool operator!=(const std::optional& lhs, const OptionalStringSetting& rhs) { + return !(rhs == lhs); + } + +private: + std::optional value_{}; +}; + struct AnalyticsConfig { std::size_t trade_window_messages{1000}; double realized_vol_window_seconds{300.0}; std::size_t depth_levels{10}; std::size_t expected_messages{0}; + std::vector prediction_horizons{}; + OptionalStringSetting prediction_report_out{}; + std::vector prediction_horizons_messages{}; + + bool prediction_report_output_enabled() const noexcept { + return prediction_report_out.has_value() && !prediction_report_out.empty(); + } + + std::vector resolved_prediction_horizons() const { + const bool use_message_horizons = !prediction_horizons_messages.empty(); + std::vector horizons; + + if (use_message_horizons) { + horizons.reserve(prediction_horizons_messages.size()); + for (const int horizon : prediction_horizons_messages) { + if (horizon > 0) { + horizons.push_back(static_cast(horizon)); + } + } + return horizons; + } + + horizons.reserve(prediction_horizons.size()); + for (const std::size_t horizon : prediction_horizons) { + if (horizon > 0 && + horizon <= static_cast(std::numeric_limits::max())) { + horizons.push_back(horizon); + } + } + return horizons; + } + + std::vector resolved_prediction_horizons_messages() const { + const std::vector resolved_horizons = resolved_prediction_horizons(); + std::vector horizons; + horizons.reserve(resolved_horizons.size()); + for (const std::size_t horizon : resolved_horizons) { + horizons.push_back(static_cast(horizon)); + } + return horizons; + } + + bool prediction_reporting_enabled() const { + return prediction_report_output_enabled() && !resolved_prediction_horizons().empty(); + } }; struct AnalyticsRow { @@ -36,6 +236,29 @@ struct AnalyticsRow { std::optional rolling_realized_vol; }; +struct PredictionSnapshot { + std::size_t message_index{0}; + std::optional mid_price; + std::optional order_imbalance_top5; +}; + +struct PredictionSummaryRow { + std::size_t horizon_messages{0}; + std::size_t total_rows_seen{0}; + std::size_t eligible_rows_with_valid_mid{0}; + std::size_t labeled_rows{0}; + std::size_t skipped_no_valid_mid{0}; + std::size_t skipped_no_future_move_within_horizon{0}; + std::size_t skipped_zero_signal{0}; + std::size_t up_moves{0}; + std::size_t down_moves{0}; + std::size_t correct_predictions{0}; + std::size_t incorrect_predictions{0}; + double hit_rate{0.0}; + double information_coefficient{0.0}; + double coverage_vs_total{0.0}; +}; + class AnalyticsEngine { public: explicit AnalyticsEngine(AnalyticsConfig config = {}); @@ -64,4 +287,22 @@ std::vector replay_with_analytics( void write_analytics_csv(const std::vector& rows, const std::string& output_path); +std::vector collect_prediction_snapshots(const std::vector& rows); + +std::vector summarize_prediction_horizons( + const std::vector& snapshots, + const std::vector& horizons); + +std::vector summarize_prediction_horizons( + const std::vector& snapshots, + const std::vector& horizons); + +std::vector summarize_prediction_horizons( + const std::vector& snapshots, + std::initializer_list horizons); + +void write_prediction_report_csv( + const std::vector& rows, + const std::string& output_path); + } // namespace lob diff --git a/report/benchmark_report.md b/report/benchmark_report.md index 3d61b3e..b620f26 100644 --- a/report/benchmark_report.md +++ b/report/benchmark_report.md @@ -43,6 +43,27 @@ ctest --test-dir "$build_dir" --output-on-failure -C Release python -m pytest tests -q --tb=short ``` +## Prediction Reporting Feature Gate + +The new prediction labeling/reporting path is outside the replay-only benchmark timer and remains optional. The core `lob_benchmark` command is still the replay hot-path check: + +```bash +taskset -c 0 "$build_dir/lob_benchmark" --dataset data/AAPL_sample_messages.csv --backend both --reserve on --depth 5 --repeat 100000 +``` + +To exercise the same dataset through the normal replay CLI with prediction reporting disabled versus enabled: + +```bash +"$build_dir/lob_engine" data/AAPL_sample_messages.csv --backend map --analytics-out "$build_dir/analytics.csv" +"$build_dir/lob_engine" data/AAPL_sample_messages.csv --backend map --analytics-out "$build_dir/analytics.csv" --prediction-report-out "$build_dir/prediction_report.csv" --prediction-horizons 100 +``` + +Expected behavior: + +- without prediction flags, the CLI emits the existing analytics CSV only +- with prediction flags, the analytics CSV stays unchanged and a separate prediction report CSV is added +- any extra work is feature-gated to the prediction-enabled CLI path; the replay-only benchmark command above remains valid and unchanged + ## Measurement methodology - baseline variant: clean `origin/main` tree at commit `d627b73` diff --git a/src/analytics.cpp b/src/analytics.cpp index 79e3793..d898eac 100644 --- a/src/analytics.cpp +++ b/src/analytics.cpp @@ -28,6 +28,46 @@ bool is_trade_event(EventType event_type) noexcept { event_type == EventType::CrossTrade; } +int sign_of(double value) noexcept { + return (value > 0.0) - (value < 0.0); +} + +class RunningPearsonCorrelation { +public: + void observe(double x, double y) noexcept { + ++count_; + sum_x_ += x; + sum_y_ += y; + sum_xy_ += x * y; + sum_x2_ += x * x; + sum_y2_ += y * y; + } + + double value() const noexcept { + if (count_ < 2) { + return 0.0; + } + + const double count = static_cast(count_); + const double numerator = count * sum_xy_ - (sum_x_ * sum_y_); + const double variance_x = count * sum_x2_ - (sum_x_ * sum_x_); + const double variance_y = count * sum_y2_ - (sum_y_ * sum_y_); + if (variance_x <= 0.0 || variance_y <= 0.0) { + return 0.0; + } + + return numerator / std::sqrt(variance_x * variance_y); + } + +private: + std::size_t count_{0}; + double sum_x_{0.0}; + double sum_y_{0.0}; + double sum_xy_{0.0}; + double sum_x2_{0.0}; + double sum_y2_{0.0}; +}; + template void write_optional(std::ofstream& output, const std::optional& value) { if (value.has_value()) { @@ -92,6 +132,113 @@ class SlidingWindowBuffer { std::size_t start_index_{0}; }; +std::optional first_future_label( + const std::vector& snapshots, + std::size_t message_index, + std::size_t horizon) { + if (message_index >= snapshots.size() || !snapshots[message_index].mid_price.has_value()) { + return std::nullopt; + } + + const double current_mid = *snapshots[message_index].mid_price; + const std::size_t end_index = std::min(snapshots.size(), message_index + horizon + 1); + for (std::size_t future_index = message_index + 1; future_index < end_index; ++future_index) { + if (!snapshots[future_index].mid_price.has_value()) { + continue; + } + + const int move_sign = sign_of(*snapshots[future_index].mid_price - current_mid); + if (move_sign != 0) { + return move_sign; + } + } + + return std::nullopt; +} + +std::vector validate_prediction_horizons(const std::vector& horizons) { + std::vector validated; + validated.reserve(horizons.size()); + for (const int horizon : horizons) { + if (horizon <= 0) { + throw std::invalid_argument("Prediction horizons must be positive"); + } + validated.push_back(static_cast(horizon)); + } + return validated; +} + +std::vector summarize_prediction_horizons_impl( + const std::vector& snapshots, + const std::vector& horizons) { + std::vector summaries; + summaries.reserve(horizons.size()); + + for (const std::size_t horizon : horizons) { + if (horizon == 0) { + throw std::invalid_argument("Prediction horizons must be positive"); + } + + PredictionSummaryRow summary; + RunningPearsonCorrelation information_coefficient; + summary.horizon_messages = horizon; + summary.total_rows_seen = snapshots.size(); + + for (std::size_t index = 0; index < snapshots.size(); ++index) { + const PredictionSnapshot& snapshot = snapshots[index]; + if (!snapshot.mid_price.has_value()) { + ++summary.skipped_no_valid_mid; + continue; + } + + ++summary.eligible_rows_with_valid_mid; + const std::optional label = first_future_label(snapshots, index, horizon); + if (!label.has_value()) { + ++summary.skipped_no_future_move_within_horizon; + continue; + } + + ++summary.labeled_rows; + if (*label > 0) { + ++summary.up_moves; + } else if (*label < 0) { + ++summary.down_moves; + } + + if (snapshot.order_imbalance_top5.has_value()) { + information_coefficient.observe(*snapshot.order_imbalance_top5, static_cast(*label)); + } + + const int signal_sign = snapshot.order_imbalance_top5.has_value() + ? sign_of(*snapshot.order_imbalance_top5) + : 0; + if (signal_sign == 0) { + ++summary.skipped_zero_signal; + continue; + } + + if (signal_sign == *label) { + ++summary.correct_predictions; + } else { + ++summary.incorrect_predictions; + } + } + + const std::size_t directional_observations = summary.correct_predictions + summary.incorrect_predictions; + summary.hit_rate = directional_observations > 0 + ? static_cast(summary.correct_predictions) / static_cast(directional_observations) + : 0.0; + summary.information_coefficient = information_coefficient.value(); + summary.coverage_vs_total = summary.total_rows_seen > 0 + ? static_cast(summary.labeled_rows) / static_cast(summary.total_rows_seen) + : 0.0; + + summaries.push_back(summary); + } + + return summaries; +} + } // namespace struct AnalyticsEngine::Impl { @@ -283,4 +430,66 @@ void write_analytics_csv(const std::vector& rows, const std::strin } } +std::vector collect_prediction_snapshots(const std::vector& rows) { + std::vector snapshots; + snapshots.reserve(rows.size()); + for (std::size_t index = 0; index < rows.size(); ++index) { + snapshots.push_back(PredictionSnapshot{ + index, + rows[index].mid_price, + rows[index].order_imbalance, + }); + } + return snapshots; +} + +std::vector summarize_prediction_horizons( + const std::vector& snapshots, + const std::vector& horizons) { + return summarize_prediction_horizons_impl(snapshots, horizons); +} + +std::vector summarize_prediction_horizons( + const std::vector& snapshots, + const std::vector& horizons) { + return summarize_prediction_horizons_impl(snapshots, validate_prediction_horizons(horizons)); +} + +std::vector summarize_prediction_horizons( + const std::vector& snapshots, + std::initializer_list horizons) { + return summarize_prediction_horizons(snapshots, std::vector(horizons)); +} + +void write_prediction_report_csv( + const std::vector& rows, + const std::string& output_path) { + std::ofstream output(output_path); + if (!output.is_open()) { + throw std::runtime_error("Could not open prediction report output path"); + } + + output << "horizon_messages,total_rows_seen,eligible_rows_with_valid_mid,labeled_rows," + "skipped_no_valid_mid,skipped_no_future_move_within_horizon,skipped_zero_signal," + "up_moves,down_moves,correct_predictions,incorrect_predictions,hit_rate," + "information_coefficient,coverage_vs_total\n"; + output << std::fixed << std::setprecision(6); + for (const PredictionSummaryRow& row : rows) { + output << row.horizon_messages << ',' + << row.total_rows_seen << ',' + << row.eligible_rows_with_valid_mid << ',' + << row.labeled_rows << ',' + << row.skipped_no_valid_mid << ',' + << row.skipped_no_future_move_within_horizon << ',' + << row.skipped_zero_signal << ',' + << row.up_moves << ',' + << row.down_moves << ',' + << row.correct_predictions << ',' + << row.incorrect_predictions << ',' + << row.hit_rate << ',' + << row.information_coefficient << ',' + << row.coverage_vs_total << '\n'; + } +} + } // namespace lob diff --git a/src/main.cpp b/src/main.cpp index 3c1f1f0..e316a72 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -1,8 +1,10 @@ #include +#include #include #include #include #include +#include #include #include #include @@ -22,12 +24,15 @@ struct CliOptions { std::string analytics_out; std::size_t trade_window_messages{1000}; double realized_vol_window_seconds{300.0}; + std::vector prediction_horizons_messages; + std::string prediction_report_out; }; void print_usage() { std::cerr << "Usage: lob_engine [--backend map|flat|both] [--depth N] [--repeat N] " - "[--analytics-out PATH] [--trade-window-messages N] [--realized-vol-window-seconds N]\n"; + "[--analytics-out PATH] [--trade-window-messages N] [--realized-vol-window-seconds N] " + "[--prediction-report-out PATH] [--prediction-horizons H1,H2,...]\n"; } std::optional parse_positive_size(const std::string& value) { @@ -56,8 +61,62 @@ std::optional parse_positive_double(const std::string& value) { return parsed; } -bool parse_args(int argc, char* argv[], CliOptions& options) { +std::string trim_ascii_whitespace(const std::string& value) { + std::size_t start = 0; + while (start < value.size() && std::isspace(static_cast(value[start])) != 0) { + ++start; + } + + std::size_t end = value.size(); + while (end > start && std::isspace(static_cast(value[end - 1])) != 0) { + --end; + } + + return value.substr(start, end - start); +} + +std::optional> parse_prediction_horizons( + const std::string& value, + std::string& error) { + if (value.empty()) { + error = "Prediction horizons must be a comma-separated list of positive integers"; + return std::nullopt; + } + + std::vector horizons; + std::size_t start = 0; + while (start <= value.size()) { + const std::size_t delimiter = value.find(',', start); + const std::string token = trim_ascii_whitespace( + value.substr(start, delimiter == std::string::npos ? std::string::npos : delimiter - start)); + if (token.empty()) { + error = "Prediction horizons must not contain empty entries"; + return std::nullopt; + } + + const auto parsed = parse_positive_size(token); + if (!parsed.has_value()) { + error = "Invalid prediction horizon: " + token; + return std::nullopt; + } + if (*parsed > static_cast(std::numeric_limits::max())) { + error = "Invalid prediction horizon: " + token; + return std::nullopt; + } + horizons.push_back(static_cast(*parsed)); + + if (delimiter == std::string::npos) { + break; + } + start = delimiter + 1; + } + + return horizons; +} + +bool parse_args(int argc, char* argv[], CliOptions& options, std::string& error) { if (argc < 2) { + error = "Missing input dataset path"; return false; } @@ -67,6 +126,7 @@ bool parse_args(int argc, char* argv[], CliOptions& options) { const std::string arg = argv[index]; if (arg == "--backend") { if (index + 1 >= argc) { + error = "Missing value for --backend"; return false; } options.backend_name = argv[++index]; @@ -75,10 +135,12 @@ bool parse_args(int argc, char* argv[], CliOptions& options) { if (arg == "--depth") { if (index + 1 >= argc) { + error = "Missing value for --depth"; return false; } const auto parsed = parse_positive_size(argv[++index]); if (!parsed.has_value()) { + error = "Invalid value for --depth"; return false; } options.depth = *parsed; @@ -87,10 +149,12 @@ bool parse_args(int argc, char* argv[], CliOptions& options) { if (arg == "--repeat") { if (index + 1 >= argc) { + error = "Missing value for --repeat"; return false; } const auto parsed = parse_positive_size(argv[++index]); if (!parsed.has_value()) { + error = "Invalid value for --repeat"; return false; } options.repeats = *parsed; @@ -99,6 +163,7 @@ bool parse_args(int argc, char* argv[], CliOptions& options) { if (arg == "--analytics-out") { if (index + 1 >= argc) { + error = "Missing value for --analytics-out"; return false; } options.analytics_out = argv[++index]; @@ -107,10 +172,12 @@ bool parse_args(int argc, char* argv[], CliOptions& options) { if (arg == "--trade-window-messages") { if (index + 1 >= argc) { + error = "Missing value for --trade-window-messages"; return false; } const auto parsed = parse_positive_size(argv[++index]); if (!parsed.has_value()) { + error = "Invalid value for --trade-window-messages"; return false; } options.trade_window_messages = *parsed; @@ -119,16 +186,50 @@ bool parse_args(int argc, char* argv[], CliOptions& options) { if (arg == "--realized-vol-window-seconds") { if (index + 1 >= argc) { + error = "Missing value for --realized-vol-window-seconds"; return false; } const auto parsed = parse_positive_double(argv[++index]); if (!parsed.has_value()) { + error = "Invalid value for --realized-vol-window-seconds"; return false; } options.realized_vol_window_seconds = *parsed; continue; } + if (arg == "--prediction-report-out") { + if (index + 1 >= argc) { + error = "Missing value for --prediction-report-out"; + return false; + } + options.prediction_report_out = argv[++index]; + continue; + } + + if (arg == "--prediction-horizons") { + if (index + 1 >= argc) { + error = "Missing value for --prediction-horizons"; + return false; + } + const auto parsed = parse_prediction_horizons(argv[++index], error); + if (!parsed.has_value()) { + return false; + } + options.prediction_horizons_messages = *parsed; + continue; + } + + error = "Unknown argument: " + arg; + return false; + } + + if (!options.prediction_report_out.empty() && options.prediction_horizons_messages.empty()) { + error = "--prediction-report-out requires --prediction-horizons"; + return false; + } + if (options.prediction_report_out.empty() && !options.prediction_horizons_messages.empty()) { + error = "--prediction-horizons requires --prediction-report-out"; return false; } @@ -156,7 +257,7 @@ std::string format_level(const std::optional& level) { " (" + std::to_string(level->order_count) + ")"; } -std::string analytics_output_path(const std::string& base, lob::OrderBookBackend backend, bool multiple_backends) { +std::string backend_output_path(const std::string& base, lob::OrderBookBackend backend, bool multiple_backends) { if (!multiple_backends) { return base; } @@ -170,8 +271,17 @@ std::string analytics_output_path(const std::string& base, lob::OrderBookBackend } // namespace int main(int argc, char* argv[]) { + if (argc == 2 && std::string(argv[1]) == "--help") { + print_usage(); + return 0; + } + CliOptions options; - if (!parse_args(argc, argv, options)) { + std::string parse_error; + if (!parse_args(argc, argv, options, parse_error)) { + if (!parse_error.empty()) { + std::cerr << parse_error << '\n'; + } print_usage(); return 1; } @@ -230,24 +340,47 @@ int main(int argc, char* argv[]) { << " ask=" << format_level(summary.final_snapshot.best_ask) << " active_orders=" << summary.final_snapshot.active_order_count << '\n'; - if (!options.analytics_out.empty()) { + const bool needs_post_replay_analytics = + !options.analytics_out.empty() || !options.prediction_report_out.empty(); + if (needs_post_replay_analytics) { std::unique_ptr book = lob::make_order_book(backend, build_config); + lob::AnalyticsConfig analytics_config; + analytics_config.trade_window_messages = options.trade_window_messages; + analytics_config.realized_vol_window_seconds = options.realized_vol_window_seconds; + analytics_config.depth_levels = std::max(options.depth, 10); + analytics_config.expected_messages = messages.size(); + analytics_config.prediction_horizons_messages = options.prediction_horizons_messages; + analytics_config.prediction_report_out = options.prediction_report_out; + const std::vector analytics_rows = - lob::replay_with_analytics( - messages, - *book, - lob::AnalyticsConfig{ - options.trade_window_messages, - options.realized_vol_window_seconds, - std::max(options.depth, 10), - messages.size(), - }); - const std::string output_path = analytics_output_path( - options.analytics_out, - backend, - backends.size() > 1); - lob::write_analytics_csv(analytics_rows, output_path); - std::cout << "Analytics CSV=" << output_path << " rows=" << analytics_rows.size() << '\n'; + lob::replay_with_analytics(messages, *book, analytics_config); + + if (!options.analytics_out.empty()) { + const std::string output_path = backend_output_path( + options.analytics_out, + backend, + backends.size() > 1); + lob::write_analytics_csv(analytics_rows, output_path); + std::cout << "Analytics CSV=" << output_path << " rows=" << analytics_rows.size() << '\n'; + } + + if (analytics_config.prediction_reporting_enabled()) { + const std::vector resolved_prediction_horizons = + analytics_config.resolved_prediction_horizons_messages(); + const std::vector prediction_snapshots = + lob::collect_prediction_snapshots(analytics_rows); + const std::vector prediction_report = + lob::summarize_prediction_horizons( + prediction_snapshots, + resolved_prediction_horizons); + const std::string output_path = backend_output_path( + analytics_config.prediction_report_out.value(), + backend, + backends.size() > 1); + lob::write_prediction_report_csv(prediction_report, output_path); + std::cout << "Prediction report=" << output_path + << " rows=" << prediction_report.size() << '\n'; + } } } diff --git a/tests/test_analytics.cpp b/tests/test_analytics.cpp index fce978c..52cc6af 100644 --- a/tests/test_analytics.cpp +++ b/tests/test_analytics.cpp @@ -1,6 +1,11 @@ #include +#include #include +#include #include +#include +#include +#include #include #include "lob/analytics.hpp" @@ -14,6 +19,23 @@ std::filesystem::path source_root() { return std::filesystem::path(LOB_ENGINE_SOURCE_DIR); } +std::filesystem::path make_temp_file(const std::string& stem) { + static int counter = 0; + return std::filesystem::temp_directory_path() / + (stem + "_" + std::to_string(counter++) + ".csv"); +} + +bool almost_equal(double lhs, double rhs, double tolerance = 1e-9) { + return std::fabs(lhs - rhs) <= tolerance; +} + +lob::PredictionSnapshot make_prediction_snapshot( + std::size_t message_index, + std::optional mid_price, + std::optional order_imbalance_top5) { + return lob::PredictionSnapshot{message_index, mid_price, order_imbalance_top5}; +} + void test_analytics_rows_cover_every_message() { const auto sample_path = source_root() / "data" / "sample_messages.csv"; lob::LobsterParser parser; @@ -84,12 +106,348 @@ void test_analytics_outputs_match_across_backends() { assert(map_book.snapshot(10) == flat_book.snapshot(10)); } +void test_prediction_config_defaults_and_round_trip() { + lob::AnalyticsConfig config; + assert(config.prediction_horizons_messages.empty()); + assert(config.prediction_horizons.empty()); + assert(config.prediction_report_out == std::nullopt); + assert(config.prediction_report_out.empty()); + assert(!config.prediction_reporting_enabled()); + + config.prediction_horizons_messages = {1, 3}; + config.prediction_report_out = "predictions.csv"; + assert(config.prediction_reporting_enabled()); + assert(config.prediction_report_out.has_value()); + assert(config.prediction_horizons_messages == std::vector({1, 3})); + assert(config.resolved_prediction_horizons_messages() == std::vector({1, 3})); + assert(config.prediction_report_out == "predictions.csv"); + + lob::AnalyticsConfig legacy_config; + legacy_config.prediction_horizons = {2, 4}; + legacy_config.prediction_report_out = "legacy.csv"; + assert(legacy_config.prediction_reporting_enabled()); + assert(legacy_config.resolved_prediction_horizons_messages() == std::vector({2, 4})); + + lob::AnalyticsConfig precedence_config; + precedence_config.prediction_horizons = {7, 9}; + precedence_config.prediction_horizons_messages = {1, 3}; + precedence_config.prediction_report_out = "precedence.csv"; + assert(precedence_config.prediction_reporting_enabled()); + assert(precedence_config.resolved_prediction_horizons_messages() == std::vector({1, 3})); + + lob::AnalyticsConfig filtered_legacy_config; + filtered_legacy_config.prediction_horizons = { + 0, + 2, + static_cast(std::numeric_limits::max()) + 1u, + }; + filtered_legacy_config.prediction_report_out = "filtered.csv"; + assert(filtered_legacy_config.prediction_reporting_enabled()); + assert(filtered_legacy_config.resolved_prediction_horizons_messages() == std::vector({2})); + + lob::AnalyticsConfig invalid_config; + invalid_config.prediction_horizons_messages = {0, -1}; + invalid_config.prediction_report_out = "invalid.csv"; + assert(!invalid_config.prediction_reporting_enabled()); + assert(invalid_config.resolved_prediction_horizons_messages().empty()); + + std::vector messages = { + {100.0, lob::EventType::NewOrder, 1, 50, 10000, lob::Side::Buy}, + {100.1, lob::EventType::NewOrder, 2, 60, 10100, lob::Side::Sell}, + {100.2, lob::EventType::NewOrder, 3, 40, 10050, lob::Side::Buy}, + }; + + lob::MapOrderBook book; + const std::vector rows = lob::replay_with_analytics(messages, book, config); + const std::vector snapshots = lob::collect_prediction_snapshots(rows); + + assert(rows.size() == messages.size()); + assert(snapshots.size() == rows.size()); + for (std::size_t index = 0; index < snapshots.size(); ++index) { + assert(snapshots[index].message_index == index); + assert(snapshots[index].mid_price == rows[index].mid_price); + assert(snapshots[index].order_imbalance_top5 == rows[index].order_imbalance); + } +} + +void test_prediction_positive_label_when_first_non_zero_future_move_is_up() { + const std::vector snapshots = { + make_prediction_snapshot(0, 100.0, 0.8), + make_prediction_snapshot(1, 100.0, -0.4), + make_prediction_snapshot(2, 101.0, 0.2), + }; + + const std::vector summaries = + lob::summarize_prediction_horizons(snapshots, {2}); + assert(summaries.size() == 1); + + const lob::PredictionSummaryRow& summary = summaries.front(); + assert(summary.horizon_messages == 2); + assert(summary.total_rows_seen == 3); + assert(summary.eligible_rows_with_valid_mid == 3); + assert(summary.labeled_rows == 2); + assert(summary.skipped_no_valid_mid == 0); + assert(summary.skipped_no_future_move_within_horizon == 1); + assert(summary.skipped_zero_signal == 0); + assert(summary.up_moves == 2); + assert(summary.down_moves == 0); + assert(summary.correct_predictions == 1); + assert(summary.incorrect_predictions == 1); + assert(almost_equal(summary.hit_rate, 0.5)); + assert(almost_equal(summary.information_coefficient, 0.0)); + assert(almost_equal(summary.coverage_vs_total, 2.0 / 3.0)); +} + +void test_prediction_negative_label_when_first_non_zero_future_move_is_down() { + const std::vector snapshots = { + make_prediction_snapshot(0, 100.0, -0.7), + make_prediction_snapshot(1, 100.0, 0.6), + make_prediction_snapshot(2, 99.0, -0.2), + }; + + const std::vector summaries = + lob::summarize_prediction_horizons(snapshots, {2}); + assert(summaries.size() == 1); + + const lob::PredictionSummaryRow& summary = summaries.front(); + assert(summary.horizon_messages == 2); + assert(summary.total_rows_seen == 3); + assert(summary.eligible_rows_with_valid_mid == 3); + assert(summary.labeled_rows == 2); + assert(summary.skipped_no_valid_mid == 0); + assert(summary.skipped_no_future_move_within_horizon == 1); + assert(summary.skipped_zero_signal == 0); + assert(summary.up_moves == 0); + assert(summary.down_moves == 2); + assert(summary.correct_predictions == 1); + assert(summary.incorrect_predictions == 1); + assert(almost_equal(summary.hit_rate, 0.5)); + assert(almost_equal(summary.information_coefficient, 0.0)); + assert(almost_equal(summary.coverage_vs_total, 2.0 / 3.0)); +} + +void test_prediction_skips_invalid_current_mid() { + const std::vector snapshots = { + make_prediction_snapshot(0, std::nullopt, 0.6), + make_prediction_snapshot(1, 100.0, 0.5), + make_prediction_snapshot(2, 101.0, 0.4), + }; + + const std::vector summaries = + lob::summarize_prediction_horizons(snapshots, {2}); + const lob::PredictionSummaryRow& summary = summaries.front(); + + assert(summary.total_rows_seen == 3); + assert(summary.eligible_rows_with_valid_mid == 2); + assert(summary.labeled_rows == 1); + assert(summary.skipped_no_valid_mid == 1); + assert(summary.skipped_no_future_move_within_horizon == 1); + assert(summary.skipped_zero_signal == 0); + assert(summary.up_moves == 1); + assert(summary.down_moves == 0); + assert(summary.correct_predictions == 1); + assert(summary.incorrect_predictions == 0); + assert(almost_equal(summary.hit_rate, 1.0)); + assert(almost_equal(summary.information_coefficient, 0.0)); + assert(almost_equal(summary.coverage_vs_total, 1.0 / 3.0)); +} + +void test_prediction_skips_when_horizon_expires_with_only_zero_moves() { + const std::vector snapshots = { + make_prediction_snapshot(0, 100.0, 0.5), + make_prediction_snapshot(1, 100.0, 0.3), + make_prediction_snapshot(2, 100.0, -0.2), + }; + + const std::vector summaries = + lob::summarize_prediction_horizons(snapshots, {2}); + const lob::PredictionSummaryRow& summary = summaries.front(); + + assert(summary.total_rows_seen == 3); + assert(summary.eligible_rows_with_valid_mid == 3); + assert(summary.labeled_rows == 0); + assert(summary.skipped_no_valid_mid == 0); + assert(summary.skipped_no_future_move_within_horizon == 3); + assert(summary.skipped_zero_signal == 0); + assert(summary.up_moves == 0); + assert(summary.down_moves == 0); + assert(summary.correct_predictions == 0); + assert(summary.incorrect_predictions == 0); + assert(almost_equal(summary.hit_rate, 0.0)); + assert(almost_equal(summary.information_coefficient, 0.0)); + assert(almost_equal(summary.coverage_vs_total, 0.0)); +} + +void test_prediction_zero_signal_is_skipped_instead_of_labeled() { + const std::vector snapshots = { + make_prediction_snapshot(0, 100.0, 0.0), + make_prediction_snapshot(1, 101.0, 0.5), + make_prediction_snapshot(2, 102.0, -0.5), + }; + + const std::vector summaries = + lob::summarize_prediction_horizons(snapshots, {1}); + const lob::PredictionSummaryRow& summary = summaries.front(); + + assert(summary.total_rows_seen == 3); + assert(summary.eligible_rows_with_valid_mid == 3); + assert(summary.labeled_rows == 2); + assert(summary.skipped_no_valid_mid == 0); + assert(summary.skipped_no_future_move_within_horizon == 1); + assert(summary.skipped_zero_signal == 1); + assert(summary.up_moves == 2); + assert(summary.down_moves == 0); + assert(summary.correct_predictions == 1); + assert(summary.incorrect_predictions == 0); + assert(almost_equal(summary.hit_rate, 1.0)); + assert(almost_equal(summary.information_coefficient, 0.0)); + assert(almost_equal(summary.coverage_vs_total, 2.0 / 3.0)); +} + +void test_prediction_uses_first_non_zero_future_move_even_if_later_moves_reverse() { + const std::vector snapshots = { + make_prediction_snapshot(0, 100.0, 0.5), + make_prediction_snapshot(1, 101.0, -0.4), + make_prediction_snapshot(2, 99.0, 0.7), + }; + + const std::vector summaries = + lob::summarize_prediction_horizons(snapshots, {2}); + const lob::PredictionSummaryRow& summary = summaries.front(); + + assert(summary.total_rows_seen == 3); + assert(summary.eligible_rows_with_valid_mid == 3); + assert(summary.labeled_rows == 2); + assert(summary.skipped_no_valid_mid == 0); + assert(summary.skipped_no_future_move_within_horizon == 1); + assert(summary.skipped_zero_signal == 0); + assert(summary.up_moves == 1); + assert(summary.down_moves == 1); + assert(summary.correct_predictions == 2); + assert(summary.incorrect_predictions == 0); + assert(almost_equal(summary.hit_rate, 1.0)); + assert(almost_equal(summary.information_coefficient, 1.0)); + assert(almost_equal(summary.coverage_vs_total, 2.0 / 3.0)); +} + +void test_prediction_multi_horizon_case_can_skip_shorter_horizon_and_label_longer_one() { + const std::vector snapshots = { + make_prediction_snapshot(0, 100.0, 0.5), + make_prediction_snapshot(1, 100.0, 0.2), + make_prediction_snapshot(2, 100.0, -0.3), + make_prediction_snapshot(3, 101.0, 0.4), + }; + + const std::vector summaries = + lob::summarize_prediction_horizons(snapshots, {1, 3}); + assert(summaries.size() == 2); + + const lob::PredictionSummaryRow& short_horizon = summaries[0]; + assert(short_horizon.horizon_messages == 1); + assert(short_horizon.total_rows_seen == 4); + assert(short_horizon.eligible_rows_with_valid_mid == 4); + assert(short_horizon.labeled_rows == 1); + assert(short_horizon.skipped_no_valid_mid == 0); + assert(short_horizon.skipped_no_future_move_within_horizon == 3); + assert(short_horizon.skipped_zero_signal == 0); + assert(short_horizon.up_moves == 1); + assert(short_horizon.down_moves == 0); + assert(short_horizon.correct_predictions == 0); + assert(short_horizon.incorrect_predictions == 1); + assert(almost_equal(short_horizon.hit_rate, 0.0)); + assert(almost_equal(short_horizon.information_coefficient, 0.0)); + assert(almost_equal(short_horizon.coverage_vs_total, 0.25)); + + const lob::PredictionSummaryRow& long_horizon = summaries[1]; + assert(long_horizon.horizon_messages == 3); + assert(long_horizon.total_rows_seen == 4); + assert(long_horizon.eligible_rows_with_valid_mid == 4); + assert(long_horizon.labeled_rows == 3); + assert(long_horizon.skipped_no_valid_mid == 0); + assert(long_horizon.skipped_no_future_move_within_horizon == 1); + assert(long_horizon.skipped_zero_signal == 0); + assert(long_horizon.up_moves == 3); + assert(long_horizon.down_moves == 0); + assert(long_horizon.correct_predictions == 2); + assert(long_horizon.incorrect_predictions == 1); + assert(almost_equal(long_horizon.hit_rate, 2.0 / 3.0)); + assert(almost_equal(long_horizon.information_coefficient, 0.0)); + assert(almost_equal(long_horizon.coverage_vs_total, 0.75)); +} + +void test_prediction_information_coefficient_uses_raw_imbalance_values() { + const std::vector snapshots = { + make_prediction_snapshot(0, 100.0, -1.0), + make_prediction_snapshot(1, 99.0, -1.0), + make_prediction_snapshot(2, 98.0, 1.0), + make_prediction_snapshot(3, 99.0, 1.0), + make_prediction_snapshot(4, 100.0, 0.1), + }; + + const std::vector summaries = + lob::summarize_prediction_horizons(snapshots, {1}); + const lob::PredictionSummaryRow& summary = summaries.front(); + + assert(summary.labeled_rows == 4); + assert(summary.up_moves == 2); + assert(summary.down_moves == 2); + assert(summary.correct_predictions == 4); + assert(summary.information_coefficient > 0.999999); +} + +void test_prediction_report_writer_emits_expected_header_and_rows() { + const std::vector snapshots = { + make_prediction_snapshot(0, 100.0, 0.5), + make_prediction_snapshot(1, 101.0, 0.0), + make_prediction_snapshot(2, 101.0, -0.5), + make_prediction_snapshot(3, 100.0, 0.5), + make_prediction_snapshot(4, std::nullopt, 0.2), + }; + + const std::vector summaries = + lob::summarize_prediction_horizons(snapshots, {1, 2}); + const auto output_path = make_temp_file("prediction_report"); + lob::write_prediction_report_csv(summaries, output_path.string()); + + std::ifstream input(output_path); + assert(input.is_open()); + + std::string header; + std::string first_row; + std::string second_row; + assert(std::getline(input, header)); + assert(std::getline(input, first_row)); + assert(std::getline(input, second_row)); + + assert( + header == + "horizon_messages,total_rows_seen,eligible_rows_with_valid_mid,labeled_rows," + "skipped_no_valid_mid,skipped_no_future_move_within_horizon,skipped_zero_signal," + "up_moves,down_moves,correct_predictions,incorrect_predictions,hit_rate," + "information_coefficient,coverage_vs_total"); + assert(first_row == "1,5,4,2,1,2,0,1,1,2,0,1.000000,1.000000,0.400000"); + assert(second_row == "2,5,4,3,1,1,1,1,2,2,0,1.000000,0.866025,0.600000"); + + input.close(); + std::filesystem::remove(output_path); +} + } // namespace int main() { test_analytics_rows_cover_every_message(); test_trade_metrics_and_realized_vol_are_populated(); test_analytics_outputs_match_across_backends(); + test_prediction_config_defaults_and_round_trip(); + test_prediction_positive_label_when_first_non_zero_future_move_is_up(); + test_prediction_negative_label_when_first_non_zero_future_move_is_down(); + test_prediction_skips_invalid_current_mid(); + test_prediction_skips_when_horizon_expires_with_only_zero_moves(); + test_prediction_zero_signal_is_skipped_instead_of_labeled(); + test_prediction_uses_first_non_zero_future_move_even_if_later_moves_reverse(); + test_prediction_multi_horizon_case_can_skip_shorter_horizon_and_label_longer_one(); + test_prediction_information_coefficient_uses_raw_imbalance_values(); + test_prediction_report_writer_emits_expected_header_and_rows(); std::cout << "ALL TESTS PASSED\n"; return 0; } diff --git a/tests/test_parser.py b/tests/test_parser.py index ee18161..a459201 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -31,6 +31,8 @@ def test_cmake_build_and_cpp_binaries(): cli_binary = build_dir / executable_name("lob_engine") benchmark_binary = build_dir / executable_name("lob_benchmark") analytics_output = build_dir / "analytics_map.csv" + prediction_analytics_output = build_dir / "analytics_prediction.csv" + prediction_report_output = build_dir / "prediction_report.csv" cpp_tests = run_command([str(test_binary)], cwd=build_dir) assert "ALL TESTS PASSED" in cpp_tests.stdout @@ -58,8 +60,71 @@ def test_cmake_build_and_cpp_binaries(): cwd=build_dir, ) assert "Analytics CSV=" in cli_export.stdout - assert analytics_output.with_name("analytics_map_map.csv").exists() - assert analytics_output.with_name("analytics_map_flat_vector.csv").exists() + baseline_map_csv_path = analytics_output.with_name("analytics_map_map.csv") + baseline_flat_csv_path = analytics_output.with_name("analytics_map_flat_vector.csv") + assert baseline_map_csv_path.exists() + assert baseline_flat_csv_path.exists() + baseline_analytics_csv = baseline_map_csv_path.read_text() + baseline_flat_analytics_csv = baseline_flat_csv_path.read_text() + assert baseline_analytics_csv.startswith( + "timestamp,best_bid,best_ask,spread,mid,bid_depth_1,bid_depth_5,bid_depth_10," + "ask_depth_1,ask_depth_5,ask_depth_10,order_imbalance,rolling_vwap,trade_flow_imbalance," + "rolling_realized_vol\n" + ) + assert baseline_flat_analytics_csv == baseline_analytics_csv + assert len(baseline_analytics_csv.strip().splitlines()) == 21 + + cli_prediction = run_command( + [ + str(cli_binary), + str(sample_file), + "--analytics-out", + str(prediction_analytics_output), + "--prediction-report-out", + str(prediction_report_output), + "--prediction-horizons", + "100,500", + "--backend", + "both", + ], + cwd=build_dir, + ) + assert "Prediction report=" in cli_prediction.stdout + prediction_map_csv = prediction_analytics_output.with_name("analytics_prediction_map.csv") + prediction_report_map = prediction_report_output.with_name("prediction_report_map.csv") + prediction_report_flat = prediction_report_output.with_name("prediction_report_flat_vector.csv") + assert prediction_map_csv.exists() + assert prediction_report_map.exists() + assert prediction_report_flat.exists() + assert prediction_map_csv.read_text() == baseline_analytics_csv + assert prediction_analytics_output.with_name("analytics_prediction_flat_vector.csv").read_text() == baseline_flat_analytics_csv + prediction_report_lines = prediction_report_map.read_text().strip().splitlines() + prediction_report_flat_lines = prediction_report_flat.read_text().strip().splitlines() + assert prediction_report_lines[0] == ( + "horizon_messages,total_rows_seen,eligible_rows_with_valid_mid,labeled_rows," + "skipped_no_valid_mid,skipped_no_future_move_within_horizon,skipped_zero_signal," + "up_moves,down_moves,correct_predictions,incorrect_predictions,hit_rate," + "information_coefficient,coverage_vs_total" + ) + assert len(prediction_report_lines) == 3 + assert prediction_report_flat_lines == prediction_report_lines + + bad_prediction = subprocess.run( + [ + str(cli_binary), + str(sample_file), + "--prediction-report-out", + str(prediction_report_output), + "--prediction-horizons", + "1,,2", + ], + cwd=build_dir, + check=False, + text=True, + capture_output=True, + ) + assert bad_prediction.returncode != 0 + assert "Prediction horizons must not contain empty entries" in bad_prediction.stderr cli_both = run_command( [str(cli_binary), str(sample_file), "--backend", "both", "--repeat", "2"],