Refactor model class hierarchy into composable Forecaster/StepPredictor layers#208
Conversation
- Update test_datasets.py to use ForecasterModule instead of GraphLAM - Update test_plotting.py to use ForecasterModule instead of GraphLAM - Fix interior_mask_bool property shape (1,) -> (N,) for correct loss masking - Fix all_gather_cat to handle single-device runs without incorrect dim collapse
|
@joeloskarsson I have made the changes to the issue #49 please check if this is up to mark. I would love to hear your feedback. |
|
@observingClouds and @leifdenby I have made the changes explained in the issue #49 if there are some mistakes please tell me about it would be helpful to know what i have done wrong thank you sorry for pestering you |
leifdenby
left a comment
There was a problem hiding this comment.
Thank you for working on this @Sir-Sloth-The-Lazy!
This PR changes quite a lot (as expected of course because you are redesigning the class hierarchy for the models). I would like to give this a more thorough review, but first I need to get a clearer overview.
Would you be able to make a diagram which explains what each of the classes (forecast, step-predictor) take as input and return? With this I mean details of the shapes of the inputs and outputs too. I think we should put that in some documentation so that going forward the class hierarchy of the model architectures is clearer.
Thanks!
Sir-Sloth-The-Lazy
left a comment
There was a problem hiding this comment.
I have provided the clarification for the design choices. If some feel unsatisfactory please let me know , I would be happy to find some other way ;) @leifdenby
|
@leifdenby sorry for disturbing you, just a reminder , wanted to know if this is the write way or should i find another way to redesign the class |
joeloskarsson
left a comment
There was a problem hiding this comment.
Thanks for starting all the work with this! I think this is a huge improvement to the codebase, and many things clearly become cleaner 😄 I added a number of comments and some points for discussion.
This is quite a large change to the codebase, that touches much of the code. There are many PRs open that will cause changes to code that is being refactored here. Mostly these PRs will make this change simpler I think, which is good. But because of this there will be some planning needed about when to merge this, and we will need to discuss this in our monthly dev meetings. But should be no issue to get this to a merge:able state and then making a plan for when to merge it!
|
@joeloskarsson thank you for this detailed review. Honestly , I was not feeling motivated enough to code today but the sheer amount of effort you have put in giving this feedback just made me instantly wanting to work again. I feel the time you would have taken out to read this code change would be huge! Thanks again to all of you guys @leifdenby @joeloskarsson will make the recommended changes and come again cheers ! ;) |
|
Hi @Sir-Sloth-The-Lazy, fantastic work getting this foundational PR started! I've been following the discussions in #49 closely, as I am planning to focus my GSoC proposal on extending this exact hierarchy to support the probabilistic and ensemble models ( @joeloskarsson, regarding your comment on the I want to make sure I accurately map out the data flow for the generative architectures in my proposal draft! |
…r hierarchy - Replace opaque argparse.Namespace with explicit keyword arguments in StepPredictor, BaseGraphModel, BaseHiGraphModel, GraphLAM, HiLAM, and HiLAMParallel __init__ methods - Reorder methods in step_predictor.py: forward/expand_to_batch now appear before clamping methods - Update all instantiation sites (train_model.py, test_training.py, test_prediction_model_classes.py) to pass explicit kwargs - HiLAM helper methods (make_same/up/down_gnns) now use self.hidden_dim and self.hidden_layers instead of args parameter Addresses review comments on PR mllam#208.
- Rename border to boundary in Forecaster - Pass Forecaster object to ForecasterModule init instead of Predictor - Remove inline imports in ForecasterModule - Move loss-related pred_std logic fully into ForecasterModule - Delete obsolete test_refactored_hierarchy.py
|
Following the review thread and thinking ahead to probabilistic support (#62): would it make sense to formalize the interface as |
|
@AdMub regarding using |
|
@kshirajahere yes, I agree with you on this. @Sir-Sloth-The-Lazy I interpret your review request that you have made changes and I should have a look again? Could you leave a short comment on each of my requested changes, just telling how you fixed each? (if you want link to a commit) That makes it a lot easier for me to review again 🫶 Thanks! |
|
While input here is appreciated, I think this PR is at a stage now where we need to be very careful what is suggested, otherwise we will never be able to merge this, and will just keep iterating forever. Please separate 1. Bugs and regressions, these are good if they are fixed before merging this (@kshirajahere nicely pointed out a few things that I had missed, these probably should be fixed before merging), and 2. suggestions for improvements, making the code more flexible or further changing the overall design. I would recommend to not bring up suggestions of type 2 here, but rather raise those as follow-up issues. Better to get this merged, and then have more small PRs improving things later. |
…hierarchy-issue-49
…escription, and test coverage - Re-export BaseGraphModel and BaseHiGraphModel from neural_lam/models/__init__.py - Fix Forecaster.forward() return type to Optional[torch.Tensor] for pred_std - Save args in save_hyperparameters so checkpoints are self-describing - Pass args from train_model.py into ForecasterModule to carry architecture kwargs - Add load_forecaster_module_from_checkpoint helper to train_model.py - Thread metrics_watch and var_leads_metrics_watch through run_simple_training - Add test_graph_lam_no_static_features using real GraphLAM to prove empty static tensor flows through GNN
|
@kshirajahere @joeloskarsson I have taken into account all the requested enhancements with the latest commits hope this suffices the concerns ! |
| # makes the checkpoint self-describing: it carries model, graph_name, | ||
| # hidden_dim, etc. so the caller can reconstruct the exact forecaster | ||
| # architecture from the checkpoint alone. | ||
| self.save_hyperparameters(ignore=["datastore", "forecaster"]) |
There was a problem hiding this comment.
The placement is intentional. args is still bound when save_hyperparameters runs at line 85, so the namespace is captured. Running it earlier would capture the default values for loss/lr/etc. before the unpack at lines 56-68 has had a chance to override them, that's exactly the legacy-checkpoint regression the unpack block exists to fix (see comment at lines 50-55). The ignore list is also required because datastore and forecaster aren't serializable as hparams.
| """ | ||
| return x.unsqueeze(0).expand(batch_size, -1, -1) | ||
|
|
||
| @abstractmethod |
There was a problem hiding this comment.
The return type is already explicit at line 92 — tuple[torch.Tensor, Optional[torch.Tensor]]. The other files in neural_lam/models/ (e.g. base_graph_model.py, interaction_net.py) use the same compact inline-Returns: style as the current docstring, so reformatting just this one to NumPy-style would make the module less consistent, not more. Happy to switch the whole module's style in a separate PR if there's a project preference by @joeloskarsson but right now , I would rather prefer timely merging of this PR as much depends on this PR !
| from .base_graph_model import BaseGraphModel | ||
| from .base_hi_graph_model import BaseHiGraphModel |
There was a problem hiding this comment.
There's nothing to deprecate here, base_graph_model.py and base_hi_graph_model.py are still in neural_lam/models/, not in any neural_lam.models.graph subpackage. The current __init__.py already re-exports both names, so existing from neural_lam.models import BaseGraphModel imports continue to work. If a relocation lands in a future PR we can add a shim then and at that point the right pattern is a module-level __getattr__ (PEP 562), not a top-level warnings.warn that fires on every package import regardless of whether the deprecated name is actually accessed.
| "graph_lam": GraphLAM, | ||
| "hi_lam": HiLAM, | ||
| "hi_lam_parallel": HiLAMParallel, | ||
| } |
There was a problem hiding this comment.
The proposed __all__ includes Forecaster and StepPredictor, but neither is currently imported in this __init__.py, so from neural_lam.models import * would fail with AttributeError. We'd need to add those imports first.
Beyond that, no other __init__.py in the project defines __all__ , so adding it just here introduces an inconsistency. If the project wants explicit export lists for docs tooling, it should be a coordinated sweep across all __init__.py files @joeloskarsson would love to here your opinion on this. 😁 I lean towards doing this in a future PR!
|
@GiGiKoneti Thanks for taking the time to review. I want to engage with these in good faith,
Extracting Happy to make any of these changes if I'm missing something, just want to make sure we're |
Hey @Sir-Sloth-The-Lazy, thanks for the detailed breakdown! You're absolutely right on the To answer your question on
Yes, there's exactly one: When a user loads a checkpoint trained on an older version of Also, the Other than that, the refactor is looking great and I'm fully on board with the architecture changes! Let's get this merged once those robustness checks are in. Kudos to you. |
|
@GiGiKoneti Did a quick empirical check on this 😄. There's already an existing test at I also ran a one-off probe specifically for >>> loaded.hparams.keys()
['args', 'config', 'create_gif', 'loss', 'lr', 'metrics_watch',
'n_example_pred', 'restore_opt', 'val_steps_to_log', 'var_leads_metrics_watch']
>>> loaded.hparams.metrics_watch
[]
>>> "metrics_watch" in loaded.hparams
TrueThe reason no So I'd like to keep the attribute access — the Let me know if you spot something I've missed!👌🏻 |
|
First @Sir-Sloth-The-Lazy - thank you so much for sticking with this and all your hard work. I have been re-reading through your code and re-reminding myself of the main change here. That prompted me to dig out some code I wrote to parse the class inheritance tree and render it out. I can add the script I used to generate the visualisation below if people think that is useful (I was thinking it could be used for generating documentation). The main change seems to be the introduction of the abstract base classes Forecaster and StepPredictor. Those seem like really good ideas to me (as we've discussed many times). I will continue my review tomorrow and focus on whether I feel like the functionality the currently sits in each is defined where I expect it should be based on this class hierarchy. Hope this is along the lines on what you were hoping for from my angle on reviewing this substantial piece of work :)
|
|
@Sir-Sloth-The-Lazy You're right, I missed that |
|
@leifdenby thank you so much for you kind words and taking out time to review this work !
yes absolutely ! Looking forward to it😄 |
|
@GiGiKoneti Please don't apologies for reviewing my work and sounding your doubts , as you said you learnt something from our conversation as long as that is happening , I would be really happy to answer anymore of your question ! Thank you for your time in making this work better ! 😄 |
leifdenby
left a comment
There was a problem hiding this comment.
Looks great! I have only one minor change I think we should do, and a suggestion that wouldn't be alot of work:
- I think we should put clear docstrings on every
nn.Module.forward(...)method that defines only what the function does (as its function and responsibility), but also details the shape of every call argument (and what each dimension means). - Given that this PR introduces clearer separation between classes by introduces a hierarchy of classes building on base-classes I think we should consider reorganising the python module structure here too, to more closely reflect what the different
torch.nn.Modulederived classes actually do and where they sit in the hierarchy. Currently,neural_lam/models/has a flat structure where we just have all modules next to each other, that makes it hard to reason about where to go. I would suggest a layout like this:
neural_lam/models/
__init__.py
module.py # ForecasterModule
forecasters/
__init__.py
base.py # Forecaster
autoregressive.py # ARForecaster
step_predictors/
__init__.py
base.py # StepPredictor
graph/
__init__.py
base.py # BaseGraphModel
hierarchical.py # BaseHiGraphModel
graph_lam.py # GraphLAM
hi_lam.py # HiLAM
hi_lam_parallel.py # HiLAMParallel
| def predicts_std(self) -> bool: | ||
| return self.predictor.predicts_std | ||
|
|
||
| def forward( |
There was a problem hiding this comment.
I think it would be good to add numpy-style docstrings here, detailing what each argument is and the shape of it, because this function is they main entrypoint to the nn.Module and having a clear interface defined here will make future development easier.
Actually, I would advocate that we start using jaxtyping with beartype to define the meaning of the dims. But we should do that in a separate PR to not hold up this work further. But we should at
There was a problem hiding this comment.
Agree that numpy-style docstrings for these would be good. Agree that jaxtyping could be a nice addition, but should be saved for a future PR.
There was a problem hiding this comment.
There are nice docstrings added now. For the tensor shapes I would however suggest to be consistent with either 1) always write out what each dimension is in the description of each tensor (now this is mostly done, but nor for all) or 2) collect all dimensions explanations in one place of the docstring, so that you just list the tensors and their shapes first, and then one can look up what each dimension is in another place of the docstring.
| super().__init__(args, config=config, datastore=datastore) | ||
| def __init__( | ||
| self, | ||
| config: NeuralLAMConfig, |
There was a problem hiding this comment.
great work on making all these architecture parameters explicitly rather than passed through args - is config: NeuralLAMConfig required here? Maybe we shouldn't propagate that down below the ForecasterModule or at least not all the way down to the graph-based models I think
There was a problem hiding this comment.
I guess this was for the clamping-options. Now they are fed separately, which seems good
…ple config Address PR mllam#208 review comments from @leifdenby: - Reorganise neural_lam/models to mirror the new class hierarchy: forecasters/{base,autoregressive}.py, step_predictors/base.py, and step_predictors/graph/{base,hierarchical,graph_lam,hi_lam,hi_lam_parallel}.py. ForecasterModule moves to neural_lam/models/module.py. The package __init__.py re-exports the public API so callers can use `from neural_lam.models import ARForecaster, GraphLAM, ...` unchanged. - Convert the four forward(...) docstrings on the nn.Module hierarchy (Forecaster, ARForecaster, StepPredictor, BaseGraphModel) to numpy-style with Parameters/Returns sections describing each argument's shape and the meaning of every dimension. - Decouple NeuralLAMConfig from the predictor stack. StepPredictor and the graph-based predictors no longer accept `config: NeuralLAMConfig` — they take `output_clamping_lower` / `output_clamping_upper` dicts directly. NeuralLAMConfig now stops at ForecasterModule. Construction sites in train_model.py and the tests resolve the two dicts from config.training.output_clamping at the call site.
|
I have made the changes as asked @leifdenby @joeloskarsson , looking forward to your review |
…turn type - Drop the unreachable (d_f,) shape from Forecaster.forward's pred_std docstring. The only concrete implementation (ARForecaster) returns (B, pred_steps, num_grid_nodes, d_f) or None — never a per-feature vector — so the wider contract was misleading. - Fix ARForecaster.forward return annotation: tuple[torch.Tensor, torch.Tensor] -> tuple[torch.Tensor, Optional[torch.Tensor]], reflecting that pred_std is None when the wrapped predictor does not output an std.
| ---------- | ||
| init_states : torch.Tensor | ||
| Shape ``(B, 2, num_grid_nodes, d_f)``. The two initial states | ||
| ``[X_{t-1}, X_t]`` used to seed the forecast. Dims: ``B`` is |
There was a problem hiding this comment.
I have never heard the terminology "seed the forecast", so maybe just writing something like "start the forecast from" is better to avoid confusion.
|
There are sadly again also some conflicts to resolve. I have discussed with other maintainers that we need to now pause merging other PRs so this can make it in, so we don't have to keep resolving conflicts here. |
Convert every remaining informal/Google-style docstring in models/, interaction_net, metrics, utils, vis, datastore/base, and custom_loggers to numpy-style (Parameters\n---------- / Returns\n-------). Also fill in missing inline Dims: explanations for all tensor parameters and return values so dimension letters are documented consistently throughout.
|
@leifdenby and @joeloskarsson I have tried to add numpy style docstrings to the codebase , please 🙏 give your feedback on this , whether we should go forward with this or revert to the way it was before this commit. With only the changes suggested by Joel. 😄 |


Describe your changes
Refactors the monolithic
ARModelclass into a composable hierarchy of smaller, focused components:ForecasterModule(pl.LightningModule): Training loop, metrics, plotting, optimizer configARForecaster(nn.Module): Auto-regressive unrolling with boundary maskingStepPredictor(nn.Module): Single-step prediction, normalization, clampingThis separation makes it straightforward to add non-autoregressive forecasters, new step predictor architectures (e.g. Vision Transformers), or ensemble strategies without modifying the training infrastructure.
Also fixes two pre-existing bugs:
interior_mask_boolshape(1,)→(N,)for correct loss maskingall_gather_catdimension collapse on single-device runsIssue Link
Closes #49
Type of change
Checklist before requesting a review