Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion examples/ensemble_attack/configs/experiment_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ target_model: # This is only used for testing the attack on a real target model.

# Data paths
data_paths:
midst_data_path: /projects/midst-experiments/all_tabddpms/ # Used to collect the data (input) as defined in data_processing_config

processed_base_data_dir: ${base_experiment_dir} # To save new processed data for training, or read from previously collected and processed data (testing phase).
population_path: ${data_paths.processed_base_data_dir}/population_data # Path where the collected population data will be stored (output/input)
processed_attack_data_path: ${data_paths.processed_base_data_dir}/attack_data # Path where the processed attack real train and evaluation data is stored (output/input)
Expand All @@ -38,6 +38,10 @@ model_paths:

# Dataset specific information used for processing in this example
data_processing_config:
midst_data_path: /projects/midst-experiments/all_tabddpms/ # Used to collect the data (input)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One step towards config modularization.

# The following directories should exist and contain `tabddpm_{i} for i in folder_ranges`:
# population data: midst_data_path/population_attack_data_types_to_collect/population_splits
# challenge data: midst_data_path/challenge_attack_data_types_to_collect/challenge_splits
population_attack_data_types_to_collect:
[
"tabddpm_trained_with_10k",
Expand Down
2 changes: 1 addition & 1 deletion examples/ensemble_attack/run_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def run_data_processing(config: DictConfig) -> None:
log(INFO, "Running data processing pipeline...")
# Collect the real data from the MIDST challenge resources.
population_data = collect_population_data_ensemble(
midst_data_input_dir=Path(config.data_paths.midst_data_path),
midst_data_input_dir=Path(config.data_processing_config.midst_data_path),
data_processing_config=config.data_processing_config,
save_dir=Path(config.data_paths.population_path),
base_population=original_population_data,
Expand Down
35 changes: 19 additions & 16 deletions examples/ensemble_attack/test_attack_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from midst_toolkit.attacks.ensemble.data_utils import load_dataframe
from midst_toolkit.common.logger import log
from midst_toolkit.common.random import set_all_random_seeds
from midst_toolkit.models.clavaddpm.train import get_df_without_id


class RmiaTrainingDataChoice(Enum):
Expand Down Expand Up @@ -51,37 +52,34 @@ def save_results(
f.write(f"TPR at FPR=0.1: {pred_score:.4f}\n")


def extract_and_drop_id_column(
def extract_primary_id_column(
data_frame: pd.DataFrame,
data_types_file_path: Path,
) -> tuple[pd.DataFrame, pd.Series]:
) -> pd.Series:
"""
Extracts IDs from the dataframe and drops the ID column. ID column is identified based on
Extracts and returns the primary IDs from the dataframe. The primary ID column is identified based on
the data types JSON file with "id_column_name" key.
primary IDs are unique keys in the dataset.
For example, in the Berka dataset, "trans_id" is the primary ID column, while "account_id" is not.

Args:
data_frame: Input dataframe.
data_types_file_path: Path to the data types JSON file.

Returns:
A tuple containing:
- The modified dataframe with ID columns dropped.
- A Series containing the extracted data of ID columns.
A Series containing the extracted data of the main ID column.
"""
# Extract ID column from the dataframe
with open(data_types_file_path, "r") as f:
column_types = json.load(f)

assert "id_column_name" in column_types, f"{data_types_file_path} must contain 'id_column_name' key."
id_column_name = column_types["id_column_name"]

# Make sure we have one main id column
assert isinstance(id_column_name, str), "Only one main id column should be identified."
assert id_column_name in data_frame.columns, f"Dataframe must have {id_column_name} column"
data_trans_ids = data_frame[id_column_name]

# Drop ID column from data
data_frame = data_frame.drop(columns=id_column_name)

return data_frame, data_trans_ids
return data_frame[id_column_name]


def run_rmia_shadow_training(config: DictConfig, df_challenge: pd.DataFrame) -> list[dict[str, list[Any]]]:
Expand Down Expand Up @@ -183,7 +181,7 @@ def collect_challenge_and_train_data(
midst_data_input_dir=targets_data_path,
attack_types=challenge_attack_types,
# For ensemble experiments, change to ``test`` for 10k, and change to ``final`` for 20k
split_folders=["test"],
split_folders=["final"],
dataset="challenge",
data_processing_config=data_processing_config,
)
Expand Down Expand Up @@ -266,7 +264,7 @@ def train_rmia_shadows_for_test_phase(config: DictConfig) -> list[dict[str, list
df_challenge_experiment, df_master_train = collect_challenge_and_train_data(
config.data_processing_config,
processed_attack_data_path=Path(config.data_paths.processed_attack_data_path),
targets_data_path=Path(config.data_paths.midst_data_path),
targets_data_path=Path(config.data_processing_config.midst_data_path),
)
# Load the challenge dataframe for training RMIA shadow models.
rmia_training_choice = RmiaTrainingDataChoice(config.target_model.attack_rmia_shadow_training_data_choice)
Expand Down Expand Up @@ -357,8 +355,13 @@ def run_metaclassifier_testing(
else:
log(INFO, "All shadow models for testing phase found. Using existing RMIA shadow models...")

# Extract and drop id columns from the test data
test_data, test_trans_ids = extract_and_drop_id_column(test_data, Path(config.metaclassifier.data_types_file_path))
# Extract the main ID column's values from the test data
test_trans_ids = extract_primary_id_column(
data_frame=test_data,
data_types_file_path=Path(config.metaclassifier.data_types_file_path),
)
# Drop id columns from the test data. Berka has two id columns: "trans_id" and "account_id".
test_data = get_df_without_id(test_data)

# 4) Initialize the attacker object, and assign the loaded metaclassifier to it.
blending_attacker = BlendingPlusPlus(
Expand Down
4 changes: 3 additions & 1 deletion src/midst_toolkit/attacks/ensemble/blending.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,8 @@ def predict(
score = None

if y_test is not None:
score = TprAtFpr.get_tpr_at_fpr(true_membership=y_test, predictions=probabilities, max_fpr=0.1)
score = TprAtFpr.get_tpr_at_fpr(
true_membership=y_test, predicted_membership=probabilities, fpr_threshold=0.1
)

return probabilities, score
4 changes: 2 additions & 2 deletions tests/unit/attacks/ensemble/test_meta_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def test_predict_flow(
call_args = mock_get_tpr.call_args

np.testing.assert_array_equal(call_args.kwargs["true_membership"], sample_dataframes["y_test"])
np.testing.assert_array_almost_equal(call_args.kwargs["predictions"], expected_probabilities)
np.testing.assert_equal(call_args.kwargs["max_fpr"], 0.1)
np.testing.assert_array_almost_equal(call_args.kwargs["predicted_membership"], expected_probabilities)
np.testing.assert_equal(call_args.kwargs["fpr_threshold"], 0.1)

assert score == 0.99