diff --git a/stlearn/_datasets/_datasets.py b/stlearn/_datasets/_datasets.py index 3c7e38dd..948b51d8 100644 --- a/stlearn/_datasets/_datasets.py +++ b/stlearn/_datasets/_datasets.py @@ -6,7 +6,6 @@ from .._settings import settings -# TODO - Add scanpy and covert this over. def visium_sge( sample_id="V1_Breast_Cancer_Block_A_Section_1", *, @@ -35,11 +34,11 @@ def visium_sge( def xenium_sge( - base_url="https://cf.10xgenomics.com/samples/xenium/1.0.1", - image_filename="he_image.ome.tif", - alignment_filename="he_imagealignment.csv", - zip_filename="outs.zip", - library_id="Xenium_FFPE_Human_Breast_Cancer_Rep1", + base_url: str="https://cf.10xgenomics.com/samples/xenium/1.0.1", + library_id: str="Xenium_FFPE_Human_Breast_Cancer_Rep1", + zip_filename: str="outs.zip", + image_filename: str="he_image.ome.tif", + alignment_filename: str="he_imagealignment.csv", include_hires_tiff: bool = False, ): """ @@ -48,17 +47,25 @@ def xenium_sge( Args: base_url: Base URL for downloads + library_id: Identifier for the library + zip_filename: Name of the zip file to download image_filename: Name of the image file to download alignment_filename: Name of the affine transformation file to download - zip_filename: Name of the zip file to download - library_id: Identifier for the library include_hires_tiff: Whether to download the high-res TIFF image """ sc.settings.datasetdir = settings.datasetdir library_dir = settings.datasetdir / library_id library_dir.mkdir(parents=True, exist_ok=True) - files_to_extract = ["cell_feature_matrix.h5", "cells.csv.gz", "experiment.xenium"] + if "xe_outs.zip" in zip_filename: + files_to_extract = [ + "cell_feature_matrix.zarr.zip", "cells.zarr.zip", "experiment.xenium" + ] + else: + files_to_extract = [ + "cell_feature_matrix.h5", "cells.csv.gz", "experiment.xenium" + ] + all_sge_files_exist = all( (library_dir / sge_file).exists() for sge_file in files_to_extract ) @@ -79,11 +86,11 @@ def xenium_sge( sc.readwrite._download(url=url, path=file_path) if not all_sge_files_exist: + zip_file_path = library_dir / zip_filename try: - zip_file_path = library_dir / zip_filename with zf.ZipFile(zip_file_path, "r") as zip_ref: - for zip_filename in files_to_extract: - with open(library_dir / zip_filename, "wb") as file_name: - file_name.write(zip_ref.read(f"outs/{zip_filename}")) + members = {m.rsplit("/", 1)[-1]: m for m in zip_ref.namelist()} + for name in files_to_extract: + (library_dir / name).write_bytes(zip_ref.read(members[name])) except zf.BadZipFile as b: - raise ValueError(f"Invalid zip file: {library_dir / zip_filename}") from b + raise ValueError(f"Invalid zip file: {zip_file_path}") from b diff --git a/stlearn/spatial/trajectory/pseudotime.py b/stlearn/spatial/trajectory/pseudotime.py index c5a27a74..3c2bbf13 100644 --- a/stlearn/spatial/trajectory/pseudotime.py +++ b/stlearn/spatial/trajectory/pseudotime.py @@ -3,9 +3,9 @@ import pandas as pd import scanpy as sc from anndata import AnnData +from networkx import Graph from sklearn.neighbors import NearestCentroid -from stlearn.pp import neighbors from stlearn.spatial.clustering import localization from stlearn.spatial.morphology import adjust from stlearn.types import _METHOD @@ -15,8 +15,6 @@ def pseudotime( adata: AnnData, use_label: str = "leiden", eps: float = 20, - n_neighbors: int = 25, - use_rep: str = "X_pca", threshold: float = 0.01, radius: int = 50, method: _METHOD = "mean", @@ -25,11 +23,10 @@ def pseudotime( reverse: bool = False, pseudotime_key: str = "dpt_pseudotime", max_nodes: int = 4, - run_knn: bool = False, copy: bool = False, ) -> AnnData | None: """\ - Perform pseudotime analysis. + Perform pseudotime analysis. Requires having run knn neighbours beforehand. Parameters ---------- @@ -71,6 +68,12 @@ def pseudotime( """ + if "neighbors" not in adata.uns and "connectivities" not in adata.obsp: + raise ValueError( + "A neighbor graph is required - none found in uns or obsp. " + "Subsetting data requires re-running." + ) + keys_obsm = ["X_diffmap", "X_draw_graph_fr", "X_diffmap_morphology"] keys_uns = [ "split_node", @@ -91,10 +94,6 @@ def pseudotime( localization(adata, use_label=use_label, eps=eps) - # Running knn - if run_knn: - neighbors(adata, n_neighbors=n_neighbors, use_rep=use_rep, random_state=0) - # Running paga sc.tl.paga(adata, groups=use_label) @@ -138,7 +137,7 @@ def pseudotime( replicate_list = np.array([]) for i in range(0, len(cnt_matrix)): replicate_list = np.concatenate( - [replicate_list, np.array([i] * len(split_node[i]))] + [replicate_list, np.array([i] * len(split_node[i]))], ) # Connection matrix for subcluster @@ -155,15 +154,15 @@ def pseudotime( ] # Create a connection graph of subclusters - G = nx.from_pandas_adjacency(cnt_matrix) - G_nodes = list(range(len(G.nodes))) + graph = nx.from_pandas_adjacency(cnt_matrix) + graph_nodes = list(range(len(graph.nodes))) node_convert = {} - for pair in zip(list(G.nodes), G_nodes, strict=True): + for pair in zip(list(graph.nodes), graph_nodes, strict=True): node_convert[pair[1]] = pair[0] adata.uns["global_graph"] = {} - adata.uns["global_graph"]["graph"] = nx.to_scipy_sparse_array(G) + adata.uns["global_graph"]["graph"] = nx.to_scipy_sparse_array(graph) adata.uns["global_graph"]["node_dict"] = node_convert # Create centroid dict for subclusters @@ -216,34 +215,27 @@ def selection_sort(x): def store_available_paths(adata, threshold, use_label, max_nodes, pseudotime_key): - # Read original PAGA graph - G = nx.from_numpy_array(adata.uns["paga"]["connectivities"].toarray()) - edge_weights = nx.get_edge_attributes(G, "weight") - G.remove_edges_from((e for e, w in edge_weights.items() if w < threshold)) - - H = G.to_directed() + # Recreate original PAGA graph. + graph = nx.from_numpy_array(adata.uns["paga"]["connectivities"].toarray()) + edge_weights = nx.get_edge_attributes(graph, "weight") + graph.remove_edges_from((e for e, w in edge_weights.items() if w < threshold)) # Calculate pseudotime for each node - node_pseudotime = {} - - for node in H.nodes: - node_pseudotime[node] = adata.obs.query(use_label + " == '" + str(node) + "'")[ - pseudotime_key - ].max() + node_pseudotime = node_pseudotime_summary(adata, graph, pseudotime_key, use_label) - # Force original PAGA to directed PAGA based on pseudotime - edge_to_remove = [] - for edge in H.edges: - if node_pseudotime[edge[0]] - node_pseudotime[edge[1]] > 0: - edge_to_remove.append(edge) - H.remove_edges_from(edge_to_remove) + # Convert undirected graph to directed graph by pseudotime. + directed_graph = orient_by_pseudotime(graph, node_pseudotime) # Extract all available paths all_paths = {} - for source in H.nodes: - for target in H.nodes: - paths = nx.all_simple_paths(H, source=source, target=target) + for source in directed_graph.nodes: + for target in directed_graph.nodes: + paths = nx.all_simple_paths( + directed_graph, + source=source, + target=target, + ) for i, path in enumerate(paths): if len(path) < max_nodes: all_paths[str(i) + "_" + str(source) + "_" + str(target)] = path @@ -253,5 +245,37 @@ def store_available_paths(adata, threshold, use_label, max_nodes, pseudotime_key "All available trajectory paths are stored in adata.uns['available_paths'] " + "with length < " + str(max_nodes) - + " nodes" + + " nodes", ) + + +def node_pseudotime_summary(adata, graph: Graph, pseudotime_key, use_label): + summary = {} + for node in graph.nodes: + s = adata.obs.query(f"{use_label} == '{node}'")[pseudotime_key] + finite = s[np.isfinite(s)] + summary[node] = float(finite.median()) if len(finite) else np.nan + return summary + + +def orient_by_pseudotime(graph, node_pseudotime): + """Orient an undirected PAGA graph into a DAG using per-node pseudotime. + + Each undirected edge becomes a single arc pointing from lower to higher + pseudotime. Edges touching a node with NaN pseudotime (a cluster + unreachable from the root) cannot be ordered and are dropped. Ties are + broken deterministically by node id so no 2-cycle can survive. + """ + directed_graph = nx.DiGraph() + directed_graph.add_nodes_from(graph.nodes) + for u, v in graph.edges: + pu, pv = node_pseudotime[u], node_pseudotime[v] + if not (np.isfinite(pu) and np.isfinite(pv)): + continue + if pu < pv: + directed_graph.add_edge(u, v) + elif pv < pu: + directed_graph.add_edge(v, u) + else: + directed_graph.add_edge(min(u, v), max(u, v)) + return directed_graph diff --git a/tests/spatial/test_pseudotime.py b/tests/spatial/test_pseudotime.py new file mode 100644 index 00000000..83150d0d --- /dev/null +++ b/tests/spatial/test_pseudotime.py @@ -0,0 +1,87 @@ +import unittest +from types import SimpleNamespace + +import networkx +import numpy +import pandas + +from stlearn.spatial.trajectory.pseudotime import ( + node_pseudotime_summary, + orient_by_pseudotime, +) + + +class TestPseudotime(unittest.TestCase): + + @staticmethod + def new_graph(edges, n): + graph = networkx.Graph() + graph.add_nodes_from(range(n)) + graph.add_edges_from(edges) + return graph + + @staticmethod + def make_adata(labels, pseudotime): + obs = pandas.DataFrame( + { + "leiden": pandas.Categorical([str(x) for x in labels]), + "dpt_pseudotime": numpy.array(pseudotime, dtype=float), + }, + ) + return SimpleNamespace(obs=obs) + + def test_inf_cluster_drops_its_edges(self): + graph = TestPseudotime.new_graph([(0, 1), (1, 2)], 3) + d = orient_by_pseudotime(graph, {0: 0.0, 1: 0.5, 2: float("inf")}) + assert networkx.is_directed_acyclic_graph(d) + assert set(d.edges) == {(0, 1)} + + def test_chain_orients_lower_to_higher(self): + graph = TestPseudotime.new_graph([(0, 1), (1, 2), (2, 3)], 4) + directed_graph = orient_by_pseudotime(graph, {0: 0.0, 1: 0.25, 2: 0.5, 3: 1.0}) + assert networkx.is_directed_acyclic_graph(directed_graph) + assert set(directed_graph.edges) == {(0, 1), (1, 2), (2, 3)} + + def test_tie_yields_single_arc_not_two_cycle(self): + graph = TestPseudotime.new_graph([(0, 1)], 2) + directed_graph = orient_by_pseudotime(graph, {0: 0.5, 1: 0.5}) + assert networkx.is_directed_acyclic_graph(directed_graph) + assert directed_graph.number_of_edges() == 1 + + def test_nan_cluster_drops_its_edges(self): + graph = TestPseudotime.new_graph([(0, 1), (1, 2)], 3) + new_graph = orient_by_pseudotime(graph, {0: 0.0, 1: 0.5, 2: float("nan")}) + assert networkx.is_directed_acyclic_graph(new_graph) + assert set(new_graph.edges) == {(0, 1)} + + # Test never cycles with ties, NaN, etc. + def test_orientation_is_always_acyclic(self): + """Property test: ties + NaN + arbitrary connectivity must never cycle.""" + rng = numpy.random.default_rng(0) + for _ in range(500): + n = int(rng.integers(3, 8)) + random_node = rng.random((n, n)) + random_node = (random_node + random_node.T) / 2 + random_node[random_node < 0.5] = 0.0 + numpy.fill_diagonal(random_node, 0.0) + graph = networkx.from_numpy_array(random_node) + vals = rng.choice([0.0, 0.0, 0.5, 1.0, float("nan")], size=n) + directed_graph = orient_by_pseudotime( + graph, + {i: float(vals[i]) for i in range(n)}, + ) + assert networkx.is_directed_acyclic_graph(directed_graph) + + # a stray-inf cluster survives orientation under the summary, vanishes under .max() + def test_summary_keeps_cluster_that_broken_drops(self): + adata = TestPseudotime.make_adata( + ["0", "0", "1", "1", "2", "2"], + [0.0, 0.1, 0.5, float("inf"), 0.9, 0.95], + ) + chain = networkx.Graph() + chain.add_nodes_from([0, 1, 2]) + chain.add_edges_from([(0, 1), (1, 2)]) + + summary = node_pseudotime_summary(adata, chain, "dpt_pseudotime", "leiden") + # Will be zero if broken. + assert orient_by_pseudotime(chain, summary).number_of_edges() == 2 diff --git a/tests/spatial/test_psts.py b/tests/spatial/test_psts.py index ffd6d2a5..84082369 100644 --- a/tests/spatial/test_psts.py +++ b/tests/spatial/test_psts.py @@ -27,7 +27,7 @@ def test_psts(self): print("Done leiden!") self.adata.uns["iroot"] = np.flatnonzero(self.adata.obs["leiden"] == "0")[0] st.spatial.trajectory.pseudotime( - self.adata, eps=100, use_rep="X_pca", use_sme=False, use_label="leiden" + self.adata, eps=100, use_sme=False, use_label="leiden" ) st.spatial.trajectory.pseudotimespace_global( self.adata, use_label="leiden", list_clusters=[0, 1] diff --git a/tests/tl/test_lr_core.py b/tests/tl/test_lr_core.py index a319c7b6..42032379 100644 --- a/tests/tl/test_lr_core.py +++ b/tests/tl/test_lr_core.py @@ -49,7 +49,7 @@ def test_matches_reference_on_random_inputs(self): got = lr_core( spot_lr1, spot_lr2, - neighbour_lists, + List(neighbour_lists), min_expr, spot_indices, )