Skip to content

Refactor model class hierarchy into composable Forecaster/StepPredictor layers#208

Open
Sir-Sloth-The-Lazy wants to merge 40 commits intomllam:mainfrom
Sir-Sloth-The-Lazy:refactor/model-class-hierarchy-issue-49
Open

Refactor model class hierarchy into composable Forecaster/StepPredictor layers#208
Sir-Sloth-The-Lazy wants to merge 40 commits intomllam:mainfrom
Sir-Sloth-The-Lazy:refactor/model-class-hierarchy-issue-49

Conversation

@Sir-Sloth-The-Lazy
Copy link
Copy Markdown
Contributor

@Sir-Sloth-The-Lazy Sir-Sloth-The-Lazy commented Feb 21, 2026

Describe your changes

Refactors the monolithic ARModel class into a composable hierarchy of smaller, focused components:

  • ForecasterModule (pl.LightningModule): Training loop, metrics, plotting, optimizer config
  • ARForecaster (nn.Module): Auto-regressive unrolling with boundary masking
  • StepPredictor (nn.Module): Single-step prediction, normalization, clamping

This separation makes it straightforward to add non-autoregressive forecasters, new step predictor architectures (e.g. Vision Transformers), or ensemble strategies without modifying the training infrastructure.

Also fixes two pre-existing bugs:

  • interior_mask_bool shape (1,)(N,) for correct loss masking
  • all_gather_cat dimension collapse on single-device runs

Issue Link

Closes #49

Type of change

  • 🐛 Bug fix (non-breaking change that fixes an issue)
  • ✨ New feature (non-breaking change that adds functionality)
  • 💥 Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • 📖 Documentation (Addition or improvements to documentation)

Checklist before requesting a review

  • My branch is up-to-date with the target branch
  • I have performed a self-review of my code
  • For any new/modified functions/classes I have added docstrings that clearly describe its purpose, expected inputs and returned values
  • I have placed in-line comments to clarify the intent of any hard-to-understand passages of my code
  • I have updated the README to cover introduced code changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have given the PR a name that clearly describes the change, written in imperative form
  • I have requested a reviewer and an assignee

@Sir-Sloth-The-Lazy
Copy link
Copy Markdown
Contributor Author

@joeloskarsson I have made the changes to the issue #49 please check if this is up to mark. I would love to hear your feedback.

@Sir-Sloth-The-Lazy
Copy link
Copy Markdown
Contributor Author

@observingClouds and @leifdenby I have made the changes explained in the issue #49 if there are some mistakes please tell me about it would be helpful to know what i have done wrong thank you sorry for pestering you

Copy link
Copy Markdown
Member

@leifdenby leifdenby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for working on this @Sir-Sloth-The-Lazy!

This PR changes quite a lot (as expected of course because you are redesigning the class hierarchy for the models). I would like to give this a more thorough review, but first I need to get a clearer overview.

Would you be able to make a diagram which explains what each of the classes (forecast, step-predictor) take as input and return? With this I mean details of the shapes of the inputs and outputs too. I think we should put that in some documentation so that going forward the class hierarchy of the model architectures is clearer.

Thanks!

Comment thread neural_lam/models/archive/ar_model.py Outdated
Comment thread neural_lam/models/forecaster_module.py Outdated
Copy link
Copy Markdown
Contributor Author

@Sir-Sloth-The-Lazy Sir-Sloth-The-Lazy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have provided the clarification for the design choices. If some feel unsatisfactory please let me know , I would be happy to find some other way ;) @leifdenby

@Sir-Sloth-The-Lazy
Copy link
Copy Markdown
Contributor Author

@leifdenby sorry for disturbing you, just a reminder , wanted to know if this is the write way or should i find another way to redesign the class

Copy link
Copy Markdown
Collaborator

@joeloskarsson joeloskarsson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for starting all the work with this! I think this is a huge improvement to the codebase, and many things clearly become cleaner 😄 I added a number of comments and some points for discussion.

This is quite a large change to the codebase, that touches much of the code. There are many PRs open that will cause changes to code that is being refactored here. Mostly these PRs will make this change simpler I think, which is good. But because of this there will be some planning needed about when to merge this, and we will need to discuss this in our monthly dev meetings. But should be no issue to get this to a merge:able state and then making a plan for when to merge it!

Comment thread neural_lam/models/forecaster.py Outdated
Comment thread neural_lam/models/forecaster_module.py Outdated
Comment thread neural_lam/models/forecaster_module.py Outdated
Comment thread neural_lam/models/forecaster_module.py Outdated
Comment thread neural_lam/models/forecaster_module.py Outdated
Comment thread neural_lam/models/step_predictor.py Outdated
Comment thread neural_lam/models/step_predictor.py Outdated
Comment thread neural_lam/models/step_predictor.py Outdated
Comment thread tests/test_refactored_hierarchy.py
Comment thread tests/test_training.py Outdated
@Sir-Sloth-The-Lazy
Copy link
Copy Markdown
Contributor Author

Sir-Sloth-The-Lazy commented Mar 1, 2026

@joeloskarsson thank you for this detailed review. Honestly , I was not feeling motivated enough to code today but the sheer amount of effort you have put in giving this feedback just made me instantly wanting to work again. I feel the time you would have taken out to read this code change would be huge! Thanks again to all of you guys @leifdenby @joeloskarsson will make the recommended changes and come again cheers ! ;)

@AdMub
Copy link
Copy Markdown

AdMub commented Mar 1, 2026

Hi @Sir-Sloth-The-Lazy, fantastic work getting this foundational PR started! I've been following the discussions in #49 closely, as I am planning to focus my GSoC proposal on extending this exact hierarchy to support the probabilistic and ensemble models (ARProbModel, Diffusion-LAM).

@joeloskarsson, regarding your comment on the pred_std handling: If the StepPredictor returns both next_state and pred_std, would you prefer the Forecaster to blindly pass both up to the ForecasterModule for loss computation, or should the Forecaster (specifically the ARForecaster subclass) actively use that pred_std during its unrolling loop to sample the next state for ensemble generation?

I want to make sure I accurately map out the data flow for the generative architectures in my proposal draft!

…r hierarchy

- Replace opaque argparse.Namespace with explicit keyword arguments in
  StepPredictor, BaseGraphModel, BaseHiGraphModel, GraphLAM, HiLAM,
  and HiLAMParallel __init__ methods
- Reorder methods in step_predictor.py: forward/expand_to_batch now
  appear before clamping methods
- Update all instantiation sites (train_model.py, test_training.py,
  test_prediction_model_classes.py) to pass explicit kwargs
- HiLAM helper methods (make_same/up/down_gnns) now use self.hidden_dim
  and self.hidden_layers instead of args parameter

Addresses review comments on PR mllam#208.
- Rename border to boundary in Forecaster
- Pass Forecaster object to ForecasterModule init instead of Predictor
- Remove inline imports in ForecasterModule
- Move loss-related pred_std logic fully into ForecasterModule
- Delete obsolete test_refactored_hierarchy.py
@kshirajahere
Copy link
Copy Markdown
Contributor

Following the review thread and thinking ahead to probabilistic support (#62): would it make sense to formalize the interface as StepPredictor -> (pred_state, pred_std | None), then keep loss-weighting and mask handling in ForecasterModule (with boundary mask read directly from datastore)? @joeloskarsson That seems to reduce coupling and may make ensemble/probabilistic evaluation plumbing cleaner later.

@joeloskarsson
Copy link
Copy Markdown
Collaborator

@AdMub regarding using pred_std during unrolling: While sampling using pred_std (just adding Gaussian noise with this std) would be the theoretically sound think to do, this is a very bad idea in practice and just makes us end up with noisy forecasts. So I see no need for this, even optionally.

@joeloskarsson
Copy link
Copy Markdown
Collaborator

@kshirajahere yes, I agree with you on this.

@Sir-Sloth-The-Lazy I interpret your review request that you have made changes and I should have a look again? Could you leave a short comment on each of my requested changes, just telling how you fixed each? (if you want link to a commit) That makes it a lot easier for me to review again 🫶 Thanks!

@joeloskarsson
Copy link
Copy Markdown
Collaborator

While input here is appreciated, I think this PR is at a stage now where we need to be very careful what is suggested, otherwise we will never be able to merge this, and will just keep iterating forever. Please separate 1. Bugs and regressions, these are good if they are fixed before merging this (@kshirajahere nicely pointed out a few things that I had missed, these probably should be fixed before merging), and 2. suggestions for improvements, making the code more flexible or further changing the overall design. I would recommend to not bring up suggestions of type 2 here, but rather raise those as follow-up issues. Better to get this merged, and then have more small PRs improving things later.

…escription, and test coverage

- Re-export BaseGraphModel and BaseHiGraphModel from neural_lam/models/__init__.py
- Fix Forecaster.forward() return type to Optional[torch.Tensor] for pred_std
- Save args in save_hyperparameters so checkpoints are self-describing
- Pass args from train_model.py into ForecasterModule to carry architecture kwargs
- Add load_forecaster_module_from_checkpoint helper to train_model.py
- Thread metrics_watch and var_leads_metrics_watch through run_simple_training
- Add test_graph_lam_no_static_features using real GraphLAM to prove empty static tensor flows through GNN
@Sir-Sloth-The-Lazy
Copy link
Copy Markdown
Contributor Author

@kshirajahere @joeloskarsson I have taken into account all the requested enhancements with the latest commits hope this suffices the concerns !

Comment thread neural_lam/models/module.py
Comment thread neural_lam/models/module.py
Comment thread neural_lam/models/module.py
# makes the checkpoint self-describing: it carries model, graph_name,
# hidden_dim, etc. so the caller can reconstruct the exact forecaster
# architecture from the checkpoint alone.
self.save_hyperparameters(ignore=["datastore", "forecaster"])
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The placement is intentional. args is still bound when save_hyperparameters runs at line 85, so the namespace is captured. Running it earlier would capture the default values for loss/lr/etc. before the unpack at lines 56-68 has had a chance to override them, that's exactly the legacy-checkpoint regression the unpack block exists to fix (see comment at lines 50-55). The ignore list is also required because datastore and forecaster aren't serializable as hparams.

"""
return x.unsqueeze(0).expand(batch_size, -1, -1)

@abstractmethod
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return type is already explicit at line 92 — tuple[torch.Tensor, Optional[torch.Tensor]]. The other files in neural_lam/models/ (e.g. base_graph_model.py, interaction_net.py) use the same compact inline-Returns: style as the current docstring, so reformatting just this one to NumPy-style would make the module less consistent, not more. Happy to switch the whole module's style in a separate PR if there's a project preference by @joeloskarsson but right now , I would rather prefer timely merging of this PR as much depends on this PR !

Comment thread neural_lam/models/__init__.py Outdated
Comment on lines 2 to 3
from .base_graph_model import BaseGraphModel
from .base_hi_graph_model import BaseHiGraphModel
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's nothing to deprecate here, base_graph_model.py and base_hi_graph_model.py are still in neural_lam/models/, not in any neural_lam.models.graph subpackage. The current __init__.py already re-exports both names, so existing from neural_lam.models import BaseGraphModel imports continue to work. If a relocation lands in a future PR we can add a shim then and at that point the right pattern is a module-level __getattr__ (PEP 562), not a top-level warnings.warn that fires on every package import regardless of whether the deprecated name is actually accessed.

"graph_lam": GraphLAM,
"hi_lam": HiLAM,
"hi_lam_parallel": HiLAMParallel,
}
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The proposed __all__ includes Forecaster and StepPredictor, but neither is currently imported in this __init__.py, so from neural_lam.models import * would fail with AttributeError. We'd need to add those imports first.

Beyond that, no other __init__.py in the project defines __all__ , so adding it just here introduces an inconsistency. If the project wants explicit export lists for docs tooling, it should be a coordinated sweep across all __init__.py files @joeloskarsson would love to here your opinion on this. 😁 I lean towards doing this in a future PR!

@Sir-Sloth-The-Lazy
Copy link
Copy Markdown
Contributor Author

Sir-Sloth-The-Lazy commented Apr 28, 2026

@GiGiKoneti Thanks for taking the time to review. I want to engage with these in good faith,
but on a careful read I'm having trouble seeing the reasoning behind several of them, so I'd
appreciate it if you could take another look 😊. To summarise where I'm getting stuck:


save_hyperparameters placement — Moving the call to before the args unpack would cause
self.hparams.loss/lr/etc. to capture the parameter defaults rather than the values from
args, which is exactly the legacy-checkpoint regression the unpack block is there to prevent.
Could you walk me through the scenario where the current placement causes a problem?

self.hparams.get("metrics_watch", [])metrics_watch is a named parameter in
__init__ with a default of None (resolved to [] at line 76), and save_hyperparameters
captures every signature parameter except those explicitly ignored. So
self.hparams.metrics_watch is always present. Is there a code path where it isn't?

sync_dist=False, rank_zero_only=True in training_steptraining_step runs on
every rank with a per-rank batch_loss. sync_dist=True all-reduces across ranks so the
logged value reflects the loss over the full effective batch. Switching to
rank_zero_only=True would log only rank 0's local slice. The deadlock concern that motivated
the change in aggregate_and_plot_metrics doesn't apply here because that call is inside
if is_global_zero with values pre-aggregated, whereas training_step is called on all
ranks. Could you clarify what failure mode you're trying to address?

forward return type — The annotation tuple[torch.Tensor, Optional[torch.Tensor]] is
already on the method, and the docstring already documents both outputs and their shapes. Was
there something specific you'd like changed beyond the docstring style? (Note that the
surrounding files in models/ use the compact inline Returns: style, not NumPy-style.)

__init__.py backward-compat shimBaseGraphModel/BaseHiGraphModel haven't been
moved; they still live at neural_lam.models.base_graph_model /
neural_lam.models.base_hi_graph_model, and __init__.py re-exports both. There's no
neural_lam.models.graph subpackage, so from .graph import … would fail. What relocation
are you assuming?

__all__ export list — The proposed list includes Forecaster and StepPredictor, but
those aren't currently imported in __init__.py, so from neural_lam.models import * would
raise AttributeError. Also, no other __init__.py in the project defines __all__, so
this would be a one-off inconsistency. Is there a specific tooling need driving this?

Extracting per_var_std init into a helper — Happy to do this if you feel strongly @joeloskarsson (would also ask you for your opinion please 🙏 on this, sorry for the hassle), but
the block is used exactly once and already has a leading comment. I'd lean toward leaving it
inline unless there's a concrete readability or reuse benefit you have in mind.


Happy to make any of these changes if I'm missing something, just want to make sure we're
aligned on the reasoning before churning the PR.😁

@GiGiKoneti
Copy link
Copy Markdown
Contributor

@GiGiKoneti Thanks for taking the time to review. I want to engage with these in good faith, but on a careful read I'm having trouble seeing the reasoning behind several of them, so I'd appreciate it if you could take another look 😊. To summarise where I'm getting stuck:

save_hyperparameters placement — Moving the call to before the args unpack would cause self.hparams.loss/lr/etc. to capture the parameter defaults rather than the values from args, which is exactly the legacy-checkpoint regression the unpack block is there to prevent. Could you walk me through the scenario where the current placement causes a problem?

self.hparams.get("metrics_watch", [])metrics_watch is a named parameter in __init__ with a default of None (resolved to [] at line 76), and save_hyperparameters captures every signature parameter except those explicitly ignored. So self.hparams.metrics_watch is always present. Is there a code path where it isn't?

sync_dist=False, rank_zero_only=True in training_steptraining_step runs on every rank with a per-rank batch_loss. sync_dist=True all-reduces across ranks so the logged value reflects the loss over the full effective batch. Switching to rank_zero_only=True would log only rank 0's local slice. The deadlock concern that motivated the change in aggregate_and_plot_metrics doesn't apply here because that call is inside if is_global_zero with values pre-aggregated, whereas training_step is called on all ranks. Could you clarify what failure mode you're trying to address?

forward return type — The annotation tuple[torch.Tensor, Optional[torch.Tensor]] is already on the method, and the docstring already documents both outputs and their shapes. Was there something specific you'd like changed beyond the docstring style? (Note that the surrounding files in models/ use the compact inline Returns: style, not NumPy-style.)

__init__.py backward-compat shimBaseGraphModel/BaseHiGraphModel haven't been moved; they still live at neural_lam.models.base_graph_model / neural_lam.models.base_hi_graph_model, and __init__.py re-exports both. There's no neural_lam.models.graph subpackage, so from .graph import … would fail. What relocation are you assuming?

__all__ export list — The proposed list includes Forecaster and StepPredictor, but those aren't currently imported in __init__.py, so from neural_lam.models import * would raise AttributeError. Also, no other __init__.py in the project defines __all__, so this would be a one-off inconsistency. Is there a specific tooling need driving this?

Extracting per_var_std init into a helper — Happy to do this if you feel strongly @joeloskarsson (would also ask you for your opinion please 🙏 on this, sorry for the hassle), but the block is used exactly once and already has a leading comment. I'd lean toward leaving it inline unless there's a concrete readability or reuse benefit you have in mind.

Happy to make any of these changes if I'm missing something, just want to make sure we're aligned on the reasoning before churning the PR.😁

Hey @Sir-Sloth-The-Lazy, thanks for the detailed breakdown!

You're absolutely right on the training_step DDP sync, the save_hyperparameters placement for the legacy unpack, and the project-wide docstring style. I agree we should prioritize consistency with the existing repo. I'll drop those suggestions..

To answer your question on metrics_watch:

So self.hparams.metrics_watch is always present. Is there a code path where it isn't?

Yes, there's exactly one: load_from_checkpoint("legacy_model.ckpt").

When a user loads a checkpoint trained on an older version of neural-lam (pre-refactor), PTL restores the hparams from the dictionary saved in the file. Because old models were saved before metrics_watch was added to the signature, that key will be missing from self.hparams. Accessing it via attribute will trigger an AttributeError and crash the evaluation loop. Using .get("metrics_watch", []) acts as a simple backward-compatibility shim for those legacy models.

Also, the self.trainer check is just to prevent crashes if users run the module in a pure PyTorch script or notebook w/o a Lightning Trainer attached (where self.trainer is None).

Other than that, the refactor is looking great and I'm fully on board with the architecture changes! Let's get this merged once those robustness checks are in. Kudos to you.

@Sir-Sloth-The-Lazy
Copy link
Copy Markdown
Contributor Author

Sir-Sloth-The-Lazy commented Apr 28, 2026

@GiGiKoneti Did a quick empirical check on this 😄. There's already an existing test at tests/test_prediction_model_classes.py::test_forecaster_module_old_checkpoint that loads a checkpoint in the legacy {"args": Namespace(...)} format and asserts loaded_model.hparams.loss, .lr, .val_steps_to_log work via attribute access (lines 299-301). It passes.

I also ran a one-off probe specifically for metrics_watch on a synthetic legacy checkpoint:

>>> loaded.hparams.keys()
['args', 'config', 'create_gif', 'loss', 'lr', 'metrics_watch',
 'n_example_pred', 'restore_opt', 'val_steps_to_log', 'var_leads_metrics_watch']
>>> loaded.hparams.metrics_watch
[]
>>> "metrics_watch" in loaded.hparams
True

The reason no AttributeError fires: self.hparams isn't deserialized from the checkpoint dict — load_from_checkpoint calls __init__ again, the unpack at lines 56-68 pulls values out of args, and save_hyperparameters rebuilds self.hparams from the current __init__ signature. So self.hparams.metrics_watch is always present regardless of what shape the on-disk hparams had.

So I'd like to keep the attribute access — the .get() shim would be guarding against a failure that the existing mechanic already prevents. On self.trainer: same conclusion, on_validation_epoch_end is a Lightning hook only invoked from a Trainer, so self.trainer is always bound there.

Let me know if you spot something I've missed!👌🏻

@leifdenby
Copy link
Copy Markdown
Member

First @Sir-Sloth-The-Lazy - thank you so much for sticking with this and all your hard work. I have been re-reading through your code and re-reminding myself of the main change here. That prompted me to dig out some code I wrote to parse the class inheritance tree and render it out. I can add the script I used to generate the visualisation below if people think that is useful (I was thinking it could be used for generating documentation).

The main change seems to be the introduction of the abstract base classes Forecaster and StepPredictor. Those seem like really good ideas to me (as we've discussed many times). I will continue my review tomorrow and focus on whether I feel like the functionality the currently sits in each is defined where I expect it should be based on this class hierarchy.

Hope this is along the lines on what you were hoping for from my angle on reviewing this substantial piece of work :)

main refactor/model-class-hierarchy
main refactor

@GiGiKoneti
Copy link
Copy Markdown
Contributor

@Sir-Sloth-The-Lazy You're right, I missed that load_from_checkpoint re-invokes __init__ rather than directly deserializing hparams, so the current signature defaults always populate. Thanks for walking me through the mechanic and for the patience. Apologies I have wasted you time. Thanks I have learnt alot from your work. Nothing else from my side..the refactor is solid. 👍

@Sir-Sloth-The-Lazy
Copy link
Copy Markdown
Contributor Author

@leifdenby thank you so much for you kind words and taking out time to review this work !

Hope this is along the lines on what you were hoping for from my angle on reviewing this substantial piece of work :)

yes absolutely ! Looking forward to it😄

@Sir-Sloth-The-Lazy
Copy link
Copy Markdown
Contributor Author

@GiGiKoneti Please don't apologies for reviewing my work and sounding your doubts , as you said you learnt something from our conversation as long as that is happening , I would be really happy to answer anymore of your question ! Thank you for your time in making this work better ! 😄

Copy link
Copy Markdown
Member

@leifdenby leifdenby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! I have only one minor change I think we should do, and a suggestion that wouldn't be alot of work:

  1. I think we should put clear docstrings on every nn.Module.forward(...) method that defines only what the function does (as its function and responsibility), but also details the shape of every call argument (and what each dimension means).
  2. Given that this PR introduces clearer separation between classes by introduces a hierarchy of classes building on base-classes I think we should consider reorganising the python module structure here too, to more closely reflect what the different torch.nn.Module derived classes actually do and where they sit in the hierarchy. Currently, neural_lam/models/ has a flat structure where we just have all modules next to each other, that makes it hard to reason about where to go. I would suggest a layout like this:
neural_lam/models/
  __init__.py
  module.py                         # ForecasterModule
  forecasters/
    __init__.py
    base.py                         # Forecaster
    autoregressive.py               # ARForecaster
  step_predictors/
    __init__.py
    base.py                         # StepPredictor
    graph/
      __init__.py
      base.py                       # BaseGraphModel
      hierarchical.py               # BaseHiGraphModel
      graph_lam.py                  # GraphLAM
      hi_lam.py                     # HiLAM
      hi_lam_parallel.py            # HiLAMParallel

def predicts_std(self) -> bool:
return self.predictor.predicts_std

def forward(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be good to add numpy-style docstrings here, detailing what each argument is and the shape of it, because this function is they main entrypoint to the nn.Module and having a clear interface defined here will make future development easier.

Actually, I would advocate that we start using jaxtyping with beartype to define the meaning of the dims. But we should do that in a separate PR to not hold up this work further. But we should at

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree that numpy-style docstrings for these would be good. Agree that jaxtyping could be a nice addition, but should be saved for a future PR.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are nice docstrings added now. For the tensor shapes I would however suggest to be consistent with either 1) always write out what each dimension is in the description of each tensor (now this is mostly done, but nor for all) or 2) collect all dimensions explanations in one place of the docstring, so that you just list the tensors and their shapes first, and then one can look up what each dimension is in another place of the docstring.

super().__init__(args, config=config, datastore=datastore)
def __init__(
self,
config: NeuralLAMConfig,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great work on making all these architecture parameters explicitly rather than passed through args - is config: NeuralLAMConfig required here? Maybe we shouldn't propagate that down below the ForecasterModule or at least not all the way down to the graph-based models I think

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this was for the clamping-options. Now they are fed separately, which seems good

…ple config

Address PR mllam#208 review comments from @leifdenby:

- Reorganise neural_lam/models to mirror the new class hierarchy:
  forecasters/{base,autoregressive}.py, step_predictors/base.py, and
  step_predictors/graph/{base,hierarchical,graph_lam,hi_lam,hi_lam_parallel}.py.
  ForecasterModule moves to neural_lam/models/module.py. The package
  __init__.py re-exports the public API so callers can use
  `from neural_lam.models import ARForecaster, GraphLAM, ...` unchanged.

- Convert the four forward(...) docstrings on the nn.Module hierarchy
  (Forecaster, ARForecaster, StepPredictor, BaseGraphModel) to
  numpy-style with Parameters/Returns sections describing each
  argument's shape and the meaning of every dimension.

- Decouple NeuralLAMConfig from the predictor stack. StepPredictor and
  the graph-based predictors no longer accept `config: NeuralLAMConfig`
  — they take `output_clamping_lower` / `output_clamping_upper` dicts
  directly. NeuralLAMConfig now stops at ForecasterModule. Construction
  sites in train_model.py and the tests resolve the two dicts from
  config.training.output_clamping at the call site.
@Sir-Sloth-The-Lazy
Copy link
Copy Markdown
Contributor Author

I have made the changes as asked @leifdenby @joeloskarsson , looking forward to your review

…turn type

- Drop the unreachable (d_f,) shape from Forecaster.forward's pred_std
  docstring. The only concrete implementation (ARForecaster) returns
  (B, pred_steps, num_grid_nodes, d_f) or None — never a per-feature
  vector — so the wider contract was misleading.

- Fix ARForecaster.forward return annotation:
  tuple[torch.Tensor, torch.Tensor] -> tuple[torch.Tensor, Optional[torch.Tensor]],
  reflecting that pred_std is None when the wrapped predictor does not
  output an std.
Comment thread neural_lam/models/forecasters/base.py Outdated
----------
init_states : torch.Tensor
Shape ``(B, 2, num_grid_nodes, d_f)``. The two initial states
``[X_{t-1}, X_t]`` used to seed the forecast. Dims: ``B`` is
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have never heard the terminology "seed the forecast", so maybe just writing something like "start the forecast from" is better to avoid confusion.

@joeloskarsson
Copy link
Copy Markdown
Collaborator

There are sadly again also some conflicts to resolve. I have discussed with other maintainers that we need to now pause merging other PRs so this can make it in, so we don't have to keep resolving conflicts here.

Convert every remaining informal/Google-style docstring in models/,
interaction_net, metrics, utils, vis, datastore/base, and
custom_loggers to numpy-style (Parameters\n---------- /
Returns\n-------). Also fill in missing inline Dims: explanations
for all tensor parameters and return values so dimension letters are
documented consistently throughout.
@Sir-Sloth-The-Lazy
Copy link
Copy Markdown
Contributor Author

@leifdenby and @joeloskarsson I have tried to add numpy style docstrings to the codebase , please 🙏 give your feedback on this , whether we should go forward with this or revert to the way it was before this commit. With only the changes suggested by Joel. 😄

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Refactor model class hierarchy