Skip to content

Question about PositionalEmbedding used for DiscreteDiffusionSDE #54

@ShirongLiu

Description

@ShirongLiu

In the forward function of the class PositionalEmbedding,there is
x = x.ger(freqs.to(x.dtype))

However, if x.dtype is int64, freqs.to(int64) will lose some values.

For example, in line 528 of diffusion/diffusionsde.py , the class DiscreteDiffusionSDE's sample method, t.dtype is torch.long.
t = torch.full((n_samples,), sample_step_schedule[i], dtype=torch.long, device=self.device)

Maybe my understanding is not correct, would you please to show where to handle this case? Thank you!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions