Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
feabe5a
refactor: removed unnecessary all-reduce ops and improved accuracy of…
le1nux Dec 6, 2025
a4775ad
chore: added documentation and renamed pytorch rms norm key
le1nux Dec 6, 2025
719e35e
feat: added timestamp and dtype to debugged model for input/output ac…
le1nux Dec 6, 2025
03354c1
feat: steppable component can now perform backward pass and optimizer…
le1nux Dec 6, 2025
9e661dc
feat: added fused and foreach options to Adam and AdamW optimizers
le1nux Dec 13, 2025
3518d00
refactor: profilers are now components
le1nux Dec 19, 2025
02e8fdd
feat: logger outputs now rank info
le1nux Dec 19, 2025
37b25d8
refactor: step information in profiling now part of the config instea…
le1nux Dec 19, 2025
52924ea
refactor: added new profiling setup to the profiling tutorial's config
le1nux Dec 19, 2025
3cfa305
refactor: experiments_root_path now passed in from outside
le1nux Dec 21, 2025
361ddc5
feat: profiling now available also in training loop
le1nux Dec 21, 2025
68dd9d2
feat: added memory profiling to kernel profiler
le1nux Dec 22, 2025
e4fe4b0
refactor: added experiments_root_path to warmstart API and improved e…
le1nux Dec 29, 2025
fbab937
refactor: refactored wamstart tutorial scripts
le1nux Dec 29, 2025
6b359c9
chore: Merge remote-tracking branch 'refs/remotes/origin/main'
le1nux Dec 30, 2025
93bd721
chore: Merge branch 'main' into monitoring_improvements
le1nux Dec 30, 2025
6bef8b0
fix: HSDP was not applied at all due to wrong condition check
le1nux Jan 4, 2026
a400a7c
refactor: allow data_parallel_replicate_degree to be -1 for auto-calc…
le1nux Jan 6, 2026
6d0e864
chore: improved device mesh logging
le1nux Jan 16, 2026
b893532
fix: in case of tp, we DP_SHARD > 1. Fixed the validation logic accor…
le1nux Jan 16, 2026
85388cd
fix: tp can now be used with dp_shard or dp_replicate
le1nux Jan 17, 2026
eb747fd
chore: improved tokenizer vocabulary warning
le1nux Jan 23, 2026
6b058ae
chore: Merge remote-tracking branch 'refs/remotes/origin/main'
le1nux Jan 23, 2026
9420c0b
chore: removed fixme since it was invalid
le1nux Jan 23, 2026
0e153ef
chore: Merge branch 'main' into monitoring_improvements
le1nux Jan 23, 2026
212bd12
fix: fixed merge conflict bug
le1nux Jan 25, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -172,4 +172,7 @@ config_files/instruction_tuning
data/lorem_ipsum_instruct.jsonl
tutorials/scaling_up/logs*
tutorials/scaling_up/experiments_old/*
results/*
results/*
tutorials/einsum_transformer/experiments/*
tutorials/warmstart/experiments/*

145 changes: 70 additions & 75 deletions src/modalities/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ def main() -> None:
help="Path to the YAML training config file.",
)
@click.option(
"--test_comm",
is_flag=True,
default=False,
help="If set, run a communication test before training.",
"--experiments_root_path",
type=click_pathlib.Path(exists=True),
required=True,
help="Path to the root directory where experiment folders will be created.",
)
@click.option(
"--experiment_id",
Expand All @@ -71,61 +71,51 @@ def main() -> None:
default=None,
help="Optional path to a folder where error logs will be written.",
)
@click.option(
"--test_comm",
is_flag=True,
default=False,
help="If set, run a communication test before training.",
)
def CMD_entry_point_run_modalities(
config_file_path: Path,
test_comm: bool = False,
experiments_root_path: Path,
experiment_id: Optional[str] = None,
error_log_folder: Optional[Path] = None,
test_comm: bool = False,
):
"""Entrypoint to run the model training.

Args:
config_file_path (Path): Path to the YAML training config file.
test_comm (bool): If set, run a communication test before training.
experiments_root_path (Path): Path to the root directory where experiment folders will be created.
experiment_id (Optional[str]): Optional experiment ID to use for this run.
If not provided it will be generated. Default is None.
error_log_folder (Optional[Path]): Optional path to a folder where error logs will be written.
test_comm (bool): If set, run a communication test before training.
"""

def _format_exception_as_json(e: Exception, environment: dict[str, Any]) -> str:
# Format an exception into a structured JSON string with error message, type, and stack trace.
error = {
"error": str(e),
"type": type(e).__name__,
"stacktrace": traceback.format_exception(type(e), e, e.__traceback__),
}

return json.dumps({"environment": environment, "error": error}, indent=2)

try:
with CudaEnv(process_group_backend=ProcessGroupBackendType.nccl):
if test_comm:
print_rank_0("Running communication test...")
run_communication_test()
print_rank_0("Communication test succeeded.")

main_obj = Main(config_file_path, experiment_id=experiment_id)
main_obj = Main(config_file_path, experiments_root_path=experiments_root_path, experiment_id=experiment_id)
components = main_obj.build_components(components_model_type=TrainingComponentsInstantiationModel)
main_obj.run(components)
except Exception as e:
if error_log_folder is not None:
environment = {
"rank": int(os.environ["RANK"] if "RANK" in os.environ else -1),
"local_rank": int(os.environ["LOCAL_RANK"] if "LOCAL_RANK" in os.environ else -1),
"world_size": int(os.environ["WORLD_SIZE"] if "WORLD_SIZE" in os.environ else -1),
"hostname": socket.gethostname(),
}
error_log_folder = (
error_log_folder / f"error_logs_{environment['hostname']}_{environment['local_rank']}.log"
)
error_log_folder.parent.mkdir(parents=True, exist_ok=True)
with open(error_log_folder, "w", encoding="utf-8") as f:
f.write(_format_exception_as_json(e, environment))

raise RuntimeError(f"An error occurred while running the training: {e}. ") from e
_exception_handling(e, error_log_folder)


@main.command(name="warmstart")
@click.option(
"--experiments_root_path",
type=click_pathlib.Path(exists=True),
required=True,
help="Path to the root directory where experiment folders will be created.",
)
@click.option(
"--config_file_path",
type=click_pathlib.Path(exists=True),
Expand All @@ -138,10 +128,22 @@ def _format_exception_as_json(e: Exception, environment: dict[str, Any]) -> str:
required=True,
help="Path to the file containing the model and optimizer checkpoint paths from the last successful checkpoint.",
)
def CMD_entry_point_warmstart_modalities(config_file_path: Path, last_checkpoint_info_file_path: Path):
@click.option(
"--error_log_folder",
type=click_pathlib.Path(),
default=None,
help="Optional path to a folder where error logs will be written.",
)
def CMD_entry_point_warmstart_modalities(
experiments_root_path: Path,
config_file_path: Path,
last_checkpoint_info_file_path: Path,
error_log_folder: Optional[Path] = None,
):
"""Entrypoint to run the model warmstart.

Args:
experiments_root_path (Path): Path to the root directory where experiment folders will be created.
config_file_path (Path): Path to the YAML warmstart config file.
last_checkpoint_info_file_path (Path): Path to the file containing the model and
optimizer checkpoint paths from the last successful checkpoint.
Expand All @@ -159,10 +161,15 @@ def get_last_checkpoint_resolver_fun(var_name: str, last_checkpoint_info_file_pa
get_last_checkpoint_resolver_fun, last_checkpoint_info_file_path=last_checkpoint_info_file_path
)
}
with CudaEnv(process_group_backend=ProcessGroupBackendType.nccl):
main_obj = Main(config_file_path, additional_resolver_funs=resolver_funs)
components = main_obj.build_components(components_model_type=TrainingComponentsInstantiationModel)
main_obj.run(components)
try:
with CudaEnv(process_group_backend=ProcessGroupBackendType.nccl):
main_obj = Main(
config_file_path, experiments_root_path=experiments_root_path, additional_resolver_funs=resolver_funs
)
components = main_obj.build_components(components_model_type=TrainingComponentsInstantiationModel)
main_obj.run(components)
except Exception as e:
_exception_handling(e, error_log_folder)


@main.command(name="generate_text")
Expand Down Expand Up @@ -705,54 +712,42 @@ def profile():
required=True,
help="Path to the experiment output directory.",
)
@click.option(
"--num_wait_steps",
type=int,
default=1,
show_default=True,
help="Number of wait steps to skip in profiling.",
)
@click.option(
"--num_warmup_steps",
type=int,
default=1,
show_default=True,
help="Number of warmup steps to skip in profiling. Already recording but dropping the data.",
)
@click.option(
"--num_measurement_steps",
type=int,
default=3,
show_default=True,
help="Number of steps to measure during profiling.",
)
@click.option(
"--profiled_ranks",
type=str,
default="0",
help="Comma-separated list of profiled ranks (must not have spaces), e.g. --profiled_ranks '2,4,8'",
)
def CMD_entry_point_run_train_step_profiler(
config_file_path: Path,
experiment_root_path: Path,
num_wait_steps: int,
num_warmup_steps: int,
num_measurement_steps: int,
profiled_ranks: str,
):
"""Run train step profiler and write result to JSON if RANK=0."""
profiled_ranks_list = [int(i) for i in profiled_ranks.split(",")] if profiled_ranks != "" else [0]
logger.info(f"Running distributed profiling on ranks {profiled_ranks_list}")

ModalitiesProfilerStarter.run_distributed(
config_file_path=config_file_path,
num_measurement_steps=num_measurement_steps,
num_wait_steps=num_wait_steps,
num_warmup_steps=num_warmup_steps,
experiment_root_path=experiment_root_path,
profiled_ranks=profiled_ranks_list,
)


def _format_exception_as_json(e: Exception, environment: dict[str, Any]) -> str:
# Format an exception into a structured JSON string with error message, type, and stack trace.
error = {
"error": str(e),
"type": type(e).__name__,
"stacktrace": traceback.format_exception(type(e), e, e.__traceback__),
}
return json.dumps({"environment": environment, "error": error}, indent=2)


def _exception_handling(e: Exception, error_log_folder: Path | None):
if error_log_folder is not None:
environment = {
"rank": int(os.environ["RANK"] if "RANK" in os.environ else -1),
"local_rank": int(os.environ["LOCAL_RANK"] if "LOCAL_RANK" in os.environ else -1),
"world_size": int(os.environ["WORLD_SIZE"] if "WORLD_SIZE" in os.environ else -1),
"hostname": socket.gethostname(),
}
error_log_folder = error_log_folder / f"error_logs_{environment['hostname']}_{environment['local_rank']}.log"
error_log_folder.parent.mkdir(parents=True, exist_ok=True)
with open(error_log_folder, "w", encoding="utf-8") as f:
f.write(_format_exception_as_json(e, environment))

raise RuntimeError(f"An error occurred while running the training: {e}. ") from e


if __name__ == "__main__":
main()
28 changes: 21 additions & 7 deletions src/modalities/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,8 @@ class AdamOptimizerConfig(BaseModel):
eps: float
weight_decay: float
weight_decay_groups_excluded: list[str]
# foreach: bool | None = None
# fused: bool | None = None
foreach: bool | None = None
fused: bool | None = None


class AdamWOptimizerConfig(BaseModel):
Expand All @@ -163,8 +163,8 @@ class AdamWOptimizerConfig(BaseModel):
eps: float
weight_decay: float
weight_decay_groups_excluded: list[str]
# foreach: bool | None = None
# fused: bool | None = None
foreach: bool | None = None
fused: bool | None = None


class DummyLRSchedulerConfig(BaseModel):
Expand Down Expand Up @@ -318,8 +318,17 @@ def validate_tp_mesh_existence(self) -> "GPT2ModelTPConfig":
raise ValueError(f"Device mesh {self.device_mesh=} has no defined mesh_dim_names.")
if ParallelismDegrees.TP.value not in mesh_dim_names:
raise ValueError(f"Tensor parallelism key '{ParallelismDegrees.TP.value}' not in {self.device_mesh=}")
if ParallelismDegrees.DP_REPLICATE.value in mesh_dim_names:
raise ValueError("data_parallel_replicate_degree > 1 cannot be used with Tensor Parallelism.")
if (
ParallelismDegrees.DP_SHARD.value in mesh_dim_names
and self.device_mesh[ParallelismDegrees.DP_SHARD.value].size() > 1
) and (
ParallelismDegrees.DP_REPLICATE.value in mesh_dim_names
and self.device_mesh[ParallelismDegrees.DP_REPLICATE.value].size() > 1
):
raise ValueError(
"Either dp_replicate_degree > 1 or data_parallel_shard_degree > 1 can be "
"used with Tensor Parallelism. Not both."
)
return self


Expand Down Expand Up @@ -508,13 +517,16 @@ class ParallelDegreeConfig(BaseModel):

def load_app_config_dict(
config_file_path: Path,
experiment_id: Optional[str] = None,
experiments_root_path: Path | None = None,
experiment_id: str | None = None,
additional_resolver_funs: Optional[dict[str, Resolver]] = None,
) -> dict[str, YAMLValue]:
"""Load the application configuration from the given YAML file.

Args:
config_file_path (Path): YAML config file.
experiments_root_path: (Path, optional): The path to the experiments root directory.
Defaults to None.
experiment_id (str, optional): The experiment_id of the current run.
additional_resolver_funs (dict[str, Resolver], optional): Additional resolver functions.

Expand All @@ -541,6 +553,8 @@ def node_env_resolver_fun(var_name: str) -> int | None:
"config_file_path": config_file_path,
"config_folder_path": config_file_path.parent,
}
if experiments_root_path is not None:
modalities_env_kwargs["experiments_root_path"] = experiments_root_path
if experiment_id is not None:
modalities_env_kwargs["experiment_id"] = experiment_id
OmegaConf.register_new_resolver(
Expand Down
13 changes: 8 additions & 5 deletions src/modalities/config/instantiation_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
PydanticPipelineType,
PydanticPytorchDeviceType,
PydanticPytorchModuleType,
PydanticSteppableProfilerIFType,
PydanticTextInferenceComponentType,
PydanticTokenizerIFType,
)
from modalities.config.utils import parse_torch_device
from modalities.dataloader.dataset import Dataset
from modalities.util import warn_rank_0
from modalities.utils.profilers.profilers import SteppableNoProfiler


class CudaEnvSettings(BaseModel):
Expand Down Expand Up @@ -67,7 +69,7 @@ class TrainingProgress(BaseModel):
class TrainingComponentsInstantiationModel(BaseModel):
class Settings(BaseModel):
class Paths(BaseModel):
checkpoint_saving_path: Path # Explicitly defined field
experiments_root_path: Path # Explicitly defined field

class Config:
extra = "allow"
Expand Down Expand Up @@ -182,13 +184,14 @@ def _check_last_step_checkpointed(self) -> "TrainingComponentsInstantiationModel
evaluation_subscriber: PydanticMessageSubscriberIFType
checkpoint_saving: PydanticCheckpointSavingIFType
gradient_clipper: PydanticGradientClipperIFType
mfu_calculator: Optional[PydanticMFUCalculatorABCType] = None
scheduled_pipeline: Optional[PydanticPipelineType] = None
device_mesh: Optional[PydanticDeviceMeshIFType] = None
profiler: PydanticSteppableProfilerIFType = SteppableNoProfiler()
mfu_calculator: PydanticMFUCalculatorABCType | None = None
scheduled_pipeline: PydanticPipelineType | None = None
device_mesh: PydanticDeviceMeshIFType | None = None
model_raw: PydanticPytorchModuleType

@model_validator(mode="after")
def _check_token_amount_in_dataset(self) -> "TrainingComponentsInstantiationModel.Settings":
def _check_token_amount_in_dataset(self) -> "TrainingComponentsInstantiationModel":
if (
len(self.train_dataset) * self.settings.step_profile.sequence_length
< self.settings.training_target.num_target_tokens
Expand Down
2 changes: 2 additions & 0 deletions src/modalities/config/pydantic_if_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from modalities.utils.debug_components import Debugging
from modalities.utils.mfu import MFUCalculatorABC
from modalities.utils.profilers.batch_generator import DatasetBatchGeneratorIF
from modalities.utils.profilers.profilers import SteppableProfilerIF
from modalities.utils.profilers.steppable_components import SteppableComponentIF


Expand Down Expand Up @@ -92,6 +93,7 @@ def __get_pydantic_core_schema__(
PydanticPipelineType = Annotated[Pipeline, PydanticThirdPartyTypeIF(Pipeline)]
PydanticPipelineStageType = Annotated[PipelineStage, PydanticThirdPartyTypeIF(PipelineStage)]
PydanticSteppableComponentIFType = Annotated[SteppableComponentIF, PydanticThirdPartyTypeIF(SteppableComponentIF)]
PydanticSteppableProfilerIFType = Annotated[SteppableProfilerIF, PydanticThirdPartyTypeIF(SteppableProfilerIF)]
PydanticRemovableHandleType = Annotated[
torch.utils.hooks.RemovableHandle, PydanticThirdPartyTypeIF(torch.utils.hooks.RemovableHandle)
]
Expand Down
Loading