Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 33 additions & 6 deletions privacy_guard/analysis/mia/analysis_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,11 +175,36 @@ def __init__(

super().__init__(analysis_input=analysis_input)

def _get_tpr_index(self) -> int:
"""Convert TPR target to array index."""
if self._tpr_target is None:
return 0 # Legacy behavior: TPR = 1% is at index 0
return int(np.where(self._error_thresholds == self._tpr_target)[0][0])
@staticmethod
def get_tpr_index(
tpr_target: float | None,
tpr_threshold_width: float = 0.0025,
) -> int:
"""
Convert TPR target to array index in the error_thresholds grid.

Uses np.isclose to handle floating-point precision issues with np.linspace.
Raises ValueError if tpr_target does not align with the threshold grid.

Args:
tpr_target: Target TPR value. If None, returns 0 (legacy behavior).
tpr_threshold_width: Width between TPR thresholds.

Returns:
Index into the error_thresholds array.
"""
if tpr_target is None:
return 0
num_thresholds = int((1.0 - 0.01) / tpr_threshold_width) + 1
error_thresholds = np.linspace(0.01, 1.0, num_thresholds)
matches = np.where(np.isclose(error_thresholds, tpr_target))[0]
if len(matches) > 0:
return int(matches[0])
raise ValueError(
f"tpr_target={tpr_target} does not align with the error_thresholds array. "
f"Nearest value is {error_thresholds[np.argmin(np.abs(error_thresholds - tpr_target))]}. "
f"Adjust tpr_target and tpr_threshold_width so that tpr_target falls on the threshold grid."
)

def _calculate_one_off_eps(self) -> float:
df_train_user = self.analysis_input.df_train_user
Expand Down Expand Up @@ -308,7 +333,9 @@ def run_analysis(self) -> BaseAnalysisOutput:

eps_tpr_boundary = eps_tpr_ub if self._use_upper_bound else eps_tpr_lb

tpr_idx = self._get_tpr_index()
tpr_idx = AnalysisNode.get_tpr_index(
self._tpr_target, self._tpr_threshold_width
)
outputs = AnalysisNodeOutput(
eps=eps_tpr_boundary[tpr_idx], # epsilon at specified TPR threshold
eps_lb=eps_tpr_lb[tpr_idx],
Expand Down
4 changes: 3 additions & 1 deletion privacy_guard/analysis/mia/parallel_analysis_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,9 @@ def run_analysis(self) -> AnalysisNodeOutput:

eps_tpr_boundary = eps_tpr_ub if self._use_upper_bound else eps_tpr_lb

tpr_idx = self._get_tpr_index()
tpr_idx = AnalysisNode.get_tpr_index(
self._tpr_target, self._tpr_threshold_width
)
outputs = AnalysisNodeOutput(
eps=eps_tpr_boundary[tpr_idx], # epsilon at specified TPR threshold
eps_lb=eps_tpr_lb[tpr_idx],
Expand Down
74 changes: 13 additions & 61 deletions privacy_guard/analysis/tests/test_analysis_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,43 +450,6 @@ def test_use_fnr_tnr_parameter_comparison(self) -> None:
)
self.assertAlmostEqual(outputs_false["auc"], outputs_true["auc"], places=10)

def test_get_tpr_index_none_target(self) -> None:
"""Test that _get_tpr_index returns 0 when tpr_target is None (legacy behavior)."""
analysis_node = AnalysisNode(
analysis_input=self.analysis_input,
delta=0.000001,
n_users_for_eval=100,
num_bootstrap_resampling_times=10,
tpr_target=None,
)
self.assertEqual(analysis_node._get_tpr_index(), 0)

def test_get_tpr_index_with_target(self) -> None:
"""Test that _get_tpr_index returns correct index that points to tpr_target."""
# Create error_thresholds array to get actual values
num_thresholds = int((1.0 - 0.01) / 0.0025) + 1
error_thresholds = np.linspace(0.01, 1.0, num_thresholds)

# Test with actual values from the array at various indices
test_indices = [0, 6, 36, 196, num_thresholds - 1]

for idx in test_indices:
tpr_target = error_thresholds[idx]
analysis_node = AnalysisNode(
analysis_input=self.analysis_input,
delta=0.000001,
n_users_for_eval=100,
num_bootstrap_resampling_times=10,
tpr_target=tpr_target,
tpr_threshold_width=0.0025,
)
tpr_idx = analysis_node._get_tpr_index()
self.assertEqual(
tpr_idx,
idx,
msg=f"tpr_target={tpr_target}: expected index {idx}, got {tpr_idx}",
)

def test_tpr_threshold_width_positive_validation(self) -> None:
"""Test that tpr_threshold_width must be positive."""
with self.assertRaisesRegex(ValueError, "must be positive"):
Expand Down Expand Up @@ -538,28 +501,17 @@ def test_tpr_target_range_validation(self) -> None:
tpr_target=1.5,
)

def test_error_thresholds_array_creation(self) -> None:
"""Test that _error_thresholds array is correctly created."""
# Legacy mode: 100 thresholds
analysis_node_legacy = AnalysisNode(
analysis_input=self.analysis_input,
delta=0.000001,
n_users_for_eval=100,
num_bootstrap_resampling_times=10,
tpr_target=None,
)
self.assertEqual(len(analysis_node_legacy._error_thresholds), 100)
def test_get_tpr_index(self) -> None:
with self.subTest("none_target"):
self.assertEqual(AnalysisNode.get_tpr_index(None), 0)

# Fine-grained mode
analysis_node_fine = AnalysisNode(
analysis_input=self.analysis_input,
delta=0.000001,
n_users_for_eval=100,
num_bootstrap_resampling_times=10,
tpr_target=0.01,
tpr_threshold_width=0.0025,
)
expected_num_thresholds = int(0.99 / 0.0025) + 1
self.assertEqual(
len(analysis_node_fine._error_thresholds), expected_num_thresholds
)
with self.subTest("arbitrary_float"):
num_thresholds = int((1.0 - 0.01) / 0.0025) + 1
error_thresholds = np.linspace(0.01, 1.0, num_thresholds)
for target in [0.03, 0.06, 0.07, 0.15, 0.5]:
idx = AnalysisNode.get_tpr_index(target)
self.assertAlmostEqual(error_thresholds[idx], target, places=5)

with self.subTest("misaligned_target"):
with self.assertRaises(ValueError):
AnalysisNode.get_tpr_index(0.0115)
19 changes: 0 additions & 19 deletions privacy_guard/analysis/tests/test_parallel_analysis_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,22 +274,3 @@ def test_use_fnr_tnr_parameter(self) -> None:
self.assertGreater(
len(outputs_false["eps_tpr_ub"]), len(outputs_true["eps_tpr_ub"])
)

def test_tpr_target_parameter(self) -> None:
"""Test that tpr_target parameter works correctly in ParallelAnalysisNode."""
parallel_node = ParallelAnalysisNode(
analysis_input=self.analysis_input,
delta=0.000001,
n_users_for_eval=100,
num_bootstrap_resampling_times=10,
eps_computation_tasks_num=2,
tpr_target=0.025,
tpr_threshold_width=0.0025,
)
# Verify _get_tpr_index returns correct index
tpr_idx = parallel_node._get_tpr_index()
self.assertAlmostEqual(
parallel_node._error_thresholds[tpr_idx],
0.025,
places=10,
)
Loading