Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions .github/workflows/scoringutils-wis-parity.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
name: Scoringutils WIS Parity

on:
pull_request:
push:
branches:
- main

jobs:
scoringutils-wis-parity:
runs-on: ubuntu-latest
defaults:
run:
working-directory: app

steps:
- name: Checkout repository
uses: actions/checkout@v4

- name: Setup Node.js
uses: actions/setup-node@v4
with:
node-version: 'lts/*'
cache: 'npm'
cache-dependency-path: app/package-lock.json

- name: Set up R
uses: r-lib/actions/setup-r@v2
with:
use-public-rspm: true

- name: Install R packages
run: Rscript -e 'install.packages("scoringutils", repos = "https://cloud.r-project.org")'

- name: Install app dependencies
run: npm ci

- name: Run scoringutils WIS parity test
run: npm run test:scoringutils
1 change: 1 addition & 0 deletions app/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"build": "vite build",
"build:staging": "vite build --mode staging",
"lint": "eslint .",
"test:scoringutils": "node scripts/test-scoringutils-wis.mjs",
"preview": "vite preview",
"format": "prettier . --write",
"format:check": "prettier . --check --log-level warn",
Expand Down
158 changes: 158 additions & 0 deletions app/scripts/test-scoringutils-wis.mjs
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
import assert from "node:assert/strict";
import { spawnSync } from "node:child_process";

import { calculateWIS } from "../src/utils/forecastleScoring.js";
import { calculateWIS as calculateSharedWIS } from "../src/lib/forecast-components/scoring.js";

const rscriptCandidates = [process.env.SCORINGUTILS_RSCRIPT, "Rscript"].filter(
Boolean,
);

const findScoringutilsRscript = () => {
for (const rscript of rscriptCandidates) {
const check = spawnSync(
rscript,
[
"-e",
"quit(status = !requireNamespace('scoringutils', quietly = TRUE))",
],
{ encoding: "utf8" },
);
if (check.status === 0) {
return rscript;
}
}
return null;
};

const rscript = findScoringutilsRscript();

if (!rscript) {
if (process.env.CI) {
throw new Error("No Rscript with scoringutils found.");
} else {
console.log(
"Skipping scoringutils WIS parity test: no Rscript with scoringutils found.",
);
process.exit(0);
}
}

const cases = [
{
observed: 100,
median: 105,
lower50: 90,
upper50: 115,
lower95: 80,
upper95: 140,
},
{
observed: 70,
median: 105,
lower50: 90,
upper50: 115,
lower95: 80,
upper95: 140,
},
{
observed: 160,
median: 105,
lower50: 90,
upper50: 115,
lower95: 80,
upper95: 140,
},
];

const rProgram = `
library(scoringutils)
cases <- list(
list(observed=100, predicted=c(80, 90, 105, 115, 140)),
list(observed=70, predicted=c(80, 90, 105, 115, 140)),
list(observed=160, predicted=c(80, 90, 105, 115, 140))
)
quantile_level <- c(0.025, 0.25, 0.5, 0.75, 0.975)
for (case in cases) {
result <- wis(
observed = case$observed,
predicted = case$predicted,
quantile_level = quantile_level,
separate_results = TRUE
)
cat(sprintf(
"%.12f,%.12f,%.12f,%.12f\\n",
result$wis,
result$dispersion,
result$underprediction,
result$overprediction
))
}
`;

const result = spawnSync(rscript, ["-e", rProgram], { encoding: "utf8" });

if (result.status !== 0) {
throw new Error(
`scoringutils parity reference failed:\n${result.stderr || result.stdout}`,
);
}

const expected = result.stdout
.trim()
.split("\n")
.map((line) => {
const [wis, dispersion, underprediction, overprediction] = line
.split(",")
.map(Number);
return { wis, dispersion, underprediction, overprediction };
});

const assertClose = (actual, expectedValue, label) => {
assert.ok(
Math.abs(actual - expectedValue) < 1e-9,
`${label}: expected ${expectedValue}, got ${actual}`,
);
};

cases.forEach((testCase, index) => {
const forecastle = calculateWIS(
testCase.observed,
testCase.median,
testCase.lower50,
testCase.upper50,
testCase.lower95,
testCase.upper95,
);
const shared = calculateSharedWIS(
testCase.observed,
testCase.median,
testCase.lower50,
testCase.upper50,
testCase.lower95,
testCase.upper95,
);

for (const implementation of [forecastle, shared]) {
assertClose(implementation.wis, expected[index].wis, `case ${index} wis`);
assertClose(
implementation.dispersion,
expected[index].dispersion,
`case ${index} dispersion`,
);
assertClose(
implementation.underprediction,
expected[index].underprediction,
`case ${index} underprediction`,
);
assertClose(
implementation.overprediction,
expected[index].overprediction,
`case ${index} overprediction`,
);
}
});

assert.equal(calculateWIS(100, NaN, 90, 115, 80, 140), null);

console.log(`scoringutils WIS parity passed using ${rscript}`);
103 changes: 3 additions & 100 deletions app/src/lib/forecast-components/scoring.js
Original file line number Diff line number Diff line change
@@ -1,108 +1,11 @@
/**
* Reusable Forecast Scoring Utilities
* Extracted from forecastleScoring.js for use across the app
* WIS is re-exported from the canonical Forecastle scoring utility.
*/

/**
* Calculate interval score for a single prediction interval
* @param {number} observed - Observed value
* @param {number} lower - Lower bound of prediction interval
* @param {number} upper - Upper bound of prediction interval
* @param {number} alpha - Alpha level (e.g., 0.5 for 50% interval, 0.05 for 95% interval)
* @returns {Object} Interval score with components {score, dispersion, underprediction, overprediction}
*/
const calculateIntervalScore = (observed, lower, upper, alpha) => {
if (
!Number.isFinite(observed) ||
!Number.isFinite(lower) ||
!Number.isFinite(upper)
) {
return null;
}

const dispersion = upper - lower;
const underprediction =
observed < lower ? (2 / alpha) * (lower - observed) : 0;
const overprediction =
observed > upper ? (2 / alpha) * (observed - upper) : 0;
const score = dispersion + underprediction + overprediction;

return {
score,
dispersion,
underprediction,
overprediction,
};
};

/**
* Calculate WIS (Weighted Interval Score) for a single forecast
* @param {number} observed - Observed value
* @param {number} median - Median prediction
* @param {number} lower50 - Lower bound of 50% interval (0.25 quantile)
* @param {number} upper50 - Upper bound of 50% interval (0.75 quantile)
* @param {number} lower95 - Lower bound of 95% interval (0.025 quantile)
* @param {number} upper95 - Upper bound of 95% interval (0.975 quantile)
* @returns {Object} WIS with components {wis, dispersion, underprediction, overprediction}
*/
export const calculateWIS = (
observed,
median,
lower50,
upper50,
lower95,
upper95,
) => {
if (!Number.isFinite(observed)) {
return null;
}

// Calculate interval scores for each interval
const interval50 = calculateIntervalScore(observed, lower50, upper50, 0.5);
const interval95 = calculateIntervalScore(observed, lower95, upper95, 0.05);

if (!interval50 || !interval95) {
return null;
}

// Median absolute error (treated as 0-width interval)
const medianAE = Number.isFinite(median) ? Math.abs(observed - median) : 0;

// Weights: alpha/2 for each interval, 0.5 for median
const weight50 = 0.5 / 2; // 0.25
const weight95 = 0.05 / 2; // 0.025
const weightMedian = 0.5; // 0.5

// Weighted sum
const totalWeight = weight50 + weight95 + weightMedian;
const wis =
(weight50 * interval50.score +
weight95 * interval95.score +
weightMedian * medianAE) /
totalWeight;

// Calculate aggregate components
const dispersion =
(weight50 * interval50.dispersion + weight95 * interval95.dispersion) /
totalWeight;

const underprediction =
(weight50 * interval50.underprediction +
weight95 * interval95.underprediction) /
totalWeight;

const overprediction =
(weight50 * interval50.overprediction +
weight95 * interval95.overprediction) /
totalWeight;
import { calculateWIS } from "../../utils/forecastleScoring.js";

return {
wis,
dispersion,
underprediction,
overprediction,
};
};
export { calculateWIS };

/**
* Validate forecast intervals
Expand Down
40 changes: 27 additions & 13 deletions app/src/utils/forecastleScoring.js
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ export const calculateWIS = (
lower95,
upper95,
) => {
if (!Number.isFinite(observed)) {
if (!Number.isFinite(observed) || !Number.isFinite(median)) {
return null;
}

Expand All @@ -101,35 +101,40 @@ export const calculateWIS = (
return null;
}

// Median absolute error (treated as 0-width interval)
const medianAE = Number.isFinite(median) ? Math.abs(observed - median) : 0;
const medianAE = Math.abs(observed - median);

// Weights: alpha/2 for each interval, 0.5 for median
// Total weight = 0.25 + 0.025 + 0.5 = 0.775
// Weights match scoringutils::wis(..., weigh = TRUE,
// count_median_twice = FALSE), normalized by K + 0.5.
const weight50 = 0.5 / 2; // 0.25
const weight95 = 0.05 / 2; // 0.025
const weightMedian = 0.5; // 0.5
const normalizer = 2 + 0.5; // two central intervals plus median half-weight

// Weighted sum
const totalWeight = weight50 + weight95 + weightMedian;
const wis =
(weight50 * interval50.score +
weight95 * interval95.score +
weightMedian * medianAE) /
totalWeight;
normalizer;

// Aggregate components
const dispersion =
(weight50 * interval50.dispersion + weight95 * interval95.dispersion) /
totalWeight;
normalizer;
const medianUnderprediction =
observed > median ? (weightMedian * (observed - median)) / normalizer : 0;
const medianOverprediction =
observed < median ? (weightMedian * (median - observed)) / normalizer : 0;
const underprediction =
(weight50 * interval50.underprediction +
weight95 * interval95.underprediction) /
totalWeight;
const overprediction =
(weight50 * interval50.overprediction +
weight95 * interval95.overprediction) /
totalWeight;
normalizer +
medianUnderprediction;
const overprediction =
(weight50 * interval50.underprediction +
weight95 * interval95.underprediction) /
normalizer +
medianOverprediction;

return {
wis,
Expand Down Expand Up @@ -307,6 +312,15 @@ export const scoreModels = (modelForecasts, horizons, groundTruthValues) => {
return null;
}

if (
lower95 > lower50 ||
lower50 > median ||
median > upper50 ||
upper50 > upper95
) {
return null;
}

return {
median,
lower50,
Expand Down
Loading