Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 43 additions & 10 deletions roll/configs/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
from roll.platforms import current_platform
from roll.utils.config_utils import (calculate_megatron_dp_size,
validate_megatron_batch_size)
from roll.utils.exceptions import (
RollConfigConflictError,
RollConfigValidationError,
RollPipelineError,
)
from roll.utils.logging import get_logger

logger = get_logger()
Expand Down Expand Up @@ -250,7 +255,13 @@ def to_dict(self):

def __post_init__(self):

assert self.response_length or self.sequence_length, "response_length or sequence_length must be set"
if not (self.response_length or self.sequence_length):
raise RollConfigValidationError(
field_name="response_length/sequence_length",
expected_type="at least one must be set",
actual_value=f"response_length={self.response_length}, sequence_length={self.sequence_length}",
message="Either response_length or sequence_length must be set"
)

if self.sequence_length is None:
self.sequence_length = self.response_length + self.prompt_length
Expand All @@ -259,12 +270,22 @@ def __post_init__(self):
self.response_length = None

if self.val_prompt_length is None:
assert self.val_sequence_length is None, "val_prompt_length and val_sequence_length must be set simultaneously"
if self.val_sequence_length is not None:
raise RollConfigConflictError(
field1="val_prompt_length",
field2="val_sequence_length",
reason="val_prompt_length is None but val_sequence_length is set"
)
self.val_prompt_length = self.prompt_length
self.val_sequence_length = self.sequence_length

if self.val_prompt_length is not None:
assert self.val_sequence_length, "val_prompt_length and val_sequence_length must be set simultaneously"
if not self.val_sequence_length:
raise RollConfigConflictError(
field1="val_prompt_length",
field2="val_sequence_length",
reason="val_prompt_length is set but val_sequence_length is None or empty"
)


if self.track_with == "tensorboard":
Expand Down Expand Up @@ -297,9 +318,12 @@ def __post_init__(self):
if hasattr(attribute, "training_args"):
setattr(attribute.training_args, "seed", self.seed)

assert not (
self.profiler_timeline and self.profiler_memory
), f"ensure that only one profiling mode is enabled at a time"
if self.profiler_timeline and self.profiler_memory:
raise RollConfigConflictError(
field1="profiler_timeline",
field2="profiler_memory",
reason="Only one profiling mode can be enabled at a time"
)

self.profiler_output_dir = os.path.join(
self.profiler_output_dir, self.exp_name, datetime.now().strftime("%Y%m%d-%H%M%S")
Expand Down Expand Up @@ -353,9 +377,13 @@ def __post_init__(self):

if hasattr(self, 'actor_infer') and isinstance(self.actor_infer, WorkerConfig) and self.actor_infer.strategy_args is not None:
strategy_name = self.actor_infer.strategy_args.strategy_name
assert strategy_name in ["vllm", "sglang"]
# Use max_running_requests+1 to reserve extra one for abort_requests.
# 1000 is ray_constants.DEFAULT_MAX_CONCURRENCY_ASYNC.
if strategy_name not in ["vllm", "sglang"]:
raise RollConfigValidationError(
field_name="actor_infer.strategy_args.strategy_name",
expected_type="one of ['vllm', 'sglang']",
actual_value=strategy_name,
message=f"Invalid inference strategy '{strategy_name}'. Only 'vllm' and 'sglang' are supported for actor_infer"
)
max_concurrency = max(self.max_running_requests + 1, 1000)
self.actor_infer.max_concurrency = max(self.actor_infer.max_concurrency, max_concurrency)
logger.info(f"Set max_concurrency of actor_infer to {self.actor_infer.max_concurrency}")
Expand Down Expand Up @@ -594,7 +622,12 @@ class PPOConfig(BaseConfig):

def __post_init__(self):
super().__post_init__()
assert self.async_generation_ratio == 0 or self.generate_opt_level == 1
if self.async_generation_ratio != 0 and self.generate_opt_level != 1:
raise RollConfigConflictError(
field1="async_generation_ratio",
field2="generate_opt_level",
reason="async_generation_ratio != 0 requires generate_opt_level == 1"
)

if (
self.actor_train.model_args.model_name_or_path is None
Expand Down
Loading