diff --git a/examples/ensemble_attack/configs/experiment_config.yaml b/examples/ensemble_attack/configs/experiment_config.yaml index fc50bec6..4216715c 100644 --- a/examples/ensemble_attack/configs/experiment_config.yaml +++ b/examples/ensemble_attack/configs/experiment_config.yaml @@ -25,7 +25,7 @@ target_model: # This is only used for testing the attack on a real target model. # Data paths data_paths: - midst_data_path: /projects/midst-experiments/all_tabddpms/ # Used to collect the data (input) as defined in data_processing_config + processed_base_data_dir: ${base_experiment_dir} # To save new processed data for training, or read from previously collected and processed data (testing phase). population_path: ${data_paths.processed_base_data_dir}/population_data # Path where the collected population data will be stored (output/input) processed_attack_data_path: ${data_paths.processed_base_data_dir}/attack_data # Path where the processed attack real train and evaluation data is stored (output/input) @@ -38,6 +38,10 @@ model_paths: # Dataset specific information used for processing in this example data_processing_config: + midst_data_path: /projects/midst-experiments/all_tabddpms/ # Used to collect the data (input) + # The following directories should exist and contain `tabddpm_{i} for i in folder_ranges`: + # population data: midst_data_path/population_attack_data_types_to_collect/population_splits + # challenge data: midst_data_path/challenge_attack_data_types_to_collect/challenge_splits population_attack_data_types_to_collect: [ "tabddpm_trained_with_10k", diff --git a/examples/ensemble_attack/run_attack.py b/examples/ensemble_attack/run_attack.py index c53a86dc..4e67fa50 100644 --- a/examples/ensemble_attack/run_attack.py +++ b/examples/ensemble_attack/run_attack.py @@ -38,7 +38,7 @@ def run_data_processing(config: DictConfig) -> None: log(INFO, "Running data processing pipeline...") # Collect the real data from the MIDST challenge resources. population_data = collect_population_data_ensemble( - midst_data_input_dir=Path(config.data_paths.midst_data_path), + midst_data_input_dir=Path(config.data_processing_config.midst_data_path), data_processing_config=config.data_processing_config, save_dir=Path(config.data_paths.population_path), base_population=original_population_data, diff --git a/examples/ensemble_attack/test_attack_model.py b/examples/ensemble_attack/test_attack_model.py index a4bd18cc..910189bc 100644 --- a/examples/ensemble_attack/test_attack_model.py +++ b/examples/ensemble_attack/test_attack_model.py @@ -18,6 +18,7 @@ from midst_toolkit.attacks.ensemble.data_utils import load_dataframe from midst_toolkit.common.logger import log from midst_toolkit.common.random import set_all_random_seeds +from midst_toolkit.models.clavaddpm.train import get_df_without_id class RmiaTrainingDataChoice(Enum): @@ -51,22 +52,22 @@ def save_results( f.write(f"TPR at FPR=0.1: {pred_score:.4f}\n") -def extract_and_drop_id_column( +def extract_primary_id_column( data_frame: pd.DataFrame, data_types_file_path: Path, -) -> tuple[pd.DataFrame, pd.Series]: +) -> pd.Series: """ - Extracts IDs from the dataframe and drops the ID column. ID column is identified based on + Extracts and returns the primary IDs from the dataframe. The primary ID column is identified based on the data types JSON file with "id_column_name" key. + primary IDs are unique keys in the dataset. + For example, in the Berka dataset, "trans_id" is the primary ID column, while "account_id" is not. Args: data_frame: Input dataframe. data_types_file_path: Path to the data types JSON file. Returns: - A tuple containing: - - The modified dataframe with ID columns dropped. - - A Series containing the extracted data of ID columns. + A Series containing the extracted data of the main ID column. """ # Extract ID column from the dataframe with open(data_types_file_path, "r") as f: @@ -74,14 +75,11 @@ def extract_and_drop_id_column( assert "id_column_name" in column_types, f"{data_types_file_path} must contain 'id_column_name' key." id_column_name = column_types["id_column_name"] - + # Make sure we have one main id column + assert isinstance(id_column_name, str), "Only one main id column should be identified." assert id_column_name in data_frame.columns, f"Dataframe must have {id_column_name} column" - data_trans_ids = data_frame[id_column_name] - - # Drop ID column from data - data_frame = data_frame.drop(columns=id_column_name) - return data_frame, data_trans_ids + return data_frame[id_column_name] def run_rmia_shadow_training(config: DictConfig, df_challenge: pd.DataFrame) -> list[dict[str, list[Any]]]: @@ -183,7 +181,7 @@ def collect_challenge_and_train_data( midst_data_input_dir=targets_data_path, attack_types=challenge_attack_types, # For ensemble experiments, change to ``test`` for 10k, and change to ``final`` for 20k - split_folders=["test"], + split_folders=["final"], dataset="challenge", data_processing_config=data_processing_config, ) @@ -266,7 +264,7 @@ def train_rmia_shadows_for_test_phase(config: DictConfig) -> list[dict[str, list df_challenge_experiment, df_master_train = collect_challenge_and_train_data( config.data_processing_config, processed_attack_data_path=Path(config.data_paths.processed_attack_data_path), - targets_data_path=Path(config.data_paths.midst_data_path), + targets_data_path=Path(config.data_processing_config.midst_data_path), ) # Load the challenge dataframe for training RMIA shadow models. rmia_training_choice = RmiaTrainingDataChoice(config.target_model.attack_rmia_shadow_training_data_choice) @@ -357,8 +355,13 @@ def run_metaclassifier_testing( else: log(INFO, "All shadow models for testing phase found. Using existing RMIA shadow models...") - # Extract and drop id columns from the test data - test_data, test_trans_ids = extract_and_drop_id_column(test_data, Path(config.metaclassifier.data_types_file_path)) + # Extract the main ID column's values from the test data + test_trans_ids = extract_primary_id_column( + data_frame=test_data, + data_types_file_path=Path(config.metaclassifier.data_types_file_path), + ) + # Drop id columns from the test data. Berka has two id columns: "trans_id" and "account_id". + test_data = get_df_without_id(test_data) # 4) Initialize the attacker object, and assign the loaded metaclassifier to it. blending_attacker = BlendingPlusPlus( diff --git a/src/midst_toolkit/attacks/ensemble/blending.py b/src/midst_toolkit/attacks/ensemble/blending.py index 2b7c194d..24104cc1 100644 --- a/src/midst_toolkit/attacks/ensemble/blending.py +++ b/src/midst_toolkit/attacks/ensemble/blending.py @@ -251,6 +251,8 @@ def predict( score = None if y_test is not None: - score = TprAtFpr.get_tpr_at_fpr(true_membership=y_test, predictions=probabilities, max_fpr=0.1) + score = TprAtFpr.get_tpr_at_fpr( + true_membership=y_test, predicted_membership=probabilities, fpr_threshold=0.1 + ) return probabilities, score diff --git a/tests/unit/attacks/ensemble/test_meta_classifier.py b/tests/unit/attacks/ensemble/test_meta_classifier.py index c4dc0fe4..5e8154ec 100644 --- a/tests/unit/attacks/ensemble/test_meta_classifier.py +++ b/tests/unit/attacks/ensemble/test_meta_classifier.py @@ -359,7 +359,7 @@ def test_predict_flow( call_args = mock_get_tpr.call_args np.testing.assert_array_equal(call_args.kwargs["true_membership"], sample_dataframes["y_test"]) - np.testing.assert_array_almost_equal(call_args.kwargs["predictions"], expected_probabilities) - np.testing.assert_equal(call_args.kwargs["max_fpr"], 0.1) + np.testing.assert_array_almost_equal(call_args.kwargs["predicted_membership"], expected_probabilities) + np.testing.assert_equal(call_args.kwargs["fpr_threshold"], 0.1) assert score == 0.99