You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
online softmax make cumulative sum $l$ dynamic update while a new element added. It's more effiecent method is to update sum $l$ with block-wise element added. This advantage is we could parallelism to compute online softmax
we seperate compute different block $l^{(t)}$ and $m^{(t)}$
it’s easy to update global $m,l$
$$
\begin{align}
m=\max({x_{:2N}})&=\max(\max({x_{:N}}),\max(x_{N+1:2N}))\
&=max(m^{(1)},m^{(2)})
\end{align}
$$
but the $l$ NOT update as follow:
$$
l=l_{:2N}\neq l^{(1)}+l^{(2)}
$$
So we based block sum $l^{(t)}$ and max $m^{(t)}$ to online update global $l$
we do multi block online softmax by for-loop :
$$
l_\text{new}= l_\text{old} (e^{m_\text{old}-m}) +l_\text{new}(e^{m_{\text{new}}-m})
$$
noted current block max/sum as $m_\text{new},l_\text{new}$ ,the m is $m=\max(m_\text{old},m_\text{new})$, and then update:
$$
l_\text{old} \leftarrow l_\text{new}
$$
batch online softmax
In attention machine, we need softmax for attention score matrix
$$
S=QK^T,S\in\mathbb{R}^{N\times N}
$$
the query is row-wise matrix $Q\in\mathbb{R}^{N\times D}$;
and we need softmax attention score:
$$
P_{i,:}=\text{softmax}(S_{i,:})
$$
when we use online-softmax, we could parallel update k-row max $M^{(t)}$ and row-wise sum $L^{(t)}$,
$$
L = L^{(1)}(e^{M^{(1)}-M})+L^{(2)}(e^{M^{(2)}-M})
$$
where $L,M\in\mathbb{R}^{k\times 1}$