-
Notifications
You must be signed in to change notification settings - Fork 81
Description
🐛 Bug
When breaking early (e.g. when setting limit_<stage>_batches in Trainer), ParallelStreamingDataset with StreamingDataLoader does not resume from where it left off in the previous epoch if load_state_dict was manually called externally.
This is because resuming of ParallelStreamingDataset only happens if self.restore is False in StreamingDataLoader. There are only two places where self.restore is set to False other than in __init__: at the end of StreamingDataLoader.__iter__ and at the beginning to undo the effect of calling load_state_dict internally since #761. The first may never happen when breaking early with limit_<stage>_batches. The second only happens if self.restore is False in the first place. Calling load_state_dict externally sets self.restore to True such that it is not called again internally when entering StreamingDataLoader.__iter__ and that the epoch counter is not incremented. But as a consequence resuming gets totally disabled since self.restore is never set to False again.
One solution is to always set self.restore to False when entering StreamingDataLoader.__iter__.
#761 fixed a similar behavior caused by calling load_state_dict internally. But it did not consider calling load_state_dict externally.
To Reproduce
See code sample below.
Code sample
from litdata import ParallelStreamingDataset, StreamingDataLoader, StreamingDataset
from litdata.streaming import Cache
cache = Cache(input_dir="temp/", chunk_size=1)
dset_len = 10
for i in range(dset_len):
cache[i] = i
cache.done()
cache.merge()
dset = ParallelStreamingDataset([StreamingDataset("temp/")], length=999, resume=True)
assert dset.is_cycling()
dloader = StreamingDataLoader(dset)
expected = 0
# epoch 1
for i, (batch,) in enumerate(dloader):
assert batch == expected, (batch, expected) # succeeds
expected = (expected + 1) % dset_len
if i == 3:
break
# epoch 2
for i, (batch,) in enumerate(dloader):
assert batch == expected, (batch, expected) # succeeds
expected = (expected + 1) % dset_len
if i == 3:
break
state_dict = dloader.state_dict()
expected_on_load = expected
# epoch 3
for i, (batch,) in enumerate(dloader):
assert batch == expected, (batch, expected) # succeeds since #761
expected = (expected + 1) % dset_len
if i == 3:
break
# simulate training crash and resume
dset = ParallelStreamingDataset([StreamingDataset("temp/")], length=999, resume=True)
assert dset.is_cycling()
dloader = StreamingDataLoader(dset)
dloader.load_state_dict(state_dict) # external call to load state dict
expected = expected_on_load
# epoch 3 resumed
for i, (batch,) in enumerate(dloader):
assert batch == expected, (batch, expected) # succeeds
expected = (expected + 1) % dset_len
if i == 3:
break
# epoch 4
for i, (batch,) in enumerate(dloader):
assert batch == expected, (batch, expected) # fails because internal resuming got disabled
expected = (expected + 1) % dset_len
if i == 3:
breakExpected behavior
The resuming behavior of ParallelStreamingDataset should persist after manually loading a state dict externally.