Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 19 additions & 18 deletions src/litdata/streaming/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,25 +639,26 @@ def __init__(
) # type: ignore

def __iter__(self) -> Any:
if not self.restore:
if (
isinstance(self.dataset, ParallelStreamingDataset)
and self.dataset.is_cycling()
and self.dataset.resume
and self.current_epoch != 0
):
# For ParallelStreamingDataset with _length != None we want to cycle the wrapped datasets i.e. we do not
# want to restart at index 0 at every epoch. So we set them in restore state.
if (
isinstance(self.dataset, ParallelStreamingDataset)
and self.dataset.is_cycling()
and self.dataset.resume
and self.current_epoch != 0
):
# For ParallelStreamingDataset with _length != None we want to cycle the wrapped datasets i.e. we do not
# want to restart at index 0 at every epoch. So we set them in restore state.
Comment on lines +648 to +649
Copy link

Copilot AI Dec 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment states "we do not want to restart at index 0 at every epoch" but could be more explicit about what this fix addresses. Consider expanding the comment to mention that this handles both automatic cycling (between epochs in the same session) and manual resume (after loading a state dict from a previous session/crash), as this is the key bug being fixed.

Suggested change
# For ParallelStreamingDataset with _length != None we want to cycle the wrapped datasets i.e. we do not
# want to restart at index 0 at every epoch. So we set them in restore state.
# For ParallelStreamingDataset with _length != None, we want to cycle the wrapped datasets and avoid
# restarting at index 0 at every epoch. This logic ensures that we correctly handle both automatic cycling
# between epochs in the same session and manual resume after loading a state dict from a previous session
# or crash. We set the datasets in restore state to maintain the correct position across both scenarios.

Copilot uses AI. Check for mistakes.
if not self.restore:
self.load_state_dict(self.state_dict())
self.restore = False
else:
self._latest_worker_idx = 0
self._worker_idx = cycle(list(range(self.num_workers if self.num_workers > 0 else 1)))
self._worker_idx_iter = iter(self._worker_idx)
self._num_samples_yielded_wrapper = {}
self._num_samples_yielded_streaming = 0
self._num_cycles = {}
self.dataset.reset_state_dict()
self.current_epoch += 1
self.restore = False
elif not self.restore:
self._latest_worker_idx = 0
self._worker_idx = cycle(list(range(self.num_workers if self.num_workers > 0 else 1)))
self._worker_idx_iter = iter(self._worker_idx)
self._num_samples_yielded_wrapper = {}
self._num_samples_yielded_streaming = 0
self._num_cycles = {}
self.dataset.reset_state_dict()
self.current_epoch += 1

self.dataset.set_epoch(self.current_epoch)
Expand Down
160 changes: 139 additions & 21 deletions tests/streaming/test_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,9 +423,9 @@ def test_dataloader_shuffle(tmp_path, shuffle):


def prepare_parallel_dataset_and_dataloder(
tmp_path_factory, parlen, len1=48, len2=56, num_workers=0, batch_size=4, shuffle=True, resume=True
tmp_path_factory, parlen, len1=48, len2=56, num_workers=0, batch_size=4, shuffle=True, resume=True, tmpdir=None
):
tmpdir = tmp_path_factory.mktemp("data")
tmpdir = tmp_path_factory.mktemp("data") if tmpdir is None else tmpdir
datasets = [str(tmpdir / f"dataset_{i}") for i in range(2)]
for dataset, num_items in zip(datasets, [len1, len2]):
cache = Cache(input_dir=dataset, chunk_size=10)
Expand All @@ -437,12 +437,12 @@ def prepare_parallel_dataset_and_dataloder(
dset2 = StreamingDataset(datasets[1], shuffle=shuffle)
pardset = ParallelStreamingDataset(datasets=[dset1, dset2], length=parlen, resume=resume)
dloader = StreamingDataLoader(pardset, num_workers=num_workers, batch_size=batch_size)
return dset1, dset2, pardset, dloader
return dset1, dset2, pardset, dloader, tmpdir


@pytest.mark.parametrize("length", [None, 3, float("inf")])
def test_parallel_dataset_dataloader_states_without_any_iterations(tmp_path_factory, length):
_, _, _, dataloader = prepare_parallel_dataset_and_dataloder(tmp_path_factory, length)
_, _, _, dataloader, _ = prepare_parallel_dataset_and_dataloder(tmp_path_factory, length)
assert not dataloader.restore
dataloader.load_state_dict(dataloader.state_dict())
assert not dataloader.restore
Expand All @@ -458,7 +458,7 @@ def test_parallel_dataset_dataloader_states_without_any_iterations(tmp_path_fact
def test_parallel_dataset_dataloader_states_complete_iterations(tmp_path_factory, length, num_workers, batch_size):
print(f"Testing with num_workers={num_workers}")

_, _, parallel_dataset, dataloader = prepare_parallel_dataset_and_dataloder(
_, _, parallel_dataset, dataloader, _ = prepare_parallel_dataset_and_dataloder(
tmp_path_factory,
length,
batch_size=batch_size,
Expand Down Expand Up @@ -524,7 +524,7 @@ def test_parallel_dataset_dataloader_states_partial_iterations(
):
print(f"Testing with num_workers={num_workers}, break_at={break_at}")

_, _, parallel_dataset, dataloader = prepare_parallel_dataset_and_dataloder(
_, _, parallel_dataset, dataloader, _ = prepare_parallel_dataset_and_dataloder(
tmp_path_factory, length, batch_size=batch_size, num_workers=num_workers, shuffle=True
)

Expand Down Expand Up @@ -701,7 +701,7 @@ def test_parallel_dataset_with_dataloader_2_epochs(
expected_states_1: ExpectedStates,
expected_states_2: ExpectedStates,
):
dataset1, dataset2, _, dataloader = prepare_parallel_dataset_and_dataloder(
dataset1, dataset2, _, dataloader, _ = prepare_parallel_dataset_and_dataloder(
tmp_path_factory,
length,
len1,
Expand Down Expand Up @@ -837,7 +837,7 @@ def test_parallel_dataset_with_dataloader_2_epochs(
@pytest.mark.parametrize("shuffle", [False, True])
@pytest.mark.skipif(sys.platform in ("win32", "darwin"), reason="too slow in CI")
def test_parallel_dataset_partial_iteration_resume(tmp_path_factory, length, resume, shuffle):
_, _, pardset, dloader = prepare_parallel_dataset_and_dataloder(
_, _, pardset, dloader, tmpdir = prepare_parallel_dataset_and_dataloder(
tmp_path_factory, parlen=length, len1=10, len2=10, batch_size=1, num_workers=2, shuffle=shuffle, resume=resume
)
assert pardset.is_cycling() or length is None
Expand Down Expand Up @@ -871,6 +871,7 @@ def test_parallel_dataset_partial_iteration_resume(tmp_path_factory, length, res
assert all(torch.equal(x, y) for x, y in zip(batch, batches_1[i]))
if i == break_at:
break
state_dict_after_2 = dloader.state_dict()
expected_3 = [
[torch.tensor([4]), torch.tensor([4])],
[torch.tensor([9]), torch.tensor([9])],
Expand All @@ -888,34 +889,98 @@ def test_parallel_dataset_partial_iteration_resume(tmp_path_factory, length, res
if i == break_at:
break

# simulate training crash and manually resume by loading state dict created after epoch 2
_, _, pardset, dloader, _ = prepare_parallel_dataset_and_dataloder(
tmp_path_factory,
parlen=length,
len1=10,
len2=10,
batch_size=1,
num_workers=2,
shuffle=shuffle,
resume=resume,
tmpdir=tmpdir,
)
assert pardset.is_cycling() or length is None
dloader.load_state_dict(state_dict_after_2)
assert dloader.restore
# we should get same samples as in epoch 3 if resuming
# else we should get samples from epoch 2
batches_2 = []
for i, batch in enumerate(dloader):
if not shuffle:
assert all(
torch.equal(x, y)
for x, y in zip(batch, (expected_3 if resume and length is not None else expected_2)[i])
)
batches_2.append(batch)
if i == break_at:
break
# for some reason the workers are swapping their samples compared to the previous epoch when not resuming
# so we update expected_2 and batches_2 accordingly
expected_2 = [expected_2[i + 1] if i % 2 == 0 else expected_2[i - 1] for i in range(len(expected_2))]
batches_2 = [batches_2[i + 1] if i % 2 == 0 else batches_2[i - 1] for i in range(len(batches_2))]
Comment on lines +919 to +922
Copy link

Copilot AI Dec 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment suggests uncertainty about the behavior ("for some reason the workers are swapping their samples"). The test then compensates by swapping elements in expected_2 and batches_2 lists. This workaround indicates either: (1) the underlying behavior is not well understood, or (2) there's a non-deterministic or undocumented aspect of worker sample distribution. Consider investigating the root cause and documenting why this swapping occurs, or fixing the underlying issue if it's a bug.

Suggested change
# for some reason the workers are swapping their samples compared to the previous epoch when not resuming
# so we update expected_2 and batches_2 accordingly
expected_2 = [expected_2[i + 1] if i % 2 == 0 else expected_2[i - 1] for i in range(len(expected_2))]
batches_2 = [batches_2[i + 1] if i % 2 == 0 else batches_2[i - 1] for i in range(len(batches_2))]
# The order of samples delivered by the workers should be deterministic when shuffle=False.
# If this test fails, investigate the worker sample assignment logic in the DataLoader.
# Remove the workaround that swaps expected_2 and batches_2; the test should reflect the true, documented order.

Copilot uses AI. Check for mistakes.
expected_4 = [
[torch.tensor([1]), torch.tensor([1])],
[torch.tensor([6]), torch.tensor([6])],
[torch.tensor([2]), torch.tensor([2])],
[torch.tensor([7]), torch.tensor([7])],
]
for i, batch in enumerate(dloader):
if not shuffle:
assert all(
torch.equal(x, y)
for x, y in zip(batch, (expected_4 if resume and length is not None else expected_2)[i])
)
elif not resume and length is not None:
assert all(torch.equal(x, y) for x, y in zip(batch, batches_2[i]))
if i == break_at:
break
expected_5 = [
[torch.tensor([3]), torch.tensor([3])],
[torch.tensor([8]), torch.tensor([8])],
[torch.tensor([4]), torch.tensor([4])],
[torch.tensor([9]), torch.tensor([9])],
]
for i, batch in enumerate(dloader):
if not shuffle:
assert all(
torch.equal(x, y)
for x, y in zip(batch, (expected_5 if resume and length is not None else expected_2)[i])
)
elif not resume and length is not None:
assert all(torch.equal(x, y) for x, y in zip(batch, batches_2[i]))
if i == break_at:
break


@pytest.mark.parametrize("length", [None, 5])
@pytest.mark.parametrize("length", [None, 4])
@pytest.mark.parametrize("resume", [False, True])
@pytest.mark.parametrize("shuffle", [False, True])
@pytest.mark.skipif(sys.platform in ("win32", "darwin"), reason="too slow in CI")
def test_parallel_dataset_complete_iteration_resume(tmp_path_factory, length, resume, shuffle):
_, _, pardset, dloader = prepare_parallel_dataset_and_dataloder(
tmp_path_factory, parlen=length, len1=4, len2=4, batch_size=1, num_workers=2, shuffle=shuffle, resume=resume
_, _, pardset, dloader, tmpdir = prepare_parallel_dataset_and_dataloder(
tmp_path_factory, parlen=length, len1=6, len2=6, batch_size=1, num_workers=2, shuffle=shuffle, resume=resume
)
assert pardset.is_cycling() or length is None
expected_1 = [
[torch.tensor([0]), torch.tensor([0])],
[torch.tensor([2]), torch.tensor([2])],
[torch.tensor([1]), torch.tensor([1])],
[torch.tensor([3]), torch.tensor([3])],
[torch.tensor([0]), torch.tensor([0])],
[torch.tensor([1]), torch.tensor([1])],
[torch.tensor([4]), torch.tensor([4])],
[torch.tensor([2]), torch.tensor([2])],
[torch.tensor([5]), torch.tensor([5])],
]
batches_1 = []
for i, batch in enumerate(dloader):
if not shuffle:
assert all(torch.equal(x, y) for x, y in zip(batch, expected_1[i]))
batches_1.append(batch)
expected_2 = [
[torch.tensor([1]), torch.tensor([1])],
[torch.tensor([2]), torch.tensor([2])],
[torch.tensor([5]), torch.tensor([5])],
[torch.tensor([0]), torch.tensor([0])],
[torch.tensor([3]), torch.tensor([3])],
[torch.tensor([1]), torch.tensor([1])],
]
for i, batch in enumerate(dloader):
if not shuffle:
Expand All @@ -925,12 +990,12 @@ def test_parallel_dataset_complete_iteration_resume(tmp_path_factory, length, re
)
elif not resume and length is not None:
assert all(torch.equal(x, y) for x, y in zip(batch, batches_1[i]))
state_dict_after_2 = dloader.state_dict()
expected_3 = [
[torch.tensor([1]), torch.tensor([1])],
[torch.tensor([3]), torch.tensor([3])],
[torch.tensor([0]), torch.tensor([0])],
[torch.tensor([4]), torch.tensor([4])],
[torch.tensor([2]), torch.tensor([2])],
[torch.tensor([1]), torch.tensor([1])],
[torch.tensor([5]), torch.tensor([5])],
]
for i, batch in enumerate(dloader):
if not shuffle:
Expand All @@ -940,14 +1005,67 @@ def test_parallel_dataset_complete_iteration_resume(tmp_path_factory, length, re
)
elif not resume and length is not None:
assert all(torch.equal(x, y) for x, y in zip(batch, batches_1[i]))
# simulate training crash and manually resume by loading state dict created after epoch 2
_, _, pardset, dloader, _ = prepare_parallel_dataset_and_dataloder(
tmp_path_factory,
parlen=length,
len1=6,
len2=6,
batch_size=1,
num_workers=2,
shuffle=shuffle,
resume=resume,
tmpdir=tmpdir,
)
assert pardset.is_cycling() or length is None
dloader.load_state_dict(state_dict_after_2)
assert not dloader.restore # iterations are complete so no restore
# we should get same samples as in epoch 3 if resuming
# else we should get samples from epoch 1
for i, batch in enumerate(dloader):
if not shuffle:
assert all(
torch.equal(x, y)
for x, y in zip(batch, (expected_3 if resume and length is not None else expected_1)[i])
)
elif not resume and length is not None:
assert all(torch.equal(x, y) for x, y in zip(batch, batches_1[i]))
expected_4 = [
[torch.tensor([0]), torch.tensor([0])],
[torch.tensor([3]), torch.tensor([3])],
[torch.tensor([1]), torch.tensor([1])],
[torch.tensor([4]), torch.tensor([4])],
]
for i, batch in enumerate(dloader):
if not shuffle:
assert all(
torch.equal(x, y)
for x, y in zip(batch, (expected_4 if resume and length is not None else expected_1)[i])
)
elif not resume and length is not None:
assert all(torch.equal(x, y) for x, y in zip(batch, batches_1[i]))
expected_5 = [
[torch.tensor([2]), torch.tensor([2])],
[torch.tensor([5]), torch.tensor([5])],
[torch.tensor([0]), torch.tensor([0])],
[torch.tensor([3]), torch.tensor([3])],
]
for i, batch in enumerate(dloader):
if not shuffle:
assert all(
torch.equal(x, y)
for x, y in zip(batch, (expected_5 if resume and length is not None else expected_1)[i])
)
elif not resume and length is not None:
assert all(torch.equal(x, y) for x, y in zip(batch, batches_1[i]))


@pytest.mark.parametrize("length", [None, 18])
@pytest.mark.parametrize("resume", [False, True])
@pytest.mark.parametrize("shuffle", [False, True])
@pytest.mark.skipif(sys.platform in ("win32", "darwin"), reason="too slow in CI")
def test_parallel_dataset_partial_iteration_resume_without_dataloader(tmp_path_factory, length, resume, shuffle):
_, _, pardset, _ = prepare_parallel_dataset_and_dataloder(
_, _, pardset, _, _ = prepare_parallel_dataset_and_dataloder(
tmp_path_factory, parlen=length, len1=10, len2=10, batch_size=1, num_workers=2, shuffle=shuffle, resume=resume
)
assert pardset.is_cycling() or length is None
Expand Down Expand Up @@ -979,7 +1097,7 @@ def test_parallel_dataset_partial_iteration_resume_without_dataloader(tmp_path_f
@pytest.mark.parametrize("shuffle", [False, True])
@pytest.mark.skipif(sys.platform in ("win32", "darwin"), reason="too slow in CI")
def test_parallel_dataset_complete_iteration_resume_without_dataloader(tmp_path_factory, length, resume, shuffle):
_, _, pardset, _ = prepare_parallel_dataset_and_dataloder(
_, _, pardset, _, _ = prepare_parallel_dataset_and_dataloder(
tmp_path_factory, parlen=length, len1=4, len2=4, batch_size=1, num_workers=2, shuffle=shuffle, resume=resume
)
assert pardset.is_cycling() or length is None
Expand Down
Loading