-
Notifications
You must be signed in to change notification settings - Fork 590
Add SleepWakeClassification task for DREAMT #892
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
diegofariasc
wants to merge
28
commits into
sunlabuiuc:master
Choose a base branch
from
diegofariasc:diegof4/dreamt_sleep_tracking
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+1,567
−0
Open
Changes from all commits
Commits
Show all changes
28 commits
Select commit
Hold shift + click to select a range
d168979
feat: add SleepWakeClassification task for binary sleep/wake labeling
diegofariasc 8f8fe6c
feat: add SleepWakeClassification task with epoch segmentation and ba…
diegofariasc a511bcf
test: add runs unit test for SleepWakeClassification task
diegofariasc 93246ab
feat: implement feature extraction pipeline for SleepWakeClassification
diegofariasc 58d8366
feat: add record-level BVP features to SleepWakeClassification
diegofariasc a90af6c
feat: add record-level EDA features to SleepWakeClassification
diegofariasc 30d0f90
feat: add temporal feature enhancement to SleepWakeClassification
diegofariasc 4e91596
feat: add initial sleep-wake task and temporal feature ablation example
diegofariasc dbcc496
feat: add modality ablation experiments for DREAMT sleep-wake classif…
diegofariasc 7ef252d
refactor: improve readability and reuse in sleep-wake task
diegofariasc 20cf21f
refactor: reorder sleep-wake task methods by responsibility
diegofariasc 4686a51
doc: add sleep_wake_classification.rst
diegofariasc df5720f
doc: document all methods in SleepWakeClassification
diegofariasc 9098292
feat: add SleepWakeClassification to init.py
diegofariasc 79efe27
feat: add Sleep-Wake Classification to tasks.rst
diegofariasc b72ae6b
refactor: use black+isort to autoformat task code following PEP88
diegofariasc df985b3
test: add tests covering new SleepWakeClassification task
diegofariasc 9357506
doc: add docstrings to tests
diegofariasc f04f602
refactor: use black+isort to autoformat test code following PEP88
diegofariasc a98a1b4
refactor: use specific Exception types instead of general Exception
diegofariasc a2090a2
refactor: generalize sleep-wake classification example
diegofariasc 1ba1795
doc: add file header to sleep_wake_classification.py
diegofariasc cbc98de
refactor: improve formatting of results in sleep_wake_classification …
diegofariasc e8c7024
refactor: use black+issort on example study
diegofariasc 11447e4
refactor: rename sleep_wake_classification example to dreamt_sleep_wa…
diegofariasc 489d443
refactor: improve typing in sleep_wake_classification task and example
diegofariasc 80798a0
refactor: add support for synthetic data in example
diegofariasc 18d2f80
Merge branch 'sunlabuiuc:master' into diegof4/dreamt_sleep_tracking
diegofariasc File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| pyhealth.tasks.sleep_wake_classification | ||
| ======================================== | ||
|
|
||
| .. autoclass:: pyhealth.tasks.sleep_wake_classification.SleepWakeClassification | ||
| :members: | ||
| :undoc-members: | ||
| :show-inheritance: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,376 @@ | ||
| import io | ||
| import logging | ||
| import warnings | ||
| from collections import Counter | ||
| from contextlib import redirect_stderr, redirect_stdout | ||
| from typing import Iterable | ||
|
|
||
| import lightgbm as lgb | ||
| import numpy as np | ||
| from sklearn.ensemble import RandomForestClassifier | ||
| from sklearn.exceptions import ConvergenceWarning | ||
| from sklearn.impute import SimpleImputer | ||
| from sklearn.linear_model import LogisticRegression | ||
| from sklearn.metrics import ( | ||
| accuracy_score, | ||
| average_precision_score, | ||
| f1_score, | ||
| roc_auc_score, | ||
| ) | ||
|
|
||
| from pyhealth.datasets import DREAMTDataset | ||
| from pyhealth.tasks.sleep_wake_classification import SleepWakeClassification | ||
|
|
||
| # Configuration | ||
| DREAMT_ROOT = "REPLACE_WITH_DREAMT_ROOT" | ||
| TRAIN_PATIENT_IDS = ["S028", "S062", "S078"] | ||
| EVAL_PATIENT_IDS = ["S081", "S099"] | ||
| EPOCH_SECONDS = 30 | ||
| SAMPLING_RATE = 64 | ||
|
|
||
| # Console formatting codes | ||
| RESET = "\033[0m" | ||
| BOLD = "\033[1m" | ||
| CYAN = "\033[36m" | ||
| GREEN = "\033[32m" | ||
| YELLOW = "\033[33m" | ||
|
|
||
|
|
||
| def build_synthetic_benchmark_data() -> tuple[np.ndarray, np.ndarray, np.ndarray]: | ||
| """Builds synthetic sleep-wake samples for a runnable ablation example. | ||
|
|
||
| Returns: | ||
| Synthetic feature matrix, binary labels, and patient IDs. | ||
| """ | ||
| rng = np.random.default_rng(42) | ||
| patient_ids = TRAIN_PATIENT_IDS + EVAL_PATIENT_IDS | ||
| samples_per_patient = 24 | ||
| num_base_features = 21 | ||
| num_temporal_features = num_base_features * 3 | ||
| num_features = num_base_features + num_temporal_features | ||
|
|
||
| groups = np.repeat(patient_ids, samples_per_patient) | ||
| y = rng.binomial(1, 0.35, size=len(groups)) | ||
|
|
||
| X = rng.normal(0.0, 1.0, size=(len(groups), num_features)) | ||
| X[y == 1, :10] += 0.9 | ||
| X[y == 1, 10:14] += 0.4 | ||
| X[y == 1, 14:17] += 0.3 | ||
| X[y == 1, 17:21] += 0.2 | ||
| X[y == 1, 21:] += 0.25 | ||
|
|
||
| return X.astype(float), y.astype(int), groups.astype(str) | ||
|
|
||
|
|
||
| def format_section(title: str) -> str: | ||
| """Formats a section title for console output. | ||
|
|
||
| Args: | ||
| title: Section title to format. | ||
|
|
||
| Returns: | ||
| A colorized section title string. | ||
| """ | ||
| return f"\n{BOLD}{CYAN}{title}{RESET}" | ||
|
|
||
|
|
||
| def format_patient_ids(patient_ids: Iterable[str]) -> str: | ||
| """Formats patient IDs for readable console output. | ||
|
|
||
| Args: | ||
| patient_ids: Iterable of patient identifiers. | ||
|
|
||
| Returns: | ||
| A comma-separated string of patient IDs. | ||
| """ | ||
| return ", ".join(sorted(str(patient_id) for patient_id in set(patient_ids))) | ||
|
|
||
|
|
||
| def print_metric(name: str, value: float) -> None: | ||
| """Prints a metric with consistent console formatting. | ||
|
|
||
| Args: | ||
| name: Metric name. | ||
| value: Metric value. | ||
| """ | ||
| print(f" {name:<16}{value:.4f}") | ||
|
|
||
|
|
||
| def summarize_label_counts(labels): | ||
| """Builds a readable sleep/wake label summary. | ||
|
|
||
| Args: | ||
| labels: Iterable of binary labels. | ||
|
|
||
| Returns: | ||
| A formatted label count string. | ||
| """ | ||
| counts = Counter(labels) | ||
| return f"sleep (0): {counts.get(0, 0)}, " f"wake (1): {counts.get(1, 0)}" | ||
|
|
||
|
|
||
| def configure_clean_output() -> None: | ||
| """Suppresses noisy warnings and logs for a cleaner example run.""" | ||
| warnings.filterwarnings("ignore", category=ConvergenceWarning) | ||
| logging.getLogger("pyhealth").setLevel(logging.ERROR) | ||
| logging.getLogger("pyhealth.tasks.sleep_wake_classification").setLevel( | ||
| logging.ERROR | ||
| ) | ||
|
|
||
|
|
||
| def split_samples_by_patient_ids( | ||
| X: np.ndarray, | ||
| y: np.ndarray, | ||
| groups: np.ndarray, | ||
| ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: | ||
| """Splits samples into train and evaluation sets using patient IDs. | ||
|
|
||
| Args: | ||
| X: Feature matrix. | ||
| y: Binary label vector. | ||
| groups: Patient identifier for each sample. | ||
|
|
||
| Returns: | ||
| Train and evaluation features, labels, and patient groups. | ||
| """ | ||
| train_mask = np.isin(groups, TRAIN_PATIENT_IDS) | ||
| eval_mask = np.isin(groups, EVAL_PATIENT_IDS) | ||
|
|
||
| if not np.any(train_mask): | ||
| raise ValueError("No samples found for TRAIN_PATIENT_IDS.") | ||
| if not np.any(eval_mask): | ||
| raise ValueError("No samples found for EVAL_PATIENT_IDS.") | ||
|
|
||
| return ( | ||
| X[train_mask], | ||
| X[eval_mask], | ||
| y[train_mask], | ||
| y[eval_mask], | ||
| groups[train_mask], | ||
| groups[eval_mask], | ||
| ) | ||
|
|
||
|
|
||
| def run_experiment( | ||
| X: np.ndarray, | ||
| y: np.ndarray, | ||
| groups: np.ndarray, | ||
| name: str, | ||
| ) -> None: | ||
| """Runs one feature ablation experiment and prints evaluation metrics. | ||
|
|
||
| Args: | ||
| X: Feature matrix for the selected experiment. | ||
| y: Binary label vector. | ||
| groups: Patient identifier for each sample. | ||
| name: Name of the ablation setting. | ||
| """ | ||
| # Split samples into train and evaluation sets | ||
| X_train, X_test, y_train, y_test, g_train, g_test = split_samples_by_patient_ids( | ||
| X, | ||
| y, | ||
| groups, | ||
| ) | ||
|
|
||
| # Report dataset statistics | ||
| print(format_section(f"Ablation: {name}")) | ||
| print(f"{BOLD}Train patients:{RESET} {format_patient_ids(g_train)}") | ||
| print(f"{BOLD}Eval patients:{RESET} {format_patient_ids(g_test)}") | ||
| print(f"{BOLD}Train samples:{RESET} {len(X_train)}") | ||
| print(f"{BOLD}Eval samples:{RESET} {len(X_test)}") | ||
|
|
||
| # Remove features that are all NaN in the training set | ||
| non_all_nan_cols = ~np.isnan(X_train).all(axis=0) | ||
| X_train = X_train[:, non_all_nan_cols] | ||
| X_test = X_test[:, non_all_nan_cols] | ||
|
|
||
| print(f"{BOLD}Feature count:{RESET} {X_train.shape[1]}") | ||
|
|
||
| imputer = SimpleImputer(strategy="median") | ||
| X_train = imputer.fit_transform(X_train) | ||
| X_test = imputer.transform(X_test) | ||
|
|
||
| # Train a LightGBM model on the current feature subset. | ||
| train_data = lgb.Dataset(X_train, label=y_train) | ||
| test_data = lgb.Dataset(X_test, label=y_test, reference=train_data) | ||
|
|
||
| params = { | ||
| "objective": "binary", | ||
| "metric": "binary_logloss", | ||
| "boosting_type": "gbdt", | ||
| "learning_rate": 0.05, | ||
| "num_leaves": 31, | ||
| "feature_fraction": 0.9, | ||
| "bagging_fraction": 0.9, | ||
| "bagging_freq": 5, | ||
| "verbose": -1, | ||
| "seed": 42, | ||
| } | ||
|
|
||
| model = lgb.train( | ||
| params, | ||
| train_data, | ||
| num_boost_round=200, | ||
| valid_sets=[test_data], | ||
| callbacks=[lgb.early_stopping(stopping_rounds=20, verbose=False)], | ||
| ) | ||
|
|
||
| y_prob = model.predict(X_test) | ||
| y_pred = (y_prob >= 0.3).astype(int) | ||
|
|
||
| # Report standard binary classification metrics. | ||
| print_metric("Accuracy", accuracy_score(y_test, y_pred)) | ||
| print_metric("F1", f1_score(y_test, y_pred)) | ||
| print_metric("AUROC", roc_auc_score(y_test, y_prob)) | ||
| print_metric("AUPRC", average_precision_score(y_test, y_prob)) | ||
|
|
||
|
|
||
| def run_model_comparison( | ||
| X: np.ndarray, | ||
| y: np.ndarray, | ||
| groups: np.ndarray, | ||
| ) -> None: | ||
| """Runs a small model comparison on the full temporal feature set. | ||
|
|
||
| Args: | ||
| X: Full feature matrix. | ||
| y: Binary label vector. | ||
| groups: Patient identifier for each sample. | ||
| """ | ||
| # Use the same predefined patient split to compare alternative models | ||
| X_train, X_test, y_train, y_test, g_train, g_test = split_samples_by_patient_ids( | ||
| X, | ||
| y, | ||
| groups, | ||
| ) | ||
|
|
||
| print(format_section("Model Comparison: ALL modalities + temporal")) | ||
| print(f"{BOLD}Train patients:{RESET} {format_patient_ids(g_train)}") | ||
| print(f"{BOLD}Eval patients:{RESET} {format_patient_ids(g_test)}") | ||
|
|
||
| non_all_nan_cols = ~np.isnan(X_train).all(axis=0) | ||
| X_train = X_train[:, non_all_nan_cols] | ||
| X_test = X_test[:, non_all_nan_cols] | ||
|
|
||
| imputer = SimpleImputer(strategy="median") | ||
| X_train = imputer.fit_transform(X_train) | ||
| X_test = imputer.transform(X_test) | ||
|
|
||
| # Compare logistic regression and random forest on the full feature set. | ||
| models = { | ||
| "LogisticRegression": LogisticRegression(max_iter=1000), | ||
| "RandomForest": RandomForestClassifier( | ||
| n_estimators=200, | ||
| random_state=42, | ||
| n_jobs=-1, | ||
| ), | ||
| } | ||
|
|
||
| for name, model in models.items(): | ||
| model.fit(X_train, y_train) | ||
|
|
||
| if hasattr(model, "predict_proba"): | ||
| y_prob = model.predict_proba(X_test)[:, 1] | ||
| else: | ||
| y_prob = model.decision_function(X_test) | ||
|
|
||
| y_pred = (y_prob >= 0.3).astype(int) | ||
|
|
||
| print(f"\n{YELLOW}{name}{RESET}") | ||
| print_metric("Accuracy", accuracy_score(y_test, y_pred)) | ||
| print_metric("F1", f1_score(y_test, y_pred)) | ||
| print_metric("AUROC", roc_auc_score(y_test, y_prob)) | ||
| print_metric("AUPRC", average_precision_score(y_test, y_prob)) | ||
|
|
||
|
|
||
| def main() -> None: | ||
| """Runs the DREAMT sleep-wake classification example workflow.""" | ||
| configure_clean_output() | ||
|
|
||
| if DREAMT_ROOT == "REPLACE_WITH_DREAMT_ROOT": | ||
| print(format_section("DREAMT Sleep-Wake Classification Example")) | ||
| print("DREAMT_ROOT not set. Running the ablation workflow on synthetic data...") | ||
| print( | ||
| f"{YELLOW}Warning:{RESET} synthetic samples are randomly generated to " | ||
| "make the example runnable without DREAMT. The resulting metrics are " | ||
| "not realistic and should not be interpreted as evidence for the " | ||
| "task or paper claims\n." | ||
| ) | ||
| print(f"{BOLD}Train patients:{RESET} {', '.join(TRAIN_PATIENT_IDS)}") | ||
| print(f"{BOLD}Eval patients:{RESET} {', '.join(EVAL_PATIENT_IDS)}") | ||
|
|
||
| X_all, y, groups = build_synthetic_benchmark_data() | ||
| print(f"{BOLD}Total epoch samples:{RESET} {len(X_all)}") | ||
| print(f"{BOLD}Label counts:{RESET} {summarize_label_counts(y)}") | ||
| print( | ||
| f"{BOLD}Feature matrix:{RESET} " | ||
| f"{X_all.shape[0]} samples x {X_all.shape[1]} features" | ||
| ) | ||
| else: | ||
| # Suppress verbose dataset initialization messages and print a cleaner summary. | ||
| with redirect_stdout(io.StringIO()), redirect_stderr(io.StringIO()): | ||
| dataset = DREAMTDataset(root=DREAMT_ROOT) | ||
| task = SleepWakeClassification( | ||
| epoch_seconds=EPOCH_SECONDS, | ||
| sampling_rate=SAMPLING_RATE, | ||
| ) | ||
|
|
||
| print(format_section("DREAMT Sleep-Wake Classification Example")) | ||
| print(f"{BOLD}Dataset root:{RESET} {DREAMT_ROOT}") | ||
| print(f"{BOLD}Train patients:{RESET} {', '.join(TRAIN_PATIENT_IDS)}") | ||
| print(f"{BOLD}Eval patients:{RESET} {', '.join(EVAL_PATIENT_IDS)}") | ||
|
|
||
| # Convert the selected DREAMT patients into epoch-level sleep/wake samples. | ||
| all_samples = [] | ||
| selected_patient_ids = TRAIN_PATIENT_IDS + EVAL_PATIENT_IDS | ||
| for patient_id in selected_patient_ids: | ||
| patient = dataset.get_patient(patient_id) | ||
| samples = task(patient) | ||
| print(f" patient {patient_id:<4} -> {len(samples)} epoch samples") | ||
| all_samples.extend(samples) | ||
|
|
||
| print(f"{BOLD}Total epoch samples:{RESET} {len(all_samples)}") | ||
| print( | ||
| f"{BOLD}Label counts:{RESET} " | ||
| f"{summarize_label_counts(sample['label'] for sample in all_samples)}" | ||
| ) | ||
|
|
||
| # Turn the task samples into arrays for training and evaluation. | ||
| X_all = np.array([s["features"] for s in all_samples], dtype=float) | ||
| y = np.array([s["label"] for s in all_samples], dtype=int) | ||
| groups = np.array([s["patient_id"] for s in all_samples]) | ||
|
|
||
| if DREAMT_ROOT != "REPLACE_WITH_DREAMT_ROOT": | ||
| print( | ||
| f"{BOLD}Feature matrix:{RESET} " | ||
| f"{X_all.shape[0]} samples x {X_all.shape[1]} features" | ||
| ) | ||
|
|
||
| # Keep only the base per-epoch features without temporal augmentation. | ||
| X_base = X_all[:, :21] | ||
|
|
||
| # Keep the full feature matrix, including temporal context features. | ||
| X_temporal = X_all | ||
|
|
||
| # Group feature indices by modality for the ablation experiments. | ||
| acc_idx = list(range(0, 10)) | ||
| temp_idx = list(range(10, 14)) | ||
| bvp_idx = list(range(14, 17)) | ||
| eda_idx = list(range(17, 21)) | ||
|
|
||
| X_acc = X_base[:, acc_idx] | ||
| X_acc_temp = X_base[:, acc_idx + temp_idx] | ||
| X_acc_temp_bvp = X_base[:, acc_idx + temp_idx + bvp_idx] | ||
| X_all_modalities = X_base[:, acc_idx + temp_idx + bvp_idx + eda_idx] | ||
|
|
||
| # Run experiments using different feature groups. | ||
| run_experiment(X_acc, y, groups, "ACC only") | ||
| run_experiment(X_acc_temp, y, groups, "ACC + TEMP") | ||
| run_experiment(X_acc_temp_bvp, y, groups, "ACC + TEMP + BVP") | ||
| run_experiment(X_all_modalities, y, groups, "ALL modalities") | ||
| run_experiment(X_temporal, y, groups, "ALL modalities + temporal") | ||
| run_model_comparison(X_temporal, y, groups) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious, is it possible to create a ContraWR example with a signals-based model here? The lightgbm isn't bad, and I see the inputs in essence are functionally formatted like a table in some sense (table of signals) here, which is still pretty cool to see as an example.
But this is a signals dataset after all:
https://physionet.org/content/dreamt/2.1.0/