From 7b846a8556d777732cef382bcf5fde12ee9f8883 Mon Sep 17 00:00:00 2001 From: Duhyeong Kim Date: Thu, 26 Feb 2026 13:58:33 -0800 Subject: [PATCH] Fix floating-point comparison bug in _get_tpr_index MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Fix `IndexError` in `_get_tpr_index` caused by exact float equality (`==`) against `np.linspace`-generated thresholds (e.g., `0.060000000000000005 ≠ 0.06`). Found via F1040508640. Replaced with `AnalysisNode.get_tpr_index()` static method using `np.isclose`, which raises `ValueError` on true misalignment. All call sites updated to use the shared method directly. Reviewed By: mgrange1998 Differential Revision: D94187721 --- privacy_guard/analysis/mia/analysis_node.py | 39 ++++++++-- .../analysis/mia/parallel_analysis_node.py | 4 +- .../analysis/tests/test_analysis_node.py | 74 ++++--------------- .../tests/test_parallel_analysis_node.py | 19 ----- 4 files changed, 49 insertions(+), 87 deletions(-) diff --git a/privacy_guard/analysis/mia/analysis_node.py b/privacy_guard/analysis/mia/analysis_node.py index 3b13076..8bdec24 100644 --- a/privacy_guard/analysis/mia/analysis_node.py +++ b/privacy_guard/analysis/mia/analysis_node.py @@ -175,11 +175,36 @@ def __init__( super().__init__(analysis_input=analysis_input) - def _get_tpr_index(self) -> int: - """Convert TPR target to array index.""" - if self._tpr_target is None: - return 0 # Legacy behavior: TPR = 1% is at index 0 - return int(np.where(self._error_thresholds == self._tpr_target)[0][0]) + @staticmethod + def get_tpr_index( + tpr_target: float | None, + tpr_threshold_width: float = 0.0025, + ) -> int: + """ + Convert TPR target to array index in the error_thresholds grid. + + Uses np.isclose to handle floating-point precision issues with np.linspace. + Raises ValueError if tpr_target does not align with the threshold grid. + + Args: + tpr_target: Target TPR value. If None, returns 0 (legacy behavior). + tpr_threshold_width: Width between TPR thresholds. + + Returns: + Index into the error_thresholds array. + """ + if tpr_target is None: + return 0 + num_thresholds = int((1.0 - 0.01) / tpr_threshold_width) + 1 + error_thresholds = np.linspace(0.01, 1.0, num_thresholds) + matches = np.where(np.isclose(error_thresholds, tpr_target))[0] + if len(matches) > 0: + return int(matches[0]) + raise ValueError( + f"tpr_target={tpr_target} does not align with the error_thresholds array. " + f"Nearest value is {error_thresholds[np.argmin(np.abs(error_thresholds - tpr_target))]}. " + f"Adjust tpr_target and tpr_threshold_width so that tpr_target falls on the threshold grid." + ) def _calculate_one_off_eps(self) -> float: df_train_user = self.analysis_input.df_train_user @@ -308,7 +333,9 @@ def run_analysis(self) -> BaseAnalysisOutput: eps_tpr_boundary = eps_tpr_ub if self._use_upper_bound else eps_tpr_lb - tpr_idx = self._get_tpr_index() + tpr_idx = AnalysisNode.get_tpr_index( + self._tpr_target, self._tpr_threshold_width + ) outputs = AnalysisNodeOutput( eps=eps_tpr_boundary[tpr_idx], # epsilon at specified TPR threshold eps_lb=eps_tpr_lb[tpr_idx], diff --git a/privacy_guard/analysis/mia/parallel_analysis_node.py b/privacy_guard/analysis/mia/parallel_analysis_node.py index f0bf688..c3980a5 100644 --- a/privacy_guard/analysis/mia/parallel_analysis_node.py +++ b/privacy_guard/analysis/mia/parallel_analysis_node.py @@ -226,7 +226,9 @@ def run_analysis(self) -> AnalysisNodeOutput: eps_tpr_boundary = eps_tpr_ub if self._use_upper_bound else eps_tpr_lb - tpr_idx = self._get_tpr_index() + tpr_idx = AnalysisNode.get_tpr_index( + self._tpr_target, self._tpr_threshold_width + ) outputs = AnalysisNodeOutput( eps=eps_tpr_boundary[tpr_idx], # epsilon at specified TPR threshold eps_lb=eps_tpr_lb[tpr_idx], diff --git a/privacy_guard/analysis/tests/test_analysis_node.py b/privacy_guard/analysis/tests/test_analysis_node.py index e35b2df..23ad7ae 100644 --- a/privacy_guard/analysis/tests/test_analysis_node.py +++ b/privacy_guard/analysis/tests/test_analysis_node.py @@ -450,43 +450,6 @@ def test_use_fnr_tnr_parameter_comparison(self) -> None: ) self.assertAlmostEqual(outputs_false["auc"], outputs_true["auc"], places=10) - def test_get_tpr_index_none_target(self) -> None: - """Test that _get_tpr_index returns 0 when tpr_target is None (legacy behavior).""" - analysis_node = AnalysisNode( - analysis_input=self.analysis_input, - delta=0.000001, - n_users_for_eval=100, - num_bootstrap_resampling_times=10, - tpr_target=None, - ) - self.assertEqual(analysis_node._get_tpr_index(), 0) - - def test_get_tpr_index_with_target(self) -> None: - """Test that _get_tpr_index returns correct index that points to tpr_target.""" - # Create error_thresholds array to get actual values - num_thresholds = int((1.0 - 0.01) / 0.0025) + 1 - error_thresholds = np.linspace(0.01, 1.0, num_thresholds) - - # Test with actual values from the array at various indices - test_indices = [0, 6, 36, 196, num_thresholds - 1] - - for idx in test_indices: - tpr_target = error_thresholds[idx] - analysis_node = AnalysisNode( - analysis_input=self.analysis_input, - delta=0.000001, - n_users_for_eval=100, - num_bootstrap_resampling_times=10, - tpr_target=tpr_target, - tpr_threshold_width=0.0025, - ) - tpr_idx = analysis_node._get_tpr_index() - self.assertEqual( - tpr_idx, - idx, - msg=f"tpr_target={tpr_target}: expected index {idx}, got {tpr_idx}", - ) - def test_tpr_threshold_width_positive_validation(self) -> None: """Test that tpr_threshold_width must be positive.""" with self.assertRaisesRegex(ValueError, "must be positive"): @@ -538,28 +501,17 @@ def test_tpr_target_range_validation(self) -> None: tpr_target=1.5, ) - def test_error_thresholds_array_creation(self) -> None: - """Test that _error_thresholds array is correctly created.""" - # Legacy mode: 100 thresholds - analysis_node_legacy = AnalysisNode( - analysis_input=self.analysis_input, - delta=0.000001, - n_users_for_eval=100, - num_bootstrap_resampling_times=10, - tpr_target=None, - ) - self.assertEqual(len(analysis_node_legacy._error_thresholds), 100) + def test_get_tpr_index(self) -> None: + with self.subTest("none_target"): + self.assertEqual(AnalysisNode.get_tpr_index(None), 0) - # Fine-grained mode - analysis_node_fine = AnalysisNode( - analysis_input=self.analysis_input, - delta=0.000001, - n_users_for_eval=100, - num_bootstrap_resampling_times=10, - tpr_target=0.01, - tpr_threshold_width=0.0025, - ) - expected_num_thresholds = int(0.99 / 0.0025) + 1 - self.assertEqual( - len(analysis_node_fine._error_thresholds), expected_num_thresholds - ) + with self.subTest("arbitrary_float"): + num_thresholds = int((1.0 - 0.01) / 0.0025) + 1 + error_thresholds = np.linspace(0.01, 1.0, num_thresholds) + for target in [0.03, 0.06, 0.07, 0.15, 0.5]: + idx = AnalysisNode.get_tpr_index(target) + self.assertAlmostEqual(error_thresholds[idx], target, places=5) + + with self.subTest("misaligned_target"): + with self.assertRaises(ValueError): + AnalysisNode.get_tpr_index(0.0115) diff --git a/privacy_guard/analysis/tests/test_parallel_analysis_node.py b/privacy_guard/analysis/tests/test_parallel_analysis_node.py index 8e810ef..fd8fc50 100644 --- a/privacy_guard/analysis/tests/test_parallel_analysis_node.py +++ b/privacy_guard/analysis/tests/test_parallel_analysis_node.py @@ -274,22 +274,3 @@ def test_use_fnr_tnr_parameter(self) -> None: self.assertGreater( len(outputs_false["eps_tpr_ub"]), len(outputs_true["eps_tpr_ub"]) ) - - def test_tpr_target_parameter(self) -> None: - """Test that tpr_target parameter works correctly in ParallelAnalysisNode.""" - parallel_node = ParallelAnalysisNode( - analysis_input=self.analysis_input, - delta=0.000001, - n_users_for_eval=100, - num_bootstrap_resampling_times=10, - eps_computation_tasks_num=2, - tpr_target=0.025, - tpr_threshold_width=0.0025, - ) - # Verify _get_tpr_index returns correct index - tpr_idx = parallel_node._get_tpr_index() - self.assertAlmostEqual( - parallel_node._error_thresholds[tpr_idx], - 0.025, - places=10, - )