From bf1b92fbbc190fb2a88adddeefc8c1b4742686ef Mon Sep 17 00:00:00 2001 From: Shanmugam Ramasamy Date: Mon, 2 Feb 2026 12:55:50 -0800 Subject: [PATCH 01/76] Speculative decoding simple implementation --- .gitignore | 5 +- create_deepseek_dummy_ckpt.sh | 76 +++++ deepseek_inference.sh | 20 ++ .../checkpoints/iter_0000001/.metadata | Bin 0 -> 212642 bytes .../checkpoints/iter_0000001/common.pt | Bin 0 -> 25127 bytes .../checkpoints/iter_0000001/metadata.json | 1 + .../latest_checkpointed_iteration.txt | 1 + ....tfevents.1769037870.pool0-01476.3253909.0 | Bin 0 -> 62232 bytes ....tfevents.1769037919.pool0-01476.3254808.0 | Bin 0 -> 63278 bytes .../inference/gpt/gpt_dynamic_inference.py | 11 +- examples/inference/gpt/utils.py | 3 + .../inference/contexts/dynamic_context.py | 144 ++++++-- .../core/inference/engines/dynamic_engine.py | 31 +- .../text_generation_controller.py | 314 ++++++++++++++++-- .../common/embeddings/rotary_pos_embedding.py | 10 +- megatron/core/models/gpt/gpt_model.py | 108 +++--- .../text/libraries/huggingface_tokenizer.py | 4 +- megatron/training/checkpointing.py | 1 + 18 files changed, 605 insertions(+), 124 deletions(-) create mode 100644 create_deepseek_dummy_ckpt.sh create mode 100644 deepseek_inference.sh create mode 100644 deepseek_mtp_dummy_ckpt/checkpoints/iter_0000001/.metadata create mode 100644 deepseek_mtp_dummy_ckpt/checkpoints/iter_0000001/common.pt create mode 100644 deepseek_mtp_dummy_ckpt/checkpoints/iter_0000001/metadata.json create mode 100644 deepseek_mtp_dummy_ckpt/checkpoints/latest_checkpointed_iteration.txt create mode 100644 deepseek_mtp_dummy_ckpt/tensorboard/events.out.tfevents.1769037870.pool0-01476.3253909.0 create mode 100644 deepseek_mtp_dummy_ckpt/tensorboard/events.out.tfevents.1769037919.pool0-01476.3254808.0 diff --git a/.gitignore b/.gitignore index a9ce4aa0a93..bf4b473c583 100644 --- a/.gitignore +++ b/.gitignore @@ -19,4 +19,7 @@ runs/ # Sphinx documentation docs/_build -docs/apidocs \ No newline at end of file +docs/apidocs +# Large checkpoint files +deepseek_mtp_dummy_ckpt/checkpoints/iter_0000001/__0_0.distcp +deepseek_mtp_dummy_ckpt/checkpoints/iter_0000001/__0_1.distcp diff --git a/create_deepseek_dummy_ckpt.sh b/create_deepseek_dummy_ckpt.sh new file mode 100644 index 00000000000..347f992ec69 --- /dev/null +++ b/create_deepseek_dummy_ckpt.sh @@ -0,0 +1,76 @@ +#!/bin/bash + +# --- User Configuration --- +# Path to your Megatron-LM repository +MEGATRON_PATH="/lustre/fsw/portfolios/coreai/users/shanmugamr/Megatron-LM/" +# Path for saving the checkpoint and logs +OUTPUT_PATH="/lustre/fsw/portfolios/coreai/users/shanmugamr/Megatron-LM/deepseek_mtp_dummy_ckpt" + +# Path to a dummy data file (can be a simple text file) +# Example: echo "hello world" > dummy_data.txt + +# --- Script --- +mkdir -p ${OUTPUT_PATH}/checkpoints +mkdir -p ${OUTPUT_PATH}/tensorboard + +# These arguments define a very small DeepSeek-like model with MTP heads. +# Model size is reduced for quick checkpoint creation. +PRETRAIN_ARGS=( + # Parallelism + --tensor-model-parallel-size 1 + --pipeline-model-parallel-size 1 + --use-mcore-models + + # Model Architecture (Small) + --num-layers 16 + --hidden-size 256 + --ffn-hidden-size 1024 + --num-attention-heads 8 + --seq-length 1024 + --max-position-embeddings 1024 + --position-embedding-type rope + --normalization RMSNorm + --swiglu + --untie-embeddings-and-output-weights + + # MTP Head Configuration + # These arguments are taken from the deepseek example script + --mtp-num-layers 3 + --mtp-loss-scaling-factor 0.1 + + # Training Configuration (Minimal) + --micro-batch-size 1 + --global-batch-size 1 + --train-iters 1 # Run for only 1 iteration to create the checkpoint + --lr 1e-4 + --lr-decay-style cosine + + # Data and Tokenizer + --tokenizer-type HuggingFaceTokenizer + --tokenizer-model deepseek-ai/deepseek-coder-6.7b-base # or another HF tokenizer + --mock-data + --split 100,0,0 + --no-create-attention-mask-in-dataloader + + # Checkpointing + --save ${OUTPUT_PATH}/checkpoints + --save-interval 1 # Save after the first iteration + --eval-interval 1 + + # Other settings + --use-flash-attn + --disable-bias-linear + --bf16 + --log-interval 1 + --tensorboard-dir ${OUTPUT_PATH}/tensorboard +) + +# --- Execution --- +cd ${MEGATRON_PATH} +export PYTHONPATH=${MEGATRON_PATH}:${PYTHONPATH} + +python ${MEGATRON_PATH}/pretrain_gpt.py ${PRETRAIN_ARGS[@]} + +echo "---" +echo "Dummy checkpoint created in: ${OUTPUT_PATH}/checkpoints" +echo "---" diff --git a/deepseek_inference.sh b/deepseek_inference.sh new file mode 100644 index 00000000000..9f735291060 --- /dev/null +++ b/deepseek_inference.sh @@ -0,0 +1,20 @@ +torchrun --nproc-per-node 1 \ + -m examples.inference.gpt.gpt_dynamic_inference \ + --load /lustre/fsw/portfolios/coreai/users/shanmugamr/Megatron-LM/deepseek_mtp_dummy_ckpt/checkpoints \ + --bf16 \ + --model-provider gpt \ + --tensor-model-parallel-size 1 \ + --micro-batch-size 16 \ + --num-tokens-to-generate 20 \ + --inference-dynamic-batching-buffer-size-gb 5 \ + --prompt-file /lustre/fsw/portfolios/llmservice/users/ksanthanam/megatron-lm/debug_prompts.jsonl \ + --use-checkpoint-args \ + --enable-cuda-graph \ + --incoming-requests-per-sec 16 \ + --dist-ckpt-strictness log_unexpected \ + --decode-only-cuda-graphs \ + --tokenizer-type HuggingFaceTokenizer \ + --tokenizer-model deepseek-ai/deepseek-coder-6.7b-base \ + --no-use-tokenizer-model-from-checkpoint-args \ + --output-path /lustre/fsw/portfolios/coreai/users/shanmugamr/Megatron-LM/output.json \ + --return-log-probs \ No newline at end of file diff --git a/deepseek_mtp_dummy_ckpt/checkpoints/iter_0000001/.metadata b/deepseek_mtp_dummy_ckpt/checkpoints/iter_0000001/.metadata new file mode 100644 index 0000000000000000000000000000000000000000..a9acd55675cee0f3b24570e1e5aa1e2fbfc23746 GIT binary patch literal 212642 zcmeHQ33wF6)(u1A@01=o!y`(G< z%IX_g9tpRjt*Q)G6bDL*&kux(%St2R=KdMMkiR4_x2!5;)@o5*i?qBjSQuFp86R0p zs1EYyj>z?&GIo6Ku!;T=c_)t@J}U22Qh#)cRU;J14>znVK0g?-xf@O}|GGSqyI@{! zO1;Z0q;P6MQB~>8%E(kx{y`_3f8jQ^aoP0gmBA2cX_wI{qf@F88ksESzjNYNO{Qq)Ge9=);NsaE+zc`Mn8l6)8A2Z3{&AN8=s{FXR^~^Q^ z>()oqZGfoT5K*@gqHZdpZev8NZEzO+(aefvDROQMZ-2ZjkuNFu0%i zgL6U^0lzha|;1*t~-2~v%?5~LdOB2G2afil6}8{VfI*b@h{TPMyPvFh%=4X(?WXSleGd4^k; zG0$-8GUgd>UB*1at;?8axOExx47V<0p5fMI%ro4&jCqDzmod+9>oVpUZe7Ma!>!Ah zXPk9SHyH|c6Yey#hD2K~cbipG9u4NE7i85n7+8cWE_K=f-%=|fysY>12ZN@ zE^s3ZFk=?w0;ght8B;D7xG@HpF+X#Gi9f(?%$Ur%z)dm0jG3Pc+zbQE2ot!#%`w1? zn1TzOh5=>-AY9-U7+^-U!Ub-L0cM0YT;NtU0Gn}8AMsc(4t|6aMK&&x)ex7+s;-_e zQ831SlJsKNWen%sx{MK;TbD8Da_ch2QEpwv(8#UJ811-q83P!%E@MpM)@2Mc+`5dB zgIkv|7;x({{F_^s;j-Mi3@>!ntu5R;rOwd zDY88>g#!m~3h0lSrf~dN%@n;nGKB*NZwly-nWk|3Sj`l@Ju-y@2X6}KkC~=${8-Ht zeLOOS0|##k=#QDEaQs-!6n#B1g#!m~3h0lSrf~dN%@qASGKB*NZwly-nWk|3Sj`mu zJu-y@2X6}KkC~=${8-Ht13WT?0|##k=#QDEaQyhBgoTe_L!J?*JWbLa7ttenMWFDy^JOI%KedcFx3?5ev#Hf{u3+h0sHB(D?TpEdx}) zqG`T;QiG|gYf@8#SHsB0hg~&(pTyc1;#D-0*TZqr zg^vbh2I!^{X&7ll%nUV*3e`1f7eqCVIorpk(5`WYO5TLuOUNsH9?n@mE2*=7lno&- znX`UY26EOX?KvQ`evDUWN$0Gem7#6c2S6XVSsy7*Xiw*?563+qvp!B9bmoc8jR_d)5a?7d{&2tY0HmvS)pOAYtR2 zvwrpgn)ThhLdosO!TItk>j*tLdJgP$jx+H!c)X(~jr04Lm4}LFkwfqm8P>7fjNqJd ze_-~E4Esg_?gayhCu>`W8(a4S^vWjZ^y5yxI}g`m1)jhW=&w8%uhH0&nv&v~w2>TX z_TEnQPL1G21#Vyxz}`s{UW~wg=eqjNh4`H6X!bkT?7HlCuGw|j?_9I%vfsI8 z*JZzR&92LS=bBxY{mwPJF8iHpc3t*6*LGbqZ9SUfLZ){|cS!5Jh-A)ftHtHb*(Vp+ zb=fBu*mdi?sC$ulCCfj=i%#Wu5zK|ml@DqjGS?S~vp=!w3F9y0><_mt<8#4j|u8r$1D6nl@cR_(| z3djWoLpaj_BbKV038I5AV|~MQ&AO{8L=m~E+ZMmt;<+PbL%paQQW$W zWE8h9BN@f5%Sc8!>zc_Z*#N4}bE1dss<#N)J)E)9bUuN&Y*n4-h&Dpl_C0`}(~Gn8 z$cQ_9QN!DS?s>*KyCQk`uP;ZExf(ifswQ((#Y?L31OniNL5?*Y2cqH}&i*t*^_f{> zdU|GvzOFKbquXW8+9&-Cj&wPbqxKtAj{lv-ku7cUq$wZ6QC^e5Ws?%QDam;GZ!D)( zfag)W8|KO5LWDrppe}hJ(W{4$B)O1jF?K`^ydRT5yp=t~+Uy3vgPjg~}o0EK7Fno;`pg@;|R(o;AQODK?Pq#D_dfm9|8>skxffTHL_m` zsYbRcA=St(C8QeJW`R^Ady|lAWLuI`&D@ck#Oe8_AA*i))pdS*8zas!>pFkJ>j$;4 z^FvQpB}9UCV?DX~HsR^kb@v2=)WYtLI0{wQ@1qMZBlpHKy*WZ$Ti;300yB#6TC^k^B?(&cIS(klr$ArL zkzF&$I9-sV+|{*eqj*TJxF_RcIS;098CwIlPHT~3-W^2o{W__G@?dC(O{yx$s7PMv z5PcXTsiR_2w7`s#gy9)De@1WB#%y?^qa4Uo!P>0%Hn%D>d;zD5+9@6~Rj`!T?x~^{ zDTYiH?u&r6x*Bk*NU|!(s7PLOs5MhXQnbK~QhTO~=vB;G`!8gwVC|j$W2TDQDIPLa zu$0&CsiGDshD;Ui+heu58gQyevMQIMrkD51_Lm1L0<#YE6|k=RD;?568>8hFWd`o{ zJ#tnNa|2SWdO{+War+*(F5`9zZe7Ohd)&H=+l07vNhW=CO8g~b+`8;zo9w$T*$1B4 zb=e1=*>%|mp4oNT2cFq=*$1B4b=e1=*>%~+HraLA2cFq=*$1BOy3E`6uwI0EztMc5 zCU!lY`AU-Ootngs7OP&f6LIS@oQPYO;Y8fJ3@75&WjGPHF2jkqbs0{?t;=vCZe4~G zaqBXih+CK8MBKUzC*sy+I1#rl!-=?c8BXM^%e*y!??ee98aSjdJ*VyhKIbI^+&h{W z4$QBMxxE=!3tV^Yl6}`u z-L*@$DGunhORoE4fHMK&{qpM1QB+?50*awXu+F|OguM`9*JZ>|TwKNw%&p59Rk?K; z10}aEW9;MBWejKBx{MKuTbB_-aqBW-C~jRw48^U>h@rT388H;ME+dBG)@8&{&bnp{ z<$4py;W-EHvX&`{d%ZkW(3cZ%Ne7Z|Apfs^I~@5vch_FOs&v#!Pvzbd*P$J_a@p|u z09_y=*)YUxVf7%%t2TfxndFNifG)|}m}GBD)H;`hORsG%$x{CR!X;~LSnYDjBp*@- zx+H4@=pW*eaOt(pC0WY=ZCiUbWlyw~!bf^LCikbbl?)+5U`KcO@ z*ZdRCPw_i{!Vp;3K<8jB!rAM`TId`s@wMAI zlKfUWutO&0WpJ>T+B0*omc#!T=U|Di-OiEZ_i)=+l+N&1h615rW@S;JqR`(nqo+SB zBP&w0CVgR~Dx6kc5-2SVR`?48p+F=roK_hss|d^p(!ZNuI!g(+$PI=9=6@emS{R%Y zSso$3B>yGcVEVbGk?=53cEHM&`-5{xloBt;pVc=KZd7a`=aKEaV+emiQAS~LWvHS! zzbX_g%qS=d7R)R!D=rOXOfN18R?e*q1!qOVY5oaR>!{M{WieGPDuN|}Q1R>_dD}Qt zWLER{B&{|93(6zm)UxT*D}$lP=#+43Nw9PVsgnD1qpEykONW#c+A!l=@{i#@3GGVg zR=l%Z28sG1QPHqhtke16ws=*d;qZJ`G!COzU5Hb(d7c#lM(_b_M(Cp7-o?$ z6%@+!o-NSOH$g+qK^&O(hzkY0ifxC+Q>hmVG)S}^%Deq?0k2})p%K-mO9UDu+79J? zX}N$`vF*^9>fRLs4H9jK@-AN~;8kopG^)Dr4uJ-VwnKS?_X&6v+YXJZ@*WXrkZ3!U zx8G9&Uhb89>?JZ}dD{5}%9ozXSRtp!Z3odrnmpC0mjzT4;;3reVJRr92>Jt|P_Q&q zTvnPPFRJ2AfxXA#L{)$9(adbi3!1V)AZU0Eg5oCKYy}m=HwsiZyG3*(M83g2z7TYa z;~Wi+y~HX{cKX+Xc8!Z8tFaF{Kkq58XvmKOMb769*qxA9(S55xMXxx!ByPvaleXC@ zc=FCp7QIMCo@(G80hM#?VWqhf!fDUOo`U@XqdUhQfG97ZJ%IIeQ~Zl2@Zo zPM_vN9o$TN&FR=2Uz;fOnS&TJUjM|yB*8i5{=n=RHQ%H`xvjJ_`QE(!p&PJn?hu0fA@b&_X>z1{z_pyib1@dAzt@ZpB{HfiXJ=`_K(A+Z%U2YqBTR5vAZML z>9Dpzq8-n1Tu=x-ykr{|+Hty22tBJ}9~Rni7*Pm){hXGKkp^*`T@*sUc##ha?KtWv zg#PK$3@o(cWTX(f%~gZ3(2j$YLg<0lr6@ zWmssaQ*11nkA;Rh#lF&ISZJtI98wh*ab|-?(|1(C$-=jmY!+*D-O1uuw>UERnq~g6e7jhtC#K{NHuNSQ;R^7G8~3YA z1Z>BsJnd6%ZPiRPC5C}dTeHlvJ3EPGLY{Vw8kL^*zwXaemssO=NBrU+>)S6qIz+4y z__Q_4{Lj;4#WEpJyGD&lPrKcVez8PXcE@q-P-7~8(5q!aUB+@(cmqx42j7~Hg@#5> zOW$9Hg@$H#p88}Z78)9qZTadUEHpH`)AYL+vCzmqv7GQYuQ~r7GR;F!=|tI)mUig z*mml`Td~m40jPZF!&qqOD6!&{wODAVE3X^#9u^wvZ@-M+B7}CY92`!OymjJ1XYNRj zOujDzmWn6(Q_^tDFt#ykT;e3GAL>=V&tcum4$wJ0*xD1FH3sJiI*~byWivawWnNb= zrCy5pKj)cX*Mf&_*BbRIJ^iQ8?dg{5RCy}!rJ#g22puo$IdcSg(U*9j;rZ9+4Yl>9 zvEB~~y55Wl zt`I_pb2vSi@v;N56GL~yQf`k3B^o!lIx*Bg((kzy3k~&;@ekdPg@*da1y4MUg@*da z>gQg;LPP!IotNIjLPP!I&)2`iLPPzd)4M+jp~H~xJhC$CcxM63@*_X`LtQN^6w#vg z+|N^5L}Mq`>;$qcYl7FJ_Qow~VzsW3(h-p)Bo6TcLCeaQf9N2TX#6BSQvUjDH!L)C zq&#?MPb@Taq#V3w5EdFbQqI^v5(^C-DX(rc2@4G!DW7Q`z(PYu%5U1t5<-U+ZXm?_ z3t+Hp-eImMf|E7^_4X|i3+|RM%z)lLb$r=NxG6fd+ch?O6V_TH>xjghmHB-b(7<)s zu`96H^`Y1g^<0I;Cf^2x>Hf6uJy>k=B{Udzy%Qh7VmE?fAA9mMSZwE~<%IoY!l;+9 z*v_xZ39&C6`z99K`GGkh_8pVn$6_~y+WXxppJB0^L9urSzQ$rZzceSbN#~-UvDj%) z-J?o>!(z99VpoLrV6mN_oDGjzdRZGG zxpA9ofQRno9-78<&BdCkf!EHrd~@!2~DV4!9}5kg?Q@zuiiL*G_LVK4#zI49` zLx+SOZO#-zhld^zzdybG6tS%FoE4CL^|Tu+IP3Pe@o;=l}9|BrL4+H`;Ly0g@lFm{|_>eu(1CB*8n6etp9f#j)aBv|J*zz zEUf>}o-Be5yI$moeh?v=A4vQh2z#tw6j>A*9}&KIxiViUB#RS}CA{xr&NrNUm;oBz zDDgl;+y}=K);<6Yzs(kE$OqX0`tcf2lnM}i=t7}r*;91bAr+wX465tPQT_Jp$15&hbK?H>hwQ8CKL_n z^eoXDI(_RkLZy&SpXFVrAMl1y4W!ePmu1u(FaHA%G$?lZtG*O!fOLBJCZJcH{_!7# zq9L80C0awL|7wR&DWub9d)Mhx{}QTk<(nKkMA1+yNvYS&Mt;GH$oNHa$uYSV5*D`P zSlAH>3tMvB+Z725TXKBZ0|^UTa{S#72@6|t9CZ>B7PjQb8zX{sFS+@yfE_hJnc?$J z7s)fab&*?g-+B%b7S{jYG?1{c{{MRj2@C80hs{UA!utR4#Yk9K|DUxS2@C80*WWCH z4G&0M;W7HZXA6)Q?-FY1EzlCTB2o;6etJYGL*Z@#><|H1LbiEcC>Cfl<#pS@*Mu^H zut&#QfhcV50?b$No>2E-sBX`EApE&d5b$|MKT<2b4u9f1p$wo8Bq_Mwv`r{>V3K3j z#(QizO&G5W+$=d~_pEnLG!%xx!shHVn<8OhbN0n;kg%{h`@W7ySlFEXQ5PgEY|h@B zj)aBH*+*xKV8e;0`mnPOC@>v6P^7{5ShADPISC01>*T9)kg%{$zJ5Fs7S_pko`r;k zb@Iatkg%{$K4K;k7S_p2D@CwjUvyCSZ05b;0wG~XFGI|=mk7C>M-o{V`e8j95D5Kz zl~9qh^~%9w+lY`W$6@K@M}#NbBox;*PMl|*uW+@H)oFg^*lx-FLM~_LOHu&$1?asjVe)1h(Hr%MD9B$^K8 zeQCLXSFP#LNb24d0tpgLhw?68Dd1IWIy9EL@D70liKat&gZBw|)tU~Crt%&UNRVhc zl(*kg0^W2-VuV~+WxWx{7|Y2}&RvZ=M)J-tP^$D)#v=T%~_X`v`pEpZU z-1Al#dvekm)f44mouiNY#aQXbk+Z8=l$JU3@ZQrkPhtmWFsp62tx_a~l@BhHs+8un+3H4~t#HH&J5PCk|@a7-^FlzKIgUo_>-Ki(SJv zQDWH3MrL5KYxpKg4EyQ4!C34XzKIgU{(52_7Q2RTqQtP9O__$puHl;~G3;J}QY>~2 z-$aRFpFQIuEOrgwM2TT9nza;*UBfp~V%QH>-h#!h;hQKi>`&%Ch{dkqnn9`I7PGg~wo_QkT8=T@AySuAoaN2Hm^Orp+z zpRqkTIyobb=wAPrHD8&!U#t@H$ZOmkiC-h4ANkgTR?S2cV;K0zb?ZFD=#(6tN{@U< zd8Sw;&z)~~!eWDy zKN~*C#A1V!KmYx72o@Wh{OPiJEEXG_{K@&=FT{2)b%BFFXG((Re=2_|6Y4Swj*`v) z-1OUgEHref`|9pxSZL@{ciY~TSZL@{w|&EhvCz<^?$BmyvCz<^ZgK1Pu+Y$@?zM+( z5kk9{6b`3I;wtjF!*_V16T9@GCpU|_n4cHoe&r#yjQsA%lr-E5jcv{vWpc(Gn93IE ztvykxdr{Fhv$I$y^iu)sm*~tZDxHh7*q*KMV^XoF?eEjm)|bY3i^zE1HF85Ig)2@N zgoTD~3Oq7wBo-PvDSUqFL@YFPQfQEO4i*|ZDI7ns1PcwF6egWD2MZ0I6fQdFQYM2X`p1!TAICyN{bTg}wODAV ze}opjjfIB#$IVxLiiL*y$7|PohlPgv$Mzd`2%*ESyVD&0QA5y#^cUEK-{F?OJrT?b zOSA|cwz^4+XdK0wS)mqYMJQSX8}}Y07MugmEFuL$hs@;XFmVr@mOf+ zkoo=d{jt!{A+zPn!?DoNA+z6`z=g(jbP5`0@Xwj4VBkDIX2 z(3oug|5jt6$*qAP)pw^nfQ5#RUGJwpfrTcQ5`t9!FYN^^G;~<(+U`{>G`aH-q--U%H7eazm zpEP(c78*JQ%^cpSWi+N_4XR<`h|W*Af!YVyzJq7gu4Sp>zHluik}) zhR*rFz4ZVV8an46`oUvZXy}}O@~6*Wp`mkr$)*>D&|&8t(5!dW@lF94L9hGnRk$kF zET=`)n*Y2dR&|&oKJsMn{&C0qHe*A^&=4M(m@cl~K}fuRf^84$dKlht4jQ{!@7^dB zXuK~yb`9A385SBkcI7wNgoTEVT~{^z1`7=xyB=@#0~Q)Oc71j5f3VQdu`BiPZCGgN z*p+$Y??UMCXpVQ7L0>{zR+ocJ?*6AxqtUFp?Bu>M<9}FasDs?wdmk1W>L4EutWUpP zL|&C;4QUL6Kpo`op{ZDCsDm7JYBMY})IsucTVkQ14szavwnFIeWR8Q>bnr{UTL82E zt!EyhE|;|c*CY2$|Doz~;lYb=vIbc9_n=QK*EL!?)?P^(Ee|U>QmE0GB0X9TuQ(bD z4IM3K%{dke4IM47zbFF>4IM3CyfhmN4IM3iy0R}88ai6Gxn>|18ai4IykW2qIy@zD zv~)O#Bl3r@l2l!~;O3J&QJlO%D17@VV#PfbM;d5mWVZ#orcTk`$`0Cs0PBrJJB0RVfr z4+%@&AOOIQ=q`c{PvZCiI49wqE0ku7lp3AV6EhPqSGZvi5*FqPYfnMK!d&6!aY$I0 zE3}=0goU}n2|*+*%oPgDkgzaUSTa`x8-^asO4zdYW*&I*5|K7zvSjDqyaWjg>-3 z0rUZQj?pV0m~yXBH_!)Ix;6N~rH>260)4>!VJqp*_rMyVAXk3MvB?G6n}IxKSgQ8p zH$+N}oMYwkQ~N(a!orp)-M>P@!j>qf{fLBxEm6+@4G9ZdqTK!$5*D^ZdApHcwDNKH z1lST~S1S>$dx_3>1?0d1>;Jw(Me>YGB>VrV-H@=b{$HMfgoXA075$O0u>Qa9BqS`X z|Nk-u2@C802c3b0h4udv10vY4D;#q4e=nEC(`O20xZ)6=jAXAzmsJUMyQX%YZqLHo zr{@a=0iS2|Yr4`A#@CC5GFM#R6@nyd>y#i%3~v`bDiky@$uaDM&)agEFkUzKTXGKn zVV!(rOC&6;lb?GC5*F6UZ}cHyVV(Ttqmi(%PX22K5*F6U5AG*|4R`c18XY`D$dwbv zmF$pn#;HPaUE{=g7FS+9R>)DBFmrn4d70&96`|>6CBMqpM&W^QmsAXHIS+G9*EnL~r+mBHXlKXEUAVb!czbNvK5lo_L_%q$KCEBrmJ z|7AtOjVlAQgBEYM-dWlCfu6x&kL<#J**$t^4H(d)f1t2ekE~vOdj|@859r@7zZdyI zNm-!K{6Sv0@vO4Ks*+%3WNM@;Kin!%QsOTuD+rYM%S!^K=D!VhBMmMn$|x)*4JyvB z3Iz)@Vp1~9YNf%7NVv%aB4Ql*Tg)nrh|F0gMt9>`GAoCho>4*m%_#CO?K;Ny8NKtxGT2ZSaUs1QJd=g8q zBMqpQ$dug6CE0PLdA%8vyleikf&7tYyl4L72jq`x#Es-1iv|(CzJza}3?IqyVfls= zzLrEl_oB8%9TpcIzBm-gm!*$C2mxyD7!T>=Z|nonPxO#J{;URw{%jBFNg#@S_JtnO z#~&vFwZ6nddJ->VTOaa}KK{fDsPz|mNZ%CS`b#~ek3W$(g0jsHRlPP<212f!k%;|Z6 zKw+&bP>9Ah6u71^;Z4$rm&rW+Iw(v?U~_H-f}Ys>Y?DcmlOO@t)XzL5cb+Bzwft)j z$w`2KYx&O}k~{NMpqBsUAvp;Sa4p~CA-OY^25GrbuO5lRnNhJ+hbOTCs_{lsZ~2>I z`P+EQPoe`{`#XBe?@UgDo_`l_`AK|$YkxZ7zb-X}{7L@jVMlbPQ$dEv^41U}RKOWx zfVce4yf3Kz!_@eVQRMfoc^(jXu1D`iADsf>Yjx-#aRhekx>1J? ze73>3O@|JA+QGP2hYo!H!FW`M4tyfQcvgoFd`7}pr$YxmHDRn**YN>?h|f_hK0a>8 z0g#Zwj7?S;-&@SAphZCiHfv#Q)u95LzA$#_ zP=U>37<+Z7z$P?|M)f(XDV5%d&1@Jgbf`e5I1bjK0-fkMLWc?xrr=hYN9j<3PI~mv zp#q)y=%YggIstN`4i)I^#|RxN&^eG?9V*b7kV)z)rV4kQfT9Wp zcWw%q#x9STTQ||ooEr9Aj8co76<8<;k;ux-)}lkgAQ$NnB9R`tOp6W)gDlk{L?Tah zy%rr32DwFt5Q&7@omzBA800}6LL@SGPioO2VURUCgh-?iU(=#P!XWSJ5F(Mw{8)<) z34?6XAw(jn`hyl75(e3(Lx@Cn_fIW4Bn+}ohY*Q0b7})#hCv}fY$jolmO6w;(UJ{>|N67NTA(IH`w3>`uwb^!Wn(IH`w!8(LUY(<=+MTdkz@^lE1*b_Njiw+5c zOj8#!oj@pW&J@`oj?EctF5_IQF~&@@@EK*K03}k~+!FbVxfVAokD{PMB5klxhYpF= z`eGeABo@in=+GgtHeI1Zhs2U_l@1*eE3tcY=#W?dJ)%R0#Jb}d9Xce!{g-s;kchJ1 z)S*KnxP4!T4vCoZGaWi4LbAnm*WIoJ!3_FjN=^aA8C zFF=m)0;H1{AYHrw>FNc@QC@%?;{`~%7a+%Z0n)<@kW4Q?vb+H46q`(W1pcf!ByZ|Zo0%WEQVsCQT-`oP7KM4_lqBNt_EPO@< zDL@+>7uvkhw?ZY3&s=1KNQ%GQRGnrlGz*`xh!i04%WPf+@z>ZOlHzYPRi_ydv+x=J zA_YkNZ8ooh_|-OuEgpLu-~m&2nsJv|_>B8V0h0fO&8#5*DK9{tu|cG5u#P|w8$53o zK4UE@KsH!!Gb`BOZ5zavkMsPErs_0fgIV~D4@m(M|CO5;+jV}R%v&h)cTnbSHnW1+ zcGw`2X4_+`PBV6zh0oYc3Xs|AHLA||jZU$Nv36@lFbH)+v+x;BNP(oTS50m0>f=uT zw6(Yyo6$vGa#Nb*rqqkVk!Q0mc2^_o_DADTf~A}#Kime9w&zh+n~Y9o;WN6D0%Xq~ zHnT!M%(Ovl`8fCOZ>ml+vdzM0^d$vI{7{=$LHtQJh@|+@rs_0fgjx8EQKSHgpJ4MU zh@WJGNQ$3ks!lT|n}yFfn-n1N(`{Y_@kKU>Egn0HoNMY%GiI6vd8LIEAo+7`W(D~d zcmXoc29dVGWdwrQV1Zfqj7v!YvcXcDS-}QZ+aR`loae7FRi_!t&BABgKnjrfRc>Bv z*SVK6ucpj*Lzy47nH9|TxD6s{wiis*X~t7#;WM5i1;}i#MtKz`?sqBm8fL1O#i&sLj^Z_UDI{74GWqx{WgR_KR2Z4g^N&VBzjRi_z$ znuX8!A1OfM8>V_-`&1i5QhX}{K*TpS3!l+~6d>{KZC(Y-ceFty#UE*^PBRWS3!l-M z6d>`(+Pn(lkGDZ=@z`G9$JCu>WSWJ~=tT;U{1a_v1^Gj~02yY3NZTNXKoA=YHw&LJ zk`y2tjJKH;Y%tLVvE}1Df2yfE%{ap>e8v<~fW!ygyx6W&N|}o&a|x7rw#}?ywmCM4 zq}djls?&@M%))2PCk4oCi=(^>6Zf^0I!vjr##1Xy+&5a>-b~!L5iI42d$kQ>KS*pp ze86gxahF+;q#-FlkMaqdS)m_3WrNuAaqhd;RGnr#YZg9Z4Jkn4->`WV#ILtOZ1Gsz ze`xAXGu|-^pYc8^K=Qv(Xa0dQZ=%fKL7BIyGw-I%zftC0Q0DzCX6(_CCXFq7*Eb8F zkxB}Xz1`drM@L#)+;r?PwwSGteX~SGip1%W4w|G$93bheNs7c7lJ1(MNE{>SsY!~& zNs|7Wq(~el8LCN&#CejDnxse^DH*3pio~gs$(p1{94t9UlN5=wB}JN~NE|OI*Ca*a zgvlIDQX~$UEbv6i6$FLG^Nw&38=zccNinW63!kx!6rcdb%`Fk2tgyH_0g8eYi2!Al zCMgmD$~~H-NCYU4Xp$lkpgf~VibR0&k|rq<0m_@2q(}rP?`x7G5ukjgNs2^(^0g)@ z5&_E3nxsesD8FfvA`zhM(IiD8K&jWn%oTV)9ZCc!O+As)hM*`1D2G~7jDyU=XLKM1 zC_r&@O9UuKTHKr`q98>gKsi>E6o~*OTay%t0A-*iDG~w7Nt&cc1SmP0q(}rP<26Z< z2vE+_Bt;@X$=4)BB0w=TNs$OpDl|!v2vE+~Bt;@Xxmc4Fi2&tFO;RKRlxsbaaw9=e z4p45lq!>4wh0j<;3Q&OJ=9UOh?z6Z#0g8eYi2&s>O;RKRl;<=_kqA&;)+9wDKzU1( z6o~-kLrqd70+cT_Ns$OpzSSf}B0%{?lN5;nWv3=75&_CznxsesC=Hqlas>(lzeIr2 zT$2=u0HvKKDG~w7VV+3oN>G#ol;bTa#?fZsGma$%C_rIwOI!ro+al*gC<(#;bs#}gFg1WO-Fijip+GMq)3=0SCbZr1j{5{QY6eWRg)Hp1WQnt6bZA;(xgQq!BVA5iiBC_Y0@H* zU|FO~iiBB~Xwo8)V7X406bZB3tVxSRg5?ffQY6gsfF>;x36>{xNs%zi3!1b@Bv@Y6 zB}Kw4?`YB@kzm=VONxY9zS5*cBEj;#E-4ab*{Vs4M1p0PE-4ab*{ey5M1rMJb3sx= zVR_aCB=BjEPTeRqyPmdZf=PHWrM}d2~ZTINCYUKXp$lkplsG8MIu1?QIiyj0A;%- zDG~w7ZcS1o0+julq(}rPjnf3_0tIK52vAySk|GhHwAUm>B0%Y+Ns2^(a*QS^5&=r4 zCMgmDNBptu5*=!ce!^~aI_`mXuM2J%Oq@t*mQACN!N7e?kqx*HqK zzmKzo7&&I)GxA6Q3Qs1hGY2U16w2&}G8d~eS5W2>%3Ka*K3|>r63RS}GB1QOU!~3* zq0CDu^R-aso79=_pv)^N^X*XP2cpah*RwxKsUN1)kK?Hou4i9kakFD@iEpC4M&K-0 zS1`v0OOo-XS@?{1NCEQo7`1{qKC!qxGsk8Er#SNcWGOPfF$s?K~oWj>lR9}8vftdb!1d?sa{3T2)VWmXv8$|zDu%KRH; z-UVgeug=`0r8TS9Hw*G55mJEqc56~m(W^RA=6001J(Rgilv!c4KbBH=qtxB;)C$8} zw#Dt)@HUXZsgCw1Tat_u&BA94BL%1r#Hba_G1}tx%p4O4oMOK`+frnlZWcb{EK-21 zQK-&5i!#rk%rl|PRqD(aQRX?6`9di3W$Mh!DDz^Ot*v>UCqL09776Fpy_5t2OwD%GpBbd z2tfxR12hRi2Oz^V2|))SqcjOY2Oy_u5`qpu&eS9X9e@Nh2|))S#hQem1CVny2|))S zb2SM;2OtYI2|))Si!}*B2O!sI5`qpuR%jA}4nS6UBIF)|q8NZYY6&s!Hw&NfFeyL* zh?^N5fIMq4a{>?rA?N^PohBjZ0A#%;A?N_)15HBE0m$c?grEbEZ!`%(2O$5^Bm^CR z{H{p|Iso~fCL!nmq<(9`dS78cLkA$uGzmcmAZ;}XK?fj*Y7&ADK#tTT1Ra1J>xqzT zf}$9JoL~tt`j~~!7(fb80ODpw2OuX~%$xv3K?phk8LdeOIsloVNeDUsnW9MuIshrq zBm^CR%+w?V9e`A75`qpuF3=A z{YsVh2g1e^cI-P~Lk~c^{>` z_fy`7p}fzk^1e)YU!c4%L3!U&<^72AzDs#OfbxE+%KJU#-AsAEh4OAy<^7ZL?x4JX zKzaA6@-}X3&4cyK!e=xh1!x{@MJg)IgB>VuTgrO~l((}gZ#w1eN_mfg@@A>>4y3%j zDQ|x$?@6k>qbctQ$~y|mJ3*EAY|48&l=nu;dkd8JZdKk#DDQog_aP|nGpfAnDDU%>cP*56y(;gA zl=mIV`#zNS3pX$JUC8e!^Crst4V3v;7W3$o%TcTEU6gw}<=!dJO_Kju?!6W_C;hJ= zg{1$nQW~{0Cui@akPQH=lopz#kSzeLl!G-%A)5eLDMx6MLbd_0QjXFjg=_?1rS#Aw zg=__2rS#Dxg=_|3rJSfq3fT_8N*SR^3fU0AO3BqEg=`66rA*Q!g=`97rA*Z%g=`C8 zr35ueAsYi&DYHD0QbkY_macdy7gmA`zh6sY!}NfbyUwDG~w7lbWPR1So4XNs$OpUehE+B0zaplN5;n znaAPSpPHD#qPoUz5|(Pwva!n@5TJq{JVPwkcd<^K{QE7k%&}&&?H47QrV_S zibSOHrzcYO5ftS}rSU;D!EY8M!A}ZMq{85q_@GZKi<}dxC}@$0n%e7;_#@B+Rl>lNN~;;$6C=NSNgz zOhF~6bZ9z)}%!uk?^A~DH3Mcu1SkT zQe(F+DH3Mc?}?Vi2YZlUX-z=LNrGl(;WJv20+e8Jb4w&OI#}GC{EC7Ui9|wYO;RKl z@7*;?kytSI)Fee>k=kFA6p4l5P)$-K7GondNs(9pjngDWV$m^KlN5 z_fqE7l=*Hb^P}p_&r#+lDDyK==9i<)3ZIX7n^M0cX_a%W- z?0nx_ij2)>;WNG^1*r3FRcHQ_GVh?we?Xb{sWUe|#99y3GYhieN(xY$Tak*2*6u)= z+fwF3pv;}4%nIE;ol7^X0Cj;Jb><0_ zIhQh@24$Y2&Rj^Dr%~p7D07KAa}{MSqs*01=6O+Og>HWtrCva(FU3zWRE5x zB;SJ*Qm?%kCVKC_B-?`%(o~ZWlIy_6(O) z91l)NmL?%2!-Eqtz!M?E2#R9Rm17AphMR@Y7)c6H(B)=E2O#4uX3hwvAOszNoTW(! zIsnPnBm^CR7@CBj1CR<$LeK%o`I>~F1CWa~2|))SS85W14nVHeBm^CR+^9(iIsmy% zlMr+Oa<3*K=m6wVO+wHC$g`S+paYO~o(Ne_P!a+V+&!ruS~84x%))29PYO`@@r5ey zca(P%<^2ZA`>QJNF3P)|^6rH4?p5VYC5H#eTSh5n;WHYN0%Yfw=9xqDAYz`e)%@(!oG zBcZ(GRe7gS-bs}AOek-GDsKtpolbcTC~ruWcOK;l=o&R?_H|A4^!TIDer?&-ltW0U!uIvQQjA!yl<-Ven5HOro8V#c|W&#qZgAe vKAw!68^|g0JmWp{&#L^O@fE3+XM9QioEHhFR#gV_OM;Ors`3L>`5FHQf>um` literal 0 HcmV?d00001 diff --git a/deepseek_mtp_dummy_ckpt/checkpoints/iter_0000001/common.pt b/deepseek_mtp_dummy_ckpt/checkpoints/iter_0000001/common.pt new file mode 100644 index 0000000000000000000000000000000000000000..650ff74769adcc5a4c629023626910850ffadcd1 GIT binary patch literal 25127 zcmcJ22b^4G)%S*OL-~azV&n>e_c;D~+e&6~d$v)>i{q%F5>xDCi9yT>~)KOFW{Tw@W z{8W9=>kayw8qp})IGpcZaQ?AV_nG{;?2t>7snaSLV=A&%nyjhguVG|qymmI~#py7r z$7_>APJd`}=%vYFr&ZPhu|MuryU|i?X>$0?y2`cD6!+_cMr_p*=!gJn#QTAcoUL3F zo!Xs7Bkotz&Lwej)Sjs|m1EJ^Z1$^Vz%}+v9acF$%o>eGaeveq040tZX|i_asLFAG z745aCJ|FiR$u;Y}xD}0T(BD`eSiEs$Q9o@CtQXtH7^_Uz-gY$VpYPaflcQ%2zdX6t zMaeN**!)7Z-j4eHxSJ-&&aAIo2UE0cFdkNuacoP~e7a=aN@ZPGn}dXwCdXl6?3H6y z2jkIjJgQ<%adPdI%FzI(gXX9g?W>NKhH-Ly;8zm=b9gCfQW3cE4be&lZ z7@9CH4sO0PtTru9k`reu*F|p*HXNiKolxAX#f?U%k4wES&yWcF1AX%Ai2!VXb(~Z= z1&9`RqfuueF1D(xt&MTIqck}ghjOC>qj3$5PCAPE^*Fiy%w-yGoW@lfj3S)f{!2)c z8?0!%0izbBadN{w>uymwhEs%YvDI<)$tlyje~VW9i0%Tk(;vmwjiZIQk=$sdaua|9 z1VdH3ATnVL2iz5?ry12p154J=>h^qf!?>E(qb?3CIdw%8$MeYAvMx8CK5so|$Nap) z{s8*|mTuuDvy~Ik7Bw0N*ky9l<;t}H?fIy2oeN2BHe0y?8>9KSx-h6mHC;9iaG{fS zYTdY6TS{(@LuAiBs2A(IL|aI1A$qxJ&)sLyoSYU`W4&rTOgmi+efspSAO7{N%U893 zr<00?!|sx^EeEP@878+}shr033Sx2EX^f+8wLSphYjxbksOXa1%C8Cg?MIzzEAEby zGXfRz_96Jr-k1(F6(ZSai>0yK5W{wQn^0o#Uu;iKF0lewcd?- z!|FoR9mmP$S@AWntGJOkvg)AOOyg0qWg0#8Q@IBI!dmxL)2JCITm3F;v5l&=PM@Sk z8If#T7CDR|HsWf*6V<5S$ar!w?zGyV>g~(Azuvez>R{UxH3j`AXRe4Id!u1l%I%mr z)L9SB-9ehJrkAs(cfa$Q{CSq7TW=6&BTMY<^nDg$q5<-Ns7TH(`jHEQM9iS)6i`y3 zZ=ecE8!aWbp5Fb4OMWo@x)UB)xgKDCyNiwLXi!b&OSgIMicYZJg%D{OPUD|Bi%jdp%;-hX=p$ZZN1elyRz3UE3kH2#B zW6yC(*J@OIgB}ia+)Hj-j3$N{gCTR^`2~>s?h%*{yPXjhzmVK+S!AJBa&13O!T5uP zxF)%?&au`V)aMsFkT)|<5L`;U5A%v7AR=7m+@(0>ldwSu zkn-)3`*2lM2zzG5;X@lVJKZ>$T@f3HU$q|m5!ooY>&#k-u?Y}$;Ds|sS5CkjJgMOl z1YW(3y=pj1?zSSKl68W5oKKUB_9l0qnTm8wIFy0m+Tz}TyabVszmtnsG-PN^`_V9M z4@T8tG-@aJ&>duL8;9U{Ec~8YT(48NLA4f*um(~7y=K+~79Z3gbxM#*fEdI9RBj1! z@7c-?u^-U6#ZVF@P_@$pF~kkX|NE>|c(8GAI9jTX``CQGiEg|Qca!_hOjR}jqu!1o zlEUe?qqJR(utukY`H8Apr~`sbllx;2 zB<_HYopj?gsm`h~-Uo8ddq2PxyBZr~bAgCCdYyinL_Vv4X{Xi4Nx?8_bYn|uE{#i| zq+WNNwRab4x;AOd9D8{ZUz9W}H|Cje;>lsEu!*X)54(d123*o|snCo1)CH}2(ssJb z8;XT@aE{{947VEdH0G8?s5}SlC8gw{sglg^O}aU0;ciDU1y6t6jjS_TN_w-h75rMb ztbRnj?P8VUT2iXNtRV_4^loV&*n@)=sZ-h?)~RuVlr{k_Hd8feGIW!Tuux`h3}SP+ zMoPxZmGwY`9)|t_0nR~aVdFxlr;&8p^Vz1d2ZW_4$)K~XEY#BXyo1SI~zW$BmwA=QR1>OoF8 z8N4>`&Iib)lVb|9s6G!~)^X;227?i9GNKT0v;HA6RQvUMH@Qs8nzD_j-|55ZEv%!5 z`i2@YP8*B_dg#<$T0d+>%Z84qgYZBhLmcjC0H;BOmOR`on(m-gQ2Hal&Gb!R5u^jF z4I;?JM#qxNEB8n1rtTPvj5jsY#ZAM3jhcgQXOM29+Yxm(VS$!zO2M_gaVzTCrai$g z*>IOVn;J2+N*vEuX;n1Fz1~u_K0h37DovC;(v?<}&Vw2lt&y8PqNWgG6%PAeK97#-ONk8vFwtVS8XzL>`@OA*Sx zl?HoUU=9gY=zc(8fl|VTWj@|TwM*R$!C}=7L}u~?Ddj%o$U>a_m!yqW-j%e550pI7 zX`ME=7VFkx@+7}SHp6ZZD0%XV6e^4$+C0fqTmgn(t5J8@j*_QN@4f<8iXSPb3^;0o zE!%@`BYE0P!wt{jsMCYvYRS{xd4{!=JVRGC$Uq{@AQqBmu83_5U;w~%pmf4zKC7?{ z=)iEJuCt3il!_%;lRSGx3NdiFgYwxMTiez4)<*IiM*?{bp{Jo$DUkgi?$q`m9G>f{ zX$XeK~3JGN)#?cfpwTr3~!tRGb#3}S*@!Lw*#`fGobo|F$a0k%VhYg_a?8&1stSQp~1jB z{d)4+6>0P)r@|TFud-KwTut5*+-h3f(7mHnvR6jK zTk|N(Q0gp)((X>)1}Nok%K&wn)dJr>b4bx%7#ztv&`!fw`0iFh&O-DoVsY;*{G$Wl zjE7it@~#4h{*}g=hmi_zD0%nv?h9*szVy3S{!@xXuuOJQ+XRW4yr=L~$LZ>9}b(-!*{T9xw({IKmqQ{v+hdZ0SPZGb=A9fL$gLUpfuH^kP zX%E&DUIhk7ojFD@jj8#P4|u#Hbb_xE8x2gKE%~6|ZU*fmCN}6dJFVnH-VG~k#LbAt z3P+^1H%2V#!-W)MWUP!efc!~50)yxPfcfO3cm5B4Bsb70BiBOq!8WFCm7hxTu@LAf zkQuj&U>QP{3Rl)Q13DjfZo#+)cLrgtgJOQd{TpE--cYTN8xa|7*iMsA7W50h3)Ti= zKiw{DxitAyzLtD_Ah#Ciyapd4#nBl=ki`6_S1LF6gMo^JxxhdfsQ?ZKwl^|xx8GBJ z5I2+0%z88nahcp=f$IRTGUJ0@C)m%QU6JF7rBj-fBxZe{XYjd|%IQAuY7qQj1W|&7 zb*4lFCHef!R#7+x%zzNT8Vu6AHAuc7&DZQS2gw(+sk0-WIJhsZRBjQroYVAywSw6C z%d^=^i^DvaTvdm!%okKF38DbegGH#=Z#s7aCsSgu+4{jts*Ygb zrOCJUOdTSH$3tg$w>F^v1nV^U_KIeMx!6d)v*$~PNi{<=n5afhLh{`eSs=x3%E0fr zwMD&MtN{L3W6+cNoP6KKHcjcUL&)t$a>exSr-Xo#JJunttWIUbAGl{y?!PZu=s4@thDSnuPzpfHP$<5uNJv>lQBU4+^b8B{0-KO8!_mL|aS# zG`;)e9asM12DjVb0W~`Cmi)QcBo!)Ff^13`j0xCZJaSvK(IP?nUuT`*;lU?=`~T%> z6~~w+e^+#d5d)Amlp&)T{{T7BOi-&q|6^@A@Skop!*Uf^x+?jXM}x9PaA5Lp-wDG5 z$$#=b{FQgt>DknpsVQe3HTF1*Tic*FuqT~DTO%YU&M>y{Ft#0{Rz`Xm^QglFOA)_n z1o28VI}|9(l23iaB!iJnn|WfZc33{rDm2b6?2K4A!ke zM3l6QcjUnaEa6et*ik?f_cX^{IuSN0;ib8LBP%&ExQh|oDs@>PeiLy2=3S_ zPIq0lF~KWDw^TvJKH2p&4CUzrl!R+133kiXu<;NYF8AiNMvBrf;K}IZwwkYNf|{Hm zO1rqeuz9Kn!?qg$W!cjr@^-bf8)|6>>nEl>MMIO7G<1}vV>kktJcW?8ah(RW-AIU( zLB)AHt0Vi7+IoN`p~47*d1zZ`NktFsR6M~!U6s0JJlx^H%4&BMUBuBB6R z*|a|2^!t?D^@T zy0i8Y6G%vXrgked!*ArJDLKT}RS8i$12`T{@$i@D{O9A4Sk01uwn2EbmWs??7P`ZN zIlgTa4#V*)d3~yjKx+w=;n}tEhN}i@@K3QFjQH?354WmVBlu7*v9*W@ibJCsj3l-l{TQ9yiHcwM#tOyf6fNbL?}3 zOe-vl;0?EW(B5th&}48g17+rB6Jh7Johv+w&%m`vdLR_);c~kTpzX^%2$ttW5aj9f zVw?q2Oh8X<7toMIq+41ZB|}SbQGF0&)MW62OliB(!&6h)%1i8R=V?)%szVH;NODva z=4V=h*&n;D@F?269q|OHAlyM0vDJ~C&tl(;oXGd6uUhj zZksm%@mA#=?GA!48bins@+kBX*()5!)B;-5+M-z|qZ)Yj$_FP*P-~GF|Gyp@;E|OKy=(6o0Xcg(!VMlR9KTCs6Jo6xxzp z0d?un?x{9%q<*h@qK?^1xIA_*Ks}kr9$ENKM!Qwila>{EESQ8MG%PQIb+yo`$98XF zvHEJ#lb0r^Rj3WbttGn;Fl8Il7oncF`?7_SoZPCgc_3>P95HrF622dBJVEE4#9*vC zq;ZqHxmS~)%FT^iEmmK^wnmZ>4t|JQ;a|)#x?DT)i}hosiLMxrRvNkJyUbtI_|k@Gd1Y4_f zNH1Am<53WmI)us>8vqCci_j+P-6pcuHq@w;EKCSg6-=>&T^e3Aq1ND*n-Q=^CY-`# z^7Aa!7-UCiHD@3jsZEuoq%Dln0pD;RV?oH@s`x@Uq69_AbClXbK49QAoY^AUJb#0P zA?b@W?Io97w8l!9H2INpSs@%3&1UsemsY@3gsa}p-dnOfed?R0biInq?5~XBqGp%0dAAF zn8(}Y(L5ALE6OgHkfhS9eI$Nie8>B$J3*LKB(DnSL2fA= zU(npxhOlYvDZrWDJqjUbQYO42ltD+mTBkK0j7gKQVHp%p?WtOu2Y{d&ur_XCBwAlg)YxS&E@P5*eAt&{&F4M{abuLPLoGgNm%WtDcZ+Kdxg4i98~GU;8f!))hKbO`e@9Eebd9x-rwd}P7(9_GxayFoTC;zvc)eSo~ zo>kkx3UzxOu_JD;Ddm})h`k*tfW=dhu zY2V(YUS!2l1|mGYrn+Z?{9=4TxVx4JNP093XfpqG-4@5PyFUP(Vt& z`vQkXA?Vg)Jei55pn|!H-)RMLL$(=6m7m7uHb{c7VKk2)ii3raq@M*V<5^b&l0 zNL*DAyGr;Vphbx{PmA`MDhMfc8iRUj9|DdidEgh|9Sv?w+EjY>~m~CRFV$sLm@J+{_|{HC#f5z$Wwgw1;7fmPHysDV<*BbFuyo~ zNg5Pp=9Af%0AH5C7hQt7X>;0_C%W;Z`0fKl35_6=mM8xeAbDzkf1syjz6uOa$MKMm zezvc%g=ZVOFd$ST{YXB3olrU_3y89OgTHWU6bga}ClJ3$7(=_GPJhY1#ouz{Mk9Tj zO?>SEBN%yP`woBSmC20ol6U(q8jE8jq0ms4O9PUux9xjC*mI>n^FjeFvs?6!c)fD9 zz7P1!14N!BFe0ER>cI_z;p_^ctX1|~C_3iSJ2LwL=a@Q5ha!yHU<97TQ zp-$lZSU7L|#pS2G<8HrngPMnOfZ_1i2j`dCPq?sCL|^4M9`;i$h#Hm6TUAnLKT{*^ zDy9h_5mp6Cb>u(ih-(x%3|Jxfo5xdy_6tYLnAA~I`z4xZJwC4oFL)~mUWpU?mE%+H z@hQxz$A0!}^`kB!Gzj1@za%~1XV6~NM?onzIR=au*i9IJG_*enLJG)6 z9}nwGC`|>+|5-39XTlITBEL|jfIm@?RR@l9@?U@`&X30X+_e8ITe9tho(=Kp-vDuo ztTbo{`utrJ(snPwTmi7$Q2z2iI5Hz?A(LVMWJ9*KFcr0n{R;s2EZNNj5M)|bzij{3 zd^9Rb^HEsqXjA(S;QlV61Q$&D%cc$?eANL-vju&X1s&E1k)i0)Kq1%^2#a;Yb_nqN zSy~BC6~!T+C`aK?U=`(mjKP(#5jdjNG;vFIn8wX0Iy(ti#JocEfO&!PiSVe;4i`3A zvotR4=L{P`G&@2_l-NYdE*>Tna&mSgR|t@Payqy~v=LB_0*d>%?2k%9WCBpTNfJYq zq+LVc>}nVkaGHj8rY&u2HGM`ZZk{3}CyfeZk=iwZBTWI%Erhj?W>ZG(;R0w|BV~Lo zjhOLSvx_i6X}~O(aSR8bstLW0PfX$pBSD7E8(%AjF#*=fXjSgI3TqX`K!WGnBDV?0xb1f!36Jtg)M>Rasw8kb50VI0H^Jj{7Rw9ZPWp<1(OK7|vkmyrRbCOuC(?nZ3B zOoaq!Mv#PArvkj5;Pc$$ELug!R!#CDSjfC5$mFoP?* z@}@)pzJ5gb9tP=kBQO-y0K$tm1I~)1m}m!Cxjv$6rC>}bx3-%D(bH{>6wWUrX<)VR z3f^uZD6PFRG(VZs1g1(TgC(d>7l?AY3{twq#mOy&K-MfHP!)pV@sR+@6}uI%yyS+Z z?|7_Rgp;K8&&a0*l0?^6PkXSLC<T`CZ6`qRa%m2-k+rAD5}mCMJi1zh%0tkV zQcic-qBOCc1F*l5;DeVi9@8$^)vW=`s{D|vuygr4C70WZ{!ofjECpfgdqQ(Pqw~Til@Rlp5hOrb-_u^{2d8dpn z-NH+r??2$sYbt_5ed5k z+AC}V}?JqdjSKn3o3Ft5qhXm;2NuMy?VQ7 z)Dt?pHz0E1QHUAaefS%NLkSILJ`AAy0sjS+`FaxR6$p3R3k>&hi}U9- zmO?xC2MWG_!eJ(HN|}@@aP~~CRWiCjYAk__F_#9TvH1YacKXCftC3Ky!Q8aON3~V6 z2z{o{yFL2gr?MG;9DD+ZV(gT4AQMB>FvP5$qFhZ^kAq2Q)wP>7dWCQ`J$O#=g2zZs zrq=kM%)uT`t`-abmRG+W>W|`TQ){9V6yJ`k?*(>86eJu)SRfvA%SLHo1dl#B9fR#* z9hT>43uDvOPjR4!7}nwH+$$(*Fkn8&hnNG<^Q_8Y)2<4|3bZ4YGC{H2ye6k$>2;BX zZ;$aA0~wkbLYH%_39%zDjT06wJwxefjQ@sEiXBUwhfC?B>z=9eDraB)?#rD!`--W9 z@EdTA>lEoJs3#N9@fST*vIhdeqmulf z4VDbbwg<71(M*mhx5zFuwVx!H!`G$JCa}1c6$VV z&8}&Cc1@R~QJ!Xhp*h@ckL0g@Y}*fTNsrPHeni`SW44So7ys~=!9by`G?>TA*sZ7= z6&-`vdd{hIFl~@W3q~#GFh0=EGZglid~k;<99CzA1vcpwA&=Dnv@{$gOX&E~+2a64 zU~01Z#ZU_sea z@vB!Ag`%NiPK37b2Ki|~C>&Ojb3mkEb(E)1pwL_iC=4Cw(GsFSG*0XpnoY4mr0wwu z9M}+ime8|j0s>!79mDc^Mu4y-cr;6?dbN6S%rXc>ql&}Va6+c{R3y1BF{9>c*`MznQ8wFI+4TNl?8VYfWrM(&G$mr0pDexG0&x*J0EyOr?6$Yz3 zbBwnV<37sa6k`-~h4OBDn-4_oQlMqAfDGyFMA|38S|Cx$c4KAHGrV^=-sBZY?ICQ| z#ks)seJ63&RMt;S2(bZ&(4BK|5+dmw(%$8`R8qyf<#^+}fe`!~R3p8IzfqkQ$og_Z zOph}7>h@kA(=DXze*Jq6>>~s)w1X!@ zU{>bz?4xYsDnsCd9~$7-g;M(%A)G`H<766=!ymlTSw5-K7mHpHuD`2 zH+<}}(*Go&p0LTFZW(6z*-zzg%G3-e*FpAaO-UNu4+d{gAOEq>0LPz%Heh9amcP;3 zNt+}4oVxSk@tb>0Dv!tQ;!_gbAD^y$o?U3f^;va-Z-Z`VZ&1ZN4Aw6YCJ~Ak#(7dy zXZ}Tks3E2=ym*f^%a_Th132~_AyAnX<7No{sLt=^yTYT2$X;?* zQDwv`3KTus_x6W+fdxMeXx|4^3PhHoUBO>VS6Rw^{{SGbQ4OnMj^cnhmMeuRKJ+W% zrveZF3AJKZ3Cxq!aUdX?75du`1)_lyxTggmt1f>e0Lv-hL(JQc@he!JJX;b@KVhqX zJdT{gVt0T3`%?gx6?HDVAo=?VzZH5PMTz>)d@NGD7p_b|f9_Cfu|YuZo0 z)c%Eb{{3(KEmu1@tjH@M%=jO)?zJh1)zdDuHT+BC!?FgyuvR1^FgpZ5|C&Sg6__9q z4h3LYu>gE@R8>xgp`|dhxV^0Y>F@~%-383VI8*j<06Rj{DDeh6uWNRsx`a;*w2SWw zwEt8%FC0j@`{J`Jb`<(8OGxl5K1g$GobXxG{41u9w6$oLT4B`nnncRJWZEu$S(Qwz zr&mVjIKXVgXozu4F{s()R)4J!l+Bq(5ZuOG#3BpQ{ z19IaXLdUN*|9pfUg8?qJW3|$Wxn)t{QbLK5hn&{|4`fdNpJeWh$FXru`>OU2<-(y4 zu1jC?;CWr!p~bC;z&)6;*s$XPF1{|oOV!sW8g{}2KuHk>HnNXR*mVGJ|53G_AgTc1 zgePjk@f=xf`uvN(u!F}T1qW(GatFmtSxxSBfqJQ34}YE}$;21*FkX?x@M<&F>ZNuP z5Zq}*-5}(BcuWR|Y=DY&g$(!E$v|2jk2hX8bL3Gk`2L-~v1{+t6#u*MpkHG@acWog zHTI2ejbxm9!TIN`{wVw8C)%&}w=G=4+kf*NKAf21e+OT}NmD!XC1j=fvVNF$zTo`r z|EpE(*vYsc2Jw{F|AdHd$gTQ=|9w(ZO_x9vRZtnJ%2pSgAWPOkDJ$Gq?J zr{Q$?--R=4xZwQ%IiO6%6rYnnj0czge*{){F4G@P!Q$ufV2aRr71^&IqfA0{1B##H z?q7_tAKk7Vrc4F2&WfKW6~pAzx4L)v9;tJSpA)P3A_Cp4?pr=xn!Cl%lZw6vzV{1f yjyU3sfPc!N8}gncW&UJ;^DPj6c0NFV;S!IaHl6xtFiXb3OP&xuq z6a^8)-ld5MQlw+S3jAm8o85Qk?LNM{{LUZGQ4WUB+&6b-?zB6%R}tmc%@=zuGAy6| zZpnv6AL&xPbe7NTaHrj$XLASIWeiR;HVxR^evdE9WA<6o{6Uw?>?=&m_c;PKUz!bn z%4ByqZ6@VUPIf=Eu@m7Q{;Km?!yki|ltSFbnZ|Tu6?f2Oa+(WmK7X@-tsr0`xUQv2 z+*2W;La`D>yB14t)u)8ujP)FhV3H=t-Z0WG0En(c&s*`DNfXH zYYb`COaAt((S+L3*xs0AwcQ@IOn9vc{XK4*y4n&~+SKbzxG7bBpJUkbFyXc|W*V#J zIILEi+vImluZQGdUX#watXr&e+CS({6X0bTOqxrZ>-M zp)PWzf8~9QrEZJF(qv?%IRHBzaCiXTX0~c}s$_VvMcb&T_YHWe?t`?{(i8j7GL*0r zYA<6CV=Wai%WTQDxviW5)x((Ox_uFB7(nS%k3N@DN}YfIfnBnH{?%mzyik!WGe zFjmdYGg)%XZnw>;S+FEpnHefvp|y9lz6AHjH_Ki%>>5XKU5xh`>ty>pL9c0i(B>;t z`N3CRHoWTMJ%c+us2Pq|Ej_0BJ4KCR>nl+v4Sk9Xx1w+I9*-t_Pd|| zlr7ipk(aA7*b@wRg8@^vIiQg;{pFVKx4b%&vo+%h8|AE;-(wHB%mt=Ep;rT0KFoo8 z^2~0B(`i;0Y@au?Q5WJtk=OUGF@Wq70JYa_t)M!^yw{1o*V! z<++A-c?8(m*vVLj0D%#pVz~~l$?mg_=NqOAeZ18(mVN}?(elk}hUW_ix~s9Xu^xWB z*W-5xm9V+8Y*wqoovqn637jk3>gz9DWZ*|8$l#H{Q5Y?O`#fGr(proub9FGSy#9rI z-ZIphNZ>t<_Zyq|Y)*5)k!MpySlG&}AShu#^8q?~@>tg|ZVVv`ra|{VZ5TR z|9czsrTo&H>tE_hL<)a%zhOA^BoS$CY-y}Y1hUM2&7vj5^yctO74aq$H=bu|USldD zqMFwTE!pO?fH~!&D03oBf-L-d`I6qW$X4~sEH@0AMv$G19gIl?>2Sl@>$I8kY#Ku< ztyN=^0>|5XM-cGzMB6sQ3)2a>zp<~eIRX1^<4sODqXYr32ZaYt0lV1}@c8%#ND|4_ z^u-d1`8|oy&iDUaXgDx~2z4;FGbSo~>9v{s7PAwyh@eGEOpIRYZNKw zY>=2c#21v*{hjp8cbLR2m_yhoan)T48*CQbHeE=*U*(EqeLINu>y@R4-!_z3K#)C* z-HnZ0=3JX8&toxX30n<{o#*g7vcTWVD&!O-4m0Kce#yGQ%;5;WLIluu>3A93@4^1?gR1V$Ir>)yZzCRCqz_=6tIVmJhF{Jx5Z>}nq6K~p4l1HoE_4Vxz;s0 zT5$23)mc44fe`bYdD^d)5N7uT0C z_x8Xp0!R0@hAZIJjY=g4KU1|n`&vyq$c~K$y9p4Fjd}>^^!WX;9&*joq~M{w}YA7YqA+Z8f<(F3|a)i$hBMgB7YD zT)Xk(Gnn|C`i=lme5$*G;33C{IPG-(^zI}+-A1G>Ff=?&h;5Cn!FC{`-|KV)Kzj1{ zA|;XJR<*~i&5tnu*mRUI(Lb&wFwysR+x!qacr>um?r40M)BSq?!knTxKM*cDMKzVR zf-7$W9#}O8ONDTHpoHHGE{)IZRM~16_heUQoe%yIh2SValLYlNpC^{3LN$;Ra@5++a@w5yB-IXvsYsF}}%vEU?M=nJ_W1 zsU{W-MmmxW6~pT^WH~*S+CD)y`GcT^u?a2LXLjdm zETSX>qneLV^(@7`rWTX06Mqt-$a8{Ld3<0JZ7vVGagerw|HisyvpNl*&+<04{~>J5 z+temdzuWBf=Xe4puQ`yzp{s>eqipD>(!AiUJEfR++rbufm`Bk2KO!QMsV;}b=P_lO z1HkP_??qB&>HL1=NW!~JC;~SL5`_Z1Z%-BkV)0=fnsTNaGY@Nuys8vF&^*-gL4*Ja zWj#z?9qah{iRvuH-e!$GH}uqljruSVDOH?kpolcE zxx9fwQ_u}^9TpTP6n>n1FC|t>`gh{D&7WE5tlN}Z@vx!IJ%s&`ahS24CC3J_8Cjtm zvp>gVhV2A5*Xo6E63#%2*{@`yAiCo8bXk_P+gqL;Sv~`i#3L&iOsNCT*FuQfM4Bsh z;Lj8MYni=!tt0`Xy{oDKV_8h2ni6LuAzWD^`xuk!Z%Y%J$mP&7L2R5h4W90o-}~w&Xo#mTe6P8qIPO#jK$axXqYNS3-BpWJJbo5)ICj{9mp;#?10- zDMUk@+$s%NdA9fnZLL;cGQOL$uQRLHo@h+iSg&n_=E}lU&wFx2QwOjn-MA+joK=zUQd9;ZjLfh6)$PIoE@`X`Qj`K=pEe&STw0cY}xG|h_>vk zpffibu?8x$X0$^&*7#T!qjFI5o z7K@ZUIq{EcFS9Dg1`#}3mAV*$qGfTop=m*P0gGsg?oBjABPA#rYUE2R(Laq6SrFt6{i zhX@|MzG_ZS_W$;CB>#Ns>z6t%C3{rtRJ|I8lOqXkPEq5FMZsdIZE;W*)i`-M;mG#k zX3ijNv8|cipWn>u@dfN2r^DlK264B|>}UpC1v=&j57!mUHoJVy`q`j&$LDdU_U+fq zYJ=pI&6aDz^uHBa?+Q(pTrXxm!X0^Pq<>-Qy9=1RlQ@PJEZiNJ*`Ev6In>Q#5BhIc zB<9LJXyvZr>}(urB4|<6&vHPNh6P1aG*tZ^o7?QbNTz(#%mgcJoPbp-fl+}a3Fgba z`9PPQ?7-S;AzVDLnxb~VLV+1d;E7zhgA(SgCP}-Pp8RAZD#G-O)l;D&zU5cFHmoz_ zcIh06+qI}(EZP#{*~FF&ov$+e=sS*JMRid04#biDEA#WyMcH}%l9S-@yl$X`@sMW+ zTqv7kOdp?c4E?5MQ&xlh*iHCYgVo?G&?W~&%iLT@Nn$sI5P4E!=|u5a8(1JyWIPcN zfsBw6r$|XaPZ>wmmAr~6Y3a<#$n+DYa5Av@L&$|v+~uJ9da|c@olY#Sw+0CnQw7x! z6)FV+IabJik7wS;av2Nq2~y;4p=^a^rdZM5n)yswmbUz|kkCZhlF&?Mr#HvUF{QL+ z^E4^uPF+-C0Apm;M0WV}eQ z1C9NS8D^1ARlE}65<5Dkxfkd*}!6 zQz@*Quo){>S!_DqT8TLv-_9j$Q6+`;#~|D(00)#7A1kgaZ|t9}4Y%&ggo|ysjZi2Y zqRqr&2|_V5)H)Q02bMhY@3HJO^e-d|q9e&_!}=)r&Q=qut_z(@Wg}u0oW?U&&S9+C zut;VNqB_AdwG>8$)mP2=2giKP4Efb1ge%fOp+HEfcgGyXuFo1e}%q<-KE|bqGitn-R z-srb`dDuy`ayy~oNt6s|kj8`3qF;pjk;MXO2jzP^HJj}}$fB+zI|&w}F1WYj75Tw1 zF!p!wCF=lpv#x8yeAv?O5HR|%H3hKZG(-x7R4I}A1bQBsIh!H2eK!J8Ep6~cMn}9q zyYdSx`|a67h$8zf)~Gb2Lf!y^u6(T!NT)7uX&e1Pi9Z?g%J*-B9Q1-v@CEB4Pq=4K zfNAl?4+vQVFJjHPP;G}6G>#|Ut;^z{!3|}>Jmxl zZnF8X%By@I$JXcG`N6)&n0+byn!rW9CIs$=dx0Le-Jzc!=WG7`*y7>LesB4fV9|cp zg$=RV>}D)@ATzKo8HC&(yUZ#{7Hh|P@XaYsrV0OiPtd3di88bY;uQ`m83ZYy?NnsM zB^E-YA0bE#A+(UW98xlbsZ!!I^)1WE(IhAbKmGf8!@^^PImkG`m;uWSoxEyn2rU3u zz6#4KC#1iLUwm}%2vfEL$BC9uHi;JA?=(Zr6Vjl9l<1F$;{wr`pU%x-UdzoBL`irp zK*(YWaP+;SZ{U78u`QKt!7 zbW8v=1TL8Hhf7ONClvf*MjrF-wqtA`W}MaRLwIZ9SjmQyI}7@9{GeJsh&v+F4*K8l zkR|sVldikZ6D5?cO!x*!;z5=iZx5nTg&rUftbh+3RK*5aJT^NQFJ1dHv)PZ!*jIJE zsQD_j<^X*$&|?`o;9ZVr?Ev~;_1WfPOZt<}?AuX%o#BbgM5&vxi?JSj75dy#d}+yz zLw5V!Th26w(9NTE?J(@QO6Y@)1C6Z!U5gA|vvAde!db<`f}0wgZIb%8luSzd^5wr{ zpYrFg(x=pqM$GKY_Hb-zOY9E2bc=gXtpNdsp7Mu;j*JQ^_MIrmtaa57Ge%kdlo-_l z&{Z*xo~?-V?W9lPC&ep&Oinh@u83=0OaCTfc&!U}MJ1*eI>te6u47z(Cd}XcOXSd0 z@^8ba0z}U3usa}*p+3x*d-f*(_RD7GDO9*gl+aUX3m+Cb0MrjuC@S?7zbQWlFOfpa z7MBWr{V6-TX?f-7GRAyPD5thoI(vG>Twrfa>m)KlJP>mJ*y?j1WbK3w#8) zH4ZzaGsr`PcJO>{Mo@pzeG80JSx#n00+A3`CxJvLj|9arH?=UC0#Jw3*k#FKXqher z5HIg`e!>F2vv(5#5%2*4ann7^gMCp@E{&Dj$^FVb)Q%*#!kqIb4Ydt~+sv41EDse& zjSnm7Kzc@|G)>~SdA~MfH;YD>BqZ!Gs{yrKpRnGdBP7=4b!u0%G|K=kC{3Uu1E>n! z9bo6)e5OCET76xXutn7h^P}pnL0g#B;{r8yXqwISW#ahx(d!vQ?yVp(q$Wa<978az z&t-^6BsthQcVbtTEqd@?0u_Dy;eH;m(pccRA^W!X?thpZKUbOHQH~qRYc4S7IUoc< z%MwT0^_dZMyR!tto~lFu6AU$+&?||zuYe8{dk5opNCS zTXDq(+7+99E@+mIUeVHBajnHGcCa&~M+)JJ1P-oP!HvR(=+V^=>)MxK(bW@;2@#{K zS}I~x21LiaHKFJx76g9WlrS*}L|q82GQ#Re#{|n*9A7)u&W<`mGlCRH9ah+U4y(-s z;hL~*F1SjR10|U_I;v-$O!=HePkqvfgovKdxfT7WUJv-u{})14y9_VSgkxGJ5fb7o zF8+k>3&e9}^RdHCM>FNy(~@vS83L~U+zD#EnH3K(%lTJpLPg72kKL|vxGlbD7EZb) z2Z;(Cn-WmHm{Wa=J-|-t&h3Z*p44TXnk1ZLEiq*#=gbBaEA4JiXMXj%YeyS=9SEtp zG0j*JGv8V?J$V~Ga7HDjXoorxlyKHyTVQsQ6C6)1pswn!-`Qd&+ntJC2vO{giWqH~ zx~SfXCA-X}Pws2T>t`6&jZmi-H9lPw93ALw=1$4~-?lAPMf~a(r+DzLc_dlF& zS9>8V#%&!z=%Q-whMTWc>c`soZ4^P5AqfZDr?B;38zrw_LjhYWqh!~0(fv=Z$>*Id z4->k$1foHwTZMv1M7yXijkz~I>&?z~?^vQCI$Ja};CeA6|KcREbIybjtPHU^i%4J@ zqKQ6<$h1$@uuoma9(2OEoKOI)n!EcNm8!&Auzs}?HEhAEgw^9f(5*4c`tjD7DO0~< z+17^H1Sgtap({evb4V+#EVvw5UDVyl!{0ehu&DNhT!IrV z+o38eMq4qoY;#Royfpa=v-$g7ge?-o8utvY3uzJ%`usKix;%wlG5O0w$auvh6_KI0 z3K7X#;e}4c3yg=%+6|XBF&S*?BQhw1Es>1KaA0#DwO=U4@pj~-t>El>WZ#NQ%xwz< zh?)p^#f3hc>frEWAC#`-=yUCd?OjW=&~#NE5fGOl@N@8r6=`a?cZUn6K77a8SK6@k zpDvKskMNS=6oQ)X@xh&(SP7&qJTijW&6IHJ4WT@^CjY`y27l-65|0Z00*y0G1vcNuFY z&$7t#FIWwUQwesMafq=sgAIBi^%ZXM5?g@bVzkfDm&@m;J@FMwQIDEN1VuGZ;X6AQ zlv7=AgY_wSmUVuP3FxBfL`gU%dX#kKhueo0s~(xnT=I`*5IVZ#Rg`YWI5wj6iGgjH zIr!^2!V~5o+G-=+D1BzMKK5QR8^bVo7Qu`0nF8MIhf&KAq*H_mfv7KNi7Z{~6B)Pt zkMG&u&Yn#~aBpj?N~<`9FlR<{#?%Tc*S)kDPc9@9A@?S(4XSF^sK)Ng50Tg|d$!pO z!%uUFQa|GZ#x#|Z2sEJ82^|e=1BccNY=GjMtp7!=22}XMN_4uJYW`^GKA-4}Fb+4i zLppE=R&nqnsbE#bWV72H7VPYdm1cu4hTX%W@W&SrO^m|J+r8LKmqyk$VV@VACK=sS?)mVIFm&C$sFYuOdjLNhJ&iO++{lnQy+t2OK@Cs>uAn zAYVrk-Lbu|T{L|CI>BZbn;Wa-6ylI47$~VxK`CJ%RTo}P%_+;Q)7@(cOjsu|gbBek z(<%8<^=06~z|zH8l+<%Q!HfQXwV`fzZ0)e(?2;&iq@B0mjgmB8J~xGvETzOla>Q}ESa6EG^ch(nbjC~y_P;j)3G z;&>wtzZgDaHk9gBfGSGku0p_C0))m)n~A)j!Uq$U@k(oTRDZ|x}f%KPVe)q z#pAmZ1S_WLS|J&Z1|sSs3%VP)aM9|b40O#YEvNz*WyqZ_=!V}LcVuSlr_%%~;&EK5 z7kqnKsW^wl%_E#O0e|Yx()gqLQnI1#!w-%nhcu=3{v(E(=ZMl!;2ch8K9h zW!Ce%t3(d1XOf}>ck&&+yFPl7om{1UBVciIh4xut4&9{^{&47+q}y(UTMn;kE*N{|7c44jT@w&z#SnF29MzQJy~uTf|XHIVJKmk8W;yuUh;% zD;Tc&o4`bQ)M=go&4x6Tg)dP0cJA`rJNfcdmKHz!FF|5jyai~E2j-;<#Z^r(2z-de z${1lzOrINB)$c0Dnmmj*iJEBgkg18t{f-G8>yU7UX>a=?%am?Up}nE>cZc6B##K>GVM80kqC({0`&33Wpg@k;l&ob4jOT|N?X7U z?z5As^k92ix-#L56GN%?dx9ca5VV8)7waBO>6iaEi+HrM{6uHUvkxw9!TfMz3gM!2 zU(G`E{37QhNq53WH3S!=p)`o}r(tLLjrnek2^393^BicYHp8@V5P0$YPvpuq`zdjS zi~f?U32m6Kl-HC9i5>&G6ofB)L46qJm6#%Z8(r*92|o>Gx%V|`ge_{qIfYr!WQqkQ zwtOy#`qp!0`tqVotBy4%bfwP)BNEjt;3OJxsT1NC?M!jq514XbXlv%vm&qg&!lxJV z4gnFgGT~9{h4CwzrZfH3AAW4iS=JHRx+OtlM`Tr+WF5~nxewikKW55#s|_KFjz|g7 zG!CZQa5n3*pnPi0R3_b>+Y_o72rQ#UduKpW8L55{=#)~Kofm~22~;^R@C7k8v6C-W zpR2!BYw!_kz*%!20b{FB6&x90iqgw1*J9+NNwK6Rt(gyWsSCl0MoR@Qjz-JBZfx+d zZo7Kj2~WI)BCh6$=A&4@x7MJ-4s3xVdq@kc40)&c=0&VHHHc4DezF(g2?H;VDBz(3 zT$2dzKz~@R|#Q_ce4_Hx4z9V`;C+_1)OL#i^J--H;yUEj;%?f2~%{6;5n$w9=ap+ z>J1vyWWM;0hY4Ny;)ss8c;8#PjT`%!wYpQzJSM&u#}dBiR6}g+6ag}JEZ`_Nm~g(y}s}K!pV2`tW%=F^5`To4H##rpJOl%=#R35T>v`E_l92Y$(@g>AdA| zbCz2NGmHLvF2RZZ2M`H-ZZ-Qh){MU;oWG-6S?2Xta}%Dp9TR#FJ#O3s#rOTiaQk>= z+)$A)23+osUhBB00doze`Un?Y1DMth`4V_Z1XNIm6Vh4Eo<{%o$8AL!WA_9k##UAC zV}W<8twBh5tWRUpM++yieJ_$vc*?%Ron7={G&v(_-{BEEX?o@LlJonqRA6c$!HG^1 zm{q7Xn!A$H1Y)T$l8r-@A8VXQuDPpCD+|_v$te z+;qo%!rWx6n0K{vHDnc%c25ze(y<7G*PttsuNR|v)h@J5ACTtG)62?3imklkmkc2Te8bR=nAj<;V_t zf>3OSqLpX|*C(JEwe-{SW_y@^`<^EPsNZnw05U4TecknH;9jdx<_A_cST~E{ux=pW zZfkqksM{r5o29cj;`nR=!#JWEuFda-@oTp;W(+XtpQ0^U>{xv+0gBjB0mSD86*I{{ zTC!=}me;#_br`Fr4x3M?qLmF%;lgecL%PM8Ej?nWu*GbZ zaQBaCGo0m}maimS%sUBOG|y3nR8$GGyL5*~;TEn3Iv)xGk4o6HH`lVUqyFkD!WKT1 z`ZhUVu&%)U^q0lXcu@BZi3hd7w!qVA(v#;LRyv+CGuS)lpMIa5&DQzg+Q@ZAK8UEh z&L!iglx28R*2{P<8$S9I2j2SAxgM6I+O(1I#PxWEkIGAv5$C;h=t#lp=|u}3Vyz*k zHWMPYhSX3m>_lcurDpJ6koJdETMw98Hw1^6V1--Nwl{w{wb@og`E}uokq$%Etwdv_ z@gaDFA@nztOF?|E(!#P;FZ7T?D!jN5{=I0hOKp|bIx95`LNt}=kM*`c#zdskTSQd2 z*%DDG6GfgIx?7r+x4t@4ZrcvR6|ty_i_fLPB``%{G>4ondSTj?cFaJ(u#3<|ix*Ar z4HMzKlH^)*)*b6Qz9&0PcD+Ne;xq|3as?zNBp`R9nL(qw*vGGoc%Gg3XWu1Mk(UXh zqF)pvH+YM5>o;E7!dgQs?Ii-(8d@(*AVzS07VcWO+|NqFz4j5fC<%vo3_d<26rmvF zb=aVLUfF_J8K2O%^v}$w7VIY?qVYi`l4BRs#V`X9EB4RcIdz%eyX-^47Je_ohQS5S zAdHW=9jElg$M*fpUZzWjJ|aq}OSJ+%cpL>PK`syY*{G1Qe!<1ePj)l0O!$;=MRL<@ zg-IHCPYpsDxcCw)>YTiZpD_*X`8lBq4b`ARBmg20`7;hL2|bs|GXrakXTDDUmxL~S z9lT4O10H`o#m*Mb{lbeHTfQbpjJN7%K_yt}BZf#5V*u>swra2EM>p@oYOmSK`fX2q zOVHSFo327*6%UR~rH4OM01}yap%2hj@~*n

5hv~)1>Jl5lHbl}o$|hiNg`dgHmXPV1!b5BXQNzf``6+>$A zIE;r6U4!Qx%!zpYG~uEX0guatQbbUk!mLTaP}sJlqEM?ohNjW+*CF&UbZZ0a`{V^?66-XP?6kJ zyS?H_m=^mbaXNDlp1Vw#=pZDCv@WC_^xi0zGE>a)ZMpN~vaI)b%TLBKbFY1e+gRt$6jAL(_Zc02MW#o)WdKL`--%vEr>vQ3({ zbm^lrGNi$Ge}CgY%8&*~vfdbR<#RuldvB(zr5&;i26Mz_tW5q8OHXR5wy zk7CbkXQ_#m{}LRgCMtutz(d&)CxE0l>C%n3LGzM-W}^Mo4FW^aP9)~*4C<0TnB`i( zxv$YV=GT-cvRvu@L%*h;qA4JU0zFRb{WqrbN0r+bBW!WP3hYqt6+fILnc+G73tv4} zj|t%;2?Qz9GvdGoUjyr&98vH4&kv~aG)s!_xSIe)?ln)ja0M9(QSFPv5MBQC$A$`c?aBdQ9(Xq^+%hcjKcf9|{-X6ApbNO;1`V>Mc|>cH7&(O4$w zqti3w0lv>n8(E5(`P9mUiDo_--kJ4)`vYejOj?Ep#E8Z^$ri}BW4-rR{hDR=v#Sy; zdYbhlSh$`DuMdFX0#A#pgaj?$l-YXSKAAz!t1dy;5zs=>_Vx$!0XNMvc}SOuza*Wb=E;Ic^)a9sgx_t|W4T}#}p zgR5>@Zb$}PO64z_bG|cU#sjq_X23;z*gkyoNEcS$hjYxEzmxev6YCPRh;9TlO&<){ zuqP5~5HRGP@1ewM$l^;!;=jWd4`On+t3IKN1hGJm6}GjygV&Smv}2uNXB!eMc7`?8 zTeL~2jONIcP*7LuwE5VAMe}u(wm=dFRLcuezL`@d@{MMZ) zh-;TfK@r68yY#WL#4*J^kR=_O_atPI-a=%2Hw8Lzpsb>?(E1$t^!jVO&tqI~LKVrl z5J!SuOCdb}8GHZhru{SbvAbZeK0vs57pw{5!W|HFW-#vpt{lA1VgdXp&X?qXnu7)K z^L}Yo=HsPZI~Z4v_mjAiDABP-x^=p`Q|0BiXRy>pg@IbG(CVYzqtNH}rH`vRnf$dF zOt{KpV=7mayaPVM7TJj@wF}bp3iWJEYHrSCHJyit5Cu`wk@$gW6E1s_C6tLe`O%i)VsR$XJhrCNsXZGU$^d zb)W^TNF(uO>xze?%5{pblm^h9$6OFJrq^%WC6w^eo$hHHC0=98TuIsV-XNIw6!f6(w zAesrKuYk)QSPx;wAlC1ABQ0$(t6l815fxFpkf>lS7p)9XgX}a6R(4E?tJV3<)mJ;R zfcn2|q9oc)$Fia<0e#*=R;I@eGG;qzypM6>fs4GzFb}8?&n&r zV;P~NUP2a?E=0yoTIDE(JR;VMIkY5`n(}4!goz(8&AP`=*r>e?5L>BV!i_4C90Nj; zXNT1O0L@6soo1Eab%x&fiX9-WfbaV=&q6dvGJ*0}8Z{*?Kt?CN9BCkS7RQ56dZ$0!^Gpc=u9%@DCzh5gC@ z9zVqtsoc{#Jcn+Ms|IbamfFy%%RX&1ZgM2aVXvki{2xLE+nn(+iz8=s=0}Bd!Dv?P1$o?L{Jm z+bnEa*oDUdCvJFh8ebCT_>u!Co#AKwZ*MbEiA20=HQ*(}7UL8ochOgD9CY z`Mud4&rbW6;hk6@;Z=o1j>t?#R8*w1nN`cSR^67$j{MS#2v1~Z0T0@)aiES4uC&Zn zQu1wY*{)?-&AZ(ag2gIQl{_6Zy;Z)@blzR;tPd85K2N|5VtYGRS4Nh7lSOx5 zewa@3#>o?)qXlL^DV^#m*&KN&c>F6bgR&TaQE z^FD5EBxK>J%E*xctS;)3#g)v=Jhpt3P|-Zrl~KhUlPI3GF7}_pT8(BwlxHhpiy%tI z#@rIPls-%`#c?GeeRIkL7Rj!7iztX786TiGLu(Vb*bvXaDI1yN)LnexFP-=R>q9#T z8YQGIK|=>mo(YD9L&Z5#ail{-vOZ7reB~XcHs#(XU{srA0><&4(FH_OM`kP^-?sfz zCV6dl6Rt2rgbNK=L2xLsb1G&t{OLPWrm+?0zbCD@HX);L8d|l!=PP@t#}ZPfuoW-d zE3G)hV#HNft&c1pM6E8KwzA;*qkV)d207-yICh^0-x=UMUP<9OO}GB=O1 zK1|2Ic$50dexf1<@!yV$(rIQjjq`hA?E+0W^}`;f|FsViHIe1SmJ#&w;3<@NYA=<1 zJ&G{>ANVn$3NIK@aRZ>w7*1NKEaSu!Ze#WRl_J;Y{>3^X3O*%l?1-p=*l=yf6U@#5 z7d}?dsara%X9wAa&k0o2p9C9VZ=sb@BP#mqoL})u9|l?ED+w|Y1`c9k1d(HY7Ifg{ zXOft!*7_TQ6eTUh!dEV}vuyO&*xu}0$8@FOI|3BvN|qU(ErY{K4CLvXE@m)tj!K4( z_iSvuF+P(Wh3gIxG8(#C(%0kc6s@>Q;WBUuhGwYY@`E?iGMI7x;0FR0?UqV~EXxD) zdIWEy_lHV0RC2tt2QIG4q_)Bdf)=Ni3=PeJVDN5R9pa5tu3qz!Sa>_+Bw>pJj*wb# zKYaL(D2-|pAHmtJzmva?x$!VdRLuLC@Wu2FA!h2R>)X-L-O_c%-Ot0-u@GyJVDfBXyc1T!xXK6-+U@jKuir&DPH zzQb=AThsR{tMi*L5e?Drh%|5%jOs(cpJ+#F>TmY)FTImksIlw{p<}2~Js(}a)sg=-amCo~NIY7klh zq@B##uJEXM~iog!V?W@uo0NtjAoQ2&EPfW z*1J}AWj#b|{vk+lFBW$jr;KH1$MN3dTkq{z!W8<#e}st&-6%ZsE`*Ee9vD3ZBbQ^Y zMc&i@Gw$a$xJ4w;l-3R}*<=a$puY~Af4MYC656U|%S`Dy>6z4d{C4w0MOP>thQiNP zp%wEEZW2sU|TWuSUQ=Ra}FkCKoYTOwqXOJqLI3u5*oWy+K zU+*DGO4=Gsgvb_(Q^DiW&#|RvG?xeM%@)ANkWa9$d}5?p*e!8DO%fV#uTQUPn98T;fL9wkqdff#b92d=sVR@0Yf9Vab~QB$&!pq4a7@7E;)!H zH?JfiCu7%v&kbn`l|;;I_J5zod}mt~LKcnkDj+l_diC$rxY=7zi3|QJ!mQt>Y6Obb zFHr)8BD+_US&|e(zAvz`!>HCQ)%|A;!W5})sc8gkUH~S3!VC`x%VXuk(G#zpV@_o6 zS_ChgNL9n}kk;$Qmo`socl>zPzenn4I~a&}Bq`FVZo=CRkAfq%RaQAGovaT}`2 zRgy@*UrL{_9@XaEUv4*szRx^r*ABy~CPZnlaiFoqZ78ASLGN4;MviHWh9n&PvzF~C znat`R$5V-zc%I>Z5K|5su0gt_UQO>*o+T-+Hz!hHFmhor@>HGy8|?UVxgFm2utQy`iqr6 zDzqbtqDfdCq6|6C$TGaHM0{K1Vj@e+_Uk~fA}uQ>K4T$(X5wPiN{oPTYK|FwR&BM; zx3HAL+)hM6_>?XuOa{#fx^tD1SX6MyWJ_LH;N+X-*bblWOu)(x)4MzZ92g!IS%2By zH>O*&X@_^N`}7+_gKmUA+BovICQcP65iCYiP>F1MtIRf1Lbyfel&ef!Lfl90=GfZo=FGcx3k7? zVSDyNZ+XvD3QD|5bU3oaJWT{p$!xEp?d@L%~W(~ZI1`r9!wp{jLkI{VUw|Gp(`5Pjcr-A&k9!MXznE- zQRT3ETSPTBk~Med!4Ii&FsS$0Y-Vbv_y~|pO^Z7=HN6dvjh)HDrr@g%y8{G`$xhIU z7K(`RPnVXeIe022^!ZKF{Fv9e=h=^F`*9cL!W>`;)!O$iH* z5C%<%GL2fYL0r&|B>$Yp>_h5gi5+4#rUw!W%JdFJFZjYGC6gNX?aeO!#+--jsRWD8 zLnCF;Ay~T5WQA_PhMgft&yvWFSmX%WmXTR_la!T~{JbQcHcf8Vu8V}iEm6#_*; zt4E;N4hl~ufKkPj$2>Xfv!A~Dnk76>E+%kHc-C~mm`pRC&0wu#?uI#eY)BTnM^s}e zVd6ca8VY9g1{)~`llu`jZ0$UO`Rt>Y6C(QTi3%d>TC`g}k}62K&AA5}Kh3sk(Mp08 z>0=cV7pgG`x}x=K&1ox{r8%*RKt%_N!%vOs(dNtPi~hl$i@i*18@xfNB7zH*&y*$8 zL*gg~nk)YL1RkCK&U98ZeRwSq5CfA6%`QmL+clRYbeXVWUhNl|Nnf#^&_r1agFyA* zH^U8am>~px6D}Op$5mK=#ew^_O=LOsOB;y>=F~wTs1XMyDMjXeq$!aQmsoPr>`-D0 z$lXyN6uv)`mF>sAN$}X=P)op5E72X{AGePAijQSwzTRb z6))t|aY>OQJ{z2dK94)qDNWj_G0a%V@@+eI6FlbIk|lVox<*#>rFHYeP?4P<|^qELdVQinuP9wSHmDJrp(aRNs>rT0h`Lwxtn1x zyNfdTGopgGb(%#|3Gpg=E=G?`xumZbu;BpHzaTO=93WjK1Ms5x4{EtEy~zp<9q|$R z>gEeAm~njPDendc`v6hS}m!iTyXqk_wlBj2w(IkA_K`4>fw6S@fURCLU^Ic=g(JT3y4 zC(Ztdd)z080F|gi+!2{3xU%5{+0?d>Np74J(Rinr_ORBX;7NkU)}mB(?cftbe~!4Y zi!HivUl>1ui%oPA>sF5}&Q`zcXKD3yR1&&(qomWr4MR!QY|eyYTYqNzTI4K2iwv0x z4Y$AHyP8RMBB74_rYMwGweBg@fx`=aLD_Gnv&dyg&qm z)hAy7loT#G=0?0OuaE0I|J;6dsBXPPxaeaff`C{|s_JW;bA8yCmkzVCDB|}k1dCBb zEy_VDZcAj!C~5;x6kgzNl;oky!UBVrLIrIq32m1C;x7O$2?9szg^QhslU+iPrf;pwfDSNf{?JiCk2p{ zhC_@iF=_xKha%@PYbZz+7&lIi#@Hbz6}rE(-ExtyXs>KQFB8dy&bCw=999n%5z4 z$nUtsxLWZ6v0`^=TCEd1eOHwyY|+?D*uZo#o&$#j#x0KXO7yzI*nXs<#P-ng;Ucbb zn^M$S`SK-EbVEF*vb}pwzjqjPxyn(X;nf_cQd5Z|2lDD07mO@^>r$1VMf^sL5sQrj z_4VOns<8uXLUlsL0}Rs!D1y0=XpT8qvuj{t7v>W0sY#INY}XD=*i*z>v4OhWS9AQu+KS0>~X#SUBl$Ab8W&FH$)<_Ex9_%F9`=XOB_!Q11r;Iz3{RQ2bK(?C=;pj z{Nse@nlX3mazjE^+%ZKw;vW{cy#CEGEKF|Lm~cgd2XPzk^8a5(HjLY|WeL;9(rH9R zxC}(aW4GgN%J>qfw=8Q1GcV23oS=n2DTEtRkfA5$b@{DHPJD05U(Qm4XH+}-+>^pzq6&> z3gLxb9Cgn{4ch61+r*KbtI=0q_3?u@*fyQ$M5J(=l9UP2a4$42(EE=3JAt*`*6l)| z*mj$q=YY#AaBjrDc8OEm5VbO-_GEU{)Qu>i9j&IHhOQcR)XOMo?kth#mYY%2iZ{12 zBlpVvge}?-!QSI@7mx}O^=RdRFHN{5@mvzWF2zfI$xP8dJqZ|1Q5}5E*n20qXX>0B z$QInbPt*mA$b?=bi)q0hym;|5W__mgC17EF#O#DDc!f*ta@W90W?5q<-tANDV|Lf@ zxBdi;cMYp3M{A_V*dV-pHT$8vIHb^POAD3SA#$Q;G*$Dz@PoArI*{e~@U7*84ebXJ zopfU}W94vmlxt}8{rmFya!;_FWx<1lg*nS~EVT(?!rNtD59AM_gcEw|80PpvlbfcL z6P?UC=g6jPe8>8sL|8EKsX!u1aMgO1+T; Ug(1qlxGKnE$AVO$SHl_q53@`WGXMYp literal 0 HcmV?d00001 diff --git a/deepseek_mtp_dummy_ckpt/tensorboard/events.out.tfevents.1769037919.pool0-01476.3254808.0 b/deepseek_mtp_dummy_ckpt/tensorboard/events.out.tfevents.1769037919.pool0-01476.3254808.0 new file mode 100644 index 0000000000000000000000000000000000000000..007f6ef83d639933293021fccc1687da06225d54 GIT binary patch literal 63278 zcmcJ22b5IB(!U6Tb3t;3gzf=JICv&i}o0zUT3q{Z-%Us_IHz-TfZ*KewLl zzRXlF^Q&^hMjq)<@qtXg#py}yl56(_TV)JLRT>2Co`Bb%>9zQ6sezE&ZShY?&GS2h zc7LiJe`R(!U3RnjZ>*O)HET<_hxfPt#q{OF^U5P`y~awqQqB`{n_ZR(c7LEz&^|tB zCb*78?k`lbK*@XW&)4zZ^kzNoR|+Zx@c;i{yKN=A5}Lp1hp(HCWD?qgN=v1x-DCCI z?0$2CsHRJGafj9N6Z)&hOJ9=o>MwWptRwkO;=<^Eh`DwIvQ zO_jz<`D~}nX7`u_&WU#Z@ef1*^Jg@bzP@VJW4pT$-mhn~*O^9-CA?Nj3#GEd;V~Ox zN{dWS%V?r4vZ$xtGRD%QE{Ubd$Vy8Pc0B0x0=(T~)9qB@$YS%g&`^Inm|b)rE%kvZ zeHNN#cnP(;(p9OVA!b^vId+eYGoWGwvv`MK(9==l-_4c<+?WYiEGo^`AKF{f~dGlf@gCw|IyFR@> zw@^<4oN0Ptv1#&D0&K6eRjLsnFalI8$LTXW{PuBt!wjJV+t(iHP0($u2i`Dknnuta zmG(*v{CJ-?;1nuhcW2sdHm4^`w`~$QSGb=lALz=!Z#*W0M+3)Tv;^+=`XotvWkH*b z18C(ne=W4p)ZqyN@1}H7>ig|3OVF8X*F;#@%FGZbVNmx0271zB_a}n}5e0MqE;CKR zSw!I>rJK?q1_fl7%^3($X3hXQu9 z-y4Ku=YMa5v6QcUCHJw8L}bEM&zq)`bBRcErKwV$2xM9Ux1psEO~i)=8~ni3 zyyAR9L^ZDzUb5Y91#`+pQRYOd1lgZQE-=5PX@on-&OH2E09@Uh)`>#m6D|HrO$2-SS>ElB93aj{?&VqjAm!{<>v?% z&+KH>iZEo%^-2tHAv65?=RHm^si?Dra8W8Mc)TEa0G)s_89ZR<*UVBomJ%W!EY&SG zTMQ}YY>=2c$RCo_y}(sZ7bbB_mJ>EgTt&Ca2D=rv%@C6B*RsQt>i4JpTKdZ2EvBNc z5M)=Svr^Y>$+4Sry;e)6u+^a0xz2zy6a2l*37mq&VW!-3cf9dBb2w}(hyeO<6+K?R z8zf<(5H*e{9WMiV59|X!)O7FR_j6}3PV9SK;sorQ-)*-00#28n5t4|d+Z(1e)gW(l zZharqy+W%9Qs|z==W|U^omYXL?&GVKfOWdJA)j}SnONMfU7aUPkH1O4gOmYEQx({6 z4>)Zhi_2{Ff)gWl5alGOQH+<|yleZuWFr0P+eA(nIuK`%#c9s6yF&3IzBjJr9>(me zYh-2z#)T|?yUh$|LEMnTyaQ8P6Hj(+Trt2@Z#_Y#D~*&=fjnoH9?VcH((0wwXCIuo zLH^)zA*VG5zBdZcc*jZ(lcSl7eC0zzMHjg&2-bLW zz~YE^kp)ZxpJt1!vPD{CQl{M>G-or2a%%JJ&a7;`-(pCp?C1jpnd9a8gb1NmSKS?Q1)ZQ>80%>M5&sBd3Aysn zsRGQsU9p|O(Y>wY4*CqEQpv$@tXZGECx=&J$HwK)2oR5r8VKm}1_JROa%rhj@X+)1 znwiYOPyd3jMIh<++QWQ^4o)TZauNQu-0c6DbY$)(P$9xvL}G(e2{w>(Ryr-%QghDF zo6k(mQ+o(i_*4|{CItEUt>WjHdHC#Wf)q!*DpY2O`YpN>P5N|DGl`!c-bn4$myDOI zTi<6)zkfrBot2JCP1T_1*~~$&d0bB727Fc7d;erPlxq749&;#VtuBaeviy35lQ^0# z8->gno06GP73jS9hGsSNhb?#H?O`15aZci}$_>oIc)*?pql8N`(2{$2<(I9=EU`N=4Rb%5-_HIeAVwS5$vX=gZ9QAG%1es@+FnfJ~_0A}l{dA~Aw7 zNpzX&-`BSGb!Ml^UM6g`Q?($tk&PdEo?x@0vOJQ^@e||wLM=|CJFBAgo?VB#89L1 zk}{^0I^6YUupY}qee@4uVkRmD)yy{mB4%qgNUh+3E_gO(!r=bTz0QnHwf_iO7@P2N z{T5G-<=O@Nn8FNzaDe-<8SNu-k+v@|>VmUO$*byW5Ly9HedFf8*V(ay{x_$faT5y@1y z)9Uw{Gc7^jcC_~*sj_r_k3G6}I1`F$1ql*`0=#c;CIn)MVg5XS={jZ}#@tVsXdbHg zAwqzJvJs}Cj{Wp^s*9!Ho>vGK({J^_Yw_D5WhDWdoer?_c3b?pp-E}0a+rvGTZCw! zh}5#XeZdLlkO$;CA}B5>{J8jDO01Uj@A%|d1+8?}tQar<{LeF0*%!l_FrYcjZ-JEQ{$>Q{s#ygxe-geVEDh+hqt%%cz0u6pm8wXPBI_<7Z6UuJ3y3%{EKW;__OcK*^~XTYp>Dy4vlDBcg!#Q#B%l1;ny$ ziw6@#Svo(+81K^yC+o1duV78W!?>?%1kXSyB)OF|dQj;FNz8br*CtTadkrtr3^h7q zmF3>6;d(!M<`9epPssCX}v-EZp?s*qhg`oJg$ z?HEjXLM{uGv~@?Hp{8Bxw(Ax%%cWBZ8qIQj)vTcqcr2JrS3`HqWJJep5)ICjEmtd^ zVP?5!8qp9Zw?+e2o+Um)U#m5ijQrLmud{mXv<$+=dTm`S-}s%G$eRFG8Ne!sGhRCF zHYV}W;WMS#o_^4TNZ_6(Y20_)gV|o=>=E}l=ie&-J2M)8HX~Ft8Yvp8QBQ!xZjNkT z`{8@+08ej8$m#(ekrwdqL-E_>4fRgH#&$9AAi;{m)#?jjiVy7sR72v)8A+-byxm~$ zP_h?iH$OYq^m-e@?xpllQmv5u7F7|b=R+bizLo)L44ZUAT)%?^l5nN zUDNS)M5d?mkkSw`*N{qs<9aN_cpxiqazS~~FSU|k$y&U%wPoK#w)~wK*a79i`zg|DMc+s}dn9mRHOqs@(9qJov z*_k0%PAp)C}YAqNU#Au9> z;5}BWls)M%_`QFaRVg`y;L)m7#}E`Pi_-&53kKCS37bpavz>pf!Tjd|!w6f{!EAQ0 zCgA455#hAPNOu_T;mUOrOE8`+93k-%CP!K2q#!R5{R-+qqdpAYQP;^Zr2kM8}sMkP~OGr=v&r#zhxEQZz=2W3&6lb0Qd zY)3b8gg3u?ogJ+?QhiE4!t{mucu+J z-i>T_NKV=9Ic7}%+o1Jsg4vqm!^}sdBTtL;KRY++TjuU`x6p!xyW_S5a=v?%IlIw4BKf}%MVs_~9h8l8QbDc|dMf)zGSz^awNn81<* z^X0x;qtzOAU>(dNTs*KEpmxAQff;JxiCnph5@zOwS z;HQk+ja(A9t5CgIv?V06iI*O)+>+_X1P{TA>Y(TyNFe+7baszCc3!XX5geY^wbU>k z^6Y>MWpj+_6BCa5oyYBDHQ4ih!p9n{4qt^fJ0V);=0Zvmdmx0!l@d$q$GuR71u_*v zL_h>GLP}gBB>_EU9Mw?r&R&{+n>iWn@(5En8QArWh*Q*#f#Q?fqR4L|e^YYh&7$ZwRDlxK37%`X)-ecV*S9+D5!jH@#I1!;m;KW%>UOxYP<2`@2`=USCVfo1ugpK}G zZ4H}np+yOUG<~~^%n7+FxV7N5Wh3ZxE9HIfW7Ex91m9Qbt)!vws~K#K5Qei-xDxM& zEGf|ZHVc*O&LK)7gGv6jiX7&@*_M%Q&d$*7sxh-_oa$bSg|m)%L<7UwdT?_3Ao~nC zSj-=&|DZPby7ivEwAb7#`g_`~*O?A3pHC!2y3Pl26*321L|uq1LGokjc+1{SW`Mq5 zNRYw+g(1ayy%OOPLcI>y?Q+>&p!6Ag z4EvRRb4O+rtE?teG>Yjyr_YWDw_d_U-r#@5Z`_)k?ozk~v+lOHh!R@&I@!?Y1X&KJ z7gLf@?Sx3d9W&cukoml{WYu0)Tz~oe}%q<-KZnNJd zitq96-l?0#CbN_1&=x|)lPDR`AdLs5#efL+ql*R74$Aj-%I@4UlSN$xJ|!Z4&M=lsFgPO zB4Z=2yjpGm%YM)QoDfC!TdYxSMuofq1YLQ0A&^d8-qO}++^0zl`Li$XfE@CHQ1AsC zBER^rGM#C0{x1nx1TSLEIZ$nf7Br412`sEt?!T-0yZ=LG-dcP`kZ9g&3dnHeiODL8 z^Etfa8M7wZ*=xIXVTXWoFM$isSb#%c3#Qc7{IfaF8HC;g$aF&vJP$5Oayrsy=z6R2 z4OzHpG&xm=)_Jy=7VRTS1C{%MyfM0L;+wFO;^PpK_n3*XzdisqpvG zK2;}@(A{MBW0hC^K8|h7y{$6~c$s}^d62+Gy(SFqfqQ{okHcx4ALnZ>_(+e5%zkGd zCRnuJ)nP+yc83KE9>@%=ONJn~$1bx&5Io=*bqY6DhAXq^^IH;-(vU=?fE?$QA$FX(u zH1<`t%et@9Y7WpB13i}E1K#b7)ed0%RXsWkoYseQ=Ccp)d&ji?52Dmb>7dkruR@<& ziZ87>3CQjrtsnYw6ro!l{^&!~?^g+ZfYMKC2IzWZ=$eJA9u&^19v0lx;B1rBzo}$W zF8$Jwf6BWzq)(|CilDYmN=Xa=@xgs`IlBQM*V$LLPti075i=!WcQKmRK}>Z zTN0z10J|Z^NhpM9$-MI3bRqKFqj#_G&II+?{y}`w9>x^b}gchlLLS?E_VcYCR=j z&dbJ2q|ma(rNUT$ethX#8grFz7bKGCd{^+>9T3#~k46z}4ba@3{MAY(vy@<(i4aAH z3w#8)HBJYmGsr`PcJMrXM$mZCwcBi-!g4aU!bCz`odgo$JQ5VgJk-Ku4niGHXO|_1 zp=r7tK$yEVW&z)V2Z(?O_<(@8>7MDuz9=Y{#>?%m|NMSqE0Wxji!Pip9VE{tEjR&3^vKz17=DEd6Nd%8_Tw7lAcuTGm zLJ+hp38Y=Ge%GrWOEC1TN(3;$P{{?ol6dxD33)qa2%S0;lh3(^g&?yBM59JFgJ-kbYv5Q0-suR4p;TKtS+zr2n1_VkmE?lT7 zaUm%JD4d?Sdwy5^*7q#45mo9Cs;HqeF2q=>q}g03=CrBra7bgq#2^rLA-u{ct0Nr~EMsANd*8S0sC%Xu~tL{I44ihfj|7yRh|3!y7Z9=XMYqgE>- zB*a-<{0ZL|NaV`q^nN*4nesVX6Rs#jz}26-L2ds`a4_=#mbWETw462A?JB3o>W^jN zq+4>3sKBu)LCuS?&HQs2JE@PhCjxj<7jx;7aFVsel$o3}E4FL4??F2AD^`Ddq^W2p zLP}Fol~S1b)}tB8TTNH~QB2WVb|EO?tiiUx>?9XBo_at-)$Q`+f0x+qWOgG&u{#=K ztYsRaF4=FIyO=(?mnpZmX~9E;Iy;{-GaooQ(A&(NlK;PLTbhalv@I?_v(C@$X_cbs ze|H~RB{uz6#5KaX4k@yzc6!OcNa>r4D>r_`40oUYgeMHQ!1F`06*fFJtRX+EuDqVY zOxN^*gp8&OCV7ZvanydNt|Ds)Sz~`+Cp%=f4JJr2B`ORFjvn79L%=7mWqrpK=GssK z6s|4yS;vWVW90ne>}%_c(0wg}CG@pV`rA90^JyJPkfOy|;JI;QFm7-|!xR)q!VQ1Y zusp+$Kk>uQUo&p59m-pjpFG$a+vo& zEVz2OA}hvKwGz6hntR~pE0y}Oc76v%Fl5NUAIfC1^$)ho>(^1h*2);!bwl(v`8pS8 zoh|dS2whwP(V;_linj^{k*Ic2LmGdc$ajdH?OUBhLv*(2Xu$R2#7Vpx$Sljs5P#(m z2`od@HzpCC_9-9nsf#*7E*O^+4uCat_l3(klC@x^dWagfV3on@aWLf3nPua6Yt_`L z*IBl8w2$CK(<^jEhvwXK-Cemw+(lZ-+X|7qTlRH75`}9WI#0@b6C@ z@g!USZ`0)UBfMldg`nnn{ctBIUIPEVKJgf{o0Vq}0g=we!JH6S2O;yXegY-~EOF;Z z>bjmB%9cOi32FItG|s57Kq=hZ#MdwFZyz>@)m8Mv%d9jwZ59#0(qMVPkXTDD9Wh^u z+Wd%1(M;F8;q7{})`OL^305?G3s|*W3TI*rju@`DM8!hW*m}R6E3LN*1IA-8PBttL z9WHGl&+?x63v8y_PZ8`8WsuUG!G?U0`iiu8i7mizG1g}o%jJzGy+^PVb<>4JP*n3& zzO!>dJ=Kji*qG9!`ya{A1l0KqQ4)@c5hX*$ar^L)DR&if$rmpobactfsojnVY{a>H zTJK=yVCQp$C(J>t)keEf#>^-*#@3pRVfgz6f*0d61-vByqn05^rw9`QQGdu9UAi_V zGN9?NyV>5REF~hiw<(&^s!k!ynbDmwRU*n&nik{DfkYzY-o&**P0i*!l3ZXIiS4U% z8qGC%mlLJlN>3$Kqa*?iXmvtI1KYr%^#U89_$C{F(K{c^>1!i8oy?U_m{z|`bcQQK zl~zax?!c-Jel!)Vs+jE#htrCko$=DF_|MIUSQP%%3ZjWocnOCuy;;0aFZUWdk3~8~ z-ykRv>0~^Nf~K>)w!K zUwK%>_s<$a7pXasdkBwU(;-W#64vr@Wz_F|nPpF1PmpSpN(2s?h;Se>-+YM=IC@r7 zk#9mTmTN?JL*B(8p;P zS&Xn@8_^M!FEkHk7o-Hlh!pB@i}xGWpWJhsHC|l)gs??h512dAJ&Nj0+75$ip}7Ob zy`nEHrf|he>)mgU_GA2NyhGwws^5+-s<11WA%_^2;U4eswV>>@z6;z~w!Y5lK87`z zxjrX4*a%oV-V(p07g8eaO3lHKF}+y!1%abp)Q!KS9~*xAs+KiaWcBrKf)|(fG3SYI z8#tVP2oxg35*};teUmaf*sdOhZAfZ*&q{Tf zv-A3XB8ARQI+6;^?r1& z4}7@zzi3{@$dT+IDshnL;6aoWzJ=w>PzM(9)f+3|`PJXAW&7UuFu}6RU18X8ZOZ{y zqIILdjA?xF&Cb7B*!9$pL_>sK@L^jSBkmn8Ceug5X7`o4oEYd;pZ*1E+??f$hcU&hU*e6r+3o zo3?KQwD>P-Fm)A7FP0-u9jPSn!)~gPJ*oha}lyU>M4}a1-04x zx~yj{9<_cYSTRl42FY+V5HTNFFxntD1%-AcJ2~@=6 zxKbbZ_Ow!Q4vU*dIBNp_Z6H(Uj~YwK;8QQnA43jlos^wNOo#s%7h zdHcZX3oLw`Q;@(!dDLZ@2+f8xl!Y%)`gZQ}boyX(f0h>Cbw5F3TD%Epju+;o3&qt; zFbI5@#p)PgPRy7a*ZzE#zY%gnA!?$@L#8Go_q!(a;-ifkG3~AY01*<}8%~dh58sI3 z$NM4w{+qV0Em!>}BqH9_lOt6y{s5|63_1D&oDV%sV$Xle&R3QqeJq>tg zsYq60ZAr+9Jxe!tQ0fW776{C2YRR9y1BbL`+EXfp2#GEN^zp=Hb2@P0#TL8{8g;lz zTfhzOpIftDV|)8hb;1`XhFb0ShQu2%AqTjB@$NzUvGd2Uh<9pD0u~W(Ae5Db)5)P! z3xn98#gs05z&}CtKkC&yY5m-2o?iXm6HJx1)h04Rm2zOF669oB;6(#Z0G?QZtTxByOg7c zh! zLIgW0B%7igC^bc2o|HD33+V~l@{`O?e3e0D&`wl?FiKp>i)0$2}C zb^9rCg^T{Vlr%H*mA14bLZZijE(PHWUr-;0c_rp(-^LI-uh)j7EcaflHDQaIaQ1{u zXfnkD6I(tPL}TmOy7@hmnO5~|OXzB!3q~ZGS-?p&;!-EXFZ!9{h9A&kiT8Wv(@$wn zB!o{dQmSPmk*Q{SQhe5iOq}nX2!(!OosLiCOl!_ z#SsNObb@OVddOa}A5h~U8@n(1jTzBjk+l=>1mf^SJ=K6_>nli;( zck0>`*x7+o&pgAVWk--eg#_aI@Oe5hhuT=1xm)ReY0WTZeg4fQOksW8@O+QhP_EI^ zdCTFd)T`E(S@Z@I2u}1rfJoqTt66ujX8bK-|H#VSnb+%_M0nzMO!z(YgmJfBDm|a! zzC1<74HpUHz~%nvGbbMlG1uVuRKi8q0H(D=z64$p0Tndhgmjj(r_tX(`HqP(w%T-w zvE|kKSm532YY-A1>(dyvXugH*yKM&HsrwFhcF~8?<&30#hezzB>6Kr;Hv9-n1+I94 z;6x`0Oss^b6Y-m)F{KrUL#k%q>ae);{c2Qw$T-#h|Vr(XzQtseu25_4TLh93CZ&$F1m&z?`v zsP7f+UbyLw`-Hj4crh;>FKA&ElAjh5rrNOxgV&%dlCKw|dCe}yOX+5lA|u&3RN+~o zgy&El5!2I2kr{Hj(4^yV)2L~Z(tb*@{Fw*VtYuDF|HXukPFXE)CeG%EWRq8vL=rZ& zjaz+DlO1z&o+o&buysS=C#DN%hh50;gBq_PbMSVWq%U(zovF@_xy?%mUmSDLR0D6t zTilt>tdKVZ#dauKiFR;f0=iL43y%)sj|^U1N(4~9;no3URDk=2>(#*>Y+w6nRyU~i zD#2mhK)~J6_OMa6w^#qSpT!Z6zeZpfM^wPI1$;1m?M}vw0Y3I|)@J7S?|6d%MeL{o z;`4&4ndBcW*)(p;SA2YR9;>G2Urnf@l?_qh!fqBmtr%mihOkd9+WH7HmCfEFSWz`W zSRDC@6ClRNxzgKTU4Mb8+LU()5>+h;AyvFsQxnQY+Ek0Q)_&PmbJ2_BJU5h~zrR+G z4*CGIFXBO0e<`6)XtZSkL%+ISj~;-et8;uj=;|*e^ya_Jaj&k ztg$tEozV{>8m_bG&KX@8UjL6}JhvSm{Ye0CNZkYXvmDiwPYF+4k5~DqzBCzi-b;s$ z6s(?`^n#bQhHThDh}arZNxQHUoh_A`!Mj7+A5v{yVQSqV9AbhMZcW=hytVk2{T}r{ zzwR61G+p|fXpB&X!5a+Wzfdm)@x4lo$X1Q;LkbPy#f8W}iw3)f%~G3ZHq3+&O(S~x zk9~PeP3rF=qQcFVh(eht`rI(w(#@aU|0z>$+n0nZVo?njpG(CtUgF>A4mm^gKgU?E zG6TK*D?%47UNpToLWJ{5l55>rcX-J=&Dm-4!(M_Fr%BM6BOoy$0l5?13>w45ni`qc zu@m3)Euo6MOavAEqByz1Tcr0No|d1rhW7i82w-byjR=7_!TDKOJp0J!tR(#0_XI9V z!Vw;Wj}Hk&D9HGncIcj0w;*1|&s+9gK4w%;93@j)Y!?GV$&Far=TcKupMKVyDx zwIhTr{9c9)g9}_C7$0#bP8o~O_2CUGGF`GCB}%AERf2wa90e*tZZG)RsF3k~!N6Mo ztzcrg>^R|yQEsP0Fj6M8HbmIp3CIH zl8+2zzD}v130?R)c$YdGJpM$Ao##9w*RgoZa)ux=-l~}im0-1x7$QxK0kD_broWya z+q~D=@XBn~Z@cUlg2sN^bPXD-cyMH@J^bMUkm&T~d@&Z1-{%#~V!`df^F#`R+cL2~7m^nRW=i6M%bq_BTtIr1iZ-aAW|3@EV*59}-E**6n%o z0@I5}eZ;dWkm} zuJvyTw<@$s*u;=pJPs4#!;B-XzhX|rjvIuFP6RwI7fumDaSF30L9Z_dCYM{`%~7~{ zrCYw#i7P(@`R(_0A7@&f|6ihkT3th<5fF3p@#ath#XsMi1AdlX?{SOJMRHT!z(m6*p0hcX9Y^}U3N4nkFt)`hf#(Hq55W{Nq!>oQaB zWxdDW79dzLa3lvNVdXyBwh2Y5z?VAHk@MCYR&_Mgx0D8 zIzZUi=&`$W!Y*0u#+vWCZAppOS!!ZyF@nR?L|G6QcqlvS1dtRbUAhrB_Wg3jnP^`s zL0~A_NyL1EL0!@Zvs|nH=}JwQUsJ6#fudhiL)8?JLxG+k_JhN2zQ<130c8nWoUj5r z+Fj+S*y}b$3Mq6uH-2^}-co zC`7d{0Yg+iHTz@cQ~z6;fYGO}5e*z;*ooA6Cio8%GBw$nu(|C}w(>evrIn`;0lb8f ziDyi7*e9_h*D2kse`EbaTiM|!9i#YEA3s->5&%v+CTfRvAtVRop5pQ z;6Y_b7e}vBnxX}TVVD4^abR_^TUS_6{z**&#AHNy0T`=uLi%u~3tiT)9>~o6U$qHO zn0c&5i&h;t`>Z<4Bz<&xhCIM`*=A>PX694s5hj}XWO!%R3+@m08lk<5Zj6#_fqXl5 zUCSEBnWx#U0l}iDSwn(_>xuCC02nUtw75!0(DF^01Ho3M81$4>3A&nq7J8Nqi8W|m zie9tCj0|{@L))-pZ$Gz(iR`L0f)*7h0iEd$f(z%tm*}IBrJZ9HsDZ5(^baSW09~d1U-)c4{bXi3FY+72vTsv3rTUG*#)v_y1y+ zXHa{BM9WhSAu|$sy6TJ0Ll?1p#*B^xD6YEV7zOZ-b@wZy#W}XjO}tc%3&xht1df7H zhrsDNZNfd3E1Q4$4dcc6t`aZGqlcqX(CwBaX*|Gd@#4(I46w#S5@02_7j+KQ4V(x2 zI4-@hzrUlaehH?`rF#(|T8U%?bU2`gHxBc|S9-bsC==`a?A3>WMdv31Q}j$c78wP% z-R8u)WAyeJvgBHaGn<%rOz%g8P&}$f6M}m=ac4!Ur%kr9^54e;2wW}yMSTTbV&bef zwo4DbdnJuoj{gP`zQ|nS;>8^TNHS#E-MZO=#;}LhWd79!)jDbR@nWfh%;Hs(lX zq51ikpLJ~%p^9W&m?I&dbpky98Grw4+Y+m~v%6sR#}F>w1*?y^a0djP8O*zYD+jN$ zSOGtV^CdZ;me6?keL$L(+4uRRzKkoDOo=N=5*=H#TW6>{tBZ^cFs>}J>A6Cyk9Cj2 znA`8)__#Ndzt0_nt3EcSaYfBL;3I6&otRR)AWg5(&PMx)@v~V?=SDVB5H%f%ADA}b zvL{(WshQU~?B90Gj?~K`#E2asUO<2neoUGJ8#APFp(k%J*Jg~Hz|n=QCM_4w1fh`e z9tpOumUhIgX$)LnRBh$W2l92paONkbRDmMx%^F ze15{UfJCF7#2LC@LdOL@Cd^3~bN|WG-#@~{F(sF1M9d*=t$3oCCW1w`9Y_?gjYD!> zYvYs?21_8Hj zOQnD+LlVI6KRj27F=Y7yi6JS7s@@lZt9WSY?}RJoBmYj`bd>q&yPhUokvcw&yl>E3VFSgfGUZiiLw? z6bS-QjbO%Rm{`2R{y^>Xe8z0Q7l=$mkx1AfMFahvE~iy@^HgFH_NGd{SGnTnMzR*e zi7yhidhY{Rq|WjZe+NY-)f&XR^zt81eUNeJ_%fM8;jswPPDJ8RycZTPnHdcq8_0}U zg_ntp@YvzmBzTA9u7;S@8`a@URxX|RDj|zOGjFLc?V z10~*%D0JLah}p|cuMsKSW?|DJE<6r6@xYVQ_>wTkmmEOp3}5;Ah8N9LA`!1z{qY82 zi}4g(Qeo&XUd4;NWCvSLk5zI~l2qx;b7#+}qAZlpSWV19V?k8oU4$)yC>a}bOW;!aF~yX?l>}SPwE--W9r-0u5J57eW#Igfgd{r0X}I{( zBc5?G37PQ~L8F9JCur#4$u+~UaHu#(D~@z%NY+O=zK7TEHtZ!}RGVZ1#_^u91w>Lu zW-JE|YUE^+cVQpl3Nu8w(0~;JhY~xd;x@xGf1X;7DSEZI&t}V{=zw&WH6<#o+;s!vUF`TqeS;mPe+{POFD<{5t^%Cod=zEf|u_K}q zV#Bo^ZzwApT=;lFuPn;B#SXGZPZ6l7KM6L#-a;#*PE?H7xok?yVGMG^X$dk31`c9k z1d-!@7W7ETOM95BcK$3uijo#$;VYN=SvJON9Ma)HJEkjD&J&7YRO?;(^Wcck?NOWuBQ=dTdz0g6Jr-(I%=fg~ znuwuB#XM-4$;5_LY&Yftq(m-c+@q<~k}q@KXT858@)I2P{#H=oz%!1DUt^-BKL6ZI z7Ou^?kI*put3+tyA?;+*cZJ8qE8nKXB!;)WP&A%+nkwo|r6Uut z`OMGVSdyTHpQ}Nu<{jK5n4-wRW83I^pIOt2DeuM71T2&n9;p|Xuti-1W^v)m7LXqa zz#Rd7?~k;2T}EWxz%&;tZp9 zlun#2OO9+{K5(}RL`h9sgNYE?LUAg1y~a7V^o%rl&`zm3r(R|?#PO8~SJV)p5RBR0vnnEKH@0zn`_1@2X=Sg)4Sv1OPfY6xeGrm*fVQ)SC z{mvinFzfe4Jpx7Rmn4Bgk=>`uEJ=zX-xrvi)^H0;bt?@CQ>413rV+4tL74anGdv(H zkCzXHA3c1FS-#GV2wpgmnug;cZPbk~ZJyNb*lOC9ZS3TKES;c5haZPF8oD8!G|=)W zm36iUl}4-{y1p@?i+ZSrZkqtdEAT}uK7nH0xe*_fWy$ArO^E;|pPLyIfN9k@iBkB) z<}Bke9Zgua?|gykzd9(`^*=astf@i^BKNQ|TxoL$ay|4PHRj-mn`~6b-g!QQAZSEczMc-$6_@fU^PqijW1C)MBlRHpC$%Edx5R4qt z84XD|_-Ez(qN{%#XiLPz^9=ukn0nA~4KgIvk=T5_F#? zSpzwx)A!i4cihCqHoeb|Gj-`i)H0PZN~b$e1MdiWib+<8oo1MKlR)Bl`oX-FT%o%V zapB6VhY@&RP*5`bF5}ZFeO9MBne=HB{u|>pZR$qoLzTfw>pOfJ3}tmVK?5ijiaUug z~{mF`x3OK14zIlx`PH2F(t6a@3MoOmN9$OJ3N^A=3c1!(a9zV0DM-T^<1r z43CLyyzKgK&*_y#JG}XwuMe2+A4upUl@WI|ahf=ZU@?}0N@UZUHEtm#gxfBz{yfu} z9)pQ2>P*)=2$oFNIbQqVSWdcu!VcN7tOgT?M~4z|6ov*G@o=<-z7Zq>!Lfn1Tz7Xj30RMW(c*$WR7y6%{@dcA&VaI#>KamV;5 zJ)?0G=~`EzQhyv}spQI|iGb+S(v}bXj?f;X&b*BbWRCYUP8@FVc1seyIfa`4VH!Aw zsH7?ll;TlsQijTH{UT#1bIa#u5)iuO~>wcZzI>|q!1t~v->yeWXtZWvC*};o*YA`-*^h$gvixfoD21k|li06lu zdB1(k;`l$u5gf+xNpNm^0&qq^Ml?n!QaQ5LlJXDL7(iQc?NXx-F<#dSO1w_OC5zfd zg3pJ%QQ-LpS#Ui%mtaM3tv(l~-8c#)^j_n@af2=}yS98hf#X>uhLlld$%50wAm4-a zljs*AV~a`3hvoaR(Ead4LdVd(7W!ryI<&SYf?fUiOgA$&C8rQJ8Ji}CqQTwRd>yvt zV^xl>Qwc~^IUK$wF^!F6%{{sBBTWuwzv=18Oik`I0whz@$b!|I`zNC!$B_ssrzX9t5P)Vgdt*suugZF;j?r zvj|#Sy}w2{EQoOE!oL+P57ytYz$&K3`&on2qU$#DY4#L)8nua7oFe2L7oV zi|t^}!{&tqi_Sw`bKnD!Lyg~#s9%Fguqb&*_b)>N*$GKrbayypoOO=oR&<%o*^gmXz@ zoV{@ATzfEtuIH|}6f)~^k+5ebW5A!PADF=od(l(a^! zDh>kndw+R#P5miz#V>j4HF__$sc;blpq(~oYkhoBtK`<1pW$#Vd!sRn{$@Gvoih=HmzcGOgkJW6&il&Z@L_iEonqYB5g5IIKBw@&ek{cKGU?zRh zhlD1|S{MXs2fqbwh{Fsa=$mlks6MX3#w%WFxyc{6+Vc_7z??b=1U2HoB&F!Qk2EC` z;u33issl=FLAg81v-Fh}tZaYJHiE|vhbjV|T8ZuofB(+Xi<$4*_7lQK-?bt%aKeLw zS$DVgGUwj!!paT7?J{e^*nx3(2Y*lfGXvPkwDL287k(F6cpTV)c?$iGNsOLLs9{7hD4ypcTd9A z#SnY*U0=eS=1~U-SHut+F6J2$er@BM$FEmrnX83A5ISbAQYCaRycz~+F?EKvL6Ssr z3fNSZDHmrgVs}x#Jw#OSwoapHDq&v5&czs!`Kw9B4K^IW^dphM;Q;9x8Gskfe^ASX z=}k6h=tzu^>+zA@m~m`#j0mA|tg8`%lhYr}gqi0C1f;`Wo}Y1hkkPr88r0I@!(EOs!DY#j@7MX>_#z+kFB+dCbP?of=$LVH z*~Qeyga}+}Q{)Hcar-V10V+|4xg#=7aAm^};oLBs8DcyCzh<%?(G>BlnrvVE=pC!q_!QA3BRJa7e3)Jnvi6b^1|ok!+W zJH@QNnG+|nLv_h@!bKk=2?WGy)>L2ToEyV#I&`uy ziz4>^L$DY{RG}P{>b69ujAAz6eYa@S1-8ao z;stCXaMRlaixYvPV542!3X!aGF_BdkPk4&eXu{X?*W??_3H9cCMtcgmoZpV#1&#pX z;!rUHU7T8EVH^jH#dQUx>z!S1pS!|zW@UcDMvYH_VNg)Y#=2?1=gfi1Kms>f#(vbh zIZIQXxsRYRO<9hF4z6qQWlK89JN}KUBjID-0jWV{mpDYstv`oATfafQL91LEPd;BeXdlb)tSU#)q6$OMG+8LzXc+H)cDcG> z9FvvP6$o6&3V}m@$1TRyiVuhvd*{sVzp>M|a%I96jlF~oOc&!ha7bXn;`oHOYD>oU z!Br%-S0fT|5m&uUDeA0z`I0DlARg1$US?7a-ZMNYIR-Smn&VPyDhcF3`V{s(%MwAW zs}Zz_-)Plhu?e7-{bxxkJHURZL8y3uVcGyiFb5LNaVKkjdh^Wh%q1>bn;_BIP6+}}Wg<CC-h>F+Yz}u9GB~X8!-n0_)(%x@D(88Y-!VM|N@DuZf{MIEWz61w} zrI3%eB3N;Ik35kq08YTY89wk0Pjjlm{&Mci;iiIZ2(*DxPbmSt1jduMBoA_O@iRkM z-$>JT1S9%Jv`az!VolGv;O#8AQOs`iR*HneYmi*5$5)mCUlv zOq5?f^jCJ*Fmnh&<6XmY>d_kQG1iJ~U*+EDE)JUDv!{kj?GQQ9Gn$(DublqLy{*Y| z?Ao|ufa&RBL?>Noq?CEK3 zW0>OyO>VkUPHZx%hfj=U<2x>oBr-U@ql^viT%zO|u2`~Bj~uRYj6qd5>!5^8!~>N( z&{~3pD5EU|g&|65Toq)o<3W06!R_Sgv?c3a%IjqMqR#~pVi(U>udGs9c>o6W_&vbH zhL{;BgTE~G^hfqo`jgN14E@K{qJQL1Qk1GndG;sCxq)O-OYzq`cQ0*_On+VXx7Gbj zbq7TLx|C8vQN$gJhG;tB_m8$~78v$+efs?qm3H(rUDbYHMX9Wm(SFZKJ=$(}k7aX{ z=ugKqI#Jz}F*ri1xKdOprtTwi9@?pht*^i?chhkkt9SGI@AHze}oN=gN#B>h+> zf`7Lx*LEhE{!V;Pm!Xlr6W@ar8j8^@;(Lm|_U&@^J^6-3elNa9`@KvGS>Ybi`md@) zs(;VUMlFmn>62h6^4yG(gB2&?@iwC^4R%}i33GeCJ!;a z=8ybcEv1H%6xq|r-^r{%i1Wvsiy7&CpQKd&AEd3H)$H)Zdp`|kq(^)f`MWqsBg7|D n=N@6CPY?V0f03Sdpk{|=Ig3s(($^l2{9PQR5u$kQtk?e^fD=yl literal 0 HcmV?d00001 diff --git a/examples/inference/gpt/gpt_dynamic_inference.py b/examples/inference/gpt/gpt_dynamic_inference.py index 679dd78b42b..d45f0398fb6 100644 --- a/examples/inference/gpt/gpt_dynamic_inference.py +++ b/examples/inference/gpt/gpt_dynamic_inference.py @@ -180,7 +180,7 @@ def get_inference_context( max_tokens=args.inference_dynamic_batching_max_tokens, tensor_model_parallel_size=args.tensor_model_parallel_size, pipeline_model_parallel_size=args.pipeline_model_parallel_size, - materialize_only_last_token_logits=not args.return_log_probs, + materialize_only_last_token_logits=False , #not args.return_log_probs, mamba_inference_state_config=mamba_inference_state_config, cache_mla_latent=args.multi_latent_attention and args.cache_mla_latents, kv_lora_rank=args.kv_lora_rank if args.multi_latent_attention else None, @@ -454,7 +454,7 @@ def main(): num_tokens_to_generate=args.num_tokens_to_generate, termination_id=args.termination_id if args.termination_id is not None else tokenizer.eod, top_n_logprobs=args.top_n_logprobs, - stop_words=args.stop_words, + stop_words=args.stop_words ) model = get_model() @@ -490,6 +490,7 @@ def main(): track_paused_request_events=args.inference_dynamic_batching_track_paused_request_events, enable_chunked_prefill=not args.disable_chunked_prefill, inference_logging_step_interval=args.inference_logging_step_interval, + num_speculative_tokens=args.num_speculative_tokens, ) setup_prefix = build_dynamic_engine_setup_prefix(args, model, context, requests) @@ -546,7 +547,7 @@ def escape_str(s): # ---- Prompt summary line ---- prompt_len = len(requests[request_idxs[0]].prompt_tokens) escaped_prompt_text = escape_str(prompt_text) - print(f"\n{unique_idx+1}/{len(unique_prompt_map)} [n {len(request_idxs)}, l {prompt_len}] {escaped_prompt_text}") + #print(f"\n{unique_idx+1}/{len(unique_prompt_map)} [n {len(request_idxs)}, l {prompt_len}] {escaped_prompt_text}") # ---- Group all outputs for this prompt ---- output_map = defaultdict(list) @@ -575,7 +576,7 @@ def escape_str(s): o_hash = "--" o_len = 0 escaped_output_text = "--" - print(f" >>>> [n {len(output_request_idxs)}, {o_len} tokens, hash {o_hash}{', ' if evicted else ''}] {escaped_output_text}") + #print(f" >>>> [n {len(output_request_idxs)}, {o_len} tokens, hash {o_hash}{', ' if evicted else ''}] {escaped_output_text}") text_hashes.append(o_hash) # Write results to JSON. Primarily used for functional testing. @@ -585,7 +586,7 @@ def escape_str(s): # Write every 'n' requests, plus the final request. for i, req in enumerate(requests): if i % args.output_every_n_results == 0 or i == len(requests) - 1: - print(f' Attributes of request {i}: {req.__dict__}') + #print(f' Attributes of request {i}: {req.__dict__}') result_dict = { "input_prompt": req.prompt_text, "generated_text": req.output_text, diff --git a/examples/inference/gpt/utils.py b/examples/inference/gpt/utils.py index a04b856c0a6..ee0dcf8b7f0 100644 --- a/examples/inference/gpt/utils.py +++ b/examples/inference/gpt/utils.py @@ -32,6 +32,9 @@ def add_common_inference_args(parser: ArgumentParser) -> ArgumentParser: default=False, help='Return the log probabilities of the final output tokens', ) + group.add_argument( + "--num-speculative-tokens", type=int, default=0, help='Number of speculative tokens to generate.', + ) group.add_argument( "--prompts", metavar='N', diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index 4267f9d0952..9df78c2dd75 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -328,6 +328,8 @@ def __init__( else: pp_size = pipeline_model_parallel_size + self.num_speculative_tokens = 0 + # Cache the PP group we should use for PP collectives inside the context. # If the model provides a pg_collection with a pp group, prefer it. # Otherwise: @@ -1486,7 +1488,8 @@ def current_input_and_position_ids( self.token_to_pos_ids[:num_tokens].unsqueeze(0), ) - def last_token_logits(self, logits: Tensor) -> Tensor: + # TODO SHAN : Should do verification here for speculative tokens and get the indices. + def last_token_logits(self, logits: Tensor, mtp_logits: Optional[Tensor] = None) -> Tensor: """Last tokens of logits. Args: @@ -1495,16 +1498,16 @@ def last_token_logits(self, logits: Tensor) -> Tensor: Return: (Tensor) Last token logits. """ - # todo: @lmcafee, remove these asserts? assert logits.size(0) == 1, f"logits.size(0) ({tuple(logits.shape)}) != 1" assert logits.size(1) == self.padded_active_token_count, ( f"logits.size(1) ({tuple(logits.shape)}) != " f"padded_active_token_count ({self.padded_active_token_count})." ) + # Logits shape is [1, padded_active_token_count, vocab_size] # Last token logits. - logits = logits.squeeze(0) + logits = logits.squeeze(0) # [padded_active_token_count, vocaba_size] last_token_idxs = ( torch.cumsum( self.request_query_lengths[self.paused_request_count : self.total_request_count], @@ -1670,7 +1673,7 @@ def add_request(self, req: DynamicInferenceRequest, chunk_length: Optional[int] self.total_request_count += 0 if req.finished_chunk_token_count > 0 else 1 self.num_prefill_requests += 1 - def _move_book_keeping_tensors(self, src_idxs, dst_idxs, next_tokens): + def _move_book_keeping_tensors(self, src_idxs, dst_idxs, next_tokens, new_speculative_tokens = None): """ Move all the relevent booking tensors with src idxs to dst idxs """ @@ -1679,7 +1682,9 @@ def _move_book_keeping_tensors(self, src_idxs, dst_idxs, next_tokens): self.request_output_lengths[dst_idxs] = self.request_output_lengths[src_idxs] self.request_ids[dst_idxs] = self.request_ids[src_idxs] next_tokens[dst_idxs] = next_tokens[src_idxs] - + if new_speculative_tokens is not None: + # Handle multi-token next_tokens: shape [num_speculative_tokens, total_request_count] + new_speculative_tokens[:, dst_idxs] = new_speculative_tokens[:, src_idxs] self.request_to_kv_block_ids[dst_idxs] = self.request_to_kv_block_ids[src_idxs] self.request_kv_block_counts[dst_idxs] = self.request_kv_block_counts[src_idxs] self.request_last_kv_block_id[dst_idxs] = self.request_last_kv_block_id[src_idxs] @@ -1794,8 +1799,8 @@ def resume_paused_requests( self.request_last_kv_block_offset[ self.paused_request_count : (self.paused_request_count + resume_request_count) ] - == self.block_size_tokens - 1 - ), "The request_last_kv_block_offset should be 0 for the requests that just got resumed this step." + >= self.block_size_tokens - 1 - self.num_speculative_tokens + ), "The request_last_kv_block_offset should be greater than or equal to the block size tokens - 1 - num_speculative_tokens for the requests that just got resumed this step. (Currently its {self.request_last_kv_block_offset[self.paused_request_count : (self.paused_request_count + resume_request_count)]}), block size tokens: {self.block_size_tokens}, num_speculative_tokens: {self.num_speculative_tokens}" assert resume_request_count <= self.block_allocator.total_avail block_ids = self.block_allocator.allocate_memory_blocks(resume_request_count) @@ -1932,7 +1937,7 @@ def evict_overflow_paused_requests( return evict_request_ids - def update_requests(self, active_requests_mask: Tensor, new_tokens: Tensor) -> Tensor: + def update_requests(self, active_requests_mask: Tensor, new_tokens: Tensor, new_speculative_tokens: Tensor = None) -> Tensor: """Update context state after calling engine.step(). This method is responsible for: @@ -1965,8 +1970,9 @@ def update_requests(self, active_requests_mask: Tensor, new_tokens: Tensor) -> T 8. We make relevant changes to the token bookkeeping tensors Args: - active_requests_mask (Tensor): 1D Mask tensor marking active requests. - new_tokens (Tensor): Newly sampled tokens, with one token per active request. + active_requests_mask (Tensor): 1D Mask tensor marking active requests. (Active request length) + new_tokens (Tensor): Newly sampled tokens, with one token per active request. (Active request length) + new_speculative_tokens (Tensor): Newly sampled speculative tokens, with one token per active request. (num_speculative_tokens, active_request_length) Return: (Tensor) Newly paused request IDs. @@ -2018,6 +2024,7 @@ def update_requests(self, active_requests_mask: Tensor, new_tokens: Tensor) -> T if self.paused_request_count != 0: assert self.paused_tokens is not None next_tokens = torch.cat((self.paused_tokens, new_tokens)) + new_speculative_tokens = torch.cat((self.paused_speculative_tokens, new_speculative_tokens), dim=1) else: next_tokens = new_tokens @@ -2048,6 +2055,7 @@ def update_requests(self, active_requests_mask: Tensor, new_tokens: Tensor) -> T src_idxs=active_idxs_on_right, dst_idxs=finished_idxs_on_left, next_tokens=next_tokens, + new_speculative_tokens=new_speculative_tokens, ) # Reset chunk ids for recently moved requests. @@ -2065,7 +2073,7 @@ def update_requests(self, active_requests_mask: Tensor, new_tokens: Tensor) -> T self.paused_request_count : (active_request_count + self.paused_request_count) ] active_requests_requiring_new_block = ( - num_tokens_in_last_block == self.block_size_tokens - 1 + num_tokens_in_last_block > self.block_size_tokens - 1 - self.num_speculative_tokens ).byte() if self.chunked_prefill_request_id != -1: @@ -2113,7 +2121,7 @@ def update_requests(self, active_requests_mask: Tensor, new_tokens: Tensor) -> T dst_idxs = torch.cat((active_request_ids_on_left, paused_requests_idxs_on_right)) src_idxs = torch.cat((paused_requests_idxs_on_right, active_request_ids_on_left)) self._move_book_keeping_tensors( - src_idxs=src_idxs, dst_idxs=dst_idxs, next_tokens=next_tokens + src_idxs=src_idxs, dst_idxs=dst_idxs, next_tokens=next_tokens, new_speculative_tokens=new_speculative_tokens ) self.paused_request_count += active_requests_requiring_new_block_count @@ -2122,6 +2130,15 @@ def update_requests(self, active_requests_mask: Tensor, new_tokens: Tensor) -> T # 6. Now that we have the requests in following order [Paused, Active, Finished] # We determine how many requests we can resume and resume them + # For multi-token generation: store previous block IDs BEFORE resume allocates new blocks. + # This allows us to know which block tokens should go to if they don't cross the boundary. + # After resume_paused_requests, request_last_kv_block_id will be updated to the NEW block + # for resumed requests, but we need the OLD block for tokens that don't cross. + prev_last_block_ids = None + if self.num_speculative_tokens > 1: + prev_last_block_ids = self.request_last_kv_block_id.clone() + + # 6.a. First, resume temporarily paused requests. active_request_count, newly_paused_request_ids = self.resume_paused_requests( active_request_count, newly_paused_request_ids, next_tokens @@ -2139,6 +2156,7 @@ def update_requests(self, active_requests_mask: Tensor, new_tokens: Tensor) -> T # 6.d. Swap the chunked prefill request to the end of the active requests # to obey the invariance. + # SHAN : Should check this if self.chunked_prefill_request_id != -1: self._swap_book_keeping_tensors( src_idxs=torch.tensor([self.get_index_of_chunked_prefill_request()]), @@ -2149,45 +2167,105 @@ def update_requests(self, active_requests_mask: Tensor, new_tokens: Tensor) -> T # 7. We make changes to the request book keeping tesnsors and setup the tokens for next iteration assert self.total_request_count == active_request_count + self.paused_request_count - # All these active requests are in decode phase, so they need only 1 token per request - self.active_token_count = active_request_count - # Always the first section of token input ids are only used. - self.token_to_input_ids[: self.active_token_count] = next_tokens[ - self.paused_request_count : self.total_request_count - ] - if self.paused_request_count > 0: self.paused_tokens = next_tokens[: self.paused_request_count] + self.paused_speculative_tokens = new_speculative_tokens[:, : self.paused_request_count] # add_ and fill_ calls seems to work as intended with sliced indexing (i.e. x[3:5].add(...) or x[3:5].fill_) # but when another tensor is used for indexing, it does not work as expected (i.e. x[y] if x and y are torch tensors) self.request_kv_length_offsets[self.paused_request_count : self.total_request_count].add_( self.request_query_lengths[self.paused_request_count : self.total_request_count] ) - self.request_query_lengths[self.paused_request_count : self.total_request_count].fill_(1) - self.token_to_pos_ids[: self.active_token_count] = self.request_kv_length_offsets[ - self.paused_request_count : self.total_request_count - ] + + self.request_query_lengths[self.paused_request_count : self.total_request_count].fill_(1 + self.num_speculative_tokens) + + old_offsets = self.request_last_kv_block_offset[self.paused_request_count : self.total_request_count].clone() self.request_last_kv_block_offset[self.paused_request_count : self.total_request_count] = ( - self.request_last_kv_block_offset[self.paused_request_count : self.total_request_count] - + 1 + old_offsets + 1 + self.num_speculative_tokens ) % self.block_size_tokens + # ================================================================ + self.active_token_count = active_request_count * (1 + self.num_speculative_tokens) + sampled_tokens = next_tokens[ + self.paused_request_count : self.total_request_count + ] + sampled_speculative_tokens = new_speculative_tokens[ + self.paused_request_count : self.total_request_count + ] + next_tokens = torch.vstack([sampled_tokens, sampled_speculative_tokens]).T.reshape(-1) # This will insert the speculative tokens after the sampled tokens + + self.token_to_input_ids[: self.active_token_count] = next_tokens + + # kv length offsets will tell the sequence length (query + generated_tokens) (During add request alone its 0) (It tells how many tokens there are in kv cache) + self.token_to_pos_ids[: self.active_token_count] = self.request_kv_length_offsets[ + self.paused_request_count : self.total_request_count + ].repeate_interleave(1 + self.num_speculative_tokens) + torch.arange(1 + self.num_speculative_tokens).repeat(active_request_count) + # 8. We make relevant changes to the token bookkeeping tensors self.token_to_request_idx[: self.active_token_count] = torch.arange( self.paused_request_count, self.total_request_count, device=torch.cuda.current_device() - ) + ).repeate_interleave(1 + self.num_speculative_tokens) + + # shan : Same as token_to_pos_ids ? self.token_to_position_in_request[: self.active_token_count] = ( self.request_kv_length_offsets[self.paused_request_count : self.total_request_count] - ) + ).repeate_interleave(1 + self.num_speculative_tokens) + torch.arange(1 + self.num_speculative_tokens).repeat(active_request_count) - self.token_to_block_idx[: self.active_token_count] = self.request_last_kv_block_id[ - self.paused_request_count : self.total_request_count - ] - self.token_to_local_position_within_kv_block[: self.active_token_count] = ( - self.request_last_kv_block_offset[self.paused_request_count : self.total_request_count] - ) + self.token_to_local_position_within_kv_block[: self.active_token_count] = self.token_to_pos_ids[: self.active_token_count] % self.block_size_tokens + + + current_block_ids = self.request_last_kv_block_id[self.paused_request_count : self.total_request_count] + raw_positions = old_offsets[:, None] + 1 + torch.arange(1 + self.num_speculative_tokens + 1 )[None, :] # [active_request_count, num_speculative_tokens + 1] (+1 for generated toekns) + # A token crosses to the next block if its raw_position >= block_size + crosses_boundary = raw_positions >= self.block_size_tokens + + # TOKEN TO BLOCK IDX alone is quite complex + if not crosses_boundary.any(): + # Fast path: no tokens cross block boundary, all use current block + self.token_to_block_idx[: self.active_token_count] = self.request_last_kv_block_id[ + self.paused_request_count : self.total_request_count + ].repeate_interleave(1 + self.num_speculative_tokens) + else: + + # Some tokens cross to the next block (this happens for resumed requests) + # + # When a request is paused and resumed: + # 1. It was paused because remaining_space < num_tokens_per_step + # 2. A NEW block is allocated in resume_paused_requests + # 3. request_last_kv_block_id is updated to the NEW block + # 4. The old offset is preserved (wasn't reset) + # + # So for resumed requests: + # - Tokens before the boundary (raw_pos < block_size): go to PREVIOUS block + # - Tokens at/after the boundary (raw_pos >= block_size): go to CURRENT (new) block + # + # For non-resumed requests (no boundary crossing): all go to current block + # + # We use prev_last_block_ids which was stored BEFORE resume_paused_requests + # was called, so it contains the OLD block IDs before new blocks were allocated. + + # Get previous block IDs (stored before resume_paused_requests) + prev_block_ids = prev_last_block_ids[self.paused_request_count : self.total_request_count] # [active_count] + + # For each request, check if ANY token crosses (i.e., request was resumed) + request_has_crossing = crosses_boundary.any(dim=1) # [active_count] + + # Build block_idx: [active_count, N] + # Start with current (new) block for all + block_idx = current_block_ids[:, None].expand(-1, 1 + self.num_speculative_tokens).clone() # [active_count, N] + + # For requests that have crossing, tokens BEFORE boundary use prev block + # crosses_boundary is False for tokens before boundary + # So: where request_has_crossing AND NOT crosses_boundary, use prev_block + use_prev_block = request_has_crossing[:, None] & ~crosses_boundary # [active_count, N] + + # Apply previous block IDs where needed + prev_block_ids_expanded = prev_block_ids[:, None].expand(-1, 1 + self.num_speculative_tokens) + block_idx = torch.where(use_prev_block, prev_block_ids_expanded, block_idx) + + self.token_to_block_idx[: self.active_token_count] = block_idx.flatten() + # ================================================================ return { "newly_paused_request_ids": newly_paused_request_ids, diff --git a/megatron/core/inference/engines/dynamic_engine.py b/megatron/core/inference/engines/dynamic_engine.py index 1bb4ac73f44..d099b735cd7 100644 --- a/megatron/core/inference/engines/dynamic_engine.py +++ b/megatron/core/inference/engines/dynamic_engine.py @@ -148,6 +148,7 @@ def __init__( enable_chunked_prefill: bool = True, inference_logging_step_interval: int = 0, pg_collection: Optional[ProcessGroupCollection] = None, + num_speculative_tokens: Optional[int] = 0, ): assert isinstance( @@ -185,6 +186,16 @@ def __init__( self.inference_logging_step_interval = inference_logging_step_interval self.unified_memory_level = context.unified_memory_level self.persist_cuda_graphs = context.persist_cuda_graphs + self.num_speculative_tokens = num_speculative_tokens + + assert self.num_speculative_tokens >= 0, "Number of speculative tokens must be non-negative" + + if self.num_speculative_tokens > 0: + assert not self.context.materialize_only_last_token_logits, "Speculative decoding requires materialize_only_last_token_logits to be False" + assert self.num_speculative_tokens <= self.controller.num_mtp_heads, f"Number of speculative tokens {self.num_speculative_tokens} must be less than or equal to number of MTP heads {self.controller.num_mtp_heads}" + + self.context.num_speculative_tokens = num_speculative_tokens + self.controller.num_speculative_tokens = num_speculative_tokens if enable_cuda_graph is not None: self.cuda_graph_impl = "local" if enable_cuda_graph else "none" @@ -813,7 +824,7 @@ def post_process_requests( finished_request_ids (torch.Tensor): A list of finished request ids evict_request_ids (torch.Tensor): A list of evicted request ids. step_time (float): The latency of the last step - sample: (torch.Tensor): The newly generated tokens for each request + sample: List[Tensor]: The newly generated tokens for each request (Will include speculative tokens as well) log_probs: (List): Log probs for each request top_n_logprobs: (Dict): Top-n log probs for each request. Maps request_idx to list of (top_n_logprobs, top_n_indices) tuples. @@ -830,15 +841,18 @@ def post_process_requests( log_probs_iter = log_probs if log_probs else repeat(None) - for req_idx, (request_id, token, request_log_probs) in enumerate( + for req_idx, (request_id, tokens, request_log_probs) in enumerate( zip(request_ids.tolist(), sample.tolist(), log_probs_iter) ): request: DynamicInferenceRequest = self.get_request(request_id) if request_id != self.context.chunked_prefill_request_id: # Skip appending token for requests being finished due to stop words # (they already have their final token from the previous step) + # If the request already has more tokens, then we only append as much as is necessary + if len(request.generated_tokens) + len(tokens) >= request.sampling_params.max_tokens: + tokens = tokens[:request.sampling_params.max_tokens - len(request.generated_tokens)] if request_id not in self.stop_word_being_finished_ids: - request.generated_tokens.append(token) + request.generated_tokens.append(tokens) if request.tpot is None: request.tpot = [] request.tpot.append(step_time) @@ -1014,8 +1028,15 @@ def _check_stop_words_for_request_post_append(self, request: DynamicInferenceReq stop_len = len(stop_word_ids) if len(generated_tokens) >= stop_len: # Check if the last stop_len tokens match the stop word - if list(generated_tokens[-stop_len:]) == stop_word_ids: - return True + if stop_len > self.num_speculative_tokens: + if list(generated_tokens[-stop_len:]) == stop_word_ids: + return True + else: + # Need to check the last stop len tokens shifting by 1 up to num_speculative_tokens + # Check logic and vecotrize this if possible + for i in range(self.num_speculative_tokens): + if list(generated_tokens[-stop_len - i: -i]) == stop_word_ids: + return True return False diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index a5233983ed0..93757d0902a 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -30,6 +30,7 @@ from megatron.core.transformer.moe.moe_layer import BaseMoELayer from megatron.core.transformer.utils import set_model_to_sequence_parallel from megatron.core.utils import get_asyncio_loop, get_model_config, unwrap_model +from megatron.core.inference.contexts.dynamic_context import DynamicInferenceContext try: import transformer_engine as te # pylint: disable=unused-import @@ -63,7 +64,7 @@ def __init__( ): self.inference_wrapped_model = inference_wrapped_model self.tokenizer = tokenizer - + self.num_speculative_tokens = None self.pp_group = pp_group # For models without pipeline parallelism, is_first_stage and is_last_stage returns True @@ -74,10 +75,18 @@ def __init__( model_config = get_model_config(self.inference_wrapped_model.model) self.sampling_rng = torch.Generator(device=torch.cuda.current_device()) self.sampling_rng.manual_seed(model_config.inference_sampling_seed) + self.num_mtp_heads= self._get_mtp_num_heads() if self.inference_wrapped_model.inference_context.is_dynamic_batching(): self._init_dynamic_sampling_tensors() + def _get_mtp_num_heads(self) -> int: + """Get the number of MTP layers from the model config.""" + model = self.inference_wrapped_model.model + if hasattr(model, 'config') and hasattr(model.config, 'mtp_num_layers'): + return model.config.mtp_num_layers or 0 + return 0 + def set_stop_word_finished_ids_callback(self, callback): """Set a callback to get request IDs that should be marked as finished due to stop words. @@ -91,7 +100,7 @@ def set_stop_word_finished_ids_callback(self, callback): def _init_dynamic_sampling_tensors(self): """Initialize tensors needed for dynamic sampling.""" - context = self.inference_wrapped_model.inference_context + context: DynamicInferenceContext = self.inference_wrapped_model.inference_context max_requests = context.max_requests # Callback to get request IDs that should be marked as finished due to stop words @@ -104,6 +113,7 @@ def _init_dynamic_sampling_tensors(self): self._sampling_backend = "torch" self._sampled_tokens_cuda = torch.empty(max_requests, dtype=torch.int64, device=device) + self._sampled_mtp_tokens_cuda = torch.empty([self.num_mtp_heads, max_requests], dtype=torch.int64, device=device) # Keep track of request metadata. self._request_metadata: Dict[str, Tensor] = {} @@ -504,7 +514,7 @@ def _dynamic_step_context_init( input_ids (Tensor): The active input IDs. position_ids (Tensor): The active position IDs. """ - context = self.inference_wrapped_model.inference_context + context: DynamicInferenceContext = self.inference_wrapped_model.inference_context inference_wrapper_config = self.inference_wrapped_model.inference_wrapper_config active_request_slice = slice(context.paused_request_count, context.total_request_count) @@ -571,13 +581,23 @@ def _dynamic_step_forward_logits(self, input_ids: Tensor, position_ids: Tensor) """ inference_wrapper_config = self.inference_wrapped_model.inference_wrapper_config - context = self.inference_wrapped_model.inference_context + context: DynamicInferenceContext = self.inference_wrapped_model.inference_context active_request_count = context.total_request_count - context.paused_request_count with torch.inference_mode(): logits = self.inference_wrapped_model.run_one_forward_step( {"tokens": input_ids, "position_ids": position_ids, "attention_mask": None} ) + + if self.num_speculative_tokens > 0: + unwrapped_model = unwrap_model(self.inference_wrapped_model.model) + assert hasattr(unwrapped_model, '_mtp_logits_cache'), "MTP logits cache not found" + mtp_logits = unwrapped_model._mtp_logits_cache + expected_mtp_logits_length, _, vocab_size = mtp_logits.shape + assert expected_mtp_logits_length == self.num_mtp_heads, f"MTP logits length mismatch. Expected mtp logits length {self.num_mtp_heads}, got {expected_mtp_logits_length}" + mtp_logits = mtp_logits[:self.num_speculative_tokens] + logits = torch.cat([logits, mtp_logits], dim = 0) + if self.model_is_pipeline_parallel: logits_seq_len = ( @@ -586,7 +606,7 @@ def _dynamic_step_forward_logits(self, input_ids: Tensor, position_ids: Tensor) else input_ids.shape[1] ) vocab_size = inference_wrapper_config.padded_vocab_size - logits_shape = [1, logits_seq_len, vocab_size] + logits_shape = [self.num_speculative_tokens + 1, logits_seq_len, vocab_size] if is_pipeline_last_stage(self.pp_group): assert logits is not None and torch.Size(logits_shape) == logits.shape @@ -602,7 +622,7 @@ def _dynamic_step_forward_logits(self, input_ids: Tensor, position_ids: Tensor) def _dynamic_step_sample_bookkeeping(self): """Perform bookkeeping necessary to sample logits for dynamic batching.""" - context = self.inference_wrapped_model.inference_context + context: DynamicInferenceContext = self.inference_wrapped_model.inference_context active_request_slice = slice(context.paused_request_count, context.total_request_count) if self._sampling_backend == "torch": @@ -616,18 +636,227 @@ def _dynamic_step_sample_bookkeeping(self): top_p = self._request_metadata["top_p"][active_request_slice].tolist() for i, (t, k, p) in enumerate(zip(temp, top_k, top_p)): - h = (t, k, p) - bucket = bucket_map.get(h, None) - if bucket is None: - bucket_map[h] = ([i], i) - else: - bucket[0].append(i) + sampling_params = (t, k, p) + bucket_map[sampling_params].append(i) - # Store the buckets and their equivalence class representatives. + # Just unpack the key directly! self._torch_sampling_buckets = ( - (indices, temp[rep], top_k[rep], top_p[rep]) for indices, rep in bucket_map.values() + (indices, *sampling_params) for sampling_params, indices in bucket_map.items() ) + def _update_kv_cache_bookkeeping_for_speculative_decoding(self): + """Update the KV cache bookkeeping for speculative decoding. + + After forward pass with speculative tokens, some tokens may be rejected. + This function "rewinds" the KV cache bookkeeping to reflect only the accepted tokens. + + When speculative tokens are rejected, we need to: + 1. Update request_kv_length_offsets (total sequence length) + 2. Update request_last_kv_block_offset (position within last block) + 3. If rewinding crosses a block boundary: + - Reduce request_kv_block_counts + - Update request_last_kv_block_id to point to the previous block + - Clear the entry in request_to_kv_block_ids for the released block + - Release the block back to the allocator + """ + context: DynamicInferenceContext = self.inference_wrapped_model.inference_context + active_request_count = context.total_request_count - context.paused_request_count + active_request_slice = slice(context.paused_request_count, context.total_request_count) + + # Get the accepted token counts for each request + # Note: _accepted_token_counts is indexed from 0 to active_request_count-1 + accepted_token_counts = self._accepted_token_counts[:active_request_count] + + # Number of tokens to rewind (rejected speculative tokens) + num_tokens_to_rewind = self.num_speculative_tokens - accepted_token_counts + + # Save the original offset BEFORE modifying to correctly detect block boundary crossing + original_offset = context.request_last_kv_block_offset[active_request_slice].clone() + + # Check which requests need to rewind to a previous block BEFORE modifying + # A request crosses back to a previous block if: original_offset - num_tokens_to_rewind < 0 + remove_allocated_blocks_mask = (original_offset - num_tokens_to_rewind) < 0 + + # Update the offsets + context.request_last_kv_block_offset[active_request_slice] = ( + original_offset - num_tokens_to_rewind + ) % context.block_size_tokens + + context.request_kv_length_offsets[active_request_slice] = ( + context.request_kv_length_offsets[active_request_slice] - num_tokens_to_rewind + ) + + # No need to update request_query_lengths (It will be set correctly in the next iteration) + + # For requests that crossed back to a previous block, we need to: + # 1. Reduce the block count by 1 + # 2. Get the block ID to release (current request_last_kv_block_id) + # 3. Update request_last_kv_block_id to point to the previous block + # 4. Clear the entry in request_to_kv_block_ids for the released block + # 5. Release the block back to the allocator + if remove_allocated_blocks_mask.any(): + # Get indices of requests that need to release a block (relative to active requests) + requests_needing_release = torch.nonzero(remove_allocated_blocks_mask, as_tuple=True)[0] + # Convert to absolute indices in the context tensors + absolute_indices = requests_needing_release + context.paused_request_count + + # Get the block IDs to release (current last block for these requests) + blocks_to_release = context.request_last_kv_block_id[absolute_indices].clone() + + # Reduce block counts for requests that crossed back + context.request_kv_block_counts[absolute_indices] -= 1 + + # Get the new block counts after decrement + new_block_counts = context.request_kv_block_counts[absolute_indices] + + # Update request_last_kv_block_id to point to the previous block + # and clear the released block entry in request_to_kv_block_ids + # TODO : This can be easily vectorized. + for i, req_idx in enumerate(absolute_indices): + new_count = new_block_counts[i].item() + if new_count > 0: + # Update to point to the previous block (at index new_count - 1) + context.request_last_kv_block_id[req_idx] = context.request_to_kv_block_ids[ + req_idx, new_count - 1 + ] + # Clear the released block entry (at index new_count, which was the old last block) + context.request_to_kv_block_ids[req_idx, new_count] = -1 + + # Release the blocks back to the allocator + context.block_allocator.release_memory_blocks(blocks_to_release) + + def _dynamic_step_sample_logits_with_speculative_tokens(self, logits: Tensor, mtp_logits: Tensor, input_ids: Tensor): + """Sample tokens from logits for dynamic batching with speculative tokens. + + # E.g lets say 3 requests are present (Total 11 tokens) + # Request 1 : [b1 b1s1 b1s2] (b1 is the generated token, and b1s1 and b1s2 are the speculative tokens which were generated last time when b1 was generated) + # Request 2 : [c1 c1s1 c1s2] (c1 is the generated token, and c1s1 and c1s2 are the speculative tokens which were generated last time when c1 was generated) + # Request 3 : [a1 a2 a3 a4 a5] (This is a new request, so all are input tokens) + # input ids : [b1 b1s1 b1s2 c1 c1s1 c1s2 a1 a2 a3 a4 a5] + # logits : Tensor of size [1, 11, vocab_size] where each position tells the probability of the tokens at the next position (e.g) Logits at b1 tell the probability of the tokens at b1s1 and so on . + # mtp_logits : Tensor of size [num_speculative_tokens, 11, vocab_size] where each position tells the next mtp heads probabilites. + + The idea here is to verify which tokens need to be accepted based on the input tokens sent (which includes speculative tokens as well) and the current logits and update the _sampled_tokens_cuda and _sampled_mtp_tokens_cuda tensors . + E.g for request 1, we need to accept b1s1 if the sampled logits at position 0 is b1s1. Lets say the sampled logit at position 1 is not b1s2, then we need to reject b1s2 and so just use the corresponding sampled logit at position 1, and the speculative tokens at position 1 for the next pass. For the last request, which is a new request, we just need to sample the logit at last position and the speculative tokens as well at that position. + + The final idea is : + 1. To populate the _sampled_tokens_cuda and _sampled_mtp_tokens_cuda tensors with the tokens that need to be accepted and the next pass of speculative tokens. + 2. To store _accepted_token_counts for verify_and_update_for_mtp_tokens to update KV cache bookkeeping. + + Args: + logits (Tensor): The logits from the forward pass. Shape: [1, seq_len, vocab_size] + mtp_logits (Tensor): The MTP logits from the forward pass. Shape: [num_speculative_tokens, seq_len, vocab_size] + input_ids (Tensor): The input IDs. Shape: [1, seq_len] + """ + context: DynamicInferenceContext = self.inference_wrapped_model.inference_context + active_request_count = context.total_request_count - context.paused_request_count + active_request_slice = slice(context.paused_request_count, context.total_request_count) + + # Get query lengths to identify decode vs prefill requests and token boundaries + query_lengths = context.request_query_lengths[active_request_slice] + query_cumsum = query_lengths.cumsum(dim=0) + query_starts = torch.cat([ + torch.zeros(1, device=query_cumsum.device, dtype=query_cumsum.dtype), + query_cumsum[:-1] + ]) + + # Squeeze logits and input_ids: [1, seq_len, vocab] -> [seq_len, vocab] + logits_2d = logits.squeeze(0) + input_ids_1d = input_ids.squeeze(0) + + device = self._sampled_tokens_cuda.device + num_spec_tokens = self.num_speculative_tokens + + # Initialize acceptance count tracker (will be used by verify_and_update_for_mtp_tokens) + # This stores how many speculative tokens were accepted for each request + self._accepted_token_counts = torch.zeros(active_request_count, dtype=torch.int32, device=device) + + # todo : tHIS IS NOT FOOL PROOF, Need to find another way to identify decode vs prefill requests + # Can create a new tensor which states prefill vs decode + expected_decode_length = 1 + num_spec_tokens + is_decode_mask = (query_lengths == expected_decode_length) + + # Process each request + for req_idx in range(active_request_count): + req_start = int(query_starts[req_idx].item()) + req_length = int(query_lengths[req_idx].item()) + + # Get sampling parameters for this request + ctx_req_idx = context.paused_request_count + req_idx + temp = float(self._request_metadata["temperature"][ctx_req_idx].item()) + top_k = int(self._request_metadata["top_k"][ctx_req_idx].item()) + top_p = float(self._request_metadata["top_p"][ctx_req_idx].item()) + + if is_decode_mask[req_idx]: + # ================================================================ + # DECODE REQUEST: Verify speculative tokens from previous step + # ================================================================ + # Token layout: [main_token, spec_token_1, ..., spec_token_k] + # logits[pos] predicts the token at position pos+1 + # We verify: does sample(logits[pos]) == input_ids[pos+1]? + + accepted_count = 0 + sample_pos = req_start # Position from which to get new speculative tokens + + for spec_idx in range(num_spec_tokens): + # Sample from logits at position (req_start + spec_idx) + # This predicts the token at position (req_start + spec_idx + 1) + pos = req_start + spec_idx + logit = logits_2d[pos:pos+1, :] + sampled_token = self._torch_sampling_func(logit, temp, top_k, top_p).item() + + # The speculative token we're verifying is at position (pos + 1) + spec_token = input_ids_1d[pos + 1].item() + + if sampled_token == spec_token: + # Speculative token matches! Accept it and continue verification + accepted_count += 1 + else: + # Rejection: sampled token differs from speculative token + # Use the sampled token as the next output token + self._sampled_tokens_cuda[req_idx] = sampled_token + sample_pos = pos + break + else: + # All speculative tokens were accepted + # Sample a new token from the position after the last speculative token + sample_pos = req_start + num_spec_tokens + logit = logits_2d[sample_pos:sample_pos+1, :] + sampled_token = self._torch_sampling_func(logit, temp, top_k, top_p).item() + self._sampled_tokens_cuda[req_idx] = sampled_token + + self._accepted_token_counts[req_idx] = accepted_count + + # Get new speculative tokens from MTP logits at the sample position + # These will be used as speculative tokens in the next forward pass + for mtp_idx in range(num_spec_tokens): + mtp_logit = mtp_logits[mtp_idx, sample_pos:sample_pos+1, :] + spec_token = self._torch_sampling_func(mtp_logit, temp, top_k, top_p).item() + self._sampled_mtp_tokens_cuda[mtp_idx, req_idx] = spec_token + + else: + # ================================================================ + # PREFILL REQUEST: Sample from the last position only + # ================================================================ + # No speculative tokens to verify for new requests + # Just sample the next token from the last position's logits + + last_pos = req_start + req_length - 1 + logit = logits_2d[last_pos:last_pos+1, :] + sampled_token = self._torch_sampling_func(logit, temp, top_k, top_p).item() + self._sampled_tokens_cuda[req_idx] = sampled_token + + # For prefill, acceptance count represents that all prompt tokens are in KV cache + # (though semantically different from decode's speculative token acceptance) + self._accepted_token_counts[req_idx] = 0 # No speculative tokens were verified + + # Get speculative tokens from MTP logits at the last position + # These will be used as speculative tokens in the next forward pass + for mtp_idx in range(num_spec_tokens): + mtp_logit = mtp_logits[mtp_idx, last_pos:last_pos+1, :] + spec_token = self._torch_sampling_func(mtp_logit, temp, top_k, top_p).item() + self._sampled_mtp_tokens_cuda[mtp_idx, req_idx] = spec_token + def _dynamic_step_sample_logits(self, logits: Tensor): """Sample tokens from logits for dynamic batching. @@ -638,28 +867,34 @@ def _dynamic_step_sample_logits(self, logits: Tensor): # and then broadcast the sampled tokens rather than broadcasting the raw logits. # Last token logits. - context = self.inference_wrapped_model.inference_context + context: DynamicInferenceContext = self.inference_wrapped_model.inference_context if context.materialize_only_last_token_logits: # When materialize_only_last_token_logits is true, last_token_logits is # already called in the forward pass of GPT. - last_token_logits = logits.squeeze(0) + required_token_indices = logits.squeeze(0) else: - last_token_logits = context.last_token_logits(logits) + # todo : Should do verification here and get approrpiate las token logits + required_token_indices = context.last_token_logits(logits) if self._sampling_backend == "torch": # Concatenate the outputs once to prevent repeated small writes. token_list = [] indices_list = [] + # e.g torch sample buckets will be + # i.e (for all unique comibnation of t, topk, topk what are the associated requests indices (based on the active slices) + # [ [req at index 0, req at index 2], t1, topk1, topp1 ]] + # [ [req at index 1, req at index 3, req at index 4] , t2, topk2, topp2] for indices, temp, top_k, top_p in self._torch_sampling_buckets: token_list.append( - self._torch_sampling_func(last_token_logits[indices, :], temp, top_k, top_p) + self._torch_sampling_func(required_token_indices[indices, :], temp, top_k, top_p) ) indices_list.append(torch.tensor(indices)) # Single write to the output tensor. sampled_tokens = torch.cat(token_list, dim=0) sampled_indices = torch.cat(indices_list, dim=0) + self._sampled_tokens_cuda[sampled_indices] = sampled_tokens def _dynamic_step_log_probs_bookkeeping(self) -> Tuple[bool, bool]: @@ -668,7 +903,7 @@ def _dynamic_step_log_probs_bookkeeping(self) -> Tuple[bool, bool]: Returns: return_log_probs (bool): Whether to return the sampled log_probs. """ - context = self.inference_wrapped_model.inference_context + context: DynamicInferenceContext = self.inference_wrapped_model.inference_context active_request_slice = slice(context.paused_request_count, context.total_request_count) return_log_probs = self._request_metadata["return_log_probs"][active_request_slice] @@ -678,7 +913,7 @@ def _dynamic_step_log_probs_bookkeeping(self) -> Tuple[bool, bool]: def _dynamic_step_calculate_log_probs(self, logits: Tensor) -> Optional[Tensor]: """Calculate log probs from logits.""" - context = self.inference_wrapped_model.inference_context + context: DynamicInferenceContext = self.inference_wrapped_model.inference_context active_request_count = context.total_request_count - context.paused_request_count return context.calculate_log_probs( @@ -707,7 +942,7 @@ def _dynamic_step_calculate_top_n_logprobs( "computing log_probs when return_top_n_logprobs is True." ) - context = self.inference_wrapped_model.inference_context + context: DynamicInferenceContext = self.inference_wrapped_model.inference_context active_request_count = context.total_request_count - context.paused_request_count active_request_slice = slice(context.paused_request_count, context.total_request_count) @@ -775,7 +1010,7 @@ def dummy_forward(self): """Perform a dummy forward pass. This is used in expert model parallelism on ranks that do not have any real requests.""" - context = self.inference_wrapped_model.inference_context + context: DynamicInferenceContext = self.inference_wrapped_model.inference_context # if no cuda graphs, directly use dummy forward if not context.cuda_graph_batch_dimensions_list: return self.inference_wrapped_model.dummy_forward() @@ -816,14 +1051,14 @@ def _dynamic_step_context_bookkeeping(self) -> Dict[str, Tensor]: newly_paused_request_ids (Tensor): Newly paused request IDs. finished_request_ids (Tensor): Finished request IDs. """ - context = self.inference_wrapped_model.inference_context + context: DynamicInferenceContext = self.inference_wrapped_model.inference_context active_request_count = context.total_request_count - context.paused_request_count active_request_slice = slice(context.paused_request_count, context.total_request_count) # Active sequence lengths. active_request_ids = context.request_ids[active_request_slice].long() active_sequence_lengths = context.get_active_sequence_lengths() - active_sequence_lengths += 1 # Account for the token we just generated + active_sequence_lengths += self._accepted_token_counts + 1 # SHAN CHECK IF YOU NEED +1 max_sequence_lengths = context.get_max_sequence_lengths() # Request finished if termination_id or length >= max_sequence_length. @@ -834,6 +1069,7 @@ def _dynamic_step_context_bookkeeping(self) -> Dict[str, Tensor]: ).byte() & torch.less(active_sequence_lengths, max_sequence_lengths).byte() # Mark requests as finished if they hit stop words (detected in previous step's post_process_requests) + # TODO : SHAN : Implement this if self._get_stop_word_finished_ids_callback is not None: request_ids_list = active_request_ids.tolist() stop_word_finished_ids = self._get_stop_word_finished_ids_callback(request_ids_list) @@ -851,7 +1087,7 @@ def _dynamic_step_context_bookkeeping(self) -> Dict[str, Tensor]: new_sample_copy = self._sampled_tokens_cuda[:active_request_count].clone() # Update requests. - update_result = context.update_requests(active_request_mask, new_sample_copy) + update_result = context.update_requests(active_request_mask, new_sample_copy, self._sampled_mtp_tokens_cuda[:active_request_count]) return { "active_request_ids": active_request_ids, @@ -859,6 +1095,7 @@ def _dynamic_step_context_bookkeeping(self) -> Dict[str, Tensor]: **(update_result or {}), } + @torch.inference_mode() async def async_generate_output_tokens_dynamic_batch( self, skip_bookkeeping: Optional[bool] = False @@ -877,7 +1114,7 @@ async def async_generate_output_tokens_dynamic_batch( log_probs (Optional[Tensor]): Log probabilities of the new sample, if requested. cuda_graph_request_count (Optional[int]): Size of cuda graph used for this step. """ - context = self.inference_wrapped_model.inference_context + context: DynamicInferenceContext = self.inference_wrapped_model.inference_context active_request_count = context.total_request_count - context.paused_request_count # No tokens? @@ -890,8 +1127,13 @@ async def async_generate_output_tokens_dynamic_batch( context.padded_active_request_count if context.is_decode_only() else None ) - logits = self._dynamic_step_forward_logits(input_ids, position_ids) - + logits_and_mtp_logits = self._dynamic_step_forward_logits(input_ids, position_ids) + mtp_logits = None + if logits_and_mtp_logits.shape[0] > 1: + logits = logits_and_mtp_logits[:1] + mtp_logits = logits_and_mtp_logits[1:] + print(f"mtp_logits: {mtp_logits.shape}", "logits: {logits.shape}") + # This is the best place to yield control back to event loop. # At this point we have enqueued FW pass GPU kernels asynchronously. # While they are running, we can do other useful CPU work. @@ -900,13 +1142,23 @@ async def async_generate_output_tokens_dynamic_batch( # Todo [Siddharth]: Can we condition the sleep on a cuda event? # NOTE [TDE]: This will be moved once CPU and GPU methods are separated. await asyncio.sleep(0) - + # For now lets not care about log probs and top n logprobs return_log_probs, return_top_n_logprobs = self._dynamic_step_log_probs_bookkeeping() self._dynamic_step_sample_bookkeeping() - self._dynamic_step_sample_logits(logits) + if self.num_speculative_tokens > 1: + self._dynamic_step_sample_logits_with_speculative_tokens(logits, mtp_logits, input_ids) + self._update_kv_cache_bookkeeping_for_speculative_decoding() + else: + self._dynamic_step_sample_logits(logits) + # Afer this you have + # self._sampled_tokens_cuda : [active_request_count] + # self._sampled_mtp_tokens_cuda : [num_mtp_heads, active_request_count] + + log_probs = None top_n_logprobs = None + # TODO SHAN : Implement all of this if return_log_probs or return_top_n_logprobs: log_probs, log_probs_tensor = self._dynamic_step_calculate_log_probs(logits) if return_top_n_logprobs: @@ -1190,7 +1442,7 @@ def generate_all_output_tokens_static_batch( self.inference_wrapped_model.inference_context.is_decode_only() or not (sampling_params.return_log_probs or sampling_params.top_n_logprobs > 0) ) - inference_context = self.inference_wrapped_model.inference_context + inference_context: DynamicInferenceContext = self.inference_wrapped_model.inference_context inference_context.materialize_only_last_token_logits = ( materialize_only_last_token_logits ) diff --git a/megatron/core/models/common/embeddings/rotary_pos_embedding.py b/megatron/core/models/common/embeddings/rotary_pos_embedding.py index 05a7e8f60bb..0e560f939f2 100644 --- a/megatron/core/models/common/embeddings/rotary_pos_embedding.py +++ b/megatron/core/models/common/embeddings/rotary_pos_embedding.py @@ -240,7 +240,15 @@ def get_rotary_seq_len( # by the tp and cp size. return max(packed_seq_params.max_seqlen_q, packed_seq_params.max_seqlen_kv) elif inference_context is not None: - rotary_seq_len = inference_context.max_sequence_length + # For dynamic batching, use the max of context's max_sequence_length and the actual + # input size to ensure rotary embeddings cover CUDA graph warmup token counts + context_max_seq_len = inference_context.max_sequence_length + input_seq_len = 0 + if transformer_input is not None: + input_seq_len = transformer_input.size(0) + elif transformer is not None and transformer.input_tensor is not None: + input_seq_len = transformer.input_tensor.size(0) + rotary_seq_len = max(context_max_seq_len, input_seq_len) else: if transformer is not None and transformer.input_tensor is not None: rotary_seq_len = transformer.input_tensor.size(0) diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py index e70221d2cfa..daef19bf2a3 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py @@ -569,7 +569,7 @@ def _postprocess( position_ids=position_ids, hidden_states=hidden_states, attention_mask=attention_mask, - inference_params=inference_params, + inference_params=None, # MTP layers don't use KV cache rotary_pos_emb=rotary_pos_emb, rotary_pos_cos=rotary_pos_cos, rotary_pos_sin=rotary_pos_sin, @@ -583,56 +583,72 @@ def _postprocess( return hidden_states if self.config.mtp_num_layers is not None: - mtp_labels = labels.clone() + hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0) hidden_states = hidden_states_list[0] - if loss_mask is None: - # if loss_mask is not provided, use all ones as loss_mask - loss_mask = torch.ones_like(mtp_labels) - for mtp_layer_number in range(self.config.mtp_num_layers): - # output - mtp_logits, _ = self.output_layer( - hidden_states_list[mtp_layer_number + 1], - weight=output_weight, - runtime_gather_output=runtime_gather_output, - ) - # Calc loss for the current Multi-Token Prediction (MTP) layers. - mtp_labels, _ = roll_tensor( - mtp_labels, - shifts=-1, - dims=-1, - cp_group=self.cp_group, - packed_seq_params=packed_seq_params, - ) - loss_mask, num_tokens = roll_tensor( - loss_mask, - shifts=-1, - dims=-1, - cp_group=self.cp_group, - packed_seq_params=packed_seq_params, - ) - mtp_loss = self.compute_language_model_loss(mtp_labels, mtp_logits) - mtp_loss = loss_mask * mtp_loss - if self.training: - # TODO(shifangx): remove the use of parallel_state here - # after moving loss logging to loss_func in pretrain_gpt.py - MTPLossLoggingHelper.save_loss_to_tracker( - torch.sum(mtp_loss) / num_tokens, - mtp_layer_number, - self.config.mtp_num_layers, - avg_group=parallel_state.get_data_parallel_group( - with_context_parallel=True - ), + self._mtp_logits_cache = None + if in_inference_mode: + # For inference with speculative decoding, compute and cache MTP logits + mtp_inference_logits = [] + for mtp_layer_number in range(self.config.mtp_num_layers): + mtp_logits, _ = self.output_layer( + hidden_states_list[mtp_layer_number + 1], + weight=output_weight, + runtime_gather_output=runtime_gather_output, + ) + # mtp logits shape [b, 1, vocab size] + mtp_inference_logits.append(mtp_logits.squeeze(1).unsqueeze(0)) + self._mtp_logits_cache = torch.cat(mtp_inference_logits, dim=0) + else: + # Training mode - compute MTP loss + mtp_labels = labels.clone() + if loss_mask is None: + # if loss_mask is not provided, use all ones as loss_mask + loss_mask = torch.ones_like(mtp_labels) + for mtp_layer_number in range(self.config.mtp_num_layers): + # output + mtp_logits, _ = self.output_layer( + hidden_states_list[mtp_layer_number + 1], + weight=output_weight, + runtime_gather_output=runtime_gather_output, ) - mtp_loss_scale = self.config.mtp_loss_scaling_factor / self.config.mtp_num_layers - if self.config.calculate_per_token_loss: - hidden_states = MTPLossAutoScaler.apply( - hidden_states, mtp_loss_scale * mtp_loss + # Calc loss for the current Multi-Token Prediction (MTP) layers. + mtp_labels, _ = roll_tensor( + mtp_labels, + shifts=-1, + dims=-1, + cp_group=self.cp_group, + packed_seq_params=packed_seq_params, ) - else: - hidden_states = MTPLossAutoScaler.apply( - hidden_states, mtp_loss_scale * mtp_loss / num_tokens + loss_mask, num_tokens = roll_tensor( + loss_mask, + shifts=-1, + dims=-1, + cp_group=self.cp_group, + packed_seq_params=packed_seq_params, ) + mtp_loss = self.compute_language_model_loss(mtp_labels, mtp_logits) + mtp_loss = loss_mask * mtp_loss + if self.training: + # TODO(shifangx): remove the use of parallel_state here + # after moving loss logging to loss_func in pretrain_gpt.py + MTPLossLoggingHelper.save_loss_to_tracker( + torch.sum(mtp_loss) / num_tokens, + mtp_layer_number, + self.config.mtp_num_layers, + avg_group=parallel_state.get_data_parallel_group( + with_context_parallel=True + ), + ) + mtp_loss_scale = self.config.mtp_loss_scaling_factor / self.config.mtp_num_layers + if self.config.calculate_per_token_loss: + hidden_states = MTPLossAutoScaler.apply( + hidden_states, mtp_loss_scale * mtp_loss + ) + else: + hidden_states = MTPLossAutoScaler.apply( + hidden_states, mtp_loss_scale * mtp_loss / num_tokens + ) sequence_parallel_override = False if in_inference_mode and inference_context.materialize_only_last_token_logits: diff --git a/megatron/core/tokenizers/text/libraries/huggingface_tokenizer.py b/megatron/core/tokenizers/text/libraries/huggingface_tokenizer.py index 458689fa1f4..4d1bc34eb25 100644 --- a/megatron/core/tokenizers/text/libraries/huggingface_tokenizer.py +++ b/megatron/core/tokenizers/text/libraries/huggingface_tokenizer.py @@ -245,9 +245,9 @@ def ids_to_text(self, ids: List[int], remove_special_tokens: bool = True) -> str """Converts list of ids to text.""" tokens = self.ids_to_tokens(ids) if remove_special_tokens: - tokens_clean = [t for t in tokens if t not in self.tokenizer.all_special_tokens] + tokens_clean = [t for t in tokens if t is not None and t not in self.tokenizer.all_special_tokens] else: - tokens_clean = tokens + tokens_clean = [t for t in tokens if t is not None] text = self.tokens_to_text(tokens_clean) return text diff --git a/megatron/training/checkpointing.py b/megatron/training/checkpointing.py index 66bfe185a3b..a3a8b0a495f 100644 --- a/megatron/training/checkpointing.py +++ b/megatron/training/checkpointing.py @@ -1363,6 +1363,7 @@ def _set_arg(arg_name, old_arg_name=None, force=False): _set_arg('hybrid_mlp_ratio', force=True) _set_arg('num_experts', force=True) + _set_arg('mtp_num_layers', force=True) _set_arg('moe_layer_freq', force=True) if getattr(checkpoint_args, 'num_experts', None) is not None: _set_arg('moe_ffn_hidden_size', force=True) From 194e0e4843f64d04eb87173b901a64506a4fe1f2 Mon Sep 17 00:00:00 2001 From: Shanmugam Ramasamy Date: Thu, 5 Feb 2026 15:58:36 -0800 Subject: [PATCH 02/76] Speculative decoding vectorized implementation --- .../inference/contexts/dynamic_context.py | 56 ++- .../core/inference/engines/dynamic_engine.py | 32 +- .../text_generation_controller.py | 377 ++++++++++-------- 3 files changed, 286 insertions(+), 179 deletions(-) diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index 9df78c2dd75..0eac6d1c70d 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -600,6 +600,7 @@ def allocate_all_tensors(self, *, is_init: bool) -> None: ) # request_query_lengths is the input prompt tokens length during prefill phase (1st step) and then 1 for the decode phase (i.e During generation) self.request_query_lengths = torch.empty_like(self.request_ids) + self.request_in_prefill_status_tensor = torch.empty_like(self.request_ids) # request_output_lengths is len(input_prompt_tokens) + num_tokens_to_generate self.request_output_lengths = torch.empty_like(self.request_ids) # request_kv_length_offsets is the same as query length during prefill phase (1st step) and then 1 for the decode phase (i.e During generation) @@ -1153,6 +1154,7 @@ def add_dummy_requests_parallel( self.request_ids[request_slice] = request_ids_tensor self.request_query_lengths[request_slice] = lengths_tensor + self.request_in_prefill_status_tensor[request_slice] = 1 self.request_output_lengths[request_slice] = lengths_tensor + tokens_to_generate_tensor self.request_kv_length_offsets[request_slice] = 0 self.request_kv_block_counts[request_slice] = block_counts @@ -1442,6 +1444,7 @@ def reset(self) -> None: self.request_last_kv_block_id.fill_(-1) self.request_last_kv_block_offset.fill_(0) self.request_to_kv_block_ids.fill_(-1) + self.request_in_prefill_status_tensor.fill_(-1) # Reset request metadata. for metadata_tensor in self.request_metadata.values(): @@ -1487,8 +1490,7 @@ def current_input_and_position_ids( self.token_to_input_ids[:num_tokens].unsqueeze(0), self.token_to_pos_ids[:num_tokens].unsqueeze(0), ) - - # TODO SHAN : Should do verification here for speculative tokens and get the indices. + def last_token_logits(self, logits: Tensor, mtp_logits: Optional[Tensor] = None) -> Tensor: """Last tokens of logits. @@ -1616,6 +1618,7 @@ def add_request(self, req: DynamicInferenceRequest, chunk_length: Optional[int] # Handle length and block assignments. self.request_query_lengths[current_id] = chunk_length + self.request_in_prefill_status_tensor[current_id] = 1 self.request_output_lengths[current_id] = ( req.finished_chunk_token_count + chunk_length @@ -1678,12 +1681,12 @@ def _move_book_keeping_tensors(self, src_idxs, dst_idxs, next_tokens, new_specul Move all the relevent booking tensors with src idxs to dst idxs """ self.request_kv_length_offsets[dst_idxs] = self.request_kv_length_offsets[src_idxs] + self.request_in_prefill_status_tensor[dst_idxs] = self.request_in_prefill_status_tensor[src_idxs] self.request_query_lengths[dst_idxs] = self.request_query_lengths[src_idxs] self.request_output_lengths[dst_idxs] = self.request_output_lengths[src_idxs] self.request_ids[dst_idxs] = self.request_ids[src_idxs] - next_tokens[dst_idxs] = next_tokens[src_idxs] + next_tokens[dst_idxs] = next_tokens[src_idxs] # num tokens sames as num samples if new_speculative_tokens is not None: - # Handle multi-token next_tokens: shape [num_speculative_tokens, total_request_count] new_speculative_tokens[:, dst_idxs] = new_speculative_tokens[:, src_idxs] self.request_to_kv_block_ids[dst_idxs] = self.request_to_kv_block_ids[src_idxs] self.request_kv_block_counts[dst_idxs] = self.request_kv_block_counts[src_idxs] @@ -1704,6 +1707,7 @@ def _swap_book_keeping_tensors(self, src_idxs, dst_idxs, next_tokens): """ tensor_swap(self.request_kv_length_offsets, src_idxs, dst_idxs) tensor_swap(self.request_query_lengths, src_idxs, dst_idxs) + tensor_swap(self.request_in_prefill_status_tensor, src_idxs, dst_idxs) tensor_swap(self.request_output_lengths, src_idxs, dst_idxs) tensor_swap(self.request_ids, src_idxs, dst_idxs) tensor_swap(next_tokens, src_idxs, dst_idxs) @@ -1937,6 +1941,7 @@ def evict_overflow_paused_requests( return evict_request_ids + def update_requests(self, active_requests_mask: Tensor, new_tokens: Tensor, new_speculative_tokens: Tensor = None) -> Tensor: """Update context state after calling engine.step(). @@ -1982,12 +1987,24 @@ def update_requests(self, active_requests_mask: Tensor, new_tokens: Tensor, new_ # active_request_count -> This corresponds to requests that have not reached EOD or max length # finished_request_count are requests that have reached the termination criterion + + # new_tokens : [ b4 , c4, a6] + # actgive_requesT_mask [ 0 1 0 ] + # [1 0 0 ] + # new_spec_Tokens : [ [b4s1, c4s1, a6s1], + # [b4s2, c4s2, a6s2]] + + ## Vijay : [b4 b4s1, b4s2, c4 , c4s1, c4s2, a6 , a6s1, a6s2] + # self.num_prefill_requests = 0 # all turns to decode + # All request that were in prefill become decode requests + self.request_in_prefill_status_tensor[self.request_in_prefill_status_tensor == 1] = 0 # TODO : Check how this works with chunked prefill if self.chunked_prefill_request_id != -1: active_requests_mask[-1] = ( 1 # must keep this, next iteration will add a new chunk to it ) + active_request_count = (active_requests_mask == 1).sum().item() finished_request_count = (active_requests_mask == 0).sum().item() assert ( @@ -2156,7 +2173,6 @@ def update_requests(self, active_requests_mask: Tensor, new_tokens: Tensor, new_ # 6.d. Swap the chunked prefill request to the end of the active requests # to obey the invariance. - # SHAN : Should check this if self.chunked_prefill_request_id != -1: self._swap_book_keeping_tensors( src_idxs=torch.tensor([self.get_index_of_chunked_prefill_request()]), @@ -2190,33 +2206,43 @@ def update_requests(self, active_requests_mask: Tensor, new_tokens: Tensor, new_ sampled_tokens = next_tokens[ self.paused_request_count : self.total_request_count ] - sampled_speculative_tokens = new_speculative_tokens[ - self.paused_request_count : self.total_request_count - ] - next_tokens = torch.vstack([sampled_tokens, sampled_speculative_tokens]).T.reshape(-1) # This will insert the speculative tokens after the sampled tokens + if self.num_speculative_tokens > 0: + # new_speculative_tokens has shape [num_spec_tokens, num_requests], slice the request dimension (dim 1) + sampled_speculative_tokens = new_speculative_tokens[ + :, self.paused_request_count : self.total_request_count + ] + next_tokens = torch.vstack([sampled_tokens.unsqueeze(0), sampled_speculative_tokens]).T.reshape(-1)# + + else: + next_tokens = sampled_tokens + self.token_to_input_ids[: self.active_token_count] = next_tokens # kv length offsets will tell the sequence length (query + generated_tokens) (During add request alone its 0) (It tells how many tokens there are in kv cache) self.token_to_pos_ids[: self.active_token_count] = self.request_kv_length_offsets[ self.paused_request_count : self.total_request_count - ].repeate_interleave(1 + self.num_speculative_tokens) + torch.arange(1 + self.num_speculative_tokens).repeat(active_request_count) + ].repeat_interleave(1 + self.num_speculative_tokens) + torch.arange(1 + self.num_speculative_tokens, device=torch.cuda.current_device()).repeat(active_request_count) + # - # 8. We make relevant changes to the token bookkeeping tensors + # 8. We make relevant changes to the token bookkeeping tensors [1 2 3] [1 1 1 2 2 2 ] self.token_to_request_idx[: self.active_token_count] = torch.arange( self.paused_request_count, self.total_request_count, device=torch.cuda.current_device() - ).repeate_interleave(1 + self.num_speculative_tokens) + ).repeat_interleave(1 + self.num_speculative_tokens) # shan : Same as token_to_pos_ids ? self.token_to_position_in_request[: self.active_token_count] = ( self.request_kv_length_offsets[self.paused_request_count : self.total_request_count] - ).repeate_interleave(1 + self.num_speculative_tokens) + torch.arange(1 + self.num_speculative_tokens).repeat(active_request_count) + ).repeat_interleave(1 + self.num_speculative_tokens) + torch.arange(1 + self.num_speculative_tokens, device=torch.cuda.current_device()).repeat(active_request_count) self.token_to_local_position_within_kv_block[: self.active_token_count] = self.token_to_pos_ids[: self.active_token_count] % self.block_size_tokens current_block_ids = self.request_last_kv_block_id[self.paused_request_count : self.total_request_count] - raw_positions = old_offsets[:, None] + 1 + torch.arange(1 + self.num_speculative_tokens + 1 )[None, :] # [active_request_count, num_speculative_tokens + 1] (+1 for generated toekns) + # 16 IS THE NUMBER OF TOKENS + # 4 speculative tokens + # 14 (2 ) + raw_positions = old_offsets[:, None] + 1 + torch.arange(1 + self.num_speculative_tokens + 1, device=torch.cuda.current_device())[None, :] # [active_request_count, num_speculative_tokens + 1] (+1 for generated toekns) # A token crosses to the next block if its raw_position >= block_size crosses_boundary = raw_positions >= self.block_size_tokens @@ -2225,7 +2251,7 @@ def update_requests(self, active_requests_mask: Tensor, new_tokens: Tensor, new_ # Fast path: no tokens cross block boundary, all use current block self.token_to_block_idx[: self.active_token_count] = self.request_last_kv_block_id[ self.paused_request_count : self.total_request_count - ].repeate_interleave(1 + self.num_speculative_tokens) + ].repeat_interleave(1 + self.num_speculative_tokens) else: # Some tokens cross to the next block (this happens for resumed requests) diff --git a/megatron/core/inference/engines/dynamic_engine.py b/megatron/core/inference/engines/dynamic_engine.py index d099b735cd7..dddf185d50a 100644 --- a/megatron/core/inference/engines/dynamic_engine.py +++ b/megatron/core/inference/engines/dynamic_engine.py @@ -196,6 +196,8 @@ def __init__( self.context.num_speculative_tokens = num_speculative_tokens self.controller.num_speculative_tokens = num_speculative_tokens + # Initialize MTP sampling tensor now that num_speculative_tokens is set + self.controller._init_mtp_sampling_tensor() if enable_cuda_graph is not None: self.cuda_graph_impl = "local" if enable_cuda_graph else "none" @@ -813,6 +815,7 @@ def post_process_requests( evict_request_ids: torch.Tensor, step_time: float, sample: torch.Tensor, + accepted_tokens: torch.Tensor, log_probs: torch.Tensor, top_n_logprobs: Optional[Dict[int, List[Tuple[torch.Tensor, torch.Tensor]]]] = None, ) -> Tuple[List[DynamicInferenceRequest], List[DynamicInferenceRequest]]: @@ -824,7 +827,8 @@ def post_process_requests( finished_request_ids (torch.Tensor): A list of finished request ids evict_request_ids (torch.Tensor): A list of evicted request ids. step_time (float): The latency of the last step - sample: List[Tensor]: The newly generated tokens for each request (Will include speculative tokens as well) + sample: Tensor: The newly generated token for each request + accepted_tokens: Tensor: The additional accepted tokens for each request log_probs: (List): Log probs for each request top_n_logprobs: (Dict): Top-n log probs for each request. Maps request_idx to list of (top_n_logprobs, top_n_indices) tuples. @@ -841,18 +845,31 @@ def post_process_requests( log_probs_iter = log_probs if log_probs else repeat(None) - for req_idx, (request_id, tokens, request_log_probs) in enumerate( - zip(request_ids.tolist(), sample.tolist(), log_probs_iter) + # When accepted_tokens is None (no speculative decoding), use repeat([]) to provide + # empty lists for each request, so the zip produces the correct number of iterations + accepted_tokens_iter = repeat([]) if accepted_tokens is None else accepted_tokens.tolist() + + for req_idx, (request_id, tokens, accepted_tokens_list, request_log_probs) in enumerate( + zip(request_ids.tolist(), sample.tolist(), accepted_tokens_iter, log_probs_iter) ): + + # Ensure tokens is always a list for consistent handling + if not isinstance(tokens, list): + tokens = [tokens] + + if self.num_speculative_tokens > 0: + accepted_tokens = list(filter(lambda tok: tok != -1, accepted_tokens_list)) + tokens = tokens + accepted_tokens + request: DynamicInferenceRequest = self.get_request(request_id) if request_id != self.context.chunked_prefill_request_id: # Skip appending token for requests being finished due to stop words # (they already have their final token from the previous step) # If the request already has more tokens, then we only append as much as is necessary - if len(request.generated_tokens) + len(tokens) >= request.sampling_params.max_tokens: - tokens = tokens[:request.sampling_params.max_tokens - len(request.generated_tokens)] + if len(request.generated_tokens) + len(tokens) >= request.sampling_params.num_tokens_to_generate: + tokens = tokens[:request.sampling_params.num_tokens_to_generate - len(request.generated_tokens)] if request_id not in self.stop_word_being_finished_ids: - request.generated_tokens.append(tokens) + request.generated_tokens += tokens if request.tpot is None: request.tpot = [] request.tpot.append(step_time) @@ -1005,6 +1022,7 @@ def _get_and_clear_stop_word_finished_ids(self, active_request_ids: list[int]) - self.stop_word_finished_request_ids -= result return result + # TODO : We also might have to delete some tokens, if stop word hit in the middle (speculative case) def _check_stop_words_for_request_post_append(self, request: DynamicInferenceRequest) -> bool: """Check if a request should stop due to stop words (after token is appended). @@ -1223,6 +1241,7 @@ async def async_bookkeep( newly_paused_request_ids = step_result.get("newly_paused_request_ids") evict_request_ids = step_result.get("evict_request_ids") sample = step_result["sample"] + accepted_tokens = step_result["accepted_tokens"] log_probs = step_result["log_probs"] top_n_logprobs = step_result.get("top_n_logprobs", None) cuda_graph_request_count = step_result["cuda_graph_request_count"] @@ -1241,6 +1260,7 @@ async def async_bookkeep( evict_request_ids, step_time, sample, + accepted_tokens, log_probs, top_n_logprobs, ) diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index 93757d0902a..4ed3ab4d77b 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -75,7 +75,7 @@ def __init__( model_config = get_model_config(self.inference_wrapped_model.model) self.sampling_rng = torch.Generator(device=torch.cuda.current_device()) self.sampling_rng.manual_seed(model_config.inference_sampling_seed) - self.num_mtp_heads= self._get_mtp_num_heads() + self.num_mtp_heads = self._get_mtp_num_heads() if self.inference_wrapped_model.inference_context.is_dynamic_batching(): self._init_dynamic_sampling_tensors() @@ -113,7 +113,10 @@ def _init_dynamic_sampling_tensors(self): self._sampling_backend = "torch" self._sampled_tokens_cuda = torch.empty(max_requests, dtype=torch.int64, device=device) - self._sampled_mtp_tokens_cuda = torch.empty([self.num_mtp_heads, max_requests], dtype=torch.int64, device=device) + # Speculative tokens tensor will be allocated later when num_speculative_tokens is set by the engine + self._accepted_tokens_per_request = None + # MTP tensor will be allocated later when num_speculative_tokens is set by the engine + self._sampled_mtp_tokens_cuda = None # Keep track of request metadata. self._request_metadata: Dict[str, Tensor] = {} @@ -127,7 +130,20 @@ def _init_dynamic_sampling_tensors(self): # Used for inefficient torch sampling. if self._sampling_backend == "torch": - self._torch_sampling_buckets: Iterator[Tuple] = [] + self._torch_sampling_buckets: List[Tuple] = [] + + def _init_mtp_sampling_tensor(self): + """Initialize the MTP sampling tensor after num_speculative_tokens is set.""" + if self.num_speculative_tokens is not None and self.num_speculative_tokens > 0: + context: DynamicInferenceContext = self.inference_wrapped_model.inference_context + max_requests = context.max_requests + device = torch.cuda.current_device() + self._sampled_mtp_tokens_cuda = torch.empty( + [self.num_speculative_tokens, max_requests], dtype=torch.int64, device=device + ) + self._accepted_tokens_per_request = torch.ones( + [max_requests, self.num_speculative_tokens], dtype=torch.int64, device=device + ) * -1 def tokenize_prompt(self, prompt: str, add_BOS: bool = False) -> List[int]: """Utility to tokenize the input prompts. @@ -588,6 +604,8 @@ def _dynamic_step_forward_logits(self, input_ids: Tensor, position_ids: Tensor) logits = self.inference_wrapped_model.run_one_forward_step( {"tokens": input_ids, "position_ids": position_ids, "attention_mask": None} ) + # [1, seq_len, vocab_size] (logits) + # [num_speculative_tokens, seq_len, vocab_size] (mtp_logits) if self.num_speculative_tokens > 0: unwrapped_model = unwrap_model(self.inference_wrapped_model.model) @@ -596,7 +614,7 @@ def _dynamic_step_forward_logits(self, input_ids: Tensor, position_ids: Tensor) expected_mtp_logits_length, _, vocab_size = mtp_logits.shape assert expected_mtp_logits_length == self.num_mtp_heads, f"MTP logits length mismatch. Expected mtp logits length {self.num_mtp_heads}, got {expected_mtp_logits_length}" mtp_logits = mtp_logits[:self.num_speculative_tokens] - logits = torch.cat([logits, mtp_logits], dim = 0) + logits = torch.cat([logits, mtp_logits], dim = 0) # [num_speculative_tokens + 1, seq_len, vocab_size] if self.model_is_pipeline_parallel: @@ -628,7 +646,7 @@ def _dynamic_step_sample_bookkeeping(self): if self._sampling_backend == "torch": # Bucketize the core sampling parameters. # Doing so via list comprehension is orders of magnitude faster than via torch. - bucket_map = {} + bucket_map = defaultdict(list) # Shorthands for the dictionary comprehension. temp = self._request_metadata["temperature"][active_request_slice].tolist() @@ -640,11 +658,11 @@ def _dynamic_step_sample_bookkeeping(self): bucket_map[sampling_params].append(i) # Just unpack the key directly! - self._torch_sampling_buckets = ( + self._torch_sampling_buckets = [ (indices, *sampling_params) for sampling_params, indices in bucket_map.items() - ) + ] - def _update_kv_cache_bookkeeping_for_speculative_decoding(self): + def _rewind_kv_cache(self): """Update the KV cache bookkeeping for speculative decoding. After forward pass with speculative tokens, some tokens may be rejected. @@ -665,10 +683,10 @@ def _update_kv_cache_bookkeeping_for_speculative_decoding(self): # Get the accepted token counts for each request # Note: _accepted_token_counts is indexed from 0 to active_request_count-1 - accepted_token_counts = self._accepted_token_counts[:active_request_count] + accepted_tokens_per_request = self._accepted_token_counts_per_request[:active_request_count] # Number of tokens to rewind (rejected speculative tokens) - num_tokens_to_rewind = self.num_speculative_tokens - accepted_token_counts + num_tokens_to_rewind = accepted_tokens_per_request - self.num_speculative_tokens # Save the original offset BEFORE modifying to correctly detect block boundary crossing original_offset = context.request_last_kv_block_offset[active_request_slice].clone() @@ -694,7 +712,7 @@ def _update_kv_cache_bookkeeping_for_speculative_decoding(self): # 3. Update request_last_kv_block_id to point to the previous block # 4. Clear the entry in request_to_kv_block_ids for the released block # 5. Release the block back to the allocator - if remove_allocated_blocks_mask.any(): + if remove_allocated_blocks_mask.any(): # Get indices of requests that need to release a block (relative to active requests) requests_needing_release = torch.nonzero(remove_allocated_blocks_mask, as_tuple=True)[0] # Convert to absolute indices in the context tensors @@ -711,151 +729,183 @@ def _update_kv_cache_bookkeeping_for_speculative_decoding(self): # Update request_last_kv_block_id to point to the previous block # and clear the released block entry in request_to_kv_block_ids - # TODO : This can be easily vectorized. - for i, req_idx in enumerate(absolute_indices): - new_count = new_block_counts[i].item() - if new_count > 0: - # Update to point to the previous block (at index new_count - 1) - context.request_last_kv_block_id[req_idx] = context.request_to_kv_block_ids[ - req_idx, new_count - 1 - ] - # Clear the released block entry (at index new_count, which was the old last block) - context.request_to_kv_block_ids[req_idx, new_count] = -1 + # Vectorized implementation using advanced indexing: + # Note: new_block_counts is guaranteed to be > 0 for all requests here, since + # crossing back to a previous block implies the request had at least 2 blocks. + + # Update request_last_kv_block_id to point to the previous block (at index new_count - 1) + context.request_last_kv_block_id[absolute_indices] = context.request_to_kv_block_ids[ + absolute_indices, new_block_counts - 1 + ] + + # Clear the released block entry (at index new_count, which was the old last block) + context.request_to_kv_block_ids[absolute_indices, new_block_counts] = -1 # Release the blocks back to the allocator context.block_allocator.release_memory_blocks(blocks_to_release) - def _dynamic_step_sample_logits_with_speculative_tokens(self, logits: Tensor, mtp_logits: Tensor, input_ids: Tensor): - """Sample tokens from logits for dynamic batching with speculative tokens. + def _dynamic_step_sample_logits_and_verify_tokens(self, logits: Tensor, mtp_logits: Tensor, input_ids: Tensor): + f"""Sample tokens from logits for dynamic batching with speculative tokens and verify the tokens. + """ + context: DynamicInferenceContext = self.inference_wrapped_model.inference_context + active_request_count = context.total_request_count - context.paused_request_count - # E.g lets say 3 requests are present (Total 11 tokens) - # Request 1 : [b1 b1s1 b1s2] (b1 is the generated token, and b1s1 and b1s2 are the speculative tokens which were generated last time when b1 was generated) - # Request 2 : [c1 c1s1 c1s2] (c1 is the generated token, and c1s1 and c1s2 are the speculative tokens which were generated last time when c1 was generated) - # Request 3 : [a1 a2 a3 a4 a5] (This is a new request, so all are input tokens) - # input ids : [b1 b1s1 b1s2 c1 c1s1 c1s2 a1 a2 a3 a4 a5] - # logits : Tensor of size [1, 11, vocab_size] where each position tells the probability of the tokens at the next position (e.g) Logits at b1 tell the probability of the tokens at b1s1 and so on . - # mtp_logits : Tensor of size [num_speculative_tokens, 11, vocab_size] where each position tells the next mtp heads probabilites. - The idea here is to verify which tokens need to be accepted based on the input tokens sent (which includes speculative tokens as well) and the current logits and update the _sampled_tokens_cuda and _sampled_mtp_tokens_cuda tensors . - E.g for request 1, we need to accept b1s1 if the sampled logits at position 0 is b1s1. Lets say the sampled logit at position 1 is not b1s2, then we need to reject b1s2 and so just use the corresponding sampled logit at position 1, and the speculative tokens at position 1 for the next pass. For the last request, which is a new request, we just need to sample the logit at last position and the speculative tokens as well at that position. + # ================ PART 1 The following part of the code is to get all the relevant logit indices alone ========= + # i.e For prefill requests just the last token logits are enough. + # i.e For decode requests we will need all tokens + # Decode request will always be on the left, followed by prefill requests + # In non speculative case, it was simple in the other function, we just always get the last token logits using query lengths. - The final idea is : - 1. To populate the _sampled_tokens_cuda and _sampled_mtp_tokens_cuda tensors with the tokens that need to be accepted and the next pass of speculative tokens. - 2. To store _accepted_token_counts for verify_and_update_for_mtp_tokens to update KV cache bookkeeping. + # 5 requests # Input ids shape : [1, 15] + # Assume input ids : [ a5 a6s a7s | b3 b4s b5s | c6 c7s c8s | d1 d2 | e1 e2 e3 e4] + # Request to prefill [ 0 | 0 | 0 | 1 | 1 ] + # Request query lengths [ 3 | 3 | 3 | 2 | 4 ] + # OUTPUT : required_logit_indices [ 0 1 2 | 3 4 5 | 6 7 8 | 10 | 14 ] + + request_in_prefill_status_tensor = context.request_in_prefill_status_tensor[context.paused_request_count : context.total_request_count] + request_query_lengths = context.request_query_lengths[context.paused_request_count : context.total_request_count] - Args: - logits (Tensor): The logits from the forward pass. Shape: [1, seq_len, vocab_size] - mtp_logits (Tensor): The MTP logits from the forward pass. Shape: [num_speculative_tokens, seq_len, vocab_size] - input_ids (Tensor): The input IDs. Shape: [1, seq_len] - """ - context: DynamicInferenceContext = self.inference_wrapped_model.inference_context - active_request_count = context.total_request_count - context.paused_request_count - active_request_slice = slice(context.paused_request_count, context.total_request_count) - - # Get query lengths to identify decode vs prefill requests and token boundaries - query_lengths = context.request_query_lengths[active_request_slice] - query_cumsum = query_lengths.cumsum(dim=0) - query_starts = torch.cat([ - torch.zeros(1, device=query_cumsum.device, dtype=query_cumsum.dtype), - query_cumsum[:-1] - ]) - - # Squeeze logits and input_ids: [1, seq_len, vocab] -> [seq_len, vocab] - logits_2d = logits.squeeze(0) - input_ids_1d = input_ids.squeeze(0) - - device = self._sampled_tokens_cuda.device - num_spec_tokens = self.num_speculative_tokens + num_prefill_requests = request_in_prefill_status_tensor.sum().item() + num_decode_requests = active_request_count - num_prefill_requests - # Initialize acceptance count tracker (will be used by verify_and_update_for_mtp_tokens) - # This stores how many speculative tokens were accepted for each request - self._accepted_token_counts = torch.zeros(active_request_count, dtype=torch.int32, device=device) + decode_request_indices = torch.arange(num_decode_requests * (self.num_speculative_tokens + 1), device=logits.device) + prefill_request_indices = request_query_lengths.cumsum(dim=0)[request_in_prefill_status_tensor == 1] -1 # Last token indices for prefill requests + required_logit_indices = torch.cat([decode_request_indices, prefill_request_indices]) + assert len(required_logit_indices) == num_decode_requests * (self.num_speculative_tokens + 1) + num_prefill_requests, f"Expected length of required_logit_indices to be num_decode_requests * (self.num_speculative_tokens + 1) + num_prefill_requests, but got {len(required_logit_indices)} for num_decode_requests {num_decode_requests} and num_prefill_requests {num_prefill_requests}" + + + required_logits = logits.squeeze(0)[required_logit_indices, :] # Shape [1, 11, vocab_size] + required_mtp_logits = mtp_logits[:, required_logit_indices, :] # Shape [num_speculative_tokens, 11, vocab_size] + + # ================ PART 1 The following part of the code is to sample the logits and mtp logits based on the sampling parameters ========= + + # request_indices will be 0, 1, 2, 3, 4 (since we have only 5 requests) + # For torch sampling buckets :-[request_indices, temp, top_k, top_p] + # [ + # [[0,2], temp1, top_k1, top_p1], + # [1], temp3, top_k3, top_p3] + # [3, 4], temp2, top_k2, top_p2], + # ] + + # Token to request idx : [ 0 0 0 | 1 1 1 | 2 2 2 | 3 | 4 ] + # required_logits : [ a5l a6l a7l | b3l b4l b5l | c6l c7l c8l | d2l | e4l ] # Shape [11, vocab_size] + # For first iteration : + # sampling buckets : [0,2], temp1, top_k1, top_p1 + # output_tokens_jumbled_list = [a5s a6s a7s c6s c7s c8s] #s->sampled tokens # + # request_order_list = [0, 2] + # token_order_list = [0, 1, 2, 6, 7, 8] + # For second iteration : + # sampling buckets : [1], temp3, top_k3, top_p3 + # output_tokens_jumbled_list = [b3s b4s b5s] + # request_order_list = [1] + # token_order_list = [3, 4, 5] + # For third iteration : + # sampling buckets : [3, 4], temp2, top_k2, top_p2 + # output_tokens_jumbled_list = [d2s e4s] #s->sampled tokens # + # request_order_list = [3, 4] + # token_order_list = [9,10] + # Final output tokens : [a5s a6s a7s c6s c7s c8s b3s b4s b5s d2s e4s] # Shape [11] + # Final request order list : [0, 2, 1, 3, 4] + # Final token order list : [0, 1, 2, 6, 7, 8, 3, 4, 5, 9, 10] + + + repeats = torch.where(request_in_prefill_status_tensor == 0, 1 + self.num_speculative_tokens, 1) + token_to_request_index = torch.repeat_interleave(torch.arange(len(request_in_prefill_status_tensor), device=request_in_prefill_status_tensor.device), repeats) + + output_tokens_jumbled_list = [] + mtp_output_tokens_jumbled_list = [] + token_order_list = [] + + # TODO : Maybe its okay to have a loop with num spec tokens ? (Since it will only be max 3 , so might be faster) + for request_indices, temp, top_k, top_p in self._torch_sampling_buckets: + request_indices_tensor = torch.tensor(request_indices, device=token_to_request_index.device) + required_indices = torch.where(torch.isin(token_to_request_index, request_indices_tensor))[0] + # TODO : Can maybe club the following two and then split later ? + # TODO : Can directly initzlie output tokens as a tensor and put the logits in the right place + output_tokens_jumbled_list.append(self._torch_sampling_func(required_logits[required_indices, :], temp, top_k, top_p)) + mtp_output_tokens_jumbled_list.append( + self._torch_sampling_func(required_mtp_logits[:, required_indices, :], temp, top_k, top_p) + ) + token_order_list.append(required_indices) + - # todo : tHIS IS NOT FOOL PROOF, Need to find another way to identify decode vs prefill requests - # Can create a new tensor which states prefill vs decode - expected_decode_length = 1 + num_spec_tokens - is_decode_mask = (query_lengths == expected_decode_length) + + output_tokens_jumbled = torch.cat(output_tokens_jumbled_list, dim=0) + output_tokens = torch.empty(len(output_tokens_jumbled), device=output_tokens_jumbled.device, dtype=output_tokens_jumbled.dtype) + token_order = torch.cat(token_order_list, dim=0) + # Rearrange output tokens because previously it will be in the order of the sampling_bucket request indices, but now we want to put them according to their corresponding input ids + output_tokens[token_order] = output_tokens_jumbled + + mtp_output_tokens_jumbled = torch.cat(mtp_output_tokens_jumbled_list, dim=1) # Shape [num_speculative_tokens, total_tokens] + mtp_output_tokens = torch.empty_like(mtp_output_tokens_jumbled) + mtp_output_tokens[:, token_order] = mtp_output_tokens_jumbled + + ### ================ PART 3 This part is to do the fowlling : ================ + # Create the accepted tokens tensor + # For prefill it is always set to 1 + # For decode, the first token is always accepted, then we compare with input tokens and accept the next tokens if its a match + # Then find the index of the last 1 in every request of the accepted tokens tensor + # Then these are the index of the tokens that will be sent to the next forward pass + # In the example (assume 1 spec token, 2 spec tokens and 0 sepc tokens are accepted in teh first 3 requests + + + # Assume input ids : [ a5 a6s a7s | b3 b4s b5s | c6 c7s c8s | d1 d2 | e1 e2 e3 e4] + # input_tokens_required: [ a5 a6s a7s | b3 b4s b5s | c6 c7s c8s | d2 | e4 ] # Size 11 + # Output tokens [ a6o a7o a8o | b40 b5o b6o | c7o c8o c9o | d3o | e5o ] # At every index we get next positions sample + # Output tokens right shift [ d3o a6o a7o | a8o b40 b5o | b6o c7o c8o | c9o | d3o ] + # Accepted tokens mask [ 1 1 0 | 1 1 1 | 1 0 0 | 1 | 1 ] + # Last one indices [ 1 | 5 | 6 | 9 | 10 ] + - # Process each request - for req_idx in range(active_request_count): - req_start = int(query_starts[req_idx].item()) - req_length = int(query_lengths[req_idx].item()) - - # Get sampling parameters for this request - ctx_req_idx = context.paused_request_count + req_idx - temp = float(self._request_metadata["temperature"][ctx_req_idx].item()) - top_k = int(self._request_metadata["top_k"][ctx_req_idx].item()) - top_p = float(self._request_metadata["top_p"][ctx_req_idx].item()) - - if is_decode_mask[req_idx]: - # ================================================================ - # DECODE REQUEST: Verify speculative tokens from previous step - # ================================================================ - # Token layout: [main_token, spec_token_1, ..., spec_token_k] - # logits[pos] predicts the token at position pos+1 - # We verify: does sample(logits[pos]) == input_ids[pos+1]? - - accepted_count = 0 - sample_pos = req_start # Position from which to get new speculative tokens - - for spec_idx in range(num_spec_tokens): - # Sample from logits at position (req_start + spec_idx) - # This predicts the token at position (req_start + spec_idx + 1) - pos = req_start + spec_idx - logit = logits_2d[pos:pos+1, :] - sampled_token = self._torch_sampling_func(logit, temp, top_k, top_p).item() - - # The speculative token we're verifying is at position (pos + 1) - spec_token = input_ids_1d[pos + 1].item() - - if sampled_token == spec_token: - # Speculative token matches! Accept it and continue verification - accepted_count += 1 - else: - # Rejection: sampled token differs from speculative token - # Use the sampled token as the next output token - self._sampled_tokens_cuda[req_idx] = sampled_token - sample_pos = pos - break - else: - # All speculative tokens were accepted - # Sample a new token from the position after the last speculative token - sample_pos = req_start + num_spec_tokens - logit = logits_2d[sample_pos:sample_pos+1, :] - sampled_token = self._torch_sampling_func(logit, temp, top_k, top_p).item() - self._sampled_tokens_cuda[req_idx] = sampled_token - - self._accepted_token_counts[req_idx] = accepted_count - - # Get new speculative tokens from MTP logits at the sample position - # These will be used as speculative tokens in the next forward pass - for mtp_idx in range(num_spec_tokens): - mtp_logit = mtp_logits[mtp_idx, sample_pos:sample_pos+1, :] - spec_token = self._torch_sampling_func(mtp_logit, temp, top_k, top_p).item() - self._sampled_mtp_tokens_cuda[mtp_idx, req_idx] = spec_token - - else: - # ================================================================ - # PREFILL REQUEST: Sample from the last position only - # ================================================================ - # No speculative tokens to verify for new requests - # Just sample the next token from the last position's logits - - last_pos = req_start + req_length - 1 - logit = logits_2d[last_pos:last_pos+1, :] - sampled_token = self._torch_sampling_func(logit, temp, top_k, top_p).item() - self._sampled_tokens_cuda[req_idx] = sampled_token - - # For prefill, acceptance count represents that all prompt tokens are in KV cache - # (though semantically different from decode's speculative token acceptance) - self._accepted_token_counts[req_idx] = 0 # No speculative tokens were verified - - # Get speculative tokens from MTP logits at the last position - # These will be used as speculative tokens in the next forward pass - for mtp_idx in range(num_spec_tokens): - mtp_logit = mtp_logits[mtp_idx, last_pos:last_pos+1, :] - spec_token = self._torch_sampling_func(mtp_logit, temp, top_k, top_p).item() - self._sampled_mtp_tokens_cuda[mtp_idx, req_idx] = spec_token + input_tokens_required = input_ids[0, required_logit_indices] + if input_tokens_required.ndim == 2: + assert input_tokens_required.shape[0] == 1, f"Expected input_tokens_required to have 1 row, but got {input_tokens_required.shape}" + input_tokens_required = input_tokens_required.squeeze(0) + + # This is to get the place where the output sampled speculative token is equal to input token + output_right_shifted = output_tokens.roll(1) + accepted_tokens_mask = input_tokens_required == output_right_shifted + + # This is to make all prefill tokens accepted + token_to_prefill_idx = torch.repeat_interleave(request_in_prefill_status_tensor, repeats) + accepted_tokens_mask[token_to_prefill_idx == 1] = 1 + + # This is to make first decode token in all requests accepted + deocde_query_starts = torch.arange(num_decode_requests) * (1 + self.num_speculative_tokens) + accepted_tokens_mask[deocde_query_starts] = 1 + + # This is to find the index of the last 1 in every request + last_one_indices = torch.full((active_request_count,), -1, device=token_to_request_index.device) + last_one_indices[token_to_request_index[accepted_tokens_mask == 1]] = torch.where(accepted_tokens_mask == 1)[0] # [1, 5, 6] + + # These are the tokens (output + speculative tokens) that will be going to the next forward pass + final_sampled_tokens = output_tokens[last_one_indices] + self._sampled_tokens_cuda[:len(final_sampled_tokens)] = final_sampled_tokens + self._sampled_mtp_tokens_cuda[:, :len(final_sampled_tokens)] = mtp_output_tokens[:, last_one_indices] + + ### ================ PART 4 This part is to do the fowlling : ================ + # To fill the speculative otkens and accepted_token counts + # For prefill it is always set to 1 + # For decode, the first token is always accepted, then we compare with input tokens and accept the next tokens if its a match + # Then find the index of the last 1 in every request of the accepted tokens tensor + # Then these are the index of the tokens that will be sent to the next forward pass + # In the example (assume 1 spec token, 2 spec tokens and 0 sepc tokens are accepted in teh first 3 requests + + + # input_tokens_required: [ a5 a6s a7s | b3 b4s b5s | c6 c7s c8s | d2 | e4 ] # Size 11 + # Accepted tokens mask [ 1 1 0 | 1 1 1 | 1 0 0 | 1 | 1 ] + # Accepted tokens [ [a6s -1] | [b4s b5s] | [-1 -1] ] # Only handle decod requests, (Prefill already defaults to -1s) + # Accepted token counts [ 1 | 2 | 0 ] # Prefill defaults to 0 + + # This part tis to extract the accepted tokens + input_tokens_required[accepted_tokens_mask == 0 ] = -1 # Masks out non accepted tokens + input_tokens_decode_mode = input_tokens_required[:num_decode_requests * (self.num_speculative_tokens + 1)] + input_tokens_reshaped = input_tokens_decode_mode.reshape(-1, self.num_speculative_tokens + 1) # shape : [num_decode_requests, num_speculative_tokens + 1] + + accepted_tokens = input_tokens_reshaped[: , 1:] # Skip the first token of every decode request (i.e a5, b3, c6) + self._accepted_tokens_per_request[:accepted_tokens.shape[0],:] = accepted_tokens + self._accepted_token_counts_per_request = (self._accepted_tokens_per_request != -1).sum(dim=1) def _dynamic_step_sample_logits(self, logits: Tensor): """Sample tokens from logits for dynamic batching. @@ -1058,7 +1108,12 @@ def _dynamic_step_context_bookkeeping(self) -> Dict[str, Tensor]: # Active sequence lengths. active_request_ids = context.request_ids[active_request_slice].long() active_sequence_lengths = context.get_active_sequence_lengths() - active_sequence_lengths += self._accepted_token_counts + 1 # SHAN CHECK IF YOU NEED +1 + + if self.num_speculative_tokens > 0: + accepted_token_counts_per_request = self._accepted_token_counts_per_request[:active_request_count] + active_sequence_lengths += accepted_token_counts_per_request + 1 + else: + active_sequence_lengths += 1 max_sequence_lengths = context.get_max_sequence_lengths() # Request finished if termination_id or length >= max_sequence_length. @@ -1069,7 +1124,7 @@ def _dynamic_step_context_bookkeeping(self) -> Dict[str, Tensor]: ).byte() & torch.less(active_sequence_lengths, max_sequence_lengths).byte() # Mark requests as finished if they hit stop words (detected in previous step's post_process_requests) - # TODO : SHAN : Implement this + # TODO : SHAN : Correclty implement this if self._get_stop_word_finished_ids_callback is not None: request_ids_list = active_request_ids.tolist() stop_word_finished_ids = self._get_stop_word_finished_ids_callback(request_ids_list) @@ -1087,7 +1142,12 @@ def _dynamic_step_context_bookkeeping(self) -> Dict[str, Tensor]: new_sample_copy = self._sampled_tokens_cuda[:active_request_count].clone() # Update requests. - update_result = context.update_requests(active_request_mask, new_sample_copy, self._sampled_mtp_tokens_cuda[:active_request_count]) + # _sampled_mtp_tokens_cuda has shape [num_speculative_tokens, max_requests] + if self.num_speculative_tokens > 0: + sampled_mtp_tokens_cuda = self._sampled_mtp_tokens_cuda[:, :active_request_count] + else: + sampled_mtp_tokens_cuda = None + update_result = context.update_requests(active_request_mask, new_sample_copy, sampled_mtp_tokens_cuda) return { "active_request_ids": active_request_ids, @@ -1130,9 +1190,11 @@ async def async_generate_output_tokens_dynamic_batch( logits_and_mtp_logits = self._dynamic_step_forward_logits(input_ids, position_ids) mtp_logits = None if logits_and_mtp_logits.shape[0] > 1: - logits = logits_and_mtp_logits[:1] - mtp_logits = logits_and_mtp_logits[1:] - print(f"mtp_logits: {mtp_logits.shape}", "logits: {logits.shape}") + logits = logits_and_mtp_logits[:1] # [1, seq_len, vocab_size] + mtp_logits = logits_and_mtp_logits[1:] # [num_speculative_tokens, seq_len, vocab_size] + print(f"mtp_logits: {mtp_logits.shape}",f"logits: {logits.shape}") + else: + logits = logits_and_mtp_logits # This is the best place to yield control back to event loop. # At this point we have enqueued FW pass GPU kernels asynchronously. @@ -1144,21 +1206,18 @@ async def async_generate_output_tokens_dynamic_batch( await asyncio.sleep(0) # For now lets not care about log probs and top n logprobs return_log_probs, return_top_n_logprobs = self._dynamic_step_log_probs_bookkeeping() + self._dynamic_step_sample_bookkeeping() if self.num_speculative_tokens > 1: - self._dynamic_step_sample_logits_with_speculative_tokens(logits, mtp_logits, input_ids) - self._update_kv_cache_bookkeeping_for_speculative_decoding() + self._dynamic_step_sample_logits_and_verify_tokens(logits, mtp_logits, input_ids) + self._rewind_kv_cache() else: self._dynamic_step_sample_logits(logits) - # Afer this you have - # self._sampled_tokens_cuda : [active_request_count] - # self._sampled_mtp_tokens_cuda : [num_mtp_heads, active_request_count] log_probs = None top_n_logprobs = None - # TODO SHAN : Implement all of this if return_log_probs or return_top_n_logprobs: log_probs, log_probs_tensor = self._dynamic_step_calculate_log_probs(logits) if return_top_n_logprobs: @@ -1170,9 +1229,11 @@ async def async_generate_output_tokens_dynamic_batch( request_bookkeeping = {} else: request_bookkeeping = self._dynamic_step_context_bookkeeping() + sample = self._sampled_tokens_cuda[:active_request_count] ret = { - "sample": self._sampled_tokens_cuda[:active_request_count], + "sample": sample, + "accepted_tokens": self._accepted_tokens_per_request, "log_probs": log_probs, "top_n_logprobs": top_n_logprobs, "cuda_graph_request_count": cuda_graph_request_count, From e911a17dc2c961d23cf4ce11cd8d253ead66cb89 Mon Sep 17 00:00:00 2001 From: Shanmugam Ramasamy Date: Fri, 6 Feb 2026 10:34:57 -0800 Subject: [PATCH 03/76] Added comments and cleaned up code --- deepseek_inference.sh | 20 ----- ....tfevents.1769037870.pool0-01476.3253909.0 | Bin 62232 -> 0 bytes ....tfevents.1769037919.pool0-01476.3254808.0 | Bin 63278 -> 0 bytes .../inference/contexts/dynamic_context.py | 69 ++++++++---------- .../core/inference/engines/dynamic_engine.py | 1 + .../text_generation_controller.py | 48 ++++++------ 6 files changed, 55 insertions(+), 83 deletions(-) delete mode 100644 deepseek_inference.sh delete mode 100644 deepseek_mtp_dummy_ckpt/tensorboard/events.out.tfevents.1769037870.pool0-01476.3253909.0 delete mode 100644 deepseek_mtp_dummy_ckpt/tensorboard/events.out.tfevents.1769037919.pool0-01476.3254808.0 diff --git a/deepseek_inference.sh b/deepseek_inference.sh deleted file mode 100644 index 9f735291060..00000000000 --- a/deepseek_inference.sh +++ /dev/null @@ -1,20 +0,0 @@ -torchrun --nproc-per-node 1 \ - -m examples.inference.gpt.gpt_dynamic_inference \ - --load /lustre/fsw/portfolios/coreai/users/shanmugamr/Megatron-LM/deepseek_mtp_dummy_ckpt/checkpoints \ - --bf16 \ - --model-provider gpt \ - --tensor-model-parallel-size 1 \ - --micro-batch-size 16 \ - --num-tokens-to-generate 20 \ - --inference-dynamic-batching-buffer-size-gb 5 \ - --prompt-file /lustre/fsw/portfolios/llmservice/users/ksanthanam/megatron-lm/debug_prompts.jsonl \ - --use-checkpoint-args \ - --enable-cuda-graph \ - --incoming-requests-per-sec 16 \ - --dist-ckpt-strictness log_unexpected \ - --decode-only-cuda-graphs \ - --tokenizer-type HuggingFaceTokenizer \ - --tokenizer-model deepseek-ai/deepseek-coder-6.7b-base \ - --no-use-tokenizer-model-from-checkpoint-args \ - --output-path /lustre/fsw/portfolios/coreai/users/shanmugamr/Megatron-LM/output.json \ - --return-log-probs \ No newline at end of file diff --git a/deepseek_mtp_dummy_ckpt/tensorboard/events.out.tfevents.1769037870.pool0-01476.3253909.0 b/deepseek_mtp_dummy_ckpt/tensorboard/events.out.tfevents.1769037870.pool0-01476.3253909.0 deleted file mode 100644 index 10edc841e5f1e7c1ffe5291cf86921cb29757c12..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 62232 zcmcJ22Xs``^FLrvcIhCUkkCU(APK!mlO|n?G?&emWXU$Rq@V~E5D`>FiXb3OP&xuq z6a^8)-ld5MQlw+S3jAm8o85Qk?LNM{{LUZGQ4WUB+&6b-?zB6%R}tmc%@=zuGAy6| zZpnv6AL&xPbe7NTaHrj$XLASIWeiR;HVxR^evdE9WA<6o{6Uw?>?=&m_c;PKUz!bn z%4ByqZ6@VUPIf=Eu@m7Q{;Km?!yki|ltSFbnZ|Tu6?f2Oa+(WmK7X@-tsr0`xUQv2 z+*2W;La`D>yB14t)u)8ujP)FhV3H=t-Z0WG0En(c&s*`DNfXH zYYb`COaAt((S+L3*xs0AwcQ@IOn9vc{XK4*y4n&~+SKbzxG7bBpJUkbFyXc|W*V#J zIILEi+vImluZQGdUX#watXr&e+CS({6X0bTOqxrZ>-M zp)PWzf8~9QrEZJF(qv?%IRHBzaCiXTX0~c}s$_VvMcb&T_YHWe?t`?{(i8j7GL*0r zYA<6CV=Wai%WTQDxviW5)x((Ox_uFB7(nS%k3N@DN}YfIfnBnH{?%mzyik!WGe zFjmdYGg)%XZnw>;S+FEpnHefvp|y9lz6AHjH_Ki%>>5XKU5xh`>ty>pL9c0i(B>;t z`N3CRHoWTMJ%c+us2Pq|Ej_0BJ4KCR>nl+v4Sk9Xx1w+I9*-t_Pd|| zlr7ipk(aA7*b@wRg8@^vIiQg;{pFVKx4b%&vo+%h8|AE;-(wHB%mt=Ep;rT0KFoo8 z^2~0B(`i;0Y@au?Q5WJtk=OUGF@Wq70JYa_t)M!^yw{1o*V! z<++A-c?8(m*vVLj0D%#pVz~~l$?mg_=NqOAeZ18(mVN}?(elk}hUW_ix~s9Xu^xWB z*W-5xm9V+8Y*wqoovqn637jk3>gz9DWZ*|8$l#H{Q5Y?O`#fGr(proub9FGSy#9rI z-ZIphNZ>t<_Zyq|Y)*5)k!MpySlG&}AShu#^8q?~@>tg|ZVVv`ra|{VZ5TR z|9czsrTo&H>tE_hL<)a%zhOA^BoS$CY-y}Y1hUM2&7vj5^yctO74aq$H=bu|USldD zqMFwTE!pO?fH~!&D03oBf-L-d`I6qW$X4~sEH@0AMv$G19gIl?>2Sl@>$I8kY#Ku< ztyN=^0>|5XM-cGzMB6sQ3)2a>zp<~eIRX1^<4sODqXYr32ZaYt0lV1}@c8%#ND|4_ z^u-d1`8|oy&iDUaXgDx~2z4;FGbSo~>9v{s7PAwyh@eGEOpIRYZNKw zY>=2c#21v*{hjp8cbLR2m_yhoan)T48*CQbHeE=*U*(EqeLINu>y@R4-!_z3K#)C* z-HnZ0=3JX8&toxX30n<{o#*g7vcTWVD&!O-4m0Kce#yGQ%;5;WLIluu>3A93@4^1?gR1V$Ir>)yZzCRCqz_=6tIVmJhF{Jx5Z>}nq6K~p4l1HoE_4Vxz;s0 zT5$23)mc44fe`bYdD^d)5N7uT0C z_x8Xp0!R0@hAZIJjY=g4KU1|n`&vyq$c~K$y9p4Fjd}>^^!WX;9&*joq~M{w}YA7YqA+Z8f<(F3|a)i$hBMgB7YD zT)Xk(Gnn|C`i=lme5$*G;33C{IPG-(^zI}+-A1G>Ff=?&h;5Cn!FC{`-|KV)Kzj1{ zA|;XJR<*~i&5tnu*mRUI(Lb&wFwysR+x!qacr>um?r40M)BSq?!knTxKM*cDMKzVR zf-7$W9#}O8ONDTHpoHHGE{)IZRM~16_heUQoe%yIh2SValLYlNpC^{3LN$;Ra@5++a@w5yB-IXvsYsF}}%vEU?M=nJ_W1 zsU{W-MmmxW6~pT^WH~*S+CD)y`GcT^u?a2LXLjdm zETSX>qneLV^(@7`rWTX06Mqt-$a8{Ld3<0JZ7vVGagerw|HisyvpNl*&+<04{~>J5 z+temdzuWBf=Xe4puQ`yzp{s>eqipD>(!AiUJEfR++rbufm`Bk2KO!QMsV;}b=P_lO z1HkP_??qB&>HL1=NW!~JC;~SL5`_Z1Z%-BkV)0=fnsTNaGY@Nuys8vF&^*-gL4*Ja zWj#z?9qah{iRvuH-e!$GH}uqljruSVDOH?kpolcE zxx9fwQ_u}^9TpTP6n>n1FC|t>`gh{D&7WE5tlN}Z@vx!IJ%s&`ahS24CC3J_8Cjtm zvp>gVhV2A5*Xo6E63#%2*{@`yAiCo8bXk_P+gqL;Sv~`i#3L&iOsNCT*FuQfM4Bsh z;Lj8MYni=!tt0`Xy{oDKV_8h2ni6LuAzWD^`xuk!Z%Y%J$mP&7L2R5h4W90o-}~w&Xo#mTe6P8qIPO#jK$axXqYNS3-BpWJJbo5)ICj{9mp;#?10- zDMUk@+$s%NdA9fnZLL;cGQOL$uQRLHo@h+iSg&n_=E}lU&wFx2QwOjn-MA+joK=zUQd9;ZjLfh6)$PIoE@`X`Qj`K=pEe&STw0cY}xG|h_>vk zpffibu?8x$X0$^&*7#T!qjFI5o z7K@ZUIq{EcFS9Dg1`#}3mAV*$qGfTop=m*P0gGsg?oBjABPA#rYUE2R(Laq6SrFt6{i zhX@|MzG_ZS_W$;CB>#Ns>z6t%C3{rtRJ|I8lOqXkPEq5FMZsdIZE;W*)i`-M;mG#k zX3ijNv8|cipWn>u@dfN2r^DlK264B|>}UpC1v=&j57!mUHoJVy`q`j&$LDdU_U+fq zYJ=pI&6aDz^uHBa?+Q(pTrXxm!X0^Pq<>-Qy9=1RlQ@PJEZiNJ*`Ev6In>Q#5BhIc zB<9LJXyvZr>}(urB4|<6&vHPNh6P1aG*tZ^o7?QbNTz(#%mgcJoPbp-fl+}a3Fgba z`9PPQ?7-S;AzVDLnxb~VLV+1d;E7zhgA(SgCP}-Pp8RAZD#G-O)l;D&zU5cFHmoz_ zcIh06+qI}(EZP#{*~FF&ov$+e=sS*JMRid04#biDEA#WyMcH}%l9S-@yl$X`@sMW+ zTqv7kOdp?c4E?5MQ&xlh*iHCYgVo?G&?W~&%iLT@Nn$sI5P4E!=|u5a8(1JyWIPcN zfsBw6r$|XaPZ>wmmAr~6Y3a<#$n+DYa5Av@L&$|v+~uJ9da|c@olY#Sw+0CnQw7x! z6)FV+IabJik7wS;av2Nq2~y;4p=^a^rdZM5n)yswmbUz|kkCZhlF&?Mr#HvUF{QL+ z^E4^uPF+-C0Apm;M0WV}eQ z1C9NS8D^1ARlE}65<5Dkxfkd*}!6 zQz@*Quo){>S!_DqT8TLv-_9j$Q6+`;#~|D(00)#7A1kgaZ|t9}4Y%&ggo|ysjZi2Y zqRqr&2|_V5)H)Q02bMhY@3HJO^e-d|q9e&_!}=)r&Q=qut_z(@Wg}u0oW?U&&S9+C zut;VNqB_AdwG>8$)mP2=2giKP4Efb1ge%fOp+HEfcgGyXuFo1e}%q<-KE|bqGitn-R z-srb`dDuy`ayy~oNt6s|kj8`3qF;pjk;MXO2jzP^HJj}}$fB+zI|&w}F1WYj75Tw1 zF!p!wCF=lpv#x8yeAv?O5HR|%H3hKZG(-x7R4I}A1bQBsIh!H2eK!J8Ep6~cMn}9q zyYdSx`|a67h$8zf)~Gb2Lf!y^u6(T!NT)7uX&e1Pi9Z?g%J*-B9Q1-v@CEB4Pq=4K zfNAl?4+vQVFJjHPP;G}6G>#|Ut;^z{!3|}>Jmxl zZnF8X%By@I$JXcG`N6)&n0+byn!rW9CIs$=dx0Le-Jzc!=WG7`*y7>LesB4fV9|cp zg$=RV>}D)@ATzKo8HC&(yUZ#{7Hh|P@XaYsrV0OiPtd3di88bY;uQ`m83ZYy?NnsM zB^E-YA0bE#A+(UW98xlbsZ!!I^)1WE(IhAbKmGf8!@^^PImkG`m;uWSoxEyn2rU3u zz6#4KC#1iLUwm}%2vfEL$BC9uHi;JA?=(Zr6Vjl9l<1F$;{wr`pU%x-UdzoBL`irp zK*(YWaP+;SZ{U78u`QKt!7 zbW8v=1TL8Hhf7ONClvf*MjrF-wqtA`W}MaRLwIZ9SjmQyI}7@9{GeJsh&v+F4*K8l zkR|sVldikZ6D5?cO!x*!;z5=iZx5nTg&rUftbh+3RK*5aJT^NQFJ1dHv)PZ!*jIJE zsQD_j<^X*$&|?`o;9ZVr?Ev~;_1WfPOZt<}?AuX%o#BbgM5&vxi?JSj75dy#d}+yz zLw5V!Th26w(9NTE?J(@QO6Y@)1C6Z!U5gA|vvAde!db<`f}0wgZIb%8luSzd^5wr{ zpYrFg(x=pqM$GKY_Hb-zOY9E2bc=gXtpNdsp7Mu;j*JQ^_MIrmtaa57Ge%kdlo-_l z&{Z*xo~?-V?W9lPC&ep&Oinh@u83=0OaCTfc&!U}MJ1*eI>te6u47z(Cd}XcOXSd0 z@^8ba0z}U3usa}*p+3x*d-f*(_RD7GDO9*gl+aUX3m+Cb0MrjuC@S?7zbQWlFOfpa z7MBWr{V6-TX?f-7GRAyPD5thoI(vG>Twrfa>m)KlJP>mJ*y?j1WbK3w#8) zH4ZzaGsr`PcJO>{Mo@pzeG80JSx#n00+A3`CxJvLj|9arH?=UC0#Jw3*k#FKXqher z5HIg`e!>F2vv(5#5%2*4ann7^gMCp@E{&Dj$^FVb)Q%*#!kqIb4Ydt~+sv41EDse& zjSnm7Kzc@|G)>~SdA~MfH;YD>BqZ!Gs{yrKpRnGdBP7=4b!u0%G|K=kC{3Uu1E>n! z9bo6)e5OCET76xXutn7h^P}pnL0g#B;{r8yXqwISW#ahx(d!vQ?yVp(q$Wa<978az z&t-^6BsthQcVbtTEqd@?0u_Dy;eH;m(pccRA^W!X?thpZKUbOHQH~qRYc4S7IUoc< z%MwT0^_dZMyR!tto~lFu6AU$+&?||zuYe8{dk5opNCS zTXDq(+7+99E@+mIUeVHBajnHGcCa&~M+)JJ1P-oP!HvR(=+V^=>)MxK(bW@;2@#{K zS}I~x21LiaHKFJx76g9WlrS*}L|q82GQ#Re#{|n*9A7)u&W<`mGlCRH9ah+U4y(-s z;hL~*F1SjR10|U_I;v-$O!=HePkqvfgovKdxfT7WUJv-u{})14y9_VSgkxGJ5fb7o zF8+k>3&e9}^RdHCM>FNy(~@vS83L~U+zD#EnH3K(%lTJpLPg72kKL|vxGlbD7EZb) z2Z;(Cn-WmHm{Wa=J-|-t&h3Z*p44TXnk1ZLEiq*#=gbBaEA4JiXMXj%YeyS=9SEtp zG0j*JGv8V?J$V~Ga7HDjXoorxlyKHyTVQsQ6C6)1pswn!-`Qd&+ntJC2vO{giWqH~ zx~SfXCA-X}Pws2T>t`6&jZmi-H9lPw93ALw=1$4~-?lAPMf~a(r+DzLc_dlF& zS9>8V#%&!z=%Q-whMTWc>c`soZ4^P5AqfZDr?B;38zrw_LjhYWqh!~0(fv=Z$>*Id z4->k$1foHwTZMv1M7yXijkz~I>&?z~?^vQCI$Ja};CeA6|KcREbIybjtPHU^i%4J@ zqKQ6<$h1$@uuoma9(2OEoKOI)n!EcNm8!&Auzs}?HEhAEgw^9f(5*4c`tjD7DO0~< z+17^H1Sgtap({evb4V+#EVvw5UDVyl!{0ehu&DNhT!IrV z+o38eMq4qoY;#Royfpa=v-$g7ge?-o8utvY3uzJ%`usKix;%wlG5O0w$auvh6_KI0 z3K7X#;e}4c3yg=%+6|XBF&S*?BQhw1Es>1KaA0#DwO=U4@pj~-t>El>WZ#NQ%xwz< zh?)p^#f3hc>frEWAC#`-=yUCd?OjW=&~#NE5fGOl@N@8r6=`a?cZUn6K77a8SK6@k zpDvKskMNS=6oQ)X@xh&(SP7&qJTijW&6IHJ4WT@^CjY`y27l-65|0Z00*y0G1vcNuFY z&$7t#FIWwUQwesMafq=sgAIBi^%ZXM5?g@bVzkfDm&@m;J@FMwQIDEN1VuGZ;X6AQ zlv7=AgY_wSmUVuP3FxBfL`gU%dX#kKhueo0s~(xnT=I`*5IVZ#Rg`YWI5wj6iGgjH zIr!^2!V~5o+G-=+D1BzMKK5QR8^bVo7Qu`0nF8MIhf&KAq*H_mfv7KNi7Z{~6B)Pt zkMG&u&Yn#~aBpj?N~<`9FlR<{#?%Tc*S)kDPc9@9A@?S(4XSF^sK)Ng50Tg|d$!pO z!%uUFQa|GZ#x#|Z2sEJ82^|e=1BccNY=GjMtp7!=22}XMN_4uJYW`^GKA-4}Fb+4i zLppE=R&nqnsbE#bWV72H7VPYdm1cu4hTX%W@W&SrO^m|J+r8LKmqyk$VV@VACK=sS?)mVIFm&C$sFYuOdjLNhJ&iO++{lnQy+t2OK@Cs>uAn zAYVrk-Lbu|T{L|CI>BZbn;Wa-6ylI47$~VxK`CJ%RTo}P%_+;Q)7@(cOjsu|gbBek z(<%8<^=06~z|zH8l+<%Q!HfQXwV`fzZ0)e(?2;&iq@B0mjgmB8J~xGvETzOla>Q}ESa6EG^ch(nbjC~y_P;j)3G z;&>wtzZgDaHk9gBfGSGku0p_C0))m)n~A)j!Uq$U@k(oTRDZ|x}f%KPVe)q z#pAmZ1S_WLS|J&Z1|sSs3%VP)aM9|b40O#YEvNz*WyqZ_=!V}LcVuSlr_%%~;&EK5 z7kqnKsW^wl%_E#O0e|Yx()gqLQnI1#!w-%nhcu=3{v(E(=ZMl!;2ch8K9h zW!Ce%t3(d1XOf}>ck&&+yFPl7om{1UBVciIh4xut4&9{^{&47+q}y(UTMn;kE*N{|7c44jT@w&z#SnF29MzQJy~uTf|XHIVJKmk8W;yuUh;% zD;Tc&o4`bQ)M=go&4x6Tg)dP0cJA`rJNfcdmKHz!FF|5jyai~E2j-;<#Z^r(2z-de z${1lzOrINB)$c0Dnmmj*iJEBgkg18t{f-G8>yU7UX>a=?%am?Up}nE>cZc6B##K>GVM80kqC({0`&33Wpg@k;l&ob4jOT|N?X7U z?z5As^k92ix-#L56GN%?dx9ca5VV8)7waBO>6iaEi+HrM{6uHUvkxw9!TfMz3gM!2 zU(G`E{37QhNq53WH3S!=p)`o}r(tLLjrnek2^393^BicYHp8@V5P0$YPvpuq`zdjS zi~f?U32m6Kl-HC9i5>&G6ofB)L46qJm6#%Z8(r*92|o>Gx%V|`ge_{qIfYr!WQqkQ zwtOy#`qp!0`tqVotBy4%bfwP)BNEjt;3OJxsT1NC?M!jq514XbXlv%vm&qg&!lxJV z4gnFgGT~9{h4CwzrZfH3AAW4iS=JHRx+OtlM`Tr+WF5~nxewikKW55#s|_KFjz|g7 zG!CZQa5n3*pnPi0R3_b>+Y_o72rQ#UduKpW8L55{=#)~Kofm~22~;^R@C7k8v6C-W zpR2!BYw!_kz*%!20b{FB6&x90iqgw1*J9+NNwK6Rt(gyWsSCl0MoR@Qjz-JBZfx+d zZo7Kj2~WI)BCh6$=A&4@x7MJ-4s3xVdq@kc40)&c=0&VHHHc4DezF(g2?H;VDBz(3 zT$2dzKz~@R|#Q_ce4_Hx4z9V`;C+_1)OL#i^J--H;yUEj;%?f2~%{6;5n$w9=ap+ z>J1vyWWM;0hY4Ny;)ss8c;8#PjT`%!wYpQzJSM&u#}dBiR6}g+6ag}JEZ`_Nm~g(y}s}K!pV2`tW%=F^5`To4H##rpJOl%=#R35T>v`E_l92Y$(@g>AdA| zbCz2NGmHLvF2RZZ2M`H-ZZ-Qh){MU;oWG-6S?2Xta}%Dp9TR#FJ#O3s#rOTiaQk>= z+)$A)23+osUhBB00doze`Un?Y1DMth`4V_Z1XNIm6Vh4Eo<{%o$8AL!WA_9k##UAC zV}W<8twBh5tWRUpM++yieJ_$vc*?%Ron7={G&v(_-{BEEX?o@LlJonqRA6c$!HG^1 zm{q7Xn!A$H1Y)T$l8r-@A8VXQuDPpCD+|_v$te z+;qo%!rWx6n0K{vHDnc%c25ze(y<7G*PttsuNR|v)h@J5ACTtG)62?3imklkmkc2Te8bR=nAj<;V_t zf>3OSqLpX|*C(JEwe-{SW_y@^`<^EPsNZnw05U4TecknH;9jdx<_A_cST~E{ux=pW zZfkqksM{r5o29cj;`nR=!#JWEuFda-@oTp;W(+XtpQ0^U>{xv+0gBjB0mSD86*I{{ zTC!=}me;#_br`Fr4x3M?qLmF%;lgecL%PM8Ej?nWu*GbZ zaQBaCGo0m}maimS%sUBOG|y3nR8$GGyL5*~;TEn3Iv)xGk4o6HH`lVUqyFkD!WKT1 z`ZhUVu&%)U^q0lXcu@BZi3hd7w!qVA(v#;LRyv+CGuS)lpMIa5&DQzg+Q@ZAK8UEh z&L!iglx28R*2{P<8$S9I2j2SAxgM6I+O(1I#PxWEkIGAv5$C;h=t#lp=|u}3Vyz*k zHWMPYhSX3m>_lcurDpJ6koJdETMw98Hw1^6V1--Nwl{w{wb@og`E}uokq$%Etwdv_ z@gaDFA@nztOF?|E(!#P;FZ7T?D!jN5{=I0hOKp|bIx95`LNt}=kM*`c#zdskTSQd2 z*%DDG6GfgIx?7r+x4t@4ZrcvR6|ty_i_fLPB``%{G>4ondSTj?cFaJ(u#3<|ix*Ar z4HMzKlH^)*)*b6Qz9&0PcD+Ne;xq|3as?zNBp`R9nL(qw*vGGoc%Gg3XWu1Mk(UXh zqF)pvH+YM5>o;E7!dgQs?Ii-(8d@(*AVzS07VcWO+|NqFz4j5fC<%vo3_d<26rmvF zb=aVLUfF_J8K2O%^v}$w7VIY?qVYi`l4BRs#V`X9EB4RcIdz%eyX-^47Je_ohQS5S zAdHW=9jElg$M*fpUZzWjJ|aq}OSJ+%cpL>PK`syY*{G1Qe!<1ePj)l0O!$;=MRL<@ zg-IHCPYpsDxcCw)>YTiZpD_*X`8lBq4b`ARBmg20`7;hL2|bs|GXrakXTDDUmxL~S z9lT4O10H`o#m*Mb{lbeHTfQbpjJN7%K_yt}BZf#5V*u>swra2EM>p@oYOmSK`fX2q zOVHSFo327*6%UR~rH4OM01}yap%2hj@~*n

5hv~)1>Jl5lHbl}o$|hiNg`dgHmXPV1!b5BXQNzf``6+>$A zIE;r6U4!Qx%!zpYG~uEX0guatQbbUk!mLTaP}sJlqEM?ohNjW+*CF&UbZZ0a`{V^?66-XP?6kJ zyS?H_m=^mbaXNDlp1Vw#=pZDCv@WC_^xi0zGE>a)ZMpN~vaI)b%TLBKbFY1e+gRt$6jAL(_Zc02MW#o)WdKL`--%vEr>vQ3({ zbm^lrGNi$Ge}CgY%8&*~vfdbR<#RuldvB(zr5&;i26Mz_tW5q8OHXR5wy zk7CbkXQ_#m{}LRgCMtutz(d&)CxE0l>C%n3LGzM-W}^Mo4FW^aP9)~*4C<0TnB`i( zxv$YV=GT-cvRvu@L%*h;qA4JU0zFRb{WqrbN0r+bBW!WP3hYqt6+fILnc+G73tv4} zj|t%;2?Qz9GvdGoUjyr&98vH4&kv~aG)s!_xSIe)?ln)ja0M9(QSFPv5MBQC$A$`c?aBdQ9(Xq^+%hcjKcf9|{-X6ApbNO;1`V>Mc|>cH7&(O4$w zqti3w0lv>n8(E5(`P9mUiDo_--kJ4)`vYejOj?Ep#E8Z^$ri}BW4-rR{hDR=v#Sy; zdYbhlSh$`DuMdFX0#A#pgaj?$l-YXSKAAz!t1dy;5zs=>_Vx$!0XNMvc}SOuza*Wb=E;Ic^)a9sgx_t|W4T}#}p zgR5>@Zb$}PO64z_bG|cU#sjq_X23;z*gkyoNEcS$hjYxEzmxev6YCPRh;9TlO&<){ zuqP5~5HRGP@1ewM$l^;!;=jWd4`On+t3IKN1hGJm6}GjygV&Smv}2uNXB!eMc7`?8 zTeL~2jONIcP*7LuwE5VAMe}u(wm=dFRLcuezL`@d@{MMZ) zh-;TfK@r68yY#WL#4*J^kR=_O_atPI-a=%2Hw8Lzpsb>?(E1$t^!jVO&tqI~LKVrl z5J!SuOCdb}8GHZhru{SbvAbZeK0vs57pw{5!W|HFW-#vpt{lA1VgdXp&X?qXnu7)K z^L}Yo=HsPZI~Z4v_mjAiDABP-x^=p`Q|0BiXRy>pg@IbG(CVYzqtNH}rH`vRnf$dF zOt{KpV=7mayaPVM7TJj@wF}bp3iWJEYHrSCHJyit5Cu`wk@$gW6E1s_C6tLe`O%i)VsR$XJhrCNsXZGU$^d zb)W^TNF(uO>xze?%5{pblm^h9$6OFJrq^%WC6w^eo$hHHC0=98TuIsV-XNIw6!f6(w zAesrKuYk)QSPx;wAlC1ABQ0$(t6l815fxFpkf>lS7p)9XgX}a6R(4E?tJV3<)mJ;R zfcn2|q9oc)$Fia<0e#*=R;I@eGG;qzypM6>fs4GzFb}8?&n&r zV;P~NUP2a?E=0yoTIDE(JR;VMIkY5`n(}4!goz(8&AP`=*r>e?5L>BV!i_4C90Nj; zXNT1O0L@6soo1Eab%x&fiX9-WfbaV=&q6dvGJ*0}8Z{*?Kt?CN9BCkS7RQ56dZ$0!^Gpc=u9%@DCzh5gC@ z9zVqtsoc{#Jcn+Ms|IbamfFy%%RX&1ZgM2aVXvki{2xLE+nn(+iz8=s=0}Bd!Dv?P1$o?L{Jm z+bnEa*oDUdCvJFh8ebCT_>u!Co#AKwZ*MbEiA20=HQ*(}7UL8ochOgD9CY z`Mud4&rbW6;hk6@;Z=o1j>t?#R8*w1nN`cSR^67$j{MS#2v1~Z0T0@)aiES4uC&Zn zQu1wY*{)?-&AZ(ag2gIQl{_6Zy;Z)@blzR;tPd85K2N|5VtYGRS4Nh7lSOx5 zewa@3#>o?)qXlL^DV^#m*&KN&c>F6bgR&TaQE z^FD5EBxK>J%E*xctS;)3#g)v=Jhpt3P|-Zrl~KhUlPI3GF7}_pT8(BwlxHhpiy%tI z#@rIPls-%`#c?GeeRIkL7Rj!7iztX786TiGLu(Vb*bvXaDI1yN)LnexFP-=R>q9#T z8YQGIK|=>mo(YD9L&Z5#ail{-vOZ7reB~XcHs#(XU{srA0><&4(FH_OM`kP^-?sfz zCV6dl6Rt2rgbNK=L2xLsb1G&t{OLPWrm+?0zbCD@HX);L8d|l!=PP@t#}ZPfuoW-d zE3G)hV#HNft&c1pM6E8KwzA;*qkV)d207-yICh^0-x=UMUP<9OO}GB=O1 zK1|2Ic$50dexf1<@!yV$(rIQjjq`hA?E+0W^}`;f|FsViHIe1SmJ#&w;3<@NYA=<1 zJ&G{>ANVn$3NIK@aRZ>w7*1NKEaSu!Ze#WRl_J;Y{>3^X3O*%l?1-p=*l=yf6U@#5 z7d}?dsara%X9wAa&k0o2p9C9VZ=sb@BP#mqoL})u9|l?ED+w|Y1`c9k1d(HY7Ifg{ zXOft!*7_TQ6eTUh!dEV}vuyO&*xu}0$8@FOI|3BvN|qU(ErY{K4CLvXE@m)tj!K4( z_iSvuF+P(Wh3gIxG8(#C(%0kc6s@>Q;WBUuhGwYY@`E?iGMI7x;0FR0?UqV~EXxD) zdIWEy_lHV0RC2tt2QIG4q_)Bdf)=Ni3=PeJVDN5R9pa5tu3qz!Sa>_+Bw>pJj*wb# zKYaL(D2-|pAHmtJzmva?x$!VdRLuLC@Wu2FA!h2R>)X-L-O_c%-Ot0-u@GyJVDfBXyc1T!xXK6-+U@jKuir&DPH zzQb=AThsR{tMi*L5e?Drh%|5%jOs(cpJ+#F>TmY)FTImksIlw{p<}2~Js(}a)sg=-amCo~NIY7klh zq@B##uJEXM~iog!V?W@uo0NtjAoQ2&EPfW z*1J}AWj#b|{vk+lFBW$jr;KH1$MN3dTkq{z!W8<#e}st&-6%ZsE`*Ee9vD3ZBbQ^Y zMc&i@Gw$a$xJ4w;l-3R}*<=a$puY~Af4MYC656U|%S`Dy>6z4d{C4w0MOP>thQiNP zp%wEEZW2sU|TWuSUQ=Ra}FkCKoYTOwqXOJqLI3u5*oWy+K zU+*DGO4=Gsgvb_(Q^DiW&#|RvG?xeM%@)ANkWa9$d}5?p*e!8DO%fV#uTQUPn98T;fL9wkqdff#b92d=sVR@0Yf9Vab~QB$&!pq4a7@7E;)!H zH?JfiCu7%v&kbn`l|;;I_J5zod}mt~LKcnkDj+l_diC$rxY=7zi3|QJ!mQt>Y6Obb zFHr)8BD+_US&|e(zAvz`!>HCQ)%|A;!W5})sc8gkUH~S3!VC`x%VXuk(G#zpV@_o6 zS_ChgNL9n}kk;$Qmo`socl>zPzenn4I~a&}Bq`FVZo=CRkAfq%RaQAGovaT}`2 zRgy@*UrL{_9@XaEUv4*szRx^r*ABy~CPZnlaiFoqZ78ASLGN4;MviHWh9n&PvzF~C znat`R$5V-zc%I>Z5K|5su0gt_UQO>*o+T-+Hz!hHFmhor@>HGy8|?UVxgFm2utQy`iqr6 zDzqbtqDfdCq6|6C$TGaHM0{K1Vj@e+_Uk~fA}uQ>K4T$(X5wPiN{oPTYK|FwR&BM; zx3HAL+)hM6_>?XuOa{#fx^tD1SX6MyWJ_LH;N+X-*bblWOu)(x)4MzZ92g!IS%2By zH>O*&X@_^N`}7+_gKmUA+BovICQcP65iCYiP>F1MtIRf1Lbyfel&ef!Lfl90=GfZo=FGcx3k7? zVSDyNZ+XvD3QD|5bU3oaJWT{p$!xEp?d@L%~W(~ZI1`r9!wp{jLkI{VUw|Gp(`5Pjcr-A&k9!MXznE- zQRT3ETSPTBk~Med!4Ii&FsS$0Y-Vbv_y~|pO^Z7=HN6dvjh)HDrr@g%y8{G`$xhIU z7K(`RPnVXeIe022^!ZKF{Fv9e=h=^F`*9cL!W>`;)!O$iH* z5C%<%GL2fYL0r&|B>$Yp>_h5gi5+4#rUw!W%JdFJFZjYGC6gNX?aeO!#+--jsRWD8 zLnCF;Ay~T5WQA_PhMgft&yvWFSmX%WmXTR_la!T~{JbQcHcf8Vu8V}iEm6#_*; zt4E;N4hl~ufKkPj$2>Xfv!A~Dnk76>E+%kHc-C~mm`pRC&0wu#?uI#eY)BTnM^s}e zVd6ca8VY9g1{)~`llu`jZ0$UO`Rt>Y6C(QTi3%d>TC`g}k}62K&AA5}Kh3sk(Mp08 z>0=cV7pgG`x}x=K&1ox{r8%*RKt%_N!%vOs(dNtPi~hl$i@i*18@xfNB7zH*&y*$8 zL*gg~nk)YL1RkCK&U98ZeRwSq5CfA6%`QmL+clRYbeXVWUhNl|Nnf#^&_r1agFyA* zH^U8am>~px6D}Op$5mK=#ew^_O=LOsOB;y>=F~wTs1XMyDMjXeq$!aQmsoPr>`-D0 z$lXyN6uv)`mF>sAN$}X=P)op5E72X{AGePAijQSwzTRb z6))t|aY>OQJ{z2dK94)qDNWj_G0a%V@@+eI6FlbIk|lVox<*#>rFHYeP?4P<|^qELdVQinuP9wSHmDJrp(aRNs>rT0h`Lwxtn1x zyNfdTGopgGb(%#|3Gpg=E=G?`xumZbu;BpHzaTO=93WjK1Ms5x4{EtEy~zp<9q|$R z>gEeAm~njPDendc`v6hS}m!iTyXqk_wlBj2w(IkA_K`4>fw6S@fURCLU^Ic=g(JT3y4 zC(Ztdd)z080F|gi+!2{3xU%5{+0?d>Np74J(Rinr_ORBX;7NkU)}mB(?cftbe~!4Y zi!HivUl>1ui%oPA>sF5}&Q`zcXKD3yR1&&(qomWr4MR!QY|eyYTYqNzTI4K2iwv0x z4Y$AHyP8RMBB74_rYMwGweBg@fx`=aLD_Gnv&dyg&qm z)hAy7loT#G=0?0OuaE0I|J;6dsBXPPxaeaff`C{|s_JW;bA8yCmkzVCDB|}k1dCBb zEy_VDZcAj!C~5;x6kgzNl;oky!UBVrLIrIq32m1C;x7O$2?9szg^QhslU+iPrf;pwfDSNf{?JiCk2p{ zhC_@iF=_xKha%@PYbZz+7&lIi#@Hbz6}rE(-ExtyXs>KQFB8dy&bCw=999n%5z4 z$nUtsxLWZ6v0`^=TCEd1eOHwyY|+?D*uZo#o&$#j#x0KXO7yzI*nXs<#P-ng;Ucbb zn^M$S`SK-EbVEF*vb}pwzjqjPxyn(X;nf_cQd5Z|2lDD07mO@^>r$1VMf^sL5sQrj z_4VOns<8uXLUlsL0}Rs!D1y0=XpT8qvuj{t7v>W0sY#INY}XD=*i*z>v4OhWS9AQu+KS0>~X#SUBl$Ab8W&FH$)<_Ex9_%F9`=XOB_!Q11r;Iz3{RQ2bK(?C=;pj z{Nse@nlX3mazjE^+%ZKw;vW{cy#CEGEKF|Lm~cgd2XPzk^8a5(HjLY|WeL;9(rH9R zxC}(aW4GgN%J>qfw=8Q1GcV23oS=n2DTEtRkfA5$b@{DHPJD05U(Qm4XH+}-+>^pzq6&> z3gLxb9Cgn{4ch61+r*KbtI=0q_3?u@*fyQ$M5J(=l9UP2a4$42(EE=3JAt*`*6l)| z*mj$q=YY#AaBjrDc8OEm5VbO-_GEU{)Qu>i9j&IHhOQcR)XOMo?kth#mYY%2iZ{12 zBlpVvge}?-!QSI@7mx}O^=RdRFHN{5@mvzWF2zfI$xP8dJqZ|1Q5}5E*n20qXX>0B z$QInbPt*mA$b?=bi)q0hym;|5W__mgC17EF#O#DDc!f*ta@W90W?5q<-tANDV|Lf@ zxBdi;cMYp3M{A_V*dV-pHT$8vIHb^POAD3SA#$Q;G*$Dz@PoArI*{e~@U7*84ebXJ zopfU}W94vmlxt}8{rmFya!;_FWx<1lg*nS~EVT(?!rNtD59AM_gcEw|80PpvlbfcL z6P?UC=g6jPe8>8sL|8EKsX!u1aMgO1+T; Ug(1qlxGKnE$AVO$SHl_q53@`WGXMYp diff --git a/deepseek_mtp_dummy_ckpt/tensorboard/events.out.tfevents.1769037919.pool0-01476.3254808.0 b/deepseek_mtp_dummy_ckpt/tensorboard/events.out.tfevents.1769037919.pool0-01476.3254808.0 deleted file mode 100644 index 007f6ef83d639933293021fccc1687da06225d54..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 63278 zcmcJ22b5IB(!U6Tb3t;3gzf=JICv&i}o0zUT3q{Z-%Us_IHz-TfZ*KewLl zzRXlF^Q&^hMjq)<@qtXg#py}yl56(_TV)JLRT>2Co`Bb%>9zQ6sezE&ZShY?&GS2h zc7LiJe`R(!U3RnjZ>*O)HET<_hxfPt#q{OF^U5P`y~awqQqB`{n_ZR(c7LEz&^|tB zCb*78?k`lbK*@XW&)4zZ^kzNoR|+Zx@c;i{yKN=A5}Lp1hp(HCWD?qgN=v1x-DCCI z?0$2CsHRJGafj9N6Z)&hOJ9=o>MwWptRwkO;=<^Eh`DwIvQ zO_jz<`D~}nX7`u_&WU#Z@ef1*^Jg@bzP@VJW4pT$-mhn~*O^9-CA?Nj3#GEd;V~Ox zN{dWS%V?r4vZ$xtGRD%QE{Ubd$Vy8Pc0B0x0=(T~)9qB@$YS%g&`^Inm|b)rE%kvZ zeHNN#cnP(;(p9OVA!b^vId+eYGoWGwvv`MK(9==l-_4c<+?WYiEGo^`AKF{f~dGlf@gCw|IyFR@> zw@^<4oN0Ptv1#&D0&K6eRjLsnFalI8$LTXW{PuBt!wjJV+t(iHP0($u2i`Dknnuta zmG(*v{CJ-?;1nuhcW2sdHm4^`w`~$QSGb=lALz=!Z#*W0M+3)Tv;^+=`XotvWkH*b z18C(ne=W4p)ZqyN@1}H7>ig|3OVF8X*F;#@%FGZbVNmx0271zB_a}n}5e0MqE;CKR zSw!I>rJK?q1_fl7%^3($X3hXQu9 z-y4Ku=YMa5v6QcUCHJw8L}bEM&zq)`bBRcErKwV$2xM9Ux1psEO~i)=8~ni3 zyyAR9L^ZDzUb5Y91#`+pQRYOd1lgZQE-=5PX@on-&OH2E09@Uh)`>#m6D|HrO$2-SS>ElB93aj{?&VqjAm!{<>v?% z&+KH>iZEo%^-2tHAv65?=RHm^si?Dra8W8Mc)TEa0G)s_89ZR<*UVBomJ%W!EY&SG zTMQ}YY>=2c$RCo_y}(sZ7bbB_mJ>EgTt&Ca2D=rv%@C6B*RsQt>i4JpTKdZ2EvBNc z5M)=Svr^Y>$+4Sry;e)6u+^a0xz2zy6a2l*37mq&VW!-3cf9dBb2w}(hyeO<6+K?R z8zf<(5H*e{9WMiV59|X!)O7FR_j6}3PV9SK;sorQ-)*-00#28n5t4|d+Z(1e)gW(l zZharqy+W%9Qs|z==W|U^omYXL?&GVKfOWdJA)j}SnONMfU7aUPkH1O4gOmYEQx({6 z4>)Zhi_2{Ff)gWl5alGOQH+<|yleZuWFr0P+eA(nIuK`%#c9s6yF&3IzBjJr9>(me zYh-2z#)T|?yUh$|LEMnTyaQ8P6Hj(+Trt2@Z#_Y#D~*&=fjnoH9?VcH((0wwXCIuo zLH^)zA*VG5zBdZcc*jZ(lcSl7eC0zzMHjg&2-bLW zz~YE^kp)ZxpJt1!vPD{CQl{M>G-or2a%%JJ&a7;`-(pCp?C1jpnd9a8gb1NmSKS?Q1)ZQ>80%>M5&sBd3Aysn zsRGQsU9p|O(Y>wY4*CqEQpv$@tXZGECx=&J$HwK)2oR5r8VKm}1_JROa%rhj@X+)1 znwiYOPyd3jMIh<++QWQ^4o)TZauNQu-0c6DbY$)(P$9xvL}G(e2{w>(Ryr-%QghDF zo6k(mQ+o(i_*4|{CItEUt>WjHdHC#Wf)q!*DpY2O`YpN>P5N|DGl`!c-bn4$myDOI zTi<6)zkfrBot2JCP1T_1*~~$&d0bB727Fc7d;erPlxq749&;#VtuBaeviy35lQ^0# z8->gno06GP73jS9hGsSNhb?#H?O`15aZci}$_>oIc)*?pql8N`(2{$2<(I9=EU`N=4Rb%5-_HIeAVwS5$vX=gZ9QAG%1es@+FnfJ~_0A}l{dA~Aw7 zNpzX&-`BSGb!Ml^UM6g`Q?($tk&PdEo?x@0vOJQ^@e||wLM=|CJFBAgo?VB#89L1 zk}{^0I^6YUupY}qee@4uVkRmD)yy{mB4%qgNUh+3E_gO(!r=bTz0QnHwf_iO7@P2N z{T5G-<=O@Nn8FNzaDe-<8SNu-k+v@|>VmUO$*byW5Ly9HedFf8*V(ay{x_$faT5y@1y z)9Uw{Gc7^jcC_~*sj_r_k3G6}I1`F$1ql*`0=#c;CIn)MVg5XS={jZ}#@tVsXdbHg zAwqzJvJs}Cj{Wp^s*9!Ho>vGK({J^_Yw_D5WhDWdoer?_c3b?pp-E}0a+rvGTZCw! zh}5#XeZdLlkO$;CA}B5>{J8jDO01Uj@A%|d1+8?}tQar<{LeF0*%!l_FrYcjZ-JEQ{$>Q{s#ygxe-geVEDh+hqt%%cz0u6pm8wXPBI_<7Z6UuJ3y3%{EKW;__OcK*^~XTYp>Dy4vlDBcg!#Q#B%l1;ny$ ziw6@#Svo(+81K^yC+o1duV78W!?>?%1kXSyB)OF|dQj;FNz8br*CtTadkrtr3^h7q zmF3>6;d(!M<`9epPssCX}v-EZp?s*qhg`oJg$ z?HEjXLM{uGv~@?Hp{8Bxw(Ax%%cWBZ8qIQj)vTcqcr2JrS3`HqWJJep5)ICjEmtd^ zVP?5!8qp9Zw?+e2o+Um)U#m5ijQrLmud{mXv<$+=dTm`S-}s%G$eRFG8Ne!sGhRCF zHYV}W;WMS#o_^4TNZ_6(Y20_)gV|o=>=E}l=ie&-J2M)8HX~Ft8Yvp8QBQ!xZjNkT z`{8@+08ej8$m#(ekrwdqL-E_>4fRgH#&$9AAi;{m)#?jjiVy7sR72v)8A+-byxm~$ zP_h?iH$OYq^m-e@?xpllQmv5u7F7|b=R+bizLo)L44ZUAT)%?^l5nN zUDNS)M5d?mkkSw`*N{qs<9aN_cpxiqazS~~FSU|k$y&U%wPoK#w)~wK*a79i`zg|DMc+s}dn9mRHOqs@(9qJov z*_k0%PAp)C}YAqNU#Au9> z;5}BWls)M%_`QFaRVg`y;L)m7#}E`Pi_-&53kKCS37bpavz>pf!Tjd|!w6f{!EAQ0 zCgA455#hAPNOu_T;mUOrOE8`+93k-%CP!K2q#!R5{R-+qqdpAYQP;^Zr2kM8}sMkP~OGr=v&r#zhxEQZz=2W3&6lb0Qd zY)3b8gg3u?ogJ+?QhiE4!t{mucu+J z-i>T_NKV=9Ic7}%+o1Jsg4vqm!^}sdBTtL;KRY++TjuU`x6p!xyW_S5a=v?%IlIw4BKf}%MVs_~9h8l8QbDc|dMf)zGSz^awNn81<* z^X0x;qtzOAU>(dNTs*KEpmxAQff;JxiCnph5@zOwS z;HQk+ja(A9t5CgIv?V06iI*O)+>+_X1P{TA>Y(TyNFe+7baszCc3!XX5geY^wbU>k z^6Y>MWpj+_6BCa5oyYBDHQ4ih!p9n{4qt^fJ0V);=0Zvmdmx0!l@d$q$GuR71u_*v zL_h>GLP}gBB>_EU9Mw?r&R&{+n>iWn@(5En8QArWh*Q*#f#Q?fqR4L|e^YYh&7$ZwRDlxK37%`X)-ecV*S9+D5!jH@#I1!;m;KW%>UOxYP<2`@2`=USCVfo1ugpK}G zZ4H}np+yOUG<~~^%n7+FxV7N5Wh3ZxE9HIfW7Ex91m9Qbt)!vws~K#K5Qei-xDxM& zEGf|ZHVc*O&LK)7gGv6jiX7&@*_M%Q&d$*7sxh-_oa$bSg|m)%L<7UwdT?_3Ao~nC zSj-=&|DZPby7ivEwAb7#`g_`~*O?A3pHC!2y3Pl26*321L|uq1LGokjc+1{SW`Mq5 zNRYw+g(1ayy%OOPLcI>y?Q+>&p!6Ag z4EvRRb4O+rtE?teG>Yjyr_YWDw_d_U-r#@5Z`_)k?ozk~v+lOHh!R@&I@!?Y1X&KJ z7gLf@?Sx3d9W&cukoml{WYu0)Tz~oe}%q<-KZnNJd zitq96-l?0#CbN_1&=x|)lPDR`AdLs5#efL+ql*R74$Aj-%I@4UlSN$xJ|!Z4&M=lsFgPO zB4Z=2yjpGm%YM)QoDfC!TdYxSMuofq1YLQ0A&^d8-qO}++^0zl`Li$XfE@CHQ1AsC zBER^rGM#C0{x1nx1TSLEIZ$nf7Br412`sEt?!T-0yZ=LG-dcP`kZ9g&3dnHeiODL8 z^Etfa8M7wZ*=xIXVTXWoFM$isSb#%c3#Qc7{IfaF8HC;g$aF&vJP$5Oayrsy=z6R2 z4OzHpG&xm=)_Jy=7VRTS1C{%MyfM0L;+wFO;^PpK_n3*XzdisqpvG zK2;}@(A{MBW0hC^K8|h7y{$6~c$s}^d62+Gy(SFqfqQ{okHcx4ALnZ>_(+e5%zkGd zCRnuJ)nP+yc83KE9>@%=ONJn~$1bx&5Io=*bqY6DhAXq^^IH;-(vU=?fE?$QA$FX(u zH1<`t%et@9Y7WpB13i}E1K#b7)ed0%RXsWkoYseQ=Ccp)d&ji?52Dmb>7dkruR@<& ziZ87>3CQjrtsnYw6ro!l{^&!~?^g+ZfYMKC2IzWZ=$eJA9u&^19v0lx;B1rBzo}$W zF8$Jwf6BWzq)(|CilDYmN=Xa=@xgs`IlBQM*V$LLPti075i=!WcQKmRK}>Z zTN0z10J|Z^NhpM9$-MI3bRqKFqj#_G&II+?{y}`w9>x^b}gchlLLS?E_VcYCR=j z&dbJ2q|ma(rNUT$ethX#8grFz7bKGCd{^+>9T3#~k46z}4ba@3{MAY(vy@<(i4aAH z3w#8)HBJYmGsr`PcJMrXM$mZCwcBi-!g4aU!bCz`odgo$JQ5VgJk-Ku4niGHXO|_1 zp=r7tK$yEVW&z)V2Z(?O_<(@8>7MDuz9=Y{#>?%m|NMSqE0Wxji!Pip9VE{tEjR&3^vKz17=DEd6Nd%8_Tw7lAcuTGm zLJ+hp38Y=Ge%GrWOEC1TN(3;$P{{?ol6dxD33)qa2%S0;lh3(^g&?yBM59JFgJ-kbYv5Q0-suR4p;TKtS+zr2n1_VkmE?lT7 zaUm%JD4d?Sdwy5^*7q#45mo9Cs;HqeF2q=>q}g03=CrBra7bgq#2^rLA-u{ct0Nr~EMsANd*8S0sC%Xu~tL{I44ihfj|7yRh|3!y7Z9=XMYqgE>- zB*a-<{0ZL|NaV`q^nN*4nesVX6Rs#jz}26-L2ds`a4_=#mbWETw462A?JB3o>W^jN zq+4>3sKBu)LCuS?&HQs2JE@PhCjxj<7jx;7aFVsel$o3}E4FL4??F2AD^`Ddq^W2p zLP}Fol~S1b)}tB8TTNH~QB2WVb|EO?tiiUx>?9XBo_at-)$Q`+f0x+qWOgG&u{#=K ztYsRaF4=FIyO=(?mnpZmX~9E;Iy;{-GaooQ(A&(NlK;PLTbhalv@I?_v(C@$X_cbs ze|H~RB{uz6#5KaX4k@yzc6!OcNa>r4D>r_`40oUYgeMHQ!1F`06*fFJtRX+EuDqVY zOxN^*gp8&OCV7ZvanydNt|Ds)Sz~`+Cp%=f4JJr2B`ORFjvn79L%=7mWqrpK=GssK z6s|4yS;vWVW90ne>}%_c(0wg}CG@pV`rA90^JyJPkfOy|;JI;QFm7-|!xR)q!VQ1Y zusp+$Kk>uQUo&p59m-pjpFG$a+vo& zEVz2OA}hvKwGz6hntR~pE0y}Oc76v%Fl5NUAIfC1^$)ho>(^1h*2);!bwl(v`8pS8 zoh|dS2whwP(V;_linj^{k*Ic2LmGdc$ajdH?OUBhLv*(2Xu$R2#7Vpx$Sljs5P#(m z2`od@HzpCC_9-9nsf#*7E*O^+4uCat_l3(klC@x^dWagfV3on@aWLf3nPua6Yt_`L z*IBl8w2$CK(<^jEhvwXK-Cemw+(lZ-+X|7qTlRH75`}9WI#0@b6C@ z@g!USZ`0)UBfMldg`nnn{ctBIUIPEVKJgf{o0Vq}0g=we!JH6S2O;yXegY-~EOF;Z z>bjmB%9cOi32FItG|s57Kq=hZ#MdwFZyz>@)m8Mv%d9jwZ59#0(qMVPkXTDD9Wh^u z+Wd%1(M;F8;q7{})`OL^305?G3s|*W3TI*rju@`DM8!hW*m}R6E3LN*1IA-8PBttL z9WHGl&+?x63v8y_PZ8`8WsuUG!G?U0`iiu8i7mizG1g}o%jJzGy+^PVb<>4JP*n3& zzO!>dJ=Kji*qG9!`ya{A1l0KqQ4)@c5hX*$ar^L)DR&if$rmpobactfsojnVY{a>H zTJK=yVCQp$C(J>t)keEf#>^-*#@3pRVfgz6f*0d61-vByqn05^rw9`QQGdu9UAi_V zGN9?NyV>5REF~hiw<(&^s!k!ynbDmwRU*n&nik{DfkYzY-o&**P0i*!l3ZXIiS4U% z8qGC%mlLJlN>3$Kqa*?iXmvtI1KYr%^#U89_$C{F(K{c^>1!i8oy?U_m{z|`bcQQK zl~zax?!c-Jel!)Vs+jE#htrCko$=DF_|MIUSQP%%3ZjWocnOCuy;;0aFZUWdk3~8~ z-ykRv>0~^Nf~K>)w!K zUwK%>_s<$a7pXasdkBwU(;-W#64vr@Wz_F|nPpF1PmpSpN(2s?h;Se>-+YM=IC@r7 zk#9mTmTN?JL*B(8p;P zS&Xn@8_^M!FEkHk7o-Hlh!pB@i}xGWpWJhsHC|l)gs??h512dAJ&Nj0+75$ip}7Ob zy`nEHrf|he>)mgU_GA2NyhGwws^5+-s<11WA%_^2;U4eswV>>@z6;z~w!Y5lK87`z zxjrX4*a%oV-V(p07g8eaO3lHKF}+y!1%abp)Q!KS9~*xAs+KiaWcBrKf)|(fG3SYI z8#tVP2oxg35*};teUmaf*sdOhZAfZ*&q{Tf zv-A3XB8ARQI+6;^?r1& z4}7@zzi3{@$dT+IDshnL;6aoWzJ=w>PzM(9)f+3|`PJXAW&7UuFu}6RU18X8ZOZ{y zqIILdjA?xF&Cb7B*!9$pL_>sK@L^jSBkmn8Ceug5X7`o4oEYd;pZ*1E+??f$hcU&hU*e6r+3o zo3?KQwD>P-Fm)A7FP0-u9jPSn!)~gPJ*oha}lyU>M4}a1-04x zx~yj{9<_cYSTRl42FY+V5HTNFFxntD1%-AcJ2~@=6 zxKbbZ_Ow!Q4vU*dIBNp_Z6H(Uj~YwK;8QQnA43jlos^wNOo#s%7h zdHcZX3oLw`Q;@(!dDLZ@2+f8xl!Y%)`gZQ}boyX(f0h>Cbw5F3TD%Epju+;o3&qt; zFbI5@#p)PgPRy7a*ZzE#zY%gnA!?$@L#8Go_q!(a;-ifkG3~AY01*<}8%~dh58sI3 z$NM4w{+qV0Em!>}BqH9_lOt6y{s5|63_1D&oDV%sV$Xle&R3QqeJq>tg zsYq60ZAr+9Jxe!tQ0fW776{C2YRR9y1BbL`+EXfp2#GEN^zp=Hb2@P0#TL8{8g;lz zTfhzOpIftDV|)8hb;1`XhFb0ShQu2%AqTjB@$NzUvGd2Uh<9pD0u~W(Ae5Db)5)P! z3xn98#gs05z&}CtKkC&yY5m-2o?iXm6HJx1)h04Rm2zOF669oB;6(#Z0G?QZtTxByOg7c zh! zLIgW0B%7igC^bc2o|HD33+V~l@{`O?e3e0D&`wl?FiKp>i)0$2}C zb^9rCg^T{Vlr%H*mA14bLZZijE(PHWUr-;0c_rp(-^LI-uh)j7EcaflHDQaIaQ1{u zXfnkD6I(tPL}TmOy7@hmnO5~|OXzB!3q~ZGS-?p&;!-EXFZ!9{h9A&kiT8Wv(@$wn zB!o{dQmSPmk*Q{SQhe5iOq}nX2!(!OosLiCOl!_ z#SsNObb@OVddOa}A5h~U8@n(1jTzBjk+l=>1mf^SJ=K6_>nli;( zck0>`*x7+o&pgAVWk--eg#_aI@Oe5hhuT=1xm)ReY0WTZeg4fQOksW8@O+QhP_EI^ zdCTFd)T`E(S@Z@I2u}1rfJoqTt66ujX8bK-|H#VSnb+%_M0nzMO!z(YgmJfBDm|a! zzC1<74HpUHz~%nvGbbMlG1uVuRKi8q0H(D=z64$p0Tndhgmjj(r_tX(`HqP(w%T-w zvE|kKSm532YY-A1>(dyvXugH*yKM&HsrwFhcF~8?<&30#hezzB>6Kr;Hv9-n1+I94 z;6x`0Oss^b6Y-m)F{KrUL#k%q>ae);{c2Qw$T-#h|Vr(XzQtseu25_4TLh93CZ&$F1m&z?`v zsP7f+UbyLw`-Hj4crh;>FKA&ElAjh5rrNOxgV&%dlCKw|dCe}yOX+5lA|u&3RN+~o zgy&El5!2I2kr{Hj(4^yV)2L~Z(tb*@{Fw*VtYuDF|HXukPFXE)CeG%EWRq8vL=rZ& zjaz+DlO1z&o+o&buysS=C#DN%hh50;gBq_PbMSVWq%U(zovF@_xy?%mUmSDLR0D6t zTilt>tdKVZ#dauKiFR;f0=iL43y%)sj|^U1N(4~9;no3URDk=2>(#*>Y+w6nRyU~i zD#2mhK)~J6_OMa6w^#qSpT!Z6zeZpfM^wPI1$;1m?M}vw0Y3I|)@J7S?|6d%MeL{o z;`4&4ndBcW*)(p;SA2YR9;>G2Urnf@l?_qh!fqBmtr%mihOkd9+WH7HmCfEFSWz`W zSRDC@6ClRNxzgKTU4Mb8+LU()5>+h;AyvFsQxnQY+Ek0Q)_&PmbJ2_BJU5h~zrR+G z4*CGIFXBO0e<`6)XtZSkL%+ISj~;-et8;uj=;|*e^ya_Jaj&k ztg$tEozV{>8m_bG&KX@8UjL6}JhvSm{Ye0CNZkYXvmDiwPYF+4k5~DqzBCzi-b;s$ z6s(?`^n#bQhHThDh}arZNxQHUoh_A`!Mj7+A5v{yVQSqV9AbhMZcW=hytVk2{T}r{ zzwR61G+p|fXpB&X!5a+Wzfdm)@x4lo$X1Q;LkbPy#f8W}iw3)f%~G3ZHq3+&O(S~x zk9~PeP3rF=qQcFVh(eht`rI(w(#@aU|0z>$+n0nZVo?njpG(CtUgF>A4mm^gKgU?E zG6TK*D?%47UNpToLWJ{5l55>rcX-J=&Dm-4!(M_Fr%BM6BOoy$0l5?13>w45ni`qc zu@m3)Euo6MOavAEqByz1Tcr0No|d1rhW7i82w-byjR=7_!TDKOJp0J!tR(#0_XI9V z!Vw;Wj}Hk&D9HGncIcj0w;*1|&s+9gK4w%;93@j)Y!?GV$&Far=TcKupMKVyDx zwIhTr{9c9)g9}_C7$0#bP8o~O_2CUGGF`GCB}%AERf2wa90e*tZZG)RsF3k~!N6Mo ztzcrg>^R|yQEsP0Fj6M8HbmIp3CIH zl8+2zzD}v130?R)c$YdGJpM$Ao##9w*RgoZa)ux=-l~}im0-1x7$QxK0kD_broWya z+q~D=@XBn~Z@cUlg2sN^bPXD-cyMH@J^bMUkm&T~d@&Z1-{%#~V!`df^F#`R+cL2~7m^nRW=i6M%bq_BTtIr1iZ-aAW|3@EV*59}-E**6n%o z0@I5}eZ;dWkm} zuJvyTw<@$s*u;=pJPs4#!;B-XzhX|rjvIuFP6RwI7fumDaSF30L9Z_dCYM{`%~7~{ zrCYw#i7P(@`R(_0A7@&f|6ihkT3th<5fF3p@#ath#XsMi1AdlX?{SOJMRHT!z(m6*p0hcX9Y^}U3N4nkFt)`hf#(Hq55W{Nq!>oQaB zWxdDW79dzLa3lvNVdXyBwh2Y5z?VAHk@MCYR&_Mgx0D8 zIzZUi=&`$W!Y*0u#+vWCZAppOS!!ZyF@nR?L|G6QcqlvS1dtRbUAhrB_Wg3jnP^`s zL0~A_NyL1EL0!@Zvs|nH=}JwQUsJ6#fudhiL)8?JLxG+k_JhN2zQ<130c8nWoUj5r z+Fj+S*y}b$3Mq6uH-2^}-co zC`7d{0Yg+iHTz@cQ~z6;fYGO}5e*z;*ooA6Cio8%GBw$nu(|C}w(>evrIn`;0lb8f ziDyi7*e9_h*D2kse`EbaTiM|!9i#YEA3s->5&%v+CTfRvAtVRop5pQ z;6Y_b7e}vBnxX}TVVD4^abR_^TUS_6{z**&#AHNy0T`=uLi%u~3tiT)9>~o6U$qHO zn0c&5i&h;t`>Z<4Bz<&xhCIM`*=A>PX694s5hj}XWO!%R3+@m08lk<5Zj6#_fqXl5 zUCSEBnWx#U0l}iDSwn(_>xuCC02nUtw75!0(DF^01Ho3M81$4>3A&nq7J8Nqi8W|m zie9tCj0|{@L))-pZ$Gz(iR`L0f)*7h0iEd$f(z%tm*}IBrJZ9HsDZ5(^baSW09~d1U-)c4{bXi3FY+72vTsv3rTUG*#)v_y1y+ zXHa{BM9WhSAu|$sy6TJ0Ll?1p#*B^xD6YEV7zOZ-b@wZy#W}XjO}tc%3&xht1df7H zhrsDNZNfd3E1Q4$4dcc6t`aZGqlcqX(CwBaX*|Gd@#4(I46w#S5@02_7j+KQ4V(x2 zI4-@hzrUlaehH?`rF#(|T8U%?bU2`gHxBc|S9-bsC==`a?A3>WMdv31Q}j$c78wP% z-R8u)WAyeJvgBHaGn<%rOz%g8P&}$f6M}m=ac4!Ur%kr9^54e;2wW}yMSTTbV&bef zwo4DbdnJuoj{gP`zQ|nS;>8^TNHS#E-MZO=#;}LhWd79!)jDbR@nWfh%;Hs(lX zq51ikpLJ~%p^9W&m?I&dbpky98Grw4+Y+m~v%6sR#}F>w1*?y^a0djP8O*zYD+jN$ zSOGtV^CdZ;me6?keL$L(+4uRRzKkoDOo=N=5*=H#TW6>{tBZ^cFs>}J>A6Cyk9Cj2 znA`8)__#Ndzt0_nt3EcSaYfBL;3I6&otRR)AWg5(&PMx)@v~V?=SDVB5H%f%ADA}b zvL{(WshQU~?B90Gj?~K`#E2asUO<2neoUGJ8#APFp(k%J*Jg~Hz|n=QCM_4w1fh`e z9tpOumUhIgX$)LnRBh$W2l92paONkbRDmMx%^F ze15{UfJCF7#2LC@LdOL@Cd^3~bN|WG-#@~{F(sF1M9d*=t$3oCCW1w`9Y_?gjYD!> zYvYs?21_8Hj zOQnD+LlVI6KRj27F=Y7yi6JS7s@@lZt9WSY?}RJoBmYj`bd>q&yPhUokvcw&yl>E3VFSgfGUZiiLw? z6bS-QjbO%Rm{`2R{y^>Xe8z0Q7l=$mkx1AfMFahvE~iy@^HgFH_NGd{SGnTnMzR*e zi7yhidhY{Rq|WjZe+NY-)f&XR^zt81eUNeJ_%fM8;jswPPDJ8RycZTPnHdcq8_0}U zg_ntp@YvzmBzTA9u7;S@8`a@URxX|RDj|zOGjFLc?V z10~*%D0JLah}p|cuMsKSW?|DJE<6r6@xYVQ_>wTkmmEOp3}5;Ah8N9LA`!1z{qY82 zi}4g(Qeo&XUd4;NWCvSLk5zI~l2qx;b7#+}qAZlpSWV19V?k8oU4$)yC>a}bOW;!aF~yX?l>}SPwE--W9r-0u5J57eW#Igfgd{r0X}I{( zBc5?G37PQ~L8F9JCur#4$u+~UaHu#(D~@z%NY+O=zK7TEHtZ!}RGVZ1#_^u91w>Lu zW-JE|YUE^+cVQpl3Nu8w(0~;JhY~xd;x@xGf1X;7DSEZI&t}V{=zw&WH6<#o+;s!vUF`TqeS;mPe+{POFD<{5t^%Cod=zEf|u_K}q zV#Bo^ZzwApT=;lFuPn;B#SXGZPZ6l7KM6L#-a;#*PE?H7xok?yVGMG^X$dk31`c9k z1d-!@7W7ETOM95BcK$3uijo#$;VYN=SvJON9Ma)HJEkjD&J&7YRO?;(^Wcck?NOWuBQ=dTdz0g6Jr-(I%=fg~ znuwuB#XM-4$;5_LY&Yftq(m-c+@q<~k}q@KXT858@)I2P{#H=oz%!1DUt^-BKL6ZI z7Ou^?kI*put3+tyA?;+*cZJ8qE8nKXB!;)WP&A%+nkwo|r6Uut z`OMGVSdyTHpQ}Nu<{jK5n4-wRW83I^pIOt2DeuM71T2&n9;p|Xuti-1W^v)m7LXqa zz#Rd7?~k;2T}EWxz%&;tZp9 zlun#2OO9+{K5(}RL`h9sgNYE?LUAg1y~a7V^o%rl&`zm3r(R|?#PO8~SJV)p5RBR0vnnEKH@0zn`_1@2X=Sg)4Sv1OPfY6xeGrm*fVQ)SC z{mvinFzfe4Jpx7Rmn4Bgk=>`uEJ=zX-xrvi)^H0;bt?@CQ>413rV+4tL74anGdv(H zkCzXHA3c1FS-#GV2wpgmnug;cZPbk~ZJyNb*lOC9ZS3TKES;c5haZPF8oD8!G|=)W zm36iUl}4-{y1p@?i+ZSrZkqtdEAT}uK7nH0xe*_fWy$ArO^E;|pPLyIfN9k@iBkB) z<}Bke9Zgua?|gykzd9(`^*=astf@i^BKNQ|TxoL$ay|4PHRj-mn`~6b-g!QQAZSEczMc-$6_@fU^PqijW1C)MBlRHpC$%Edx5R4qt z84XD|_-Ez(qN{%#XiLPz^9=ukn0nA~4KgIvk=T5_F#? zSpzwx)A!i4cihCqHoeb|Gj-`i)H0PZN~b$e1MdiWib+<8oo1MKlR)Bl`oX-FT%o%V zapB6VhY@&RP*5`bF5}ZFeO9MBne=HB{u|>pZR$qoLzTfw>pOfJ3}tmVK?5ijiaUug z~{mF`x3OK14zIlx`PH2F(t6a@3MoOmN9$OJ3N^A=3c1!(a9zV0DM-T^<1r z43CLyyzKgK&*_y#JG}XwuMe2+A4upUl@WI|ahf=ZU@?}0N@UZUHEtm#gxfBz{yfu} z9)pQ2>P*)=2$oFNIbQqVSWdcu!VcN7tOgT?M~4z|6ov*G@o=<-z7Zq>!Lfn1Tz7Xj30RMW(c*$WR7y6%{@dcA&VaI#>KamV;5 zJ)?0G=~`EzQhyv}spQI|iGb+S(v}bXj?f;X&b*BbWRCYUP8@FVc1seyIfa`4VH!Aw zsH7?ll;TlsQijTH{UT#1bIa#u5)iuO~>wcZzI>|q!1t~v->yeWXtZWvC*};o*YA`-*^h$gvixfoD21k|li06lu zdB1(k;`l$u5gf+xNpNm^0&qq^Ml?n!QaQ5LlJXDL7(iQc?NXx-F<#dSO1w_OC5zfd zg3pJ%QQ-LpS#Ui%mtaM3tv(l~-8c#)^j_n@af2=}yS98hf#X>uhLlld$%50wAm4-a zljs*AV~a`3hvoaR(Ead4LdVd(7W!ryI<&SYf?fUiOgA$&C8rQJ8Ji}CqQTwRd>yvt zV^xl>Qwc~^IUK$wF^!F6%{{sBBTWuwzv=18Oik`I0whz@$b!|I`zNC!$B_ssrzX9t5P)Vgdt*suugZF;j?r zvj|#Sy}w2{EQoOE!oL+P57ytYz$&K3`&on2qU$#DY4#L)8nua7oFe2L7oV zi|t^}!{&tqi_Sw`bKnD!Lyg~#s9%Fguqb&*_b)>N*$GKrbayypoOO=oR&<%o*^gmXz@ zoV{@ATzfEtuIH|}6f)~^k+5ebW5A!PADF=od(l(a^! zDh>kndw+R#P5miz#V>j4HF__$sc;blpq(~oYkhoBtK`<1pW$#Vd!sRn{$@Gvoih=HmzcGOgkJW6&il&Z@L_iEonqYB5g5IIKBw@&ek{cKGU?zRh zhlD1|S{MXs2fqbwh{Fsa=$mlks6MX3#w%WFxyc{6+Vc_7z??b=1U2HoB&F!Qk2EC` z;u33issl=FLAg81v-Fh}tZaYJHiE|vhbjV|T8ZuofB(+Xi<$4*_7lQK-?bt%aKeLw zS$DVgGUwj!!paT7?J{e^*nx3(2Y*lfGXvPkwDL287k(F6cpTV)c?$iGNsOLLs9{7hD4ypcTd9A z#SnY*U0=eS=1~U-SHut+F6J2$er@BM$FEmrnX83A5ISbAQYCaRycz~+F?EKvL6Ssr z3fNSZDHmrgVs}x#Jw#OSwoapHDq&v5&czs!`Kw9B4K^IW^dphM;Q;9x8Gskfe^ASX z=}k6h=tzu^>+zA@m~m`#j0mA|tg8`%lhYr}gqi0C1f;`Wo}Y1hkkPr88r0I@!(EOs!DY#j@7MX>_#z+kFB+dCbP?of=$LVH z*~Qeyga}+}Q{)Hcar-V10V+|4xg#=7aAm^};oLBs8DcyCzh<%?(G>BlnrvVE=pC!q_!QA3BRJa7e3)Jnvi6b^1|ok!+W zJH@QNnG+|nLv_h@!bKk=2?WGy)>L2ToEyV#I&`uy ziz4>^L$DY{RG}P{>b69ujAAz6eYa@S1-8ao z;stCXaMRlaixYvPV542!3X!aGF_BdkPk4&eXu{X?*W??_3H9cCMtcgmoZpV#1&#pX z;!rUHU7T8EVH^jH#dQUx>z!S1pS!|zW@UcDMvYH_VNg)Y#=2?1=gfi1Kms>f#(vbh zIZIQXxsRYRO<9hF4z6qQWlK89JN}KUBjID-0jWV{mpDYstv`oATfafQL91LEPd;BeXdlb)tSU#)q6$OMG+8LzXc+H)cDcG> z9FvvP6$o6&3V}m@$1TRyiVuhvd*{sVzp>M|a%I96jlF~oOc&!ha7bXn;`oHOYD>oU z!Br%-S0fT|5m&uUDeA0z`I0DlARg1$US?7a-ZMNYIR-Smn&VPyDhcF3`V{s(%MwAW zs}Zz_-)Plhu?e7-{bxxkJHURZL8y3uVcGyiFb5LNaVKkjdh^Wh%q1>bn;_BIP6+}}Wg<CC-h>F+Yz}u9GB~X8!-n0_)(%x@D(88Y-!VM|N@DuZf{MIEWz61w} zrI3%eB3N;Ik35kq08YTY89wk0Pjjlm{&Mci;iiIZ2(*DxPbmSt1jduMBoA_O@iRkM z-$>JT1S9%Jv`az!VolGv;O#8AQOs`iR*HneYmi*5$5)mCUlv zOq5?f^jCJ*Fmnh&<6XmY>d_kQG1iJ~U*+EDE)JUDv!{kj?GQQ9Gn$(DublqLy{*Y| z?Ao|ufa&RBL?>Noq?CEK3 zW0>OyO>VkUPHZx%hfj=U<2x>oBr-U@ql^viT%zO|u2`~Bj~uRYj6qd5>!5^8!~>N( z&{~3pD5EU|g&|65Toq)o<3W06!R_Sgv?c3a%IjqMqR#~pVi(U>udGs9c>o6W_&vbH zhL{;BgTE~G^hfqo`jgN14E@K{qJQL1Qk1GndG;sCxq)O-OYzq`cQ0*_On+VXx7Gbj zbq7TLx|C8vQN$gJhG;tB_m8$~78v$+efs?qm3H(rUDbYHMX9Wm(SFZKJ=$(}k7aX{ z=ugKqI#Jz}F*ri1xKdOprtTwi9@?pht*^i?chhkkt9SGI@AHze}oN=gN#B>h+> zf`7Lx*LEhE{!V;Pm!Xlr6W@ar8j8^@;(Lm|_U&@^J^6-3elNa9`@KvGS>Ybi`md@) zs(;VUMlFmn>62h6^4yG(gB2&?@iwC^4R%}i33GeCJ!;a z=8ybcEv1H%6xq|r-^r{%i1Wvsiy7&CpQKd&AEd3H)$H)Zdp`|kq(^)f`MWqsBg7|D n=N@6CPY?V0f03Sdpk{|=Ig3s(($^l2{9PQR5u$kQtk?e^fD=yl diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index 0eac6d1c70d..7969e5c5076 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -600,6 +600,7 @@ def allocate_all_tensors(self, *, is_init: bool) -> None: ) # request_query_lengths is the input prompt tokens length during prefill phase (1st step) and then 1 for the decode phase (i.e During generation) self.request_query_lengths = torch.empty_like(self.request_ids) + # True only for a new request , then after a forward pass it is set to False self.request_in_prefill_status_tensor = torch.empty_like(self.request_ids) # request_output_lengths is len(input_prompt_tokens) + num_tokens_to_generate self.request_output_lengths = torch.empty_like(self.request_ids) @@ -1491,7 +1492,7 @@ def current_input_and_position_ids( self.token_to_pos_ids[:num_tokens].unsqueeze(0), ) - def last_token_logits(self, logits: Tensor, mtp_logits: Optional[Tensor] = None) -> Tensor: + def last_token_logits(self, logits: Tensor) -> Tensor: """Last tokens of logits. Args: @@ -1506,10 +1507,9 @@ def last_token_logits(self, logits: Tensor, mtp_logits: Optional[Tensor] = None) f"logits.size(1) ({tuple(logits.shape)}) != " f"padded_active_token_count ({self.padded_active_token_count})." ) - # Logits shape is [1, padded_active_token_count, vocab_size] # Last token logits. - logits = logits.squeeze(0) # [padded_active_token_count, vocaba_size] + logits = logits.squeeze(0) last_token_idxs = ( torch.cumsum( self.request_query_lengths[self.paused_request_count : self.total_request_count], @@ -1977,7 +1977,7 @@ def update_requests(self, active_requests_mask: Tensor, new_tokens: Tensor, new_ Args: active_requests_mask (Tensor): 1D Mask tensor marking active requests. (Active request length) new_tokens (Tensor): Newly sampled tokens, with one token per active request. (Active request length) - new_speculative_tokens (Tensor): Newly sampled speculative tokens, with one token per active request. (num_speculative_tokens, active_request_length) + new_speculative_tokens (Tensor): Newly sampled speculative tokens, with num_speculative tokens per active request. (num_speculative_tokens, active_request_length) Return: (Tensor) Newly paused request IDs. @@ -1987,15 +1987,6 @@ def update_requests(self, active_requests_mask: Tensor, new_tokens: Tensor, new_ # active_request_count -> This corresponds to requests that have not reached EOD or max length # finished_request_count are requests that have reached the termination criterion - - # new_tokens : [ b4 , c4, a6] - # actgive_requesT_mask [ 0 1 0 ] - # [1 0 0 ] - # new_spec_Tokens : [ [b4s1, c4s1, a6s1], - # [b4s2, c4s2, a6s2]] - - ## Vijay : [b4 b4s1, b4s2, c4 , c4s1, c4s2, a6 , a6s1, a6s2] - # self.num_prefill_requests = 0 # all turns to decode # All request that were in prefill become decode requests self.request_in_prefill_status_tensor[self.request_in_prefill_status_tensor == 1] = 0 # TODO : Check how this works with chunked prefill @@ -2193,65 +2184,65 @@ def update_requests(self, active_requests_mask: Tensor, new_tokens: Tensor, new_ self.request_query_lengths[self.paused_request_count : self.total_request_count] ) - self.request_query_lengths[self.paused_request_count : self.total_request_count].fill_(1 + self.num_speculative_tokens) + num_generated_tokens = 1 + self.num_speculative_tokens + self.request_query_lengths[self.paused_request_count : self.total_request_count].fill_(num_generated_tokens) old_offsets = self.request_last_kv_block_offset[self.paused_request_count : self.total_request_count].clone() self.request_last_kv_block_offset[self.paused_request_count : self.total_request_count] = ( - old_offsets + 1 + self.num_speculative_tokens + old_offsets + num_generated_tokens ) % self.block_size_tokens # ================================================================ - self.active_token_count = active_request_count * (1 + self.num_speculative_tokens) + self.active_token_count = active_request_count * num_generated_tokens sampled_tokens = next_tokens[ self.paused_request_count : self.total_request_count ] + if self.num_speculative_tokens > 0: # new_speculative_tokens has shape [num_spec_tokens, num_requests], slice the request dimension (dim 1) sampled_speculative_tokens = new_speculative_tokens[ :, self.paused_request_count : self.total_request_count ] - - next_tokens = torch.vstack([sampled_tokens.unsqueeze(0), sampled_speculative_tokens]).T.reshape(-1)# - + # This will become [sampled, spec1, spec2, sampled, spec1, spec2 ...] # For every request we will have the sampled token followed by the speculative tokens (i.e next indices) + next_tokens = torch.vstack([sampled_tokens.unsqueeze(0), sampled_speculative_tokens]).T.reshape(-1) else: next_tokens = sampled_tokens self.token_to_input_ids[: self.active_token_count] = next_tokens - # kv length offsets will tell the sequence length (query + generated_tokens) (During add request alone its 0) (It tells how many tokens there are in kv cache) + # Req kv length offsets : [0, 5, 10 ... ] + # For num spec tokens = 2 , this will become [0, 1, 2, 5, 6, 7 10, 11, 12 ...] self.token_to_pos_ids[: self.active_token_count] = self.request_kv_length_offsets[ self.paused_request_count : self.total_request_count - ].repeat_interleave(1 + self.num_speculative_tokens) + torch.arange(1 + self.num_speculative_tokens, device=torch.cuda.current_device()).repeat(active_request_count) + ].repeat_interleave(num_generated_tokens) + torch.arange(num_generated_tokens, device=torch.cuda.current_device()).repeat(active_request_count) # - - # 8. We make relevant changes to the token bookkeeping tensors [1 2 3] [1 1 1 2 2 2 ] + # Token to request idx : [0, 0, 0, 1, 1, 1, 2, 2, 2 ...] self.token_to_request_idx[: self.active_token_count] = torch.arange( self.paused_request_count, self.total_request_count, device=torch.cuda.current_device() - ).repeat_interleave(1 + self.num_speculative_tokens) + ).repeat_interleave(num_generated_tokens) # shan : Same as token_to_pos_ids ? - self.token_to_position_in_request[: self.active_token_count] = ( - self.request_kv_length_offsets[self.paused_request_count : self.total_request_count] - ).repeat_interleave(1 + self.num_speculative_tokens) + torch.arange(1 + self.num_speculative_tokens, device=torch.cuda.current_device()).repeat(active_request_count) + self.token_to_position_in_request[: self.active_token_count] = self.token_to_pos_ids[: self.active_token_count] self.token_to_local_position_within_kv_block[: self.active_token_count] = self.token_to_pos_ids[: self.active_token_count] % self.block_size_tokens - current_block_ids = self.request_last_kv_block_id[self.paused_request_count : self.total_request_count] - # 16 IS THE NUMBER OF TOKENS - # 4 speculative tokens - # 14 (2 ) - raw_positions = old_offsets[:, None] + 1 + torch.arange(1 + self.num_speculative_tokens + 1, device=torch.cuda.current_device())[None, :] # [active_request_count, num_speculative_tokens + 1] (+1 for generated toekns) + + # raw positions shape : [active_request_count, num_generated_tokens] + # e.g block size 6, old_offsets = [1,5,2] , num_generated_tokens = 3 + # raw_positions = [[1, 2, 3], [5, 6, 7], [2, 3, 4]] + # crosses_boundary = [[False, False, False], [False, True, True], [False, False, False]] + raw_positions = old_offsets[:, None] + torch.arange(num_generated_tokens, device=torch.cuda.current_device())[None, :] + # # A token crosses to the next block if its raw_position >= block_size crosses_boundary = raw_positions >= self.block_size_tokens - # TOKEN TO BLOCK IDX alone is quite complex - if not crosses_boundary.any(): + if not crosses_boundary.any() or self.num_speculative_tokens == 0: # Fast path: no tokens cross block boundary, all use current block self.token_to_block_idx[: self.active_token_count] = self.request_last_kv_block_id[ self.paused_request_count : self.total_request_count - ].repeat_interleave(1 + self.num_speculative_tokens) + ].repeat_interleave(num_generated_tokens) else: # Some tokens cross to the next block (this happens for resumed requests) @@ -2279,7 +2270,9 @@ def update_requests(self, active_requests_mask: Tensor, new_tokens: Tensor, new_ # Build block_idx: [active_count, N] # Start with current (new) block for all - block_idx = current_block_ids[:, None].expand(-1, 1 + self.num_speculative_tokens).clone() # [active_count, N] + # Lets say current block ids is [a1, a2 , a3] and num generated_tokens is 3 + # This will be [[a1, a1, a1], [a2, a2, a2], [a3, a3, a3]] + block_idx = current_block_ids[:, None].expand(-1, num_generated_tokens).clone() # [active_count, N] # For requests that have crossing, tokens BEFORE boundary use prev block # crosses_boundary is False for tokens before boundary @@ -2287,11 +2280,11 @@ def update_requests(self, active_requests_mask: Tensor, new_tokens: Tensor, new_ use_prev_block = request_has_crossing[:, None] & ~crosses_boundary # [active_count, N] # Apply previous block IDs where needed - prev_block_ids_expanded = prev_block_ids[:, None].expand(-1, 1 + self.num_speculative_tokens) + prev_block_ids_expanded = prev_block_ids[:, None].expand(-1, num_generated_tokens) block_idx = torch.where(use_prev_block, prev_block_ids_expanded, block_idx) + # Convert back to 1d tensor self.token_to_block_idx[: self.active_token_count] = block_idx.flatten() - # ================================================================ return { "newly_paused_request_ids": newly_paused_request_ids, diff --git a/megatron/core/inference/engines/dynamic_engine.py b/megatron/core/inference/engines/dynamic_engine.py index dddf185d50a..473edbb356c 100644 --- a/megatron/core/inference/engines/dynamic_engine.py +++ b/megatron/core/inference/engines/dynamic_engine.py @@ -193,6 +193,7 @@ def __init__( if self.num_speculative_tokens > 0: assert not self.context.materialize_only_last_token_logits, "Speculative decoding requires materialize_only_last_token_logits to be False" assert self.num_speculative_tokens <= self.controller.num_mtp_heads, f"Number of speculative tokens {self.num_speculative_tokens} must be less than or equal to number of MTP heads {self.controller.num_mtp_heads}" + assert not self.enable_chunked_prefill, "Chunked prefill is not supported with speculative tokens" self.context.num_speculative_tokens = num_speculative_tokens self.controller.num_speculative_tokens = num_speculative_tokens diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index 4ed3ab4d77b..7f52029665a 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -100,7 +100,7 @@ def set_stop_word_finished_ids_callback(self, callback): def _init_dynamic_sampling_tensors(self): """Initialize tensors needed for dynamic sampling.""" - context: DynamicInferenceContext = self.inference_wrapped_model.inference_context + context = self.inference_wrapped_model.inference_context max_requests = context.max_requests # Callback to get request IDs that should be marked as finished due to stop words @@ -135,7 +135,7 @@ def _init_dynamic_sampling_tensors(self): def _init_mtp_sampling_tensor(self): """Initialize the MTP sampling tensor after num_speculative_tokens is set.""" if self.num_speculative_tokens is not None and self.num_speculative_tokens > 0: - context: DynamicInferenceContext = self.inference_wrapped_model.inference_context + context = self.inference_wrapped_model.inference_context max_requests = context.max_requests device = torch.cuda.current_device() self._sampled_mtp_tokens_cuda = torch.empty( @@ -530,7 +530,7 @@ def _dynamic_step_context_init( input_ids (Tensor): The active input IDs. position_ids (Tensor): The active position IDs. """ - context: DynamicInferenceContext = self.inference_wrapped_model.inference_context + context = self.inference_wrapped_model.inference_context inference_wrapper_config = self.inference_wrapped_model.inference_wrapper_config active_request_slice = slice(context.paused_request_count, context.total_request_count) @@ -597,7 +597,7 @@ def _dynamic_step_forward_logits(self, input_ids: Tensor, position_ids: Tensor) """ inference_wrapper_config = self.inference_wrapped_model.inference_wrapper_config - context: DynamicInferenceContext = self.inference_wrapped_model.inference_context + context = self.inference_wrapped_model.inference_context active_request_count = context.total_request_count - context.paused_request_count with torch.inference_mode(): @@ -640,7 +640,7 @@ def _dynamic_step_forward_logits(self, input_ids: Tensor, position_ids: Tensor) def _dynamic_step_sample_bookkeeping(self): """Perform bookkeeping necessary to sample logits for dynamic batching.""" - context: DynamicInferenceContext = self.inference_wrapped_model.inference_context + context = self.inference_wrapped_model.inference_context active_request_slice = slice(context.paused_request_count, context.total_request_count) if self._sampling_backend == "torch": @@ -653,9 +653,9 @@ def _dynamic_step_sample_bookkeeping(self): top_k = self._request_metadata["top_k"][active_request_slice].tolist() top_p = self._request_metadata["top_p"][active_request_slice].tolist() - for i, (t, k, p) in enumerate(zip(temp, top_k, top_p)): + for request_index, (t, k, p) in enumerate(zip(temp, top_k, top_p)): sampling_params = (t, k, p) - bucket_map[sampling_params].append(i) + bucket_map[sampling_params].append(request_index) # Just unpack the key directly! self._torch_sampling_buckets = [ @@ -677,7 +677,7 @@ def _rewind_kv_cache(self): - Clear the entry in request_to_kv_block_ids for the released block - Release the block back to the allocator """ - context: DynamicInferenceContext = self.inference_wrapped_model.inference_context + context = self.inference_wrapped_model.inference_context active_request_count = context.total_request_count - context.paused_request_count active_request_slice = slice(context.paused_request_count, context.total_request_count) @@ -747,7 +747,7 @@ def _rewind_kv_cache(self): def _dynamic_step_sample_logits_and_verify_tokens(self, logits: Tensor, mtp_logits: Tensor, input_ids: Tensor): f"""Sample tokens from logits for dynamic batching with speculative tokens and verify the tokens. """ - context: DynamicInferenceContext = self.inference_wrapped_model.inference_context + context = self.inference_wrapped_model.inference_context active_request_count = context.total_request_count - context.paused_request_count @@ -917,7 +917,7 @@ def _dynamic_step_sample_logits(self, logits: Tensor): # and then broadcast the sampled tokens rather than broadcasting the raw logits. # Last token logits. - context: DynamicInferenceContext = self.inference_wrapped_model.inference_context + context = self.inference_wrapped_model.inference_context if context.materialize_only_last_token_logits: # When materialize_only_last_token_logits is true, last_token_logits is # already called in the forward pass of GPT. @@ -953,7 +953,7 @@ def _dynamic_step_log_probs_bookkeeping(self) -> Tuple[bool, bool]: Returns: return_log_probs (bool): Whether to return the sampled log_probs. """ - context: DynamicInferenceContext = self.inference_wrapped_model.inference_context + context = self.inference_wrapped_model.inference_context active_request_slice = slice(context.paused_request_count, context.total_request_count) return_log_probs = self._request_metadata["return_log_probs"][active_request_slice] @@ -963,7 +963,7 @@ def _dynamic_step_log_probs_bookkeeping(self) -> Tuple[bool, bool]: def _dynamic_step_calculate_log_probs(self, logits: Tensor) -> Optional[Tensor]: """Calculate log probs from logits.""" - context: DynamicInferenceContext = self.inference_wrapped_model.inference_context + context = self.inference_wrapped_model.inference_context active_request_count = context.total_request_count - context.paused_request_count return context.calculate_log_probs( @@ -992,7 +992,7 @@ def _dynamic_step_calculate_top_n_logprobs( "computing log_probs when return_top_n_logprobs is True." ) - context: DynamicInferenceContext = self.inference_wrapped_model.inference_context + context = self.inference_wrapped_model.inference_context active_request_count = context.total_request_count - context.paused_request_count active_request_slice = slice(context.paused_request_count, context.total_request_count) @@ -1060,7 +1060,7 @@ def dummy_forward(self): """Perform a dummy forward pass. This is used in expert model parallelism on ranks that do not have any real requests.""" - context: DynamicInferenceContext = self.inference_wrapped_model.inference_context + context = self.inference_wrapped_model.inference_context # if no cuda graphs, directly use dummy forward if not context.cuda_graph_batch_dimensions_list: return self.inference_wrapped_model.dummy_forward() @@ -1101,7 +1101,7 @@ def _dynamic_step_context_bookkeeping(self) -> Dict[str, Tensor]: newly_paused_request_ids (Tensor): Newly paused request IDs. finished_request_ids (Tensor): Finished request IDs. """ - context: DynamicInferenceContext = self.inference_wrapped_model.inference_context + context = self.inference_wrapped_model.inference_context active_request_count = context.total_request_count - context.paused_request_count active_request_slice = slice(context.paused_request_count, context.total_request_count) @@ -1110,8 +1110,7 @@ def _dynamic_step_context_bookkeeping(self) -> Dict[str, Tensor]: active_sequence_lengths = context.get_active_sequence_lengths() if self.num_speculative_tokens > 0: - accepted_token_counts_per_request = self._accepted_token_counts_per_request[:active_request_count] - active_sequence_lengths += accepted_token_counts_per_request + 1 + active_sequence_lengths += self._accepted_token_counts_per_request[:active_request_count] + 1 else: active_sequence_lengths += 1 max_sequence_lengths = context.get_max_sequence_lengths() @@ -1174,7 +1173,7 @@ async def async_generate_output_tokens_dynamic_batch( log_probs (Optional[Tensor]): Log probabilities of the new sample, if requested. cuda_graph_request_count (Optional[int]): Size of cuda graph used for this step. """ - context: DynamicInferenceContext = self.inference_wrapped_model.inference_context + context = self.inference_wrapped_model.inference_context active_request_count = context.total_request_count - context.paused_request_count # No tokens? @@ -1191,8 +1190,7 @@ async def async_generate_output_tokens_dynamic_batch( mtp_logits = None if logits_and_mtp_logits.shape[0] > 1: logits = logits_and_mtp_logits[:1] # [1, seq_len, vocab_size] - mtp_logits = logits_and_mtp_logits[1:] # [num_speculative_tokens, seq_len, vocab_size] - print(f"mtp_logits: {mtp_logits.shape}",f"logits: {logits.shape}") + mtp_logits = logits_and_mtp_logits[1:] # [num_speculative_tokens, seq_len, vocab_size]\ else: logits = logits_and_mtp_logits @@ -1204,8 +1202,9 @@ async def async_generate_output_tokens_dynamic_batch( # Todo [Siddharth]: Can we condition the sleep on a cuda event? # NOTE [TDE]: This will be moved once CPU and GPU methods are separated. await asyncio.sleep(0) - # For now lets not care about log probs and top n logprobs return_log_probs, return_top_n_logprobs = self._dynamic_step_log_probs_bookkeeping() + if self.num_speculative_tokens > 0: + assert return_log_probs == False and return_top_n_logprobs == False, "Log probs and top n log probs are not supported with speculative tokens" self._dynamic_step_sample_bookkeeping() @@ -1228,11 +1227,10 @@ async def async_generate_output_tokens_dynamic_batch( if skip_bookkeeping: request_bookkeeping = {} else: - request_bookkeeping = self._dynamic_step_context_bookkeeping() - sample = self._sampled_tokens_cuda[:active_request_count] + request_bookkeeping = self._dynamic_step_context_bookkeeping() ret = { - "sample": sample, + "sample": self._sampled_tokens_cuda[:active_request_count], "accepted_tokens": self._accepted_tokens_per_request, "log_probs": log_probs, "top_n_logprobs": top_n_logprobs, @@ -1503,7 +1501,7 @@ def generate_all_output_tokens_static_batch( self.inference_wrapped_model.inference_context.is_decode_only() or not (sampling_params.return_log_probs or sampling_params.top_n_logprobs > 0) ) - inference_context: DynamicInferenceContext = self.inference_wrapped_model.inference_context + inference_context = self.inference_wrapped_model.inference_context inference_context.materialize_only_last_token_logits = ( materialize_only_last_token_logits ) From 6aab08f54a263cff2f76bbddb90be222f43bb3cd Mon Sep 17 00:00:00 2001 From: Shanmugam Ramasamy Date: Fri, 6 Feb 2026 13:11:13 -0800 Subject: [PATCH 04/76] Bug fix --- .../text_generation_controllers/text_generation_controller.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index 7f52029665a..dce6475298c 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -1208,7 +1208,7 @@ async def async_generate_output_tokens_dynamic_batch( self._dynamic_step_sample_bookkeeping() - if self.num_speculative_tokens > 1: + if self.num_speculative_tokens > 0: self._dynamic_step_sample_logits_and_verify_tokens(logits, mtp_logits, input_ids) self._rewind_kv_cache() else: From fceb983a344c12d9a6c8b1306134c89c14d3f7aa Mon Sep 17 00:00:00 2001 From: Shanmugam Ramasamy Date: Fri, 6 Feb 2026 14:04:39 -0800 Subject: [PATCH 05/76] Bug fix --- megatron/core/inference/contexts/dynamic_context.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index 7969e5c5076..b02dd782779 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -2081,7 +2081,7 @@ def update_requests(self, active_requests_mask: Tensor, new_tokens: Tensor, new_ self.paused_request_count : (active_request_count + self.paused_request_count) ] active_requests_requiring_new_block = ( - num_tokens_in_last_block > self.block_size_tokens - 1 - self.num_speculative_tokens + num_tokens_in_last_block >= self.block_size_tokens - 1 - self.num_speculative_tokens ).byte() if self.chunked_prefill_request_id != -1: @@ -2143,7 +2143,7 @@ def update_requests(self, active_requests_mask: Tensor, new_tokens: Tensor, new_ # After resume_paused_requests, request_last_kv_block_id will be updated to the NEW block # for resumed requests, but we need the OLD block for tokens that don't cross. prev_last_block_ids = None - if self.num_speculative_tokens > 1: + if self.num_speculative_tokens > 0: prev_last_block_ids = self.request_last_kv_block_id.clone() From 6aeaced25540d95b6a0eab68154a034212859d39 Mon Sep 17 00:00:00 2001 From: Shanmugam Ramasamy Date: Fri, 6 Feb 2026 15:47:47 -0800 Subject: [PATCH 06/76] Bug fix --- examples/inference/gpt/gpt_dynamic_inference.py | 2 +- .../text_generation_controller.py | 10 +++++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/examples/inference/gpt/gpt_dynamic_inference.py b/examples/inference/gpt/gpt_dynamic_inference.py index d45f0398fb6..29d410b295f 100644 --- a/examples/inference/gpt/gpt_dynamic_inference.py +++ b/examples/inference/gpt/gpt_dynamic_inference.py @@ -443,7 +443,7 @@ def main(): # Reset peak memory stats so functional tests measure this run and not # whatever happened earlier during initialization. torch.cuda.reset_peak_memory_stats() - + # Sampling params. sampling_params = SamplingParams( temperature=args.temperature, diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index dce6475298c..6f9406f5ba3 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -686,7 +686,12 @@ def _rewind_kv_cache(self): accepted_tokens_per_request = self._accepted_token_counts_per_request[:active_request_count] # Number of tokens to rewind (rejected speculative tokens) - num_tokens_to_rewind = accepted_tokens_per_request - self.num_speculative_tokens + num_tokens_to_rewind = self.num_speculative_tokens - accepted_tokens_per_request + + # For prefill requests, no speculative tokens were forwarded through the model, + # so there is nothing to rewind. + request_in_prefill_status = context.request_in_prefill_status_tensor[active_request_slice] + num_tokens_to_rewind[request_in_prefill_status == 1] = 0 # Save the original offset BEFORE modifying to correctly detect block boundary crossing original_offset = context.request_last_kv_block_offset[active_request_slice].clone() @@ -1236,6 +1241,9 @@ async def async_generate_output_tokens_dynamic_batch( "top_n_logprobs": top_n_logprobs, "cuda_graph_request_count": cuda_graph_request_count, } + if self.num_speculative_tokens > 0: + self._accepted_tokens_per_request.fill_(-1) + self._accepted_token_counts_per_request.fill_(0) ret.update(request_bookkeeping) return ret From b2718b8e555968bc8af5a47e14666204bce4022d Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Mon, 23 Feb 2026 11:41:35 -0800 Subject: [PATCH 07/76] WIP MTP for mamba Signed-off-by: Keshav Santhanam --- .../inference/gpt/gpt_dynamic_inference.py | 1 - megatron/core/inference/config.py | 3 + .../attention_context/mamba_metadata.py | 7 + .../inference/contexts/dynamic_context.py | 220 ++++++++++----- .../core/inference/engines/dynamic_engine.py | 39 ++- .../text_generation_controller.py | 260 +++++++++++------- megatron/core/models/gpt/gpt_model.py | 4 +- megatron/core/ssm/mamba_mixer.py | 60 +++- megatron/training/arguments.py | 3 +- 9 files changed, 399 insertions(+), 198 deletions(-) diff --git a/examples/inference/gpt/gpt_dynamic_inference.py b/examples/inference/gpt/gpt_dynamic_inference.py index 74c5e47f741..8729af3c94c 100644 --- a/examples/inference/gpt/gpt_dynamic_inference.py +++ b/examples/inference/gpt/gpt_dynamic_inference.py @@ -310,7 +310,6 @@ def main(): engine = DynamicInferenceEngine( controller, context, - num_speculative_tokens=args.num_speculative_tokens, ) setup_prefix = build_dynamic_engine_setup_prefix(args, model, context, requests) diff --git a/megatron/core/inference/config.py b/megatron/core/inference/config.py index 54800a3fbff..5a13b6aaf5d 100644 --- a/megatron/core/inference/config.py +++ b/megatron/core/inference/config.py @@ -185,6 +185,9 @@ class InferenceConfig: enable_chunked_prefill: bool = False """Whether to enable chunked prefill.""" + num_speculative_tokens: int = 0 + """The number of speculative tokens to generate for decode steps.""" + # ================================= # Logging config # ================================= diff --git a/megatron/core/inference/contexts/attention_context/mamba_metadata.py b/megatron/core/inference/contexts/attention_context/mamba_metadata.py index d7fcf7436a2..fbc5d2145ac 100644 --- a/megatron/core/inference/contexts/attention_context/mamba_metadata.py +++ b/megatron/core/inference/contexts/attention_context/mamba_metadata.py @@ -64,6 +64,11 @@ def __init__(self, max_requests: int, max_tokens: int): (2,), dtype=torch.int32, device=self.device ) + # Map from requests to accepted tokens in speculative decoding + self._num_accepted_tokens_buffer = torch.zeros( + (self.max_requests,), dtype=torch.int32, device=self.device + ) + # Allocator for Mamba state slots self.mamba_state_free_slots = torch.arange( self.max_requests, dtype=torch.int32, device=torch.cuda.current_device() @@ -95,6 +100,7 @@ def reset_varlen_metadata(self) -> None: self.seq_idx = None self.device_decode_prefill = None self.device_chunked_prefill = None + self.num_accepted_tokens = None def update( self, @@ -175,6 +181,7 @@ def update( if padded_decode_count > real_decode_count: self._batch_indices_decode_buffer[real_decode_count:padded_decode_count] = -1 self.batch_indices_decode = self._batch_indices_decode_buffer[:padded_decode_count] + self.num_accepted_tokens = self._num_accepted_tokens_buffer[:padded_decode_count] # Determine if we have a chunked prefill request and adjust counts for regular prefill regular_prefill_count = real_prefill_count diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index 08e0a8ee29e..e2ceb3b3b68 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -622,6 +622,27 @@ def _allocate_mamba_states(self): dtype=self.params_dtype, device=torch.cuda.current_device(), ) + if self.num_speculative_tokens > 0: + self.mamba_intermediate_conv_states = torch.empty( + ( + self.num_mamba_layers, + self.max_requests, + self.num_speculative_tokens, + *self.mamba_conv_states_shape, + ), + dtype=self.params_dtype, + device=torch.cuda.current_device(), + ) + self.mamba_intermediate_ssm_states = torch.empty( + ( + self.num_mamba_layers, + self.max_requests, + self.num_speculative_tokens, + *self.mamba_ssm_states_shape, + ), + dtype=self.params_dtype, + device=torch.cuda.current_device(), + ) if ( self.kv_cache_management_mode == KVCacheManagementMode.OFFLOAD and not self._uses_torch_memory_saver @@ -635,6 +656,19 @@ def _allocate_mamba_states(self): self._offloadable_cpu_backups["mamba_ssm_states"] = torch.empty_like( self.mamba_ssm_states, device="cpu" ).pin_memory() + if self.num_speculative_tokens > 0: + self._offloadable_tensor_names.add("mamba_intermediate_conv_states") + self._offloadable_cpu_backups["mamba_intermediate_conv_states"] = ( + torch.empty_like( + self.mamba_intermediate_conv_states, device="cpu" + ).pin_memory() + ) + self._offloadable_tensor_names.add("mamba_intermediate_ssm_states") + self._offloadable_cpu_backups["mamba_intermediate_ssm_states"] = ( + torch.empty_like( + self.mamba_intermediate_ssm_states, device="cpu" + ).pin_memory() + ) else: self.mamba_metadata = None @@ -931,13 +965,19 @@ def key_value_cache(self, layer_number: int) -> Tuple[Tensor, Optional[Tensor], self.active_attn_metadata["mha_metadata"].state_data["block_table"], ) - def mamba_states_cache(self, layer_number: int) -> Tuple[Tensor, Tensor]: + def mamba_states_cache( + self, layer_number: int, intermediate: bool = False + ) -> Tuple[Tensor, Tensor]: """Returns the Mamba state tensors for the given layer.""" assert self.is_hybrid_model, "Only hybrid models have Mamba state tensors" mamba_layer_number = self.layer_map[layer_number - 1] - conv_state = self.mamba_conv_states[mamba_layer_number] - ssm_state = self.mamba_ssm_states[mamba_layer_number] + if intermediate: + conv_state = self.mamba_intermediate_conv_states[mamba_layer_number] + ssm_state = self.mamba_intermediate_ssm_states[mamba_layer_number] + else: + conv_state = self.mamba_conv_states[mamba_layer_number] + ssm_state = self.mamba_ssm_states[mamba_layer_number] return (conv_state, ssm_state) @@ -1470,7 +1510,7 @@ def current_input_and_position_ids( self.token_to_input_ids[:num_tokens].unsqueeze(0), self.token_to_pos_ids[:num_tokens].unsqueeze(0), ) - + def last_token_logits(self, logits: Tensor) -> Tensor: """Last tokens of logits. @@ -1488,7 +1528,7 @@ def last_token_logits(self, logits: Tensor) -> Tensor: ) # Last token logits. - logits = logits.squeeze(0) + logits = logits.squeeze(0) last_token_idxs = ( torch.cumsum( self.request_query_lengths[self.paused_request_count : self.total_request_count], @@ -1665,16 +1705,20 @@ def add_request(self, req: DynamicInferenceRequest, chunk_length: Optional[int] self.total_request_count += 0 if req.finished_chunk_token_count > 0 else 1 self.num_prefill_requests += 1 - def _move_book_keeping_tensors(self, src_idxs, dst_idxs, next_tokens, new_speculative_tokens = None): + def _move_book_keeping_tensors( + self, src_idxs, dst_idxs, next_tokens, new_speculative_tokens=None + ): """ Move all the relevent booking tensors with src idxs to dst idxs """ self.request_kv_length_offsets[dst_idxs] = self.request_kv_length_offsets[src_idxs] - self.request_in_prefill_status_tensor[dst_idxs] = self.request_in_prefill_status_tensor[src_idxs] + self.request_in_prefill_status_tensor[dst_idxs] = self.request_in_prefill_status_tensor[ + src_idxs + ] self.request_query_lengths[dst_idxs] = self.request_query_lengths[src_idxs] self.request_output_lengths[dst_idxs] = self.request_output_lengths[src_idxs] self.request_ids[dst_idxs] = self.request_ids[src_idxs] - next_tokens[dst_idxs] = next_tokens[src_idxs] # num tokens sames as num samples + next_tokens[dst_idxs] = next_tokens[src_idxs] # num tokens sames as num samples if new_speculative_tokens is not None: new_speculative_tokens[:, dst_idxs] = new_speculative_tokens[:, src_idxs] self.request_to_kv_block_ids[dst_idxs] = self.request_to_kv_block_ids[src_idxs] @@ -1936,8 +1980,12 @@ def evict_overflow_paused_requests( return evict_request_ids - - def update_requests(self, active_requests_mask: Tensor, new_tokens: Tensor, new_speculative_tokens: Tensor = None) -> Tensor: + def update_requests( + self, + active_requests_mask: Tensor, + new_tokens: Tensor, + new_speculative_tokens: Tensor = None, + ) -> Tensor: """Update context state after calling engine.step(). This method is responsible for: @@ -1984,13 +2032,14 @@ def update_requests(self, active_requests_mask: Tensor, new_tokens: Tensor, new_ self.num_prefill_requests = 0 # all turns to decode # All request that were in prefill become decode requests - self.request_in_prefill_status_tensor[self.request_in_prefill_status_tensor == 1] = 0 # TODO : Check how this works with chunked prefill + self.request_in_prefill_status_tensor[self.request_in_prefill_status_tensor == 1] = ( + 0 # TODO : Check how this works with chunked prefill + ) if self.chunked_prefill_request_id != -1: active_requests_mask[-1] = ( 1 # must keep this, next iteration will add a new chunk to it ) - active_request_count = (active_requests_mask == 1).sum().item() finished_request_count = (active_requests_mask == 0).sum().item() assert ( @@ -2026,7 +2075,9 @@ def update_requests(self, active_requests_mask: Tensor, new_tokens: Tensor, new_ if self.paused_request_count != 0: assert self.paused_tokens is not None next_tokens = torch.cat((self.paused_tokens, new_tokens)) - new_speculative_tokens = torch.cat((self.paused_speculative_tokens, new_speculative_tokens), dim=1) + new_speculative_tokens = torch.cat( + (self.paused_speculative_tokens, new_speculative_tokens), dim=1 + ) else: next_tokens = new_tokens @@ -2123,7 +2174,10 @@ def update_requests(self, active_requests_mask: Tensor, new_tokens: Tensor, new_ dst_idxs = torch.cat((active_request_ids_on_left, paused_requests_idxs_on_right)) src_idxs = torch.cat((paused_requests_idxs_on_right, active_request_ids_on_left)) self._move_book_keeping_tensors( - src_idxs=src_idxs, dst_idxs=dst_idxs, next_tokens=next_tokens, new_speculative_tokens=new_speculative_tokens + src_idxs=src_idxs, + dst_idxs=dst_idxs, + next_tokens=next_tokens, + new_speculative_tokens=new_speculative_tokens, ) self.paused_request_count += active_requests_requiring_new_block_count @@ -2140,7 +2194,6 @@ def update_requests(self, active_requests_mask: Tensor, new_tokens: Tensor, new_ if self.num_speculative_tokens > 0: prev_last_block_ids = self.request_last_kv_block_id.clone() - # 6.a. First, resume temporarily paused requests. active_request_count, newly_paused_request_ids = self.resume_paused_requests( active_request_count, newly_paused_request_ids, next_tokens @@ -2179,19 +2232,21 @@ def update_requests(self, active_requests_mask: Tensor, new_tokens: Tensor, new_ ) num_generated_tokens = 1 + self.num_speculative_tokens - self.request_query_lengths[self.paused_request_count : self.total_request_count].fill_(num_generated_tokens) + self.request_query_lengths[self.paused_request_count : self.total_request_count].fill_( + num_generated_tokens + ) - old_offsets = self.request_last_kv_block_offset[self.paused_request_count : self.total_request_count].clone() + old_offsets = self.request_last_kv_block_offset[ + self.paused_request_count : self.total_request_count + ].clone() self.request_last_kv_block_offset[self.paused_request_count : self.total_request_count] = ( old_offsets + num_generated_tokens ) % self.block_size_tokens # ================================================================ - self.active_token_count = active_request_count * num_generated_tokens - sampled_tokens = next_tokens[ - self.paused_request_count : self.total_request_count - ] + self.active_token_count = active_request_count * num_generated_tokens + sampled_tokens = next_tokens[self.paused_request_count : self.total_request_count] if self.num_speculative_tokens > 0: # new_speculative_tokens has shape [num_spec_tokens, num_requests], slice the request dimension (dim 1) @@ -2199,86 +2254,105 @@ def update_requests(self, active_requests_mask: Tensor, new_tokens: Tensor, new_ :, self.paused_request_count : self.total_request_count ] # This will become [sampled, spec1, spec2, sampled, spec1, spec2 ...] # For every request we will have the sampled token followed by the speculative tokens (i.e next indices) - next_tokens = torch.vstack([sampled_tokens.unsqueeze(0), sampled_speculative_tokens]).T.reshape(-1) + next_tokens = torch.vstack( + [sampled_tokens.unsqueeze(0), sampled_speculative_tokens] + ).T.reshape(-1) else: next_tokens = sampled_tokens - + self.token_to_input_ids[: self.active_token_count] = next_tokens # Req kv length offsets : [0, 5, 10 ... ] - # For num spec tokens = 2 , this will become [0, 1, 2, 5, 6, 7 10, 11, 12 ...] + # For num spec tokens = 2 , this will become [0, 1, 2, 5, 6, 7 10, 11, 12 ...] self.token_to_pos_ids[: self.active_token_count] = self.request_kv_length_offsets[ self.paused_request_count : self.total_request_count - ].repeat_interleave(num_generated_tokens) + torch.arange(num_generated_tokens, device=torch.cuda.current_device()).repeat(active_request_count) + ].repeat_interleave(num_generated_tokens) + torch.arange( + num_generated_tokens, device=torch.cuda.current_device() + ).repeat( + active_request_count + ) # # Token to request idx : [0, 0, 0, 1, 1, 1, 2, 2, 2 ...] self.token_to_request_idx[: self.active_token_count] = torch.arange( self.paused_request_count, self.total_request_count, device=torch.cuda.current_device() ).repeat_interleave(num_generated_tokens) - # shan : Same as token_to_pos_ids ? - self.token_to_position_in_request[: self.active_token_count] = self.token_to_pos_ids[: self.active_token_count] + # shan : Same as token_to_pos_ids ? + self.token_to_position_in_request[: self.active_token_count] = self.token_to_pos_ids[ + : self.active_token_count + ] - self.token_to_local_position_within_kv_block[: self.active_token_count] = self.token_to_pos_ids[: self.active_token_count] % self.block_size_tokens + self.token_to_local_position_within_kv_block[: self.active_token_count] = ( + self.token_to_pos_ids[: self.active_token_count] % self.block_size_tokens + ) - current_block_ids = self.request_last_kv_block_id[self.paused_request_count : self.total_request_count] + current_block_ids = self.request_last_kv_block_id[ + self.paused_request_count : self.total_request_count + ] # raw positions shape : [active_request_count, num_generated_tokens] - # e.g block size 6, old_offsets = [1,5,2] , num_generated_tokens = 3 + # e.g block size 6, old_offsets = [1,5,2] , num_generated_tokens = 3 # raw_positions = [[1, 2, 3], [5, 6, 7], [2, 3, 4]] # crosses_boundary = [[False, False, False], [False, True, True], [False, False, False]] - raw_positions = old_offsets[:, None] + torch.arange(num_generated_tokens, device=torch.cuda.current_device())[None, :] + raw_positions = ( + old_offsets[:, None] + + torch.arange(num_generated_tokens, device=torch.cuda.current_device())[None, :] + ) # # A token crosses to the next block if its raw_position >= block_size crosses_boundary = raw_positions >= self.block_size_tokens if not crosses_boundary.any() or self.num_speculative_tokens == 0: - # Fast path: no tokens cross block boundary, all use current block + # Fast path: no tokens cross block boundary, all use current block self.token_to_block_idx[: self.active_token_count] = self.request_last_kv_block_id[ self.paused_request_count : self.total_request_count - ].repeat_interleave(num_generated_tokens) + ].repeat_interleave(num_generated_tokens) else: - # Some tokens cross to the next block (this happens for resumed requests) - # - # When a request is paused and resumed: - # 1. It was paused because remaining_space < num_tokens_per_step - # 2. A NEW block is allocated in resume_paused_requests - # 3. request_last_kv_block_id is updated to the NEW block - # 4. The old offset is preserved (wasn't reset) - # - # So for resumed requests: - # - Tokens before the boundary (raw_pos < block_size): go to PREVIOUS block - # - Tokens at/after the boundary (raw_pos >= block_size): go to CURRENT (new) block - # - # For non-resumed requests (no boundary crossing): all go to current block - # - # We use prev_last_block_ids which was stored BEFORE resume_paused_requests - # was called, so it contains the OLD block IDs before new blocks were allocated. - - # Get previous block IDs (stored before resume_paused_requests) - prev_block_ids = prev_last_block_ids[self.paused_request_count : self.total_request_count] # [active_count] - - # For each request, check if ANY token crosses (i.e., request was resumed) - request_has_crossing = crosses_boundary.any(dim=1) # [active_count] - - # Build block_idx: [active_count, N] - # Start with current (new) block for all - # Lets say current block ids is [a1, a2 , a3] and num generated_tokens is 3 - # This will be [[a1, a1, a1], [a2, a2, a2], [a3, a3, a3]] - block_idx = current_block_ids[:, None].expand(-1, num_generated_tokens).clone() # [active_count, N] - - # For requests that have crossing, tokens BEFORE boundary use prev block - # crosses_boundary is False for tokens before boundary - # So: where request_has_crossing AND NOT crosses_boundary, use prev_block - use_prev_block = request_has_crossing[:, None] & ~crosses_boundary # [active_count, N] - - # Apply previous block IDs where needed - prev_block_ids_expanded = prev_block_ids[:, None].expand(-1, num_generated_tokens) - block_idx = torch.where(use_prev_block, prev_block_ids_expanded, block_idx) - - # Convert back to 1d tensor - self.token_to_block_idx[: self.active_token_count] = block_idx.flatten() + # Some tokens cross to the next block (this happens for resumed requests) + # + # When a request is paused and resumed: + # 1. It was paused because remaining_space < num_tokens_per_step + # 2. A NEW block is allocated in resume_paused_requests + # 3. request_last_kv_block_id is updated to the NEW block + # 4. The old offset is preserved (wasn't reset) + # + # So for resumed requests: + # - Tokens before the boundary (raw_pos < block_size): go to PREVIOUS block + # - Tokens at/after the boundary (raw_pos >= block_size): go to CURRENT (new) block + # + # For non-resumed requests (no boundary crossing): all go to current block + # + # We use prev_last_block_ids which was stored BEFORE resume_paused_requests + # was called, so it contains the OLD block IDs before new blocks were allocated. + + # Get previous block IDs (stored before resume_paused_requests) + prev_block_ids = prev_last_block_ids[ + self.paused_request_count : self.total_request_count + ] # [active_count] + + # For each request, check if ANY token crosses (i.e., request was resumed) + request_has_crossing = crosses_boundary.any(dim=1) # [active_count] + + # Build block_idx: [active_count, N] + # Start with current (new) block for all + # Lets say current block ids is [a1, a2 , a3] and num generated_tokens is 3 + # This will be [[a1, a1, a1], [a2, a2, a2], [a3, a3, a3]] + block_idx = ( + current_block_ids[:, None].expand(-1, num_generated_tokens).clone() + ) # [active_count, N] + + # For requests that have crossing, tokens BEFORE boundary use prev block + # crosses_boundary is False for tokens before boundary + # So: where request_has_crossing AND NOT crosses_boundary, use prev_block + use_prev_block = request_has_crossing[:, None] & ~crosses_boundary # [active_count, N] + + # Apply previous block IDs where needed + prev_block_ids_expanded = prev_block_ids[:, None].expand(-1, num_generated_tokens) + block_idx = torch.where(use_prev_block, prev_block_ids_expanded, block_idx) + + # Convert back to 1d tensor + self.token_to_block_idx[: self.active_token_count] = block_idx.flatten() return { "newly_paused_request_ids": newly_paused_request_ids, diff --git a/megatron/core/inference/engines/dynamic_engine.py b/megatron/core/inference/engines/dynamic_engine.py index db0d2e48a0b..6e59c34c98a 100644 --- a/megatron/core/inference/engines/dynamic_engine.py +++ b/megatron/core/inference/engines/dynamic_engine.py @@ -152,7 +152,7 @@ class DynamicInferenceEngine(AbstractEngine): *DEPRECATED_ARGS, message="Argument `{name}` has been deprecated. Only pass `controller` and `context`", ) - def __init__(self, controller: TextGenerationController, context: DynamicInferenceContext, num_speculative_tokens: Optional[int] = 0,): + def __init__(self, controller: TextGenerationController, context: DynamicInferenceContext): assert isinstance( controller, TextGenerationController @@ -172,17 +172,24 @@ def __init__(self, controller: TextGenerationController, context: DynamicInferen # Initialization options. self.controller = controller self.context = context - self.num_speculative_tokens = num_speculative_tokens + + self.num_speculative_tokens = inference_config.num_speculative_tokens assert self.num_speculative_tokens >= 0, "Number of speculative tokens must be non-negative" if self.num_speculative_tokens > 0: - assert not self.context.materialize_only_last_token_logits, "Speculative decoding requires materialize_only_last_token_logits to be False" - assert self.num_speculative_tokens <= self.controller.num_mtp_heads, f"Number of speculative tokens {self.num_speculative_tokens} must be less than or equal to number of MTP heads {self.controller.num_mtp_heads}" - assert not self.enable_chunked_prefill, "Chunked prefill is not supported with speculative tokens" - - self.context.num_speculative_tokens = num_speculative_tokens - self.controller.num_speculative_tokens = num_speculative_tokens + assert ( + not self.context.materialize_only_last_token_logits + ), "Speculative decoding requires materialize_only_last_token_logits to be False" + assert ( + self.num_speculative_tokens <= self.controller.num_mtp_heads + ), f"Number of speculative tokens {self.num_speculative_tokens} must be less than or equal to number of MTP heads {self.controller.num_mtp_heads}" + assert ( + not self.enable_chunked_prefill + ), "Chunked prefill is not supported with speculative tokens" + + self.context.num_speculative_tokens = self.num_speculative_tokens + self.controller.num_speculative_tokens = self.num_speculative_tokens # Initialize MTP sampling tensor now that num_speculative_tokens is set self.controller._init_mtp_sampling_tensor() @@ -917,16 +924,22 @@ def post_process_requests( if self.num_speculative_tokens > 0: accepted_tokens = list(filter(lambda tok: tok != -1, accepted_tokens_list)) tokens = tokens + accepted_tokens - + request: DynamicInferenceRequest = self.get_request(request_id) if request_id != self.context.chunked_prefill_request_id: # Skip appending token for requests being finished due to stop words # (they already have their final token from the previous step) # If the request already has more tokens, then we only append as much as is necessary - if len(request.generated_tokens) + len(tokens) >= request.sampling_params.num_tokens_to_generate: - tokens = tokens[:request.sampling_params.num_tokens_to_generate - len(request.generated_tokens)] + if ( + len(request.generated_tokens) + len(tokens) + >= request.sampling_params.num_tokens_to_generate + ): + tokens = tokens[ + : request.sampling_params.num_tokens_to_generate + - len(request.generated_tokens) + ] if request_id not in self.stop_word_being_finished_ids: - + is_first_token = len(request.generated_tokens) == 0 request.generated_tokens += tokens # TODO : SHAN Should check and change the following for speculative tokens @@ -1145,7 +1158,7 @@ def _check_stop_words_for_request_post_append(self, request: DynamicInferenceReq # Need to check the last stop len tokens shifting by 1 up to num_speculative_tokens # Check logic and vecotrize this if possible for i in range(self.num_speculative_tokens): - if list(generated_tokens[-stop_len - i: -i]) == stop_word_ids: + if list(generated_tokens[-stop_len - i : -i]) == stop_word_ids: return True return False diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index 089a9b5d28f..60149e57a9a 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -6,7 +6,7 @@ import functools import inspect from collections import defaultdict -from typing import Any, Dict, Iterator, List, Optional, OrderedDict, Tuple, Union +from typing import Any, Dict, List, Optional, OrderedDict, Tuple, Union import torch import torch.nn.functional as F @@ -32,7 +32,6 @@ from megatron.core.transformer.moe.moe_layer import BaseMoELayer from megatron.core.transformer.moe.router_replay import RouterReplay, RouterReplayAction from megatron.core.transformer.utils import set_model_to_sequence_parallel -from megatron.core.inference.contexts.dynamic_context import DynamicInferenceContext from megatron.core.utils import get_asyncio_loop, get_model_config, get_pg_size, unwrap_model try: @@ -147,9 +146,12 @@ def _init_mtp_sampling_tensor(self): self._sampled_mtp_tokens_cuda = torch.empty( [self.num_speculative_tokens, max_requests], dtype=torch.int64, device=device ) - self._accepted_tokens_per_request = torch.ones( - [max_requests, self.num_speculative_tokens], dtype=torch.int64, device=device - ) * -1 + self._accepted_tokens_per_request = ( + torch.ones( + [max_requests, self.num_speculative_tokens], dtype=torch.int64, device=device + ) + * -1 + ) def tokenize_prompt(self, prompt: str, add_BOS: bool = False) -> List[int]: """Utility to tokenize the input prompts. @@ -608,17 +610,20 @@ def _dynamic_step_forward_logits(self, input_ids: Tensor, position_ids: Tensor) {"tokens": input_ids, "position_ids": position_ids, "attention_mask": None} ) # [1, seq_len, vocab_size] (logits) - # [num_speculative_tokens, seq_len, vocab_size] (mtp_logits) - - if self.num_speculative_tokens > 0: - unwrapped_model = unwrap_model(self.inference_wrapped_model.model) - assert hasattr(unwrapped_model, '_mtp_logits_cache'), "MTP logits cache not found" - mtp_logits = unwrapped_model._mtp_logits_cache - expected_mtp_logits_length, _, vocab_size = mtp_logits.shape - assert expected_mtp_logits_length == self.num_mtp_heads, f"MTP logits length mismatch. Expected mtp logits length {self.num_mtp_heads}, got {expected_mtp_logits_length}" - mtp_logits = mtp_logits[:self.num_speculative_tokens] - logits = torch.cat([logits, mtp_logits], dim = 0) # [num_speculative_tokens + 1, seq_len, vocab_size] + # [num_speculative_tokens, seq_len, vocab_size] (mtp_logits) + if self.num_speculative_tokens > 0: + unwrapped_model = unwrap_model(self.inference_wrapped_model.model) + assert hasattr(unwrapped_model, '_mtp_logits_cache'), "MTP logits cache not found" + mtp_logits = unwrapped_model._mtp_logits_cache + expected_mtp_logits_length, _, vocab_size = mtp_logits.shape + assert ( + expected_mtp_logits_length == self.num_mtp_heads + ), f"MTP logits length mismatch. Expected mtp logits length {self.num_mtp_heads}, got {expected_mtp_logits_length}" + mtp_logits = mtp_logits[: self.num_speculative_tokens] + logits = torch.cat( + [logits, mtp_logits], dim=0 + ) # [num_speculative_tokens + 1, seq_len, vocab_size] if self.model_is_pipeline_parallel: logits_seq_len = ( @@ -666,10 +671,10 @@ def _dynamic_step_sample_bookkeeping(self): def _rewind_kv_cache(self): """Update the KV cache bookkeeping for speculative decoding. - + After forward pass with speculative tokens, some tokens may be rejected. This function "rewinds" the KV cache bookkeeping to reflect only the accepted tokens. - + When speculative tokens are rejected, we need to: 1. Update request_kv_length_offsets (total sequence length) 2. Update request_last_kv_block_offset (position within last block) @@ -691,7 +696,7 @@ def _rewind_kv_cache(self): num_tokens_to_rewind = self.num_speculative_tokens - accepted_tokens_per_request # For prefill requests, no speculative tokens were forwarded through the model, - # so there is nothing to rewind. + # so there is nothing to rewind. request_in_prefill_status = context.request_in_prefill_status_tensor[active_request_slice] num_tokens_to_rewind[request_in_prefill_status == 1] = 0 @@ -719,7 +724,7 @@ def _rewind_kv_cache(self): # 3. Update request_last_kv_block_id to point to the previous block # 4. Clear the entry in request_to_kv_block_ids for the released block # 5. Release the block back to the allocator - if remove_allocated_blocks_mask.any(): + if remove_allocated_blocks_mask.any(): # Get indices of requests that need to release a block (relative to active requests) requests_needing_release = torch.nonzero(remove_allocated_blocks_mask, as_tuple=True)[0] # Convert to absolute indices in the context tensors @@ -739,7 +744,7 @@ def _rewind_kv_cache(self): # Vectorized implementation using advanced indexing: # Note: new_block_counts is guaranteed to be > 0 for all requests here, since # crossing back to a previous block implies the request had at least 2 blocks. - + # Update request_last_kv_block_id to point to the previous block (at index new_count - 1) context.request_last_kv_block_id[absolute_indices] = context.request_to_kv_block_ids[ absolute_indices, new_block_counts - 1 @@ -751,63 +756,81 @@ def _rewind_kv_cache(self): # Release the blocks back to the allocator context.block_allocator.release_memory_blocks(blocks_to_release) - def _dynamic_step_sample_logits_and_verify_tokens(self, logits: Tensor, mtp_logits: Tensor, input_ids: Tensor): + # Mamba speculative rewind state update + if context.is_hybrid_model: + # TODO(ksanthanam): Maybe reset interemdiate states + pass + + def _dynamic_step_sample_logits_and_verify_tokens( + self, logits: Tensor, mtp_logits: Tensor, input_ids: Tensor + ): f"""Sample tokens from logits for dynamic batching with speculative tokens and verify the tokens. - """ + """ context = self.inference_wrapped_model.inference_context active_request_count = context.total_request_count - context.paused_request_count - - # ================ PART 1 The following part of the code is to get all the relevant logit indices alone ========= - # i.e For prefill requests just the last token logits are enough. - # i.e For decode requests we will need all tokens - # Decode request will always be on the left, followed by prefill requests - # In non speculative case, it was simple in the other function, we just always get the last token logits using query lengths. + # ================ PART 1 The following part of the code is to get all the relevant logit indices alone ========= + # i.e For prefill requests just the last token logits are enough. + # i.e For decode requests we will need all tokens + # Decode request will always be on the left, followed by prefill requests + # In non speculative case, it was simple in the other function, we just always get the last token logits using query lengths. # 5 requests # Input ids shape : [1, 15] - # Assume input ids : [ a5 a6s a7s | b3 b4s b5s | c6 c7s c8s | d1 d2 | e1 e2 e3 e4] + # Assume input ids : [ a5 a6s a7s | b3 b4s b5s | c6 c7s c8s | d1 d2 | e1 e2 e3 e4] # Request to prefill [ 0 | 0 | 0 | 1 | 1 ] # Request query lengths [ 3 | 3 | 3 | 2 | 4 ] # OUTPUT : required_logit_indices [ 0 1 2 | 3 4 5 | 6 7 8 | 10 | 14 ] - - request_in_prefill_status_tensor = context.request_in_prefill_status_tensor[context.paused_request_count : context.total_request_count] - request_query_lengths = context.request_query_lengths[context.paused_request_count : context.total_request_count] + + request_in_prefill_status_tensor = context.request_in_prefill_status_tensor[ + context.paused_request_count : context.total_request_count + ] + request_query_lengths = context.request_query_lengths[ + context.paused_request_count : context.total_request_count + ] num_prefill_requests = request_in_prefill_status_tensor.sum().item() num_decode_requests = active_request_count - num_prefill_requests - - decode_request_indices = torch.arange(num_decode_requests * (self.num_speculative_tokens + 1), device=logits.device) - prefill_request_indices = request_query_lengths.cumsum(dim=0)[request_in_prefill_status_tensor == 1] -1 # Last token indices for prefill requests - required_logit_indices = torch.cat([decode_request_indices, prefill_request_indices]) - assert len(required_logit_indices) == num_decode_requests * (self.num_speculative_tokens + 1) + num_prefill_requests, f"Expected length of required_logit_indices to be num_decode_requests * (self.num_speculative_tokens + 1) + num_prefill_requests, but got {len(required_logit_indices)} for num_decode_requests {num_decode_requests} and num_prefill_requests {num_prefill_requests}" + decode_request_indices = torch.arange( + num_decode_requests * (self.num_speculative_tokens + 1), device=logits.device + ) + prefill_request_indices = ( + request_query_lengths.cumsum(dim=0)[request_in_prefill_status_tensor == 1] - 1 + ) # Last token indices for prefill requests + required_logit_indices = torch.cat([decode_request_indices, prefill_request_indices]) + assert ( + len(required_logit_indices) + == num_decode_requests * (self.num_speculative_tokens + 1) + num_prefill_requests + ), f"Expected length of required_logit_indices to be num_decode_requests * (self.num_speculative_tokens + 1) + num_prefill_requests, but got {len(required_logit_indices)} for num_decode_requests {num_decode_requests} and num_prefill_requests {num_prefill_requests}" required_logits = logits.squeeze(0)[required_logit_indices, :] # Shape [1, 11, vocab_size] - required_mtp_logits = mtp_logits[:, required_logit_indices, :] # Shape [num_speculative_tokens, 11, vocab_size] + required_mtp_logits = mtp_logits[ + :, required_logit_indices, : + ] # Shape [num_speculative_tokens, 11, vocab_size] - # ================ PART 1 The following part of the code is to sample the logits and mtp logits based on the sampling parameters ========= + # ================ PART 1 The following part of the code is to sample the logits and mtp logits based on the sampling parameters ========= # request_indices will be 0, 1, 2, 3, 4 (since we have only 5 requests) # For torch sampling buckets :-[request_indices, temp, top_k, top_p] - # [ - # [[0,2], temp1, top_k1, top_p1], + # [ + # [[0,2], temp1, top_k1, top_p1], # [1], temp3, top_k3, top_p3] - # [3, 4], temp2, top_k2, top_p2], + # [3, 4], temp2, top_k2, top_p2], # ] # Token to request idx : [ 0 0 0 | 1 1 1 | 2 2 2 | 3 | 4 ] # required_logits : [ a5l a6l a7l | b3l b4l b5l | c6l c7l c8l | d2l | e4l ] # Shape [11, vocab_size] - # For first iteration : - # sampling buckets : [0,2], temp1, top_k1, top_p1 + # For first iteration : + # sampling buckets : [0,2], temp1, top_k1, top_p1 # output_tokens_jumbled_list = [a5s a6s a7s c6s c7s c8s] #s->sampled tokens # # request_order_list = [0, 2] # token_order_list = [0, 1, 2, 6, 7, 8] - # For second iteration : + # For second iteration : # sampling buckets : [1], temp3, top_k3, top_p3 # output_tokens_jumbled_list = [b3s b4s b5s] # request_order_list = [1] # token_order_list = [3, 4, 5] - # For third iteration : + # For third iteration : # sampling buckets : [3, 4], temp2, top_k2, top_p2 # output_tokens_jumbled_list = [d2s e4s] #s->sampled tokens # # request_order_list = [3, 4] @@ -816,9 +839,16 @@ def _dynamic_step_sample_logits_and_verify_tokens(self, logits: Tensor, mtp_logi # Final request order list : [0, 2, 1, 3, 4] # Final token order list : [0, 1, 2, 6, 7, 8, 3, 4, 5, 9, 10] - - repeats = torch.where(request_in_prefill_status_tensor == 0, 1 + self.num_speculative_tokens, 1) - token_to_request_index = torch.repeat_interleave(torch.arange(len(request_in_prefill_status_tensor), device=request_in_prefill_status_tensor.device), repeats) + repeats = torch.where( + request_in_prefill_status_tensor == 0, 1 + self.num_speculative_tokens, 1 + ) + token_to_request_index = torch.repeat_interleave( + torch.arange( + len(request_in_prefill_status_tensor), + device=request_in_prefill_status_tensor.device, + ), + repeats, + ) output_tokens_jumbled_list = [] mtp_output_tokens_jumbled_list = [] @@ -826,93 +856,120 @@ def _dynamic_step_sample_logits_and_verify_tokens(self, logits: Tensor, mtp_logi # TODO : Maybe its okay to have a loop with num spec tokens ? (Since it will only be max 3 , so might be faster) for request_indices, temp, top_k, top_p in self._torch_sampling_buckets: - request_indices_tensor = torch.tensor(request_indices, device=token_to_request_index.device) - required_indices = torch.where(torch.isin(token_to_request_index, request_indices_tensor))[0] + request_indices_tensor = torch.tensor( + request_indices, device=token_to_request_index.device + ) + required_indices = torch.where( + torch.isin(token_to_request_index, request_indices_tensor) + )[0] # TODO : Can maybe club the following two and then split later ? # TODO : Can directly initzlie output tokens as a tensor and put the logits in the right place - output_tokens_jumbled_list.append(self._torch_sampling_func(required_logits[required_indices, :], temp, top_k, top_p)) + output_tokens_jumbled_list.append( + self._torch_sampling_func(required_logits[required_indices, :], temp, top_k, top_p) + ) mtp_output_tokens_jumbled_list.append( - self._torch_sampling_func(required_mtp_logits[:, required_indices, :], temp, top_k, top_p) + self._torch_sampling_func( + required_mtp_logits[:, required_indices, :], temp, top_k, top_p + ) ) - token_order_list.append(required_indices) - - + token_order_list.append(required_indices) output_tokens_jumbled = torch.cat(output_tokens_jumbled_list, dim=0) - output_tokens = torch.empty(len(output_tokens_jumbled), device=output_tokens_jumbled.device, dtype=output_tokens_jumbled.dtype) + output_tokens = torch.empty( + len(output_tokens_jumbled), + device=output_tokens_jumbled.device, + dtype=output_tokens_jumbled.dtype, + ) token_order = torch.cat(token_order_list, dim=0) # Rearrange output tokens because previously it will be in the order of the sampling_bucket request indices, but now we want to put them according to their corresponding input ids output_tokens[token_order] = output_tokens_jumbled - mtp_output_tokens_jumbled = torch.cat(mtp_output_tokens_jumbled_list, dim=1) # Shape [num_speculative_tokens, total_tokens] + mtp_output_tokens_jumbled = torch.cat( + mtp_output_tokens_jumbled_list, dim=1 + ) # Shape [num_speculative_tokens, total_tokens] mtp_output_tokens = torch.empty_like(mtp_output_tokens_jumbled) mtp_output_tokens[:, token_order] = mtp_output_tokens_jumbled ### ================ PART 3 This part is to do the fowlling : ================ - # Create the accepted tokens tensor + # Create the accepted tokens tensor # For prefill it is always set to 1 # For decode, the first token is always accepted, then we compare with input tokens and accept the next tokens if its a match # Then find the index of the last 1 in every request of the accepted tokens tensor # Then these are the index of the tokens that will be sent to the next forward pass # In the example (assume 1 spec token, 2 spec tokens and 0 sepc tokens are accepted in teh first 3 requests - - # Assume input ids : [ a5 a6s a7s | b3 b4s b5s | c6 c7s c8s | d1 d2 | e1 e2 e3 e4] + # Assume input ids : [ a5 a6s a7s | b3 b4s b5s | c6 c7s c8s | d1 d2 | e1 e2 e3 e4] # input_tokens_required: [ a5 a6s a7s | b3 b4s b5s | c6 c7s c8s | d2 | e4 ] # Size 11 # Output tokens [ a6o a7o a8o | b40 b5o b6o | c7o c8o c9o | d3o | e5o ] # At every index we get next positions sample - # Output tokens right shift [ d3o a6o a7o | a8o b40 b5o | b6o c7o c8o | c9o | d3o ] - # Accepted tokens mask [ 1 1 0 | 1 1 1 | 1 0 0 | 1 | 1 ] + # Output tokens right shift [ d3o a6o a7o | a8o b40 b5o | b6o c7o c8o | c9o | d3o ] + # Accepted tokens mask [ 1 1 0 | 1 1 1 | 1 0 0 | 1 | 1 ] # Last one indices [ 1 | 5 | 6 | 9 | 10 ] - input_tokens_required = input_ids[0, required_logit_indices] if input_tokens_required.ndim == 2: - assert input_tokens_required.shape[0] == 1, f"Expected input_tokens_required to have 1 row, but got {input_tokens_required.shape}" - input_tokens_required = input_tokens_required.squeeze(0) + assert ( + input_tokens_required.shape[0] == 1 + ), f"Expected input_tokens_required to have 1 row, but got {input_tokens_required.shape}" + input_tokens_required = input_tokens_required.squeeze(0) # This is to get the place where the output sampled speculative token is equal to input token - output_right_shifted = output_tokens.roll(1) - accepted_tokens_mask = input_tokens_required == output_right_shifted + output_right_shifted = output_tokens.roll(1) + accepted_tokens_mask = input_tokens_required == output_right_shifted # This is to make all prefill tokens accepted - token_to_prefill_idx = torch.repeat_interleave(request_in_prefill_status_tensor, repeats) - accepted_tokens_mask[token_to_prefill_idx == 1] = 1 + token_to_prefill_idx = torch.repeat_interleave(request_in_prefill_status_tensor, repeats) + accepted_tokens_mask[token_to_prefill_idx == 1] = 1 # This is to make first decode token in all requests accepted - deocde_query_starts = torch.arange(num_decode_requests) * (1 + self.num_speculative_tokens) - accepted_tokens_mask[deocde_query_starts] = 1 + deocde_query_starts = torch.arange(num_decode_requests) * (1 + self.num_speculative_tokens) + accepted_tokens_mask[deocde_query_starts] = 1 # This is to find the index of the last 1 in every request - last_one_indices = torch.full((active_request_count,), -1, device=token_to_request_index.device) - last_one_indices[token_to_request_index[accepted_tokens_mask == 1]] = torch.where(accepted_tokens_mask == 1)[0] # [1, 5, 6] + last_one_indices = torch.full( + (active_request_count,), -1, device=token_to_request_index.device + ) + last_one_indices[token_to_request_index[accepted_tokens_mask == 1]] = torch.where( + accepted_tokens_mask == 1 + )[ + 0 + ] # [1, 5, 6] # These are the tokens (output + speculative tokens) that will be going to the next forward pass final_sampled_tokens = output_tokens[last_one_indices] - self._sampled_tokens_cuda[:len(final_sampled_tokens)] = final_sampled_tokens - self._sampled_mtp_tokens_cuda[:, :len(final_sampled_tokens)] = mtp_output_tokens[:, last_one_indices] + self._sampled_tokens_cuda[: len(final_sampled_tokens)] = final_sampled_tokens + self._sampled_mtp_tokens_cuda[:, : len(final_sampled_tokens)] = mtp_output_tokens[ + :, last_one_indices + ] ### ================ PART 4 This part is to do the fowlling : ================ - # To fill the speculative otkens and accepted_token counts + # To fill the speculative otkens and accepted_token counts # For prefill it is always set to 1 # For decode, the first token is always accepted, then we compare with input tokens and accept the next tokens if its a match # Then find the index of the last 1 in every request of the accepted tokens tensor # Then these are the index of the tokens that will be sent to the next forward pass # In the example (assume 1 spec token, 2 spec tokens and 0 sepc tokens are accepted in teh first 3 requests - # input_tokens_required: [ a5 a6s a7s | b3 b4s b5s | c6 c7s c8s | d2 | e4 ] # Size 11 - # Accepted tokens mask [ 1 1 0 | 1 1 1 | 1 0 0 | 1 | 1 ] + # Accepted tokens mask [ 1 1 0 | 1 1 1 | 1 0 0 | 1 | 1 ] # Accepted tokens [ [a6s -1] | [b4s b5s] | [-1 -1] ] # Only handle decod requests, (Prefill already defaults to -1s) # Accepted token counts [ 1 | 2 | 0 ] # Prefill defaults to 0 # This part tis to extract the accepted tokens - input_tokens_required[accepted_tokens_mask == 0 ] = -1 # Masks out non accepted tokens - input_tokens_decode_mode = input_tokens_required[:num_decode_requests * (self.num_speculative_tokens + 1)] - input_tokens_reshaped = input_tokens_decode_mode.reshape(-1, self.num_speculative_tokens + 1) # shape : [num_decode_requests, num_speculative_tokens + 1] - - accepted_tokens = input_tokens_reshaped[: , 1:] # Skip the first token of every decode request (i.e a5, b3, c6) - self._accepted_tokens_per_request[:accepted_tokens.shape[0],:] = accepted_tokens - self._accepted_token_counts_per_request = (self._accepted_tokens_per_request != -1).sum(dim=1) + input_tokens_required[accepted_tokens_mask == 0] = -1 # Masks out non accepted tokens + input_tokens_decode_mode = input_tokens_required[ + : num_decode_requests * (self.num_speculative_tokens + 1) + ] + input_tokens_reshaped = input_tokens_decode_mode.reshape( + -1, self.num_speculative_tokens + 1 + ) # shape : [num_decode_requests, num_speculative_tokens + 1] + + accepted_tokens = input_tokens_reshaped[ + :, 1: + ] # Skip the first token of every decode request (i.e a5, b3, c6) + self._accepted_tokens_per_request[: accepted_tokens.shape[0], :] = accepted_tokens + self._accepted_token_counts_per_request = (self._accepted_tokens_per_request != -1).sum( + dim=1 + ) def _dynamic_step_sample_logits(self, logits: Tensor): """Sample tokens from logits for dynamic batching. @@ -938,13 +995,15 @@ def _dynamic_step_sample_logits(self, logits: Tensor): token_list = [] indices_list = [] - # e.g torch sample buckets will be + # e.g torch sample buckets will be # i.e (for all unique comibnation of t, topk, topk what are the associated requests indices (based on the active slices) # [ [req at index 0, req at index 2], t1, topk1, topp1 ]] # [ [req at index 1, req at index 3, req at index 4] , t2, topk2, topp2] for indices, temp, top_k, top_p in self._torch_sampling_buckets: token_list.append( - self._torch_sampling_func(required_token_indices[indices, :], temp, top_k, top_p) + self._torch_sampling_func( + required_token_indices[indices, :], temp, top_k, top_p + ) ) indices_list.append(torch.tensor(indices)) @@ -1177,7 +1236,9 @@ def _dynamic_step_context_bookkeeping(self) -> Dict[str, Tensor]: active_sequence_lengths = context.get_active_sequence_lengths() if self.num_speculative_tokens > 0: - active_sequence_lengths += self._accepted_token_counts_per_request[:active_request_count] + 1 + active_sequence_lengths += ( + self._accepted_token_counts_per_request[:active_request_count] + 1 + ) else: active_sequence_lengths += 1 max_sequence_lengths = context.get_max_sequence_lengths() @@ -1213,7 +1274,9 @@ def _dynamic_step_context_bookkeeping(self) -> Dict[str, Tensor]: sampled_mtp_tokens_cuda = self._sampled_mtp_tokens_cuda[:, :active_request_count] else: sampled_mtp_tokens_cuda = None - update_result = context.update_requests(active_request_mask, new_sample_copy, sampled_mtp_tokens_cuda) + update_result = context.update_requests( + active_request_mask, new_sample_copy, sampled_mtp_tokens_cuda + ) return { "active_request_ids": active_request_ids, @@ -1221,7 +1284,6 @@ def _dynamic_step_context_bookkeeping(self) -> Dict[str, Tensor]: **(update_result or {}), } - @torch.inference_mode() async def async_generate_output_tokens_dynamic_batch( self, skip_bookkeeping: Optional[bool] = False @@ -1253,7 +1315,6 @@ async def async_generate_output_tokens_dynamic_batch( context.padded_active_request_count if context.is_decode_only() else None ) - # Enable routing recording before forward pass if routing replay is enabled config = self.inference_wrapped_model.model.config if config.moe_enable_routing_replay: @@ -1262,11 +1323,11 @@ async def async_generate_output_tokens_dynamic_batch( logits_and_mtp_logits = self._dynamic_step_forward_logits(input_ids, position_ids) mtp_logits = None if logits_and_mtp_logits.shape[0] > 1: - logits = logits_and_mtp_logits[:1] # [1, seq_len, vocab_size] - mtp_logits = logits_and_mtp_logits[1:] # [num_speculative_tokens, seq_len, vocab_size]\ + logits = logits_and_mtp_logits[:1] # [1, seq_len, vocab_size] + mtp_logits = logits_and_mtp_logits[1:] # [num_speculative_tokens, seq_len, vocab_size]\ else: logits = logits_and_mtp_logits - + # Collect routing indices per request (must be done before context transitions) routing_indices_per_request = self._router_record_bookkeeping() @@ -1280,7 +1341,9 @@ async def async_generate_output_tokens_dynamic_batch( await asyncio.sleep(0) return_log_probs, return_top_n_logprobs = self._dynamic_step_log_probs_bookkeeping() if self.num_speculative_tokens > 0: - assert return_log_probs == False and return_top_n_logprobs == False, "Log probs and top n log probs are not supported with speculative tokens" + assert ( + return_log_probs == False and return_top_n_logprobs == False + ), "Log probs and top n log probs are not supported with speculative tokens" self._dynamic_step_sample_bookkeeping() @@ -1290,7 +1353,6 @@ async def async_generate_output_tokens_dynamic_batch( else: self._dynamic_step_sample_logits(logits) - log_probs = None top_n_logprobs = None if return_log_probs or return_top_n_logprobs: @@ -1303,11 +1365,11 @@ async def async_generate_output_tokens_dynamic_batch( if skip_bookkeeping: request_bookkeeping = {} else: - request_bookkeeping = self._dynamic_step_context_bookkeeping() + request_bookkeeping = self._dynamic_step_context_bookkeeping() ret = { "sample": self._sampled_tokens_cuda[:active_request_count], - "accepted_tokens": self._accepted_tokens_per_request, + "accepted_tokens": self._accepted_tokens_per_request, "log_probs": log_probs, "top_n_logprobs": top_n_logprobs, "routing_indices_per_request": routing_indices_per_request, diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py index 8852437d89f..55ae6d79ff9 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py @@ -612,7 +612,9 @@ def _postprocess( # The new process_mtp_loss function doesn't handle mtp_logits_cache, # so we manually generate and cache MTP logits when in inference mode. if in_inference_mode: - hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0) + hidden_states_list = torch.chunk( + hidden_states, 1 + self.config.mtp_num_layers, dim=0 + ) hidden_states = hidden_states_list[0] self._mtp_logits_cache = None mtp_inference_logits = [] diff --git a/megatron/core/ssm/mamba_mixer.py b/megatron/core/ssm/mamba_mixer.py index d8ed441880f..6ee50b36619 100644 --- a/megatron/core/ssm/mamba_mixer.py +++ b/megatron/core/ssm/mamba_mixer.py @@ -477,11 +477,19 @@ def _dynamic_inference(self, hidden_states: torch.Tensor, context: DynamicInfere # For mixed batch, the decode tokens are at the start of zxBCdt zxBCdt_decode = zxBCdt[:decode_req_count] if prefill_req_count > 0 else zxBCdt + num_accepted_tokens = context.mamba_metadata.num_accepted_tokens + intermediate_conv_state, intermediate_ssm_state = context.mamba_states_cache( + self.layer_number - self.pp_layer_offset, intermediate=True + ) + y_decode = self._ssm_decode( zxBCdt_decode.transpose(0, 1), conv_state, ssm_state, context.mamba_metadata.batch_indices_decode, + num_accepted_tokens=num_accepted_tokens, + intermediate_conv_window=intermediate_conv_state, + intermediate_ssm_state=intermediate_ssm_state, ).transpose(0, 1) # Prefill @@ -903,6 +911,9 @@ def _ssm_decode( conv_state: torch.Tensor, ssm_state: torch.Tensor, batch_indices: Optional[torch.Tensor] = None, + num_accepted_tokens: Optional[torch.Tensor] = None, + intermediate_conv_state: Optional[torch.Tensor] = None, + intermediate_ssm_state: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Performs SSM computation for inference decode step. @@ -920,10 +931,19 @@ def _ssm_decode( """ seq_len, batch_size, _ = zxBCdt.shape dtype = zxBCdt.dtype - assert seq_len == 1, "Only support decoding with 1 token at a time for now" + if seq_len > 1: + assert ( + num_accepted_tokens is not None + and intermediate_conv_window is not None + and intermediate_ssm_state is not None + ), "Decoding with > 1 token per request requires speculative decoding state" + is_speculative_decoding = True + else: + is_speculative_decoding = False - # Remove sequence dimension - zxBCdt = zxBCdt.squeeze(0) + if not is_speculative_decoding: + # Remove sequence dimension + zxBCdt = zxBCdt.squeeze(0) z, xBC, dt = torch.split( zxBCdt, @@ -953,6 +973,10 @@ def _ssm_decode( self.conv1d.bias, self.activation, conv_state_indices=batch_indices, + num_accepted_tokens=num_accepted_tokens, + intermediate_conv_window=intermediate_conv_window, + intermediate_state_indices=batch_indices, + pad_slot_id=-1, ) x, B, C = torch.split( @@ -968,6 +992,7 @@ def _ssm_decode( # SSM step if selective_state_update is None: + assert not is_speculative_decode if self.ngroups_local_tp > 1: B = rearrange(B, "b (g n) -> b g n", n=self.d_state) C = rearrange(C, "b (g n) -> b g n", n=self.d_state) @@ -1014,14 +1039,22 @@ def _ssm_decode( y = y * self.act(z) # (B D) else: A = repeat(A, "h -> h p n", p=self.headdim, n=self.d_state).to(dtype=torch.float32) - dt = repeat(dt, "b h -> b h p", p=self.headdim) dt_bias = repeat(self.dt_bias, "h -> h p", p=self.headdim) D = repeat(self.D, "h -> h p", p=self.headdim) - B = rearrange(B, "b (g n) -> b g n", g=self.ngroups_local_tp) - C = rearrange(C, "b (g n) -> b g n", g=self.ngroups_local_tp) - x_reshaped = rearrange(x, "b (h p) -> b h p", p=self.headdim) - if not self.rmsnorm: - z = rearrange(z, "b (h p) -> b h p", p=self.headdim) + if is_speculative_deocode: + dt = repeat(dt, "s b h -> s b h p", p=self.headdim) + B = rearrange(B, "s b (g n) -> s b g n", g=self.ngroups_local_tp) + C = rearrange(C, "s b (g n) -> s b g n", g=self.ngroups_local_tp) + x_reshaped = rearrange(x, "s b (h p) -> s b h p", p=self.headdim) + if not self.rmsnorm: + z = rearrange(z, "s b (h p) -> s b h p", p=self.headdim) + else: + dt = repeat(dt, "b h -> b h p", p=self.headdim) + B = rearrange(B, "b (g n) -> b g n", g=self.ngroups_local_tp) + C = rearrange(C, "b (g n) -> b g n", g=self.ngroups_local_tp) + x_reshaped = rearrange(x, "b (h p) -> b h p", p=self.headdim) + if not self.rmsnorm: + z = rearrange(z, "b (h p) -> b h p", p=self.headdim) # Upcast the batch_indices to prevent integer overflow errors in the case of # large max request counts. @@ -1040,8 +1073,15 @@ def _ssm_decode( dt_bias=dt_bias, dt_softplus=True, state_batch_indices=batch_indices, + disable_state_update=True, + intermediate_states_buffer=intermediate_ssm_state, + cache_steps=seq_len, + intermediate_state_indices=batch_indices, ) - y = rearrange(y, "b h p -> b (h p)") + if is_speculative_decode: + y = rearrange(y, "s b h p -> s b (h p)") + else: + y = rearrange(y, "b h p -> b (h p)") if self.rmsnorm: y = self.norm(y, z) diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index bd6143409e7..e98b3fb8616 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1676,10 +1676,11 @@ def _add_inference_args(parser): '1) allocate `memory_buffer` in unified memory. ' 'Eventually, additional levels will be included to ' 'control other tensors within the context.') - # TODO(ksanthanam): Clean this up in future PR group.add_argument('--enable-chunked-prefill', dest='enable_chunked_prefill', action='store_true', default=False, help="Enable chunked prefill (disabled by default)") + group.add_argument('--num-speculative-tokens', type=int, default=0, + help='Number of speculative tokens generated during decode') group.add_argument('--inference-dynamic-batching-cuda-graph-max-tokens', type=int, default=16384, help='Maximum number of tokens to capture in a cuda graph.') From 397fd542be6e32b74fb153fb4175ce96224f488c Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Tue, 24 Feb 2026 10:03:59 -0800 Subject: [PATCH 08/76] WIP debugging Signed-off-by: Keshav Santhanam --- megatron/core/inference/contexts/dynamic_context.py | 2 +- megatron/core/inference/engines/dynamic_engine.py | 2 -- .../text_generation_controller.py | 2 +- megatron/core/ssm/mamba_mixer.py | 10 ++++++++-- 4 files changed, 10 insertions(+), 6 deletions(-) diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index e2ceb3b3b68..ce66cd443c7 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -262,7 +262,7 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC pp_size = model_config.pipeline_model_parallel_size self.hidden_size_per_attention_head = core_divide(projection_size, num_attention_heads) self.num_attention_heads_per_partition = core_divide(num_attention_heads, tp_size) - self.num_speculative_tokens = 0 + self.num_speculative_tokens = inference_config.num_speculative_tokens # Cache the PP group we should use for PP collectives inside the context. # If the model provides a pg_collection with a pp group, prefer it. diff --git a/megatron/core/inference/engines/dynamic_engine.py b/megatron/core/inference/engines/dynamic_engine.py index 6e59c34c98a..852558dc396 100644 --- a/megatron/core/inference/engines/dynamic_engine.py +++ b/megatron/core/inference/engines/dynamic_engine.py @@ -188,8 +188,6 @@ def __init__(self, controller: TextGenerationController, context: DynamicInferen not self.enable_chunked_prefill ), "Chunked prefill is not supported with speculative tokens" - self.context.num_speculative_tokens = self.num_speculative_tokens - self.controller.num_speculative_tokens = self.num_speculative_tokens # Initialize MTP sampling tensor now that num_speculative_tokens is set self.controller._init_mtp_sampling_tensor() diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index 60149e57a9a..a0d22fc844d 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -62,7 +62,7 @@ def __init__(self, inference_wrapped_model: AbstractModelInferenceWrapper, token self.model_config = self.inference_wrapped_model.model.config inference_config = self.inference_wrapped_model.inference_context.config self.tokenizer = tokenizer - self.num_speculative_tokens = None + self.num_speculative_tokens = inference_config.num_speculative_tokens pg_collection = inference_config.pg_collection if pg_collection is not None: diff --git a/megatron/core/ssm/mamba_mixer.py b/megatron/core/ssm/mamba_mixer.py index 14a2b6af8c7..d30eb45de7d 100644 --- a/megatron/core/ssm/mamba_mixer.py +++ b/megatron/core/ssm/mamba_mixer.py @@ -44,13 +44,16 @@ from .mamba_context_parallel import MambaContextParallel try: - from mamba_ssm.ops.triton.selective_state_update import selective_state_update + #from mamba_ssm.ops.triton.selective_state_update import selective_state_update + from megatron.core.ssm.ops.mamba_ssm import selective_state_update except ImportError: selective_state_update = None try: - from causal_conv1d import causal_conv1d_fn, causal_conv1d_update + # from causal_conv1d import causal_conv1d_fn, causal_conv1d_update + from causal_conv1d import causal_conv1d_fn from causal_conv1d.causal_conv1d_varlen import causal_conv1d_varlen_states + from megatron.core.ssm.ops.causal_conv1d_triton import causal_conv1d_update except ImportError: causal_conv1d_fn = None causal_conv1d_update = None @@ -439,6 +442,9 @@ def _dynamic_inference(self, hidden_states: torch.Tensor, context: DynamicInfere y_decode = None y_prefill = None + + if self.layer_number == 1: + torch.distributed.breakpoint(0) # Decode if decode_req_count > 0: From af604024657dc180565cf424bc294c573d05ae2f Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Tue, 24 Feb 2026 12:55:36 -0800 Subject: [PATCH 09/76] Add SGLang kernels Signed-off-by: Keshav Santhanam --- megatron/core/ssm/ops/__init__.py | 0 megatron/core/ssm/ops/causal_conv1d_triton.py | 1187 +++++++++++++++++ megatron/core/ssm/ops/mamba_ssm.py | 494 +++++++ 3 files changed, 1681 insertions(+) create mode 100644 megatron/core/ssm/ops/__init__.py create mode 100644 megatron/core/ssm/ops/causal_conv1d_triton.py create mode 100644 megatron/core/ssm/ops/mamba_ssm.py diff --git a/megatron/core/ssm/ops/__init__.py b/megatron/core/ssm/ops/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/megatron/core/ssm/ops/causal_conv1d_triton.py b/megatron/core/ssm/ops/causal_conv1d_triton.py new file mode 100644 index 00000000000..c82f4d730fa --- /dev/null +++ b/megatron/core/ssm/ops/causal_conv1d_triton.py @@ -0,0 +1,1187 @@ +# Copyright (c) 2024, Tri Dao. +# Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py +# and https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py + +from typing import List, Optional, Union + +import torch +import triton +import triton.language as tl + +PAD_SLOT_ID = -1 + + +@triton.jit() +def _causal_conv1d_fwd_kernel( # continuous batching + # Pointers to matrices + x_ptr, # (dim, cu_seqlen) holding `batch` of actual sequences + padded sequences + w_ptr, # (dim, width) + bias_ptr, + initial_states_ptr, # conv_states_ptr + cache_indices_ptr, # conv_state_indices_ptr + has_initial_states_ptr, + query_start_loc_ptr, + o_ptr, # (dim, seqlen) - actually pointing to x_ptr + # Matrix dimensions + dim: tl.constexpr, + seqlen: tl.int32, # cu_seqlen + num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines + # Strides + stride_x_seq: tl.constexpr, # stride to get to next sequence, + stride_x_dim: tl.constexpr, # stride to get to next feature-value, + stride_x_token: tl.constexpr, # stride to get to next token (same feature-index, same sequence-index) + stride_w_dim: tl.constexpr, # stride to get to next dim-axis value + stride_w_width: tl.constexpr, # stride to get to next width-axis value + stride_istate_seq: tl.constexpr, + stride_istate_dim: tl.constexpr, + stride_istate_token: tl.constexpr, + stride_o_seq: tl.constexpr, + stride_o_dim: tl.constexpr, + stride_o_token: tl.constexpr, + # others + pad_slot_id: tl.constexpr, + # Meta-parameters + HAS_BIAS: tl.constexpr, + KERNEL_WIDTH: tl.constexpr, + SILU_ACTIVATION: tl.constexpr, + HAS_INITIAL_STATES: tl.constexpr, + HAS_CACHE: tl.constexpr, + IS_CONTINUOUS_BATCHING: tl.constexpr, + USE_PAD_SLOT: tl.constexpr, + NP2_STATELEN: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + conv_states_ptr = initial_states_ptr + conv_state_indices_ptr = cache_indices_ptr + stride_conv_state_seq = stride_istate_seq + stride_conv_state_dim = stride_istate_dim + stride_conv_state_tok = stride_istate_token + state_len = ( + KERNEL_WIDTH - 1 + ) # can be passed via argument if it's not the same as this value + + # one program handles one chunk in a single sequence + # rather than mixing sequences - to make updating initial_states across sequences efficiently + + # single-sequence id + idx_seq = tl.program_id(0) + chunk_offset = tl.program_id(1) + + # BLOCK_N elements along the feature-dimension (channel) + idx_feats = tl.program_id(2) * BLOCK_N + tl.arange(0, BLOCK_N) + + if idx_seq == pad_slot_id: + return + + sequence_start_index = tl.load(query_start_loc_ptr + idx_seq) + sequence_end_index = tl.load(query_start_loc_ptr + idx_seq + 1) + # find the actual sequence length + seqlen = sequence_end_index - sequence_start_index + + token_offset = BLOCK_M * chunk_offset + segment_len = min(BLOCK_M, seqlen - token_offset) + + if segment_len <= 0: + return + + # base of the sequence + x_base = ( + x_ptr + sequence_start_index * stride_x_token + idx_feats * stride_x_dim + ) # [BLOCK_N,] + + if IS_CONTINUOUS_BATCHING: + # cache_idx + conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq).to(tl.int64) + else: + # cache_idx + conv_state_batch_coord = idx_seq + if USE_PAD_SLOT: # noqa + if conv_state_batch_coord == pad_slot_id: + # not processing as this is not the actual sequence + return + conv_states_base = ( + conv_states_ptr + + (conv_state_batch_coord * stride_conv_state_seq) + + (idx_feats * stride_conv_state_dim) + ) # [BLOCK_N,] + + w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,] + + # Does 2 things: + # 1. READ prior-block init-state data - [done by every Triton programs] + # 2. update conv_state with new data [only by the Triton program handles chunk_offset=0] + if chunk_offset == 0: + # read from conv_states + load_init_state = False + if HAS_INITIAL_STATES: # the new HAS_INITIAL_STATES + load_init_state = tl.load(has_initial_states_ptr + idx_seq).to(tl.int1) + if load_init_state: + # load from conv_states + prior_tokens = conv_states_base + (state_len - 1) * stride_conv_state_tok + mask_w = idx_feats < dim + if KERNEL_WIDTH == 2: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH == 3: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col1 = tl.load(conv_states_ptrs, mask_w, 0.0) + conv_states_ptrs = prior_tokens - 1 * stride_conv_state_tok # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH == 4: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col2 = tl.load(conv_states_ptrs, mask_w, 0.0) + conv_states_ptrs = prior_tokens - 1 * stride_conv_state_tok # [BLOCK_N] + col1 = tl.load(conv_states_ptrs, mask_w, 0.0) + conv_states_ptrs = prior_tokens - 2 * stride_conv_state_tok # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH == 5: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col3 = tl.load(conv_states_ptrs, mask_w, 0.0) + conv_states_ptrs = prior_tokens - 1 * stride_conv_state_tok # [BLOCK_N] + col2 = tl.load(conv_states_ptrs, mask_w, 0.0) + conv_states_ptrs = prior_tokens - 2 * stride_conv_state_tok # [BLOCK_N] + col1 = tl.load(conv_states_ptrs, mask_w, 0.0) + conv_states_ptrs = prior_tokens - 3 * stride_conv_state_tok # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0) + else: + # prior-tokens are zeros + if KERNEL_WIDTH >= 2: # STRATEGY1 + # first chunk and does not have prior-token, so just set to 0 + col0 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty) + if KERNEL_WIDTH >= 3: # STRATEGY1 + col1 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty) + if KERNEL_WIDTH >= 4: # STRATEGY1 + col2 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty) + if KERNEL_WIDTH >= 5: # STRATEGY1 + col3 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty) + + # STEP 2: + # here prepare data for updating conv_state + if ( + state_len <= seqlen + ): # SMALL_CACHE=True (only move part of 'x' into conv_state cache) + # just read from 'x' + # copy 'x' data to conv_state + # load only 'x' data (and set 0 before 'x' if seqlen < state_len) + idx_tokens_last = (seqlen - state_len) + tl.arange( + 0, NP2_STATELEN + ) # [BLOCK_M] + x_ptrs = ( + x_ptr + + ((sequence_start_index + idx_tokens_last) * stride_x_token)[:, None] + + (idx_feats * stride_x_dim)[None, :] + ) # [BLOCK_M,BLOCK_N,] + mask_x = ( + (idx_tokens_last >= 0)[:, None] + & (idx_tokens_last < seqlen)[:, None] + & (idx_feats < dim)[None, :] + ) # token-index # token-index # feature-index + loaded_x = tl.load(x_ptrs, mask_x, 0.0) + new_conv_state = tl.load(x_ptrs, mask_x, 0.0) + idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M] + conv_states_ptrs_target = ( + conv_states_base[None, :] + + (idx_tokens_conv * stride_conv_state_tok)[:, None] + ) + + mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[None, :] + tl.debug_barrier() # NOTE: use this due to bug in Triton compiler + tl.store(conv_states_ptrs_target, new_conv_state, mask) + + else: + if load_init_state: + # update conv_state by shifting left, i.e. take last few cols from conv_state + cols from 'x' + idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M] + + conv_states_ptrs_source = ( + conv_states_ptr + + (conv_state_batch_coord * stride_conv_state_seq) + + (idx_feats * stride_conv_state_dim)[None, :] + + ((idx_tokens_conv + seqlen) * stride_conv_state_tok)[:, None] + ) # [BLOCK_M, BLOCK_N] + mask = ( + (conv_state_batch_coord < num_cache_lines) + & ((idx_tokens_conv + seqlen) < state_len)[:, None] + & (idx_feats < dim)[None, :] + ) + conv_state = tl.load(conv_states_ptrs_source, mask, other=0.0) + + VAL = state_len - seqlen + + x_ptrs = ( + x_base[None, :] + + ((idx_tokens_conv - VAL) * stride_x_token)[:, None] + ) # [BLOCK_M, BLOCK_N] + + mask_x = ( + (idx_tokens_conv - VAL >= 0)[:, None] + & (idx_tokens_conv - VAL < seqlen)[:, None] + & (idx_feats < dim)[None, :] + ) # token-index # token-index # feature-index + loaded_x = tl.load(x_ptrs, mask_x, 0.0) + + tl.debug_barrier() # need this due to the bug in tl.where not enforcing this when data is the result of another tl.load + new_conv_state = tl.where( + mask, conv_state, loaded_x + ) # BUG in 'tl.where' which requires a barrier before this + conv_states_ptrs_target = ( + conv_states_base + + (idx_tokens_conv * stride_conv_state_tok)[:, None] + ) # [BLOCK_M, BLOCK_N] + mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[ + None, : + ] + tl.store(conv_states_ptrs_target, new_conv_state, mask) + else: # load_init_state == False + # update conv_state by shifting left, BUT + # set cols prior to 'x' as zeros + cols from 'x' + idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M] + + VAL = state_len - seqlen + + x_ptrs = ( + x_base[None, :] + + ((idx_tokens_conv - VAL) * stride_x_token)[:, None] + ) # [BLOCK_M, BLOCK_N] + + mask_x = ( + (idx_tokens_conv - VAL >= 0)[:, None] + & (idx_tokens_conv - VAL < seqlen)[:, None] + & (idx_feats < dim)[None, :] + ) # token-index # token-index # feature-index + new_conv_state = tl.load(x_ptrs, mask_x, 0.0) + + conv_states_ptrs_target = ( + conv_states_base + + (idx_tokens_conv * stride_conv_state_tok)[:, None] + ) # [BLOCK_M, BLOCK_N] + mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[ + None, : + ] + tl.store(conv_states_ptrs_target, new_conv_state, mask) + + else: # chunk_offset > 0 + # read prior-token data from `x` + load_init_state = True + prior_tokens = x_base + (token_offset - 1) * stride_x_token + mask_w = idx_feats < dim + if KERNEL_WIDTH == 2: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") + if KERNEL_WIDTH == 3: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") + conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") + if KERNEL_WIDTH == 4: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col2 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") + conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N] + col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") + conv_states_ptrs = prior_tokens - 2 * stride_x_token # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") + if KERNEL_WIDTH == 5: + # ruff: noqa: F841 + conv_states_ptrs = prior_tokens # [BLOCK_N] + col3 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") + conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N] + col2 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") + conv_states_ptrs = prior_tokens - 2 * stride_x_token # [BLOCK_N] + col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") + conv_states_ptrs = prior_tokens - 3 * stride_x_token # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") + + if HAS_BIAS: + bias = bias_ptr + idx_feats + mask_bias = idx_feats < dim + acc_preload = tl.load(bias, mask=mask_bias, other=0.0).to( + tl.float32 + ) # [BLOCK_N] + else: + acc_preload = tl.zeros((BLOCK_N,), dtype=tl.float32) + + x_base_1d = x_base + token_offset * stride_x_token # starting of chunk + + # PRE-LOAD WEIGHTS + mask_w = idx_feats < dim + if KERNEL_WIDTH >= 2: + w_ptrs = w_base + (0 * stride_w_width) # [BLOCK_N] tensor + w_col0 = tl.load(w_ptrs, mask_w, other=0.0) + w_ptrs = w_base + (1 * stride_w_width) # [BLOCK_N] tensor + w_col1 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 3: + w_ptrs = w_base + (2 * stride_w_width) # [BLOCK_N] tensor + w_col2 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 4: + w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor + w_col3 = tl.load(w_ptrs, mask_w, other=0.0) + mask_x_1d = idx_feats < dim + for idx_token in range(segment_len): + acc = acc_preload + + matrix_w = w_col0 + matrix_x = col0 + for j in tl.static_range(KERNEL_WIDTH): + + if KERNEL_WIDTH == 2: + if j == 1: # KERNEL_WIDTH-1: + matrix_w = w_col1 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 3: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 4: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + matrix_x = col2 + elif j == 3: + matrix_w = w_col3 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + + acc += matrix_x * matrix_w # [BLOCK_N] + + if KERNEL_WIDTH == 2: + col0 = matrix_x + elif KERNEL_WIDTH == 3: + col0 = col1 + col1 = matrix_x + elif KERNEL_WIDTH == 4: + col0 = col1 + col1 = col2 + col2 = matrix_x + + if SILU_ACTIVATION: + acc = acc / (1 + tl.exp(-acc)) + mask_1d = (idx_token < segment_len) & ( + idx_feats < dim + ) # token-index # feature-index + o_ptrs = ( + o_ptr + + (sequence_start_index + token_offset + idx_token) * stride_o_token + + (idx_feats * stride_o_dim) + ) + + tl.store(o_ptrs, acc, mask=mask_1d) + + +def causal_conv1d_fn( + x: torch.Tensor, + weight: torch.Tensor, + bias: Union[torch.Tensor, None], + conv_states: torch.Tensor, + query_start_loc: torch.Tensor, + seq_lens_cpu: List[int], + cache_indices: Optional[torch.Tensor] = None, + has_initial_state: Optional[torch.Tensor] = None, + activation: Optional[str] = "silu", + pad_slot_id: int = PAD_SLOT_ID, + validate_data=False, + **kwargs, +): + """support varlen + continuous batching when x is 2D tensor + + x: (dim,cu_seq_len) + cu_seq_len = total tokens of all seqs in that batch + sequences are concatenated from left to right for varlen + weight: (dim, width) + conv_states: (...,dim,width - 1) itype + updated inplace if provided + [it use `cache_indices` to get the index to the cache of conv_state for that sequence + + conv_state[cache_indices[i]] for seq-i - to be used as initial_state when has_initial_state[i] = True + and after that conv_state[cache_indices[i]] need to be shift-left and updated with values from 'x' + ] + query_start_loc: (batch + 1) int32 + The cumulative sequence lengths of the sequences in + the batch, used to index into sequence. prepended by 0. + if + x = [5, 1, 1, 1] <- continuous batching (batch=4) + then + query_start_loc = [0, 5, 6, 7, 8] <- the starting index of the next sequence; while the last value is + the ending index of the last sequence + [length(query_start_loc)-1 == batch] + for example: query_start_loc = torch.Tensor([0,10,16,17]), + x.shape=(dim,17) + seq_lens_cpu: (batch) int32 + The sequence lengths of the sequences in the batch + cache_indices: (batch) int32 + indicates the corresponding state index, + like so: conv_state = conv_states[cache_indices[batch_id]] + has_initial_state: (batch) bool + indicates whether should the kernel take the current state as initial + state for the calculations + [single boolean for each sequence in the batch: True or False] + bias: (dim,) + activation: either None or "silu" or "swish" or True + pad_slot_id: int + if cache_indices is passed, lets the kernel identify padded + entries that will not be processed, + for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] + in this case, the kernel will not process entries at + indices 0 and 3 + + out: same shape as `x` + """ + if isinstance(activation, bool) and activation: + activation = "silu" + + out = torch.empty_like(x) + + is_channel_last = (x.stride(0) == 1) & (x.stride(1) > 1) + dim, cu_seqlen = x.shape + _, width = weight.shape + state_len = width - 1 + np2_statelen = triton.next_power_of_2(state_len) + + stride_x_seq = 0 + stride_x_dim = x.stride(0) + stride_x_token = x.stride(1) + stride_w_dim = weight.stride(0) + stride_w_width = weight.stride(1) + stride_istate_seq = 0 + stride_istate_dim = 0 + stride_istate_token = 0 + num_cache_lines = 0 + if conv_states is not None: + # extensions to support vLLM: + # 1. conv_states is used to replaced initial_states + # 2. conv_states serve as a cache with num cache lines can be larger than batch size + # 3. mapping from sequence x[idx] to a cache line at index as specified via cache_indices[idx] + # 4. computation can be skipped if cache_indices[idx] == pad_slot_id + num_cache_lines = conv_states.size(0) + assert ( + num_cache_lines == conv_states.shape[0] + and dim == conv_states.shape[1] + and width - 1 <= conv_states.shape[2] + ) + stride_istate_seq = conv_states.stride(0) + stride_istate_dim = conv_states.stride(1) + stride_istate_token = conv_states.stride(2) + # assert stride_istate_dim == 1 + if out.dim() == 2: + stride_o_seq = 0 + stride_o_dim = out.stride(0) + stride_o_token = out.stride(1) + else: + stride_o_seq = out.stride(0) + stride_o_dim = out.stride(1) + stride_o_token = out.stride(2) + + if validate_data: + assert x.dim() == 2 + assert query_start_loc is not None + assert query_start_loc.dim() == 1 + assert x.stride(0) == 1 or x.stride(1) == 1 + padded_batch = query_start_loc.size(0) - 1 + if bias is not None: + assert bias.dim() == 1 + assert dim == bias.size(0) + if cache_indices is not None: + assert cache_indices.dim() == 1 + assert padded_batch == cache_indices.size(0) + if has_initial_state is not None: + assert has_initial_state.size() == (padded_batch,) + assert ( + conv_states is not None + ), "ERROR: `has_initial_state` is used, which needs also `conv_states`" + assert weight.stride(1) == 1 + assert (dim, width) == weight.shape + assert is_channel_last, "Need to run in channel-last layout" + + def grid(META): + max_seq_len = max(seq_lens_cpu) + return ( + len(seq_lens_cpu), # batch_size + (max_seq_len + META["BLOCK_M"] - 1) // META["BLOCK_M"], + triton.cdiv(dim, META["BLOCK_N"]), + ) + + _causal_conv1d_fwd_kernel[grid]( + # Pointers to matrices + x, + weight, + bias, + conv_states, + cache_indices, + has_initial_state, + query_start_loc, + out, + # Matrix dimensions + dim, + cu_seqlen, + num_cache_lines, + # stride + stride_x_seq, + stride_x_dim, + stride_x_token, + stride_w_dim, + stride_w_width, + stride_istate_seq, + stride_istate_dim, + stride_istate_token, + stride_o_seq, + stride_o_dim, + stride_o_token, + # others + pad_slot_id, + # META + HAS_BIAS=bias is not None, + KERNEL_WIDTH=width, + SILU_ACTIVATION=activation in ["silu", "swish"], + HAS_INITIAL_STATES=has_initial_state is not None, + HAS_CACHE=conv_states is not None, + IS_CONTINUOUS_BATCHING=cache_indices is not None, + USE_PAD_SLOT=pad_slot_id is not None, + NP2_STATELEN=np2_statelen, + # launch_cooperative_grid=True + BLOCK_M=8, + BLOCK_N=256, + num_stages=2, + ) + return out + + +# HAS_EAGLE_TREE_CUSTOM_ATTN_MASK is added to support eagle tree attention mask +# retrieve_next_token_ptr: [N, NP2_T], retrieve_next_sibling_ptr: [N, NP2_T] +# e.g. for a sequence of length 4, the eagle tree attention structure is: +# retrieve_next_token=[1, 3, -1, -1] -> retrieve_next_token[i]: the 1st child token of token i +# retrieve_next_sibling=[-1, 2, -1, -1] -> retrieve_next_sibling[i]: the 1st tree sibling token of token i +# retrieve_parent_token=[n/a, 0, 0, 1] -> retrieve_parent_token[i]: the parent token of token i +# Tree: +# 0 +# / \ +# 1 2 +# / +# 3 +# When calculating token 3's convolution, it should conv to token 1 (parent) and token 0 (grand-parent) +# When calculating token 2's convolution, it should conv to token 0 (parent) +# This kernel is a fused kernel which will also produce retrieve_parent_token based on retrieve_next_token & retrieve_next_sibling +@triton.jit() +def _causal_conv1d_update_kernel( + # Pointers to matrices + x_ptr, # (batch, dim, seqlen) + w_ptr, # (dim, width) + bias_ptr, + conv_state_ptr, + cache_seqlens_ptr, # circular buffer + conv_state_indices_ptr, + num_accepted_tokens_ptr, + intermediate_conv_window_ptr, + intermediate_state_indices_ptr, + retrieve_next_token_ptr, + retrieve_next_sibling_ptr, + retrieve_parent_token_ptr, + o_ptr, # (batch, dim, seqlen) + # Matrix dimensions + batch: int, + dim: tl.constexpr, + seqlen: tl.constexpr, + state_len: tl.constexpr, + num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines + # Strides + stride_x_seq: tl.constexpr, + stride_x_dim: tl.constexpr, + stride_x_token: tl.constexpr, + stride_w_dim: tl.constexpr, + stride_w_width: tl.constexpr, + stride_conv_state_seq: tl.constexpr, + stride_conv_state_dim: tl.constexpr, + stride_conv_state_tok: tl.constexpr, + stride_state_indices: tl.constexpr, + stride_inter_seq: tl.constexpr, + stride_inter_step: tl.constexpr, + stride_inter_dim: tl.constexpr, + stride_inter_win: tl.constexpr, + stride_intermediate_state_indices: tl.constexpr, + stride_retrieve_next_token_seq: tl.constexpr, + stride_retrieve_next_token_token: tl.constexpr, + stride_retrieve_next_sibling_seq: tl.constexpr, + stride_retrieve_next_sibling_token: tl.constexpr, + stride_retrieve_parent_token_seq: tl.constexpr, + stride_retrieve_parent_token_token: tl.constexpr, + stride_o_seq: tl.constexpr, + stride_o_dim: tl.constexpr, + stride_o_token: tl.constexpr, + # others + pad_slot_id: tl.constexpr, + # Meta-parameters + HAS_BIAS: tl.constexpr, + KERNEL_WIDTH: tl.constexpr, + SILU_ACTIVATION: tl.constexpr, + IS_CONTINUOUS_BATCHING: tl.constexpr, + IS_SPEC_DECODING: tl.constexpr, + NP2_STATELEN: tl.constexpr, + NP2_SEQLEN: tl.constexpr, + USE_PAD_SLOT: tl.constexpr, + BLOCK_N: tl.constexpr, + SAVE_INTERMEDIATE: tl.constexpr, + HAS_EAGLE_TREE_CUSTOM_ATTN_MASK: tl.constexpr, +): + # ruff: noqa: E501 + idx_seq = tl.program_id(0) + if idx_seq >= batch: + return + + # [BLOCK_N,] elements along the feature-dimension (channel) + idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) + + if IS_CONTINUOUS_BATCHING: + # mask = idx_seq < batch + conv_state_batch_coord = tl.load( + conv_state_indices_ptr + idx_seq * stride_state_indices + ).to(tl.int64) + if SAVE_INTERMEDIATE: + intermediate_state_batch_coord = tl.load( + intermediate_state_indices_ptr + + idx_seq * stride_intermediate_state_indices + ).to(tl.int64) + else: + conv_state_batch_coord = idx_seq + if USE_PAD_SLOT: # noqa + if conv_state_batch_coord == pad_slot_id: + # not processing as this is not the actual sequence + return + + if IS_SPEC_DECODING: + # The rolling of conv state: + # + # Before forward, the conv_state is: + # [history1, history2, ..., historyM]. + # + # After forward, the conv_state becomes: + # [history2, ..., historyM, draft1, draft2, ..., draftN]. + # + # After acceptance, it becomes: + # + # - accept 1 tokens: [history2, ..., historyM, draft1] + # - accept 2 tokens: [history3, ..., historyM, draft1, draft2] + # - and so on. + conv_state_token_offset = tl.load(num_accepted_tokens_ptr + idx_seq) - 1 + else: + conv_state_token_offset = 0 + + # STEP 1: READ init_state data + conv_states_base = ( + conv_state_ptr + + (conv_state_batch_coord * stride_conv_state_seq) + + (idx_feats * stride_conv_state_dim) + ) + mask_w = idx_feats < dim + + prior_tokens = conv_states_base + conv_state_token_offset * stride_conv_state_tok + if KERNEL_WIDTH >= 2: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH >= 3: + conv_states_ptrs = prior_tokens + 1 * stride_conv_state_tok # [BLOCK_N] + col1 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH >= 4: + conv_states_ptrs = prior_tokens + 2 * stride_conv_state_tok # [BLOCK_N] + col2 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH == 5: + conv_states_ptrs = prior_tokens + 3 * stride_conv_state_tok # [BLOCK_N] + col3 = tl.load(conv_states_ptrs, mask_w, 0.0) + + # STEP 2: assume state_len > seqlen + idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M] + + # The conv_state updates works in a sliding window manner, + # at each forward pass, the tokens are shift by 1, so we + # load since idx_tokens + 1. + conv_state_ptrs_source = ( + conv_state_ptr + + (conv_state_batch_coord * stride_conv_state_seq) + + conv_state_token_offset * stride_conv_state_tok + + (idx_feats * stride_conv_state_dim)[None, :] + + ((idx_tokens + (1 if IS_SPEC_DECODING else seqlen)) * stride_conv_state_tok)[ + :, None + ] + ) # [BLOCK_M, BLOCK_N] + mask = ( + (conv_state_batch_coord < num_cache_lines) + & ((idx_tokens + seqlen) < state_len)[:, None] + & (idx_feats < dim)[None, :] + ) + conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0) + + VAL = state_len - seqlen + x_base = x_ptr + (idx_seq * stride_x_seq) + (idx_feats * stride_x_dim) # [BLOCK_N] + + x_ptrs = ( + x_base[None, :] + ((idx_tokens - VAL) * stride_x_token)[:, None] + ) # [BLOCK_M, BLOCK_N] + + mask_x = ( + (idx_tokens - VAL >= 0)[:, None] + & (idx_tokens - VAL < seqlen)[:, None] + & (idx_feats < dim)[None, :] + ) # token-index # token-index # feature-index + loaded_x = tl.load(x_ptrs, mask_x, 0.0) + tl.debug_barrier() + + new_conv_state = tl.where(mask, conv_state, loaded_x) + + conv_state_base = ( + conv_state_ptr + + (conv_state_batch_coord * stride_conv_state_seq) + + (idx_feats * stride_conv_state_dim) + ) # [BLOCK_N,] + conv_state_ptrs_target = ( + conv_state_base + (idx_tokens * stride_conv_state_tok)[:, None] + ) # [BLOCK_M, BLOCK_N] + mask = (idx_tokens < state_len)[:, None] & (idx_feats < dim)[None, :] + tl.store(conv_state_ptrs_target, new_conv_state, mask) + + # STEP 3: init accumulator + if HAS_BIAS: + bias = bias_ptr + idx_feats + mask_bias = idx_feats < dim + acc_preload = tl.load(bias, mask=mask_bias, other=0.0).to( + tl.float32 + ) # [BLOCK_N] + else: + acc_preload = tl.zeros((BLOCK_N,), dtype=tl.float32) + + # STEP 4: + # PRE-LOAD WEIGHTS + # first kernel column, configured for weights to handle BLOCK_N features in range + if HAS_EAGLE_TREE_CUSTOM_ATTN_MASK: + idx_tokens = tl.arange(0, NP2_SEQLEN) # [BLOCK_M] + # Update parent mapping for all tokens at once using vectorized operations + mask_retrieve = idx_tokens < seqlen + retrieve_next_token_base = ( + retrieve_next_token_ptr + + (idx_seq * stride_retrieve_next_token_seq) + + idx_tokens * stride_retrieve_next_token_token + ) + retrieve_next_tokens = tl.load(retrieve_next_token_base, mask_retrieve) + retrieve_next_sibling_base = ( + retrieve_next_sibling_ptr + + (idx_seq * stride_retrieve_next_sibling_seq) + + idx_tokens * stride_retrieve_next_sibling_token + ) + retrieve_next_siblings = tl.load(retrieve_next_sibling_base, mask_retrieve) + parent_idx_tokens = tl.zeros((NP2_SEQLEN,), dtype=tl.int32) + + w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,] + mask_w = idx_feats < dim + if KERNEL_WIDTH >= 2: + w_ptrs = w_base + (0 * stride_w_width) # [BLOCK_N] tensor + w_col0 = tl.load(w_ptrs, mask_w, other=0.0) + w_ptrs = w_base + (1 * stride_w_width) # [BLOCK_N] tensor + w_col1 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 3: + w_ptrs = w_base + (2 * stride_w_width) # [BLOCK_N] tensor + w_col2 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 4: + w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor + w_col3 = tl.load(w_ptrs, mask_w, other=0.0) + + x_base_1d = x_base # starting of chunk [BLOCK_N] + mask_x_1d = idx_feats < dim + + # STEP 5: compute each token + for idx_token in tl.static_range(seqlen): + acc = acc_preload + + if HAS_EAGLE_TREE_CUSTOM_ATTN_MASK: + # set the parent index of the next token in the eagle tree + # next token's parent is the current token + retrieve_next_token_idx = tl.sum( + tl.where(idx_tokens == idx_token, retrieve_next_tokens, 0) + ) + if retrieve_next_token_idx != -1: # pad slot id + parent_idx_tokens = tl.where( + idx_tokens == retrieve_next_token_idx, + idx_token, + parent_idx_tokens, + ) + # next token's parent is the parent of the current token + retrieve_sibling_token_idx = tl.sum( + tl.where(idx_tokens == idx_token, retrieve_next_siblings, 0) + ) + if retrieve_sibling_token_idx != -1: # pad slot id + parent_idx_token = tl.sum( + tl.where(idx_tokens == idx_token, parent_idx_tokens, 0) + ) + parent_idx_tokens = tl.where( + idx_tokens == retrieve_sibling_token_idx, + parent_idx_token, + parent_idx_tokens, + ) + # tl.device_print("am", parent_idx_tokens) + + _idx_token = idx_token + x_ptrs_1d = x_base_1d + _idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + # convolution operation: itself * wcol[-1] + parent * wcol[-2] + grand-parent * wcol[-3] + ... + for j in tl.static_range(KERNEL_WIDTH): + if KERNEL_WIDTH == 2: + if j == 0: + matrix_w = w_col1 + else: + matrix_w = w_col0 + elif KERNEL_WIDTH == 3: + if j == 0: + matrix_w = w_col2 + elif j == 1: + matrix_w = w_col1 + else: + matrix_w = w_col0 + elif KERNEL_WIDTH == 4: + if j == 0: + matrix_w = w_col3 + elif j == 1: + matrix_w = w_col2 + elif j == 2: + matrix_w = w_col1 + else: + matrix_w = w_col0 + + if SAVE_INTERMEDIATE: + # Save the window state after consuming this token + # Layout: [seq(cache line), step, dim, win(K-1)] + base_ptr = ( + intermediate_conv_window_ptr + + intermediate_state_batch_coord * stride_inter_seq + + idx_token * stride_inter_step + + idx_feats * stride_inter_dim + ) + + # store itself in KERNEL_WIDTH-2 slot, parent in KERNEL_WIDTH-3 slot, grand-parent in KERNEL_WIDTH-4 slot, ... + if KERNEL_WIDTH - j - 2 >= 0: + tl.store( + base_ptr + (KERNEL_WIDTH - j - 2) * stride_inter_win, + matrix_x, + mask=mask_w, + ) + + acc += matrix_x * matrix_w + + # move to parent for next iteration + if _idx_token > 0: + _idx_token = tl.sum( + tl.where(idx_tokens == _idx_token, parent_idx_tokens, 0) + ) + x_ptrs_1d = x_base_1d + _idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + else: + # no parent within the current chunk, load from prev conv state: col[-1] (idx 0's parent), col[-2] (idx 0's grand parent), ... + if KERNEL_WIDTH == 2: + if _idx_token == 0: + matrix_x = col0 + elif KERNEL_WIDTH == 3: + if _idx_token == 0: + matrix_x = col1 + else: + matrix_x = col0 + elif KERNEL_WIDTH == 4: + if _idx_token == 0: + matrix_x = col2 + elif _idx_token == -1: + matrix_x = col1 + else: + matrix_x = col0 + _idx_token = _idx_token - 1 + else: + matrix_w = w_col0 + matrix_x = col0 + + for j in tl.static_range(KERNEL_WIDTH): + if KERNEL_WIDTH == 2: + if j == 1: # KERNEL_WIDTH-1: + matrix_w = w_col1 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 3: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 4: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + matrix_x = col2 + elif j == 3: + matrix_w = w_col3 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + + acc += matrix_x * matrix_w # [BLOCK_N] + + if KERNEL_WIDTH == 2: + col0 = matrix_x + elif KERNEL_WIDTH == 3: + col0 = col1 + col1 = matrix_x + elif KERNEL_WIDTH == 4: + col0 = col1 + col1 = col2 + col2 = matrix_x + + if SAVE_INTERMEDIATE: + # Save the window state after consuming this token + # Layout: [seq(cache line), step, dim, win(K-1)] + base_ptr = ( + intermediate_conv_window_ptr + + intermediate_state_batch_coord * stride_inter_seq + + idx_token * stride_inter_step + + idx_feats * stride_inter_dim + ) + if KERNEL_WIDTH >= 2: + tl.store(base_ptr + 0 * stride_inter_win, col0, mask=mask_w) + if KERNEL_WIDTH >= 3: + tl.store(base_ptr + 1 * stride_inter_win, col1, mask=mask_w) + if KERNEL_WIDTH >= 4: + tl.store(base_ptr + 2 * stride_inter_win, col2, mask=mask_w) + + if SILU_ACTIVATION: + acc = acc / (1 + tl.exp(-acc)) + mask_1d = (idx_token < seqlen) & ( + idx_feats < dim + ) # token-index # feature-index + o_ptrs = ( + o_ptr + + (idx_seq) * stride_o_seq + + idx_token * stride_o_token + + (idx_feats * stride_o_dim) + ) + + tl.store(o_ptrs, acc, mask=mask_1d) + + # fuse: store calculated retrieve_parent_token to tensor + if HAS_EAGLE_TREE_CUSTOM_ATTN_MASK: + tl.store( + retrieve_parent_token_ptr + + idx_seq * stride_retrieve_parent_token_seq + + idx_tokens * stride_retrieve_parent_token_token, + parent_idx_tokens, + mask=mask_retrieve, + ) + + +def causal_conv1d_update( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + activation: Union[bool, str, None] = None, + cache_seqlens: Optional[torch.Tensor] = None, + conv_state_indices: Optional[torch.Tensor] = None, + num_accepted_tokens: Optional[torch.Tensor] = None, + intermediate_conv_window: Optional[torch.Tensor] = None, + intermediate_state_indices: Optional[torch.Tensor] = None, + retrieve_next_token: Optional[torch.Tensor] = None, + retrieve_next_sibling: Optional[torch.Tensor] = None, + retrieve_parent_token: Optional[torch.Tensor] = None, + pad_slot_id: int = PAD_SLOT_ID, + metadata=None, + validate_data=False, +): + """ + x: (batch, dim) or (batch, dim, seqlen) + [shape=2: single token prediction] + [shape=3: single or multiple tokens prediction] + conv_state: (..., dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state + starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If not None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + pad_slot_id: int + if cache_indices is passed, lets the kernel identify padded + entries that will not be processed, + for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] + in this case, the kernel will not process entries at + indices 0 and 3 + out: (batch, dim) or (batch, dim, seqlen) + """ + if validate_data: + assert cache_seqlens is None # not implemented yet - ok for vLLM + assert pad_slot_id is not None + assert x.stride(1) == 1 + if isinstance(activation, bool): + activation = "silu" if activation is True else None + elif activation is not None: + assert activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + # make it (batch, dim, seqlen) with seqlen == 1 + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape + _, width = weight.shape + # conv_state: (..., dim, state_len), where state_len >= width - 1 + num_cache_lines, _, state_len = conv_state.size() + + if validate_data: + assert dim == weight.size(0) + assert ( + conv_state.stride(-2) == 1 + ), f"ERROR: expect contiguous along feat-dim of conv_state (currently stride={conv_state.stride()})" + assert state_len >= width - 1 + # when above happens, we don't shift-left to keep any records in conv_state + assert dim == conv_state.size(1) + if conv_state_indices is None: + assert conv_state.size(0) >= batch + else: + assert (batch,) == conv_state_indices.shape + assert intermediate_state_indices is not None + assert (batch,) == intermediate_state_indices.shape + + assert num_cache_lines >= batch + assert weight.stride(1) == 1 # Need this + assert cache_seqlens is None # not needed for vLLM - circular buffer + + # adopt the strategy in vLLM that overwrite on 'x' directly, rather than creating a new tensor 'o' + out = torch.empty_like(x) + stride_w_dim, stride_w_width = weight.stride() + + stride_x_seq, stride_x_dim, stride_x_token = x.stride() # X (batch, dim, seqlen) + + stride_o_seq, stride_o_dim, stride_o_token = out.stride() + stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride() + stride_state_indices = ( + conv_state_indices.stride(0) if conv_state_indices is not None else 0 + ) + stride_intermediate_state_indices = ( + intermediate_state_indices.stride(0) + if intermediate_state_indices is not None + else 0 + ) + if num_accepted_tokens is not None: + state_len = width - 1 + (seqlen - 1) # effective state_len needed + else: + state_len = width - 1 + np2_statelen = triton.next_power_of_2(state_len) + np2_seqlen = triton.next_power_of_2(seqlen) + + def grid(META): + return ( + batch, + triton.cdiv(dim, META["BLOCK_N"]), + ) + + # prepare intermediate buffer strides if provided + if intermediate_conv_window is not None: + stride_inter_seq, stride_inter_step, stride_inter_dim, stride_inter_win = ( + intermediate_conv_window.stride(0), + intermediate_conv_window.stride(1), + intermediate_conv_window.stride(2), + intermediate_conv_window.stride(3), + ) + else: + stride_inter_seq = stride_inter_step = stride_inter_dim = stride_inter_win = 0 + + # prepare retrieve next token buffer strides if provided + if retrieve_next_token is not None: + stride_retrieve_next_token_seq, stride_retrieve_next_token_token = ( + retrieve_next_token.stride(0), + retrieve_next_token.stride(1), + ) + else: + stride_retrieve_next_token_seq = stride_retrieve_next_token_token = 0 + + # prepare retrieve next sibling buffer strides if provided + if retrieve_next_sibling is not None: + stride_retrieve_next_sibling_seq, stride_retrieve_next_sibling_token = ( + retrieve_next_sibling.stride(0), + retrieve_next_sibling.stride(1), + ) + else: + stride_retrieve_next_sibling_seq = stride_retrieve_next_sibling_token = 0 + + # prepare retrieve parent token buffer strides if provided + if retrieve_parent_token is not None: + stride_retrieve_parent_token_seq, stride_retrieve_parent_token_token = ( + retrieve_parent_token.stride(0), + retrieve_parent_token.stride(1), + ) + else: + stride_retrieve_parent_token_seq = stride_retrieve_parent_token_token = 0 + + _causal_conv1d_update_kernel[grid]( + # Pointers to matrices + x, + weight, + bias, + conv_state, + cache_seqlens, + conv_state_indices, + num_accepted_tokens, + intermediate_conv_window if intermediate_conv_window is not None else x, + intermediate_state_indices, + retrieve_next_token, + retrieve_next_sibling, + retrieve_parent_token, + out, + # Matrix dimensions + batch, + dim, + seqlen, + state_len, + num_cache_lines, + # stride + stride_x_seq, + stride_x_dim, + stride_x_token, + stride_w_dim, + stride_w_width, + stride_istate_seq, + stride_istate_dim, + stride_istate_token, + stride_state_indices, + stride_inter_seq, + stride_inter_step, + stride_inter_dim, + stride_inter_win, + stride_intermediate_state_indices, + stride_retrieve_next_token_seq, + stride_retrieve_next_token_token, + stride_retrieve_next_sibling_seq, + stride_retrieve_next_sibling_token, + stride_retrieve_parent_token_seq, + stride_retrieve_parent_token_token, + stride_o_seq, + stride_o_dim, + stride_o_token, + # others + pad_slot_id, + # META + HAS_BIAS=bias is not None, + KERNEL_WIDTH=width, + SILU_ACTIVATION=activation in ["silu", "swish"], + IS_CONTINUOUS_BATCHING=conv_state_indices is not None, + IS_SPEC_DECODING=num_accepted_tokens is not None, + NP2_STATELEN=np2_statelen, + NP2_SEQLEN=np2_seqlen, + USE_PAD_SLOT=pad_slot_id is not None, + BLOCK_N=256, + SAVE_INTERMEDIATE=intermediate_conv_window is not None, + HAS_EAGLE_TREE_CUSTOM_ATTN_MASK=retrieve_next_token is not None, + ) + if unsqueeze: + out = out.squeeze(-1) + return out diff --git a/megatron/core/ssm/ops/mamba_ssm.py b/megatron/core/ssm/ops/mamba_ssm.py new file mode 100644 index 00000000000..f238d51b47e --- /dev/null +++ b/megatron/core/ssm/ops/mamba_ssm.py @@ -0,0 +1,494 @@ +# Adapted from: https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/mamba/ops/mamba_ssm.py + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/selective_state_update.py + +import torch +import triton +import triton.language as tl +from packaging import version + +PAD_SLOT_ID = -1 + +TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0") + +if TRITON3: + + @triton.jit + def softplus(dt): + dt = tl.where(dt <= 20.0, tl.math.log(tl.math.exp(dt) + 1), dt) + return dt + +else: + + @triton.jit + def softplus(dt): + dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt) + return dt + + +@triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None}) +@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None}) +@triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None}) +@triton.heuristics( + { + "HAS_STATE_BATCH_INDICES": lambda args: args["state_batch_indices_ptr"] + is not None + } +) +@triton.heuristics( + {"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])} +) +@triton.heuristics( + { + "CACHE_INTERMEDIATE_STATES": lambda args: args["intermediate_states_buffer"] + is not None + } +) +@triton.heuristics( + { + "HAS_EAGLE_TREE_CUSTOM_ATTN_MASK": lambda args: args[ + "retrieve_parent_token_ptr" + ] + is not None + } +) +@triton.heuristics( + { + "HAS_INTERMEDIATE_STATE_INDICES": lambda args: args[ + "intermediate_state_indices_ptr" + ] + is not None + } +) +@triton.jit(do_not_specialize=["T"]) +def _selective_scan_update_kernel( + # Pointers to matrices + state_ptr, + x_ptr, + dt_ptr, + dt_bias_ptr, + A_ptr, + B_ptr, + C_ptr, + D_ptr, + z_ptr, + out_ptr, + state_batch_indices_ptr, + pad_slot_id, + intermediate_states_buffer, + cache_steps, + retrieve_parent_token_ptr, + intermediate_state_indices_ptr, + # Matrix dimensions + batch, + T, + nheads, + dim, + dstate, + nheads_ngroups_ratio, + # Strides + stride_state_batch, + stride_state_head, + stride_state_dim, + stride_state_dstate, + stride_x_batch, + stride_x_T, + stride_x_head, + stride_x_dim, + stride_dt_batch, + stride_dt_T, + stride_dt_head, + stride_dt_dim, + stride_dt_bias_head, + stride_dt_bias_dim, + stride_A_head, + stride_A_dim, + stride_A_dstate, + stride_B_batch, + stride_B_T, + stride_B_group, + stride_B_dstate, + stride_C_batch, + stride_C_T, + stride_C_group, + stride_C_dstate, + stride_D_head, + stride_D_dim, + stride_z_batch, + stride_z_T, + stride_z_head, + stride_z_dim, + stride_out_batch, + stride_out_T, + stride_out_head, + stride_out_dim, + stride_retrieve_parent_token_batch, + stride_retrieve_parent_token_T, + # Meta-parameters + DT_SOFTPLUS: tl.constexpr, + TIE_HDIM: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + HAS_DT_BIAS: tl.constexpr, + HAS_D: tl.constexpr, + HAS_Z: tl.constexpr, + HAS_STATE_BATCH_INDICES: tl.constexpr, + DISABLE_STATE_UPDATE: tl.constexpr, + CACHE_INTERMEDIATE_STATES: tl.constexpr, + HAS_EAGLE_TREE_CUSTOM_ATTN_MASK: tl.constexpr, + HAS_INTERMEDIATE_STATE_INDICES: tl.constexpr, + BLOCK_SIZE_DSTATE: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_b = tl.program_id(axis=1) + pid_h = tl.program_id(axis=2) + + # If HAS_STATE_BATCH_INDICES is true, then the ssm state's batch coordinate + # is taken from the state_batch_indices_ptr Otherwise, the state coordinate + # is the same as the batch id. + if HAS_STATE_BATCH_INDICES: + state_batch_indices_ptr += pid_b + state_batch_idx = tl.load(state_batch_indices_ptr).to(tl.int64) + state_ptr += state_batch_idx * stride_state_batch + pid_h * stride_state_head + else: + state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head + + x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head + dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head + if HAS_DT_BIAS: + dt_bias_ptr += pid_h * stride_dt_bias_head + A_ptr += pid_h * stride_A_head + B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group + C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group + if HAS_Z: + z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head + out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = tl.arange(0, BLOCK_SIZE_DSTATE) + state_ptrs = state_ptr + ( + offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate + ) + + mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate) + if HAS_STATE_BATCH_INDICES: + mask &= state_batch_idx != pad_slot_id + state = tl.load(state_ptrs, mask=mask, other=0.0).to(tl.float32) + + if HAS_DT_BIAS: + dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim + if HAS_D: + D_ptr += pid_h * stride_D_head + D_ptrs = D_ptr + offs_m * stride_D_dim + A_ptrs = A_ptr + offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate + + cache_idx = -1 + if CACHE_INTERMEDIATE_STATES: + if HAS_INTERMEDIATE_STATE_INDICES: + intermediate_state_idx = tl.load(intermediate_state_indices_ptr + pid_b).to( + tl.int64 + ) + cache_idx = intermediate_state_idx + elif HAS_STATE_BATCH_INDICES: + cache_idx = state_batch_idx + else: + cache_idx = pid_b + + current_step_idx = 0 + for _ in range(T): + if HAS_EAGLE_TREE_CUSTOM_ATTN_MASK: + if current_step_idx != 0 and cache_idx >= 0: + parent_ptr = ( + retrieve_parent_token_ptr + + pid_b * stride_retrieve_parent_token_batch + + current_step_idx * stride_retrieve_parent_token_T + ) + parent_step_idx = tl.load(parent_ptr).to(tl.int32) + + if parent_step_idx >= 0 and parent_step_idx < T: + step_offset = parent_step_idx * nheads * dim * dstate + cache_ptr = ( + intermediate_states_buffer + + cache_idx * cache_steps * nheads * dim * dstate + + step_offset + + pid_h * dim * dstate + + offs_m[:, None] * dstate + + offs_n[None, :] + ) + state = tl.load(cache_ptr, mask=mask, other=0.0).to(tl.float32) + + x_ptrs = x_ptr + offs_m * stride_x_dim + dt_ptrs = dt_ptr + offs_m * stride_dt_dim + B_ptrs = B_ptr + offs_n * stride_B_dstate + C_ptrs = C_ptr + offs_n * stride_C_dstate + if HAS_Z: + z_ptrs = z_ptr + offs_m * stride_z_dim + out_ptrs = out_ptr + offs_m * stride_out_dim + + x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if not TIE_HDIM: + dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if HAS_DT_BIAS: + dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if DT_SOFTPLUS: + dt = softplus(dt) + A = tl.load( + A_ptrs, + mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), + other=0.0, + ).to(tl.float32) + dA = tl.exp(A * dt[:, None]) + else: + dt = tl.load(dt_ptr).to(tl.float32) + if HAS_DT_BIAS: + dt += tl.load(dt_bias_ptr).to(tl.float32) + if DT_SOFTPLUS: + dt = softplus(dt) + A = tl.load(A_ptr).to(tl.float32) + dA = tl.exp(A * dt) # scalar, not a matrix + + B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) + C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) + if HAS_D: + D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if HAS_Z: + z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + + dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt + state = state * dA + dB * x[:, None] + + if CACHE_INTERMEDIATE_STATES: + if HAS_STATE_BATCH_INDICES: + if state_batch_idx != pad_slot_id: + cache_ptr_base = ( + intermediate_states_buffer + + cache_idx * cache_steps * nheads * dim * dstate + + current_step_idx * nheads * dim * dstate + + pid_h * dim * dstate + ) + cache_ptrs = cache_ptr_base + ( + offs_m[:, None] * dstate + offs_n[None, :] + ) + tl.store( + cache_ptrs, state.to(cache_ptrs.dtype.element_ty), mask=mask + ) + + out = tl.sum(state * C[None, :], axis=1) + if HAS_D: + out += x * D + if HAS_Z: + out *= z * tl.sigmoid(z) + tl.store(out_ptrs, out, mask=offs_m < dim) + + current_step_idx += 1 + + x_ptr += stride_x_T + dt_ptr += stride_dt_T + B_ptr += stride_B_T + C_ptr += stride_C_T + out_ptr += stride_out_T + if HAS_Z: + z_ptr += stride_z_T + + if not DISABLE_STATE_UPDATE: + tl.store(state_ptrs, state.to(state_ptrs.dtype.element_ty), mask=mask) + + +def selective_state_update( + state, + x, + dt, + A, + B, + C, + D=None, + z=None, + dt_bias=None, + dt_softplus=False, + state_batch_indices=None, + pad_slot_id=PAD_SLOT_ID, + out=None, + disable_state_update=False, + intermediate_states_buffer=None, + cache_steps=None, + retrieve_parent_token=None, + intermediate_state_indices=None, +): + """ + Argument: + state: (batch, dim, dstate) or (batch, nheads, dim, dstate) + x: (batch, dim) or (batch, nheads, dim) for single-token or (batch, T, nheads, dim) for multi-token + dt: (batch, dim) or (batch, nheads, dim) + A: (dim, dstate) or (nheads, dim, dstate) + B: (batch, dstate) or (batch, ngroups, dstate) for single-token or (batch, T, ngroups, dstate) for multi-token + C: (batch, dstate) or (batch, ngroups, dstate) + D: (dim,) or (nheads, dim) + z: (batch, dim) or (batch, nheads, dim) + dt_bias: (dim,) or (nheads, dim) + pad_slot_id: int + if cache_indices is passed, lets the kernel identify padded + entries that will not be processed, + for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] + in this case, the kernel will not process entries at + indices 0 and 3 + out: Preallocated ssm output tensor. Assume same shape as x. + In-place updated. + disable_state_update: If True, don't write back to state (for speculative verify) + intermediate_states_buffer: Buffer to cache intermediate states + cache_steps: Total number of steps in the buffer + retrieve_parent_token: (batch, T) tensor of parent token indices for EAGLE tree attention + intermediate_state_indices: (batch,) tensor of indices for intermediate_states_buffer operations. + If provided, uses these indices instead of state_batch_indices for the buffer. + """ + if state.dim() == 3: + state = state.unsqueeze(1) + if x.dim() == 2: + x = x.unsqueeze(1) + if x.dim() == 3: + x = x.unsqueeze(1) + if dt.dim() == 2: + dt = dt.unsqueeze(1) + if dt.dim() == 3: + dt = dt.unsqueeze(1) + if A.dim() == 2: + A = A.unsqueeze(0) + if B.dim() == 2: + B = B.unsqueeze(1) + if B.dim() == 3: + B = B.unsqueeze(1) + if C.dim() == 2: + C = C.unsqueeze(1) + if C.dim() == 3: + C = C.unsqueeze(1) + if D is not None and D.dim() == 1: + D = D.unsqueeze(0) + if z is not None: + if z.dim() == 2: + z = z.unsqueeze(1) + if z.dim() == 3: + z = z.unsqueeze(1) + if dt_bias is not None and dt_bias.dim() == 1: + dt_bias = dt_bias.unsqueeze(0) + if out.dim() == 2: + out = out.unsqueeze(1) + if out.dim() == 3: + out = out.unsqueeze(1) + + _, nheads, dim, dstate = state.shape + batch, T, _, _ = x.shape + + assert x.shape == (batch, T, nheads, dim) + assert dt.shape == x.shape + assert A.shape == (nheads, dim, dstate) + ngroups = B.shape[2] + assert nheads % ngroups == 0, "nheads must be divisible by ngroups" + assert B.shape == (batch, T, ngroups, dstate) + assert C.shape == B.shape + if D is not None: + assert D.shape == (nheads, dim) + if z is not None: + assert z.shape == x.shape + if dt_bias is not None: + assert dt_bias.shape == (nheads, dim) + if state_batch_indices is not None: + assert state_batch_indices.shape == (batch,) + assert out.shape == x.shape + + grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE_M"]), batch, nheads) + z_strides = ( + (z.stride(0), z.stride(1), z.stride(2), z.stride(3)) + if z is not None + else (0, 0, 0, 0) + ) + # We don't want autotune since it will overwrite the state + # We instead tune by hand. + BLOCK_SIZE_M, num_warps = ( + (32, 4) + if dstate <= 16 + else ( + (16, 4) + if dstate <= 32 + else ((8, 4) if dstate <= 64 else ((4, 4) if dstate <= 128 else ((4, 8)))) + ) + ) + tie_hdim = ( + A.stride(-1) == 0 + and A.stride(-2) == 0 + and dt.stride(-1) == 0 + and dt_bias.stride(-1) == 0 + ) + + retrieve_parent_token_strides = ( + (retrieve_parent_token.stride(0), retrieve_parent_token.stride(1)) + if retrieve_parent_token is not None + else (0, 0) + ) + + with torch.cuda.device(x.device.index): + _selective_scan_update_kernel[grid]( + state, + x, + dt, + dt_bias, + A, + B, + C, + D, + z, + out, + state_batch_indices, + pad_slot_id, + intermediate_states_buffer, + cache_steps if cache_steps is not None else 0, + retrieve_parent_token, + intermediate_state_indices, + batch, + T, + nheads, + dim, + dstate, + nheads // ngroups, + state.stride(0), + state.stride(1), + state.stride(2), + state.stride(3), + x.stride(0), + x.stride(1), + x.stride(2), + x.stride(3), + dt.stride(0), + dt.stride(1), + dt.stride(2), + dt.stride(3), + *(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0, + A.stride(0), + A.stride(1), + A.stride(2), + B.stride(0), + B.stride(1), + B.stride(2), + B.stride(3), + C.stride(0), + C.stride(1), + C.stride(2), + C.stride(3), + *(D.stride(0), D.stride(1)) if D is not None else 0, + z_strides[0], + z_strides[1], + z_strides[2], + z_strides[3], + out.stride(0), + out.stride(1), + out.stride(2), + out.stride(3), + retrieve_parent_token_strides[0], + retrieve_parent_token_strides[1], + dt_softplus, + tie_hdim, + BLOCK_SIZE_M, + DISABLE_STATE_UPDATE=disable_state_update, + num_warps=num_warps, + ) From 77957376d0542ee250b29d523d806c32fb9313f7 Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Tue, 24 Feb 2026 13:10:22 -0800 Subject: [PATCH 10/76] More debugging Signed-off-by: Keshav Santhanam --- megatron/core/ssm/mamba_mixer.py | 38 ++++++++++++++++++++------------ 1 file changed, 24 insertions(+), 14 deletions(-) diff --git a/megatron/core/ssm/mamba_mixer.py b/megatron/core/ssm/mamba_mixer.py index d30eb45de7d..4dedbd6d9cd 100644 --- a/megatron/core/ssm/mamba_mixer.py +++ b/megatron/core/ssm/mamba_mixer.py @@ -443,18 +443,25 @@ def _dynamic_inference(self, hidden_states: torch.Tensor, context: DynamicInfere y_decode = None y_prefill = None + """ if self.layer_number == 1: torch.distributed.breakpoint(0) + """ # Decode if decode_req_count > 0: # For mixed batch, the decode tokens are at the start of zxBCdt zxBCdt_decode = zxBCdt[:decode_req_count] if prefill_req_count > 0 else zxBCdt - num_accepted_tokens = context.mamba_metadata.num_accepted_tokens - intermediate_conv_state, intermediate_ssm_state = context.mamba_states_cache( - self.layer_number - self.pp_layer_offset, intermediate=True - ) + if context.num_speculative_tokens > 0: + num_accepted_tokens = context.mamba_metadata.num_accepted_tokens + intermediate_conv_state, intermediate_ssm_state = context.mamba_states_cache( + self.layer_number - self.pp_layer_offset, intermediate=True + ) + else: + num_accepted_tokens = None + intermediate_conv_state = None + intermediate_ssm_state = None y_decode = self._ssm_decode( zxBCdt_decode.transpose(0, 1), @@ -462,7 +469,7 @@ def _dynamic_inference(self, hidden_states: torch.Tensor, context: DynamicInfere ssm_state, context.mamba_metadata.batch_indices_decode, num_accepted_tokens=num_accepted_tokens, - intermediate_conv_window=intermediate_conv_state, + intermediate_conv_state=intermediate_conv_state, intermediate_ssm_state=intermediate_ssm_state, ).transpose(0, 1) @@ -908,7 +915,7 @@ def _ssm_decode( if seq_len > 1: assert ( num_accepted_tokens is not None - and intermediate_conv_window is not None + and intermediate_conv_state is not None and intermediate_ssm_state is not None ), "Decoding with > 1 token per request requires speculative decoding state" is_speculative_decoding = True @@ -948,7 +955,7 @@ def _ssm_decode( self.activation, conv_state_indices=batch_indices, num_accepted_tokens=num_accepted_tokens, - intermediate_conv_window=intermediate_conv_window, + intermediate_conv_window=intermediate_conv_state, intermediate_state_indices=batch_indices, pad_slot_id=-1, ) @@ -1015,7 +1022,7 @@ def _ssm_decode( A = repeat(A, "h -> h p n", p=self.headdim, n=self.d_state).to(dtype=torch.float32) dt_bias = repeat(self.dt_bias, "h -> h p", p=self.headdim) D = repeat(self.D, "h -> h p", p=self.headdim) - if is_speculative_deocode: + if is_speculative_decoding: dt = repeat(dt, "s b h -> s b h p", p=self.headdim) B = rearrange(B, "s b (g n) -> s b g n", g=self.ngroups_local_tp) C = rearrange(C, "s b (g n) -> s b g n", g=self.ngroups_local_tp) @@ -1035,7 +1042,8 @@ def _ssm_decode( if batch_indices is not None: batch_indices = batch_indices.to(torch.int64) - y = selective_state_update( + y = torch.empty_like(x_reshaped) + selective_state_update( ssm_state, x_reshaped, dt, @@ -1047,12 +1055,14 @@ def _ssm_decode( dt_bias=dt_bias, dt_softplus=True, state_batch_indices=batch_indices, - disable_state_update=True, - intermediate_states_buffer=intermediate_ssm_state, - cache_steps=seq_len, - intermediate_state_indices=batch_indices, + pad_slot_id=-1, + out=y, + #disable_state_update=True, + #intermediate_states_buffer=intermediate_ssm_state, + #cache_steps=seq_len, + #intermediate_state_indices=batch_indices, ) - if is_speculative_decode: + if is_speculative_decoding: y = rearrange(y, "s b h p -> s b (h p)") else: y = rearrange(y, "b h p -> b (h p)") From ae037471e9943317694c411a5af9e9a7ccf77f08 Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Tue, 24 Feb 2026 21:19:40 -0800 Subject: [PATCH 11/76] Working causal_conv1d_update triton kernel Signed-off-by: Keshav Santhanam --- megatron/core/ssm/ops/causal_conv1d_triton.py | 1321 ++--------------- 1 file changed, 161 insertions(+), 1160 deletions(-) diff --git a/megatron/core/ssm/ops/causal_conv1d_triton.py b/megatron/core/ssm/ops/causal_conv1d_triton.py index c82f4d730fa..fc27a8df500 100644 --- a/megatron/core/ssm/ops/causal_conv1d_triton.py +++ b/megatron/core/ssm/ops/causal_conv1d_triton.py @@ -1,1187 +1,188 @@ -# Copyright (c) 2024, Tri Dao. -# Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py -# and https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py - -from typing import List, Optional, Union - import torch import triton import triton.language as tl -PAD_SLOT_ID = -1 - - -@triton.jit() -def _causal_conv1d_fwd_kernel( # continuous batching - # Pointers to matrices - x_ptr, # (dim, cu_seqlen) holding `batch` of actual sequences + padded sequences - w_ptr, # (dim, width) - bias_ptr, - initial_states_ptr, # conv_states_ptr - cache_indices_ptr, # conv_state_indices_ptr - has_initial_states_ptr, - query_start_loc_ptr, - o_ptr, # (dim, seqlen) - actually pointing to x_ptr - # Matrix dimensions - dim: tl.constexpr, - seqlen: tl.int32, # cu_seqlen - num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines - # Strides - stride_x_seq: tl.constexpr, # stride to get to next sequence, - stride_x_dim: tl.constexpr, # stride to get to next feature-value, - stride_x_token: tl.constexpr, # stride to get to next token (same feature-index, same sequence-index) - stride_w_dim: tl.constexpr, # stride to get to next dim-axis value - stride_w_width: tl.constexpr, # stride to get to next width-axis value - stride_istate_seq: tl.constexpr, - stride_istate_dim: tl.constexpr, - stride_istate_token: tl.constexpr, - stride_o_seq: tl.constexpr, - stride_o_dim: tl.constexpr, - stride_o_token: tl.constexpr, - # others - pad_slot_id: tl.constexpr, - # Meta-parameters - HAS_BIAS: tl.constexpr, - KERNEL_WIDTH: tl.constexpr, - SILU_ACTIVATION: tl.constexpr, - HAS_INITIAL_STATES: tl.constexpr, - HAS_CACHE: tl.constexpr, - IS_CONTINUOUS_BATCHING: tl.constexpr, - USE_PAD_SLOT: tl.constexpr, - NP2_STATELEN: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, -): - conv_states_ptr = initial_states_ptr - conv_state_indices_ptr = cache_indices_ptr - stride_conv_state_seq = stride_istate_seq - stride_conv_state_dim = stride_istate_dim - stride_conv_state_tok = stride_istate_token - state_len = ( - KERNEL_WIDTH - 1 - ) # can be passed via argument if it's not the same as this value - - # one program handles one chunk in a single sequence - # rather than mixing sequences - to make updating initial_states across sequences efficiently - - # single-sequence id - idx_seq = tl.program_id(0) - chunk_offset = tl.program_id(1) - - # BLOCK_N elements along the feature-dimension (channel) - idx_feats = tl.program_id(2) * BLOCK_N + tl.arange(0, BLOCK_N) - - if idx_seq == pad_slot_id: - return - - sequence_start_index = tl.load(query_start_loc_ptr + idx_seq) - sequence_end_index = tl.load(query_start_loc_ptr + idx_seq + 1) - # find the actual sequence length - seqlen = sequence_end_index - sequence_start_index - - token_offset = BLOCK_M * chunk_offset - segment_len = min(BLOCK_M, seqlen - token_offset) - - if segment_len <= 0: - return - - # base of the sequence - x_base = ( - x_ptr + sequence_start_index * stride_x_token + idx_feats * stride_x_dim - ) # [BLOCK_N,] - - if IS_CONTINUOUS_BATCHING: - # cache_idx - conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq).to(tl.int64) - else: - # cache_idx - conv_state_batch_coord = idx_seq - if USE_PAD_SLOT: # noqa - if conv_state_batch_coord == pad_slot_id: - # not processing as this is not the actual sequence - return - conv_states_base = ( - conv_states_ptr - + (conv_state_batch_coord * stride_conv_state_seq) - + (idx_feats * stride_conv_state_dim) - ) # [BLOCK_N,] - - w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,] - - # Does 2 things: - # 1. READ prior-block init-state data - [done by every Triton programs] - # 2. update conv_state with new data [only by the Triton program handles chunk_offset=0] - if chunk_offset == 0: - # read from conv_states - load_init_state = False - if HAS_INITIAL_STATES: # the new HAS_INITIAL_STATES - load_init_state = tl.load(has_initial_states_ptr + idx_seq).to(tl.int1) - if load_init_state: - # load from conv_states - prior_tokens = conv_states_base + (state_len - 1) * stride_conv_state_tok - mask_w = idx_feats < dim - if KERNEL_WIDTH == 2: - conv_states_ptrs = prior_tokens # [BLOCK_N] - col0 = tl.load(conv_states_ptrs, mask_w, 0.0) - if KERNEL_WIDTH == 3: - conv_states_ptrs = prior_tokens # [BLOCK_N] - col1 = tl.load(conv_states_ptrs, mask_w, 0.0) - conv_states_ptrs = prior_tokens - 1 * stride_conv_state_tok # [BLOCK_N] - col0 = tl.load(conv_states_ptrs, mask_w, 0.0) - if KERNEL_WIDTH == 4: - conv_states_ptrs = prior_tokens # [BLOCK_N] - col2 = tl.load(conv_states_ptrs, mask_w, 0.0) - conv_states_ptrs = prior_tokens - 1 * stride_conv_state_tok # [BLOCK_N] - col1 = tl.load(conv_states_ptrs, mask_w, 0.0) - conv_states_ptrs = prior_tokens - 2 * stride_conv_state_tok # [BLOCK_N] - col0 = tl.load(conv_states_ptrs, mask_w, 0.0) - if KERNEL_WIDTH == 5: - conv_states_ptrs = prior_tokens # [BLOCK_N] - col3 = tl.load(conv_states_ptrs, mask_w, 0.0) - conv_states_ptrs = prior_tokens - 1 * stride_conv_state_tok # [BLOCK_N] - col2 = tl.load(conv_states_ptrs, mask_w, 0.0) - conv_states_ptrs = prior_tokens - 2 * stride_conv_state_tok # [BLOCK_N] - col1 = tl.load(conv_states_ptrs, mask_w, 0.0) - conv_states_ptrs = prior_tokens - 3 * stride_conv_state_tok # [BLOCK_N] - col0 = tl.load(conv_states_ptrs, mask_w, 0.0) - else: - # prior-tokens are zeros - if KERNEL_WIDTH >= 2: # STRATEGY1 - # first chunk and does not have prior-token, so just set to 0 - col0 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty) - if KERNEL_WIDTH >= 3: # STRATEGY1 - col1 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty) - if KERNEL_WIDTH >= 4: # STRATEGY1 - col2 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty) - if KERNEL_WIDTH >= 5: # STRATEGY1 - col3 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty) - - # STEP 2: - # here prepare data for updating conv_state - if ( - state_len <= seqlen - ): # SMALL_CACHE=True (only move part of 'x' into conv_state cache) - # just read from 'x' - # copy 'x' data to conv_state - # load only 'x' data (and set 0 before 'x' if seqlen < state_len) - idx_tokens_last = (seqlen - state_len) + tl.arange( - 0, NP2_STATELEN - ) # [BLOCK_M] - x_ptrs = ( - x_ptr - + ((sequence_start_index + idx_tokens_last) * stride_x_token)[:, None] - + (idx_feats * stride_x_dim)[None, :] - ) # [BLOCK_M,BLOCK_N,] - mask_x = ( - (idx_tokens_last >= 0)[:, None] - & (idx_tokens_last < seqlen)[:, None] - & (idx_feats < dim)[None, :] - ) # token-index # token-index # feature-index - loaded_x = tl.load(x_ptrs, mask_x, 0.0) - new_conv_state = tl.load(x_ptrs, mask_x, 0.0) - idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M] - conv_states_ptrs_target = ( - conv_states_base[None, :] - + (idx_tokens_conv * stride_conv_state_tok)[:, None] - ) - - mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[None, :] - tl.debug_barrier() # NOTE: use this due to bug in Triton compiler - tl.store(conv_states_ptrs_target, new_conv_state, mask) - - else: - if load_init_state: - # update conv_state by shifting left, i.e. take last few cols from conv_state + cols from 'x' - idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M] - - conv_states_ptrs_source = ( - conv_states_ptr - + (conv_state_batch_coord * stride_conv_state_seq) - + (idx_feats * stride_conv_state_dim)[None, :] - + ((idx_tokens_conv + seqlen) * stride_conv_state_tok)[:, None] - ) # [BLOCK_M, BLOCK_N] - mask = ( - (conv_state_batch_coord < num_cache_lines) - & ((idx_tokens_conv + seqlen) < state_len)[:, None] - & (idx_feats < dim)[None, :] - ) - conv_state = tl.load(conv_states_ptrs_source, mask, other=0.0) - - VAL = state_len - seqlen - - x_ptrs = ( - x_base[None, :] - + ((idx_tokens_conv - VAL) * stride_x_token)[:, None] - ) # [BLOCK_M, BLOCK_N] - - mask_x = ( - (idx_tokens_conv - VAL >= 0)[:, None] - & (idx_tokens_conv - VAL < seqlen)[:, None] - & (idx_feats < dim)[None, :] - ) # token-index # token-index # feature-index - loaded_x = tl.load(x_ptrs, mask_x, 0.0) - - tl.debug_barrier() # need this due to the bug in tl.where not enforcing this when data is the result of another tl.load - new_conv_state = tl.where( - mask, conv_state, loaded_x - ) # BUG in 'tl.where' which requires a barrier before this - conv_states_ptrs_target = ( - conv_states_base - + (idx_tokens_conv * stride_conv_state_tok)[:, None] - ) # [BLOCK_M, BLOCK_N] - mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[ - None, : - ] - tl.store(conv_states_ptrs_target, new_conv_state, mask) - else: # load_init_state == False - # update conv_state by shifting left, BUT - # set cols prior to 'x' as zeros + cols from 'x' - idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M] - - VAL = state_len - seqlen - - x_ptrs = ( - x_base[None, :] - + ((idx_tokens_conv - VAL) * stride_x_token)[:, None] - ) # [BLOCK_M, BLOCK_N] - - mask_x = ( - (idx_tokens_conv - VAL >= 0)[:, None] - & (idx_tokens_conv - VAL < seqlen)[:, None] - & (idx_feats < dim)[None, :] - ) # token-index # token-index # feature-index - new_conv_state = tl.load(x_ptrs, mask_x, 0.0) - - conv_states_ptrs_target = ( - conv_states_base - + (idx_tokens_conv * stride_conv_state_tok)[:, None] - ) # [BLOCK_M, BLOCK_N] - mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[ - None, : - ] - tl.store(conv_states_ptrs_target, new_conv_state, mask) - - else: # chunk_offset > 0 - # read prior-token data from `x` - load_init_state = True - prior_tokens = x_base + (token_offset - 1) * stride_x_token - mask_w = idx_feats < dim - if KERNEL_WIDTH == 2: - conv_states_ptrs = prior_tokens # [BLOCK_N] - col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") - if KERNEL_WIDTH == 3: - conv_states_ptrs = prior_tokens # [BLOCK_N] - col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") - conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N] - col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") - if KERNEL_WIDTH == 4: - conv_states_ptrs = prior_tokens # [BLOCK_N] - col2 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") - conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N] - col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") - conv_states_ptrs = prior_tokens - 2 * stride_x_token # [BLOCK_N] - col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") - if KERNEL_WIDTH == 5: - # ruff: noqa: F841 - conv_states_ptrs = prior_tokens # [BLOCK_N] - col3 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") - conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N] - col2 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") - conv_states_ptrs = prior_tokens - 2 * stride_x_token # [BLOCK_N] - col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") - conv_states_ptrs = prior_tokens - 3 * stride_x_token # [BLOCK_N] - col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") - - if HAS_BIAS: - bias = bias_ptr + idx_feats - mask_bias = idx_feats < dim - acc_preload = tl.load(bias, mask=mask_bias, other=0.0).to( - tl.float32 - ) # [BLOCK_N] - else: - acc_preload = tl.zeros((BLOCK_N,), dtype=tl.float32) - - x_base_1d = x_base + token_offset * stride_x_token # starting of chunk - - # PRE-LOAD WEIGHTS - mask_w = idx_feats < dim - if KERNEL_WIDTH >= 2: - w_ptrs = w_base + (0 * stride_w_width) # [BLOCK_N] tensor - w_col0 = tl.load(w_ptrs, mask_w, other=0.0) - w_ptrs = w_base + (1 * stride_w_width) # [BLOCK_N] tensor - w_col1 = tl.load(w_ptrs, mask_w, other=0.0) - if KERNEL_WIDTH >= 3: - w_ptrs = w_base + (2 * stride_w_width) # [BLOCK_N] tensor - w_col2 = tl.load(w_ptrs, mask_w, other=0.0) - if KERNEL_WIDTH >= 4: - w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor - w_col3 = tl.load(w_ptrs, mask_w, other=0.0) - mask_x_1d = idx_feats < dim - for idx_token in range(segment_len): - acc = acc_preload - - matrix_w = w_col0 - matrix_x = col0 - for j in tl.static_range(KERNEL_WIDTH): - - if KERNEL_WIDTH == 2: - if j == 1: # KERNEL_WIDTH-1: - matrix_w = w_col1 - x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] - matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) - elif KERNEL_WIDTH == 3: - if j == 1: - matrix_w = w_col1 - matrix_x = col1 - elif j == 2: - matrix_w = w_col2 - x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] - matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) - elif KERNEL_WIDTH == 4: - if j == 1: - matrix_w = w_col1 - matrix_x = col1 - elif j == 2: - matrix_w = w_col2 - matrix_x = col2 - elif j == 3: - matrix_w = w_col3 - x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] - matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) - - acc += matrix_x * matrix_w # [BLOCK_N] - - if KERNEL_WIDTH == 2: - col0 = matrix_x - elif KERNEL_WIDTH == 3: - col0 = col1 - col1 = matrix_x - elif KERNEL_WIDTH == 4: - col0 = col1 - col1 = col2 - col2 = matrix_x - - if SILU_ACTIVATION: - acc = acc / (1 + tl.exp(-acc)) - mask_1d = (idx_token < segment_len) & ( - idx_feats < dim - ) # token-index # feature-index - o_ptrs = ( - o_ptr - + (sequence_start_index + token_offset + idx_token) * stride_o_token - + (idx_feats * stride_o_dim) - ) - - tl.store(o_ptrs, acc, mask=mask_1d) - - -def causal_conv1d_fn( - x: torch.Tensor, - weight: torch.Tensor, - bias: Union[torch.Tensor, None], - conv_states: torch.Tensor, - query_start_loc: torch.Tensor, - seq_lens_cpu: List[int], - cache_indices: Optional[torch.Tensor] = None, - has_initial_state: Optional[torch.Tensor] = None, - activation: Optional[str] = "silu", - pad_slot_id: int = PAD_SLOT_ID, - validate_data=False, - **kwargs, -): - """support varlen + continuous batching when x is 2D tensor - - x: (dim,cu_seq_len) - cu_seq_len = total tokens of all seqs in that batch - sequences are concatenated from left to right for varlen - weight: (dim, width) - conv_states: (...,dim,width - 1) itype - updated inplace if provided - [it use `cache_indices` to get the index to the cache of conv_state for that sequence - - conv_state[cache_indices[i]] for seq-i - to be used as initial_state when has_initial_state[i] = True - and after that conv_state[cache_indices[i]] need to be shift-left and updated with values from 'x' - ] - query_start_loc: (batch + 1) int32 - The cumulative sequence lengths of the sequences in - the batch, used to index into sequence. prepended by 0. - if - x = [5, 1, 1, 1] <- continuous batching (batch=4) - then - query_start_loc = [0, 5, 6, 7, 8] <- the starting index of the next sequence; while the last value is - the ending index of the last sequence - [length(query_start_loc)-1 == batch] - for example: query_start_loc = torch.Tensor([0,10,16,17]), - x.shape=(dim,17) - seq_lens_cpu: (batch) int32 - The sequence lengths of the sequences in the batch - cache_indices: (batch) int32 - indicates the corresponding state index, - like so: conv_state = conv_states[cache_indices[batch_id]] - has_initial_state: (batch) bool - indicates whether should the kernel take the current state as initial - state for the calculations - [single boolean for each sequence in the batch: True or False] - bias: (dim,) - activation: either None or "silu" or "swish" or True - pad_slot_id: int - if cache_indices is passed, lets the kernel identify padded - entries that will not be processed, - for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] - in this case, the kernel will not process entries at - indices 0 and 3 - - out: same shape as `x` - """ - if isinstance(activation, bool) and activation: - activation = "silu" - - out = torch.empty_like(x) - - is_channel_last = (x.stride(0) == 1) & (x.stride(1) > 1) - dim, cu_seqlen = x.shape - _, width = weight.shape - state_len = width - 1 - np2_statelen = triton.next_power_of_2(state_len) - - stride_x_seq = 0 - stride_x_dim = x.stride(0) - stride_x_token = x.stride(1) - stride_w_dim = weight.stride(0) - stride_w_width = weight.stride(1) - stride_istate_seq = 0 - stride_istate_dim = 0 - stride_istate_token = 0 - num_cache_lines = 0 - if conv_states is not None: - # extensions to support vLLM: - # 1. conv_states is used to replaced initial_states - # 2. conv_states serve as a cache with num cache lines can be larger than batch size - # 3. mapping from sequence x[idx] to a cache line at index as specified via cache_indices[idx] - # 4. computation can be skipped if cache_indices[idx] == pad_slot_id - num_cache_lines = conv_states.size(0) - assert ( - num_cache_lines == conv_states.shape[0] - and dim == conv_states.shape[1] - and width - 1 <= conv_states.shape[2] - ) - stride_istate_seq = conv_states.stride(0) - stride_istate_dim = conv_states.stride(1) - stride_istate_token = conv_states.stride(2) - # assert stride_istate_dim == 1 - if out.dim() == 2: - stride_o_seq = 0 - stride_o_dim = out.stride(0) - stride_o_token = out.stride(1) - else: - stride_o_seq = out.stride(0) - stride_o_dim = out.stride(1) - stride_o_token = out.stride(2) - - if validate_data: - assert x.dim() == 2 - assert query_start_loc is not None - assert query_start_loc.dim() == 1 - assert x.stride(0) == 1 or x.stride(1) == 1 - padded_batch = query_start_loc.size(0) - 1 - if bias is not None: - assert bias.dim() == 1 - assert dim == bias.size(0) - if cache_indices is not None: - assert cache_indices.dim() == 1 - assert padded_batch == cache_indices.size(0) - if has_initial_state is not None: - assert has_initial_state.size() == (padded_batch,) - assert ( - conv_states is not None - ), "ERROR: `has_initial_state` is used, which needs also `conv_states`" - assert weight.stride(1) == 1 - assert (dim, width) == weight.shape - assert is_channel_last, "Need to run in channel-last layout" - - def grid(META): - max_seq_len = max(seq_lens_cpu) - return ( - len(seq_lens_cpu), # batch_size - (max_seq_len + META["BLOCK_M"] - 1) // META["BLOCK_M"], - triton.cdiv(dim, META["BLOCK_N"]), - ) - - _causal_conv1d_fwd_kernel[grid]( - # Pointers to matrices - x, - weight, - bias, - conv_states, - cache_indices, - has_initial_state, - query_start_loc, - out, - # Matrix dimensions - dim, - cu_seqlen, - num_cache_lines, - # stride - stride_x_seq, - stride_x_dim, - stride_x_token, - stride_w_dim, - stride_w_width, - stride_istate_seq, - stride_istate_dim, - stride_istate_token, - stride_o_seq, - stride_o_dim, - stride_o_token, - # others - pad_slot_id, - # META - HAS_BIAS=bias is not None, - KERNEL_WIDTH=width, - SILU_ACTIVATION=activation in ["silu", "swish"], - HAS_INITIAL_STATES=has_initial_state is not None, - HAS_CACHE=conv_states is not None, - IS_CONTINUOUS_BATCHING=cache_indices is not None, - USE_PAD_SLOT=pad_slot_id is not None, - NP2_STATELEN=np2_statelen, - # launch_cooperative_grid=True - BLOCK_M=8, - BLOCK_N=256, - num_stages=2, - ) - return out - - -# HAS_EAGLE_TREE_CUSTOM_ATTN_MASK is added to support eagle tree attention mask -# retrieve_next_token_ptr: [N, NP2_T], retrieve_next_sibling_ptr: [N, NP2_T] -# e.g. for a sequence of length 4, the eagle tree attention structure is: -# retrieve_next_token=[1, 3, -1, -1] -> retrieve_next_token[i]: the 1st child token of token i -# retrieve_next_sibling=[-1, 2, -1, -1] -> retrieve_next_sibling[i]: the 1st tree sibling token of token i -# retrieve_parent_token=[n/a, 0, 0, 1] -> retrieve_parent_token[i]: the parent token of token i -# Tree: -# 0 -# / \ -# 1 2 -# / -# 3 -# When calculating token 3's convolution, it should conv to token 1 (parent) and token 0 (grand-parent) -# When calculating token 2's convolution, it should conv to token 0 (parent) -# This kernel is a fused kernel which will also produce retrieve_parent_token based on retrieve_next_token & retrieve_next_sibling -@triton.jit() -def _causal_conv1d_update_kernel( - # Pointers to matrices - x_ptr, # (batch, dim, seqlen) - w_ptr, # (dim, width) - bias_ptr, - conv_state_ptr, - cache_seqlens_ptr, # circular buffer +@triton.jit +def causal_conv1d_update_kernel( +x_ptr, x_b_stride, x_c_stride, + conv_state_ptr, conv_state_b_stride, conv_state_c_stride, conv_state_l_stride, + weight_ptr, weight_c_stride, weight_width_stride, + bias_ptr, bias_stride, + out_ptr, out_b_stride, out_c_stride, conv_state_indices_ptr, - num_accepted_tokens_ptr, - intermediate_conv_window_ptr, - intermediate_state_indices_ptr, - retrieve_next_token_ptr, - retrieve_next_sibling_ptr, - retrieve_parent_token_ptr, - o_ptr, # (batch, dim, seqlen) - # Matrix dimensions - batch: int, - dim: tl.constexpr, - seqlen: tl.constexpr, - state_len: tl.constexpr, - num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines - # Strides - stride_x_seq: tl.constexpr, - stride_x_dim: tl.constexpr, - stride_x_token: tl.constexpr, - stride_w_dim: tl.constexpr, - stride_w_width: tl.constexpr, - stride_conv_state_seq: tl.constexpr, - stride_conv_state_dim: tl.constexpr, - stride_conv_state_tok: tl.constexpr, - stride_state_indices: tl.constexpr, - stride_inter_seq: tl.constexpr, - stride_inter_step: tl.constexpr, - stride_inter_dim: tl.constexpr, - stride_inter_win: tl.constexpr, - stride_intermediate_state_indices: tl.constexpr, - stride_retrieve_next_token_seq: tl.constexpr, - stride_retrieve_next_token_token: tl.constexpr, - stride_retrieve_next_sibling_seq: tl.constexpr, - stride_retrieve_next_sibling_token: tl.constexpr, - stride_retrieve_parent_token_seq: tl.constexpr, - stride_retrieve_parent_token_token: tl.constexpr, - stride_o_seq: tl.constexpr, - stride_o_dim: tl.constexpr, - stride_o_token: tl.constexpr, - # others - pad_slot_id: tl.constexpr, - # Meta-parameters + cache_seqlens_ptr, + batch, dim, state_len, + WIDTH: tl.constexpr, + BLOCK_DIM: tl.constexpr, HAS_BIAS: tl.constexpr, - KERNEL_WIDTH: tl.constexpr, + IS_CIRCULAR: tl.constexpr, + HAS_STATE_INDICES: tl.constexpr, SILU_ACTIVATION: tl.constexpr, - IS_CONTINUOUS_BATCHING: tl.constexpr, - IS_SPEC_DECODING: tl.constexpr, - NP2_STATELEN: tl.constexpr, - NP2_SEQLEN: tl.constexpr, - USE_PAD_SLOT: tl.constexpr, - BLOCK_N: tl.constexpr, - SAVE_INTERMEDIATE: tl.constexpr, - HAS_EAGLE_TREE_CUSTOM_ATTN_MASK: tl.constexpr, ): - # ruff: noqa: E501 - idx_seq = tl.program_id(0) - if idx_seq >= batch: + batch_id = tl.program_id(0) + channel_block_id = tl.program_id(1) + + channel_offsets = channel_block_id * BLOCK_DIM + tl.arange(0, BLOCK_DIM) + mask = channel_offsets < dim + + # State batch coordinate mapping + if HAS_STATE_INDICES: + state_batch_coord = tl.load(conv_state_indices_ptr + batch_id) + else: + state_batch_coord = batch_id + + # Pointers + x_ptrs = x_ptr + batch_id * x_b_stride + channel_offsets * x_c_stride + out_ptrs = out_ptr + batch_id * out_b_stride + channel_offsets * out_c_stride + conv_state_ptrs = conv_state_ptr + state_batch_coord * conv_state_b_stride + channel_offsets * conv_state_c_stride + weight_ptrs = weight_ptr + channel_offsets * weight_c_stride + + # Skip padding tokens (block-level uniform condition) + if state_batch_coord < 0: + tl.store(out_ptrs, 0.0, mask=mask) return - # [BLOCK_N,] elements along the feature-dimension (channel) - idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) - - if IS_CONTINUOUS_BATCHING: - # mask = idx_seq < batch - conv_state_batch_coord = tl.load( - conv_state_indices_ptr + idx_seq * stride_state_indices - ).to(tl.int64) - if SAVE_INTERMEDIATE: - intermediate_state_batch_coord = tl.load( - intermediate_state_indices_ptr - + idx_seq * stride_intermediate_state_indices - ).to(tl.int64) + # Load Bias + if HAS_BIAS: + bias_val = tl.load(bias_ptr + channel_offsets * bias_stride, mask=mask).to(tl.float32) else: - conv_state_batch_coord = idx_seq - if USE_PAD_SLOT: # noqa - if conv_state_batch_coord == pad_slot_id: - # not processing as this is not the actual sequence - return - - if IS_SPEC_DECODING: - # The rolling of conv state: - # - # Before forward, the conv_state is: - # [history1, history2, ..., historyM]. - # - # After forward, the conv_state becomes: - # [history2, ..., historyM, draft1, draft2, ..., draftN]. - # - # After acceptance, it becomes: - # - # - accept 1 tokens: [history2, ..., historyM, draft1] - # - accept 2 tokens: [history3, ..., historyM, draft1, draft2] - # - and so on. - conv_state_token_offset = tl.load(num_accepted_tokens_ptr + idx_seq) - 1 + bias_val = tl.zeros([BLOCK_DIM], dtype=tl.float32) + + # Load Weights + if WIDTH == 2: + w0 = tl.load(weight_ptrs + 0 * weight_width_stride, mask=mask).to(tl.float32) + w1 = tl.load(weight_ptrs + 1 * weight_width_stride, mask=mask).to(tl.float32) + elif WIDTH == 3: + w0 = tl.load(weight_ptrs + 0 * weight_width_stride, mask=mask).to(tl.float32) + w1 = tl.load(weight_ptrs + 1 * weight_width_stride, mask=mask).to(tl.float32) + w2 = tl.load(weight_ptrs + 2 * weight_width_stride, mask=mask).to(tl.float32) + elif WIDTH == 4: + w0 = tl.load(weight_ptrs + 0 * weight_width_stride, mask=mask).to(tl.float32) + w1 = tl.load(weight_ptrs + 1 * weight_width_stride, mask=mask).to(tl.float32) + w2 = tl.load(weight_ptrs + 2 * weight_width_stride, mask=mask).to(tl.float32) + w3 = tl.load(weight_ptrs + 3 * weight_width_stride, mask=mask).to(tl.float32) + + # Initialize independent x_vals to match unrolled float array + x_val_0 = tl.zeros([BLOCK_DIM], dtype=tl.float32) + x_val_1 = tl.zeros([BLOCK_DIM], dtype=tl.float32) + x_val_2 = tl.zeros([BLOCK_DIM], dtype=tl.float32) + x_val_3 = tl.zeros([BLOCK_DIM], dtype=tl.float32) + + if not IS_CIRCULAR: + # Load the last (WIDTH - 1) elements to use them BEFORE they are overwritten by the shift + if WIDTH >= 2: + x_val_0 = tl.load(conv_state_ptrs + (state_len - WIDTH + 1) * conv_state_l_stride, mask=mask).to(tl.float32) + if WIDTH >= 3: + x_val_1 = tl.load(conv_state_ptrs + (state_len - WIDTH + 2) * conv_state_l_stride, mask=mask).to(tl.float32) + if WIDTH >= 4: + x_val_2 = tl.load(conv_state_ptrs + (state_len - WIDTH + 3) * conv_state_l_stride, mask=mask).to(tl.float32) + + # Shift the linear state buffer left by 1 since advance_len is exactly 1 + i = 0 + while i < state_len - 1: + val = tl.load(conv_state_ptrs + (i + 1) * conv_state_l_stride, mask=mask) + tl.store(conv_state_ptrs + i * conv_state_l_stride, val, mask=mask) + i += 1 + else: - conv_state_token_offset = 0 - - # STEP 1: READ init_state data - conv_states_base = ( - conv_state_ptr - + (conv_state_batch_coord * stride_conv_state_seq) - + (idx_feats * stride_conv_state_dim) - ) - mask_w = idx_feats < dim - - prior_tokens = conv_states_base + conv_state_token_offset * stride_conv_state_tok - if KERNEL_WIDTH >= 2: - conv_states_ptrs = prior_tokens # [BLOCK_N] - col0 = tl.load(conv_states_ptrs, mask_w, 0.0) - if KERNEL_WIDTH >= 3: - conv_states_ptrs = prior_tokens + 1 * stride_conv_state_tok # [BLOCK_N] - col1 = tl.load(conv_states_ptrs, mask_w, 0.0) - if KERNEL_WIDTH >= 4: - conv_states_ptrs = prior_tokens + 2 * stride_conv_state_tok # [BLOCK_N] - col2 = tl.load(conv_states_ptrs, mask_w, 0.0) - if KERNEL_WIDTH == 5: - conv_states_ptrs = prior_tokens + 3 * stride_conv_state_tok # [BLOCK_N] - col3 = tl.load(conv_states_ptrs, mask_w, 0.0) - - # STEP 2: assume state_len > seqlen - idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M] - - # The conv_state updates works in a sliding window manner, - # at each forward pass, the tokens are shift by 1, so we - # load since idx_tokens + 1. - conv_state_ptrs_source = ( - conv_state_ptr - + (conv_state_batch_coord * stride_conv_state_seq) - + conv_state_token_offset * stride_conv_state_tok - + (idx_feats * stride_conv_state_dim)[None, :] - + ((idx_tokens + (1 if IS_SPEC_DECODING else seqlen)) * stride_conv_state_tok)[ - :, None - ] - ) # [BLOCK_M, BLOCK_N] - mask = ( - (conv_state_batch_coord < num_cache_lines) - & ((idx_tokens + seqlen) < state_len)[:, None] - & (idx_feats < dim)[None, :] - ) - conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0) - - VAL = state_len - seqlen - x_base = x_ptr + (idx_seq * stride_x_seq) + (idx_feats * stride_x_dim) # [BLOCK_N] - - x_ptrs = ( - x_base[None, :] + ((idx_tokens - VAL) * stride_x_token)[:, None] - ) # [BLOCK_M, BLOCK_N] - - mask_x = ( - (idx_tokens - VAL >= 0)[:, None] - & (idx_tokens - VAL < seqlen)[:, None] - & (idx_feats < dim)[None, :] - ) # token-index # token-index # feature-index - loaded_x = tl.load(x_ptrs, mask_x, 0.0) - tl.debug_barrier() - - new_conv_state = tl.where(mask, conv_state, loaded_x) - - conv_state_base = ( - conv_state_ptr - + (conv_state_batch_coord * stride_conv_state_seq) - + (idx_feats * stride_conv_state_dim) - ) # [BLOCK_N,] - conv_state_ptrs_target = ( - conv_state_base + (idx_tokens * stride_conv_state_tok)[:, None] - ) # [BLOCK_M, BLOCK_N] - mask = (idx_tokens < state_len)[:, None] & (idx_feats < dim)[None, :] - tl.store(conv_state_ptrs_target, new_conv_state, mask) - - # STEP 3: init accumulator - if HAS_BIAS: - bias = bias_ptr + idx_feats - mask_bias = idx_feats < dim - acc_preload = tl.load(bias, mask=mask_bias, other=0.0).to( - tl.float32 - ) # [BLOCK_N] + cache_seqlen = tl.load(cache_seqlens_ptr + batch_id) % state_len + update_idx = cache_seqlen - (WIDTH - 1) + update_idx = tl.where(update_idx < 0, update_idx + state_len, update_idx) + + if WIDTH >= 2: + state_val = tl.load(conv_state_ptrs + update_idx * conv_state_l_stride, mask=mask) + x_val_0 = state_val.to(tl.float32) + update_idx = tl.where(update_idx + 1 >= state_len, update_idx + 1 - state_len, update_idx + 1) + if WIDTH >= 3: + state_val = tl.load(conv_state_ptrs + update_idx * conv_state_l_stride, mask=mask) + x_val_1 = state_val.to(tl.float32) + update_idx = tl.where(update_idx + 1 >= state_len, update_idx + 1 - state_len, update_idx + 1) + if WIDTH >= 4: + state_val = tl.load(conv_state_ptrs + update_idx * conv_state_l_stride, mask=mask) + x_val_2 = state_val.to(tl.float32) + update_idx = tl.where(update_idx + 1 >= state_len, update_idx + 1 - state_len, update_idx + 1) + + # Process the single token per request + x_val = tl.load(x_ptrs, mask=mask) + + if not IS_CIRCULAR: + tl.store(conv_state_ptrs + (state_len - 1) * conv_state_l_stride, x_val, mask=mask) else: - acc_preload = tl.zeros((BLOCK_N,), dtype=tl.float32) - - # STEP 4: - # PRE-LOAD WEIGHTS - # first kernel column, configured for weights to handle BLOCK_N features in range - if HAS_EAGLE_TREE_CUSTOM_ATTN_MASK: - idx_tokens = tl.arange(0, NP2_SEQLEN) # [BLOCK_M] - # Update parent mapping for all tokens at once using vectorized operations - mask_retrieve = idx_tokens < seqlen - retrieve_next_token_base = ( - retrieve_next_token_ptr - + (idx_seq * stride_retrieve_next_token_seq) - + idx_tokens * stride_retrieve_next_token_token - ) - retrieve_next_tokens = tl.load(retrieve_next_token_base, mask_retrieve) - retrieve_next_sibling_base = ( - retrieve_next_sibling_ptr - + (idx_seq * stride_retrieve_next_sibling_seq) - + idx_tokens * stride_retrieve_next_sibling_token - ) - retrieve_next_siblings = tl.load(retrieve_next_sibling_base, mask_retrieve) - parent_idx_tokens = tl.zeros((NP2_SEQLEN,), dtype=tl.int32) - - w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,] - mask_w = idx_feats < dim - if KERNEL_WIDTH >= 2: - w_ptrs = w_base + (0 * stride_w_width) # [BLOCK_N] tensor - w_col0 = tl.load(w_ptrs, mask_w, other=0.0) - w_ptrs = w_base + (1 * stride_w_width) # [BLOCK_N] tensor - w_col1 = tl.load(w_ptrs, mask_w, other=0.0) - if KERNEL_WIDTH >= 3: - w_ptrs = w_base + (2 * stride_w_width) # [BLOCK_N] tensor - w_col2 = tl.load(w_ptrs, mask_w, other=0.0) - if KERNEL_WIDTH >= 4: - w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor - w_col3 = tl.load(w_ptrs, mask_w, other=0.0) - - x_base_1d = x_base # starting of chunk [BLOCK_N] - mask_x_1d = idx_feats < dim - - # STEP 5: compute each token - for idx_token in tl.static_range(seqlen): - acc = acc_preload - - if HAS_EAGLE_TREE_CUSTOM_ATTN_MASK: - # set the parent index of the next token in the eagle tree - # next token's parent is the current token - retrieve_next_token_idx = tl.sum( - tl.where(idx_tokens == idx_token, retrieve_next_tokens, 0) - ) - if retrieve_next_token_idx != -1: # pad slot id - parent_idx_tokens = tl.where( - idx_tokens == retrieve_next_token_idx, - idx_token, - parent_idx_tokens, - ) - # next token's parent is the parent of the current token - retrieve_sibling_token_idx = tl.sum( - tl.where(idx_tokens == idx_token, retrieve_next_siblings, 0) - ) - if retrieve_sibling_token_idx != -1: # pad slot id - parent_idx_token = tl.sum( - tl.where(idx_tokens == idx_token, parent_idx_tokens, 0) - ) - parent_idx_tokens = tl.where( - idx_tokens == retrieve_sibling_token_idx, - parent_idx_token, - parent_idx_tokens, - ) - # tl.device_print("am", parent_idx_tokens) - - _idx_token = idx_token - x_ptrs_1d = x_base_1d + _idx_token * stride_x_token # [BLOCK_N] - matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) - # convolution operation: itself * wcol[-1] + parent * wcol[-2] + grand-parent * wcol[-3] + ... - for j in tl.static_range(KERNEL_WIDTH): - if KERNEL_WIDTH == 2: - if j == 0: - matrix_w = w_col1 - else: - matrix_w = w_col0 - elif KERNEL_WIDTH == 3: - if j == 0: - matrix_w = w_col2 - elif j == 1: - matrix_w = w_col1 - else: - matrix_w = w_col0 - elif KERNEL_WIDTH == 4: - if j == 0: - matrix_w = w_col3 - elif j == 1: - matrix_w = w_col2 - elif j == 2: - matrix_w = w_col1 - else: - matrix_w = w_col0 - - if SAVE_INTERMEDIATE: - # Save the window state after consuming this token - # Layout: [seq(cache line), step, dim, win(K-1)] - base_ptr = ( - intermediate_conv_window_ptr - + intermediate_state_batch_coord * stride_inter_seq - + idx_token * stride_inter_step - + idx_feats * stride_inter_dim - ) - - # store itself in KERNEL_WIDTH-2 slot, parent in KERNEL_WIDTH-3 slot, grand-parent in KERNEL_WIDTH-4 slot, ... - if KERNEL_WIDTH - j - 2 >= 0: - tl.store( - base_ptr + (KERNEL_WIDTH - j - 2) * stride_inter_win, - matrix_x, - mask=mask_w, - ) - - acc += matrix_x * matrix_w - - # move to parent for next iteration - if _idx_token > 0: - _idx_token = tl.sum( - tl.where(idx_tokens == _idx_token, parent_idx_tokens, 0) - ) - x_ptrs_1d = x_base_1d + _idx_token * stride_x_token # [BLOCK_N] - matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) - else: - # no parent within the current chunk, load from prev conv state: col[-1] (idx 0's parent), col[-2] (idx 0's grand parent), ... - if KERNEL_WIDTH == 2: - if _idx_token == 0: - matrix_x = col0 - elif KERNEL_WIDTH == 3: - if _idx_token == 0: - matrix_x = col1 - else: - matrix_x = col0 - elif KERNEL_WIDTH == 4: - if _idx_token == 0: - matrix_x = col2 - elif _idx_token == -1: - matrix_x = col1 - else: - matrix_x = col0 - _idx_token = _idx_token - 1 - else: - matrix_w = w_col0 - matrix_x = col0 - - for j in tl.static_range(KERNEL_WIDTH): - if KERNEL_WIDTH == 2: - if j == 1: # KERNEL_WIDTH-1: - matrix_w = w_col1 - x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] - matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) - elif KERNEL_WIDTH == 3: - if j == 1: - matrix_w = w_col1 - matrix_x = col1 - elif j == 2: - matrix_w = w_col2 - x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] - matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) - elif KERNEL_WIDTH == 4: - if j == 1: - matrix_w = w_col1 - matrix_x = col1 - elif j == 2: - matrix_w = w_col2 - matrix_x = col2 - elif j == 3: - matrix_w = w_col3 - x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] - matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) - - acc += matrix_x * matrix_w # [BLOCK_N] - - if KERNEL_WIDTH == 2: - col0 = matrix_x - elif KERNEL_WIDTH == 3: - col0 = col1 - col1 = matrix_x - elif KERNEL_WIDTH == 4: - col0 = col1 - col1 = col2 - col2 = matrix_x - - if SAVE_INTERMEDIATE: - # Save the window state after consuming this token - # Layout: [seq(cache line), step, dim, win(K-1)] - base_ptr = ( - intermediate_conv_window_ptr - + intermediate_state_batch_coord * stride_inter_seq - + idx_token * stride_inter_step - + idx_feats * stride_inter_dim - ) - if KERNEL_WIDTH >= 2: - tl.store(base_ptr + 0 * stride_inter_win, col0, mask=mask_w) - if KERNEL_WIDTH >= 3: - tl.store(base_ptr + 1 * stride_inter_win, col1, mask=mask_w) - if KERNEL_WIDTH >= 4: - tl.store(base_ptr + 2 * stride_inter_win, col2, mask=mask_w) - - if SILU_ACTIVATION: - acc = acc / (1 + tl.exp(-acc)) - mask_1d = (idx_token < seqlen) & ( - idx_feats < dim - ) # token-index # feature-index - o_ptrs = ( - o_ptr - + (idx_seq) * stride_o_seq - + idx_token * stride_o_token - + (idx_feats * stride_o_dim) - ) - - tl.store(o_ptrs, acc, mask=mask_1d) - - # fuse: store calculated retrieve_parent_token to tensor - if HAS_EAGLE_TREE_CUSTOM_ATTN_MASK: - tl.store( - retrieve_parent_token_ptr - + idx_seq * stride_retrieve_parent_token_seq - + idx_tokens * stride_retrieve_parent_token_token, - parent_idx_tokens, - mask=mask_retrieve, - ) - + tl.store(conv_state_ptrs + update_idx * conv_state_l_stride, x_val, mask=mask) + # No need to advance update_idx further since the sequence loop is removed + + x_val_f32 = x_val.to(tl.float32) + if WIDTH == 2: + x_val_1 = x_val_f32 + elif WIDTH == 3: + x_val_2 = x_val_f32 + elif WIDTH == 4: + x_val_3 = x_val_f32 + + out_val = bias_val + if WIDTH == 2: + out_val += w0 * x_val_0 + w1 * x_val_1 + elif WIDTH == 3: + out_val += w0 * x_val_0 + w1 * x_val_1 + w2 * x_val_2 + elif WIDTH == 4: + out_val += w0 * x_val_0 + w1 * x_val_1 + w2 * x_val_2 + w3 * x_val_3 + + if SILU_ACTIVATION: + out_val = out_val * tl.sigmoid(out_val) + + tl.store(out_ptrs, out_val.to(out_ptrs.dtype.element_ty), mask=mask) def causal_conv1d_update( x: torch.Tensor, conv_state: torch.Tensor, weight: torch.Tensor, - bias: Optional[torch.Tensor] = None, - activation: Union[bool, str, None] = None, - cache_seqlens: Optional[torch.Tensor] = None, - conv_state_indices: Optional[torch.Tensor] = None, - num_accepted_tokens: Optional[torch.Tensor] = None, - intermediate_conv_window: Optional[torch.Tensor] = None, - intermediate_state_indices: Optional[torch.Tensor] = None, - retrieve_next_token: Optional[torch.Tensor] = None, - retrieve_next_sibling: Optional[torch.Tensor] = None, - retrieve_parent_token: Optional[torch.Tensor] = None, - pad_slot_id: int = PAD_SLOT_ID, - metadata=None, - validate_data=False, -): - """ - x: (batch, dim) or (batch, dim, seqlen) - [shape=2: single token prediction] - [shape=3: single or multiple tokens prediction] - conv_state: (..., dim, state_len), where state_len >= width - 1 - weight: (dim, width) - bias: (dim,) - cache_seqlens: (batch,), dtype int32. - If not None, the conv_state is treated as a circular buffer. - The conv_state will be updated by copying x to the conv_state - starting at the index - @cache_seqlens % state_len. - conv_state_indices: (batch,), dtype int32 - If not None, the conv_state is a larger tensor along the batch dim, - and we are selecting the batch coords specified by conv_state_indices. - Useful for a continuous batching scenario. - pad_slot_id: int - if cache_indices is passed, lets the kernel identify padded - entries that will not be processed, - for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] - in this case, the kernel will not process entries at - indices 0 and 3 - out: (batch, dim) or (batch, dim, seqlen) - """ - if validate_data: - assert cache_seqlens is None # not implemented yet - ok for vLLM - assert pad_slot_id is not None - assert x.stride(1) == 1 - if isinstance(activation, bool): - activation = "silu" if activation is True else None - elif activation is not None: - assert activation in ["silu", "swish"] - unsqueeze = x.dim() == 2 - if unsqueeze: - # make it (batch, dim, seqlen) with seqlen == 1 - x = x.unsqueeze(-1) - batch, dim, seqlen = x.shape - _, width = weight.shape - # conv_state: (..., dim, state_len), where state_len >= width - 1 - num_cache_lines, _, state_len = conv_state.size() - - if validate_data: - assert dim == weight.size(0) - assert ( - conv_state.stride(-2) == 1 - ), f"ERROR: expect contiguous along feat-dim of conv_state (currently stride={conv_state.stride()})" - assert state_len >= width - 1 - # when above happens, we don't shift-left to keep any records in conv_state - assert dim == conv_state.size(1) - if conv_state_indices is None: - assert conv_state.size(0) >= batch - else: - assert (batch,) == conv_state_indices.shape - assert intermediate_state_indices is not None - assert (batch,) == intermediate_state_indices.shape - - assert num_cache_lines >= batch - assert weight.stride(1) == 1 # Need this - assert cache_seqlens is None # not needed for vLLM - circular buffer - - # adopt the strategy in vLLM that overwrite on 'x' directly, rather than creating a new tensor 'o' + bias: torch.Tensor | None, + silu_activation: bool, + cache_seqlens: torch.Tensor | None, + conv_state_indices: torch.Tensor | None, +) -> torch.Tensor: out = torch.empty_like(x) - stride_w_dim, stride_w_width = weight.stride() - - stride_x_seq, stride_x_dim, stride_x_token = x.stride() # X (batch, dim, seqlen) - - stride_o_seq, stride_o_dim, stride_o_token = out.stride() - stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride() - stride_state_indices = ( - conv_state_indices.stride(0) if conv_state_indices is not None else 0 - ) - stride_intermediate_state_indices = ( - intermediate_state_indices.stride(0) - if intermediate_state_indices is not None - else 0 - ) - if num_accepted_tokens is not None: - state_len = width - 1 + (seqlen - 1) # effective state_len needed - else: - state_len = width - 1 - np2_statelen = triton.next_power_of_2(state_len) - np2_seqlen = triton.next_power_of_2(seqlen) + batch, dim = x.shape + state_len = conv_state.shape[-1] + width = weight.shape[-1] - def grid(META): - return ( - batch, - triton.cdiv(dim, META["BLOCK_N"]), - ) - - # prepare intermediate buffer strides if provided - if intermediate_conv_window is not None: - stride_inter_seq, stride_inter_step, stride_inter_dim, stride_inter_win = ( - intermediate_conv_window.stride(0), - intermediate_conv_window.stride(1), - intermediate_conv_window.stride(2), - intermediate_conv_window.stride(3), - ) + if bias is not None: + bias_stride = bias.stride(0) + has_bias = True else: - stride_inter_seq = stride_inter_step = stride_inter_dim = stride_inter_win = 0 + bias = x # Dummy pointer + bias_stride = 0 + has_bias = False - # prepare retrieve next token buffer strides if provided - if retrieve_next_token is not None: - stride_retrieve_next_token_seq, stride_retrieve_next_token_token = ( - retrieve_next_token.stride(0), - retrieve_next_token.stride(1), - ) + if cache_seqlens is not None: + is_circular = True else: - stride_retrieve_next_token_seq = stride_retrieve_next_token_token = 0 + cache_seqlens = x # Dummy pointer + is_circular = False - # prepare retrieve next sibling buffer strides if provided - if retrieve_next_sibling is not None: - stride_retrieve_next_sibling_seq, stride_retrieve_next_sibling_token = ( - retrieve_next_sibling.stride(0), - retrieve_next_sibling.stride(1), - ) + if conv_state_indices is not None: + has_state_indices = True else: - stride_retrieve_next_sibling_seq = stride_retrieve_next_sibling_token = 0 - - # prepare retrieve parent token buffer strides if provided - if retrieve_parent_token is not None: - stride_retrieve_parent_token_seq, stride_retrieve_parent_token_token = ( - retrieve_parent_token.stride(0), - retrieve_parent_token.stride(1), - ) - else: - stride_retrieve_parent_token_seq = stride_retrieve_parent_token_token = 0 - - _causal_conv1d_update_kernel[grid]( - # Pointers to matrices - x, - weight, - bias, - conv_state, - cache_seqlens, - conv_state_indices, - num_accepted_tokens, - intermediate_conv_window if intermediate_conv_window is not None else x, - intermediate_state_indices, - retrieve_next_token, - retrieve_next_sibling, - retrieve_parent_token, - out, - # Matrix dimensions - batch, - dim, - seqlen, - state_len, - num_cache_lines, - # stride - stride_x_seq, - stride_x_dim, - stride_x_token, - stride_w_dim, - stride_w_width, - stride_istate_seq, - stride_istate_dim, - stride_istate_token, - stride_state_indices, - stride_inter_seq, - stride_inter_step, - stride_inter_dim, - stride_inter_win, - stride_intermediate_state_indices, - stride_retrieve_next_token_seq, - stride_retrieve_next_token_token, - stride_retrieve_next_sibling_seq, - stride_retrieve_next_sibling_token, - stride_retrieve_parent_token_seq, - stride_retrieve_parent_token_token, - stride_o_seq, - stride_o_dim, - stride_o_token, - # others - pad_slot_id, - # META - HAS_BIAS=bias is not None, - KERNEL_WIDTH=width, - SILU_ACTIVATION=activation in ["silu", "swish"], - IS_CONTINUOUS_BATCHING=conv_state_indices is not None, - IS_SPEC_DECODING=num_accepted_tokens is not None, - NP2_STATELEN=np2_statelen, - NP2_SEQLEN=np2_seqlen, - USE_PAD_SLOT=pad_slot_id is not None, - BLOCK_N=256, - SAVE_INTERMEDIATE=intermediate_conv_window is not None, - HAS_EAGLE_TREE_CUSTOM_ATTN_MASK=retrieve_next_token is not None, + conv_state_indices = x # Dummy pointer + has_state_indices = False + + BLOCK_DIM = 64 + grid = (batch, triton.cdiv(dim, BLOCK_DIM)) + + causal_conv1d_update_kernel[grid]( + x, x.stride(0), x.stride(1), + conv_state, conv_state.stride(0), conv_state.stride(1), conv_state.stride(2), + weight, weight.stride(0), weight.stride(1), + bias, bias_stride, + out, out.stride(0), out.stride(1), + conv_state_indices, cache_seqlens, + batch, dim, state_len, + WIDTH=width, + BLOCK_DIM=BLOCK_DIM, + HAS_BIAS=has_bias, + IS_CIRCULAR=is_circular, + HAS_STATE_INDICES=has_state_indices, + SILU_ACTIVATION=silu_activation == "silu", ) - if unsqueeze: - out = out.squeeze(-1) + return out From 5915cc210f2c943396dcb24605d9857823a9dca6 Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Wed, 25 Feb 2026 15:12:09 -0800 Subject: [PATCH 12/76] Mamba almost working Signed-off-by: Keshav Santhanam --- .../attention_context/mamba_metadata.py | 11 +- .../inference/contexts/dynamic_context.py | 102 ++-- .../core/inference/engines/dynamic_engine.py | 9 +- .../text_generation_controller.py | 26 +- megatron/core/models/mamba/mamba_model.py | 46 +- megatron/core/ssm/mamba_mixer.py | 182 +++---- megatron/core/ssm/ops/causal_conv1d_triton.py | 299 +++++++---- megatron/core/ssm/ops/mamba_ssm.py | 464 ++++++++---------- megatron/core/transformer/attention.py | 2 +- megatron/inference/utils.py | 3 +- 10 files changed, 620 insertions(+), 524 deletions(-) diff --git a/megatron/core/inference/contexts/attention_context/mamba_metadata.py b/megatron/core/inference/contexts/attention_context/mamba_metadata.py index fbc5d2145ac..e492a1b32e0 100644 --- a/megatron/core/inference/contexts/attention_context/mamba_metadata.py +++ b/megatron/core/inference/contexts/attention_context/mamba_metadata.py @@ -266,12 +266,17 @@ def update( self.cu_seqlens = self._cu_seqlens_buffer[: padded_prefill_count + 1] if padded_decode_count > 0 and padded_prefill_count > 0: - self._device_decode_prefill_buffer[0] = real_decode_count + self._device_decode_prefill_buffer[0] = cu_seqlens[real_decode_count] # This describes the number of items in the prefill tensor relative to the # decode tensor. If chunked prefill is present, it is included in the # "prefill" part of the main split. - self._device_decode_prefill_buffer[1] = regular_prefill_count + ( - 1 if has_chunked_prefill_req else 0 + self._device_decode_prefill_buffer[1] = ( + cu_seqlens[ + real_decode_count + + regular_prefill_count + + (1 if has_chunked_prefill_req else 0) + ] + - cu_seqlens[real_decode_count] ) self.device_decode_prefill = self._device_decode_prefill_buffer diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index ce66cd443c7..ba0d15ee513 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -348,6 +348,7 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC mamba_states_memory_per_request += math.prod(self.mamba_ssm_states_shape) mamba_states_memory_per_request *= self.num_mamba_layers mamba_states_memory_per_request *= dtype_size_bytes + mamba_states_memory_per_request *= self.num_speculative_tokens + 1 # Unified memory and general tensor management. self.unified_memory_level = inference_config.unified_memory_level @@ -612,8 +613,10 @@ def _allocate_mamba_states(self): self.mamba_metadata = MambaMetadata( max_requests=self.max_requests, max_tokens=self.max_tokens ) + expanded_conv_shape = list(self.mamba_conv_states_shape) + expanded_conv_shape[-1] += self.num_speculative_tokens self.mamba_conv_states = torch.empty( - (self.num_mamba_layers, self.max_requests) + self.mamba_conv_states_shape, + (self.num_mamba_layers, self.max_requests, *expanded_conv_shape), dtype=self.params_dtype, device=torch.cuda.current_device(), ) @@ -623,21 +626,11 @@ def _allocate_mamba_states(self): device=torch.cuda.current_device(), ) if self.num_speculative_tokens > 0: - self.mamba_intermediate_conv_states = torch.empty( - ( - self.num_mamba_layers, - self.max_requests, - self.num_speculative_tokens, - *self.mamba_conv_states_shape, - ), - dtype=self.params_dtype, - device=torch.cuda.current_device(), - ) self.mamba_intermediate_ssm_states = torch.empty( ( self.num_mamba_layers, self.max_requests, - self.num_speculative_tokens, + self.num_speculative_tokens + 1, *self.mamba_ssm_states_shape, ), dtype=self.params_dtype, @@ -657,12 +650,6 @@ def _allocate_mamba_states(self): self.mamba_ssm_states, device="cpu" ).pin_memory() if self.num_speculative_tokens > 0: - self._offloadable_tensor_names.add("mamba_intermediate_conv_states") - self._offloadable_cpu_backups["mamba_intermediate_conv_states"] = ( - torch.empty_like( - self.mamba_intermediate_conv_states, device="cpu" - ).pin_memory() - ) self._offloadable_tensor_names.add("mamba_intermediate_ssm_states") self._offloadable_cpu_backups["mamba_intermediate_ssm_states"] = ( torch.empty_like( @@ -973,7 +960,7 @@ def mamba_states_cache( mamba_layer_number = self.layer_map[layer_number - 1] if intermediate: - conv_state = self.mamba_intermediate_conv_states[mamba_layer_number] + conv_state = None ssm_state = self.mamba_intermediate_ssm_states[mamba_layer_number] else: conv_state = self.mamba_conv_states[mamba_layer_number] @@ -1346,7 +1333,7 @@ def initialize_attention_state( self.max_requests, self.round_up_tokens(self.active_token_count), ) - padded_decode_req_count = padded_token_count + padded_decode_req_count = padded_token_count // (self.num_speculative_tokens + 1) padded_prefill_req_count = 0 else: target_padding_req_count = min( @@ -1506,6 +1493,9 @@ def current_input_and_position_ids( (Tuple[Tensor, Tensor]) Flattened active input and position IDs. """ num_tokens = num_warmup_tokens or self.padded_active_token_count + assert num_tokens >= self.batch_dimensions.decode_req_count * ( + self.num_speculative_tokens + 1 + ) return ( self.token_to_input_ids[:num_tokens].unsqueeze(0), self.token_to_pos_ids[:num_tokens].unsqueeze(0), @@ -1813,55 +1803,57 @@ def resume_paused_requests( resume_request_count = 0 if self.paused_request_count > 0: active_block_count_avail = self.block_allocator.get_active_avail() - paused_block_counts = self.request_kv_block_counts[: self.paused_request_count] + paused_block_counts = self.request_kv_block_counts[: self.paused_request_count].clone() # Flip counts before cumsum, since paused requests are resumed from # the right-most index, so we must count resumed blocks starting from # the right side. paused_block_counts = paused_block_counts.flip(dims=[0]) - # Add +1 to all block counts, since any time a paused request is - # resumed, it will be starting a new memory block. For background, - # pausing happens after a request has generated the final token of a - # memory block (i.e., token 256 of that block), which means the very - # next token (whenever that request gets unpaused) will be in a new - # block. So, when we resume a paused request, we have to account for - # the fact that it will need an extra block beyond the ones that it - # has already used. - paused_block_counts += 1 # +1 for newly added block + + # Check which paused requests will actually need a new block upon resuming + offsets = self.request_last_kv_block_offset[: self.paused_request_count] + needs_new_block = ( + offsets >= self.block_size_tokens - 1 - self.num_speculative_tokens + ).to(paused_block_counts.dtype) + needs_new_block = needs_new_block.flip(dims=[0]) + + # Add +1 ONLY to the block counts of requests that finished their previous memory block + paused_block_counts += needs_new_block paused_block_counts_cumsum = paused_block_counts.cumsum(dim=0) resume_request_count = min( torch.nonzero(paused_block_counts_cumsum <= active_block_count_avail).numel(), self.block_allocator.total_avail, ) + # Constrain resumptions by the maximum allowed active requests + max_allowed_active = self.max_requests // (self.num_speculative_tokens + 1) + allowed_to_resume = max(0, max_allowed_active - active_request_count) + resume_request_count = min(resume_request_count, allowed_to_resume) + self.paused_request_count -= resume_request_count active_request_count += resume_request_count # Resume requests by assigning blocks and updating bookkeeping tensors. if resume_request_count > 0: - assert torch.all( - self.request_last_kv_block_offset[ - self.paused_request_count : (self.paused_request_count + resume_request_count) - ] - >= self.block_size_tokens - 1 - self.num_speculative_tokens - ), "The request_last_kv_block_offset should be greater than or equal to the block size tokens - 1 - num_speculative_tokens for the requests that just got resumed this step. (Currently its {self.request_last_kv_block_offset[self.paused_request_count : (self.paused_request_count + resume_request_count)]}), block size tokens: {self.block_size_tokens}, num_speculative_tokens: {self.num_speculative_tokens}" + resume_start = self.paused_request_count + resume_end = self.paused_request_count + resume_request_count - assert resume_request_count <= self.block_allocator.total_avail - block_ids = self.block_allocator.allocate_memory_blocks(resume_request_count) - row_idx = torch.arange( - self.paused_request_count, - self.paused_request_count + resume_request_count, - device=torch.cuda.current_device(), - ) - col_idx = self.request_kv_block_counts[ - self.paused_request_count : (self.paused_request_count + resume_request_count) - ] - self.request_to_kv_block_ids[row_idx, col_idx] = block_ids - self.request_kv_block_counts[ - self.paused_request_count : (self.paused_request_count + resume_request_count) - ] += 1 - self.request_last_kv_block_id[ - self.paused_request_count : (self.paused_request_count + resume_request_count) - ] = block_ids + # Check which resumed requests actually need a new block + offsets = self.request_last_kv_block_offset[resume_start:resume_end] + needs_new_block = offsets >= (self.block_size_tokens - 1 - self.num_speculative_tokens) + num_new_blocks = needs_new_block.sum().item() + + if num_new_blocks > 0: + assert num_new_blocks <= self.block_allocator.total_avail + block_ids = self.block_allocator.allocate_memory_blocks(num_new_blocks) + + # Apply updates only to the requests that required a new block + relative_row_idx = torch.nonzero(needs_new_block).squeeze(1) + row_idx = resume_start + relative_row_idx + col_idx = self.request_kv_block_counts[row_idx] + + self.request_to_kv_block_ids[row_idx, col_idx] = block_ids + self.request_kv_block_counts[row_idx] += 1 + self.request_last_kv_block_id[row_idx] = block_ids # Remove resumed requests from newly_paused_request_ids. We do this by # truncating the end of newly_paused_request_ids, which works because we @@ -2134,6 +2126,10 @@ def update_requests( active_requests_requiring_new_block[ self.get_index_of_chunked_prefill_request() - self.paused_request_count ] = 0 # chunked prefill should not be paused + elif active_request_count * (self.num_speculative_tokens + 1) > self.max_requests: + # Force-pause excess requests in a decode-only batch + max_allowed_active = self.max_requests // (self.num_speculative_tokens + 1) + active_requests_requiring_new_block[max_allowed_active:] = 1 active_requests_requiring_new_block_count = ( (active_requests_requiring_new_block == 1).sum().item() diff --git a/megatron/core/inference/engines/dynamic_engine.py b/megatron/core/inference/engines/dynamic_engine.py index 852558dc396..9228f02b032 100644 --- a/megatron/core/inference/engines/dynamic_engine.py +++ b/megatron/core/inference/engines/dynamic_engine.py @@ -179,14 +179,15 @@ def __init__(self, controller: TextGenerationController, context: DynamicInferen if self.num_speculative_tokens > 0: assert ( - not self.context.materialize_only_last_token_logits + not inference_config.materialize_only_last_token_logits ), "Speculative decoding requires materialize_only_last_token_logits to be False" assert ( self.num_speculative_tokens <= self.controller.num_mtp_heads ), f"Number of speculative tokens {self.num_speculative_tokens} must be less than or equal to number of MTP heads {self.controller.num_mtp_heads}" - assert ( - not self.enable_chunked_prefill - ), "Chunked prefill is not supported with speculative tokens" + + # assert ( + # not self.enable_chunked_prefill + # ), "Chunked prefill is not supported with speculative tokens" # Initialize MTP sampling tensor now that num_speculative_tokens is set self.controller._init_mtp_sampling_tensor() diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index a0d22fc844d..6aaf8831eab 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -758,8 +758,19 @@ def _rewind_kv_cache(self): # Mamba speculative rewind state update if context.is_hybrid_model: - # TODO(ksanthanam): Maybe reset interemdiate states - pass + active_mamba_indices = context.mamba_metadata.request_to_mamba_state_idx[ + active_request_slice + ] + is_decode_mask = context.request_in_prefill_status_tensor[active_request_slice] == 0 + decode_mamba_indices = active_mamba_indices[is_decode_mask] + accepted_tokens_per_decode_request = accepted_tokens_per_request[is_decode_mask] + + if decode_mamba_indices.numel() > 0: + context.mamba_ssm_states[:, decode_mamba_indices] = ( + context.mamba_intermediate_ssm_states[ + :, decode_mamba_indices, accepted_tokens_per_decode_request + ] + ) def _dynamic_step_sample_logits_and_verify_tokens( self, logits: Tensor, mtp_logits: Tensor, input_ids: Tensor @@ -801,7 +812,12 @@ def _dynamic_step_sample_logits_and_verify_tokens( assert ( len(required_logit_indices) == num_decode_requests * (self.num_speculative_tokens + 1) + num_prefill_requests - ), f"Expected length of required_logit_indices to be num_decode_requests * (self.num_speculative_tokens + 1) + num_prefill_requests, but got {len(required_logit_indices)} for num_decode_requests {num_decode_requests} and num_prefill_requests {num_prefill_requests}" + ), ( + f"Expected length of required_logit_indices to be " + f"num_decode_requests * (self.num_speculative_tokens + 1) + num_prefill_requests, " + f"but got {len(required_logit_indices)} for num_decode_requests {num_decode_requests} " + f"and num_prefill_requests {num_prefill_requests}" + ) required_logits = logits.squeeze(0)[required_logit_indices, :] # Shape [1, 11, vocab_size] required_mtp_logits = mtp_logits[ @@ -863,7 +879,7 @@ def _dynamic_step_sample_logits_and_verify_tokens( torch.isin(token_to_request_index, request_indices_tensor) )[0] # TODO : Can maybe club the following two and then split later ? - # TODO : Can directly initzlie output tokens as a tensor and put the logits in the right place + # TODO : Can directly initialize output tokens as a tensor and put the logits in the right place output_tokens_jumbled_list.append( self._torch_sampling_func(required_logits[required_indices, :], temp, top_k, top_p) ) @@ -954,7 +970,7 @@ def _dynamic_step_sample_logits_and_verify_tokens( # Accepted tokens [ [a6s -1] | [b4s b5s] | [-1 -1] ] # Only handle decod requests, (Prefill already defaults to -1s) # Accepted token counts [ 1 | 2 | 0 ] # Prefill defaults to 0 - # This part tis to extract the accepted tokens + # This part is to extract the accepted tokens input_tokens_required[accepted_tokens_mask == 0] = -1 # Masks out non accepted tokens input_tokens_decode_mode = input_tokens_required[ : num_decode_requests * (self.num_speculative_tokens + 1) diff --git a/megatron/core/models/mamba/mamba_model.py b/megatron/core/models/mamba/mamba_model.py index 8dd614fdaaa..c1cd3113c47 100644 --- a/megatron/core/models/mamba/mamba_model.py +++ b/megatron/core/models/mamba/mamba_model.py @@ -2,6 +2,7 @@ from typing import Literal, Optional +import torch from torch import Tensor from megatron.core import tensor_parallel @@ -323,19 +324,38 @@ def forward( return hidden_states if self.config.mtp_num_layers is not None: - hidden_states = process_mtp_loss( - hidden_states=hidden_states, - labels=labels, - loss_mask=loss_mask, - output_layer=self.output_layer, - output_weight=output_weight, - runtime_gather_output=runtime_gather_output, - is_training=self.training, - compute_language_model_loss=self.compute_language_model_loss, - config=self.config, - cp_group=self.pg_collection.cp, - packed_seq_params=packed_seq_params, - ) + # The new process_mtp_loss function doesn't handle mtp_logits_cache, + # so we manually generate and cache MTP logits when in inference mode. + if in_inference_mode: + hidden_states_list = torch.chunk( + hidden_states, 1 + self.config.mtp_num_layers, dim=0 + ) + hidden_states = hidden_states_list[0] + self._mtp_logits_cache = None + mtp_inference_logits = [] + for mtp_layer_number in range(self.config.mtp_num_layers): + mtp_logits, _ = self.output_layer( + hidden_states_list[mtp_layer_number + 1], + weight=output_weight, + runtime_gather_output=runtime_gather_output, + ) + # mtp logits shape [b, 1, vocab size] + mtp_inference_logits.append(mtp_logits.squeeze(1).unsqueeze(0)) + self._mtp_logits_cache = torch.cat(mtp_inference_logits, dim=0) + else: + hidden_states = process_mtp_loss( + hidden_states=hidden_states, + labels=labels, + loss_mask=loss_mask, + output_layer=self.output_layer, + output_weight=output_weight, + runtime_gather_output=runtime_gather_output, + is_training=self.training, + compute_language_model_loss=self.compute_language_model_loss, + config=self.config, + cp_group=self.pg_collection.cp, + packed_seq_params=packed_seq_params, + ) sequence_parallel_override = False if in_inference_mode and inference_context.config.materialize_only_last_token_logits: diff --git a/megatron/core/ssm/mamba_mixer.py b/megatron/core/ssm/mamba_mixer.py index 4dedbd6d9cd..745f36bcc3d 100644 --- a/megatron/core/ssm/mamba_mixer.py +++ b/megatron/core/ssm/mamba_mixer.py @@ -44,7 +44,7 @@ from .mamba_context_parallel import MambaContextParallel try: - #from mamba_ssm.ops.triton.selective_state_update import selective_state_update + # from mamba_ssm.ops.triton.selective_state_update import selective_state_update from megatron.core.ssm.ops.mamba_ssm import selective_state_update except ImportError: selective_state_update = None @@ -53,6 +53,7 @@ # from causal_conv1d import causal_conv1d_fn, causal_conv1d_update from causal_conv1d import causal_conv1d_fn from causal_conv1d.causal_conv1d_varlen import causal_conv1d_varlen_states + from megatron.core.ssm.ops.causal_conv1d_triton import causal_conv1d_update except ImportError: causal_conv1d_fn = None @@ -430,8 +431,17 @@ def _dynamic_inference(self, hidden_states: torch.Tensor, context: DynamicInfere ) assert sequence_packing_available, reason_for_no_sequence_packing + # Grab standard states conv_state, ssm_state = context.mamba_states_cache(self.layer_number - self.pp_layer_offset) + # Only fetch intermediate SSM state for speculative decoding + int_ssm_state = None + if context.num_speculative_tokens > 0: + # We ignore the conv intermediate state since we use the expanded circular buffer + _, int_ssm_state = context.mamba_states_cache( + self.layer_number - self.pp_layer_offset, intermediate=True + ) + padded_dims = context.padded_batch_dimensions token_count = padded_dims.token_count decode_req_count = padded_dims.decode_req_count @@ -442,36 +452,37 @@ def _dynamic_inference(self, hidden_states: torch.Tensor, context: DynamicInfere y_decode = None y_prefill = None - - """ - if self.layer_number == 1: - torch.distributed.breakpoint(0) - """ # Decode if decode_req_count > 0: # For mixed batch, the decode tokens are at the start of zxBCdt - zxBCdt_decode = zxBCdt[:decode_req_count] if prefill_req_count > 0 else zxBCdt + seq_len = 1 + context.num_speculative_tokens + decode_token_count = decode_req_count * seq_len - if context.num_speculative_tokens > 0: - num_accepted_tokens = context.mamba_metadata.num_accepted_tokens - intermediate_conv_state, intermediate_ssm_state = context.mamba_states_cache( - self.layer_number - self.pp_layer_offset, intermediate=True - ) - else: - num_accepted_tokens = None - intermediate_conv_state = None - intermediate_ssm_state = None + zxBCdt_decode = zxBCdt[:decode_token_count] if prefill_req_count > 0 else zxBCdt + + # Reshape from [N*S, 1, d] to [N, S, d] for the 3D Triton kernels + # if self.layer_number == 1: + # torch.distributed.breakpoint(0) + zxBCdt_decode = zxBCdt_decode.squeeze(1).view(decode_req_count, seq_len, -1) + + # Get sequence lengths for the circular buffer calculation + req_start = context.paused_request_count + cache_seqlens = context.request_kv_length_offsets[ + req_start : req_start + decode_req_count + ] y_decode = self._ssm_decode( - zxBCdt_decode.transpose(0, 1), + zxBCdt_decode, conv_state, ssm_state, - context.mamba_metadata.batch_indices_decode, - num_accepted_tokens=num_accepted_tokens, - intermediate_conv_state=intermediate_conv_state, - intermediate_ssm_state=intermediate_ssm_state, - ).transpose(0, 1) + batch_indices=context.mamba_metadata.batch_indices_decode, + intermediate_ssm_state=int_ssm_state, + cache_seqlens=cache_seqlens, + ) + + # Flatten back to [N*S, 1, d] to match merge logic + y_decode = y_decode.view(decode_token_count, 1, -1) # Prefill if prefill_req_count > 0: @@ -755,12 +766,26 @@ def _ssm_prefill( # Compute short convolution initial_conv_state = None if conv_state is not None and is_dynamic_batching: - # xBC should have shape (b l d) for causal_conv1d_varlen_states - assert batch_indices is not None + # Extract linear states (newest token is at state_len - 1) + state_len = conv_state.shape[-1] conv_varlen_states = causal_conv1d_varlen_states( - xBC.squeeze(0), cu_seqlens, state_len=conv_state.shape[-1] + xBC.squeeze(0), cu_seqlens, state_len=state_len ) - tensor_masked_update(conv_state, batch_indices, conv_varlen_states) + + # Roll into circular buffer layout expected by decode + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] + # Shift required to move token from `state_len - 1` to `(seqlen - 1) % state_len` + shifts = (seqlens % state_len).view(-1, 1, 1) + + B, D, W = conv_varlen_states.shape + base_idx = torch.arange(W, device=conv_state.device).view(1, 1, W) + gather_idx = (base_idx - shifts) % W + gather_idx = gather_idx.expand(B, D, W) + + conv_varlen_states_circular = torch.gather(conv_varlen_states, dim=2, index=gather_idx) + + # Update state + tensor_masked_update(conv_state, batch_indices, conv_varlen_states_circular) # Maintain channels-last memory layout to use seq_idx for causal_conv1d_fn # See https://github.com/Dao-AILab/causal-conv1d/blob/69e6dadc28b169a4c49cb86b586f64ee90242c70/csrc/causal_conv1d.cpp#L174 # pylint: disable=line-too-long @@ -769,12 +794,16 @@ def _ssm_prefill( # Maintain channels-last memory layout to use initial_states for causal_conv1d_fn # See https://github.com/Dao-AILab/causal-conv1d/blob/69e6dadc28b169a4c49cb86b586f64ee90242c70/csrc/causal_conv1d.cpp#L200 # pylint: disable=line-too-long assert batch_indices is not None + state_len = conv_state.shape[-1] initial_conv_state = ( - conv_state[batch_indices, :, 1:].permute(0, 2, 1).contiguous().transpose(1, 2) + conv_state[batch_indices, :, -self.d_conv + 1 :] + .permute(0, 2, 1) + .contiguous() + .transpose(1, 2) ) xBC = xBC.transpose(1, 2) tensor_masked_update( - conv_state, batch_indices, F.pad(xBC, (self.d_conv - xBC.shape[-1], 0)) + conv_state, batch_indices, F.pad(xBC, (state_len - xBC.shape[-1], 0)) ) else: # transpose: b l pd --> b pd l @@ -892,39 +921,26 @@ def _ssm_decode( conv_state: torch.Tensor, ssm_state: torch.Tensor, batch_indices: Optional[torch.Tensor] = None, - num_accepted_tokens: Optional[torch.Tensor] = None, - intermediate_conv_state: Optional[torch.Tensor] = None, intermediate_ssm_state: Optional[torch.Tensor] = None, + cache_seqlens: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Performs SSM computation for inference decode step. Args: - zxBCdt: The input tensor of shape (l, b, d), which is a concatenation of - z, x, B, C, and dt projections. For decoding, l must be 1. + zxBCdt: The input tensor of shape (b, s, d), which is a concatenation of + z, x, B, C, and dt projections. s is the sequence length (1 + num_speculative_tokens). conv_state: The convolution state tensor for inference. ssm_state: The selective scan state tensor for inference. - batch_indices: A map from batch id to position in the Mamba state tensors for - dynamic inference. + batch_indices: A map from batch id to position in the Mamba state tensors. + intermediate_ssm_state: Optional buffer for storing sequence steps in SSM state. + cache_seqlens: Optional tensor representing cache sequence length for circular buffering. Returns: - The output tensor of shape (l, b, d). + The output tensor of shape (b, s, d). """ - seq_len, batch_size, _ = zxBCdt.shape + batch_size, seq_len, _ = zxBCdt.shape dtype = zxBCdt.dtype - if seq_len > 1: - assert ( - num_accepted_tokens is not None - and intermediate_conv_state is not None - and intermediate_ssm_state is not None - ), "Decoding with > 1 token per request requires speculative decoding state" - is_speculative_decoding = True - else: - is_speculative_decoding = False - - if not is_speculative_decoding: - # Remove sequence dimension - zxBCdt = zxBCdt.squeeze(0) z, xBC, dt = torch.split( zxBCdt, @@ -938,26 +954,26 @@ def _ssm_decode( # Conv step if causal_conv1d_update is None: + assert seq_len == 1, "Native PyTorch fallback only supports 1 token at a time" + xBC_squeeze = xBC.squeeze(1) conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) - conv_state[:, :, -1] = xBC - xBC = torch.sum( + conv_state[:, :, -1] = xBC_squeeze + xBC_squeeze = torch.sum( conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1 ) # (B D) if self.conv1d.bias is not None: - xBC = xBC + self.conv1d.bias - xBC = self.act(xBC).to(dtype=xBC.dtype) + xBC_squeeze = xBC_squeeze + self.conv1d.bias + xBC = self.act(xBC_squeeze).to(dtype=xBC.dtype).unsqueeze(1) else: + # We completely omit the intermediate_conv_states parameter here xBC = causal_conv1d_update( xBC, conv_state, rearrange(self.conv1d.weight, "d 1 w -> d w"), self.conv1d.bias, self.activation, + cache_seqlens=cache_seqlens, conv_state_indices=batch_indices, - num_accepted_tokens=num_accepted_tokens, - intermediate_conv_window=intermediate_conv_state, - intermediate_state_indices=batch_indices, - pad_slot_id=-1, ) x, B, C = torch.split( @@ -973,7 +989,15 @@ def _ssm_decode( # SSM step if selective_state_update is None: - assert not is_speculative_decode + assert seq_len == 1, "Native PyTorch fallback only supports 1 token at a time" + + x = x.squeeze(1) + B = B.squeeze(1) + C = C.squeeze(1) + dt = dt.squeeze(1) + if z is not None: + z = z.squeeze(1) + if self.ngroups_local_tp > 1: B = rearrange(B, "b (g n) -> b g n", n=self.d_state) C = rearrange(C, "b (g n) -> b g n", n=self.d_state) @@ -1018,32 +1042,27 @@ def _ssm_decode( y = rearrange(y, "b h p -> b (h p)") if not self.rmsnorm: y = y * self.act(z) # (B D) + + y = y.unsqueeze(1) # Restore seq dimension else: A = repeat(A, "h -> h p n", p=self.headdim, n=self.d_state).to(dtype=torch.float32) + + # Incorporate sequence dimension in einops rearrengements + dt = repeat(dt, "b s h -> b s h p", p=self.headdim) dt_bias = repeat(self.dt_bias, "h -> h p", p=self.headdim) D = repeat(self.D, "h -> h p", p=self.headdim) - if is_speculative_decoding: - dt = repeat(dt, "s b h -> s b h p", p=self.headdim) - B = rearrange(B, "s b (g n) -> s b g n", g=self.ngroups_local_tp) - C = rearrange(C, "s b (g n) -> s b g n", g=self.ngroups_local_tp) - x_reshaped = rearrange(x, "s b (h p) -> s b h p", p=self.headdim) - if not self.rmsnorm: - z = rearrange(z, "s b (h p) -> s b h p", p=self.headdim) - else: - dt = repeat(dt, "b h -> b h p", p=self.headdim) - B = rearrange(B, "b (g n) -> b g n", g=self.ngroups_local_tp) - C = rearrange(C, "b (g n) -> b g n", g=self.ngroups_local_tp) - x_reshaped = rearrange(x, "b (h p) -> b h p", p=self.headdim) - if not self.rmsnorm: - z = rearrange(z, "b (h p) -> b h p", p=self.headdim) + B = rearrange(B, "b s (g n) -> b s g n", g=self.ngroups_local_tp) + C = rearrange(C, "b s (g n) -> b s g n", g=self.ngroups_local_tp) + x_reshaped = rearrange(x, "b s (h p) -> b s h p", p=self.headdim) + if not self.rmsnorm: + z = rearrange(z, "b s (h p) -> b s h p", p=self.headdim) # Upcast the batch_indices to prevent integer overflow errors in the case of # large max request counts. if batch_indices is not None: batch_indices = batch_indices.to(torch.int64) - y = torch.empty_like(x_reshaped) - selective_state_update( + y = selective_state_update( ssm_state, x_reshaped, dt, @@ -1055,23 +1074,14 @@ def _ssm_decode( dt_bias=dt_bias, dt_softplus=True, state_batch_indices=batch_indices, - pad_slot_id=-1, - out=y, - #disable_state_update=True, - #intermediate_states_buffer=intermediate_ssm_state, - #cache_steps=seq_len, - #intermediate_state_indices=batch_indices, + intermediate_ssm_states=intermediate_ssm_state, # SSM only ) - if is_speculative_decoding: - y = rearrange(y, "s b h p -> s b (h p)") - else: - y = rearrange(y, "b h p -> b (h p)") + y = rearrange(y, "b s h p -> b s (h p)") if self.rmsnorm: y = self.norm(y, z) - # Restore sequence dimension - return y.unsqueeze(0) + return y def mamba_state_shapes_per_request(self) -> Tuple[Tuple[int], Tuple[int]]: """Returns the Mamba conv and ssm states shapes per request.""" diff --git a/megatron/core/ssm/ops/causal_conv1d_triton.py b/megatron/core/ssm/ops/causal_conv1d_triton.py index fc27a8df500..1e399fd9650 100644 --- a/megatron/core/ssm/ops/causal_conv1d_triton.py +++ b/megatron/core/ssm/ops/causal_conv1d_triton.py @@ -2,44 +2,75 @@ import triton import triton.language as tl + @triton.jit def causal_conv1d_update_kernel( -x_ptr, x_b_stride, x_c_stride, - conv_state_ptr, conv_state_b_stride, conv_state_c_stride, conv_state_l_stride, - weight_ptr, weight_c_stride, weight_width_stride, - bias_ptr, bias_stride, - out_ptr, out_b_stride, out_c_stride, + x_ptr, + x_b_stride, + x_s_stride, + x_c_stride, + conv_state_ptr, + conv_state_b_stride, + conv_state_c_stride, + conv_state_l_stride, + int_state_ptr, + int_state_b_stride, + int_state_s_stride, + int_state_c_stride, + int_state_l_stride, + weight_ptr, + weight_c_stride, + weight_width_stride, + bias_ptr, + bias_stride, + out_ptr, + out_b_stride, + out_s_stride, + out_c_stride, conv_state_indices_ptr, cache_seqlens_ptr, - batch, dim, state_len, + batch, + seq_len, + dim, + state_len, WIDTH: tl.constexpr, BLOCK_DIM: tl.constexpr, HAS_BIAS: tl.constexpr, IS_CIRCULAR: tl.constexpr, HAS_STATE_INDICES: tl.constexpr, + HAS_INT_STATE: tl.constexpr, SILU_ACTIVATION: tl.constexpr, ): batch_id = tl.program_id(0) channel_block_id = tl.program_id(1) - + channel_offsets = channel_block_id * BLOCK_DIM + tl.arange(0, BLOCK_DIM) mask = channel_offsets < dim - + # State batch coordinate mapping if HAS_STATE_INDICES: state_batch_coord = tl.load(conv_state_indices_ptr + batch_id) else: state_batch_coord = batch_id - - # Pointers - x_ptrs = x_ptr + batch_id * x_b_stride + channel_offsets * x_c_stride - out_ptrs = out_ptr + batch_id * out_b_stride + channel_offsets * out_c_stride - conv_state_ptrs = conv_state_ptr + state_batch_coord * conv_state_b_stride + channel_offsets * conv_state_c_stride + + # Base Pointers + conv_state_ptrs = ( + conv_state_ptr + + state_batch_coord * conv_state_b_stride + + channel_offsets * conv_state_c_stride + ) weight_ptrs = weight_ptr + channel_offsets * weight_c_stride - + # Skip padding tokens (block-level uniform condition) if state_batch_coord < 0: - tl.store(out_ptrs, 0.0, mask=mask) + for s in range(seq_len): + out_ptrs = ( + out_ptr + + batch_id * out_b_stride + + s * out_s_stride + + channel_offsets * out_c_stride + ) + tl.store(out_ptrs, 0.0, mask=mask) return # Load Bias @@ -47,7 +78,7 @@ def causal_conv1d_update_kernel( bias_val = tl.load(bias_ptr + channel_offsets * bias_stride, mask=mask).to(tl.float32) else: bias_val = tl.zeros([BLOCK_DIM], dtype=tl.float32) - + # Load Weights if WIDTH == 2: w0 = tl.load(weight_ptrs + 0 * weight_width_stride, mask=mask).to(tl.float32) @@ -67,70 +98,110 @@ def causal_conv1d_update_kernel( x_val_1 = tl.zeros([BLOCK_DIM], dtype=tl.float32) x_val_2 = tl.zeros([BLOCK_DIM], dtype=tl.float32) x_val_3 = tl.zeros([BLOCK_DIM], dtype=tl.float32) - - if not IS_CIRCULAR: - # Load the last (WIDTH - 1) elements to use them BEFORE they are overwritten by the shift - if WIDTH >= 2: - x_val_0 = tl.load(conv_state_ptrs + (state_len - WIDTH + 1) * conv_state_l_stride, mask=mask).to(tl.float32) - if WIDTH >= 3: - x_val_1 = tl.load(conv_state_ptrs + (state_len - WIDTH + 2) * conv_state_l_stride, mask=mask).to(tl.float32) - if WIDTH >= 4: - x_val_2 = tl.load(conv_state_ptrs + (state_len - WIDTH + 3) * conv_state_l_stride, mask=mask).to(tl.float32) - - # Shift the linear state buffer left by 1 since advance_len is exactly 1 - i = 0 - while i < state_len - 1: - val = tl.load(conv_state_ptrs + (i + 1) * conv_state_l_stride, mask=mask) - tl.store(conv_state_ptrs + i * conv_state_l_stride, val, mask=mask) - i += 1 - - else: - cache_seqlen = tl.load(cache_seqlens_ptr + batch_id) % state_len - update_idx = cache_seqlen - (WIDTH - 1) - update_idx = tl.where(update_idx < 0, update_idx + state_len, update_idx) - - if WIDTH >= 2: - state_val = tl.load(conv_state_ptrs + update_idx * conv_state_l_stride, mask=mask) - x_val_0 = state_val.to(tl.float32) - update_idx = tl.where(update_idx + 1 >= state_len, update_idx + 1 - state_len, update_idx + 1) - if WIDTH >= 3: - state_val = tl.load(conv_state_ptrs + update_idx * conv_state_l_stride, mask=mask) - x_val_1 = state_val.to(tl.float32) - update_idx = tl.where(update_idx + 1 >= state_len, update_idx + 1 - state_len, update_idx + 1) - if WIDTH >= 4: - state_val = tl.load(conv_state_ptrs + update_idx * conv_state_l_stride, mask=mask) - x_val_2 = state_val.to(tl.float32) - update_idx = tl.where(update_idx + 1 >= state_len, update_idx + 1 - state_len, update_idx + 1) - - # Process the single token per request - x_val = tl.load(x_ptrs, mask=mask) - - if not IS_CIRCULAR: - tl.store(conv_state_ptrs + (state_len - 1) * conv_state_l_stride, x_val, mask=mask) - else: - tl.store(conv_state_ptrs + update_idx * conv_state_l_stride, x_val, mask=mask) - # No need to advance update_idx further since the sequence loop is removed - - x_val_f32 = x_val.to(tl.float32) - if WIDTH == 2: - x_val_1 = x_val_f32 - elif WIDTH == 3: - x_val_2 = x_val_f32 - elif WIDTH == 4: - x_val_3 = x_val_f32 - out_val = bias_val - if WIDTH == 2: - out_val += w0 * x_val_0 + w1 * x_val_1 - elif WIDTH == 3: - out_val += w0 * x_val_0 + w1 * x_val_1 + w2 * x_val_2 - elif WIDTH == 4: - out_val += w0 * x_val_0 + w1 * x_val_1 + w2 * x_val_2 + w3 * x_val_3 - - if SILU_ACTIVATION: - out_val = out_val * tl.sigmoid(out_val) - - tl.store(out_ptrs, out_val.to(out_ptrs.dtype.element_ty), mask=mask) + # If circular, we only need to read the base cache sequence length once + if IS_CIRCULAR: + base_cache_seqlen = tl.load(cache_seqlens_ptr + batch_id) + + # Loop over the sequence dimension (e.g., speculative tokens) + for s in range(seq_len): + x_ptrs = x_ptr + batch_id * x_b_stride + s * x_s_stride + channel_offsets * x_c_stride + out_ptrs = ( + out_ptr + batch_id * out_b_stride + s * out_s_stride + channel_offsets * out_c_stride + ) + + if not IS_CIRCULAR: + # Load the last (WIDTH - 1) elements to use them BEFORE they are overwritten by the shift + if WIDTH >= 2: + x_val_0 = tl.load( + conv_state_ptrs + (state_len - WIDTH + 1) * conv_state_l_stride, mask=mask + ).to(tl.float32) + if WIDTH >= 3: + x_val_1 = tl.load( + conv_state_ptrs + (state_len - WIDTH + 2) * conv_state_l_stride, mask=mask + ).to(tl.float32) + if WIDTH >= 4: + x_val_2 = tl.load( + conv_state_ptrs + (state_len - WIDTH + 3) * conv_state_l_stride, mask=mask + ).to(tl.float32) + + # Shift the linear state buffer left by 1 + i = 0 + while i < state_len - 1: + val = tl.load(conv_state_ptrs + (i + 1) * conv_state_l_stride, mask=mask) + tl.store(conv_state_ptrs + i * conv_state_l_stride, val, mask=mask) + i += 1 + else: + cache_seqlen = base_cache_seqlen + s + update_idx = cache_seqlen % state_len + read_idx = update_idx - (WIDTH - 1) + read_idx = tl.where(read_idx < 0, read_idx + state_len, read_idx) + + if WIDTH >= 2: + state_val = tl.load(conv_state_ptrs + read_idx * conv_state_l_stride, mask=mask) + x_val_0 = state_val.to(tl.float32) + read_idx = tl.where( + read_idx + 1 >= state_len, read_idx + 1 - state_len, read_idx + 1 + ) + if WIDTH >= 3: + state_val = tl.load(conv_state_ptrs + read_idx * conv_state_l_stride, mask=mask) + x_val_1 = state_val.to(tl.float32) + read_idx = tl.where( + read_idx + 1 >= state_len, read_idx + 1 - state_len, read_idx + 1 + ) + if WIDTH >= 4: + state_val = tl.load(conv_state_ptrs + read_idx * conv_state_l_stride, mask=mask) + x_val_2 = state_val.to(tl.float32) + + # Process the single token for the current sequence step + x_val = tl.load(x_ptrs, mask=mask) + + # Store the new token in the state buffer + if not IS_CIRCULAR: + tl.store(conv_state_ptrs + (state_len - 1) * conv_state_l_stride, x_val, mask=mask) + else: + cache_seqlen = base_cache_seqlen + s + update_idx = cache_seqlen % state_len + tl.store(conv_state_ptrs + update_idx * conv_state_l_stride, x_val, mask=mask) + + # Write out to the intermediate state buffer if requested + if HAS_INT_STATE: + i = 0 + while i < state_len: + val = tl.load(conv_state_ptrs + i * conv_state_l_stride, mask=mask) + int_ptr = ( + int_state_ptr + + state_batch_coord * int_state_b_stride + + s * int_state_s_stride + + channel_offsets * int_state_c_stride + + i * int_state_l_stride + ) + tl.store(int_ptr, val, mask=mask) + i += 1 + + # Advance registers for calculation + x_val_f32 = x_val.to(tl.float32) + if WIDTH == 2: + x_val_1 = x_val_f32 + elif WIDTH == 3: + x_val_2 = x_val_f32 + elif WIDTH == 4: + x_val_3 = x_val_f32 + + # Compute output + out_val = bias_val + if WIDTH == 2: + out_val += w0 * x_val_0 + w1 * x_val_1 + elif WIDTH == 3: + out_val += w0 * x_val_0 + w1 * x_val_1 + w2 * x_val_2 + elif WIDTH == 4: + out_val += w0 * x_val_0 + w1 * x_val_1 + w2 * x_val_2 + w3 * x_val_3 + + if SILU_ACTIVATION: + out_val = out_val * tl.sigmoid(out_val) + + tl.store(out_ptrs, out_val.to(out_ptrs.dtype.element_ty), mask=mask) + def causal_conv1d_update( x: torch.Tensor, @@ -140,9 +211,16 @@ def causal_conv1d_update( silu_activation: bool, cache_seqlens: torch.Tensor | None, conv_state_indices: torch.Tensor | None, + intermediate_conv_states: torch.Tensor | None = None, ) -> torch.Tensor: + + # Check if input is 2D, temporarily treat as 3D for uniform processing + is_2d = x.dim() == 2 + if is_2d: + x = x.unsqueeze(1) + + batch, seq_len, dim = x.shape out = torch.empty_like(x) - batch, dim = x.shape state_len = conv_state.shape[-1] width = weight.shape[-1] @@ -150,39 +228,80 @@ def causal_conv1d_update( bias_stride = bias.stride(0) has_bias = True else: - bias = x # Dummy pointer + bias = x # Dummy pointer bias_stride = 0 has_bias = False if cache_seqlens is not None: is_circular = True else: - cache_seqlens = x # Dummy pointer + cache_seqlens = x # Dummy pointer is_circular = False if conv_state_indices is not None: has_state_indices = True else: - conv_state_indices = x # Dummy pointer + conv_state_indices = x # Dummy pointer has_state_indices = False + # Extract intermediate state strides if provided + if intermediate_conv_states is not None: + has_int_state = True + int_state_ptr = intermediate_conv_states + int_state_b_stride = intermediate_conv_states.stride(0) + int_state_s_stride = intermediate_conv_states.stride(1) + int_state_c_stride = intermediate_conv_states.stride(2) + int_state_l_stride = intermediate_conv_states.stride(3) + else: + has_int_state = False + int_state_ptr = x # Dummy pointer + int_state_b_stride = 0 + int_state_s_stride = 0 + int_state_c_stride = 0 + int_state_l_stride = 0 + BLOCK_DIM = 64 grid = (batch, triton.cdiv(dim, BLOCK_DIM)) causal_conv1d_update_kernel[grid]( - x, x.stride(0), x.stride(1), - conv_state, conv_state.stride(0), conv_state.stride(1), conv_state.stride(2), - weight, weight.stride(0), weight.stride(1), - bias, bias_stride, - out, out.stride(0), out.stride(1), - conv_state_indices, cache_seqlens, - batch, dim, state_len, + x, + x.stride(0), + x.stride(1), + x.stride(2), + conv_state, + conv_state.stride(0), + conv_state.stride(1), + conv_state.stride(2), + int_state_ptr, + int_state_b_stride, + int_state_s_stride, + int_state_c_stride, + int_state_l_stride, + weight, + weight.stride(0), + weight.stride(1), + bias, + bias_stride, + out, + out.stride(0), + out.stride(1), + out.stride(2), + conv_state_indices, + cache_seqlens, + batch, + seq_len, + dim, + state_len, WIDTH=width, BLOCK_DIM=BLOCK_DIM, HAS_BIAS=has_bias, IS_CIRCULAR=is_circular, HAS_STATE_INDICES=has_state_indices, + HAS_INT_STATE=has_int_state, SILU_ACTIVATION=silu_activation == "silu", ) + if is_2d: + out = out.squeeze(1) + return out diff --git a/megatron/core/ssm/ops/mamba_ssm.py b/megatron/core/ssm/ops/mamba_ssm.py index f238d51b47e..e3d3dde0a01 100644 --- a/megatron/core/ssm/ops/mamba_ssm.py +++ b/megatron/core/ssm/ops/mamba_ssm.py @@ -1,70 +1,24 @@ -# Adapted from: https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/mamba/ops/mamba_ssm.py - -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - # Copyright (c) 2024, Tri Dao, Albert Gu. -# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/selective_state_update.py + +import math import torch +import torch.nn.functional as F import triton import triton.language as tl -from packaging import version - -PAD_SLOT_ID = -1 - -TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0") - -if TRITON3: - - @triton.jit - def softplus(dt): - dt = tl.where(dt <= 20.0, tl.math.log(tl.math.exp(dt) + 1), dt) - return dt - -else: - - @triton.jit - def softplus(dt): - dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt) - return dt +from einops import rearrange, repeat +from mamba_ssm.ops.triton.softplus import softplus @triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None}) @triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None}) @triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None}) @triton.heuristics( - { - "HAS_STATE_BATCH_INDICES": lambda args: args["state_batch_indices_ptr"] - is not None - } -) -@triton.heuristics( - {"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])} -) -@triton.heuristics( - { - "CACHE_INTERMEDIATE_STATES": lambda args: args["intermediate_states_buffer"] - is not None - } -) -@triton.heuristics( - { - "HAS_EAGLE_TREE_CUSTOM_ATTN_MASK": lambda args: args[ - "retrieve_parent_token_ptr" - ] - is not None - } + {"HAS_STATE_BATCH_INDICES": lambda args: args["state_batch_indices_ptr"] is not None} ) -@triton.heuristics( - { - "HAS_INTERMEDIATE_STATE_INDICES": lambda args: args[ - "intermediate_state_indices_ptr" - ] - is not None - } -) -@triton.jit(do_not_specialize=["T"]) +@triton.heuristics({"HAS_INT_STATE": lambda args: args["int_state_ptr"] is not None}) +@triton.heuristics({"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}) +@triton.jit def _selective_scan_update_kernel( # Pointers to matrices state_ptr, @@ -78,14 +32,10 @@ def _selective_scan_update_kernel( z_ptr, out_ptr, state_batch_indices_ptr, - pad_slot_id, - intermediate_states_buffer, - cache_steps, - retrieve_parent_token_ptr, - intermediate_state_indices_ptr, + int_state_ptr, # Matrix dimensions batch, - T, + seq_len, nheads, dim, dstate, @@ -96,11 +46,11 @@ def _selective_scan_update_kernel( stride_state_dim, stride_state_dstate, stride_x_batch, - stride_x_T, + stride_x_seq, stride_x_head, stride_x_dim, stride_dt_batch, - stride_dt_T, + stride_dt_seq, stride_dt_head, stride_dt_dim, stride_dt_bias_head, @@ -109,25 +59,28 @@ def _selective_scan_update_kernel( stride_A_dim, stride_A_dstate, stride_B_batch, - stride_B_T, + stride_B_seq, stride_B_group, stride_B_dstate, stride_C_batch, - stride_C_T, + stride_C_seq, stride_C_group, stride_C_dstate, stride_D_head, stride_D_dim, stride_z_batch, - stride_z_T, + stride_z_seq, stride_z_head, stride_z_dim, stride_out_batch, - stride_out_T, + stride_out_seq, stride_out_head, stride_out_dim, - stride_retrieve_parent_token_batch, - stride_retrieve_parent_token_T, + stride_int_batch, + stride_int_seq, + stride_int_head, + stride_int_dim, + stride_int_dstate, # Meta-parameters DT_SOFTPLUS: tl.constexpr, TIE_HDIM: tl.constexpr, @@ -136,165 +89,155 @@ def _selective_scan_update_kernel( HAS_D: tl.constexpr, HAS_Z: tl.constexpr, HAS_STATE_BATCH_INDICES: tl.constexpr, - DISABLE_STATE_UPDATE: tl.constexpr, - CACHE_INTERMEDIATE_STATES: tl.constexpr, - HAS_EAGLE_TREE_CUSTOM_ATTN_MASK: tl.constexpr, - HAS_INTERMEDIATE_STATE_INDICES: tl.constexpr, + HAS_INT_STATE: tl.constexpr, BLOCK_SIZE_DSTATE: tl.constexpr, ): pid_m = tl.program_id(axis=0) pid_b = tl.program_id(axis=1) pid_h = tl.program_id(axis=2) - # If HAS_STATE_BATCH_INDICES is true, then the ssm state's batch coordinate - # is taken from the state_batch_indices_ptr Otherwise, the state coordinate - # is the same as the batch id. + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = tl.arange(0, BLOCK_SIZE_DSTATE) + + out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + out_ptrs = out_ptr + offs_m * stride_out_dim + + # 1. State Mapping (handles dynamic batching slot allocation) if HAS_STATE_BATCH_INDICES: state_batch_indices_ptr += pid_b - state_batch_idx = tl.load(state_batch_indices_ptr).to(tl.int64) + state_batch_idx = tl.load(state_batch_indices_ptr) + # Skip padding tokens (e.g. from graph capture or inactive slots) + if state_batch_idx < 0: + for s in range(seq_len): + out_s_ptrs = out_ptrs + s * stride_out_seq + tl.store(out_s_ptrs, 0.0, mask=offs_m < dim) + return state_ptr += state_batch_idx * stride_state_batch + pid_h * stride_state_head + if HAS_INT_STATE: + int_state_ptr += state_batch_idx * stride_int_batch + pid_h * stride_int_head else: state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head + if HAS_INT_STATE: + int_state_ptr += pid_b * stride_int_batch + pid_h * stride_int_head + # Base Pointers for Sequence iteration x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head if HAS_DT_BIAS: dt_bias_ptr += pid_h * stride_dt_bias_head + A_ptr += pid_h * stride_A_head B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group if HAS_Z: z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head - out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = tl.arange(0, BLOCK_SIZE_DSTATE) + # Constant offsets (A, D, and bias do not have a sequence dimension) state_ptrs = state_ptr + ( offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate ) + if HAS_INT_STATE: + int_state_ptrs = int_state_ptr + ( + offs_m[:, None] * stride_int_dim + offs_n[None, :] * stride_int_dstate + ) - mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate) - if HAS_STATE_BATCH_INDICES: - mask &= state_batch_idx != pad_slot_id - state = tl.load(state_ptrs, mask=mask, other=0.0).to(tl.float32) + x_ptrs = x_ptr + offs_m * stride_x_dim + dt_ptrs = dt_ptr + offs_m * stride_dt_dim if HAS_DT_BIAS: dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim if HAS_D: D_ptr += pid_h * stride_D_head D_ptrs = D_ptr + offs_m * stride_D_dim - A_ptrs = A_ptr + offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate - cache_idx = -1 - if CACHE_INTERMEDIATE_STATES: - if HAS_INTERMEDIATE_STATE_INDICES: - intermediate_state_idx = tl.load(intermediate_state_indices_ptr + pid_b).to( - tl.int64 - ) - cache_idx = intermediate_state_idx - elif HAS_STATE_BATCH_INDICES: - cache_idx = state_batch_idx - else: - cache_idx = pid_b - - current_step_idx = 0 - for _ in range(T): - if HAS_EAGLE_TREE_CUSTOM_ATTN_MASK: - if current_step_idx != 0 and cache_idx >= 0: - parent_ptr = ( - retrieve_parent_token_ptr - + pid_b * stride_retrieve_parent_token_batch - + current_step_idx * stride_retrieve_parent_token_T - ) - parent_step_idx = tl.load(parent_ptr).to(tl.int32) - - if parent_step_idx >= 0 and parent_step_idx < T: - step_offset = parent_step_idx * nheads * dim * dstate - cache_ptr = ( - intermediate_states_buffer - + cache_idx * cache_steps * nheads * dim * dstate - + step_offset - + pid_h * dim * dstate - + offs_m[:, None] * dstate - + offs_n[None, :] - ) - state = tl.load(cache_ptr, mask=mask, other=0.0).to(tl.float32) - - x_ptrs = x_ptr + offs_m * stride_x_dim - dt_ptrs = dt_ptr + offs_m * stride_dt_dim - B_ptrs = B_ptr + offs_n * stride_B_dstate - C_ptrs = C_ptr + offs_n * stride_C_dstate + A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate) + B_ptrs = B_ptr + offs_n * stride_B_dstate + C_ptrs = C_ptr + offs_n * stride_C_dstate + + if HAS_Z: + z_ptrs = z_ptr + offs_m * stride_z_dim + + # Load initial historical state and constant parameters + state = tl.load( + state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0 + ).to(tl.float32) + + if not TIE_HDIM: + A = tl.load( + A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0 + ).to(tl.float32) + else: + A = tl.load(A_ptr).to(tl.float32) + + if HAS_D: + D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + + # ---------------------------------------------------- + # Sequence Loop (Processes Main Token + Speculative Drafts) + # ---------------------------------------------------- + for s in range(seq_len): + x_s_ptrs = x_ptrs + s * stride_x_seq + dt_s_ptrs = dt_ptrs + s * stride_dt_seq + B_s_ptrs = B_ptrs + s * stride_B_seq + C_s_ptrs = C_ptrs + s * stride_C_seq if HAS_Z: - z_ptrs = z_ptr + offs_m * stride_z_dim - out_ptrs = out_ptr + offs_m * stride_out_dim + z_s_ptrs = z_ptrs + s * stride_z_seq + + x = tl.load(x_s_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) - x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + # Calculate dt and dA if not TIE_HDIM: - dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + dt = tl.load(dt_s_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) if HAS_DT_BIAS: dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) if DT_SOFTPLUS: - dt = softplus(dt) - A = tl.load( - A_ptrs, - mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), - other=0.0, - ).to(tl.float32) + dt = tl.where(dt <= 20.0, softplus(dt), dt) dA = tl.exp(A * dt[:, None]) else: - dt = tl.load(dt_ptr).to(tl.float32) + dt = tl.load(dt_ptr + s * stride_dt_seq).to(tl.float32) if HAS_DT_BIAS: dt += tl.load(dt_bias_ptr).to(tl.float32) if DT_SOFTPLUS: - dt = softplus(dt) - A = tl.load(A_ptr).to(tl.float32) - dA = tl.exp(A * dt) # scalar, not a matrix + dt = tl.where(dt <= 20.0, softplus(dt), dt) + dA = tl.exp(A * dt) - B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) - C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) - if HAS_D: - D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + # Load B and C + B = tl.load(B_s_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) + C = tl.load(C_s_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) if HAS_Z: - z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + z = tl.load(z_s_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + + if not TIE_HDIM: + dB = B[None, :] * dt[:, None] + else: + dB = B * dt - dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt + # ---------------------------------------------------- + # The Core State Recurrence (h_t = dA * h_{t-1} + dB * x_t) + # ---------------------------------------------------- state = state * dA + dB * x[:, None] - if CACHE_INTERMEDIATE_STATES: - if HAS_STATE_BATCH_INDICES: - if state_batch_idx != pad_slot_id: - cache_ptr_base = ( - intermediate_states_buffer - + cache_idx * cache_steps * nheads * dim * dstate - + current_step_idx * nheads * dim * dstate - + pid_h * dim * dstate - ) - cache_ptrs = cache_ptr_base + ( - offs_m[:, None] * dstate + offs_n[None, :] - ) - tl.store( - cache_ptrs, state.to(cache_ptrs.dtype.element_ty), mask=mask - ) + # ---------------------------------------------------- + # Dump Intermediate Speculative State Snapshot + # ---------------------------------------------------- + if HAS_INT_STATE: + int_state_s_ptrs = int_state_ptrs + s * stride_int_seq + tl.store( + int_state_s_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate) + ) + # Calculate Output out = tl.sum(state * C[None, :], axis=1) if HAS_D: out += x * D if HAS_Z: out *= z * tl.sigmoid(z) - tl.store(out_ptrs, out, mask=offs_m < dim) - - current_step_idx += 1 - x_ptr += stride_x_T - dt_ptr += stride_dt_T - B_ptr += stride_B_T - C_ptr += stride_C_T - out_ptr += stride_out_T - if HAS_Z: - z_ptr += stride_z_T + out_s_ptrs = out_ptrs + s * stride_out_seq + tl.store(out_s_ptrs, out, mask=offs_m < dim) - if not DISABLE_STATE_UPDATE: - tl.store(state_ptrs, state.to(state_ptrs.dtype.element_ty), mask=mask) + # After processing all sequence steps, persist the final state back to HBM + tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate)) def selective_state_update( @@ -309,102 +252,90 @@ def selective_state_update( dt_bias=None, dt_softplus=False, state_batch_indices=None, - pad_slot_id=PAD_SLOT_ID, - out=None, - disable_state_update=False, - intermediate_states_buffer=None, - cache_steps=None, - retrieve_parent_token=None, - intermediate_state_indices=None, + intermediate_ssm_states=None, ): """ Argument: state: (batch, dim, dstate) or (batch, nheads, dim, dstate) - x: (batch, dim) or (batch, nheads, dim) for single-token or (batch, T, nheads, dim) for multi-token - dt: (batch, dim) or (batch, nheads, dim) + x: (batch, dim), (batch, seqlen, dim), (batch, nheads, dim) or (batch, seqlen, nheads, dim) + dt: Matches x A: (dim, dstate) or (nheads, dim, dstate) - B: (batch, dstate) or (batch, ngroups, dstate) for single-token or (batch, T, ngroups, dstate) for multi-token - C: (batch, dstate) or (batch, ngroups, dstate) + B: (batch, dstate), (batch, seqlen, dstate), (batch, ngroups, dstate) or (batch, seqlen, ngroups, dstate) + C: Matches B D: (dim,) or (nheads, dim) - z: (batch, dim) or (batch, nheads, dim) + z: Matches x dt_bias: (dim,) or (nheads, dim) - pad_slot_id: int - if cache_indices is passed, lets the kernel identify padded - entries that will not be processed, - for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] - in this case, the kernel will not process entries at - indices 0 and 3 - out: Preallocated ssm output tensor. Assume same shape as x. - In-place updated. - disable_state_update: If True, don't write back to state (for speculative verify) - intermediate_states_buffer: Buffer to cache intermediate states - cache_steps: Total number of steps in the buffer - retrieve_parent_token: (batch, T) tensor of parent token indices for EAGLE tree attention - intermediate_state_indices: (batch,) tensor of indices for intermediate_states_buffer operations. - If provided, uses these indices instead of state_batch_indices for the buffer. + intermediate_ssm_states: Optional buffer of shape (batch, seqlen, nheads, dim, dstate) + or (batch, seqlen, dim, dstate) + Return: + out: shape matches x """ - if state.dim() == 3: + has_heads = state.dim() > 3 + if not has_heads: state = state.unsqueeze(1) - if x.dim() == 2: - x = x.unsqueeze(1) - if x.dim() == 3: - x = x.unsqueeze(1) - if dt.dim() == 2: - dt = dt.unsqueeze(1) - if dt.dim() == 3: - dt = dt.unsqueeze(1) + + # Standardize inputs to explicit sequence and head dimensions: (batch, seq_len, nheads, dim) + is_seq_unsq = False + if has_heads: + if x.dim() == 3: # (batch, nheads, dim) -> (batch, 1, nheads, dim) + x = x.unsqueeze(1) + dt = dt.unsqueeze(1) + B = B.unsqueeze(1) + C = C.unsqueeze(1) + if z is not None: + z = z.unsqueeze(1) + is_seq_unsq = True + else: + if x.dim() == 2: # (batch, dim) -> (batch, 1, 1, dim) + x = x.unsqueeze(1).unsqueeze(2) + dt = dt.unsqueeze(1).unsqueeze(2) + B = B.unsqueeze(1).unsqueeze(2) + C = C.unsqueeze(1).unsqueeze(2) + if z is not None: + z = z.unsqueeze(1).unsqueeze(2) + is_seq_unsq = True + elif x.dim() == 3: # (batch, seqlen, dim) -> (batch, seqlen, 1, dim) + x = x.unsqueeze(2) + dt = dt.unsqueeze(2) + B = B.unsqueeze(2) + C = C.unsqueeze(2) + if z is not None: + z = z.unsqueeze(2) + if A.dim() == 2: A = A.unsqueeze(0) - if B.dim() == 2: - B = B.unsqueeze(1) - if B.dim() == 3: - B = B.unsqueeze(1) - if C.dim() == 2: - C = C.unsqueeze(1) - if C.dim() == 3: - C = C.unsqueeze(1) if D is not None and D.dim() == 1: D = D.unsqueeze(0) - if z is not None: - if z.dim() == 2: - z = z.unsqueeze(1) - if z.dim() == 3: - z = z.unsqueeze(1) if dt_bias is not None and dt_bias.dim() == 1: dt_bias = dt_bias.unsqueeze(0) - if out.dim() == 2: - out = out.unsqueeze(1) - if out.dim() == 3: - out = out.unsqueeze(1) - - _, nheads, dim, dstate = state.shape - batch, T, _, _ = x.shape - - assert x.shape == (batch, T, nheads, dim) - assert dt.shape == x.shape - assert A.shape == (nheads, dim, dstate) - ngroups = B.shape[2] - assert nheads % ngroups == 0, "nheads must be divisible by ngroups" - assert B.shape == (batch, T, ngroups, dstate) - assert C.shape == B.shape - if D is not None: - assert D.shape == (nheads, dim) - if z is not None: - assert z.shape == x.shape - if dt_bias is not None: - assert dt_bias.shape == (nheads, dim) - if state_batch_indices is not None: - assert state_batch_indices.shape == (batch,) - assert out.shape == x.shape - - grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE_M"]), batch, nheads) + + # Set up Intermediate State standardization + if intermediate_ssm_states is not None: + if not has_heads and intermediate_ssm_states.dim() == 4: + intermediate_ssm_states = intermediate_ssm_states.unsqueeze( + 2 + ) # (batch, seqlen, 1, dim, dstate) + int_state_strides = ( + intermediate_ssm_states.stride(0), + intermediate_ssm_states.stride(1), + intermediate_ssm_states.stride(2), + intermediate_ssm_states.stride(3), + intermediate_ssm_states.stride(4), + ) + else: + intermediate_ssm_states = x # Dummy pointer + int_state_strides = (0, 0, 0, 0, 0) + + batch, seq_len, nheads, dim = x.shape + dstate = state.shape[-1] + ngroups = B.shape[-2] + + out = torch.empty_like(x) + grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads) z_strides = ( - (z.stride(0), z.stride(1), z.stride(2), z.stride(3)) - if z is not None - else (0, 0, 0, 0) + (z.stride(0), z.stride(1), z.stride(2), z.stride(3)) if z is not None else (0, 0, 0, 0) ) - # We don't want autotune since it will overwrite the state - # We instead tune by hand. + BLOCK_SIZE_M, num_warps = ( (32, 4) if dstate <= 16 @@ -414,17 +345,12 @@ def selective_state_update( else ((8, 4) if dstate <= 64 else ((4, 4) if dstate <= 128 else ((4, 8)))) ) ) + tie_hdim = ( A.stride(-1) == 0 and A.stride(-2) == 0 and dt.stride(-1) == 0 - and dt_bias.stride(-1) == 0 - ) - - retrieve_parent_token_strides = ( - (retrieve_parent_token.stride(0), retrieve_parent_token.stride(1)) - if retrieve_parent_token is not None - else (0, 0) + and (dt_bias is None or dt_bias.stride(-1) == 0) ) with torch.cuda.device(x.device.index): @@ -440,13 +366,9 @@ def selective_state_update( z, out, state_batch_indices, - pad_slot_id, - intermediate_states_buffer, - cache_steps if cache_steps is not None else 0, - retrieve_parent_token, - intermediate_state_indices, + intermediate_ssm_states, batch, - T, + seq_len, nheads, dim, dstate, @@ -463,7 +385,7 @@ def selective_state_update( dt.stride(1), dt.stride(2), dt.stride(3), - *(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0, + *(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else (0, 0), A.stride(0), A.stride(1), A.stride(2), @@ -475,7 +397,7 @@ def selective_state_update( C.stride(1), C.stride(2), C.stride(3), - *(D.stride(0), D.stride(1)) if D is not None else 0, + *(D.stride(0), D.stride(1)) if D is not None else (0, 0), z_strides[0], z_strides[1], z_strides[2], @@ -484,11 +406,17 @@ def selective_state_update( out.stride(1), out.stride(2), out.stride(3), - retrieve_parent_token_strides[0], - retrieve_parent_token_strides[1], + *int_state_strides, dt_softplus, tie_hdim, BLOCK_SIZE_M, - DISABLE_STATE_UPDATE=disable_state_update, num_warps=num_warps, ) + + # Revert dimensions back to match original x format + if not has_heads: + out = out.squeeze(2) + if is_seq_unsq: + out = out.squeeze(1) + + return out diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index 019c6fef396..1edbdc760b8 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -788,7 +788,7 @@ def flash_decode_and_prefill( assert block_table is not None # Flash attn kernel. - if not is_decode_only: + if max_seqlen_q > 1: q = q.squeeze(1) if getattr(self, "softmax_scale", None) is not None: softmax_scale = self.softmax_scale diff --git a/megatron/inference/utils.py b/megatron/inference/utils.py index 92d153755fe..25fbfcd6cba 100644 --- a/megatron/inference/utils.py +++ b/megatron/inference/utils.py @@ -299,11 +299,12 @@ def get_inference_config_from_model_and_args(model: MegatronModule, args): mamba_inference_state_config=mamba_inference_state_config, pg_collection=pg_collection, use_flashinfer_fused_rope=args.use_flashinfer_fused_rope, - materialize_only_last_token_logits=not args.return_log_probs, + materialize_only_last_token_logits=(not args.return_log_probs and not args.num_speculative_tokens > 0), track_paused_request_events=args.inference_dynamic_batching_track_paused_request_events, enable_chunked_prefill=args.enable_chunked_prefill, metrics_writer=metrics_writer, logging_step_interval=args.inference_logging_step_interval, + num_speculative_tokens=args.num_speculative_tokens, ) From a011c73cd0029ab666abb6cb12b40be69bbf851a Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Wed, 25 Feb 2026 15:41:35 -0800 Subject: [PATCH 13/76] Fix non-consecutive acceptance bug Signed-off-by: Keshav Santhanam --- .../text_generation_controller.py | 51 +++++++++++++------ 1 file changed, 35 insertions(+), 16 deletions(-) diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index 6aaf8831eab..bc1d06f54c0 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -906,7 +906,7 @@ def _dynamic_step_sample_logits_and_verify_tokens( mtp_output_tokens = torch.empty_like(mtp_output_tokens_jumbled) mtp_output_tokens[:, token_order] = mtp_output_tokens_jumbled - ### ================ PART 3 This part is to do the fowlling : ================ + ### ================ PART 3 This part is to do the following : ================ # Create the accepted tokens tensor # For prefill it is always set to 1 # For decode, the first token is always accepted, then we compare with input tokens and accept the next tokens if its a match @@ -928,27 +928,46 @@ def _dynamic_step_sample_logits_and_verify_tokens( ), f"Expected input_tokens_required to have 1 row, but got {input_tokens_required.shape}" input_tokens_required = input_tokens_required.squeeze(0) - # This is to get the place where the output sampled speculative token is equal to input token - output_right_shifted = output_tokens.roll(1) - accepted_tokens_mask = input_tokens_required == output_right_shifted + # Initialize mask with False to prevent boundary bleed + accepted_tokens_mask = torch.zeros_like(input_tokens_required, dtype=torch.bool) # This is to make all prefill tokens accepted token_to_prefill_idx = torch.repeat_interleave(request_in_prefill_status_tensor, repeats) - accepted_tokens_mask[token_to_prefill_idx == 1] = 1 + accepted_tokens_mask[token_to_prefill_idx == 1] = True - # This is to make first decode token in all requests accepted - deocde_query_starts = torch.arange(num_decode_requests) * (1 + self.num_speculative_tokens) - accepted_tokens_mask[deocde_query_starts] = 1 + # Safe decode token verification without cross-batch boundary contamination + if num_decode_requests > 0: + decode_len = num_decode_requests * (self.num_speculative_tokens + 1) + + decode_inputs = input_tokens_required[:decode_len].reshape( + num_decode_requests, self.num_speculative_tokens + 1 + ) + decode_outputs = output_tokens[:decode_len].reshape( + num_decode_requests, self.num_speculative_tokens + 1 + ) + + # Shift outputs right by 1 *within* each request to align sampled tokens with input targets + decode_outputs_shifted = decode_outputs.roll(1, dims=1) + + decode_mask_2d = decode_inputs == decode_outputs_shifted + + # The first token (base token) is always accepted + decode_mask_2d[:, 0] = True + + # ENFORCE CONSECUTIVE ACCEPTANCE: + # cummin() on booleans propagates False (0) to the right, invalidating all subsequent mismatches + decode_mask_2d = decode_mask_2d.cummin(dim=1).values + + # Put the consecutive-enforced mask back into the flattened 1D tensor + accepted_tokens_mask[:decode_len] = decode_mask_2d.flatten() # This is to find the index of the last 1 in every request + # (Now mathematically guaranteed to be the final consecutive match) last_one_indices = torch.full( (active_request_count,), -1, device=token_to_request_index.device ) - last_one_indices[token_to_request_index[accepted_tokens_mask == 1]] = torch.where( - accepted_tokens_mask == 1 - )[ - 0 - ] # [1, 5, 6] + valid_indices = torch.nonzero(accepted_tokens_mask).squeeze(-1) + last_one_indices[token_to_request_index[valid_indices]] = valid_indices # These are the tokens (output + speculative tokens) that will be going to the next forward pass final_sampled_tokens = output_tokens[last_one_indices] @@ -957,8 +976,8 @@ def _dynamic_step_sample_logits_and_verify_tokens( :, last_one_indices ] - ### ================ PART 4 This part is to do the fowlling : ================ - # To fill the speculative otkens and accepted_token counts + ### ================ PART 4 This part is to do the following : ================ + # To fill the speculative tokens and accepted_token counts # For prefill it is always set to 1 # For decode, the first token is always accepted, then we compare with input tokens and accept the next tokens if its a match # Then find the index of the last 1 in every request of the accepted tokens tensor @@ -968,7 +987,7 @@ def _dynamic_step_sample_logits_and_verify_tokens( # input_tokens_required: [ a5 a6s a7s | b3 b4s b5s | c6 c7s c8s | d2 | e4 ] # Size 11 # Accepted tokens mask [ 1 1 0 | 1 1 1 | 1 0 0 | 1 | 1 ] # Accepted tokens [ [a6s -1] | [b4s b5s] | [-1 -1] ] # Only handle decod requests, (Prefill already defaults to -1s) - # Accepted token counts [ 1 | 2 | 0 ] # Prefill defaults to 0 + # Accepted token counts [ 1 | 2 | 0 ] # Prefill defaults to 0 # This part is to extract the accepted tokens input_tokens_required[accepted_tokens_mask == 0] = -1 # Masks out non accepted tokens From 06b08d749149c04d52f09a2c5a336e0bf57ab4bf Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Wed, 25 Feb 2026 22:52:30 -0800 Subject: [PATCH 14/76] More progress Signed-off-by: Keshav Santhanam --- .../core/inference/contexts/dynamic_context.py | 7 ++++--- .../core/inference/engines/dynamic_engine.py | 10 +++++----- .../text_generation_controller.py | 18 ++++++++++++++---- 3 files changed, 23 insertions(+), 12 deletions(-) diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index ba0d15ee513..b5c6971a588 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -1330,10 +1330,11 @@ def initialize_attention_state( if self.is_decode_only(): padded_token_count = min( self.max_tokens, - self.max_requests, + self.max_requests * (self.num_speculative_tokens + 1), self.round_up_tokens(self.active_token_count), ) padded_decode_req_count = padded_token_count // (self.num_speculative_tokens + 1) + #print(f"self.max_tokens={self.max_tokens}, self.max_requests={self.max_requests}, self.active_token_count={self.active_token_count}, padded_decode_req_count={padded_decode_req_count}, padded_token_count={padded_token_count}") padded_prefill_req_count = 0 else: target_padding_req_count = min( @@ -2218,8 +2219,8 @@ def update_requests( assert self.total_request_count == active_request_count + self.paused_request_count if self.paused_request_count > 0: - self.paused_tokens = next_tokens[: self.paused_request_count] - self.paused_speculative_tokens = new_speculative_tokens[:, : self.paused_request_count] + self.paused_tokens = next_tokens[: self.paused_request_count].clone() + self.paused_speculative_tokens = new_speculative_tokens[:, : self.paused_request_count].clone() # add_ and fill_ calls seems to work as intended with sliced indexing (i.e. x[3:5].add(...) or x[3:5].fill_) # but when another tensor is used for indexing, it does not work as expected (i.e. x[y] if x and y are torch tensors) diff --git a/megatron/core/inference/engines/dynamic_engine.py b/megatron/core/inference/engines/dynamic_engine.py index 9228f02b032..efd2a5411a0 100644 --- a/megatron/core/inference/engines/dynamic_engine.py +++ b/megatron/core/inference/engines/dynamic_engine.py @@ -922,7 +922,7 @@ def post_process_requests( if self.num_speculative_tokens > 0: accepted_tokens = list(filter(lambda tok: tok != -1, accepted_tokens_list)) - tokens = tokens + accepted_tokens + tokens = accepted_tokens + tokens request: DynamicInferenceRequest = self.get_request(request_id) if request_id != self.context.chunked_prefill_request_id: @@ -1154,10 +1154,10 @@ def _check_stop_words_for_request_post_append(self, request: DynamicInferenceReq if list(generated_tokens[-stop_len:]) == stop_word_ids: return True else: - # Need to check the last stop len tokens shifting by 1 up to num_speculative_tokens - # Check logic and vecotrize this if possible - for i in range(self.num_speculative_tokens): - if list(generated_tokens[-stop_len - i : -i]) == stop_word_ids: + # Check the last stop len tokens shifting by 1 up to num_speculative_tokens + for i in range(self.num_speculative_tokens + 1): + end_idx = -i if i > 0 else None + if list(generated_tokens[-stop_len - i : end_idx]) == stop_word_ids: return True return False diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index bc1d06f54c0..8b85c688c43 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -966,8 +966,18 @@ def _dynamic_step_sample_logits_and_verify_tokens( last_one_indices = torch.full( (active_request_count,), -1, device=token_to_request_index.device ) - valid_indices = torch.nonzero(accepted_tokens_mask).squeeze(-1) - last_one_indices[token_to_request_index[valid_indices]] = valid_indices + + if num_decode_requests > 0: + # Summing the consecutive mask gives the count; subtract 1 for the local index + local_last_indices = decode_mask_2d.sum(dim=1) - 1 + row_offsets = torch.arange(num_decode_requests, device=last_one_indices.device) * (self.num_speculative_tokens + 1) + last_one_indices[:num_decode_requests] = row_offsets + local_last_indices + + if num_prefill_requests > 0: + # Prefill requests only have 1 token evaluated, so nonzero() is perfectly safe here + decode_len = num_decode_requests * (self.num_speculative_tokens + 1) + prefill_valid = torch.nonzero(accepted_tokens_mask[decode_len:]).squeeze(-1) + decode_len + last_one_indices[num_decode_requests:] = prefill_valid # These are the tokens (output + speculative tokens) that will be going to the next forward pass final_sampled_tokens = output_tokens[last_one_indices] @@ -1403,8 +1413,8 @@ async def async_generate_output_tokens_dynamic_batch( request_bookkeeping = self._dynamic_step_context_bookkeeping() ret = { - "sample": self._sampled_tokens_cuda[:active_request_count], - "accepted_tokens": self._accepted_tokens_per_request, + "sample": self._sampled_tokens_cuda[:active_request_count].clone(), + "accepted_tokens": self._accepted_tokens_per_request.clone() if self.num_speculative_tokens > 0 else None, "log_probs": log_probs, "top_n_logprobs": top_n_logprobs, "routing_indices_per_request": routing_indices_per_request, From 6b8835ad27fb5a7a554d3908d98f7c2ff2f35681 Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Thu, 26 Feb 2026 22:41:21 -0800 Subject: [PATCH 15/76] Working with cuda graphs Signed-off-by: Keshav Santhanam --- .../core/inference/batch_dimensions_utils.py | 50 ++++++++++++++----- .../attention_context/mha_metadata.py | 12 ++++- .../inference/contexts/dynamic_context.py | 49 +++++++++++------- .../text_generation_controller.py | 18 +++++-- megatron/core/ssm/mamba_mixer.py | 6 ++- megatron/core/ssm/ops/causal_conv1d_triton.py | 3 +- megatron/core/ssm/ops/mamba_ssm.py | 7 +-- 7 files changed, 99 insertions(+), 46 deletions(-) diff --git a/megatron/core/inference/batch_dimensions_utils.py b/megatron/core/inference/batch_dimensions_utils.py index 1a202c35af5..b98332f8681 100644 --- a/megatron/core/inference/batch_dimensions_utils.py +++ b/megatron/core/inference/batch_dimensions_utils.py @@ -73,7 +73,9 @@ def is_applicable_for_batch_dim( >= real_batch_dim.prefill_req_count + real_batch_dim.decode_req_count ) - def is_valid(self, max_requests: int, max_sequence_length: int) -> bool: + def is_valid( + self, max_requests: int, max_sequence_length: int, num_speculative_tokens: int + ) -> bool: """ Checks if the batch dimension is valid based on resource constraints. @@ -92,11 +94,17 @@ def is_valid(self, max_requests: int, max_sequence_length: int) -> bool: return False # Check if token count is sufficient for requests - if self.token_count < self.prefill_req_count + self.decode_req_count: + if self.token_count < self.prefill_req_count + self.decode_req_count * ( + num_speculative_tokens + 1 + ): return False # Check if the prefill requests are shorter than the max sequence length - if self.token_count > self.prefill_req_count * max_sequence_length + self.decode_req_count: + if ( + self.token_count + > self.prefill_req_count * max_sequence_length + + self.decode_req_count * (num_speculative_tokens + 1) + ): return False return True @@ -296,6 +304,7 @@ def generate_cuda_graph_batch_dimensions_list( max_tokens: int, max_sequence_length: int, use_cuda_graphs_for_non_decode_steps: bool, + num_speculative_tokens: int = 0, ) -> Tuple[List[InferenceBatchDimensions], Optional[List[int]]]: """ Generate CUDA graph batch dimensions. @@ -332,6 +341,7 @@ def generate_cuda_graph_batch_dimensions_list( max_tokens: Maximum total tokens max_sequence_length: Maximum sequence length use_cuda_graphs_for_non_decode_steps: Whether to use CUDA graphs for non-decode steps + num_speculative_tokens: Number of speculative tokens Returns: Tuple containing: @@ -343,7 +353,7 @@ def generate_cuda_graph_batch_dimensions_list( def add_if_valid(token_count: int, prefill_req_count: int, decode_req_count: int) -> None: """Helper to create and append batch dimension to list only if it's valid.""" batch_dim = InferenceBatchDimensions(token_count, prefill_req_count, decode_req_count) - if batch_dim.is_valid(max_requests, max_sequence_length): + if batch_dim.is_valid(max_requests, max_sequence_length, num_speculative_tokens): cuda_graph_batch_dimensions_list.append(batch_dim) # Cuda graph token-counts @@ -377,8 +387,9 @@ def add_if_valid(token_count: int, prefill_req_count: int, decode_req_count: int ) # Calculate separate token counts for decode-only graphs. - # Decode graphs can be more conservative since each request uses exactly 1 token. - cuda_graph_max_tokens_decode = min(cuda_graph_max_tokens, max_requests) + cuda_graph_max_tokens_decode = min( + cuda_graph_max_tokens, max_requests * (num_speculative_tokens + 1) + ) cuda_graph_decode_token_counts = ( CUDAGraphBatchDimensionBuilder._calculate_cuda_graph_token_counts( tp_size=tp_size, @@ -397,20 +408,28 @@ def add_if_valid(token_count: int, prefill_req_count: int, decode_req_count: int ): # decode only # Use decode-specific token counts for decode-only graphs for size in cuda_graph_decode_token_counts: + decode_req_count = min(size // (num_speculative_tokens + 1), max_requests) add_if_valid( - token_count=min(size, max_requests), + token_count=decode_req_count * (num_speculative_tokens + 1), prefill_req_count=0, - decode_req_count=min(size, max_requests), + decode_req_count=decode_req_count, ) else: # Mixed prefill and decode mode # Create prefill and mixed dimensions with full token counts for size in cuda_graph_prefill_token_counts: + prefill_req_count = min(cuda_graph_mixed_prefill_count, max_requests) + decode_req_count = max( + 0, + min( + (size - prefill_req_count) // (num_speculative_tokens + 1), + max_requests - prefill_req_count, + ), + ) add_if_valid( token_count=size, - prefill_req_count=min(cuda_graph_mixed_prefill_count, max_requests), - decode_req_count=min(size, max_requests) - - min(cuda_graph_mixed_prefill_count, max_requests), + prefill_req_count=prefill_req_count, + decode_req_count=decode_req_count, ) # We need to ensure the prefill requests are shorter than the max sequence length, # considering the one decode token is used for prefill request construction @@ -427,16 +446,21 @@ def add_if_valid(token_count: int, prefill_req_count: int, decode_req_count: int # Create decode-only dimensions with optimized token counts for size in cuda_graph_decode_token_counts: + decode_req_count = min(size // (num_speculative_tokens + 1), max_requests) add_if_valid( token_count=min(size, max_requests), prefill_req_count=0, - decode_req_count=min(size, max_requests), + decode_req_count=decode_req_count, ) # Remove duplicates and sort by prefill token count cuda_graph_batch_dimensions_list = list(set(cuda_graph_batch_dimensions_list)) cuda_graph_batch_dimensions_list.sort( - key=lambda x: ((x.token_count - x.decode_req_count), x.decode_req_count), reverse=True + key=lambda x: ( + (x.token_count - x.decode_req_count * (num_speculative_tokens + 1)), + x.decode_req_count, + ), + reverse=True, ) # Collect actual token counts from batch dimensions, then unique and sort diff --git a/megatron/core/inference/contexts/attention_context/mha_metadata.py b/megatron/core/inference/contexts/attention_context/mha_metadata.py index 1b6e8020275..07f8a349b51 100644 --- a/megatron/core/inference/contexts/attention_context/mha_metadata.py +++ b/megatron/core/inference/contexts/attention_context/mha_metadata.py @@ -41,6 +41,7 @@ def update( request_to_kv_block_ids: torch.Tensor, batch_dimensions: InferenceBatchDimensions, padded_batch_dimensions: InferenceBatchDimensions, + num_speculative_tokens: int = 0, ): """ Args: @@ -49,6 +50,7 @@ def update( request_to_kv_block_ids: (>real_batch_size, max_kv_blocks) batch_dimensions: Configuration object containing real batch settings padded_batch_dimensions: Configuration object containing padded batch settings + num_speculative_tokens: Number of speculative tokens """ # Extract values from configs real_batch_size = batch_dimensions.req_count @@ -99,7 +101,7 @@ def update( ) if padded_batch_dimensions.prefill_req_count == 0: - self._max_seqlen_q = 1 + self._max_seqlen_q = num_speculative_tokens + 1 else: # Make sure we will launch the prefill kernel for prefill graphs self._max_seqlen_q = max(2, padded_batch_dimensions.token_count) @@ -150,6 +152,7 @@ def update( request_to_kv_block_ids: torch.Tensor, batch_dimensions: InferenceBatchDimensions, padded_batch_dimensions: InferenceBatchDimensions, + num_speculative_tokens: int = 0, ): """ Args: @@ -158,6 +161,7 @@ def update( request_to_kv_block_ids: (>real_batch_size, max_kv_blocks) batch_dimensions: Configuration object containing real batch settings padded_batch_dimensions: Configuration object containing padded batch settings + num_speculative_tokens: Number of speculative tokens """ super().update( request_query_lengths, @@ -165,6 +169,7 @@ def update( request_to_kv_block_ids, batch_dimensions, padded_batch_dimensions, + num_speculative_tokens, ) def reset(self): @@ -183,6 +188,7 @@ def update( request_to_kv_block_ids: torch.Tensor, batch_dimensions: InferenceBatchDimensions, padded_batch_dimensions: InferenceBatchDimensions, + num_speculative_tokens: int = 0, ): """ Args: @@ -191,6 +197,7 @@ def update( request_to_kv_block_ids: (>real_batch_size, max_kv_blocks) batch_dimensions: Configuration object containing real batch settings padded_batch_dimensions: Configuration object containing padded batch settings + num_speculative_tokens: Number of speculative tokens """ super().update( request_query_lengths, @@ -198,10 +205,11 @@ def update( request_to_kv_block_ids, batch_dimensions, padded_batch_dimensions, + num_speculative_tokens, ) if len(self.state_data["query_lengths"]) > 0: self.state_data["max_seqlen_q"] = torch.max(self.state_data["query_lengths"]).item() self.state_data["max_seqlen_k"] = torch.max(self.state_data["kv_seq_lengths"]).item() else: - self.state_data["max_seqlen_q"] = 1 + self.state_data["max_seqlen_q"] = num_speculative_tokens + 1 self.state_data["max_seqlen_k"] = 1 diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index b5c6971a588..ca5858443ab 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -261,7 +261,10 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC tp_size = model_config.tensor_model_parallel_size pp_size = model_config.pipeline_model_parallel_size self.hidden_size_per_attention_head = core_divide(projection_size, num_attention_heads) - self.num_attention_heads_per_partition = core_divide(num_attention_heads, tp_size) + if num_attention_heads >= tp_size: + self.num_attention_heads_per_partition = core_divide(num_attention_heads, tp_size) + else: + self.num_attention_heads_per_partition = 1 self.num_speculative_tokens = inference_config.num_speculative_tokens # Cache the PP group we should use for PP collectives inside the context. @@ -536,12 +539,13 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC CUDAGraphBatchDimensionBuilder.generate_cuda_graph_batch_dimensions_list( tp_size=tp_size, num_cuda_graphs=inference_config.num_cuda_graphs, - cuda_graph_max_tokens=self.max_requests, + cuda_graph_max_tokens=self.max_requests * (self.num_speculative_tokens + 1), cuda_graph_mixed_prefill_count=inference_config.cuda_graph_mixed_prefill_count, max_requests=self.max_requests, max_tokens=self.max_tokens, max_sequence_length=self.max_sequence_length, use_cuda_graphs_for_non_decode_steps=self.use_cuda_graphs_for_non_decode_steps, + num_speculative_tokens=self.num_speculative_tokens, ) ) @@ -1236,11 +1240,15 @@ def add_dummy_requests_for_cudagraph_capture( Adds dummy requests to reflect the number of prefill and decode requests in the graph config. These are using during cuda graph captures. """ - prefill_tokens = graph_dimensions.token_count - graph_dimensions.decode_req_count + prefill_tokens = graph_dimensions.token_count - ( + graph_dimensions.decode_req_count * (self.num_speculative_tokens + 1) + ) # Pre-construct shared objects (safe due to deep copy in DynamicInferenceRequest.__post_init__) shared_sampling_params = SamplingParams(num_tokens_to_generate=1, termination_id=-1) - shared_decode_tokens = torch.zeros(1, dtype=torch.long, device=torch.cuda.current_device()) + shared_decode_tokens = torch.zeros( + self.num_speculative_tokens + 1, dtype=torch.long, device=torch.cuda.current_device() + ) decode_requests = [ DynamicInferenceRequest( @@ -1326,17 +1334,14 @@ def initialize_attention_state( if self.using_cuda_graph_this_step(): self.padded_batch_dimensions = best_graph else: - padded_token_count = self.round_up_tokens(self.active_token_count) if self.is_decode_only(): - padded_token_count = min( - self.max_tokens, - self.max_requests * (self.num_speculative_tokens + 1), - self.round_up_tokens(self.active_token_count), + padded_decode_req_count = min( + self.max_requests, self.round_up_requests(self.num_decode_requests) ) - padded_decode_req_count = padded_token_count // (self.num_speculative_tokens + 1) - #print(f"self.max_tokens={self.max_tokens}, self.max_requests={self.max_requests}, self.active_token_count={self.active_token_count}, padded_decode_req_count={padded_decode_req_count}, padded_token_count={padded_token_count}") + padded_token_count = padded_decode_req_count * (self.num_speculative_tokens + 1) padded_prefill_req_count = 0 else: + padded_token_count = self.round_up_tokens(self.active_token_count) target_padding_req_count = min( self.max_requests, self.round_up_requests(self.total_request_count - self.paused_request_count), @@ -1395,6 +1400,7 @@ def initialize_attention_state( request_to_kv_block_ids=request_to_kv_block_ids_view, batch_dimensions=attn_dimensions, padded_batch_dimensions=self.padded_batch_dimensions, + num_speculative_tokens=self.num_speculative_tokens, ) if self.is_hybrid_model: @@ -1825,8 +1831,10 @@ def resume_paused_requests( self.block_allocator.total_avail, ) - # Constrain resumptions by the maximum allowed active requests - max_allowed_active = self.max_requests // (self.num_speculative_tokens + 1) + # Constrain resumptions by the maximum allowed active requests and tokens + max_allowed_active = min( + self.max_requests, self.max_tokens // (self.num_speculative_tokens + 1) + ) allowed_to_resume = max(0, max_allowed_active - active_request_count) resume_request_count = min(resume_request_count, allowed_to_resume) @@ -2127,10 +2135,13 @@ def update_requests( active_requests_requiring_new_block[ self.get_index_of_chunked_prefill_request() - self.paused_request_count ] = 0 # chunked prefill should not be paused - elif active_request_count * (self.num_speculative_tokens + 1) > self.max_requests: - # Force-pause excess requests in a decode-only batch - max_allowed_active = self.max_requests // (self.num_speculative_tokens + 1) - active_requests_requiring_new_block[max_allowed_active:] = 1 + else: + max_allowed_active = min( + self.max_requests, self.max_tokens // (self.num_speculative_tokens + 1) + ) + if active_request_count > max_allowed_active: + # Force-pause excess requests in a decode-only batch + active_requests_requiring_new_block[max_allowed_active:] = 1 active_requests_requiring_new_block_count = ( (active_requests_requiring_new_block == 1).sum().item() @@ -2220,7 +2231,9 @@ def update_requests( if self.paused_request_count > 0: self.paused_tokens = next_tokens[: self.paused_request_count].clone() - self.paused_speculative_tokens = new_speculative_tokens[:, : self.paused_request_count].clone() + self.paused_speculative_tokens = new_speculative_tokens[ + :, : self.paused_request_count + ].clone() # add_ and fill_ calls seems to work as intended with sliced indexing (i.e. x[3:5].add(...) or x[3:5].fill_) # but when another tensor is used for indexing, it does not work as expected (i.e. x[y] if x and y are torch tensors) diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index 8b85c688c43..d7b945ca48c 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -966,17 +966,21 @@ def _dynamic_step_sample_logits_and_verify_tokens( last_one_indices = torch.full( (active_request_count,), -1, device=token_to_request_index.device ) - + if num_decode_requests > 0: # Summing the consecutive mask gives the count; subtract 1 for the local index local_last_indices = decode_mask_2d.sum(dim=1) - 1 - row_offsets = torch.arange(num_decode_requests, device=last_one_indices.device) * (self.num_speculative_tokens + 1) + row_offsets = torch.arange(num_decode_requests, device=last_one_indices.device) * ( + self.num_speculative_tokens + 1 + ) last_one_indices[:num_decode_requests] = row_offsets + local_last_indices - + if num_prefill_requests > 0: # Prefill requests only have 1 token evaluated, so nonzero() is perfectly safe here decode_len = num_decode_requests * (self.num_speculative_tokens + 1) - prefill_valid = torch.nonzero(accepted_tokens_mask[decode_len:]).squeeze(-1) + decode_len + prefill_valid = ( + torch.nonzero(accepted_tokens_mask[decode_len:]).squeeze(-1) + decode_len + ) last_one_indices[num_decode_requests:] = prefill_valid # These are the tokens (output + speculative tokens) that will be going to the next forward pass @@ -1414,7 +1418,11 @@ async def async_generate_output_tokens_dynamic_batch( ret = { "sample": self._sampled_tokens_cuda[:active_request_count].clone(), - "accepted_tokens": self._accepted_tokens_per_request.clone() if self.num_speculative_tokens > 0 else None, + "accepted_tokens": ( + self._accepted_tokens_per_request.clone() + if self.num_speculative_tokens > 0 + else None + ), "log_probs": log_probs, "top_n_logprobs": top_n_logprobs, "routing_indices_per_request": routing_indices_per_request, diff --git a/megatron/core/ssm/mamba_mixer.py b/megatron/core/ssm/mamba_mixer.py index 745f36bcc3d..0adf95cc166 100644 --- a/megatron/core/ssm/mamba_mixer.py +++ b/megatron/core/ssm/mamba_mixer.py @@ -929,12 +929,14 @@ def _ssm_decode( Args: zxBCdt: The input tensor of shape (b, s, d), which is a concatenation of - z, x, B, C, and dt projections. s is the sequence length (1 + num_speculative_tokens). + z, x, B, C, and dt projections. + s is the sequence length (1 + num_speculative_tokens). conv_state: The convolution state tensor for inference. ssm_state: The selective scan state tensor for inference. batch_indices: A map from batch id to position in the Mamba state tensors. intermediate_ssm_state: Optional buffer for storing sequence steps in SSM state. - cache_seqlens: Optional tensor representing cache sequence length for circular buffering. + cache_seqlens: Optional tensor representing cache sequence length for circular + buffering. Returns: The output tensor of shape (b, s, d). diff --git a/megatron/core/ssm/ops/causal_conv1d_triton.py b/megatron/core/ssm/ops/causal_conv1d_triton.py index 1e399fd9650..3a412be3da8 100644 --- a/megatron/core/ssm/ops/causal_conv1d_triton.py +++ b/megatron/core/ssm/ops/causal_conv1d_triton.py @@ -111,7 +111,8 @@ def causal_conv1d_update_kernel( ) if not IS_CIRCULAR: - # Load the last (WIDTH - 1) elements to use them BEFORE they are overwritten by the shift + # Load the last (WIDTH - 1) elements to use them BEFORE they are overwritten + # by the shift if WIDTH >= 2: x_val_0 = tl.load( conv_state_ptrs + (state_len - WIDTH + 1) * conv_state_l_stride, mask=mask diff --git a/megatron/core/ssm/ops/mamba_ssm.py b/megatron/core/ssm/ops/mamba_ssm.py index e3d3dde0a01..949141871b1 100644 --- a/megatron/core/ssm/ops/mamba_ssm.py +++ b/megatron/core/ssm/ops/mamba_ssm.py @@ -1,12 +1,8 @@ # Copyright (c) 2024, Tri Dao, Albert Gu. -import math - import torch -import torch.nn.functional as F import triton import triton.language as tl -from einops import rearrange, repeat from mamba_ssm.ops.triton.softplus import softplus @@ -260,7 +256,8 @@ def selective_state_update( x: (batch, dim), (batch, seqlen, dim), (batch, nheads, dim) or (batch, seqlen, nheads, dim) dt: Matches x A: (dim, dstate) or (nheads, dim, dstate) - B: (batch, dstate), (batch, seqlen, dstate), (batch, ngroups, dstate) or (batch, seqlen, ngroups, dstate) + B: (batch, dstate), (batch, seqlen, dstate), (batch, ngroups, dstate) or + (batch, seqlen, ngroups, dstate) C: Matches B D: (dim,) or (nheads, dim) z: Matches x From 9057442e6d23de7699e11e60ca4e08b16326858b Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Fri, 27 Feb 2026 01:19:45 -0800 Subject: [PATCH 16/76] Fix cuda graphs and chunked prefill Signed-off-by: Keshav Santhanam --- megatron/core/inference/batch_dimensions_utils.py | 2 +- megatron/core/inference/contexts/dynamic_context.py | 6 +++--- megatron/core/inference/engines/dynamic_engine.py | 4 ++++ 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/megatron/core/inference/batch_dimensions_utils.py b/megatron/core/inference/batch_dimensions_utils.py index b98332f8681..2bd3e6d9931 100644 --- a/megatron/core/inference/batch_dimensions_utils.py +++ b/megatron/core/inference/batch_dimensions_utils.py @@ -448,7 +448,7 @@ def add_if_valid(token_count: int, prefill_req_count: int, decode_req_count: int for size in cuda_graph_decode_token_counts: decode_req_count = min(size // (num_speculative_tokens + 1), max_requests) add_if_valid( - token_count=min(size, max_requests), + token_count=decode_req_count * (num_speculative_tokens + 1), prefill_req_count=0, decode_req_count=decode_req_count, ) diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index ca5858443ab..1fda5455100 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -1608,9 +1608,9 @@ def add_request(self, req: DynamicInferenceRequest, chunk_length: Optional[int] # no need to update count, as it is already here if is_chunked_prefill: current_id = self.total_request_count - 1 - self.active_token_count -= ( - 1 # Overwrite the last token, which is the useless token from chunked prefill - ) + # Overwrite the last token, which is the useless token from chunked prefill + chunked_prefill_offset = 1 + self.num_speculative_tokens + self.active_token_count -= chunked_prefill_offset assert ( self.request_ids[current_id] == req.request_id ), "Continuation current_id mismatch" diff --git a/megatron/core/inference/engines/dynamic_engine.py b/megatron/core/inference/engines/dynamic_engine.py index efd2a5411a0..865ef8bc686 100644 --- a/megatron/core/inference/engines/dynamic_engine.py +++ b/megatron/core/inference/engines/dynamic_engine.py @@ -788,10 +788,14 @@ def _add_request( len(request.prompt_tokens) + request.sampling_params.num_tokens_to_generate > self.context.max_sequence_length ) or (request.sampling_params.num_tokens_to_generate < 0): + if torch.distributed.get_rank() == 0: + print(f"REQUEST {request_id} FAILED! MaxSequenceLengthOverflowError") request.status = Status.FAILED request.add_event_error_nontransient(MaxSequenceLengthOverflowError(request_id)) if len(request.prompt_tokens) > self.context.max_tokens and not self.enable_chunked_prefill: + if torch.distributed.get_rank() == 0: + print(f"REQUEST {request_id} FAILED! TokenOverflowError") request.status = Status.FAILED request.add_event_error_nontransient(TokenOverflowError(request_id)) From 4d5fe5d94f2d5c7804ca43407a5a6c45c7905b8d Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Sun, 1 Mar 2026 22:11:20 -0800 Subject: [PATCH 17/76] Add speculative decode unit tests Signed-off-by: Keshav Santhanam --- .../inference/contexts/dynamic_context.py | 7 +- .../text_generation_controller.py | 9 +- .../attention_metadata/test_mamba_metadata.py | 8 +- .../contexts/test_dynamic_context.py | 258 ++++++++++++++++++ .../inference/engines/test_dynamic_engine.py | 129 ++++++++- .../test_text_generation_controller.py | 190 +++++++++++++ 6 files changed, 581 insertions(+), 20 deletions(-) diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index fa22e2f4d3f..fb120c2d415 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -2140,9 +2140,10 @@ def update_requests( if self.paused_request_count != 0: assert self.paused_tokens is not None next_tokens = torch.cat((self.paused_tokens, new_tokens)) - new_speculative_tokens = torch.cat( - (self.paused_speculative_tokens, new_speculative_tokens), dim=1 - ) + if new_speculative_tokens is not None: + new_speculative_tokens = torch.cat( + (self.paused_speculative_tokens, new_speculative_tokens), dim=1 + ) else: next_tokens = new_tokens diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index f624e5de07b..411ec1a448a 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -886,11 +886,12 @@ def _dynamic_step_sample_logits_and_verify_tokens( output_tokens_jumbled_list.append( self._torch_sampling_func(required_logits[required_indices, :], temp, top_k, top_p) ) - mtp_output_tokens_jumbled_list.append( - self._torch_sampling_func( - required_mtp_logits[:, required_indices, :], temp, top_k, top_p - ) + mtp_logits_slice = required_mtp_logits[:, required_indices, :] + num_spec, num_reqs, vocab = mtp_logits_slice.shape + sampled_mtp = self._torch_sampling_func( + mtp_logits_slice.reshape(num_spec * num_reqs, vocab), temp, top_k, top_p ) + mtp_output_tokens_jumbled_list.append(sampled_mtp.reshape(num_spec, num_reqs)) token_order_list.append(required_indices) output_tokens_jumbled = torch.cat(output_tokens_jumbled_list, dim=0) diff --git a/tests/unit_tests/inference/contexts/attention_metadata/test_mamba_metadata.py b/tests/unit_tests/inference/contexts/attention_metadata/test_mamba_metadata.py index dd34061888e..c793348233d 100644 --- a/tests/unit_tests/inference/contexts/attention_metadata/test_mamba_metadata.py +++ b/tests/unit_tests/inference/contexts/attention_metadata/test_mamba_metadata.py @@ -242,7 +242,7 @@ def test_update_mixed_batch_exact(self, metadata_context): assert torch.equal(metadata_context.batch_indices_prefill, expected_prefill) expected_device_counts = torch.tensor( - [2, 2], dtype=torch.int32, device=metadata_context.device + [2, 30], dtype=torch.int32, device=metadata_context.device ) assert torch.equal(metadata_context.device_decode_prefill, expected_device_counts) @@ -288,7 +288,7 @@ def test_update_padded_prefill_and_decode(self, metadata_context): assert torch.equal(metadata_context.batch_indices_prefill, expected_prefill) expected_device_counts = torch.tensor( - [1, 1], dtype=torch.int32, device=metadata_context.device + [1, 10], dtype=torch.int32, device=metadata_context.device ) assert torch.equal(metadata_context.device_decode_prefill, expected_device_counts) @@ -334,7 +334,7 @@ def test_update_chunked_prefill_mixed_exact(self, metadata_context): assert torch.equal(metadata_context.batch_indices_prefill, expected_prefill) expected_device_counts = torch.tensor( - [1, 2], dtype=torch.int32, device=metadata_context.device + [1, 60], dtype=torch.int32, device=metadata_context.device ) assert torch.equal(metadata_context.device_decode_prefill, expected_device_counts) @@ -375,7 +375,7 @@ def test_update_chunked_prefill_mixed_padded(self, metadata_context): assert torch.equal(metadata_context.batch_indices_prefill, expected_prefill) expected_device_counts = torch.tensor( - [2, 2], dtype=torch.int32, device=metadata_context.device + [2, 60], dtype=torch.int32, device=metadata_context.device ) assert torch.equal(metadata_context.device_decode_prefill, expected_device_counts) diff --git a/tests/unit_tests/inference/contexts/test_dynamic_context.py b/tests/unit_tests/inference/contexts/test_dynamic_context.py index 26636cadd39..01ea08b44fd 100644 --- a/tests/unit_tests/inference/contexts/test_dynamic_context.py +++ b/tests/unit_tests/inference/contexts/test_dynamic_context.py @@ -2,6 +2,7 @@ import contextlib import math +from unittest import mock import pytest import torch @@ -1493,3 +1494,260 @@ def test_gqa_high_tp_partition_heads(self): # With TP=8 and GQA=2, num_attention_heads_per_partition should be clamped to 1 assert dynamic_context.num_attention_heads_per_partition == 1 + + @pytest.mark.internal + @rounder_override(64) + def test_update_requests_speculative(self): + """Test update_requests correctly interleaves sampled and speculative tokens.""" + self._setup_model_parallel_group(1, 1) + + model_config = TransformerConfig( + params_dtype=torch.float32, num_layers=2, kv_channels=8, num_attention_heads=2 + ) + inference_config = InferenceConfig( + max_sequence_length=128, + buffer_size_gb=0.01, + block_size_tokens=32, + num_speculative_tokens=2, + unified_memory_level=0, + ) + ctx = DynamicInferenceContext(model_config=model_config, inference_config=inference_config) + + # Setup 2 active decode requests + ctx.total_request_count = 2 + ctx.paused_request_count = 0 + ctx.active_token_count = 2 + ctx.request_ids[:2] = torch.tensor([10, 11]) + ctx.request_query_lengths[:2] = 1 + ctx.request_kv_length_offsets[:2] = torch.tensor([5, 8]) + ctx.request_last_kv_block_offset[:2] = torch.tensor([5, 8]) + ctx.request_to_kv_block_ids[:2, 0] = torch.tensor([0, 1]) + ctx.request_last_kv_block_id[:2] = torch.tensor([0, 1]) + + active_requests_mask = torch.tensor([1, 1], device='cuda') + new_tokens = torch.tensor([99, 100], device='cuda') # Sampled tokens + new_speculative_tokens = torch.tensor( + [[991, 1001], [992, 1002]], device='cuda' + ) # Spec tokens + + ctx.update_requests( + active_requests_mask=active_requests_mask, + new_tokens=new_tokens, + new_speculative_tokens=new_speculative_tokens, + ) + + # Each request generates 1 (sampled) + 2 (speculative) = 3 tokens. + assert ctx.active_token_count == 6 + assert torch.equal( + ctx.request_query_lengths[:2], torch.tensor([3, 3], dtype=torch.int32, device='cuda') + ) + assert torch.equal( + ctx.request_kv_length_offsets[:2], + torch.tensor([6, 9], dtype=torch.int32, device='cuda'), + ) + + # Check interleaving: [sampled_1, spec1_1, spec2_1, sampled_2, spec1_2, spec2_2] + expected_tokens = torch.tensor([99, 991, 992, 100, 1001, 1002], device='cuda') + assert torch.equal(ctx.token_to_input_ids[:6], expected_tokens) + + @pytest.mark.internal + @rounder_override(64) + def test_speculative_boundary_crossing(self): + """Test token block assignment when speculative tokens cross a KV block boundary.""" + self._setup_model_parallel_group(1, 1) + + model_config = TransformerConfig( + params_dtype=torch.float32, num_layers=2, kv_channels=8, num_attention_heads=2 + ) + inference_config = InferenceConfig( + max_sequence_length=128, + buffer_size_gb=0.01, + block_size_tokens=4, # Small block size to force boundary crossing + num_speculative_tokens=2, + unified_memory_level=0, + ) + ctx = DynamicInferenceContext(model_config=model_config, inference_config=inference_config) + + # Setup 1 active decode request + ctx.total_request_count = 1 + ctx.paused_request_count = 0 + ctx.active_token_count = 1 + + ctx.request_ids[0] = 10 + ctx.request_query_lengths[0] = 1 + ctx.request_kv_block_counts[0] = 1 + + # Request is at offset 2. Adding 3 tokens (1 sampled + 2 spec) will cross boundary (2+3 = 5 > 4). + ctx.request_kv_length_offsets[0] = 2 + ctx.request_last_kv_block_offset[0] = 2 + + # Allocate one initial block manually + blocks = ctx.block_allocator.allocate_memory_blocks(1) + first_block = blocks[0] + ctx.request_to_kv_block_ids[0, 0] = first_block + ctx.request_last_kv_block_id[0] = first_block + + active_requests_mask = torch.tensor([1], device='cuda') + new_tokens = torch.tensor([50], device='cuda') + new_speculative_tokens = torch.tensor([[51], [52]], device='cuda') + + # Run update_requests natively. It will automatically: + # 1. Detect the boundary crossing and pause the request. + # 2. Clone the prev_last_block_ids internally. + # 3. Resume the request, allocating the new block. + # 4. Map the 3 new tokens across the boundary. + ctx.update_requests( + active_requests_mask=active_requests_mask, + new_tokens=new_tokens, + new_speculative_tokens=new_speculative_tokens, + ) + + # Verify a new block was natively allocated by the resume logic + assert ctx.request_kv_block_counts[0] == 2 + second_block = ctx.request_to_kv_block_ids[0, 1] + assert second_block != -1 + assert second_block != first_block + + # Expected token mapping for the 3 generated tokens (sampled, spec1, spec2) + # Token 0 (offset 2) -> first_block + # Token 1 (offset 3) -> first_block + # Token 2 (offset 4) -> second_block + expected_blocks = torch.tensor( + [first_block, first_block, second_block], dtype=torch.int, device='cuda' + ) + + assert torch.equal(ctx.token_to_block_idx[:3], expected_blocks) + + @pytest.mark.internal + @rounder_override(64) + def test_paused_speculative_tokens_tracking(self): + """ + Test that speculative tokens are correctly saved and concatenated + when requests are temporarily paused. + """ + self._setup_model_parallel_group(1, 1) + + model_config = TransformerConfig( + params_dtype=torch.float32, num_layers=2, kv_channels=8, num_attention_heads=2 + ) + inference_config = InferenceConfig( + max_sequence_length=128, + buffer_size_gb=0.01, + block_size_tokens=16, + num_speculative_tokens=2, + unified_memory_level=0, + ) + ctx = DynamicInferenceContext(model_config=model_config, inference_config=inference_config) + + # Setup 2 active requests. Request 0 is about to overflow its block. + ctx.total_request_count = 2 + ctx.paused_request_count = 0 + ctx.active_token_count = 2 + ctx.request_ids[:2] = torch.tensor([10, 11]) + ctx.request_query_lengths[:2] = 1 + + # Request 0 is at offset 14. Adding 1 sampled + 2 spec = 3 tokens will push it to 17, + # which is >= block_size_tokens (16). It will require a new block. + # Request 1 is at offset 5. It will not require a new block. + ctx.request_kv_length_offsets[:2] = torch.tensor([14, 5]) + ctx.request_last_kv_block_offset[:2] = torch.tensor([14, 5]) + ctx.request_kv_block_counts[:2] = 1 + + # Allocate blocks + blocks = ctx.block_allocator.allocate_memory_blocks(2) + ctx.request_to_kv_block_ids[0, 0] = blocks[0] + ctx.request_to_kv_block_ids[1, 0] = blocks[1] + ctx.request_last_kv_block_id[:2] = blocks + + # Force the allocator to have no available blocks. + # This guarantees request 0 stays paused and cannot immediately resume. + ctx.block_allocator.total_avail = 0 + ctx.block_allocator.paused_count = 100 # Ensure it doesn't get completely evicted either + + active_requests_mask = torch.tensor([1, 1], device='cuda') + new_tokens = torch.tensor([99, 100], device='cuda') # Sampled + new_speculative_tokens = torch.tensor( + [[991, 1001], [992, 1002]], device='cuda' + ) # Speculative + + # In update_requests, request 0 will be paused to allocate a new block. + # Since total_avail is 0, it will stay paused and its tokens will be cached. + ctx.update_requests( + active_requests_mask=active_requests_mask, + new_tokens=new_tokens, + new_speculative_tokens=new_speculative_tokens, + ) + + # Verify paused state was populated correctly + assert ctx.paused_tokens is not None + assert ctx.paused_speculative_tokens is not None + + # Request 0 was the one paused, so its tokens should be shifted to + # index 0 of the paused tensors. + assert ctx.paused_request_count == 1 + assert ctx.total_request_count == 2 + + assert ctx.paused_tokens[0].item() == 99 + assert torch.equal( + ctx.paused_speculative_tokens[:, 0], torch.tensor([991, 992], device='cuda') + ) + + @pytest.mark.internal + @rounder_override(64) + def test_chunked_prefill_speculative_offset_math(self): + """ + Test that the active_token_count is correctly adjusted by chunked_prefill_offset + when a chunked prefill request continues in a speculative decoding setup. + """ + self._setup_model_parallel_group(1, 1) + + model_config = TransformerConfig( + params_dtype=torch.float32, num_layers=2, kv_channels=8, num_attention_heads=2 + ) + inference_config = InferenceConfig( + max_sequence_length=512, + buffer_size_gb=0.05, + block_size_tokens=128, + max_requests=256, + max_tokens=256, + num_speculative_tokens=3, # 3 spec tokens -> offset = 4 + enable_chunked_prefill=True, + ) + ctx = DynamicInferenceContext(model_config=model_config, inference_config=inference_config) + + # Setup a request that is already mid-chunked-prefill + ctx.total_request_count = 1 + ctx.chunked_prefill_request_id = 42 + ctx.request_ids[0] = 42 + + # Simulate active tokens from the previous step. + # Normally, the previous step generated a dummy token + spec tokens that + # need to be overwritten. + initial_active_tokens = 100 + ctx.active_token_count = initial_active_tokens + + req = DynamicInferenceRequest( + request_id=42, + prompt_tokens=torch.arange(0, 50, device='cuda'), + sampling_params=SamplingParams(num_tokens_to_generate=10), + ) + # Mark as continuing chunked prefill + req.finished_chunk_token_count = 100 + + # Add the next chunk + chunk_length = 50 + ctx.add_request(req, chunk_length=chunk_length) + + # The new active token count should be: + # initial (100) - chunked_prefill_offset (1 + 3 = 4) + chunk_length (50) = 146 + expected_active_tokens = ( + initial_active_tokens - (1 + ctx.num_speculative_tokens) + chunk_length + ) + + assert ctx.active_token_count == expected_active_tokens + assert ( + ctx.request_output_lengths[0].item() + == req.finished_chunk_token_count + + chunk_length + + req.sampling_params.num_tokens_to_generate + ) diff --git a/tests/unit_tests/inference/engines/test_dynamic_engine.py b/tests/unit_tests/inference/engines/test_dynamic_engine.py index 3cfcfc1c894..bfeeb3b1ddf 100644 --- a/tests/unit_tests/inference/engines/test_dynamic_engine.py +++ b/tests/unit_tests/inference/engines/test_dynamic_engine.py @@ -8,6 +8,7 @@ from dataclasses import dataclass, field from functools import partial from typing import Dict, List, Optional, Tuple +from unittest import mock import pytest import torch @@ -138,19 +139,18 @@ class DynamicEngineTestConfig: kv_cache_management_mode: str = "persist" static_kv_memory_pointers: bool = True track_generated_token_events: bool = False - - fp8: bool = False + num_speculative_tokens: int = 0 def __post_init__(self): # Compute max_sequence_length. - assert self.max_sequence_length is None - assert self.num_tokens_to_generate is None or self.num_tokens_total is None - if self.num_tokens_to_generate is not None: - self.max_sequence_length = self.max_prompt_length + self.num_tokens_to_generate - else: - assert self.num_tokens_total is not None - self.max_sequence_length = self.num_tokens_total + if self.max_sequence_length is None: + assert self.num_tokens_to_generate is None or self.num_tokens_total is None + if self.num_tokens_to_generate is not None: + self.max_sequence_length = self.max_prompt_length + self.num_tokens_to_generate + else: + assert self.num_tokens_total is not None + self.max_sequence_length = self.num_tokens_total # Default paused buffer size. if self.context_paused_buffer_size_gb is None: @@ -262,6 +262,7 @@ def _build_inference_context( # this is for compatibility with the LTS environment unified_memory_level=0, # unit tests currently broken with UVM track_generated_token_events=test_config.track_generated_token_events, + num_speculative_tokens=test_config.num_speculative_tokens, ), ) @@ -295,6 +296,7 @@ def _build_test_env(cls, test_config): transformer_config = TransformerConfig( params_dtype=torch.bfloat16, num_layers=4, + mtp_num_layers=test_config.num_speculative_tokens, hidden_size=128 if test_config.fp8 else 32, num_attention_heads=4, use_cpu_initialization=True, @@ -354,6 +356,7 @@ def _build_test_env(cls, test_config): num_layers=( 3 if pp_size == 1 else 6 ), # 1 Mamba layer, 1 attention layer, 1 MLP layer + mtp_num_layers=test_config.num_speculative_tokens, hidden_size=256, # The Mamba layer places several constraints on this mamba_num_heads=16, num_attention_heads=16, @@ -1998,3 +2001,111 @@ def test_staleness_tracking(self, use_checkpoint): assert (record[-1].policy_staleness == pre_ps + 1).all() assert (record[-1].kv_cache_staleness == 0).all() + + @pytest.mark.internal + @pytest.mark.skipif( + not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" + ) + @torch.inference_mode() + def test_speculative_decoding_with_early_termination(self): + """Test that speculative decoding handles premature request termination safely + (e.g. hitting max_sequence_length mid-speculative-batch).""" + + # Set max_sequence_length tight so it terminates during a speculative step + test_config = DynamicEngineTestConfig( + num_requests=1, + min_prompt_length=4, + max_prompt_length=4, + num_tokens_to_generate=3, # Prompt (4) + Gen (3) = 7 + max_sequence_length=7, # Will force termination after 3 tokens + model_provider="gpt", + num_speculative_tokens=3, + materialize_only_last_token_logits=False, + ) + + env = self._build_test_env(test_config) + unwrapped_model = env.engine.controller.inference_wrapped_model.model + + # Mock forward to return deterministic data so speculative tokens are always accepted + def mock_mtp_forward(*args, **kwargs): + tokens = kwargs.get("tokens", args[0] if args else kwargs.get("input_ids")) + + base_logits = torch.zeros( + tokens.size(0), + tokens.size(1), + test_config.vocab_size, + device=tokens.device, + dtype=torch.bfloat16, + ) + base_logits[:, :, 0] = 100.0 # High probability for token 0 + + unwrapped_model._mtp_logits_cache = torch.zeros( + 3, + tokens.size(1), + test_config.vocab_size, + device=tokens.device, + dtype=torch.bfloat16, + ) + unwrapped_model._mtp_logits_cache[:, :, 0] = 100.0 # High probability for token 0 + return base_logits + + unwrapped_model.forward = mock_mtp_forward + + env.engine._add_request(env.requests[0]) + env.engine.schedule_waiting_requests() + + # Step engine until finished naturally + # This allows the bookkeeping logic to gracefully truncate the + # speculative tokens to the max_sequence_length boundary. + while env.engine.has_unfinished_requests(): + env.engine.step_modern() + + assert env.requests[0].status == Status.COMPLETED + + # It should trim the output to the max_sequence_length boundary + # Prompt was 4, Max was 7, so it should have generated exactly 3 tokens. + assert len(env.requests[0].generated_tokens) == 3 + + # Validate the engine's tracking state is clean + assert env.engine.context.active_token_count == 0 + assert env.engine.context.total_request_count == 0 + + @pytest.mark.internal + @pytest.mark.skipif( + not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" + ) + @torch.inference_mode() + def test_speculative_stop_word_hit(self): + """Test that if an accepted speculative token completes a stop word, + the request correctly triggers the stop logic without crashing.""" + + test_config = DynamicEngineTestConfig( + num_requests=0, num_speculative_tokens=2, materialize_only_last_token_logits=False + ) + env = self._build_test_env(test_config) + + # Mock request with a stop word + req = DynamicInferenceRequest( + request_id=0, + prompt_tokens=torch.tensor([1, 2, 3], device='cuda'), + sampling_params=SamplingParams(num_tokens_to_generate=10), + ) + # Let's say the stop word is [99, 100] + req.stop_word_ids = [[99, 100]] + + # Fast-forward state: The base token was 99 + req.generated_tokens = [99] + tokens_to_append = [100, 101] # 1 accepted spec token, 1 rejected + + # Check before appending speculative tokens + stop_hit = env.engine._check_stop_words_for_request_post_append(req) + assert stop_hit is False # Only 99 is in generated_tokens initially + + # Now append the tokens as `post_process_requests` would + req.generated_tokens += tokens_to_append + + # Check again. It should detect the stop word [99, 100] inside [99, 100, 101] + # Specifically, it shifts backwards due to the speculative tokens. + stop_hit = env.engine._check_stop_words_for_request_post_append(req) + + assert stop_hit is True diff --git a/tests/unit_tests/inference/text_generation_controllers/test_text_generation_controller.py b/tests/unit_tests/inference/text_generation_controllers/test_text_generation_controller.py index bdf95c2d9bf..eb0a6dfffdd 100644 --- a/tests/unit_tests/inference/text_generation_controllers/test_text_generation_controller.py +++ b/tests/unit_tests/inference/text_generation_controllers/test_text_generation_controller.py @@ -1007,3 +1007,193 @@ def test_sampled_tokens_match_with_parallelism(self, static, tp_size, pp_size): assert ( expected == actual ), f"Rank {i} tokens differ from rank {local_rank} tokens for request {j}" + + @pytest.mark.internal + def test_speculative_verify_tokens(self): + """Test consecutive token acceptance logic for speculative decoding.""" + self.setup_model(torch.float32, static=False) + + # Enable speculative decoding + self.text_generation_controller.num_speculative_tokens = 2 + ctx = self.text_generation_controller.inference_wrapped_model.inference_context + ctx.num_speculative_tokens = 2 + ctx.max_requests = 2 + ctx.total_request_count = 2 + ctx.paused_request_count = 0 + ctx.request_in_prefill_status_tensor = torch.tensor( + [0, 0], device='cuda' + ) # Decode requests + ctx.request_query_lengths = torch.tensor( + [3, 3], dtype=torch.int32, device='cuda' + ) # 1 sampled + 2 spec + + # Init accepted tokens tensors + self.text_generation_controller._init_mtp_sampling_tensor() + + # Mock inputs: [Req 1 sampled, Req 1 spec1, Req 1 spec2, Req 2 sampled, Req 2 spec1, Req 2 spec2] + # Target tokens (what the model was fed): [T0, T1, T2, T3, T4, T5] + input_ids = torch.tensor([[10, 11, 12, 20, 21, 22]], device='cuda') + + # We need the sampling function to return a 1D tensor for base logits, + # and a 2D tensor for the MTP logits to satisfy torch.cat(dim=1). + def mock_sampling_func(logits, *args, **kwargs): + if logits.dim() == 2: + # Base logits -> return 1D tensor of shape [6] + # Req 1: Predicts [11, 12, 99]. Matches T1, T2. Rejects T3. -> Accepts 2 spec tokens. + # Req 2: Predicts [99, 22, 23]. Fails at first spec token (99 != 21). -> Accepts 0 spec tokens. + return torch.tensor([11, 12, 99, 99, 22, 23], dtype=torch.long, device='cuda') + else: + # MTP logits -> return 2D tensor of shape [num_speculative_tokens, 6] + # The verification logic only uses base tokens, so we can return zeros here. + return torch.zeros((2, 6), dtype=torch.long, device='cuda') + + # Override sampling to return our predictable mock outputs + self.text_generation_controller._torch_sampling_buckets = [([0, 1], 1.0, 1, 0.0)] + self.text_generation_controller._torch_sampling_func = mock.MagicMock( + side_effect=mock_sampling_func + ) + + # Mock logits matching input shape + logits = torch.randn(1, 6, self.vocab_size, device='cuda') + mtp_logits = torch.randn(2, 6, self.vocab_size, device='cuda') + + self.text_generation_controller._dynamic_step_sample_logits_and_verify_tokens( + logits, mtp_logits, input_ids + ) + + # Verify acceptance counts + accepted_counts = self.text_generation_controller._accepted_token_counts_per_request[:2] + assert torch.equal(accepted_counts, torch.tensor([2, 0], device='cuda')) + + # Verify accepted tokens tensor + accepted_tokens = self.text_generation_controller._accepted_tokens_per_request[:2] + # Req 1 accepted 2 tokens: 11, 12 + assert torch.equal(accepted_tokens[0], torch.tensor([11, 12], device='cuda')) + # Req 2 accepted 0 tokens, should remain -1 + assert torch.equal(accepted_tokens[1], torch.tensor([-1, -1], device='cuda')) + + @pytest.mark.internal + @pytest.mark.parametrize("is_hybrid_model", [False, True]) + def test_rewind_kv_cache(self, is_hybrid_model): + """Test KV cache state is properly rewound for rejected speculative tokens.""" + self.setup_model(torch.float32, static=False) + self.text_generation_controller.num_speculative_tokens = 3 + ctx = self.text_generation_controller.inference_wrapped_model.inference_context + ctx.num_speculative_tokens = 3 + ctx.block_size_tokens = 4 + ctx.total_request_count = 2 + ctx.paused_request_count = 0 + ctx.request_in_prefill_status_tensor = torch.tensor([0, 0], device='cuda') + + # Initialize allocator and states + ctx.block_allocator.total_avail = 100 + ctx.request_kv_length_offsets[:2] = torch.tensor([10, 15], device='cuda') + ctx.request_kv_block_counts[:2] = torch.tensor([3, 4], device='cuda') + + # Req 0: offset 2. Rewinding 2 tokens -> offset 0. No block released. + # Req 1: offset 1. Rewinding 3 tokens -> offset 2 (prev block). 1 block released. + ctx.request_last_kv_block_offset[:2] = torch.tensor([2, 1], device='cuda') + ctx.request_last_kv_block_id[:2] = torch.tensor([50, 60], device='cuda') + ctx.request_to_kv_block_ids[:2, :4] = torch.tensor( + [[48, 49, 50, -1], [57, 58, 59, 60]], dtype=torch.int, device='cuda' + ) + + if is_hybrid_model: + ctx.is_hybrid_model = True + ctx.mamba_metadata = mock.MagicMock() + ctx.mamba_metadata.request_to_mamba_state_idx = torch.tensor([0, 1], device='cuda') + ctx.mamba_ssm_states = torch.zeros((1, 2, 16), device='cuda') + ctx.mamba_intermediate_ssm_states = torch.ones((1, 2, 4, 16), device='cuda') * 99 + + # Mock accepted token counts: Req 0 accepts 1 (rejects 2), Req 1 accepts 0 (rejects 3) + self.text_generation_controller._init_mtp_sampling_tensor() + self.text_generation_controller._accepted_token_counts_per_request = torch.tensor( + [1, 0], device='cuda' + ) + + self.text_generation_controller._rewind_kv_cache() + + # Assert offsets updated + assert torch.equal( + ctx.request_last_kv_block_offset[:2], + torch.tensor([0, 2], dtype=torch.int, device='cuda'), + ) + assert torch.equal( + ctx.request_kv_length_offsets[:2], torch.tensor([8, 12], dtype=torch.int, device='cuda') + ) + + # Assert block counts and IDs updated for boundary crossing + assert torch.equal( + ctx.request_kv_block_counts[:2], torch.tensor([3, 3], dtype=torch.int, device='cuda') + ) + assert torch.equal( + ctx.request_last_kv_block_id[:2], torch.tensor([50, 59], dtype=torch.int, device='cuda') + ) + + # Assert released block is cleared + assert ctx.request_to_kv_block_ids[1, 3].item() == -1 + assert ctx.block_allocator.total_avail == 101 # 1 block released + + if is_hybrid_model: + # Check Mamba state was restored from intermediate cache based on accepted counts + assert torch.all(ctx.mamba_ssm_states[:, 0] == 99) # Req 0 accepted 1, loaded index 1 + assert torch.all(ctx.mamba_ssm_states[:, 1] == 99) # Req 1 accepted 0, loaded index 0 + + @pytest.mark.internal + def test_speculative_multinomial_sampling(self): + """Test that speculative decoding can successfully use non-greedy sampling + (top_k > 1, top_p > 0) by flattening 3D MTP logits for torch.multinomial.""" + self.setup_model(torch.float32, static=False) + + # Enable speculative decoding + num_spec = 3 + self.text_generation_controller.num_speculative_tokens = num_spec + ctx = self.text_generation_controller.inference_wrapped_model.inference_context + ctx.num_speculative_tokens = num_spec + ctx.max_requests = 2 + ctx.total_request_count = 2 + ctx.paused_request_count = 0 + ctx.request_in_prefill_status_tensor = torch.tensor( + [0, 0], device='cuda' + ) # Decode requests + # query lengths for decode with spec tokens is (1 + num_spec) = 4 + ctx.request_query_lengths = torch.tensor([4, 4], dtype=torch.int32, device='cuda') + + # Init accepted tokens tensors + self.text_generation_controller._init_mtp_sampling_tensor() + + # Setup inputs + input_ids = torch.randint(0, self.vocab_size, (1, 8), device='cuda') + + # Create random logits + # Base logits shape: [1, 8, vocab_size] + logits = torch.randn(1, 8, self.vocab_size, device='cuda') + # MTP logits shape: [num_spec, 8, vocab_size] + mtp_logits = torch.randn(num_spec, 8, self.vocab_size, device='cuda') + + # Set up a bucket that forces multinomial sampling (top_p = 0.9, top_k = 0) + # _torch_sampling_buckets format: (indices, temp, top_k, top_p) + self.text_generation_controller._torch_sampling_buckets = [([0, 1], 1.0, 0, 0.9)] + + # Since we are actually testing the internal math of `_torch_sampling_func` handling the shapes, + # we DO NOT mock `_torch_sampling_func` here. We want it to run natively to prove it doesn't crash. + + try: + self.text_generation_controller._dynamic_step_sample_logits_and_verify_tokens( + logits, mtp_logits, input_ids + ) + except RuntimeError as e: + if "prob_dist must be 1 or 2 dim" in str(e): + pytest.fail("MTP logits were not flattened before calling multinomial sampling.") + else: + raise e + + # Validate that sampling produced output arrays of the correct sizes + active_request_count = ctx.total_request_count + sampled_tokens = self.text_generation_controller._sampled_tokens_cuda[:active_request_count] + sampled_mtp_tokens = self.text_generation_controller._sampled_mtp_tokens_cuda[ + :, :active_request_count + ] + + assert sampled_tokens.shape == (2,) + assert sampled_mtp_tokens.shape == (num_spec, 2) From 8e3710f6ec5e8613d3fa9d1daa845e5d5c7a6975 Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Sun, 1 Mar 2026 22:29:41 -0800 Subject: [PATCH 18/76] Minor fix Signed-off-by: Keshav Santhanam --- megatron/core/models/gpt/gpt_model.py | 1 + megatron/core/models/mamba/mamba_model.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py index 46e5418b67c..3cd6bfae2df 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py @@ -613,6 +613,7 @@ def _postprocess( return hidden_states if self.config.mtp_num_layers: + assert self.config.mtp_num_layers > 0 # The new process_mtp_loss function doesn't handle mtp_logits_cache, # so we manually generate and cache MTP logits when in inference mode. if in_inference_mode: diff --git a/megatron/core/models/mamba/mamba_model.py b/megatron/core/models/mamba/mamba_model.py index 07ae2233334..a021b424734 100644 --- a/megatron/core/models/mamba/mamba_model.py +++ b/megatron/core/models/mamba/mamba_model.py @@ -402,7 +402,8 @@ def forward( if not self.post_process: return hidden_states - if self.config.mtp_num_layers is not None: + if self.config.mtp_num_layers: + assert self.config.mtp_num_layers > 0 # The new process_mtp_loss function doesn't handle mtp_logits_cache, # so we manually generate and cache MTP logits when in inference mode. if in_inference_mode: From 867a13781c40e92dc9d4a5f08810690aba5a76a5 Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Sun, 1 Mar 2026 22:33:07 -0800 Subject: [PATCH 19/76] Minimize diff Signed-off-by: Keshav Santhanam --- .gitignore | 5 +- create_deepseek_dummy_ckpt.sh | 76 ------------------ .../checkpoints/iter_0000001/.metadata | Bin 212642 -> 0 bytes .../checkpoints/iter_0000001/common.pt | Bin 25127 -> 0 bytes .../checkpoints/iter_0000001/metadata.json | 1 - .../latest_checkpointed_iteration.txt | 1 - .../inference/gpt/gpt_dynamic_inference.py | 9 +-- 7 files changed, 4 insertions(+), 88 deletions(-) delete mode 100644 create_deepseek_dummy_ckpt.sh delete mode 100644 deepseek_mtp_dummy_ckpt/checkpoints/iter_0000001/.metadata delete mode 100644 deepseek_mtp_dummy_ckpt/checkpoints/iter_0000001/common.pt delete mode 100644 deepseek_mtp_dummy_ckpt/checkpoints/iter_0000001/metadata.json delete mode 100644 deepseek_mtp_dummy_ckpt/checkpoints/latest_checkpointed_iteration.txt diff --git a/.gitignore b/.gitignore index bf4b473c583..a9ce4aa0a93 100644 --- a/.gitignore +++ b/.gitignore @@ -19,7 +19,4 @@ runs/ # Sphinx documentation docs/_build -docs/apidocs -# Large checkpoint files -deepseek_mtp_dummy_ckpt/checkpoints/iter_0000001/__0_0.distcp -deepseek_mtp_dummy_ckpt/checkpoints/iter_0000001/__0_1.distcp +docs/apidocs \ No newline at end of file diff --git a/create_deepseek_dummy_ckpt.sh b/create_deepseek_dummy_ckpt.sh deleted file mode 100644 index 347f992ec69..00000000000 --- a/create_deepseek_dummy_ckpt.sh +++ /dev/null @@ -1,76 +0,0 @@ -#!/bin/bash - -# --- User Configuration --- -# Path to your Megatron-LM repository -MEGATRON_PATH="/lustre/fsw/portfolios/coreai/users/shanmugamr/Megatron-LM/" -# Path for saving the checkpoint and logs -OUTPUT_PATH="/lustre/fsw/portfolios/coreai/users/shanmugamr/Megatron-LM/deepseek_mtp_dummy_ckpt" - -# Path to a dummy data file (can be a simple text file) -# Example: echo "hello world" > dummy_data.txt - -# --- Script --- -mkdir -p ${OUTPUT_PATH}/checkpoints -mkdir -p ${OUTPUT_PATH}/tensorboard - -# These arguments define a very small DeepSeek-like model with MTP heads. -# Model size is reduced for quick checkpoint creation. -PRETRAIN_ARGS=( - # Parallelism - --tensor-model-parallel-size 1 - --pipeline-model-parallel-size 1 - --use-mcore-models - - # Model Architecture (Small) - --num-layers 16 - --hidden-size 256 - --ffn-hidden-size 1024 - --num-attention-heads 8 - --seq-length 1024 - --max-position-embeddings 1024 - --position-embedding-type rope - --normalization RMSNorm - --swiglu - --untie-embeddings-and-output-weights - - # MTP Head Configuration - # These arguments are taken from the deepseek example script - --mtp-num-layers 3 - --mtp-loss-scaling-factor 0.1 - - # Training Configuration (Minimal) - --micro-batch-size 1 - --global-batch-size 1 - --train-iters 1 # Run for only 1 iteration to create the checkpoint - --lr 1e-4 - --lr-decay-style cosine - - # Data and Tokenizer - --tokenizer-type HuggingFaceTokenizer - --tokenizer-model deepseek-ai/deepseek-coder-6.7b-base # or another HF tokenizer - --mock-data - --split 100,0,0 - --no-create-attention-mask-in-dataloader - - # Checkpointing - --save ${OUTPUT_PATH}/checkpoints - --save-interval 1 # Save after the first iteration - --eval-interval 1 - - # Other settings - --use-flash-attn - --disable-bias-linear - --bf16 - --log-interval 1 - --tensorboard-dir ${OUTPUT_PATH}/tensorboard -) - -# --- Execution --- -cd ${MEGATRON_PATH} -export PYTHONPATH=${MEGATRON_PATH}:${PYTHONPATH} - -python ${MEGATRON_PATH}/pretrain_gpt.py ${PRETRAIN_ARGS[@]} - -echo "---" -echo "Dummy checkpoint created in: ${OUTPUT_PATH}/checkpoints" -echo "---" diff --git a/deepseek_mtp_dummy_ckpt/checkpoints/iter_0000001/.metadata b/deepseek_mtp_dummy_ckpt/checkpoints/iter_0000001/.metadata deleted file mode 100644 index a9acd55675cee0f3b24570e1e5aa1e2fbfc23746..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 212642 zcmeHQ33wF6)(u1A@01=o!y`(G< z%IX_g9tpRjt*Q)G6bDL*&kux(%St2R=KdMMkiR4_x2!5;)@o5*i?qBjSQuFp86R0p zs1EYyj>z?&GIo6Ku!;T=c_)t@J}U22Qh#)cRU;J14>znVK0g?-xf@O}|GGSqyI@{! zO1;Z0q;P6MQB~>8%E(kx{y`_3f8jQ^aoP0gmBA2cX_wI{qf@F88ksESzjNYNO{Qq)Ge9=);NsaE+zc`Mn8l6)8A2Z3{&AN8=s{FXR^~^Q^ z>()oqZGfoT5K*@gqHZdpZev8NZEzO+(aefvDROQMZ-2ZjkuNFu0%i zgL6U^0lzha|;1*t~-2~v%?5~LdOB2G2afil6}8{VfI*b@h{TPMyPvFh%=4X(?WXSleGd4^k; zG0$-8GUgd>UB*1at;?8axOExx47V<0p5fMI%ro4&jCqDzmod+9>oVpUZe7Ma!>!Ah zXPk9SHyH|c6Yey#hD2K~cbipG9u4NE7i85n7+8cWE_K=f-%=|fysY>12ZN@ zE^s3ZFk=?w0;ght8B;D7xG@HpF+X#Gi9f(?%$Ur%z)dm0jG3Pc+zbQE2ot!#%`w1? zn1TzOh5=>-AY9-U7+^-U!Ub-L0cM0YT;NtU0Gn}8AMsc(4t|6aMK&&x)ex7+s;-_e zQ831SlJsKNWen%sx{MK;TbD8Da_ch2QEpwv(8#UJ811-q83P!%E@MpM)@2Mc+`5dB zgIkv|7;x({{F_^s;j-Mi3@>!ntu5R;rOwd zDY88>g#!m~3h0lSrf~dN%@n;nGKB*NZwly-nWk|3Sj`l@Ju-y@2X6}KkC~=${8-Ht zeLOOS0|##k=#QDEaQs-!6n#B1g#!m~3h0lSrf~dN%@qASGKB*NZwly-nWk|3Sj`mu zJu-y@2X6}KkC~=${8-Ht13WT?0|##k=#QDEaQyhBgoTe_L!J?*JWbLa7ttenMWFDy^JOI%KedcFx3?5ev#Hf{u3+h0sHB(D?TpEdx}) zqG`T;QiG|gYf@8#SHsB0hg~&(pTyc1;#D-0*TZqr zg^vbh2I!^{X&7ll%nUV*3e`1f7eqCVIorpk(5`WYO5TLuOUNsH9?n@mE2*=7lno&- znX`UY26EOX?KvQ`evDUWN$0Gem7#6c2S6XVSsy7*Xiw*?563+qvp!B9bmoc8jR_d)5a?7d{&2tY0HmvS)pOAYtR2 zvwrpgn)ThhLdosO!TItk>j*tLdJgP$jx+H!c)X(~jr04Lm4}LFkwfqm8P>7fjNqJd ze_-~E4Esg_?gayhCu>`W8(a4S^vWjZ^y5yxI}g`m1)jhW=&w8%uhH0&nv&v~w2>TX z_TEnQPL1G21#Vyxz}`s{UW~wg=eqjNh4`H6X!bkT?7HlCuGw|j?_9I%vfsI8 z*JZzR&92LS=bBxY{mwPJF8iHpc3t*6*LGbqZ9SUfLZ){|cS!5Jh-A)ftHtHb*(Vp+ zb=fBu*mdi?sC$ulCCfj=i%#Wu5zK|ml@DqjGS?S~vp=!w3F9y0><_mt<8#4j|u8r$1D6nl@cR_(| z3djWoLpaj_BbKV038I5AV|~MQ&AO{8L=m~E+ZMmt;<+PbL%paQQW$W zWE8h9BN@f5%Sc8!>zc_Z*#N4}bE1dss<#N)J)E)9bUuN&Y*n4-h&Dpl_C0`}(~Gn8 z$cQ_9QN!DS?s>*KyCQk`uP;ZExf(ifswQ((#Y?L31OniNL5?*Y2cqH}&i*t*^_f{> zdU|GvzOFKbquXW8+9&-Cj&wPbqxKtAj{lv-ku7cUq$wZ6QC^e5Ws?%QDam;GZ!D)( zfag)W8|KO5LWDrppe}hJ(W{4$B)O1jF?K`^ydRT5yp=t~+Uy3vgPjg~}o0EK7Fno;`pg@;|R(o;AQODK?Pq#D_dfm9|8>skxffTHL_m` zsYbRcA=St(C8QeJW`R^Ady|lAWLuI`&D@ck#Oe8_AA*i))pdS*8zas!>pFkJ>j$;4 z^FvQpB}9UCV?DX~HsR^kb@v2=)WYtLI0{wQ@1qMZBlpHKy*WZ$Ti;300yB#6TC^k^B?(&cIS(klr$ArL zkzF&$I9-sV+|{*eqj*TJxF_RcIS;098CwIlPHT~3-W^2o{W__G@?dC(O{yx$s7PMv z5PcXTsiR_2w7`s#gy9)De@1WB#%y?^qa4Uo!P>0%Hn%D>d;zD5+9@6~Rj`!T?x~^{ zDTYiH?u&r6x*Bk*NU|!(s7PLOs5MhXQnbK~QhTO~=vB;G`!8gwVC|j$W2TDQDIPLa zu$0&CsiGDshD;Ui+heu58gQyevMQIMrkD51_Lm1L0<#YE6|k=RD;?568>8hFWd`o{ zJ#tnNa|2SWdO{+War+*(F5`9zZe7Ohd)&H=+l07vNhW=CO8g~b+`8;zo9w$T*$1B4 zb=e1=*>%|mp4oNT2cFq=*$1B4b=e1=*>%~+HraLA2cFq=*$1BOy3E`6uwI0EztMc5 zCU!lY`AU-Ootngs7OP&f6LIS@oQPYO;Y8fJ3@75&WjGPHF2jkqbs0{?t;=vCZe4~G zaqBXih+CK8MBKUzC*sy+I1#rl!-=?c8BXM^%e*y!??ee98aSjdJ*VyhKIbI^+&h{W z4$QBMxxE=!3tV^Yl6}`u z-L*@$DGunhORoE4fHMK&{qpM1QB+?50*awXu+F|OguM`9*JZ>|TwKNw%&p59Rk?K; z10}aEW9;MBWejKBx{MKuTbB_-aqBW-C~jRw48^U>h@rT388H;ME+dBG)@8&{&bnp{ z<$4py;W-EHvX&`{d%ZkW(3cZ%Ne7Z|Apfs^I~@5vch_FOs&v#!Pvzbd*P$J_a@p|u z09_y=*)YUxVf7%%t2TfxndFNifG)|}m}GBD)H;`hORsG%$x{CR!X;~LSnYDjBp*@- zx+H4@=pW*eaOt(pC0WY=ZCiUbWlyw~!bf^LCikbbl?)+5U`KcO@ z*ZdRCPw_i{!Vp;3K<8jB!rAM`TId`s@wMAI zlKfUWutO&0WpJ>T+B0*omc#!T=U|Di-OiEZ_i)=+l+N&1h615rW@S;JqR`(nqo+SB zBP&w0CVgR~Dx6kc5-2SVR`?48p+F=roK_hss|d^p(!ZNuI!g(+$PI=9=6@emS{R%Y zSso$3B>yGcVEVbGk?=53cEHM&`-5{xloBt;pVc=KZd7a`=aKEaV+emiQAS~LWvHS! zzbX_g%qS=d7R)R!D=rOXOfN18R?e*q1!qOVY5oaR>!{M{WieGPDuN|}Q1R>_dD}Qt zWLER{B&{|93(6zm)UxT*D}$lP=#+43Nw9PVsgnD1qpEykONW#c+A!l=@{i#@3GGVg zR=l%Z28sG1QPHqhtke16ws=*d;qZJ`G!COzU5Hb(d7c#lM(_b_M(Cp7-o?$ z6%@+!o-NSOH$g+qK^&O(hzkY0ifxC+Q>hmVG)S}^%Deq?0k2})p%K-mO9UDu+79J? zX}N$`vF*^9>fRLs4H9jK@-AN~;8kopG^)Dr4uJ-VwnKS?_X&6v+YXJZ@*WXrkZ3!U zx8G9&Uhb89>?JZ}dD{5}%9ozXSRtp!Z3odrnmpC0mjzT4;;3reVJRr92>Jt|P_Q&q zTvnPPFRJ2AfxXA#L{)$9(adbi3!1V)AZU0Eg5oCKYy}m=HwsiZyG3*(M83g2z7TYa z;~Wi+y~HX{cKX+Xc8!Z8tFaF{Kkq58XvmKOMb769*qxA9(S55xMXxx!ByPvaleXC@ zc=FCp7QIMCo@(G80hM#?VWqhf!fDUOo`U@XqdUhQfG97ZJ%IIeQ~Zl2@Zo zPM_vN9o$TN&FR=2Uz;fOnS&TJUjM|yB*8i5{=n=RHQ%H`xvjJ_`QE(!p&PJn?hu0fA@b&_X>z1{z_pyib1@dAzt@ZpB{HfiXJ=`_K(A+Z%U2YqBTR5vAZML z>9Dpzq8-n1Tu=x-ykr{|+Hty22tBJ}9~Rni7*Pm){hXGKkp^*`T@*sUc##ha?KtWv zg#PK$3@o(cWTX(f%~gZ3(2j$YLg<0lr6@ zWmssaQ*11nkA;Rh#lF&ISZJtI98wh*ab|-?(|1(C$-=jmY!+*D-O1uuw>UERnq~g6e7jhtC#K{NHuNSQ;R^7G8~3YA z1Z>BsJnd6%ZPiRPC5C}dTeHlvJ3EPGLY{Vw8kL^*zwXaemssO=NBrU+>)S6qIz+4y z__Q_4{Lj;4#WEpJyGD&lPrKcVez8PXcE@q-P-7~8(5q!aUB+@(cmqx42j7~Hg@#5> zOW$9Hg@$H#p88}Z78)9qZTadUEHpH`)AYL+vCzmqv7GQYuQ~r7GR;F!=|tI)mUig z*mml`Td~m40jPZF!&qqOD6!&{wODAVE3X^#9u^wvZ@-M+B7}CY92`!OymjJ1XYNRj zOujDzmWn6(Q_^tDFt#ykT;e3GAL>=V&tcum4$wJ0*xD1FH3sJiI*~byWivawWnNb= zrCy5pKj)cX*Mf&_*BbRIJ^iQ8?dg{5RCy}!rJ#g22puo$IdcSg(U*9j;rZ9+4Yl>9 zvEB~~y55Wl zt`I_pb2vSi@v;N56GL~yQf`k3B^o!lIx*Bg((kzy3k~&;@ekdPg@*da1y4MUg@*da z>gQg;LPP!IotNIjLPP!I&)2`iLPPzd)4M+jp~H~xJhC$CcxM63@*_X`LtQN^6w#vg z+|N^5L}Mq`>;$qcYl7FJ_Qow~VzsW3(h-p)Bo6TcLCeaQf9N2TX#6BSQvUjDH!L)C zq&#?MPb@Taq#V3w5EdFbQqI^v5(^C-DX(rc2@4G!DW7Q`z(PYu%5U1t5<-U+ZXm?_ z3t+Hp-eImMf|E7^_4X|i3+|RM%z)lLb$r=NxG6fd+ch?O6V_TH>xjghmHB-b(7<)s zu`96H^`Y1g^<0I;Cf^2x>Hf6uJy>k=B{Udzy%Qh7VmE?fAA9mMSZwE~<%IoY!l;+9 z*v_xZ39&C6`z99K`GGkh_8pVn$6_~y+WXxppJB0^L9urSzQ$rZzceSbN#~-UvDj%) z-J?o>!(z99VpoLrV6mN_oDGjzdRZGG zxpA9ofQRno9-78<&BdCkf!EHrd~@!2~DV4!9}5kg?Q@zuiiL*G_LVK4#zI49` zLx+SOZO#-zhld^zzdybG6tS%FoE4CL^|Tu+IP3Pe@o;=l}9|BrL4+H`;Ly0g@lFm{|_>eu(1CB*8n6etp9f#j)aBv|J*zz zEUf>}o-Be5yI$moeh?v=A4vQh2z#tw6j>A*9}&KIxiViUB#RS}CA{xr&NrNUm;oBz zDDgl;+y}=K);<6Yzs(kE$OqX0`tcf2lnM}i=t7}r*;91bAr+wX465tPQT_Jp$15&hbK?H>hwQ8CKL_n z^eoXDI(_RkLZy&SpXFVrAMl1y4W!ePmu1u(FaHA%G$?lZtG*O!fOLBJCZJcH{_!7# zq9L80C0awL|7wR&DWub9d)Mhx{}QTk<(nKkMA1+yNvYS&Mt;GH$oNHa$uYSV5*D`P zSlAH>3tMvB+Z725TXKBZ0|^UTa{S#72@6|t9CZ>B7PjQb8zX{sFS+@yfE_hJnc?$J z7s)fab&*?g-+B%b7S{jYG?1{c{{MRj2@C80hs{UA!utR4#Yk9K|DUxS2@C80*WWCH z4G&0M;W7HZXA6)Q?-FY1EzlCTB2o;6etJYGL*Z@#><|H1LbiEcC>Cfl<#pS@*Mu^H zut&#QfhcV50?b$No>2E-sBX`EApE&d5b$|MKT<2b4u9f1p$wo8Bq_Mwv`r{>V3K3j z#(QizO&G5W+$=d~_pEnLG!%xx!shHVn<8OhbN0n;kg%{h`@W7ySlFEXQ5PgEY|h@B zj)aBH*+*xKV8e;0`mnPOC@>v6P^7{5ShADPISC01>*T9)kg%{$zJ5Fs7S_pko`r;k zb@Iatkg%{$K4K;k7S_p2D@CwjUvyCSZ05b;0wG~XFGI|=mk7C>M-o{V`e8j95D5Kz zl~9qh^~%9w+lY`W$6@K@M}#NbBox;*PMl|*uW+@H)oFg^*lx-FLM~_LOHu&$1?asjVe)1h(Hr%MD9B$^K8 zeQCLXSFP#LNb24d0tpgLhw?68Dd1IWIy9EL@D70liKat&gZBw|)tU~Crt%&UNRVhc zl(*kg0^W2-VuV~+WxWx{7|Y2}&RvZ=M)J-tP^$D)#v=T%~_X`v`pEpZU z-1Al#dvekm)f44mouiNY#aQXbk+Z8=l$JU3@ZQrkPhtmWFsp62tx_a~l@BhHs+8un+3H4~t#HH&J5PCk|@a7-^FlzKIgUo_>-Ki(SJv zQDWH3MrL5KYxpKg4EyQ4!C34XzKIgU{(52_7Q2RTqQtP9O__$puHl;~G3;J}QY>~2 z-$aRFpFQIuEOrgwM2TT9nza;*UBfp~V%QH>-h#!h;hQKi>`&%Ch{dkqnn9`I7PGg~wo_QkT8=T@AySuAoaN2Hm^Orp+z zpRqkTIyobb=wAPrHD8&!U#t@H$ZOmkiC-h4ANkgTR?S2cV;K0zb?ZFD=#(6tN{@U< zd8Sw;&z)~~!eWDy zKN~*C#A1V!KmYx72o@Wh{OPiJEEXG_{K@&=FT{2)b%BFFXG((Re=2_|6Y4Swj*`v) z-1OUgEHref`|9pxSZL@{ciY~TSZL@{w|&EhvCz<^?$BmyvCz<^ZgK1Pu+Y$@?zM+( z5kk9{6b`3I;wtjF!*_V16T9@GCpU|_n4cHoe&r#yjQsA%lr-E5jcv{vWpc(Gn93IE ztvykxdr{Fhv$I$y^iu)sm*~tZDxHh7*q*KMV^XoF?eEjm)|bY3i^zE1HF85Ig)2@N zgoTD~3Oq7wBo-PvDSUqFL@YFPQfQEO4i*|ZDI7ns1PcwF6egWD2MZ0I6fQdFQYM2X`p1!TAICyN{bTg}wODAV ze}opjjfIB#$IVxLiiL*y$7|PohlPgv$Mzd`2%*ESyVD&0QA5y#^cUEK-{F?OJrT?b zOSA|cwz^4+XdK0wS)mqYMJQSX8}}Y07MugmEFuL$hs@;XFmVr@mOf+ zkoo=d{jt!{A+zPn!?DoNA+z6`z=g(jbP5`0@Xwj4VBkDIX2 z(3oug|5jt6$*qAP)pw^nfQ5#RUGJwpfrTcQ5`t9!FYN^^G;~<(+U`{>G`aH-q--U%H7eazm zpEP(c78*JQ%^cpSWi+N_4XR<`h|W*Af!YVyzJq7gu4Sp>zHluik}) zhR*rFz4ZVV8an46`oUvZXy}}O@~6*Wp`mkr$)*>D&|&8t(5!dW@lF94L9hGnRk$kF zET=`)n*Y2dR&|&oKJsMn{&C0qHe*A^&=4M(m@cl~K}fuRf^84$dKlht4jQ{!@7^dB zXuK~yb`9A385SBkcI7wNgoTEVT~{^z1`7=xyB=@#0~Q)Oc71j5f3VQdu`BiPZCGgN z*p+$Y??UMCXpVQ7L0>{zR+ocJ?*6AxqtUFp?Bu>M<9}FasDs?wdmk1W>L4EutWUpP zL|&C;4QUL6Kpo`op{ZDCsDm7JYBMY})IsucTVkQ14szavwnFIeWR8Q>bnr{UTL82E zt!EyhE|;|c*CY2$|Doz~;lYb=vIbc9_n=QK*EL!?)?P^(Ee|U>QmE0GB0X9TuQ(bD z4IM3K%{dke4IM47zbFF>4IM3CyfhmN4IM3iy0R}88ai6Gxn>|18ai4IykW2qIy@zD zv~)O#Bl3r@l2l!~;O3J&QJlO%D17@VV#PfbM;d5mWVZ#orcTk`$`0Cs0PBrJJB0RVfr z4+%@&AOOIQ=q`c{PvZCiI49wqE0ku7lp3AV6EhPqSGZvi5*FqPYfnMK!d&6!aY$I0 zE3}=0goU}n2|*+*%oPgDkgzaUSTa`x8-^asO4zdYW*&I*5|K7zvSjDqyaWjg>-3 z0rUZQj?pV0m~yXBH_!)Ix;6N~rH>260)4>!VJqp*_rMyVAXk3MvB?G6n}IxKSgQ8p zH$+N}oMYwkQ~N(a!orp)-M>P@!j>qf{fLBxEm6+@4G9ZdqTK!$5*D^ZdApHcwDNKH z1lST~S1S>$dx_3>1?0d1>;Jw(Me>YGB>VrV-H@=b{$HMfgoXA075$O0u>Qa9BqS`X z|Nk-u2@C802c3b0h4udv10vY4D;#q4e=nEC(`O20xZ)6=jAXAzmsJUMyQX%YZqLHo zr{@a=0iS2|Yr4`A#@CC5GFM#R6@nyd>y#i%3~v`bDiky@$uaDM&)agEFkUzKTXGKn zVV!(rOC&6;lb?GC5*F6UZ}cHyVV(Ttqmi(%PX22K5*F6U5AG*|4R`c18XY`D$dwbv zmF$pn#;HPaUE{=g7FS+9R>)DBFmrn4d70&96`|>6CBMqpM&W^QmsAXHIS+G9*EnL~r+mBHXlKXEUAVb!czbNvK5lo_L_%q$KCEBrmJ z|7AtOjVlAQgBEYM-dWlCfu6x&kL<#J**$t^4H(d)f1t2ekE~vOdj|@859r@7zZdyI zNm-!K{6Sv0@vO4Ks*+%3WNM@;Kin!%QsOTuD+rYM%S!^K=D!VhBMmMn$|x)*4JyvB z3Iz)@Vp1~9YNf%7NVv%aB4Ql*Tg)nrh|F0gMt9>`GAoCho>4*m%_#CO?K;Ny8NKtxGT2ZSaUs1QJd=g8q zBMqpQ$dug6CE0PLdA%8vyleikf&7tYyl4L72jq`x#Es-1iv|(CzJza}3?IqyVfls= zzLrEl_oB8%9TpcIzBm-gm!*$C2mxyD7!T>=Z|nonPxO#J{;URw{%jBFNg#@S_JtnO z#~&vFwZ6nddJ->VTOaa}KK{fDsPz|mNZ%CS`b#~ek3W$(g0jsHRlPP<212f!k%;|Z6 zKw+&bP>9Ah6u71^;Z4$rm&rW+Iw(v?U~_H-f}Ys>Y?DcmlOO@t)XzL5cb+Bzwft)j z$w`2KYx&O}k~{NMpqBsUAvp;Sa4p~CA-OY^25GrbuO5lRnNhJ+hbOTCs_{lsZ~2>I z`P+EQPoe`{`#XBe?@UgDo_`l_`AK|$YkxZ7zb-X}{7L@jVMlbPQ$dEv^41U}RKOWx zfVce4yf3Kz!_@eVQRMfoc^(jXu1D`iADsf>Yjx-#aRhekx>1J? ze73>3O@|JA+QGP2hYo!H!FW`M4tyfQcvgoFd`7}pr$YxmHDRn**YN>?h|f_hK0a>8 z0g#Zwj7?S;-&@SAphZCiHfv#Q)u95LzA$#_ zP=U>37<+Z7z$P?|M)f(XDV5%d&1@Jgbf`e5I1bjK0-fkMLWc?xrr=hYN9j<3PI~mv zp#q)y=%YggIstN`4i)I^#|RxN&^eG?9V*b7kV)z)rV4kQfT9Wp zcWw%q#x9STTQ||ooEr9Aj8co76<8<;k;ux-)}lkgAQ$NnB9R`tOp6W)gDlk{L?Tah zy%rr32DwFt5Q&7@omzBA800}6LL@SGPioO2VURUCgh-?iU(=#P!XWSJ5F(Mw{8)<) z34?6XAw(jn`hyl75(e3(Lx@Cn_fIW4Bn+}ohY*Q0b7})#hCv}fY$jolmO6w;(UJ{>|N67NTA(IH`w3>`uwb^!Wn(IH`w!8(LUY(<=+MTdkz@^lE1*b_Njiw+5c zOj8#!oj@pW&J@`oj?EctF5_IQF~&@@@EK*K03}k~+!FbVxfVAokD{PMB5klxhYpF= z`eGeABo@in=+GgtHeI1Zhs2U_l@1*eE3tcY=#W?dJ)%R0#Jb}d9Xce!{g-s;kchJ1 z)S*KnxP4!T4vCoZGaWi4LbAnm*WIoJ!3_FjN=^aA8C zFF=m)0;H1{AYHrw>FNc@QC@%?;{`~%7a+%Z0n)<@kW4Q?vb+H46q`(W1pcf!ByZ|Zo0%WEQVsCQT-`oP7KM4_lqBNt_EPO@< zDL@+>7uvkhw?ZY3&s=1KNQ%GQRGnrlGz*`xh!i04%WPf+@z>ZOlHzYPRi_ydv+x=J zA_YkNZ8ooh_|-OuEgpLu-~m&2nsJv|_>B8V0h0fO&8#5*DK9{tu|cG5u#P|w8$53o zK4UE@KsH!!Gb`BOZ5zavkMsPErs_0fgIV~D4@m(M|CO5;+jV}R%v&h)cTnbSHnW1+ zcGw`2X4_+`PBV6zh0oYc3Xs|AHLA||jZU$Nv36@lFbH)+v+x;BNP(oTS50m0>f=uT zw6(Yyo6$vGa#Nb*rqqkVk!Q0mc2^_o_DADTf~A}#Kime9w&zh+n~Y9o;WN6D0%Xq~ zHnT!M%(Ovl`8fCOZ>ml+vdzM0^d$vI{7{=$LHtQJh@|+@rs_0fgjx8EQKSHgpJ4MU zh@WJGNQ$3ks!lT|n}yFfn-n1N(`{Y_@kKU>Egn0HoNMY%GiI6vd8LIEAo+7`W(D~d zcmXoc29dVGWdwrQV1Zfqj7v!YvcXcDS-}QZ+aR`loae7FRi_!t&BABgKnjrfRc>Bv z*SVK6ucpj*Lzy47nH9|TxD6s{wiis*X~t7#;WM5i1;}i#MtKz`?sqBm8fL1O#i&sLj^Z_UDI{74GWqx{WgR_KR2Z4g^N&VBzjRi_z$ znuX8!A1OfM8>V_-`&1i5QhX}{K*TpS3!l+~6d>{KZC(Y-ceFty#UE*^PBRWS3!l-M z6d>`(+Pn(lkGDZ=@z`G9$JCu>WSWJ~=tT;U{1a_v1^Gj~02yY3NZTNXKoA=YHw&LJ zk`y2tjJKH;Y%tLVvE}1Df2yfE%{ap>e8v<~fW!ygyx6W&N|}o&a|x7rw#}?ywmCM4 zq}djls?&@M%))2PCk4oCi=(^>6Zf^0I!vjr##1Xy+&5a>-b~!L5iI42d$kQ>KS*pp ze86gxahF+;q#-FlkMaqdS)m_3WrNuAaqhd;RGnr#YZg9Z4Jkn4->`WV#ILtOZ1Gsz ze`xAXGu|-^pYc8^K=Qv(Xa0dQZ=%fKL7BIyGw-I%zftC0Q0DzCX6(_CCXFq7*Eb8F zkxB}Xz1`drM@L#)+;r?PwwSGteX~SGip1%W4w|G$93bheNs7c7lJ1(MNE{>SsY!~& zNs|7Wq(~el8LCN&#CejDnxse^DH*3pio~gs$(p1{94t9UlN5=wB}JN~NE|OI*Ca*a zgvlIDQX~$UEbv6i6$FLG^Nw&38=zccNinW63!kx!6rcdb%`Fk2tgyH_0g8eYi2!Al zCMgmD$~~H-NCYU4Xp$lkpgf~VibR0&k|rq<0m_@2q(}rP?`x7G5ukjgNs2^(^0g)@ z5&_E3nxsesD8FfvA`zhM(IiD8K&jWn%oTV)9ZCc!O+As)hM*`1D2G~7jDyU=XLKM1 zC_r&@O9UuKTHKr`q98>gKsi>E6o~*OTay%t0A-*iDG~w7Nt&cc1SmP0q(}rP<26Z< z2vE+_Bt;@X$=4)BB0w=TNs$OpDl|!v2vE+~Bt;@Xxmc4Fi2&tFO;RKRlxsbaaw9=e z4p45lq!>4wh0j<;3Q&OJ=9UOh?z6Z#0g8eYi2&s>O;RKRl;<=_kqA&;)+9wDKzU1( z6o~-kLrqd70+cT_Ns$OpzSSf}B0%{?lN5;nWv3=75&_CznxsesC=Hqlas>(lzeIr2 zT$2=u0HvKKDG~w7VV+3oN>G#ol;bTa#?fZsGma$%C_rIwOI!ro+al*gC<(#;bs#}gFg1WO-Fijip+GMq)3=0SCbZr1j{5{QY6eWRg)Hp1WQnt6bZA;(xgQq!BVA5iiBC_Y0@H* zU|FO~iiBB~Xwo8)V7X406bZB3tVxSRg5?ffQY6gsfF>;x36>{xNs%zi3!1b@Bv@Y6 zB}Kw4?`YB@kzm=VONxY9zS5*cBEj;#E-4ab*{Vs4M1p0PE-4ab*{ey5M1rMJb3sx= zVR_aCB=BjEPTeRqyPmdZf=PHWrM}d2~ZTINCYUKXp$lkplsG8MIu1?QIiyj0A;%- zDG~w7ZcS1o0+julq(}rPjnf3_0tIK52vAySk|GhHwAUm>B0%Y+Ns2^(a*QS^5&=r4 zCMgmDNBptu5*=!ce!^~aI_`mXuM2J%Oq@t*mQACN!N7e?kqx*HqK zzmKzo7&&I)GxA6Q3Qs1hGY2U16w2&}G8d~eS5W2>%3Ka*K3|>r63RS}GB1QOU!~3* zq0CDu^R-aso79=_pv)^N^X*XP2cpah*RwxKsUN1)kK?Hou4i9kakFD@iEpC4M&K-0 zS1`v0OOo-XS@?{1NCEQo7`1{qKC!qxGsk8Er#SNcWGOPfF$s?K~oWj>lR9}8vftdb!1d?sa{3T2)VWmXv8$|zDu%KRH; z-UVgeug=`0r8TS9Hw*G55mJEqc56~m(W^RA=6001J(Rgilv!c4KbBH=qtxB;)C$8} zw#Dt)@HUXZsgCw1Tat_u&BA94BL%1r#Hba_G1}tx%p4O4oMOK`+frnlZWcb{EK-21 zQK-&5i!#rk%rl|PRqD(aQRX?6`9di3W$Mh!DDz^Ot*v>UCqL09776Fpy_5t2OwD%GpBbd z2tfxR12hRi2Oz^V2|))SqcjOY2Oy_u5`qpu&eS9X9e@Nh2|))S#hQem1CVny2|))S zb2SM;2OtYI2|))Si!}*B2O!sI5`qpuR%jA}4nS6UBIF)|q8NZYY6&s!Hw&NfFeyL* zh?^N5fIMq4a{>?rA?N^PohBjZ0A#%;A?N_)15HBE0m$c?grEbEZ!`%(2O$5^Bm^CR z{H{p|Iso~fCL!nmq<(9`dS78cLkA$uGzmcmAZ;}XK?fj*Y7&ADK#tTT1Ra1J>xqzT zf}$9JoL~tt`j~~!7(fb80ODpw2OuX~%$xv3K?phk8LdeOIsloVNeDUsnW9MuIshrq zBm^CR%+w?V9e`A75`qpuF3=A z{YsVh2g1e^cI-P~Lk~c^{>` z_fy`7p}fzk^1e)YU!c4%L3!U&<^72AzDs#OfbxE+%KJU#-AsAEh4OAy<^7ZL?x4JX zKzaA6@-}X3&4cyK!e=xh1!x{@MJg)IgB>VuTgrO~l((}gZ#w1eN_mfg@@A>>4y3%j zDQ|x$?@6k>qbctQ$~y|mJ3*EAY|48&l=nu;dkd8JZdKk#DDQog_aP|nGpfAnDDU%>cP*56y(;gA zl=mIV`#zNS3pX$JUC8e!^Crst4V3v;7W3$o%TcTEU6gw}<=!dJO_Kju?!6W_C;hJ= zg{1$nQW~{0Cui@akPQH=lopz#kSzeLl!G-%A)5eLDMx6MLbd_0QjXFjg=_?1rS#Aw zg=__2rS#Dxg=_|3rJSfq3fT_8N*SR^3fU0AO3BqEg=`66rA*Q!g=`97rA*Z%g=`C8 zr35ueAsYi&DYHD0QbkY_macdy7gmA`zh6sY!}NfbyUwDG~w7lbWPR1So4XNs$OpUehE+B0zaplN5;n znaAPSpPHD#qPoUz5|(Pwva!n@5TJq{JVPwkcd<^K{QE7k%&}&&?H47QrV_S zibSOHrzcYO5ftS}rSU;D!EY8M!A}ZMq{85q_@GZKi<}dxC}@$0n%e7;_#@B+Rl>lNN~;;$6C=NSNgz zOhF~6bZ9z)}%!uk?^A~DH3Mcu1SkT zQe(F+DH3Mc?}?Vi2YZlUX-z=LNrGl(;WJv20+e8Jb4w&OI#}GC{EC7Ui9|wYO;RKl z@7*;?kytSI)Fee>k=kFA6p4l5P)$-K7GondNs(9pjngDWV$m^KlN5 z_fqE7l=*Hb^P}p_&r#+lDDyK==9i<)3ZIX7n^M0cX_a%W- z?0nx_ij2)>;WNG^1*r3FRcHQ_GVh?we?Xb{sWUe|#99y3GYhieN(xY$Tak*2*6u)= z+fwF3pv;}4%nIE;ol7^X0Cj;Jb><0_ zIhQh@24$Y2&Rj^Dr%~p7D07KAa}{MSqs*01=6O+Og>HWtrCva(FU3zWRE5x zB;SJ*Qm?%kCVKC_B-?`%(o~ZWlIy_6(O) z91l)NmL?%2!-Eqtz!M?E2#R9Rm17AphMR@Y7)c6H(B)=E2O#4uX3hwvAOszNoTW(! zIsnPnBm^CR7@CBj1CR<$LeK%o`I>~F1CWa~2|))SS85W14nVHeBm^CR+^9(iIsmy% zlMr+Oa<3*K=m6wVO+wHC$g`S+paYO~o(Ne_P!a+V+&!ruS~84x%))29PYO`@@r5ey zca(P%<^2ZA`>QJNF3P)|^6rH4?p5VYC5H#eTSh5n;WHYN0%Yfw=9xqDAYz`e)%@(!oG zBcZ(GRe7gS-bs}AOek-GDsKtpolbcTC~ruWcOK;l=o&R?_H|A4^!TIDer?&-ltW0U!uIvQQjA!yl<-Ven5HOro8V#c|W&#qZgAe vKAw!68^|g0JmWp{&#L^O@fE3+XM9QioEHhFR#gV_OM;Ors`3L>`5FHQf>um` diff --git a/deepseek_mtp_dummy_ckpt/checkpoints/iter_0000001/common.pt b/deepseek_mtp_dummy_ckpt/checkpoints/iter_0000001/common.pt deleted file mode 100644 index 650ff74769adcc5a4c629023626910850ffadcd1..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 25127 zcmcJ22b^4G)%S*OL-~azV&n>e_c;D~+e&6~d$v)>i{q%F5>xDCi9yT>~)KOFW{Tw@W z{8W9=>kayw8qp})IGpcZaQ?AV_nG{;?2t>7snaSLV=A&%nyjhguVG|qymmI~#py7r z$7_>APJd`}=%vYFr&ZPhu|MuryU|i?X>$0?y2`cD6!+_cMr_p*=!gJn#QTAcoUL3F zo!Xs7Bkotz&Lwej)Sjs|m1EJ^Z1$^Vz%}+v9acF$%o>eGaeveq040tZX|i_asLFAG z745aCJ|FiR$u;Y}xD}0T(BD`eSiEs$Q9o@CtQXtH7^_Uz-gY$VpYPaflcQ%2zdX6t zMaeN**!)7Z-j4eHxSJ-&&aAIo2UE0cFdkNuacoP~e7a=aN@ZPGn}dXwCdXl6?3H6y z2jkIjJgQ<%adPdI%FzI(gXX9g?W>NKhH-Ly;8zm=b9gCfQW3cE4be&lZ z7@9CH4sO0PtTru9k`reu*F|p*HXNiKolxAX#f?U%k4wES&yWcF1AX%Ai2!VXb(~Z= z1&9`RqfuueF1D(xt&MTIqck}ghjOC>qj3$5PCAPE^*Fiy%w-yGoW@lfj3S)f{!2)c z8?0!%0izbBadN{w>uymwhEs%YvDI<)$tlyje~VW9i0%Tk(;vmwjiZIQk=$sdaua|9 z1VdH3ATnVL2iz5?ry12p154J=>h^qf!?>E(qb?3CIdw%8$MeYAvMx8CK5so|$Nap) z{s8*|mTuuDvy~Ik7Bw0N*ky9l<;t}H?fIy2oeN2BHe0y?8>9KSx-h6mHC;9iaG{fS zYTdY6TS{(@LuAiBs2A(IL|aI1A$qxJ&)sLyoSYU`W4&rTOgmi+efspSAO7{N%U893 zr<00?!|sx^EeEP@878+}shr033Sx2EX^f+8wLSphYjxbksOXa1%C8Cg?MIzzEAEby zGXfRz_96Jr-k1(F6(ZSai>0yK5W{wQn^0o#Uu;iKF0lewcd?- z!|FoR9mmP$S@AWntGJOkvg)AOOyg0qWg0#8Q@IBI!dmxL)2JCITm3F;v5l&=PM@Sk z8If#T7CDR|HsWf*6V<5S$ar!w?zGyV>g~(Azuvez>R{UxH3j`AXRe4Id!u1l%I%mr z)L9SB-9ehJrkAs(cfa$Q{CSq7TW=6&BTMY<^nDg$q5<-Ns7TH(`jHEQM9iS)6i`y3 zZ=ecE8!aWbp5Fb4OMWo@x)UB)xgKDCyNiwLXi!b&OSgIMicYZJg%D{OPUD|Bi%jdp%;-hX=p$ZZN1elyRz3UE3kH2#B zW6yC(*J@OIgB}ia+)Hj-j3$N{gCTR^`2~>s?h%*{yPXjhzmVK+S!AJBa&13O!T5uP zxF)%?&au`V)aMsFkT)|<5L`;U5A%v7AR=7m+@(0>ldwSu zkn-)3`*2lM2zzG5;X@lVJKZ>$T@f3HU$q|m5!ooY>&#k-u?Y}$;Ds|sS5CkjJgMOl z1YW(3y=pj1?zSSKl68W5oKKUB_9l0qnTm8wIFy0m+Tz}TyabVszmtnsG-PN^`_V9M z4@T8tG-@aJ&>duL8;9U{Ec~8YT(48NLA4f*um(~7y=K+~79Z3gbxM#*fEdI9RBj1! z@7c-?u^-U6#ZVF@P_@$pF~kkX|NE>|c(8GAI9jTX``CQGiEg|Qca!_hOjR}jqu!1o zlEUe?qqJR(utukY`H8Apr~`sbllx;2 zB<_HYopj?gsm`h~-Uo8ddq2PxyBZr~bAgCCdYyinL_Vv4X{Xi4Nx?8_bYn|uE{#i| zq+WNNwRab4x;AOd9D8{ZUz9W}H|Cje;>lsEu!*X)54(d123*o|snCo1)CH}2(ssJb z8;XT@aE{{947VEdH0G8?s5}SlC8gw{sglg^O}aU0;ciDU1y6t6jjS_TN_w-h75rMb ztbRnj?P8VUT2iXNtRV_4^loV&*n@)=sZ-h?)~RuVlr{k_Hd8feGIW!Tuux`h3}SP+ zMoPxZmGwY`9)|t_0nR~aVdFxlr;&8p^Vz1d2ZW_4$)K~XEY#BXyo1SI~zW$BmwA=QR1>OoF8 z8N4>`&Iib)lVb|9s6G!~)^X;227?i9GNKT0v;HA6RQvUMH@Qs8nzD_j-|55ZEv%!5 z`i2@YP8*B_dg#<$T0d+>%Z84qgYZBhLmcjC0H;BOmOR`on(m-gQ2Hal&Gb!R5u^jF z4I;?JM#qxNEB8n1rtTPvj5jsY#ZAM3jhcgQXOM29+Yxm(VS$!zO2M_gaVzTCrai$g z*>IOVn;J2+N*vEuX;n1Fz1~u_K0h37DovC;(v?<}&Vw2lt&y8PqNWgG6%PAeK97#-ONk8vFwtVS8XzL>`@OA*Sx zl?HoUU=9gY=zc(8fl|VTWj@|TwM*R$!C}=7L}u~?Ddj%o$U>a_m!yqW-j%e550pI7 zX`ME=7VFkx@+7}SHp6ZZD0%XV6e^4$+C0fqTmgn(t5J8@j*_QN@4f<8iXSPb3^;0o zE!%@`BYE0P!wt{jsMCYvYRS{xd4{!=JVRGC$Uq{@AQqBmu83_5U;w~%pmf4zKC7?{ z=)iEJuCt3il!_%;lRSGx3NdiFgYwxMTiez4)<*IiM*?{bp{Jo$DUkgi?$q`m9G>f{ zX$XeK~3JGN)#?cfpwTr3~!tRGb#3}S*@!Lw*#`fGobo|F$a0k%VhYg_a?8&1stSQp~1jB z{d)4+6>0P)r@|TFud-KwTut5*+-h3f(7mHnvR6jK zTk|N(Q0gp)((X>)1}Nok%K&wn)dJr>b4bx%7#ztv&`!fw`0iFh&O-DoVsY;*{G$Wl zjE7it@~#4h{*}g=hmi_zD0%nv?h9*szVy3S{!@xXuuOJQ+XRW4yr=L~$LZ>9}b(-!*{T9xw({IKmqQ{v+hdZ0SPZGb=A9fL$gLUpfuH^kP zX%E&DUIhk7ojFD@jj8#P4|u#Hbb_xE8x2gKE%~6|ZU*fmCN}6dJFVnH-VG~k#LbAt z3P+^1H%2V#!-W)MWUP!efc!~50)yxPfcfO3cm5B4Bsb70BiBOq!8WFCm7hxTu@LAf zkQuj&U>QP{3Rl)Q13DjfZo#+)cLrgtgJOQd{TpE--cYTN8xa|7*iMsA7W50h3)Ti= zKiw{DxitAyzLtD_Ah#Ciyapd4#nBl=ki`6_S1LF6gMo^JxxhdfsQ?ZKwl^|xx8GBJ z5I2+0%z88nahcp=f$IRTGUJ0@C)m%QU6JF7rBj-fBxZe{XYjd|%IQAuY7qQj1W|&7 zb*4lFCHef!R#7+x%zzNT8Vu6AHAuc7&DZQS2gw(+sk0-WIJhsZRBjQroYVAywSw6C z%d^=^i^DvaTvdm!%okKF38DbegGH#=Z#s7aCsSgu+4{jts*Ygb zrOCJUOdTSH$3tg$w>F^v1nV^U_KIeMx!6d)v*$~PNi{<=n5afhLh{`eSs=x3%E0fr zwMD&MtN{L3W6+cNoP6KKHcjcUL&)t$a>exSr-Xo#JJunttWIUbAGl{y?!PZu=s4@thDSnuPzpfHP$<5uNJv>lQBU4+^b8B{0-KO8!_mL|aS# zG`;)e9asM12DjVb0W~`Cmi)QcBo!)Ff^13`j0xCZJaSvK(IP?nUuT`*;lU?=`~T%> z6~~w+e^+#d5d)Amlp&)T{{T7BOi-&q|6^@A@Skop!*Uf^x+?jXM}x9PaA5Lp-wDG5 z$$#=b{FQgt>DknpsVQe3HTF1*Tic*FuqT~DTO%YU&M>y{Ft#0{Rz`Xm^QglFOA)_n z1o28VI}|9(l23iaB!iJnn|WfZc33{rDm2b6?2K4A!ke zM3l6QcjUnaEa6et*ik?f_cX^{IuSN0;ib8LBP%&ExQh|oDs@>PeiLy2=3S_ zPIq0lF~KWDw^TvJKH2p&4CUzrl!R+133kiXu<;NYF8AiNMvBrf;K}IZwwkYNf|{Hm zO1rqeuz9Kn!?qg$W!cjr@^-bf8)|6>>nEl>MMIO7G<1}vV>kktJcW?8ah(RW-AIU( zLB)AHt0Vi7+IoN`p~47*d1zZ`NktFsR6M~!U6s0JJlx^H%4&BMUBuBB6R z*|a|2^!t?D^@T zy0i8Y6G%vXrgked!*ArJDLKT}RS8i$12`T{@$i@D{O9A4Sk01uwn2EbmWs??7P`ZN zIlgTa4#V*)d3~yjKx+w=;n}tEhN}i@@K3QFjQH?354WmVBlu7*v9*W@ibJCsj3l-l{TQ9yiHcwM#tOyf6fNbL?}3 zOe-vl;0?EW(B5th&}48g17+rB6Jh7Johv+w&%m`vdLR_);c~kTpzX^%2$ttW5aj9f zVw?q2Oh8X<7toMIq+41ZB|}SbQGF0&)MW62OliB(!&6h)%1i8R=V?)%szVH;NODva z=4V=h*&n;D@F?269q|OHAlyM0vDJ~C&tl(;oXGd6uUhj zZksm%@mA#=?GA!48bins@+kBX*()5!)B;-5+M-z|qZ)Yj$_FP*P-~GF|Gyp@;E|OKy=(6o0Xcg(!VMlR9KTCs6Jo6xxzp z0d?un?x{9%q<*h@qK?^1xIA_*Ks}kr9$ENKM!Qwila>{EESQ8MG%PQIb+yo`$98XF zvHEJ#lb0r^Rj3WbttGn;Fl8Il7oncF`?7_SoZPCgc_3>P95HrF622dBJVEE4#9*vC zq;ZqHxmS~)%FT^iEmmK^wnmZ>4t|JQ;a|)#x?DT)i}hosiLMxrRvNkJyUbtI_|k@Gd1Y4_f zNH1Am<53WmI)us>8vqCci_j+P-6pcuHq@w;EKCSg6-=>&T^e3Aq1ND*n-Q=^CY-`# z^7Aa!7-UCiHD@3jsZEuoq%Dln0pD;RV?oH@s`x@Uq69_AbClXbK49QAoY^AUJb#0P zA?b@W?Io97w8l!9H2INpSs@%3&1UsemsY@3gsa}p-dnOfed?R0biInq?5~XBqGp%0dAAF zn8(}Y(L5ALE6OgHkfhS9eI$Nie8>B$J3*LKB(DnSL2fA= zU(npxhOlYvDZrWDJqjUbQYO42ltD+mTBkK0j7gKQVHp%p?WtOu2Y{d&ur_XCBwAlg)YxS&E@P5*eAt&{&F4M{abuLPLoGgNm%WtDcZ+Kdxg4i98~GU;8f!))hKbO`e@9Eebd9x-rwd}P7(9_GxayFoTC;zvc)eSo~ zo>kkx3UzxOu_JD;Ddm})h`k*tfW=dhu zY2V(YUS!2l1|mGYrn+Z?{9=4TxVx4JNP093XfpqG-4@5PyFUP(Vt& z`vQkXA?Vg)Jei55pn|!H-)RMLL$(=6m7m7uHb{c7VKk2)ii3raq@M*V<5^b&l0 zNL*DAyGr;Vphbx{PmA`MDhMfc8iRUj9|DdidEgh|9Sv?w+EjY>~m~CRFV$sLm@J+{_|{HC#f5z$Wwgw1;7fmPHysDV<*BbFuyo~ zNg5Pp=9Af%0AH5C7hQt7X>;0_C%W;Z`0fKl35_6=mM8xeAbDzkf1syjz6uOa$MKMm zezvc%g=ZVOFd$ST{YXB3olrU_3y89OgTHWU6bga}ClJ3$7(=_GPJhY1#ouz{Mk9Tj zO?>SEBN%yP`woBSmC20ol6U(q8jE8jq0ms4O9PUux9xjC*mI>n^FjeFvs?6!c)fD9 zz7P1!14N!BFe0ER>cI_z;p_^ctX1|~C_3iSJ2LwL=a@Q5ha!yHU<97TQ zp-$lZSU7L|#pS2G<8HrngPMnOfZ_1i2j`dCPq?sCL|^4M9`;i$h#Hm6TUAnLKT{*^ zDy9h_5mp6Cb>u(ih-(x%3|Jxfo5xdy_6tYLnAA~I`z4xZJwC4oFL)~mUWpU?mE%+H z@hQxz$A0!}^`kB!Gzj1@za%~1XV6~NM?onzIR=au*i9IJG_*enLJG)6 z9}nwGC`|>+|5-39XTlITBEL|jfIm@?RR@l9@?U@`&X30X+_e8ITe9tho(=Kp-vDuo ztTbo{`utrJ(snPwTmi7$Q2z2iI5Hz?A(LVMWJ9*KFcr0n{R;s2EZNNj5M)|bzij{3 zd^9Rb^HEsqXjA(S;QlV61Q$&D%cc$?eANL-vju&X1s&E1k)i0)Kq1%^2#a;Yb_nqN zSy~BC6~!T+C`aK?U=`(mjKP(#5jdjNG;vFIn8wX0Iy(ti#JocEfO&!PiSVe;4i`3A zvotR4=L{P`G&@2_l-NYdE*>Tna&mSgR|t@Payqy~v=LB_0*d>%?2k%9WCBpTNfJYq zq+LVc>}nVkaGHj8rY&u2HGM`ZZk{3}CyfeZk=iwZBTWI%Erhj?W>ZG(;R0w|BV~Lo zjhOLSvx_i6X}~O(aSR8bstLW0PfX$pBSD7E8(%AjF#*=fXjSgI3TqX`K!WGnBDV?0xb1f!36Jtg)M>Rasw8kb50VI0H^Jj{7Rw9ZPWp<1(OK7|vkmyrRbCOuC(?nZ3B zOoaq!Mv#PArvkj5;Pc$$ELug!R!#CDSjfC5$mFoP?* z@}@)pzJ5gb9tP=kBQO-y0K$tm1I~)1m}m!Cxjv$6rC>}bx3-%D(bH{>6wWUrX<)VR z3f^uZD6PFRG(VZs1g1(TgC(d>7l?AY3{twq#mOy&K-MfHP!)pV@sR+@6}uI%yyS+Z z?|7_Rgp;K8&&a0*l0?^6PkXSLC<T`CZ6`qRa%m2-k+rAD5}mCMJi1zh%0tkV zQcic-qBOCc1F*l5;DeVi9@8$^)vW=`s{D|vuygr4C70WZ{!ofjECpfgdqQ(Pqw~Til@Rlp5hOrb-_u^{2d8dpn z-NH+r??2$sYbt_5ed5k z+AC}V}?JqdjSKn3o3Ft5qhXm;2NuMy?VQ7 z)Dt?pHz0E1QHUAaefS%NLkSILJ`AAy0sjS+`FaxR6$p3R3k>&hi}U9- zmO?xC2MWG_!eJ(HN|}@@aP~~CRWiCjYAk__F_#9TvH1YacKXCftC3Ky!Q8aON3~V6 z2z{o{yFL2gr?MG;9DD+ZV(gT4AQMB>FvP5$qFhZ^kAq2Q)wP>7dWCQ`J$O#=g2zZs zrq=kM%)uT`t`-abmRG+W>W|`TQ){9V6yJ`k?*(>86eJu)SRfvA%SLHo1dl#B9fR#* z9hT>43uDvOPjR4!7}nwH+$$(*Fkn8&hnNG<^Q_8Y)2<4|3bZ4YGC{H2ye6k$>2;BX zZ;$aA0~wkbLYH%_39%zDjT06wJwxefjQ@sEiXBUwhfC?B>z=9eDraB)?#rD!`--W9 z@EdTA>lEoJs3#N9@fST*vIhdeqmulf z4VDbbwg<71(M*mhx5zFuwVx!H!`G$JCa}1c6$VV z&8}&Cc1@R~QJ!Xhp*h@ckL0g@Y}*fTNsrPHeni`SW44So7ys~=!9by`G?>TA*sZ7= z6&-`vdd{hIFl~@W3q~#GFh0=EGZglid~k;<99CzA1vcpwA&=Dnv@{$gOX&E~+2a64 zU~01Z#ZU_sea z@vB!Ag`%NiPK37b2Ki|~C>&Ojb3mkEb(E)1pwL_iC=4Cw(GsFSG*0XpnoY4mr0wwu z9M}+ime8|j0s>!79mDc^Mu4y-cr;6?dbN6S%rXc>ql&}Va6+c{R3y1BF{9>c*`MznQ8wFI+4TNl?8VYfWrM(&G$mr0pDexG0&x*J0EyOr?6$Yz3 zbBwnV<37sa6k`-~h4OBDn-4_oQlMqAfDGyFMA|38S|Cx$c4KAHGrV^=-sBZY?ICQ| z#ks)seJ63&RMt;S2(bZ&(4BK|5+dmw(%$8`R8qyf<#^+}fe`!~R3p8IzfqkQ$og_Z zOph}7>h@kA(=DXze*Jq6>>~s)w1X!@ zU{>bz?4xYsDnsCd9~$7-g;M(%A)G`H<766=!ymlTSw5-K7mHpHuD`2 zH+<}}(*Go&p0LTFZW(6z*-zzg%G3-e*FpAaO-UNu4+d{gAOEq>0LPz%Heh9amcP;3 zNt+}4oVxSk@tb>0Dv!tQ;!_gbAD^y$o?U3f^;va-Z-Z`VZ&1ZN4Aw6YCJ~Ak#(7dy zXZ}Tks3E2=ym*f^%a_Th132~_AyAnX<7No{sLt=^yTYT2$X;?* zQDwv`3KTus_x6W+fdxMeXx|4^3PhHoUBO>VS6Rw^{{SGbQ4OnMj^cnhmMeuRKJ+W% zrveZF3AJKZ3Cxq!aUdX?75du`1)_lyxTggmt1f>e0Lv-hL(JQc@he!JJX;b@KVhqX zJdT{gVt0T3`%?gx6?HDVAo=?VzZH5PMTz>)d@NGD7p_b|f9_Cfu|YuZo0 z)c%Eb{{3(KEmu1@tjH@M%=jO)?zJh1)zdDuHT+BC!?FgyuvR1^FgpZ5|C&Sg6__9q z4h3LYu>gE@R8>xgp`|dhxV^0Y>F@~%-383VI8*j<06Rj{DDeh6uWNRsx`a;*w2SWw zwEt8%FC0j@`{J`Jb`<(8OGxl5K1g$GobXxG{41u9w6$oLT4B`nnncRJWZEu$S(Qwz zr&mVjIKXVgXozu4F{s()R)4J!l+Bq(5ZuOG#3BpQ{ z19IaXLdUN*|9pfUg8?qJW3|$Wxn)t{QbLK5hn&{|4`fdNpJeWh$FXru`>OU2<-(y4 zu1jC?;CWr!p~bC;z&)6;*s$XPF1{|oOV!sW8g{}2KuHk>HnNXR*mVGJ|53G_AgTc1 zgePjk@f=xf`uvN(u!F}T1qW(GatFmtSxxSBfqJQ34}YE}$;21*FkX?x@M<&F>ZNuP z5Zq}*-5}(BcuWR|Y=DY&g$(!E$v|2jk2hX8bL3Gk`2L-~v1{+t6#u*MpkHG@acWog zHTI2ejbxm9!TIN`{wVw8C)%&}w=G=4+kf*NKAf21e+OT}NmD!XC1j=fvVNF$zTo`r z|EpE(*vYsc2Jw{F|AdHd$gTQ=|9w(ZO_x9vRZtnJ%2pSgAWPOkDJ$Gq?J zr{Q$?--R=4xZwQ%IiO6%6rYnnj0czge*{){F4G@P!Q$ufV2aRr71^&IqfA0{1B##H z?q7_tAKk7Vrc4F2&WfKW6~pAzx4L)v9;tJSpA)P3A_Cp4?pr=xn!Cl%lZw6vzV{1f yjyU3sfPc!N8}gncW&UJ;^DPj6c0NFV;S!IaHl6xt Date: Sun, 1 Mar 2026 22:49:33 -0800 Subject: [PATCH 20/76] Formatting Signed-off-by: Keshav Santhanam --- .../inference/contexts/dynamic_context.py | 26 +++++++++++++------ .../text_generation_controller.py | 21 ++++++++++----- 2 files changed, 32 insertions(+), 15 deletions(-) diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index fb120c2d415..8ced7e89a5b 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -1353,8 +1353,10 @@ def initialize_attention_state( """Initialize attention state so that every layer can use it. Args: - construct_graph_dimensions (Optional[InferenceBatchDimensions]): The graph config to use for constructing the cuda graphs. - is_expert_parallel_dummy_cuda_graph_step (bool): Whether this is a dummy expert model parallel step. + construct_graph_dimensions (Optional[InferenceBatchDimensions]): + The graph config to use for constructing the cuda graphs. + is_expert_parallel_dummy_cuda_graph_step (bool): + Whether this is a dummy expert model parallel step. Return: None. """ @@ -1621,7 +1623,9 @@ def check_availability(self, req: DynamicInferenceRequest) -> Tuple[bool, bool, return request_can_be_added, request_tokens_can_be_added, kv_cache_available def add_request(self, req: DynamicInferenceRequest, chunk_length: Optional[int] = None) -> None: - """Add request to context. At this stage, we assume that the request is valid and can be added, as the checks are done in the schedule function. + """ + Add request to context. At this stage, we assume that the request is valid and can be added, + as the checks are done in the schedule function. Args: req (DynamicInferenceRequest): Request to add. @@ -2085,7 +2089,9 @@ def update_requests( Args: active_requests_mask (Tensor): 1D Mask tensor marking active requests. (Active request length) new_tokens (Tensor): Newly sampled tokens, with one token per active request. (Active request length) - new_speculative_tokens (Tensor): Newly sampled speculative tokens, with num_speculative tokens per active request. (num_speculative_tokens, active_request_length) + new_speculative_tokens (Tensor): Newly sampled speculative tokens, + with num_speculative tokens per active request. + (num_speculative_tokens, active_request_length) Return: (Tensor) Newly paused request IDs. @@ -2300,8 +2306,9 @@ def update_requests( :, : self.paused_request_count ].clone() - # add_ and fill_ calls seems to work as intended with sliced indexing (i.e. x[3:5].add(...) or x[3:5].fill_) - # but when another tensor is used for indexing, it does not work as expected (i.e. x[y] if x and y are torch tensors) + # add_ and fill_ calls seems to work as intended with sliced indexing + # (i.e. x[3:5].add(...) or x[3:5].fill_) but when another tensor is used + # for indexing, it does not work as expected (i.e. x[y] if x and y are torch tensors) self.request_kv_length_offsets[self.paused_request_count : self.total_request_count].add_( self.request_query_lengths[self.paused_request_count : self.total_request_count] ) @@ -2324,11 +2331,14 @@ def update_requests( sampled_tokens = next_tokens[self.paused_request_count : self.total_request_count] if self.num_speculative_tokens > 0: - # new_speculative_tokens has shape [num_spec_tokens, num_requests], slice the request dimension (dim 1) + # new_speculative_tokens has shape [num_spec_tokens, num_requests], + # slice the request dimension (dim 1) sampled_speculative_tokens = new_speculative_tokens[ :, self.paused_request_count : self.total_request_count ] - # This will become [sampled, spec1, spec2, sampled, spec1, spec2 ...] # For every request we will have the sampled token followed by the speculative tokens (i.e next indices) + # This will become [sampled, spec1, spec2, sampled, spec1, spec2 ...] + # For every request we will have the sampled token followed by the + # speculative tokens (i.e next indices) next_tokens = torch.vstack( [sampled_tokens.unsqueeze(0), sampled_speculative_tokens] ).T.reshape(-1) diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index 411ec1a448a..4275c4baa81 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -901,7 +901,9 @@ def _dynamic_step_sample_logits_and_verify_tokens( dtype=output_tokens_jumbled.dtype, ) token_order = torch.cat(token_order_list, dim=0) - # Rearrange output tokens because previously it will be in the order of the sampling_bucket request indices, but now we want to put them according to their corresponding input ids + # Rearrange output tokens because previously it will be in the order of the + # sampling_bucket request indices, but now we want to put them according to + # their corresponding input ids output_tokens[token_order] = output_tokens_jumbled mtp_output_tokens_jumbled = torch.cat( @@ -913,10 +915,12 @@ def _dynamic_step_sample_logits_and_verify_tokens( ### ================ PART 3 This part is to do the following : ================ # Create the accepted tokens tensor # For prefill it is always set to 1 - # For decode, the first token is always accepted, then we compare with input tokens and accept the next tokens if its a match + # For decode, the first token is always accepted, then we compare with input tokens + # and accept the next tokens if its a match # Then find the index of the last 1 in every request of the accepted tokens tensor # Then these are the index of the tokens that will be sent to the next forward pass - # In the example (assume 1 spec token, 2 spec tokens and 0 sepc tokens are accepted in teh first 3 requests + # In the example (assume 1 spec token, 2 spec tokens and 0 sepc tokens are accepted + # in the first 3 requests # Assume input ids : [ a5 a6s a7s | b3 b4s b5s | c6 c7s c8s | d1 d2 | e1 e2 e3 e4] # input_tokens_required: [ a5 a6s a7s | b3 b4s b5s | c6 c7s c8s | d2 | e4 ] # Size 11 @@ -997,10 +1001,12 @@ def _dynamic_step_sample_logits_and_verify_tokens( ### ================ PART 4 This part is to do the following : ================ # To fill the speculative tokens and accepted_token counts # For prefill it is always set to 1 - # For decode, the first token is always accepted, then we compare with input tokens and accept the next tokens if its a match + # For decode, the first token is always accepted, then we compare with input tokens and + # accept the next tokens if its a match # Then find the index of the last 1 in every request of the accepted tokens tensor # Then these are the index of the tokens that will be sent to the next forward pass - # In the example (assume 1 spec token, 2 spec tokens and 0 sepc tokens are accepted in teh first 3 requests + # In the example (assume 1 spec token, 2 spec tokens and 0 sepc tokens are accepted in + # the first 3 requests # input_tokens_required: [ a5 a6s a7s | b3 b4s b5s | c6 c7s c8s | d2 | e4 ] # Size 11 # Accepted tokens mask [ 1 1 0 | 1 1 1 | 1 0 0 | 1 | 1 ] @@ -1049,7 +1055,8 @@ def _dynamic_step_sample_logits(self, logits: Tensor): indices_list = [] # e.g torch sample buckets will be - # i.e (for all unique comibnation of t, topk, topk what are the associated requests indices (based on the active slices) + # i.e (for all unique comibnation of t, topk, topk what are the associated + # requests indices (based on the active slices) # [ [req at index 0, req at index 2], t1, topk1, topp1 ]] # [ [req at index 1, req at index 3, req at index 4] , t2, topk2, topp2] for indices, temp, top_k, top_p in self._torch_sampling_buckets: @@ -1373,7 +1380,7 @@ async def async_generate_output_tokens_dynamic_batch( mtp_logits = None if logits_and_mtp_logits.shape[0] > 1: logits = logits_and_mtp_logits[:1] # [1, seq_len, vocab_size] - mtp_logits = logits_and_mtp_logits[1:] # [num_speculative_tokens, seq_len, vocab_size]\ + mtp_logits = logits_and_mtp_logits[1:] # [num_speculative_tokens, seq_len, vocab_size] else: logits = logits_and_mtp_logits From 9917e353bd07d9ac0cbbbf8929b549733f8144b3 Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Sun, 1 Mar 2026 22:52:50 -0800 Subject: [PATCH 21/76] Linting Signed-off-by: Keshav Santhanam --- megatron/core/transformer/attention.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index 0ac7e78fae3..310a59bde35 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -60,9 +60,7 @@ rearrange = None try: - from flash_attn_3.flash_attn_interface import ( - _flash_attn_forward, - ) + from flash_attn_3.flash_attn_interface import _flash_attn_forward from flash_attn_3.flash_attn_interface import ( flash_attn_with_kvcache as flash_attn3_with_kvcache, ) @@ -73,9 +71,7 @@ if not HAVE_FA3: try: - from flashattn_hopper.flash_attn_interface import ( - _flash_attn_forward, - ) + from flashattn_hopper.flash_attn_interface import _flash_attn_forward from flashattn_hopper.flash_attn_interface import ( flash_attn_with_kvcache as flash_attn3_with_kvcache, ) From eff0fa1b9434516d945f70cd509e8f98a1d4fa85 Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Sun, 1 Mar 2026 23:07:40 -0800 Subject: [PATCH 22/76] Linting / copyright Signed-off-by: Keshav Santhanam --- megatron/core/inference/engines/dynamic_engine.py | 4 ---- megatron/core/ssm/ops/__init__.py | 1 + megatron/core/ssm/ops/causal_conv1d_triton.py | 9 +++++++++ megatron/core/ssm/ops/mamba_ssm.py | 5 +++++ 4 files changed, 15 insertions(+), 4 deletions(-) diff --git a/megatron/core/inference/engines/dynamic_engine.py b/megatron/core/inference/engines/dynamic_engine.py index 865ef8bc686..efd2a5411a0 100644 --- a/megatron/core/inference/engines/dynamic_engine.py +++ b/megatron/core/inference/engines/dynamic_engine.py @@ -788,14 +788,10 @@ def _add_request( len(request.prompt_tokens) + request.sampling_params.num_tokens_to_generate > self.context.max_sequence_length ) or (request.sampling_params.num_tokens_to_generate < 0): - if torch.distributed.get_rank() == 0: - print(f"REQUEST {request_id} FAILED! MaxSequenceLengthOverflowError") request.status = Status.FAILED request.add_event_error_nontransient(MaxSequenceLengthOverflowError(request_id)) if len(request.prompt_tokens) > self.context.max_tokens and not self.enable_chunked_prefill: - if torch.distributed.get_rank() == 0: - print(f"REQUEST {request_id} FAILED! TokenOverflowError") request.status = Status.FAILED request.add_event_error_nontransient(TokenOverflowError(request_id)) diff --git a/megatron/core/ssm/ops/__init__.py b/megatron/core/ssm/ops/__init__.py index e69de29bb2d..3e4afde2e29 100644 --- a/megatron/core/ssm/ops/__init__.py +++ b/megatron/core/ssm/ops/__init__.py @@ -0,0 +1 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. diff --git a/megatron/core/ssm/ops/causal_conv1d_triton.py b/megatron/core/ssm/ops/causal_conv1d_triton.py index 3a412be3da8..b7ce54684da 100644 --- a/megatron/core/ssm/ops/causal_conv1d_triton.py +++ b/megatron/core/ssm/ops/causal_conv1d_triton.py @@ -1,3 +1,9 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +# Some of this code was adopted from https://github.com/Dao-AILab/causal-conv1d/ +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + import torch import triton import triton.language as tl @@ -102,6 +108,8 @@ def causal_conv1d_update_kernel( # If circular, we only need to read the base cache sequence length once if IS_CIRCULAR: base_cache_seqlen = tl.load(cache_seqlens_ptr + batch_id) + else: + base_cache_seqlen = None # Loop over the sequence dimension (e.g., speculative tokens) for s in range(seq_len): @@ -214,6 +222,7 @@ def causal_conv1d_update( conv_state_indices: torch.Tensor | None, intermediate_conv_states: torch.Tensor | None = None, ) -> torch.Tensor: + """Triton implementation of causal_conv1d_update.""" # Check if input is 2D, temporarily treat as 3D for uniform processing is_2d = x.dim() == 2 diff --git a/megatron/core/ssm/ops/mamba_ssm.py b/megatron/core/ssm/ops/mamba_ssm.py index 949141871b1..7427440915b 100644 --- a/megatron/core/ssm/ops/mamba_ssm.py +++ b/megatron/core/ssm/ops/mamba_ssm.py @@ -1,5 +1,10 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2024, Tri Dao, Albert Gu. +# Some of this code was adopted from https://github.com/state-spaces/mamba/ +# This source code is licensed under the Apache license found in the +# LICENSE file in the root directory of this source tree. + import torch import triton import triton.language as tl From 0d05f8bfbd655e807653dcee62439b794b06a2bd Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Sun, 1 Mar 2026 23:12:02 -0800 Subject: [PATCH 23/76] Linting Signed-off-by: Keshav Santhanam --- megatron/core/ssm/ops/causal_conv1d_triton.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/megatron/core/ssm/ops/causal_conv1d_triton.py b/megatron/core/ssm/ops/causal_conv1d_triton.py index b7ce54684da..dbe585a107e 100644 --- a/megatron/core/ssm/ops/causal_conv1d_triton.py +++ b/megatron/core/ssm/ops/causal_conv1d_triton.py @@ -47,6 +47,7 @@ def causal_conv1d_update_kernel( HAS_INT_STATE: tl.constexpr, SILU_ACTIVATION: tl.constexpr, ): + """Triton implementation of causal_conv1d_update (kernel).""" batch_id = tl.program_id(0) channel_block_id = tl.program_id(1) @@ -222,7 +223,7 @@ def causal_conv1d_update( conv_state_indices: torch.Tensor | None, intermediate_conv_states: torch.Tensor | None = None, ) -> torch.Tensor: - """Triton implementation of causal_conv1d_update.""" + """Triton implementation of causal_conv1d_update (entrypoint).""" # Check if input is 2D, temporarily treat as 3D for uniform processing is_2d = x.dim() == 2 From 5947e3a3cdf824f649b59689e2d184378051e05a Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Thu, 5 Mar 2026 11:24:08 -0800 Subject: [PATCH 24/76] Bug fixes Signed-off-by: Keshav Santhanam --- megatron/core/inference/batch_dimensions_utils.py | 15 ++++++++------- .../core/inference/contexts/dynamic_context.py | 6 ++++-- .../data_parallel_inference_coordinator.py | 6 ++---- 3 files changed, 14 insertions(+), 13 deletions(-) diff --git a/megatron/core/inference/batch_dimensions_utils.py b/megatron/core/inference/batch_dimensions_utils.py index 37ef76fbbee..e7969bf5e88 100644 --- a/megatron/core/inference/batch_dimensions_utils.py +++ b/megatron/core/inference/batch_dimensions_utils.py @@ -382,7 +382,7 @@ def add_if_valid(token_count: int, prefill_req_count: int, decode_req_count: int ): cuda_graph_max_tokens = max_tokens - assert cuda_graph_max_tokens == max_requests, ( + assert cuda_graph_max_tokens == max_requests * (num_speculative_tokens + 1), ( f"cuda_graph_max_tokens ({cuda_graph_max_tokens}) must equal max_requests " f"({max_requests}). This is required for correctly syncing EP ranks: " f"prefill and decode graph pools must have the same token count granularity." @@ -427,15 +427,16 @@ def add_if_valid(token_count: int, prefill_req_count: int, decode_req_count: int # Use decode-specific token counts for decode-only graphs for size in cuda_graph_decode_token_counts: decode_req_count = min(size // (num_speculative_tokens + 1), max_requests) + token_count = decode_req_count * (num_speculative_tokens + 1) + token_count = token_count // tp_size * tp_size add_if_valid( - token_count=decode_req_count * (num_speculative_tokens + 1), - prefill_req_count=0, - decode_req_count=decode_req_count, + token_count=token_count, prefill_req_count=0, decode_req_count=decode_req_count ) else: # Mixed prefill and decode mode # Create prefill and mixed dimensions with full token counts for size in cuda_graph_prefill_token_counts: + assert size % tp_size == 0 prefill_req_count = min(cuda_graph_mixed_prefill_request_count, max_requests) decode_req_count = max( 0, @@ -465,10 +466,10 @@ def add_if_valid(token_count: int, prefill_req_count: int, decode_req_count: int # Create decode-only dimensions with optimized token counts for size in cuda_graph_decode_token_counts: decode_req_count = min(size // (num_speculative_tokens + 1), max_requests) + token_count = decode_req_count * (num_speculative_tokens + 1) + token_count = token_count // tp_size * tp_size add_if_valid( - token_count=decode_req_count * (num_speculative_tokens + 1), - prefill_req_count=0, - decode_req_count=decode_req_count, + token_count=token_count, prefill_req_count=0, decode_req_count=decode_req_count ) # Remove duplicates and sort by prefill token count diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index 1100f125014..f3dcd4d326f 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -1332,7 +1332,9 @@ def add_dummy_requests_for_expert_parallel_step(self) -> None: pass can run without error. """ - smallest_cuda_graph_dimensions = min(self.cuda_graph_batch_dimensions_list) + smallest_cuda_graph_dimensions = min( + [x for x in self.cuda_graph_batch_dimensions_list if x.prefill_req_count == 0] + ) # the smallest cuda graph is decode only. assert smallest_cuda_graph_dimensions.prefill_req_count == 0 @@ -1608,7 +1610,7 @@ def current_input_and_position_ids( (Tuple[Tensor, Tensor]) Flattened active input and position IDs. """ num_tokens = num_warmup_tokens or self.padded_active_token_count - assert num_tokens >= self.batch_dimensions.decode_req_count * ( + assert num_tokens >= self.padded_batch_dimensions.decode_req_count * ( self.num_speculative_tokens + 1 ) return ( diff --git a/megatron/core/inference/data_parallel_inference_coordinator.py b/megatron/core/inference/data_parallel_inference_coordinator.py index 60ca06819e7..a9b2445b5e5 100644 --- a/megatron/core/inference/data_parallel_inference_coordinator.py +++ b/megatron/core/inference/data_parallel_inference_coordinator.py @@ -389,14 +389,12 @@ def start(self): return if request_hashes: - self._update_rank_hashes(next_data_parallel_rank_identity, request_hashes) + self._update_rank_hashes(next_identity, request_hashes) if self.schedule_records is not None: self.schedule_records.append( { "request_id": request_id, - "rank_index": self.identity_to_rank_index[ - next_data_parallel_rank_identity - ], + "rank_index": self.identity_to_rank_index[next_identity], "num_hashes": len(request_hashes), } ) From 789f6e8887678f6e5cdb0c1ba4c26139b7dc37ea Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Thu, 5 Mar 2026 11:45:56 -0800 Subject: [PATCH 25/76] More fixes Signed-off-by: Keshav Santhanam --- .../inference/contexts/attention_context/mamba_metadata.py | 7 ------- tests/unit_tests/inference/test_batch_dimension_utils.py | 2 +- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/megatron/core/inference/contexts/attention_context/mamba_metadata.py b/megatron/core/inference/contexts/attention_context/mamba_metadata.py index caa8bc6a2cd..34a19cf0394 100644 --- a/megatron/core/inference/contexts/attention_context/mamba_metadata.py +++ b/megatron/core/inference/contexts/attention_context/mamba_metadata.py @@ -64,11 +64,6 @@ def __init__(self, max_requests: int, max_tokens: int): (2,), dtype=torch.int32, device=self.device ) - # Map from requests to accepted tokens in speculative decoding - self._num_accepted_tokens_buffer = torch.zeros( - (self.max_requests,), dtype=torch.int32, device=self.device - ) - # Allocator for Mamba state slots self.mamba_state_free_slots = torch.arange( self.max_requests, dtype=torch.int32, device=torch.cuda.current_device() @@ -100,7 +95,6 @@ def reset_varlen_metadata(self) -> None: self.seq_idx = None self.device_decode_prefill = None self.device_chunked_prefill = None - self.num_accepted_tokens = None def update( self, @@ -181,7 +175,6 @@ def update( if padded_decode_count > real_decode_count: self._batch_indices_decode_buffer[real_decode_count:padded_decode_count] = -1 self.batch_indices_decode = self._batch_indices_decode_buffer[:padded_decode_count] - self.num_accepted_tokens = self._num_accepted_tokens_buffer[:padded_decode_count] # Determine if we have a chunked prefill request and adjust counts for regular prefill regular_prefill_count = real_prefill_count diff --git a/tests/unit_tests/inference/test_batch_dimension_utils.py b/tests/unit_tests/inference/test_batch_dimension_utils.py index 42613810724..d92d4eb4c81 100644 --- a/tests/unit_tests/inference/test_batch_dimension_utils.py +++ b/tests/unit_tests/inference/test_batch_dimension_utils.py @@ -395,7 +395,7 @@ def test_generate_graphs_with_speculative_tokens(self, num_speculative_tokens): graph_list, _ = CUDAGraphBatchDimensionBuilder.generate_cuda_graph_batch_dimensions_list( tp_size=TP_SIZE, num_cuda_graphs=4, - cuda_graph_max_tokens=MAX_REQUESTS, + cuda_graph_max_tokens=MAX_REQUESTS * (num_speculative_tokens + 1), cuda_graph_mixed_prefill_request_count=MIXED_PREFILL_COUNT, max_requests=MAX_REQUESTS, max_tokens=MAX_TOKENS, From 56b84f532dee6c0a4ec2b39dff4bad80e6d81e0b Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Thu, 5 Mar 2026 12:59:12 -0800 Subject: [PATCH 26/76] Minor fixes Signed-off-by: Keshav Santhanam --- megatron/core/inference/engines/dynamic_engine.py | 5 ----- tests/unit_tests/inference/test_batch_dimension_utils.py | 2 +- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/megatron/core/inference/engines/dynamic_engine.py b/megatron/core/inference/engines/dynamic_engine.py index e9ccb2ed260..06e3f1e09aa 100644 --- a/megatron/core/inference/engines/dynamic_engine.py +++ b/megatron/core/inference/engines/dynamic_engine.py @@ -215,11 +215,6 @@ def __init__(self, controller: TextGenerationController, context: DynamicInferen self.num_speculative_tokens <= self.controller.num_mtp_heads ), f"Number of speculative tokens {self.num_speculative_tokens} must be less than or equal to number of MTP heads {self.controller.num_mtp_heads}" - # assert ( - # not self.enable_chunked_prefill - # ), "Chunked prefill is not supported with speculative tokens" - - # Initialize MTP sampling tensor now that num_speculative_tokens is set self.controller._init_mtp_sampling_tensor() self.track_paused_request_events = inference_config.track_paused_request_events diff --git a/tests/unit_tests/inference/test_batch_dimension_utils.py b/tests/unit_tests/inference/test_batch_dimension_utils.py index d92d4eb4c81..f899f8b1c97 100644 --- a/tests/unit_tests/inference/test_batch_dimension_utils.py +++ b/tests/unit_tests/inference/test_batch_dimension_utils.py @@ -446,7 +446,7 @@ def test_ep_sync_with_speculative_tokens(self, num_cuda_graphs): graph_list, _ = CUDAGraphBatchDimensionBuilder.generate_cuda_graph_batch_dimensions_list( tp_size=TP_SIZE, num_cuda_graphs=num_cuda_graphs, - cuda_graph_max_tokens=MAX_REQUESTS, + cuda_graph_max_tokens=MAX_REQUESTS * (num_speculative_tokens + 1), cuda_graph_mixed_prefill_request_count=MIXED_PREFILL_COUNT, max_requests=MAX_REQUESTS, max_tokens=MAX_TOKENS, From 1a852c673cdeb6a1ea110502174500911381da0f Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Thu, 5 Mar 2026 13:59:05 -0800 Subject: [PATCH 27/76] Add softplus Signed-off-by: Keshav Santhanam --- megatron/core/ssm/ops/mamba_ssm.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/megatron/core/ssm/ops/mamba_ssm.py b/megatron/core/ssm/ops/mamba_ssm.py index 7427440915b..cd2041eb084 100644 --- a/megatron/core/ssm/ops/mamba_ssm.py +++ b/megatron/core/ssm/ops/mamba_ssm.py @@ -8,7 +8,24 @@ import torch import triton import triton.language as tl -from mamba_ssm.ops.triton.softplus import softplus +from packaging import version + +TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0") + + +if TRITON3: + + @triton.jit + def softplus(dt): + """Optimized softplus.""" + return tl.math.log(tl.math.exp(dt) + 1) + +else: + + @triton.jit + def softplus(dt): + """Optimized softplus.""" + return tl.math.log1p(tl.exp(dt)) @triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None}) From 3549711a641b84bd2aac8ca416c9540fac0b275c Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Thu, 5 Mar 2026 14:10:29 -0800 Subject: [PATCH 28/76] Remove dead code Signed-off-by: Keshav Santhanam --- megatron/core/ssm/mamba_mixer.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/megatron/core/ssm/mamba_mixer.py b/megatron/core/ssm/mamba_mixer.py index 161fd801cb9..7eb75d43b74 100644 --- a/megatron/core/ssm/mamba_mixer.py +++ b/megatron/core/ssm/mamba_mixer.py @@ -25,6 +25,8 @@ ) from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.ssm.ops.causal_conv1d_triton import causal_conv1d_update +from megatron.core.ssm.ops.mamba_ssm import selective_state_update from megatron.core.tensor_parallel import get_cuda_rng_tracker from megatron.core.transformer import TransformerConfig from megatron.core.transformer.module import MegatronModule @@ -45,20 +47,11 @@ from .mamba_context_parallel import MambaContextParallel try: - # from mamba_ssm.ops.triton.selective_state_update import selective_state_update - from megatron.core.ssm.ops.mamba_ssm import selective_state_update -except ImportError: - selective_state_update = None - -try: - # from causal_conv1d import causal_conv1d_fn, causal_conv1d_update from causal_conv1d import causal_conv1d_fn from causal_conv1d.causal_conv1d_varlen import causal_conv1d_varlen_states - from megatron.core.ssm.ops.causal_conv1d_triton import causal_conv1d_update except ImportError: causal_conv1d_fn = None - causal_conv1d_update = None try: from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated @@ -463,8 +456,6 @@ def _dynamic_inference(self, hidden_states: torch.Tensor, context: DynamicInfere zxBCdt_decode = zxBCdt[:decode_token_count] if prefill_req_count > 0 else zxBCdt # Reshape from [N*S, 1, d] to [N, S, d] for the 3D Triton kernels - # if self.layer_number == 1: - # torch.distributed.breakpoint(0) zxBCdt_decode = zxBCdt_decode.squeeze(1).view(decode_req_count, seq_len, -1) # Get sequence lengths for the circular buffer calculation @@ -968,6 +959,7 @@ def _ssm_decode( # Conv step if causal_conv1d_update is None: + # TODO(ksanthanam): Consider deprecating this path assert seq_len == 1, "Native PyTorch fallback only supports 1 token at a time" xBC_squeeze = xBC.squeeze(1) conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) @@ -1007,6 +999,7 @@ def _ssm_decode( # SSM step if selective_state_update is None: + # TODO(ksanthanam): Consider deprecating this path assert seq_len == 1, "Native PyTorch fallback only supports 1 token at a time" x = x.squeeze(1) From 7faee8349dfd9ff20647f39018b22b43cb824884 Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Thu, 5 Mar 2026 14:36:33 -0800 Subject: [PATCH 29/76] Fix flaky test Signed-off-by: Keshav Santhanam --- tests/unit_tests/inference/contexts/test_dynamic_context.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/unit_tests/inference/contexts/test_dynamic_context.py b/tests/unit_tests/inference/contexts/test_dynamic_context.py index 98be2452cb7..4c4997929dd 100644 --- a/tests/unit_tests/inference/contexts/test_dynamic_context.py +++ b/tests/unit_tests/inference/contexts/test_dynamic_context.py @@ -827,6 +827,10 @@ def test_release_memory_blocks_for_finished_requests(self, is_hybrid_model): dynamic_context.request_to_kv_block_ids[i, 0] = initial_blocks[i] dynamic_context.request_query_lengths[i] = 1 dynamic_context.request_ids[i] = i + dynamic_context.request_last_kv_block_id[i] = initial_blocks[i] + dynamic_context.request_last_kv_block_offset[i] = 0 + dynamic_context.request_kv_block_counts[i] = 1 + dynamic_context.request_in_prefill_status_tensor[i] = 0 if is_hybrid_model: dynamic_context.mamba_conv_states[:, i, :, :].fill_( float(i + 1) From fc806ef188255e28a106c8802c3b132380b25faa Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Thu, 5 Mar 2026 14:41:39 -0800 Subject: [PATCH 30/76] More flaky test fixes Signed-off-by: Keshav Santhanam --- tests/unit_tests/inference/contexts/test_dynamic_context.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/unit_tests/inference/contexts/test_dynamic_context.py b/tests/unit_tests/inference/contexts/test_dynamic_context.py index 4c4997929dd..5983e3b464b 100644 --- a/tests/unit_tests/inference/contexts/test_dynamic_context.py +++ b/tests/unit_tests/inference/contexts/test_dynamic_context.py @@ -919,6 +919,11 @@ def test_finished_requests_with_multiple_blocks(self, is_hybrid_model): for i in range(3): dynamic_context.request_query_lengths[i] = 1 dynamic_context.request_ids[i] = i + dynamic_context.request_last_kv_block_id[i] = dynamic_context.request_to_kv_block_ids[ + i, dynamic_context.request_kv_block_counts[i] - 1 + ] + dynamic_context.request_last_kv_block_offset[i] = 0 + dynamic_context.request_in_prefill_status_tensor[i] = 0 if is_hybrid_model: dynamic_context.mamba_conv_states[:, i, :, :].fill_(float(i + 1)) dynamic_context.mamba_ssm_states[:, i, :, :, :].fill_(float(i + 1)) @@ -1733,6 +1738,7 @@ def test_chunked_prefill_speculative_offset_math(self): enable_chunked_prefill=True, ) ctx = DynamicInferenceContext(model_config=model_config, inference_config=inference_config) + ctx.reset_tensors() # Setup a request that is already mid-chunked-prefill ctx.total_request_count = 1 From fadbc0cd7d362c3d65c81315fd321c00706f0c3f Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Thu, 5 Mar 2026 15:49:16 -0800 Subject: [PATCH 31/76] Address claude's comments Signed-off-by: Keshav Santhanam --- .../inference/contexts/dynamic_context.py | 1 - .../core/inference/engines/dynamic_engine.py | 16 ++++------ .../text_generation_controller.py | 3 +- .../inference/engines/test_dynamic_engine.py | 31 +++++++++++++++++++ .../test_text_generation_controller.py | 8 ++--- 5 files changed, 43 insertions(+), 16 deletions(-) diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index f3dcd4d326f..e9f1fb8dd12 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -2564,7 +2564,6 @@ def update_requests( self.paused_request_count, self.total_request_count, device=torch.cuda.current_device() ).repeat_interleave(num_generated_tokens) - # shan : Same as token_to_pos_ids ? self.token_to_position_in_request[: self.active_token_count] = self.token_to_pos_ids[ : self.active_token_count ] diff --git a/megatron/core/inference/engines/dynamic_engine.py b/megatron/core/inference/engines/dynamic_engine.py index 06e3f1e09aa..746adbc6590 100644 --- a/megatron/core/inference/engines/dynamic_engine.py +++ b/megatron/core/inference/engines/dynamic_engine.py @@ -1242,17 +1242,13 @@ def _check_stop_words_for_request_post_append(self, request: DynamicInferenceReq for stop_word_ids in request.stop_word_ids: stop_len = len(stop_word_ids) if len(generated_tokens) >= stop_len: - # Check if the last stop_len tokens match the stop word - if stop_len > self.num_speculative_tokens: - if list(generated_tokens[-stop_len:]) == stop_word_ids: + # Check the last stop_len tokens shifting by 1 up to num_speculative_tokens. + # We do this regardless of stop_len because speculative decoding can append + # multiple tokens at once, meaning the stop word might end at any of those positions. + for i in range(self.num_speculative_tokens + 1): + end_idx = -i if i > 0 else None + if list(generated_tokens[-stop_len - i : end_idx]) == stop_word_ids: return True - else: - # Check the last stop len tokens shifting by 1 up to num_speculative_tokens - for i in range(self.num_speculative_tokens + 1): - end_idx = -i if i > 0 else None - if list(generated_tokens[-stop_len - i : end_idx]) == stop_word_ids: - return True - return False def get_prefix_coordination_metrics(self) -> dict: diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index 920d13dc48d..55df5a95334 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -784,7 +784,8 @@ def _rewind_kv_cache(self): def _dynamic_step_sample_logits_and_verify_tokens( self, logits: Tensor, mtp_logits: Tensor, input_ids: Tensor ): - f"""Sample tokens from logits for dynamic batching with speculative tokens and verify the tokens. + """ + Sample tokens from logits for dynamic batching with speculative tokens and verify the tokens. """ context = self.inference_wrapped_model.inference_context active_request_count = context.total_request_count - context.paused_request_count diff --git a/tests/unit_tests/inference/engines/test_dynamic_engine.py b/tests/unit_tests/inference/engines/test_dynamic_engine.py index 94e21d4c6e4..a2cbed87d5e 100644 --- a/tests/unit_tests/inference/engines/test_dynamic_engine.py +++ b/tests/unit_tests/inference/engines/test_dynamic_engine.py @@ -2110,3 +2110,34 @@ def test_speculative_stop_word_hit(self): stop_hit = env.engine._check_stop_words_for_request_post_append(req) assert stop_hit is True + + @pytest.mark.internal + @pytest.mark.skipif( + not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" + ) + @torch.inference_mode() + def test_speculative_long_stop_word_hit(self): + """Test that if an accepted speculative token completes a long stop word + (length > num_speculative_tokens), it is correctly detected.""" + + test_config = DynamicEngineTestConfig( + num_requests=0, num_speculative_tokens=2, materialize_only_last_token_logits=False + ) + env = self._build_test_env(test_config) + + # Mock request with a stop word + req = DynamicInferenceRequest( + request_id=0, + prompt_tokens=torch.tensor([1, 2, 3], device='cuda'), + sampling_params=SamplingParams(num_tokens_to_generate=10), + ) + # Stop word length 3 > num_speculative_tokens (2) + req.stop_word_ids = [[98, 99, 100]] + + # Fast-forward state: base tokens were generated up to 99 + req.generated_tokens = [98, 99] + tokens_to_append = [100, 101] # Completes stop word at index -2 + req.generated_tokens += tokens_to_append + + stop_hit = env.engine._check_stop_words_for_request_post_append(req) + assert stop_hit is True diff --git a/tests/unit_tests/inference/text_generation_controllers/test_text_generation_controller.py b/tests/unit_tests/inference/text_generation_controllers/test_text_generation_controller.py index eb0a6dfffdd..833416a54f9 100644 --- a/tests/unit_tests/inference/text_generation_controllers/test_text_generation_controller.py +++ b/tests/unit_tests/inference/text_generation_controllers/test_text_generation_controller.py @@ -1035,17 +1035,17 @@ def test_speculative_verify_tokens(self): input_ids = torch.tensor([[10, 11, 12, 20, 21, 22]], device='cuda') # We need the sampling function to return a 1D tensor for base logits, - # and a 2D tensor for the MTP logits to satisfy torch.cat(dim=1). + # and a 1D tensor for the flattened MTP logits. def mock_sampling_func(logits, *args, **kwargs): - if logits.dim() == 2: + if logits.shape[0] == 6: # Base logits -> return 1D tensor of shape [6] # Req 1: Predicts [11, 12, 99]. Matches T1, T2. Rejects T3. -> Accepts 2 spec tokens. # Req 2: Predicts [99, 22, 23]. Fails at first spec token (99 != 21). -> Accepts 0 spec tokens. return torch.tensor([11, 12, 99, 99, 22, 23], dtype=torch.long, device='cuda') else: - # MTP logits -> return 2D tensor of shape [num_speculative_tokens, 6] + # MTP logits -> return 1D tensor of shape [12] # The verification logic only uses base tokens, so we can return zeros here. - return torch.zeros((2, 6), dtype=torch.long, device='cuda') + return torch.zeros((12,), dtype=torch.long, device='cuda') # Override sampling to return our predictable mock outputs self.text_generation_controller._torch_sampling_buckets = [([0, 1], 1.0, 1, 0.0)] From 5f3c141bf02a963741b158d7987d06518f913860 Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Thu, 5 Mar 2026 21:15:16 -0800 Subject: [PATCH 32/76] Chunked prefill fix Signed-off-by: Keshav Santhanam --- .../attention_context/mamba_metadata.py | 11 ++++++++ .../inference/contexts/dynamic_context.py | 26 ++++++++++++----- megatron/core/ssm/mamba_mixer.py | 28 +++++++++++++++---- 3 files changed, 53 insertions(+), 12 deletions(-) diff --git a/megatron/core/inference/contexts/attention_context/mamba_metadata.py b/megatron/core/inference/contexts/attention_context/mamba_metadata.py index 34a19cf0394..baeb5782b30 100644 --- a/megatron/core/inference/contexts/attention_context/mamba_metadata.py +++ b/megatron/core/inference/contexts/attention_context/mamba_metadata.py @@ -41,6 +41,11 @@ def __init__(self, max_requests: int, max_tokens: int): (1,), -1, dtype=torch.int32, device=self.device ) + # Cache sequence lengths for the chunked prefill request + self._chunked_prefill_cache_seqlens_buffer = torch.zeros( + (1,), dtype=torch.int32, device=self.device + ) + # Map from token id to request id for active prefill requests self._seq_idx_buffer = torch.full( (1, self.max_tokens), -1, dtype=torch.int32, device=self.device @@ -95,12 +100,14 @@ def reset_varlen_metadata(self) -> None: self.seq_idx = None self.device_decode_prefill = None self.device_chunked_prefill = None + self.chunked_prefill_cache_seqlens = None def update( self, active_mamba_indices: torch.Tensor, token_to_request_idx: torch.Tensor, cu_seqlens: torch.Tensor, + request_kv_length_offsets: torch.Tensor, batch_dimensions: InferenceBatchDimensions, padded_batch_dimensions: InferenceBatchDimensions, enable_chunked_prefill: bool, @@ -188,6 +195,10 @@ def update( # Update chunked prefill indices self._batch_indices_chunked_prefill_buffer[0] = active_mamba_indices[chunked_req_idx] self.batch_indices_chunked_prefill = self._batch_indices_chunked_prefill_buffer + + # Update chunked prefill cache seqlen + self._chunked_prefill_cache_seqlens_buffer[0] = request_kv_length_offsets[chunked_req_idx] + self.chunked_prefill_cache_seqlens = self._chunked_prefill_cache_seqlens_buffer else: self.batch_indices_chunked_prefill = None diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index e9f1fb8dd12..97341a93cca 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -1426,10 +1426,20 @@ def initialize_attention_state( self.padded_batch_dimensions = best_graph else: if self.is_decode_only(): - padded_decode_req_count = min( - self.max_requests, self.round_up_requests(self.num_decode_requests) - ) - padded_token_count = padded_decode_req_count * (self.num_speculative_tokens + 1) + if self.num_speculative_tokens > 0: + padded_decode_req_count = min( + self.max_requests, self.round_up_requests(self.num_decode_requests) + ) + padded_token_count = padded_decode_req_count * ( + self.num_speculative_tokens + 1 + ) + else: + padded_token_count = min( + self.max_tokens, + self.max_requests, + self.round_up_tokens(self.active_token_count), + ) + padded_decode_req_count = padded_token_count padded_prefill_req_count = 0 else: padded_token_count = self.round_up_tokens(self.active_token_count) @@ -1504,6 +1514,7 @@ def initialize_attention_state( active_mamba_indices_view, token_to_request_idx_view, cu_seqlens, + request_kv_length_offsets_view, batch_dimensions=attn_dimensions, padded_batch_dimensions=self.padded_batch_dimensions, enable_chunked_prefill=self.is_chunked_prefill_enabled(), @@ -2504,9 +2515,10 @@ def update_requests( if self.paused_request_count > 0: self.paused_tokens = next_tokens[: self.paused_request_count].clone() - self.paused_speculative_tokens = new_speculative_tokens[ - :, : self.paused_request_count - ].clone() + if new_speculative_tokens is not None: + self.paused_speculative_tokens = new_speculative_tokens[ + :, : self.paused_request_count + ].clone() # add_ and fill_ calls seems to work as intended with sliced indexing # (i.e. x[3:5].add(...) or x[3:5].fill_) but when another tensor is used diff --git a/megatron/core/ssm/mamba_mixer.py b/megatron/core/ssm/mamba_mixer.py index 7eb75d43b74..3a3fb83e22f 100644 --- a/megatron/core/ssm/mamba_mixer.py +++ b/megatron/core/ssm/mamba_mixer.py @@ -551,6 +551,7 @@ def _dynamic_inference_prefill( ssm_state=ssm_state, batch_indices=metadata.batch_indices_chunked_prefill, is_chunked_prefill=True, + cache_seqlens=metadata.chunked_prefill_cache_seqlens, ) # Update zxBCdt to contain the remaining slice for regular prefill @@ -720,6 +721,7 @@ def _ssm_prefill( return_varlen_states: bool = False, batch_indices: Optional[torch.Tensor] = None, is_chunked_prefill: bool = False, + cache_seqlens: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Performs SSM computation for inference prefill step. @@ -791,17 +793,33 @@ def _ssm_prefill( # Maintain channels-last memory layout to use initial_states for causal_conv1d_fn # See https://github.com/Dao-AILab/causal-conv1d/blob/69e6dadc28b169a4c49cb86b586f64ee90242c70/csrc/causal_conv1d.cpp#L200 # pylint: disable=line-too-long assert batch_indices is not None + assert cache_seqlens is not None state_len = conv_state.shape[-1] + + # Read last (d_conv - 1) tokens from the circular buffer + seq_len = cache_seqlens.view(-1, 1, 1) + gather_indices = (seq_len - self.d_conv + 1 + torch.arange(self.d_conv - 1, device=conv_state.device).view(1, 1, -1)) % state_len + gather_indices = gather_indices.expand(len(batch_indices), conv_state.shape[1], self.d_conv - 1) + initial_conv_state = torch.gather(conv_state[batch_indices], 2, gather_indices) + initial_conv_state = ( - conv_state[batch_indices, :, -self.d_conv + 1 :] - .permute(0, 2, 1) + initial_conv_state.permute(0, 2, 1) .contiguous() .transpose(1, 2) ) xBC = xBC.transpose(1, 2) - tensor_masked_update( - conv_state, batch_indices, F.pad(xBC, (state_len - xBC.shape[-1], 0)) - ) + + # We only need to retain at most the last `state_len` tokens of the chunk + chunk_len = xBC.shape[-1] + copy_len = min(chunk_len, state_len) + xBC_tail = xBC[..., -copy_len:] + + update_indices = (seq_len + chunk_len - copy_len + torch.arange(copy_len, device=conv_state.device).view(1, 1, -1)) % state_len + update_indices = update_indices.expand(len(batch_indices), conv_state.shape[1], copy_len) + + conv_state_slice = conv_state[batch_indices] + conv_state_slice.scatter_(2, update_indices, xBC_tail) + tensor_masked_update(conv_state, batch_indices, conv_state_slice) else: # transpose: b l pd --> b pd l xBC = rearrange(xBC, "b l d -> b d l").contiguous() From a51d979f14dbcde5a9805a0b3cb4841f8338c325 Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Thu, 5 Mar 2026 22:32:56 -0800 Subject: [PATCH 33/76] Move cache_seqlens_decode into mamba_metadata.py Signed-off-by: Keshav Santhanam --- .../attention_context/mamba_metadata.py | 19 +++++++++++++++---- megatron/core/ssm/mamba_mixer.py | 10 ++-------- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/megatron/core/inference/contexts/attention_context/mamba_metadata.py b/megatron/core/inference/contexts/attention_context/mamba_metadata.py index baeb5782b30..42dcb00ebd4 100644 --- a/megatron/core/inference/contexts/attention_context/mamba_metadata.py +++ b/megatron/core/inference/contexts/attention_context/mamba_metadata.py @@ -41,8 +41,13 @@ def __init__(self, max_requests: int, max_tokens: int): (1,), -1, dtype=torch.int32, device=self.device ) + # Cache sequence lengths for decode requests + self._cache_seqlens_decode_buffer = torch.zeros( + (self.max_requests,), dtype=torch.int32, device=self.device + ) + # Cache sequence lengths for the chunked prefill request - self._chunked_prefill_cache_seqlens_buffer = torch.zeros( + self._cache_seqlens_chunked_prefill_buffer = torch.zeros( (1,), dtype=torch.int32, device=self.device ) @@ -100,7 +105,8 @@ def reset_varlen_metadata(self) -> None: self.seq_idx = None self.device_decode_prefill = None self.device_chunked_prefill = None - self.chunked_prefill_cache_seqlens = None + self.cache_seqlens_decode = None + self.cache_seqlens_chunked_prefill = None def update( self, @@ -179,9 +185,14 @@ def update( self._batch_indices_decode_buffer[:real_decode_count].copy_( active_mamba_indices[:real_decode_count] ) + self._cache_seqlens_decode_buffer[:real_decode_count].copy_( + request_kv_length_offsets[:real_decode_count] + ) if padded_decode_count > real_decode_count: self._batch_indices_decode_buffer[real_decode_count:padded_decode_count] = -1 + self._cache_seqlens_decode_buffer[real_decode_count:padded_decode_count] = 0 self.batch_indices_decode = self._batch_indices_decode_buffer[:padded_decode_count] + self.cache_seqlens_decode = self._cache_seqlens_decode_buffer[:padded_decode_count] # Determine if we have a chunked prefill request and adjust counts for regular prefill regular_prefill_count = real_prefill_count @@ -197,8 +208,8 @@ def update( self.batch_indices_chunked_prefill = self._batch_indices_chunked_prefill_buffer # Update chunked prefill cache seqlen - self._chunked_prefill_cache_seqlens_buffer[0] = request_kv_length_offsets[chunked_req_idx] - self.chunked_prefill_cache_seqlens = self._chunked_prefill_cache_seqlens_buffer + self._cache_seqlens_chunked_prefill_buffer[0] = request_kv_length_offsets[chunked_req_idx] + self.cache_seqlens_chunked_prefill = self._cache_seqlens_chunked_prefill_buffer else: self.batch_indices_chunked_prefill = None diff --git a/megatron/core/ssm/mamba_mixer.py b/megatron/core/ssm/mamba_mixer.py index 3a3fb83e22f..2062d23f16d 100644 --- a/megatron/core/ssm/mamba_mixer.py +++ b/megatron/core/ssm/mamba_mixer.py @@ -458,19 +458,13 @@ def _dynamic_inference(self, hidden_states: torch.Tensor, context: DynamicInfere # Reshape from [N*S, 1, d] to [N, S, d] for the 3D Triton kernels zxBCdt_decode = zxBCdt_decode.squeeze(1).view(decode_req_count, seq_len, -1) - # Get sequence lengths for the circular buffer calculation - req_start = context.paused_request_count - cache_seqlens = context.request_kv_length_offsets[ - req_start : req_start + decode_req_count - ] - y_decode = self._ssm_decode( zxBCdt_decode, conv_state, ssm_state, batch_indices=context.mamba_metadata.batch_indices_decode, intermediate_ssm_state=int_ssm_state, - cache_seqlens=cache_seqlens, + cache_seqlens=context.mamba_metadata.cache_seqlens_decode, ) # Flatten back to [N*S, 1, d] to match merge logic @@ -551,7 +545,7 @@ def _dynamic_inference_prefill( ssm_state=ssm_state, batch_indices=metadata.batch_indices_chunked_prefill, is_chunked_prefill=True, - cache_seqlens=metadata.chunked_prefill_cache_seqlens, + cache_seqlens=metadata.cache_seqlens_chunked_prefill, ) # Update zxBCdt to contain the remaining slice for regular prefill From 307fad567d4d8c32201ec64ee0145c64026f115c Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Thu, 5 Mar 2026 23:29:40 -0800 Subject: [PATCH 34/76] AAdd triton kernels (possibly revert) Signed-off-by: Keshav Santhanam --- .../attention_context/mamba_metadata.py | 6 +- .../inference/contexts/dynamic_context.py | 4 +- megatron/core/ssm/mamba_mixer.py | 49 ++-- megatron/core/ssm/ops/causal_conv1d_triton.py | 256 ++++++++++++++++++ megatron/core/transformer/attention.py | 8 +- 5 files changed, 286 insertions(+), 37 deletions(-) diff --git a/megatron/core/inference/contexts/attention_context/mamba_metadata.py b/megatron/core/inference/contexts/attention_context/mamba_metadata.py index 42dcb00ebd4..64b9ef35f13 100644 --- a/megatron/core/inference/contexts/attention_context/mamba_metadata.py +++ b/megatron/core/inference/contexts/attention_context/mamba_metadata.py @@ -206,9 +206,11 @@ def update( # Update chunked prefill indices self._batch_indices_chunked_prefill_buffer[0] = active_mamba_indices[chunked_req_idx] self.batch_indices_chunked_prefill = self._batch_indices_chunked_prefill_buffer - + # Update chunked prefill cache seqlen - self._cache_seqlens_chunked_prefill_buffer[0] = request_kv_length_offsets[chunked_req_idx] + self._cache_seqlens_chunked_prefill_buffer[0] = request_kv_length_offsets[ + chunked_req_idx + ] self.cache_seqlens_chunked_prefill = self._cache_seqlens_chunked_prefill_buffer else: self.batch_indices_chunked_prefill = None diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index 97341a93cca..d04359aacdf 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -1430,9 +1430,7 @@ def initialize_attention_state( padded_decode_req_count = min( self.max_requests, self.round_up_requests(self.num_decode_requests) ) - padded_token_count = padded_decode_req_count * ( - self.num_speculative_tokens + 1 - ) + padded_token_count = padded_decode_req_count * (self.num_speculative_tokens + 1) else: padded_token_count = min( self.max_tokens, diff --git a/megatron/core/ssm/mamba_mixer.py b/megatron/core/ssm/mamba_mixer.py index 2062d23f16d..0eba5d0f7eb 100644 --- a/megatron/core/ssm/mamba_mixer.py +++ b/megatron/core/ssm/mamba_mixer.py @@ -25,7 +25,12 @@ ) from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.process_groups_config import ProcessGroupCollection -from megatron.core.ssm.ops.causal_conv1d_triton import causal_conv1d_update +from megatron.core.ssm.ops.causal_conv1d_triton import ( + causal_conv1d_update, + gather_conv_state, + roll_conv_varlen_states, + scatter_conv_state, +) from megatron.core.ssm.ops.mamba_ssm import selective_state_update from megatron.core.tensor_parallel import get_cuda_rng_tracker from megatron.core.transformer import TransformerConfig @@ -765,17 +770,8 @@ def _ssm_prefill( xBC.squeeze(0), cu_seqlens, state_len=state_len ) - # Roll into circular buffer layout expected by decode - seqlens = cu_seqlens[1:] - cu_seqlens[:-1] - # Shift required to move token from `state_len - 1` to `(seqlen - 1) % state_len` - shifts = (seqlens % state_len).view(-1, 1, 1) - - B, D, W = conv_varlen_states.shape - base_idx = torch.arange(W, device=conv_state.device).view(1, 1, W) - gather_idx = (base_idx - shifts) % W - gather_idx = gather_idx.expand(B, D, W) - - conv_varlen_states_circular = torch.gather(conv_varlen_states, dim=2, index=gather_idx) + # Roll into circular buffer layout expected by decode using fused Triton kernel + conv_varlen_states_circular = roll_conv_varlen_states(conv_varlen_states, cu_seqlens) # Update state tensor_masked_update(conv_state, batch_indices, conv_varlen_states_circular) @@ -789,31 +785,24 @@ def _ssm_prefill( assert batch_indices is not None assert cache_seqlens is not None state_len = conv_state.shape[-1] - + + xBC = xBC.transpose(1, 2) + # Read last (d_conv - 1) tokens from the circular buffer - seq_len = cache_seqlens.view(-1, 1, 1) - gather_indices = (seq_len - self.d_conv + 1 + torch.arange(self.d_conv - 1, device=conv_state.device).view(1, 1, -1)) % state_len - gather_indices = gather_indices.expand(len(batch_indices), conv_state.shape[1], self.d_conv - 1) - initial_conv_state = torch.gather(conv_state[batch_indices], 2, gather_indices) - - initial_conv_state = ( - initial_conv_state.permute(0, 2, 1) - .contiguous() - .transpose(1, 2) + initial_conv_state = gather_conv_state( + conv_state, batch_indices, cache_seqlens, self.d_conv ) - xBC = xBC.transpose(1, 2) - + initial_conv_state = initial_conv_state.permute(0, 2, 1).contiguous().transpose(1, 2) + # We only need to retain at most the last `state_len` tokens of the chunk chunk_len = xBC.shape[-1] copy_len = min(chunk_len, state_len) xBC_tail = xBC[..., -copy_len:] - - update_indices = (seq_len + chunk_len - copy_len + torch.arange(copy_len, device=conv_state.device).view(1, 1, -1)) % state_len - update_indices = update_indices.expand(len(batch_indices), conv_state.shape[1], copy_len) - conv_state_slice = conv_state[batch_indices] - conv_state_slice.scatter_(2, update_indices, xBC_tail) - tensor_masked_update(conv_state, batch_indices, conv_state_slice) + # Scatter tail back into the main buffer using fused Triton kernel + scatter_conv_state( + conv_state, xBC_tail, batch_indices, cache_seqlens, chunk_len, copy_len + ) else: # transpose: b l pd --> b pd l xBC = rearrange(xBC, "b l d -> b d l").contiguous() diff --git a/megatron/core/ssm/ops/causal_conv1d_triton.py b/megatron/core/ssm/ops/causal_conv1d_triton.py index dbe585a107e..4be7ce324f3 100644 --- a/megatron/core/ssm/ops/causal_conv1d_triton.py +++ b/megatron/core/ssm/ops/causal_conv1d_triton.py @@ -9,6 +9,262 @@ import triton.language as tl +@triton.jit +def _roll_circular_buffer_kernel( + in_ptr, + out_ptr, + cu_seqlens_ptr, + B: tl.constexpr, + D: tl.constexpr, + W: tl.constexpr, + stride_in_b, + stride_in_d, + stride_in_w, + stride_out_b, + stride_out_d, + stride_out_w, + BLOCK_W: tl.constexpr, +): + # We map a 1D grid over B * D + pid = tl.program_id(0) + b = pid // D + d = pid % D + + # 1. Load sequence lengths to calculate shift + seqlen_start = tl.load(cu_seqlens_ptr + b) + seqlen_end = tl.load(cu_seqlens_ptr + b + 1) + seqlen = seqlen_end - seqlen_start + + shift = seqlen % W + + # 2. Setup standard W offsets + w_offsets = tl.arange(0, BLOCK_W) + mask = w_offsets < W + + # 3. Calculate gathered indices + # NOTE: Triton/C++ modulo operator truncates towards zero for negative numbers. + # Because shift < W, (w_offsets - shift) is at least -W + 1. + # Adding W ensures the dividend is strictly positive, giving the correct wrapping behavior. + src_w_offsets = (w_offsets - shift + W) % W + + # 4. Compute memory pointers + in_offsets = in_ptr + (b * stride_in_b) + (d * stride_in_d) + (src_w_offsets * stride_in_w) + out_offsets = out_ptr + (b * stride_out_b) + (d * stride_out_d) + (w_offsets * stride_out_w) + + # 5. Load and Store + vals = tl.load(in_offsets, mask=mask) + tl.store(out_offsets, vals, mask=mask) + + +def roll_conv_varlen_states( + conv_varlen_states: torch.Tensor, cu_seqlens: torch.Tensor +) -> torch.Tensor: + """ + Rolls the convolution states into a circular buffer layout based on sequence lengths. + """ + B, D, W = conv_varlen_states.shape + out = torch.empty_like(conv_varlen_states) + + # Next power of 2 for block size (e.g. W=4 -> BLOCK_W=4) + BLOCK_W = triton.next_power_of_2(W) + + # Grid of size B * D + grid = lambda meta: (B * D,) + + _roll_circular_buffer_kernel[grid]( + conv_varlen_states, + out, + cu_seqlens, + B, + D, + W, + conv_varlen_states.stride(0), + conv_varlen_states.stride(1), + conv_varlen_states.stride(2), + out.stride(0), + out.stride(1), + out.stride(2), + BLOCK_W=BLOCK_W, + ) + + return out + + +@triton.jit +def _gather_conv_state_kernel( + conv_state_ptr, + batch_indices_ptr, + cache_seqlens_ptr, + out_ptr, + stride_cs_b, + stride_cs_d, + stride_cs_w, + stride_out_b, + stride_out_d, + stride_out_w, + D: tl.constexpr, + state_len: tl.constexpr, + d_conv: tl.constexpr, + BLOCK_W: tl.constexpr, +): + pid = tl.program_id(0) + b = pid // D + d = pid % D + + # Load batch map + req_idx = tl.load(batch_indices_ptr + b) + + # Check for padding/invalid batch index + if req_idx < 0: + w_offsets = tl.arange(0, BLOCK_W) + mask = w_offsets < (d_conv - 1) + out_offsets = out_ptr + (b * stride_out_b) + (d * stride_out_d) + (w_offsets * stride_out_w) + # Store 0.0 to prevent NaNs/garbage data from propagating + tl.store(out_offsets, 0.0, mask=mask) + return + + # Load sequence length + seq_len = tl.load(cache_seqlens_ptr + b) + + w_offsets = tl.arange(0, BLOCK_W) + mask = w_offsets < (d_conv - 1) + + # Calculate circular buffer index. + # We add state_len before modulo to prevent negative values in C++ modulo + # when seq_len < d_conv - 1. + val = seq_len - d_conv + 1 + w_offsets + gather_indices = (val + state_len) % state_len + + # Calculate memory offsets + cs_offsets = ( + conv_state_ptr + + (req_idx * stride_cs_b) + + (d * stride_cs_d) + + (gather_indices * stride_cs_w) + ) + out_offsets = out_ptr + (b * stride_out_b) + (d * stride_out_d) + (w_offsets * stride_out_w) + + data = tl.load(cs_offsets, mask=mask) + tl.store(out_offsets, data, mask=mask) + + +def gather_conv_state( + conv_state: torch.Tensor, batch_indices: torch.Tensor, cache_seqlens: torch.Tensor, d_conv: int +) -> torch.Tensor: + """Reads the last (d_conv - 1) tokens from the circular convolution state.""" + B = batch_indices.shape[0] + D = conv_state.shape[1] + state_len = conv_state.shape[2] + + out = torch.empty((B, D, d_conv - 1), device=conv_state.device, dtype=conv_state.dtype) + BLOCK_W = triton.next_power_of_2(d_conv - 1) + + grid = lambda meta: (B * D,) + _gather_conv_state_kernel[grid]( + conv_state, + batch_indices, + cache_seqlens, + out, + conv_state.stride(0), + conv_state.stride(1), + conv_state.stride(2), + out.stride(0), + out.stride(1), + out.stride(2), + D, + state_len, + d_conv, + BLOCK_W=BLOCK_W, + ) + return out + + +@triton.jit +def _scatter_conv_state_kernel( + conv_state_ptr, + batch_indices_ptr, + cache_seqlens_ptr, + xbc_tail_ptr, + stride_cs_b, + stride_cs_d, + stride_cs_w, + stride_xbc_b, + stride_xbc_d, + stride_xbc_w, + D: tl.constexpr, + state_len: tl.constexpr, + chunk_len: tl.constexpr, + copy_len: tl.constexpr, + BLOCK_W: tl.constexpr, +): + pid = tl.program_id(0) + b = pid // D + d = pid % D + + # Load batch map + req_idx = tl.load(batch_indices_ptr + b) + + # Check for padding/invalid batch index and safely return + if req_idx < 0: + return + + # Load sequence length + seq_len = tl.load(cache_seqlens_ptr + b) + + w_offsets = tl.arange(0, BLOCK_W) + mask = w_offsets < copy_len + + # seq_len >= 0 and chunk_len >= copy_len, so this is guaranteed to be >= 0. + update_indices = (seq_len + chunk_len - copy_len + w_offsets) % state_len + + # Calculate memory offsets + xbc_offsets = ( + xbc_tail_ptr + (b * stride_xbc_b) + (d * stride_xbc_d) + (w_offsets * stride_xbc_w) + ) + cs_offsets = ( + conv_state_ptr + + (req_idx * stride_cs_b) + + (d * stride_cs_d) + + (update_indices * stride_cs_w) + ) + + data = tl.load(xbc_offsets, mask=mask) + tl.store(cs_offsets, data, mask=mask) + + +def scatter_conv_state( + conv_state: torch.Tensor, + xbc_tail: torch.Tensor, + batch_indices: torch.Tensor, + cache_seqlens: torch.Tensor, + chunk_len: int, + copy_len: int, +): + """Writes the newest chunk of tokens into the circular convolution state.""" + B, D, _ = xbc_tail.shape + state_len = conv_state.shape[2] + BLOCK_W = triton.next_power_of_2(copy_len) + + grid = lambda meta: (B * D,) + _scatter_conv_state_kernel[grid]( + conv_state, + batch_indices, + cache_seqlens, + xbc_tail, + conv_state.stride(0), + conv_state.stride(1), + conv_state.stride(2), + xbc_tail.stride(0), + xbc_tail.stride(1), + xbc_tail.stride(2), + D, + state_len, + chunk_len, + copy_len, + BLOCK_W=BLOCK_W, + ) + + @triton.jit def causal_conv1d_update_kernel( x_ptr, diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index 310a59bde35..0ac7e78fae3 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -60,7 +60,9 @@ rearrange = None try: - from flash_attn_3.flash_attn_interface import _flash_attn_forward + from flash_attn_3.flash_attn_interface import ( + _flash_attn_forward, + ) from flash_attn_3.flash_attn_interface import ( flash_attn_with_kvcache as flash_attn3_with_kvcache, ) @@ -71,7 +73,9 @@ if not HAVE_FA3: try: - from flashattn_hopper.flash_attn_interface import _flash_attn_forward + from flashattn_hopper.flash_attn_interface import ( + _flash_attn_forward, + ) from flashattn_hopper.flash_attn_interface import ( flash_attn_with_kvcache as flash_attn3_with_kvcache, ) From e0b0a8cfb199235d059f17545302d549cf401874 Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Fri, 6 Mar 2026 01:00:08 -0800 Subject: [PATCH 35/76] Fix bug Signed-off-by: Keshav Santhanam --- megatron/core/ssm/ops/causal_conv1d_triton.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/megatron/core/ssm/ops/causal_conv1d_triton.py b/megatron/core/ssm/ops/causal_conv1d_triton.py index 4be7ce324f3..b0c630d7eb0 100644 --- a/megatron/core/ssm/ops/causal_conv1d_triton.py +++ b/megatron/core/ssm/ops/causal_conv1d_triton.py @@ -144,7 +144,8 @@ def _gather_conv_state_kernel( ) out_offsets = out_ptr + (b * stride_out_b) + (d * stride_out_d) + (w_offsets * stride_out_w) - data = tl.load(cs_offsets, mask=mask) + valid_mask = mask & (val >= 0) + data = tl.load(cs_offsets, mask=valid_mask, other=0.0) tl.store(out_offsets, data, mask=mask) From f4112ea207bf45eedfe0071731a5b63db70ae3d1 Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Fri, 6 Mar 2026 01:01:11 -0800 Subject: [PATCH 36/76] Undo formatting changes Signed-off-by: Keshav Santhanam --- megatron/core/transformer/attention.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index 0ac7e78fae3..310a59bde35 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -60,9 +60,7 @@ rearrange = None try: - from flash_attn_3.flash_attn_interface import ( - _flash_attn_forward, - ) + from flash_attn_3.flash_attn_interface import _flash_attn_forward from flash_attn_3.flash_attn_interface import ( flash_attn_with_kvcache as flash_attn3_with_kvcache, ) @@ -73,9 +71,7 @@ if not HAVE_FA3: try: - from flashattn_hopper.flash_attn_interface import ( - _flash_attn_forward, - ) + from flashattn_hopper.flash_attn_interface import _flash_attn_forward from flashattn_hopper.flash_attn_interface import ( flash_attn_with_kvcache as flash_attn3_with_kvcache, ) From 33e19107453d6adb9d6f0221fd5e4ec32c40f535 Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Fri, 6 Mar 2026 01:28:23 -0800 Subject: [PATCH 37/76] Cleanup Signed-off-by: Keshav Santhanam --- megatron/core/ssm/mamba_mixer.py | 9 +---- megatron/core/ssm/ops/causal_conv1d_triton.py | 35 +++++++++++-------- 2 files changed, 21 insertions(+), 23 deletions(-) diff --git a/megatron/core/ssm/mamba_mixer.py b/megatron/core/ssm/mamba_mixer.py index 0eba5d0f7eb..ae2e37d1b33 100644 --- a/megatron/core/ssm/mamba_mixer.py +++ b/megatron/core/ssm/mamba_mixer.py @@ -794,15 +794,8 @@ def _ssm_prefill( ) initial_conv_state = initial_conv_state.permute(0, 2, 1).contiguous().transpose(1, 2) - # We only need to retain at most the last `state_len` tokens of the chunk - chunk_len = xBC.shape[-1] - copy_len = min(chunk_len, state_len) - xBC_tail = xBC[..., -copy_len:] - # Scatter tail back into the main buffer using fused Triton kernel - scatter_conv_state( - conv_state, xBC_tail, batch_indices, cache_seqlens, chunk_len, copy_len - ) + scatter_conv_state(conv_state, xBC, batch_indices, cache_seqlens) else: # transpose: b l pd --> b pd l xBC = rearrange(xBC, "b l d -> b d l").contiguous() diff --git a/megatron/core/ssm/ops/causal_conv1d_triton.py b/megatron/core/ssm/ops/causal_conv1d_triton.py index b0c630d7eb0..7a04a218c35 100644 --- a/megatron/core/ssm/ops/causal_conv1d_triton.py +++ b/megatron/core/ssm/ops/causal_conv1d_triton.py @@ -185,13 +185,13 @@ def _scatter_conv_state_kernel( conv_state_ptr, batch_indices_ptr, cache_seqlens_ptr, - xbc_tail_ptr, + xBC_tail_ptr, stride_cs_b, stride_cs_d, stride_cs_w, - stride_xbc_b, - stride_xbc_d, - stride_xbc_w, + stride_xBC_b, + stride_xBC_d, + stride_xBC_w, D: tl.constexpr, state_len: tl.constexpr, chunk_len: tl.constexpr, @@ -219,8 +219,8 @@ def _scatter_conv_state_kernel( update_indices = (seq_len + chunk_len - copy_len + w_offsets) % state_len # Calculate memory offsets - xbc_offsets = ( - xbc_tail_ptr + (b * stride_xbc_b) + (d * stride_xbc_d) + (w_offsets * stride_xbc_w) + xBC_offsets = ( + xBC_tail_ptr + (b * stride_xBC_b) + (d * stride_xBC_d) + (w_offsets * stride_xBC_w) ) cs_offsets = ( conv_state_ptr @@ -229,20 +229,25 @@ def _scatter_conv_state_kernel( + (update_indices * stride_cs_w) ) - data = tl.load(xbc_offsets, mask=mask) + data = tl.load(xBC_offsets, mask=mask) tl.store(cs_offsets, data, mask=mask) def scatter_conv_state( conv_state: torch.Tensor, - xbc_tail: torch.Tensor, + xBC: torch.Tensor, batch_indices: torch.Tensor, cache_seqlens: torch.Tensor, - chunk_len: int, - copy_len: int, ): """Writes the newest chunk of tokens into the circular convolution state.""" - B, D, _ = xbc_tail.shape + state_len = conv_state.shape[2] + chunk_len = xBC.shape[-1] + + # We only need to retain at most the last `state_len` tokens of the chunk + copy_len = min(chunk_len, state_len) + xBC_tail = xBC[..., -copy_len:] + + B, D, _ = xBC_tail.shape state_len = conv_state.shape[2] BLOCK_W = triton.next_power_of_2(copy_len) @@ -251,13 +256,13 @@ def scatter_conv_state( conv_state, batch_indices, cache_seqlens, - xbc_tail, + xBC_tail, conv_state.stride(0), conv_state.stride(1), conv_state.stride(2), - xbc_tail.stride(0), - xbc_tail.stride(1), - xbc_tail.stride(2), + xBC_tail.stride(0), + xBC_tail.stride(1), + xBC_tail.stride(2), D, state_len, chunk_len, From 9b4dadf8bdba99c87d50e53630c9d7f87b717380 Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Fri, 6 Mar 2026 01:39:28 -0800 Subject: [PATCH 38/76] Add unit tests Signed-off-by: Keshav Santhanam --- .../ssm/test_causal_conv1d_triton.py | 418 ++++++++++++++++++ 1 file changed, 418 insertions(+) create mode 100644 tests/unit_tests/ssm/test_causal_conv1d_triton.py diff --git a/tests/unit_tests/ssm/test_causal_conv1d_triton.py b/tests/unit_tests/ssm/test_causal_conv1d_triton.py new file mode 100644 index 00000000000..1d58897840c --- /dev/null +++ b/tests/unit_tests/ssm/test_causal_conv1d_triton.py @@ -0,0 +1,418 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +import pytest +import torch + +from megatron.core.ssm.ops.causal_conv1d_triton import ( + causal_conv1d_update, + gather_conv_state, + roll_conv_varlen_states, + scatter_conv_state, +) + + +def _requires_cuda(): + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + +# ---------------------- Reference Implementations ---------------------- # + + +def roll_conv_varlen_states_ref(conv_varlen_states, cu_seqlens): + """Reference: roll each [D, W] slice by (seqlen % W) positions.""" + B, D, W = conv_varlen_states.shape + out = torch.empty_like(conv_varlen_states) + for b in range(B): + seqlen = (cu_seqlens[b + 1] - cu_seqlens[b]).item() + shift = seqlen % W + for d in range(D): + for w in range(W): + src = (w - shift + W) % W + out[b, d, w] = conv_varlen_states[b, d, src] + return out + + +def gather_conv_state_ref(conv_state, batch_indices, cache_seqlens, d_conv): + """Reference: read last (d_conv-1) elements from circular buffer.""" + B = batch_indices.shape[0] + D = conv_state.shape[1] + state_len = conv_state.shape[2] + out = torch.zeros((B, D, d_conv - 1), device=conv_state.device, dtype=conv_state.dtype) + for b in range(B): + req_idx = batch_indices[b].item() + if req_idx < 0: + continue + seq_len = cache_seqlens[b].item() + for d in range(D): + for w in range(d_conv - 1): + val = seq_len - d_conv + 1 + w + if val < 0: + continue + idx = (val + state_len) % state_len + out[b, d, w] = conv_state[req_idx, d, idx] + return out + + +def scatter_conv_state_ref(conv_state, xBC, batch_indices, cache_seqlens): + """Reference: write newest chunk into circular buffer.""" + state_len = conv_state.shape[2] + chunk_len = xBC.shape[-1] + copy_len = min(chunk_len, state_len) + xBC_tail = xBC[..., -copy_len:] + B, D, _ = xBC_tail.shape + for b in range(B): + req_idx = batch_indices[b].item() + if req_idx < 0: + continue + seq_len = cache_seqlens[b].item() + for d in range(D): + for w in range(copy_len): + idx = (seq_len + chunk_len - copy_len + w) % state_len + conv_state[req_idx, d, idx] = xBC_tail[b, d, w] + + +def causal_conv1d_update_ref(x, conv_state, weight, bias, silu_activation): + """Reference: linear (non-circular) causal conv1d update.""" + batch, seq_len, dim = x.shape + width = weight.shape[1] + state_len = conv_state.shape[-1] + out = torch.empty_like(x) + for b in range(batch): + for s in range(seq_len): + # Shift state left by 1 + conv_state[b, :, :-1] = conv_state[b, :, 1:].clone() + conv_state[b, :, -1] = x[b, s, :] + # Convolution over the last `width` elements + window = conv_state[b, :, state_len - width : state_len].float() + w = weight.float() + val = (window * w).sum(dim=1) + if bias is not None: + val = val + bias.float() + if silu_activation: + val = val * torch.sigmoid(val) + out[b, s, :] = val.to(x.dtype) + return out + + +# ---------------------- Tests ---------------------- # + + +@pytest.mark.internal +class TestRollConvVarlenStates: + + def setup_method(self, method): + _requires_cuda() + + @pytest.mark.parametrize("B,D,W", [(1, 4, 4), (3, 8, 4), (2, 16, 3)]) + def test_matches_reference(self, B, D, W): + torch.manual_seed(42) + conv_states = torch.randn(B, D, W, device="cuda", dtype=torch.float32) + seqlens = torch.randint(1, 20, (B,), device="cuda", dtype=torch.int32) + cu_seqlens = torch.zeros(B + 1, device="cuda", dtype=torch.int32) + cu_seqlens[1:] = torch.cumsum(seqlens, dim=0) + + result = roll_conv_varlen_states(conv_states, cu_seqlens) + expected = roll_conv_varlen_states_ref(conv_states, cu_seqlens) + + torch.testing.assert_close(result, expected) + + def test_zero_shift(self): + """When all seqlens are multiples of W, no rolling should occur.""" + B, D, W = 2, 4, 4 + conv_states = torch.randn(B, D, W, device="cuda", dtype=torch.float32) + cu_seqlens = torch.tensor([0, W, 2 * W], device="cuda", dtype=torch.int32) + + result = roll_conv_varlen_states(conv_states, cu_seqlens) + torch.testing.assert_close(result, conv_states) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) + def test_dtypes(self, dtype): + B, D, W = 2, 8, 4 + conv_states = torch.randn(B, D, W, device="cuda", dtype=dtype) + cu_seqlens = torch.tensor([0, 3, 7], device="cuda", dtype=torch.int32) + + result = roll_conv_varlen_states(conv_states, cu_seqlens) + expected = roll_conv_varlen_states_ref(conv_states, cu_seqlens) + + torch.testing.assert_close(result, expected) + + +@pytest.mark.internal +class TestGatherConvState: + + def setup_method(self, method): + _requires_cuda() + + @pytest.mark.parametrize("d_conv", [2, 3, 4]) + def test_matches_reference(self, d_conv): + torch.manual_seed(42) + B, D, state_len = 3, 8, 16 + conv_state = torch.randn(B, D, state_len, device="cuda", dtype=torch.float32) + batch_indices = torch.arange(B, device="cuda", dtype=torch.int32) + cache_seqlens = torch.randint( + d_conv, state_len + 10, (B,), device="cuda", dtype=torch.int32 + ) + + result = gather_conv_state(conv_state, batch_indices, cache_seqlens, d_conv) + expected = gather_conv_state_ref(conv_state, batch_indices, cache_seqlens, d_conv) + + torch.testing.assert_close(result, expected) + + def test_negative_batch_index_zeros_output(self): + B, D, state_len, d_conv = 2, 4, 8, 4 + conv_state = torch.randn(B, D, state_len, device="cuda", dtype=torch.float32) + batch_indices = torch.tensor([-1, 0], device="cuda", dtype=torch.int32) + cache_seqlens = torch.tensor([5, 5], device="cuda", dtype=torch.int32) + + result = gather_conv_state(conv_state, batch_indices, cache_seqlens, d_conv) + + # First batch should be all zeros due to negative index + torch.testing.assert_close( + result[0], torch.zeros(D, d_conv - 1, device="cuda", dtype=torch.float32) + ) + + def test_small_seqlen(self): + """When seq_len < d_conv - 1, early positions should be zero-padded.""" + B, D, state_len, d_conv = 1, 4, 8, 4 + conv_state = torch.randn(B, D, state_len, device="cuda", dtype=torch.float32) + batch_indices = torch.tensor([0], device="cuda", dtype=torch.int32) + cache_seqlens = torch.tensor([1], device="cuda", dtype=torch.int32) + + result = gather_conv_state(conv_state, batch_indices, cache_seqlens, d_conv) + expected = gather_conv_state_ref(conv_state, batch_indices, cache_seqlens, d_conv) + + torch.testing.assert_close(result, expected) + + +@pytest.mark.internal +class TestScatterConvState: + + def setup_method(self, method): + _requires_cuda() + + @pytest.mark.parametrize("chunk_len", [4, 8, 20]) + def test_matches_reference(self, chunk_len): + torch.manual_seed(42) + B, D, state_len = 3, 8, 16 + conv_state_triton = torch.zeros(B, D, state_len, device="cuda", dtype=torch.float32) + conv_state_ref = conv_state_triton.clone() + xBC = torch.randn(B, D, chunk_len, device="cuda", dtype=torch.float32) + batch_indices = torch.arange(B, device="cuda", dtype=torch.int32) + cache_seqlens = torch.randint(0, 20, (B,), device="cuda", dtype=torch.int32) + + scatter_conv_state(conv_state_triton, xBC, batch_indices, cache_seqlens) + scatter_conv_state_ref(conv_state_ref, xBC, batch_indices, cache_seqlens) + + torch.testing.assert_close(conv_state_triton, conv_state_ref) + + def test_negative_batch_index_noop(self): + B, D, state_len, chunk_len = 2, 4, 8, 4 + conv_state = torch.zeros(B, D, state_len, device="cuda", dtype=torch.float32) + conv_state_orig = conv_state.clone() + xBC = torch.randn(2, D, chunk_len, device="cuda", dtype=torch.float32) + batch_indices = torch.tensor([-1, -1], device="cuda", dtype=torch.int32) + cache_seqlens = torch.tensor([0, 0], device="cuda", dtype=torch.int32) + + scatter_conv_state(conv_state, xBC, batch_indices, cache_seqlens) + + torch.testing.assert_close(conv_state, conv_state_orig) + + def test_chunk_larger_than_state(self): + """When chunk_len > state_len, only last state_len tokens should be written.""" + B, D, state_len = 1, 4, 4 + chunk_len = 10 + conv_state = torch.zeros(B, D, state_len, device="cuda", dtype=torch.float32) + conv_state_ref = conv_state.clone() + xBC = torch.randn(B, D, chunk_len, device="cuda", dtype=torch.float32) + batch_indices = torch.tensor([0], device="cuda", dtype=torch.int32) + cache_seqlens = torch.tensor([0], device="cuda", dtype=torch.int32) + + scatter_conv_state(conv_state, xBC, batch_indices, cache_seqlens) + scatter_conv_state_ref(conv_state_ref, xBC, batch_indices, cache_seqlens) + + torch.testing.assert_close(conv_state, conv_state_ref) + + +@pytest.mark.internal +class TestCausalConv1dUpdate: + + def setup_method(self, method): + _requires_cuda() + + @pytest.mark.parametrize("width", [2, 3, 4]) + def test_linear_no_bias(self, width): + torch.manual_seed(42) + B, seq_len, D, state_len = 2, 3, 64, 8 + x = torch.randn(B, seq_len, D, device="cuda", dtype=torch.float32) + conv_state_triton = torch.randn(B, D, state_len, device="cuda", dtype=torch.float32) + conv_state_ref = conv_state_triton.clone() + weight = torch.randn(D, width, device="cuda", dtype=torch.float32) + + result = causal_conv1d_update( + x, + conv_state_triton, + weight, + bias=None, + silu_activation=False, + cache_seqlens=None, + conv_state_indices=None, + ) + expected = causal_conv1d_update_ref( + x, conv_state_ref, weight, bias=None, silu_activation=False + ) + + torch.testing.assert_close(result, expected, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(conv_state_triton, conv_state_ref, atol=1e-5, rtol=1e-5) + + @pytest.mark.parametrize("width", [2, 3, 4]) + def test_linear_with_bias(self, width): + torch.manual_seed(42) + B, seq_len, D, state_len = 2, 3, 64, 8 + x = torch.randn(B, seq_len, D, device="cuda", dtype=torch.float32) + conv_state_triton = torch.randn(B, D, state_len, device="cuda", dtype=torch.float32) + conv_state_ref = conv_state_triton.clone() + weight = torch.randn(D, width, device="cuda", dtype=torch.float32) + bias = torch.randn(D, device="cuda", dtype=torch.float32) + + result = causal_conv1d_update( + x, + conv_state_triton, + weight, + bias=bias, + silu_activation=False, + cache_seqlens=None, + conv_state_indices=None, + ) + expected = causal_conv1d_update_ref( + x, conv_state_ref, weight, bias=bias, silu_activation=False + ) + + torch.testing.assert_close(result, expected, atol=1e-5, rtol=1e-5) + + @pytest.mark.parametrize("width", [2, 3, 4]) + def test_linear_with_silu(self, width): + torch.manual_seed(42) + B, seq_len, D, state_len = 2, 1, 64, 8 + x = torch.randn(B, seq_len, D, device="cuda", dtype=torch.float32) + conv_state_triton = torch.randn(B, D, state_len, device="cuda", dtype=torch.float32) + conv_state_ref = conv_state_triton.clone() + weight = torch.randn(D, width, device="cuda", dtype=torch.float32) + bias = torch.randn(D, device="cuda", dtype=torch.float32) + + result = causal_conv1d_update( + x, + conv_state_triton, + weight, + bias=bias, + silu_activation="silu", + cache_seqlens=None, + conv_state_indices=None, + ) + expected = causal_conv1d_update_ref( + x, conv_state_ref, weight, bias=bias, silu_activation=True + ) + + torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4) + + def test_2d_input(self): + """Test that 2D input (B, D) is handled correctly and returns 2D output.""" + torch.manual_seed(42) + B, D, state_len, width = 2, 64, 8, 4 + x = torch.randn(B, D, device="cuda", dtype=torch.float32) + conv_state = torch.randn(B, D, state_len, device="cuda", dtype=torch.float32) + weight = torch.randn(D, width, device="cuda", dtype=torch.float32) + + result = causal_conv1d_update( + x, + conv_state, + weight, + bias=None, + silu_activation=False, + cache_seqlens=None, + conv_state_indices=None, + ) + + assert result.dim() == 2 + assert result.shape == (B, D) + + def test_conv_state_indices(self): + """Test that conv_state_indices correctly maps batch to state entries.""" + torch.manual_seed(42) + B, D, state_len, width = 2, 64, 8, 4 + num_states = 4 + x = torch.randn(B, 1, D, device="cuda", dtype=torch.float32) + conv_state = torch.randn(num_states, D, state_len, device="cuda", dtype=torch.float32) + weight = torch.randn(D, width, device="cuda", dtype=torch.float32) + # Map batch 0 -> state 2, batch 1 -> state 0 + state_indices = torch.tensor([2, 0], device="cuda", dtype=torch.int32) + + # Run with indices + conv_state_copy = conv_state.clone() + result = causal_conv1d_update( + x, + conv_state_copy, + weight, + bias=None, + silu_activation=False, + cache_seqlens=None, + conv_state_indices=state_indices, + ) + + # Run without indices by manually reordering + conv_state_reordered = conv_state[state_indices.long()].clone() + expected = causal_conv1d_update( + x, + conv_state_reordered, + weight, + bias=None, + silu_activation=False, + cache_seqlens=None, + conv_state_indices=None, + ) + + torch.testing.assert_close(result, expected, atol=1e-5, rtol=1e-5) + + def test_negative_state_index_zeros_output(self): + """Padding batch entries (index < 0) should produce zero output.""" + torch.manual_seed(42) + B, D, state_len, width = 2, 64, 8, 4 + x = torch.randn(B, 1, D, device="cuda", dtype=torch.float32) + conv_state = torch.randn(B, D, state_len, device="cuda", dtype=torch.float32) + weight = torch.randn(D, width, device="cuda", dtype=torch.float32) + state_indices = torch.tensor([-1, 0], device="cuda", dtype=torch.int32) + + result = causal_conv1d_update( + x, + conv_state, + weight, + bias=None, + silu_activation=False, + cache_seqlens=None, + conv_state_indices=state_indices, + ) + + # Batch 0 (padded) should be all zeros + torch.testing.assert_close(result[0], torch.zeros(1, D, device="cuda", dtype=torch.float32)) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + def test_half_precision(self, dtype): + torch.manual_seed(42) + B, seq_len, D, state_len, width = 2, 1, 64, 8, 4 + x = torch.randn(B, seq_len, D, device="cuda", dtype=dtype) + conv_state = torch.randn(B, D, state_len, device="cuda", dtype=dtype) + weight = torch.randn(D, width, device="cuda", dtype=dtype) + + result = causal_conv1d_update( + x, + conv_state, + weight, + bias=None, + silu_activation=False, + cache_seqlens=None, + conv_state_indices=None, + ) + + assert result.dtype == dtype + assert result.shape == (B, seq_len, D) + assert torch.isfinite(result).all() From e59d6e955be502960f9c478ea7b149fb0e5629a3 Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Fri, 6 Mar 2026 10:22:36 -0800 Subject: [PATCH 39/76] Test cache_seqlens in mamba_metadata.py Signed-off-by: Keshav Santhanam --- .../attention_metadata/test_mamba_metadata.py | 142 +++++++++++++++++- 1 file changed, 134 insertions(+), 8 deletions(-) diff --git a/tests/unit_tests/inference/contexts/attention_metadata/test_mamba_metadata.py b/tests/unit_tests/inference/contexts/attention_metadata/test_mamba_metadata.py index c793348233d..fce9518caeb 100644 --- a/tests/unit_tests/inference/contexts/attention_metadata/test_mamba_metadata.py +++ b/tests/unit_tests/inference/contexts/attention_metadata/test_mamba_metadata.py @@ -31,6 +31,7 @@ def _run_update_test( num_decode_requests: int, padded_dims: InferenceBatchDimensions, enable_chunked_prefill: bool, + request_kv_length_offsets: list[int] | None = None, ): """ Helper to construct inputs and run update(). @@ -42,6 +43,7 @@ def _run_update_test( num_decode_requests: Number of requests in req_seq_lengths that are in the decode phase. padded_dims: The padded batch dimensions to test against. enable_chunked_prefill: Whether chunked prefill is enabled. + request_kv_length_offsets: KV cache length offsets per request. Defaults to zeros. """ num_active_requests = len(req_seq_lengths) total_tokens = sum(req_seq_lengths) @@ -53,6 +55,12 @@ def _run_update_test( decode_req_count=num_decode_requests, ) + if request_kv_length_offsets is None: + request_kv_length_offsets = [0] * num_active_requests + kv_length_offsets_tensor = torch.tensor( + request_kv_length_offsets, dtype=torch.int32, device=metadata.device + ) + # Assuming 1:1 mapping (req_id i -> slot i) active_mamba_indices = torch.arange( num_active_requests, dtype=torch.int32, device=metadata.device @@ -74,6 +82,7 @@ def _run_update_test( active_mamba_indices=active_mamba_indices, token_to_request_idx=token_to_req_tensor, cu_seqlens=cu_seqlens_tensor, + request_kv_length_offsets=kv_length_offsets_tensor, batch_dimensions=real_dims, padded_batch_dimensions=padded_dims, enable_chunked_prefill=enable_chunked_prefill, @@ -90,18 +99,31 @@ def test_update_decode_only_exact_match(self, metadata_context): """Test simple decode only case where real dims match padded dims.""" seq_lengths = [1, 1, 1, 1] # 4 requests num_decode = 4 + kv_offsets = [5, 10, 15, 20] padded_dims = InferenceBatchDimensions( token_count=4, prefill_req_count=0, decode_req_count=4 ) self._run_update_test( - metadata_context, seq_lengths, num_decode, padded_dims, enable_chunked_prefill=False + metadata_context, + seq_lengths, + num_decode, + padded_dims, + enable_chunked_prefill=False, + request_kv_length_offsets=kv_offsets, ) expected_decode = torch.arange(4, dtype=torch.int32, device=metadata_context.device) assert torch.equal(metadata_context.batch_indices_decode, expected_decode) + + expected_cache_seqlens = torch.tensor( + kv_offsets, dtype=torch.int32, device=metadata_context.device + ) + assert torch.equal(metadata_context.cache_seqlens_decode, expected_cache_seqlens) + assert metadata_context.batch_indices_prefill is None assert metadata_context.batch_indices_chunked_prefill is None + assert metadata_context.cache_seqlens_chunked_prefill is None assert metadata_context.device_decode_prefill is None assert metadata_context.cu_seqlens is None assert metadata_context.seq_idx is None @@ -111,21 +133,34 @@ def test_update_decode_only_padded(self, metadata_context): """Test decode only with padding (e.g. using CUDA graphs bucket).""" seq_lengths = [1, 1] # 2 requests num_decode = 2 + kv_offsets = [7, 12] # Padding to 4 requests padded_dims = InferenceBatchDimensions( token_count=4, prefill_req_count=0, decode_req_count=4 ) self._run_update_test( - metadata_context, seq_lengths, num_decode, padded_dims, enable_chunked_prefill=False + metadata_context, + seq_lengths, + num_decode, + padded_dims, + enable_chunked_prefill=False, + request_kv_length_offsets=kv_offsets, ) expected_decode = torch.tensor( [0, 1, -1, -1], dtype=torch.int32, device=metadata_context.device ) assert torch.equal(metadata_context.batch_indices_decode, expected_decode) + + expected_cache_seqlens = torch.tensor( + [7, 12, 0, 0], dtype=torch.int32, device=metadata_context.device + ) + assert torch.equal(metadata_context.cache_seqlens_decode, expected_cache_seqlens) + assert metadata_context.batch_indices_prefill is None assert metadata_context.batch_indices_chunked_prefill is None + assert metadata_context.cache_seqlens_chunked_prefill is None assert metadata_context.device_decode_prefill is None @pytest.mark.internal @@ -133,18 +168,31 @@ def test_update_chunked_enabled_no_prefill_reqs(self, metadata_context): """Test edge case: Chunked prefill enabled, but only decode requests exist.""" seq_lengths = [1, 1] num_decode = 2 + kv_offsets = [3, 8] padded_dims = InferenceBatchDimensions( token_count=2, prefill_req_count=0, decode_req_count=2 ) self._run_update_test( - metadata_context, seq_lengths, num_decode, padded_dims, enable_chunked_prefill=True + metadata_context, + seq_lengths, + num_decode, + padded_dims, + enable_chunked_prefill=True, + request_kv_length_offsets=kv_offsets, ) # Should behave exactly like decode-only (chunked logic skipped if real_prefill == 0) expected_decode = torch.tensor([0, 1], dtype=torch.int32, device=metadata_context.device) assert torch.equal(metadata_context.batch_indices_decode, expected_decode) + + expected_cache_seqlens = torch.tensor( + [3, 8], dtype=torch.int32, device=metadata_context.device + ) + assert torch.equal(metadata_context.cache_seqlens_decode, expected_cache_seqlens) + assert metadata_context.batch_indices_chunked_prefill is None + assert metadata_context.cache_seqlens_chunked_prefill is None assert metadata_context.batch_indices_prefill is None assert metadata_context.cu_seqlens is None assert metadata_context.seq_idx is None @@ -180,7 +228,9 @@ def test_update_prefill_only_exact(self, metadata_context): assert torch.equal(metadata_context.seq_idx, expected_seq_idx) assert metadata_context.batch_indices_decode is None + assert metadata_context.cache_seqlens_decode is None assert metadata_context.batch_indices_chunked_prefill is None + assert metadata_context.cache_seqlens_chunked_prefill is None assert metadata_context.device_decode_prefill is None @pytest.mark.internal @@ -214,7 +264,9 @@ def test_update_prefill_only_padded(self, metadata_context): assert torch.equal(metadata_context.seq_idx, expected_seq_idx) assert metadata_context.batch_indices_decode is None + assert metadata_context.cache_seqlens_decode is None assert metadata_context.batch_indices_chunked_prefill is None + assert metadata_context.cache_seqlens_chunked_prefill is None assert metadata_context.device_decode_prefill is None # ------------------------------------------------------------------------- @@ -227,17 +279,29 @@ def test_update_mixed_batch_exact(self, metadata_context): # 2 decode (len 1), 2 prefill (len 10, 20) seq_lengths = [1, 1, 10, 20] num_decode = 2 + kv_offsets = [5, 10, 0, 0] padded_dims = InferenceBatchDimensions( token_count=32, prefill_req_count=2, decode_req_count=2 ) self._run_update_test( - metadata_context, seq_lengths, num_decode, padded_dims, enable_chunked_prefill=False + metadata_context, + seq_lengths, + num_decode, + padded_dims, + enable_chunked_prefill=False, + request_kv_length_offsets=kv_offsets, ) expected_decode = torch.tensor([0, 1], dtype=torch.int32, device=metadata_context.device) assert torch.equal(metadata_context.batch_indices_decode, expected_decode) + expected_cache_seqlens = torch.tensor( + [5, 10], dtype=torch.int32, device=metadata_context.device + ) + assert torch.equal(metadata_context.cache_seqlens_decode, expected_cache_seqlens) + assert metadata_context.cache_seqlens_chunked_prefill is None + expected_prefill = torch.tensor([2, 3], dtype=torch.int32, device=metadata_context.device) assert torch.equal(metadata_context.batch_indices_prefill, expected_prefill) @@ -267,6 +331,7 @@ def test_update_padded_prefill_and_decode(self, metadata_context): # Real: 1 decode, 1 prefill. seq_lengths = [1, 10] num_decode = 1 + kv_offsets = [25, 0] # Padded: 4 decode, 4 prefill. Total tokens 32. padded_dims = InferenceBatchDimensions( @@ -274,7 +339,12 @@ def test_update_padded_prefill_and_decode(self, metadata_context): ) self._run_update_test( - metadata_context, seq_lengths, num_decode, padded_dims, enable_chunked_prefill=False + metadata_context, + seq_lengths, + num_decode, + padded_dims, + enable_chunked_prefill=False, + request_kv_length_offsets=kv_offsets, ) expected_decode = torch.tensor( @@ -282,6 +352,12 @@ def test_update_padded_prefill_and_decode(self, metadata_context): ) assert torch.equal(metadata_context.batch_indices_decode, expected_decode) + expected_cache_seqlens = torch.tensor( + [25, 0, 0, 0], dtype=torch.int32, device=metadata_context.device + ) + assert torch.equal(metadata_context.cache_seqlens_decode, expected_cache_seqlens) + assert metadata_context.cache_seqlens_chunked_prefill is None + expected_prefill = torch.tensor( [1, -1, -1, -1], dtype=torch.int32, device=metadata_context.device ) @@ -313,6 +389,7 @@ def test_update_chunked_prefill_mixed_exact(self, metadata_context): # 1 decode, 1 chunked prefill (len 50), 1 regular prefill (len 10) seq_lengths = [1, 50, 10] num_decode = 1 + kv_offsets = [9, 100, 0] # Exact dimensions padded_dims = InferenceBatchDimensions( @@ -320,7 +397,12 @@ def test_update_chunked_prefill_mixed_exact(self, metadata_context): ) self._run_update_test( - metadata_context, seq_lengths, num_decode, padded_dims, enable_chunked_prefill=True + metadata_context, + seq_lengths, + num_decode, + padded_dims, + enable_chunked_prefill=True, + request_kv_length_offsets=kv_offsets, ) expected_device_chunked_prefill = torch.tensor( @@ -330,6 +412,18 @@ def test_update_chunked_prefill_mixed_exact(self, metadata_context): assert metadata_context.batch_indices_chunked_prefill[0] == 1 + expected_cache_seqlens_decode = torch.tensor( + [9], dtype=torch.int32, device=metadata_context.device + ) + assert torch.equal(metadata_context.cache_seqlens_decode, expected_cache_seqlens_decode) + + expected_cache_seqlens_chunked = torch.tensor( + [100], dtype=torch.int32, device=metadata_context.device + ) + assert torch.equal( + metadata_context.cache_seqlens_chunked_prefill, expected_cache_seqlens_chunked + ) + expected_prefill = torch.tensor([2, -1], dtype=torch.int32, device=metadata_context.device) assert torch.equal(metadata_context.batch_indices_prefill, expected_prefill) @@ -353,17 +447,28 @@ def test_update_chunked_prefill_mixed_padded(self, metadata_context): # 2 decode, 1 chunked prefill (len 50), 1 regular prefill (len 10) seq_lengths = [1, 1, 50, 10] num_decode = 2 + kv_offsets = [4, 6, 200, 0] padded_dims = InferenceBatchDimensions( token_count=62, prefill_req_count=2, decode_req_count=2 ) self._run_update_test( - metadata_context, seq_lengths, num_decode, padded_dims, enable_chunked_prefill=True + metadata_context, + seq_lengths, + num_decode, + padded_dims, + enable_chunked_prefill=True, + request_kv_length_offsets=kv_offsets, ) expected_decode = torch.tensor([0, 1], dtype=torch.int32, device=metadata_context.device) assert torch.equal(metadata_context.batch_indices_decode, expected_decode) + expected_cache_seqlens_decode = torch.tensor( + [4, 6], dtype=torch.int32, device=metadata_context.device + ) + assert torch.equal(metadata_context.cache_seqlens_decode, expected_cache_seqlens_decode) + expected_device_chunked_prefill = torch.tensor( [50, 10], dtype=torch.int32, device=metadata_context.device ) @@ -371,6 +476,13 @@ def test_update_chunked_prefill_mixed_padded(self, metadata_context): assert metadata_context.batch_indices_chunked_prefill[0] == 2 + expected_cache_seqlens_chunked = torch.tensor( + [200], dtype=torch.int32, device=metadata_context.device + ) + assert torch.equal( + metadata_context.cache_seqlens_chunked_prefill, expected_cache_seqlens_chunked + ) + expected_prefill = torch.tensor([3, -1], dtype=torch.int32, device=metadata_context.device) assert torch.equal(metadata_context.batch_indices_prefill, expected_prefill) @@ -394,19 +506,33 @@ def test_update_chunked_only_padded(self, metadata_context): # 1 chunked prefill request. seq_lengths = [100] num_decode = 0 + kv_offsets = [50] padded_dims = InferenceBatchDimensions( token_count=128, prefill_req_count=2, decode_req_count=0 ) self._run_update_test( - metadata_context, seq_lengths, num_decode, padded_dims, enable_chunked_prefill=True + metadata_context, + seq_lengths, + num_decode, + padded_dims, + enable_chunked_prefill=True, + request_kv_length_offsets=kv_offsets, ) assert metadata_context.batch_indices_decode is None + assert metadata_context.cache_seqlens_decode is None assert metadata_context.batch_indices_chunked_prefill[0] == 0 + expected_cache_seqlens_chunked = torch.tensor( + [50], dtype=torch.int32, device=metadata_context.device + ) + assert torch.equal( + metadata_context.cache_seqlens_chunked_prefill, expected_cache_seqlens_chunked + ) + expected_prefill = torch.tensor([-1, -1], dtype=torch.int32, device=metadata_context.device) assert torch.equal(metadata_context.batch_indices_prefill, expected_prefill) From 213e5d7522d7921a926d1cd4bf0eb52299d1c52d Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Fri, 6 Mar 2026 10:49:18 -0800 Subject: [PATCH 40/76] Add spec decode + prefix caching unit tests Signed-off-by: Keshav Santhanam --- .../contexts/test_dynamic_context.py | 354 ++++++++++++++++++ .../inference/engines/test_dynamic_engine.py | 133 +++++++ .../test_text_generation_controller.py | 94 +++++ 3 files changed, 581 insertions(+) diff --git a/tests/unit_tests/inference/contexts/test_dynamic_context.py b/tests/unit_tests/inference/contexts/test_dynamic_context.py index 5983e3b464b..bf7387cd658 100644 --- a/tests/unit_tests/inference/contexts/test_dynamic_context.py +++ b/tests/unit_tests/inference/contexts/test_dynamic_context.py @@ -1776,3 +1776,357 @@ def test_chunked_prefill_speculative_offset_math(self): + chunk_length + req.sampling_params.num_tokens_to_generate ) + + @pytest.mark.internal + @rounder_override(64) + def test_speculative_with_prefix_caching_shared_blocks(self): + """Test that prefix caching correctly shares blocks when speculative decoding is enabled.""" + self._setup_model_parallel_group(1, 1) + + model_config = TransformerConfig( + params_dtype=torch.float32, num_layers=2, kv_channels=8, num_attention_heads=2 + ) + inference_config = InferenceConfig( + max_sequence_length=512, + buffer_size_gb=0.1, + block_size_tokens=32, + num_speculative_tokens=2, + enable_prefix_caching=True, + unified_memory_level=0, + ) + ctx = DynamicInferenceContext(model_config=model_config, inference_config=inference_config) + + bs = ctx.block_size_tokens + prompt = torch.arange(bs * 3, device='cuda') + + # First request registers blocks. + req1 = DynamicInferenceRequest( + request_id=1, + prompt_tokens=prompt.clone(), + sampling_params=SamplingParams(num_tokens_to_generate=10), + block_size_tokens=bs, + enable_prefix_caching=True, + ) + ctx.add_request(req1) + first_blocks = [ctx.request_to_kv_block_ids[0][i].item() for i in range(3)] + avail_after_first = ctx.block_allocator.total_avail + + # Second request with same prefix should share all blocks. + req2 = DynamicInferenceRequest( + request_id=2, + prompt_tokens=prompt.clone(), + sampling_params=SamplingParams(num_tokens_to_generate=10), + block_size_tokens=bs, + enable_prefix_caching=True, + ) + ctx.add_request(req2) + second_blocks = [ctx.request_to_kv_block_ids[1][i].item() for i in range(3)] + + # Blocks should be shared (same IDs, no pool consumption). + assert first_blocks == second_blocks + assert ctx.block_allocator.total_avail == avail_after_first + + # Ref counts should be 2. + for bid in first_blocks: + assert ctx.block_allocator.block_ref_counts[bid].item() == 2 + + # Second request should skip prefix tokens (query_length == 1 for full match). + assert ctx.request_query_lengths[1].item() == 1 + + @pytest.mark.internal + @rounder_override(64) + def test_speculative_with_prefix_caching_kv_offset(self): + """Test that KV offset accounts for prefix skip when spec decoding is enabled.""" + self._setup_model_parallel_group(1, 1) + + model_config = TransformerConfig( + params_dtype=torch.float32, num_layers=2, kv_channels=8, num_attention_heads=2 + ) + inference_config = InferenceConfig( + max_sequence_length=512, + buffer_size_gb=0.1, + block_size_tokens=32, + num_speculative_tokens=3, + enable_prefix_caching=True, + unified_memory_level=0, + ) + ctx = DynamicInferenceContext(model_config=model_config, inference_config=inference_config) + + bs = ctx.block_size_tokens + prompt = torch.arange(bs * 2, device='cuda') + + # First request. + req1 = DynamicInferenceRequest( + request_id=1, + prompt_tokens=prompt.clone(), + sampling_params=SamplingParams(num_tokens_to_generate=10), + block_size_tokens=bs, + enable_prefix_caching=True, + ) + ctx.add_request(req1) + + # Second request with same prefix: should have kv_offset = prefix_skip_tokens. + req2 = DynamicInferenceRequest( + request_id=2, + prompt_tokens=prompt.clone(), + sampling_params=SamplingParams(num_tokens_to_generate=10), + block_size_tokens=bs, + enable_prefix_caching=True, + ) + ctx.add_request(req2) + + # Full match: prefix_skip = min(2 * bs, 2*bs - 1) = 2*bs - 1 + expected_skip = 2 * bs - 1 + assert ctx.request_kv_length_offsets[1].item() == expected_skip + assert ctx.request_query_lengths[1].item() == 1 + + @pytest.mark.internal + @rounder_override(64) + def test_speculative_update_then_release_with_prefix_caching(self): + """Test that update_requests with spec tokens + block release respects ref counts.""" + self._setup_model_parallel_group(1, 1) + + model_config = TransformerConfig( + params_dtype=torch.float32, num_layers=2, kv_channels=8, num_attention_heads=2 + ) + inference_config = InferenceConfig( + max_sequence_length=512, + buffer_size_gb=0.1, + block_size_tokens=4, + num_speculative_tokens=2, + enable_prefix_caching=True, + unified_memory_level=0, + ) + ctx = DynamicInferenceContext(model_config=model_config, inference_config=inference_config) + + bs = ctx.block_size_tokens + prompt = torch.arange(bs * 2, device='cuda') + + # Two requests sharing the same prefix. + req1 = DynamicInferenceRequest( + request_id=1, + prompt_tokens=prompt.clone(), + sampling_params=SamplingParams(num_tokens_to_generate=10), + block_size_tokens=bs, + enable_prefix_caching=True, + ) + ctx.add_request(req1) + shared_blocks = [ctx.request_to_kv_block_ids[0][i].item() for i in range(2)] + + req2 = DynamicInferenceRequest( + request_id=2, + prompt_tokens=prompt.clone(), + sampling_params=SamplingParams(num_tokens_to_generate=10), + block_size_tokens=bs, + enable_prefix_caching=True, + ) + ctx.add_request(req2) + + # Verify initial ref counts are 2. + for bid in shared_blocks: + assert ctx.block_allocator.block_ref_counts[bid].item() == 2 + + # Release one request. Ref counts should decrement to 1. + ctx.release_memory_blocks_from_request_indexes(torch.tensor([0])) + for bid in shared_blocks: + assert ctx.block_allocator.block_ref_counts[bid].item() == 1 + + # Blocks should still be discoverable via hash map. + for bid in shared_blocks: + h = ctx.block_allocator.block_hashes[bid].item() + assert h in ctx.block_allocator.hash_to_block_id + + @pytest.mark.internal + @rounder_override(64) + def test_speculative_boundary_crossing_with_prefix_caching(self): + """Test block boundary crossing from speculative tokens does not corrupt shared blocks.""" + self._setup_model_parallel_group(1, 1) + + model_config = TransformerConfig( + params_dtype=torch.float32, num_layers=2, kv_channels=8, num_attention_heads=2 + ) + inference_config = InferenceConfig( + max_sequence_length=512, + buffer_size_gb=0.1, + block_size_tokens=4, + num_speculative_tokens=2, + enable_prefix_caching=True, + unified_memory_level=0, + ) + ctx = DynamicInferenceContext(model_config=model_config, inference_config=inference_config) + + bs = ctx.block_size_tokens + prompt = torch.arange(bs * 2, device='cuda') + + # Request 1: adds prefix blocks. + req1 = DynamicInferenceRequest( + request_id=1, + prompt_tokens=prompt.clone(), + sampling_params=SamplingParams(num_tokens_to_generate=10), + block_size_tokens=bs, + enable_prefix_caching=True, + ) + ctx.add_request(req1) + shared_b0 = ctx.request_to_kv_block_ids[0][0].item() + shared_b1 = ctx.request_to_kv_block_ids[0][1].item() + + # Request 2: shares prefix, gets its own decode block. + req2 = DynamicInferenceRequest( + request_id=2, + prompt_tokens=prompt.clone(), + sampling_params=SamplingParams(num_tokens_to_generate=10), + block_size_tokens=bs, + enable_prefix_caching=True, + ) + ctx.add_request(req2) + + # Both requests share the same 2 blocks. + assert ctx.request_to_kv_block_ids[1][0].item() == shared_b0 + assert ctx.request_to_kv_block_ids[1][1].item() == shared_b1 + + # Set up request 0 for decode at offset that will cross block boundary. + # Place at offset (block_size - 1) in last block so adding 3 tokens crosses. + ctx.request_kv_length_offsets[0] = bs * 2 - 1 # one token from end of block 1 + ctx.request_last_kv_block_offset[0] = bs - 1 + ctx.request_query_lengths[0] = 1 + ctx.request_in_prefill_status_tensor[0] = 0 + ctx.active_token_count = 2 + + active_mask = torch.tensor([1, 0], device='cuda', dtype=torch.int32) + new_tokens = torch.tensor([50], device='cuda') + new_spec = torch.tensor([[51], [52]], device='cuda') + + ctx.update_requests( + active_requests_mask=active_mask, + new_tokens=new_tokens, + new_speculative_tokens=new_spec, + ) + + # A new block should have been allocated for the boundary crossing. + assert ctx.request_kv_block_counts[0] == 3 + new_block = ctx.request_to_kv_block_ids[0][2].item() + assert new_block != -1 + assert new_block != shared_b0 + assert new_block != shared_b1 + + # Shared blocks should remain intact with ref count 2. + assert ctx.block_allocator.block_ref_counts[shared_b0].item() == 2 + assert ctx.block_allocator.block_ref_counts[shared_b1].item() == 2 + + @pytest.mark.internal + @rounder_override(64) + def test_chunked_prefill_speculative_offset_with_prefix_caching(self): + """Test chunked prefill offset math combines correctly with prefix caching and spec decoding.""" + self._setup_model_parallel_group(1, 1) + + model_config = TransformerConfig( + params_dtype=torch.float32, num_layers=2, kv_channels=8, num_attention_heads=2 + ) + inference_config = InferenceConfig( + max_sequence_length=512, + buffer_size_gb=0.1, + block_size_tokens=32, + max_requests=256, + max_tokens=256, + num_speculative_tokens=2, + enable_chunked_prefill=True, + enable_prefix_caching=True, + unified_memory_level=0, + ) + ctx = DynamicInferenceContext(model_config=model_config, inference_config=inference_config) + ctx.reset_tensors() + + bs = ctx.block_size_tokens + + # First request: register prefix blocks (bs * 3 tokens = 3 complete blocks). + first_prompt = torch.arange(bs * 3, device='cuda') + req_first = DynamicInferenceRequest( + request_id=1, + prompt_tokens=first_prompt.clone(), + sampling_params=SamplingParams(num_tokens_to_generate=10), + block_size_tokens=bs, + enable_prefix_caching=True, + ) + ctx.add_request(req_first) + + # Second request: same prefix, continuing chunked prefill. + # Simulate that this request already processed bs tokens in a prior chunk. + ctx.chunked_prefill_request_id = 42 + ctx.request_ids[ctx.total_request_count] = 42 + + # Manually set up as if request is mid-chunked-prefill. + ctx.total_request_count += 1 + current_id = ctx.total_request_count - 1 + ctx.request_ids[current_id] = 42 + + initial_active_tokens = ctx.active_token_count + 1 + ctx.num_speculative_tokens + ctx.active_token_count = initial_active_tokens + + req2 = DynamicInferenceRequest( + request_id=42, + prompt_tokens=first_prompt.clone(), + sampling_params=SamplingParams(num_tokens_to_generate=10), + block_size_tokens=bs, + enable_prefix_caching=True, + ) + req2.finished_chunk_token_count = bs # Already processed 1 block + + chunk_length = bs * 2 # Process 2 more blocks + ctx.add_request(req2, chunk_length=chunk_length) + + # Prefix match should find 2 matching blocks (blocks 1 and 2 from req_first). + # The chunked_prefill_offset (1 + num_speculative_tokens = 3) should be subtracted. + chunked_prefill_offset = 1 + ctx.num_speculative_tokens + # With prefix match: 2 blocks matched -> skip (2*bs - 1) tokens + # effective_chunk_length = chunk_length - prefix_skip_tokens + (_, _, _, _, prefix_skip, eff_chunk) = ctx._compute_prefix_match(req2, chunk_length) + expected_active = initial_active_tokens - chunked_prefill_offset + eff_chunk + assert ctx.active_token_count == expected_active + + @pytest.mark.internal + @rounder_override(64) + def test_prefix_caching_check_availability_with_speculative(self): + """Test check_availability accounts for prefix match when spec decoding is enabled.""" + self._setup_model_parallel_group(1, 1) + + model_config = TransformerConfig( + params_dtype=torch.float32, num_layers=2, kv_channels=8, num_attention_heads=2 + ) + inference_config = InferenceConfig( + max_sequence_length=512, + buffer_size_gb=0.01, + block_size_tokens=32, + num_speculative_tokens=3, + enable_prefix_caching=True, + unified_memory_level=0, + ) + ctx = DynamicInferenceContext(model_config=model_config, inference_config=inference_config) + + bs = ctx.block_size_tokens + prompt = torch.arange(bs * 2, device='cuda') + + # First request registers blocks. + req1 = DynamicInferenceRequest( + request_id=1, + prompt_tokens=prompt.clone(), + sampling_params=SamplingParams(num_tokens_to_generate=10), + block_size_tokens=bs, + enable_prefix_caching=True, + ) + ctx.add_request(req1) + + # Exhaust the remaining pool. + while ctx.block_allocator.total_avail > 0: + ctx.block_allocator.allocate_memory_blocks(1) + + # A new request with the same prefix should still be schedulable + # because prefix matching means 0 new blocks are needed from pool. + req2 = DynamicInferenceRequest( + request_id=2, + prompt_tokens=prompt.clone(), + sampling_params=SamplingParams(num_tokens_to_generate=10), + block_size_tokens=bs, + enable_prefix_caching=True, + ) + _, _, kv_available = ctx.check_availability(req2) + assert kv_available, "Matched blocks should not require pool allocation" diff --git a/tests/unit_tests/inference/engines/test_dynamic_engine.py b/tests/unit_tests/inference/engines/test_dynamic_engine.py index a2cbed87d5e..e730322a390 100644 --- a/tests/unit_tests/inference/engines/test_dynamic_engine.py +++ b/tests/unit_tests/inference/engines/test_dynamic_engine.py @@ -2141,3 +2141,136 @@ def test_speculative_long_stop_word_hit(self): stop_hit = env.engine._check_stop_words_for_request_post_append(req) assert stop_hit is True + + @pytest.mark.internal + @pytest.mark.skipif( + not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" + ) + @torch.inference_mode() + def test_speculative_decoding_with_prefix_caching(self): + """Test that speculative decoding works correctly when prefix caching is enabled. + + Two requests share the same prompt prefix. The second request should reuse + cached KV blocks from the first and still generate correctly with spec decoding. + """ + test_config = DynamicEngineTestConfig( + num_requests=4, + min_prompt_length=8, + max_prompt_length=8, + num_tokens_to_generate=4, + num_speculative_tokens=2, + materialize_only_last_token_logits=False, + model_provider="gpt", + ) + env = self._build_test_env(test_config) + + # Enable prefix caching on the context. + env.engine.context.enable_prefix_caching = True + + # Create two pairs of requests with shared prefixes. + shared_prompt_a = torch.randint( + 0, test_config.vocab_size - 1, (8,), dtype=torch.int64, device='cuda' + ) + shared_prompt_b = torch.randint( + 0, test_config.vocab_size - 1, (8,), dtype=torch.int64, device='cuda' + ) + + for i, prompt in enumerate([shared_prompt_a, shared_prompt_a, shared_prompt_b, shared_prompt_b]): + env.requests[i].prompt_tokens = prompt.clone() + + # Run all requests through the engine. + for request in env.requests: + env.engine._add_request(request) + + while env.engine.has_unfinished_requests(): + self._run_step(env) + + # All requests should complete. + for request in env.requests: + assert request.status in (Status.COMPLETED, Status.FAILED) + if request.status == Status.COMPLETED: + assert len(request.generated_tokens) > 0 + + # Context should be clean after all requests finish. + assert env.engine.context.active_token_count == 0 + assert env.engine.context.total_request_count == 0 + + @pytest.mark.internal + @pytest.mark.skipif( + not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" + ) + @torch.inference_mode() + def test_speculative_decoding_with_chunked_prefill(self): + """Test that speculative decoding combined with chunked prefill completes correctly.""" + test_config = DynamicEngineTestConfig( + num_requests=2, + min_prompt_length=16, + max_prompt_length=16, + num_tokens_to_generate=4, + num_speculative_tokens=2, + materialize_only_last_token_logits=False, + enable_chunked_prefill=True, + model_provider="gpt", + context_max_tokens=32, # Force chunking by limiting token budget + ) + env = self._build_test_env(test_config) + + for request in env.requests: + env.engine._add_request(request) + + while env.engine.has_unfinished_requests(): + self._run_step(env) + + for request in env.requests: + assert request.status in (Status.COMPLETED, Status.FAILED) + + assert env.engine.context.active_token_count == 0 + assert env.engine.context.total_request_count == 0 + + @pytest.mark.internal + @pytest.mark.skipif( + not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" + ) + @torch.inference_mode() + def test_speculative_decoding_chunked_prefill_and_prefix_caching(self): + """End-to-end test combining speculative decoding, chunked prefill, and prefix caching. + + Verifies that all three features interact correctly: + - Prefix caching shares KV blocks between requests with common prompts + - Chunked prefill processes long prompts in chunks + - Speculative decoding generates multiple tokens per step + """ + test_config = DynamicEngineTestConfig( + num_requests=4, + min_prompt_length=16, + max_prompt_length=16, + num_tokens_to_generate=4, + num_speculative_tokens=2, + materialize_only_last_token_logits=False, + enable_chunked_prefill=True, + model_provider="gpt", + context_max_tokens=48, # Force chunking + ) + env = self._build_test_env(test_config) + + # Enable prefix caching. + env.engine.context.enable_prefix_caching = True + + # Create pairs with shared prefixes to exercise prefix caching. + shared_prompt = torch.randint( + 0, test_config.vocab_size - 1, (16,), dtype=torch.int64, device='cuda' + ) + for i in range(len(env.requests)): + env.requests[i].prompt_tokens = shared_prompt.clone() + + for request in env.requests: + env.engine._add_request(request) + + while env.engine.has_unfinished_requests(): + self._run_step(env) + + for request in env.requests: + assert request.status in (Status.COMPLETED, Status.FAILED) + + assert env.engine.context.active_token_count == 0 + assert env.engine.context.total_request_count == 0 diff --git a/tests/unit_tests/inference/text_generation_controllers/test_text_generation_controller.py b/tests/unit_tests/inference/text_generation_controllers/test_text_generation_controller.py index 833416a54f9..ba453670862 100644 --- a/tests/unit_tests/inference/text_generation_controllers/test_text_generation_controller.py +++ b/tests/unit_tests/inference/text_generation_controllers/test_text_generation_controller.py @@ -1197,3 +1197,97 @@ def test_speculative_multinomial_sampling(self): assert sampled_tokens.shape == (2,) assert sampled_mtp_tokens.shape == (num_spec, 2) + + @pytest.mark.internal + def test_rewind_kv_cache_with_prefix_caching_ref_counts(self): + """Test that _rewind_kv_cache correctly decrements ref counts on shared blocks + when speculative token rejection causes a block boundary crossing.""" + self.setup_model(torch.float32, static=False) + + ctx = self.text_generation_controller.inference_wrapped_model.inference_context + self.text_generation_controller.num_speculative_tokens = 2 + ctx.num_speculative_tokens = 2 + ctx.block_size_tokens = 4 + ctx.enable_prefix_caching = True + ctx.total_request_count = 2 + ctx.paused_request_count = 0 + ctx.request_in_prefill_status_tensor = torch.tensor([0, 0], device='cuda') + + # Initialize allocator ref count tracking. + ctx.block_allocator.enable_prefix_caching = True + if not hasattr(ctx.block_allocator, 'block_ref_counts'): + ctx.block_allocator.block_ref_counts = torch.zeros( + ctx.block_allocator.total_count, dtype=torch.int32, device='cuda' + ) + + # Req 0: 3 blocks, offset 1 in last block. Rewinding 1 token -> no block release. + # Req 1: 3 blocks, offset 0 in last block. Rewinding 2 tokens -> crosses back, release block. + ctx.request_kv_length_offsets[:2] = torch.tensor([9, 9], device='cuda') + ctx.request_kv_block_counts[:2] = torch.tensor([3, 3], device='cuda') + ctx.request_last_kv_block_offset[:2] = torch.tensor([1, 0], device='cuda') + ctx.request_last_kv_block_id[:2] = torch.tensor([10, 20], device='cuda') + ctx.request_to_kv_block_ids[:2, :3] = torch.tensor( + [[8, 9, 10], [18, 19, 20]], dtype=torch.int, device='cuda' + ) + + # Set ref counts: block 20 is shared (ref=2), block 10 is exclusive (ref=1). + ctx.block_allocator.block_ref_counts[20] = 2 + ctx.block_allocator.block_ref_counts[10] = 1 + + initial_avail = ctx.block_allocator.total_avail + + # Req 0 accepts 1 (rewinds 1), Req 1 accepts 0 (rewinds 2, crosses boundary). + self.text_generation_controller._init_mtp_sampling_tensor() + self.text_generation_controller._accepted_token_counts_per_request = torch.tensor( + [1, 0], device='cuda' + ) + + self.text_generation_controller._rewind_kv_cache() + + # Req 1 should have released block 20 (ref count decremented). + assert ctx.block_allocator.block_ref_counts[20].item() == 1 + # Block 10 should be untouched. + assert ctx.block_allocator.block_ref_counts[10].item() == 1 + + @pytest.mark.internal + def test_rewind_kv_cache_does_not_release_shared_prefix_blocks(self): + """Test that rewinding only releases the last block, never shared prefix blocks.""" + self.setup_model(torch.float32, static=False) + + ctx = self.text_generation_controller.inference_wrapped_model.inference_context + self.text_generation_controller.num_speculative_tokens = 3 + ctx.num_speculative_tokens = 3 + ctx.block_size_tokens = 4 + ctx.total_request_count = 1 + ctx.paused_request_count = 0 + ctx.request_in_prefill_status_tensor = torch.tensor([0], device='cuda') + + # 4 blocks. Offset 2 in last block. Rewinding 3 crosses into previous block. + ctx.request_kv_length_offsets[:1] = torch.tensor([14], device='cuda') + ctx.request_kv_block_counts[:1] = torch.tensor([4], device='cuda') + ctx.request_last_kv_block_offset[:1] = torch.tensor([2], device='cuda') + ctx.request_last_kv_block_id[:1] = torch.tensor([40], device='cuda') + ctx.request_to_kv_block_ids[0, :4] = torch.tensor( + [10, 20, 30, 40], dtype=torch.int, device='cuda' + ) + + # Blocks 10, 20 are shared prefix blocks. Block 30, 40 are exclusive. + ctx.block_allocator.total_avail = 50 + + self.text_generation_controller._init_mtp_sampling_tensor() + self.text_generation_controller._accepted_token_counts_per_request = torch.tensor( + [0], device='cuda' + ) + + self.text_generation_controller._rewind_kv_cache() + + # Only block 40 should be released, not blocks 10, 20, or 30. + assert ctx.request_kv_block_counts[0].item() == 3 + assert ctx.request_last_kv_block_id[0].item() == 30 + assert ctx.request_to_kv_block_ids[0, 3].item() == -1 + assert ctx.block_allocator.total_avail == 51 # exactly 1 block released + + # Prefix blocks remain in request_to_kv_block_ids. + assert ctx.request_to_kv_block_ids[0, 0].item() == 10 + assert ctx.request_to_kv_block_ids[0, 1].item() == 20 + assert ctx.request_to_kv_block_ids[0, 2].item() == 30 From d26450bbf73be4901398e7730700319f16ad2d93 Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Fri, 6 Mar 2026 11:06:55 -0800 Subject: [PATCH 41/76] Fix speculative decode engine test Signed-off-by: Keshav Santhanam --- .../inference/engines/test_dynamic_engine.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/unit_tests/inference/engines/test_dynamic_engine.py b/tests/unit_tests/inference/engines/test_dynamic_engine.py index e730322a390..18059417cf8 100644 --- a/tests/unit_tests/inference/engines/test_dynamic_engine.py +++ b/tests/unit_tests/inference/engines/test_dynamic_engine.py @@ -42,6 +42,7 @@ get_gpt_layer_local_spec, get_gpt_layer_with_inference_spec, get_gpt_layer_with_transformer_engine_spec, + get_gpt_mtp_block_spec, ) from megatron.core.models.gpt.gpt_model import GPTModel from megatron.core.models.mamba.mamba_layer_specs import mamba_stack_spec @@ -339,6 +340,14 @@ def _build_test_env(cls, test_config): elif test_config.transformer_impl == "inference_optimized": layer_spec = get_gpt_layer_with_inference_spec() + # MTP block spec (needed for speculative decoding). + mtp_block_spec = None + if test_config.num_speculative_tokens > 0: + use_te = test_config.fp8 or test_config.transformer_impl == "transformer_engine" + mtp_block_spec = get_gpt_mtp_block_spec( + config=transformer_config, spec=layer_spec, use_transformer_engine=use_te, + ) + # GPT model. model = GPTModel( config=transformer_config, @@ -348,6 +357,7 @@ def _build_test_env(cls, test_config): parallel_output=True, pre_process=parallel_state.is_pipeline_first_stage(), post_process=parallel_state.is_pipeline_last_stage(), + mtp_block_spec=mtp_block_spec, ).cuda() elif test_config.model_provider == "mamba": pp_size = test_config.pipeline_model_parallel_size From d942c0fbd7cfec7eab922df2808d44a484778076 Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Fri, 6 Mar 2026 11:39:18 -0800 Subject: [PATCH 42/76] Enable prefix caching in the config Signed-off-by: Keshav Santhanam --- tests/unit_tests/inference/engines/test_dynamic_engine.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/unit_tests/inference/engines/test_dynamic_engine.py b/tests/unit_tests/inference/engines/test_dynamic_engine.py index 18059417cf8..8f7279b3cd9 100644 --- a/tests/unit_tests/inference/engines/test_dynamic_engine.py +++ b/tests/unit_tests/inference/engines/test_dynamic_engine.py @@ -126,6 +126,7 @@ class DynamicEngineTestConfig: materialize_only_last_token_logits: bool = True skip_prompt_log_probs: bool = False enable_chunked_prefill: bool = False + enable_prefix_caching: bool = False cuda_graph_scope: List[CudaGraphScope] = field( default_factory=lambda: [CudaGraphScope.full_iteration_inference] ) @@ -260,6 +261,7 @@ def _build_inference_context( ), static_kv_memory_pointers=test_config.static_kv_memory_pointers, enable_chunked_prefill=test_config.enable_chunked_prefill, + enable_prefix_caching=test_config.enable_prefix_caching, use_flashinfer_fused_rope=None, # default to using flash-infer if available # this is for compatibility with the LTS environment unified_memory_level=0, # unit tests currently broken with UVM @@ -2169,14 +2171,12 @@ def test_speculative_decoding_with_prefix_caching(self): max_prompt_length=8, num_tokens_to_generate=4, num_speculative_tokens=2, + enable_prefix_caching=True, materialize_only_last_token_logits=False, model_provider="gpt", ) env = self._build_test_env(test_config) - # Enable prefix caching on the context. - env.engine.context.enable_prefix_caching = True - # Create two pairs of requests with shared prefixes. shared_prompt_a = torch.randint( 0, test_config.vocab_size - 1, (8,), dtype=torch.int64, device='cuda' From 8b3fd1058e65e5e4215481cda2695681f2303edb Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Sun, 8 Mar 2026 23:04:41 -0700 Subject: [PATCH 43/76] Fix tests Signed-off-by: Keshav Santhanam --- .../contexts/dynamic_block_allocator.py | 40 +- .../inference/contexts/dynamic_context.py | 1 + megatron/core/transformer/attention.py | 8 +- .../contexts/test_dynamic_context.py | 304 +++++++++++++- .../inference/engines/test_dynamic_engine.py | 382 +++++++++++++----- 5 files changed, 620 insertions(+), 115 deletions(-) diff --git a/megatron/core/inference/contexts/dynamic_block_allocator.py b/megatron/core/inference/contexts/dynamic_block_allocator.py index abfb7278b14..5bbf7001094 100644 --- a/megatron/core/inference/contexts/dynamic_block_allocator.py +++ b/megatron/core/inference/contexts/dynamic_block_allocator.py @@ -85,19 +85,39 @@ def get_total_used(self): def get_active_used(self): """Compute number of active blocks used.""" - return ( - self.context.request_kv_block_counts[ - self.context.paused_request_count : self.context.total_request_count - ] - .sum() - .item() - ) + if not self.enable_prefix_caching: + return ( + self.context.request_kv_block_counts[ + self.context.paused_request_count : self.context.total_request_count + ] + .sum() + .item() + ) + + active_start = self.context.paused_request_count + active_end = self.context.total_request_count + if active_end > active_start: + active_rows = self.context.request_to_kv_block_ids[active_start:active_end] + valid_ids = active_rows[active_rows >= 0] + if valid_ids.numel() > 0: + return int(torch.unique(valid_ids).numel()) + return 0 def get_paused_used(self): """Compute number of paused blocks used.""" - return ( - self.context.request_kv_block_counts[: self.context.paused_request_count].sum().item() - ) + if not self.enable_prefix_caching: + return ( + self.context.request_kv_block_counts[: self.context.paused_request_count] + .sum() + .item() + ) + + if self.context.paused_request_count > 0: + paused_rows = self.context.request_to_kv_block_ids[: self.context.paused_request_count] + valid_ids = paused_rows[paused_rows >= 0] + if valid_ids.numel() > 0: + return int(torch.unique(valid_ids).numel()) + return 0 def get_active_avail(self): """Compute number of active blocks available.""" diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index d04359aacdf..22bb07bd968 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -2592,6 +2592,7 @@ def update_requests( # crosses_boundary = [[False, False, False], [False, True, True], [False, False, False]] raw_positions = ( old_offsets[:, None] + + 1 # Offset by 1 because old_offsets points to the LAST token + torch.arange(num_generated_tokens, device=torch.cuda.current_device())[None, :] ) # diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index 310a59bde35..0ac7e78fae3 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -60,7 +60,9 @@ rearrange = None try: - from flash_attn_3.flash_attn_interface import _flash_attn_forward + from flash_attn_3.flash_attn_interface import ( + _flash_attn_forward, + ) from flash_attn_3.flash_attn_interface import ( flash_attn_with_kvcache as flash_attn3_with_kvcache, ) @@ -71,7 +73,9 @@ if not HAVE_FA3: try: - from flashattn_hopper.flash_attn_interface import _flash_attn_forward + from flashattn_hopper.flash_attn_interface import ( + _flash_attn_forward, + ) from flashattn_hopper.flash_attn_interface import ( flash_attn_with_kvcache as flash_attn3_with_kvcache, ) diff --git a/tests/unit_tests/inference/contexts/test_dynamic_context.py b/tests/unit_tests/inference/contexts/test_dynamic_context.py index bf7387cd658..1b7a2cba5dc 100644 --- a/tests/unit_tests/inference/contexts/test_dynamic_context.py +++ b/tests/unit_tests/inference/contexts/test_dynamic_context.py @@ -1601,9 +1601,11 @@ def test_speculative_boundary_crossing(self): ctx.request_query_lengths[0] = 1 ctx.request_kv_block_counts[0] = 1 - # Request is at offset 2. Adding 3 tokens (1 sampled + 2 spec) will cross boundary (2+3 = 5 > 4). + # Length is 2, meaning existing tokens are at indices 0 and 1. + # The last inserted token was at offset 1. + # Adding 3 tokens places them at offsets 2, 3, and 4 (crosses block size of 4). ctx.request_kv_length_offsets[0] = 2 - ctx.request_last_kv_block_offset[0] = 2 + ctx.request_last_kv_block_offset[0] = 1 # Allocate one initial block manually blocks = ctx.block_allocator.allocate_memory_blocks(1) @@ -1896,6 +1898,8 @@ def test_speculative_update_then_release_with_prefix_caching(self): num_speculative_tokens=2, enable_prefix_caching=True, unified_memory_level=0, + max_requests=512, + max_tokens=512, ) ctx = DynamicInferenceContext(model_config=model_config, inference_config=inference_config) @@ -1952,6 +1956,8 @@ def test_speculative_boundary_crossing_with_prefix_caching(self): num_speculative_tokens=2, enable_prefix_caching=True, unified_memory_level=0, + max_tokens=512, + max_requests=512, ) ctx = DynamicInferenceContext(model_config=model_config, inference_config=inference_config) @@ -1987,19 +1993,18 @@ def test_speculative_boundary_crossing_with_prefix_caching(self): # Set up request 0 for decode at offset that will cross block boundary. # Place at offset (block_size - 1) in last block so adding 3 tokens crosses. ctx.request_kv_length_offsets[0] = bs * 2 - 1 # one token from end of block 1 - ctx.request_last_kv_block_offset[0] = bs - 1 + # The local offset of index 6 is (6 % bs) + ctx.request_last_kv_block_offset[0] = bs - 2 ctx.request_query_lengths[0] = 1 ctx.request_in_prefill_status_tensor[0] = 0 ctx.active_token_count = 2 - active_mask = torch.tensor([1, 0], device='cuda', dtype=torch.int32) - new_tokens = torch.tensor([50], device='cuda') - new_spec = torch.tensor([[51], [52]], device='cuda') + active_mask = torch.tensor([1, 1], device='cuda', dtype=torch.int32) + new_tokens = torch.tensor([50, 50], device='cuda') + new_spec = torch.tensor([[51, 51], [52, 52]], device='cuda') ctx.update_requests( - active_requests_mask=active_mask, - new_tokens=new_tokens, - new_speculative_tokens=new_spec, + active_requests_mask=active_mask, new_tokens=new_tokens, new_speculative_tokens=new_spec ) # A new block should have been allocated for the boundary crossing. @@ -2130,3 +2135,284 @@ def test_prefix_caching_check_availability_with_speculative(self): ) _, _, kv_available = ctx.check_availability(req2) assert kv_available, "Matched blocks should not require pool allocation" + + @pytest.mark.internal + @rounder_override(64) + def test_prefix_match_exact_block_boundary(self): + """Test prefix matching when the shared prefix is an exact multiple of the block size.""" + self._setup_model_parallel_group(1, 1) + + model_config = TransformerConfig( + params_dtype=torch.float32, num_layers=2, kv_channels=8, num_attention_heads=2 + ) + inference_config = InferenceConfig( + max_sequence_length=512, + buffer_size_gb=0.1, + block_size_tokens=16, + enable_prefix_caching=True, + unified_memory_level=0, + max_tokens=512, + max_requests=512, + ) + ctx = DynamicInferenceContext(model_config=model_config, inference_config=inference_config) + + bs = ctx.block_size_tokens + + # req1: 32 tokens (exactly 2 complete blocks) + prompt1 = torch.arange(bs * 2, device='cuda') + req1 = DynamicInferenceRequest( + request_id=1, + prompt_tokens=prompt1, + sampling_params=SamplingParams(num_tokens_to_generate=10), + block_size_tokens=bs, + enable_prefix_caching=True, + ) + ctx.add_request(req1) + + # req2: 35 tokens (first 32 tokens match req1) + prompt2 = torch.arange(bs * 2 + 3, device='cuda') + req2 = DynamicInferenceRequest( + request_id=2, + prompt_tokens=prompt2, + sampling_params=SamplingParams(num_tokens_to_generate=10), + block_size_tokens=bs, + enable_prefix_caching=True, + ) + ctx.add_request(req2) + + # req2 should have 3 blocks total + assert ctx.request_kv_block_counts[1].item() == 3 + + # The first 2 blocks should be shared + assert ctx.request_to_kv_block_ids[1, 0].item() == ctx.request_to_kv_block_ids[0, 0].item() + assert ctx.request_to_kv_block_ids[1, 1].item() == ctx.request_to_kv_block_ids[0, 1].item() + + # The 3rd block should be a newly allocated pool block + assert ctx.request_to_kv_block_ids[1, 2].item() != ctx.request_to_kv_block_ids[0, 1].item() + + # The offset points to the last token (index 34). In the 3rd block (indices 32-47), 34 is at offset 2. + assert ctx.request_last_kv_block_offset[1].item() == 2 + + # Effective query length should be 3 (35 total - 32 skipped) + assert ctx.request_query_lengths[1].item() == 3 + + @pytest.mark.internal + @rounder_override(64) + def test_eviction_with_shared_prefix_blocks(self): + """Test that evicting a request drops ref counts correctly without destroying shared blocks.""" + self._setup_model_parallel_group(1, 1) + + model_config = TransformerConfig( + params_dtype=torch.float32, num_layers=2, kv_channels=8, num_attention_heads=2 + ) + inference_config = InferenceConfig( + max_sequence_length=512, + buffer_size_gb=0.1, + block_size_tokens=16, + enable_prefix_caching=True, + unified_memory_level=0, + paused_buffer_size_gb=0.0, # 0 paused capacity to force immediate eviction + max_tokens=512, + max_requests=512, + ) + ctx = DynamicInferenceContext(model_config=model_config, inference_config=inference_config) + + bs = ctx.block_size_tokens + prompt = torch.arange(bs * 2, device='cuda') + + # Add req1 and req2 with identical prompts + req1 = DynamicInferenceRequest( + request_id=1, + prompt_tokens=prompt.clone(), + sampling_params=SamplingParams(num_tokens_to_generate=10), + block_size_tokens=bs, + enable_prefix_caching=True, + ) + ctx.add_request(req1) + + req2 = DynamicInferenceRequest( + request_id=2, + prompt_tokens=prompt.clone(), + sampling_params=SamplingParams(num_tokens_to_generate=10), + block_size_tokens=bs, + enable_prefix_caching=True, + ) + ctx.add_request(req2) + + shared_b0 = ctx.request_to_kv_block_ids[0, 0].item() + shared_b1 = ctx.request_to_kv_block_ids[0, 1].item() + + # Both blocks should be safely shared with ref count 2 + assert ctx.block_allocator.block_ref_counts[shared_b0].item() == 2 + + # Mock the state to make req1 paused and req2 active + ctx.paused_request_count = 1 + ctx.total_request_count = 2 + ctx.request_ids[0] = 1 + ctx.request_ids[1] = 2 + ctx.request_kv_block_counts[0] = 2 + ctx.request_kv_block_counts[1] = 2 + + # Exhaust the active block allocator + ctx.block_allocator.total_avail = 0 + + # Trigger the eviction logic + # next_tokens must be sized to total_request_count (1 paused + 1 active = 2) + next_tokens = torch.tensor([50, 51], device='cuda') + evicted_ids = ctx.evict_overflow_paused_requests( + active_request_count=1, next_tokens=next_tokens + ) + + # req1 should be successfully evicted + assert evicted_ids is not None + assert evicted_ids[0].item() == 1 + + # req2 remains active, so the shared blocks should drop to a ref count of 1 + assert ctx.block_allocator.block_ref_counts[shared_b0].item() == 1 + assert ctx.block_allocator.block_ref_counts[shared_b1].item() == 1 + + @pytest.mark.internal + @rounder_override(64) + def test_oom_during_speculative_boundary_crossing(self): + """Test boundary crossing with speculative tokens pauses the request gracefully when KV cache is full, keeping other requests active.""" + self._setup_model_parallel_group(1, 1) + + model_config = TransformerConfig( + params_dtype=torch.float32, num_layers=2, kv_channels=8, num_attention_heads=2 + ) + inference_config = InferenceConfig( + max_sequence_length=128, + buffer_size_gb=0.1, + block_size_tokens=16, + num_speculative_tokens=2, + unified_memory_level=0, + max_tokens=512, + max_requests=512, + ) + ctx = DynamicInferenceContext(model_config=model_config, inference_config=inference_config) + bs = ctx.block_size_tokens + + # Setup 2 active requests. + # Request 0 is exactly 1 token away from its boundary (will OOM). + # Request 1 has plenty of space (will remain active). + ctx.total_request_count = 2 + ctx.paused_request_count = 0 + ctx.active_token_count = 2 + + ctx.request_ids[:2] = torch.tensor([10, 11], device='cuda') + ctx.request_query_lengths[:2] = 1 + ctx.request_kv_block_counts[:2] = 1 + + # Request 0 offset is 15. Adding 1 sampled + 2 spec = 3 tokens crosses the boundary (16). + # Request 1 offset is 5. Adding 3 tokens = 8 (does not cross). + ctx.request_kv_length_offsets[:2] = torch.tensor( + [bs - 1, 5], device='cuda', dtype=torch.int32 + ) + ctx.request_last_kv_block_offset[:2] = torch.tensor( + [bs - 1, 5], device='cuda', dtype=torch.int32 + ) + + blocks = ctx.block_allocator.allocate_memory_blocks(2) + ctx.request_to_kv_block_ids[0, 0] = blocks[0] + ctx.request_to_kv_block_ids[1, 0] = blocks[1] + ctx.request_last_kv_block_id[:2] = blocks + + # Force OOM condition (no blocks left in the active pool) + ctx.block_allocator.total_avail = 0 + ctx.block_allocator.paused_count = 100 # Prevent immediate eviction out of the system + + active_mask = torch.tensor([1, 1], device='cuda', dtype=torch.int32) + new_tokens = torch.tensor([99, 88], device='cuda') + new_spec = torch.tensor([[100, 200], [101, 201]], device='cuda') + + # Run update requests + ctx.update_requests( + active_requests_mask=active_mask, new_tokens=new_tokens, new_speculative_tokens=new_spec + ) + + # Request 0 should detect OOM, fail to allocate a new block, and pause. + # Request 1 remains active, so active_request_count goes 2 -> 1, avoiding the deadlock assert. + assert ctx.paused_request_count == 1 + assert ctx.total_request_count == 2 + + # Request 1 generated 3 tokens (1 sampled + 2 spec) + assert ctx.active_token_count == 3 + + # Tokens must be cached in the paused buffers so Request 0 can resume cleanly later + assert ctx.paused_tokens is not None + assert ctx.paused_tokens[0].item() == 99 + + assert ctx.paused_speculative_tokens is not None + assert ctx.paused_speculative_tokens[0, 0].item() == 100 + assert ctx.paused_speculative_tokens[1, 0].item() == 101 + + @pytest.mark.internal + @rounder_override(64) + def test_chunked_prefill_meets_prefix_caching(self): + """Test that chunks in a chunked-prefill pipeline properly hit the prefix cache mid-flight.""" + self._setup_model_parallel_group(1, 1) + + model_config = TransformerConfig( + params_dtype=torch.float32, num_layers=2, kv_channels=8, num_attention_heads=2 + ) + inference_config = InferenceConfig( + max_sequence_length=512, + buffer_size_gb=0.1, + block_size_tokens=32, + enable_chunked_prefill=True, + enable_prefix_caching=True, + unified_memory_level=0, + max_tokens=512, + max_requests=512, + ) + ctx = DynamicInferenceContext(model_config=model_config, inference_config=inference_config) + + bs = ctx.block_size_tokens + prompt = torch.arange(128, device='cuda') + + # Cache req1 (fully processed) + req1 = DynamicInferenceRequest( + request_id=1, + prompt_tokens=prompt.clone(), + sampling_params=SamplingParams(num_tokens_to_generate=10), + block_size_tokens=bs, + enable_prefix_caching=True, + ) + ctx.add_request(req1) + req1_blocks = [ctx.request_to_kv_block_ids[0, i].item() for i in range(4)] + + # Start chunked prefill for req2. + req2 = DynamicInferenceRequest( + request_id=2, + prompt_tokens=prompt.clone(), + sampling_params=SamplingParams(num_tokens_to_generate=10), + block_size_tokens=bs, + enable_prefix_caching=True, + ) + + # Add the first chunk (64 tokens) + req2.finished_chunk_token_count = 0 + ctx.chunked_prefill_request_id = 2 + ctx.add_request(req2, chunk_length=64) + + # Assert the first chunk perfectly matched the first 2 cached blocks + assert ctx.request_to_kv_block_ids[1, 0].item() == req1_blocks[0] + assert ctx.request_to_kv_block_ids[1, 1].item() == req1_blocks[1] + assert ctx.request_kv_block_counts[1].item() == 2 + + # Simulate update_requests completing the chunk + ctx.active_token_count += 1 + ctx.request_in_prefill_status_tensor[1] = 0 + + # Add the second chunk (64 tokens) + req2.finished_chunk_token_count = 64 + ctx.add_request(req2, chunk_length=64) + + # It should correctly discover the remaining prefix blocks despite being mid-prefill + assert ctx.request_to_kv_block_ids[1, 2].item() == req1_blocks[2] + assert ctx.request_to_kv_block_ids[1, 3].item() == req1_blocks[3] + assert ctx.request_kv_block_counts[1].item() == 4 + + # Verify block references updated appropriately + assert ctx.block_allocator.block_ref_counts[req1_blocks[2]].item() == 2 + assert ctx.block_allocator.block_ref_counts[req1_blocks[3]].item() == 2 diff --git a/tests/unit_tests/inference/engines/test_dynamic_engine.py b/tests/unit_tests/inference/engines/test_dynamic_engine.py index 8f7279b3cd9..bff603594c7 100644 --- a/tests/unit_tests/inference/engines/test_dynamic_engine.py +++ b/tests/unit_tests/inference/engines/test_dynamic_engine.py @@ -150,10 +150,14 @@ def __post_init__(self): if self.max_sequence_length is None: assert self.num_tokens_to_generate is None or self.num_tokens_total is None if self.num_tokens_to_generate is not None: - self.max_sequence_length = self.max_prompt_length + self.num_tokens_to_generate + self.max_sequence_length = ( + self.max_prompt_length + + self.num_tokens_to_generate + + self.num_speculative_tokens + ) else: assert self.num_tokens_total is not None - self.max_sequence_length = self.num_tokens_total + self.max_sequence_length = self.num_tokens_total + self.num_speculative_tokens # Default paused buffer size. if self.context_paused_buffer_size_gb is None: @@ -347,7 +351,7 @@ def _build_test_env(cls, test_config): if test_config.num_speculative_tokens > 0: use_te = test_config.fp8 or test_config.transformer_impl == "transformer_engine" mtp_block_spec = get_gpt_mtp_block_spec( - config=transformer_config, spec=layer_spec, use_transformer_engine=use_te, + config=transformer_config, spec=layer_spec, use_transformer_engine=use_te ) # GPT model. @@ -2083,6 +2087,59 @@ def mock_mtp_forward(*args, **kwargs): assert env.engine.context.active_token_count == 0 assert env.engine.context.total_request_count == 0 + @pytest.mark.internal + @torch.inference_mode() + def test_speculative_block_boundary_crossing(self): + """Test to verify KV cache block boundary crossing logic. + + When a request fills exactly one block and speculative decoding generates + multiple tokens, the first new token shouldn't incorrectly overwrite the old block. + """ + test_config = DynamicEngineTestConfig( + num_requests=1, + min_prompt_length=4, + max_prompt_length=4, + num_tokens_to_generate=3, + num_speculative_tokens=2, + context_block_size_tokens=4, # Exactly matches prompt length + context_max_requests=16, + model_provider="gpt", + materialize_only_last_token_logits=False, + use_fixed_output_lengths=True, + ) + env = self._build_test_env(test_config) + + req = env.requests[0] + req.sampling_params.num_tokens_to_generate = 3 + env.engine._add_request(req) + env.engine.schedule_waiting_requests() + + # Step 1: Prefill. Processes the 4 prompt tokens. + # At the end of this step, `update_requests` prepares the token indices for Step 2. + # It assigns block indices for the 3 upcoming tokens (1 base + 2 spec). + env.engine.step_modern() + + context = env.engine.context + + # The request has 2 blocks allocated now (1 for prompt, 1 for the new 3 tokens) + assigned_blocks = context.request_to_kv_block_ids[0] + first_block = assigned_blocks[0].item() + second_block = assigned_blocks[1].item() + + # The active_token_count for the next step should be 3 + assert context.active_token_count == 3 + + # Check which blocks the 3 new tokens are assigned to. + # Because the prompt exactly filled the first block, ALL 3 new tokens + # MUST go to the second block. + token_blocks = context.token_to_block_idx[: context.active_token_count].tolist() + + assert token_blocks == [ + second_block, + second_block, + second_block, + ], f"Expected all new tokens to go to block {second_block}, but got {token_blocks}." + @pytest.mark.internal @pytest.mark.skipif( not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" @@ -2093,35 +2150,71 @@ def test_speculative_stop_word_hit(self): the request correctly triggers the stop logic without crashing.""" test_config = DynamicEngineTestConfig( - num_requests=0, num_speculative_tokens=2, materialize_only_last_token_logits=False + num_requests=0, # We will manually add our request cleanly + min_prompt_length=4, + max_prompt_length=4, + num_tokens_to_generate=10, + num_speculative_tokens=2, + materialize_only_last_token_logits=False, + model_provider="gpt", ) env = self._build_test_env(test_config) - # Mock request with a stop word - req = DynamicInferenceRequest( - request_id=0, - prompt_tokens=torch.tensor([1, 2, 3], device='cuda'), - sampling_params=SamplingParams(num_tokens_to_generate=10), - ) - # Let's say the stop word is [99, 100] - req.stop_word_ids = [[99, 100]] + unwrapped_model = env.engine.controller.inference_wrapped_model.model - # Fast-forward state: The base token was 99 - req.generated_tokens = [99] - tokens_to_append = [100, 101] # 1 accepted spec token, 1 rejected + # Mock forward to deterministically output an ascending sequence (1->2->3...) + def mock_deterministic_forward(*args, **kwargs): + tokens = kwargs.get("tokens", args[0] if args else kwargs.get("input_ids")) + b, s = tokens.shape - # Check before appending speculative tokens - stop_hit = env.engine._check_stop_words_for_request_post_append(req) - assert stop_hit is False # Only 99 is in generated_tokens initially + base_logits = torch.zeros( + b, s, test_config.vocab_size, device=tokens.device, dtype=torch.bfloat16 + ) + next_toks = (tokens + 1).clamp(max=test_config.vocab_size - 1) + base_logits.scatter_(2, next_toks.unsqueeze(-1), 100.0) - # Now append the tokens as `post_process_requests` would - req.generated_tokens += tokens_to_append + mtp_logits = torch.zeros( + 2, s, test_config.vocab_size, device=tokens.device, dtype=torch.bfloat16 + ) + mtp1_toks = (tokens + 2).clamp(max=test_config.vocab_size - 1) + mtp_logits[0].scatter_(1, mtp1_toks.squeeze(0).unsqueeze(-1), 100.0) - # Check again. It should detect the stop word [99, 100] inside [99, 100, 101] - # Specifically, it shifts backwards due to the speculative tokens. - stop_hit = env.engine._check_stop_words_for_request_post_append(req) + mtp2_toks = (tokens + 3).clamp(max=test_config.vocab_size - 1) + mtp_logits[1].scatter_(1, mtp2_toks.squeeze(0).unsqueeze(-1), 100.0) - assert stop_hit is True + unwrapped_model._mtp_logits_cache = mtp_logits + return base_logits + + unwrapped_model.forward = mock_deterministic_forward + + # Add the request formally to ensure all internal state tensors align + env.engine.add_request( + request_id=0, + prompt=torch.tensor([1, 2, 3, 4], device='cuda'), + sampling_params=SamplingParams(num_tokens_to_generate=10, termination_id=99), + ) + + # Inject the parsed stop word IDs + tracked_req = env.engine.get_request(0) + tracked_req.stop_word_ids = [[8, 9]] # The sequence will generate 5, 6, 7, 8, 9, ... + + finished_records = [] + while env.engine.has_unfinished_requests(): + res = env.engine.step_modern() + finished_records.extend(res["finished_request_records"]) + + # Retrieve the finalized request from the engine's output + finished_req = finished_records[0].merge() + + assert finished_req.status == Status.COMPLETED + # Since num_tokens_to_generate=10, output should stop early at ~7 tokens + assert len(finished_req.generated_tokens) < 10 + # Verify the stop word was actually generated and caused the termination + token_pairs = [ + finished_req.generated_tokens[i : i + 2] + for i in range(len(finished_req.generated_tokens) - 1) + ] + assert [8, 9] in token_pairs @pytest.mark.internal @pytest.mark.skipif( @@ -2133,26 +2226,141 @@ def test_speculative_long_stop_word_hit(self): (length > num_speculative_tokens), it is correctly detected.""" test_config = DynamicEngineTestConfig( - num_requests=0, num_speculative_tokens=2, materialize_only_last_token_logits=False + num_requests=0, + min_prompt_length=4, + max_prompt_length=4, + num_tokens_to_generate=10, + num_speculative_tokens=2, + materialize_only_last_token_logits=False, + model_provider="gpt", ) env = self._build_test_env(test_config) - # Mock request with a stop word - req = DynamicInferenceRequest( + unwrapped_model = env.engine.controller.inference_wrapped_model.model + + # Mock forward to deterministically output an ascending sequence + def mock_deterministic_forward(*args, **kwargs): + tokens = kwargs.get("tokens", args[0] if args else kwargs.get("input_ids")) + b, s = tokens.shape + + base_logits = torch.zeros( + b, s, test_config.vocab_size, device=tokens.device, dtype=torch.bfloat16 + ) + next_toks = (tokens + 1).clamp(max=test_config.vocab_size - 1) + base_logits.scatter_(2, next_toks.unsqueeze(-1), 100.0) + + mtp_logits = torch.zeros( + 2, s, test_config.vocab_size, device=tokens.device, dtype=torch.bfloat16 + ) + mtp1_toks = (tokens + 2).clamp(max=test_config.vocab_size - 1) + mtp_logits[0].scatter_(1, mtp1_toks.squeeze(0).unsqueeze(-1), 100.0) + + mtp2_toks = (tokens + 3).clamp(max=test_config.vocab_size - 1) + mtp_logits[1].scatter_(1, mtp2_toks.squeeze(0).unsqueeze(-1), 100.0) + + unwrapped_model._mtp_logits_cache = mtp_logits + return base_logits + + unwrapped_model.forward = mock_deterministic_forward + + env.engine.add_request( request_id=0, - prompt_tokens=torch.tensor([1, 2, 3], device='cuda'), - sampling_params=SamplingParams(num_tokens_to_generate=10), + prompt=torch.tensor([1, 2, 3, 4], device='cuda'), + sampling_params=SamplingParams(num_tokens_to_generate=10, termination_id=99), ) + # Stop word length 3 > num_speculative_tokens (2) - req.stop_word_ids = [[98, 99, 100]] + tracked_req = env.engine.get_request(0) + tracked_req.stop_word_ids = [[7, 8, 9]] + + finished_records = [] + while env.engine.has_unfinished_requests(): + res = env.engine.step_modern() + finished_records.extend(res["finished_request_records"]) + + finished_req = finished_records[0].merge() + + assert finished_req.status == Status.COMPLETED + assert len(finished_req.generated_tokens) < 10 + token_triplets = [ + finished_req.generated_tokens[i : i + 3] + for i in range(len(finished_req.generated_tokens) - 2) + ] + assert [7, 8, 9] in token_triplets + + @pytest.mark.internal + @torch.inference_mode() + def test_speculative_sequence_length_double_counting(self): + """Test to verify active_sequence_lengths is not double-counted. - # Fast-forward state: base tokens were generated up to 99 - req.generated_tokens = [98, 99] - tokens_to_append = [100, 101] # Completes stop word at index -2 - req.generated_tokens += tokens_to_append + If active sequence length is double-counted during speculative decoding, + the request will terminate prematurely before generating the requested tokens. + """ + test_config = DynamicEngineTestConfig( + num_requests=0, + min_prompt_length=4, + max_prompt_length=4, + num_tokens_to_generate=6, + max_sequence_length=10, # Exactly prompt (4) + generate (6) + context_max_requests=16, + num_speculative_tokens=2, + model_provider="gpt", + materialize_only_last_token_logits=False, + use_fixed_output_lengths=False, + context_max_tokens=512, + ) + env = self._build_test_env(test_config) + + # Mock forward pass to return deterministic disparate logits so + # speculative tokens are completely rejected every time. + def mock_mtp_forward_reject(*args, **kwargs): + tokens = kwargs.get("tokens", args[0] if args else kwargs.get("input_ids")) + b, s = tokens.shape + + # Base model correctly predicts tokens + 1 + base_logits = torch.zeros( + b, s, test_config.vocab_size, device=tokens.device, dtype=torch.bfloat16 + ) + next_toks = (tokens + 1).clamp(max=test_config.vocab_size - 1) + base_logits.scatter_(2, next_toks.unsqueeze(-1), 100.0) + + # Speculative model consistently predicts wildly wrong tokens to guarantee rejection + model = env.engine.controller.inference_wrapped_model.model + mtp_logits = torch.zeros( + test_config.num_speculative_tokens, + s, + test_config.vocab_size, + device=tokens.device, + dtype=torch.bfloat16, + ) + wrong_toks = (tokens + 5).clamp(max=test_config.vocab_size - 1) + mtp_logits[0].scatter_(1, wrong_toks.squeeze(0).unsqueeze(-1), 100.0) + mtp_logits[1].scatter_(1, wrong_toks.squeeze(0).unsqueeze(-1), 100.0) + + model._mtp_logits_cache = mtp_logits + return base_logits + + env.engine.controller.inference_wrapped_model.model.forward = mock_mtp_forward_reject + + env.engine.add_request( + request_id=0, + prompt=torch.tensor([1, 2, 3, 4], device='cuda'), + sampling_params=SamplingParams(num_tokens_to_generate=6, termination_id=99), + ) + + finished_records = [] + while env.engine.has_unfinished_requests(): + res = env.engine.step_modern() + finished_records.extend(res["finished_request_records"]) - stop_hit = env.engine._check_stop_words_for_request_post_append(req) - assert stop_hit is True + finished_req = finished_records[0].merge() + + # If there is double counting, the tracked active length will outpace the actual + # generated tokens, causing premature termination when it thinks it hit max_sequence_length. + assert finished_req.status == Status.COMPLETED + assert ( + len(finished_req.generated_tokens) == 6 + ), f"Expected 6 tokens, got {len(finished_req.generated_tokens)}. Double counting occurred." @pytest.mark.internal @pytest.mark.skipif( @@ -2166,18 +2374,21 @@ def test_speculative_decoding_with_prefix_caching(self): cached KV blocks from the first and still generate correctly with spec decoding. """ test_config = DynamicEngineTestConfig( - num_requests=4, + num_requests=0, # Added manually below min_prompt_length=8, max_prompt_length=8, num_tokens_to_generate=4, num_speculative_tokens=2, - enable_prefix_caching=True, + enable_prefix_caching=True, # Set at config level + context_block_size_tokens=8, # Ensure exact 1 block per prompt materialize_only_last_token_logits=False, model_provider="gpt", + context_max_tokens=512, + context_max_requests=512, ) env = self._build_test_env(test_config) - # Create two pairs of requests with shared prefixes. + # Create two pairs of requests with identical shared prefixes. shared_prompt_a = torch.randint( 0, test_config.vocab_size - 1, (8,), dtype=torch.int64, device='cuda' ) @@ -2185,55 +2396,39 @@ def test_speculative_decoding_with_prefix_caching(self): 0, test_config.vocab_size - 1, (8,), dtype=torch.int64, device='cuda' ) - for i, prompt in enumerate([shared_prompt_a, shared_prompt_a, shared_prompt_b, shared_prompt_b]): - env.requests[i].prompt_tokens = prompt.clone() - - # Run all requests through the engine. - for request in env.requests: - env.engine._add_request(request) - - while env.engine.has_unfinished_requests(): - self._run_step(env) + prompts = [shared_prompt_a, shared_prompt_a, shared_prompt_b, shared_prompt_b] - # All requests should complete. - for request in env.requests: - assert request.status in (Status.COMPLETED, Status.FAILED) - if request.status == Status.COMPLETED: - assert len(request.generated_tokens) > 0 + for i, prompt in enumerate(prompts): + # Using the clean public API guarantees correct hashing and dataclass creation + env.engine.add_request( + request_id=i, + prompt=prompt.clone(), + sampling_params=SamplingParams(num_tokens_to_generate=4, termination_id=99), + ) - # Context should be clean after all requests finish. - assert env.engine.context.active_token_count == 0 - assert env.engine.context.total_request_count == 0 + # First, run schedule_waiting_requests and ONE step to allocate the prefill blocks. + # Req 0 and 2 will schedule immediately. Req 1 and 3 will defer because their hashes + # are currently pending (being registered by 0 and 2). + env.engine.schedule_waiting_requests() + env.engine.step_modern() - @pytest.mark.internal - @pytest.mark.skipif( - not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" - ) - @torch.inference_mode() - def test_speculative_decoding_with_chunked_prefill(self): - """Test that speculative decoding combined with chunked prefill completes correctly.""" - test_config = DynamicEngineTestConfig( - num_requests=2, - min_prompt_length=16, - max_prompt_length=16, - num_tokens_to_generate=4, - num_speculative_tokens=2, - materialize_only_last_token_logits=False, - enable_chunked_prefill=True, - model_provider="gpt", - context_max_tokens=32, # Force chunking by limiting token budget - ) - env = self._build_test_env(test_config) + # After step 1, Req 0 and 2 have completely registered their cached blocks. + # Now, schedule the deferred ones (Req 1 and 3). They will find the registered blocks! + env.engine.schedule_waiting_requests() + env.engine.step_modern() - for request in env.requests: - env.engine._add_request(request) + # 4 requests. 2 unique prefixes (1 block each). + # Without sharing, we'd need 8 blocks + 1 dummy = 9 active_used. + # With sharing, we need 2 shared blocks + 4 generation blocks + 1 dummy = 7 active_used. + active_used = env.engine.context.block_allocator.get_active_used() + assert ( + active_used <= 7 + ), f"Prefix caching failed, expected <= 7 active blocks but got {active_used}" while env.engine.has_unfinished_requests(): - self._run_step(env) - - for request in env.requests: - assert request.status in (Status.COMPLETED, Status.FAILED) + env.engine.step_modern() + # Context should be clean after all requests finish. assert env.engine.context.active_token_count == 0 assert env.engine.context.total_request_count == 0 @@ -2251,36 +2446,35 @@ def test_speculative_decoding_chunked_prefill_and_prefix_caching(self): - Speculative decoding generates multiple tokens per step """ test_config = DynamicEngineTestConfig( - num_requests=4, + num_requests=0, min_prompt_length=16, max_prompt_length=16, num_tokens_to_generate=4, num_speculative_tokens=2, materialize_only_last_token_logits=False, enable_chunked_prefill=True, + enable_prefix_caching=True, # Set at config level + context_block_size_tokens=8, model_provider="gpt", context_max_tokens=48, # Force chunking + context_max_requests=48, ) env = self._build_test_env(test_config) - # Enable prefix caching. - env.engine.context.enable_prefix_caching = True - - # Create pairs with shared prefixes to exercise prefix caching. + # Create identical prompts for all 4 requests shared_prompt = torch.randint( 0, test_config.vocab_size - 1, (16,), dtype=torch.int64, device='cuda' ) - for i in range(len(env.requests)): - env.requests[i].prompt_tokens = shared_prompt.clone() - for request in env.requests: - env.engine._add_request(request) + for i in range(4): + env.engine.add_request( + request_id=i, + prompt=shared_prompt.clone(), + sampling_params=SamplingParams(num_tokens_to_generate=4, termination_id=99), + ) while env.engine.has_unfinished_requests(): - self._run_step(env) - - for request in env.requests: - assert request.status in (Status.COMPLETED, Status.FAILED) + env.engine.step_modern() assert env.engine.context.active_token_count == 0 assert env.engine.context.total_request_count == 0 From e6f61d6d3bcddb74a75dab0b6a77364aac6c3af1 Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Sun, 8 Mar 2026 23:17:20 -0700 Subject: [PATCH 44/76] Linting Signed-off-by: Keshav Santhanam --- megatron/core/transformer/attention.py | 8 ++------ megatron/inference/utils.py | 2 +- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index 0ac7e78fae3..310a59bde35 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -60,9 +60,7 @@ rearrange = None try: - from flash_attn_3.flash_attn_interface import ( - _flash_attn_forward, - ) + from flash_attn_3.flash_attn_interface import _flash_attn_forward from flash_attn_3.flash_attn_interface import ( flash_attn_with_kvcache as flash_attn3_with_kvcache, ) @@ -73,9 +71,7 @@ if not HAVE_FA3: try: - from flashattn_hopper.flash_attn_interface import ( - _flash_attn_forward, - ) + from flashattn_hopper.flash_attn_interface import _flash_attn_forward from flashattn_hopper.flash_attn_interface import ( flash_attn_with_kvcache as flash_attn3_with_kvcache, ) diff --git a/megatron/inference/utils.py b/megatron/inference/utils.py index 0cf206b46fc..ec8f1088be1 100644 --- a/megatron/inference/utils.py +++ b/megatron/inference/utils.py @@ -341,7 +341,7 @@ def get_inference_config_from_model_and_args(model: MegatronModule, args): mamba_inference_state_config=mamba_inference_state_config, pg_collection=pg_collection, use_flashinfer_fused_rope=args.use_flashinfer_fused_rope, - materialize_only_last_token_logits=(not args.return_log_probs and not args.num_speculative_tokens > 0), + materialize_only_last_token_logits=(not args.return_log_probs and args.num_speculative_tokens == 0), track_generated_token_events=args.inference_dynamic_batching_track_generated_token_events, track_paused_request_events=args.inference_dynamic_batching_track_paused_request_events, enable_chunked_prefill=args.enable_chunked_prefill, From 5e02618b632287eed323f64f3be28d4c345d731d Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Sun, 8 Mar 2026 23:29:33 -0700 Subject: [PATCH 45/76] Address review comments Signed-off-by: Keshav Santhanam --- .../core/inference/engines/dynamic_engine.py | 2 - .../text_generation_controller.py | 275 ++++++++++-------- megatron/core/models/gpt/gpt_model.py | 21 +- megatron/core/models/mamba/mamba_model.py | 21 +- .../transformer/multi_token_prediction.py | 38 +++ 5 files changed, 204 insertions(+), 153 deletions(-) diff --git a/megatron/core/inference/engines/dynamic_engine.py b/megatron/core/inference/engines/dynamic_engine.py index 746adbc6590..422aa18f435 100644 --- a/megatron/core/inference/engines/dynamic_engine.py +++ b/megatron/core/inference/engines/dynamic_engine.py @@ -215,8 +215,6 @@ def __init__(self, controller: TextGenerationController, context: DynamicInferen self.num_speculative_tokens <= self.controller.num_mtp_heads ), f"Number of speculative tokens {self.num_speculative_tokens} must be less than or equal to number of MTP heads {self.controller.num_mtp_heads}" - self.controller._init_mtp_sampling_tensor() - self.track_paused_request_events = inference_config.track_paused_request_events self.track_generated_token_events = inference_config.track_generated_token_events self.enable_chunked_prefill = inference_config.enable_chunked_prefill diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index 55df5a95334..af733384532 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -137,6 +137,8 @@ def _init_dynamic_sampling_tensors(self): if self._sampling_backend == "torch": self._torch_sampling_buckets: List[Tuple] = [] + self._init_mtp_sampling_tensor() + def _init_mtp_sampling_tensor(self): """Initialize the MTP sampling tensor after num_speculative_tokens is set.""" if self.num_speculative_tokens is not None and self.num_speculative_tokens > 0: @@ -781,39 +783,31 @@ def _rewind_kv_cache(self): ] ) - def _dynamic_step_sample_logits_and_verify_tokens( - self, logits: Tensor, mtp_logits: Tensor, input_ids: Tensor - ): - """ - Sample tokens from logits for dynamic batching with speculative tokens and verify the tokens. - """ - context = self.inference_wrapped_model.inference_context - active_request_count = context.total_request_count - context.paused_request_count - - # ================ PART 1 The following part of the code is to get all the relevant logit indices alone ========= - # i.e For prefill requests just the last token logits are enough. - # i.e For decode requests we will need all tokens - # Decode request will always be on the left, followed by prefill requests - # In non speculative case, it was simple in the other function, we just always get the last token logits using query lengths. - - # 5 requests # Input ids shape : [1, 15] - # Assume input ids : [ a5 a6s a7s | b3 b4s b5s | c6 c7s c8s | d1 d2 | e1 e2 e3 e4] - # Request to prefill [ 0 | 0 | 0 | 1 | 1 ] - # Request query lengths [ 3 | 3 | 3 | 2 | 4 ] - # OUTPUT : required_logit_indices [ 0 1 2 | 3 4 5 | 6 7 8 | 10 | 14 ] - - request_in_prefill_status_tensor = context.request_in_prefill_status_tensor[ - context.paused_request_count : context.total_request_count - ] - request_query_lengths = context.request_query_lengths[ - context.paused_request_count : context.total_request_count - ] - - num_prefill_requests = request_in_prefill_status_tensor.sum().item() - num_decode_requests = active_request_count - num_prefill_requests + def _get_required_logit_indices( + self, + request_in_prefill_status_tensor: Tensor, + request_query_lengths: Tensor, + num_decode_requests: int, + num_prefill_requests: int, + device: torch.device, + ) -> Tensor: + """Get indices into the logits tensor for tokens that need sampling. + + For decode requests, all tokens (base + speculative) are needed. + For prefill requests, only the last token logits are needed. + Decode requests will always be on the left, followed by prefill requests. + + Example with 5 requests (2 spec tokens): + Assume input ids : [ a5 a6s a7s | b3 b4s b5s | c6 c7s c8s | d1 d2 | e1 e2 e3 e4] + Request to prefill [ 0 | 0 | 0 | 1 | 1 ] + Request query lengths [ 3 | 3 | 3 | 2 | 4 ] + OUTPUT : required_logit_indices [ 0 1 2 | 3 4 5 | 6 7 8 | 10 | 14 ] + Returns: + Tensor: Indices into the sequence dimension of the logits tensor. + """ decode_request_indices = torch.arange( - num_decode_requests * (self.num_speculative_tokens + 1), device=logits.device + num_decode_requests * (self.num_speculative_tokens + 1), device=device ) prefill_request_indices = ( request_query_lengths.cumsum(dim=0)[request_in_prefill_status_tensor == 1] - 1 @@ -828,43 +822,32 @@ def _dynamic_step_sample_logits_and_verify_tokens( f"but got {len(required_logit_indices)} for num_decode_requests {num_decode_requests} " f"and num_prefill_requests {num_prefill_requests}" ) + return required_logit_indices - required_logits = logits.squeeze(0)[required_logit_indices, :] # Shape [1, 11, vocab_size] - required_mtp_logits = mtp_logits[ - :, required_logit_indices, : - ] # Shape [num_speculative_tokens, 11, vocab_size] - - # ================ PART 1 The following part of the code is to sample the logits and mtp logits based on the sampling parameters ========= - - # request_indices will be 0, 1, 2, 3, 4 (since we have only 5 requests) - # For torch sampling buckets :-[request_indices, temp, top_k, top_p] - # [ - # [[0,2], temp1, top_k1, top_p1], - # [1], temp3, top_k3, top_p3] - # [3, 4], temp2, top_k2, top_p2], - # ] - - # Token to request idx : [ 0 0 0 | 1 1 1 | 2 2 2 | 3 | 4 ] - # required_logits : [ a5l a6l a7l | b3l b4l b5l | c6l c7l c8l | d2l | e4l ] # Shape [11, vocab_size] - # For first iteration : - # sampling buckets : [0,2], temp1, top_k1, top_p1 - # output_tokens_jumbled_list = [a5s a6s a7s c6s c7s c8s] #s->sampled tokens # - # request_order_list = [0, 2] - # token_order_list = [0, 1, 2, 6, 7, 8] - # For second iteration : - # sampling buckets : [1], temp3, top_k3, top_p3 - # output_tokens_jumbled_list = [b3s b4s b5s] - # request_order_list = [1] - # token_order_list = [3, 4, 5] - # For third iteration : - # sampling buckets : [3, 4], temp2, top_k2, top_p2 - # output_tokens_jumbled_list = [d2s e4s] #s->sampled tokens # - # request_order_list = [3, 4] - # token_order_list = [9,10] - # Final output tokens : [a5s a6s a7s c6s c7s c8s b3s b4s b5s d2s e4s] # Shape [11] - # Final request order list : [0, 2, 1, 3, 4] - # Final token order list : [0, 1, 2, 6, 7, 8, 3, 4, 5, 9, 10] + def _sample_speculative_logits( + self, + required_logits: Tensor, + required_mtp_logits: Tensor, + request_in_prefill_status_tensor: Tensor, + ) -> tuple: + """Sample tokens from logits and MTP logits using sampling buckets. + + For torch sampling buckets: [request_indices, temp, top_k, top_p] + Example with 5 requests: + token_to_request_idx : [ 0 0 0 | 1 1 1 | 2 2 2 | 3 | 4 ] + required_logits : [ a5l a6l a7l | b3l b4l b5l | c6l c7l c8l | d2l | e4l ] # Shape [11, vocab_size] + + Sampling buckets: [[[0,2], temp1, top_k1, top_p1], [[1], temp3, top_k3, top_p3], [[3, 4], temp2, top_k2, top_p2]] + + Final output tokens : [a5s a6s a7s c6s c7s c8s b3s b4s b5s d2s e4s] # Shape [11] + (Rearranged from sampling bucket order back to input order using token_order) + + Returns: + tuple: (output_tokens, mtp_output_tokens, repeats) where output_tokens has shape + [total_required_tokens] and mtp_output_tokens has shape + [num_speculative_tokens, total_required_tokens]. + """ repeats = torch.where( request_in_prefill_status_tensor == 0, 1 + self.num_speculative_tokens, 1 ) @@ -880,7 +863,6 @@ def _dynamic_step_sample_logits_and_verify_tokens( mtp_output_tokens_jumbled_list = [] token_order_list = [] - # TODO : Maybe its okay to have a loop with num spec tokens ? (Since it will only be max 3 , so might be faster) for request_indices, temp, top_k, top_p in self._torch_sampling_buckets: request_indices_tensor = torch.tensor( request_indices, device=token_to_request_index.device @@ -888,8 +870,6 @@ def _dynamic_step_sample_logits_and_verify_tokens( required_indices = torch.where( torch.isin(token_to_request_index, request_indices_tensor) )[0] - # TODO : Can maybe club the following two and then split later ? - # TODO : Can directly initialize output tokens as a tensor and put the logits in the right place output_tokens_jumbled_list.append( self._torch_sampling_func(required_logits[required_indices, :], temp, top_k, top_p) ) @@ -908,9 +888,7 @@ def _dynamic_step_sample_logits_and_verify_tokens( dtype=output_tokens_jumbled.dtype, ) token_order = torch.cat(token_order_list, dim=0) - # Rearrange output tokens because previously it will be in the order of the - # sampling_bucket request indices, but now we want to put them according to - # their corresponding input ids + # Rearrange output tokens from sampling_bucket request order back to input ids order output_tokens[token_order] = output_tokens_jumbled mtp_output_tokens_jumbled = torch.cat( @@ -919,24 +897,37 @@ def _dynamic_step_sample_logits_and_verify_tokens( mtp_output_tokens = torch.empty_like(mtp_output_tokens_jumbled) mtp_output_tokens[:, token_order] = mtp_output_tokens_jumbled - ### ================ PART 3 This part is to do the following : ================ - # Create the accepted tokens tensor - # For prefill it is always set to 1 - # For decode, the first token is always accepted, then we compare with input tokens - # and accept the next tokens if its a match - # Then find the index of the last 1 in every request of the accepted tokens tensor - # Then these are the index of the tokens that will be sent to the next forward pass - # In the example (assume 1 spec token, 2 spec tokens and 0 sepc tokens are accepted - # in the first 3 requests - - # Assume input ids : [ a5 a6s a7s | b3 b4s b5s | c6 c7s c8s | d1 d2 | e1 e2 e3 e4] - # input_tokens_required: [ a5 a6s a7s | b3 b4s b5s | c6 c7s c8s | d2 | e4 ] # Size 11 - # Output tokens [ a6o a7o a8o | b40 b5o b6o | c7o c8o c9o | d3o | e5o ] # At every index we get next positions sample - # Output tokens right shift [ d3o a6o a7o | a8o b40 b5o | b6o c7o c8o | c9o | d3o ] - # Accepted tokens mask [ 1 1 0 | 1 1 1 | 1 0 0 | 1 | 1 ] - # Last one indices [ 1 | 5 | 6 | 9 | 10 ] + return output_tokens, mtp_output_tokens, repeats - input_tokens_required = input_ids[0, required_logit_indices] + def _verify_speculative_tokens( + self, + output_tokens: Tensor, + input_tokens_required: Tensor, + request_in_prefill_status_tensor: Tensor, + repeats: Tensor, + num_decode_requests: int, + num_prefill_requests: int, + active_request_count: int, + ) -> tuple: + """Verify speculative tokens against input tokens and compute acceptance. + + Creates an accepted tokens mask where: + - For prefill requests, the token is always accepted. + - For decode requests, the first token (base token) is always accepted, then we compare + sampled tokens with input tokens and accept consecutive matches. + Then finds the index of the last accepted token per request. + + Example (assume 1, 2, and 0 spec tokens are accepted in the first 3 decode requests): + input_tokens_required: [ a5 a6s a7s | b3 b4s b5s | c6 c7s c8s | d2 | e4 ] # Size 11 + Output tokens [ a6o a7o a8o | b40 b5o b6o | c7o c8o c9o | d3o | e5o ] + Output tokens right shift [ d3o a6o a7o | a8o b40 b5o | b6o c7o c8o | c9o | d3o ] + Accepted tokens mask [ 1 1 0 | 1 1 1 | 1 0 0 | 1 | 1 ] + Last one indices [ 1 | 5 | 6 | 9 | 10 ] + + Returns: + tuple: (last_one_indices, accepted_tokens_mask, input_tokens_required) where + last_one_indices contains the index of the last accepted token per request. + """ if input_tokens_required.ndim == 2: assert ( input_tokens_required.shape[0] == 1 @@ -946,11 +937,12 @@ def _dynamic_step_sample_logits_and_verify_tokens( # Initialize mask with False to prevent boundary bleed accepted_tokens_mask = torch.zeros_like(input_tokens_required, dtype=torch.bool) - # This is to make all prefill tokens accepted + # Make all prefill tokens accepted token_to_prefill_idx = torch.repeat_interleave(request_in_prefill_status_tensor, repeats) accepted_tokens_mask[token_to_prefill_idx == 1] = True # Safe decode token verification without cross-batch boundary contamination + decode_mask_2d = None if num_decode_requests > 0: decode_len = num_decode_requests * (self.num_speculative_tokens + 1) @@ -963,23 +955,15 @@ def _dynamic_step_sample_logits_and_verify_tokens( # Shift outputs right by 1 *within* each request to align sampled tokens with input targets decode_outputs_shifted = decode_outputs.roll(1, dims=1) - decode_mask_2d = decode_inputs == decode_outputs_shifted - # The first token (base token) is always accepted decode_mask_2d[:, 0] = True - - # ENFORCE CONSECUTIVE ACCEPTANCE: - # cummin() on booleans propagates False (0) to the right, invalidating all subsequent mismatches + # Enforce consecutive acceptance: cummin propagates False to the right decode_mask_2d = decode_mask_2d.cummin(dim=1).values - - # Put the consecutive-enforced mask back into the flattened 1D tensor accepted_tokens_mask[:decode_len] = decode_mask_2d.flatten() - # This is to find the index of the last 1 in every request - # (Now mathematically guaranteed to be the final consecutive match) last_one_indices = torch.full( - (active_request_count,), -1, device=token_to_request_index.device + (active_request_count,), -1, device=input_tokens_required.device ) if num_decode_requests > 0: @@ -991,47 +975,92 @@ def _dynamic_step_sample_logits_and_verify_tokens( last_one_indices[:num_decode_requests] = row_offsets + local_last_indices if num_prefill_requests > 0: - # Prefill requests only have 1 token evaluated, so nonzero() is perfectly safe here decode_len = num_decode_requests * (self.num_speculative_tokens + 1) prefill_valid = ( torch.nonzero(accepted_tokens_mask[decode_len:]).squeeze(-1) + decode_len ) last_one_indices[num_decode_requests:] = prefill_valid - # These are the tokens (output + speculative tokens) that will be going to the next forward pass + return last_one_indices, accepted_tokens_mask, input_tokens_required + + def _dynamic_step_sample_logits_and_verify_tokens( + self, logits: Tensor, mtp_logits: Tensor, input_ids: Tensor + ): + """ + Sample tokens from logits for dynamic batching with speculative tokens and verify the tokens. + """ + context = self.inference_wrapped_model.inference_context + active_request_count = context.total_request_count - context.paused_request_count + + request_in_prefill_status_tensor = context.request_in_prefill_status_tensor[ + context.paused_request_count : context.total_request_count + ] + request_query_lengths = context.request_query_lengths[ + context.paused_request_count : context.total_request_count + ] + + num_prefill_requests = request_in_prefill_status_tensor.sum().item() + num_decode_requests = active_request_count - num_prefill_requests + + # Get the logit indices for tokens that need sampling. + required_logit_indices = self._get_required_logit_indices( + request_in_prefill_status_tensor, + request_query_lengths, + num_decode_requests, + num_prefill_requests, + logits.device, + ) + + required_logits = logits.squeeze(0)[required_logit_indices, :] # Shape [num_required, vocab_size] + required_mtp_logits = mtp_logits[ + :, required_logit_indices, : + ] # Shape [num_speculative_tokens, num_required, vocab_size] + + # Sample tokens from logits and MTP logits. + output_tokens, mtp_output_tokens, repeats = self._sample_speculative_logits( + required_logits, required_mtp_logits, request_in_prefill_status_tensor + ) + + # Verify speculative tokens against input tokens. + input_tokens_required = input_ids[0, required_logit_indices] + last_one_indices, accepted_tokens_mask, input_tokens_required = ( + self._verify_speculative_tokens( + output_tokens, + input_tokens_required, + request_in_prefill_status_tensor, + repeats, + num_decode_requests, + num_prefill_requests, + active_request_count, + ) + ) + + # Store the final sampled tokens and MTP tokens for the next forward pass. final_sampled_tokens = output_tokens[last_one_indices] self._sampled_tokens_cuda[: len(final_sampled_tokens)] = final_sampled_tokens self._sampled_mtp_tokens_cuda[:, : len(final_sampled_tokens)] = mtp_output_tokens[ :, last_one_indices ] - ### ================ PART 4 This part is to do the following : ================ - # To fill the speculative tokens and accepted_token counts - # For prefill it is always set to 1 - # For decode, the first token is always accepted, then we compare with input tokens and - # accept the next tokens if its a match - # Then find the index of the last 1 in every request of the accepted tokens tensor - # Then these are the index of the tokens that will be sent to the next forward pass - # In the example (assume 1 spec token, 2 spec tokens and 0 sepc tokens are accepted in - # the first 3 requests - - # input_tokens_required: [ a5 a6s a7s | b3 b4s b5s | c6 c7s c8s | d2 | e4 ] # Size 11 - # Accepted tokens mask [ 1 1 0 | 1 1 1 | 1 0 0 | 1 | 1 ] - # Accepted tokens [ [a6s -1] | [b4s b5s] | [-1 -1] ] # Only handle decod requests, (Prefill already defaults to -1s) - # Accepted token counts [ 1 | 2 | 0 ] # Prefill defaults to 0 - - # This part is to extract the accepted tokens - input_tokens_required[accepted_tokens_mask == 0] = -1 # Masks out non accepted tokens + # Extract accepted tokens and counts for decode requests. + # For prefill it is always set to 1. For decode, the first token is always accepted, + # then we compare with input tokens and accept the next tokens if its a match. + # + # Example (continuing from above): + # input_tokens_required: [ a5 a6s a7s | b3 b4s b5s | c6 c7s c8s | d2 | e4 ] + # Accepted tokens mask [ 1 1 0 | 1 1 1 | 1 0 0 | 1 | 1 ] + # Accepted tokens [ [a6s -1] | [b4s b5s] | [-1 -1] ] # Only decode requests (prefill defaults to -1) + # Accepted token counts [ 1 | 2 | 0 ] # Prefill defaults to 0 + input_tokens_required[accepted_tokens_mask == 0] = -1 # Mask out non-accepted tokens input_tokens_decode_mode = input_tokens_required[ : num_decode_requests * (self.num_speculative_tokens + 1) ] input_tokens_reshaped = input_tokens_decode_mode.reshape( -1, self.num_speculative_tokens + 1 - ) # shape : [num_decode_requests, num_speculative_tokens + 1] + ) # shape: [num_decode_requests, num_speculative_tokens + 1] - accepted_tokens = input_tokens_reshaped[ - :, 1: - ] # Skip the first token of every decode request (i.e a5, b3, c6) + # Skip the first token of every decode request (i.e a5, b3, c6) + accepted_tokens = input_tokens_reshaped[:, 1:] self._accepted_tokens_per_request[: accepted_tokens.shape[0], :] = accepted_tokens self._accepted_token_counts_per_request = (self._accepted_tokens_per_request != -1).sum( dim=1 diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py index 3cd6bfae2df..33a5f9c9bc4 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py @@ -27,6 +27,7 @@ from megatron.core.transformer.enums import CudaGraphScope, ModelType from megatron.core.transformer.multi_token_prediction import ( MultiTokenPredictionBlock, + compute_mtp_inference_logits, mtp_on_this_rank, process_mtp_loss, ) @@ -617,21 +618,13 @@ def _postprocess( # The new process_mtp_loss function doesn't handle mtp_logits_cache, # so we manually generate and cache MTP logits when in inference mode. if in_inference_mode: - hidden_states_list = torch.chunk( - hidden_states, 1 + self.config.mtp_num_layers, dim=0 + hidden_states, self._mtp_logits_cache = compute_mtp_inference_logits( + hidden_states=hidden_states, + mtp_num_layers=self.config.mtp_num_layers, + output_layer=self.output_layer, + output_weight=output_weight, + runtime_gather_output=runtime_gather_output, ) - hidden_states = hidden_states_list[0] - self._mtp_logits_cache = None - mtp_inference_logits = [] - for mtp_layer_number in range(self.config.mtp_num_layers): - mtp_logits, _ = self.output_layer( - hidden_states_list[mtp_layer_number + 1], - weight=output_weight, - runtime_gather_output=runtime_gather_output, - ) - # mtp logits shape [b, 1, vocab size] - mtp_inference_logits.append(mtp_logits.squeeze(1).unsqueeze(0)) - self._mtp_logits_cache = torch.cat(mtp_inference_logits, dim=0) else: # In training/eval, use the utility function for processing MTP loss/scaling. hidden_states = process_mtp_loss( diff --git a/megatron/core/models/mamba/mamba_model.py b/megatron/core/models/mamba/mamba_model.py index 8ba3c9d556a..71140b1c2af 100644 --- a/megatron/core/models/mamba/mamba_model.py +++ b/megatron/core/models/mamba/mamba_model.py @@ -20,6 +20,7 @@ from megatron.core.transformer.enums import ModelType from megatron.core.transformer.multi_token_prediction import ( MultiTokenPredictionBlock, + compute_mtp_inference_logits, mtp_on_this_rank, process_mtp_loss, ) @@ -408,21 +409,13 @@ def forward( # The new process_mtp_loss function doesn't handle mtp_logits_cache, # so we manually generate and cache MTP logits when in inference mode. if in_inference_mode: - hidden_states_list = torch.chunk( - hidden_states, 1 + self.config.mtp_num_layers, dim=0 + hidden_states, self._mtp_logits_cache = compute_mtp_inference_logits( + hidden_states=hidden_states, + mtp_num_layers=self.config.mtp_num_layers, + output_layer=self.output_layer, + output_weight=output_weight, + runtime_gather_output=runtime_gather_output, ) - hidden_states = hidden_states_list[0] - self._mtp_logits_cache = None - mtp_inference_logits = [] - for mtp_layer_number in range(self.config.mtp_num_layers): - mtp_logits, _ = self.output_layer( - hidden_states_list[mtp_layer_number + 1], - weight=output_weight, - runtime_gather_output=runtime_gather_output, - ) - # mtp logits shape [b, 1, vocab size] - mtp_inference_logits.append(mtp_logits.squeeze(1).unsqueeze(0)) - self._mtp_logits_cache = torch.cat(mtp_inference_logits, dim=0) else: hidden_states = process_mtp_loss( hidden_states=hidden_states, diff --git a/megatron/core/transformer/multi_token_prediction.py b/megatron/core/transformer/multi_token_prediction.py index 4ad2e517cfc..29dd8fef986 100755 --- a/megatron/core/transformer/multi_token_prediction.py +++ b/megatron/core/transformer/multi_token_prediction.py @@ -612,6 +612,44 @@ def set_loss_scale(scale: torch.Tensor): MTPLossAutoScaler.main_loss_backward_scale = scale +def compute_mtp_inference_logits( + hidden_states: Tensor, + mtp_num_layers: int, + output_layer: Callable, + output_weight: Optional[Tensor], + runtime_gather_output: Optional[bool], +) -> tuple: + """Compute MTP logits for inference mode. + + Splits the concatenated hidden states and generates logits for each MTP layer. + + Args: + hidden_states (Tensor): Concatenated hidden states from main + MTP layers. + mtp_num_layers (int): Number of MTP layers. + output_layer (Callable): Output layer method to compute logits. + output_weight (Optional[Tensor]): Optional output weight for shared embeddings. + runtime_gather_output (Optional[bool]): Whether to gather output at runtime. + + Returns: + tuple: (hidden_states, mtp_logits_cache) where hidden_states is the main hidden + states and mtp_logits_cache is a tensor of shape + [mtp_num_layers, batch_size, vocab_size]. + """ + hidden_states_list = torch.chunk(hidden_states, 1 + mtp_num_layers, dim=0) + hidden_states = hidden_states_list[0] + mtp_inference_logits = [] + for mtp_layer_number in range(mtp_num_layers): + mtp_logits, _ = output_layer( + hidden_states_list[mtp_layer_number + 1], + weight=output_weight, + runtime_gather_output=runtime_gather_output, + ) + # mtp logits shape [b, 1, vocab size] + mtp_inference_logits.append(mtp_logits.squeeze(1).unsqueeze(0)) + mtp_logits_cache = torch.cat(mtp_inference_logits, dim=0) + return hidden_states, mtp_logits_cache + + def process_mtp_loss( hidden_states: Tensor, labels: Tensor, From ff8721d7b809be189d27235dd2f51e44a9c4b309 Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Mon, 9 Mar 2026 00:03:59 -0700 Subject: [PATCH 46/76] Update clones Signed-off-by: Keshav Santhanam --- .../inference/contexts/dynamic_context.py | 20 +++++++++++++++++-- .../text_generation_controller.py | 14 ++++++++++--- 2 files changed, 29 insertions(+), 5 deletions(-) diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index 22bb07bd968..bec468b67eb 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -2089,6 +2089,8 @@ def resume_paused_requests( resume_request_count = 0 if self.paused_request_count > 0: active_block_count_avail = self.block_allocator.get_active_avail() + # Clone needed: flip() returns a view, and subsequent += (line below) would + # write through to self.request_kv_block_counts without the clone. paused_block_counts = self.request_kv_block_counts[: self.paused_request_count].clone() # Flip counts before cumsum, since paused requests are resumed from # the right-most index, so we must count resumed blocks starting from @@ -2209,6 +2211,8 @@ def evict_overflow_paused_requests( evict_request_idxs = torch.arange( evict_start_idx, evict_end_idx, device=torch.cuda.current_device() ) + # Clone needed: subsequent release_memory_blocks_from_request_indexes and + # _swap_book_keeping_tensors calls mutate self.request_ids in place. evict_request_ids = self.request_ids[evict_start_idx:evict_end_idx].clone() # Release memory. @@ -2482,6 +2486,9 @@ def update_requests( # for resumed requests, but we need the OLD block for tokens that don't cross. prev_last_block_ids = None if self.num_speculative_tokens > 0: + # Clone needed: resume_paused_requests mutates request_last_kv_block_id + # (assigns new block IDs), but we need the old values later to determine + # which block tokens should go to when they don't cross a block boundary. prev_last_block_ids = self.request_last_kv_block_id.clone() # 6.a. First, resume temporarily paused requests. @@ -2512,8 +2519,11 @@ def update_requests( assert self.total_request_count == active_request_count + self.paused_request_count if self.paused_request_count > 0: + # Clone needed: next_tokens is a shared buffer that will be overwritten in + # the next iteration; paused_tokens must persist independently. self.paused_tokens = next_tokens[: self.paused_request_count].clone() if new_speculative_tokens is not None: + # Clone needed: same reason as paused_tokens above. self.paused_speculative_tokens = new_speculative_tokens[ :, : self.paused_request_count ].clone() @@ -2530,6 +2540,10 @@ def update_requests( num_generated_tokens ) + # Clone needed: old_offsets is reused later (line ~2606) to compute raw_positions + # for block-boundary detection. The write-back on the next line overwrites the + # underlying tensor, so without clone the boundary-crossing logic would see the + # new offsets instead of the pre-update values. old_offsets = self.request_last_kv_block_offset[ self.paused_request_count : self.total_request_count ].clone() @@ -2635,8 +2649,10 @@ def update_requests( # Start with current (new) block for all # Lets say current block ids is [a1, a2 , a3] and num generated_tokens is 3 # This will be [[a1, a1, a1], [a2, a2, a2], [a3, a3, a3]] - block_idx = ( - current_block_ids[:, None].expand(-1, num_generated_tokens).clone() + # No clone needed: expand() returns a read-only view, and downstream + # torch.where() and .flatten() both return new tensors without in-place mutation. + block_idx = current_block_ids[:, None].expand( + -1, num_generated_tokens ) # [active_count, N] # For requests that have crossing, tokens BEFORE boundary use prev block diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index af733384532..bbcf11be371 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -306,6 +306,8 @@ def modify_logits_for_top_p_filtering(logits, top_p): # in the original implementation: # https://github.com/ari-holtzman/degen/blob/master/gen.py # and I guess it is needed so keeping it for now. + # Clone needed: filter_[:, 1:] and filter_[:, :-1] are overlapping views; + # without clone, each write would corrupt the next read during the shift. filter_[:, 1:] = filter_[:, :-1].clone() # Make sure we at least have one token to select from. filter_[..., 0] = 0 @@ -318,6 +320,8 @@ def modify_logits_for_top_p_filtering(logits, top_p): if top_k == 1: sampled_logits = torch.argmax(last_token_logits, dim=-1) else: + # Clone needed: .div_() and masked_fill_() below modify in-place, + # which would mutate the caller's tensor without this clone. last_token_logits = last_token_logits.clone() if temperature != 1.0: last_token_logits.div_(temperature) @@ -741,8 +745,9 @@ def _rewind_kv_cache(self): # Convert to absolute indices in the context tensors absolute_indices = requests_needing_release + context.paused_request_count - # Get the block IDs to release (current last block for these requests) - blocks_to_release = context.request_last_kv_block_id[absolute_indices].clone() + # No clone needed: advanced (fancy) indexing with a tensor already returns + # a copy, not a view. + blocks_to_release = context.request_last_kv_block_id[absolute_indices] # Reduce block counts for requests that crossed back context.request_kv_block_counts[absolute_indices] -= 1 @@ -1362,7 +1367,8 @@ def _dynamic_step_context_bookkeeping(self) -> Dict[str, Tensor]: ) finished_request_ids = context.request_ids[finished_idxs] - # New sample gets updated in update_requests, so we pass in a clone + # Clone needed: update_requests mutates next_tokens in-place via tensor_swap, + # which would corrupt the reused _sampled_tokens_cuda buffer. new_sample_copy = self._sampled_tokens_cuda[:active_request_count].clone() # Update requests. @@ -1465,8 +1471,10 @@ async def async_generate_output_tokens_dynamic_batch( request_bookkeeping = self._dynamic_step_context_bookkeeping() ret = { + # Clone needed: _sampled_tokens_cuda is a reused buffer overwritten each step. "sample": self._sampled_tokens_cuda[:active_request_count].clone(), "accepted_tokens": ( + # Clone needed: .fill_(-1) on line 1480 would corrupt the returned value. self._accepted_tokens_per_request.clone() if self.num_speculative_tokens > 0 else None From 83f526cc8cd6324ae1b277cfa204280213361509 Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Sun, 8 Mar 2026 23:30:29 -0700 Subject: [PATCH 47/76] Linting Signed-off-by: Keshav Santhanam --- .../text_generation_controllers/text_generation_controller.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index bbcf11be371..ee03763b1be 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -1016,7 +1016,9 @@ def _dynamic_step_sample_logits_and_verify_tokens( logits.device, ) - required_logits = logits.squeeze(0)[required_logit_indices, :] # Shape [num_required, vocab_size] + required_logits = logits.squeeze(0)[ + required_logit_indices, : + ] # Shape [num_required, vocab_size] required_mtp_logits = mtp_logits[ :, required_logit_indices, : ] # Shape [num_speculative_tokens, num_required, vocab_size] From 119466313da9aa3c35e2cb081d7bab9ee5c30b1d Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Mon, 9 Mar 2026 00:16:10 -0700 Subject: [PATCH 48/76] Update clones Signed-off-by: Keshav Santhanam --- megatron/core/inference/contexts/dynamic_context.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index bec468b67eb..9b605c79bbf 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -2089,9 +2089,8 @@ def resume_paused_requests( resume_request_count = 0 if self.paused_request_count > 0: active_block_count_avail = self.block_allocator.get_active_avail() - # Clone needed: flip() returns a view, and subsequent += (line below) would - # write through to self.request_kv_block_counts without the clone. - paused_block_counts = self.request_kv_block_counts[: self.paused_request_count].clone() + # Clone not needed: flip() makes a copy. + paused_block_counts = self.request_kv_block_counts[: self.paused_request_count] # Flip counts before cumsum, since paused requests are resumed from # the right-most index, so we must count resumed blocks starting from # the right side. @@ -2540,7 +2539,7 @@ def update_requests( num_generated_tokens ) - # Clone needed: old_offsets is reused later (line ~2606) to compute raw_positions + # Clone needed: old_offsets is reused later to compute raw_positions # for block-boundary detection. The write-back on the next line overwrites the # underlying tensor, so without clone the boundary-crossing logic would see the # new offsets instead of the pre-update values. From 72b1f6804e5f0d982a49777f6ae512f4a702e038 Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Mon, 9 Mar 2026 00:21:33 -0700 Subject: [PATCH 49/76] Delete extraneous tokens after stop sequence Signed-off-by: Keshav Santhanam --- .../core/inference/engines/dynamic_engine.py | 10 ++- tests/unit_tests/inference/test_stop_words.py | 74 ++++++++++++++++++- 2 files changed, 78 insertions(+), 6 deletions(-) diff --git a/megatron/core/inference/engines/dynamic_engine.py b/megatron/core/inference/engines/dynamic_engine.py index 422aa18f435..318060fb08f 100644 --- a/megatron/core/inference/engines/dynamic_engine.py +++ b/megatron/core/inference/engines/dynamic_engine.py @@ -1217,12 +1217,14 @@ def _get_and_clear_stop_word_finished_ids(self, active_request_ids: list[int]) - self.stop_word_finished_request_ids -= result return result - # TODO : We also might have to delete some tokens, if stop word hit in the middle (speculative case) def _check_stop_words_for_request_post_append(self, request: DynamicInferenceRequest) -> bool: """Check if a request should stop due to stop words (after token is appended). This method is called from post_process_requests after the token has already - been appended to request.generated_tokens. + been appended to request.generated_tokens. In the speculative decoding case, + multiple tokens may have been appended at once. If a stop word is found in the + middle of the speculative tokens, the trailing tokens after the stop word are + truncated from generated_tokens. Args: request: The request to check. @@ -1246,6 +1248,10 @@ def _check_stop_words_for_request_post_append(self, request: DynamicInferenceReq for i in range(self.num_speculative_tokens + 1): end_idx = -i if i > 0 else None if list(generated_tokens[-stop_len - i : end_idx]) == stop_word_ids: + # If the stop word was found in the middle of speculative tokens + # (i > 0), truncate the trailing tokens after the stop word. + if i > 0: + del request.generated_tokens[-i:] return True return False diff --git a/tests/unit_tests/inference/test_stop_words.py b/tests/unit_tests/inference/test_stop_words.py index 31665c0bb81..455e972c2d7 100644 --- a/tests/unit_tests/inference/test_stop_words.py +++ b/tests/unit_tests/inference/test_stop_words.py @@ -31,7 +31,7 @@ class TestStopWordDetection: """Test stop word detection logic.""" def _check_stop_words_for_request_post_append( - self, request: MockDynamicInferenceRequest + self, request: MockDynamicInferenceRequest, num_speculative_tokens: int = 0 ) -> bool: """ Check if a request should stop due to stop words (after token is appended). @@ -48,9 +48,12 @@ def _check_stop_words_for_request_post_append( for stop_word_ids in request.stop_word_ids: stop_len = len(stop_word_ids) if len(generated_tokens) >= stop_len: - # Check if the last stop_len tokens match the stop word - if list(generated_tokens[-stop_len:]) == stop_word_ids: - return True + for i in range(num_speculative_tokens + 1): + end_idx = -i if i > 0 else None + if list(generated_tokens[-stop_len - i : end_idx]) == stop_word_ids: + if i > 0: + del request.generated_tokens[-i:] + return True return False @@ -158,6 +161,69 @@ def test_stop_word_in_middle_not_end(self): ) assert self._check_stop_words_for_request_post_append(request) is False + def test_speculative_stop_word_at_end(self): + """Test stop word at end of speculative tokens (no truncation needed).""" + # Speculative tokens appended: [200, 300], stop word is [300] + request = MockDynamicInferenceRequest( + request_id=1, generated_tokens=[100, 200, 300], stop_word_ids=[[300]] + ) + assert ( + self._check_stop_words_for_request_post_append(request, num_speculative_tokens=2) + is True + ) + assert request.generated_tokens == [100, 200, 300] + + def test_speculative_stop_word_in_middle_truncates(self): + """Test that stop word in middle of speculative tokens truncates trailing tokens.""" + # Speculative tokens appended: [200, 300, 400], stop word is [200] + # Token 200 is at position -3, so tokens [300, 400] should be truncated + request = MockDynamicInferenceRequest( + request_id=1, generated_tokens=[100, 200, 300, 400], stop_word_ids=[[200]] + ) + assert ( + self._check_stop_words_for_request_post_append(request, num_speculative_tokens=3) + is True + ) + assert request.generated_tokens == [100, 200] + + def test_speculative_multi_token_stop_word_in_middle_truncates(self): + """Test multi-token stop word in middle of speculative tokens truncates.""" + # Generated: [100, 200, 300, 400, 500], stop word is [200, 300] + # Stop word ends at -2, so tokens [400, 500] should be truncated + request = MockDynamicInferenceRequest( + request_id=1, + generated_tokens=[100, 200, 300, 400, 500], + stop_word_ids=[[200, 300]], + ) + assert ( + self._check_stop_words_for_request_post_append(request, num_speculative_tokens=4) + is True + ) + assert request.generated_tokens == [100, 200, 300] + + def test_speculative_stop_word_not_found(self): + """Test no stop word found even with speculative scanning.""" + request = MockDynamicInferenceRequest( + request_id=1, generated_tokens=[100, 200, 300, 400], stop_word_ids=[[999]] + ) + assert ( + self._check_stop_words_for_request_post_append(request, num_speculative_tokens=3) + is False + ) + assert request.generated_tokens == [100, 200, 300, 400] + + def test_speculative_stop_word_one_trailing_token(self): + """Test stop word with exactly one trailing token to truncate.""" + # Generated: [100, 200, 300], stop word is [200], one trailing token [300] + request = MockDynamicInferenceRequest( + request_id=1, generated_tokens=[100, 200, 300], stop_word_ids=[[200]] + ) + assert ( + self._check_stop_words_for_request_post_append(request, num_speculative_tokens=2) + is True + ) + assert request.generated_tokens == [100, 200] + class TestStopWordTrackingFlow: """Test the stop word tracking flow between steps.""" From 22e8db38f96b7565f9c405eb45c57c2ec9b63dda Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Mon, 9 Mar 2026 00:25:22 -0700 Subject: [PATCH 50/76] Add engine test for deleting speculative tokens after stop token Signed-off-by: Keshav Santhanam --- .../inference/engines/test_dynamic_engine.py | 84 +++++++++++++++++++ 1 file changed, 84 insertions(+) diff --git a/tests/unit_tests/inference/engines/test_dynamic_engine.py b/tests/unit_tests/inference/engines/test_dynamic_engine.py index bff603594c7..e67799059e0 100644 --- a/tests/unit_tests/inference/engines/test_dynamic_engine.py +++ b/tests/unit_tests/inference/engines/test_dynamic_engine.py @@ -2288,6 +2288,90 @@ def mock_deterministic_forward(*args, **kwargs): ] assert [7, 8, 9] in token_triplets + @pytest.mark.internal + @pytest.mark.skipif( + not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" + ) + @torch.inference_mode() + def test_speculative_stop_word_truncates_trailing_tokens(self): + """Test that when a stop word lands in the middle of speculative tokens, + the extra tokens generated after the stop word are removed. + + With num_speculative_tokens=2, each step produces up to 3 tokens + (1 base + 2 speculative). If the stop word is [6] and the engine + generates [5, 6, 7] in one step, token 7 must be truncated so the + output ends with the stop word [6].""" + + test_config = DynamicEngineTestConfig( + num_requests=0, + min_prompt_length=4, + max_prompt_length=4, + num_tokens_to_generate=10, + num_speculative_tokens=2, + materialize_only_last_token_logits=False, + model_provider="gpt", + ) + env = self._build_test_env(test_config) + + unwrapped_model = env.engine.controller.inference_wrapped_model.model + + # Mock forward to deterministically output an ascending sequence (1->2->3...) + def mock_deterministic_forward(*args, **kwargs): + tokens = kwargs.get("tokens", args[0] if args else kwargs.get("input_ids")) + b, s = tokens.shape + + base_logits = torch.zeros( + b, s, test_config.vocab_size, device=tokens.device, dtype=torch.bfloat16 + ) + next_toks = (tokens + 1).clamp(max=test_config.vocab_size - 1) + base_logits.scatter_(2, next_toks.unsqueeze(-1), 100.0) + + mtp_logits = torch.zeros( + 2, s, test_config.vocab_size, device=tokens.device, dtype=torch.bfloat16 + ) + mtp1_toks = (tokens + 2).clamp(max=test_config.vocab_size - 1) + mtp_logits[0].scatter_(1, mtp1_toks.squeeze(0).unsqueeze(-1), 100.0) + + mtp2_toks = (tokens + 3).clamp(max=test_config.vocab_size - 1) + mtp_logits[1].scatter_(1, mtp2_toks.squeeze(0).unsqueeze(-1), 100.0) + + unwrapped_model._mtp_logits_cache = mtp_logits + return base_logits + + unwrapped_model.forward = mock_deterministic_forward + + env.engine.add_request( + request_id=0, + prompt=torch.tensor([1, 2, 3, 4], device='cuda'), + sampling_params=SamplingParams(num_tokens_to_generate=10, termination_id=99), + ) + + # Stop word [6] will land in the middle of a speculative batch [5, 6, 7]. + # Token 7 should be truncated from the output. + tracked_req = env.engine.get_request(0) + tracked_req.stop_word_ids = [[6]] + + finished_records = [] + while env.engine.has_unfinished_requests(): + res = env.engine.step_modern() + finished_records.extend(res["finished_request_records"]) + + finished_req = finished_records[0].merge() + + assert finished_req.status == Status.COMPLETED + # The output should end exactly at the stop word, with no trailing tokens. + assert finished_req.generated_tokens[-1] == 6, ( + f"Expected last token to be stop word 6, " + f"got {finished_req.generated_tokens[-1]}. " + f"Trailing tokens after stop word were not truncated. " + f"Full output: {finished_req.generated_tokens}" + ) + # Verify no tokens after the stop word exist + assert 7 not in finished_req.generated_tokens, ( + f"Token 7 should have been truncated after stop word 6. " + f"Full output: {finished_req.generated_tokens}" + ) + @pytest.mark.internal @torch.inference_mode() def test_speculative_sequence_length_double_counting(self): From 7ce9546ab7fb0ac8d048b077773cbfce7a881eec Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Mon, 9 Mar 2026 00:36:28 -0700 Subject: [PATCH 51/76] Address review comments Signed-off-by: Keshav Santhanam --- megatron/core/inference/contexts/dynamic_context.py | 3 ++- megatron/core/inference/engines/dynamic_engine.py | 3 ++- tests/unit_tests/inference/test_stop_words.py | 3 ++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index 9b605c79bbf..ee40f38cf37 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -1571,6 +1571,7 @@ def reset_metadata(self) -> None: self.padded_active_token_count = 0 self.padded_active_request_count = 0 self.paused_tokens = None + self.paused_speculative_tokens = None # Reset attention, mamba, and block allocator state. self.reset_attention_state() @@ -2360,7 +2361,7 @@ def update_requests( if self.paused_request_count != 0: assert self.paused_tokens is not None next_tokens = torch.cat((self.paused_tokens, new_tokens)) - if new_speculative_tokens is not None: + if new_speculative_tokens is not None and self.paused_speculative_tokens is not None: new_speculative_tokens = torch.cat( (self.paused_speculative_tokens, new_speculative_tokens), dim=1 ) diff --git a/megatron/core/inference/engines/dynamic_engine.py b/megatron/core/inference/engines/dynamic_engine.py index 318060fb08f..d87a3f65297 100644 --- a/megatron/core/inference/engines/dynamic_engine.py +++ b/megatron/core/inference/engines/dynamic_engine.py @@ -1245,7 +1245,8 @@ def _check_stop_words_for_request_post_append(self, request: DynamicInferenceReq # Check the last stop_len tokens shifting by 1 up to num_speculative_tokens. # We do this regardless of stop_len because speculative decoding can append # multiple tokens at once, meaning the stop word might end at any of those positions. - for i in range(self.num_speculative_tokens + 1): + max_shift = min(self.num_speculative_tokens, len(generated_tokens) - stop_len) + for i in range(max_shift + 1): end_idx = -i if i > 0 else None if list(generated_tokens[-stop_len - i : end_idx]) == stop_word_ids: # If the stop word was found in the middle of speculative tokens diff --git a/tests/unit_tests/inference/test_stop_words.py b/tests/unit_tests/inference/test_stop_words.py index 455e972c2d7..57a200fddc8 100644 --- a/tests/unit_tests/inference/test_stop_words.py +++ b/tests/unit_tests/inference/test_stop_words.py @@ -48,7 +48,8 @@ def _check_stop_words_for_request_post_append( for stop_word_ids in request.stop_word_ids: stop_len = len(stop_word_ids) if len(generated_tokens) >= stop_len: - for i in range(num_speculative_tokens + 1): + max_shift = min(num_speculative_tokens, len(generated_tokens) - stop_len) + for i in range(max_shift + 1): end_idx = -i if i > 0 else None if list(generated_tokens[-stop_len - i : end_idx]) == stop_word_ids: if i > 0: From 89e55c00db268a61ff559452ac55b3c526aea81d Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Mon, 9 Mar 2026 01:01:36 -0700 Subject: [PATCH 52/76] Remove restriction on materialize_only_last_token_logits Signed-off-by: Keshav Santhanam --- .../inference/contexts/dynamic_context.py | 39 ++++++++++++++ .../core/inference/engines/dynamic_engine.py | 3 -- .../text_generation_controller.py | 53 ++++++++++++++----- megatron/core/models/gpt/gpt_model.py | 18 ++++--- megatron/core/models/mamba/mamba_model.py | 18 ++++--- megatron/inference/utils.py | 2 +- 6 files changed, 105 insertions(+), 28 deletions(-) diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index ee40f38cf37..9ac865c9c28 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -1657,6 +1657,45 @@ def last_token_logits(self, logits: Tensor) -> Tensor: return last_token_logits + def speculative_required_logits(self, logits: Tensor) -> Tensor: + """Extract logits at positions required for speculative decoding. + + For decode requests, all tokens (base + speculative) are needed. + For prefill requests, only the last token is needed. + + Args: + logits (Tensor): Hidden states of shape [1, padded_active_token_count, ...]. + + Return: + (Tensor) Logits at required positions, shape + [num_decode * (num_spec + 1) + num_prefill, ...]. + """ + assert logits.size(0) == 1, f"logits.size(0) ({tuple(logits.shape)}) != 1" + assert logits.size(1) == self.padded_active_token_count, ( + f"logits.size(1) ({tuple(logits.shape)}) != " + f"padded_active_token_count ({self.padded_active_token_count})." + ) + + logits = logits.squeeze(0) + active_slice = slice(self.paused_request_count, self.total_request_count) + request_in_prefill = self.request_in_prefill_status_tensor[active_slice] + query_lengths = self.request_query_lengths[active_slice] + + num_prefill = (request_in_prefill == 1).sum().item() + num_decode = len(request_in_prefill) - num_prefill + num_speculative_tokens = self.config.num_speculative_tokens + + # All tokens for decode requests (they come first in the packed sequence). + decode_indices = torch.arange( + num_decode * (num_speculative_tokens + 1), device=logits.device + ) + + # Last token index for each prefill request. + prefill_indices = query_lengths.cumsum(dim=0)[request_in_prefill == 1] - 1 + + required_indices = torch.cat([decode_indices, prefill_indices]) + return logits[required_indices, :] + def _compute_prefix_match( self, req: DynamicInferenceRequest, chunk_length: int ) -> Tuple[list, int, int, int, int, int]: diff --git a/megatron/core/inference/engines/dynamic_engine.py b/megatron/core/inference/engines/dynamic_engine.py index d87a3f65297..d0a47088502 100644 --- a/megatron/core/inference/engines/dynamic_engine.py +++ b/megatron/core/inference/engines/dynamic_engine.py @@ -208,9 +208,6 @@ def __init__(self, controller: TextGenerationController, context: DynamicInferen assert self.num_speculative_tokens >= 0, "Number of speculative tokens must be non-negative" if self.num_speculative_tokens > 0: - assert ( - not inference_config.materialize_only_last_token_logits - ), "Speculative decoding requires materialize_only_last_token_logits to be False" assert ( self.num_speculative_tokens <= self.controller.num_mtp_heads ), f"Number of speculative tokens {self.num_speculative_tokens} must be less than or equal to number of MTP heads {self.controller.num_mtp_heads}" diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index ee03763b1be..6a5ba5c0c73 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -636,16 +636,40 @@ def _dynamic_step_forward_logits(self, input_ids: Tensor, position_ids: Tensor) expected_mtp_logits_length == self.num_mtp_heads ), f"MTP logits length mismatch. Expected mtp logits length {self.num_mtp_heads}, got {expected_mtp_logits_length}" mtp_logits = mtp_logits[: self.num_speculative_tokens] + + if context.config.materialize_only_last_token_logits: + # Base logits are already filtered to required positions by the model. + # Filter MTP logits to match before concatenation. + active_slice = slice(context.paused_request_count, context.total_request_count) + request_in_prefill = context.request_in_prefill_status_tensor[active_slice] + query_lengths = context.request_query_lengths[active_slice] + num_prefill = (request_in_prefill == 1).sum().item() + num_decode = active_request_count - num_prefill + required_logit_indices = self._get_required_logit_indices( + request_in_prefill, query_lengths, num_decode, num_prefill, mtp_logits.device + ) + mtp_logits = mtp_logits[:, required_logit_indices, :] + logits = torch.cat( [logits, mtp_logits], dim=0 - ) # [num_speculative_tokens + 1, seq_len, vocab_size] + ) # [num_speculative_tokens + 1, seq_len_or_required, vocab_size] if self.model_is_pipeline_parallel: - logits_seq_len = ( - active_request_count - if context.config.materialize_only_last_token_logits - else input_ids.shape[1] - ) + if context.config.materialize_only_last_token_logits: + if self.num_speculative_tokens > 0: + active_slice = slice( + context.paused_request_count, context.total_request_count + ) + request_in_prefill = context.request_in_prefill_status_tensor[active_slice] + num_prefill = (request_in_prefill == 1).sum().item() + num_decode = active_request_count - num_prefill + logits_seq_len = ( + num_decode * (self.num_speculative_tokens + 1) + num_prefill + ) + else: + logits_seq_len = active_request_count + else: + logits_seq_len = input_ids.shape[1] logits_shape = [self.num_speculative_tokens + 1, logits_seq_len, self.vocab_size] if is_pipeline_last_stage(self.pp_group): @@ -1016,12 +1040,17 @@ def _dynamic_step_sample_logits_and_verify_tokens( logits.device, ) - required_logits = logits.squeeze(0)[ - required_logit_indices, : - ] # Shape [num_required, vocab_size] - required_mtp_logits = mtp_logits[ - :, required_logit_indices, : - ] # Shape [num_speculative_tokens, num_required, vocab_size] + if context.config.materialize_only_last_token_logits: + # Logits are already pre-filtered to required positions by the model forward. + required_logits = logits.squeeze(0) # Shape [num_required, vocab_size] + required_mtp_logits = mtp_logits # Shape [num_speculative_tokens, num_required, vocab_size] + else: + required_logits = logits.squeeze(0)[ + required_logit_indices, : + ] # Shape [num_required, vocab_size] + required_mtp_logits = mtp_logits[ + :, required_logit_indices, : + ] # Shape [num_speculative_tokens, num_required, vocab_size] # Sample tokens from logits and MTP logits. output_tokens, mtp_output_tokens, repeats = self._sample_speculative_logits( diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py index 33a5f9c9bc4..eedc2c25ade 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py @@ -657,12 +657,18 @@ def _postprocess( self.output_layer.sequence_parallel = False sequence_parallel_override = True - # Reshape [B, 1, H] to [1, B, H] → extract each sample’s true last‐token hidden - # state ([B, H]) → unsqueeze back to [B, 1, H] - # (so that the output layer, which expects S×B×H, receives only the final token) - hidden_states = inference_context.last_token_logits( - hidden_states.squeeze(1).unsqueeze(0) - ).unsqueeze(1) + # Reshape [S, B, H] (with B=1) to [1, S, H] for logit extraction, + # then back to [S’, B, H] for the output layer. + reshaped = hidden_states.squeeze(1).unsqueeze(0) + if inference_context.config.num_speculative_tokens > 0: + # For speculative decoding, keep all decode tokens + last prefill token. + hidden_states = inference_context.speculative_required_logits( + reshaped + ).unsqueeze(1) + else: + hidden_states = inference_context.last_token_logits( + reshaped + ).unsqueeze(1) logits, _ = self.output_layer( hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output diff --git a/megatron/core/models/mamba/mamba_model.py b/megatron/core/models/mamba/mamba_model.py index 71140b1c2af..4ef1aeb1695 100644 --- a/megatron/core/models/mamba/mamba_model.py +++ b/megatron/core/models/mamba/mamba_model.py @@ -446,12 +446,18 @@ def forward( self.output_layer.sequence_parallel = False sequence_parallel_override = True - # Reshape [B, 1, H] to [1, B, H] → extract each sample's true last‐token hidden - # state ([B, H]) → unsqueeze back to [B, 1, H] - # (so that the output layer, which expects S×B×H, receives only the final token) - hidden_states = inference_context.last_token_logits( - hidden_states.squeeze(1).unsqueeze(0) - ).unsqueeze(1) + # Reshape [S, B, H] (with B=1) to [1, S, H] for logit extraction, + # then back to [S', B, H] for the output layer. + reshaped = hidden_states.squeeze(1).unsqueeze(0) + if inference_context.config.num_speculative_tokens > 0: + # For speculative decoding, keep all decode tokens + last prefill token. + hidden_states = inference_context.speculative_required_logits( + reshaped + ).unsqueeze(1) + else: + hidden_states = inference_context.last_token_logits( + reshaped + ).unsqueeze(1) logits, _ = self.output_layer( hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output diff --git a/megatron/inference/utils.py b/megatron/inference/utils.py index ec8f1088be1..37898dc8c90 100644 --- a/megatron/inference/utils.py +++ b/megatron/inference/utils.py @@ -341,7 +341,7 @@ def get_inference_config_from_model_and_args(model: MegatronModule, args): mamba_inference_state_config=mamba_inference_state_config, pg_collection=pg_collection, use_flashinfer_fused_rope=args.use_flashinfer_fused_rope, - materialize_only_last_token_logits=(not args.return_log_probs and args.num_speculative_tokens == 0), + materialize_only_last_token_logits=(not args.return_log_probs), track_generated_token_events=args.inference_dynamic_batching_track_generated_token_events, track_paused_request_events=args.inference_dynamic_batching_track_paused_request_events, enable_chunked_prefill=args.enable_chunked_prefill, From a878759aed7b59180e132f08c80bede154058d62 Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Mon, 9 Mar 2026 01:21:56 -0700 Subject: [PATCH 53/76] Revert circular buffer logic for conv Signed-off-by: Keshav Santhanam --- .../inference/contexts/dynamic_context.py | 33 +- .../text_generation_controller.py | 5 + megatron/core/ssm/mamba_mixer.py | 57 +-- megatron/core/ssm/ops/causal_conv1d_triton.py | 352 ++---------------- .../ssm/test_causal_conv1d_triton.py | 273 ++++---------- 5 files changed, 146 insertions(+), 574 deletions(-) diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index 9ac865c9c28..b589e1db4f4 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -362,7 +362,16 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC math.prod(self.mamba_ssm_states_shape) * self.mamba_ssm_states_dtype.itemsize ) mamba_states_memory_per_request *= self.num_mamba_layers - mamba_states_memory_per_request *= self.num_speculative_tokens + 1 + if self.num_speculative_tokens > 0: + # Add memory for intermediate conv and SSM states + intermediate_memory_per_request = ( + math.prod(self.mamba_conv_states_shape) * self.mamba_conv_states_dtype.itemsize + + math.prod(self.mamba_ssm_states_shape) + * self.mamba_ssm_states_dtype.itemsize + ) + intermediate_memory_per_request *= self.num_mamba_layers + intermediate_memory_per_request *= self.num_speculative_tokens + 1 + mamba_states_memory_per_request += intermediate_memory_per_request # Unified memory and general tensor management. self.unified_memory_level = inference_config.unified_memory_level @@ -615,10 +624,8 @@ def _allocate_mamba_states(self): self.mamba_metadata = MambaMetadata( max_requests=self.max_requests, max_tokens=self.max_tokens ) - expanded_conv_shape = list(self.mamba_conv_states_shape) - expanded_conv_shape[-1] += self.num_speculative_tokens self.mamba_conv_states = torch.empty( - (self.num_mamba_layers, self.max_requests, *expanded_conv_shape), + (self.num_mamba_layers, self.max_requests) + self.mamba_conv_states_shape, dtype=self.mamba_conv_states_dtype, device=torch.cuda.current_device(), ) @@ -628,6 +635,16 @@ def _allocate_mamba_states(self): device=torch.cuda.current_device(), ) if self.num_speculative_tokens > 0: + self.mamba_intermediate_conv_states = torch.empty( + ( + self.num_mamba_layers, + self.max_requests, + self.num_speculative_tokens + 1, + *self.mamba_conv_states_shape, + ), + dtype=self.mamba_conv_states_dtype, + device=torch.cuda.current_device(), + ) self.mamba_intermediate_ssm_states = torch.empty( ( self.num_mamba_layers, @@ -652,6 +669,12 @@ def _allocate_mamba_states(self): self.mamba_ssm_states, device="cpu" ).pin_memory() if self.num_speculative_tokens > 0: + self._offloadable_tensor_names.add("mamba_intermediate_conv_states") + self._offloadable_cpu_backups["mamba_intermediate_conv_states"] = ( + torch.empty_like( + self.mamba_intermediate_conv_states, device="cpu" + ).pin_memory() + ) self._offloadable_tensor_names.add("mamba_intermediate_ssm_states") self._offloadable_cpu_backups["mamba_intermediate_ssm_states"] = ( torch.empty_like( @@ -986,7 +1009,7 @@ def mamba_states_cache( mamba_layer_number = self.layer_map[layer_number - 1] if intermediate: - conv_state = None + conv_state = self.mamba_intermediate_conv_states[mamba_layer_number] ssm_state = self.mamba_intermediate_ssm_states[mamba_layer_number] else: conv_state = self.mamba_conv_states[mamba_layer_number] diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index 6a5ba5c0c73..45fca84d6da 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -806,6 +806,11 @@ def _rewind_kv_cache(self): accepted_tokens_per_decode_request = accepted_tokens_per_request[is_decode_mask] if decode_mamba_indices.numel() > 0: + context.mamba_conv_states[:, decode_mamba_indices] = ( + context.mamba_intermediate_conv_states[ + :, decode_mamba_indices, accepted_tokens_per_decode_request + ] + ) context.mamba_ssm_states[:, decode_mamba_indices] = ( context.mamba_intermediate_ssm_states[ :, decode_mamba_indices, accepted_tokens_per_decode_request diff --git a/megatron/core/ssm/mamba_mixer.py b/megatron/core/ssm/mamba_mixer.py index 7b6b685dbda..4362216e744 100644 --- a/megatron/core/ssm/mamba_mixer.py +++ b/megatron/core/ssm/mamba_mixer.py @@ -25,12 +25,7 @@ ) from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.process_groups_config import ProcessGroupCollection -from megatron.core.ssm.ops.causal_conv1d_triton import ( - causal_conv1d_update, - gather_conv_state, - roll_conv_varlen_states, - scatter_conv_state, -) +from megatron.core.ssm.ops.causal_conv1d_triton import causal_conv1d_update from megatron.core.ssm.ops.mamba_ssm import selective_state_update from megatron.core.tensor_parallel import get_cuda_rng_tracker from megatron.core.transformer import TransformerConfig @@ -433,11 +428,11 @@ def _dynamic_inference(self, hidden_states: torch.Tensor, context: DynamicInfere # Grab standard states conv_state, ssm_state = context.mamba_states_cache(self.layer_number - self.pp_layer_offset) - # Only fetch intermediate SSM state for speculative decoding + # Fetch intermediate states for speculative decoding + int_conv_state = None int_ssm_state = None if context.num_speculative_tokens > 0: - # We ignore the conv intermediate state since we use the expanded circular buffer - _, int_ssm_state = context.mamba_states_cache( + int_conv_state, int_ssm_state = context.mamba_states_cache( self.layer_number - self.pp_layer_offset, intermediate=True ) @@ -468,8 +463,8 @@ def _dynamic_inference(self, hidden_states: torch.Tensor, context: DynamicInfere conv_state, ssm_state, batch_indices=context.mamba_metadata.batch_indices_decode, + intermediate_conv_state=int_conv_state, intermediate_ssm_state=int_ssm_state, - cache_seqlens=context.mamba_metadata.cache_seqlens_decode, ) # Flatten back to [N*S, 1, d] to match merge logic @@ -550,7 +545,6 @@ def _dynamic_inference_prefill( ssm_state=ssm_state, batch_indices=metadata.batch_indices_chunked_prefill, is_chunked_prefill=True, - cache_seqlens=metadata.cache_seqlens_chunked_prefill, ) # Update zxBCdt to contain the remaining slice for regular prefill @@ -690,7 +684,6 @@ def _ssm_prefill( return_varlen_states: bool = False, batch_indices: Optional[torch.Tensor] = None, is_chunked_prefill: bool = False, - cache_seqlens: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Performs SSM computation for inference prefill step. @@ -734,17 +727,12 @@ def _ssm_prefill( # Compute short convolution initial_conv_state = None if conv_state is not None and is_dynamic_batching: - # Extract linear states (newest token is at state_len - 1) - state_len = conv_state.shape[-1] + # xBC should have shape (b l d) for causal_conv1d_varlen_states + assert batch_indices is not None conv_varlen_states = causal_conv1d_varlen_states( - xBC.squeeze(0), cu_seqlens, state_len=state_len + xBC.squeeze(0), cu_seqlens, state_len=conv_state.shape[-1] ) - - # Roll into circular buffer layout expected by decode using fused Triton kernel - conv_varlen_states_circular = roll_conv_varlen_states(conv_varlen_states, cu_seqlens) - - # Update state - tensor_masked_update(conv_state, batch_indices, conv_varlen_states_circular) + tensor_masked_update(conv_state, batch_indices, conv_varlen_states) # Maintain channels-last memory layout to use seq_idx for causal_conv1d_fn # See https://github.com/Dao-AILab/causal-conv1d/blob/69e6dadc28b169a4c49cb86b586f64ee90242c70/csrc/causal_conv1d.cpp#L174 # pylint: disable=line-too-long @@ -753,19 +741,13 @@ def _ssm_prefill( # Maintain channels-last memory layout to use initial_states for causal_conv1d_fn # See https://github.com/Dao-AILab/causal-conv1d/blob/69e6dadc28b169a4c49cb86b586f64ee90242c70/csrc/causal_conv1d.cpp#L200 # pylint: disable=line-too-long assert batch_indices is not None - assert cache_seqlens is not None - state_len = conv_state.shape[-1] - + initial_conv_state = ( + conv_state[batch_indices, :, 1:].permute(0, 2, 1).contiguous().transpose(1, 2) + ) xBC = xBC.transpose(1, 2) - - # Read last (d_conv - 1) tokens from the circular buffer - initial_conv_state = gather_conv_state( - conv_state, batch_indices, cache_seqlens, self.d_conv + tensor_masked_update( + conv_state, batch_indices, F.pad(xBC, (self.d_conv - xBC.shape[-1], 0)) ) - initial_conv_state = initial_conv_state.permute(0, 2, 1).contiguous().transpose(1, 2) - - # Scatter tail back into the main buffer using fused Triton kernel - scatter_conv_state(conv_state, xBC, batch_indices, cache_seqlens) else: # transpose: b l pd --> b pd l xBC = rearrange(xBC, "b l d -> b d l").contiguous() @@ -888,8 +870,8 @@ def _ssm_decode( conv_state: torch.Tensor, ssm_state: torch.Tensor, batch_indices: Optional[torch.Tensor] = None, + intermediate_conv_state: Optional[torch.Tensor] = None, intermediate_ssm_state: Optional[torch.Tensor] = None, - cache_seqlens: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Performs SSM computation for inference decode step. @@ -901,9 +883,10 @@ def _ssm_decode( conv_state: The convolution state tensor for inference. ssm_state: The selective scan state tensor for inference. batch_indices: A map from batch id to position in the Mamba state tensors. - intermediate_ssm_state: Optional buffer for storing sequence steps in SSM state. - cache_seqlens: Optional tensor representing cache sequence length for circular - buffering. + intermediate_conv_state: Optional buffer for storing conv state at each + sequence step (for speculative decoding rollback). + intermediate_ssm_state: Optional buffer for storing SSM state at each + sequence step (for speculative decoding rollback). Returns: The output tensor of shape (b, s, d). @@ -946,8 +929,8 @@ def _ssm_decode( weight.to(conv_state.dtype), self.conv1d.bias.to(conv_state.dtype), self.activation, - cache_seqlens=cache_seqlens, conv_state_indices=batch_indices, + intermediate_conv_states=intermediate_conv_state, ).to(xBC_dtype) x, B, C = torch.split( diff --git a/megatron/core/ssm/ops/causal_conv1d_triton.py b/megatron/core/ssm/ops/causal_conv1d_triton.py index 7a04a218c35..36d14a1d91b 100644 --- a/megatron/core/ssm/ops/causal_conv1d_triton.py +++ b/megatron/core/ssm/ops/causal_conv1d_triton.py @@ -9,268 +9,6 @@ import triton.language as tl -@triton.jit -def _roll_circular_buffer_kernel( - in_ptr, - out_ptr, - cu_seqlens_ptr, - B: tl.constexpr, - D: tl.constexpr, - W: tl.constexpr, - stride_in_b, - stride_in_d, - stride_in_w, - stride_out_b, - stride_out_d, - stride_out_w, - BLOCK_W: tl.constexpr, -): - # We map a 1D grid over B * D - pid = tl.program_id(0) - b = pid // D - d = pid % D - - # 1. Load sequence lengths to calculate shift - seqlen_start = tl.load(cu_seqlens_ptr + b) - seqlen_end = tl.load(cu_seqlens_ptr + b + 1) - seqlen = seqlen_end - seqlen_start - - shift = seqlen % W - - # 2. Setup standard W offsets - w_offsets = tl.arange(0, BLOCK_W) - mask = w_offsets < W - - # 3. Calculate gathered indices - # NOTE: Triton/C++ modulo operator truncates towards zero for negative numbers. - # Because shift < W, (w_offsets - shift) is at least -W + 1. - # Adding W ensures the dividend is strictly positive, giving the correct wrapping behavior. - src_w_offsets = (w_offsets - shift + W) % W - - # 4. Compute memory pointers - in_offsets = in_ptr + (b * stride_in_b) + (d * stride_in_d) + (src_w_offsets * stride_in_w) - out_offsets = out_ptr + (b * stride_out_b) + (d * stride_out_d) + (w_offsets * stride_out_w) - - # 5. Load and Store - vals = tl.load(in_offsets, mask=mask) - tl.store(out_offsets, vals, mask=mask) - - -def roll_conv_varlen_states( - conv_varlen_states: torch.Tensor, cu_seqlens: torch.Tensor -) -> torch.Tensor: - """ - Rolls the convolution states into a circular buffer layout based on sequence lengths. - """ - B, D, W = conv_varlen_states.shape - out = torch.empty_like(conv_varlen_states) - - # Next power of 2 for block size (e.g. W=4 -> BLOCK_W=4) - BLOCK_W = triton.next_power_of_2(W) - - # Grid of size B * D - grid = lambda meta: (B * D,) - - _roll_circular_buffer_kernel[grid]( - conv_varlen_states, - out, - cu_seqlens, - B, - D, - W, - conv_varlen_states.stride(0), - conv_varlen_states.stride(1), - conv_varlen_states.stride(2), - out.stride(0), - out.stride(1), - out.stride(2), - BLOCK_W=BLOCK_W, - ) - - return out - - -@triton.jit -def _gather_conv_state_kernel( - conv_state_ptr, - batch_indices_ptr, - cache_seqlens_ptr, - out_ptr, - stride_cs_b, - stride_cs_d, - stride_cs_w, - stride_out_b, - stride_out_d, - stride_out_w, - D: tl.constexpr, - state_len: tl.constexpr, - d_conv: tl.constexpr, - BLOCK_W: tl.constexpr, -): - pid = tl.program_id(0) - b = pid // D - d = pid % D - - # Load batch map - req_idx = tl.load(batch_indices_ptr + b) - - # Check for padding/invalid batch index - if req_idx < 0: - w_offsets = tl.arange(0, BLOCK_W) - mask = w_offsets < (d_conv - 1) - out_offsets = out_ptr + (b * stride_out_b) + (d * stride_out_d) + (w_offsets * stride_out_w) - # Store 0.0 to prevent NaNs/garbage data from propagating - tl.store(out_offsets, 0.0, mask=mask) - return - - # Load sequence length - seq_len = tl.load(cache_seqlens_ptr + b) - - w_offsets = tl.arange(0, BLOCK_W) - mask = w_offsets < (d_conv - 1) - - # Calculate circular buffer index. - # We add state_len before modulo to prevent negative values in C++ modulo - # when seq_len < d_conv - 1. - val = seq_len - d_conv + 1 + w_offsets - gather_indices = (val + state_len) % state_len - - # Calculate memory offsets - cs_offsets = ( - conv_state_ptr - + (req_idx * stride_cs_b) - + (d * stride_cs_d) - + (gather_indices * stride_cs_w) - ) - out_offsets = out_ptr + (b * stride_out_b) + (d * stride_out_d) + (w_offsets * stride_out_w) - - valid_mask = mask & (val >= 0) - data = tl.load(cs_offsets, mask=valid_mask, other=0.0) - tl.store(out_offsets, data, mask=mask) - - -def gather_conv_state( - conv_state: torch.Tensor, batch_indices: torch.Tensor, cache_seqlens: torch.Tensor, d_conv: int -) -> torch.Tensor: - """Reads the last (d_conv - 1) tokens from the circular convolution state.""" - B = batch_indices.shape[0] - D = conv_state.shape[1] - state_len = conv_state.shape[2] - - out = torch.empty((B, D, d_conv - 1), device=conv_state.device, dtype=conv_state.dtype) - BLOCK_W = triton.next_power_of_2(d_conv - 1) - - grid = lambda meta: (B * D,) - _gather_conv_state_kernel[grid]( - conv_state, - batch_indices, - cache_seqlens, - out, - conv_state.stride(0), - conv_state.stride(1), - conv_state.stride(2), - out.stride(0), - out.stride(1), - out.stride(2), - D, - state_len, - d_conv, - BLOCK_W=BLOCK_W, - ) - return out - - -@triton.jit -def _scatter_conv_state_kernel( - conv_state_ptr, - batch_indices_ptr, - cache_seqlens_ptr, - xBC_tail_ptr, - stride_cs_b, - stride_cs_d, - stride_cs_w, - stride_xBC_b, - stride_xBC_d, - stride_xBC_w, - D: tl.constexpr, - state_len: tl.constexpr, - chunk_len: tl.constexpr, - copy_len: tl.constexpr, - BLOCK_W: tl.constexpr, -): - pid = tl.program_id(0) - b = pid // D - d = pid % D - - # Load batch map - req_idx = tl.load(batch_indices_ptr + b) - - # Check for padding/invalid batch index and safely return - if req_idx < 0: - return - - # Load sequence length - seq_len = tl.load(cache_seqlens_ptr + b) - - w_offsets = tl.arange(0, BLOCK_W) - mask = w_offsets < copy_len - - # seq_len >= 0 and chunk_len >= copy_len, so this is guaranteed to be >= 0. - update_indices = (seq_len + chunk_len - copy_len + w_offsets) % state_len - - # Calculate memory offsets - xBC_offsets = ( - xBC_tail_ptr + (b * stride_xBC_b) + (d * stride_xBC_d) + (w_offsets * stride_xBC_w) - ) - cs_offsets = ( - conv_state_ptr - + (req_idx * stride_cs_b) - + (d * stride_cs_d) - + (update_indices * stride_cs_w) - ) - - data = tl.load(xBC_offsets, mask=mask) - tl.store(cs_offsets, data, mask=mask) - - -def scatter_conv_state( - conv_state: torch.Tensor, - xBC: torch.Tensor, - batch_indices: torch.Tensor, - cache_seqlens: torch.Tensor, -): - """Writes the newest chunk of tokens into the circular convolution state.""" - state_len = conv_state.shape[2] - chunk_len = xBC.shape[-1] - - # We only need to retain at most the last `state_len` tokens of the chunk - copy_len = min(chunk_len, state_len) - xBC_tail = xBC[..., -copy_len:] - - B, D, _ = xBC_tail.shape - state_len = conv_state.shape[2] - BLOCK_W = triton.next_power_of_2(copy_len) - - grid = lambda meta: (B * D,) - _scatter_conv_state_kernel[grid]( - conv_state, - batch_indices, - cache_seqlens, - xBC_tail, - conv_state.stride(0), - conv_state.stride(1), - conv_state.stride(2), - xBC_tail.stride(0), - xBC_tail.stride(1), - xBC_tail.stride(2), - D, - state_len, - chunk_len, - copy_len, - BLOCK_W=BLOCK_W, - ) - - @triton.jit def causal_conv1d_update_kernel( x_ptr, @@ -296,7 +34,6 @@ def causal_conv1d_update_kernel( out_s_stride, out_c_stride, conv_state_indices_ptr, - cache_seqlens_ptr, batch, seq_len, dim, @@ -304,7 +41,6 @@ def causal_conv1d_update_kernel( WIDTH: tl.constexpr, BLOCK_DIM: tl.constexpr, HAS_BIAS: tl.constexpr, - IS_CIRCULAR: tl.constexpr, HAS_STATE_INDICES: tl.constexpr, HAS_INT_STATE: tl.constexpr, SILU_ACTIVATION: tl.constexpr, @@ -368,12 +104,6 @@ def causal_conv1d_update_kernel( x_val_2 = tl.zeros([BLOCK_DIM], dtype=tl.float32) x_val_3 = tl.zeros([BLOCK_DIM], dtype=tl.float32) - # If circular, we only need to read the base cache sequence length once - if IS_CIRCULAR: - base_cache_seqlen = tl.load(cache_seqlens_ptr + batch_id) - else: - base_cache_seqlen = None - # Loop over the sequence dimension (e.g., speculative tokens) for s in range(seq_len): x_ptrs = x_ptr + batch_id * x_b_stride + s * x_s_stride + channel_offsets * x_c_stride @@ -381,60 +111,33 @@ def causal_conv1d_update_kernel( out_ptr + batch_id * out_b_stride + s * out_s_stride + channel_offsets * out_c_stride ) - if not IS_CIRCULAR: - # Load the last (WIDTH - 1) elements to use them BEFORE they are overwritten - # by the shift - if WIDTH >= 2: - x_val_0 = tl.load( - conv_state_ptrs + (state_len - WIDTH + 1) * conv_state_l_stride, mask=mask - ).to(tl.float32) - if WIDTH >= 3: - x_val_1 = tl.load( - conv_state_ptrs + (state_len - WIDTH + 2) * conv_state_l_stride, mask=mask - ).to(tl.float32) - if WIDTH >= 4: - x_val_2 = tl.load( - conv_state_ptrs + (state_len - WIDTH + 3) * conv_state_l_stride, mask=mask - ).to(tl.float32) - - # Shift the linear state buffer left by 1 - i = 0 - while i < state_len - 1: - val = tl.load(conv_state_ptrs + (i + 1) * conv_state_l_stride, mask=mask) - tl.store(conv_state_ptrs + i * conv_state_l_stride, val, mask=mask) - i += 1 - else: - cache_seqlen = base_cache_seqlen + s - update_idx = cache_seqlen % state_len - read_idx = update_idx - (WIDTH - 1) - read_idx = tl.where(read_idx < 0, read_idx + state_len, read_idx) - - if WIDTH >= 2: - state_val = tl.load(conv_state_ptrs + read_idx * conv_state_l_stride, mask=mask) - x_val_0 = state_val.to(tl.float32) - read_idx = tl.where( - read_idx + 1 >= state_len, read_idx + 1 - state_len, read_idx + 1 - ) - if WIDTH >= 3: - state_val = tl.load(conv_state_ptrs + read_idx * conv_state_l_stride, mask=mask) - x_val_1 = state_val.to(tl.float32) - read_idx = tl.where( - read_idx + 1 >= state_len, read_idx + 1 - state_len, read_idx + 1 - ) - if WIDTH >= 4: - state_val = tl.load(conv_state_ptrs + read_idx * conv_state_l_stride, mask=mask) - x_val_2 = state_val.to(tl.float32) + # Load the last (WIDTH - 1) elements to use them BEFORE they are overwritten + # by the shift + if WIDTH >= 2: + x_val_0 = tl.load( + conv_state_ptrs + (state_len - WIDTH + 1) * conv_state_l_stride, mask=mask + ).to(tl.float32) + if WIDTH >= 3: + x_val_1 = tl.load( + conv_state_ptrs + (state_len - WIDTH + 2) * conv_state_l_stride, mask=mask + ).to(tl.float32) + if WIDTH >= 4: + x_val_2 = tl.load( + conv_state_ptrs + (state_len - WIDTH + 3) * conv_state_l_stride, mask=mask + ).to(tl.float32) + + # Shift the linear state buffer left by 1 + i = 0 + while i < state_len - 1: + val = tl.load(conv_state_ptrs + (i + 1) * conv_state_l_stride, mask=mask) + tl.store(conv_state_ptrs + i * conv_state_l_stride, val, mask=mask) + i += 1 # Process the single token for the current sequence step x_val = tl.load(x_ptrs, mask=mask) - # Store the new token in the state buffer - if not IS_CIRCULAR: - tl.store(conv_state_ptrs + (state_len - 1) * conv_state_l_stride, x_val, mask=mask) - else: - cache_seqlen = base_cache_seqlen + s - update_idx = cache_seqlen % state_len - tl.store(conv_state_ptrs + update_idx * conv_state_l_stride, x_val, mask=mask) + # Store the new token at the end of the linear state buffer + tl.store(conv_state_ptrs + (state_len - 1) * conv_state_l_stride, x_val, mask=mask) # Write out to the intermediate state buffer if requested if HAS_INT_STATE: @@ -481,7 +184,6 @@ def causal_conv1d_update( weight: torch.Tensor, bias: torch.Tensor | None, silu_activation: bool, - cache_seqlens: torch.Tensor | None, conv_state_indices: torch.Tensor | None, intermediate_conv_states: torch.Tensor | None = None, ) -> torch.Tensor: @@ -505,12 +207,6 @@ def causal_conv1d_update( bias_stride = 0 has_bias = False - if cache_seqlens is not None: - is_circular = True - else: - cache_seqlens = x # Dummy pointer - is_circular = False - if conv_state_indices is not None: has_state_indices = True else: @@ -560,7 +256,6 @@ def causal_conv1d_update( out.stride(1), out.stride(2), conv_state_indices, - cache_seqlens, batch, seq_len, dim, @@ -568,7 +263,6 @@ def causal_conv1d_update( WIDTH=width, BLOCK_DIM=BLOCK_DIM, HAS_BIAS=has_bias, - IS_CIRCULAR=is_circular, HAS_STATE_INDICES=has_state_indices, HAS_INT_STATE=has_int_state, SILU_ACTIVATION=silu_activation == "silu", diff --git a/tests/unit_tests/ssm/test_causal_conv1d_triton.py b/tests/unit_tests/ssm/test_causal_conv1d_triton.py index 1d58897840c..f937554798a 100644 --- a/tests/unit_tests/ssm/test_causal_conv1d_triton.py +++ b/tests/unit_tests/ssm/test_causal_conv1d_triton.py @@ -3,12 +3,7 @@ import pytest import torch -from megatron.core.ssm.ops.causal_conv1d_triton import ( - causal_conv1d_update, - gather_conv_state, - roll_conv_varlen_states, - scatter_conv_state, -) +from megatron.core.ssm.ops.causal_conv1d_triton import causal_conv1d_update def _requires_cuda(): @@ -19,59 +14,6 @@ def _requires_cuda(): # ---------------------- Reference Implementations ---------------------- # -def roll_conv_varlen_states_ref(conv_varlen_states, cu_seqlens): - """Reference: roll each [D, W] slice by (seqlen % W) positions.""" - B, D, W = conv_varlen_states.shape - out = torch.empty_like(conv_varlen_states) - for b in range(B): - seqlen = (cu_seqlens[b + 1] - cu_seqlens[b]).item() - shift = seqlen % W - for d in range(D): - for w in range(W): - src = (w - shift + W) % W - out[b, d, w] = conv_varlen_states[b, d, src] - return out - - -def gather_conv_state_ref(conv_state, batch_indices, cache_seqlens, d_conv): - """Reference: read last (d_conv-1) elements from circular buffer.""" - B = batch_indices.shape[0] - D = conv_state.shape[1] - state_len = conv_state.shape[2] - out = torch.zeros((B, D, d_conv - 1), device=conv_state.device, dtype=conv_state.dtype) - for b in range(B): - req_idx = batch_indices[b].item() - if req_idx < 0: - continue - seq_len = cache_seqlens[b].item() - for d in range(D): - for w in range(d_conv - 1): - val = seq_len - d_conv + 1 + w - if val < 0: - continue - idx = (val + state_len) % state_len - out[b, d, w] = conv_state[req_idx, d, idx] - return out - - -def scatter_conv_state_ref(conv_state, xBC, batch_indices, cache_seqlens): - """Reference: write newest chunk into circular buffer.""" - state_len = conv_state.shape[2] - chunk_len = xBC.shape[-1] - copy_len = min(chunk_len, state_len) - xBC_tail = xBC[..., -copy_len:] - B, D, _ = xBC_tail.shape - for b in range(B): - req_idx = batch_indices[b].item() - if req_idx < 0: - continue - seq_len = cache_seqlens[b].item() - for d in range(D): - for w in range(copy_len): - idx = (seq_len + chunk_len - copy_len + w) % state_len - conv_state[req_idx, d, idx] = xBC_tail[b, d, w] - - def causal_conv1d_update_ref(x, conv_state, weight, bias, silu_activation): """Reference: linear (non-circular) causal conv1d update.""" batch, seq_len, dim = x.shape @@ -98,142 +40,6 @@ def causal_conv1d_update_ref(x, conv_state, weight, bias, silu_activation): # ---------------------- Tests ---------------------- # -@pytest.mark.internal -class TestRollConvVarlenStates: - - def setup_method(self, method): - _requires_cuda() - - @pytest.mark.parametrize("B,D,W", [(1, 4, 4), (3, 8, 4), (2, 16, 3)]) - def test_matches_reference(self, B, D, W): - torch.manual_seed(42) - conv_states = torch.randn(B, D, W, device="cuda", dtype=torch.float32) - seqlens = torch.randint(1, 20, (B,), device="cuda", dtype=torch.int32) - cu_seqlens = torch.zeros(B + 1, device="cuda", dtype=torch.int32) - cu_seqlens[1:] = torch.cumsum(seqlens, dim=0) - - result = roll_conv_varlen_states(conv_states, cu_seqlens) - expected = roll_conv_varlen_states_ref(conv_states, cu_seqlens) - - torch.testing.assert_close(result, expected) - - def test_zero_shift(self): - """When all seqlens are multiples of W, no rolling should occur.""" - B, D, W = 2, 4, 4 - conv_states = torch.randn(B, D, W, device="cuda", dtype=torch.float32) - cu_seqlens = torch.tensor([0, W, 2 * W], device="cuda", dtype=torch.int32) - - result = roll_conv_varlen_states(conv_states, cu_seqlens) - torch.testing.assert_close(result, conv_states) - - @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) - def test_dtypes(self, dtype): - B, D, W = 2, 8, 4 - conv_states = torch.randn(B, D, W, device="cuda", dtype=dtype) - cu_seqlens = torch.tensor([0, 3, 7], device="cuda", dtype=torch.int32) - - result = roll_conv_varlen_states(conv_states, cu_seqlens) - expected = roll_conv_varlen_states_ref(conv_states, cu_seqlens) - - torch.testing.assert_close(result, expected) - - -@pytest.mark.internal -class TestGatherConvState: - - def setup_method(self, method): - _requires_cuda() - - @pytest.mark.parametrize("d_conv", [2, 3, 4]) - def test_matches_reference(self, d_conv): - torch.manual_seed(42) - B, D, state_len = 3, 8, 16 - conv_state = torch.randn(B, D, state_len, device="cuda", dtype=torch.float32) - batch_indices = torch.arange(B, device="cuda", dtype=torch.int32) - cache_seqlens = torch.randint( - d_conv, state_len + 10, (B,), device="cuda", dtype=torch.int32 - ) - - result = gather_conv_state(conv_state, batch_indices, cache_seqlens, d_conv) - expected = gather_conv_state_ref(conv_state, batch_indices, cache_seqlens, d_conv) - - torch.testing.assert_close(result, expected) - - def test_negative_batch_index_zeros_output(self): - B, D, state_len, d_conv = 2, 4, 8, 4 - conv_state = torch.randn(B, D, state_len, device="cuda", dtype=torch.float32) - batch_indices = torch.tensor([-1, 0], device="cuda", dtype=torch.int32) - cache_seqlens = torch.tensor([5, 5], device="cuda", dtype=torch.int32) - - result = gather_conv_state(conv_state, batch_indices, cache_seqlens, d_conv) - - # First batch should be all zeros due to negative index - torch.testing.assert_close( - result[0], torch.zeros(D, d_conv - 1, device="cuda", dtype=torch.float32) - ) - - def test_small_seqlen(self): - """When seq_len < d_conv - 1, early positions should be zero-padded.""" - B, D, state_len, d_conv = 1, 4, 8, 4 - conv_state = torch.randn(B, D, state_len, device="cuda", dtype=torch.float32) - batch_indices = torch.tensor([0], device="cuda", dtype=torch.int32) - cache_seqlens = torch.tensor([1], device="cuda", dtype=torch.int32) - - result = gather_conv_state(conv_state, batch_indices, cache_seqlens, d_conv) - expected = gather_conv_state_ref(conv_state, batch_indices, cache_seqlens, d_conv) - - torch.testing.assert_close(result, expected) - - -@pytest.mark.internal -class TestScatterConvState: - - def setup_method(self, method): - _requires_cuda() - - @pytest.mark.parametrize("chunk_len", [4, 8, 20]) - def test_matches_reference(self, chunk_len): - torch.manual_seed(42) - B, D, state_len = 3, 8, 16 - conv_state_triton = torch.zeros(B, D, state_len, device="cuda", dtype=torch.float32) - conv_state_ref = conv_state_triton.clone() - xBC = torch.randn(B, D, chunk_len, device="cuda", dtype=torch.float32) - batch_indices = torch.arange(B, device="cuda", dtype=torch.int32) - cache_seqlens = torch.randint(0, 20, (B,), device="cuda", dtype=torch.int32) - - scatter_conv_state(conv_state_triton, xBC, batch_indices, cache_seqlens) - scatter_conv_state_ref(conv_state_ref, xBC, batch_indices, cache_seqlens) - - torch.testing.assert_close(conv_state_triton, conv_state_ref) - - def test_negative_batch_index_noop(self): - B, D, state_len, chunk_len = 2, 4, 8, 4 - conv_state = torch.zeros(B, D, state_len, device="cuda", dtype=torch.float32) - conv_state_orig = conv_state.clone() - xBC = torch.randn(2, D, chunk_len, device="cuda", dtype=torch.float32) - batch_indices = torch.tensor([-1, -1], device="cuda", dtype=torch.int32) - cache_seqlens = torch.tensor([0, 0], device="cuda", dtype=torch.int32) - - scatter_conv_state(conv_state, xBC, batch_indices, cache_seqlens) - - torch.testing.assert_close(conv_state, conv_state_orig) - - def test_chunk_larger_than_state(self): - """When chunk_len > state_len, only last state_len tokens should be written.""" - B, D, state_len = 1, 4, 4 - chunk_len = 10 - conv_state = torch.zeros(B, D, state_len, device="cuda", dtype=torch.float32) - conv_state_ref = conv_state.clone() - xBC = torch.randn(B, D, chunk_len, device="cuda", dtype=torch.float32) - batch_indices = torch.tensor([0], device="cuda", dtype=torch.int32) - cache_seqlens = torch.tensor([0], device="cuda", dtype=torch.int32) - - scatter_conv_state(conv_state, xBC, batch_indices, cache_seqlens) - scatter_conv_state_ref(conv_state_ref, xBC, batch_indices, cache_seqlens) - - torch.testing.assert_close(conv_state, conv_state_ref) - - @pytest.mark.internal class TestCausalConv1dUpdate: @@ -255,7 +61,6 @@ def test_linear_no_bias(self, width): weight, bias=None, silu_activation=False, - cache_seqlens=None, conv_state_indices=None, ) expected = causal_conv1d_update_ref( @@ -281,7 +86,6 @@ def test_linear_with_bias(self, width): weight, bias=bias, silu_activation=False, - cache_seqlens=None, conv_state_indices=None, ) expected = causal_conv1d_update_ref( @@ -306,7 +110,6 @@ def test_linear_with_silu(self, width): weight, bias=bias, silu_activation="silu", - cache_seqlens=None, conv_state_indices=None, ) expected = causal_conv1d_update_ref( @@ -329,7 +132,6 @@ def test_2d_input(self): weight, bias=None, silu_activation=False, - cache_seqlens=None, conv_state_indices=None, ) @@ -355,7 +157,6 @@ def test_conv_state_indices(self): weight, bias=None, silu_activation=False, - cache_seqlens=None, conv_state_indices=state_indices, ) @@ -367,7 +168,6 @@ def test_conv_state_indices(self): weight, bias=None, silu_activation=False, - cache_seqlens=None, conv_state_indices=None, ) @@ -388,7 +188,6 @@ def test_negative_state_index_zeros_output(self): weight, bias=None, silu_activation=False, - cache_seqlens=None, conv_state_indices=state_indices, ) @@ -409,10 +208,78 @@ def test_half_precision(self, dtype): weight, bias=None, silu_activation=False, - cache_seqlens=None, conv_state_indices=None, ) assert result.dtype == dtype assert result.shape == (B, seq_len, D) assert torch.isfinite(result).all() + + @pytest.mark.parametrize("width", [2, 3, 4]) + def test_intermediate_state(self, width): + """Test that intermediate conv states are correctly stored at each sequence step.""" + torch.manual_seed(42) + B, seq_len, D, state_len = 2, 4, 64, 8 + x = torch.randn(B, seq_len, D, device="cuda", dtype=torch.float32) + conv_state = torch.randn(B, D, state_len, device="cuda", dtype=torch.float32) + weight = torch.randn(D, width, device="cuda", dtype=torch.float32) + + # Allocate intermediate state buffer: (B, seq_len, D, state_len) + int_states = torch.zeros(B, seq_len, D, state_len, device="cuda", dtype=torch.float32) + + # Run with intermediate state recording + conv_state_copy = conv_state.clone() + result = causal_conv1d_update( + x, + conv_state_copy, + weight, + bias=None, + silu_activation=False, + conv_state_indices=None, + intermediate_conv_states=int_states, + ) + + # Verify by running step-by-step and checking each intermediate + conv_state_ref = conv_state.clone() + for s in range(seq_len): + conv_state_ref[:, :, :-1] = conv_state_ref[:, :, 1:].clone() + conv_state_ref[:, :, -1] = x[:, s, :] + torch.testing.assert_close( + int_states[:, s, :, :], conv_state_ref, atol=1e-5, rtol=1e-5 + ) + + def test_intermediate_state_with_indices(self): + """Test intermediate states work correctly with conv_state_indices mapping.""" + torch.manual_seed(42) + B, seq_len, D, state_len, width = 2, 3, 64, 8, 4 + num_states = 4 + x = torch.randn(B, seq_len, D, device="cuda", dtype=torch.float32) + conv_state = torch.randn(num_states, D, state_len, device="cuda", dtype=torch.float32) + weight = torch.randn(D, width, device="cuda", dtype=torch.float32) + state_indices = torch.tensor([2, 0], device="cuda", dtype=torch.int32) + + # Intermediate states are indexed by state_batch_coord (i.e., req index, not batch index) + int_states = torch.zeros( + num_states, seq_len, D, state_len, device="cuda", dtype=torch.float32 + ) + + conv_state_copy = conv_state.clone() + result = causal_conv1d_update( + x, + conv_state_copy, + weight, + bias=None, + silu_activation=False, + conv_state_indices=state_indices, + intermediate_conv_states=int_states, + ) + + # The final intermediate state at last seq step should match the final conv_state + for b_idx in range(B): + req_idx = state_indices[b_idx].item() + torch.testing.assert_close( + int_states[req_idx, seq_len - 1, :, :], + conv_state_copy[req_idx, :, :], + atol=1e-5, + rtol=1e-5, + ) From 47195a1cc964eee48b251e5226dc7680d0ba9ac0 Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Mon, 9 Mar 2026 01:32:54 -0700 Subject: [PATCH 54/76] Linting Signed-off-by: Keshav Santhanam --- .../inference/contexts/dynamic_context.py | 3 +- .../text_generation_controller.py | 12 +++--- megatron/core/models/gpt/gpt_model.py | 4 +- megatron/core/models/mamba/mamba_model.py | 5 +-- tests/unit_tests/inference/test_stop_words.py | 4 +- .../ssm/test_causal_conv1d_triton.py | 39 +++---------------- 6 files changed, 15 insertions(+), 52 deletions(-) diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index b589e1db4f4..7ee13494e40 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -366,8 +366,7 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC # Add memory for intermediate conv and SSM states intermediate_memory_per_request = ( math.prod(self.mamba_conv_states_shape) * self.mamba_conv_states_dtype.itemsize - + math.prod(self.mamba_ssm_states_shape) - * self.mamba_ssm_states_dtype.itemsize + + math.prod(self.mamba_ssm_states_shape) * self.mamba_ssm_states_dtype.itemsize ) intermediate_memory_per_request *= self.num_mamba_layers intermediate_memory_per_request *= self.num_speculative_tokens + 1 diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index 45fca84d6da..24bd876fec8 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -657,15 +657,11 @@ def _dynamic_step_forward_logits(self, input_ids: Tensor, position_ids: Tensor) if self.model_is_pipeline_parallel: if context.config.materialize_only_last_token_logits: if self.num_speculative_tokens > 0: - active_slice = slice( - context.paused_request_count, context.total_request_count - ) + active_slice = slice(context.paused_request_count, context.total_request_count) request_in_prefill = context.request_in_prefill_status_tensor[active_slice] num_prefill = (request_in_prefill == 1).sum().item() num_decode = active_request_count - num_prefill - logits_seq_len = ( - num_decode * (self.num_speculative_tokens + 1) + num_prefill - ) + logits_seq_len = num_decode * (self.num_speculative_tokens + 1) + num_prefill else: logits_seq_len = active_request_count else: @@ -1048,7 +1044,9 @@ def _dynamic_step_sample_logits_and_verify_tokens( if context.config.materialize_only_last_token_logits: # Logits are already pre-filtered to required positions by the model forward. required_logits = logits.squeeze(0) # Shape [num_required, vocab_size] - required_mtp_logits = mtp_logits # Shape [num_speculative_tokens, num_required, vocab_size] + required_mtp_logits = ( + mtp_logits # Shape [num_speculative_tokens, num_required, vocab_size] + ) else: required_logits = logits.squeeze(0)[ required_logit_indices, : diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py index eedc2c25ade..3f27f9e5792 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py @@ -666,9 +666,7 @@ def _postprocess( reshaped ).unsqueeze(1) else: - hidden_states = inference_context.last_token_logits( - reshaped - ).unsqueeze(1) + hidden_states = inference_context.last_token_logits(reshaped).unsqueeze(1) logits, _ = self.output_layer( hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output diff --git a/megatron/core/models/mamba/mamba_model.py b/megatron/core/models/mamba/mamba_model.py index 4ef1aeb1695..007001ea6b0 100644 --- a/megatron/core/models/mamba/mamba_model.py +++ b/megatron/core/models/mamba/mamba_model.py @@ -3,7 +3,6 @@ import logging from typing import Literal, Optional -import torch from torch import Tensor from megatron.core import tensor_parallel @@ -455,9 +454,7 @@ def forward( reshaped ).unsqueeze(1) else: - hidden_states = inference_context.last_token_logits( - reshaped - ).unsqueeze(1) + hidden_states = inference_context.last_token_logits(reshaped).unsqueeze(1) logits, _ = self.output_layer( hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output diff --git a/tests/unit_tests/inference/test_stop_words.py b/tests/unit_tests/inference/test_stop_words.py index 57a200fddc8..148083cf093 100644 --- a/tests/unit_tests/inference/test_stop_words.py +++ b/tests/unit_tests/inference/test_stop_words.py @@ -192,9 +192,7 @@ def test_speculative_multi_token_stop_word_in_middle_truncates(self): # Generated: [100, 200, 300, 400, 500], stop word is [200, 300] # Stop word ends at -2, so tokens [400, 500] should be truncated request = MockDynamicInferenceRequest( - request_id=1, - generated_tokens=[100, 200, 300, 400, 500], - stop_word_ids=[[200, 300]], + request_id=1, generated_tokens=[100, 200, 300, 400, 500], stop_word_ids=[[200, 300]] ) assert ( self._check_stop_words_for_request_post_append(request, num_speculative_tokens=4) diff --git a/tests/unit_tests/ssm/test_causal_conv1d_triton.py b/tests/unit_tests/ssm/test_causal_conv1d_triton.py index f937554798a..3015f5ed989 100644 --- a/tests/unit_tests/ssm/test_causal_conv1d_triton.py +++ b/tests/unit_tests/ssm/test_causal_conv1d_triton.py @@ -56,12 +56,7 @@ def test_linear_no_bias(self, width): weight = torch.randn(D, width, device="cuda", dtype=torch.float32) result = causal_conv1d_update( - x, - conv_state_triton, - weight, - bias=None, - silu_activation=False, - conv_state_indices=None, + x, conv_state_triton, weight, bias=None, silu_activation=False, conv_state_indices=None ) expected = causal_conv1d_update_ref( x, conv_state_ref, weight, bias=None, silu_activation=False @@ -81,12 +76,7 @@ def test_linear_with_bias(self, width): bias = torch.randn(D, device="cuda", dtype=torch.float32) result = causal_conv1d_update( - x, - conv_state_triton, - weight, - bias=bias, - silu_activation=False, - conv_state_indices=None, + x, conv_state_triton, weight, bias=bias, silu_activation=False, conv_state_indices=None ) expected = causal_conv1d_update_ref( x, conv_state_ref, weight, bias=bias, silu_activation=False @@ -105,12 +95,7 @@ def test_linear_with_silu(self, width): bias = torch.randn(D, device="cuda", dtype=torch.float32) result = causal_conv1d_update( - x, - conv_state_triton, - weight, - bias=bias, - silu_activation="silu", - conv_state_indices=None, + x, conv_state_triton, weight, bias=bias, silu_activation="silu", conv_state_indices=None ) expected = causal_conv1d_update_ref( x, conv_state_ref, weight, bias=bias, silu_activation=True @@ -127,12 +112,7 @@ def test_2d_input(self): weight = torch.randn(D, width, device="cuda", dtype=torch.float32) result = causal_conv1d_update( - x, - conv_state, - weight, - bias=None, - silu_activation=False, - conv_state_indices=None, + x, conv_state, weight, bias=None, silu_activation=False, conv_state_indices=None ) assert result.dim() == 2 @@ -203,12 +183,7 @@ def test_half_precision(self, dtype): weight = torch.randn(D, width, device="cuda", dtype=dtype) result = causal_conv1d_update( - x, - conv_state, - weight, - bias=None, - silu_activation=False, - conv_state_indices=None, + x, conv_state, weight, bias=None, silu_activation=False, conv_state_indices=None ) assert result.dtype == dtype @@ -244,9 +219,7 @@ def test_intermediate_state(self, width): for s in range(seq_len): conv_state_ref[:, :, :-1] = conv_state_ref[:, :, 1:].clone() conv_state_ref[:, :, -1] = x[:, s, :] - torch.testing.assert_close( - int_states[:, s, :, :], conv_state_ref, atol=1e-5, rtol=1e-5 - ) + torch.testing.assert_close(int_states[:, s, :, :], conv_state_ref, atol=1e-5, rtol=1e-5) def test_intermediate_state_with_indices(self): """Test intermediate states work correctly with conv_state_indices mapping.""" From 6d7da581cda1d682b7eaa6dfc9ddc3c684a1c532 Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Mon, 9 Mar 2026 01:39:18 -0700 Subject: [PATCH 55/76] Remove references to cache_seqlens Signed-off-by: Keshav Santhanam --- .../attention_context/mamba_metadata.py | 23 ----- .../inference/contexts/dynamic_context.py | 1 - .../attention_metadata/test_mamba_metadata.py | 91 ------------------- 3 files changed, 115 deletions(-) diff --git a/megatron/core/inference/contexts/attention_context/mamba_metadata.py b/megatron/core/inference/contexts/attention_context/mamba_metadata.py index 64b9ef35f13..75174db7ad1 100644 --- a/megatron/core/inference/contexts/attention_context/mamba_metadata.py +++ b/megatron/core/inference/contexts/attention_context/mamba_metadata.py @@ -41,16 +41,6 @@ def __init__(self, max_requests: int, max_tokens: int): (1,), -1, dtype=torch.int32, device=self.device ) - # Cache sequence lengths for decode requests - self._cache_seqlens_decode_buffer = torch.zeros( - (self.max_requests,), dtype=torch.int32, device=self.device - ) - - # Cache sequence lengths for the chunked prefill request - self._cache_seqlens_chunked_prefill_buffer = torch.zeros( - (1,), dtype=torch.int32, device=self.device - ) - # Map from token id to request id for active prefill requests self._seq_idx_buffer = torch.full( (1, self.max_tokens), -1, dtype=torch.int32, device=self.device @@ -105,15 +95,12 @@ def reset_varlen_metadata(self) -> None: self.seq_idx = None self.device_decode_prefill = None self.device_chunked_prefill = None - self.cache_seqlens_decode = None - self.cache_seqlens_chunked_prefill = None def update( self, active_mamba_indices: torch.Tensor, token_to_request_idx: torch.Tensor, cu_seqlens: torch.Tensor, - request_kv_length_offsets: torch.Tensor, batch_dimensions: InferenceBatchDimensions, padded_batch_dimensions: InferenceBatchDimensions, enable_chunked_prefill: bool, @@ -185,14 +172,9 @@ def update( self._batch_indices_decode_buffer[:real_decode_count].copy_( active_mamba_indices[:real_decode_count] ) - self._cache_seqlens_decode_buffer[:real_decode_count].copy_( - request_kv_length_offsets[:real_decode_count] - ) if padded_decode_count > real_decode_count: self._batch_indices_decode_buffer[real_decode_count:padded_decode_count] = -1 - self._cache_seqlens_decode_buffer[real_decode_count:padded_decode_count] = 0 self.batch_indices_decode = self._batch_indices_decode_buffer[:padded_decode_count] - self.cache_seqlens_decode = self._cache_seqlens_decode_buffer[:padded_decode_count] # Determine if we have a chunked prefill request and adjust counts for regular prefill regular_prefill_count = real_prefill_count @@ -207,11 +189,6 @@ def update( self._batch_indices_chunked_prefill_buffer[0] = active_mamba_indices[chunked_req_idx] self.batch_indices_chunked_prefill = self._batch_indices_chunked_prefill_buffer - # Update chunked prefill cache seqlen - self._cache_seqlens_chunked_prefill_buffer[0] = request_kv_length_offsets[ - chunked_req_idx - ] - self.cache_seqlens_chunked_prefill = self._cache_seqlens_chunked_prefill_buffer else: self.batch_indices_chunked_prefill = None diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index 7ee13494e40..84648efdbf0 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -1534,7 +1534,6 @@ def initialize_attention_state( active_mamba_indices_view, token_to_request_idx_view, cu_seqlens, - request_kv_length_offsets_view, batch_dimensions=attn_dimensions, padded_batch_dimensions=self.padded_batch_dimensions, enable_chunked_prefill=self.is_chunked_prefill_enabled(), diff --git a/tests/unit_tests/inference/contexts/attention_metadata/test_mamba_metadata.py b/tests/unit_tests/inference/contexts/attention_metadata/test_mamba_metadata.py index fce9518caeb..21459cece7d 100644 --- a/tests/unit_tests/inference/contexts/attention_metadata/test_mamba_metadata.py +++ b/tests/unit_tests/inference/contexts/attention_metadata/test_mamba_metadata.py @@ -31,7 +31,6 @@ def _run_update_test( num_decode_requests: int, padded_dims: InferenceBatchDimensions, enable_chunked_prefill: bool, - request_kv_length_offsets: list[int] | None = None, ): """ Helper to construct inputs and run update(). @@ -43,7 +42,6 @@ def _run_update_test( num_decode_requests: Number of requests in req_seq_lengths that are in the decode phase. padded_dims: The padded batch dimensions to test against. enable_chunked_prefill: Whether chunked prefill is enabled. - request_kv_length_offsets: KV cache length offsets per request. Defaults to zeros. """ num_active_requests = len(req_seq_lengths) total_tokens = sum(req_seq_lengths) @@ -55,12 +53,6 @@ def _run_update_test( decode_req_count=num_decode_requests, ) - if request_kv_length_offsets is None: - request_kv_length_offsets = [0] * num_active_requests - kv_length_offsets_tensor = torch.tensor( - request_kv_length_offsets, dtype=torch.int32, device=metadata.device - ) - # Assuming 1:1 mapping (req_id i -> slot i) active_mamba_indices = torch.arange( num_active_requests, dtype=torch.int32, device=metadata.device @@ -82,7 +74,6 @@ def _run_update_test( active_mamba_indices=active_mamba_indices, token_to_request_idx=token_to_req_tensor, cu_seqlens=cu_seqlens_tensor, - request_kv_length_offsets=kv_length_offsets_tensor, batch_dimensions=real_dims, padded_batch_dimensions=padded_dims, enable_chunked_prefill=enable_chunked_prefill, @@ -99,7 +90,6 @@ def test_update_decode_only_exact_match(self, metadata_context): """Test simple decode only case where real dims match padded dims.""" seq_lengths = [1, 1, 1, 1] # 4 requests num_decode = 4 - kv_offsets = [5, 10, 15, 20] padded_dims = InferenceBatchDimensions( token_count=4, prefill_req_count=0, decode_req_count=4 ) @@ -110,20 +100,13 @@ def test_update_decode_only_exact_match(self, metadata_context): num_decode, padded_dims, enable_chunked_prefill=False, - request_kv_length_offsets=kv_offsets, ) expected_decode = torch.arange(4, dtype=torch.int32, device=metadata_context.device) assert torch.equal(metadata_context.batch_indices_decode, expected_decode) - expected_cache_seqlens = torch.tensor( - kv_offsets, dtype=torch.int32, device=metadata_context.device - ) - assert torch.equal(metadata_context.cache_seqlens_decode, expected_cache_seqlens) - assert metadata_context.batch_indices_prefill is None assert metadata_context.batch_indices_chunked_prefill is None - assert metadata_context.cache_seqlens_chunked_prefill is None assert metadata_context.device_decode_prefill is None assert metadata_context.cu_seqlens is None assert metadata_context.seq_idx is None @@ -133,7 +116,6 @@ def test_update_decode_only_padded(self, metadata_context): """Test decode only with padding (e.g. using CUDA graphs bucket).""" seq_lengths = [1, 1] # 2 requests num_decode = 2 - kv_offsets = [7, 12] # Padding to 4 requests padded_dims = InferenceBatchDimensions( token_count=4, prefill_req_count=0, decode_req_count=4 @@ -145,7 +127,6 @@ def test_update_decode_only_padded(self, metadata_context): num_decode, padded_dims, enable_chunked_prefill=False, - request_kv_length_offsets=kv_offsets, ) expected_decode = torch.tensor( @@ -153,14 +134,8 @@ def test_update_decode_only_padded(self, metadata_context): ) assert torch.equal(metadata_context.batch_indices_decode, expected_decode) - expected_cache_seqlens = torch.tensor( - [7, 12, 0, 0], dtype=torch.int32, device=metadata_context.device - ) - assert torch.equal(metadata_context.cache_seqlens_decode, expected_cache_seqlens) - assert metadata_context.batch_indices_prefill is None assert metadata_context.batch_indices_chunked_prefill is None - assert metadata_context.cache_seqlens_chunked_prefill is None assert metadata_context.device_decode_prefill is None @pytest.mark.internal @@ -168,7 +143,6 @@ def test_update_chunked_enabled_no_prefill_reqs(self, metadata_context): """Test edge case: Chunked prefill enabled, but only decode requests exist.""" seq_lengths = [1, 1] num_decode = 2 - kv_offsets = [3, 8] padded_dims = InferenceBatchDimensions( token_count=2, prefill_req_count=0, decode_req_count=2 ) @@ -179,20 +153,13 @@ def test_update_chunked_enabled_no_prefill_reqs(self, metadata_context): num_decode, padded_dims, enable_chunked_prefill=True, - request_kv_length_offsets=kv_offsets, ) # Should behave exactly like decode-only (chunked logic skipped if real_prefill == 0) expected_decode = torch.tensor([0, 1], dtype=torch.int32, device=metadata_context.device) assert torch.equal(metadata_context.batch_indices_decode, expected_decode) - expected_cache_seqlens = torch.tensor( - [3, 8], dtype=torch.int32, device=metadata_context.device - ) - assert torch.equal(metadata_context.cache_seqlens_decode, expected_cache_seqlens) - assert metadata_context.batch_indices_chunked_prefill is None - assert metadata_context.cache_seqlens_chunked_prefill is None assert metadata_context.batch_indices_prefill is None assert metadata_context.cu_seqlens is None assert metadata_context.seq_idx is None @@ -228,9 +195,7 @@ def test_update_prefill_only_exact(self, metadata_context): assert torch.equal(metadata_context.seq_idx, expected_seq_idx) assert metadata_context.batch_indices_decode is None - assert metadata_context.cache_seqlens_decode is None assert metadata_context.batch_indices_chunked_prefill is None - assert metadata_context.cache_seqlens_chunked_prefill is None assert metadata_context.device_decode_prefill is None @pytest.mark.internal @@ -264,9 +229,7 @@ def test_update_prefill_only_padded(self, metadata_context): assert torch.equal(metadata_context.seq_idx, expected_seq_idx) assert metadata_context.batch_indices_decode is None - assert metadata_context.cache_seqlens_decode is None assert metadata_context.batch_indices_chunked_prefill is None - assert metadata_context.cache_seqlens_chunked_prefill is None assert metadata_context.device_decode_prefill is None # ------------------------------------------------------------------------- @@ -279,7 +242,6 @@ def test_update_mixed_batch_exact(self, metadata_context): # 2 decode (len 1), 2 prefill (len 10, 20) seq_lengths = [1, 1, 10, 20] num_decode = 2 - kv_offsets = [5, 10, 0, 0] padded_dims = InferenceBatchDimensions( token_count=32, prefill_req_count=2, decode_req_count=2 ) @@ -290,18 +252,11 @@ def test_update_mixed_batch_exact(self, metadata_context): num_decode, padded_dims, enable_chunked_prefill=False, - request_kv_length_offsets=kv_offsets, ) expected_decode = torch.tensor([0, 1], dtype=torch.int32, device=metadata_context.device) assert torch.equal(metadata_context.batch_indices_decode, expected_decode) - expected_cache_seqlens = torch.tensor( - [5, 10], dtype=torch.int32, device=metadata_context.device - ) - assert torch.equal(metadata_context.cache_seqlens_decode, expected_cache_seqlens) - assert metadata_context.cache_seqlens_chunked_prefill is None - expected_prefill = torch.tensor([2, 3], dtype=torch.int32, device=metadata_context.device) assert torch.equal(metadata_context.batch_indices_prefill, expected_prefill) @@ -331,7 +286,6 @@ def test_update_padded_prefill_and_decode(self, metadata_context): # Real: 1 decode, 1 prefill. seq_lengths = [1, 10] num_decode = 1 - kv_offsets = [25, 0] # Padded: 4 decode, 4 prefill. Total tokens 32. padded_dims = InferenceBatchDimensions( @@ -344,7 +298,6 @@ def test_update_padded_prefill_and_decode(self, metadata_context): num_decode, padded_dims, enable_chunked_prefill=False, - request_kv_length_offsets=kv_offsets, ) expected_decode = torch.tensor( @@ -352,12 +305,6 @@ def test_update_padded_prefill_and_decode(self, metadata_context): ) assert torch.equal(metadata_context.batch_indices_decode, expected_decode) - expected_cache_seqlens = torch.tensor( - [25, 0, 0, 0], dtype=torch.int32, device=metadata_context.device - ) - assert torch.equal(metadata_context.cache_seqlens_decode, expected_cache_seqlens) - assert metadata_context.cache_seqlens_chunked_prefill is None - expected_prefill = torch.tensor( [1, -1, -1, -1], dtype=torch.int32, device=metadata_context.device ) @@ -389,7 +336,6 @@ def test_update_chunked_prefill_mixed_exact(self, metadata_context): # 1 decode, 1 chunked prefill (len 50), 1 regular prefill (len 10) seq_lengths = [1, 50, 10] num_decode = 1 - kv_offsets = [9, 100, 0] # Exact dimensions padded_dims = InferenceBatchDimensions( @@ -402,7 +348,6 @@ def test_update_chunked_prefill_mixed_exact(self, metadata_context): num_decode, padded_dims, enable_chunked_prefill=True, - request_kv_length_offsets=kv_offsets, ) expected_device_chunked_prefill = torch.tensor( @@ -412,18 +357,6 @@ def test_update_chunked_prefill_mixed_exact(self, metadata_context): assert metadata_context.batch_indices_chunked_prefill[0] == 1 - expected_cache_seqlens_decode = torch.tensor( - [9], dtype=torch.int32, device=metadata_context.device - ) - assert torch.equal(metadata_context.cache_seqlens_decode, expected_cache_seqlens_decode) - - expected_cache_seqlens_chunked = torch.tensor( - [100], dtype=torch.int32, device=metadata_context.device - ) - assert torch.equal( - metadata_context.cache_seqlens_chunked_prefill, expected_cache_seqlens_chunked - ) - expected_prefill = torch.tensor([2, -1], dtype=torch.int32, device=metadata_context.device) assert torch.equal(metadata_context.batch_indices_prefill, expected_prefill) @@ -447,7 +380,6 @@ def test_update_chunked_prefill_mixed_padded(self, metadata_context): # 2 decode, 1 chunked prefill (len 50), 1 regular prefill (len 10) seq_lengths = [1, 1, 50, 10] num_decode = 2 - kv_offsets = [4, 6, 200, 0] padded_dims = InferenceBatchDimensions( token_count=62, prefill_req_count=2, decode_req_count=2 ) @@ -458,17 +390,11 @@ def test_update_chunked_prefill_mixed_padded(self, metadata_context): num_decode, padded_dims, enable_chunked_prefill=True, - request_kv_length_offsets=kv_offsets, ) expected_decode = torch.tensor([0, 1], dtype=torch.int32, device=metadata_context.device) assert torch.equal(metadata_context.batch_indices_decode, expected_decode) - expected_cache_seqlens_decode = torch.tensor( - [4, 6], dtype=torch.int32, device=metadata_context.device - ) - assert torch.equal(metadata_context.cache_seqlens_decode, expected_cache_seqlens_decode) - expected_device_chunked_prefill = torch.tensor( [50, 10], dtype=torch.int32, device=metadata_context.device ) @@ -476,13 +402,6 @@ def test_update_chunked_prefill_mixed_padded(self, metadata_context): assert metadata_context.batch_indices_chunked_prefill[0] == 2 - expected_cache_seqlens_chunked = torch.tensor( - [200], dtype=torch.int32, device=metadata_context.device - ) - assert torch.equal( - metadata_context.cache_seqlens_chunked_prefill, expected_cache_seqlens_chunked - ) - expected_prefill = torch.tensor([3, -1], dtype=torch.int32, device=metadata_context.device) assert torch.equal(metadata_context.batch_indices_prefill, expected_prefill) @@ -506,7 +425,6 @@ def test_update_chunked_only_padded(self, metadata_context): # 1 chunked prefill request. seq_lengths = [100] num_decode = 0 - kv_offsets = [50] padded_dims = InferenceBatchDimensions( token_count=128, prefill_req_count=2, decode_req_count=0 @@ -518,21 +436,12 @@ def test_update_chunked_only_padded(self, metadata_context): num_decode, padded_dims, enable_chunked_prefill=True, - request_kv_length_offsets=kv_offsets, ) assert metadata_context.batch_indices_decode is None - assert metadata_context.cache_seqlens_decode is None assert metadata_context.batch_indices_chunked_prefill[0] == 0 - expected_cache_seqlens_chunked = torch.tensor( - [50], dtype=torch.int32, device=metadata_context.device - ) - assert torch.equal( - metadata_context.cache_seqlens_chunked_prefill, expected_cache_seqlens_chunked - ) - expected_prefill = torch.tensor([-1, -1], dtype=torch.int32, device=metadata_context.device) assert torch.equal(metadata_context.batch_indices_prefill, expected_prefill) From b65dbbecaf4264fdd335d4797e248582291233f7 Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Mon, 9 Mar 2026 02:00:35 -0700 Subject: [PATCH 56/76] Linting and misc review comments Signed-off-by: Keshav Santhanam --- .../core/inference/batch_dimensions_utils.py | 5 +- .../attention_context/mamba_metadata.py | 1 - .../inference/contexts/dynamic_context.py | 1 - .../text_generation_controller.py | 6 +-- .../attention_metadata/test_mamba_metadata.py | 48 ++++--------------- 5 files changed, 14 insertions(+), 47 deletions(-) diff --git a/megatron/core/inference/batch_dimensions_utils.py b/megatron/core/inference/batch_dimensions_utils.py index e7969bf5e88..c6bf8c79e78 100644 --- a/megatron/core/inference/batch_dimensions_utils.py +++ b/megatron/core/inference/batch_dimensions_utils.py @@ -383,8 +383,9 @@ def add_if_valid(token_count: int, prefill_req_count: int, decode_req_count: int cuda_graph_max_tokens = max_tokens assert cuda_graph_max_tokens == max_requests * (num_speculative_tokens + 1), ( - f"cuda_graph_max_tokens ({cuda_graph_max_tokens}) must equal max_requests " - f"({max_requests}). This is required for correctly syncing EP ranks: " + f"cuda_graph_max_tokens ({cuda_graph_max_tokens}) must equal max_requests *" + f"(num_speculative_tokens + 1) ({max_requests * (num_speculative_tokens + 1)}). " + "This is required for correctly syncing EP ranks: " f"prefill and decode graph pools must have the same token count granularity." ) diff --git a/megatron/core/inference/contexts/attention_context/mamba_metadata.py b/megatron/core/inference/contexts/attention_context/mamba_metadata.py index 75174db7ad1..34a19cf0394 100644 --- a/megatron/core/inference/contexts/attention_context/mamba_metadata.py +++ b/megatron/core/inference/contexts/attention_context/mamba_metadata.py @@ -188,7 +188,6 @@ def update( # Update chunked prefill indices self._batch_indices_chunked_prefill_buffer[0] = active_mamba_indices[chunked_req_idx] self.batch_indices_chunked_prefill = self._batch_indices_chunked_prefill_buffer - else: self.batch_indices_chunked_prefill = None diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index 84648efdbf0..bceec0d5ecf 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -2612,7 +2612,6 @@ def update_requests( old_offsets + num_generated_tokens ) % self.block_size_tokens - # ================================================================ self.active_token_count = active_request_count * num_generated_tokens sampled_tokens = next_tokens[self.paused_request_count : self.total_request_count] diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index 24bd876fec8..adaf93fbdf1 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -1119,10 +1119,10 @@ def _dynamic_step_sample_logits(self, logits: Tensor): if context.config.materialize_only_last_token_logits: # When materialize_only_last_token_logits is true, last_token_logits is # already called in the forward pass of GPT. - required_token_indices = logits.squeeze(0) + required_token_logits = logits.squeeze(0) else: # todo : Should do verification here and get approrpiate las token logits - required_token_indices = context.last_token_logits(logits) + required_token_logits = context.last_token_logits(logits) if self._sampling_backend == "torch": # Concatenate the outputs once to prevent repeated small writes. @@ -1137,7 +1137,7 @@ def _dynamic_step_sample_logits(self, logits: Tensor): for indices, temp, top_k, top_p in self._torch_sampling_buckets: token_list.append( self._torch_sampling_func( - required_token_indices[indices, :], temp, top_k, top_p + required_token_logits[indices, :], temp, top_k, top_p ) ) indices_list.append(torch.tensor(indices)) diff --git a/tests/unit_tests/inference/contexts/attention_metadata/test_mamba_metadata.py b/tests/unit_tests/inference/contexts/attention_metadata/test_mamba_metadata.py index 21459cece7d..7e76ce4b7b0 100644 --- a/tests/unit_tests/inference/contexts/attention_metadata/test_mamba_metadata.py +++ b/tests/unit_tests/inference/contexts/attention_metadata/test_mamba_metadata.py @@ -95,11 +95,7 @@ def test_update_decode_only_exact_match(self, metadata_context): ) self._run_update_test( - metadata_context, - seq_lengths, - num_decode, - padded_dims, - enable_chunked_prefill=False, + metadata_context, seq_lengths, num_decode, padded_dims, enable_chunked_prefill=False ) expected_decode = torch.arange(4, dtype=torch.int32, device=metadata_context.device) @@ -122,11 +118,7 @@ def test_update_decode_only_padded(self, metadata_context): ) self._run_update_test( - metadata_context, - seq_lengths, - num_decode, - padded_dims, - enable_chunked_prefill=False, + metadata_context, seq_lengths, num_decode, padded_dims, enable_chunked_prefill=False ) expected_decode = torch.tensor( @@ -148,11 +140,7 @@ def test_update_chunked_enabled_no_prefill_reqs(self, metadata_context): ) self._run_update_test( - metadata_context, - seq_lengths, - num_decode, - padded_dims, - enable_chunked_prefill=True, + metadata_context, seq_lengths, num_decode, padded_dims, enable_chunked_prefill=True ) # Should behave exactly like decode-only (chunked logic skipped if real_prefill == 0) @@ -247,11 +235,7 @@ def test_update_mixed_batch_exact(self, metadata_context): ) self._run_update_test( - metadata_context, - seq_lengths, - num_decode, - padded_dims, - enable_chunked_prefill=False, + metadata_context, seq_lengths, num_decode, padded_dims, enable_chunked_prefill=False ) expected_decode = torch.tensor([0, 1], dtype=torch.int32, device=metadata_context.device) @@ -293,11 +277,7 @@ def test_update_padded_prefill_and_decode(self, metadata_context): ) self._run_update_test( - metadata_context, - seq_lengths, - num_decode, - padded_dims, - enable_chunked_prefill=False, + metadata_context, seq_lengths, num_decode, padded_dims, enable_chunked_prefill=False ) expected_decode = torch.tensor( @@ -343,11 +323,7 @@ def test_update_chunked_prefill_mixed_exact(self, metadata_context): ) self._run_update_test( - metadata_context, - seq_lengths, - num_decode, - padded_dims, - enable_chunked_prefill=True, + metadata_context, seq_lengths, num_decode, padded_dims, enable_chunked_prefill=True ) expected_device_chunked_prefill = torch.tensor( @@ -385,11 +361,7 @@ def test_update_chunked_prefill_mixed_padded(self, metadata_context): ) self._run_update_test( - metadata_context, - seq_lengths, - num_decode, - padded_dims, - enable_chunked_prefill=True, + metadata_context, seq_lengths, num_decode, padded_dims, enable_chunked_prefill=True ) expected_decode = torch.tensor([0, 1], dtype=torch.int32, device=metadata_context.device) @@ -431,11 +403,7 @@ def test_update_chunked_only_padded(self, metadata_context): ) self._run_update_test( - metadata_context, - seq_lengths, - num_decode, - padded_dims, - enable_chunked_prefill=True, + metadata_context, seq_lengths, num_decode, padded_dims, enable_chunked_prefill=True ) assert metadata_context.batch_indices_decode is None From 712824f37480cdde5ffe87f31a7cf6feed5b679b Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Mon, 9 Mar 2026 02:05:25 -0700 Subject: [PATCH 57/76] Linting Signed-off-by: Keshav Santhanam --- .../text_generation_controllers/text_generation_controller.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index adaf93fbdf1..b53b9dbe8b6 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -1136,9 +1136,7 @@ def _dynamic_step_sample_logits(self, logits: Tensor): # [ [req at index 1, req at index 3, req at index 4] , t2, topk2, topp2] for indices, temp, top_k, top_p in self._torch_sampling_buckets: token_list.append( - self._torch_sampling_func( - required_token_logits[indices, :], temp, top_k, top_p - ) + self._torch_sampling_func(required_token_logits[indices, :], temp, top_k, top_p) ) indices_list.append(torch.tensor(indices)) From 61e16e5ca39fe89a8f2b905e81c93ea31e971c78 Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Mon, 9 Mar 2026 10:03:33 -0700 Subject: [PATCH 58/76] Revert materialize_only_last_token_logits changes Signed-off-by: Keshav Santhanam --- .../inference/contexts/dynamic_context.py | 39 ------------------ .../core/inference/engines/dynamic_engine.py | 9 ++++ .../text_generation_controller.py | 41 ++++--------------- megatron/core/models/gpt/gpt_model.py | 8 +--- megatron/core/models/mamba/mamba_model.py | 8 +--- megatron/inference/utils.py | 2 +- 6 files changed, 19 insertions(+), 88 deletions(-) diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index bceec0d5ecf..e7a2ab3ed85 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -1678,45 +1678,6 @@ def last_token_logits(self, logits: Tensor) -> Tensor: return last_token_logits - def speculative_required_logits(self, logits: Tensor) -> Tensor: - """Extract logits at positions required for speculative decoding. - - For decode requests, all tokens (base + speculative) are needed. - For prefill requests, only the last token is needed. - - Args: - logits (Tensor): Hidden states of shape [1, padded_active_token_count, ...]. - - Return: - (Tensor) Logits at required positions, shape - [num_decode * (num_spec + 1) + num_prefill, ...]. - """ - assert logits.size(0) == 1, f"logits.size(0) ({tuple(logits.shape)}) != 1" - assert logits.size(1) == self.padded_active_token_count, ( - f"logits.size(1) ({tuple(logits.shape)}) != " - f"padded_active_token_count ({self.padded_active_token_count})." - ) - - logits = logits.squeeze(0) - active_slice = slice(self.paused_request_count, self.total_request_count) - request_in_prefill = self.request_in_prefill_status_tensor[active_slice] - query_lengths = self.request_query_lengths[active_slice] - - num_prefill = (request_in_prefill == 1).sum().item() - num_decode = len(request_in_prefill) - num_prefill - num_speculative_tokens = self.config.num_speculative_tokens - - # All tokens for decode requests (they come first in the packed sequence). - decode_indices = torch.arange( - num_decode * (num_speculative_tokens + 1), device=logits.device - ) - - # Last token index for each prefill request. - prefill_indices = query_lengths.cumsum(dim=0)[request_in_prefill == 1] - 1 - - required_indices = torch.cat([decode_indices, prefill_indices]) - return logits[required_indices, :] - def _compute_prefix_match( self, req: DynamicInferenceRequest, chunk_length: int ) -> Tuple[list, int, int, int, int, int]: diff --git a/megatron/core/inference/engines/dynamic_engine.py b/megatron/core/inference/engines/dynamic_engine.py index d0a47088502..1c3bfe833c4 100644 --- a/megatron/core/inference/engines/dynamic_engine.py +++ b/megatron/core/inference/engines/dynamic_engine.py @@ -211,6 +211,9 @@ def __init__(self, controller: TextGenerationController, context: DynamicInferen assert ( self.num_speculative_tokens <= self.controller.num_mtp_heads ), f"Number of speculative tokens {self.num_speculative_tokens} must be less than or equal to number of MTP heads {self.controller.num_mtp_heads}" + assert not self.materialize_only_last_token_logits, ( + "materialize_only_last_token_logits must be False when num_speculative_tokens > 0" + ) self.track_paused_request_events = inference_config.track_paused_request_events self.track_generated_token_events = inference_config.track_generated_token_events @@ -996,6 +999,12 @@ def post_process_requests( if self.num_speculative_tokens > 0: accepted_tokens = list(filter(lambda tok: tok != -1, accepted_tokens_list)) + + # The order `accepted_tokens + tokens` is correct here. + # `accepted_tokens` contains the sequence of + # successfully verified draft tokens. `tokens` (from `sample`) is the + # brand new token generated by the target model based on that accepted prefix. + # Therefore, the newly sampled token must go at the end of the sequence. tokens = accepted_tokens + tokens request: DynamicInferenceRequest = self.get_request(request_id) diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index b53b9dbe8b6..4314d858d11 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -637,33 +637,13 @@ def _dynamic_step_forward_logits(self, input_ids: Tensor, position_ids: Tensor) ), f"MTP logits length mismatch. Expected mtp logits length {self.num_mtp_heads}, got {expected_mtp_logits_length}" mtp_logits = mtp_logits[: self.num_speculative_tokens] - if context.config.materialize_only_last_token_logits: - # Base logits are already filtered to required positions by the model. - # Filter MTP logits to match before concatenation. - active_slice = slice(context.paused_request_count, context.total_request_count) - request_in_prefill = context.request_in_prefill_status_tensor[active_slice] - query_lengths = context.request_query_lengths[active_slice] - num_prefill = (request_in_prefill == 1).sum().item() - num_decode = active_request_count - num_prefill - required_logit_indices = self._get_required_logit_indices( - request_in_prefill, query_lengths, num_decode, num_prefill, mtp_logits.device - ) - mtp_logits = mtp_logits[:, required_logit_indices, :] - logits = torch.cat( [logits, mtp_logits], dim=0 ) # [num_speculative_tokens + 1, seq_len_or_required, vocab_size] if self.model_is_pipeline_parallel: if context.config.materialize_only_last_token_logits: - if self.num_speculative_tokens > 0: - active_slice = slice(context.paused_request_count, context.total_request_count) - request_in_prefill = context.request_in_prefill_status_tensor[active_slice] - num_prefill = (request_in_prefill == 1).sum().item() - num_decode = active_request_count - num_prefill - logits_seq_len = num_decode * (self.num_speculative_tokens + 1) + num_prefill - else: - logits_seq_len = active_request_count + logits_seq_len = active_request_count else: logits_seq_len = input_ids.shape[1] logits_shape = [self.num_speculative_tokens + 1, logits_seq_len, self.vocab_size] @@ -1041,19 +1021,12 @@ def _dynamic_step_sample_logits_and_verify_tokens( logits.device, ) - if context.config.materialize_only_last_token_logits: - # Logits are already pre-filtered to required positions by the model forward. - required_logits = logits.squeeze(0) # Shape [num_required, vocab_size] - required_mtp_logits = ( - mtp_logits # Shape [num_speculative_tokens, num_required, vocab_size] - ) - else: - required_logits = logits.squeeze(0)[ - required_logit_indices, : - ] # Shape [num_required, vocab_size] - required_mtp_logits = mtp_logits[ - :, required_logit_indices, : - ] # Shape [num_speculative_tokens, num_required, vocab_size] + required_logits = logits.squeeze(0)[ + required_logit_indices, : + ] # Shape [num_required, vocab_size] + required_mtp_logits = mtp_logits[ + :, required_logit_indices, : + ] # Shape [num_speculative_tokens, num_required, vocab_size] # Sample tokens from logits and MTP logits. output_tokens, mtp_output_tokens, repeats = self._sample_speculative_logits( diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py index 3f27f9e5792..716362061c2 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py @@ -660,13 +660,7 @@ def _postprocess( # Reshape [S, B, H] (with B=1) to [1, S, H] for logit extraction, # then back to [S’, B, H] for the output layer. reshaped = hidden_states.squeeze(1).unsqueeze(0) - if inference_context.config.num_speculative_tokens > 0: - # For speculative decoding, keep all decode tokens + last prefill token. - hidden_states = inference_context.speculative_required_logits( - reshaped - ).unsqueeze(1) - else: - hidden_states = inference_context.last_token_logits(reshaped).unsqueeze(1) + hidden_states = inference_context.last_token_logits(reshaped).unsqueeze(1) logits, _ = self.output_layer( hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output diff --git a/megatron/core/models/mamba/mamba_model.py b/megatron/core/models/mamba/mamba_model.py index 007001ea6b0..2da8c31d14e 100644 --- a/megatron/core/models/mamba/mamba_model.py +++ b/megatron/core/models/mamba/mamba_model.py @@ -448,13 +448,7 @@ def forward( # Reshape [S, B, H] (with B=1) to [1, S, H] for logit extraction, # then back to [S', B, H] for the output layer. reshaped = hidden_states.squeeze(1).unsqueeze(0) - if inference_context.config.num_speculative_tokens > 0: - # For speculative decoding, keep all decode tokens + last prefill token. - hidden_states = inference_context.speculative_required_logits( - reshaped - ).unsqueeze(1) - else: - hidden_states = inference_context.last_token_logits(reshaped).unsqueeze(1) + hidden_states = inference_context.last_token_logits(reshaped).unsqueeze(1) logits, _ = self.output_layer( hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output diff --git a/megatron/inference/utils.py b/megatron/inference/utils.py index 37898dc8c90..ec8f1088be1 100644 --- a/megatron/inference/utils.py +++ b/megatron/inference/utils.py @@ -341,7 +341,7 @@ def get_inference_config_from_model_and_args(model: MegatronModule, args): mamba_inference_state_config=mamba_inference_state_config, pg_collection=pg_collection, use_flashinfer_fused_rope=args.use_flashinfer_fused_rope, - materialize_only_last_token_logits=(not args.return_log_probs), + materialize_only_last_token_logits=(not args.return_log_probs and args.num_speculative_tokens == 0), track_generated_token_events=args.inference_dynamic_batching_track_generated_token_events, track_paused_request_events=args.inference_dynamic_batching_track_paused_request_events, enable_chunked_prefill=args.enable_chunked_prefill, From 456853cf79bb3e815900a8cc5a43439552edd6b8 Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Mon, 9 Mar 2026 10:07:49 -0700 Subject: [PATCH 59/76] Formatting Signed-off-by: Keshav Santhanam --- megatron/core/inference/engines/dynamic_engine.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/megatron/core/inference/engines/dynamic_engine.py b/megatron/core/inference/engines/dynamic_engine.py index 1c3bfe833c4..500954caf03 100644 --- a/megatron/core/inference/engines/dynamic_engine.py +++ b/megatron/core/inference/engines/dynamic_engine.py @@ -211,9 +211,9 @@ def __init__(self, controller: TextGenerationController, context: DynamicInferen assert ( self.num_speculative_tokens <= self.controller.num_mtp_heads ), f"Number of speculative tokens {self.num_speculative_tokens} must be less than or equal to number of MTP heads {self.controller.num_mtp_heads}" - assert not self.materialize_only_last_token_logits, ( - "materialize_only_last_token_logits must be False when num_speculative_tokens > 0" - ) + assert ( + not self.materialize_only_last_token_logits + ), "materialize_only_last_token_logits must be False when num_speculative_tokens > 0" self.track_paused_request_events = inference_config.track_paused_request_events self.track_generated_token_events = inference_config.track_generated_token_events @@ -999,11 +999,11 @@ def post_process_requests( if self.num_speculative_tokens > 0: accepted_tokens = list(filter(lambda tok: tok != -1, accepted_tokens_list)) - + # The order `accepted_tokens + tokens` is correct here. - # `accepted_tokens` contains the sequence of - # successfully verified draft tokens. `tokens` (from `sample`) is the - # brand new token generated by the target model based on that accepted prefix. + # `accepted_tokens` contains the sequence of + # successfully verified draft tokens. `tokens` (from `sample`) is the + # brand new token generated by the target model based on that accepted prefix. # Therefore, the newly sampled token must go at the end of the sequence. tokens = accepted_tokens + tokens From c3e697fc30d4762f4ba3c84f413f46d50a0394d3 Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Mon, 9 Mar 2026 10:21:24 -0700 Subject: [PATCH 60/76] Minor fix Signed-off-by: Keshav Santhanam --- megatron/core/inference/engines/dynamic_engine.py | 6 +++--- megatron/core/ssm/mamba_mixer.py | 1 + 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/megatron/core/inference/engines/dynamic_engine.py b/megatron/core/inference/engines/dynamic_engine.py index 500954caf03..54ed7ff0a44 100644 --- a/megatron/core/inference/engines/dynamic_engine.py +++ b/megatron/core/inference/engines/dynamic_engine.py @@ -204,6 +204,9 @@ def __init__(self, controller: TextGenerationController, context: DynamicInferen self.context = context self.num_speculative_tokens = inference_config.num_speculative_tokens + self.materialize_only_last_token_logits = ( + inference_config.materialize_only_last_token_logits + ) assert self.num_speculative_tokens >= 0, "Number of speculative tokens must be non-negative" @@ -221,9 +224,6 @@ def __init__(self, controller: TextGenerationController, context: DynamicInferen self.metrics_writer = inference_config.metrics_writer self.logging_step_interval = inference_config.logging_step_interval self.unified_memory_level = inference_config.unified_memory_level - self.materialize_only_last_token_logits = ( - inference_config.materialize_only_last_token_logits - ) self.cuda_graph_impl = model_config.cuda_graph_impl self.cuda_graph_scope = model_config.cuda_graph_scope # Initialize engine. diff --git a/megatron/core/ssm/mamba_mixer.py b/megatron/core/ssm/mamba_mixer.py index 4362216e744..30b1a28bd71 100644 --- a/megatron/core/ssm/mamba_mixer.py +++ b/megatron/core/ssm/mamba_mixer.py @@ -429,6 +429,7 @@ def _dynamic_inference(self, hidden_states: torch.Tensor, context: DynamicInfere conv_state, ssm_state = context.mamba_states_cache(self.layer_number - self.pp_layer_offset) # Fetch intermediate states for speculative decoding + # (just buffers, existing data is overwritten) int_conv_state = None int_ssm_state = None if context.num_speculative_tokens > 0: From 955f4049ee114f6ec28096003754ea611eee8890 Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Mon, 9 Mar 2026 10:32:41 -0700 Subject: [PATCH 61/76] Remove outdated assertion on test Signed-off-by: Keshav Santhanam --- tests/unit_tests/inference/engines/test_dynamic_engine.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/unit_tests/inference/engines/test_dynamic_engine.py b/tests/unit_tests/inference/engines/test_dynamic_engine.py index e67799059e0..cb0351044be 100644 --- a/tests/unit_tests/inference/engines/test_dynamic_engine.py +++ b/tests/unit_tests/inference/engines/test_dynamic_engine.py @@ -1093,10 +1093,6 @@ def test_parallel_inference( "when tp_size > 1." ) ) - if model_provider == "mamba": - pytest.skip( - reason="Mamba model is not supported with the inference optimized transformer." - ) env = self._run_test( model_provider=model_provider, From 1a0584d1b1b8a405e52d4f3eaa61a922bb42672d Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Mon, 9 Mar 2026 10:38:33 -0700 Subject: [PATCH 62/76] Nits Signed-off-by: Keshav Santhanam --- megatron/core/inference/engines/dynamic_engine.py | 1 - .../text_generation_controllers/text_generation_controller.py | 1 - 2 files changed, 2 deletions(-) diff --git a/megatron/core/inference/engines/dynamic_engine.py b/megatron/core/inference/engines/dynamic_engine.py index 54ed7ff0a44..e929d2b0aed 100644 --- a/megatron/core/inference/engines/dynamic_engine.py +++ b/megatron/core/inference/engines/dynamic_engine.py @@ -1021,7 +1021,6 @@ def post_process_requests( - len(request.generated_tokens) ] if request_id not in self.stop_word_being_finished_ids: - is_first_token = len(request.generated_tokens) == 0 request.generated_tokens += tokens # TODO : SHAN Should check and change the following for speculative tokens diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index 4314d858d11..4f1f6356875 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -1358,7 +1358,6 @@ def _dynamic_step_context_bookkeeping(self) -> Dict[str, Tensor]: ).byte() & torch.less(active_sequence_lengths, max_sequence_lengths).byte() # Mark requests as finished if they hit stop words (detected in previous step's post_process_requests) - # TODO : SHAN : Correclty implement this if self._get_stop_word_finished_ids_callback is not None: request_ids_list = active_request_ids.tolist() stop_word_finished_ids = self._get_stop_word_finished_ids_callback(request_ids_list) From 00a7dcc0c9c1db854a1f97772fbeb95a2a5d77d5 Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Mon, 9 Mar 2026 11:09:00 -0700 Subject: [PATCH 63/76] Fix event tracking for speculative tokens Signed-off-by: Keshav Santhanam --- .../core/inference/engines/dynamic_engine.py | 42 +++++++++---------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/megatron/core/inference/engines/dynamic_engine.py b/megatron/core/inference/engines/dynamic_engine.py index e929d2b0aed..79f1469b7ce 100644 --- a/megatron/core/inference/engines/dynamic_engine.py +++ b/megatron/core/inference/engines/dynamic_engine.py @@ -1023,31 +1023,31 @@ def post_process_requests( if request_id not in self.stop_word_being_finished_ids: is_first_token = len(request.generated_tokens) == 0 request.generated_tokens += tokens - # TODO : SHAN Should check and change the following for speculative tokens - token = tokens[0] + first_token_event = None if self.track_generated_token_events: - if block_allocator.enable_prefix_caching: - event_generated_token = request.add_event_generated_token( - token, - blocks_total=block_allocator.total_count, - blocks_hashed_total=blocks_allocated, - blocks_hashed_active=blocks_hashed_active, - blocks_ref_count=blocks_ref_count, - ) - else: - event_generated_token = request.add_event_generated_token( - token, - blocks_total=block_allocator.total_count, - blocks_hashed_total=blocks_allocated, - blocks_hashed_active=blocks_hashed_active, - ) + for token in tokens: + if block_allocator.enable_prefix_caching: + event = request.add_event_generated_token( + token, + blocks_total=block_allocator.total_count, + blocks_hashed_total=blocks_allocated, + blocks_hashed_active=blocks_hashed_active, + blocks_ref_count=blocks_ref_count, + ) + else: + event = request.add_event_generated_token( + token, + blocks_total=block_allocator.total_count, + blocks_hashed_total=blocks_allocated, + blocks_hashed_active=blocks_hashed_active, + ) + if first_token_event is None: + first_token_event = event if is_first_token: - if self.track_generated_token_events: - first_token_event = event_generated_token - else: + if not self.track_generated_token_events: first_token_event = DynamicInferenceEvent( type=DynamicInferenceEventType.GENERATED_TOKEN, - payload={"token_id": token}, + payload={"token_id": tokens[0]}, ) request.ttft = ( first_token_event.timestamp - request.event_add_engine.timestamp From 9ca68d09084dd4915f88de1641b8e996c6a03f5d Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Mon, 9 Mar 2026 12:27:23 -0700 Subject: [PATCH 64/76] Update text_generation_controller tests Signed-off-by: Keshav Santhanam --- .../test_text_generation_controller.py | 76 ++++++++++--------- 1 file changed, 42 insertions(+), 34 deletions(-) diff --git a/tests/unit_tests/inference/text_generation_controllers/test_text_generation_controller.py b/tests/unit_tests/inference/text_generation_controllers/test_text_generation_controller.py index ba453670862..56958bfe953 100644 --- a/tests/unit_tests/inference/text_generation_controllers/test_text_generation_controller.py +++ b/tests/unit_tests/inference/text_generation_controllers/test_text_generation_controller.py @@ -52,6 +52,11 @@ def setup_model( batch_size: int = 4, static: bool = True, use_training_random_init: bool = False, + materialize_only_last_token_logits: bool = False, + num_speculative_tokens: int = 0, + block_size_tokens: int = 256, + enable_prefix_caching: bool = False, + max_requests: int = None, ): Utils.initialize_model_parallel( tensor_model_parallel_size=tensor_model_parallel_size, @@ -108,10 +113,14 @@ def setup_model( inference_config=InferenceConfig( max_sequence_length=2048, buffer_size_gb=0.2, - materialize_only_last_token_logits=False, + materialize_only_last_token_logits=materialize_only_last_token_logits, use_flashinfer_fused_rope=None, # default to using flash-infer if available # this is for compatibility with the LTS environment unified_memory_level=0, # unit tests currently broken with UVM + num_speculative_tokens=num_speculative_tokens, + block_size_tokens=block_size_tokens, + enable_prefix_caching=enable_prefix_caching, + max_requests=max_requests, ), ) @@ -224,11 +233,15 @@ def test_sample_from_dynamic_logits( self, backend: str, materialize_only_last_token_logits: bool ): batch_size = 12 - self.setup_model(torch.float32, batch_size=batch_size, static=False) + self.setup_model( + torch.float32, + batch_size=batch_size, + static=False, + materialize_only_last_token_logits=materialize_only_last_token_logits, + ) self.mock_tokenizer.eod = self.vocab_size context = self.text_generation_controller.inference_wrapped_model.inference_context - context.materialize_only_last_token_logits = materialize_only_last_token_logits # Prepare sampling params in human-readable format, to aid with test maintenance. sampling_test_cases: List[Tuple[SamplingParams, List[int]]] = [ @@ -743,11 +756,15 @@ def test_dynamic_top_n_logprobs_calculation( 3. Correct number of tokens are returned for each request """ batch_size = 4 - self.setup_model(torch.bfloat16, batch_size=batch_size, static=False) + self.setup_model( + torch.bfloat16, + batch_size=batch_size, + static=False, + materialize_only_last_token_logits=materialize_only_last_token_logits, + ) self.mock_tokenizer.eod = self.vocab_size context = self.text_generation_controller.inference_wrapped_model.inference_context - context.materialize_only_last_token_logits = materialize_only_last_token_logits # Prepare sampling params top_n = 5 @@ -1011,13 +1028,11 @@ def test_sampled_tokens_match_with_parallelism(self, static, tp_size, pp_size): @pytest.mark.internal def test_speculative_verify_tokens(self): """Test consecutive token acceptance logic for speculative decoding.""" - self.setup_model(torch.float32, static=False) + self.setup_model(torch.float32, static=False, num_speculative_tokens=2, max_requests=2) # Enable speculative decoding self.text_generation_controller.num_speculative_tokens = 2 ctx = self.text_generation_controller.inference_wrapped_model.inference_context - ctx.num_speculative_tokens = 2 - ctx.max_requests = 2 ctx.total_request_count = 2 ctx.paused_request_count = 0 ctx.request_in_prefill_status_tensor = torch.tensor( @@ -1076,11 +1091,9 @@ def mock_sampling_func(logits, *args, **kwargs): @pytest.mark.parametrize("is_hybrid_model", [False, True]) def test_rewind_kv_cache(self, is_hybrid_model): """Test KV cache state is properly rewound for rejected speculative tokens.""" - self.setup_model(torch.float32, static=False) + self.setup_model(torch.float32, static=False, num_speculative_tokens=3, block_size_tokens=4) self.text_generation_controller.num_speculative_tokens = 3 ctx = self.text_generation_controller.inference_wrapped_model.inference_context - ctx.num_speculative_tokens = 3 - ctx.block_size_tokens = 4 ctx.total_request_count = 2 ctx.paused_request_count = 0 ctx.request_in_prefill_status_tensor = torch.tensor([0, 0], device='cuda') @@ -1104,6 +1117,8 @@ def test_rewind_kv_cache(self, is_hybrid_model): ctx.mamba_metadata.request_to_mamba_state_idx = torch.tensor([0, 1], device='cuda') ctx.mamba_ssm_states = torch.zeros((1, 2, 16), device='cuda') ctx.mamba_intermediate_ssm_states = torch.ones((1, 2, 4, 16), device='cuda') * 99 + ctx.mamba_conv_states = torch.zeros((1, 2, 8), device='cuda') + ctx.mamba_intermediate_conv_states = torch.ones((1, 2, 4, 8), device='cuda') * 77 # Mock accepted token counts: Req 0 accepts 1 (rejects 2), Req 1 accepts 0 (rejects 3) self.text_generation_controller._init_mtp_sampling_tensor() @@ -1138,19 +1153,21 @@ def test_rewind_kv_cache(self, is_hybrid_model): # Check Mamba state was restored from intermediate cache based on accepted counts assert torch.all(ctx.mamba_ssm_states[:, 0] == 99) # Req 0 accepted 1, loaded index 1 assert torch.all(ctx.mamba_ssm_states[:, 1] == 99) # Req 1 accepted 0, loaded index 0 + assert torch.all(ctx.mamba_conv_states[:, 0] == 77) # Req 0 accepted 1, loaded index 1 + assert torch.all(ctx.mamba_conv_states[:, 1] == 77) # Req 1 accepted 0, loaded index 0 @pytest.mark.internal def test_speculative_multinomial_sampling(self): """Test that speculative decoding can successfully use non-greedy sampling (top_k > 1, top_p > 0) by flattening 3D MTP logits for torch.multinomial.""" - self.setup_model(torch.float32, static=False) + num_spec = 3 + self.setup_model( + torch.float32, static=False, num_speculative_tokens=num_spec, max_requests=2 + ) # Enable speculative decoding - num_spec = 3 self.text_generation_controller.num_speculative_tokens = num_spec ctx = self.text_generation_controller.inference_wrapped_model.inference_context - ctx.num_speculative_tokens = num_spec - ctx.max_requests = 2 ctx.total_request_count = 2 ctx.paused_request_count = 0 ctx.request_in_prefill_status_tensor = torch.tensor( @@ -1159,9 +1176,6 @@ def test_speculative_multinomial_sampling(self): # query lengths for decode with spec tokens is (1 + num_spec) = 4 ctx.request_query_lengths = torch.tensor([4, 4], dtype=torch.int32, device='cuda') - # Init accepted tokens tensors - self.text_generation_controller._init_mtp_sampling_tensor() - # Setup inputs input_ids = torch.randint(0, self.vocab_size, (1, 8), device='cuda') @@ -1202,24 +1216,19 @@ def test_speculative_multinomial_sampling(self): def test_rewind_kv_cache_with_prefix_caching_ref_counts(self): """Test that _rewind_kv_cache correctly decrements ref counts on shared blocks when speculative token rejection causes a block boundary crossing.""" - self.setup_model(torch.float32, static=False) + self.setup_model( + torch.float32, + static=False, + num_speculative_tokens=2, + block_size_tokens=4, + enable_prefix_caching=True, + ) ctx = self.text_generation_controller.inference_wrapped_model.inference_context - self.text_generation_controller.num_speculative_tokens = 2 - ctx.num_speculative_tokens = 2 - ctx.block_size_tokens = 4 - ctx.enable_prefix_caching = True ctx.total_request_count = 2 ctx.paused_request_count = 0 ctx.request_in_prefill_status_tensor = torch.tensor([0, 0], device='cuda') - # Initialize allocator ref count tracking. - ctx.block_allocator.enable_prefix_caching = True - if not hasattr(ctx.block_allocator, 'block_ref_counts'): - ctx.block_allocator.block_ref_counts = torch.zeros( - ctx.block_allocator.total_count, dtype=torch.int32, device='cuda' - ) - # Req 0: 3 blocks, offset 1 in last block. Rewinding 1 token -> no block release. # Req 1: 3 blocks, offset 0 in last block. Rewinding 2 tokens -> crosses back, release block. ctx.request_kv_length_offsets[:2] = torch.tensor([9, 9], device='cuda') @@ -1252,12 +1261,11 @@ def test_rewind_kv_cache_with_prefix_caching_ref_counts(self): @pytest.mark.internal def test_rewind_kv_cache_does_not_release_shared_prefix_blocks(self): """Test that rewinding only releases the last block, never shared prefix blocks.""" - self.setup_model(torch.float32, static=False) + self.setup_model( + torch.float32, static=False, num_speculative_tokens=3, block_size_tokens=4 + ) ctx = self.text_generation_controller.inference_wrapped_model.inference_context - self.text_generation_controller.num_speculative_tokens = 3 - ctx.num_speculative_tokens = 3 - ctx.block_size_tokens = 4 ctx.total_request_count = 1 ctx.paused_request_count = 0 ctx.request_in_prefill_status_tensor = torch.tensor([0], device='cuda') From ed7667c97cdf45993ec34f1587bb9ffaff40d4a9 Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Mon, 9 Mar 2026 12:50:38 -0700 Subject: [PATCH 65/76] Fix text generation controller tests Signed-off-by: Keshav Santhanam --- .../test_text_generation_controller.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/tests/unit_tests/inference/text_generation_controllers/test_text_generation_controller.py b/tests/unit_tests/inference/text_generation_controllers/test_text_generation_controller.py index 56958bfe953..ff296b68390 100644 --- a/tests/unit_tests/inference/text_generation_controllers/test_text_generation_controller.py +++ b/tests/unit_tests/inference/text_generation_controllers/test_text_generation_controller.py @@ -1091,7 +1091,13 @@ def mock_sampling_func(logits, *args, **kwargs): @pytest.mark.parametrize("is_hybrid_model", [False, True]) def test_rewind_kv_cache(self, is_hybrid_model): """Test KV cache state is properly rewound for rejected speculative tokens.""" - self.setup_model(torch.float32, static=False, num_speculative_tokens=3, block_size_tokens=4) + self.setup_model( + torch.float32, + static=False, + num_speculative_tokens=3, + block_size_tokens=4, + max_requests=16, + ) self.text_generation_controller.num_speculative_tokens = 3 ctx = self.text_generation_controller.inference_wrapped_model.inference_context ctx.total_request_count = 2 @@ -1222,6 +1228,7 @@ def test_rewind_kv_cache_with_prefix_caching_ref_counts(self): num_speculative_tokens=2, block_size_tokens=4, enable_prefix_caching=True, + max_requests=16, ) ctx = self.text_generation_controller.inference_wrapped_model.inference_context @@ -1262,7 +1269,11 @@ def test_rewind_kv_cache_with_prefix_caching_ref_counts(self): def test_rewind_kv_cache_does_not_release_shared_prefix_blocks(self): """Test that rewinding only releases the last block, never shared prefix blocks.""" self.setup_model( - torch.float32, static=False, num_speculative_tokens=3, block_size_tokens=4 + torch.float32, + static=False, + num_speculative_tokens=3, + block_size_tokens=4, + max_requests=16, ) ctx = self.text_generation_controller.inference_wrapped_model.inference_context From b92f79ca9752f0e4ee09b3f9c84f2bef71c5dc49 Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Mon, 9 Mar 2026 14:19:01 -0700 Subject: [PATCH 66/76] Fix new_speculative_tokens + eviction, add tests Signed-off-by: Keshav Santhanam --- .../inference/contexts/dynamic_context.py | 37 +++- .../contexts/test_dynamic_context.py | 177 ++++++++++++++++++ .../inference/engines/test_dynamic_engine.py | 101 ++++++++++ 3 files changed, 305 insertions(+), 10 deletions(-) diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index e7a2ab3ed85..7975e32e677 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -280,6 +280,10 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC self.num_attention_heads_per_partition = 1 self.num_speculative_tokens = inference_config.num_speculative_tokens + assert self.num_speculative_tokens < inference_config.block_size_tokens, ( + f"num_speculative_tokens ({self.num_speculative_tokens}) must be < " + f"block_size_tokens ({inference_config.block_size_tokens})" + ) # Cache the PP group we should use for PP collectives inside the context. # If the model provides a pg_collection with a pp group, prefer it. @@ -2032,7 +2036,9 @@ def _move_book_keeping_tensors( self.mamba_metadata.request_to_mamba_state_idx[src_idxs] ) - def _swap_book_keeping_tensors(self, src_idxs, dst_idxs, next_tokens): + def _swap_book_keeping_tensors( + self, src_idxs, dst_idxs, next_tokens, new_speculative_tokens=None + ): """ Swaps all the relevent booking tensors with src idxs to dst idxs """ @@ -2047,6 +2053,11 @@ def _swap_book_keeping_tensors(self, src_idxs, dst_idxs, next_tokens): tensor_swap(self.request_last_kv_block_id, src_idxs, dst_idxs) tensor_swap(self.request_last_kv_block_offset, src_idxs, dst_idxs) + if new_speculative_tokens is not None: + # new_speculative_tokens has request dimension as second dimension, + # so swap on transposed view + tensor_swap(new_speculative_tokens.t(), src_idxs, dst_idxs) + for metadata_tensor in self.request_metadata.values(): tensor_swap(metadata_tensor, src_idxs, dst_idxs) @@ -2090,10 +2101,7 @@ def release_memory_blocks_from_request_indexes(self, request_indexes) -> None: self.mamba_metadata.free_slots(request_indexes) def resume_paused_requests( - self, - active_request_count: int, - newly_paused_request_ids: torch.Tensor, - next_tokens: torch.Tensor, + self, active_request_count: int, newly_paused_request_ids: torch.Tensor ) -> tuple[int, torch.Tensor]: """Resume as many paused requests as we have space for in the active buffer. @@ -2177,7 +2185,10 @@ def resume_paused_requests( return active_request_count, newly_paused_request_ids def evict_overflow_paused_requests( - self, active_request_count: int, next_tokens: torch.Tensor + self, + active_request_count: int, + next_tokens: torch.Tensor, + new_speculative_tokens: Optional[torch.Tensor] = None, ) -> Optional[tuple[torch.Tensor, torch.Tensor]]: """Evict requests that overflow the paused buffer. @@ -2268,7 +2279,10 @@ def evict_overflow_paused_requests( # Swap evicted and active requests. self._swap_book_keeping_tensors( - src_idxs=src_idxs, dst_idxs=dst_idxs, next_tokens=next_tokens + src_idxs=src_idxs, + dst_idxs=dst_idxs, + next_tokens=next_tokens, + new_speculative_tokens=new_speculative_tokens, ) # Update tracking vars. @@ -2514,15 +2528,17 @@ def update_requests( # 6.a. First, resume temporarily paused requests. active_request_count, newly_paused_request_ids = self.resume_paused_requests( - active_request_count, newly_paused_request_ids, next_tokens + active_request_count, newly_paused_request_ids ) # 6.b. Evict requests that overflow the paused buffer. - evict_request_ids = self.evict_overflow_paused_requests(active_request_count, next_tokens) + evict_request_ids = self.evict_overflow_paused_requests( + active_request_count, next_tokens, new_speculative_tokens + ) # 6.c. Resume any additional requests. active_request_count, newly_paused_request_ids = self.resume_paused_requests( - active_request_count, newly_paused_request_ids, next_tokens + active_request_count, newly_paused_request_ids ) assert active_request_count > 0, "active_request_count == %d." % active_request_count @@ -2534,6 +2550,7 @@ def update_requests( src_idxs=torch.tensor([self.get_index_of_chunked_prefill_request()]), dst_idxs=torch.tensor([self.total_request_count - 1]), next_tokens=next_tokens, + new_speculative_tokens=new_speculative_tokens, ) # 7. We make changes to the request book keeping tesnsors and setup the tokens for next iteration diff --git a/tests/unit_tests/inference/contexts/test_dynamic_context.py b/tests/unit_tests/inference/contexts/test_dynamic_context.py index 1b7a2cba5dc..40379d05163 100644 --- a/tests/unit_tests/inference/contexts/test_dynamic_context.py +++ b/tests/unit_tests/inference/contexts/test_dynamic_context.py @@ -1718,6 +1718,110 @@ def test_paused_speculative_tokens_tracking(self): ctx.paused_speculative_tokens[:, 0], torch.tensor([991, 992], device='cuda') ) + @pytest.mark.internal + @rounder_override(64) + def test_speculative_tokens_less_than_block_size_assert(self): + self._setup_model_parallel_group(1, 1) + model_config = TransformerConfig( + params_dtype=torch.float32, num_layers=2, kv_channels=8, num_attention_heads=2 + ) + inference_config = InferenceConfig( + max_sequence_length=128, + buffer_size_gb=0.01, + block_size_tokens=16, + num_speculative_tokens=16, + unified_memory_level=0, + ) + with pytest.raises( + AssertionError, match="num_speculative_tokens.*must be < block_size_tokens" + ): + DynamicInferenceContext(model_config=model_config, inference_config=inference_config) + + @pytest.mark.internal + @rounder_override(64) + def test_swap_book_keeping_tensors_with_speculative_tokens(self): + self._setup_model_parallel_group(1, 1) + + model_config = TransformerConfig( + params_dtype=torch.float32, num_layers=2, kv_channels=8, num_attention_heads=2 + ) + inference_config = InferenceConfig( + max_sequence_length=128, + buffer_size_gb=0.01, + block_size_tokens=32, + num_speculative_tokens=2, + unified_memory_level=0, + ) + ctx = DynamicInferenceContext(model_config=model_config, inference_config=inference_config) + + ctx.request_ids[:2] = torch.tensor([10, 11]) + next_tokens = torch.tensor([99, 100], device='cuda') + new_speculative_tokens = torch.tensor([[991, 1001], [992, 1002]], device='cuda') + + ctx._swap_book_keeping_tensors( + src_idxs=torch.tensor([0]), + dst_idxs=torch.tensor([1]), + next_tokens=next_tokens, + new_speculative_tokens=new_speculative_tokens, + ) + + assert torch.equal(ctx.request_ids[:2], torch.tensor([11, 10], device='cuda')) + assert torch.equal(next_tokens[:2], torch.tensor([100, 99], device='cuda')) + assert torch.equal( + new_speculative_tokens[:, :2], torch.tensor([[1001, 991], [1002, 992]], device='cuda') + ) + + @pytest.mark.internal + @rounder_override(64) + def test_update_requests_with_finished_requests_and_speculative_tokens(self): + self._setup_model_parallel_group(1, 1) + + model_config = TransformerConfig( + params_dtype=torch.float32, num_layers=2, kv_channels=8, num_attention_heads=2 + ) + inference_config = InferenceConfig( + max_sequence_length=128, + buffer_size_gb=0.01, + block_size_tokens=32, + num_speculative_tokens=2, + unified_memory_level=0, + ) + ctx = DynamicInferenceContext(model_config=model_config, inference_config=inference_config) + + # Setup 3 active requests: req0 (active), req1 (finished), req2 (active) + ctx.total_request_count = 3 + ctx.paused_request_count = 0 + ctx.active_token_count = 3 + ctx.request_ids[:3] = torch.tensor([10, 11, 12]) + ctx.request_query_lengths[:3] = 1 + ctx.request_kv_length_offsets[:3] = torch.tensor([5, 8, 12]) + ctx.request_last_kv_block_offset[:3] = torch.tensor([5, 8, 12]) + ctx.request_to_kv_block_ids[:3, 0] = torch.tensor([0, 1, 2]) + ctx.request_last_kv_block_id[:3] = torch.tensor([0, 1, 2]) + ctx.request_kv_block_counts[:3] = 1 + + active_requests_mask = torch.tensor([1, 0, 1], device='cuda') + new_tokens = torch.tensor([99, 100, 101], device='cuda') + new_speculative_tokens = torch.tensor([[991, 1001, 1011], [992, 1002, 1012]], device='cuda') + + ctx.update_requests( + active_requests_mask=active_requests_mask, + new_tokens=new_tokens, + new_speculative_tokens=new_speculative_tokens, + ) + + # req1 is finished. req2 moves to req1's position. + assert ctx.total_request_count == 2 + assert torch.equal( + ctx.request_ids[:2], torch.tensor([10, 12], device='cuda', dtype=torch.int32) + ) + + # Check interleaving for req0 and req2 + # req0: [99, 991, 992] + # req2: [101, 1011, 1012] + expected_tokens = torch.tensor([99, 991, 992, 101, 1011, 1012], device='cuda') + assert torch.equal(ctx.token_to_input_ids[:6], expected_tokens) + @pytest.mark.internal @rounder_override(64) def test_chunked_prefill_speculative_offset_math(self): @@ -1779,6 +1883,79 @@ def test_chunked_prefill_speculative_offset_math(self): + req.sampling_params.num_tokens_to_generate ) + @pytest.mark.internal + @rounder_override(64) + def test_chunked_prefill_swap_with_speculative_tokens(self): + """Test that swapping a chunked prefill request to the end of the buffer + correctly brings along the 2D speculative tokens for the other decode requests. + """ + self._setup_model_parallel_group(1, 1) + + model_config = TransformerConfig( + params_dtype=torch.float32, num_layers=2, kv_channels=8, num_attention_heads=2 + ) + inference_config = InferenceConfig( + max_sequence_length=128, + buffer_size_gb=0.01, + block_size_tokens=32, + num_speculative_tokens=2, + enable_chunked_prefill=True, + unified_memory_level=0, + ) + ctx = DynamicInferenceContext(model_config=model_config, inference_config=inference_config) + + # Setup 2 active requests in the WRONG order (violating the invariant) + # Index 0: Chunked Prefill Request (ID 42) + # Index 1: Standard Decode Request (ID 99) + ctx.total_request_count = 2 + ctx.paused_request_count = 0 + ctx.active_token_count = 2 + + ctx.chunked_prefill_request_id = 42 + ctx.request_ids[:2] = torch.tensor([42, 99]) + + # Status: 1 = Prefill, 0 = Decode + ctx.request_in_prefill_status_tensor[:2] = torch.tensor([1, 0]) + ctx.request_query_lengths[:2] = 1 + ctx.request_kv_length_offsets[:2] = torch.tensor([10, 20]) + ctx.request_last_kv_block_offset[:2] = torch.tensor([10, 20]) + ctx.request_to_kv_block_ids[:2, 0] = torch.tensor([0, 1]) + ctx.request_last_kv_block_id[:2] = torch.tensor([0, 1]) + ctx.request_kv_block_counts[:2] = 1 + + active_requests_mask = torch.tensor([1, 1], device='cuda') + + # New base tokens: [100 (for prefill), 200 (for decode)] + new_tokens = torch.tensor([100, 200], device='cuda') + + # New spec tokens: Col 0 for prefill (dummy), Col 1 for decode (real draft tokens) + new_speculative_tokens = torch.tensor([[101, 201], [102, 202]], device='cuda') + + # Trigger update_requests. + # It must detect ID 42 is at index 0, and swap it with index 1. + ctx.update_requests( + active_requests_mask=active_requests_mask, + new_tokens=new_tokens, + new_speculative_tokens=new_speculative_tokens, + ) + + # 1. Verify the IDs were swapped successfully + assert torch.equal( + ctx.request_ids[:2], torch.tensor([99, 42], dtype=torch.int32, device='cuda') + ) + + # 2. Verify the Decode request (now at Index 0) correctly flattened its + # base token (200) AND its specific speculative tokens (201, 202). + # 3. Verify the Prefill request (now at Index 1) flattened its tokens (100, 101, 102). + expected_flattened_tokens = torch.tensor( + [200, 201, 202, 100, 101, 102], # Decode request (ID 99) # Prefill request (ID 42) + device='cuda', + ) + + assert torch.equal( + ctx.token_to_input_ids[:6], expected_flattened_tokens + ), "Speculative tokens were not correctly swapped alongside the chunked prefill request!" + @pytest.mark.internal @rounder_override(64) def test_speculative_with_prefix_caching_shared_blocks(self): diff --git a/tests/unit_tests/inference/engines/test_dynamic_engine.py b/tests/unit_tests/inference/engines/test_dynamic_engine.py index cb0351044be..248ef643379 100644 --- a/tests/unit_tests/inference/engines/test_dynamic_engine.py +++ b/tests/unit_tests/inference/engines/test_dynamic_engine.py @@ -2442,6 +2442,107 @@ def mock_mtp_forward_reject(*args, **kwargs): len(finished_req.generated_tokens) == 6 ), f"Expected 6 tokens, got {len(finished_req.generated_tokens)}. Double counting occurred." + @pytest.mark.internal + @pytest.mark.skipif( + not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" + ) + @torch.inference_mode() + def test_speculative_decoding_with_eviction_and_swapping(self): + """Test that speculative decoding works correctly when requests are paused and evicted. + + This exercises the `_swap_book_keeping_tensors` logic with the 2D `new_speculative_tokens` + tensor, ensuring no dimensional mismatch or index errors occur during tensor swapping. + """ + # Very constrained memory environment to force pausing and eviction + test_config = DynamicEngineTestConfig( + num_requests=3, + min_prompt_length=16, + max_prompt_length=16, + num_tokens_to_generate=32, + context_block_size_tokens=16, + num_speculative_tokens=2, + # 40 KB translates to 3 blocks. + # 3 requests * 3 blocks per request (1 prompt + 2 gen) = 9 blocks needed. + # This guarantees we will run out of active memory mid-generation. + context_buffer_size_gb=0.00004, + context_paused_buffer_size_gb=0.0, # 0 paused buffer forces immediate eviction + model_provider="gpt", + materialize_only_last_token_logits=False, + use_fixed_output_lengths=True, + ) + + env = self._build_test_env(test_config) + + print(f"total block count = {env.engine.context.block_allocator.total_count}") + + unwrapped_model = env.engine.controller.inference_wrapped_model.model + + # Mock forward pass to return safe, deterministic logits to avoid NaN/Inf crashes + # in torch.multinomial caused by randomly initialized weights. + def mock_safe_forward(*args, **kwargs): + tokens = kwargs.get("tokens", args[0] if args else kwargs.get("input_ids")) + b, s = tokens.shape + + base_logits = torch.zeros( + b, s, test_config.vocab_size, device=tokens.device, dtype=torch.bfloat16 + ) + base_logits[:, :, 0] = 100.0 # Force model to deterministically pick token 0 + + mtp_logits = torch.zeros( + test_config.num_speculative_tokens, + s, + test_config.vocab_size, + device=tokens.device, + dtype=torch.bfloat16, + ) + mtp_logits[:, :, 0] = 100.0 # Force speculative heads to also pick token 0 + + unwrapped_model._mtp_logits_cache = mtp_logits + return base_logits + + unwrapped_model.forward = mock_safe_forward + + # Add all requests at once. They will all start prefill, but as they generate + # and request more blocks, the engine will run out of active blocks. + # Since paused_buffer_size is 0, any request that pauses will immediately + # overflow the paused buffer and trigger an eviction. + for request in env.requests: + request.sampling_params.num_tokens_to_generate = 32 + env.engine._add_request(request) + + eviction_occurred = False + + # Step the engine manually until all requests finish. + while env.engine.has_unfinished_requests(): + # Record the number of evicted requests before the step + evicted_before = env.engine.evicted_request_count + + # Step the engine + env.engine.schedule_waiting_requests() + env.engine.step_modern() + + # Check if any request was evicted during this step + if env.engine.evicted_request_count > evicted_before: + eviction_occurred = True + + # Assert that our constrained memory actually caused an eviction, + # proving we exercised the evict_overflow_paused_requests path with spec tokens. + assert ( + eviction_occurred + ), "Test failed to trigger an eviction. The test environment memory wasn't tight enough." + + # Verify all requests successfully went back through the queue and finished cleanly. + # We MUST check the merged records from the engine, because eviction checkpoints + # the requests, leaving the original instances in env.requests permanently active. + for request_id, entry in env.engine.requests.items(): + merged_req = entry.record.merge() + assert ( + merged_req.status == Status.COMPLETED + ), f"Request {request_id} failed to complete." + assert ( + len(merged_req.generated_tokens) == 31 + ), f"Request {request_id} didn't generate expected tokens." + @pytest.mark.internal @pytest.mark.skipif( not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" From 9eb639a09d33809fca9eb5d531ab0486a4e955fa Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Mon, 9 Mar 2026 14:26:38 -0700 Subject: [PATCH 67/76] Log speculative token acceptance rates Signed-off-by: Keshav Santhanam --- .../core/inference/engines/dynamic_engine.py | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/megatron/core/inference/engines/dynamic_engine.py b/megatron/core/inference/engines/dynamic_engine.py index 79f1469b7ce..9909dc3674a 100644 --- a/megatron/core/inference/engines/dynamic_engine.py +++ b/megatron/core/inference/engines/dynamic_engine.py @@ -296,6 +296,11 @@ def reset(self) -> None: self.resume_request_ids = None + # Speculative decoding acceptance tracking. + self._spec_tokens_proposed = 0 + self._spec_tokens_accepted = 0 + self._spec_steps = 0 + # Prefix caching coordination state. self._prefix_coordination_waits = 0 @@ -989,6 +994,9 @@ def post_process_requests( # empty lists for each request, so the zip produces the correct number of iterations accepted_tokens_iter = repeat([]) if accepted_tokens is None else accepted_tokens.tolist() + if self.num_speculative_tokens > 0 and accepted_tokens is not None: + self._spec_steps += 1 + for req_idx, (request_id, tokens, accepted_tokens_list, request_log_probs) in enumerate( zip(request_ids.tolist(), sample.tolist(), accepted_tokens_iter, log_probs_iter) ): @@ -1000,6 +1008,10 @@ def post_process_requests( if self.num_speculative_tokens > 0: accepted_tokens = list(filter(lambda tok: tok != -1, accepted_tokens_list)) + # Track acceptance statistics for logging. + self._spec_tokens_proposed += self.num_speculative_tokens + self._spec_tokens_accepted += len(accepted_tokens) + # The order `accepted_tokens + tokens` is correct here. # `accepted_tokens` contains the sequence of # successfully verified draft tokens. `tokens` (from `sample`) is the @@ -1611,6 +1623,14 @@ async def async_bookkeep( else: metrics[f'inference/{key}'] = value + # Add speculative decoding acceptance metrics. + if self.num_speculative_tokens > 0 and self._spec_tokens_proposed > 0: + acceptance_rate = self._spec_tokens_accepted / self._spec_tokens_proposed + metrics['inference/spec_decode_acceptance_rate'] = float(acceptance_rate * 100.0) + metrics['inference/spec_decode_tokens_proposed'] = int(self._spec_tokens_proposed) + metrics['inference/spec_decode_tokens_accepted'] = int(self._spec_tokens_accepted) + metrics['inference/spec_decode_num_steps'] = int(self._spec_steps) + if HAVE_WANDB and self.metrics_writer.__name__ == "wandb": self.metrics_writer.log(metrics, commit=True) else: @@ -1660,10 +1680,27 @@ async def async_bookkeep( mem["reserved_bytes.all.current"] / (1024**3), ) ) + if self.num_speculative_tokens > 0 and self._spec_tokens_proposed > 0: + spec_rate = self._spec_tokens_accepted / self._spec_tokens_proposed * 100.0 + output_str += ( + " ... spec: accept %.1f%% (%d/%d in %d steps)" + % ( + spec_rate, + self._spec_tokens_accepted, + self._spec_tokens_proposed, + self._spec_steps, + ) + ) if context_state["is_decode_only"]: output_str = f"\033[94m{output_str}\033[0m" logging.info(output_str) + # Reset speculative decoding accumulators after both wandb and console logging. + if self.num_speculative_tokens > 0: + self._spec_tokens_proposed = 0 + self._spec_tokens_accepted = 0 + self._spec_steps = 0 + return { "active_request_ids": active_request_ids, "finished_request_records": finished_request_records, From 675aa01a9d3325624756e8a02b1ce5010bf436a8 Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Mon, 9 Mar 2026 15:39:07 -0700 Subject: [PATCH 68/76] Don't overcount spec proposed tokens for prefill requests Signed-off-by: Keshav Santhanam --- megatron/core/inference/engines/dynamic_engine.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/megatron/core/inference/engines/dynamic_engine.py b/megatron/core/inference/engines/dynamic_engine.py index 9909dc3674a..8b122929e40 100644 --- a/megatron/core/inference/engines/dynamic_engine.py +++ b/megatron/core/inference/engines/dynamic_engine.py @@ -1005,12 +1005,19 @@ def post_process_requests( if not isinstance(tokens, list): tokens = [tokens] + request: DynamicInferenceRequest = self.get_request(request_id) + if self.num_speculative_tokens > 0: accepted_tokens = list(filter(lambda tok: tok != -1, accepted_tokens_list)) - # Track acceptance statistics for logging. - self._spec_tokens_proposed += self.num_speculative_tokens - self._spec_tokens_accepted += len(accepted_tokens) + # Track acceptance statistics for logging (decode requests only). + # Prefill requests don't propose speculative tokens, so including + # them would inflate the proposed count and deflate the rate. + # A request in its first generation step (empty generated_tokens) + # was in prefill this step. + if len(request.generated_tokens) > 0: + self._spec_tokens_proposed += self.num_speculative_tokens + self._spec_tokens_accepted += len(accepted_tokens) # The order `accepted_tokens + tokens` is correct here. # `accepted_tokens` contains the sequence of @@ -1018,8 +1025,6 @@ def post_process_requests( # brand new token generated by the target model based on that accepted prefix. # Therefore, the newly sampled token must go at the end of the sequence. tokens = accepted_tokens + tokens - - request: DynamicInferenceRequest = self.get_request(request_id) if request_id != self.context.chunked_prefill_request_id: # Skip appending token for requests being finished due to stop words # (they already have their final token from the previous step) From 4f119d2e3a4ebdec7eeab5b9304b61a02ec5cd61 Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Mon, 9 Mar 2026 16:54:45 -0700 Subject: [PATCH 69/76] Linting Signed-off-by: Keshav Santhanam --- megatron/core/inference/engines/dynamic_engine.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/megatron/core/inference/engines/dynamic_engine.py b/megatron/core/inference/engines/dynamic_engine.py index 8b122929e40..d1587328d20 100644 --- a/megatron/core/inference/engines/dynamic_engine.py +++ b/megatron/core/inference/engines/dynamic_engine.py @@ -1688,8 +1688,7 @@ async def async_bookkeep( if self.num_speculative_tokens > 0 and self._spec_tokens_proposed > 0: spec_rate = self._spec_tokens_accepted / self._spec_tokens_proposed * 100.0 output_str += ( - " ... spec: accept %.1f%% (%d/%d in %d steps)" - % ( + " ... spec: accept %.1f%% (%d/%d in %d steps)" % ( spec_rate, self._spec_tokens_accepted, self._spec_tokens_proposed, From fe1372c1614aeb7714f37535c13b1dcb955406d6 Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Mon, 9 Mar 2026 17:10:24 -0700 Subject: [PATCH 70/76] Linting Signed-off-by: Keshav Santhanam --- megatron/core/inference/engines/dynamic_engine.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/megatron/core/inference/engines/dynamic_engine.py b/megatron/core/inference/engines/dynamic_engine.py index d1587328d20..e7338ad3141 100644 --- a/megatron/core/inference/engines/dynamic_engine.py +++ b/megatron/core/inference/engines/dynamic_engine.py @@ -1687,13 +1687,11 @@ async def async_bookkeep( ) if self.num_speculative_tokens > 0 and self._spec_tokens_proposed > 0: spec_rate = self._spec_tokens_accepted / self._spec_tokens_proposed * 100.0 - output_str += ( - " ... spec: accept %.1f%% (%d/%d in %d steps)" % ( - spec_rate, - self._spec_tokens_accepted, - self._spec_tokens_proposed, - self._spec_steps, - ) + output_str += " ... spec: accept %.1f%% (%d/%d in %d steps)" % ( + spec_rate, + self._spec_tokens_accepted, + self._spec_tokens_proposed, + self._spec_steps, ) if context_state["is_decode_only"]: output_str = f"\033[94m{output_str}\033[0m" From 3be6e4db06491996de6ac9b1f9022b8c6bf1ddd8 Mon Sep 17 00:00:00 2001 From: Shanmugam Ramasamy Date: Mon, 9 Mar 2026 19:01:50 -0700 Subject: [PATCH 71/76] Fixing logprobs, stop words adn track_generated_token_events --- .../inference/contexts/dynamic_context.py | 10 +- .../core/inference/engines/dynamic_engine.py | 106 ++++++--- .../text_generation_controller.py | 216 +++++++++++++++++- megatron/core/transformer/attention.py | 8 +- .../contexts/test_dynamic_context.py | 4 +- .../inference/engines/test_dynamic_engine.py | 6 +- tests/unit_tests/inference/test_stop_words.py | 211 ++++++++++++++--- 7 files changed, 480 insertions(+), 81 deletions(-) diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index d04359aacdf..2de4c00a921 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -2010,7 +2010,9 @@ def _move_book_keeping_tensors( self.mamba_metadata.request_to_mamba_state_idx[src_idxs] ) - def _swap_book_keeping_tensors(self, src_idxs, dst_idxs, next_tokens): + def _swap_book_keeping_tensors( + self, src_idxs, dst_idxs, next_tokens, new_speculative_tokens=None + ): """ Swaps all the relevent booking tensors with src idxs to dst idxs """ @@ -2020,6 +2022,11 @@ def _swap_book_keeping_tensors(self, src_idxs, dst_idxs, next_tokens): tensor_swap(self.request_output_lengths, src_idxs, dst_idxs) tensor_swap(self.request_ids, src_idxs, dst_idxs) tensor_swap(next_tokens, src_idxs, dst_idxs) + if new_speculative_tokens is not None: + # new_speculative_tokens has shape [num_spec, num_requests]; swap columns. + temp = new_speculative_tokens[:, src_idxs].clone() + new_speculative_tokens[:, src_idxs] = new_speculative_tokens[:, dst_idxs] + new_speculative_tokens[:, dst_idxs] = temp tensor_swap(self.request_to_kv_block_ids, src_idxs, dst_idxs) tensor_swap(self.request_kv_block_counts, src_idxs, dst_idxs) tensor_swap(self.request_last_kv_block_id, src_idxs, dst_idxs) @@ -2506,6 +2513,7 @@ def update_requests( src_idxs=torch.tensor([self.get_index_of_chunked_prefill_request()]), dst_idxs=torch.tensor([self.total_request_count - 1]), next_tokens=next_tokens, + new_speculative_tokens=new_speculative_tokens, ) # 7. We make changes to the request book keeping tesnsors and setup the tokens for next iteration diff --git a/megatron/core/inference/engines/dynamic_engine.py b/megatron/core/inference/engines/dynamic_engine.py index 746adbc6590..b6517589c3c 100644 --- a/megatron/core/inference/engines/dynamic_engine.py +++ b/megatron/core/inference/engines/dynamic_engine.py @@ -1004,6 +1004,7 @@ def post_process_requests( tokens = accepted_tokens + tokens request: DynamicInferenceRequest = self.get_request(request_id) + num_stop_word_trim = 0 if request_id != self.context.chunked_prefill_request_id: # Skip appending token for requests being finished due to stop words # (they already have their final token from the previous step) @@ -1020,41 +1021,50 @@ def post_process_requests( is_first_token = len(request.generated_tokens) == 0 request.generated_tokens += tokens - # TODO : SHAN Should check and change the following for speculative tokens - token = tokens[0] + first_event_in_step = None if self.track_generated_token_events: - if block_allocator.enable_prefix_caching: - event_generated_token = request.add_event_generated_token( - token, - blocks_total=block_allocator.total_count, - blocks_hashed_total=blocks_allocated, - blocks_hashed_active=blocks_hashed_active, - blocks_ref_count=blocks_ref_count, - ) - else: - event_generated_token = request.add_event_generated_token( - token, - blocks_total=block_allocator.total_count, - blocks_hashed_total=blocks_allocated, - blocks_hashed_active=blocks_hashed_active, - ) + for token in tokens: + if block_allocator.enable_prefix_caching: + evt = request.add_event_generated_token( + token, + blocks_total=block_allocator.total_count, + blocks_hashed_total=blocks_allocated, + blocks_hashed_active=blocks_hashed_active, + blocks_ref_count=blocks_ref_count, + ) + else: + evt = request.add_event_generated_token( + token, + blocks_total=block_allocator.total_count, + blocks_hashed_total=blocks_allocated, + blocks_hashed_active=blocks_hashed_active, + ) + if first_event_in_step is None: + first_event_in_step = evt if is_first_token: if self.track_generated_token_events: - first_token_event = event_generated_token + first_token_event = first_event_in_step else: first_token_event = DynamicInferenceEvent( type=DynamicInferenceEventType.GENERATED_TOKEN, - payload={"token_id": token}, + payload={"token_id": tokens[0]}, ) request.ttft = ( first_token_event.timestamp - request.event_add_engine.timestamp ) if request.tpot is None: request.tpot = [] - request.tpot.append(step_time) - - # Check for stop words (after token is appended) - stop_word_hit = self._check_stop_words_for_request_post_append(request) + per_token_step_time = step_time / len(tokens) + request.tpot.extend([per_token_step_time] * len(tokens)) + + # Check for stop words (after token is appended). + # With speculative decoding, a stop word may end before the last + # appended token. The check truncates generated_tokens in-place and + # returns how many trailing tokens were removed so we can also trim + # the corresponding log probs below. + stop_word_hit, num_stop_word_trim = self._check_stop_words_for_request_post_append( + request + ) if request_id in finished_request_ids: # Request finished by normal means (termination_id, max_length, or stop word from previous step) @@ -1079,6 +1089,14 @@ def post_process_requests( # Additionally, chunked prefill request do not finish. active_request_ids.append(request_id) + # When a stop word was found mid-speculative-batch, trim log probs + # and top_n_logprobs to match the truncated generated_tokens. + if num_stop_word_trim > 0: + if request_log_probs is not None: + request_log_probs = request_log_probs[:-num_stop_word_trim] + if top_n_logprobs is not None and req_idx in top_n_logprobs: + top_n_logprobs[req_idx] = top_n_logprobs[req_idx][:-num_stop_word_trim] + # Process log_probs if available (unified for both regular and chunked prefill) if request_log_probs is not None: # Initialize lists if they don't exist @@ -1102,8 +1120,16 @@ def post_process_requests( # Handle skip_prompt_log_probs during prefill # If skip_prompt_log_probs is True and we have multiple log probs (prefill), - # only process the last one (first generated token) - if request.sampling_params.skip_prompt_log_probs and len(request_log_probs) > 1: + # only process the last one (first generated token). + # With speculative decoding, decode steps also produce multiple log probs + # (one per accepted token + new sample), so we must check that this is + # actually a prefill step (no generated log probs accumulated yet). + is_prefill_log_probs = len(request.generated_log_probs) == 0 + if ( + request.sampling_params.skip_prompt_log_probs + and len(request_log_probs) > 1 + and is_prefill_log_probs + ): # Only append the last log prob (first generated token) to generated_log_probs request.generated_log_probs.append(request_log_probs[-1]) else: @@ -1219,37 +1245,47 @@ def _get_and_clear_stop_word_finished_ids(self, active_request_ids: list[int]) - self.stop_word_finished_request_ids -= result return result - # TODO : We also might have to delete some tokens, if stop word hit in the middle (speculative case) - def _check_stop_words_for_request_post_append(self, request: DynamicInferenceRequest) -> bool: + def _check_stop_words_for_request_post_append( + self, request: DynamicInferenceRequest + ) -> Tuple[bool, int]: """Check if a request should stop due to stop words (after token is appended). This method is called from post_process_requests after the token has already been appended to request.generated_tokens. + With speculative decoding, multiple tokens are appended at once. The stop word + may end before the last appended token, leaving extra tokens that must be + trimmed. When this happens, generated_tokens is truncated in-place and the + number of trimmed tokens is returned so the caller can also trim log probs. + Args: request: The request to check. Returns: - bool: True if the generated sequence ends with a stop word, False otherwise. + Tuple of (stop_word_hit, num_tokens_trimmed): + stop_word_hit: True if the generated sequence contains a stop word. + num_tokens_trimmed: Number of tokens removed from the end of + generated_tokens (0 when the stop word is at the very end + or when no stop word was found). """ - # Check if request has stop words configured if request.stop_word_ids is None or len(request.stop_word_ids) == 0: - return False + return False, 0 generated_tokens = request.generated_tokens - # Check if the sequence ends with any stop word for stop_word_ids in request.stop_word_ids: stop_len = len(stop_word_ids) if len(generated_tokens) >= stop_len: # Check the last stop_len tokens shifting by 1 up to num_speculative_tokens. - # We do this regardless of stop_len because speculative decoding can append - # multiple tokens at once, meaning the stop word might end at any of those positions. + # Speculative decoding can append multiple tokens at once, so the stop + # word might end at any position within the newly appended tokens. for i in range(self.num_speculative_tokens + 1): end_idx = -i if i > 0 else None if list(generated_tokens[-stop_len - i : end_idx]) == stop_word_ids: - return True - return False + if i > 0: + request.generated_tokens = request.generated_tokens[:-i] + return True, i + return False, 0 def get_prefix_coordination_metrics(self) -> dict: """Return prefix caching coordination metrics. diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index 55df5a95334..f1e028b1ecf 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -1165,6 +1165,201 @@ def _dynamic_step_calculate_log_probs(self, logits: Tensor) -> Optional[Tensor]: only_last_token_logits=context.config.materialize_only_last_token_logits, ) + def _dynamic_step_calculate_log_probs_speculative( + self, logits: Tensor + ) -> Tuple[List[List[float]], Tensor]: + """Calculate log probs from logits for speculative decoding. + + For decode requests, computes log probs for each accepted speculative token + and the newly sampled token using the main model logits. For prefill requests, + handles prompt log probs the same way as non-speculative decoding. + + The main model logits at position j predict the token at position j+1. So: + - log_prob(accepted_token[j]) comes from logits at position j + - log_prob(newly_sampled_token) comes from logits at position accepted_count + + Args: + logits (Tensor): The main model logits [1, seq_len, vocab_size]. + + Returns: + Tuple of (log_probs_list, log_probs_tensor): + log_probs_list: List of lists, one per active request, containing + log probs for the tokens emitted in this step. + log_probs_tensor: Full log_softmax tensor for top-n computation. + """ + context = self.inference_wrapped_model.inference_context + active_request_count = context.total_request_count - context.paused_request_count + + request_in_prefill_status_tensor = context.request_in_prefill_status_tensor[ + context.paused_request_count : context.total_request_count + ] + request_query_lengths = context.request_query_lengths[ + context.paused_request_count : context.total_request_count + ] + + num_prefill_requests = request_in_prefill_status_tensor.sum().item() + num_decode_requests = active_request_count - num_prefill_requests + + logits_squeezed = logits.squeeze(0).float() + log_probs_tensor = F.log_softmax(logits_squeezed[: context.active_token_count], dim=-1) + + log_probs_list_decode = [] + + if num_decode_requests > 0: + decode_len = num_decode_requests * (self.num_speculative_tokens + 1) + decode_log_probs = log_probs_tensor[:decode_len].reshape( + num_decode_requests, self.num_speculative_tokens + 1, -1 + ) + accepted_counts = self._accepted_token_counts_per_request[:num_decode_requests] + + # Build a [num_decode, num_spec+1] token ID matrix for gathering. + # Columns 0..num_spec-1 hold accepted speculative tokens (clamped to 0 + # where rejected, since those positions will be masked out). + # At column accepted_count[i], place the newly sampled token. + gather_tokens = torch.zeros( + num_decode_requests, + self.num_speculative_tokens + 1, + device=logits.device, + dtype=torch.long, + ) + gather_tokens[:, : self.num_speculative_tokens] = self._accepted_tokens_per_request[ + :num_decode_requests + ].clamp(min=0) + gather_tokens[ + torch.arange(num_decode_requests, device=logits.device), accepted_counts + ] = self._sampled_tokens_cuda[:num_decode_requests] + + # Gather: [num_decode, num_spec+1] + gathered_log_probs = decode_log_probs.gather(2, gather_tokens.unsqueeze(-1)).squeeze(-1) + + log_probs_list_decode = [ + gathered_log_probs[i, : accepted_counts[i].item() + 1].tolist() + for i in range(num_decode_requests) + ] + + log_probs_list_prefill = [] + if num_prefill_requests > 0: + decode_len = num_decode_requests * (self.num_speculative_tokens + 1) + prefill_log_probs = log_probs_tensor[decode_len:] + + prefill_token_ids = context.token_to_input_ids[ + decode_len : context.active_token_count + ].roll(-1, 0) + prefill_query_lengths = request_query_lengths[request_in_prefill_status_tensor == 1] + new_token_idx = prefill_query_lengths.cumsum(0) - 1 + prefill_new_tokens = self._sampled_tokens_cuda[num_decode_requests:active_request_count] + prefill_token_ids[new_token_idx] = prefill_new_tokens + + prefill_token_count = context.active_token_count - decode_len + seq_idx = torch.arange(prefill_token_count, device=logits.device) + selected_log_probs = prefill_log_probs[seq_idx, prefill_token_ids] + + prefill_log_probs_split = selected_log_probs.cpu().split( + prefill_query_lengths.tolist(), dim=0 + ) + log_probs_list_prefill = [lp.tolist() for lp in prefill_log_probs_split] + + log_probs_list = log_probs_list_decode + log_probs_list_prefill + + return log_probs_list, log_probs_tensor + + def _dynamic_step_calculate_top_n_logprobs_speculative( + self, log_probs_tensor: Tensor + ) -> Optional[Dict[int, List[Tuple[Tensor, Tensor]]]]: + """Calculate top-n log probs for speculative decoding. + + For decode requests, computes top-n at each position that produced an + emitted token (accepted speculative positions + the newly sampled position). + For prefill requests, behaves identically to the non-speculative path. + + Args: + log_probs_tensor (Tensor): Pre-computed log_softmax tensor from + _dynamic_step_calculate_log_probs_speculative. + + Returns: + A dictionary mapping request_idx to list of (top_n_values, top_n_indices) + tuples, one per emitted token position. + """ + context = self.inference_wrapped_model.inference_context + active_request_count = context.total_request_count - context.paused_request_count + active_request_slice = slice(context.paused_request_count, context.total_request_count) + + request_in_prefill_status_tensor = context.request_in_prefill_status_tensor[ + context.paused_request_count : context.total_request_count + ] + request_query_lengths = context.request_query_lengths[ + context.paused_request_count : context.total_request_count + ] + + num_prefill_requests = request_in_prefill_status_tensor.sum().item() + num_decode_requests = active_request_count - num_prefill_requests + + top_n_results = {} + + if num_decode_requests > 0: + decode_len = num_decode_requests * (self.num_speculative_tokens + 1) + decode_log_probs = log_probs_tensor[:decode_len].reshape( + num_decode_requests, self.num_speculative_tokens + 1, -1 + ) + accepted_counts = self._accepted_token_counts_per_request[:num_decode_requests] + top_n_per_request = self._request_metadata["top_n_logprobs"][active_request_slice][ + :num_decode_requests + ] + max_top_n = int(top_n_per_request.max().item()) + + if max_top_n > 0: + + # Single batched topk on GPU: [num_decode, num_spec+1, max_top_n] + topk_results = torch.topk(decode_log_probs, k=max_top_n, dim=-1) + + # Single CPU transfer instead of O(num_decode * num_spec) transfers + topk_values_cpu = topk_results.values.cpu() + topk_indices_cpu = topk_results.indices.cpu() + + for i in range(num_decode_requests): + top_n = int(top_n_per_request[i].item()) + if top_n > 0: + num_valid = accepted_counts[i].item() + 1 + top_n_results[i] = [ + (topk_values_cpu[i, j, :top_n], topk_indices_cpu[i, j, :top_n]) + for j in range(num_valid) + ] + + if num_prefill_requests > 0: + decode_len = num_decode_requests * (self.num_speculative_tokens + 1) + prefill_log_probs = log_probs_tensor[decode_len:] + prefill_query_lengths = request_query_lengths[request_in_prefill_status_tensor == 1] + prefill_log_probs_per_request = prefill_log_probs.split( + prefill_query_lengths.tolist(), dim=0 + ) + + for i in range(num_prefill_requests): + req_idx = num_decode_requests + i + top_n = int( + self._request_metadata["top_n_logprobs"][active_request_slice][req_idx].item() + ) + if top_n > 0: + request_lp = prefill_log_probs_per_request[i] + skip_prompt = bool( + self._request_metadata["skip_prompt_log_probs"][req_idx].item() + ) + + if skip_prompt and request_lp.size(0) > 1: + top_n_logits = torch.topk(request_lp[-1], k=top_n) + top_n_results[req_idx] = [ + (top_n_logits.values.cpu(), top_n_logits.indices.cpu()) + ] + else: + top_n_logits = torch.topk(request_lp, k=top_n, dim=-1) + top_n_values_cpu = top_n_logits.values.cpu() + top_n_indices_cpu = top_n_logits.indices.cpu() + top_n_results[req_idx] = [ + (top_n_values_cpu[t], top_n_indices_cpu[t]) + for t in range(request_lp.size(0)) + ] + + return top_n_results if top_n_results else None + def _dynamic_step_calculate_top_n_logprobs( self, logits: Tensor, log_probs_tensor: Optional[Tensor] = None ) -> Optional[Dict[int, List[Tuple[Tensor, Tensor]]]]: @@ -1408,10 +1603,6 @@ async def async_generate_output_tokens_dynamic_batch( # NOTE [TDE]: This will be moved once CPU and GPU methods are separated. await asyncio.sleep(0) return_log_probs, return_top_n_logprobs = self._dynamic_step_log_probs_bookkeeping() - if self.num_speculative_tokens > 0: - assert ( - return_log_probs == False and return_top_n_logprobs == False - ), "Log probs and top n log probs are not supported with speculative tokens" self._dynamic_step_sample_bookkeeping() @@ -1424,11 +1615,20 @@ async def async_generate_output_tokens_dynamic_batch( log_probs = None top_n_logprobs = None if return_log_probs or return_top_n_logprobs: - log_probs, log_probs_tensor = self._dynamic_step_calculate_log_probs(logits) - if return_top_n_logprobs: - top_n_logprobs = self._dynamic_step_calculate_top_n_logprobs( - logits, log_probs_tensor + if self.num_speculative_tokens > 0: + log_probs, log_probs_tensor = self._dynamic_step_calculate_log_probs_speculative( + logits ) + if return_top_n_logprobs: + top_n_logprobs = self._dynamic_step_calculate_top_n_logprobs_speculative( + log_probs_tensor + ) + else: + log_probs, log_probs_tensor = self._dynamic_step_calculate_log_probs(logits) + if return_top_n_logprobs: + top_n_logprobs = self._dynamic_step_calculate_top_n_logprobs( + logits, log_probs_tensor + ) if skip_bookkeeping: request_bookkeeping = {} diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index 310a59bde35..0ac7e78fae3 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -60,7 +60,9 @@ rearrange = None try: - from flash_attn_3.flash_attn_interface import _flash_attn_forward + from flash_attn_3.flash_attn_interface import ( + _flash_attn_forward, + ) from flash_attn_3.flash_attn_interface import ( flash_attn_with_kvcache as flash_attn3_with_kvcache, ) @@ -71,7 +73,9 @@ if not HAVE_FA3: try: - from flashattn_hopper.flash_attn_interface import _flash_attn_forward + from flashattn_hopper.flash_attn_interface import ( + _flash_attn_forward, + ) from flashattn_hopper.flash_attn_interface import ( flash_attn_with_kvcache as flash_attn3_with_kvcache, ) diff --git a/tests/unit_tests/inference/contexts/test_dynamic_context.py b/tests/unit_tests/inference/contexts/test_dynamic_context.py index bf7387cd658..e727b2fd81a 100644 --- a/tests/unit_tests/inference/contexts/test_dynamic_context.py +++ b/tests/unit_tests/inference/contexts/test_dynamic_context.py @@ -1997,9 +1997,7 @@ def test_speculative_boundary_crossing_with_prefix_caching(self): new_spec = torch.tensor([[51], [52]], device='cuda') ctx.update_requests( - active_requests_mask=active_mask, - new_tokens=new_tokens, - new_speculative_tokens=new_spec, + active_requests_mask=active_mask, new_tokens=new_tokens, new_speculative_tokens=new_spec ) # A new block should have been allocated for the boundary crossing. diff --git a/tests/unit_tests/inference/engines/test_dynamic_engine.py b/tests/unit_tests/inference/engines/test_dynamic_engine.py index 18059417cf8..4ad4380b03b 100644 --- a/tests/unit_tests/inference/engines/test_dynamic_engine.py +++ b/tests/unit_tests/inference/engines/test_dynamic_engine.py @@ -345,7 +345,7 @@ def _build_test_env(cls, test_config): if test_config.num_speculative_tokens > 0: use_te = test_config.fp8 or test_config.transformer_impl == "transformer_engine" mtp_block_spec = get_gpt_mtp_block_spec( - config=transformer_config, spec=layer_spec, use_transformer_engine=use_te, + config=transformer_config, spec=layer_spec, use_transformer_engine=use_te ) # GPT model. @@ -2185,7 +2185,9 @@ def test_speculative_decoding_with_prefix_caching(self): 0, test_config.vocab_size - 1, (8,), dtype=torch.int64, device='cuda' ) - for i, prompt in enumerate([shared_prompt_a, shared_prompt_a, shared_prompt_b, shared_prompt_b]): + for i, prompt in enumerate( + [shared_prompt_a, shared_prompt_a, shared_prompt_b, shared_prompt_b] + ): env.requests[i].prompt_tokens = prompt.clone() # Run all requests through the engine. diff --git a/tests/unit_tests/inference/test_stop_words.py b/tests/unit_tests/inference/test_stop_words.py index 31665c0bb81..95a39e748d5 100644 --- a/tests/unit_tests/inference/test_stop_words.py +++ b/tests/unit_tests/inference/test_stop_words.py @@ -31,132 +31,283 @@ class TestStopWordDetection: """Test stop word detection logic.""" def _check_stop_words_for_request_post_append( - self, request: MockDynamicInferenceRequest - ) -> bool: + self, request: MockDynamicInferenceRequest, num_speculative_tokens: int = 0 + ) -> tuple: """ Check if a request should stop due to stop words (after token is appended). - This mirrors the logic in DynamicInferenceEngine._check_stop_words_for_request_post_append + This mirrors the logic in DynamicInferenceEngine._check_stop_words_for_request_post_append. + Returns (stop_word_hit, num_tokens_trimmed). """ - # Check if request has stop words configured if request.stop_word_ids is None or len(request.stop_word_ids) == 0: - return False + return False, 0 generated_tokens = request.generated_tokens - # Check if the sequence ends with any stop word for stop_word_ids in request.stop_word_ids: stop_len = len(stop_word_ids) if len(generated_tokens) >= stop_len: - # Check if the last stop_len tokens match the stop word - if list(generated_tokens[-stop_len:]) == stop_word_ids: - return True + for i in range(num_speculative_tokens + 1): + end_idx = -i if i > 0 else None + if list(generated_tokens[-stop_len - i : end_idx]) == stop_word_ids: + if i > 0: + request.generated_tokens = request.generated_tokens[:-i] + return True, i - return False + return False, 0 def test_no_stop_words_configured(self): """Test that requests without stop words configured don't trigger stop.""" request = MockDynamicInferenceRequest( request_id=1, generated_tokens=[100, 200, 300], stop_word_ids=None ) - assert self._check_stop_words_for_request_post_append(request) is False + hit, trim = self._check_stop_words_for_request_post_append(request) + assert hit is False + assert trim == 0 def test_empty_stop_words_list(self): """Test that empty stop words list doesn't trigger stop.""" request = MockDynamicInferenceRequest( request_id=1, generated_tokens=[100, 200, 300], stop_word_ids=[] ) - assert self._check_stop_words_for_request_post_append(request) is False + hit, trim = self._check_stop_words_for_request_post_append(request) + assert hit is False def test_single_token_stop_word_match(self): """Test detection of single-token stop word.""" - # Stop word is token 300 request = MockDynamicInferenceRequest( request_id=1, generated_tokens=[100, 200, 300], stop_word_ids=[[300]] ) - assert self._check_stop_words_for_request_post_append(request) is True + hit, trim = self._check_stop_words_for_request_post_append(request) + assert hit is True + assert trim == 0 + assert request.generated_tokens == [100, 200, 300] def test_single_token_stop_word_no_match(self): """Test no detection when single-token stop word doesn't match.""" request = MockDynamicInferenceRequest( request_id=1, generated_tokens=[100, 200, 300], stop_word_ids=[[400]] ) - assert self._check_stop_words_for_request_post_append(request) is False + hit, trim = self._check_stop_words_for_request_post_append(request) + assert hit is False def test_multi_token_stop_word_match(self): """Test detection of multi-token stop word.""" - # Stop word is tokens [200, 300] request = MockDynamicInferenceRequest( request_id=1, generated_tokens=[100, 200, 300], stop_word_ids=[[200, 300]] ) - assert self._check_stop_words_for_request_post_append(request) is True + hit, trim = self._check_stop_words_for_request_post_append(request) + assert hit is True + assert trim == 0 def test_multi_token_stop_word_no_match_partial(self): """Test no detection when only partial stop word matches.""" - # Stop word is [200, 300], but generated ends with [100, 200] request = MockDynamicInferenceRequest( request_id=1, generated_tokens=[100, 200], stop_word_ids=[[200, 300]] ) - assert self._check_stop_words_for_request_post_append(request) is False + hit, trim = self._check_stop_words_for_request_post_append(request) + assert hit is False def test_multi_token_stop_word_no_match_wrong_order(self): """Test no detection when tokens are present but in wrong order.""" - # Stop word is [200, 300], but generated ends with [300, 200] request = MockDynamicInferenceRequest( request_id=1, generated_tokens=[100, 300, 200], stop_word_ids=[[200, 300]] ) - assert self._check_stop_words_for_request_post_append(request) is False + hit, trim = self._check_stop_words_for_request_post_append(request) + assert hit is False def test_multiple_stop_words_first_matches(self): """Test with multiple stop words where first one matches.""" request = MockDynamicInferenceRequest( request_id=1, generated_tokens=[100, 200, 300], stop_word_ids=[[300], [400], [500]] ) - assert self._check_stop_words_for_request_post_append(request) is True + hit, trim = self._check_stop_words_for_request_post_append(request) + assert hit is True def test_multiple_stop_words_second_matches(self): """Test with multiple stop words where second one matches.""" request = MockDynamicInferenceRequest( request_id=1, generated_tokens=[100, 200, 400], stop_word_ids=[[300], [400], [500]] ) - assert self._check_stop_words_for_request_post_append(request) is True + hit, trim = self._check_stop_words_for_request_post_append(request) + assert hit is True def test_multiple_stop_words_none_match(self): """Test with multiple stop words where none match.""" request = MockDynamicInferenceRequest( request_id=1, generated_tokens=[100, 200, 600], stop_word_ids=[[300], [400], [500]] ) - assert self._check_stop_words_for_request_post_append(request) is False + hit, trim = self._check_stop_words_for_request_post_append(request) + assert hit is False def test_stop_word_longer_than_generated(self): """Test that stop word longer than generated tokens doesn't crash.""" - # Stop word is 5 tokens, but only 3 tokens generated request = MockDynamicInferenceRequest( request_id=1, generated_tokens=[100, 200, 300], stop_word_ids=[[1, 2, 3, 4, 5]] ) - assert self._check_stop_words_for_request_post_append(request) is False + hit, trim = self._check_stop_words_for_request_post_append(request) + assert hit is False def test_stop_word_exact_length_match(self): """Test stop word that matches entire generated sequence.""" request = MockDynamicInferenceRequest( request_id=1, generated_tokens=[100, 200, 300], stop_word_ids=[[100, 200, 300]] ) - assert self._check_stop_words_for_request_post_append(request) is True + hit, trim = self._check_stop_words_for_request_post_append(request) + assert hit is True def test_empty_generated_tokens(self): """Test with no generated tokens.""" request = MockDynamicInferenceRequest( request_id=1, generated_tokens=[], stop_word_ids=[[300]] ) - assert self._check_stop_words_for_request_post_append(request) is False + hit, trim = self._check_stop_words_for_request_post_append(request) + assert hit is False def test_stop_word_in_middle_not_end(self): """Test that stop word in middle of sequence doesn't trigger (only end matters).""" - # Stop word is [200], which is in middle but not at end request = MockDynamicInferenceRequest( request_id=1, generated_tokens=[100, 200, 300], stop_word_ids=[[200]] ) - assert self._check_stop_words_for_request_post_append(request) is False + hit, trim = self._check_stop_words_for_request_post_append(request) + assert hit is False + + +class TestStopWordSpeculativeDecoding: + """Test stop word detection and truncation with speculative decoding.""" + + def _check_stop_words_for_request_post_append( + self, request: MockDynamicInferenceRequest, num_speculative_tokens: int = 0 + ) -> tuple: + """Mirror of DynamicInferenceEngine._check_stop_words_for_request_post_append.""" + if request.stop_word_ids is None or len(request.stop_word_ids) == 0: + return False, 0 + + generated_tokens = request.generated_tokens + + for stop_word_ids in request.stop_word_ids: + stop_len = len(stop_word_ids) + if len(generated_tokens) >= stop_len: + for i in range(num_speculative_tokens + 1): + end_idx = -i if i > 0 else None + if list(generated_tokens[-stop_len - i : end_idx]) == stop_word_ids: + if i > 0: + request.generated_tokens = request.generated_tokens[:-i] + return True, i + + return False, 0 + + def test_stop_word_at_end_no_trim(self): + """Stop word is the last token — no trimming needed.""" + # Speculative tokens: [tok1, STOP, tok3] appended, stop word at end of accepted + # But here STOP is at the very end after all tokens + request = MockDynamicInferenceRequest( + request_id=1, generated_tokens=[10, 20, 42], stop_word_ids=[[42]] + ) + hit, trim = self._check_stop_words_for_request_post_append( + request, num_speculative_tokens=2 + ) + assert hit is True + assert trim == 0 + assert request.generated_tokens == [10, 20, 42] + + def test_stop_word_with_one_extra_token(self): + """Stop word is second-to-last — one extra token should be trimmed.""" + # Speculative appended [tok1, STOP, tok3], STOP=42 at position -2 + request = MockDynamicInferenceRequest( + request_id=1, generated_tokens=[10, 20, 42, 99], stop_word_ids=[[42]] + ) + hit, trim = self._check_stop_words_for_request_post_append( + request, num_speculative_tokens=2 + ) + assert hit is True + assert trim == 1 + assert request.generated_tokens == [10, 20, 42] + + def test_stop_word_with_two_extra_tokens(self): + """Stop word is third-to-last — two extra tokens should be trimmed.""" + # Speculative appended [STOP, tok2, tok3], STOP=42 at position -3 + request = MockDynamicInferenceRequest( + request_id=1, generated_tokens=[10, 42, 77, 88], stop_word_ids=[[42]] + ) + hit, trim = self._check_stop_words_for_request_post_append( + request, num_speculative_tokens=2 + ) + assert hit is True + assert trim == 2 + assert request.generated_tokens == [10, 42] + + def test_multi_token_stop_word_with_extra_tokens(self): + """Multi-token stop word found mid-speculative-batch.""" + # Speculative appended [tok1, STOP_A, STOP_B, tok4], stop word is [STOP_A, STOP_B] + request = MockDynamicInferenceRequest( + request_id=1, generated_tokens=[10, 20, 42, 43, 99], stop_word_ids=[[42, 43]] + ) + hit, trim = self._check_stop_words_for_request_post_append( + request, num_speculative_tokens=2 + ) + assert hit is True + assert trim == 1 + assert request.generated_tokens == [10, 20, 42, 43] + + def test_multi_token_stop_word_with_two_extra(self): + """Multi-token stop word with two extra tokens after.""" + request = MockDynamicInferenceRequest( + request_id=1, generated_tokens=[10, 42, 43, 77, 88], stop_word_ids=[[42, 43]] + ) + hit, trim = self._check_stop_words_for_request_post_append( + request, num_speculative_tokens=2 + ) + assert hit is True + assert trim == 2 + assert request.generated_tokens == [10, 42, 43] + + def test_no_stop_word_speculative(self): + """No stop word in speculative batch — nothing happens.""" + request = MockDynamicInferenceRequest( + request_id=1, generated_tokens=[10, 20, 30, 40], stop_word_ids=[[42]] + ) + hit, trim = self._check_stop_words_for_request_post_append( + request, num_speculative_tokens=2 + ) + assert hit is False + assert trim == 0 + assert request.generated_tokens == [10, 20, 30, 40] + + def test_stop_word_outside_speculative_window(self): + """Stop word exists but is outside the speculative search window.""" + # Stop word [42] is at position -4, but num_speculative_tokens=2 + # so we only check positions -1, -2, -3 (i=0,1,2) + request = MockDynamicInferenceRequest( + request_id=1, generated_tokens=[42, 10, 20, 30], stop_word_ids=[[42]] + ) + hit, trim = self._check_stop_words_for_request_post_append( + request, num_speculative_tokens=2 + ) + assert hit is False + assert trim == 0 + + def test_log_probs_trimming_scenario(self): + """Verify that the trim count can be used to trim log probs correctly.""" + # Simulate: speculative batch appended [tok1, STOP, tok3] + # Log probs: [lp1, lp2, lp3] + request = MockDynamicInferenceRequest( + request_id=1, generated_tokens=[10, 20, 42, 99], stop_word_ids=[[42]] + ) + log_probs = [-1.5, -0.3, -2.1] + + hit, trim = self._check_stop_words_for_request_post_append( + request, num_speculative_tokens=2 + ) + assert hit is True + assert trim == 1 + + # Trim log probs the same way the engine does + if trim > 0: + log_probs = log_probs[:-trim] + + assert log_probs == [-1.5, -0.3] + assert request.generated_tokens == [10, 20, 42] class TestStopWordTrackingFlow: From e3e5ca264b1ebbf41d79c2f76c9ddd8d676c4dae Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Mon, 9 Mar 2026 22:47:07 -0700 Subject: [PATCH 72/76] Fix dynamic_engine unit tests Signed-off-by: Keshav Santhanam --- .../inference/engines/test_dynamic_engine.py | 104 +++++++++++++----- 1 file changed, 77 insertions(+), 27 deletions(-) diff --git a/tests/unit_tests/inference/engines/test_dynamic_engine.py b/tests/unit_tests/inference/engines/test_dynamic_engine.py index 248ef643379..2cbed75a02d 100644 --- a/tests/unit_tests/inference/engines/test_dynamic_engine.py +++ b/tests/unit_tests/inference/engines/test_dynamic_engine.py @@ -1072,7 +1072,7 @@ def test_parallel_inference( if tp_size == 1 and pp_size == 1 and ep_size == 1: pytest.skip(reason="Test requires tp_size > 1 or pp_size > 1 or ep_size > 1") elif not torch.distributed.is_initialized(): - pytest.skip("Distributed not initialized") + Utils.initialize_distributed() world_size = torch.distributed.get_world_size() min_world_size = tp_size * pp_size * ep_size if world_size < min_world_size: @@ -2093,11 +2093,11 @@ def test_speculative_block_boundary_crossing(self): """ test_config = DynamicEngineTestConfig( num_requests=1, - min_prompt_length=4, - max_prompt_length=4, + min_prompt_length=256, + max_prompt_length=256, num_tokens_to_generate=3, num_speculative_tokens=2, - context_block_size_tokens=4, # Exactly matches prompt length + context_block_size_tokens=256, # Exactly matches prompt length context_max_requests=16, model_provider="gpt", materialize_only_last_token_logits=False, @@ -2456,15 +2456,15 @@ def test_speculative_decoding_with_eviction_and_swapping(self): # Very constrained memory environment to force pausing and eviction test_config = DynamicEngineTestConfig( num_requests=3, - min_prompt_length=16, - max_prompt_length=16, - num_tokens_to_generate=32, - context_block_size_tokens=16, + min_prompt_length=256, + max_prompt_length=256, + num_tokens_to_generate=512, + context_block_size_tokens=256, num_speculative_tokens=2, - # 40 KB translates to 3 blocks. + # 640 KB translates to 3 blocks. # 3 requests * 3 blocks per request (1 prompt + 2 gen) = 9 blocks needed. # This guarantees we will run out of active memory mid-generation. - context_buffer_size_gb=0.00004, + context_buffer_size_gb=0.00064, context_paused_buffer_size_gb=0.0, # 0 paused buffer forces immediate eviction model_provider="gpt", materialize_only_last_token_logits=False, @@ -2507,7 +2507,7 @@ def mock_safe_forward(*args, **kwargs): # Since paused_buffer_size is 0, any request that pauses will immediately # overflow the paused buffer and trigger an eviction. for request in env.requests: - request.sampling_params.num_tokens_to_generate = 32 + request.sampling_params.num_tokens_to_generate = 512 env.engine._add_request(request) eviction_occurred = False @@ -2540,7 +2540,7 @@ def mock_safe_forward(*args, **kwargs): merged_req.status == Status.COMPLETED ), f"Request {request_id} failed to complete." assert ( - len(merged_req.generated_tokens) == 31 + len(merged_req.generated_tokens) == 511 ), f"Request {request_id} didn't generate expected tokens." @pytest.mark.internal @@ -2556,25 +2556,50 @@ def test_speculative_decoding_with_prefix_caching(self): """ test_config = DynamicEngineTestConfig( num_requests=0, # Added manually below - min_prompt_length=8, - max_prompt_length=8, - num_tokens_to_generate=4, + min_prompt_length=256, + max_prompt_length=256, + num_tokens_to_generate=128, num_speculative_tokens=2, enable_prefix_caching=True, # Set at config level - context_block_size_tokens=8, # Ensure exact 1 block per prompt + context_block_size_tokens=256, # Ensure exact 1 block per prompt materialize_only_last_token_logits=False, model_provider="gpt", - context_max_tokens=512, + context_max_tokens=4096, context_max_requests=512, ) env = self._build_test_env(test_config) + unwrapped_model = env.engine.controller.inference_wrapped_model.model + + # Mock forward pass to return safe, deterministic logits to avoid NaN/Inf crashes + def mock_safe_forward(*args, **kwargs): + tokens = kwargs.get("tokens", args[0] if args else kwargs.get("input_ids")) + b, s = tokens.shape + + base_logits = torch.zeros( + b, s, test_config.vocab_size, device=tokens.device, dtype=torch.bfloat16 + ) + base_logits[:, :, 0] = 100.0 + + mtp_logits = torch.zeros( + test_config.num_speculative_tokens, + s, + test_config.vocab_size, + device=tokens.device, + dtype=torch.bfloat16, + ) + mtp_logits[:, :, 0] = 100.0 + unwrapped_model._mtp_logits_cache = mtp_logits + return base_logits + + unwrapped_model.forward = mock_safe_forward + # Create two pairs of requests with identical shared prefixes. shared_prompt_a = torch.randint( - 0, test_config.vocab_size - 1, (8,), dtype=torch.int64, device='cuda' + 0, test_config.vocab_size - 1, (256,), dtype=torch.int64, device='cuda' ) shared_prompt_b = torch.randint( - 0, test_config.vocab_size - 1, (8,), dtype=torch.int64, device='cuda' + 0, test_config.vocab_size - 1, (256,), dtype=torch.int64, device='cuda' ) prompts = [shared_prompt_a, shared_prompt_a, shared_prompt_b, shared_prompt_b] @@ -2584,7 +2609,7 @@ def test_speculative_decoding_with_prefix_caching(self): env.engine.add_request( request_id=i, prompt=prompt.clone(), - sampling_params=SamplingParams(num_tokens_to_generate=4, termination_id=99), + sampling_params=SamplingParams(num_tokens_to_generate=128, termination_id=99), ) # First, run schedule_waiting_requests and ONE step to allocate the prefill blocks. @@ -2628,30 +2653,55 @@ def test_speculative_decoding_chunked_prefill_and_prefix_caching(self): """ test_config = DynamicEngineTestConfig( num_requests=0, - min_prompt_length=16, - max_prompt_length=16, - num_tokens_to_generate=4, + min_prompt_length=512, + max_prompt_length=512, + num_tokens_to_generate=128, num_speculative_tokens=2, materialize_only_last_token_logits=False, enable_chunked_prefill=True, enable_prefix_caching=True, # Set at config level - context_block_size_tokens=8, + context_block_size_tokens=256, model_provider="gpt", - context_max_tokens=48, # Force chunking + context_max_tokens=1536, # Force chunking context_max_requests=48, ) env = self._build_test_env(test_config) + unwrapped_model = env.engine.controller.inference_wrapped_model.model + + # Mock forward pass to return safe, deterministic logits to avoid NaN/Inf crashes + def mock_safe_forward(*args, **kwargs): + tokens = kwargs.get("tokens", args[0] if args else kwargs.get("input_ids")) + b, s = tokens.shape + + base_logits = torch.zeros( + b, s, test_config.vocab_size, device=tokens.device, dtype=torch.bfloat16 + ) + base_logits[:, :, 0] = 100.0 + + mtp_logits = torch.zeros( + test_config.num_speculative_tokens, + s, + test_config.vocab_size, + device=tokens.device, + dtype=torch.bfloat16, + ) + mtp_logits[:, :, 0] = 100.0 + unwrapped_model._mtp_logits_cache = mtp_logits + return base_logits + + unwrapped_model.forward = mock_safe_forward + # Create identical prompts for all 4 requests shared_prompt = torch.randint( - 0, test_config.vocab_size - 1, (16,), dtype=torch.int64, device='cuda' + 0, test_config.vocab_size - 1, (512,), dtype=torch.int64, device='cuda' ) for i in range(4): env.engine.add_request( request_id=i, prompt=shared_prompt.clone(), - sampling_params=SamplingParams(num_tokens_to_generate=4, termination_id=99), + sampling_params=SamplingParams(num_tokens_to_generate=128, termination_id=99), ) while env.engine.has_unfinished_requests(): From 0ecfb4ed6f26aa46715f1069f948b2d93e344d81 Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Tue, 10 Mar 2026 02:20:43 -0700 Subject: [PATCH 73/76] Fix speculative decoding Signed-off-by: Keshav Santhanam --- .../inference/contexts/dynamic_context.py | 23 +- .../text_generation_controller.py | 271 +++++++++++++++--- megatron/core/models/gpt/gpt_model.py | 73 ++++- megatron/core/models/mamba/mamba_model.py | 76 ++++- .../transformer/multi_token_prediction.py | 47 +++ .../contexts/test_dynamic_context.py | 21 +- .../inference/engines/test_dynamic_engine.py | 268 +++++++++-------- 7 files changed, 567 insertions(+), 212 deletions(-) diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index 7975e32e677..98996b0cbfd 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -1365,26 +1365,35 @@ def add_dummy_requests_for_expert_parallel_step(self) -> None: assert smallest_cuda_graph_dimensions.prefill_req_count == 0 N = smallest_cuda_graph_dimensions.decode_req_count + tokens_per_request = self.num_speculative_tokens + 1 + T = smallest_cuda_graph_dimensions.token_count # N * tokens_per_request dummy_block_idx = self.block_allocator.dummy_block_idx - # 1. Request counts and token count (decode-only: 1 token per request). + # 1. Request counts and token count. + # With speculative decoding each decode request has (num_speculative_tokens + 1) tokens. self.total_request_count = N - self.active_token_count = N + self.active_token_count = T self.num_prefill_requests = 0 # 2. Per-request state consumed by mha_metadata.update(). - self.request_query_lengths[0:N].fill_(1) + self.request_query_lengths[0:N].fill_(tokens_per_request) self.request_kv_length_offsets[0:N].fill_(0) self.request_to_kv_block_ids[0:N, 0] = dummy_block_idx # 3. Token-level state consumed by the triton KV append kernel. - self.token_to_block_idx[0:N] = dummy_block_idx - self.token_to_local_position_within_kv_block[0:N] = 0 + self.token_to_block_idx[0:T] = dummy_block_idx + self.token_to_local_position_within_kv_block[0:T] = 0 if self.is_hybrid_model: # 4. token_to_request_idx: needed by mamba_metadata.update() for hybrid models. - self.token_to_request_idx[0:N] = torch.arange( - 0, N, device=self.token_to_request_idx.device, dtype=self.token_to_request_idx.dtype + self.token_to_request_idx[0:T] = torch.repeat_interleave( + torch.arange( + 0, + N, + device=self.token_to_request_idx.device, + dtype=self.token_to_request_idx.dtype, + ), + tokens_per_request, ) # 5. Mamba state: allocate slots for dummy requests. diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index 4f1f6356875..479a0030c76 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -122,6 +122,8 @@ def _init_dynamic_sampling_tensors(self): self._accepted_tokens_per_request = None # MTP tensor will be allocated later when num_speculative_tokens is set by the engine self._sampled_mtp_tokens_cuda = None + # Last accepted sequence indices for serial MTP computation + self._last_accepted_seq_indices = None # Keep track of request metadata. self._request_metadata: Dict[str, Tensor] = {} @@ -624,29 +626,19 @@ def _dynamic_step_forward_logits(self, input_ids: Tensor, position_ids: Tensor) logits = self.inference_wrapped_model.run_one_forward_step( {"tokens": input_ids, "position_ids": position_ids, "attention_mask": None} ) - # [1, seq_len, vocab_size] (logits) - # [num_speculative_tokens, seq_len, vocab_size] (mtp_logits) + # logits shape: [1, seq_len, vocab_size] - if self.num_speculative_tokens > 0: - unwrapped_model = unwrap_model(self.inference_wrapped_model.model) - assert hasattr(unwrapped_model, '_mtp_logits_cache'), "MTP logits cache not found" - mtp_logits = unwrapped_model._mtp_logits_cache - expected_mtp_logits_length, _, vocab_size = mtp_logits.shape - assert ( - expected_mtp_logits_length == self.num_mtp_heads - ), f"MTP logits length mismatch. Expected mtp logits length {self.num_mtp_heads}, got {expected_mtp_logits_length}" - mtp_logits = mtp_logits[: self.num_speculative_tokens] - - logits = torch.cat( - [logits, mtp_logits], dim=0 - ) # [num_speculative_tokens + 1, seq_len_or_required, vocab_size] + # Note: When speculative decoding is active (num_speculative_tokens > 0), + # the model skips MTP computation during the forward pass. MTP logits + # will be computed serially after verification to ensure they are + # conditioned on verified tokens only. if self.model_is_pipeline_parallel: if context.config.materialize_only_last_token_logits: logits_seq_len = active_request_count else: logits_seq_len = input_ids.shape[1] - logits_shape = [self.num_speculative_tokens + 1, logits_seq_len, self.vocab_size] + logits_shape = [1, logits_seq_len, self.vocab_size] if is_pipeline_last_stage(self.pp_group): assert logits is not None and torch.Size(logits_shape) == logits.shape @@ -793,6 +785,105 @@ def _rewind_kv_cache(self): ] ) + def _sample_from_logits_2d(self, logits_2d: Tensor) -> Tensor: + """Sample tokens from 2D logits using existing sampling parameters. + + Args: + logits_2d (Tensor): Logits of shape [num_requests, vocab_size]. + + Returns: + Tensor: Sampled tokens of shape [num_requests]. + """ + spec_token_list = [] + indices_list = [] + for request_indices, temp, top_k, top_p in self._torch_sampling_buckets: + request_indices_tensor = torch.tensor( + request_indices, device=logits_2d.device, dtype=torch.long + ) + spec_token_list.append( + self._torch_sampling_func(logits_2d[request_indices_tensor, :], temp, top_k, top_p) + ) + indices_list.append(request_indices_tensor) + + spec_tokens = torch.empty(logits_2d.shape[0], device=logits_2d.device, dtype=torch.int64) + for tokens, indices in zip(spec_token_list, indices_list): + spec_tokens[indices] = tokens + return spec_tokens + + def _compute_serial_mtp_and_sample(self): + """Compute MTP logits serially after verification and sample speculative tokens. + + This ensures that MTP predictions are always conditioned on verified tokens. + Each MTP depth receives the correctly sampled token from the previous depth + (or the base token for depth 0) rather than stale speculative tokens from + the previous step. + """ + context = self.inference_wrapped_model.inference_context + active_request_count = context.total_request_count - context.paused_request_count + active_slice = slice(context.paused_request_count, context.total_request_count) + + unwrapped_model = unwrap_model(self.inference_wrapped_model.model) + + # On non-last pipeline stages, the model won't have decoder hidden states. + has_mtp = is_pipeline_last_stage(self.pp_group) and hasattr( + unwrapped_model, '_decoder_hidden_states_cache' + ) + + if has_mtp: + # Get decoder hidden states at last accepted positions. + hidden_states = unwrapped_model._decoder_hidden_states_cache + last_accepted_hidden = hidden_states[self._last_accepted_seq_indices, :, :] + # Shape: [active_request_count, 1, hidden_size] + else: + last_accepted_hidden = None + + # Compute position IDs for the next tokens. + # After rewind, request_kv_length_offsets has been adjusted. The actual + # KV cache length is: adjusted_offset + (1 + num_speculative_tokens). + # The next position to predict starts at that cache length. + adjusted_offsets = context.request_kv_length_offsets[active_slice] + base_position = adjusted_offsets + (1 + self.num_speculative_tokens) + + # Start with the freshly sampled base token. + next_token_ids = self._sampled_tokens_cuda[:active_request_count].clone() + current_hidden = last_accepted_hidden if has_mtp else None + + num_depths = min(self.num_speculative_tokens, self.num_mtp_heads) + for depth in range(num_depths): + position_ids = (base_position + depth).unsqueeze(0) # [1, active_request_count] + token_ids = next_token_ids.unsqueeze(0) # [1, active_request_count] + + mtp_logits_2d = None + if has_mtp: + current_hidden, mtp_logits = unwrapped_model.compute_mtp_single_step( + hidden_states=current_hidden, + next_token_ids=token_ids, + position_ids=position_ids, + depth=depth, + ) + # mtp_logits: [active_request_count, 1, vocab_size] + mtp_logits_2d = mtp_logits.squeeze(1) # [active_request_count, vocab_size] + + # Broadcast MTP logits across pipeline stages. + if self.model_is_pipeline_parallel: + mtp_logits_2d = broadcast_from_last_pipeline_stage( + [active_request_count, self.vocab_size], + dtype=self.model_config.params_dtype, + tensor=mtp_logits_2d, + pp_group=self.pp_group, + ) + + # Sample speculative token using the same sampling parameters. + spec_tokens = self._sample_from_logits_2d(mtp_logits_2d) + self._sampled_mtp_tokens_cuda[depth, :active_request_count] = spec_tokens + + # Use sampled token as input for the next depth. + next_token_ids = spec_tokens + + # Clean up cached hidden states. + if has_mtp: + del unwrapped_model._decoder_hidden_states_cache + def _get_required_logit_indices( self, request_in_prefill_status_tensor: Tensor, @@ -873,6 +964,8 @@ def _sample_speculative_logits( mtp_output_tokens_jumbled_list = [] token_order_list = [] + has_mtp_logits = required_mtp_logits is not None + for request_indices, temp, top_k, top_p in self._torch_sampling_buckets: request_indices_tensor = torch.tensor( request_indices, device=token_to_request_index.device @@ -883,12 +976,13 @@ def _sample_speculative_logits( output_tokens_jumbled_list.append( self._torch_sampling_func(required_logits[required_indices, :], temp, top_k, top_p) ) - mtp_logits_slice = required_mtp_logits[:, required_indices, :] - num_spec, num_reqs, vocab = mtp_logits_slice.shape - sampled_mtp = self._torch_sampling_func( - mtp_logits_slice.reshape(num_spec * num_reqs, vocab), temp, top_k, top_p - ) - mtp_output_tokens_jumbled_list.append(sampled_mtp.reshape(num_spec, num_reqs)) + if has_mtp_logits: + mtp_logits_slice = required_mtp_logits[:, required_indices, :] + num_spec, num_reqs, vocab = mtp_logits_slice.shape + sampled_mtp = self._torch_sampling_func( + mtp_logits_slice.reshape(num_spec * num_reqs, vocab), temp, top_k, top_p + ) + mtp_output_tokens_jumbled_list.append(sampled_mtp.reshape(num_spec, num_reqs)) token_order_list.append(required_indices) output_tokens_jumbled = torch.cat(output_tokens_jumbled_list, dim=0) @@ -901,11 +995,13 @@ def _sample_speculative_logits( # Rearrange output tokens from sampling_bucket request order back to input ids order output_tokens[token_order] = output_tokens_jumbled - mtp_output_tokens_jumbled = torch.cat( - mtp_output_tokens_jumbled_list, dim=1 - ) # Shape [num_speculative_tokens, total_tokens] - mtp_output_tokens = torch.empty_like(mtp_output_tokens_jumbled) - mtp_output_tokens[:, token_order] = mtp_output_tokens_jumbled + mtp_output_tokens = None + if has_mtp_logits: + mtp_output_tokens_jumbled = torch.cat( + mtp_output_tokens_jumbled_list, dim=1 + ) # Shape [num_speculative_tokens, total_tokens] + mtp_output_tokens = torch.empty_like(mtp_output_tokens_jumbled) + mtp_output_tokens[:, token_order] = mtp_output_tokens_jumbled return output_tokens, mtp_output_tokens, repeats @@ -1024,11 +1120,13 @@ def _dynamic_step_sample_logits_and_verify_tokens( required_logits = logits.squeeze(0)[ required_logit_indices, : ] # Shape [num_required, vocab_size] - required_mtp_logits = mtp_logits[ - :, required_logit_indices, : - ] # Shape [num_speculative_tokens, num_required, vocab_size] + required_mtp_logits = None + if mtp_logits is not None: + required_mtp_logits = mtp_logits[ + :, required_logit_indices, : + ] # Shape [num_speculative_tokens, num_required, vocab_size] - # Sample tokens from logits and MTP logits. + # Sample tokens from logits (and MTP logits if provided). output_tokens, mtp_output_tokens, repeats = self._sample_speculative_logits( required_logits, required_mtp_logits, request_in_prefill_status_tensor ) @@ -1047,12 +1145,19 @@ def _dynamic_step_sample_logits_and_verify_tokens( ) ) - # Store the final sampled tokens and MTP tokens for the next forward pass. + # Store the final sampled tokens for the next forward pass. final_sampled_tokens = output_tokens[last_one_indices] self._sampled_tokens_cuda[: len(final_sampled_tokens)] = final_sampled_tokens - self._sampled_mtp_tokens_cuda[:, : len(final_sampled_tokens)] = mtp_output_tokens[ - :, last_one_indices - ] + + # Store MTP tokens if they were computed inline (non-serial path). + if mtp_output_tokens is not None: + self._sampled_mtp_tokens_cuda[:, : len(final_sampled_tokens)] = mtp_output_tokens[ + :, last_one_indices + ] + + # Store the last accepted positions in the packed sequence for serial + # MTP computation after verification. + self._last_accepted_seq_indices = required_logit_indices[last_one_indices] # Extract accepted tokens and counts for decode requests. # For prefill it is always set to 1. For decode, the first token is always accepted, @@ -1300,7 +1405,9 @@ def dummy_forward(self): model_config = get_model_config(unwrapped_model) if model_config.transformer_impl == "inference_optimized": context.maybe_initialize_symmetric_memory() - return self.inference_wrapped_model.dummy_forward() + self.inference_wrapped_model.dummy_forward() + self._dummy_serial_mtp_forward() + return # attempt to use cuda-graph if possible input_ids, position_ids = self._dynamic_step_context_init(is_dummy_forward=True) @@ -1316,9 +1423,78 @@ def dummy_forward(self): # fallback to eager dummy forward self.inference_wrapped_model.dummy_forward() + # Disable MoE padding for MTP computation + if self.model_config.moe_pad_experts_for_cuda_graph_inference: + unwrapped_model = unwrap_model(self.inference_wrapped_model.model) + set_decode_expert_padding(unwrapped_model, False) + + # When speculative decoding is active, the real EP ranks perform serial + # MTP forward passes after the main forward pass. MTP layers may contain + # MoE sublayers (inherited from the decoder spec), which require EP + # all-to-all collectives. The dummy rank must participate in these + # collectives to avoid a hang. + self._dummy_serial_mtp_forward() + # clear the context of any temporary state from the dummy forward context.reset() + def _dummy_serial_mtp_forward(self): + """Run dummy MTP forward passes to participate in EP collectives. + + When speculative decoding is active and MTP layers contain MoE sublayers + (inherited from the decoder layer spec), each serial MTP step triggers + EP all-to-all collectives. The dummy EP rank must issue matching + collective calls so the real ranks do not hang. + + This mirrors the structure of ``_compute_serial_mtp_and_sample``: + - On the last PP stage (where MTP resides): run ``compute_mtp_single_step`` + with dummy tensors so the MoE all-to-all is executed. + - When PP > 1: participate in the ``broadcast_from_last_pipeline_stage`` + that the real ranks also perform. + """ + if self.num_speculative_tokens == 0 or self.num_mtp_heads == 0: + return + if self.model_config.expert_model_parallel_size <= 1: + return + + unwrapped_model = unwrap_model(self.inference_wrapped_model.model) + + is_last_stage = is_pipeline_last_stage(self.pp_group) + has_mtp = is_last_stage and hasattr(unwrapped_model, 'mtp') + if not has_mtp and not self.model_is_pipeline_parallel: + # No MTP on this rank and no PP broadcast to participate in. + return + + device = torch.cuda.current_device() + dtype = self.model_config.params_dtype + hidden_size = self.model_config.hidden_size + num_depths = min(self.num_speculative_tokens, self.num_mtp_heads) + + dummy_hidden = None + if has_mtp: + # Minimal dummy tensors — just enough to drive the MTP layer forward + # so that the MoE all-to-all collectives are issued. + dummy_hidden = torch.zeros((1, 1, hidden_size), device=device, dtype=dtype) + dummy_token_ids = torch.zeros((1, 1), device=device, dtype=torch.long) + dummy_position_ids = torch.zeros((1, 1), device=device, dtype=torch.long) + + for depth in range(num_depths): + mtp_logits_2d = None + if has_mtp: + dummy_hidden, mtp_logits = unwrapped_model.compute_mtp_single_step( + hidden_states=dummy_hidden, + next_token_ids=dummy_token_ids, + position_ids=dummy_position_ids, + depth=depth, + ) + mtp_logits_2d = mtp_logits.squeeze(1) # [1, vocab_size] + + # Match the PP broadcast that real ranks do in _compute_serial_mtp_and_sample. + if self.model_is_pipeline_parallel: + broadcast_from_last_pipeline_stage( + [1, self.vocab_size], dtype=dtype, tensor=mtp_logits_2d, pp_group=self.pp_group + ) + def _dynamic_step_context_bookkeeping(self) -> Dict[str, Tensor]: """Update the dynamic inference context after sampling. @@ -1427,13 +1603,9 @@ async def async_generate_output_tokens_dynamic_batch( if config.moe_enable_routing_replay: RouterReplay.set_global_router_replay_action(RouterReplayAction.RECORD) - logits_and_mtp_logits = self._dynamic_step_forward_logits(input_ids, position_ids) - mtp_logits = None - if logits_and_mtp_logits.shape[0] > 1: - logits = logits_and_mtp_logits[:1] # [1, seq_len, vocab_size] - mtp_logits = logits_and_mtp_logits[1:] # [num_speculative_tokens, seq_len, vocab_size] - else: - logits = logits_and_mtp_logits + # Forward pass produces only base logits. When speculative decoding is + # active, MTP logits are computed serially after verification. + logits = self._dynamic_step_forward_logits(input_ids, position_ids) # Collect routing indices per request (must be done before context transitions) routing_indices_per_request = self._router_record_bookkeeping() @@ -1455,8 +1627,19 @@ async def async_generate_output_tokens_dynamic_batch( self._dynamic_step_sample_bookkeeping() if self.num_speculative_tokens > 0: - self._dynamic_step_sample_logits_and_verify_tokens(logits, mtp_logits, input_ids) + # Phase 1: Verify speculative tokens using base logits only. + # MTP logits are NOT passed here; they will be computed serially. + self._dynamic_step_sample_logits_and_verify_tokens(logits, None, input_ids) + # Phase 2: Rewind KV cache for rejected tokens. self._rewind_kv_cache() + + # Disable MoE padding for MTP computation + if self.model_config.moe_pad_experts_for_cuda_graph_inference: + unwrapped_model = unwrap_model(self.inference_wrapped_model.model) + set_decode_expert_padding(unwrapped_model, False) + + # Phase 3: Compute MTP serially with correct (verified) inputs. + self._compute_serial_mtp_and_sample() else: self._dynamic_step_sample_logits(logits) diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py index 716362061c2..2de628f1f8e 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py @@ -590,11 +590,20 @@ def _postprocess( if in_inference_mode: assert runtime_gather_output, "Inference must always gather TP logits" + # Check if speculative decoding is active. When it is, MTP must be + # computed *after* verification so that it is conditioned on verified + # tokens rather than stale speculative tokens from the previous step. + is_spec_decode = ( + in_inference_mode + and hasattr(inference_context, 'num_speculative_tokens') + and inference_context.num_speculative_tokens > 0 + ) + # logits and loss output_weight = None if self.share_embeddings_and_output_weights: output_weight = self.shared_embedding_or_output_weight() - if mtp_in_postprocess: + if mtp_in_postprocess and not is_spec_decode: hidden_states = self.mtp( input_ids=input_ids, position_ids=position_ids, @@ -618,13 +627,18 @@ def _postprocess( # The new process_mtp_loss function doesn't handle mtp_logits_cache, # so we manually generate and cache MTP logits when in inference mode. if in_inference_mode: - hidden_states, self._mtp_logits_cache = compute_mtp_inference_logits( - hidden_states=hidden_states, - mtp_num_layers=self.config.mtp_num_layers, - output_layer=self.output_layer, - output_weight=output_weight, - runtime_gather_output=runtime_gather_output, - ) + if is_spec_decode: + # Cache decoder hidden states for serial MTP computation + # after speculative token verification. + self._decoder_hidden_states_cache = hidden_states + else: + hidden_states, self._mtp_logits_cache = compute_mtp_inference_logits( + hidden_states=hidden_states, + mtp_num_layers=self.config.mtp_num_layers, + output_layer=self.output_layer, + output_weight=output_weight, + runtime_gather_output=runtime_gather_output, + ) else: # In training/eval, use the utility function for processing MTP loss/scaling. hidden_states = process_mtp_loss( @@ -698,6 +712,49 @@ def _postprocess( return loss + @torch.inference_mode() + def compute_mtp_single_step( + self, + hidden_states: Tensor, + next_token_ids: Tensor, + position_ids: Tensor, + depth: int, + runtime_gather_output: bool = True, + ) -> tuple: + """Compute a single MTP depth for speculative decoding. + + This is called after speculative token verification to compute MTP + predictions conditioned on verified tokens only. + + Args: + hidden_states (Tensor): Hidden states at last accepted positions [N, 1, H]. + next_token_ids (Tensor): Correct next token IDs [1, N]. + position_ids (Tensor): Position IDs for the next tokens [1, N]. + depth (int): MTP depth index (0-indexed). + runtime_gather_output (bool): Whether to gather output across TP. + + Returns: + tuple: (new_hidden_states [N, 1, H], logits [N, 1, vocab_size]). + """ + layer_idx = 0 if self.mtp.mtp_use_repeated_layer else depth + mtp_hidden = self.mtp.layers[layer_idx].forward_single_position( + hidden_states=hidden_states, + next_token_ids=next_token_ids, + position_ids=position_ids, + embedding=self.embedding, + ) + + output_weight = None + if self.share_embeddings_and_output_weights: + output_weight = self.shared_embedding_or_output_weight() + + logits, _ = self.output_layer( + mtp_hidden, weight=output_weight, runtime_gather_output=runtime_gather_output + ) + logits = self._scale_logits(logits) + + return mtp_hidden, logits + def build_schedule_plan( self, input_ids: Tensor, diff --git a/megatron/core/models/mamba/mamba_model.py b/megatron/core/models/mamba/mamba_model.py index 2da8c31d14e..65c371ef0fd 100644 --- a/megatron/core/models/mamba/mamba_model.py +++ b/megatron/core/models/mamba/mamba_model.py @@ -3,6 +3,7 @@ import logging from typing import Literal, Optional +import torch from torch import Tensor from megatron.core import tensor_parallel @@ -387,7 +388,16 @@ def forward( if self.share_embeddings_and_output_weights: output_weight = self.shared_embedding_or_output_weight() - mtp_forward_ran = self.mtp_process + # Check if speculative decoding is active. When it is, MTP must be + # computed *after* verification so that it is conditioned on verified + # tokens rather than stale speculative tokens from the previous step. + is_spec_decode = ( + in_inference_mode + and hasattr(inference_context, 'num_speculative_tokens') + and inference_context.num_speculative_tokens > 0 + ) + + mtp_forward_ran = self.mtp_process and not is_spec_decode if mtp_forward_ran: hidden_states = self.mtp( input_ids=input_ids, @@ -403,18 +413,23 @@ def forward( if not self.post_process: return hidden_states - if self.config.mtp_num_layers is not None and mtp_forward_ran: + if self.config.mtp_num_layers is not None and (mtp_forward_ran or is_spec_decode): assert self.config.mtp_num_layers > 0 # The new process_mtp_loss function doesn't handle mtp_logits_cache, # so we manually generate and cache MTP logits when in inference mode. if in_inference_mode: - hidden_states, self._mtp_logits_cache = compute_mtp_inference_logits( - hidden_states=hidden_states, - mtp_num_layers=self.config.mtp_num_layers, - output_layer=self.output_layer, - output_weight=output_weight, - runtime_gather_output=runtime_gather_output, - ) + if is_spec_decode: + # Cache decoder hidden states for serial MTP computation + # after speculative token verification. + self._decoder_hidden_states_cache = hidden_states + else: + hidden_states, self._mtp_logits_cache = compute_mtp_inference_logits( + hidden_states=hidden_states, + mtp_num_layers=self.config.mtp_num_layers, + output_layer=self.output_layer, + output_weight=output_weight, + runtime_gather_output=runtime_gather_output, + ) else: hidden_states = process_mtp_loss( hidden_states=hidden_states, @@ -471,3 +486,46 @@ def forward( loss = self.compute_language_model_loss(labels, logits) return loss + + @torch.inference_mode() + def compute_mtp_single_step( + self, + hidden_states: Tensor, + next_token_ids: Tensor, + position_ids: Tensor, + depth: int, + runtime_gather_output: bool = True, + ) -> tuple: + """Compute a single MTP depth for speculative decoding. + + This is called after speculative token verification to compute MTP + predictions conditioned on verified tokens only. + + Args: + hidden_states (Tensor): Hidden states at last accepted positions [N, 1, H]. + next_token_ids (Tensor): Correct next token IDs [1, N]. + position_ids (Tensor): Position IDs for the next tokens [1, N]. + depth (int): MTP depth index (0-indexed). + runtime_gather_output (bool): Whether to gather output across TP. + + Returns: + tuple: (new_hidden_states [N, 1, H], logits [N, 1, vocab_size]). + """ + layer_idx = 0 if self.mtp.mtp_use_repeated_layer else depth + mtp_hidden = self.mtp.layers[layer_idx].forward_single_position( + hidden_states=hidden_states, + next_token_ids=next_token_ids, + position_ids=position_ids, + embedding=self.embedding, + ) + + output_weight = None + if self.share_embeddings_and_output_weights: + output_weight = self.shared_embedding_or_output_weight() + + logits, _ = self.output_layer( + mtp_hidden, weight=output_weight, runtime_gather_output=runtime_gather_output + ) + logits = self._scale_logits(logits) + + return mtp_hidden, logits diff --git a/megatron/core/transformer/multi_token_prediction.py b/megatron/core/transformer/multi_token_prediction.py index 29dd8fef986..3426d83b7b2 100755 --- a/megatron/core/transformer/multi_token_prediction.py +++ b/megatron/core/transformer/multi_token_prediction.py @@ -1036,6 +1036,53 @@ def _postprocess(self, hidden_states: torch.Tensor): return hidden_states + def forward_single_position( + self, + hidden_states: Tensor, + next_token_ids: Tensor, + position_ids: Tensor, + embedding: Callable, + attention_mask: Optional[Tensor] = None, + rotary_pos_emb: Optional[Tensor] = None, + rotary_pos_cos: Optional[Tensor] = None, + rotary_pos_sin: Optional[Tensor] = None, + inference_params=None, + packed_seq_params: Optional[PackedSeqParams] = None, + sequence_len_offset: Optional[Tensor] = None, + ) -> Tensor: + """Forward for single positions without roll_tensor (speculative decoding). + + Unlike the regular forward which rolls input_ids to get the next token's + embedding, this method directly takes the correct next_token_ids. This is + used in speculative decoding where the correct next token is known after + verification. + + Args: + hidden_states (Tensor): Hidden states at positions of interest [N, B, H]. + next_token_ids (Tensor): The correct next token IDs [B, N]. + position_ids (Tensor): Position IDs for the next tokens [B, N]. + embedding (Callable): The embedding module. + + Returns: + Tensor: MTP hidden states [N, B, H]. + """ + decoder_input = embedding(input_ids=next_token_ids, position_ids=position_ids) + hidden_states = make_viewless_tensor( + inp=hidden_states, requires_grad=False, keep_graph=False + ) + hidden_states = self._proj_and_transformer_layer( + hidden_states=hidden_states, + decoder_input=decoder_input, + attention_mask=attention_mask, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + inference_params=inference_params, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + ) + return hidden_states + def _checkpointed_forward(self, forward_func, *args, **kwargs): def checkpoint_handler(): """Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`""" diff --git a/tests/unit_tests/inference/contexts/test_dynamic_context.py b/tests/unit_tests/inference/contexts/test_dynamic_context.py index 40379d05163..72ce2878c01 100644 --- a/tests/unit_tests/inference/contexts/test_dynamic_context.py +++ b/tests/unit_tests/inference/contexts/test_dynamic_context.py @@ -61,6 +61,7 @@ def _get_dynamic_context( layer_type_list=None, paused_buffer_size_gb=None, num_cuda_graphs=None, + num_speculative_tokens=0, ): if is_hybrid_model: if layer_type_list is None: @@ -94,6 +95,7 @@ def _get_dynamic_context( ), block_size_tokens=block_size_tokens, max_tokens=max_tokens, + num_speculative_tokens=num_speculative_tokens, mamba_inference_state_config=mamba_inference_state_config, use_flashinfer_fused_rope=None, # default to using flash-infer if available # this is for compatibility with the LTS environment @@ -1416,8 +1418,9 @@ def test_max_requests_less_than_tp_size(self): @rounder_override(64) @pytest.mark.parametrize("is_hybrid_model", [False, True]) @pytest.mark.parametrize("num_cuda_graphs", [-1, 16, 32]) + @pytest.mark.parametrize("num_speculative_tokens", [0, 3]) def test_add_dummy_requests_for_expert_parallel_step_matches_slow_path( - self, is_hybrid_model: bool, num_cuda_graphs: int + self, is_hybrid_model: bool, num_cuda_graphs: int, num_speculative_tokens: int ): """The fast path (add_dummy_requests_for_expert_parallel_step) must leave the same observable state as the slow path @@ -1441,10 +1444,12 @@ def test_add_dummy_requests_for_expert_parallel_step_matches_slow_path( else None ), num_cuda_graphs=num_cuda_graphs, + num_speculative_tokens=num_speculative_tokens, ) smallest = min(ctx.cuda_graph_batch_dimensions_list) N = smallest.decode_req_count + T = smallest.token_count # N * (num_speculative_tokens + 1) assert smallest.prefill_req_count == 0, "smallest graph must be decode-only" # --- slow path (reference) --- @@ -1456,10 +1461,10 @@ def test_add_dummy_requests_for_expert_parallel_step_matches_slow_path( slow_request_query_lengths = ctx.request_query_lengths[:N].clone() slow_request_kv_length_offsets = ctx.request_kv_length_offsets[:N].clone() slow_request_to_kv_block_ids_col0 = ctx.request_to_kv_block_ids[:N, 0].clone() - slow_token_to_block_idx = ctx.token_to_block_idx[:N].clone() - slow_token_to_local_pos = ctx.token_to_local_position_within_kv_block[:N].clone() + slow_token_to_block_idx = ctx.token_to_block_idx[:T].clone() + slow_token_to_local_pos = ctx.token_to_local_position_within_kv_block[:T].clone() if is_hybrid_model: - slow_token_to_request_idx = ctx.token_to_request_idx[:N].clone() + slow_token_to_request_idx = ctx.token_to_request_idx[:T].clone() slow_mamba = ctx.mamba_metadata.request_to_mamba_state_idx[:N].clone() # --- reset and run fast path --- @@ -1478,13 +1483,13 @@ def test_add_dummy_requests_for_expert_parallel_step_matches_slow_path( # 3. Token-level state dummy_block_idx = ctx.block_allocator.dummy_block_idx - assert torch.all(ctx.token_to_block_idx[:N] == dummy_block_idx) - assert torch.equal(ctx.token_to_block_idx[:N], slow_token_to_block_idx) - assert torch.equal(ctx.token_to_local_position_within_kv_block[:N], slow_token_to_local_pos) + assert torch.all(ctx.token_to_block_idx[:T] == dummy_block_idx) + assert torch.equal(ctx.token_to_block_idx[:T], slow_token_to_block_idx) + assert torch.equal(ctx.token_to_local_position_within_kv_block[:T], slow_token_to_local_pos) if is_hybrid_model: # 4. token_to_request_idx - assert torch.equal(ctx.token_to_request_idx[:N], slow_token_to_request_idx) + assert torch.equal(ctx.token_to_request_idx[:T], slow_token_to_request_idx) # 5. Mamba state slots allocated (indices may differ, but must be valid and unique) fast_mamba = ctx.mamba_metadata.request_to_mamba_state_idx[:N] diff --git a/tests/unit_tests/inference/engines/test_dynamic_engine.py b/tests/unit_tests/inference/engines/test_dynamic_engine.py index 2cbed75a02d..c0ecd7d2107 100644 --- a/tests/unit_tests/inference/engines/test_dynamic_engine.py +++ b/tests/unit_tests/inference/engines/test_dynamic_engine.py @@ -1072,7 +1072,7 @@ def test_parallel_inference( if tp_size == 1 and pp_size == 1 and ep_size == 1: pytest.skip(reason="Test requires tp_size > 1 or pp_size > 1 or ep_size > 1") elif not torch.distributed.is_initialized(): - Utils.initialize_distributed() + pytest.skip("Distributed not initialized") world_size = torch.distributed.get_world_size() min_world_size = tp_size * pp_size * ep_size if world_size < min_world_size: @@ -2040,6 +2040,8 @@ def test_speculative_decoding_with_early_termination(self): unwrapped_model = env.engine.controller.inference_wrapped_model.model # Mock forward to return deterministic data so speculative tokens are always accepted + hidden_size = unwrapped_model.config.hidden_size + def mock_mtp_forward(*args, **kwargs): tokens = kwargs.get("tokens", args[0] if args else kwargs.get("input_ids")) @@ -2052,17 +2054,24 @@ def mock_mtp_forward(*args, **kwargs): ) base_logits[:, :, 0] = 100.0 # High probability for token 0 - unwrapped_model._mtp_logits_cache = torch.zeros( - 3, - tokens.size(1), - test_config.vocab_size, - device=tokens.device, - dtype=torch.bfloat16, + # Cache hidden states for serial MTP computation + unwrapped_model._decoder_hidden_states_cache = torch.zeros( + tokens.size(1), 1, hidden_size, device=tokens.device, dtype=torch.bfloat16 ) - unwrapped_model._mtp_logits_cache[:, :, 0] = 100.0 # High probability for token 0 return base_logits + def mock_compute_mtp_single_step( + hidden_states, next_token_ids, position_ids, depth, runtime_gather_output=True + ): + n = hidden_states.size(0) + logits = torch.zeros( + n, 1, test_config.vocab_size, device=hidden_states.device, dtype=torch.bfloat16 + ) + logits[:, :, 0] = 100.0 # High probability for token 0 + return hidden_states, logits + unwrapped_model.forward = mock_mtp_forward + unwrapped_model.compute_mtp_single_step = mock_compute_mtp_single_step env.engine._add_request(env.requests[0]) env.engine.schedule_waiting_requests() @@ -2093,11 +2102,11 @@ def test_speculative_block_boundary_crossing(self): """ test_config = DynamicEngineTestConfig( num_requests=1, - min_prompt_length=256, - max_prompt_length=256, + min_prompt_length=4, + max_prompt_length=4, num_tokens_to_generate=3, num_speculative_tokens=2, - context_block_size_tokens=256, # Exactly matches prompt length + context_block_size_tokens=4, # Exactly matches prompt length context_max_requests=16, model_provider="gpt", materialize_only_last_token_logits=False, @@ -2157,6 +2166,7 @@ def test_speculative_stop_word_hit(self): env = self._build_test_env(test_config) unwrapped_model = env.engine.controller.inference_wrapped_model.model + hidden_size = unwrapped_model.config.hidden_size # Mock forward to deterministically output an ascending sequence (1->2->3...) def mock_deterministic_forward(*args, **kwargs): @@ -2169,19 +2179,26 @@ def mock_deterministic_forward(*args, **kwargs): next_toks = (tokens + 1).clamp(max=test_config.vocab_size - 1) base_logits.scatter_(2, next_toks.unsqueeze(-1), 100.0) - mtp_logits = torch.zeros( - 2, s, test_config.vocab_size, device=tokens.device, dtype=torch.bfloat16 + # Cache hidden states for serial MTP computation + unwrapped_model._decoder_hidden_states_cache = torch.zeros( + s, 1, hidden_size, device=tokens.device, dtype=torch.bfloat16 ) - mtp1_toks = (tokens + 2).clamp(max=test_config.vocab_size - 1) - mtp_logits[0].scatter_(1, mtp1_toks.squeeze(0).unsqueeze(-1), 100.0) - - mtp2_toks = (tokens + 3).clamp(max=test_config.vocab_size - 1) - mtp_logits[1].scatter_(1, mtp2_toks.squeeze(0).unsqueeze(-1), 100.0) - - unwrapped_model._mtp_logits_cache = mtp_logits return base_logits + def mock_compute_mtp_single_step( + hidden_states, next_token_ids, position_ids, depth, runtime_gather_output=True + ): + n = hidden_states.size(0) + # Predict next_token_ids + 1 (continuing the ascending sequence) + pred_toks = (next_token_ids + 1).clamp(max=test_config.vocab_size - 1) + logits = torch.zeros( + n, 1, test_config.vocab_size, device=hidden_states.device, dtype=torch.bfloat16 + ) + logits.scatter_(2, pred_toks.transpose(0, 1).unsqueeze(-1), 100.0) + return hidden_states, logits + unwrapped_model.forward = mock_deterministic_forward + unwrapped_model.compute_mtp_single_step = mock_compute_mtp_single_step # Add the request formally to ensure all internal state tensors align env.engine.add_request( @@ -2233,6 +2250,7 @@ def test_speculative_long_stop_word_hit(self): env = self._build_test_env(test_config) unwrapped_model = env.engine.controller.inference_wrapped_model.model + hidden_size = unwrapped_model.config.hidden_size # Mock forward to deterministically output an ascending sequence def mock_deterministic_forward(*args, **kwargs): @@ -2245,19 +2263,26 @@ def mock_deterministic_forward(*args, **kwargs): next_toks = (tokens + 1).clamp(max=test_config.vocab_size - 1) base_logits.scatter_(2, next_toks.unsqueeze(-1), 100.0) - mtp_logits = torch.zeros( - 2, s, test_config.vocab_size, device=tokens.device, dtype=torch.bfloat16 + # Cache hidden states for serial MTP computation + unwrapped_model._decoder_hidden_states_cache = torch.zeros( + s, 1, hidden_size, device=tokens.device, dtype=torch.bfloat16 ) - mtp1_toks = (tokens + 2).clamp(max=test_config.vocab_size - 1) - mtp_logits[0].scatter_(1, mtp1_toks.squeeze(0).unsqueeze(-1), 100.0) - - mtp2_toks = (tokens + 3).clamp(max=test_config.vocab_size - 1) - mtp_logits[1].scatter_(1, mtp2_toks.squeeze(0).unsqueeze(-1), 100.0) - - unwrapped_model._mtp_logits_cache = mtp_logits return base_logits + def mock_compute_mtp_single_step( + hidden_states, next_token_ids, position_ids, depth, runtime_gather_output=True + ): + n = hidden_states.size(0) + # Predict next_token_ids + 1 (continuing the ascending sequence) + pred_toks = (next_token_ids + 1).clamp(max=test_config.vocab_size - 1) + logits = torch.zeros( + n, 1, test_config.vocab_size, device=hidden_states.device, dtype=torch.bfloat16 + ) + logits.scatter_(2, pred_toks.transpose(0, 1).unsqueeze(-1), 100.0) + return hidden_states, logits + unwrapped_model.forward = mock_deterministic_forward + unwrapped_model.compute_mtp_single_step = mock_compute_mtp_single_step env.engine.add_request( request_id=0, @@ -2310,6 +2335,7 @@ def test_speculative_stop_word_truncates_trailing_tokens(self): env = self._build_test_env(test_config) unwrapped_model = env.engine.controller.inference_wrapped_model.model + hidden_size = unwrapped_model.config.hidden_size # Mock forward to deterministically output an ascending sequence (1->2->3...) def mock_deterministic_forward(*args, **kwargs): @@ -2322,19 +2348,26 @@ def mock_deterministic_forward(*args, **kwargs): next_toks = (tokens + 1).clamp(max=test_config.vocab_size - 1) base_logits.scatter_(2, next_toks.unsqueeze(-1), 100.0) - mtp_logits = torch.zeros( - 2, s, test_config.vocab_size, device=tokens.device, dtype=torch.bfloat16 + # Cache hidden states for serial MTP computation + unwrapped_model._decoder_hidden_states_cache = torch.zeros( + s, 1, hidden_size, device=tokens.device, dtype=torch.bfloat16 ) - mtp1_toks = (tokens + 2).clamp(max=test_config.vocab_size - 1) - mtp_logits[0].scatter_(1, mtp1_toks.squeeze(0).unsqueeze(-1), 100.0) - - mtp2_toks = (tokens + 3).clamp(max=test_config.vocab_size - 1) - mtp_logits[1].scatter_(1, mtp2_toks.squeeze(0).unsqueeze(-1), 100.0) - - unwrapped_model._mtp_logits_cache = mtp_logits return base_logits + def mock_compute_mtp_single_step( + hidden_states, next_token_ids, position_ids, depth, runtime_gather_output=True + ): + n = hidden_states.size(0) + # Predict next_token_ids + 1 (continuing the ascending sequence) + pred_toks = (next_token_ids + 1).clamp(max=test_config.vocab_size - 1) + logits = torch.zeros( + n, 1, test_config.vocab_size, device=hidden_states.device, dtype=torch.bfloat16 + ) + logits.scatter_(2, pred_toks.transpose(0, 1).unsqueeze(-1), 100.0) + return hidden_states, logits + unwrapped_model.forward = mock_deterministic_forward + unwrapped_model.compute_mtp_single_step = mock_compute_mtp_single_step env.engine.add_request( request_id=0, @@ -2391,8 +2424,12 @@ def test_speculative_sequence_length_double_counting(self): ) env = self._build_test_env(test_config) - # Mock forward pass to return deterministic disparate logits so - # speculative tokens are completely rejected every time. + # Mock forward pass to return deterministic base logits. + # Speculative tokens will be wrong (predicted by MTP as tokens + 5) + # to guarantee rejection every time. + model = env.engine.controller.inference_wrapped_model.model + hidden_size = model.config.hidden_size + def mock_mtp_forward_reject(*args, **kwargs): tokens = kwargs.get("tokens", args[0] if args else kwargs.get("input_ids")) b, s = tokens.shape @@ -2404,23 +2441,26 @@ def mock_mtp_forward_reject(*args, **kwargs): next_toks = (tokens + 1).clamp(max=test_config.vocab_size - 1) base_logits.scatter_(2, next_toks.unsqueeze(-1), 100.0) - # Speculative model consistently predicts wildly wrong tokens to guarantee rejection - model = env.engine.controller.inference_wrapped_model.model - mtp_logits = torch.zeros( - test_config.num_speculative_tokens, - s, - test_config.vocab_size, - device=tokens.device, - dtype=torch.bfloat16, + # Cache hidden states for serial MTP computation + model._decoder_hidden_states_cache = torch.zeros( + s, 1, hidden_size, device=tokens.device, dtype=torch.bfloat16 ) - wrong_toks = (tokens + 5).clamp(max=test_config.vocab_size - 1) - mtp_logits[0].scatter_(1, wrong_toks.squeeze(0).unsqueeze(-1), 100.0) - mtp_logits[1].scatter_(1, wrong_toks.squeeze(0).unsqueeze(-1), 100.0) - - model._mtp_logits_cache = mtp_logits return base_logits - env.engine.controller.inference_wrapped_model.model.forward = mock_mtp_forward_reject + def mock_compute_mtp_single_step( + hidden_states, next_token_ids, position_ids, depth, runtime_gather_output=True + ): + n = hidden_states.size(0) + # Predict wildly wrong tokens (+ 5) to guarantee rejection + wrong_toks = (next_token_ids + 5).clamp(max=test_config.vocab_size - 1) + logits = torch.zeros( + n, 1, test_config.vocab_size, device=hidden_states.device, dtype=torch.bfloat16 + ) + logits.scatter_(2, wrong_toks.transpose(0, 1).unsqueeze(-1), 100.0) + return hidden_states, logits + + model.forward = mock_mtp_forward_reject + model.compute_mtp_single_step = mock_compute_mtp_single_step env.engine.add_request( request_id=0, @@ -2456,15 +2496,15 @@ def test_speculative_decoding_with_eviction_and_swapping(self): # Very constrained memory environment to force pausing and eviction test_config = DynamicEngineTestConfig( num_requests=3, - min_prompt_length=256, - max_prompt_length=256, - num_tokens_to_generate=512, - context_block_size_tokens=256, + min_prompt_length=16, + max_prompt_length=16, + num_tokens_to_generate=32, + context_block_size_tokens=16, num_speculative_tokens=2, - # 640 KB translates to 3 blocks. + # 40 KB translates to 3 blocks. # 3 requests * 3 blocks per request (1 prompt + 2 gen) = 9 blocks needed. # This guarantees we will run out of active memory mid-generation. - context_buffer_size_gb=0.00064, + context_buffer_size_gb=0.00004, context_paused_buffer_size_gb=0.0, # 0 paused buffer forces immediate eviction model_provider="gpt", materialize_only_last_token_logits=False, @@ -2476,6 +2516,7 @@ def test_speculative_decoding_with_eviction_and_swapping(self): print(f"total block count = {env.engine.context.block_allocator.total_count}") unwrapped_model = env.engine.controller.inference_wrapped_model.model + hidden_size = unwrapped_model.config.hidden_size # Mock forward pass to return safe, deterministic logits to avoid NaN/Inf crashes # in torch.multinomial caused by randomly initialized weights. @@ -2488,26 +2529,31 @@ def mock_safe_forward(*args, **kwargs): ) base_logits[:, :, 0] = 100.0 # Force model to deterministically pick token 0 - mtp_logits = torch.zeros( - test_config.num_speculative_tokens, - s, - test_config.vocab_size, - device=tokens.device, - dtype=torch.bfloat16, + # Cache hidden states for serial MTP computation + unwrapped_model._decoder_hidden_states_cache = torch.zeros( + s, 1, hidden_size, device=tokens.device, dtype=torch.bfloat16 ) - mtp_logits[:, :, 0] = 100.0 # Force speculative heads to also pick token 0 - - unwrapped_model._mtp_logits_cache = mtp_logits return base_logits + def mock_compute_mtp_single_step( + hidden_states, next_token_ids, position_ids, depth, runtime_gather_output=True + ): + n = hidden_states.size(0) + logits = torch.zeros( + n, 1, test_config.vocab_size, device=hidden_states.device, dtype=torch.bfloat16 + ) + logits[:, :, 0] = 100.0 # Force speculative heads to also pick token 0 + return hidden_states, logits + unwrapped_model.forward = mock_safe_forward + unwrapped_model.compute_mtp_single_step = mock_compute_mtp_single_step # Add all requests at once. They will all start prefill, but as they generate # and request more blocks, the engine will run out of active blocks. # Since paused_buffer_size is 0, any request that pauses will immediately # overflow the paused buffer and trigger an eviction. for request in env.requests: - request.sampling_params.num_tokens_to_generate = 512 + request.sampling_params.num_tokens_to_generate = 32 env.engine._add_request(request) eviction_occurred = False @@ -2540,7 +2586,7 @@ def mock_safe_forward(*args, **kwargs): merged_req.status == Status.COMPLETED ), f"Request {request_id} failed to complete." assert ( - len(merged_req.generated_tokens) == 511 + len(merged_req.generated_tokens) == 31 ), f"Request {request_id} didn't generate expected tokens." @pytest.mark.internal @@ -2556,50 +2602,25 @@ def test_speculative_decoding_with_prefix_caching(self): """ test_config = DynamicEngineTestConfig( num_requests=0, # Added manually below - min_prompt_length=256, - max_prompt_length=256, - num_tokens_to_generate=128, + min_prompt_length=8, + max_prompt_length=8, + num_tokens_to_generate=4, num_speculative_tokens=2, enable_prefix_caching=True, # Set at config level - context_block_size_tokens=256, # Ensure exact 1 block per prompt + context_block_size_tokens=8, # Ensure exact 1 block per prompt materialize_only_last_token_logits=False, model_provider="gpt", - context_max_tokens=4096, + context_max_tokens=512, context_max_requests=512, ) env = self._build_test_env(test_config) - unwrapped_model = env.engine.controller.inference_wrapped_model.model - - # Mock forward pass to return safe, deterministic logits to avoid NaN/Inf crashes - def mock_safe_forward(*args, **kwargs): - tokens = kwargs.get("tokens", args[0] if args else kwargs.get("input_ids")) - b, s = tokens.shape - - base_logits = torch.zeros( - b, s, test_config.vocab_size, device=tokens.device, dtype=torch.bfloat16 - ) - base_logits[:, :, 0] = 100.0 - - mtp_logits = torch.zeros( - test_config.num_speculative_tokens, - s, - test_config.vocab_size, - device=tokens.device, - dtype=torch.bfloat16, - ) - mtp_logits[:, :, 0] = 100.0 - unwrapped_model._mtp_logits_cache = mtp_logits - return base_logits - - unwrapped_model.forward = mock_safe_forward - # Create two pairs of requests with identical shared prefixes. shared_prompt_a = torch.randint( - 0, test_config.vocab_size - 1, (256,), dtype=torch.int64, device='cuda' + 0, test_config.vocab_size - 1, (8,), dtype=torch.int64, device='cuda' ) shared_prompt_b = torch.randint( - 0, test_config.vocab_size - 1, (256,), dtype=torch.int64, device='cuda' + 0, test_config.vocab_size - 1, (8,), dtype=torch.int64, device='cuda' ) prompts = [shared_prompt_a, shared_prompt_a, shared_prompt_b, shared_prompt_b] @@ -2609,7 +2630,7 @@ def mock_safe_forward(*args, **kwargs): env.engine.add_request( request_id=i, prompt=prompt.clone(), - sampling_params=SamplingParams(num_tokens_to_generate=128, termination_id=99), + sampling_params=SamplingParams(num_tokens_to_generate=4, termination_id=99), ) # First, run schedule_waiting_requests and ONE step to allocate the prefill blocks. @@ -2653,55 +2674,30 @@ def test_speculative_decoding_chunked_prefill_and_prefix_caching(self): """ test_config = DynamicEngineTestConfig( num_requests=0, - min_prompt_length=512, - max_prompt_length=512, - num_tokens_to_generate=128, + min_prompt_length=16, + max_prompt_length=16, + num_tokens_to_generate=4, num_speculative_tokens=2, materialize_only_last_token_logits=False, enable_chunked_prefill=True, enable_prefix_caching=True, # Set at config level - context_block_size_tokens=256, + context_block_size_tokens=8, model_provider="gpt", - context_max_tokens=1536, # Force chunking + context_max_tokens=48, # Force chunking context_max_requests=48, ) env = self._build_test_env(test_config) - unwrapped_model = env.engine.controller.inference_wrapped_model.model - - # Mock forward pass to return safe, deterministic logits to avoid NaN/Inf crashes - def mock_safe_forward(*args, **kwargs): - tokens = kwargs.get("tokens", args[0] if args else kwargs.get("input_ids")) - b, s = tokens.shape - - base_logits = torch.zeros( - b, s, test_config.vocab_size, device=tokens.device, dtype=torch.bfloat16 - ) - base_logits[:, :, 0] = 100.0 - - mtp_logits = torch.zeros( - test_config.num_speculative_tokens, - s, - test_config.vocab_size, - device=tokens.device, - dtype=torch.bfloat16, - ) - mtp_logits[:, :, 0] = 100.0 - unwrapped_model._mtp_logits_cache = mtp_logits - return base_logits - - unwrapped_model.forward = mock_safe_forward - # Create identical prompts for all 4 requests shared_prompt = torch.randint( - 0, test_config.vocab_size - 1, (512,), dtype=torch.int64, device='cuda' + 0, test_config.vocab_size - 1, (16,), dtype=torch.int64, device='cuda' ) for i in range(4): env.engine.add_request( request_id=i, prompt=shared_prompt.clone(), - sampling_params=SamplingParams(num_tokens_to_generate=128, termination_id=99), + sampling_params=SamplingParams(num_tokens_to_generate=4, termination_id=99), ) while env.engine.has_unfinished_requests(): From a267a9c26b3694267bc98cc1534c62d1e5bff2a1 Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Tue, 10 Mar 2026 03:02:07 -0700 Subject: [PATCH 74/76] Restore dynamic_engine unit test changes Signed-off-by: Keshav Santhanam --- .../inference/engines/test_dynamic_engine.py | 55 +++++++++---------- 1 file changed, 25 insertions(+), 30 deletions(-) diff --git a/tests/unit_tests/inference/engines/test_dynamic_engine.py b/tests/unit_tests/inference/engines/test_dynamic_engine.py index c0ecd7d2107..4117ef39b92 100644 --- a/tests/unit_tests/inference/engines/test_dynamic_engine.py +++ b/tests/unit_tests/inference/engines/test_dynamic_engine.py @@ -1072,7 +1072,7 @@ def test_parallel_inference( if tp_size == 1 and pp_size == 1 and ep_size == 1: pytest.skip(reason="Test requires tp_size > 1 or pp_size > 1 or ep_size > 1") elif not torch.distributed.is_initialized(): - pytest.skip("Distributed not initialized") + Utils.initialize_distributed() world_size = torch.distributed.get_world_size() min_world_size = tp_size * pp_size * ep_size if world_size < min_world_size: @@ -2102,11 +2102,11 @@ def test_speculative_block_boundary_crossing(self): """ test_config = DynamicEngineTestConfig( num_requests=1, - min_prompt_length=4, - max_prompt_length=4, + min_prompt_length=256, + max_prompt_length=256, num_tokens_to_generate=3, num_speculative_tokens=2, - context_block_size_tokens=4, # Exactly matches prompt length + context_block_size_tokens=256, # Exactly matches prompt length context_max_requests=16, model_provider="gpt", materialize_only_last_token_logits=False, @@ -2496,15 +2496,12 @@ def test_speculative_decoding_with_eviction_and_swapping(self): # Very constrained memory environment to force pausing and eviction test_config = DynamicEngineTestConfig( num_requests=3, - min_prompt_length=16, - max_prompt_length=16, - num_tokens_to_generate=32, - context_block_size_tokens=16, + min_prompt_length=256, + max_prompt_length=256, + num_tokens_to_generate=512, + context_block_size_tokens=256, num_speculative_tokens=2, - # 40 KB translates to 3 blocks. - # 3 requests * 3 blocks per request (1 prompt + 2 gen) = 9 blocks needed. - # This guarantees we will run out of active memory mid-generation. - context_buffer_size_gb=0.00004, + context_buffer_size_gb=0.00064, # 640 KB context_paused_buffer_size_gb=0.0, # 0 paused buffer forces immediate eviction model_provider="gpt", materialize_only_last_token_logits=False, @@ -2513,8 +2510,6 @@ def test_speculative_decoding_with_eviction_and_swapping(self): env = self._build_test_env(test_config) - print(f"total block count = {env.engine.context.block_allocator.total_count}") - unwrapped_model = env.engine.controller.inference_wrapped_model.model hidden_size = unwrapped_model.config.hidden_size @@ -2553,7 +2548,7 @@ def mock_compute_mtp_single_step( # Since paused_buffer_size is 0, any request that pauses will immediately # overflow the paused buffer and trigger an eviction. for request in env.requests: - request.sampling_params.num_tokens_to_generate = 32 + request.sampling_params.num_tokens_to_generate = 512 env.engine._add_request(request) eviction_occurred = False @@ -2586,7 +2581,7 @@ def mock_compute_mtp_single_step( merged_req.status == Status.COMPLETED ), f"Request {request_id} failed to complete." assert ( - len(merged_req.generated_tokens) == 31 + len(merged_req.generated_tokens) == 511 ), f"Request {request_id} didn't generate expected tokens." @pytest.mark.internal @@ -2602,25 +2597,25 @@ def test_speculative_decoding_with_prefix_caching(self): """ test_config = DynamicEngineTestConfig( num_requests=0, # Added manually below - min_prompt_length=8, - max_prompt_length=8, + min_prompt_length=256, + max_prompt_length=256, num_tokens_to_generate=4, num_speculative_tokens=2, enable_prefix_caching=True, # Set at config level - context_block_size_tokens=8, # Ensure exact 1 block per prompt + context_block_size_tokens=256, # Ensure exact 1 block per prompt materialize_only_last_token_logits=False, model_provider="gpt", - context_max_tokens=512, + context_max_tokens=4096, context_max_requests=512, ) env = self._build_test_env(test_config) # Create two pairs of requests with identical shared prefixes. shared_prompt_a = torch.randint( - 0, test_config.vocab_size - 1, (8,), dtype=torch.int64, device='cuda' + 0, test_config.vocab_size - 1, (256,), dtype=torch.int64, device='cuda' ) shared_prompt_b = torch.randint( - 0, test_config.vocab_size - 1, (8,), dtype=torch.int64, device='cuda' + 0, test_config.vocab_size - 1, (256,), dtype=torch.int64, device='cuda' ) prompts = [shared_prompt_a, shared_prompt_a, shared_prompt_b, shared_prompt_b] @@ -2630,7 +2625,7 @@ def test_speculative_decoding_with_prefix_caching(self): env.engine.add_request( request_id=i, prompt=prompt.clone(), - sampling_params=SamplingParams(num_tokens_to_generate=4, termination_id=99), + sampling_params=SamplingParams(num_tokens_to_generate=128, termination_id=99), ) # First, run schedule_waiting_requests and ONE step to allocate the prefill blocks. @@ -2674,30 +2669,30 @@ def test_speculative_decoding_chunked_prefill_and_prefix_caching(self): """ test_config = DynamicEngineTestConfig( num_requests=0, - min_prompt_length=16, - max_prompt_length=16, - num_tokens_to_generate=4, + min_prompt_length=512, + max_prompt_length=512, + num_tokens_to_generate=128, num_speculative_tokens=2, materialize_only_last_token_logits=False, enable_chunked_prefill=True, enable_prefix_caching=True, # Set at config level - context_block_size_tokens=8, + context_block_size_tokens=256, model_provider="gpt", - context_max_tokens=48, # Force chunking + context_max_tokens=1536, # Force chunking context_max_requests=48, ) env = self._build_test_env(test_config) # Create identical prompts for all 4 requests shared_prompt = torch.randint( - 0, test_config.vocab_size - 1, (16,), dtype=torch.int64, device='cuda' + 0, test_config.vocab_size - 1, (512,), dtype=torch.int64, device='cuda' ) for i in range(4): env.engine.add_request( request_id=i, prompt=shared_prompt.clone(), - sampling_params=SamplingParams(num_tokens_to_generate=4, termination_id=99), + sampling_params=SamplingParams(num_tokens_to_generate=128, termination_id=99), ) while env.engine.has_unfinished_requests(): From 9922180331b16bbec428bb168b35a23af313e5ae Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Tue, 10 Mar 2026 03:06:50 -0700 Subject: [PATCH 75/76] Bug fix Signed-off-by: Keshav Santhanam --- .../inference/contexts/dynamic_context.py | 4 +- .../contexts/test_dynamic_context.py | 38 +++++++++---------- 2 files changed, 22 insertions(+), 20 deletions(-) diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index 98996b0cbfd..28559943481 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -1382,7 +1382,9 @@ def add_dummy_requests_for_expert_parallel_step(self) -> None: # 3. Token-level state consumed by the triton KV append kernel. self.token_to_block_idx[0:T] = dummy_block_idx - self.token_to_local_position_within_kv_block[0:T] = 0 + self.token_to_local_position_within_kv_block[0:T] = ( + torch.arange(T, device=self.token_to_block_idx.device) % tokens_per_request + ) if self.is_hybrid_model: # 4. token_to_request_idx: needed by mamba_metadata.update() for hybrid models. diff --git a/tests/unit_tests/inference/contexts/test_dynamic_context.py b/tests/unit_tests/inference/contexts/test_dynamic_context.py index 72ce2878c01..e16ebaf4353 100644 --- a/tests/unit_tests/inference/contexts/test_dynamic_context.py +++ b/tests/unit_tests/inference/contexts/test_dynamic_context.py @@ -1536,7 +1536,7 @@ def test_update_requests_speculative(self): inference_config = InferenceConfig( max_sequence_length=128, buffer_size_gb=0.01, - block_size_tokens=32, + block_size_tokens=256, num_speculative_tokens=2, unified_memory_level=0, ) @@ -1589,9 +1589,9 @@ def test_speculative_boundary_crossing(self): params_dtype=torch.float32, num_layers=2, kv_channels=8, num_attention_heads=2 ) inference_config = InferenceConfig( - max_sequence_length=128, - buffer_size_gb=0.01, - block_size_tokens=4, # Small block size to force boundary crossing + max_sequence_length=1024, + buffer_size_gb=0.1, + block_size_tokens=256, # FA2-compatible block size to force boundary crossing num_speculative_tokens=2, unified_memory_level=0, ) @@ -1606,11 +1606,11 @@ def test_speculative_boundary_crossing(self): ctx.request_query_lengths[0] = 1 ctx.request_kv_block_counts[0] = 1 - # Length is 2, meaning existing tokens are at indices 0 and 1. - # The last inserted token was at offset 1. - # Adding 3 tokens places them at offsets 2, 3, and 4 (crosses block size of 4). - ctx.request_kv_length_offsets[0] = 2 - ctx.request_last_kv_block_offset[0] = 1 + # Length is 254, meaning existing tokens are at indices 0..253. + # The last inserted token was at offset 253. + # Adding 3 tokens places them at offsets 254, 255, and 256 (crosses block size of 256). + ctx.request_kv_length_offsets[0] = 254 + ctx.request_last_kv_block_offset[0] = 253 # Allocate one initial block manually blocks = ctx.block_allocator.allocate_memory_blocks(1) @@ -1662,9 +1662,9 @@ def test_paused_speculative_tokens_tracking(self): params_dtype=torch.float32, num_layers=2, kv_channels=8, num_attention_heads=2 ) inference_config = InferenceConfig( - max_sequence_length=128, - buffer_size_gb=0.01, - block_size_tokens=16, + max_sequence_length=1024, + buffer_size_gb=0.1, + block_size_tokens=256, num_speculative_tokens=2, unified_memory_level=0, ) @@ -1677,11 +1677,11 @@ def test_paused_speculative_tokens_tracking(self): ctx.request_ids[:2] = torch.tensor([10, 11]) ctx.request_query_lengths[:2] = 1 - # Request 0 is at offset 14. Adding 1 sampled + 2 spec = 3 tokens will push it to 17, - # which is >= block_size_tokens (16). It will require a new block. + # Request 0 is at offset 254. Adding 1 sampled + 2 spec = 3 tokens will push it to 257, + # which is >= block_size_tokens (256). It will require a new block. # Request 1 is at offset 5. It will not require a new block. - ctx.request_kv_length_offsets[:2] = torch.tensor([14, 5]) - ctx.request_last_kv_block_offset[:2] = torch.tensor([14, 5]) + ctx.request_kv_length_offsets[:2] = torch.tensor([254, 5]) + ctx.request_last_kv_block_offset[:2] = torch.tensor([254, 5]) ctx.request_kv_block_counts[:2] = 1 # Allocate blocks @@ -1733,8 +1733,8 @@ def test_speculative_tokens_less_than_block_size_assert(self): inference_config = InferenceConfig( max_sequence_length=128, buffer_size_gb=0.01, - block_size_tokens=16, - num_speculative_tokens=16, + block_size_tokens=256, + num_speculative_tokens=256, unified_memory_level=0, ) with pytest.raises( @@ -1753,7 +1753,7 @@ def test_swap_book_keeping_tensors_with_speculative_tokens(self): inference_config = InferenceConfig( max_sequence_length=128, buffer_size_gb=0.01, - block_size_tokens=32, + block_size_tokens=256, num_speculative_tokens=2, unified_memory_level=0, ) From 4097bf12ba863e2a2f2f9bc2adccfb1b8b0e1047 Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Tue, 10 Mar 2026 11:42:35 -0700 Subject: [PATCH 76/76] Minimize diff Signed-off-by: Keshav Santhanam --- .../inference/contexts/dynamic_context.py | 5 --- .../core/inference/engines/dynamic_engine.py | 43 ++++++++++--------- megatron/core/transformer/attention.py | 8 +--- 3 files changed, 24 insertions(+), 32 deletions(-) diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index fc253af979b..28559943481 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -2059,11 +2059,6 @@ def _swap_book_keeping_tensors( tensor_swap(self.request_output_lengths, src_idxs, dst_idxs) tensor_swap(self.request_ids, src_idxs, dst_idxs) tensor_swap(next_tokens, src_idxs, dst_idxs) - if new_speculative_tokens is not None: - # new_speculative_tokens has shape [num_spec, num_requests]; swap columns. - temp = new_speculative_tokens[:, src_idxs].clone() - new_speculative_tokens[:, src_idxs] = new_speculative_tokens[:, dst_idxs] - new_speculative_tokens[:, dst_idxs] = temp tensor_swap(self.request_to_kv_block_ids, src_idxs, dst_idxs) tensor_swap(self.request_kv_block_counts, src_idxs, dst_idxs) tensor_swap(self.request_last_kv_block_id, src_idxs, dst_idxs) diff --git a/megatron/core/inference/engines/dynamic_engine.py b/megatron/core/inference/engines/dynamic_engine.py index 9c1bdf29c60..102aaa716a7 100644 --- a/megatron/core/inference/engines/dynamic_engine.py +++ b/megatron/core/inference/engines/dynamic_engine.py @@ -1009,22 +1009,25 @@ def post_process_requests( if self.num_speculative_tokens > 0: accepted_tokens = list(filter(lambda tok: tok != -1, accepted_tokens_list)) + + # Track acceptance statistics for logging (decode requests only). + # Prefill requests don't propose speculative tokens, so including + # them would inflate the proposed count and deflate the rate. + # A request in its first generation step (empty generated_tokens) + # was in prefill this step. + if len(request.generated_tokens) > 0: + self._spec_tokens_proposed += self.num_speculative_tokens + self._spec_tokens_accepted += len(accepted_tokens) + + # The order `accepted_tokens + tokens` is correct here. + # `accepted_tokens` contains the sequence of + # successfully verified draft tokens. `tokens` (from `sample`) is the + # brand new token generated by the target model based on that accepted prefix. + # Therefore, the newly sampled token must go at the end of the sequence. tokens = accepted_tokens + tokens - request: DynamicInferenceRequest = self.get_request(request_id) - # Track acceptance statistics for logging (decode requests only). - # Prefill requests don't propose speculative tokens, so including - # them would inflate the proposed count and deflate the rate. - # A request in its first generation step (empty generated_tokens) - # was in prefill this step. - if len(request.generated_tokens) > 0: - self._spec_tokens_proposed += self.num_speculative_tokens - self._spec_tokens_accepted += len(accepted_tokens) - - num_stop_word_trim = 0 + num_stop_word_trim = 0 if request_id != self.context.chunked_prefill_request_id: - - # Skip appending token for requests being finished due to stop words # (they already have their final token from the previous step) # If the request already has more tokens, then we only append as much as is necessary @@ -1039,11 +1042,11 @@ def post_process_requests( if request_id not in self.stop_word_being_finished_ids: is_first_token = len(request.generated_tokens) == 0 request.generated_tokens += tokens - first_event_in_step = None + first_token_event = None if self.track_generated_token_events: for token in tokens: if block_allocator.enable_prefix_caching: - evt = request.add_event_generated_token( + event = request.add_event_generated_token( token, blocks_total=block_allocator.total_count, blocks_hashed_total=blocks_allocated, @@ -1051,18 +1054,16 @@ def post_process_requests( blocks_ref_count=blocks_ref_count, ) else: - evt = request.add_event_generated_token( + event = request.add_event_generated_token( token, blocks_total=block_allocator.total_count, blocks_hashed_total=blocks_allocated, blocks_hashed_active=blocks_hashed_active, ) - if first_event_in_step is None: - first_event_in_step = evt + if first_token_event is None: + first_token_event = event if is_first_token: - if self.track_generated_token_events: - first_token_event = first_event_in_step - else: + if not self.track_generated_token_events: first_token_event = DynamicInferenceEvent( type=DynamicInferenceEventType.GENERATED_TOKEN, payload={"token_id": tokens[0]}, diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index 0ac7e78fae3..310a59bde35 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -60,9 +60,7 @@ rearrange = None try: - from flash_attn_3.flash_attn_interface import ( - _flash_attn_forward, - ) + from flash_attn_3.flash_attn_interface import _flash_attn_forward from flash_attn_3.flash_attn_interface import ( flash_attn_with_kvcache as flash_attn3_with_kvcache, ) @@ -73,9 +71,7 @@ if not HAVE_FA3: try: - from flashattn_hopper.flash_attn_interface import ( - _flash_attn_forward, - ) + from flashattn_hopper.flash_attn_interface import _flash_attn_forward from flashattn_hopper.flash_attn_interface import ( flash_attn_with_kvcache as flash_attn3_with_kvcache, )