From e0d50e00494bb33081c470f3b96328bf33668a58 Mon Sep 17 00:00:00 2001 From: Joseph Lemaitre Date: Tue, 28 Apr 2026 14:17:55 -0400 Subject: [PATCH] scoring- test with scroring utils --- .github/workflows/scoringutils-wis-parity.yml | 39 +++++ app/package.json | 1 + app/scripts/test-scoringutils-wis.mjs | 158 ++++++++++++++++++ app/src/lib/forecast-components/scoring.js | 103 +----------- app/src/utils/forecastleScoring.js | 40 +++-- 5 files changed, 228 insertions(+), 113 deletions(-) create mode 100644 .github/workflows/scoringutils-wis-parity.yml create mode 100644 app/scripts/test-scoringutils-wis.mjs diff --git a/.github/workflows/scoringutils-wis-parity.yml b/.github/workflows/scoringutils-wis-parity.yml new file mode 100644 index 00000000..cdacf395 --- /dev/null +++ b/.github/workflows/scoringutils-wis-parity.yml @@ -0,0 +1,39 @@ +name: Scoringutils WIS Parity + +on: + pull_request: + push: + branches: + - main + +jobs: + scoringutils-wis-parity: + runs-on: ubuntu-latest + defaults: + run: + working-directory: app + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Setup Node.js + uses: actions/setup-node@v4 + with: + node-version: 'lts/*' + cache: 'npm' + cache-dependency-path: app/package-lock.json + + - name: Set up R + uses: r-lib/actions/setup-r@v2 + with: + use-public-rspm: true + + - name: Install R packages + run: Rscript -e 'install.packages("scoringutils", repos = "https://cloud.r-project.org")' + + - name: Install app dependencies + run: npm ci + + - name: Run scoringutils WIS parity test + run: npm run test:scoringutils diff --git a/app/package.json b/app/package.json index 08fb6618..9de8e8b0 100644 --- a/app/package.json +++ b/app/package.json @@ -8,6 +8,7 @@ "build": "vite build", "build:staging": "vite build --mode staging", "lint": "eslint .", + "test:scoringutils": "node scripts/test-scoringutils-wis.mjs", "preview": "vite preview", "format": "prettier . --write", "format:check": "prettier . --check --log-level warn", diff --git a/app/scripts/test-scoringutils-wis.mjs b/app/scripts/test-scoringutils-wis.mjs new file mode 100644 index 00000000..ebbeb3dc --- /dev/null +++ b/app/scripts/test-scoringutils-wis.mjs @@ -0,0 +1,158 @@ +import assert from "node:assert/strict"; +import { spawnSync } from "node:child_process"; + +import { calculateWIS } from "../src/utils/forecastleScoring.js"; +import { calculateWIS as calculateSharedWIS } from "../src/lib/forecast-components/scoring.js"; + +const rscriptCandidates = [process.env.SCORINGUTILS_RSCRIPT, "Rscript"].filter( + Boolean, +); + +const findScoringutilsRscript = () => { + for (const rscript of rscriptCandidates) { + const check = spawnSync( + rscript, + [ + "-e", + "quit(status = !requireNamespace('scoringutils', quietly = TRUE))", + ], + { encoding: "utf8" }, + ); + if (check.status === 0) { + return rscript; + } + } + return null; +}; + +const rscript = findScoringutilsRscript(); + +if (!rscript) { + if (process.env.CI) { + throw new Error("No Rscript with scoringutils found."); + } else { + console.log( + "Skipping scoringutils WIS parity test: no Rscript with scoringutils found.", + ); + process.exit(0); + } +} + +const cases = [ + { + observed: 100, + median: 105, + lower50: 90, + upper50: 115, + lower95: 80, + upper95: 140, + }, + { + observed: 70, + median: 105, + lower50: 90, + upper50: 115, + lower95: 80, + upper95: 140, + }, + { + observed: 160, + median: 105, + lower50: 90, + upper50: 115, + lower95: 80, + upper95: 140, + }, +]; + +const rProgram = ` +library(scoringutils) +cases <- list( + list(observed=100, predicted=c(80, 90, 105, 115, 140)), + list(observed=70, predicted=c(80, 90, 105, 115, 140)), + list(observed=160, predicted=c(80, 90, 105, 115, 140)) +) +quantile_level <- c(0.025, 0.25, 0.5, 0.75, 0.975) +for (case in cases) { + result <- wis( + observed = case$observed, + predicted = case$predicted, + quantile_level = quantile_level, + separate_results = TRUE + ) + cat(sprintf( + "%.12f,%.12f,%.12f,%.12f\\n", + result$wis, + result$dispersion, + result$underprediction, + result$overprediction + )) +} +`; + +const result = spawnSync(rscript, ["-e", rProgram], { encoding: "utf8" }); + +if (result.status !== 0) { + throw new Error( + `scoringutils parity reference failed:\n${result.stderr || result.stdout}`, + ); +} + +const expected = result.stdout + .trim() + .split("\n") + .map((line) => { + const [wis, dispersion, underprediction, overprediction] = line + .split(",") + .map(Number); + return { wis, dispersion, underprediction, overprediction }; + }); + +const assertClose = (actual, expectedValue, label) => { + assert.ok( + Math.abs(actual - expectedValue) < 1e-9, + `${label}: expected ${expectedValue}, got ${actual}`, + ); +}; + +cases.forEach((testCase, index) => { + const forecastle = calculateWIS( + testCase.observed, + testCase.median, + testCase.lower50, + testCase.upper50, + testCase.lower95, + testCase.upper95, + ); + const shared = calculateSharedWIS( + testCase.observed, + testCase.median, + testCase.lower50, + testCase.upper50, + testCase.lower95, + testCase.upper95, + ); + + for (const implementation of [forecastle, shared]) { + assertClose(implementation.wis, expected[index].wis, `case ${index} wis`); + assertClose( + implementation.dispersion, + expected[index].dispersion, + `case ${index} dispersion`, + ); + assertClose( + implementation.underprediction, + expected[index].underprediction, + `case ${index} underprediction`, + ); + assertClose( + implementation.overprediction, + expected[index].overprediction, + `case ${index} overprediction`, + ); + } +}); + +assert.equal(calculateWIS(100, NaN, 90, 115, 80, 140), null); + +console.log(`scoringutils WIS parity passed using ${rscript}`); diff --git a/app/src/lib/forecast-components/scoring.js b/app/src/lib/forecast-components/scoring.js index 2a70ebc2..797acdb5 100644 --- a/app/src/lib/forecast-components/scoring.js +++ b/app/src/lib/forecast-components/scoring.js @@ -1,108 +1,11 @@ /** * Reusable Forecast Scoring Utilities - * Extracted from forecastleScoring.js for use across the app + * WIS is re-exported from the canonical Forecastle scoring utility. */ -/** - * Calculate interval score for a single prediction interval - * @param {number} observed - Observed value - * @param {number} lower - Lower bound of prediction interval - * @param {number} upper - Upper bound of prediction interval - * @param {number} alpha - Alpha level (e.g., 0.5 for 50% interval, 0.05 for 95% interval) - * @returns {Object} Interval score with components {score, dispersion, underprediction, overprediction} - */ -const calculateIntervalScore = (observed, lower, upper, alpha) => { - if ( - !Number.isFinite(observed) || - !Number.isFinite(lower) || - !Number.isFinite(upper) - ) { - return null; - } - - const dispersion = upper - lower; - const underprediction = - observed < lower ? (2 / alpha) * (lower - observed) : 0; - const overprediction = - observed > upper ? (2 / alpha) * (observed - upper) : 0; - const score = dispersion + underprediction + overprediction; - - return { - score, - dispersion, - underprediction, - overprediction, - }; -}; - -/** - * Calculate WIS (Weighted Interval Score) for a single forecast - * @param {number} observed - Observed value - * @param {number} median - Median prediction - * @param {number} lower50 - Lower bound of 50% interval (0.25 quantile) - * @param {number} upper50 - Upper bound of 50% interval (0.75 quantile) - * @param {number} lower95 - Lower bound of 95% interval (0.025 quantile) - * @param {number} upper95 - Upper bound of 95% interval (0.975 quantile) - * @returns {Object} WIS with components {wis, dispersion, underprediction, overprediction} - */ -export const calculateWIS = ( - observed, - median, - lower50, - upper50, - lower95, - upper95, -) => { - if (!Number.isFinite(observed)) { - return null; - } - - // Calculate interval scores for each interval - const interval50 = calculateIntervalScore(observed, lower50, upper50, 0.5); - const interval95 = calculateIntervalScore(observed, lower95, upper95, 0.05); - - if (!interval50 || !interval95) { - return null; - } - - // Median absolute error (treated as 0-width interval) - const medianAE = Number.isFinite(median) ? Math.abs(observed - median) : 0; - - // Weights: alpha/2 for each interval, 0.5 for median - const weight50 = 0.5 / 2; // 0.25 - const weight95 = 0.05 / 2; // 0.025 - const weightMedian = 0.5; // 0.5 - - // Weighted sum - const totalWeight = weight50 + weight95 + weightMedian; - const wis = - (weight50 * interval50.score + - weight95 * interval95.score + - weightMedian * medianAE) / - totalWeight; - - // Calculate aggregate components - const dispersion = - (weight50 * interval50.dispersion + weight95 * interval95.dispersion) / - totalWeight; - - const underprediction = - (weight50 * interval50.underprediction + - weight95 * interval95.underprediction) / - totalWeight; - - const overprediction = - (weight50 * interval50.overprediction + - weight95 * interval95.overprediction) / - totalWeight; +import { calculateWIS } from "../../utils/forecastleScoring.js"; - return { - wis, - dispersion, - underprediction, - overprediction, - }; -}; +export { calculateWIS }; /** * Validate forecast intervals diff --git a/app/src/utils/forecastleScoring.js b/app/src/utils/forecastleScoring.js index a67cd913..6bc3918c 100644 --- a/app/src/utils/forecastleScoring.js +++ b/app/src/utils/forecastleScoring.js @@ -89,7 +89,7 @@ export const calculateWIS = ( lower95, upper95, ) => { - if (!Number.isFinite(observed)) { + if (!Number.isFinite(observed) || !Number.isFinite(median)) { return null; } @@ -101,35 +101,40 @@ export const calculateWIS = ( return null; } - // Median absolute error (treated as 0-width interval) - const medianAE = Number.isFinite(median) ? Math.abs(observed - median) : 0; + const medianAE = Math.abs(observed - median); - // Weights: alpha/2 for each interval, 0.5 for median - // Total weight = 0.25 + 0.025 + 0.5 = 0.775 + // Weights match scoringutils::wis(..., weigh = TRUE, + // count_median_twice = FALSE), normalized by K + 0.5. const weight50 = 0.5 / 2; // 0.25 const weight95 = 0.05 / 2; // 0.025 const weightMedian = 0.5; // 0.5 + const normalizer = 2 + 0.5; // two central intervals plus median half-weight // Weighted sum - const totalWeight = weight50 + weight95 + weightMedian; const wis = (weight50 * interval50.score + weight95 * interval95.score + weightMedian * medianAE) / - totalWeight; + normalizer; // Aggregate components const dispersion = (weight50 * interval50.dispersion + weight95 * interval95.dispersion) / - totalWeight; + normalizer; + const medianUnderprediction = + observed > median ? (weightMedian * (observed - median)) / normalizer : 0; + const medianOverprediction = + observed < median ? (weightMedian * (median - observed)) / normalizer : 0; const underprediction = - (weight50 * interval50.underprediction + - weight95 * interval95.underprediction) / - totalWeight; - const overprediction = (weight50 * interval50.overprediction + weight95 * interval95.overprediction) / - totalWeight; + normalizer + + medianUnderprediction; + const overprediction = + (weight50 * interval50.underprediction + + weight95 * interval95.underprediction) / + normalizer + + medianOverprediction; return { wis, @@ -307,6 +312,15 @@ export const scoreModels = (modelForecasts, horizons, groundTruthValues) => { return null; } + if ( + lower95 > lower50 || + lower50 > median || + median > upper50 || + upper50 > upper95 + ) { + return null; + } + return { median, lower50,