Skip to content

S4#38

Merged
Benjamin-Walker merged 4 commits intomainfrom
S4
Oct 15, 2025
Merged

S4#38
Benjamin-Walker merged 4 commits intomainfrom
S4

Conversation

@Benjamin-Walker
Copy link
Owner

No description provided.

models/S4.py Outdated

def forward(self, u, **kwargs): # absorbs return_output and transformer src mask
"""Input and output shape (B, H, L)"""
u = u.transpose(-1, -2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just checking, this line seems to be:
if not self.transposed: u = u.transpose(-1, -2)
in the original implementation, do we always want to transpose here?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I was going to get rid of self.transpose, but I may as well keep it and use the flag as you previously did

return y


class S4Block(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a bit confused by the transpose that happens here. My previous implementation of the S4D block (utilising some of your docstrings):

class S4DBlock(nn.Module):
    """
    A single S4 block that applies:
      1. S4D module
      2. (Optionally) a linear layer + GLU activation,
      3. Residual connection
      4. Layer Normalization
      5. Dropout

    Args:
        model_dim (int): Dimensionality of the model (d_model).
        dropout_rate (float): Probability of an element to be zeroed in Dropout.
        use_glu (bool): Whether to apply a Linear -> GLU stage after the residual.
    """

    def __init__(self, model_dim: int, dropout_rate: float = 0.1, use_glu: bool = False):
        super().__init__()

        self.s4d = S4D(
            d_model=model_dim,
            transposed=False,  # use (B, L, D) shape
        )

        self.norm = nn.LayerNorm(model_dim)
        self.drop = nn.Dropout(p=dropout_rate)

        self.use_glu = use_glu
        if self.use_glu:
            self.post_linear = nn.Linear(model_dim, 2 * model_dim)
        else:
            self.post_linear = None

        self.state = None

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the S4Block.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, seq_len, model_dim).

        Returns:
            torch.Tensor: Output tensor of same shape (batch_size, seq_len, model_dim).
        """
        y, _ = self.s4d(x)
        y = y + x  # residual

        if self.use_glu:
            y_glu = self.post_linear(y)
            y_glu = F.glu(y_glu, dim=-1)
            y = y + y_glu

        y = self.norm(y)
        y = self.drop(y)
        return y

    @torch.no_grad()
    def step(self, x: torch.Tensor) -> torch.Tensor:
        """
        Stepwise inference: one token at a time.
        """
        if self.state is None:
            self.s4d.setup_step()
            batch_size = x.shape[0]
            self.state = self.s4d.default_state(batch_size)

        y, self.state = self.s4d.step(x, self.state)
        y = y + x 

        if self.use_glu:
            y_glu = self.post_linear(y)
            y_glu = F.glu(y_glu, dim=-1)
            y = y + y_glu

        y = self.norm(y)
        return y

x = self.embedding(x)

for block in self.blocks:
x = block.step(x)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Related to my previous comment, where I showed the implementation I was previously using, that had a step method defined in S4DBlock, but you don't seem to have defined one here?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we only used the step method for generating the snake game engine outputs, which we no longer include as a baseline, so I thought best not to include

@Benjamin-Walker Benjamin-Walker merged commit aecfae5 into main Oct 15, 2025
1 check passed
@Benjamin-Walker Benjamin-Walker deleted the S4 branch October 15, 2025 10:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants