transpose forcing data in WeatherDataset __getitem__#556
transpose forcing data in WeatherDataset __getitem__#556Anushka1324 wants to merge 4 commits intomllam:mainfrom
Conversation
|
Hi! I see there’s already an active PR (#556) addressing this. I’d like to contribute by improving the robustness of the solution:
Let me know if that would be helpful — happy to extend the current implementation. |
sadamov
left a comment
There was a problem hiding this comment.
The create_dataarray_from_tensor fix duplicates #309, which addresses the same hardcoded .state_feature bug.
The one non-duplicate contribution is the transpose in __getitem__, but it contains a bug: please scope the PR title and description to that change only if you want to continue.
| da_forcing_windowed = da_forcing_windowed.transpose( | ||
| "time", "grid_index", "forcing_feature_windowed" | ||
| ) |
There was a problem hiding this comment.
When self.da_forcing is None, _build_item_dataarrays produces an empty DataArray with dim "forcing_feature" (line 463). The transpose here requests "forcing_feature_windowed", which only exists after the .stack() on line 455 (i.e. when forcing data is present). xarray will raise ValueError on the no-forcing path.
| da_forcing_windowed = da_forcing_windowed.transpose( | |
| "time", "grid_index", "forcing_feature_windowed" | |
| ) | |
| if "forcing_feature_windowed" in da_forcing_windowed.dims: | |
| da_forcing_windowed = da_forcing_windowed.transpose( | |
| "time", "grid_index", "forcing_feature_windowed" | |
| ) | |
| else: | |
| da_forcing_windowed = da_forcing_windowed.transpose( | |
| "time", "grid_index", "forcing_feature" | |
| ) |
64b8cc4 to
d5fca7a
Compare
|
@Anushka1324 please fix the pre-commits |
d5fca7a to
1f3dbb0
Compare
done |
|
sadamov
left a comment
There was a problem hiding this comment.
CHANGELOG entry missing.
| # For forcing feature dimension was renamed in _build_item_dataarrays | ||
| if "forcing_feature_windowed" in da_forcing_windowed.dims: | ||
| da_forcing_windowed = da_forcing_windowed.transpose( | ||
| "time", "grid_index", "forcing_feature_windowed" | ||
| ) | ||
| else: | ||
| da_forcing_windowed = da_forcing_windowed.transpose( | ||
| "time", "grid_index", "forcing_feature" | ||
| ) |
There was a problem hiding this comment.
The if/else on dim name is fragile — a name change silently falls through to the wrong branch. xarray ... handles both the windowed and empty-forcing cases:
| # For forcing feature dimension was renamed in _build_item_dataarrays | |
| if "forcing_feature_windowed" in da_forcing_windowed.dims: | |
| da_forcing_windowed = da_forcing_windowed.transpose( | |
| "time", "grid_index", "forcing_feature_windowed" | |
| ) | |
| else: | |
| da_forcing_windowed = da_forcing_windowed.transpose( | |
| "time", "grid_index", "forcing_feature" | |
| ) | |
| da_forcing_windowed = da_forcing_windowed.transpose( | |
| "time", "grid_index", ... | |
| ) |
Describe your changes
This PR addresses a bug where raw Xarray DataArrays were converted to PyTorch tensors without enforcing a consistent dimension order in the WeatherDataset class.
Changes:
Updated the getitem method to explicitly transpose state and forcing DataArrays to the expected (time, grid_index, feature) dimension order before tensor conversion.
Added a conditional check to safely handle cases where forcing data is absent, preventing a ValueError during transposition when the forcing_feature_windowed dimension does not exist.
Closes #536
Type of change
Checklist before requesting a review
pullwith--rebaseoption if possible).Checklist for reviewers
Each PR comes with its own improvements and flaws. The reviewer should check the following:
Author checklist after completed review
reflecting type of change (add section where missing):
Checklist for assignee