Skip to content

Adds PropagationNet GNN layer and makes it optionally usable in existing deterministic models#507

Open
Sir-Sloth-The-Lazy wants to merge 25 commits intomllam:mainfrom
Sir-Sloth-The-Lazy:refactor/batch-fold-ensemble-prep
Open

Adds PropagationNet GNN layer and makes it optionally usable in existing deterministic models#507
Sir-Sloth-The-Lazy wants to merge 25 commits intomllam:mainfrom
Sir-Sloth-The-Lazy:refactor/batch-fold-ensemble-prep

Conversation

@Sir-Sloth-The-Lazy
Copy link
Copy Markdown
Contributor

@Sir-Sloth-The-Lazy Sir-Sloth-The-Lazy commented Mar 24, 2026

Describe your changes

Adds PropagationNet GNN layer and makes it optionally usable in existing deterministic models, as outlined in #62.

It is integrated into the existing model hierarchy from #208 and can be enabled via the vertical_propnets flag.

Depends on #208.
For changes on top of #208 only, see:
Sir-Sloth-The-Lazy/neural-lam@refactor/model-class-hierarchy-issue-49...refactor/batch-fold-ensemble-prep

Issue Link

Contributes to #62

Type of change

  • 🐛 Bug fix (non-breaking change that fixes an issue)
  • ✨ New feature (non-breaking change that adds functionality)
  • 💥 Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • 📖 Documentation (Addition or improvements to documentation)

Checklist before requesting a review

  • My branch is up-to-date with the target branch - if not update your fork with the changes from the target branch (use pull with --rebase option if possible).
  • I have performed a self-review of my code
  • For any new/modified functions/classes I have added docstrings that clearly describe its purpose, expected inputs and returned values
  • I have placed in-line comments to clarify the intent of any hard-to-understand passages of my code
  • I have updated the README to cover introduced code changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have given the PR a name that clearly describes the change, written in imperative form (context).
  • I have requested a reviewer and an assignee (assignee is responsible for merging). This applies only if you have write access to the repo, otherwise feel free to tag a maintainer to add a reviewer and assignee.

Checklist for reviewers

Each PR comes with its own improvements and flaws. The reviewer should check the following:

  • the code is readable
  • the code is well tested
  • the code is documented (including return types and parameters)
  • the code is easy to maintain

Author checklist after completed review

  • I have added a line to the CHANGELOG describing this change, in a section
    reflecting type of change (add section where missing):
    • added: when you have added new functionality
    • changed: when default behaviour of the code has been changed
    • fixes: when your contribution fixes a bug
    • maintenance: when your contribution is relates to repo maintenance, e.g. CI/CD or documentation

Checklist for assignee

  • PR is up to date with the base branch
  • the tests pass
  • (if the PR is not just maintenance/bugfix) the PR is assigned to the next milestone. If it is not, propose it for a future milestone.
  • author has added an entry to the changelog (and designated the change as added, changed, fixed or maintenance)
  • Once the PR is ready to be merged, squash commits and merge the PR.

Sir-Sloth-The-Lazy and others added 22 commits February 21, 2026 17:42
- Update test_datasets.py to use ForecasterModule instead of GraphLAM
- Update test_plotting.py to use ForecasterModule instead of GraphLAM
- Fix interior_mask_bool property shape (1,) -> (N,) for correct loss masking
- Fix all_gather_cat to handle single-device runs without incorrect dim collapse
…r hierarchy

- Replace opaque argparse.Namespace with explicit keyword arguments in
  StepPredictor, BaseGraphModel, BaseHiGraphModel, GraphLAM, HiLAM,
  and HiLAMParallel __init__ methods
- Reorder methods in step_predictor.py: forward/expand_to_batch now
  appear before clamping methods
- Update all instantiation sites (train_model.py, test_training.py,
  test_prediction_model_classes.py) to pass explicit kwargs
- HiLAM helper methods (make_same/up/down_gnns) now use self.hidden_dim
  and self.hidden_layers instead of args parameter

Addresses review comments on PR mllam#208.
- Rename border to boundary in Forecaster
- Pass Forecaster object to ForecasterModule init instead of Predictor
- Remove inline imports in ForecasterModule
- Move loss-related pred_std logic fully into ForecasterModule
- Delete obsolete test_refactored_hierarchy.py
Co-authored-by: Joel Oskarsson <joel.oskarsson@outlook.com>
- Add predicts_std property to StepPredictor, Forecaster and ARForecaster
  so ForecasterModule can query the forecaster instead of taking output_std
  as a separate constructor argument
- Remove output_std parameter from ForecasterModule; use
  self._forecaster.predicts_std throughout
- Move fallback per_var_std logic out of forecast_for_batch into each
  step method so pred_std is None before fallback, enabling direct None
  checks instead of hparam checks
- Replace len(datastore.boundary_mask) with datastore.num_grid_points in
  StepPredictor to avoid relying on boundary_mask
- Move get_state_feature_weighting and ARForecaster inline imports to
  module-level imports in forecaster_module.py and train_model.py
- Fix statement ordering in StepPredictor.__init__ so register_buffer for
  grid_static_features appears directly after building the tensor
- Replace dict+loop pattern for registering state_mean/state_std buffers
  with two direct register_buffer calls
- Remove all internal Item N checklist references from comments
- Remove TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD env var hack; pass
  weights_only=False explicitly to load_from_checkpoint calls and
  weights_only=True to torch.load in test_graph_creation.py
- Add test_step_predictor_no_static_features to verify models initialise
  and run correctly when the datastore returns None for static features
- Fix graph= -> graph_name= and model.forecaster -> model._forecaster in
  tests to match current API
…r_batch

Makes the forecasting path tolerant to batch-folded execution so that
future ensemble generation can fold (S, B) into (S*B) before calling
ARForecaster, without any changes to ARForecaster or StepPredictor.

Prediction is kept folded through the existing deterministic logging and
aggregation paths so all dim assumptions in training_step, validation_step,
and test_step remain correct. Unfolding to (*leading, T, N, F) is deferred
to ensemble-specific subclasses (e.g. EnsForecasterModule).

Adds test_fold_unfold_equivalence to confirm ARForecaster's rollout is
rank-transparent under a pre-entry fold.
…stic models

- Port PropagationNet as InteractionNet subclass (mean aggr, sender residual
  in messages, aggregation residual in forward)
- Add --vertical_propnets CLI flag to select PropagationNet for grid-mesh
  and vertical message passing edges
- Wire flag through model hierarchy: BaseGraphModel (g2m/m2g),
  BaseHiGraphModel (mesh init), HiLAM (up GNNs)
- Add 13 tests covering unit behavior and backward compatibility
@Sir-Sloth-The-Lazy
Copy link
Copy Markdown
Contributor Author

Sir-Sloth-The-Lazy commented Mar 24, 2026

@joeloskarsson @sadamov @observingClouds please have a look , if this qualifies as the next step in ensemble prep 😄 . Would be grateful for your feedback !

@Debadri-das
Copy link
Copy Markdown

@joeloskarsson @sadamov would request your further review on this PR!

Copy link
Copy Markdown
Collaborator

@joeloskarsson joeloskarsson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for looking at this, and sorry for being slow with reviews 😅 Shared some first thoughts here on how I think we can best integrate this. Happy to hear input also from others, as there are some non-trivial design choices around this (e.g. how to choose the GNN type for each sub-graph).

Comment thread neural_lam/models/base_graph_model.py Outdated
Comment thread neural_lam/interaction_net.py Outdated
…trings

Replace the single `vertical_propnets: bool` flag with per-direction
string-based GNN type parameters (g2m_gnn_type, m2g_gnn_type,
mesh_up_gnn_type, mesh_down_gnn_type) for more flexible and
future-proof GNN selection. Add GNN_TYPES registry and get_gnn_class()
lookup in interaction_net.py. Refactor PropagationNet to eliminate
duplicated forward() by extracting node_residual_target() override.
@Sir-Sloth-The-Lazy
Copy link
Copy Markdown
Contributor Author

Refactor: replace vertical_propnets bool with fine-grained GNN type strings

Overview

Replaced the single --vertical_propnets boolean flag with four per-direction string parameters, enabling independent GNN selection per message-passing direction.


Changes

  • Per-direction GNN type parameters
    Replaced --vertical_propnets with four string args: --g2m_gnn_type, --m2g_gnn_type, --mesh_up_gnn_type, --mesh_down_gnn_type. Each defaults to "InteractionNet" and can be set to "PropagationNet".

  • GNN_TYPES registry + get_gnn_class() lookup (interaction_net.py)
    Added a registry to resolve GNN type strings to their corresponding classes, making it straightforward to register new GNN types in the future.

  • PropagationNet forward() deduplication
    Eliminated the duplicated forward() method by extracting a node_residual_target() hook:

    • InteractionNet.node_residual_target() → returns the receiver representation (rec_rep)
    • PropagationNet.node_residual_target() → returns the aggregated edge messages (edge_rep_aggr)
    • Shared forward() now delegates to this hook.
  • CLI wiring + propagation through model hierarchy
    Threaded the per-direction types through train_model.py, BaseGraphModel, BaseHiGraphModel, HiLAM, HiLAMParallel, and GraphLAM.

  • Integration tests updated
    Tests now exercise the string-based API, including per-direction overrides and invalid-type error handling.


Files Changed (8)

File Nature of change
interaction_net.py Added GNN_TYPES registry, get_gnn_class(), node_residual_target() hook; unified forward()
base_graph_model.py Threaded per-direction GNN type args
base_hi_graph_model.py Threaded per-direction GNN type args
graph_lam.py Wired new GNN type params
hi_lam.py Wired new GNN type params
hi_lam_parallel.py Wired new GNN type params
train_model.py Replaced --vertical_propnets with four --*_gnn_type CLI args
test_propagation_net.py Updated tests for string-based API and error handling

@Sir-Sloth-The-Lazy
Copy link
Copy Markdown
Contributor Author

Sir-Sloth-The-Lazy commented Apr 5, 2026

Proposal: GNN type selection per sub-graph

Context

With the current commit (3472387), four per-direction CLI args cover the encoder/decoder
and inter-level GNNs. However, three sub-graphs remain hardcoded to InteractionNet:

Sub-graph Location Currently configurable?
m2m same-level HiLAM.make_same_gnns No
m2m processor GraphLAM.__init__ No
processor (all edges) HiLAMParallel.__init__ No

Note on m2m same-level edges: The original --vertical_propnets flag deliberately
excluded same-level m2m processing, nodes at the same mesh level are peers with no
clear sender→receiver directionality, so PropagationNet (designed for directional
vertical message passing) may not be semantically meaningful here. Any global default
that switches everything would affect m2m same-level too, which is worth keeping in mind.


Option A: Single --gnn_type default + per-direction overrides

Add one global --gnn_type arg (default "InteractionNet"). Per-direction args
(--g2m_gnn_type, --m2m_gnn_type, etc.) default to None and override the global
when set. Currently-hardcoded sub-graphs in HiLAM and GraphLAM fall back to the
global default.

Resolution logic:

g2m_gnn_type = g2m_gnn_type or gnn_type  # per-direction wins if set

Pros:

  • Clean UX for the common case — --gnn_type PropagationNet switches everything in one flag
  • Fine-grained control still available via per-direction overrides
  • Small diff on top of the current implementation
  • Closes the hardcoded gaps (except HiLAMParallel) without adding more args

Cons:

  • None-vs-default fallback logic adds a layer of indirection to understand
  • A blanket --gnn_type PropagationNet would also switch m2m same-level edges, which may
    not be desirable (mitigated by --m2m_gnn_type InteractionNet override, but the user
    has to know to do this)
  • Still CLI-arg-driven, which will eventually need migration if the codebase moves fully
    to config files

Option B: YAML config-driven (via NeuralLAMConfig)

Move GNN type selection into the existing NeuralLAMConfig YAML config. The codebase
already uses dataclass_wizard.YAMLWizard with Union-based type selection for training
config (loss weighting, output clamping). Model architecture config would follow the same
pattern:

model:
  hidden_dim: 64
  processor_layers: 4
  gnn_types:
    g2m: PropagationNet
    m2g: InteractionNet
    mesh_up: PropagationNet
    mesh_down: InteractionNet
    m2m: InteractionNet

Pros:

  • Consistent with the direction the codebase is already moving (NeuralLAMConfig handles
    datastore, loss weighting, output clamping via YAML)
  • Scales cleanly to any number of sub-graphs or future model config without CLI bloat
  • Config files are versioned, shareable, and reproducible — easier to track experiment setups
  • No fallback/override logic needed; each sub-graph has an explicit entry
  • Makes the m2m same-level choice explicit rather than inheriting from a global default

Cons:

  • Larger refactor — needs new dataclasses, migration of existing CLI model args
  • Scope creep risk if done within this PR (--hidden_dim, --processor_layers would
    ideally move at the same time)
  • Diverges from the current CLI-driven workflow that other model args still use

Recommendation

Option A now, Option B later.

Option A is a small, self-contained change that fits the existing CLI pattern and closes
the hardcoded gaps in HiLAM and GraphLAM immediately. Option B is the right
long-term direction but is better done as a dedicated refactor that moves all model
hyperparameters into NeuralLAMConfig at once, rather than partially migrating just
GNN types.

HiLAMParallel remains out of scope for both options , its processor fuses all
edge types (m2m + up + down) into a single InteractionNet with
edge_chunk_sizes/aggr_chunk_sizes, so per-direction GNN type selection is not
applicable without rethinking its architecture.

Copy link
Copy Markdown
Collaborator

@joeloskarsson joeloskarsson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking really nice now! Previous comments addressed, I just added some small suggestions here that is really just polishing.

The outstanding question is indeed what to do with the remaining GNN layers, if we should make them all configurable. In my opinion there is not that much point in making the GNNs for m2m edges propagation-networks, so for most practical use cases the present options are sufficient. But it is a bit strange that some of the GNNs can be changed but not all. I think I am leaning towards trying to merge this as is atm, and then leaving a further generalization (e.g. through a model-config file) as a future change. But this requires some discussion I think. I will propose this for the next release version, and also see what others at the dev meeting think about the configurability question.

Comment thread neural_lam/train_model.py
Comment thread neural_lam/gnn_layers.py
@joeloskarsson joeloskarsson added this to the v0.7.0 (proposed) milestone Apr 7, 2026
Use argparse `choices=list(GNN_TYPES.keys())` on the four per-direction
GNN type args so invalid values are rejected at the CLI boundary with a
clear error message, rather than failing later inside get_gnn_class().
Drops the enumerated options from the help strings since argparse now
surfaces them automatically in --help.
Rename neural_lam/interaction_net.py to neural_lam/gnn_layers.py since
the module now houses multiple GNN layer classes (InteractionNet,
PropagationNet, SplitMLPs) and the GNN_TYPES registry. Also rename
tests/test_propagation_net.py to tests/test_gnn_layers.py to match,
since it covers both InteractionNet and PropagationNet behavior.

Update all imports across neural_lam/__init__.py, train_model.py, and
the model files. Pre-commit hooks also auto-formatted a few unrelated
files (forecaster_module.py, test_plotting.py,
test_prediction_model_classes.py).
@Sir-Sloth-The-Lazy
Copy link
Copy Markdown
Contributor Author

Hi @joeloskarsson made the changes , agreed with your verdict and hope you found the proposal a little helpful ! ;)

@sadamov sadamov added the enhancement New feature or request label Apr 13, 2026
@joeloskarsson
Copy link
Copy Markdown
Collaborator

From the dev-meeting there seems to be agreement to stick to the current design of cli-arguments for choosing some gnn layers, and leaving the broader config to later.

@Sir-Sloth-The-Lazy
Copy link
Copy Markdown
Contributor Author

Understood , I would love to work on that when the time comes 😄

Copy link
Copy Markdown
Collaborator

@joeloskarsson joeloskarsson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The implementation here is all good now, so approving this. This should however go in after #208, so need to wait after possible conflict resolution with that (although might not be much as it is built on top of that).

The one thing missing is a changelog entry :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants