Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,6 @@ docs/modlyn.*
lamin_sphinx
docs/conf.py
_docs_tmp*

docs/test-modlyn/
lightning_logs/
157 changes: 123 additions & 34 deletions docs/quickstart.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,8 @@
{
"cell_type": "code",
"execution_count": null,
"id": "453f6f89",
"metadata": {
"tags": [
"hide-output"
]
},
"id": "35122bdc",
"metadata": {},
"outputs": [],
"source": [
"import lamindb as ln\n",
Expand All @@ -47,42 +43,89 @@
{
"cell_type": "code",
"execution_count": null,
"id": "980a05b7",
"metadata": {
"tags": [
"hide-output"
]
},
"id": "9708b93e",
"metadata": {},
"outputs": [],
"source": [
"ln.track()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fffe8a48",
"metadata": {},
"outputs": [],
"source": [
"# Configuration: switch between in-memory and Dask loader\n",
"USE_DASK = True # set False to use in-memory path\n",
"ZARR_UID = \"1xSHIdfBjfUdxKHm0000\" # example UID; change as needed\n",
"LABEL_COL = \"cell_line\"\n",
"\n",
"# Dask runtime\n",
"DASK_DATASET_TYPE = \"arrayloaders-dasd\" # accepted alias (normalized internally)\n",
"BATCH_SIZE = 256\n",
"N_CHUNKS = 8\n",
"DASK_SCHEDULER = \"threads\""
]
},
{
"cell_type": "markdown",
"id": "c8ad0ac1",
"id": "5086e159",
"metadata": {},
"source": [
"## Prepare dataset"
"### Using a custom Dask data loader\n",
"Set `USE_DASK = True` and provide a zarr `ZARR_UID` from `laminlabs/arrayloader-benchmarks`.\n",
"The loader auto-detects whether the cached path is a single zarr store or a directory of shard stores (`*.zarr`) and selects the right reader. For quick runs, we cap steps with `max_steps` in the training call.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dfb07f4c",
"metadata": {
"tags": [
"hide-output"
]
},
"id": "30985561",
"metadata": {},
"outputs": [],
"source": [
"artifact = ln.Artifact.using(\"laminlabs/arrayloader-benchmarks\").get(\n",
" \"JNaxQe8zbljesdbK0000\"\n",
")\n",
"adata = artifact.load()\n",
"sc.pp.log1p(adata)\n",
"adata"
"from pathlib import Path\n",
"import lamindb as ln\n",
"\n",
"if USE_DASK:\n",
" artifact = ln.Artifact.using(\"laminlabs/arrayloader-benchmarks\").get(ZARR_UID)\n",
" store_path = Path(artifact.cache())\n",
" if not store_path.is_dir():\n",
" raise ValueError(f\"ZARR_UID must cache to a directory, got: {store_path}\")\n",
"\n",
" # Decide between a directory of shards (*.zarr) vs a single zarr store\n",
" has_shards = any(child.name.endswith(\".zarr\") for child in store_path.iterdir())\n",
"\n",
" try:\n",
" from arrayloaders.io.dask_loader import read_lazy_store\n",
" except Exception:\n",
" read_lazy_store = None\n",
" from arrayloaders.io import read_lazy as read_single_store\n",
"\n",
" if has_shards and read_lazy_store is not None:\n",
" adata = read_lazy_store(store_path, obs_columns=[LABEL_COL])\n",
" else:\n",
" # Single zarr store\n",
" adata = read_single_store(store_path, obs_columns=[LABEL_COL])\n",
"else:\n",
" # Example H5AD path (keep your current artifact if you prefer)\n",
" artifact = ln.Artifact.using(\"laminlabs/arrayloader-benchmarks\").get(\n",
" \"JNaxQe8zbljesdbK0000\"\n",
" )\n",
" adata = artifact.load()\n",
" sc.pp.log1p(adata)\n",
"\n",
"print(\"adata:\", adata.shape)"
]
},
{
"cell_type": "markdown",
"id": "c8ad0ac1",
"metadata": {},
"source": [
"## Prepare dataset"
]
},
{
Expand Down Expand Up @@ -136,16 +179,55 @@
"source": [
"logreg = mn.models.SimpleLogReg(\n",
" adata=adata,\n",
" label_column=\"cell_line\",\n",
" label_column=LABEL_COL,\n",
" learning_rate=1e-1,\n",
" weight_decay=1e-3,\n",
")\n",
"\n",
"fit_kwargs = {\n",
" \"adata_train\": adata,\n",
" \"adata_val\": None,\n",
" \"train_dataloader_kwargs\": {\n",
" \"batch_size\": BATCH_SIZE,\n",
" \"drop_last\": False,\n",
" \"num_workers\": 0,\n",
" },\n",
" \"max_epochs\": 1,\n",
" \"num_sanity_val_steps\": 0,\n",
" \"log_every_n_steps\": 1,\n",
" \"max_steps\": 50,\n",
"}\n",
"\n",
"if USE_DASK:\n",
" fit_kwargs.update(\n",
" {\n",
" \"dataset_type\": DASK_DATASET_TYPE,\n",
" \"n_chunks\": N_CHUNKS,\n",
" \"dask_scheduler\": DASK_SCHEDULER,\n",
" }\n",
" )\n",
"\n",
"# logreg.fit(**fit_kwargs)\n",
"logreg.fit(\n",
" adata_train=adata,\n",
" adata_val=adata[:20],\n",
" train_dataloader_kwargs={\"batch_size\": 128, \"drop_last\": True, \"num_workers\": 4},\n",
" max_epochs=5,\n",
")"
" adata_val=adata, # reuse the lazy dataset so val has batches\n",
" train_dataloader_kwargs={\n",
" \"batch_size\": BATCH_SIZE,\n",
" \"drop_last\": False,\n",
" \"num_workers\": 0,\n",
" },\n",
" dataset_type=DASK_DATASET_TYPE,\n",
" n_chunks=N_CHUNKS,\n",
" dask_scheduler=DASK_SCHEDULER,\n",
" max_epochs=1,\n",
" num_sanity_val_steps=0,\n",
" log_every_n_steps=1,\n",
" max_steps=50,\n",
")\n",
"\n",
"\n",
"print(\"dataset_type:\", getattr(logreg.datamodule, \"dataset_type\", \"in-memory\"))\n",
"print(\"train_dataset:\", type(logreg.datamodule.train_dataloader().dataset).__name__)"
]
},
{
Expand Down Expand Up @@ -174,7 +256,14 @@
},
"outputs": [],
"source": [
"logreg.plot_classification_report(adata)"
"# eval subset\n",
"adata_eval = adata[:10000]\n",
"adata_eval = adata_eval.to_memory() if hasattr(adata_eval, \"to_memory\") else adata_eval\n",
"\n",
"if hasattr(adata_eval.X, \"compute\"):\n",
" adata_eval.X = adata_eval.X.compute()\n",
"\n",
"logreg.plot_classification_report(adata_eval)"
]
},
{
Expand Down Expand Up @@ -313,7 +402,7 @@
"notebook_metadata_filter": "-all"
},
"kernelspec": {
"display_name": "py312",
"display_name": "lamin_env",
"language": "python",
"name": "python3"
},
Expand All @@ -327,7 +416,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.8"
"version": "3.12.10"
}
},
"nbformat": 4,
Expand Down
72 changes: 66 additions & 6 deletions modlyn/models/_simple_logreg_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import lightning as L
import torch
from arrayloaders.io.dask_loader import DaskDataset
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import DataLoader, TensorDataset

Expand Down Expand Up @@ -82,8 +81,9 @@ def __init__(
self.n_chunks = n_chunks
self.dask_scheduler = dask_scheduler

# Fit label encoder on training data (only needed for tensor datasets)
if self.dataset_type == "in-memory" and self.adata_train is not None:
# Fit label encoder on training data (used by both backends)
self.label_encoder = None
if self.adata_train is not None:
self.label_encoder = LabelEncoder()
self.label_encoder.fit(self.adata_train.obs[self.label_col])

Expand All @@ -107,6 +107,13 @@ def _create_tensor_dataset(self, adata):

def _create_dask_dataset(self, adata, shuffle=True):
"""Create a DaskDataset from AnnData."""
try:
from arrayloaders.io.dask_loader import DaskDataset # lazy import
except Exception as e:
raise ImportError(
"arrayloaders is required for dataset_type='dask-arrayloader'. Install with `pip install arrayloaders`."
) from e

return DaskDataset(
adata,
label_column=self.label_col,
Expand All @@ -115,28 +122,81 @@ def _create_dask_dataset(self, adata, shuffle=True):
dask_scheduler=self.dask_scheduler,
)

def _collate_dask_batch(self, batch):
"""Collate function for DaskDataset batches -> (x_tensor, y_tensor)."""
import numpy as np
import torch

try:
import scipy.sparse as sp
except Exception: # pragma: no cover - optional
sp = None

if not batch:
return torch.empty(0), torch.empty(0, dtype=torch.long)
first = batch[0]
if isinstance(first, tuple) and len(first) == 3:
xs, ys, _ = zip(*batch, strict=False)
else:
xs, ys = zip(*batch, strict=False)
if self.label_encoder is None:
raise RuntimeError("label_encoder not initialized")
# Encode labels; fallback to ints if encoder mismatch occurs
try:
y_enc = self.label_encoder.transform(list(ys))
except Exception:
y_enc = np.array([int(y) for y in ys], dtype=np.int64)
# ensure each row is a contiguous 1D float32 array; handle sparse and object types
xs_arr = []
for x in xs:
# densify sparse rows
if sp is not None and getattr(sp, "issparse", None) and sp.issparse(x):
arr = x.toarray()
else:
arr = np.asarray(x)
# flatten any 2D shapes (e.g., 1 x n_vars)
if arr.ndim > 1:
arr = arr.ravel()
# robust dtype conversion
if arr.dtype == object:
# last-resort element-wise float coercion
try:
arr = arr.astype(np.float32, copy=False)
except Exception:
arr = np.array([float(v) for v in arr], dtype=np.float32)
else:
arr = arr.astype(np.float32, copy=False)
xs_arr.append(arr)
x_tensor = torch.as_tensor(np.stack(xs_arr, axis=0), dtype=torch.float32)
y_tensor = torch.as_tensor(y_enc, dtype=torch.long)
return x_tensor, y_tensor

def train_dataloader(self):
if self.adata_train is None:
raise ValueError("adata_train is None")

kwargs = dict(self.train_dataloader_kwargs)
if self.dataset_type == "in-memory":
train_dataset = self._create_tensor_dataset(self.adata_train)
elif self.dataset_type == "dask-arrayloader":
train_dataset = self._create_dask_dataset(self.adata_train, shuffle=True)
kwargs.setdefault("collate_fn", self._collate_dask_batch)
else:
raise ValueError(f"Unknown dataset_type: {self.dataset_type}")

return DataLoader(train_dataset, **self.train_dataloader_kwargs)
return DataLoader(train_dataset, **kwargs)

def val_dataloader(self):
if self.adata_val is None:
return None
return []

kwargs = dict(self.val_dataloader_kwargs)
if self.dataset_type == "in-memory":
val_dataset = self._create_tensor_dataset(self.adata_val)
elif self.dataset_type == "dask-arrayloader":
val_dataset = self._create_dask_dataset(self.adata_val, shuffle=False)
kwargs.setdefault("collate_fn", self._collate_dask_batch)
else:
raise ValueError(f"Unknown dataset_type: {self.dataset_type}")

return DataLoader(val_dataset, **self.val_dataloader_kwargs)
return DataLoader(val_dataset, **kwargs)
Loading