Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 21 additions & 14 deletions stlearn/_datasets/_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from .._settings import settings


# TODO - Add scanpy and covert this over.
def visium_sge(
sample_id="V1_Breast_Cancer_Block_A_Section_1",
*,
Expand Down Expand Up @@ -35,11 +34,11 @@ def visium_sge(


def xenium_sge(
base_url="https://cf.10xgenomics.com/samples/xenium/1.0.1",
image_filename="he_image.ome.tif",
alignment_filename="he_imagealignment.csv",
zip_filename="outs.zip",
library_id="Xenium_FFPE_Human_Breast_Cancer_Rep1",
base_url: str="https://cf.10xgenomics.com/samples/xenium/1.0.1",
library_id: str="Xenium_FFPE_Human_Breast_Cancer_Rep1",
zip_filename: str="outs.zip",
image_filename: str="he_image.ome.tif",
alignment_filename: str="he_imagealignment.csv",
include_hires_tiff: bool = False,
):
"""
Expand All @@ -48,17 +47,25 @@ def xenium_sge(

Args:
base_url: Base URL for downloads
library_id: Identifier for the library
zip_filename: Name of the zip file to download
image_filename: Name of the image file to download
alignment_filename: Name of the affine transformation file to download
zip_filename: Name of the zip file to download
library_id: Identifier for the library
include_hires_tiff: Whether to download the high-res TIFF image
"""
sc.settings.datasetdir = settings.datasetdir
library_dir = settings.datasetdir / library_id
library_dir.mkdir(parents=True, exist_ok=True)

files_to_extract = ["cell_feature_matrix.h5", "cells.csv.gz", "experiment.xenium"]
if "xe_outs.zip" in zip_filename:
files_to_extract = [
"cell_feature_matrix.zarr.zip", "cells.zarr.zip", "experiment.xenium"
]
else:
files_to_extract = [
"cell_feature_matrix.h5", "cells.csv.gz", "experiment.xenium"
]

all_sge_files_exist = all(
(library_dir / sge_file).exists() for sge_file in files_to_extract
)
Expand All @@ -79,11 +86,11 @@ def xenium_sge(
sc.readwrite._download(url=url, path=file_path)

if not all_sge_files_exist:
zip_file_path = library_dir / zip_filename
try:
zip_file_path = library_dir / zip_filename
with zf.ZipFile(zip_file_path, "r") as zip_ref:
for zip_filename in files_to_extract:
with open(library_dir / zip_filename, "wb") as file_name:
file_name.write(zip_ref.read(f"outs/{zip_filename}"))
members = {m.rsplit("/", 1)[-1]: m for m in zip_ref.namelist()}
for name in files_to_extract:
(library_dir / name).write_bytes(zip_ref.read(members[name]))
except zf.BadZipFile as b:
raise ValueError(f"Invalid zip file: {library_dir / zip_filename}") from b
raise ValueError(f"Invalid zip file: {zip_file_path}") from b
96 changes: 60 additions & 36 deletions stlearn/spatial/trajectory/pseudotime.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import pandas as pd
import scanpy as sc
from anndata import AnnData
from networkx import Graph
from sklearn.neighbors import NearestCentroid

from stlearn.pp import neighbors
from stlearn.spatial.clustering import localization
from stlearn.spatial.morphology import adjust
from stlearn.types import _METHOD
Expand All @@ -15,8 +15,6 @@ def pseudotime(
adata: AnnData,
use_label: str = "leiden",
eps: float = 20,
n_neighbors: int = 25,
use_rep: str = "X_pca",
threshold: float = 0.01,
radius: int = 50,
method: _METHOD = "mean",
Expand All @@ -25,11 +23,10 @@ def pseudotime(
reverse: bool = False,
pseudotime_key: str = "dpt_pseudotime",
max_nodes: int = 4,
run_knn: bool = False,
copy: bool = False,
) -> AnnData | None:
"""\
Perform pseudotime analysis.
Perform pseudotime analysis. Requires having run knn neighbours beforehand.

Parameters
----------
Expand Down Expand Up @@ -71,6 +68,12 @@ def pseudotime(

"""

if "neighbors" not in adata.uns and "connectivities" not in adata.obsp:
raise ValueError(
"A neighbor graph is required - none found in uns or obsp. "
"Subsetting data requires re-running."
)

keys_obsm = ["X_diffmap", "X_draw_graph_fr", "X_diffmap_morphology"]
keys_uns = [
"split_node",
Expand All @@ -91,10 +94,6 @@ def pseudotime(

localization(adata, use_label=use_label, eps=eps)

# Running knn
if run_knn:
neighbors(adata, n_neighbors=n_neighbors, use_rep=use_rep, random_state=0)

# Running paga
sc.tl.paga(adata, groups=use_label)

Expand Down Expand Up @@ -138,7 +137,7 @@ def pseudotime(
replicate_list = np.array([])
for i in range(0, len(cnt_matrix)):
replicate_list = np.concatenate(
[replicate_list, np.array([i] * len(split_node[i]))]
[replicate_list, np.array([i] * len(split_node[i]))],
)

# Connection matrix for subcluster
Expand All @@ -155,15 +154,15 @@ def pseudotime(
]

# Create a connection graph of subclusters
G = nx.from_pandas_adjacency(cnt_matrix)
G_nodes = list(range(len(G.nodes)))
graph = nx.from_pandas_adjacency(cnt_matrix)
graph_nodes = list(range(len(graph.nodes)))

node_convert = {}
for pair in zip(list(G.nodes), G_nodes, strict=True):
for pair in zip(list(graph.nodes), graph_nodes, strict=True):
node_convert[pair[1]] = pair[0]

adata.uns["global_graph"] = {}
adata.uns["global_graph"]["graph"] = nx.to_scipy_sparse_array(G)
adata.uns["global_graph"]["graph"] = nx.to_scipy_sparse_array(graph)
adata.uns["global_graph"]["node_dict"] = node_convert

# Create centroid dict for subclusters
Expand Down Expand Up @@ -216,34 +215,27 @@ def selection_sort(x):


def store_available_paths(adata, threshold, use_label, max_nodes, pseudotime_key):
# Read original PAGA graph
G = nx.from_numpy_array(adata.uns["paga"]["connectivities"].toarray())
edge_weights = nx.get_edge_attributes(G, "weight")
G.remove_edges_from((e for e, w in edge_weights.items() if w < threshold))

H = G.to_directed()
# Recreate original PAGA graph.
graph = nx.from_numpy_array(adata.uns["paga"]["connectivities"].toarray())
edge_weights = nx.get_edge_attributes(graph, "weight")
graph.remove_edges_from((e for e, w in edge_weights.items() if w < threshold))

# Calculate pseudotime for each node
node_pseudotime = {}

for node in H.nodes:
node_pseudotime[node] = adata.obs.query(use_label + " == '" + str(node) + "'")[
pseudotime_key
].max()
node_pseudotime = node_pseudotime_summary(adata, graph, pseudotime_key, use_label)

# Force original PAGA to directed PAGA based on pseudotime
edge_to_remove = []
for edge in H.edges:
if node_pseudotime[edge[0]] - node_pseudotime[edge[1]] > 0:
edge_to_remove.append(edge)
H.remove_edges_from(edge_to_remove)
# Convert undirected graph to directed graph by pseudotime.
directed_graph = orient_by_pseudotime(graph, node_pseudotime)

# Extract all available paths
all_paths = {}

for source in H.nodes:
for target in H.nodes:
paths = nx.all_simple_paths(H, source=source, target=target)
for source in directed_graph.nodes:
for target in directed_graph.nodes:
paths = nx.all_simple_paths(
directed_graph,
source=source,
target=target,
)
for i, path in enumerate(paths):
if len(path) < max_nodes:
all_paths[str(i) + "_" + str(source) + "_" + str(target)] = path
Expand All @@ -253,5 +245,37 @@ def store_available_paths(adata, threshold, use_label, max_nodes, pseudotime_key
"All available trajectory paths are stored in adata.uns['available_paths'] "
+ "with length < "
+ str(max_nodes)
+ " nodes"
+ " nodes",
)


def node_pseudotime_summary(adata, graph: Graph, pseudotime_key, use_label):
summary = {}
for node in graph.nodes:
s = adata.obs.query(f"{use_label} == '{node}'")[pseudotime_key]
finite = s[np.isfinite(s)]
summary[node] = float(finite.median()) if len(finite) else np.nan
return summary


def orient_by_pseudotime(graph, node_pseudotime):
"""Orient an undirected PAGA graph into a DAG using per-node pseudotime.

Each undirected edge becomes a single arc pointing from lower to higher
pseudotime. Edges touching a node with NaN pseudotime (a cluster
unreachable from the root) cannot be ordered and are dropped. Ties are
broken deterministically by node id so no 2-cycle can survive.
"""
directed_graph = nx.DiGraph()
directed_graph.add_nodes_from(graph.nodes)
for u, v in graph.edges:
pu, pv = node_pseudotime[u], node_pseudotime[v]
if not (np.isfinite(pu) and np.isfinite(pv)):
continue
if pu < pv:
directed_graph.add_edge(u, v)
elif pv < pu:
directed_graph.add_edge(v, u)
else:
directed_graph.add_edge(min(u, v), max(u, v))
return directed_graph
87 changes: 87 additions & 0 deletions tests/spatial/test_pseudotime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import unittest
from types import SimpleNamespace

import networkx
import numpy
import pandas

from stlearn.spatial.trajectory.pseudotime import (
node_pseudotime_summary,
orient_by_pseudotime,
)


class TestPseudotime(unittest.TestCase):

@staticmethod
def new_graph(edges, n):
graph = networkx.Graph()
graph.add_nodes_from(range(n))
graph.add_edges_from(edges)
return graph

@staticmethod
def make_adata(labels, pseudotime):
obs = pandas.DataFrame(
{
"leiden": pandas.Categorical([str(x) for x in labels]),
"dpt_pseudotime": numpy.array(pseudotime, dtype=float),
},
)
return SimpleNamespace(obs=obs)

def test_inf_cluster_drops_its_edges(self):
graph = TestPseudotime.new_graph([(0, 1), (1, 2)], 3)
d = orient_by_pseudotime(graph, {0: 0.0, 1: 0.5, 2: float("inf")})
assert networkx.is_directed_acyclic_graph(d)
assert set(d.edges) == {(0, 1)}

def test_chain_orients_lower_to_higher(self):
graph = TestPseudotime.new_graph([(0, 1), (1, 2), (2, 3)], 4)
directed_graph = orient_by_pseudotime(graph, {0: 0.0, 1: 0.25, 2: 0.5, 3: 1.0})
assert networkx.is_directed_acyclic_graph(directed_graph)
assert set(directed_graph.edges) == {(0, 1), (1, 2), (2, 3)}

def test_tie_yields_single_arc_not_two_cycle(self):
graph = TestPseudotime.new_graph([(0, 1)], 2)
directed_graph = orient_by_pseudotime(graph, {0: 0.5, 1: 0.5})
assert networkx.is_directed_acyclic_graph(directed_graph)
assert directed_graph.number_of_edges() == 1

def test_nan_cluster_drops_its_edges(self):
graph = TestPseudotime.new_graph([(0, 1), (1, 2)], 3)
new_graph = orient_by_pseudotime(graph, {0: 0.0, 1: 0.5, 2: float("nan")})
assert networkx.is_directed_acyclic_graph(new_graph)
assert set(new_graph.edges) == {(0, 1)}

# Test never cycles with ties, NaN, etc.
def test_orientation_is_always_acyclic(self):
"""Property test: ties + NaN + arbitrary connectivity must never cycle."""
rng = numpy.random.default_rng(0)
for _ in range(500):
n = int(rng.integers(3, 8))
random_node = rng.random((n, n))
random_node = (random_node + random_node.T) / 2
random_node[random_node < 0.5] = 0.0
numpy.fill_diagonal(random_node, 0.0)
graph = networkx.from_numpy_array(random_node)
vals = rng.choice([0.0, 0.0, 0.5, 1.0, float("nan")], size=n)
directed_graph = orient_by_pseudotime(
graph,
{i: float(vals[i]) for i in range(n)},
)
assert networkx.is_directed_acyclic_graph(directed_graph)

# a stray-inf cluster survives orientation under the summary, vanishes under .max()
def test_summary_keeps_cluster_that_broken_drops(self):
adata = TestPseudotime.make_adata(
["0", "0", "1", "1", "2", "2"],
[0.0, 0.1, 0.5, float("inf"), 0.9, 0.95],
)
chain = networkx.Graph()
chain.add_nodes_from([0, 1, 2])
chain.add_edges_from([(0, 1), (1, 2)])

summary = node_pseudotime_summary(adata, chain, "dpt_pseudotime", "leiden")
# Will be zero if broken.
assert orient_by_pseudotime(chain, summary).number_of_edges() == 2
2 changes: 1 addition & 1 deletion tests/spatial/test_psts.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_psts(self):
print("Done leiden!")
self.adata.uns["iroot"] = np.flatnonzero(self.adata.obs["leiden"] == "0")[0]
st.spatial.trajectory.pseudotime(
self.adata, eps=100, use_rep="X_pca", use_sme=False, use_label="leiden"
self.adata, eps=100, use_sme=False, use_label="leiden"
)
st.spatial.trajectory.pseudotimespace_global(
self.adata, use_label="leiden", list_clusters=[0, 1]
Expand Down
2 changes: 1 addition & 1 deletion tests/tl/test_lr_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_matches_reference_on_random_inputs(self):
got = lr_core(
spot_lr1,
spot_lr2,
neighbour_lists,
List(neighbour_lists),
min_expr,
spot_indices,
)
Expand Down