Skip to content

Add SleepWakeClassification task for DREAMT#892

Open
diegofariasc wants to merge 28 commits intosunlabuiuc:masterfrom
diegofariasc:diegof4/dreamt_sleep_tracking
Open

Add SleepWakeClassification task for DREAMT#892
diegofariasc wants to merge 28 commits intosunlabuiuc:masterfrom
diegofariasc:diegof4/dreamt_sleep_tracking

Conversation

@diegofariasc
Copy link

Contributor: Diego Farias Castro (diegof4@illinois.edu)
Type of contribution:
task
Link to original paper: https://proceedings.mlr.press/v248/wang24a.html

High-level description:
This PR adds a new standalone task, SleepWakeClassification, on top of the existing DREAMTDataset. The task supports epoch-level sleep-vs-wake prediction from multimodal wrist-worn wearable signals. It turns each DREAMT record into fixed-length epochs, extracts features from accelerometer, temperature, blood volume pulse, and electrodermal activity signals, adds temporal context features, and assigns a binary sleep/wake label to each epoch.

Implementation summary:

  • Adds SleepWakeClassification in pyhealth/tasks/sleep_wake_classification.py
  • Exports the task in pyhealth/tasks/__init__.py
  • Adds task API documentation in docs/api/tasks/pyhealth.tasks.sleep_wake_classification.rst
  • Registers the task page in docs/api/tasks.rst
  • Includes an example ablation script in examples/dreamt_sleep_wake_classification_lightgbm.py
  • Includes task tests in tests/core/test_sleep_wake_classification.py

Reproducibility scope:
This PR focuses on the task side of the paper. It makes the sleep-wake prediction setting available inside PyHealth so the generated samples can be used in new experiments and ablation studies.

Task behavior:

  • Input: wearable records from the existing DREAMTDataset
  • Output: epoch-level samples with patient_id, record_id, epoch_index,
    features, and binary label
  • Labels: wake maps to 1; sleep stages (REM, N1, N2, N3) map to 0
  • Features: accelerometer summaries, temperature summaries, BVP-based HRV
    features, and EDA-based SCR features
  • Temporal context: Gaussian smoothing, temporal derivative, and rolling
    variance for each base feature

File guide:

  • pyhealth/tasks/sleep_wake_classification.py: task implementation
  • pyhealth/tasks/__init__.py: public task export
  • docs/api/tasks/pyhealth.tasks.sleep_wake_classification.rst: task docs
  • docs/api/tasks.rst: task index update
  • examples/dreamt_sleep_wake_classification_lightgbm.py: example and ablation workflow
  • tests/core/test_sleep_wake_classification.py: task unit tests

diegofariasc and others added 28 commits March 8, 2026 22:26
Copy link
Collaborator

@jhnwu3 jhnwu3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me know if you have any more questions. Nice work!

from pyhealth.tasks.sleep_wake_classification import SleepWakeClassification


class FakeEvent:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've been thinking a lot about standardizing where/what some test cases would typically look like here. I like how you reverse-engineering some aspects of what goes in within PyHealth, but I think it would be much better if we followed some of the other examples. (Still working on test case best practices here as we learn what seems to scale better and what doesn't).

But, if possible instead of constructing object oriented test classes, can we just create fake tmp data, and use the explicit datasets/data types in PyHealth?

See this example:
https://github.com/sunlabuiuc/PyHealth/blob/master/tests/core/test_chestxray14.py

@EricSchrock can probably give way better advice on this. But, the tldr; we want the testing environment to mimic the real working environment as closely as possible here.

"""Binary sleep-wake classification task for DREAMT wearable recordings.
This task converts each DREAMT wearable recording into fixed-length epochs,
extracts physiological features from multiple sensor modalities,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you give a rough breakdown of what's happening here? It's a bit difficult to follow this like 10 different private functions haha. It'd be good to have it flow, something like starting from step 1 to step 9, this is what's going on, and the sequence of function calls here.

Similarly, in the doc strings, can you add an example on how a user might use DreamT here with the set_task() call here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's also not clear exactly what we mean by features here. What shape is it? What do its dimensions mean?

I think a key thing about what we want to do as good practice is to explain what it is we're doing here and what are the expected outputs here.

X_test = imputer.transform(X_test)

# Train a LightGBM model on the current feature subset.
train_data = lgb.Dataset(X_train, label=y_train)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, is it possible to create a ContraWR example with a signals-based model here? The lightgbm isn't bad, and I see the inputs in essence are functionally formatted like a table in some sense (table of signals) here, which is still pretty cool to see as an example.

But this is a signals dataset after all:
https://physionet.org/content/dreamt/2.1.0/

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants