Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(
use_flow_sigmas: bool = False,
use_dynamic_shifting: bool = False,
use_empirical_mu: bool = False,
shift_terminal: float | None = None,
order: int = 1,
**unused_kwargs,
) -> None:
Expand All @@ -46,6 +47,8 @@ def __init__(
use_flow_sigmas: Whether to use flow sigmas.
use_dynamic_shifting: Whether to use dynamic shifting.
use_empirical_mu: Whether to use empirical mu.
shift_terminal: If set, stretch shifted sigmas so the last
sigma equals this value instead of 1/num_steps.
order: Order of the scheduler.
**unused_kwargs: Unused keyword arguments.
"""
Expand All @@ -56,6 +59,7 @@ def __init__(
self.use_flow_sigmas = use_flow_sigmas
self.use_dynamic_shifting = use_dynamic_shifting
self.use_empirical_mu = use_empirical_mu
self.shift_terminal = shift_terminal
self.order = order

self._use_flow_sigmas = use_flow_sigmas
Expand Down Expand Up @@ -145,6 +149,14 @@ def retrieve_timesteps_and_sigmas(
if self._use_dynamic_shifting:
mu = self._calculate_mu(image_seq_len, num_inference_steps)
sigmas = self._time_shift_exponential(mu, 1.0, sigmas)

# Stretch sigmas so the last value equals shift_terminal
# (matches diffusers stretch_shift_to_terminal)
if self.shift_terminal is not None and self.shift_terminal > 0:
one_minus_z = 1.0 - sigmas
scale_factor = one_minus_z[-1] / (1.0 - self.shift_terminal)
sigmas = (1.0 - (one_minus_z / scale_factor)).astype(np.float32)

timesteps = sigmas * 1000.0
if reverse:
timesteps = ((1000.0 - timesteps) / 1000.0).astype(np.float32)
Expand Down
Loading