Skip to content
Merged

V4 #7

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@ jobs:
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: pip
# No `cache: pip` on purpose: the resolved env (torch + lightning +
# scanpy + numba + cuda libs) is >2 GB of wheels, and the post-job
# cache-save tar reliably hits the 10-minute step timeout on the
# GitHub-hosted runners. A cold resolve from a warm PyPI mirror
# only takes ~25s, so caching is a net loss here.

- name: Install scsims with test extras
run: |
Expand Down
13 changes: 7 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,20 +116,21 @@ labeled fine-tuning set.
from scsims import SIMS
import scanpy as sc

# Stage 1: unsupervised pretraining on a large unlabeled corpus.
# No class_label needed -- the pretrainer never reads cell type labels.
unlabeled = sc.read_h5ad("big_unlabeled_corpus.h5ad")
sims = SIMS(data=unlabeled, class_label="cell_type") # label col is ignored at this stage

# Stage 1: unsupervised pretraining. Cell type labels are not used.
sims = SIMS(data=unlabeled)
sims.pretrain(
pretraining_ratio=0.2,
accelerator="gpu",
devices=1,
max_epochs=50,
)

# Stage 2: supervised fine-tuning. SIMS automatically detects the
# attached pretrainer and warm-starts the classifier from its encoder
# weights before fitting.
# Stage 2: supervised fine-tuning on a smaller labeled set. Build a
# fresh SIMS with the labels, load the pretrainer checkpoint, and
# train(). SIMS automatically warm-starts the classifier from the
# pretrained encoder before fitting.
labeled = sc.read_h5ad("smaller_labeled_set.h5ad")
sims = SIMS(data=labeled, class_label="cell_type")
sims.load_pretrainer("./sims_pretrain_checkpoints/best.ckpt")
Expand Down
21 changes: 16 additions & 5 deletions scsims/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import anndata as an
import numpy as np
import pandas as pd
import torch
from scipy.sparse import issparse
from sklearn.model_selection import train_test_split
Expand Down Expand Up @@ -186,7 +187,7 @@ def clean_sample(

def generate_dataloaders(
data: an.AnnData,
class_label: str,
class_label: Optional[str] = None,
test_prop=0.2,
stratify=True,
batch_size: int = 16,
Expand All @@ -199,11 +200,21 @@ def generate_dataloaders(
if isinstance(data, str):
data = an.read_h5ad(data)

current_labels = data.obs.loc[:, class_label]
if class_label is None:
# Unsupervised mode (e.g. pretraining): no real labels exist. Use a
# zero placeholder so AnnDatasetMatrix still yields (features, label)
# tuples — the pretrainer's training step discards the label
# component anyway. Stratification has no meaning here.
current_labels = pd.Series(
np.zeros(data.shape[0], dtype=np.int64),
index=data.obs.index,
)
stratify = False
else:
current_labels = data.obs.loc[:, class_label]

# make sure data can be stratified
if stratify:
if len(current_labels.unique()) < 3:
# make sure data can be stratified
if stratify and len(current_labels.unique()) < 3:
warnings.warn(
"One class has less than 3 samples, disabling stratification"
)
Expand Down
36 changes: 23 additions & 13 deletions scsims/lightning_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,14 @@ def __init__(
self.setup()

def prepare_data(self):
unique_targets = self.data.obs.loc[:, self.class_label].unique()
label_encoder = LabelEncoder().fit(unique_targets)
# `class_label` is optional: pretraining workflows pass an unlabeled
# AnnData and never set it. In that case there's nothing to encode.
if self.class_label is None:
self.label_encoder = None
return

self.label_encoder = label_encoder
unique_targets = self.data.obs.loc[:, self.class_label].unique()
self.label_encoder = LabelEncoder().fit(unique_targets)

if not pd.api.types.is_numeric_dtype(self.data.obs.loc[:, self.class_label]):
print("Numerically encoding class labels")
Expand All @@ -86,18 +90,22 @@ def setup(self, stage: Optional[str] = None):
self.trainloader = loaders[0]
self.valloader = None
self.testloader = None
else:
else:
self.trainloader, self.valloader, self.testloader = loaders

print("Calculating weights")
labels = self.data.obs.loc[:, self.class_label].values
self.weights = torch.from_numpy(
compute_class_weight(
y=labels,
classes=np.unique(labels),
class_weight="balanced",
)
).float()
if self.class_label is None:
# No labels => no class weights. Pretraining doesn't use them.
self.weights = None
else:
print("Calculating weights")
labels = self.data.obs.loc[:, self.class_label].values
self.weights = torch.from_numpy(
compute_class_weight(
y=labels,
classes=np.unique(labels),
class_weight="balanced",
)
).float()

self.setuped = True

Expand All @@ -112,6 +120,8 @@ def test_dataloader(self):

@cached_property
def num_labels(self):
if self.class_label is None:
return None
return self.data.obs.loc[:, self.class_label].nunique()

@cached_property
Expand Down
7 changes: 7 additions & 0 deletions scsims/scvi_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,13 @@ def load_pretrainer(self, weights_path: str, **kwargs) -> SIMSPretrainer:
# ------------------------------------------------------------------

def train(self, *args, **kwargs):
if self.datamodule.class_label is None:
raise ValueError(
"SIMS.train() is supervised and requires a class_label. Build "
"the SIMS instance with `SIMS(data=adata, class_label=...)` "
"before calling .train(). For unsupervised pretraining on an "
"unlabeled dataset, use SIMS(data=adata).pretrain(...) instead."
)
print("Beginning training")
if not hasattr(self, "_trainer"):
self.setup_trainer(*args, **kwargs)
Expand Down
36 changes: 36 additions & 0 deletions tests/test_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,42 @@ def test_pretrain_warm_start_transfers_encoder_weights(synthetic_anndata, tmp_pa
raise AssertionError("no shape-matching encoder keys to spot-check")


def test_pretrain_on_truly_unlabeled_anndata(synthetic_anndata, tmp_path):
"""`SIMS(data=adata)` (no class_label) should support pretraining on
a literally unlabeled AnnData. The 3.x API forced users to pass a
class_label even when pretraining ignored it; v4 cleaned that up."""
# Drop the labels column entirely so any code path that tries to
# read it would crash.
del synthetic_anndata.obs["blobs"]
assert "blobs" not in synthetic_anndata.obs.columns

sims = SIMS(data=synthetic_anndata) # no class_label kwarg!
assert sims.datamodule.class_label is None
assert sims.datamodule.label_encoder is None
assert sims.datamodule.weights is None

sims.pretrain(
accelerator="cpu",
devices=1,
max_epochs=2,
enable_progress_bar=False,
logger=False,
checkpoint_dir=str(tmp_path / "pretrain"),
)
assert isinstance(sims.pretrainer, SIMSPretrainer)


def test_train_without_class_label_raises_clear_error(synthetic_anndata):
"""Calling .train() on a SIMS built without a class_label should raise
a ValueError that points users at the right API."""
import pytest

del synthetic_anndata.obs["blobs"]
sims = SIMS(data=synthetic_anndata)
with pytest.raises(ValueError, match="class_label"):
sims.train()


def test_load_pretrainer_round_trip(synthetic_anndata, tmp_path):
"""Two-process workflow: pretrain in 'process A', save .ckpt, fresh
SIMS in 'process B' loads it via load_pretrainer() and warm-starts
Expand Down
Loading