feat: support SD3.5 GRPO training#4
Open
niehen6174 wants to merge 18 commits into
Open
Conversation
Made-with: Cursor
…tions + PipelineConfig) - sgl_d_dtype_patch: monkey-patch DenoisingStage so target_dtype follows pipeline_config.dit_precision instead of hardcoded bf16. Without this, fp16-trained models (e.g. SD3) get a systematic logprob mismatch vs the trainer's fp16 FSDP forward, blowing up approx_kl / clipfrac. - _compute_server_args: build PipelineConfig explicitly via from_kwargs and forward all --sglang-* PipelineConfig flags that differ from the base default. The previous loop dropped them on the floor, so --sglang-dit- precision fp16 never reached sglang. - _launch_server_target / scheduler wrapper: install monkey_patch_torch_ reductions in both the launch_server process and the scheduler grandchild. sgl-d's multimodal_gen weights_updater path (unlike srt) doesn't call it, so the receiver hits AttributeError(_rebuild_cuda_tensor_original) on the first cuda-IPC bucket. - Generalize _scheduler_process_with_qwen_image_patch into _scheduler_process_with_miles_patches: dtype + reductions patches always apply, qwen-image parity patch is opt-in via --apply-qwen-image-sgl-d-patch. Co-authored-by: Cursor <cursoragent@cursor.com>
…pt --colocate - Revert b4509b6's CPU pickle+base64 fallback in DiffusionUpdateWeightFromTensor: serializing every bucket to CPU and shipping ~5 GiB over HTTP took ~238 s per update_weights, dominating step time. Use MultiprocessingSerializer (cuda IPC, zero-copy) again. - The reason b4509b6 went to CPU was "Invalid device_uuid" from sglang. Real fix is --colocate: actor and sglang must share the same Ray placement-group bundle so they see the same CUDA_VISIBLE_DEVICES, which is exactly what the qwen-image script does. Without it Ray hands the two ray actors disjoint visible devices and CUDA IPC can't map the sender's GPU UUID. The companion piece (monkey_patch_torch_reductions on the receiver) lives in the previous commit. - scripts/run-diffusion-grpo-sd3-ocr-sglang.sh: * add --colocate * add --update-weight-buffer-size 2147483648 (2 GiB) so the ~5 GiB DiT fits in 3 buckets instead of ~20 * keep --diffusion-gradient-accumulation-steps 64 * NUM_ROLLOUT env var + MILES_DEBUG_ALIGNMENT=1 toggle for debug paths - One-shot LoRA-sync diagnostic log (only printed for weight_version<=2): sent / merged_lora / unmatched_base_layer counts. Cost is negligible and this is what surfaces LoRA-prefix bugs in the future. Net effect on the 20260508 run: update_weights_time 238s → 2.9s (~82×), overall step_time roughly halved. Co-authored-by: Cursor <cursoragent@cursor.com>
…ruction Reconstructing sigmas from rollout timesteps via timesteps/num_train_timesteps is brittle: for flow-match schedulers with use_dynamic_shifting=True the rollout's sigmas are post-shift, and small drift between rollout and trainer sigmas shows up as SDE logprob mismatch (approx_kl / ratio_abs_minus_1 inflation). Prefer the sigmas snapshot the rollout actually used (carried on DitTrajectory.sigmas), and fall back to the previous reconstruction only when it isn't present. Co-authored-by: Cursor <cursoragent@cursor.com>
…h, and use target branch seed logic - Remove all defensive getattr(args, ...) patterns in actor.py (args fields are registered) - Delete --qkv-format and --true-on-policy-mode CLI parameters from arguments.py - Replace _train_microgroup_seed with inline seed logic from diffusion_RL_v0.1 - Remove sgl_d_dtype_patch.py (dit_precision fix now upstream in sglang)
- sd3_pipeline_with_logprob: cast inputs to model dtype (fp16) before
transformer forward to match actor recompute precision
- router: migrate deprecated on_event('startup') to FastAPI lifespan
- arguments: add --debug-disable-weight-sync flag; derive
num_steps_per_rollout from global_batch_size when not set
- diffusion_rollout_response: simplify log_prob deserialization to
pass full dict to deserialize_func (fixes TypeError with sglang)
- sglang script: fix global-batch-size 128->64, add HF_TOKEN export,
add --sglang-pipeline-class-name StableDiffusion3Pipeline
- actor.py: trim verbose scheduler sigmas comment - diffusion_update_weight_utils.py: trim verbose CUDA IPC comment - placement_group.py: fix RolloutManager GPU allocation to support both sglang (no GPU) and local diffusion rollout (needs GPU from PG) - arguments.py: remove unused --debug-check-update-direction param - delete scripts/run-diffusion-grpo-sd3-ocr-debug.sh (no longer needed)
- Remove --diffusion-train (unused) - Remove --diffusion-timestep-batch (alias for --micro-batch-size-tstep) - Remove --diffusion-dtype (alias for --diffusion-forward-dtype) - Remove --diffusion-gradient-accumulation-steps (dead code, never read) - Remove --debug-disable-weight-sync and related actor.py logic - Simplify noise_level getattr in diffusion_rollout.py (remove non-existent param fallback)
SD3.5 is a gated model; sglang needs HF_TOKEN to fetch model_index.json for auto-detection. With the token set, explicit --pipeline-class-name is unnecessary — sglang resolves StableDiffusion3Pipeline from model_index.json.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
Miles_diffusion currently supports GRPO training for image diffusion models (e.g., Qwen-Image) but lacks support for Stable Diffusion 3.5 (SD3). This PR adds full SD3 GRPO training capability, supporting both local diffusers rollout and sglang-based rollout pipelines.
Modifications
Core: SD3 Model Support
SD3TrainPipelineConfigdefining transformer structure, VAE config, and LoRA target modulesCondKwargswithpooled_projectionsfield required by SD3's triple text encoder architectureCore: Actor Training Logic
_init_lora()/_save_local_rollout_lora()for SD3 LoRA weight initialization and syncShardedGradScalerfor fp16 training: prevents gradient underflow in fp16 policy gradients while keeping found_inf synchronized across FSDP ranks (no-op for bf16/fp32)--diffusion-kl-beta): computes reference model log-prob by disabling LoRA adapter, penalizes drift from base model via mean-squared difference in predicted meansCore: Local Diffusion Rollout
Fixes
sd3_pipeline_with_logprob.pydeserialize_func(sglang returns dict, not{"data": ...})[C, H, W](SD3) in addition to 4D[C, F, H, W](video models)num_steps_per_rollout=2(prevents reward collapse)Infra
num_steps_per_rolloutfromglobal_batch_sizewhen not explicitly seton_event("startup")tolifespancontext manager--diffusion-ignore-last,--diffusion-init-lora-weight,--diffusion-kl-betaCLI parametersExperimental Validation