Skip to content
Open
131 changes: 1 addition & 130 deletions chebai/preprocessing/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -940,7 +940,7 @@ def _extract_class_hierarchy(self, data_path: str) -> "nx.DiGraph":
def _graph_to_raw_dataset(self, graph: "nx.DiGraph") -> pd.DataFrame:
"""
Converts the graph to a raw dataset.
Uses the graph created by `_extract_class_hierarchy` method to extract the
Uses the graph created by chebi_utils to extract the
raw data in Dataframe format with additional columns corresponding to each multi-label class.

Args:
Expand All @@ -951,21 +951,6 @@ def _graph_to_raw_dataset(self, graph: "nx.DiGraph") -> pd.DataFrame:
"""
pass

@abstractmethod
def select_classes(self, g: "nx.DiGraph", *args, **kwargs) -> List:
"""
Selects classes from the dataset based on a specified criteria.

Args:
g (nx.Graph): The graph representing the dataset.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.

Returns:
List: A sorted list of node IDs that meet the specified criteria.
"""
pass

def save_processed(self, data: pd.DataFrame, filename: str) -> None:
"""
Save the processed dataset to a pickle file.
Expand Down Expand Up @@ -1123,120 +1108,6 @@ def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
"""
pass

def get_test_split(
self, df: pd.DataFrame, seed: Optional[int] = None
) -> Tuple[pd.DataFrame, pd.DataFrame]:
"""
Split the input DataFrame into training and testing sets based on multilabel stratified sampling.

This method uses MultilabelStratifiedShuffleSplit to split the data such that the distribution of labels
in the training and testing sets is approximately the same. The split is based on the "labels" column
in the DataFrame.

Args:
df (pd.DataFrame): The input DataFrame containing the data to be split. It must contain a column
named "labels" with the multilabel data.
seed (int, optional): The random seed to be used for reproducibility. Default is None.

Returns:
Tuple[pd.DataFrame, pd.DataFrame]: A tuple containing the training set and testing set DataFrames.

Raises:
ValueError: If the DataFrame does not contain a column named "labels".
"""
from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit
from sklearn.model_selection import StratifiedShuffleSplit

print("Get test data split")

labels_list = df["labels"].tolist()

if len(labels_list[0]) > 1:
splitter = MultilabelStratifiedShuffleSplit(
n_splits=1, test_size=self.test_split, random_state=seed
)
else:
splitter = StratifiedShuffleSplit(
n_splits=1, test_size=self.test_split, random_state=seed
)

train_indices, test_indices = next(splitter.split(labels_list, labels_list))

df_train = df.iloc[train_indices]
df_test = df.iloc[test_indices]
return df_train, df_test

def get_train_val_splits_given_test(
self, df: pd.DataFrame, test_df: pd.DataFrame, seed: int = None
) -> Union[Dict[str, pd.DataFrame], Tuple[pd.DataFrame, pd.DataFrame]]:
"""
Split the dataset into train and validation sets, given a test set.
Use test set (e.g., loaded from another source or generated in get_test_split), to avoid overlap

Args:
df (pd.DataFrame): The original dataset.
test_df (pd.DataFrame): The test dataset.
seed (int, optional): The random seed to be used for reproducibility. Default is None.

Returns:
Union[Dict[str, pd.DataFrame], Tuple[pd.DataFrame, pd.DataFrame]]: A dictionary containing train and
validation sets if self.use_inner_cross_validation is True, otherwise a tuple containing the train
and validation DataFrames. The keys are the names of the train and validation sets, and the values
are the corresponding DataFrames.
"""
from iterstrat.ml_stratifiers import (
MultilabelStratifiedKFold,
MultilabelStratifiedShuffleSplit,
)
from sklearn.model_selection import StratifiedShuffleSplit

print("Split dataset into train / val with given test set")

test_ids = test_df["ident"].tolist()
df_trainval = df[~df["ident"].isin(test_ids)]
labels_list_trainval = df_trainval["labels"].tolist()

if self.use_inner_cross_validation:
folds = {}
kfold = MultilabelStratifiedKFold(
n_splits=self.inner_k_folds, random_state=seed
)
for fold, (train_ids, val_ids) in enumerate(
kfold.split(
labels_list_trainval,
labels_list_trainval,
)
):
df_validation = df_trainval.iloc[val_ids]
df_train = df_trainval.iloc[train_ids]
folds[self.raw_file_names_dict[f"fold_{fold}_train"]] = df_train
folds[self.raw_file_names_dict[f"fold_{fold}_validation"]] = (
df_validation
)

return folds

if len(labels_list_trainval[0]) > 1:
splitter = MultilabelStratifiedShuffleSplit(
n_splits=1,
test_size=self.validation_split / (1 - self.test_split),
random_state=seed,
)
else:
splitter = StratifiedShuffleSplit(
n_splits=1,
test_size=self.validation_split / (1 - self.test_split),
random_state=seed,
)

train_indices, validation_indices = next(
splitter.split(labels_list_trainval, labels_list_trainval)
)

df_validation = df_trainval.iloc[validation_indices]
df_train = df_trainval.iloc[train_indices]
return df_train, df_validation

def _retrieve_splits_from_csv(self) -> None:
"""
Retrieve previously saved data splits from splits.csv file or from provided file path.
Expand Down
Loading
Loading