Skip to content

feat: support SD3.5 GRPO training#4

Open
niehen6174 wants to merge 18 commits into
mainfrom
feat/support_sd3
Open

feat: support SD3.5 GRPO training#4
niehen6174 wants to merge 18 commits into
mainfrom
feat/support_sd3

Conversation

@niehen6174
Copy link
Copy Markdown
Collaborator

Motivation

Miles_diffusion currently supports GRPO training for image diffusion models (e.g., Qwen-Image) but lacks support for Stable Diffusion 3.5 (SD3). This PR adds full SD3 GRPO training capability, supporting both local diffusers rollout and sglang-based rollout pipelines.

Modifications

Core: SD3 Model Support

  • Add SD3TrainPipelineConfig defining transformer structure, VAE config, and LoRA target modules
  • Extend CondKwargs with pooled_projections field required by SD3's triple text encoder architecture

Core: Actor Training Logic

  • Implement _init_lora() / _save_local_rollout_lora() for SD3 LoRA weight initialization and sync
  • Support two weight-sync paths: cuda-IPC for sglang, file-based for local rollout
  • Use exact sigmas from rollout snapshots (instead of reconstructing from timesteps) to prevent flow-match scheduler sigma drift — critical for log-prob consistency
  • Add ShardedGradScaler for fp16 training: prevents gradient underflow in fp16 policy gradients while keeping found_inf synchronized across FSDP ranks (no-op for bf16/fp32)
  • Implement KL divergence regularization (--diffusion-kl-beta): computes reference model log-prob by disabling LoRA adapter, penalizes drift from base model via mean-squared difference in predicted means

Core: Local Diffusion Rollout

  • Full local SD3 rollout pipeline: load model → encode prompt → denoise with log-prob → decode → compute reward
  • LoRA hot-reload from file for on-policy weight sync
  • Pipeline lifecycle management (offload/onload for memory efficiency)

Fixes

  • dtype mismatch: Cast transformer inputs to model dtype (fp16) before forward in sd3_pipeline_with_logprob.py
  • sglang response deserialization: Fix TypeError by passing dict directly to deserialize_func (sglang returns dict, not {"data": ...})
  • OCR reward: Support 3D tensor [C, H, W] (SD3) in addition to 4D [C, F, H, W] (video models)
  • global-batch-size: Fix from 128 → 64 to ensure num_steps_per_rollout=2 (prevents reward collapse)
  • Placement group: Properly differentiate sglang mode (no GPU/no PG for RolloutManager) vs local rollout mode (GPU + PG binding)

Infra

  • Auto-derive num_steps_per_rollout from global_batch_size when not explicitly set
  • Migrate FastAPI router from deprecated on_event("startup") to lifespan context manager
  • Add --diffusion-ignore-last, --diffusion-init-lora-weight, --diffusion-kl-beta CLI parameters
  • Add SD3 OCR training scripts (sglang + local rollout) and PickScore prompt conversion tool

Experimental Validation

  • SGLang rollout: OCR reward rises from 0.38 → 0.70+ over 300+ rollouts (stable)
image
  • Local rollout: Training loop verified with correct log-prob computation and reward signals
image

niehen6174 and others added 18 commits April 29, 2026 03:31
…tions + PipelineConfig)

- sgl_d_dtype_patch: monkey-patch DenoisingStage so target_dtype follows
  pipeline_config.dit_precision instead of hardcoded bf16. Without this,
  fp16-trained models (e.g. SD3) get a systematic logprob mismatch vs the
  trainer's fp16 FSDP forward, blowing up approx_kl / clipfrac.

- _compute_server_args: build PipelineConfig explicitly via from_kwargs and
  forward all --sglang-* PipelineConfig flags that differ from the base
  default. The previous loop dropped them on the floor, so --sglang-dit-
  precision fp16 never reached sglang.

- _launch_server_target / scheduler wrapper: install monkey_patch_torch_
  reductions in both the launch_server process and the scheduler grandchild.
  sgl-d's multimodal_gen weights_updater path (unlike srt) doesn't call it,
  so the receiver hits AttributeError(_rebuild_cuda_tensor_original) on the
  first cuda-IPC bucket.

- Generalize _scheduler_process_with_qwen_image_patch into
  _scheduler_process_with_miles_patches: dtype + reductions patches always
  apply, qwen-image parity patch is opt-in via --apply-qwen-image-sgl-d-patch.

Co-authored-by: Cursor <cursoragent@cursor.com>
…pt --colocate

- Revert b4509b6's CPU pickle+base64 fallback in
  DiffusionUpdateWeightFromTensor: serializing every bucket to CPU and
  shipping ~5 GiB over HTTP took ~238 s per update_weights, dominating
  step time. Use MultiprocessingSerializer (cuda IPC, zero-copy) again.

- The reason b4509b6 went to CPU was "Invalid device_uuid" from sglang.
  Real fix is --colocate: actor and sglang must share the same Ray
  placement-group bundle so they see the same CUDA_VISIBLE_DEVICES, which
  is exactly what the qwen-image script does. Without it Ray hands the two
  ray actors disjoint visible devices and CUDA IPC can't map the sender's
  GPU UUID. The companion piece (monkey_patch_torch_reductions on the
  receiver) lives in the previous commit.

- scripts/run-diffusion-grpo-sd3-ocr-sglang.sh:
  * add --colocate
  * add --update-weight-buffer-size 2147483648 (2 GiB) so the ~5 GiB DiT
    fits in 3 buckets instead of ~20
  * keep --diffusion-gradient-accumulation-steps 64
  * NUM_ROLLOUT env var + MILES_DEBUG_ALIGNMENT=1 toggle for debug paths

- One-shot LoRA-sync diagnostic log (only printed for weight_version<=2):
  sent / merged_lora / unmatched_base_layer counts. Cost is negligible and
  this is what surfaces LoRA-prefix bugs in the future.

Net effect on the 20260508 run: update_weights_time 238s → 2.9s (~82×),
overall step_time roughly halved.

Co-authored-by: Cursor <cursoragent@cursor.com>
…ruction

Reconstructing sigmas from rollout timesteps via timesteps/num_train_timesteps
is brittle: for flow-match schedulers with use_dynamic_shifting=True the
rollout's sigmas are post-shift, and small drift between rollout and trainer
sigmas shows up as SDE logprob mismatch (approx_kl / ratio_abs_minus_1
inflation).

Prefer the sigmas snapshot the rollout actually used (carried on
DitTrajectory.sigmas), and fall back to the previous reconstruction only
when it isn't present.

Co-authored-by: Cursor <cursoragent@cursor.com>
…h, and use target branch seed logic

- Remove all defensive getattr(args, ...) patterns in actor.py (args fields are registered)
- Delete --qkv-format and --true-on-policy-mode CLI parameters from arguments.py
- Replace _train_microgroup_seed with inline seed logic from diffusion_RL_v0.1
- Remove sgl_d_dtype_patch.py (dit_precision fix now upstream in sglang)
- sd3_pipeline_with_logprob: cast inputs to model dtype (fp16) before
  transformer forward to match actor recompute precision
- router: migrate deprecated on_event('startup') to FastAPI lifespan
- arguments: add --debug-disable-weight-sync flag; derive
  num_steps_per_rollout from global_batch_size when not set
- diffusion_rollout_response: simplify log_prob deserialization to
  pass full dict to deserialize_func (fixes TypeError with sglang)
- sglang script: fix global-batch-size 128->64, add HF_TOKEN export,
  add --sglang-pipeline-class-name StableDiffusion3Pipeline
- actor.py: trim verbose scheduler sigmas comment
- diffusion_update_weight_utils.py: trim verbose CUDA IPC comment
- placement_group.py: fix RolloutManager GPU allocation to support both
  sglang (no GPU) and local diffusion rollout (needs GPU from PG)
- arguments.py: remove unused --debug-check-update-direction param
- delete scripts/run-diffusion-grpo-sd3-ocr-debug.sh (no longer needed)
- Remove --diffusion-train (unused)
- Remove --diffusion-timestep-batch (alias for --micro-batch-size-tstep)
- Remove --diffusion-dtype (alias for --diffusion-forward-dtype)
- Remove --diffusion-gradient-accumulation-steps (dead code, never read)
- Remove --debug-disable-weight-sync and related actor.py logic
- Simplify noise_level getattr in diffusion_rollout.py (remove non-existent param fallback)
SD3.5 is a gated model; sglang needs HF_TOKEN to fetch model_index.json
for auto-detection. With the token set, explicit --pipeline-class-name is
unnecessary — sglang resolves StableDiffusion3Pipeline from model_index.json.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant