Skip to content

perf(training): parallelize SFT dataset rendering with multiprocessing#358

Open
xiaoyifan wants to merge 3 commits intomainfrom
parallel-sft-rendering
Open

perf(training): parallelize SFT dataset rendering with multiprocessing#358
xiaoyifan wants to merge 3 commits intomainfrom
parallel-sft-rendering

Conversation

@xiaoyifan
Copy link
Copy Markdown
Contributor

Description

Parallelize the SFT dataset rendering phase in sft_loop.py using multiprocessing.Pool. The rendering loop tokenizes each example sequentially on a single CPU core, which takes ~6 hours for 110K multi-turn examples. With 8 parallel workers, this drops to ~45 minutes (~6.4x speedup observed).

How it works:

  • Each worker process initializes its own tokenizer and renderer via _init_render_worker (avoids pickling non-serializable objects)
  • Pool.imap with chunksize=100 streams results back to the main process for progress tracking
  • Falls back to single-threaded rendering for small datasets or single-CPU environments
  • Worker count capped at min(os.cpu_count(), 8) to respect container CPU limits

Memory consideration: Parallel workers increase peak memory (~2x vs single-threaded) due to per-worker tokenizer copies and IPC deserialization overhead. The orchestrator memory allocation formula in the control plane should account for this (tracked separately).

Type of Change

  • Bug fix
  • New feature
  • Breaking change
  • Refactoring

Testing

  • Added/updated tests
  • Tested manually
  • No testing needed

Tested on a 12.2 GiB / 110K example multi-turn dataset (qwen3.5-397b):

  • Single-threaded: ~405 examples/min → ~4.5h for rendering
  • 8 workers: ~2,580 examples/min → ~43min for rendering (6.4x speedup)

Checklist

  • Self-reviewed my code
  • Change is the minimum necessary diff
  • Added tests for my changes
  • No secrets or credentials in the diff

Made with Cursor

The sft_loop rendering phase tokenizes each example sequentially on a
single CPU core, taking ~6 hours for 110K multi-turn examples.

Use multiprocessing.Pool with imap to distribute rendering across
up to 8 worker processes.  Each worker initializes its own tokenizer
and renderer to avoid pickling issues.  Results stream back to the
main process via imap for progress tracking.

Falls back to single-threaded rendering for small datasets or
single-CPU environments.

Made-with: Cursor
…tion

After fork(), Python's cyclic GC walks the parent heap in each worker,
triggering copy-on-write page faults that duplicate ~37 GiB per worker.
With 8 workers this adds ~300 GiB of overhead, causing OOMKill at 14%
rendering progress even with a 428 GiB memory limit.

Disabling GC in workers keeps shared pages shared, reducing per-worker
overhead from ~37 GiB to ~0.5 GiB. Workers are short-lived and don't
create reference cycles, so GC is unnecessary.

Made-with: Cursor
gc.disable() was insufficient — Python's reference counting still
triggers COW page faults when workers access any inherited object.
Switching to spawn eliminates COW entirely: each worker starts as a
fresh process with no inherited heap, reducing per-worker overhead
from ~37 GiB to ~0.8 GiB.

Made-with: Cursor
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant