From 5f30022afe099e8efa57b1d9a0793521b650e694 Mon Sep 17 00:00:00 2001 From: jglee-sqbits Date: Tue, 10 Mar 2026 13:15:59 +0000 Subject: [PATCH 1/2] Add shift_terminal support to FlowMatchEulerDiscreteScheduler --- .../scheduling_flow_match_euler_discrete.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/max/python/max/pipelines/lib/diffusion_schedulers/scheduling_flow_match_euler_discrete.py b/max/python/max/pipelines/lib/diffusion_schedulers/scheduling_flow_match_euler_discrete.py index f0b67bb23b4..e78e2e28a1f 100644 --- a/max/python/max/pipelines/lib/diffusion_schedulers/scheduling_flow_match_euler_discrete.py +++ b/max/python/max/pipelines/lib/diffusion_schedulers/scheduling_flow_match_euler_discrete.py @@ -33,6 +33,7 @@ def __init__( use_flow_sigmas: bool = False, use_dynamic_shifting: bool = False, use_empirical_mu: bool = False, + shift_terminal: float | None = None, order: int = 1, **unused_kwargs, ) -> None: @@ -46,6 +47,8 @@ def __init__( use_flow_sigmas: Whether to use flow sigmas. use_dynamic_shifting: Whether to use dynamic shifting. use_empirical_mu: Whether to use empirical mu. + shift_terminal: If set, stretch shifted sigmas so the last + sigma equals this value instead of 1/num_steps. order: Order of the scheduler. **unused_kwargs: Unused keyword arguments. """ @@ -56,6 +59,7 @@ def __init__( self.use_flow_sigmas = use_flow_sigmas self.use_dynamic_shifting = use_dynamic_shifting self.use_empirical_mu = use_empirical_mu + self.shift_terminal = shift_terminal self.order = order self._use_flow_sigmas = use_flow_sigmas @@ -145,6 +149,16 @@ def retrieve_timesteps_and_sigmas( if self._use_dynamic_shifting: mu = self._calculate_mu(image_seq_len, num_inference_steps) sigmas = self._time_shift_exponential(mu, 1.0, sigmas) + + # Stretch sigmas so the last value equals shift_terminal + # (matches diffusers stretch_shift_to_terminal) + if self.shift_terminal is not None and self.shift_terminal > 0: + one_minus_z = 1.0 - sigmas + scale_factor = one_minus_z[-1] / (1.0 - self.shift_terminal) + sigmas = (1.0 - (one_minus_z / scale_factor)).astype( + np.float32 + ) + timesteps = sigmas * 1000.0 if reverse: timesteps = ((1000.0 - timesteps) / 1000.0).astype(np.float32) From ba87ec68be8fc00ac527872e091763fa98d21a71 Mon Sep 17 00:00:00 2001 From: jglee-sqbits Date: Wed, 11 Mar 2026 07:12:38 +0000 Subject: [PATCH 2/2] Format flow match scheduler --- .../scheduling_flow_match_euler_discrete.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/max/python/max/pipelines/lib/diffusion_schedulers/scheduling_flow_match_euler_discrete.py b/max/python/max/pipelines/lib/diffusion_schedulers/scheduling_flow_match_euler_discrete.py index e78e2e28a1f..84a12e0d2e6 100644 --- a/max/python/max/pipelines/lib/diffusion_schedulers/scheduling_flow_match_euler_discrete.py +++ b/max/python/max/pipelines/lib/diffusion_schedulers/scheduling_flow_match_euler_discrete.py @@ -155,9 +155,7 @@ def retrieve_timesteps_and_sigmas( if self.shift_terminal is not None and self.shift_terminal > 0: one_minus_z = 1.0 - sigmas scale_factor = one_minus_z[-1] / (1.0 - self.shift_terminal) - sigmas = (1.0 - (one_minus_z / scale_factor)).astype( - np.float32 - ) + sigmas = (1.0 - (one_minus_z / scale_factor)).astype(np.float32) timesteps = sigmas * 1000.0 if reverse: