Conversation
models/S4.py
Outdated
|
|
||
| def forward(self, u, **kwargs): # absorbs return_output and transformer src mask | ||
| """Input and output shape (B, H, L)""" | ||
| u = u.transpose(-1, -2) |
There was a problem hiding this comment.
Just checking, this line seems to be:
if not self.transposed: u = u.transpose(-1, -2)
in the original implementation, do we always want to transpose here?
There was a problem hiding this comment.
Yeah I was going to get rid of self.transpose, but I may as well keep it and use the flag as you previously did
| return y | ||
|
|
||
|
|
||
| class S4Block(nn.Module): |
There was a problem hiding this comment.
I am a bit confused by the transpose that happens here. My previous implementation of the S4D block (utilising some of your docstrings):
class S4DBlock(nn.Module):
"""
A single S4 block that applies:
1. S4D module
2. (Optionally) a linear layer + GLU activation,
3. Residual connection
4. Layer Normalization
5. Dropout
Args:
model_dim (int): Dimensionality of the model (d_model).
dropout_rate (float): Probability of an element to be zeroed in Dropout.
use_glu (bool): Whether to apply a Linear -> GLU stage after the residual.
"""
def __init__(self, model_dim: int, dropout_rate: float = 0.1, use_glu: bool = False):
super().__init__()
self.s4d = S4D(
d_model=model_dim,
transposed=False, # use (B, L, D) shape
)
self.norm = nn.LayerNorm(model_dim)
self.drop = nn.Dropout(p=dropout_rate)
self.use_glu = use_glu
if self.use_glu:
self.post_linear = nn.Linear(model_dim, 2 * model_dim)
else:
self.post_linear = None
self.state = None
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the S4Block.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, seq_len, model_dim).
Returns:
torch.Tensor: Output tensor of same shape (batch_size, seq_len, model_dim).
"""
y, _ = self.s4d(x)
y = y + x # residual
if self.use_glu:
y_glu = self.post_linear(y)
y_glu = F.glu(y_glu, dim=-1)
y = y + y_glu
y = self.norm(y)
y = self.drop(y)
return y
@torch.no_grad()
def step(self, x: torch.Tensor) -> torch.Tensor:
"""
Stepwise inference: one token at a time.
"""
if self.state is None:
self.s4d.setup_step()
batch_size = x.shape[0]
self.state = self.s4d.default_state(batch_size)
y, self.state = self.s4d.step(x, self.state)
y = y + x
if self.use_glu:
y_glu = self.post_linear(y)
y_glu = F.glu(y_glu, dim=-1)
y = y + y_glu
y = self.norm(y)
return y| x = self.embedding(x) | ||
|
|
||
| for block in self.blocks: | ||
| x = block.step(x) |
There was a problem hiding this comment.
Related to my previous comment, where I showed the implementation I was previously using, that had a step method defined in S4DBlock, but you don't seem to have defined one here?
There was a problem hiding this comment.
Yeah we only used the step method for generating the snake game engine outputs, which we no longer include as a baseline, so I thought best not to include
No description provided.