Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 28 additions & 9 deletions examples/inference/gpt/gpt_dynamic_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,10 +284,6 @@ def main():
args_defaults={'no_load_rng': True, 'no_load_optim': True},
)

# Start Nsight profiler.
if os.environ.get("NSIGHT_PREFIX"):
torch.cuda.cudart().cudaProfilerStart()

level_str = os.getenv("LOG_LEVEL", "INFO").upper()
level = getattr(logging, level_str, logging.INFO)
logging.basicConfig(level=level, force=True)
Expand Down Expand Up @@ -350,28 +346,53 @@ def main():
print(setup_prefix)
print("~~~")

# Warmup: run one untimed iteration so CUDA caches, JIT kernels, and
# allocator pools are ready before the measured runs.
if args.inference_repeat_n > 1:
print("Running warmup iteration ...")
engine.reset()
run_inference(requests, engine)
torch.cuda.synchronize()
engine.reset()

# Start CUDA profiler after warmup so nsys traces only the measured runs.
if os.environ.get("NSIGHT_PREFIX"):
torch.cuda.cudart().cudaProfilerStart()

# Run and time test, optionally `args.inference_repeat_n` times.
throughputs = []
cuda_start_event = torch.cuda.Event(enable_timing=True)
cuda_end_event = torch.cuda.Event(enable_timing=True)
for _ in range(args.inference_repeat_n):

# Reset engine.
engine.reset()

torch.cuda.reset_peak_memory_stats()

# Trial.
# Synchronize before starting the timer to avoid measuring stale GPU work.
torch.cuda.synchronize()

# Trial — use both wall-clock and CUDA events for accurate GPU timing.
t = get_curr_time()
cuda_start_event.record()
result = run_inference(requests, engine)
cuda_end_event.record()
step_times = result["step_times"]
add_times = result["add_times"]
output_times = result["output_times"]
total_output_tokens = result["total_output_tokens"]
torch.cuda.synchronize()
total_time = get_curr_time() - t
cuda_elapsed_ms = cuda_start_event.elapsed_time(cuda_end_event)
stats = torch.cuda.memory_stats()
throughput = total_output_tokens / total_time
throughputs.append(throughput)

# Stop CUDA profiler after measured runs.
if os.environ.get("NSIGHT_PREFIX"):
torch.cuda.cudart().cudaProfilerStop()

# Validate all requests finished.
for request in requests:
assert request.state == "finished", f"request.state == '{request.state}' != 'finished'."
Expand Down Expand Up @@ -505,19 +526,17 @@ def escape_str(s):
# f"count [ p {p_count}, d {d_count} ]."
# )
capture_str = f"{engine.capture_stats['time']:.2f} sec" if engine.capture_stats else "--"
cuda_throughput = total_output_tokens / (cuda_elapsed_ms / 1000.0)
print(
f"{setup_prefix} … " f"throughput: {throughput:.3f} tok/s … ",
f"total time: {total_time:.3f}s … "
f"cuda time: {cuda_elapsed_ms:.1f}ms ({cuda_throughput:.3f} tok/s) … "
f"mem {peak_alloc_gb:.1f}/{peak_resvd_gb:.1f} GB … "
f"steps: {engine.context.step_count:d} … "
f"capture {capture_str}",
)
print("~~~")

# Stop Nsight profiler.
if os.environ.get("NSIGHT_PREFIX"):
torch.cuda.cudart().cudaProfilerStop()


if __name__ == "__main__":
main()
1 change: 0 additions & 1 deletion examples/inference/gpt/gpt_dynamic_inference_12b.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
set -u

# Libraries.
pip install simpy
pip install sentencepiece
pip install tiktoken

Expand Down
1 change: 0 additions & 1 deletion examples/inference/gpt/gpt_dynamic_inference_357m.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
set -u

# Libraries.
pip install simpy
pip install sentencepiece
pip install tiktoken

Expand Down
17 changes: 6 additions & 11 deletions examples/inference/gpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,18 +106,13 @@ def get_time_offsets(

random.seed(seed)

import simpy # Guard against this import in test case

# Generate random time offsets.
def arrival(r):
while True:
yield env.timeout(random.expovariate(r))
time_offsets.append(env.now)

# Generate Poisson arrival times by accumulating exponential inter-arrival intervals.
time_offsets = []
env = simpy.Environment()
env.process(arrival(incoming_requests_per_sec))
env.run(incoming_requests_duration)
current_time = 0.0
while current_time < incoming_requests_duration:
current_time += random.expovariate(incoming_requests_per_sec)
if current_time < incoming_requests_duration:
time_offsets.append(current_time)

# Ensure at least a single request.
if len(time_offsets) == 0:
Expand Down
94 changes: 94 additions & 0 deletions megatron/core/inference/engines/dynamic_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
internal_api,
is_row_parallel_linear,
trace_async_exceptions,
unwrap_model,
)

from .async_zmq_communicator import AsyncZMQCommunicator
Expand Down Expand Up @@ -452,6 +453,11 @@ def create_cuda_graphs(self, reset_context: bool = True):
if is_inference_optimized_ep:
unset_inference_cuda_graphed_iteration_for_ep_inference(unwrapped_model)

# MTP CUDA graph warmup: capture graphs for the MTP TransformerLayers
# used during speculative decoding. This must happen after decoder graph
# warmup so that the MTP graphs are captured independently.
self._create_mtp_cuda_graphs(controller, context)

# Memory usage.
time_end = time.time()
mem_stats_end = torch.cuda.memory_stats()
Expand All @@ -476,6 +482,94 @@ def create_cuda_graphs(self, reset_context: bool = True):

self.capture_stats = capture_stats

def _create_mtp_cuda_graphs(self, controller, context):
"""Capture CUDA graphs for MTP layers used in speculative decoding.

Derives the set of MTP batch sizes from the decoder CUDA graph batch
dimensions, then runs ``compute_mtp_single_step`` per batch size to
trigger graph capture. With ``mtp_use_repeated_layer`` one call covers
every depth; with unique layers the remaining depths capture lazily.
"""
num_mtp_heads = controller.num_mtp_heads
num_spec_tokens = controller.num_speculative_tokens or 0
if num_mtp_heads == 0 or num_spec_tokens == 0:
return

model = controller.inference_wrapped_model.model
unwrapped = unwrap_model(model)
if not hasattr(unwrapped, 'mtp'):
return

model_config = model.config

# Only proceed when local CUDA graphs are enabled.
if model_config.cuda_graph_impl != "local":
return

# Collect batch sizes from all graph dimensions. MTP serial forward
# runs on all active requests (decode + prefill), so we need graphs
# for total request counts, not just decode-only counts.
tp_size = get_pg_size(controller.inference_wrapped_model.tp_group)
sp_enabled = model_config.sequence_parallel and tp_size > 1
mtp_batch_sizes = set()
for dim in context.cuda_graph_batch_dimensions_list:
n = dim.req_count
if n > 0:
if sp_enabled:
n += (tp_size - n % tp_size) % tp_size
mtp_batch_sizes.add(n)
if not mtp_batch_sizes:
return

# Flag that MTP CUDA graphs are available. The actual padded count is
# re-derived at runtime from padded_batch_dimensions.req_count.
controller._has_mtp_cuda_graphs = True

device = torch.cuda.current_device()
dtype = model_config.params_dtype
hidden_size = model_config.hidden_size

# Enable inference dispatcher for EP during MTP graph capture.
is_inference_optimized_ep = (
model_config.transformer_impl == "inference_optimized"
and model_config.expert_model_parallel_size > 1
)
if is_inference_optimized_ep:
set_inference_cuda_graphed_iteration_for_ep_inference(model)

logging.info("> MTP CUDA graph warmup: %d batch size(s)", len(mtp_batch_sizes))

from megatron.core.transformer.cuda_graphs import _set_capture_end, _set_capture_start

_set_capture_start()
for batch_size in sorted(mtp_batch_sizes):
dummy_hidden = torch.zeros((batch_size, 1, hidden_size), device=device, dtype=dtype)
if sp_enabled:
from megatron.core.tensor_parallel.mappings import (
scatter_to_sequence_parallel_region,
)

dummy_hidden = scatter_to_sequence_parallel_region(
dummy_hidden, group=controller.inference_wrapped_model.tp_group
)
dummy_token_ids = torch.zeros((1, batch_size), device=device, dtype=torch.long)
dummy_position_ids = torch.zeros((1, batch_size), device=device, dtype=torch.int64)

# One call per batch size; depth=0 warms the shared layer (repeated
# mode) or the first unique layer (non-repeated mode).
unwrapped.compute_mtp_single_step(
hidden_states=dummy_hidden,
next_token_ids=dummy_token_ids,
position_ids=dummy_position_ids,
depth=0,
)
_set_capture_end()

if is_inference_optimized_ep:
unset_inference_cuda_graphed_iteration_for_ep_inference(model)

logging.info("> MTP CUDA graph warmup complete")

@internal_api
async def start_listening_to_data_parallel_coordinator(
self,
Expand Down
Loading