Skip to content

Add QuantileOutput support to DeepAR#3280

Open
timoschowski wants to merge 8 commits intoawslabs:devfrom
timoschowski:feature/deepar-quantile-output
Open

Add QuantileOutput support to DeepAR#3280
timoschowski wants to merge 8 commits intoawslabs:devfrom
timoschowski:feature/deepar-quantile-output

Conversation

@timoschowski
Copy link
Contributor

Summary

  • Add support for QuantileOutput in DeepAR, enabling direct quantile regression (pinball loss) as an alternative to distribution-based outputs trained with NLL loss
  • Refactor forward() into _forward_distribution() (existing path, unchanged) and _forward_quantile() (new path using median for autoregressive feedback)
  • Wire QuantileForecastGenerator in the estimator's create_predictor() so end-to-end training and inference works out of the box
  • Add unit tests and two comparison example scripts (synthetic data + electricity dataset)

Motivation

DeepAR currently only supports distribution-based outputs (e.g. StudentTOutput, NormalOutput). This PR gives users a simpler, non-parametric alternative — QuantileOutput — that directly predicts quantile values via pinball loss, similar to what MQ-CNN already supports. This is useful when users want specific quantile forecasts without assuming a parametric distribution family.

Usage

from gluonts.torch.model.deepar import DeepAREstimator
from gluonts.torch.distributions.quantile_output import QuantileOutput

estimator = DeepAREstimator(
    freq="H",
    prediction_length=24,
    distr_output=QuantileOutput(quantiles=[0.1, 0.5, 0.9]),
)
predictor = estimator.train(training_data)

Changes

File Change
src/gluonts/torch/model/deepar/module.py Widen distr_output type to Output, dispatch forward() to distribution/quantile paths, add _forward_quantile(), guard output_distribution() and log_prob()
src/gluonts/torch/model/deepar/estimator.py Widen distr_output type, select QuantileForecastGenerator vs SampleForecastGenerator in create_predictor()
test/torch/model/test_deepar_modules.py Add test_deepar_quantile_output() — shapes, loss, log_prob error
examples/deepar_quantile_comparison.py Comparison on synthetic sine data: NormalOutput vs QuantileOutput
examples/deepar_electricity_studentt_vs_quantile.py Comparison on electricity dataset: StudentTOutput vs QuantileOutput with GluonTS Evaluator

Design decisions

  1. Median for autoregressive feedback: The P50 quantile (or closest to 0.5) is fed back into the RNN at each prediction step
  2. Return normalized predictions: _forward_quantile() returns quantile values in scale-normalized space; QuantileForecastGenerator handles scale multiplication — consistent with how QuantileOutput.loss() already works
  3. num_parallel_samples ignored for QuantileOutput: Quantile predictions are deterministic, no sampling needed
  4. Branching via isinstance: Simple, localized, fully backward-compatible — no changes to the Output class hierarchy

Test plan

  • python -m pytest test/torch/model/test_deepar_modules.py -v — all 11 tests pass (4 existing distribution tests + 6 RNN input tests + 1 new quantile test)
  • examples/deepar_quantile_comparison.py — trains both models, produces forecasts, prints metrics, saves plot
  • examples/deepar_electricity_studentt_vs_quantile.py — end-to-end on electricity dataset with GluonTS Evaluator metrics

🤖 Generated with Claude Code

timjan-db and others added 8 commits January 31, 2026 06:09
…data.

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Key fixes:
- Changed scaling parameter default from True to None (defaults to False for quantile output, matching MXNet's NOPScaler behavior)
- Removed incorrect scale multiplication before quantile projection (MXNet applies quantile_proj directly to decoder output)
- Added checkpoint loading error handling in PyTorchLightningEstimator
- Added AddSeriesScale transformation for forking sequence models

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Added static features (including log(scale)) to RNN encoder input
to match CNN encoder and MXNet MQRNN behavior.

Changes:
- RNNEncoder.forward() now concatenates target + static_features + dynamic_features
- Matches CNN encoder feature concatenation pattern
- Matches MXNet RNNEncoder._assemble_inputs() behavior

Results:
- Reduced MAE difference from 32.85% to 21.16% on electricity dataset
- RMSE difference improved to 8.62%
- Still investigating remaining ~20% difference

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Adds unit tests, integration tests, and parity tests to ensure MQCNN
and MQRNN PyTorch implementations work correctly and maintain parity
with MXNet reference implementations.

Key additions:
- Regression test for lazy initialization optimizer bug
- Integration tests for both MQCNN and MQRNN estimators
- MXNet vs PyTorch parity tests with documented tolerances
- Tests verify RNN parameters are trained correctly

Test coverage:
- test_optimizer_includes_all_parameters: Critical regression test
- test_mq_dnn_estimator_constant_dataset: End-to-end estimator tests
- test_mq_dnn_mxnet_pytorch_parity: Framework parity verification
- Additional tests for various configurations and edge cases

All tests include clear assertions and failure messages to facilitate
debugging if issues arise.

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Remove backup files and debug output that should not be included
in the pull request:
- estimator.py.backup (backup file)
- extreme_scales_output.txt (debug output)
- MQ_DNN_MIGRATION_SUMMARY.md (internal documentation)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Enable DeepAR to use QuantileOutput (pinball loss) as an alternative to
distribution-based outputs (e.g. StudentTOutput with NLL loss). This gives
users a simpler, non-parametric option for probabilistic forecasting with
DeepAR's autoregressive architecture.

Changes:
- module.py: Widen distr_output type to Output, refactor forward() into
  _forward_distribution() and _forward_quantile() paths, add assertion
  to output_distribution(), raise NotImplementedError in log_prob() for
  QuantileOutput
- estimator.py: Widen distr_output type, select QuantileForecastGenerator
  vs SampleForecastGenerator in create_predictor()
- test_deepar_modules.py: Add test_deepar_quantile_output() covering
  shapes, loss, and log_prob error
- examples/: Add comparison scripts on synthetic and electricity data

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@timoschowski
Copy link
Contributor Author

deepar_electricity_studentt_vs_quantile deepar_quantile_comparison

@carlocav
Copy link

carlocav commented Mar 2, 2026

Hello. Is it possible to use quantile loss for computing the validation loss and keep likelihood for the training loss? Deepar usually is unstable, and minimizing a validation quantile loss while still training the model to predict the full distribution rather than quantiles could be useful. This is implemented in nixtla neuralforecast which however lacks essential functions such as generating dependent sample paths

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants