-
Notifications
You must be signed in to change notification settings - Fork 12
Multi stage pipeline parallelism support #418
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
919c6bd
68b02aa
c9677b1
f885ca8
a915805
bc7089b
07ef847
7f0518b
e7cb524
56660fb
d24c7bc
534bf5a
30bb0ac
8aa7302
aaafde3
0e527a9
a8913e0
92a0c2c
ad9dfb3
6080f19
b88a060
32206df
292fcac
2b821aa
7537637
3a85058
d0499d1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,6 +15,8 @@ | |
| from torch.optim import Optimizer | ||
| from torch.optim.lr_scheduler import LRScheduler | ||
|
|
||
| from modalities.optimizers.optimizer_list import OptimizersList | ||
|
|
||
|
|
||
| class StatefulComponents(Enum): | ||
| MODEL = "model" | ||
|
|
@@ -34,15 +36,18 @@ class AppState(Stateful): | |
| https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html | ||
| """ | ||
|
|
||
| def __init__(self, model: nn.Module, optimizer: Optimizer, lr_scheduler: Optional[LRScheduler] = None): | ||
| def __init__( | ||
| self, model: nn.Module | list[nn.Module], optimizer: Optimizer, lr_scheduler: Optional[LRScheduler] = None | ||
| ): | ||
| """Initializes the AppState object. | ||
|
|
||
| Args: | ||
| model (nn.Module): The model can be either a non-sharded model, FSDP1 or FSDP2 model. | ||
| model (nn.Module | list[nn.Module]): The model or model parts can be either | ||
| a non-sharded model, FSDP1 or FSDP2 model. | ||
| optimizer (Optimizer): The optimizer can be either a non-sharded optimizer, FSDP1 or FSDP2 optimizer. | ||
| lr_scheduler (Optional[LRScheduler], optional): The lr scheduler used during training. Defaults to None. | ||
| """ | ||
| self._model = model | ||
| self._model_parts = list(model) if isinstance(model, list) else [model] | ||
| self._optimizer = optimizer | ||
| self._lr_scheduler = lr_scheduler | ||
| self._is_loaded = False | ||
|
|
@@ -56,8 +61,8 @@ def is_loaded(self) -> bool: | |
| return self._is_loaded | ||
|
|
||
| @property | ||
| def model(self) -> nn.Module: | ||
| return self._model | ||
| def model_parts(self) -> list[nn.Module]: | ||
| return self._model_parts | ||
|
|
||
| @property | ||
| def optimizer(self) -> Optimizer: | ||
|
|
@@ -153,15 +158,18 @@ def get_state_dict(app_state: AppState) -> dict[str, Any]: | |
| class ModelStateRetriever(StateRetrieverIF): | ||
| @staticmethod | ||
| def get_state_dict(app_state: AppState) -> dict[str, Any]: | ||
| """Returns the state dict of the model in the AppState object. | ||
| """Returns the flattened state dicts of the model parts in the AppState object. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. flattened keys or tensors?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess flattened keys. Though, I'm not sure if I would call it that. We are mapping from a list of dicts to a single dict. Flattened keys sounds more like flattening a dict of dicts. |
||
|
|
||
| Args: | ||
| app_state (AppState): The app_state object containing the model. | ||
|
|
||
| Returns: | ||
| dict[str, Any]: The state dict of the model in the AppState object. | ||
| """ | ||
| return get_model_state_dict(model=app_state.model) | ||
| state_dicts = list(map(get_model_state_dict, app_state.model_parts)) | ||
| state_dict_keys = sum((list(sd.keys()) for sd in state_dicts), []) | ||
| assert len(state_dict_keys) == len(set(state_dict_keys)), "State dict keys are not unique across model parts." | ||
| return {k: v for sd in state_dicts for k, v in sd.items()} | ||
|
|
||
| @staticmethod | ||
| def load_state_dict_(app_state: AppState, state_dict: dict[str, Any]) -> None: | ||
|
|
@@ -171,7 +179,8 @@ def load_state_dict_(app_state: AppState, state_dict: dict[str, Any]) -> None: | |
| app_state (AppState): The app_state object containing the model. | ||
| state_dict (dict[str, Any]): The state dict to load into the model. | ||
| """ | ||
| set_model_state_dict(model=app_state.model, model_state_dict=state_dict, options=StateDictOptions(strict=False)) | ||
| for model in app_state.model_parts: | ||
| set_model_state_dict(model=model, model_state_dict=state_dict, options=StateDictOptions(strict=False)) | ||
|
|
||
|
|
||
| class OptimizerStateRetriever(StateRetrieverIF): | ||
|
|
@@ -185,13 +194,17 @@ def get_state_dict(app_state: AppState) -> dict[str, Any]: | |
| Returns: | ||
| dict[str, Any]: The state dict of the optimizer in the AppState object. | ||
| """ | ||
| sd = get_optimizer_state_dict( | ||
| model=app_state.model, | ||
| optimizers=app_state.optimizer, | ||
| # NOTE: Flattening is required for pipeline parallelism to work correctly. | ||
| # see https://github.com/pytorch/torchtitan/blob/b291ad662493b63d25b038a30a915082d3617baf/torchtitan/components/checkpoint.py#L193-L214 | ||
| options=StateDictOptions(flatten_optimizer_state_dict=True), | ||
| ) | ||
| if isinstance(app_state.optimizer, OptimizersList): | ||
| sd = app_state.optimizer.state_dict() | ||
| else: | ||
| assert len(app_state.model_parts) == 1, "Expected a single model part for non-OptimizersList optimizer." | ||
| sd = get_optimizer_state_dict( | ||
| model=app_state.model_parts[0], | ||
| optimizers=app_state.optimizer, | ||
| # NOTE: Flattening is required for pipeline parallelism to work correctly. | ||
| # see https://github.com/pytorch/torchtitan/blob/b291ad662493b63d25b038a30a915082d3617baf/torchtitan/components/checkpoint.py#L193-L214 | ||
| options=StateDictOptions(flatten_optimizer_state_dict=True), | ||
| ) | ||
|
Comment on lines
+199
to
+207
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we remove this, since in case of PP we now always have an optimizer list which takes care of the flattening?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sounds good! |
||
| return sd | ||
|
|
||
| @staticmethod | ||
|
|
@@ -202,12 +215,16 @@ def load_state_dict_(app_state: AppState, state_dict: dict[str, Any]) -> None: | |
| app_state (AppState): The app_state object containing the optimizer. | ||
| state_dict (dict[str, Any]): The state dict to load into the optimizer. | ||
| """ | ||
| set_optimizer_state_dict( | ||
| model=app_state.model, | ||
| optimizers=app_state.optimizer, | ||
| optim_state_dict=state_dict, | ||
| options=StateDictOptions(flatten_optimizer_state_dict=True), | ||
| ) | ||
| if isinstance(app_state.optimizer, OptimizersList): | ||
| app_state.optimizer.load_state_dict(state_dict) | ||
| else: | ||
| assert len(app_state.model_parts) == 1, "Expected a single model part for non-OptimizersList optimizer." | ||
| set_optimizer_state_dict( | ||
| model=app_state.model_parts[0], | ||
| optimizers=app_state.optimizer, | ||
| optim_state_dict=state_dict, | ||
| options=StateDictOptions(flatten_optimizer_state_dict=True), | ||
| ) | ||
|
|
||
|
Comment on lines
+221
to
228
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Given your comment above, I assume the else case can also be removed here? |
||
|
|
||
| class LRSchedulerStateRetriever(StateRetrieverIF): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,7 @@ | ||
| from typing import Any, Type, TypeVar | ||
|
|
||
| from pydantic import BaseModel | ||
| from pydantic import AliasChoices, BaseModel | ||
| from pydantic.fields import FieldInfo | ||
|
|
||
| from modalities.registry.registry import Registry | ||
| from modalities.util import print_rank_0 | ||
|
|
@@ -164,30 +165,53 @@ def _instantiate_component_config(self, component_key: str, variant_key: str, co | |
| config_dict=config_dict, | ||
| component_config_type=component_config_type, | ||
| ) | ||
| comp_config = component_config_type(**config_dict, strict=True) | ||
| comp_config = component_config_type.model_validate(config_dict, extra="forbid") | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch! |
||
| return comp_config | ||
|
|
||
| def _assert_valid_config_keys( | ||
| self, component_key: str, variant_key: str, config_dict: dict, component_config_type: Type[BaseModelChild] | ||
| ) -> None: | ||
| required_keys = [] | ||
| optional_keys = [] | ||
| for key, field in component_config_type.model_fields.items(): | ||
| # Collect required and optional keys, including aliases if defined. | ||
| required_keys: list[str] = [] | ||
| optional_keys: list[str] = [] | ||
| # Map aliases to canonical field names for clearer error messages. | ||
| alias_to_field: dict[str, str] = {} | ||
|
|
||
| for field_name, field in component_config_type.model_fields.items(): | ||
| names_for_field = self._parse_str_aliases(alias_to_field, field_name, field) | ||
| if field.is_required(): | ||
| required_keys.append(key) | ||
| required_keys.extend(names_for_field) | ||
| else: | ||
| optional_keys.append(key) | ||
| optional_keys.extend(names_for_field) | ||
|
|
||
| invalid_keys = [] | ||
| for key in config_dict.keys(): | ||
| if key not in required_keys and key not in optional_keys: | ||
| invalid_keys.append(key) | ||
| all_valid_keys = set(required_keys) | set(optional_keys) | ||
|
|
||
| invalid_keys = [key for key in config_dict.keys() if key not in all_valid_keys] | ||
| if len(invalid_keys) > 0: | ||
| message = f"Invalid keys {invalid_keys} for config `{component_key}.{variant_key}`" | ||
| message += f" of type {component_config_type}:\n{config_dict}\n" | ||
| message += f"Required keys: {required_keys}\nOptional keys: {optional_keys}" | ||
| if alias_to_field: | ||
| message += f"Alias to field mapping: {alias_to_field}\n" | ||
| message += f"Required keys (including aliases): {required_keys}\n" | ||
| message += f"Optional keys (including aliases): {optional_keys}\n" | ||
| raise ValueError(message) | ||
|
|
||
| def _parse_str_aliases(self, alias_to_field: dict[str, str], field_name: str, field: FieldInfo) -> set[str]: | ||
| names_for_field = {field_name} | ||
| if field.alias and field.alias != field_name: | ||
| names_for_field.add(field.alias) | ||
| alias_to_field[field.alias] = field_name | ||
| if field.validation_alias and field.validation_alias != field_name: | ||
| if isinstance(field.validation_alias, str): | ||
| names_for_field.add(field.validation_alias) | ||
| alias_to_field[field.validation_alias] = field_name | ||
| elif isinstance(field.validation_alias, AliasChoices): | ||
| for alias in field.validation_alias.choices: | ||
| if isinstance(alias, str): | ||
| names_for_field.add(alias) | ||
| alias_to_field[alias] = field_name | ||
| return names_for_field | ||
|
|
||
| def _instantiate_component(self, component_key: str, variant_key: str, component_config: BaseModel) -> Any: | ||
| component_type: Type = self.registry.get_component(component_key, variant_key) | ||
| component_config_dict = self._base_model_to_dict(component_config) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think, creating a new list here is saver in case an outside context accidentally changes the input list.