Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 111 additions & 24 deletions privacy_guard/attacks/rmia_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

# pyre-strict
import logging
from typing import Optional

import numpy as np
Expand All @@ -24,6 +25,8 @@
from privacy_guard.attacks.base_attack import BaseAttack
from sklearn.metrics import auc, roc_curve

logger: logging.Logger = logging.getLogger(__name__)


class RmiaAttack(BaseAttack):
"""
Expand All @@ -35,6 +38,12 @@ class RmiaAttack(BaseAttack):
The attack estimates population probability distributions using reference models and compares
the target model's predictions against estimated population averages to determine membership.

The attack supports two modes:
- Offline (default): Only OUT-model signals are used; population probability is
approximated via alpha mixing.
- Online: Both IN and OUT model signals are available per sample; population
probability is estimated directly from all reference models.

The attack leverages:
- Multiple reference models trained without target samples
- Population samples to establish baseline probability distributions
Expand All @@ -45,6 +54,25 @@ class RmiaAttack(BaseAttack):
REF_SCORE_PREFIX = "score_ref_"
REF_MEMBER_PREFIX = "member_ref_"

@staticmethod
def _compute_alpha_mixing_estimate(
mean_out: np.ndarray, alpha: float
) -> np.ndarray:
"""
Compute population probability estimate using alpha mixing formula.

This implements the offline RMIA approximation:
``population_estimate = (1 + alpha) / 2 * mean_out + (1 - alpha) / 2``

Args:
mean_out: Mean of OUT-model signals
alpha: Approximation coefficient for population probability

Returns:
Population probability estimate
"""
return (1 + alpha) / 2 * mean_out + (1 - alpha) / 2

def __init__(
self,
df_train_merge: pd.DataFrame,
Expand All @@ -55,6 +83,7 @@ def __init__(
alpha_coefficient: float = 0.3,
enable_auto_tuning: bool = False,
user_id_key: str = "user_id",
online: bool = False,
) -> None:
"""
Initialize the RMIA attack.
Expand All @@ -67,12 +96,24 @@ def __init__(
alpha_coefficient: Approximation coefficient for population probability estimation
num_reference_models: Number of reference models to use in attack (default: half of available models)
enable_auto_tuning: Enable automatic tuning of alpha_coefficient
user_id_key: Column name for user identifier
online: If True, use online RMIA mode where both IN and OUT model
signals are used directly. If False (default), use offline mode
with alpha mixing approximation.
"""
self.df_train_merge: pd.DataFrame = df_train_merge.copy()
self.df_test_merge: pd.DataFrame = df_test_merge.copy()
self.df_population: pd.DataFrame = df_population.copy()
self.row_aggregation: AggregationType = row_aggregation
self.alpha_coefficient: float = alpha_coefficient
self.online: bool = online

if online and enable_auto_tuning:
logger.warning(
"Auto-tuning is not applicable in online mode (alpha is not used). "
"Disabling auto-tuning."
)
enable_auto_tuning = False

if num_reference_models is None:
# estimate number of reference models based on the number of columns in the dataframe
Expand Down Expand Up @@ -136,39 +177,57 @@ def compute_ref_signal_averages(
ref_memberships: np.ndarray,
num_models: Optional[int] = None,
alpha: float = 0.3,
online: bool = False,
) -> np.ndarray:
"""
Compute average prediction probabilities from reference models excluding target samples.
Compute average prediction probabilities from reference models.

In offline mode (default), excludes IN-model signals and uses only
OUT-model signals. In online mode, uses all reference model signals
directly (both IN and OUT).

Args:
ref_signals: Prediction scores from reference models
ref_memberships: Boolean membership matrix indicating training inclusion
num_models: Number of reference models to consider
alpha: Approximation coefficient for population probability
online: If True, use online mode (direct population estimate)

Returns:
Averaged prediction scores excluding membership samples
Averaged prediction scores
"""
non_member_mask = ~ref_memberships
out_ref_signals = ref_signals * non_member_mask
if online:
# Online: use all reference model signals directly (IN + OUT)
if num_models is None:
num_models = ref_signals.shape[1]
if num_models >= ref_signals.shape[1]:
return ref_signals
# Select top-k signals across all models
return np.sort(ref_signals, axis=1)[:, -num_models:]
else:
# Offline: mask out IN signals, use only OUT signals
non_member_mask = ~ref_memberships
out_ref_signals = ref_signals * non_member_mask

if num_models is None:
num_models = ref_signals.shape[1] // 2
if num_models is None:
num_models = ref_signals.shape[1] // 2

if num_models > 1:
# Select top non-zero signals for each sample
out_ref_signals = np.sort(out_ref_signals, axis=1)[:, -num_models:]
else:
# Apply single model approximation formula
if alpha != 0:
approximation = ((ref_signals + alpha - 1) / alpha) * ref_memberships
out_ref_signals += approximation
if num_models > 1:
# Select top non-zero signals for each sample
out_ref_signals = np.sort(out_ref_signals, axis=1)[:, -num_models:]
else:
# Default fallback approximation w/ alpha=0.3
fallback = ((ref_signals - 0.7) / 0.3) * ref_memberships
out_ref_signals += fallback

return out_ref_signals
# Apply single model approximation formula
if alpha != 0:
approximation = (
(ref_signals + alpha - 1) / alpha
) * ref_memberships
out_ref_signals += approximation
else:
# Default fallback approximation w/ alpha=0.3
fallback = ((ref_signals - 0.7) / 0.3) * ref_memberships
out_ref_signals += fallback

return out_ref_signals

def _auto_tune_alpha_coefficient(
self,
Expand All @@ -178,11 +237,13 @@ def _auto_tune_alpha_coefficient(
population_ref_scores: np.ndarray,
) -> float:
"""Auto-tune alpha coefficient using cross-validation on reference models.

Args:
target_model_idx: Index of target model in reference scores
ref_scores: Reference model prediction scores
ref_memberships: Reference model membership matrix
population_ref_scores: Population reference model scores

Returns:
Tuned alpha coefficient
"""
Expand Down Expand Up @@ -233,6 +294,7 @@ def _compute_membership_scores(
population_ref_scores: np.ndarray,
alpha: float,
num_models: int,
online: bool = False,
) -> np.ndarray:
"""
Execute core membership inference scoring algorithm.
Expand All @@ -245,27 +307,50 @@ def _compute_membership_scores(
population_ref_scores: Population reference model scores
alpha: Population probability approximation coefficient
num_models: Number of reference models to use
online: If True, use online mode (direct population estimate)

Returns:
Membership inference scores (higher indicates greater membership likelihood)
"""

# Compute reference signals for target dataset
target_ref_mean = self.compute_ref_signal_averages(
ref_scores, ref_memberships, num_models, alpha
ref_scores, ref_memberships, num_models, alpha, online=online
)
target_mean_out = np.mean(target_ref_mean, axis=1)
target_population_estimate = (1 + alpha) / 2 * target_mean_out + (1 - alpha) / 2

if online:
# Online: direct population estimate (no alpha mixing)
target_population_estimate = target_mean_out
else:
# Offline: alpha-based approximation
target_population_estimate = self._compute_alpha_mixing_estimate(
target_mean_out, alpha
)

target_probability_ratios = target_scores.ravel() / target_population_estimate

# Compute reference signals for population dataset
# Population samples are not used for training, so we set the membership to False
population_memberships = np.zeros_like(population_ref_scores).astype(bool)
population_ref_mean = self.compute_ref_signal_averages(
population_ref_scores, population_memberships, num_models, alpha
population_ref_scores,
population_memberships,
num_models,
alpha,
online=online,
)
population_mean_out = np.mean(population_ref_mean, axis=1)
population_estimate = (1 + alpha) / 2 * population_mean_out + (1 - alpha) / 2

if online:
# Online: direct population estimate
population_estimate = population_mean_out
else:
# Offline: alpha-based approximation
population_estimate = self._compute_alpha_mixing_estimate(
population_mean_out, alpha
)

population_probability_ratios = (
population_target_scores.ravel() / population_estimate
)
Expand Down Expand Up @@ -339,7 +424,7 @@ def run_attack(self) -> AggregateAnalysisInput:
self.df_population
)

# Auto-tune alpha coefficient if requested
# Auto-tune alpha coefficient if requested (only in offline mode)
current_alpha = self.alpha_coefficient
if self.enable_auto_tuning:
# Use reference data for alpha tuning
Expand All @@ -364,6 +449,7 @@ def run_attack(self) -> AggregateAnalysisInput:
population_ref_scores=population_ref_scores,
alpha=current_alpha,
num_models=self.num_reference_models,
online=self.online,
)

# Compute membership scores for test data
Expand All @@ -375,6 +461,7 @@ def run_attack(self) -> AggregateAnalysisInput:
population_ref_scores=population_ref_scores,
alpha=current_alpha,
num_models=self.num_reference_models,
online=self.online,
)

# Create output dataframes with computed scores
Expand Down
Loading
Loading