Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ default_stages:
- pre-push
minimum_pre_commit_version: 2.12.0
repos:
- repo: https://github.com/pre-commit/mirrors-prettier
rev: v4.0.0-alpha.4
- repo: https://github.com/rbubley/mirrors-prettier
rev: v3.6.2
hooks:
- id: prettier
exclude: |
Expand All @@ -24,13 +24,13 @@ repos:
docs/notes/
)
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.7
rev: v0.12.10
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix, --unsafe-fixes]
- id: ruff-format
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
rev: v6.0.0
hooks:
- id: detect-private-key
- id: check-ast
Expand All @@ -44,7 +44,7 @@ repos:
- id: trailing-whitespace
- id: check-case-conflict
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.7.1
rev: v1.17.1
hooks:
- id: mypy
args: [--no-strict-optional, --ignore-missing-imports]
Expand Down
31 changes: 18 additions & 13 deletions docs/quickstart.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
"import scanpy as sc\n",
"import pandas as pd\n",
"import seaborn as sns\n",
"\n",
"sns.set_theme()\n",
"%config InlineBackend.figure_formats = ['svg']"
]
Expand Down Expand Up @@ -76,7 +77,9 @@
},
"outputs": [],
"source": [
"artifact = ln.Artifact.using(\"laminlabs/arrayloader-benchmarks\").get(\"JNaxQe8zbljesdbK0000\")\n",
"artifact = ln.Artifact.using(\"laminlabs/arrayloader-benchmarks\").get(\n",
" \"JNaxQe8zbljesdbK0000\"\n",
")\n",
"adata = artifact.load()\n",
"sc.pp.log1p(adata)\n",
"adata"
Expand All @@ -93,7 +96,7 @@
},
"outputs": [],
"source": [
"keep = adata.obs[\"cell_line\"].value_counts().loc[lambda x: x>3].index\n",
"keep = adata.obs[\"cell_line\"].value_counts().loc[lambda x: x > 3].index\n",
"adata = adata[adata.obs[\"cell_line\"].isin(keep)].copy()\n",
"adata"
]
Expand Down Expand Up @@ -133,18 +136,14 @@
"source": [
"logreg = mn.models.SimpleLogReg(\n",
" adata=adata,\n",
" label_column=\"cell_line\", \n",
" label_column=\"cell_line\",\n",
" learning_rate=1e-1,\n",
" weight_decay=1e-3,\n",
")\n",
"logreg.fit(\n",
" adata_train=adata,\n",
" adata_val=adata[:20],\n",
" train_dataloader_kwargs={\n",
" \"batch_size\": 128,\n",
" \"drop_last\": True,\n",
" \"num_workers\": 4\n",
" },\n",
" train_dataloader_kwargs={\"batch_size\": 128, \"drop_last\": True, \"num_workers\": 4},\n",
" max_epochs=5,\n",
")"
]
Expand Down Expand Up @@ -212,8 +211,10 @@
},
"outputs": [],
"source": [
"sc.tl.rank_genes_groups(adata, 'cell_line', method='logreg', key_added='sc_logreg')\n",
"df_scanpy_logreg = sc.get.rank_genes_groups_df(adata, group=None, key=\"sc_logreg\").pivot(index='group', columns='names', values='scores')\n",
"sc.tl.rank_genes_groups(adata, \"cell_line\", method=\"logreg\", key_added=\"sc_logreg\")\n",
"df_scanpy_logreg = sc.get.rank_genes_groups_df(\n",
" adata, group=None, key=\"sc_logreg\"\n",
").pivot(index=\"group\", columns=\"names\", values=\"scores\")\n",
"df_scanpy_logreg.attrs[\"method_name\"] = \"scanpy_logreg\"\n",
"df_scanpy_logreg.head()"
]
Expand All @@ -229,8 +230,10 @@
},
"outputs": [],
"source": [
"sc.tl.rank_genes_groups(adata, 'cell_line', method='wilcoxon', key_added='sc_wilcoxon')\n",
"df_scanpy_wilcoxon = sc.get.rank_genes_groups_df(adata, group=None, key=\"sc_wilcoxon\").pivot(index='group', columns='names', values='scores')\n",
"sc.tl.rank_genes_groups(adata, \"cell_line\", method=\"wilcoxon\", key_added=\"sc_wilcoxon\")\n",
"df_scanpy_wilcoxon = sc.get.rank_genes_groups_df(\n",
" adata, group=None, key=\"sc_wilcoxon\"\n",
").pivot(index=\"group\", columns=\"names\", values=\"scores\")\n",
"df_scanpy_wilcoxon.attrs[\"method_name\"] = \"scanpy_wilcoxon\"\n",
"df_scanpy_wilcoxon.head()"
]
Expand All @@ -254,7 +257,9 @@
},
"outputs": [],
"source": [
"compare = mn.eval.CompareScoresJaccard([df_modlyn_logreg, df_scanpy_logreg, df_scanpy_wilcoxon], n_top_values=[5, 10, 25])"
"compare = mn.eval.CompareScoresJaccard(\n",
" [df_modlyn_logreg, df_scanpy_logreg, df_scanpy_wilcoxon], n_top_values=[5, 10, 25]\n",
")"
]
},
{
Expand Down
1 change: 1 addition & 0 deletions modlyn/eval/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
CompareScores

"""

from ._jaccard import CompareScores

CompareScoresJaccard = CompareScores # backward compat
17 changes: 8 additions & 9 deletions modlyn/eval/_jaccard.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,9 @@ class CompareScores:
def __init__(self, dataframes, n_top_values=None):
"""Initialize with dataframes and n_top values to compare.

Parameters:
-----------
dataframes : list of pd.DataFrame
List of dataframes with method results. Each should have df.attrs["method_name"]
n_top_values : list of int
List of top-N values to compare across
Args:
dataframes: List of dataframes with method results. Each should have df.attrs["method_name"]
n_top_values: List of top-N values to compare across
"""
if n_top_values is None:
n_top_values = [25, 50, 100, 200]
Expand Down Expand Up @@ -45,7 +42,7 @@ def compute_jaccard_comparison(self):
for cell_line in common_cells:
scores = {
name: df.loc[cell_line]
for df, name in zip(dfs_aligned, method_names)
for df, name in zip(dfs_aligned, method_names, strict=False)
}
top_features = {
name: set(scores[name].abs().nlargest(n_top).index)
Expand Down Expand Up @@ -121,7 +118,7 @@ def plot_jaccard_comparison(self):
)

# Add value labels
for bar, value in zip(bars, values):
for bar, value in zip(bars, values, strict=False):
if not np.isnan(value):
ax.text(
bar.get_x() + bar.get_width() / 2,
Expand Down Expand Up @@ -164,7 +161,9 @@ def plot_heatmaps(self):
axes = [axes]

# Plot heatmaps
for i, (df, method_name) in enumerate(zip(dfs_sorted, method_names)):
for i, (df, method_name) in enumerate(
zip(dfs_sorted, method_names, strict=False)
):
sns.heatmap(df, ax=axes[i], cmap="viridis", vmin=vmin, vmax=vmax, cbar=True)
axes[i].set_title(method_name)

Expand Down
1 change: 1 addition & 0 deletions modlyn/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@
SimpleLogReg

"""

from ._simple_logreg_model import SimpleLogReg
4 changes: 2 additions & 2 deletions modlyn/models/_simple_logreg_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ class SimpleLogRegDataModule(L.LightningDataModule):
"""A simple LightningDataModule for classification tasks using TensorDataset.

Args:
adata_train: anndata.AnnData object containing the training data.
adata_val: anndata.AnnData object containing the validation data.
adata_train: `AnnData` object containing the training data.
adata_val: `AnnData` object containing the validation data.
label_column: Name of the column in `obs` that contains the target values.
train_dataloader_kwargs: Additional keyword arguments passed to the torch DataLoader for the training dataset.
val_dataloader_kwargs: Additional keyword arguments passed to the torch DataLoader for the validation dataset.
Expand Down
1 change: 0 additions & 1 deletion modlyn/models/_simple_logreg_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ class SimpleLogReg(L.LightningModule):
label_column: Name of the column in `obs` that contains the target values.
learning_rate: Learning rate for the optimizer.
weight_decay: Weight decay for the optimizer.

"""

def __init__(
Expand Down
1 change: 0 additions & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,4 @@ def docs(session):
if not IS_PR:
subprocess.run(
"lndocs --strip-prefix --format text --error-on-index", # --strict back
shell=True, # noqa: S602
)
10 changes: 5 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi"

[project]
name = "modlyn"
requires-python = ">=3.9,<3.13"
requires-python = ">=3.11,<3.14"
authors = [{name = "Lamin Labs", email = "open-source@lamin.ai"}]
readme = "README.md"
dynamic = ["version", "description"]
Expand Down Expand Up @@ -41,7 +41,7 @@ filterwarnings = [
[tool.ruff]
src = ["src"]
line-length = 88
select = [
lint.select = [
"F", # Errors detected by Pyflakes
"E", # Error detected by Pycodestyle
"W", # Warning detected by Pycodestyle
Expand All @@ -58,7 +58,7 @@ select = [
"PTH", # Use pathlib
"S" # Security
]
ignore = [
lint.ignore = [
# Do not catch blind exception: `Exception`
"BLE001",
# Errors from function calls in argument defaults. These are fine when the result is immutable.
Expand Down Expand Up @@ -129,10 +129,10 @@ ignore = [
"S607"
]

[tool.ruff.pydocstyle]
[tool.ruff.lint.pydocstyle]
convention = "google"

[tool.ruff.per-file-ignores]
[tool.ruff.lint.per-file-ignores]
"docs/*" = ["I"]
"tests/**/*.py" = [
"D", # docstrings are allowed to look a bit off
Expand Down
2 changes: 0 additions & 2 deletions tests/test_base.py

This file was deleted.

1 change: 0 additions & 1 deletion tests/test_notebooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@


def test_notebooks():
# assuming this is in the tests folder
docs_folder = Path(__file__).parents[1] / "docs/"

for check_folder in docs_folder.glob("./**"):
Expand Down