diff --git a/src/hnoca/mapping/mapper.py b/src/hnoca/mapping/mapper.py index 8020cd5..57d0984 100644 --- a/src/hnoca/mapping/mapper.py +++ b/src/hnoca/mapping/mapper.py @@ -33,8 +33,8 @@ def __init__( ref_model: The reference model to map the query dataset to. """ # Check optional dependencies - check_deps("scvi-tools") - check_deps("scarches") + # check_deps("scvi-tools") # Comment as it is a function that fails to detect scvi-tools was installed + # check_deps("scarches") # Comment as it is a function that fails to detect scArches was installed # Import and store as attributes so other methods can use them import scarches import scvi # local import assured by previous check diff --git a/src/hnoca/mapping/wknn.py b/src/hnoca/mapping/wknn.py index 361329f..dc66210 100644 --- a/src/hnoca/mapping/wknn.py +++ b/src/hnoca/mapping/wknn.py @@ -53,13 +53,15 @@ def build_nn( # noqa: D103 ref, query=None, k=100, - use_rapids: bool = False, + weight: Literal["unweighted", "dist", "gaussian_kernel"] = "unweighted", + sigma=None, + use_rapids: bool = True, # Ensure that RAPIDS is used ): if query is None: query = ref if use_rapids: - check_deps("cuml") + # check_deps("cuml") # Comment check_deps() because the function is broken from cuml.neighbors import NearestNeighbors logger.info("Using cuML for neighborhood estimation on GPU.") @@ -223,9 +225,9 @@ def estimate_presence_score( ref = ref_adata.obsm[use_rep_ref_trans_prop] ref_trans_prop = get_transition_prob_mat(ref, k=k_ref_trans_prop) - if split_by and split_by in query_adata.obs.columns: + if split_by in query_adata.obs.columns: presence_split = [ - np.array(wknn[query_adata.obs[split_by] == x, :].sum(axis=0)).flatten() + np.array(wknn[query_adata.obs[split_by].to_numpy() == x, :].sum(axis=0)).flatten() # added to_numpy() for better compatibility for x in query_adata.obs[split_by].unique() ] else: @@ -270,13 +272,14 @@ def estimate_presence_score( } -def transfer_labels(ref_adata: sc.AnnData, query_adata: sc.AnnData, wknn, label_key: str = "celltype"): +def transfer_labels(ref_adata: sc.AnnData, query_adata: sc.AnnData, wknn, label_key: str ="celltype"): """Transfer labels from reference to query data.""" scores = pd.DataFrame( wknn @ pd.get_dummies(ref_adata.obs[label_key]), columns=pd.get_dummies(ref_adata.obs[label_key]).columns, index=query_adata.obs_names, ) + scores["best_score"] = scores.max(1) # change order, first find the score then the label so no string is inputted among floats scores["best_label"] = scores.idxmax(1) - scores["best_score"] = scores.max(1) + return scores