diff --git a/.gitignore b/.gitignore index 31caeb05..1b275edb 100644 --- a/.gitignore +++ b/.gitignore @@ -16,3 +16,6 @@ # config /cs2kz-api.toml + +examples/ +testdata/ \ No newline at end of file diff --git a/.sqlx/query-40dc82fa83d150c24dab57f4a610f981a505fdb2593449bbe901b7e6c1fa847d.json b/.sqlx/query-40dc82fa83d150c24dab57f4a610f981a505fdb2593449bbe901b7e6c1fa847d.json new file mode 100644 index 00000000..5adf54b9 --- /dev/null +++ b/.sqlx/query-40dc82fa83d150c24dab57f4a610f981a505fdb2593449bbe901b7e6c1fa847d.json @@ -0,0 +1,54 @@ +{ + "db_name": "MySQL", + "query": "SELECT\n filter_id AS `filter_id: CourseFilterId`,\n player_id AS `player_id: PlayerId`,\n record_id AS `record_id: RecordId`,\n time\n FROM BestNubRecords\n WHERE filter_id = ?\n ORDER BY time ASC", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "filter_id: CourseFilterId", + "type_info": { + "type": "Short", + "flags": "NOT_NULL | PRIMARY_KEY | UNSIGNED | NO_DEFAULT_VALUE", + "max_size": 5 + } + }, + { + "ordinal": 1, + "name": "player_id: PlayerId", + "type_info": { + "type": "LongLong", + "flags": "NOT_NULL | PRIMARY_KEY | MULTIPLE_KEY | UNSIGNED | NO_DEFAULT_VALUE", + "max_size": 20 + } + }, + { + "ordinal": 2, + "name": "record_id: RecordId", + "type_info": { + "type": "String", + "flags": "NOT_NULL | MULTIPLE_KEY | BINARY | NO_DEFAULT_VALUE", + "max_size": 16 + } + }, + { + "ordinal": 3, + "name": "time", + "type_info": { + "type": "Double", + "flags": "NOT_NULL | MULTIPLE_KEY | NO_DEFAULT_VALUE", + "max_size": 22 + } + } + ], + "parameters": { + "Right": 1 + }, + "nullable": [ + false, + false, + false, + false + ] + }, + "hash": "40dc82fa83d150c24dab57f4a610f981a505fdb2593449bbe901b7e6c1fa847d" +} diff --git a/.sqlx/query-5fc61264428ce56fd2b3b4a6d7bf2a9cfb5e537551932bb2f133f6030fa6ddb3.json b/.sqlx/query-5fc61264428ce56fd2b3b4a6d7bf2a9cfb5e537551932bb2f133f6030fa6ddb3.json deleted file mode 100644 index c8055805..00000000 --- a/.sqlx/query-5fc61264428ce56fd2b3b4a6d7bf2a9cfb5e537551932bb2f133f6030fa6ddb3.json +++ /dev/null @@ -1,84 +0,0 @@ -{ - "db_name": "MySQL", - "query": "WITH BanCounts AS (\n SELECT b.player_id, COUNT(*) AS count\n FROM Bans AS b\n RIGHT JOIN Unbans AS ub ON ub.ban_id = b.id\n WHERE (b.id IS NULL OR b.expires_at > NOW())\n )\n SELECT\n p.id AS `id: PlayerId`,\n p.name,\n p.vnl_rating,\n p.ckz_rating,\n (COALESCE(BanCounts.count, 0) > 0) AS `is_banned!: bool`,\n p.first_joined_at,\n p.last_joined_at\n FROM Players AS p\n LEFT JOIN BanCounts ON BanCounts.player_id = p.id WHERE p.id = ?", - "describe": { - "columns": [ - { - "ordinal": 0, - "name": "id: PlayerId", - "type_info": { - "type": "LongLong", - "flags": "NOT_NULL | PRIMARY_KEY | UNSIGNED | NO_DEFAULT_VALUE", - "max_size": 20 - } - }, - { - "ordinal": 1, - "name": "name", - "type_info": { - "type": "VarString", - "flags": "NOT_NULL | MULTIPLE_KEY | NO_DEFAULT_VALUE", - "max_size": 1020 - } - }, - { - "ordinal": 2, - "name": "vnl_rating", - "type_info": { - "type": "Double", - "flags": "NOT_NULL", - "max_size": 22 - } - }, - { - "ordinal": 3, - "name": "ckz_rating", - "type_info": { - "type": "Double", - "flags": "NOT_NULL", - "max_size": 22 - } - }, - { - "ordinal": 4, - "name": "is_banned!: bool", - "type_info": { - "type": "Long", - "flags": "BINARY", - "max_size": 1 - } - }, - { - "ordinal": 5, - "name": "first_joined_at", - "type_info": { - "type": "Timestamp", - "flags": "NOT_NULL | UNSIGNED | BINARY | TIMESTAMP", - "max_size": 19 - } - }, - { - "ordinal": 6, - "name": "last_joined_at", - "type_info": { - "type": "Timestamp", - "flags": "NOT_NULL | UNSIGNED | BINARY | TIMESTAMP", - "max_size": 19 - } - } - ], - "parameters": { - "Right": 1 - }, - "nullable": [ - false, - false, - false, - false, - true, - false, - false - ] - }, - "hash": "5fc61264428ce56fd2b3b4a6d7bf2a9cfb5e537551932bb2f133f6030fa6ddb3" -} diff --git a/.sqlx/query-8d0c9e3b783635e444300881be5f7ec190f757fcb1e97a954b1a83d8fe970a75.json b/.sqlx/query-8d0c9e3b783635e444300881be5f7ec190f757fcb1e97a954b1a83d8fe970a75.json new file mode 100644 index 00000000..bd604b7b --- /dev/null +++ b/.sqlx/query-8d0c9e3b783635e444300881be5f7ec190f757fcb1e97a954b1a83d8fe970a75.json @@ -0,0 +1,12 @@ +{ + "db_name": "MySQL", + "query": "INSERT INTO PointDistributionData (\n filter_id, is_pro_leaderboard, a, b, loc, scale, top_scale\n )\n VALUES (?, TRUE, ?, ?, ?, ?, ?)\n ON DUPLICATE KEY UPDATE\n a = VALUES(a),\n b = VALUES(b),\n loc = VALUES(loc),\n scale = VALUES(scale),\n top_scale = VALUES(top_scale)", + "describe": { + "columns": [], + "parameters": { + "Right": 6 + }, + "nullable": [] + }, + "hash": "8d0c9e3b783635e444300881be5f7ec190f757fcb1e97a954b1a83d8fe970a75" +} diff --git a/.sqlx/query-90ad65e337a74790d72d43e4881f4084fc4334a0a404cd8c416865049daccf2f.json b/.sqlx/query-90ad65e337a74790d72d43e4881f4084fc4334a0a404cd8c416865049daccf2f.json new file mode 100644 index 00000000..42ee3584 --- /dev/null +++ b/.sqlx/query-90ad65e337a74790d72d43e4881f4084fc4334a0a404cd8c416865049daccf2f.json @@ -0,0 +1,12 @@ +{ + "db_name": "MySQL", + "query": "INSERT INTO PointDistributionData (\n filter_id, is_pro_leaderboard, a, b, loc, scale, top_scale\n )\n VALUES (?, FALSE, ?, ?, ?, ?, ?)\n ON DUPLICATE KEY UPDATE\n a = VALUES(a),\n b = VALUES(b),\n loc = VALUES(loc),\n scale = VALUES(scale),\n top_scale = VALUES(top_scale)", + "describe": { + "columns": [], + "parameters": { + "Right": 6 + }, + "nullable": [] + }, + "hash": "90ad65e337a74790d72d43e4881f4084fc4334a0a404cd8c416865049daccf2f" +} diff --git a/.sqlx/query-91f17d399f1dc6daf4ca3363cb1c6bfbce0f6dc2a14aa104e269530eb266ceee.json b/.sqlx/query-91f17d399f1dc6daf4ca3363cb1c6bfbce0f6dc2a14aa104e269530eb266ceee.json new file mode 100644 index 00000000..01c78a73 --- /dev/null +++ b/.sqlx/query-91f17d399f1dc6daf4ca3363cb1c6bfbce0f6dc2a14aa104e269530eb266ceee.json @@ -0,0 +1,34 @@ +{ + "db_name": "MySQL", + "query": "SELECT\n nub_tier AS `nub_tier: Tier`,\n pro_tier AS `pro_tier: Tier`\n FROM CourseFilters\n WHERE id = ?", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "nub_tier: Tier", + "type_info": { + "type": "Tiny", + "flags": "NOT_NULL | UNSIGNED | NO_DEFAULT_VALUE", + "max_size": 3 + } + }, + { + "ordinal": 1, + "name": "pro_tier: Tier", + "type_info": { + "type": "Tiny", + "flags": "NOT_NULL | UNSIGNED | NO_DEFAULT_VALUE", + "max_size": 3 + } + } + ], + "parameters": { + "Right": 1 + }, + "nullable": [ + false, + false + ] + }, + "hash": "91f17d399f1dc6daf4ca3363cb1c6bfbce0f6dc2a14aa104e269530eb266ceee" +} diff --git a/.sqlx/query-f1c0f7b12cf83a9d41773011224756e681d4a715a28c6fa57fd76473af674c58.json b/.sqlx/query-a9027da8009faf6e3ef578e5994e7de555e0bff68270ed4594620de6c07228ab.json similarity index 90% rename from .sqlx/query-f1c0f7b12cf83a9d41773011224756e681d4a715a28c6fa57fd76473af674c58.json rename to .sqlx/query-a9027da8009faf6e3ef578e5994e7de555e0bff68270ed4594620de6c07228ab.json index e385779e..5dd6467f 100644 --- a/.sqlx/query-f1c0f7b12cf83a9d41773011224756e681d4a715a28c6fa57fd76473af674c58.json +++ b/.sqlx/query-a9027da8009faf6e3ef578e5994e7de555e0bff68270ed4594620de6c07228ab.json @@ -1,6 +1,6 @@ { "db_name": "MySQL", - "query": "SELECT a, b, loc, scale, top_scale\n FROM PointDistributionData\n WHERE filter_id = ?\n AND (NOT is_pro_leaderboard)", + "query": "SELECT a, b, loc, scale, top_scale\n FROM PointDistributionData\n WHERE filter_id = ?\n AND is_pro_leaderboard", "describe": { "columns": [ { @@ -60,5 +60,5 @@ false ] }, - "hash": "f1c0f7b12cf83a9d41773011224756e681d4a715a28c6fa57fd76473af674c58" + "hash": "a9027da8009faf6e3ef578e5994e7de555e0bff68270ed4594620de6c07228ab" } diff --git a/.sqlx/query-b340e131ba26df0a381b1be9e26b10d600f0ec9f1ab78fa81f1629f476ae3edb.json b/.sqlx/query-b340e131ba26df0a381b1be9e26b10d600f0ec9f1ab78fa81f1629f476ae3edb.json new file mode 100644 index 00000000..c4292bda --- /dev/null +++ b/.sqlx/query-b340e131ba26df0a381b1be9e26b10d600f0ec9f1ab78fa81f1629f476ae3edb.json @@ -0,0 +1,54 @@ +{ + "db_name": "MySQL", + "query": "SELECT a, b, loc, scale\n FROM PointDistributionData\n WHERE filter_id = ? AND (NOT is_pro_leaderboard)", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "a", + "type_info": { + "type": "Double", + "flags": "NOT_NULL | NO_DEFAULT_VALUE", + "max_size": 22 + } + }, + { + "ordinal": 1, + "name": "b", + "type_info": { + "type": "Double", + "flags": "NOT_NULL | NO_DEFAULT_VALUE", + "max_size": 22 + } + }, + { + "ordinal": 2, + "name": "loc", + "type_info": { + "type": "Double", + "flags": "NOT_NULL | NO_DEFAULT_VALUE", + "max_size": 22 + } + }, + { + "ordinal": 3, + "name": "scale", + "type_info": { + "type": "Double", + "flags": "NOT_NULL | NO_DEFAULT_VALUE", + "max_size": 22 + } + } + ], + "parameters": { + "Right": 1 + }, + "nullable": [ + false, + false, + false, + false + ] + }, + "hash": "b340e131ba26df0a381b1be9e26b10d600f0ec9f1ab78fa81f1629f476ae3edb" +} diff --git a/.sqlx/query-efd574ea4128095b642003df53ce4eed58f18536c81e4f7cc10906aa86faff2f.json b/.sqlx/query-efd574ea4128095b642003df53ce4eed58f18536c81e4f7cc10906aa86faff2f.json new file mode 100644 index 00000000..4ae40411 --- /dev/null +++ b/.sqlx/query-efd574ea4128095b642003df53ce4eed58f18536c81e4f7cc10906aa86faff2f.json @@ -0,0 +1,54 @@ +{ + "db_name": "MySQL", + "query": "SELECT\n filter_id AS `filter_id: CourseFilterId`,\n player_id AS `player_id: PlayerId`,\n record_id AS `record_id: RecordId`,\n time\n FROM BestProRecords\n WHERE filter_id = ?\n ORDER BY time ASC", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "filter_id: CourseFilterId", + "type_info": { + "type": "Short", + "flags": "NOT_NULL | PRIMARY_KEY | UNSIGNED | NO_DEFAULT_VALUE", + "max_size": 5 + } + }, + { + "ordinal": 1, + "name": "player_id: PlayerId", + "type_info": { + "type": "LongLong", + "flags": "NOT_NULL | PRIMARY_KEY | MULTIPLE_KEY | UNSIGNED | NO_DEFAULT_VALUE", + "max_size": 20 + } + }, + { + "ordinal": 2, + "name": "record_id: RecordId", + "type_info": { + "type": "String", + "flags": "NOT_NULL | MULTIPLE_KEY | BINARY | NO_DEFAULT_VALUE", + "max_size": 16 + } + }, + { + "ordinal": 3, + "name": "time", + "type_info": { + "type": "Double", + "flags": "NOT_NULL | MULTIPLE_KEY | NO_DEFAULT_VALUE", + "max_size": 22 + } + } + ], + "parameters": { + "Right": 1 + }, + "nullable": [ + false, + false, + false, + false + ] + }, + "hash": "efd574ea4128095b642003df53ce4eed58f18536c81e4f7cc10906aa86faff2f" +} diff --git a/.sqlx/query-faf48343593496abef8a35d44cc2e30379638b9f72a15be83220d1900a6104f1.json b/.sqlx/query-faf48343593496abef8a35d44cc2e30379638b9f72a15be83220d1900a6104f1.json deleted file mode 100644 index 45bf854a..00000000 --- a/.sqlx/query-faf48343593496abef8a35d44cc2e30379638b9f72a15be83220d1900a6104f1.json +++ /dev/null @@ -1,84 +0,0 @@ -{ - "db_name": "MySQL", - "query": "WITH BanCounts AS (\n SELECT b.player_id, COUNT(*) AS count\n FROM Bans AS b\n RIGHT JOIN Unbans AS ub ON ub.ban_id = b.id\n WHERE (b.id IS NULL OR b.expires_at > NOW())\n )\n SELECT\n p.id AS `id: PlayerId`,\n p.name,\n p.vnl_rating,\n p.ckz_rating,\n (COALESCE(BanCounts.count, 0) > 0) AS `is_banned!: bool`,\n p.first_joined_at,\n p.last_joined_at\n FROM Players AS p\n LEFT JOIN BanCounts ON BanCounts.player_id = p.id WHERE p.name LIKE ?", - "describe": { - "columns": [ - { - "ordinal": 0, - "name": "id: PlayerId", - "type_info": { - "type": "LongLong", - "flags": "NOT_NULL | PRIMARY_KEY | UNSIGNED | NO_DEFAULT_VALUE", - "max_size": 20 - } - }, - { - "ordinal": 1, - "name": "name", - "type_info": { - "type": "VarString", - "flags": "NOT_NULL | MULTIPLE_KEY | NO_DEFAULT_VALUE", - "max_size": 1020 - } - }, - { - "ordinal": 2, - "name": "vnl_rating", - "type_info": { - "type": "Double", - "flags": "NOT_NULL", - "max_size": 22 - } - }, - { - "ordinal": 3, - "name": "ckz_rating", - "type_info": { - "type": "Double", - "flags": "NOT_NULL", - "max_size": 22 - } - }, - { - "ordinal": 4, - "name": "is_banned!: bool", - "type_info": { - "type": "Long", - "flags": "BINARY", - "max_size": 1 - } - }, - { - "ordinal": 5, - "name": "first_joined_at", - "type_info": { - "type": "Timestamp", - "flags": "NOT_NULL | UNSIGNED | BINARY | TIMESTAMP", - "max_size": 19 - } - }, - { - "ordinal": 6, - "name": "last_joined_at", - "type_info": { - "type": "Timestamp", - "flags": "NOT_NULL | UNSIGNED | BINARY | TIMESTAMP", - "max_size": 19 - } - } - ], - "parameters": { - "Right": 1 - }, - "nullable": [ - false, - false, - false, - false, - true, - false, - false - ] - }, - "hash": "faf48343593496abef8a35d44cc2e30379638b9f72a15be83220d1900a6104f1" -} diff --git a/.sqlx/query-fc902d6636c7b4df7fa8fbaf1fe3d1b5eddbbe16122998033aa8f4c9991f3b36.json b/.sqlx/query-fc902d6636c7b4df7fa8fbaf1fe3d1b5eddbbe16122998033aa8f4c9991f3b36.json new file mode 100644 index 00000000..d264dcf6 --- /dev/null +++ b/.sqlx/query-fc902d6636c7b4df7fa8fbaf1fe3d1b5eddbbe16122998033aa8f4c9991f3b36.json @@ -0,0 +1,54 @@ +{ + "db_name": "MySQL", + "query": "SELECT a, b, loc, scale\n FROM PointDistributionData\n WHERE filter_id = ? AND is_pro_leaderboard", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "a", + "type_info": { + "type": "Double", + "flags": "NOT_NULL | NO_DEFAULT_VALUE", + "max_size": 22 + } + }, + { + "ordinal": 1, + "name": "b", + "type_info": { + "type": "Double", + "flags": "NOT_NULL | NO_DEFAULT_VALUE", + "max_size": 22 + } + }, + { + "ordinal": 2, + "name": "loc", + "type_info": { + "type": "Double", + "flags": "NOT_NULL | NO_DEFAULT_VALUE", + "max_size": 22 + } + }, + { + "ordinal": 3, + "name": "scale", + "type_info": { + "type": "Double", + "flags": "NOT_NULL | NO_DEFAULT_VALUE", + "max_size": 22 + } + } + ], + "parameters": { + "Right": 1 + }, + "nullable": [ + false, + false, + false, + false + ] + }, + "hash": "fc902d6636c7b4df7fa8fbaf1fe3d1b5eddbbe16122998033aa8f4c9991f3b36" +} diff --git a/Cargo.lock b/Cargo.lock index 887293c7..124733cf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1224,6 +1224,7 @@ dependencies = [ "futures-util", "lettre", "md-5 0.10.6", + "nig", "pin-project", "semver", "serde", @@ -2529,6 +2530,15 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "nig" +version = "0.1.0" +dependencies = [ + "rand 0.8.5", + "serde", + "tracing", +] + [[package]] name = "nom" version = "7.1.3" diff --git a/README.md b/README.md index ade0fcd3..d74c2a16 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ Questions and feedback are appreciated! Feel free to open an issue or join [our The API uses a configuration file called `cs2kz-api.toml`. An example configuration file is provided with all the default values filled in, -copy and modify it as you see fit. `.example.env` and `.docker.example.env` +copy and modify it as you see fit. `.example.env` and `.example.docker.env` should be copied to `.env` and `.docker.env` respectively. Again, change the default values as you see fit. @@ -32,6 +32,9 @@ Install docker and run the following command: ```sh docker compose up -d database + +# required for SQLx compile-time query checking +sqlx migrate run --source crates/cs2kz/migrations ``` To compile the API itself, you can use `cargo`: diff --git a/crates/cs2kz-api/src/bin/generator/main.rs b/crates/cs2kz-api/src/bin/generator/main.rs index 8cf46e35..0e0d99d5 100644 --- a/crates/cs2kz-api/src/bin/generator/main.rs +++ b/crates/cs2kz-api/src/bin/generator/main.rs @@ -80,7 +80,6 @@ async fn main() -> anyhow::Result<()> { min_connections: 1, max_connections: Some(NonZero::::MIN), }, - points: Default::default(), replay_storage: None, })?; @@ -381,7 +380,6 @@ async fn create_records( match records::submit(cx, record).await { Ok(SubmittedRecord { record_id: id, .. }) => info!(%id, "created record"), - Err(SubmitRecordError::CalculatePoints(error)) => return Err(error.into()), Err(SubmitRecordError::CalculateRating(error)) => return Err(error.into()), Err(SubmitRecordError::Database(error)) => return Err(error.into()), } diff --git a/crates/cs2kz-api/src/bin/server/cli.rs b/crates/cs2kz-api/src/bin/server/cli.rs index 2cdb2ffe..49f57fe0 100644 --- a/crates/cs2kz-api/src/bin/server/cli.rs +++ b/crates/cs2kz-api/src/bin/server/cli.rs @@ -33,10 +33,6 @@ pub struct Args { /// Path to the `DepotDownloader` executable the API should use. #[arg(long)] pub depot_downloader_path: Option, - - /// Path to a directory containing the `calc_filter.py` and `calc_run.py` scripts. - #[arg(long)] - pub scripts_path: Option, } impl Args { @@ -53,10 +49,5 @@ impl Args { if let Some(ref path) = self.depot_downloader_path { config.depot_downloader.exe_path = path.clone(); } - - if let Some(ref path) = self.scripts_path { - config.cs2kz.points.calc_filter_path = Some(path.join("calc_filter.py")); - config.cs2kz.points.calc_run_path = Some(path.join("calc_run.py")); - } } } diff --git a/crates/cs2kz/Cargo.toml b/crates/cs2kz/Cargo.toml index 31220f8e..9ed8705f 100644 --- a/crates/cs2kz/Cargo.toml +++ b/crates/cs2kz/Cargo.toml @@ -47,6 +47,9 @@ futures-util.workspace = true lettre.workspace = true sqlx.workspace = true +[dependencies.nig] +path = "../nig" + [dependencies.steam-id] path = "../steam-id" features = ["serde"] diff --git a/crates/cs2kz/src/config/mod.rs b/crates/cs2kz/src/config/mod.rs index 90e6d3d0..a6316585 100644 --- a/crates/cs2kz/src/config/mod.rs +++ b/crates/cs2kz/src/config/mod.rs @@ -1,9 +1,6 @@ mod database; pub use database::DatabaseConfig; -mod points; -pub use points::PointsConfig; - mod replays; pub use replays::ReplayStorageConfig; @@ -13,9 +10,6 @@ pub struct Config { #[serde(default)] pub database: DatabaseConfig, - #[serde(default)] - pub points: PointsConfig, - #[serde(default)] pub replay_storage: Option, } diff --git a/crates/cs2kz/src/config/points.rs b/crates/cs2kz/src/config/points.rs deleted file mode 100644 index d4e277a0..00000000 --- a/crates/cs2kz/src/config/points.rs +++ /dev/null @@ -1,10 +0,0 @@ -use std::path::PathBuf; - -use serde::Deserialize; - -#[derive(Debug, Default, Deserialize)] -#[serde(default, rename_all = "kebab-case", deny_unknown_fields)] -pub struct PointsConfig { - pub calc_filter_path: Option, - pub calc_run_path: Option, -} diff --git a/crates/cs2kz/src/context.rs b/crates/cs2kz/src/context.rs index 09d99e91..d79671d4 100644 --- a/crates/cs2kz/src/context.rs +++ b/crates/cs2kz/src/context.rs @@ -1,7 +1,8 @@ +use std::fmt; use std::sync::Arc; use std::time::Duration; -use std::{fmt, io}; +use tokio::sync::Notify; use tokio::task; use tokio_util::sync::CancellationToken; use tokio_util::task::TaskTracker; @@ -14,7 +15,6 @@ use crate::database::{ DatabaseConnectionOptions, EstablishDatabaseConnectionError, }; -use crate::points; mod inner { use super::*; @@ -25,8 +25,7 @@ mod inner { pub(super) database: Database, pub(super) shutdown_token: CancellationToken, pub(super) tasks: TaskTracker, - pub(super) points_calculator: Option, - pub(super) points_daemon: points::daemon::PointsDaemonHandle, + pub(super) points_recalculation_notify: Notify, pub(super) s3_client: Option, } } @@ -42,9 +41,6 @@ pub enum InitializeContextError { #[display("failed to run database migrations: {_0}")] RunDatabaseMigrations(sqlx::migrate::MigrateError), - - #[display("failed to initialize points calculator: {_0}")] - InitializePointsCalculator(io::Error), } impl Context { @@ -72,21 +68,6 @@ impl Context { database::MIGRATIONS.run(database.as_ref()).await?; let tasks = TaskTracker::new(); - let points_calculator = points::calculator::PointsCalculator::new(&config) - .await? - .map(|calc| { - let handle = calc.handle(); - let cancellation_token = shutdown_token.child_token(); - let task = tasks.track_future(calc.run(cancellation_token)); - - task::Builder::new() - .name("cs2kz::points_calculator") - .spawn(task) - .expect("failed to spawn tokio task"); - - handle - }); - let points_daemon = points::daemon::PointsDaemonHandle::new(); let s3_client = if let Some(ref cfg) = config.replay_storage { let config = aws_config::defaults(aws_config::BehaviorVersion::v2026_01_12()) @@ -112,8 +93,7 @@ impl Context { database, shutdown_token, tasks, - points_calculator, - points_daemon, + points_recalculation_notify: Notify::new(), s3_client, }))) } @@ -126,12 +106,12 @@ impl Context { &self.0.database } - pub fn points_calculator(&self) -> Option<&points::calculator::PointsCalculatorHandle> { - self.0.points_calculator.as_ref() + pub(crate) fn notify_points_recalculation(&self) { + self.0.points_recalculation_notify.notify_waiters(); } - pub fn points_daemon(&self) -> &points::daemon::PointsDaemonHandle { - &self.0.points_daemon + pub(crate) async fn wait_for_points_recalculation(&self) { + self.0.points_recalculation_notify.notified().await; } pub fn s3_client(&self) -> &aws_sdk_s3::Client { diff --git a/crates/cs2kz/src/lib.rs b/crates/cs2kz/src/lib.rs index 4c5bc066..351c1e4d 100644 --- a/crates/cs2kz/src/lib.rs +++ b/crates/cs2kz/src/lib.rs @@ -66,7 +66,5 @@ pub mod steam; pub mod styles; pub mod time; -mod python; - mod fmt; mod num; diff --git a/crates/cs2kz/src/points/calculator.rs b/crates/cs2kz/src/points/calculator.rs deleted file mode 100644 index 31fe8e5f..00000000 --- a/crates/cs2kz/src/points/calculator.rs +++ /dev/null @@ -1,135 +0,0 @@ -use std::io; -use std::time::Duration; - -use futures_util::TryFutureExt as _; -use tokio::sync::{mpsc, oneshot}; -use tokio::time::sleep; -use tokio_util::sync::CancellationToken; - -use crate::Config; -use crate::maps::courses::Tier; -use crate::points::DistributionParameters; -use crate::python::Python; - -type Message = (Request, oneshot::Sender); - -#[derive(Debug)] -pub struct PointsCalculator { - python: Python, - chan: (mpsc::Sender, mpsc::Receiver), -} - -#[derive(Debug, Clone)] -pub struct PointsCalculatorHandle { - chan: mpsc::Sender, -} - -#[derive(Debug, Display, Error)] -pub enum Error { - #[display("python error")] - Python(io::Error), -} - -#[derive(Debug, Clone, serde::Serialize)] -pub struct Request { - pub time: f64, - pub nub_data: LeaderboardData, - pub pro_data: Option, -} - -#[derive(Debug, serde::Deserialize)] -pub struct Response { - pub nub_fraction: f64, - pub pro_fraction: Option, -} - -#[derive(Debug, Display, Error)] -#[display("failed to calculate points ({_variant})")] -pub enum CalculatePointsError { - #[display("calculator unavailable")] - CalculatorUnavailable, -} - -#[derive(Debug, Clone, serde::Serialize)] -pub struct LeaderboardData { - pub dist_params: Option, - #[serde(serialize_with = "Tier::serialize_as_integer")] - pub tier: Tier, - pub leaderboard_size: u64, - #[serde(rename = "wr")] - pub top_time: f64, -} - -impl PointsCalculator { - pub async fn new(config: &Config) -> io::Result> { - let Some(script_path) = config.points.calc_run_path.as_deref() else { - tracing::warn!( - "no `points.calc-run-path` configured; points calculator will be disabled" - ); - return Ok(None); - }; - - let python = Python::new(script_path.to_owned(), config.database.url.clone()).await?; - let chan = mpsc::channel(128); - - Ok(Some(Self { python, chan })) - } - - pub fn handle(&self) -> PointsCalculatorHandle { - PointsCalculatorHandle { chan: self.chan.0.clone() } - } - - #[tracing::instrument(skip_all)] - pub async fn run(mut self, cancellation_token: CancellationToken) -> Result<(), Error> { - loop { - select! { - () = cancellation_token.cancelled() => { - tracing::debug!("cancelled"); - break Ok(()); - }, - - Some((request, response_tx)) = self.chan.1.recv() => { - let mut attempts = 0; - - if let Err(_err) = loop { - match self.python.send_request(&request).await { - Ok(response) => { - _ = response_tx.send(response); - break Ok(()); - }, - Err(err) => { - tracing::error!(%err, "failed to execute python request"); - self.python.reset().map_err(Error::Python).await?; - - attempts += 1; - - if attempts == 3 { - break Err(err); - } - - sleep(Duration::from_secs(1)).await; - }, - } - } { - tracing::error!("giving up after 3 failed attempts"); - } - }, - }; - } - } -} - -impl PointsCalculatorHandle { - pub async fn calculate(&self, request: Request) -> Result { - let (response_tx, response_rx) = oneshot::channel::(); - - if let Err(_send_err) = self.chan.send((request, response_tx)).await { - return Err(CalculatePointsError::CalculatorUnavailable); - } - - match response_rx.await { - Ok(response) => Ok(response), - Err(_recv_err) => Err(CalculatePointsError::CalculatorUnavailable), - } - } -} diff --git a/crates/cs2kz/src/points/daemon.rs b/crates/cs2kz/src/points/daemon.rs index b44733e2..c4097e32 100644 --- a/crates/cs2kz/src/points/daemon.rs +++ b/crates/cs2kz/src/points/daemon.rs @@ -1,48 +1,32 @@ -use std::io; -use std::sync::Arc; use std::time::Duration; use futures_util::TryFutureExt as _; -use tokio::sync::Notify; -use tokio::time::{interval, sleep}; +use nig::nig::NigParams; +use tokio::time::interval; use tokio_util::sync::CancellationToken; use crate::maps::CourseFilterId; -use crate::maps::courses::filters::GetCourseFiltersError; +use crate::maps::courses::Tier; use crate::mode::Mode; -use crate::python::Python; -use crate::records::GetRecordsError; +use crate::players::PlayerId; +use crate::points::{self}; +use crate::records::RecordId; use crate::{Context, database, players}; -#[derive(Debug, Clone)] -pub struct PointsDaemonHandle { - notifications: Arc, -} - -impl PointsDaemonHandle { - #[expect(clippy::new_without_default)] - pub fn new() -> Self { - Self { - notifications: Arc::new(Notifications { record_submitted: Notify::new() }), - } - } +const UPSERT_CHUNK_SIZE: usize = 5_000; // should prob put this somewhe - pub fn notify_record_submitted(&self) { - self.notifications.record_submitted.notify_waiters(); - } -} - -#[derive(Debug)] -struct Notifications { - record_submitted: Notify, +#[derive(Debug, Clone, Copy)] +struct BestRecordRow { + filter_id: CourseFilterId, + player_id: PlayerId, + record_id: RecordId, + time: f64, } #[derive(Debug, Display, Error, From)] pub enum Error { - GetCourseFilter(GetCourseFiltersError), - GetRecords(GetRecordsError), DetermineFilterToRecalculate(DetermineFilterToRecalculateError), - Python(io::Error), + ProcessFilter(database::Error), } #[derive(Debug, Display, Error, From)] @@ -50,54 +34,8 @@ pub enum Error { #[from(forward)] pub struct DetermineFilterToRecalculateError(database::Error); -#[derive(Debug, serde::Serialize)] -struct PythonRequest { - filter_id: CourseFilterId, -} - -#[derive(Debug, serde::Deserialize)] -struct PythonResponse { - #[expect(dead_code, reason = "included in tracing events")] - filter_id: CourseFilterId, - - #[expect(dead_code, reason = "included in tracing events")] - timings: PythonTimings, -} - -#[derive(Debug, serde::Deserialize)] -#[expect(dead_code, reason = "included in tracing events")] -struct PythonTimings { - #[serde(rename = "db_query_ms", deserialize_with = "deserialize_millis")] - db_query: Duration, - - #[serde(rename = "nub_fit_ms", deserialize_with = "deserialize_millis")] - nub_fit: Duration, - - #[serde(rename = "nub_compute_ms", deserialize_with = "deserialize_millis")] - nub_compute: Duration, - - #[serde(rename = "pro_fit_ms", deserialize_with = "deserialize_millis")] - pro_fit: Duration, - - #[serde(rename = "pro_compute_ms", deserialize_with = "deserialize_millis")] - pro_compute: Duration, - - #[serde(rename = "db_write_ms", deserialize_with = "deserialize_millis")] - db_write: Duration, -} - #[tracing::instrument(skip_all, err)] pub async fn run(cx: Context, cancellation_token: CancellationToken) -> Result<(), Error> { - let Some(script_path) = cx.config().points.calc_filter_path.as_deref() else { - tracing::warn!("no `points.calc-filter-path` configured; points daemon will be disabled"); - return Ok(()); - }; - - let mut python = Python::::new( - script_path.to_owned(), - cx.config().database.url.clone(), - ) - .await?; let mut recalc_ratings_interval = interval(Duration::from_secs(10)); loop { @@ -114,7 +52,7 @@ pub async fn run(cx: Context, cancellation_token: CancellationToken) -> Result<( res = determine_filter_to_recalculate(&cx) => { let (filter_id, priority) = res?; - process_filter(&mut python, &cancellation_token, filter_id).await?; + process_filter(&cx, filter_id).await?; update_filters_to_recalculate(&cx, filter_id, priority).await; }, }; @@ -153,12 +91,7 @@ async fn determine_filter_to_recalculate( break Ok(data); } - () = cx - .points_daemon() - .notifications - .record_submitted - .notified() - .await; + cx.wait_for_points_recalculation().await; tracing::trace!("received notification about submitted record"); } @@ -184,42 +117,211 @@ async fn update_filters_to_recalculate( } } -#[tracing::instrument(skip(python))] -async fn process_filter( - python: &mut Python, - cancellation_token: &CancellationToken, - filter_id: CourseFilterId, -) -> Result<(), Error> { - let request = PythonRequest { filter_id }; +#[tracing::instrument(skip(cx))] +async fn process_filter(cx: &Context, filter_id: CourseFilterId) -> Result<(), database::Error> { + tracing::debug!(%filter_id, "recalculating filter"); - loop { - tracing::debug!(?request); - match cancellation_token - .run_until_cancelled(python.send_request(&request)) - .await - { - None => { - tracing::debug!("cancelled"); - break Ok(()); - }, - Some(Ok(response)) => { - tracing::debug!(?response); - break Ok(()); - }, - Some(Err(err)) => { - tracing::error!(%err, "failed to execute python request"); - python.reset().map_err(Error::Python).await?; - sleep(Duration::from_secs(1)).await; - }, + let db = cx.database().as_ref(); + + let nub_rows = sqlx::query_as!( + BestRecordRow, + "SELECT + filter_id AS `filter_id: CourseFilterId`, + player_id AS `player_id: PlayerId`, + record_id AS `record_id: RecordId`, + time + FROM BestNubRecords + WHERE filter_id = ? + ORDER BY time ASC", + filter_id, + ) + .fetch_all(db) + .await?; + + let nub_times = nub_rows.iter().map(|row| row.time).collect::>(); + + // Pro records (sorted by time ASC) + let pro_rows = sqlx::query_as!( + BestRecordRow, + "SELECT + filter_id AS `filter_id: CourseFilterId`, + player_id AS `player_id: PlayerId`, + record_id AS `record_id: RecordId`, + time + FROM BestProRecords + WHERE filter_id = ? + ORDER BY time ASC", + filter_id, + ) + .fetch_all(db) + .await?; + + let pro_times = pro_rows.iter().map(|row| row.time).collect::>(); + + // Filter tiers + let tiers_row = sqlx::query!( + "SELECT + nub_tier AS `nub_tier: Tier`, + pro_tier AS `pro_tier: Tier` + FROM CourseFilters + WHERE id = ?", + filter_id, + ) + .fetch_optional(db) + .await?; + + let Some(tiers_row) = tiers_row else { + tracing::warn!(%filter_id, "filter not found in CourseFilters"); + return Ok(()); + }; + + let nub_tier = tiers_row.nub_tier; + let pro_tier = tiers_row.pro_tier; + + // Previous distribution parameters for warm start + let prev_nub_params = sqlx::query_as!( + NigParams, + "SELECT a, b, loc, scale + FROM PointDistributionData + WHERE filter_id = ? AND (NOT is_pro_leaderboard)", + filter_id, + ) + .fetch_optional(db) + .await?; + + let prev_pro_params = sqlx::query_as!( + NigParams, + "SELECT a, b, loc, scale + FROM PointDistributionData + WHERE filter_id = ? AND is_pro_leaderboard", + filter_id, + ) + .fetch_optional(db) + .await?; + + let (nub_result, pro_result) = tokio::task::spawn_blocking(move || { + let nub_result = points::recalculate_leaderboard(&nub_times, nub_tier, prev_nub_params); + + let mut pro_result = points::recalculate_leaderboard(&pro_times, pro_tier, prev_pro_params); + + for (time, recalculated_points) in pro_times.iter().zip(pro_result.records.iter_mut()) { + let nub_fraction = points::calculate_fraction(*time, &nub_result.leaderboard); + *recalculated_points = (*recalculated_points).max(nub_fraction); } - } + + (nub_result, pro_result) + }) + .await + .map_err(|_| { + database::Error::decode(std::io::Error::other("points recalculation task panicked")) + })?; + + tracing::debug!( + %filter_id, + nub_fitted = nub_result.leaderboard.dist_params.is_some(), + pro_fitted = pro_result.leaderboard.dist_params.is_some(), + "recalculation complete, writing to DB" + ); + + cx.database_transaction(async move |conn| -> Result<_, database::Error> { + upsert_best_records( + conn, + "INSERT INTO BestNubRecords (filter_id, player_id, record_id, points, time)", + &nub_rows, + &nub_result.records, + ) + .await?; + + upsert_best_records( + conn, + "INSERT INTO BestProRecords (filter_id, player_id, record_id, points, time)", + &pro_rows, + &pro_result.records, + ) + .await?; + + if let Some(params) = nub_result.leaderboard.dist_params { + sqlx::query!( + "INSERT INTO PointDistributionData ( + filter_id, is_pro_leaderboard, a, b, loc, scale, top_scale + ) + VALUES (?, FALSE, ?, ?, ?, ?, ?) + ON DUPLICATE KEY UPDATE + a = VALUES(a), + b = VALUES(b), + loc = VALUES(loc), + scale = VALUES(scale), + top_scale = VALUES(top_scale)", + filter_id, + params.a, + params.b, + params.loc, + params.scale, + params.top_scale, + ) + .execute(&mut *conn) + .await?; + } + + if let Some(params) = pro_result.leaderboard.dist_params { + sqlx::query!( + "INSERT INTO PointDistributionData ( + filter_id, is_pro_leaderboard, a, b, loc, scale, top_scale + ) + VALUES (?, TRUE, ?, ?, ?, ?, ?) + ON DUPLICATE KEY UPDATE + a = VALUES(a), + b = VALUES(b), + loc = VALUES(loc), + scale = VALUES(scale), + top_scale = VALUES(top_scale)", + filter_id, + params.a, + params.b, + params.loc, + params.scale, + params.top_scale, + ) + .execute(&mut *conn) + .await?; + } + + Ok(()) + }) + .await?; + + Ok(()) } -fn deserialize_millis<'de, D>(deserializer: D) -> Result -where - D: serde::Deserializer<'de>, -{ - >::deserialize(deserializer) - .map(|millis| millis / 1000.0) - .map(Duration::from_secs_f64) +async fn upsert_best_records( + conn: &mut database::Connection, + insert_prefix: &'static str, + rows: &[BestRecordRow], + recalculated_points: &[f64], +) -> Result<(), database::Error> { + if rows.len() != recalculated_points.len() { + return Err(database::Error::decode(std::io::Error::other( + "recalculated record count does not match fetched best record rows", + ))); + } + + for (row_chunk, points_chunk) in rows + .chunks(UPSERT_CHUNK_SIZE) + .zip(recalculated_points.chunks(UPSERT_CHUNK_SIZE)) + { + let mut query = database::QueryBuilder::new(insert_prefix); + + query.push_values(row_chunk.iter().zip(points_chunk.iter()), |mut query, (row, points)| { + query.push_bind(row.filter_id); + query.push_bind(row.player_id); + query.push_bind(row.record_id); + query.push_bind(points); + query.push_bind(row.time); + }); + + query.push(" ON DUPLICATE KEY UPDATE points = VALUES(points)"); + query.build().persistent(false).execute(&mut *conn).await?; + } + + Ok(()) } diff --git a/crates/cs2kz/src/points.rs b/crates/cs2kz/src/points/mod.rs similarity index 51% rename from crates/cs2kz/src/points.rs rename to crates/cs2kz/src/points/mod.rs index c2952332..4c04b71f 100644 --- a/crates/cs2kz/src/points.rs +++ b/crates/cs2kz/src/points/mod.rs @@ -1,19 +1,11 @@ +use ::nig::nig::{self, NigParams}; + use crate::maps::courses::Tier; pub mod daemon; -pub mod calculator; -/// The maximum points for any record. -pub const MAX: f64 = 10_000.0; - -/// Threshold for what counts as a "small" leaderboard. -pub const SMALL_LEADERBOARD_THRESHOLD: u64 = 50; - -/// [Normal-inverse Gaussian distribution][norminvgauss] parameters. -/// -/// [norminvgauss]: https://en.wikipedia.org/wiki/Normal-inverse_Gaussian_distribution #[derive(Debug, Clone, Copy, serde::Serialize)] -pub struct DistributionParameters { +pub struct NigData { pub a: f64, pub b: f64, pub loc: f64, @@ -21,6 +13,40 @@ pub struct DistributionParameters { pub top_scale: f64, } +impl NigData { + pub fn params(&self) -> nig::NigParams { + nig::NigParams { + a: self.a, + b: self.b, + loc: self.loc, + scale: self.scale, + } + } +} + +#[derive(Debug, Clone, serde::Serialize)] +pub struct LeaderboardData { + pub dist_params: Option, + #[serde(serialize_with = "Tier::serialize_as_integer")] + pub tier: Tier, + pub leaderboard_size: u64, + #[serde(rename = "wr")] + pub top_time: f64, +} + +/// Result of a leaderboard recalculation. +#[derive(Debug, Clone)] +pub struct RecalculatedLeaderboard { + pub leaderboard: LeaderboardData, + pub records: Vec, +} + +/// The maximum points for any record. +pub const MAX: f64 = 10_000.0; + +/// Threshold for what counts as a "small" leaderboard. +pub const SMALL_LEADERBOARD_THRESHOLD: u64 = 50; + /// "Completes" pre-calculated distribution points cached in the database. /// /// # Panics @@ -86,3 +112,58 @@ pub fn for_small_leaderboard(tier: Tier, top_time: f64, time: f64) -> f64 { y / z } + +pub fn calculate_fraction(time: f64, leaderboard: &LeaderboardData) -> f64 { + if leaderboard.leaderboard_size < SMALL_LEADERBOARD_THRESHOLD { + return for_small_leaderboard(leaderboard.tier, leaderboard.top_time, time); + } + + let Some(dist) = leaderboard.dist_params else { + return for_small_leaderboard(leaderboard.tier, leaderboard.top_time, time); + }; + + nig::sf(&dist.params(), time) / dist.top_scale +} + +/// Recompute point fractions for a single leaderboard. +/// +/// `times` must be sorted ascending (fastest first). +pub fn recalculate_leaderboard( + times: &[f64], + tier: Tier, + prev_params: Option, +) -> RecalculatedLeaderboard { + let params = fit_distribution(times, prev_params); + + let leaderboard = LeaderboardData { + dist_params: params, + tier, + leaderboard_size: times.len() as u64, + top_time: times.first().copied().unwrap_or(0.0), + }; + + let records = times + .iter() + .map(|&time| calculate_fraction(time, &leaderboard)) + .collect(); + + RecalculatedLeaderboard { leaderboard, records } +} + +fn fit_distribution(times: &[f64], prev_params: Option) -> Option { + if times.len() < SMALL_LEADERBOARD_THRESHOLD as usize { + return None; + } + + let p = nig::fit(times, prev_params); + let sf = nig::sf(&p, times[0]); + let top_scale = if sf <= 0.0 { 1.0 } else { sf }; + + Some(NigData { + a: p.a, + b: p.b, + loc: p.loc, + scale: p.scale, + top_scale, + }) +} diff --git a/crates/cs2kz/src/python.rs b/crates/cs2kz/src/python.rs deleted file mode 100644 index eb99f5b5..00000000 --- a/crates/cs2kz/src/python.rs +++ /dev/null @@ -1,223 +0,0 @@ -use std::bstr::ByteStr; -use std::ffi::OsStr; -use std::marker::PhantomData; -use std::path::{Path, PathBuf}; -use std::process::Stdio; -use std::time::Duration; -use std::{fmt, io, mem}; - -use futures_util::FutureExt as _; -use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; -use tokio::process::{Child, ChildStderr, ChildStdout, Command}; -use tokio::sync::OnceCell; -use tokio::task; -use tokio_util::task::AbortOnDropHandle; -use tokio_util::time::FutureExt as _; -use tracing::Instrument as _; -use url::Url; - -#[derive(Debug)] -pub struct Python { - script_path: PathBuf, - database_url: Url, - process: Child, - process_stdout: BufReader, - process_stderr_reader_task: AbortOnDropHandle>, - _request: PhantomData Request>, - _response: PhantomData Response>, -} - -#[derive(Debug, serde::Deserialize)] -#[serde(untagged)] -enum PythonResponse { - Success(T), - Error { - #[serde(rename = "error")] - message: String, - }, -} - -impl Python { - pub async fn new(script_path: PathBuf, database_url: Url) -> io::Result { - let (mut process, process_stderr_reader_task) = - spawn_script(&script_path, &database_url).await?; - - let process_stdout = process - .stdout - .take() - .map(BufReader::new) - .expect("we only take stdout once"); - - Ok(Self { - script_path, - database_url, - process, - process_stdout, - process_stderr_reader_task: AbortOnDropHandle::new(process_stderr_reader_task), - _request: PhantomData, - _response: PhantomData, - }) - } - - #[tracing::instrument(skip(self), err)] - pub async fn send_request(&mut self, request: &Request) -> io::Result - where - Request: fmt::Debug + serde::Serialize, - Response: for<'de> serde::Deserialize<'de>, - { - let mut serialized_request = - serde_json::to_vec(request).expect("requests should serialize to JSON"); - serialized_request.push(b'\n'); - - let mut serialized_response = Vec::with_capacity(128); - - 'outer: loop { - serialized_response.clear(); - - if let Some(exit_status) = self.process.try_wait()? { - tracing::warn!(?exit_status, "python process exited"); - self.reset().await?; - continue; - } - - tracing::trace!( - request = str::from_utf8(&serialized_request).unwrap(), - "writing request to python stdin" - ); - { - let stdin = self.process.stdin.as_mut().expect("we never close stdin"); - stdin.write_all(&serialized_request[..]).await?; - stdin.flush().await?; - } - - tracing::trace!("reading response from python stdout"); - for _ in 0..3 { - match self - .process_stdout - .read_until(b'\n', &mut serialized_response) - .timeout(Duration::from_secs(10)) - .await - { - Ok(Ok(_)) => break, - Ok(Err(err)) => { - tracing::error!(%err, "failed to read from stdout"); - self.reset().await?; - continue 'outer; - }, - Err(_elapsed) => { - tracing::warn!( - stdout = ?ByteStr::new(self.process_stdout.buffer()), - response_buf = ?ByteStr::new(&serialized_response), - "still waiting for response", - ); - }, - } - } - - break if serialized_response.is_empty() { - Err(io::Error::new(io::ErrorKind::TimedOut, "did not complete request in time")) - } else { - serde_json::from_slice::>(&serialized_response[..]) - .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err)) - .and_then(|response| match response { - PythonResponse::Success(res) => Ok(res), - PythonResponse::Error { message } => Err(io::Error::other(message)), - }) - }; - } - } - - pub async fn reset(&mut self) -> io::Result<()> { - let (process, process_stderr_reader_task) = - spawn_script(&self.script_path, &self.database_url).await?; - self.process = process; - self.process_stdout = self - .process - .stdout - .take() - .map(BufReader::new) - .expect("we only take stdout once"); - - let old_process_stderr_reader_task = mem::replace( - &mut self.process_stderr_reader_task, - AbortOnDropHandle::new(process_stderr_reader_task), - ); - - if let Some(Ok(Err(err))) = old_process_stderr_reader_task.now_or_never() { - tracing::error!(%err, "python stderr task encountered an error"); - } - - Ok(()) - } -} - -async fn spawn_script( - path: &Path, - database_url: &Url, -) -> io::Result<(Child, task::JoinHandle>)> { - let span = tracing::debug_span!("python", script_path = %path.display()); - let executable_name = resolve_executable_name().await?; - let mut child = Command::new(executable_name) - .arg(path) - .env("DATABASE_URL", database_url.as_str()) - .stdin(Stdio::piped()) - .stdout(Stdio::piped()) - .stderr(Stdio::piped()) - .spawn()?; - - let stderr = child.stderr.take().expect("we only take stderr once"); - let task = task::spawn(read_from_stderr(stderr).instrument(span)); - - Ok((child, task)) -} - -async fn resolve_executable_name() -> io::Result<&'static OsStr> { - #[cfg(unix)] - const EXECUTABLE_NAMES: &[&str] = &["python3", "python", "py"]; - - #[cfg(windows)] - const EXECUTABLE_NAMES: &[&str] = &["python3.exe", "python.exe", "py.exe"]; - - static EXECUTABLE_NAME: OnceCell<&OsStr> = OnceCell::const_new(); - - EXECUTABLE_NAME - .get_or_try_init(async || { - for name in EXECUTABLE_NAMES { - match Command::new(name).arg("--version").output().await { - Ok(output) => { - if output.status.success() { - return Ok(OsStr::new(name)); - } - }, - - Err(err) if err.kind() == io::ErrorKind::NotFound => { - continue; - }, - - Err(err) => return Err(err), - } - } - - Err(io::Error::other("failed to find suitable python executable")) - }) - .await - .copied() -} - -async fn read_from_stderr(stderr: ChildStderr) -> io::Result<()> { - let mut stderr = BufReader::new(stderr); - let mut line = String::new(); - - while let 1.. = stderr.read_line(&mut line).await? { - if let Some(c) = line.pop() - && c != '\n' - { - line.push(c); - } - - tracing::debug!(line); - line.clear(); - } - - Ok(()) -} diff --git a/crates/cs2kz/src/records.rs b/crates/cs2kz/src/records.rs index 257148c9..9029701d 100644 --- a/crates/cs2kz/src/records.rs +++ b/crates/cs2kz/src/records.rs @@ -14,16 +14,17 @@ use crate::num::AsF64; use crate::pagination::{Limit, Offset, Paginated}; use crate::players::{CalculateRatingError, PlayerId, PlayerInfo}; use crate::plugin::PluginVersionId; -use crate::points::calculator::{ - CalculatePointsError, - Request as CalculatePointsRequest, - Response as CalculatePointsResponse, -}; -use crate::points::{self, DistributionParameters}; +use crate::points::{self, NigData}; use crate::servers::{ServerId, ServerInfo}; use crate::styles::{ClientStyleInfo, Styles}; use crate::time::Seconds; +#[derive(Debug, Clone, Copy)] +struct CalculatedPoints { + nub_fraction: f64, + pro_fraction: Option, +} + define_id_type! { /// A unique identifier for records. pub struct RecordId(Uuid); @@ -197,10 +198,6 @@ pub struct SubmittedPB { #[derive(Debug, Display, Error, From)] pub enum SubmitRecordError { - #[display("{_0}")] - #[from] - CalculatePoints(CalculatePointsError), - #[display("{_0}")] #[from] CalculateRating(CalculateRatingError), @@ -288,7 +285,7 @@ pub async fn submit( .await?; let nub_dist = sqlx::query_as!( - DistributionParameters, + NigData, "SELECT a, b, loc, scale, top_scale FROM PointDistributionData WHERE filter_id = ? @@ -332,7 +329,7 @@ pub async fn submit( (None, Some(_)) => unreachable!(), }; - let nub_data = points::calculator::LeaderboardData { + let nub_data = points::LeaderboardData { dist_params: nub_dist, tier: nub_tier, leaderboard_size: nub_leaderboard_size, @@ -341,11 +338,11 @@ pub async fn submit( let pro_data = if teleports == 0 { let pro_dist = sqlx::query_as!( - DistributionParameters, + NigData, "SELECT a, b, loc, scale, top_scale FROM PointDistributionData WHERE filter_id = ? - AND (NOT is_pro_leaderboard)", + AND is_pro_leaderboard", filter_id, ) .fetch_optional(&mut *conn) @@ -376,7 +373,7 @@ pub async fn submit( row.map_or((0, None), |row| (row.size as u64, row.top_time)) })?; - Some(points::calculator::LeaderboardData { + Some(points::LeaderboardData { dist_params: pro_dist, tier: pro_tier, leaderboard_size: pro_leaderboard_size + u64::from(pro_pb_time.is_none()), @@ -386,37 +383,12 @@ pub async fn submit( None }; - let points = if nub_leaderboard_size > points::SMALL_LEADERBOARD_THRESHOLD - && nub_data.dist_params.is_some() - && let Some(calc) = cx.points_calculator() - { - let request = CalculatePointsRequest { - time: time.as_f64(), - nub_data, - pro_data: pro_data.clone().filter(|data| data.dist_params.is_some()), - }; - - let mut response = calc.calculate(request.clone()).await?; - - if let Some(ref pro_data) = pro_data && request.pro_data.is_none() { - response.pro_fraction = Some(points::for_small_leaderboard( - pro_data.tier, - pro_data.top_time, - time.as_f64(), - )); - } - - response - } else { - let nub_fraction = - points::for_small_leaderboard(nub_data.tier, nub_data.top_time, time.as_f64()); - - let pro_fraction = pro_data.clone().map(|pro_data| { - points::for_small_leaderboard(pro_data.tier, pro_data.top_time, time.as_f64()) - }); - - CalculatePointsResponse { nub_fraction, pro_fraction } - }; + let nub_fraction = points::calculate_fraction(time.as_f64(), &nub_data); + let pro_fraction = pro_data + .as_ref() + .map(|leaderboard| points::calculate_fraction(time.as_f64(), leaderboard)) + .map(|fraction| fraction.max(nub_fraction)); + let points = CalculatedPoints { nub_fraction, pro_fraction }; let is_nub_pb = nub_pb_time.is_none_or(|nub_pb_time| nub_pb_time > time.as_f64()); @@ -528,7 +500,7 @@ pub async fn submit( } } - cx.points_daemon().notify_record_submitted(); + cx.notify_points_recalculation(); events::dispatch(Event::NewRecord { player_id, diff --git a/crates/nig/Cargo.toml b/crates/nig/Cargo.toml new file mode 100644 index 00000000..fee065b0 --- /dev/null +++ b/crates/nig/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "nig" +version = "0.1.0" +authors = ["Praetor"] +edition.workspace = true +homepage.workspace = true +repository.workspace = true +license = "AGPL-3.0-only" + +[dependencies] +rand = { version = "0.8", features = ["small_rng"] } +serde = { workspace = true, features = ["derive"] } +tracing.workspace = true diff --git a/crates/nig/src/bessel.rs b/crates/nig/src/bessel.rs new file mode 100644 index 00000000..420cc041 --- /dev/null +++ b/crates/nig/src/bessel.rs @@ -0,0 +1,176 @@ +// Boost Software License - Version 1.0 - August 17th, 2003 + +// Permission is hereby granted, free of charge, to any person or organization +// obtaining a copy of the software and accompanying documentation covered by +// this license (the "Software") to use, reproduce, display, distribute, +// execute, and transmit the Software, and to prepare derivative works of the +// Software, and to permit third-parties to whom the Software is furnished to +// do so, all subject to the following: + +// The copyright notices in the Software and this entire statement, including +// the above license grant, this restriction and the following disclaimer, +// must be included in all copies of the Software, in whole or in part, and +// all derivative works of the Software, unless such copies or derivative +// works are solely in the form of machine-executable object code generated by +// a source language processor. + +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT +// SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE +// FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, +// ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +/// Evaluates the rational function P(z)/Q(z) using Horner's method. +/// https://github.com/boostorg/math/blob/develop/include/boost/math/tools/rational.hpp +pub(crate) fn evaluate_rational(p: &[f64], q: &[f64], z: f64) -> f64 { + if p.is_empty() || q.is_empty() { + return 0.0; + } + + let mut pn = p[p.len() - 1]; + for i in (0..p.len() - 1).rev() { + pn = pn * z + p[i]; + } + + let mut qn = q[q.len() - 1]; + for i in (0..q.len() - 1).rev() { + qn = qn * z + q[i]; + } + + pn / qn +} + +/// Bessel K1 functions adapted from Boost +/// https://github.com/boostorg/math/blob/develop/include/boost/math/special_functions/detail/bessel_k1.hpp +pub(crate) fn bessel_k1(x: f64) -> f64 { + if x <= 0.0 { + return f64::INFINITY; + } + + if x <= 1.0 { + bessel_k1_small(x) + } else { + bessel_k1_large(x) + } +} + +fn bessel_k1_small(x: f64) -> f64 { + const Y: f64 = 8.69547128677368164e-02; + + const P1: [f64; 4] = [ + -3.62137953440350228e-03, + 7.11842087490330300e-03, + 1.00302560256614306e-05, + 1.77231085381040811e-06, + ]; + const Q1: [f64; 4] = [ + 1.00000000000000000e+00, + -4.80414794429043831e-02, + 9.85972641934416525e-04, + -8.91196859397070326e-06, + ]; + + const P2: [f64; 4] = [ + -3.07965757829206184e-01, + -7.80929703673074907e-02, + -2.70619343754051620e-03, + -2.49549522229072008e-05, + ]; + const Q2: [f64; 4] = [ + 1.00000000000000000e+00, + -2.36316836412163098e-02, + 2.64524577525962719e-04, + -1.49749618004162787e-06, + ]; + + let a = x * x / 4.0; + let log_term = ((evaluate_rational(&P1, &Q1, a) + Y) * a * a + a / 2.0 + 1.0) * x / 2.0; + + evaluate_rational(&P2, &Q2, x * x) * x + 1.0 / x + x.ln() * log_term +} + +fn bessel_k1_large(x: f64) -> f64 { + let scaled = bessel_k1e_large(x); + if x < 709.0 { + return scaled * (-x).exp(); + } + + let exp_half = (-x / 2.0).exp(); + scaled * exp_half * exp_half +} + +fn bessel_k1e_large(x: f64) -> f64 { + const Y: f64 = 1.45034217834472656e+00; + + const P: [f64; 9] = [ + -1.97028041029226295e-01, + -2.32408961548087617e+00, + -7.98269784507699938e+00, + -2.39968410774221632e+00, + 3.28314043780858713e+01, + 5.67713761158496058e+01, + 3.30907788466509823e+01, + 6.62582288933739787e+00, + 3.08851840645286691e-01, + ]; + const Q: [f64; 9] = [ + 1.00000000000000000e+00, + 1.41811409298826118e+01, + 7.35979466317556420e+01, + 1.77821793937080859e+02, + 2.11014501598705982e+02, + 1.19425262951064454e+02, + 2.88448064302447607e+01, + 2.27912927104139732e+00, + 2.50358186953478678e-02, + ]; + + (evaluate_rational(&P, &Q, 1.0 / x) + Y) / x.sqrt() +} + +pub fn bessel_k1e(x: f64) -> f64 { + if x <= 0.0 { + return f64::INFINITY; + } + + if x <= 1.0 { + return bessel_k1(x) * x.exp(); + } + + bessel_k1e_large(x) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn assert_rel_close(actual: f64, expected: f64, tolerance: f64) { + let rel_error = if expected == 0.0 { + actual.abs() + } else { + (actual - expected).abs() / expected.abs() + }; + assert!( + rel_error <= tolerance, + "expected {expected:.15e}, got {actual:.15e}, rel error {rel_error:.2e}", + ); + } + + #[test] + fn bessel_k1_matches_reference_values() { + for (x, expected) in [ + (0.001, 9.999962381560855e+02), + (0.1, 9.853844780870606e+00), + (1.0, 6.019072301972346e-01), + (2.0, 1.398658818165225e-01), + (5.0, 4.044613445452163e-03), + (20.0, 5.883057969557038e-10), + (100.0, 4.679853735636910e-45), + (709.0, 5.730317612554602e-310), + ] { + assert_rel_close(bessel_k1(x), expected, 1e-14); + } + } +} diff --git a/crates/nig/src/differential_evo.rs b/crates/nig/src/differential_evo.rs new file mode 100644 index 00000000..2af0b4aa --- /dev/null +++ b/crates/nig/src/differential_evo.rs @@ -0,0 +1,113 @@ +use rand::rngs::SmallRng; +use rand::{Rng, SeedableRng}; + +/// Differential Evolution optimizer +/// +/// Storn, R. and Price, K. Differential Evolution - A Simple and Efficient +/// Heuristic for Global Optimization over Continuous Spaces. Journal of +/// Global Optimization 11, 341-359 (1997). +pub fn differential_evolution( + objective: impl Fn(&[f64; N]) -> f64, + bounds: &[(f64, f64); N], + tol: f64, + max_iter: usize, + pop_factor: usize, + mutation: (f64, f64), + crossover: f64, + seed: u64, + inits: &[[f64; N]], +) -> ([f64; N], f64, usize) { + const MAX_STAGNANT_GENERATIONS: usize = 25; + + let (f_lo, f_hi) = mutation; + + let mut rng = SmallRng::seed_from_u64(seed); + let n = N; + let np_pop = (pop_factor * n).max(inits.len() + 4); + + let lo: [f64; N] = core::array::from_fn(|i| bounds[i].0); + let hi: [f64; N] = core::array::from_fn(|i| bounds[i].1); + let span: [f64; N] = core::array::from_fn(|i| hi[i] - lo[i]); + + // Initialize population, seeding the first members with the provided + // initial guesses (clamped to bounds). + let mut pop: Vec<[f64; N]> = (0..np_pop) + .map(|_| core::array::from_fn(|i| lo[i] + rng.r#gen::() * span[i])) + .collect(); + for (member, init) in pop.iter_mut().zip(inits) { + *member = core::array::from_fn(|i| init[i].clamp(lo[i], hi[i])); + } + let mut fitness: Vec = pop.iter().map(|p| objective(p)).collect(); + let mut nfev = np_pop; + let mut stagnant_generations = 0usize; + + for _ in 0..max_iter { + let mut improved = false; + + for i in 0..np_pop { + // Pick three distinct random candidates (not i) + let mut candidates: Vec = (0..np_pop).filter(|&j| j != i).collect(); + for k in 0..3 { + let idx = rng.gen_range(k..candidates.len()); + candidates.swap(k, idx); + } + let a = candidates[0]; + let b = candidates[1]; + let c = candidates[2]; + + let f = f_lo + rng.r#gen::() * (f_hi - f_lo); + + let mutant: [f64; N] = core::array::from_fn(|j| { + (pop[a][j] + f * (pop[b][j] - pop[c][j])).clamp(lo[j], hi[j]) + }); + + let mut cross_mask: [bool; N] = + core::array::from_fn(|_| rng.r#gen::() < crossover); + // ensure at least one dimension is taken from mutant + cross_mask[rng.gen_range(0..n)] = true; + + let trial: [f64; N] = + core::array::from_fn(|j| if cross_mask[j] { mutant[j] } else { pop[i][j] }); + + let f_trial = objective(&trial); + nfev += 1; + + if f_trial <= fitness[i] { + pop[i] = trial; + fitness[i] = f_trial; + improved = true; + } + } + + let best_idx = argmin(&fitness); + if improved { + stagnant_generations = 0; + } else { + stagnant_generations += 1; + if stagnant_generations >= MAX_STAGNANT_GENERATIONS { + break; + } + } + + // Population spread check + let spread = (fitness + .iter() + .fold(0.0_f64, |acc, &fv| (fv - fitness[best_idx]).abs().max(acc))) + / (1.0 + fitness[best_idx].abs()); + if spread <= tol { + break; + } + } + + let best_idx = argmin(&fitness); + (pop[best_idx], fitness[best_idx], nfev) +} + +fn argmin(slice: &[f64]) -> usize { + slice + .iter() + .enumerate() + .min_by(|(_, a), (_, b)| a.total_cmp(b)) + .map(|(i, _)| i) + .unwrap_or(0) +} diff --git a/crates/nig/src/lib.rs b/crates/nig/src/lib.rs new file mode 100644 index 00000000..6694b0b0 --- /dev/null +++ b/crates/nig/src/lib.rs @@ -0,0 +1,5 @@ +mod bessel; +mod differential_evo; +mod nelder_mead; +mod quad; +pub mod nig; diff --git a/crates/nig/src/nelder_mead.rs b/crates/nig/src/nelder_mead.rs new file mode 100644 index 00000000..64e29f96 --- /dev/null +++ b/crates/nig/src/nelder_mead.rs @@ -0,0 +1,102 @@ +/// Nelder, J. A. and Mead, R. A Simplex Method for Function Minimization. +/// The Computer Journal 7, 308-313 (1965). +pub fn nelder_mead( + objective: impl Fn(&[f64; N]) -> f64, + start: &[f64; N], + max_iter: usize, + tol: f64, +) -> ([f64; N], f64, usize) { + const ALPHA: f64 = 1.0; // reflection + const GAMMA: f64 = 2.0; // expansion + const RHO: f64 = 0.5; // contraction + const SIGMA: f64 = 0.5; // shrink + + // Initial simplex: perturb each coordinate (scipy-style). + let mut simplex: Vec<[f64; N]> = Vec::with_capacity(N + 1); + simplex.push(*start); + for i in 0..N { + let mut vertex = *start; + if vertex[i] != 0.0 { + vertex[i] *= 1.05; + } else { + vertex[i] = 0.00025; + } + simplex.push(vertex); + } + + let mut values: Vec = simplex.iter().map(&objective).collect(); + let mut nfev = N + 1; + + for _ in 0..max_iter { + // Sort vertices by objective value. + let mut order: Vec = (0..=N).collect(); + order.sort_by(|&a, &b| values[a].total_cmp(&values[b])); + let simplex_sorted: Vec<[f64; N]> = order.iter().map(|&i| simplex[i]).collect(); + let values_sorted: Vec = order.iter().map(|&i| values[i]).collect(); + simplex = simplex_sorted; + values = values_sorted; + + let best = values[0]; + let worst = values[N]; + if (worst - best).abs() <= tol * (1.0 + best.abs()) { + break; + } + + // Centroid of all vertices except the worst. + let centroid: [f64; N] = + core::array::from_fn(|j| simplex[..N].iter().map(|v| v[j]).sum::() / N as f64); + + let reflected: [f64; N] = + core::array::from_fn(|j| centroid[j] + ALPHA * (centroid[j] - simplex[N][j])); + let f_reflected = objective(&reflected); + nfev += 1; + + if f_reflected < values[0] { + // Try expanding further in the same direction. + let expanded: [f64; N] = + core::array::from_fn(|j| centroid[j] + GAMMA * (reflected[j] - centroid[j])); + let f_expanded = objective(&expanded); + nfev += 1; + + if f_expanded < f_reflected { + simplex[N] = expanded; + values[N] = f_expanded; + } else { + simplex[N] = reflected; + values[N] = f_reflected; + } + } else if f_reflected < values[N - 1] { + simplex[N] = reflected; + values[N] = f_reflected; + } else { + // Contract towards the centroid. + let contracted: [f64; N] = + core::array::from_fn(|j| centroid[j] + RHO * (simplex[N][j] - centroid[j])); + let f_contracted = objective(&contracted); + nfev += 1; + + if f_contracted < values[N] { + simplex[N] = contracted; + values[N] = f_contracted; + } else { + // Shrink the whole simplex towards the best vertex. + for i in 1..=N { + simplex[i] = core::array::from_fn(|j| { + simplex[0][j] + SIGMA * (simplex[i][j] - simplex[0][j]) + }); + values[i] = objective(&simplex[i]); + } + nfev += N; + } + } + } + + let mut best_idx = 0; + for i in 1..=N { + if values[i] < values[best_idx] { + best_idx = i; + } + } + + (simplex[best_idx], values[best_idx], nfev) +} diff --git a/crates/nig/src/nig.rs b/crates/nig/src/nig.rs new file mode 100644 index 00000000..b3f368dd --- /dev/null +++ b/crates/nig/src/nig.rs @@ -0,0 +1,324 @@ +use serde::Serialize; + +use crate::bessel::bessel_k1e; +use crate::differential_evo::differential_evolution; +use crate::nelder_mead::nelder_mead; +use crate::quad; + +/// NIG distribution parameters (scipy loc-scale parameterization). +#[derive(Debug, Clone, Copy, Serialize)] +pub struct NigParams { + pub a: f64, + pub b: f64, + pub loc: f64, + pub scale: f64, +} + +#[derive(Debug, Clone, Copy)] +pub(crate) struct NigParamsReparametrized { + pub log_a: f64, + pub skew_raw: f64, + pub loc: f64, + pub log_scale: f64, +} + +pub fn pdf(p: &NigParams, x: f64) -> f64 { + if p.a <= 0.0 || p.scale <= 0.0 || p.b.abs() >= p.a { + return 0.0; + } + + let gamma = (p.a * p.a - p.b * p.b).sqrt(); + let z = (x - p.loc) / p.scale; + let sqrt_z2p1 = (z * z + 1.0).sqrt(); + let y = p.a * sqrt_z2p1; + let scaled_bessel = bessel_k1e(y); + + if scaled_bessel <= 0.0 { + return 0.0; + } + + let net_exp = gamma + p.b * z - y; + let log_pdf = p.a.ln() - std::f64::consts::PI.ln() - p.scale.ln() - sqrt_z2p1.ln() + + net_exp + + scaled_bessel.ln(); + + if log_pdf < -745.0 { + // exp(-745) underflows to 0 in f64 + return 0.0; + } + + log_pdf.exp() +} + +pub fn cdf(p: &NigParams, x: f64) -> f64 { + if p.a <= 0.0 || p.scale <= 0.0 || p.b.abs() >= p.a { + return 0.0; + } + + // The exp-sinh quadrature clusters its nodes near the finite endpoint, so + // always integrate the side of the distribution whose mass lies closest to + // `x`. + // `loc + scale * b / a` is a cheap proxy for the mode: exact for b = 0 and + // bounded by `loc ± scale`, unlike the mean which diverges as |b| -> a. + let peak = p.loc + p.scale * p.b / p.a; + + if x <= peak { + quad::quad(&mut |t| pdf(p, t), f64::NEG_INFINITY, x, 7, 1e-10, None).clamp(0.0, 1.0) + } else { + 1.0 - quad::quad(&mut |t| pdf(p, t), x, f64::INFINITY, 7, 1e-10, None).clamp(0.0, 1.0) + } +} + +pub fn sf(p: &NigParams, x: f64) -> f64 { + 1.0 - cdf(p, x) +} + +fn encode_nig_params(p: &NigParams) -> NigParamsReparametrized { + let safe_a = p.a.max(1e-6); + let safe_scale = p.scale.max(1e-6); + let beta_ratio = (p.b / safe_a).clamp(-1.0 + 1e-12, 1.0 - 1e-12); + NigParamsReparametrized { + log_a: safe_a.ln(), + skew_raw: beta_ratio.atanh(), + loc: p.loc, + log_scale: safe_scale.ln(), + } +} + +fn decode_nig_params(pr: &NigParamsReparametrized) -> NigParams { + let a = pr.log_a.exp(); + let b = a * pr.skew_raw.tanh(); + let scale = pr.log_scale.exp(); + NigParams { a, b, loc: pr.loc, scale } +} + +/// Moment-based initial parameter estimate for the NIG distribution. +fn estimate_nig_start(times: &[f64]) -> NigParams { + let n = times.len() as f64; + let mean = times.iter().sum::() / n; + let m2 = times.iter().map(|&t| (t - mean).powi(2)).sum::() / n; + + if m2 < 1e-12 { + return NigParams { a: 1.0, b: 0.0, loc: mean, scale: 1.0 }; + } + + let m3 = times.iter().map(|&t| (t - mean).powi(3)).sum::() / n; + let m4 = times.iter().map(|&t| (t - mean).powi(4)).sum::() / n; + let skewness = m3 / m2.powf(1.5); + let excess_kurtosis = m4 / (m2 * m2) - 3.0; + let denominator = excess_kurtosis - (4.0 * skewness * skewness) / 3.0; + + if denominator <= 1e-8 { + let scale = m2.sqrt(); + return NigParams { + a: (1.0 / scale).max(1e-3), + b: 0.0, + loc: mean, + scale, + }; + } + + let delta_gamma = 3.0 / denominator; + let beta_ratio = (skewness * delta_gamma.sqrt() / 3.0).clamp(-1.0 + 1e-8, 1.0 - 1e-8); + let cos_theta = (1.0 - beta_ratio * beta_ratio).max(1e-12).sqrt(); + let alpha = (delta_gamma / (m2 * cos_theta.powi(4))).max(1e-12).sqrt(); + let beta = alpha * beta_ratio; + let scale = (delta_gamma / (alpha * cos_theta)).max(1e-12); + let loc = mean - scale * beta_ratio / cos_theta; + + NigParams { a: alpha, b: beta, loc, scale } +} + +fn neg_log_likelihood(times: &[f64], p: &NigParams) -> f64 { + if p.a <= 0.0 || p.scale <= 0.0 || p.b.abs() >= p.a { + return f64::INFINITY; + } + + let gamma = (p.a * p.a - p.b * p.b).sqrt(); + let mut nll = 0.0; + + for &x in times { + let z = (x - p.loc) / p.scale; + let sqrt_z2p1 = (z * z + 1.0).sqrt(); + let y = p.a * sqrt_z2p1; + let scaled_bessel = bessel_k1e(y); + if scaled_bessel <= 0.0 { + return f64::INFINITY; + } + let log_pdf = + p.a.ln() - std::f64::consts::PI.ln() - p.scale.ln() - sqrt_z2p1.ln() + gamma + p.b * z + - y + + scaled_bessel.ln(); + nll -= log_pdf; + } + + nll +} + +fn de_bounds(times: &[f64]) -> [(f64, f64); 4] { + let n = times.len() as f64; + let mean = times.iter().sum::() / n; + let variance = times.iter().map(|&t| (t - mean).powi(2)).sum::() / n; + let std = variance.sqrt().max(1e-6); + let data_range = (times.iter().fold(f64::INFINITY, |a, &b| a.min(b)) + - times.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b))) + .abs() + .max(1e-6); + [ + (-2.0, 12.0), + (-8.0, 8.0), + (mean - 5.0 * std, mean + 5.0 * std), + ((data_range / 100.0).ln(), (data_range * 100.0).ln()), + ] +} + +fn optimize(times: &[f64], inits: &[NigParams]) -> Result<(NigParams, usize), ()> { + const MAX_ITER: usize = 1000; + const TOL: f64 = 1e-4; + const POLISH_MAX_ITER: usize = 2000; + const POLISH_TOL: f64 = 1e-10; + + let bounds = de_bounds(times); + + let objective = |values: &[f64; 4]| -> f64 { + let pr = NigParamsReparametrized { + log_a: values[0], + skew_raw: values[1], + loc: values[2], + log_scale: values[3], + }; + neg_log_likelihood(times, &decode_nig_params(&pr)) + }; + + let init_points: Vec<[f64; 4]> = inits + .iter() + .map(|p| { + let pr = encode_nig_params(p); + [pr.log_a, pr.skew_raw, pr.loc, pr.log_scale] + }) + .collect(); + + let (mut optimum, best_ll, mut nfev) = differential_evolution( + objective, + &bounds, + TOL, + MAX_ITER, + 15, + (0.5, 1.0), + 0.7, + 0, + &init_points, + ); + + if !best_ll.is_finite() { + return Err(()); + } + + let (polished, polished_ll, polish_nfev) = + nelder_mead(objective, &optimum, POLISH_MAX_ITER, POLISH_TOL); + nfev += polish_nfev; + + if polished_ll.is_finite() && polished_ll < best_ll { + optimum = polished; + } + + let pr = NigParamsReparametrized { + log_a: optimum[0], + skew_raw: optimum[1], + loc: optimum[2], + log_scale: optimum[3], + }; + + Ok((decode_nig_params(&pr), nfev)) +} + +pub fn fit(times: &[f64], params: Option) -> NigParams { + fit_with_stats(times, params).0 +} + +pub fn fit_with_stats(times: &[f64], params: Option) -> (NigParams, usize) { + let moment_estimate = estimate_nig_start(times); + + let mut inits = vec![moment_estimate]; + if let Some(prev) = params + && prev.a > 0.0 + && prev.scale > 0.0 + && prev.b.abs() < prev.a + { + inits.push(prev); + } + + match optimize(times, &inits) { + Ok((optimized, nfev)) => (optimized, nfev), + Err(()) => { + tracing::warn!( + samples = times.len(), + "NIG optimization failed; using initial estimates", + ); + (moment_estimate, 0) + }, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn assert_rel_close(actual: f64, expected: f64, tolerance: f64) { + let rel_error = if expected == 0.0 { + actual.abs() + } else { + (actual - expected).abs() / expected.abs() + }; + assert!( + rel_error <= tolerance, + "expected {expected:.15e}, got {actual:.15e}, rel error {rel_error:.2e}", + ); + } + + fn assert_abs_close(actual: f64, expected: f64, tolerance: f64) { + let abs_error = (actual - expected).abs(); + assert!( + abs_error <= tolerance, + "expected {expected:.15e}, got {actual:.15e}, abs error {abs_error:.2e}", + ); + } + + #[test] + fn pdf_matches_reference_values() { + let p = NigParams { + a: 33.53900289787477, + b: 33.52140111667502, + loc: 6.3663207368487065, + scale: 0.4480388195262859, + }; + + for (x, expected) in [ + (7.648, 9.314339782198335e-03), + (8.0, 2.138356268395934e-02), + (10.0, 7.240000069597700e-02), + (20.0, 3.070336727949191e-02), + ] { + assert_rel_close(pdf(&p, x), expected, 1e-10); + } + } + + #[test] + fn sf_matches_reference_values() { + let p = NigParams { + a: 33.53900289787477, + b: 33.52140111667502, + loc: 6.3663207368487065, + scale: 0.4480388195262859, + }; + + for (x, expected) in [ + (7.0, 9.999892785756547e-01), + (7.648, 9.979326056403205e-01), + (10.0, 8.873317615160712e-01), + (20.0, 3.429376452167427e-01), + ] { + assert_abs_close(sf(&p, x), expected, 1e-5); + } + } +} diff --git a/crates/nig/src/quad.rs b/crates/nig/src/quad.rs new file mode 100644 index 00000000..7970c4a8 --- /dev/null +++ b/crates/nig/src/quad.rs @@ -0,0 +1,269 @@ +// BSD 3-Clause License + +// Copyright (c) 2022, Robert A. van Engelen +// All rights reserved. + +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: + +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. + +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. + +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. + +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// article: https://www.genivia.com/qthsh.html +// original code: https://github.com/Robert-van-Engelen/Tanh-Sinh + +const FUDGE1: f64 = 10.0; +const FUDGE2: f64 = 1.0; + +fn exp_sinh_opt_d(f: &mut impl FnMut(f64) -> f64, a: f64, eps: f64, mut d: f64) -> f64 { + let mut _ev = 2; + // const base = 2; // 2 or 3 or exp(1) for example + let h2 = f(a + d / 2.0) - f(a + d * 2.0) * 4.0; + let mut i = 1; + let mut j = 32; // j=32 is optimal to search for r + + if h2.is_finite() && h2.abs() > 1e-5 { + // if |h2| > 2^-16 + let mut fl: f64; + let mut fr: f64; + let mut h: f64; + let mut s = 0.0; + let mut lfl: f64; + let mut lfr: f64; + let mut lr = 2.0; + + // find max j such that fl and fr are finite + loop { + j /= 2; + let r = (1 << (i + j)) as f64; + fl = f(a + d / r); + fr = f(a + d * r) * r * r; + _ev += 2; + h = fl - fr; + if j <= 1 || h.is_finite() { + break; + } + } + + if j > 1 && h.is_finite() && h.signum() != h2.signum() { + lfl = fl; // last fl=f(a+d/r) + lfr = fr; // last fr=f(a+d*r)*r*r + + // bisect in 4 iterations + loop { + j /= 2; + let r = (1 << (i + j)) as f64; + fl = f(a + d / r); + fr = f(a + d * r) * r * r; + _ev += 2; + h = fl - fr; + if h.is_finite() { + s += h.abs(); // sum |h| to remove noisy cases + if h.signum() == h2.signum() { + i += j; // search right half + } else { + // search left half + lfl = fl; // record last fl=f(a+d/r) + lfr = fr; // record last fr=f(a+d*r)*r*r + lr = r; // record last r + } + } + if j <= 1 { + break; + } + } + + if s > eps { + // if sum of |h| > eps + h = lfl - lfr; // use last fl and fr before the sign change + let mut r = lr; // use last r before the sign change + if h != 0.0 { + // if last difference was nonzero, back up r by one step + r /= 2.0; + } + if lfl.abs() < lfr.abs() { + d /= r; // move d closer to the finite endpoint + } else { + d *= r; // move d closer to the infinite endpoint + } + } + } + } + + d +} + +/// Integrate function `f` over the range `a..b`. +/// +/// `n` is the max number of levels (2 to 7, 6 is recommended). +/// `eps` is the relative error tolerance. +/// If `err` is `Some`, the estimated relative error is written to it. +pub fn quad(f: &mut impl FnMut(f64) -> f64, a: f64, b: f64, n: i32, eps: f64, err: Option<&mut f64>) -> f64 { + let tol = FUDGE1 * eps; + let mut c = 0.0; + let mut d = 1.0; + let mut sign = 1.0; + let mut h: f64 = 2.0; + let mut k = 0; + let mut mode = 0; // Tanh-Sinh = 0, Exp-Sinh = 1, Sinh-Sinh = 2 + + let (mut a, mut b) = (a, b); + if b < a { + // swap bounds + let v = b; + b = a; + a = v; + sign = -1.0; + } + + let mut v: f64; + if a.is_finite() && b.is_finite() { + c = (a + b) / 2.0; + d = (b - a) / 2.0; + v = c; + } else if a.is_finite() { + mode = 1; // Exp-Sinh + d = exp_sinh_opt_d(f, a, eps, d); + c = a; + v = a + d; + } else if b.is_finite() { + mode = 1; // Exp-Sinh + d = exp_sinh_opt_d(f, b, eps, -d); + sign = -sign; + c = b; + v = b + d; + } else { + mode = 2; // Sinh-Sinh + v = 0.0; + } + + let mut s = f(v); + + loop { + let mut p = 0.0; + let mut q: f64; + let mut fp = 0.0; + let mut fm = 0.0; + let mut t: f64; + let mut eh: f64; + + h /= 2.0; + t = h.exp(); + eh = t; + if k > 0 { + eh *= eh; + } + + if mode == 0 { + // Tanh-Sinh + loop { + let u = (1.0 / t - t).exp(); // = exp(-2*sinh(j*h)) = 1/exp(sinh(j*h))^2 + let r = 2.0 * u / (1.0 + u); // = 1 - tanh(sinh(j*h)) + let w = (t + 1.0 / t) * r / (1.0 + u); // = cosh(j*h)/cosh(sinh(j*h))^2 + let x = d * r; + + if a + x > a { + // if too close to a then reuse previous fp + let y = f(a + x); + if y.is_finite() { + fp = y; // if f(x) is finite, add to local sum + } + } + if b - x < b { + // if too close to b then reuse previous fm + let y = f(b - x); + if y.is_finite() { + fm = y; // if f(x) is finite, add to local sum + } + } + + q = w * (fp + fm); + p += q; + t *= eh; + if q.abs() <= eps * p.abs() { + break; + } + } + } else { + t /= 2.0; + loop { + let mut r = (t - 0.25 / t).exp(); // = exp(sinh(j*h)) + let mut x: f64; + let mut y: f64; + let mut w = r; + q = 0.0; + + if mode == 1 { + // Exp-Sinh + x = c + d / r; + if x == c { + // if x hit the finite endpoint then break + break; + } + y = f(x); + if y.is_finite() { + // if f(x) is finite, add to local sum + q += y / w; + } + } else { + // Sinh-Sinh + r = (r - 1.0 / r) / 2.0; // = sinh(sinh(j*h)) + w = (w + 1.0 / w) / 2.0; // = cosh(sinh(j*h)) + x = c - d * r; + y = f(x); + if y.is_finite() { + // if f(x) is finite, add to local sum + q += y * w; + } + } + + x = c + d * r; + y = f(x); + if y.is_finite() { + // if f(x) is finite, add to local sum + q += y * w; + } + q *= t + 0.25 / t; // q *= cosh(j*h) + p += q; + t *= eh; + if q.abs() <= eps * p.abs() { + break; + } + } + } + + v = s - p; + s += p; + k += 1; + if v.abs() <= tol * s.abs() || k > n { + break; + } + } + + // if the estimated relative error is desired, then return it + if let Some(err) = err { + *err = v.abs() / (FUDGE2 * s.abs() + eps); + } + + // result with estimated relative error err + sign * d * s * h +} diff --git a/cs2kz-api.example.toml b/cs2kz-api.example.toml index e722fcf7..b6e34e55 100644 --- a/cs2kz-api.example.toml +++ b/cs2kz-api.example.toml @@ -106,7 +106,3 @@ min-connections = 1 # # 0 will let the API choose an amount max-connections = 0 - -[points] -calc-filter-path = "scripts/calc_filter.py" -calc-run-path = "scripts/calc_run.py" diff --git a/scripts/.gitignore b/scripts/.gitignore deleted file mode 100644 index 225fc6f6..00000000 --- a/scripts/.gitignore +++ /dev/null @@ -1 +0,0 @@ -/__pycache__ diff --git a/scripts/calc_filter.py b/scripts/calc_filter.py deleted file mode 100644 index 290b8127..00000000 --- a/scripts/calc_filter.py +++ /dev/null @@ -1,410 +0,0 @@ -# SPDX-License-Identifier: AGPL-3.0-only -# Copyright (C) zer0.k, AlphaKeks - -import common -import json -import mariadb -import numpy as np -import os -import sys -import time -import traceback - -from scipy import stats -from typing import Any, List, Tuple -from urllib.parse import urlparse - -def warn(msg): - sys.stderr.write(json.dumps({'warning': msg}) + '\n') - -def open_database_conn(): - DATABASE_URL = os.getenv('DATABASE_URL', 'mysql://schnose:csgo-kz-is-dead-boys@127.0.0.1:3306/cs2kz') - - database_url = urlparse(DATABASE_URL) - return mariadb.connect( - user = database_url.username, - password = database_url.password, - host = database_url.hostname, - port = database_url.port or 3306, - database = database_url.path.lstrip('/'), - reconnect = True - ) - -def process_input(database_conn, line): - """ - Processes a single line read from stdin. - - The line is expected to contain a JSON object with the following keys: - * `filter_id` - ID of the filter to calculate - - An example object could look like this: - ```json - { - "filter_id": 74 - } - ``` - - The function will write a single line response to stdout. - - That line is a JSON object with the following keys: - * `filter_id` - the ID of the calculated filter - * `timings` - an object containing timings of the individual operations - performed - - The `timings` object will contain the following keys: - * `db_query_ms` - the time it took to query the database for the - required information, in milliseconds, as a floating - point number - * `nub_fit_ms` - the time it took to fit the NUB distribution, in - milliseconds, as a floating point number - * `nub_compute_ms` - the time it took to calculate new points for the - NUB leaderboard, in milliseconds, as a floating - point number - * `pro_fit_ms` - the time it took to fit the PRO distribution, in - milliseconds, as a floating point number - * `pro_compute_ms` - the time it took to calculate new points for the - PRO leaderboard, in milliseconds, as a floating - point number - * `db_write_ms` - the time it took to write everything back to the - database, in milliseconds, as a floating point number - """ - - cursor = database_conn.cursor() - - response = { - 'filter_id': None, - 'timings': { - 'db_query_ms': 0.0, - 'nub_fit_ms': 0.0, - 'nub_compute_ms': 0.0, - 'pro_fit_ms': 0.0, - 'pro_compute_ms': 0.0, - 'db_write_ms': 0.0 - } - } - data = json.loads(line) - filter_id = data['filter_id'] - response['filter_id'] = filter_id - - start = time.time() - cursor.execute(""" - SELECT - bnr.record_id, - bnr.time, - bnr.points - FROM - BestNubRecords AS bnr - WHERE - bnr.filter_id = ? - ORDER BY - bnr.time ASC - """, ( - filter_id, - )) - nub_records: List[Tuple[Any, float, float]] = cursor.fetchall() - cursor.execute(""" - SELECT - bpr.record_id, - bpr.time, - bpr.points - FROM - BestProRecords AS bpr - WHERE - bpr.filter_id = ? - ORDER BY - bpr.time ASC - """, ( - filter_id, - )) - pro_records: List[Tuple[Any, float, float]] = cursor.fetchall() - cursor.execute(""" - SELECT - cf.nub_tier, - cf.pro_tier - FROM - CourseFilters cf - WHERE - cf.id = ? - """, ( - filter_id, - )) - filter_row = cursor.fetchone() - - # Fetch previous distribution parameters for warm start - # (both nub and pro in one query) - cursor.execute(""" - SELECT - is_pro_leaderboard, - a, - b, - loc, - scale, - top_scale - FROM - PointDistributionData - WHERE - filter_id = ? - ORDER BY - is_pro_leaderboard - """, ( - filter_id, - )) - dist_params_rows = cursor.fetchall() - - prev_nub_params = None - prev_pro_params = None - for row in dist_params_rows: - if row[0] == 0: # is_pro_leaderboard = 0 (nub) - prev_nub_params = (row[1], row[2], row[3], row[4], row[5]) - elif row[0] == 1: # is_pro_leaderboard = 1 (pro) - prev_pro_params = (row[1], row[2], row[3], row[4], row[5]) - - response['timings']['db_query_ms'] = (time.time() - start) * 1000 - - if filter_row is None: - warn(f'Filter ID {filter_id} not found in CourseFilters.') - return response - - nub_times = [row[1] for row in nub_records] - pro_times = [row[1] for row in pro_records] - nub_tier = filter_row[0] - pro_tier = filter_row[1] - - ''' - There are 3 possible cases: - 1. Less than 50 nub times (and therefore less than 50 pro times as well) - -> do not fit distribution, use sigmoid function - 2. 50 or more nub times but less than 50 pro times - -> fit nub distribution, use sigmoid for pro - 3. 50 or more nub times and 50 or more pro times - -> fit both distributions - - Overall/nub portion only depends on its own distribution. - Pro portion takes the higher of the two distributions - ((un)fitted nub or (un)fitted pro) to avoid situations where pro portion - is lower than nub portion. - - ''' - if len(nub_times) >= 50: - start = time.time() - nub_dist, nub_params = refit_dist(nub_times, prev_nub_params) - response['timings']['nub_fit_ms'] = (time.time() - start) * 1000 - elif len(nub_times) > 0: - nub_dist, nub_params = None, (0,0,0,0,0) - response['timings']['nub_fit_ms'] = 0 - else: - warn(f'No overall records found for filter ID {filter_id}.') - return response - - # Compute nub fractions - start = time.time() - nub_times_array = np.array(nub_times) - new_fractions = common.get_dist_points_portion( - nub_times_array, - nub_times[0], - nub_dist, - nub_tier, - nub_params[4], - len(nub_times) - ) - nub_records = [ - (record_id, time, fraction) - for (record_id, time, _), fraction - in zip(nub_records, new_fractions) - ] - response['timings']['nub_compute_ms'] = (time.time() - start) * 1000 - - # Fit pro distribution - if len(pro_times) >= 50: - start = time.time() - pro_dist, pro_params = refit_dist(pro_times, prev_pro_params) - response['timings']['pro_fit_ms'] = (time.time() - start) * 1000 - elif len(pro_times) > 0: - pro_dist, pro_params = None, (0,0,0,0,0) - - # Compute pro fractions if there are any pro records - if len(pro_times) > 0: - start = time.time() - pro_times_array = np.array(pro_times) - new_fractions = np.maximum( - common.get_dist_points_portion( - pro_times_array, - pro_times[0], - pro_dist, - pro_tier, - pro_params[4], - len(pro_times) - ), - common.get_dist_points_portion( - pro_times_array, - nub_times[0], - nub_dist, - nub_tier, - nub_params[4], - len(nub_times) - ) - ) - pro_records = [ - (record_id, time, fraction) - for (record_id, time, _), fraction - in zip(pro_records, new_fractions) - ] - response['timings']['pro_compute_ms'] = (time.time() - start) * 1000 - - # Database write timing - start = time.time() - if len(nub_records) > 0: - cursor.executemany(""" - UPDATE BestNubRecords SET - points = ? - WHERE - record_id = ? - """, [ - (points, record_id) - for record_id, time, points - in nub_records - ]) - if len(pro_records) > 0: - cursor.executemany(""" - UPDATE BestProRecords SET - points = ? - WHERE - record_id = ? - """, [ - (points, record_id) - for record_id, time, points - in pro_records - ]) - if len(nub_times) >= 50: - cursor.execute(""" - INSERT INTO PointDistributionData ( - filter_id, - is_pro_leaderboard, - a, - b, - loc, - scale, - top_scale - ) VALUES ( - ?, - 0, - ?, - ?, - ?, - ?, - ? - ) - ON DUPLICATE KEY UPDATE - a = VALUES(a), - b = VALUES(b), - loc = VALUES(loc), - scale = VALUES(scale), - top_scale = VALUES(top_scale) - """, ( - filter_id, - nub_params[0], - nub_params[1], - nub_params[2], - nub_params[3], - nub_params[4] - )) - if len(pro_times) >= 50: - cursor.execute(""" - INSERT INTO PointDistributionData ( - filter_id, - is_pro_leaderboard, - a, - b, - loc, - scale, - top_scale - ) VALUES ( - ?, - 1, - ?, - ?, - ?, - ?, - ? - ) - ON DUPLICATE KEY UPDATE - a = VALUES(a), - b = VALUES(b), - loc = VALUES(loc), - scale = VALUES(scale), - top_scale = VALUES(top_scale) - """, ( - filter_id, - pro_params[0], - pro_params[1], - pro_params[2], - pro_params[3], - pro_params[4] - )) - database_conn.commit() - response['timings']['db_write_ms'] = (time.time() - start) * 1000 - response['timings']['total_ms'] = sum(response['timings'].values()) - return response - -def refit_dist(times, prev_params=None): - if prev_params is not None: - # Use previous parameters as initial guess for faster convergence - a_init, b_init, loc_init, scale_init, _ = prev_params - norminvgauss_params = stats.norminvgauss.fit( - times, - a_init, b_init, - loc=loc_init, - scale=scale_init - ) - else: - # Cold start - no initial parameters - norminvgauss_params = stats.norminvgauss.fit(times) - - norminvgauss_dist = stats.norminvgauss(*norminvgauss_params) - top_scale = norminvgauss_dist.sf(times[0]) - # Sanity safeguard - if top_scale <= 0: - warn('Fitted top_scale <= 0, resetting to 1') - top_scale = 1 - - a, b, loc, scale = norminvgauss_params - return norminvgauss_dist, (a, b, loc, scale, top_scale) - -if __name__ == '__main__': - database_conn = None - - try: - database_conn = open_database_conn() - except mariadb.Error as e: - sys.stderr.write(f'Error connecting to database: {e}\n') - sys.stderr.write(traceback.format_exc() + '\n') - sys.exit(1) - - for line in sys.stdin: - try: - response = process_input(database_conn, line) - sys.stderr.flush() - sys.stdout.write(json.dumps(response) + '\n') - sys.stdout.flush() - except KeyError as e: - sys.stderr.write(f'Missing key in input data: {e}\n') - sys.stderr.write(traceback.format_exc() + '\n') - sys.exit(1) - except json.JSONDecodeError as e: - sys.stderr.write(f'JSON decode error: {e}\n') - sys.stderr.write(traceback.format_exc() + '\n') - sys.exit(1) - except mariadb.Error as e: - sys.stderr.write(f'Database error: {e}\n') - sys.stderr.write(traceback.format_exc() + '\n') - sys.exit(1) - except Exception as e: - sys.stderr.write(f'An unexpected error occurred: {e}\n') - sys.stderr.write(traceback.format_exc() + '\n') - sys.exit(1) - - try: - database_conn.close() - except mariadb.Error as e: - sys.stderr.write(f'Failed to close database connection: {e}\n') - sys.stderr.write(traceback.format_exc() + '\n') diff --git a/scripts/calc_run.py b/scripts/calc_run.py deleted file mode 100644 index ebf967d4..00000000 --- a/scripts/calc_run.py +++ /dev/null @@ -1,138 +0,0 @@ -# SPDX-License-Identifier: AGPL-3.0-only -# Copyright (C) zer0.k, AlphaKeks - -import common -import json -import scipy.stats as stats -import sys -import sys -import traceback - -def process_input(line): - """ - Processes a single line read from stdin. - - The line is expected to contain a JSON object with the following keys: - * `time` - the time of the run, in seconds, as a floating point number - * `nub_data` - an object containing information about the nub - leaderboard the run belongs to - * `pro_data` - an object containing information about the pro - leaderboard the run belongs to - - Both `nub_data` and `pro_data` should contain the following keys: - * `tier` - the filter tier - * `wr` - the time of the world record run, in seconds, as a floating - point number - * `leaderboard_size` - the number of runs on the leaderboard - * `dist_params` - distribution parameters calculated by `calc_filter.py` - - An example object could look like this: - ```json - { - "time": 8.609375, - "nub_data": { - "tier": 1, - "wr": 7.6484375, - "leaderboard_size": 224, - "dist_params": { - "a": 33.53900289787477, - "b": 33.52140111667502, - "loc": 6.3663207368487065, - "scale": 0.4480388195262859, - "top_scale": 0.9979285278452101 - } - }, - "pro_data": { - "tier": 1, - "wr": 7.6484375, - "leaderboard_size": 165, - "dist_params": { - "a": 2.6294814553333743, - "b": 2.511121972118702, - "loc": 8.713014153227697, - "scale": 2.2226724397990805, - "top_scale": 0.9952929135343108 - } - } - } - ``` - - The function will write a single line response to stdout. - - That line is a JSON object with the following keys: - * `nub_fraction` - a floating point number - * `pro_fraction` - a floating point number - - An example object could look like this: - ```json - { - "nub_fraction": 0.9745534941686896, - "pro_fraction": 0.9760910013054752 - } - ``` - """ - - data = json.loads(line) - - nub_data = data['nub_data'] - nub_dist = stats.norminvgauss( - a = nub_data['dist_params']['a'], - b = nub_data['dist_params']['b'], - loc = nub_data['dist_params']['loc'], - scale = nub_data['dist_params']['scale'] - ) - nub_fraction = common.get_dist_points_portion(data['time'], - nub_data['wr'], - nub_dist, - nub_data['tier'], - nub_data['dist_params']['top_scale'], - nub_data['leaderboard_size']) - if 'pro_data' in data and data['pro_data'] is not None: - pro_data = data['pro_data'] - pro_dist = stats.norminvgauss( - a = pro_data['dist_params']['a'], - b = pro_data['dist_params']['b'], - loc = pro_data['dist_params']['loc'], - scale = pro_data['dist_params']['scale'] - ) - pro_fraction = common.get_dist_points_portion(data['time'], - pro_data['wr'], - pro_dist, - pro_data['tier'], - pro_data['dist_params']['top_scale'], - pro_data['leaderboard_size']) - response = { - 'nub_fraction': nub_fraction, - # Pro run in the pro leaderboard should never be worth less than - # the same run in the nub leaderboard. - 'pro_fraction': max(nub_fraction, pro_fraction) - } - return response - response = { - 'nub_fraction': nub_fraction, - 'pro_fraction': None - } - return response - -def main(): - for line in sys.stdin: - try: - response = process_input(line) - sys.stderr.flush() - sys.stdout.write(json.dumps(response) + '\n') - sys.stdout.flush() - except KeyError as e: - sys.stderr.write(f'Missing key in input data: {e}\n') - sys.stderr.write(traceback.format_exc() + '\n') - sys.exit(1) - except json.JSONDecodeError as e: - sys.stderr.write(f'JSON decode error: {e}\n') - sys.stderr.write(traceback.format_exc() + '\n') - sys.exit(1) - except Exception as e: - sys.stderr.write(f'An unexpected error occurred: {e}\n') - sys.stderr.write(traceback.format_exc() + '\n') - sys.exit(1) - -if __name__ == '__main__': - main() diff --git a/scripts/common.py b/scripts/common.py deleted file mode 100644 index 63284150..00000000 --- a/scripts/common.py +++ /dev/null @@ -1,15 +0,0 @@ -# SPDX-License-Identifier: AGPL-3.0-only -# Copyright (C) zer0.k, AlphaKeks - -import numpy as np - -from scipy import stats - -def get_distribution_points_portion_under_50(time, wr_time, tier): - return ((1+np.exp((2.1 - 0.25 * tier) * -0.5))/(1+np.exp((2.1 - 0.25 * tier) * (time/wr_time-1.5)))) - -def get_dist_points_portion(time, wr_time, dist: stats.rv_continuous, tier, top_scale, total): - if total < 50: - return get_distribution_points_portion_under_50(time, wr_time, tier) - else: - return np.clip(dist.sf(time) / top_scale, 0, 1) diff --git a/scripts/requirements.txt b/scripts/requirements.txt deleted file mode 100644 index 58ce7356..00000000 --- a/scripts/requirements.txt +++ /dev/null @@ -1,2 +0,0 @@ -scipy -mariadb