Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions exploration/create_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def parse_args():
p.add_argument("--batch", type=int, default=4)
p.add_argument("--validate", type=int, default=20, help="Print class sequences for first N problems")
p.add_argument("--max_new_tokens", type=int, default=1024)
p.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility (e.g. generation with temperature)")
return p.parse_args()

# ── HELPERS ─────────────────────────────────────────────────
Expand Down Expand Up @@ -208,6 +209,9 @@ def main():
args = parse_args()
os.makedirs(args.out, exist_ok=True)

torch.manual_seed(args.seed)
np.random.seed(args.seed)

ckpt_raw = os.path.join(args.out, "raw_extractions.pkl")
ckpt_features = os.path.join(args.out, "all_sentences_features.pkl")
ckpt_with_neu = os.path.join(args.out, "all_sentences_features_with_neutral.pkl")
Expand Down
203 changes: 203 additions & 0 deletions exploration/create_dataset_multigpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
"""
Multi-GPU launcher for create_dataset: shards problems across N GPUs (e.g. 8x A100),
runs Phase 1 + Phase 2 per shard, then merges and writes final outputs.

Usage:
python create_dataset_multigpu.py \
--model deepseek-ai/DeepSeek-R1-Distill-Qwen-14B \
--base Qwen/Qwen2.5-14B \
--dataset openai/gsm8k \
--split train \
--n 7473 \
--layer 28 \
--out ./gsm8k_output \
--ngpus 8
"""

from __future__ import annotations

import os
import sys
import pickle
import argparse
import subprocess
import tempfile

# Import from create_dataset for merge/flatten logic and validation
from create_dataset import (
CLASSES_ORDERED,
parse_args as base_parse_args,
print_validation,
)


def parse_args():
p = argparse.ArgumentParser(
description="Run create_dataset across multiple GPUs (shard by problem index)."
)
p.add_argument("--model", default="deepseek-ai/DeepSeek-R1-Distill-Qwen-14B")
p.add_argument("--base", default="Qwen/Qwen2.5-14B")
p.add_argument("--clf", default="Qwen/Qwen2.5-7B-Instruct")
p.add_argument("--dataset", default="HuggingFaceH4/MATH-500")
p.add_argument("--split", default="test")
p.add_argument("--layer", type=int, default=28)
p.add_argument("--n", type=int, default=500)
p.add_argument("--out", default="./dataset_output")
p.add_argument("--batch", type=int, default=4)
p.add_argument("--validate", type=int, default=20)
p.add_argument("--max_new_tokens", type=int, default=1024)
p.add_argument(
"--ngpus",
type=int,
default=8,
help="Number of GPUs to use (default 8 for 8x A100)",
)
p.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility (passed to workers)")
return p.parse_args()


def main():
args = parse_args()
args.out = os.path.abspath(args.out)
os.makedirs(args.out, exist_ok=True)

# Load dataset once to get size and validate
from datasets import load_dataset

config = "default" if "MATH" in args.dataset else "main"
ds = load_dataset(args.dataset, config, split=f"{args.split}[:{args.n}]")
key = "problem" if "problem" in ds[0] else "question"
n_problems = len(ds)
print(f"Dataset: {args.dataset} {args.split}, {n_problems} problems")
print(f"Sharding across {args.ngpus} GPUs\n")

ngpus = min(args.ngpus, n_problems)
if ngpus < args.ngpus:
print(f"Using {ngpus} GPUs (n_problems={n_problems} < ngpus={args.ngpus})")

# Shard boundaries: [0, s1), [s1, s2), ..., [s_{n-1}, n_problems)
shard_size = (n_problems + ngpus - 1) // ngpus
ranges = []
for i in range(ngpus):
start = i * shard_size
end = min(start + shard_size, n_problems)
if start < n_problems:
ranges.append((i, start, end))

# Args for workers (out must be absolute since workers run with cwd=script_dir)
args_dict = {
"out": os.path.abspath(args.out),
"dataset": args.dataset,
"split": args.split,
"n": args.n,
"base": args.base,
"model": args.model,
"clf": args.clf,
"layer": args.layer,
"batch": args.batch,
"seed": args.seed,
}

script_dir = os.path.dirname(os.path.abspath(__file__))
worker_script = os.path.join(script_dir, "create_dataset_worker.py")

with tempfile.NamedTemporaryFile(suffix=".pkl", delete=False) as f:
args_pickle = f.name
try:
pickle.dump(args_dict, open(args_pickle, "wb"))

procs = []
for rank, start_idx, end_idx in ranges:
env = os.environ.copy()
env["CUDA_VISIBLE_DEVICES"] = str(rank)
cmd = [
sys.executable,
worker_script,
str(rank),
str(start_idx),
str(end_idx),
args_pickle,
]
p = subprocess.Popen(cmd, env=env, cwd=script_dir)
procs.append((rank, p))

for rank, p in procs:
p.wait()
if p.returncode != 0:
raise RuntimeError(f"Worker rank {rank} exited with code {p.returncode}")

print("\nAll workers finished. Merging shards...")
finally:
if os.path.exists(args_pickle):
os.unlink(args_pickle)

# Merge raw_extractions from all shards
all_extractions = {}
for rank, start_idx, end_idx in ranges:
shard_raw = os.path.join(args.out, f"shard_{rank}", "raw_extractions.pkl")
if not os.path.exists(shard_raw):
continue
shard_data = pickle.load(open(shard_raw, "rb"))
for pid, data in shard_data.items():
all_extractions[pid] = data

# Flatten to final format (same as create_dataset)
ckpt_features = os.path.join(args.out, "all_sentences_features.pkl")
ckpt_with_neu = os.path.join(args.out, "all_sentences_features_with_neutral.pkl")
ckpt_cot = os.path.join(args.out, "cot_data.pkl")
ckpt_raw_merged = os.path.join(args.out, "raw_extractions.pkl")

pickle.dump(all_extractions, open(ckpt_raw_merged, "wb"))

all_features = []
all_features_with_neutral = []
for pid, data in all_extractions.items():
for feat in data["sentence_features"]:
if "stage" not in feat:
continue
entry = {
"hidden_state": feat["hidden_state"],
"hidden_state_last": feat["hidden_state_last"],
"problem_id": pid,
"sentence_idx": feat["sentence_idx"],
"sentence": feat["sentence"],
"stage": feat["stage"],
"is_anchor": feat.get("is_anchor", feat["stage"] != "NEUTRAL"),
}
all_features_with_neutral.append(entry)
if feat["stage"] != "NEUTRAL":
all_features.append(entry)

pickle.dump(all_features, open(ckpt_features, "wb"))
pickle.dump(all_features_with_neutral, open(ckpt_with_neu, "wb"))
pickle.dump(
{
pid: {
"problem": d["problem"],
"cot": d["cot"],
"sentences": d["sentences"],
}
for pid, d in all_extractions.items()
},
open(ckpt_cot, "wb"),
)

print_validation(all_extractions, n=args.validate)

stage_counts = {}
for f in all_features_with_neutral:
stage_counts[f["stage"]] = stage_counts.get(f["stage"], 0) + 1

print(f"\n{'='*60}")
print("DONE (multi-GPU merge)")
print(f" Non-neutral features : {len(all_features)}")
print(f" All features : {len(all_features_with_neutral)}")
print(f" Output dir : {args.out}")
print(f"\nStage distribution:")
for cls in CLASSES_ORDERED:
print(f" {cls:30s}: {stage_counts.get(cls, 0)}")
print("="*60)


if __name__ == "__main__":
main()
154 changes: 154 additions & 0 deletions exploration/create_dataset_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
"""
Worker for create_dataset_multigpu: runs Phase 1 + Phase 2 on a single GPU
for a slice of problem indices. Invoked by the launcher with
CUDA_VISIBLE_DEVICES set so this process sees only one GPU.

Usage (called by create_dataset_multigpu.py, not directly):
CUDA_VISIBLE_DEVICES=<gpu_id> python create_dataset_worker.py <rank> <start_idx> <end_idx> <args_pickle>
"""

from __future__ import annotations

import os
import sys
import gc
import pickle
import argparse

# Set device before any torch/cuda import (launcher sets CUDA_VISIBLE_DEVICES when spawning)
import torch
import numpy as np
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM

from create_dataset import (
CLASSES_ORDERED,
split_into_sentences,
get_sentence_token_ranges,
process_problem,
get_classification_prompt,
classify_sentences,
)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16


def parse_worker_args():
p = argparse.ArgumentParser()
p.add_argument("rank", type=int)
p.add_argument("start_idx", type=int)
p.add_argument("end_idx", type=int)
p.add_argument("args_pickle", type=str)
return p.parse_args()


def run_worker(rank: int, start_idx: int, end_idx: int, args_dict: dict):
out = args_dict["out"]
shard_dir = os.path.join(out, f"shard_{rank}")
os.makedirs(shard_dir, exist_ok=True)
ckpt_raw = os.path.join(shard_dir, "raw_extractions.pkl")

from datasets import load_dataset

dataset_name = args_dict["dataset"]
split = args_dict["split"]
n = args_dict["n"]
config = "default" if "MATH" in dataset_name else "main"
ds = load_dataset(dataset_name, config, split=f"{split}[:{n}]")
key = "problem" if "problem" in ds[0] else "question"
problems = [ds[i][key] for i in range(len(ds))]

my_pids = list(range(start_idx, min(end_idx, len(problems))))
if not my_pids:
print(f"[Rank {rank}] No problems in range [{start_idx}, {end_idx})")
return

seed = args_dict.get("seed", 42)

if os.path.exists(ckpt_raw):
all_extractions = pickle.load(open(ckpt_raw, "rb"))
done_pids = set(all_extractions.keys()) & set(my_pids)
else:
all_extractions = {}
done_pids = set()

remaining = [(pid, problems[pid]) for pid in my_pids if pid not in done_pids]
print(f"[Rank {rank}] Problems {start_idx}-{end_idx}: {len(remaining)} remaining, {len(done_pids)} done")

# Phase 1: generation + extraction
if remaining:
tokenizer = AutoTokenizer.from_pretrained(args_dict["base"], trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
args_dict["model"],
torch_dtype=dtype,
low_cpu_mem_usage=True,
).to(DEVICE)
model.eval()

for pid, problem in tqdm(remaining, desc=f"Rank {rank} Phase1"):
torch.manual_seed(seed + pid)
np.random.seed(seed + pid)
result = process_problem(problem, model, tokenizer, args_dict["layer"])
if result:
all_extractions[pid] = result
if (pid + 1) % 25 == 0:
pickle.dump(all_extractions, open(ckpt_raw, "wb"))
torch.cuda.empty_cache()
gc.collect()

pickle.dump(all_extractions, open(ckpt_raw, "wb"))
del model
torch.cuda.empty_cache()
gc.collect()

# Phase 2: classification
unclassified = []
unclassified_refs = []
for pid, data in all_extractions.items():
for i, feat in enumerate(data["sentence_features"]):
if "stage" not in feat:
unclassified.append(feat["sentence"])
unclassified_refs.append((pid, i))

if unclassified:
clf_tokenizer = AutoTokenizer.from_pretrained(args_dict["clf"], trust_remote_code=True)
clf_tokenizer.pad_token = clf_tokenizer.eos_token
clf_tokenizer.padding_side = "left"
classifier = AutoModelForCausalLM.from_pretrained(
args_dict["clf"],
torch_dtype=dtype,
low_cpu_mem_usage=True,
).to(DEVICE)
classifier.eval()

classifications = classify_sentences(
unclassified,
clf_tokenizer,
classifier,
batch_size=args_dict.get("batch", 4),
)

for (pid, feat_idx), cls in zip(unclassified_refs, classifications):
all_extractions[pid]["sentence_features"][feat_idx]["stage"] = cls
all_extractions[pid]["sentence_features"][feat_idx]["is_anchor"] = cls != "NEUTRAL"

pickle.dump(all_extractions, open(ckpt_raw, "wb"))
del classifier, clf_tokenizer
torch.cuda.empty_cache()
gc.collect()

print(f"[Rank {rank}] Done. Extracted {len(all_extractions)} problems.")


def main():
args = parse_worker_args()
with open(args.args_pickle, "rb") as f:
args_dict = pickle.load(f)
run_worker(args.rank, args.start_idx, args.end_idx, args_dict)


if __name__ == "__main__":
main()
Loading