diff --git a/max/python/max/pipelines/lib/diffusion_schedulers/scheduling_flow_match_euler_discrete.py b/max/python/max/pipelines/lib/diffusion_schedulers/scheduling_flow_match_euler_discrete.py index f0b67bb23b4..84a12e0d2e6 100644 --- a/max/python/max/pipelines/lib/diffusion_schedulers/scheduling_flow_match_euler_discrete.py +++ b/max/python/max/pipelines/lib/diffusion_schedulers/scheduling_flow_match_euler_discrete.py @@ -33,6 +33,7 @@ def __init__( use_flow_sigmas: bool = False, use_dynamic_shifting: bool = False, use_empirical_mu: bool = False, + shift_terminal: float | None = None, order: int = 1, **unused_kwargs, ) -> None: @@ -46,6 +47,8 @@ def __init__( use_flow_sigmas: Whether to use flow sigmas. use_dynamic_shifting: Whether to use dynamic shifting. use_empirical_mu: Whether to use empirical mu. + shift_terminal: If set, stretch shifted sigmas so the last + sigma equals this value instead of 1/num_steps. order: Order of the scheduler. **unused_kwargs: Unused keyword arguments. """ @@ -56,6 +59,7 @@ def __init__( self.use_flow_sigmas = use_flow_sigmas self.use_dynamic_shifting = use_dynamic_shifting self.use_empirical_mu = use_empirical_mu + self.shift_terminal = shift_terminal self.order = order self._use_flow_sigmas = use_flow_sigmas @@ -145,6 +149,14 @@ def retrieve_timesteps_and_sigmas( if self._use_dynamic_shifting: mu = self._calculate_mu(image_seq_len, num_inference_steps) sigmas = self._time_shift_exponential(mu, 1.0, sigmas) + + # Stretch sigmas so the last value equals shift_terminal + # (matches diffusers stretch_shift_to_terminal) + if self.shift_terminal is not None and self.shift_terminal > 0: + one_minus_z = 1.0 - sigmas + scale_factor = one_minus_z[-1] / (1.0 - self.shift_terminal) + sigmas = (1.0 - (one_minus_z / scale_factor)).astype(np.float32) + timesteps = sigmas * 1000.0 if reverse: timesteps = ((1000.0 - timesteps) / 1000.0).astype(np.float32)