Skip to content

ParallelStreamingDataset with resume=True does not resume after manually loading state dict when breaking early #770

@philgzl

Description

@philgzl

🐛 Bug

When breaking early (e.g. when setting limit_<stage>_batches in Trainer), ParallelStreamingDataset with StreamingDataLoader does not resume from where it left off in the previous epoch if load_state_dict was manually called externally.

This is because resuming of ParallelStreamingDataset only happens if self.restore is False in StreamingDataLoader. There are only two places where self.restore is set to False other than in __init__: at the end of StreamingDataLoader.__iter__ and at the beginning to undo the effect of calling load_state_dict internally since #761. The first may never happen when breaking early with limit_<stage>_batches. The second only happens if self.restore is False in the first place. Calling load_state_dict externally sets self.restore to True such that it is not called again internally when entering StreamingDataLoader.__iter__ and that the epoch counter is not incremented. But as a consequence resuming gets totally disabled since self.restore is never set to False again.

One solution is to always set self.restore to False when entering StreamingDataLoader.__iter__.

#761 fixed a similar behavior caused by calling load_state_dict internally. But it did not consider calling load_state_dict externally.

To Reproduce

See code sample below.

Code sample
from litdata import ParallelStreamingDataset, StreamingDataLoader, StreamingDataset
from litdata.streaming import Cache

cache = Cache(input_dir="temp/", chunk_size=1)
dset_len = 10
for i in range(dset_len):
    cache[i] = i
cache.done()
cache.merge()

dset = ParallelStreamingDataset([StreamingDataset("temp/")], length=999, resume=True)
assert dset.is_cycling()

dloader = StreamingDataLoader(dset)

expected = 0

# epoch 1
for i, (batch,) in enumerate(dloader):
    assert batch == expected, (batch, expected)  # succeeds
    expected = (expected + 1) % dset_len
    if i == 3:
        break

# epoch 2
for i, (batch,) in enumerate(dloader):
    assert batch == expected, (batch, expected)  # succeeds
    expected = (expected + 1) % dset_len
    if i == 3:
        break

state_dict = dloader.state_dict()
expected_on_load = expected

# epoch 3
for i, (batch,) in enumerate(dloader):
    assert batch == expected, (batch, expected)  # succeeds since #761
    expected = (expected + 1) % dset_len
    if i == 3:
        break

# simulate training crash and resume
dset = ParallelStreamingDataset([StreamingDataset("temp/")], length=999, resume=True)
assert dset.is_cycling()

dloader = StreamingDataLoader(dset)
dloader.load_state_dict(state_dict)  # external call to load state dict

expected = expected_on_load

# epoch 3 resumed
for i, (batch,) in enumerate(dloader):
    assert batch == expected, (batch, expected)  # succeeds
    expected = (expected + 1) % dset_len
    if i == 3:
        break

# epoch 4
for i, (batch,) in enumerate(dloader):
    assert batch == expected, (batch, expected)  # fails because internal resuming got disabled
    expected = (expected + 1) % dset_len
    if i == 3:
        break

Expected behavior

The resuming behavior of ParallelStreamingDataset should persist after manually loading a state dict externally.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workinghelp wantedExtra attention is needed

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions