Skip to content

MMvec refactor #166

@mortonjt

Description

@mortonjt

We're going to go pytorch OR numpyro. The framework will have the following skeleton

model.py (mmvec.py)

import torch
import torch.nn
from torch.distributions import Multinomial

class MMvec(nn.Module):
    def __init__(self, num_microbes, num_metabolites, latent_dim):
        self.encoder = nn.Embedding(num_microbes, latent_dim)
        self.decoder = nn.Sequential([nn.Linear(latent_dim, num_metabolite), nn.Softmax()])
        # TODO : may want to have a better softmax

    def forward(X, Y):
        """ X is one-hot encodings (B x num_microbes).  Y is metabolite abundances (B x num_metabolites).  B is the batch size""" 
        z = self.encoder(X)
        pred_y = self.decoder(z)
        lp = Multinomial(pred_y).log_prob(Y).mean()
        return lp

train.py (could use Pytorch lightning)

The wishlist

  • Early stopping (see video for example)
  • Arviz for diagnostics diagnostics
  • Typing would be great. See torchtyping
  • Torchtests could be cool also. See torchtest
  • Being Bayesian would be nice. SWAG is the laziest approach

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions