-
Notifications
You must be signed in to change notification settings - Fork 8
Add spectrum codec #10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
2ce99b6
Add spectrum tokenizer
LTMeyer 9cd08ba
Rename FiniteScaleQuantizer->FiniteScalarQuantizer
LTMeyer d37fb30
Add channel mask as input
LTMeyer 1cf702b
Add test to ensure previous results consistency
LTMeyer 8c37cc4
Make the tokenizer a pytorch module
LTMeyer aa71c1e
Update test to load only one model checkpoint
LTMeyer 2034934
Merge branch 'add_tokenizers' into add-spectrum-tokenizer
LTMeyer 254b871
Merge branch 'main' into add-spectrum-tokenizer
LTMeyer a6c1574
Update SpectrumCodec
LTMeyer b8549c1
Add vector_quantize_pytorch to dependencies
LTMeyer 213fad9
Specify version of vector_quantize_pytorch
LTMeyer 0ebf3e1
Merge remote-tracking branch 'origin/main' into add-spectrum-tokenizer
EiffL 84250ef
Merge branch 'add-spectrum-tokenizer' of github.com:PolymathicAI/AION…
EiffL 5bba1fb
update spectrum
EiffL 740b3fc
fixed issues
EiffL 1730a01
fixing formatting
EiffL File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,6 @@ | ||
| tests/test_data/image_codec_decoded_batch.pt filter=lfs diff=lfs merge=lfs -text | ||
| tests/test_data/image_codec_encoded_batch.pt filter=lfs diff=lfs merge=lfs -text | ||
| tests/test_data/image_codec_input_batch.pt filter=lfs diff=lfs merge=lfs -text | ||
| tests/test_data/SPECTRUM_input_batch.pt filter=lfs diff=lfs merge=lfs -text | ||
| tests/test_data/SPECTRUM_decoded_batch.pt filter=lfs diff=lfs merge=lfs -text | ||
| tests/test_data/SPECTRUM_encoded_batch.pt filter=lfs diff=lfs merge=lfs -text |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,172 @@ | ||
| import torch | ||
|
|
||
| from aion.codecs.modules.utils import LayerNorm, GRN | ||
|
|
||
|
|
||
| class ConvNextBlock1d(torch.nn.Module): | ||
| """ConvNeXtV2 Block. | ||
| Modified to 1D from the original 2D implementation from https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py | ||
|
|
||
| Args: | ||
| dim (int): Number of input channels. | ||
| drop_path (float): Stochastic depth rate. Default: 0.0 | ||
| """ | ||
|
|
||
| def __init__(self, dim: int): | ||
| super().__init__() | ||
| self.dwconv = torch.nn.Conv1d( | ||
| dim, dim, kernel_size=7, padding=3, groups=dim | ||
| ) # depthwise conv | ||
| self.norm = LayerNorm(dim, eps=1e-6) | ||
| self.pwconv1 = torch.nn.Linear( | ||
| dim, 4 * dim | ||
| ) # pointwise/1x1 convs, implemented with linear layers | ||
| self.act = torch.nn.GELU() | ||
| self.grn = GRN(4 * dim) | ||
| self.pwconv2 = torch.nn.Linear(4 * dim, dim) | ||
|
|
||
| def forward(self, x): | ||
| y = self.dwconv(x) | ||
| y = y.permute(0, 2, 1) # (B, C, N) -> (B, N, C) | ||
| y = self.norm(y) | ||
| y = self.pwconv1(y) | ||
| y = self.act(y) | ||
| y = self.grn(y) | ||
| y = self.pwconv2(y) | ||
| y = y.permute(0, 2, 1) # (B, N, C) -> (B, C, N) | ||
|
|
||
| y = x + y | ||
| return y | ||
|
|
||
|
|
||
| class ConvNextEncoder1d(torch.nn.Module): | ||
| r"""ConvNeXt encoder. | ||
|
|
||
| Modified from https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py | ||
|
|
||
| Args: | ||
| in_chans : Number of input image channels. Default: 3 | ||
| depths : Number of blocks at each stage. Default: [3, 3, 9, 3] | ||
| dims : Feature dimension at each stage. Default: [96, 192, 384, 768] | ||
| drop_path_rate : Stochastic depth rate. Default: 0. | ||
| layer_scale_init_value : Init value for Layer Scale. Default: 1e-6. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| in_chans: int = 2, | ||
| depths: tuple[int, ...] = (3, 3, 9, 3), | ||
| dims: tuple[int, ...] = (96, 192, 384, 768), | ||
| ): | ||
| super().__init__() | ||
| assert len(depths) == len(dims), "depths and dims should have the same length" | ||
| num_layers = len(depths) | ||
|
|
||
| self.downsample_layers = ( | ||
| torch.nn.ModuleList() | ||
| ) # stem and 3 intermediate downsampling conv layers | ||
| stem = torch.nn.Sequential( | ||
| torch.nn.Conv1d(in_chans, dims[0], kernel_size=4, stride=4), | ||
| LayerNorm(dims[0], eps=1e-6, data_format="channels_first"), | ||
| ) | ||
| self.downsample_layers.append(stem) | ||
| for i in range(num_layers - 1): | ||
| downsample_layer = torch.nn.Sequential( | ||
| LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), | ||
| torch.nn.Conv1d(dims[i], dims[i + 1], kernel_size=2, stride=2), | ||
| ) | ||
| self.downsample_layers.append(downsample_layer) | ||
|
|
||
| self.stages = torch.nn.ModuleList() | ||
| for i in range(num_layers): | ||
| stage = torch.nn.Sequential( | ||
| *[ | ||
| ConvNextBlock1d( | ||
| dim=dims[i], | ||
| ) | ||
| for j in range(depths[i]) | ||
| ] | ||
| ) | ||
| self.stages.append(stage) | ||
|
|
||
| self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first") | ||
|
|
||
| self.apply(self._init_weights) | ||
|
|
||
| def _init_weights(self, m): | ||
| if isinstance(m, (torch.nn.Conv1d, torch.nn.Linear)): | ||
| torch.nn.init.trunc_normal_(m.weight, std=0.02) | ||
| torch.nn.init.constant_(m.bias, 0) | ||
|
|
||
| def forward(self, x): | ||
| for ds, st in zip(self.downsample_layers, self.stages): | ||
| x = ds(x) | ||
| x = st(x) | ||
| return self.norm(x) | ||
|
|
||
|
|
||
| class ConvNextDecoder1d(torch.nn.Module): | ||
| r"""ConvNeXt decoder. Essentially a mirrored version of the encoder. | ||
|
|
||
| Args: | ||
| in_chans (int): Number of input image channels. Default: 3 | ||
| depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] | ||
| dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] | ||
| drop_path_rate (float): Stochastic depth rate. Default: 0. | ||
| layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| in_chans=768, | ||
| depths=[3, 3, 9, 3], | ||
| dims=[384, 192, 96, 2], | ||
| ): | ||
| super().__init__() | ||
| assert len(depths) == len(dims), "depths and dims should have the same length" | ||
| num_layers = len(depths) | ||
|
|
||
| self.upsample_layers = torch.nn.ModuleList() | ||
|
|
||
| stem = torch.nn.Sequential( | ||
| torch.nn.ConvTranspose1d(in_chans, dims[0], kernel_size=2, stride=2), | ||
| LayerNorm(dims[0], eps=1e-6, data_format="channels_first"), | ||
| ) | ||
| self.upsample_layers.append(stem) | ||
|
|
||
| for i in range(num_layers - 1): | ||
| upsample_layer = torch.nn.Sequential( | ||
| LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), | ||
| torch.nn.ConvTranspose1d( | ||
| dims[i], | ||
| dims[i + 1], | ||
| kernel_size=2 if i < (num_layers - 2) else 4, | ||
| stride=2 if i < (num_layers - 2) else 4, | ||
| ), | ||
| ) | ||
| self.upsample_layers.append(upsample_layer) | ||
|
|
||
| self.stages = torch.nn.ModuleList() | ||
| for i in range(num_layers): | ||
| stage = torch.nn.Sequential( | ||
| *[ | ||
| ConvNextBlock1d( | ||
| dim=dims[i], | ||
| ) | ||
| for j in range(depths[i]) | ||
| ] | ||
| ) | ||
| self.stages.append(stage) | ||
|
|
||
| self.apply(self._init_weights) | ||
|
|
||
| def _init_weights(self, m): | ||
| if isinstance(m, (torch.nn.Conv1d, torch.nn.Linear)): | ||
| torch.nn.init.trunc_normal_(m.weight, std=0.02) | ||
| torch.nn.init.constant_(m.bias, 0) | ||
|
|
||
| def forward(self, x): | ||
| for us, st in zip(self.upsample_layers, self.stages): | ||
| x = us(x) | ||
| x = st(x) | ||
| return x |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,110 @@ | ||
| import torch | ||
| from jaxtyping import Float | ||
|
|
||
|
|
||
| def interp1d( | ||
| x: Float[torch.Tensor, " b n"], | ||
| y: Float[torch.Tensor, " b n"], | ||
| xnew: Float[torch.Tensor, " b m"], | ||
| mask_value: float | None = 0.0, | ||
| ) -> Float[torch.Tensor, " b m"]: | ||
| """Linear interpolation of a 1-D tensor using torch.searchsorted. | ||
| Assumes that x and xnew are sorted in increasing order. | ||
|
|
||
| Args: | ||
| x: The x-coordinates of the data points, shape [batch, N]. | ||
| y: The y-coordinates of the data points, shape [batch, N]. | ||
| xnew: The x-coordinates of the interpolated points, shape [batch, M]. | ||
| mask_value: The value to use for xnew outside the range of x. | ||
| Returns: | ||
| The y-coordinates of the interpolated points, shape [batch, M]. | ||
| """ | ||
| # Find the indices where xnew should be inserted in sorted_x | ||
| # Given a point xnew[i] in xnew, return j where x[j] is the nearest point in x such that | ||
| # x[j] < xnew[i], except if the nearest point in x has x[j] = xnew[i] then return j - 1. | ||
| indices = torch.searchsorted(x, xnew) - 1 | ||
|
|
||
| # We can define a local linear approx of the grad in each interval | ||
| # between two points in x, and we would like to use this to interpolate | ||
| # y at those points in xnew which lie inside the range of x, otherwise | ||
| # interpolated_y is masked for points in xnew outside the range of x. | ||
| # There are len(x) - 1 such intervals between points in x, having indices | ||
| # ranging between 0 and len(x) - 2. Points with xnew < min(x) will be | ||
| # assigned indices of -1 and points with xnew > max(x) will be assigned | ||
| # indices equal to len(x). These are not valid segment indices, but we can | ||
| # clamp them to 0 and len(x) - 2 respectively to avoid breaking the | ||
| # calculation of the slope variable. The nonsense values we obtain outside | ||
| # the range of x will be discarded when masking. | ||
| indices = torch.clamp(indices, 0, x.shape[1] - 1 - 1) | ||
|
|
||
| slopes = (y[:, :-1] - y[:, 1:]) / (x[:, :-1] - x[:, 1:]) | ||
|
|
||
| # Interpolate the y-coordinates | ||
| ynew = torch.gather(y, 1, indices) + ( | ||
| xnew - torch.gather(x, 1, indices) | ||
| ) * torch.gather(slopes, 1, indices) | ||
|
|
||
| # Mask out the values that are outside the valid range | ||
| mask = (xnew < x[..., 0].reshape(-1, 1)) | (xnew > x[..., -1].reshape(-1, 1)) | ||
| ynew[mask] = mask_value | ||
|
|
||
| return ynew | ||
|
|
||
|
|
||
| class LatentSpectralGrid(torch.nn.Module): | ||
| def __init__(self, lambda_min: float, resolution: float, num_pixels: int): | ||
| """ | ||
| Initialize a latent grid to represent spectra from multiple resolutions. | ||
|
|
||
| Args: | ||
| lambda_min: The minimum wavelength value, in Angstrom. | ||
| resolution: The resolution of the spectra, in Angstrom per pixel. | ||
| num_pixels: The number of pixels in the spectra. | ||
|
|
||
| """ | ||
| super().__init__() | ||
| self.register_buffer("lambda_min", torch.tensor(lambda_min)) | ||
| self.register_buffer("resolution", torch.tensor(resolution)) | ||
| self.register_buffer("length", torch.tensor(num_pixels)) | ||
| self.register_buffer( | ||
| "_wavelength", | ||
| (torch.arange(0, num_pixels) * resolution + lambda_min).reshape( | ||
| 1, num_pixels | ||
| ), | ||
| ) | ||
|
|
||
| @property | ||
| def wavelength(self) -> Float[torch.Tensor, " n"]: | ||
| return self._wavelength.squeeze() | ||
|
|
||
| def to_observed( | ||
| self, | ||
| x_latent: Float[torch.Tensor, " b n"], | ||
| wavelength: Float[torch.Tensor, " b m"], | ||
| ) -> Float[torch.Tensor, " b m"]: | ||
| """Transforms the latent representation to the observed wavelength grid. | ||
|
|
||
| Args: | ||
| x_latent: The latent representation, [batch, self.num_pixels]. | ||
| wavelength: The observed wavelength grid, [batch, M]. | ||
|
|
||
| Returns: | ||
| The transformed representation on the observed wavelength grid. | ||
| """ | ||
| b = x_latent.shape[0] | ||
| return interp1d(self._wavelength.repeat([b, 1]), x_latent, wavelength) | ||
|
|
||
| def to_latent( | ||
| self, x_obs: Float[torch.Tensor, "b m"], wavelength: Float[torch.Tensor, "b m"] | ||
| ) -> Float[torch.Tensor, "b n"]: | ||
| """Transforms the observed representation to the latent wavelength grid. | ||
|
|
||
| Args: | ||
| x_obs: The observed representation, [batch, N]. | ||
| wavelength: The wavelength grid, [batch, N]. | ||
|
|
||
| Returns: | ||
| The transformed representation on the latent wavelength grid. | ||
| """ | ||
| b = x_obs.shape[0] | ||
| return interp1d(wavelength, x_obs, self._wavelength.repeat([b, 1])) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,45 @@ | ||
| import torch | ||
| import torch.nn.functional as F | ||
| from einops import rearrange | ||
|
|
||
|
|
||
| class LayerNorm(torch.nn.Module): | ||
| r"""LayerNorm that supports two data formats: channels_last (default) or channels_first. | ||
| The ordering of the dimensions in the inputs. channels_last corresponds to inputs with | ||
| shape (batch_size, height, width, channels) while channels_first corresponds to inputs | ||
| with shape (batch_size, channels, height, width). | ||
| """ | ||
|
|
||
| def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): | ||
| super().__init__() | ||
| self.weight = torch.nn.Parameter(torch.ones(normalized_shape)) | ||
| self.bias = torch.nn.Parameter(torch.zeros(normalized_shape)) | ||
| self.eps = eps | ||
| self.data_format = data_format | ||
| if self.data_format not in ["channels_last", "channels_first"]: | ||
| raise NotImplementedError | ||
| self.normalized_shape = (normalized_shape,) | ||
|
|
||
| def forward(self, x): | ||
| if self.data_format == "channels_last": | ||
| return F.layer_norm( | ||
| x, self.normalized_shape, self.weight, self.bias, self.eps | ||
| ) | ||
| elif self.data_format == "channels_first": | ||
| x = rearrange(x, "b c ... -> b ... c") | ||
| x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) | ||
| return rearrange(x, "b ... c -> b c ...") | ||
|
|
||
|
|
||
| class GRN(torch.nn.Module): | ||
| """GRN (Global Response Normalization) layer""" | ||
|
|
||
| def __init__(self, dim): | ||
| super().__init__() | ||
| self.gamma = torch.nn.Parameter(torch.zeros(1, 1, dim)) | ||
| self.beta = torch.nn.Parameter(torch.zeros(1, 1, dim)) | ||
|
|
||
| def forward(self, x): | ||
| Gx = torch.norm(x, p=2, dim=(1,), keepdim=True) | ||
| Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) | ||
| return self.gamma * (x * Nx) + self.beta + x |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.