diff --git a/cap_naive_bayes/naive_bayes.py b/cap_naive_bayes/naive_bayes.py index a75f82b..441e2cf 100644 --- a/cap_naive_bayes/naive_bayes.py +++ b/cap_naive_bayes/naive_bayes.py @@ -267,9 +267,12 @@ def _normalize_log_probs(log_probs: np.ndarray) -> np.ndarray: """Perform P_i = P_i / sum_i(P_i) but in log space""" logger.debug("Start _normalize_log_probs...") - sum_log_probs = np.log(np.exp(log_probs).sum(axis=1)).reshape(-1, 1) # normalization coefficient - probs = np.exp(log_probs - sum_log_probs) - + # use log-sum-exp + k = np.max(log_probs, axis=1).reshape(-1,1) + exp = np.exp(log_probs - k) + log_sum_biased = np.log(exp.sum(axis=1).reshape(-1,1)) + log_sum = log_sum_biased + k + probs = np.exp(log_probs - log_sum) logger.debug("Finished _normalize_log_probs!") return probs diff --git a/pyproject.toml b/pyproject.toml index 5fd471d..09d35b6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "cap-naive-bayes" -version = "0.1.2" +version = "0.1.3" description = "A lightweight implementation of a multinomial Naive Bayes classifier for Annotation Transfer of single-cell data." readme = "README.md" requires-python = ">=3.11"