diff --git a/src/litdata/streaming/dataloader.py b/src/litdata/streaming/dataloader.py index 44cda068..0b168f54 100644 --- a/src/litdata/streaming/dataloader.py +++ b/src/litdata/streaming/dataloader.py @@ -639,25 +639,26 @@ def __init__( ) # type: ignore def __iter__(self) -> Any: - if not self.restore: - if ( - isinstance(self.dataset, ParallelStreamingDataset) - and self.dataset.is_cycling() - and self.dataset.resume - and self.current_epoch != 0 - ): - # For ParallelStreamingDataset with _length != None we want to cycle the wrapped datasets i.e. we do not - # want to restart at index 0 at every epoch. So we set them in restore state. + if ( + isinstance(self.dataset, ParallelStreamingDataset) + and self.dataset.is_cycling() + and self.dataset.resume + and self.current_epoch != 0 + ): + # For ParallelStreamingDataset with _length != None we want to cycle the wrapped datasets i.e. we do not + # want to restart at index 0 at every epoch. So we set them in restore state. + if not self.restore: self.load_state_dict(self.state_dict()) - self.restore = False - else: - self._latest_worker_idx = 0 - self._worker_idx = cycle(list(range(self.num_workers if self.num_workers > 0 else 1))) - self._worker_idx_iter = iter(self._worker_idx) - self._num_samples_yielded_wrapper = {} - self._num_samples_yielded_streaming = 0 - self._num_cycles = {} - self.dataset.reset_state_dict() + self.current_epoch += 1 + self.restore = False + elif not self.restore: + self._latest_worker_idx = 0 + self._worker_idx = cycle(list(range(self.num_workers if self.num_workers > 0 else 1))) + self._worker_idx_iter = iter(self._worker_idx) + self._num_samples_yielded_wrapper = {} + self._num_samples_yielded_streaming = 0 + self._num_cycles = {} + self.dataset.reset_state_dict() self.current_epoch += 1 self.dataset.set_epoch(self.current_epoch) diff --git a/tests/streaming/test_parallel.py b/tests/streaming/test_parallel.py index 380e4792..6ec71671 100644 --- a/tests/streaming/test_parallel.py +++ b/tests/streaming/test_parallel.py @@ -423,9 +423,9 @@ def test_dataloader_shuffle(tmp_path, shuffle): def prepare_parallel_dataset_and_dataloder( - tmp_path_factory, parlen, len1=48, len2=56, num_workers=0, batch_size=4, shuffle=True, resume=True + tmp_path_factory, parlen, len1=48, len2=56, num_workers=0, batch_size=4, shuffle=True, resume=True, tmpdir=None ): - tmpdir = tmp_path_factory.mktemp("data") + tmpdir = tmp_path_factory.mktemp("data") if tmpdir is None else tmpdir datasets = [str(tmpdir / f"dataset_{i}") for i in range(2)] for dataset, num_items in zip(datasets, [len1, len2]): cache = Cache(input_dir=dataset, chunk_size=10) @@ -437,12 +437,12 @@ def prepare_parallel_dataset_and_dataloder( dset2 = StreamingDataset(datasets[1], shuffle=shuffle) pardset = ParallelStreamingDataset(datasets=[dset1, dset2], length=parlen, resume=resume) dloader = StreamingDataLoader(pardset, num_workers=num_workers, batch_size=batch_size) - return dset1, dset2, pardset, dloader + return dset1, dset2, pardset, dloader, tmpdir @pytest.mark.parametrize("length", [None, 3, float("inf")]) def test_parallel_dataset_dataloader_states_without_any_iterations(tmp_path_factory, length): - _, _, _, dataloader = prepare_parallel_dataset_and_dataloder(tmp_path_factory, length) + _, _, _, dataloader, _ = prepare_parallel_dataset_and_dataloder(tmp_path_factory, length) assert not dataloader.restore dataloader.load_state_dict(dataloader.state_dict()) assert not dataloader.restore @@ -458,7 +458,7 @@ def test_parallel_dataset_dataloader_states_without_any_iterations(tmp_path_fact def test_parallel_dataset_dataloader_states_complete_iterations(tmp_path_factory, length, num_workers, batch_size): print(f"Testing with num_workers={num_workers}") - _, _, parallel_dataset, dataloader = prepare_parallel_dataset_and_dataloder( + _, _, parallel_dataset, dataloader, _ = prepare_parallel_dataset_and_dataloder( tmp_path_factory, length, batch_size=batch_size, @@ -524,7 +524,7 @@ def test_parallel_dataset_dataloader_states_partial_iterations( ): print(f"Testing with num_workers={num_workers}, break_at={break_at}") - _, _, parallel_dataset, dataloader = prepare_parallel_dataset_and_dataloder( + _, _, parallel_dataset, dataloader, _ = prepare_parallel_dataset_and_dataloder( tmp_path_factory, length, batch_size=batch_size, num_workers=num_workers, shuffle=True ) @@ -701,7 +701,7 @@ def test_parallel_dataset_with_dataloader_2_epochs( expected_states_1: ExpectedStates, expected_states_2: ExpectedStates, ): - dataset1, dataset2, _, dataloader = prepare_parallel_dataset_and_dataloder( + dataset1, dataset2, _, dataloader, _ = prepare_parallel_dataset_and_dataloder( tmp_path_factory, length, len1, @@ -837,7 +837,7 @@ def test_parallel_dataset_with_dataloader_2_epochs( @pytest.mark.parametrize("shuffle", [False, True]) @pytest.mark.skipif(sys.platform in ("win32", "darwin"), reason="too slow in CI") def test_parallel_dataset_partial_iteration_resume(tmp_path_factory, length, resume, shuffle): - _, _, pardset, dloader = prepare_parallel_dataset_and_dataloder( + _, _, pardset, dloader, tmpdir = prepare_parallel_dataset_and_dataloder( tmp_path_factory, parlen=length, len1=10, len2=10, batch_size=1, num_workers=2, shuffle=shuffle, resume=resume ) assert pardset.is_cycling() or length is None @@ -871,6 +871,7 @@ def test_parallel_dataset_partial_iteration_resume(tmp_path_factory, length, res assert all(torch.equal(x, y) for x, y in zip(batch, batches_1[i])) if i == break_at: break + state_dict_after_2 = dloader.state_dict() expected_3 = [ [torch.tensor([4]), torch.tensor([4])], [torch.tensor([9]), torch.tensor([9])], @@ -888,22 +889,87 @@ def test_parallel_dataset_partial_iteration_resume(tmp_path_factory, length, res if i == break_at: break + # simulate training crash and manually resume by loading state dict created after epoch 2 + _, _, pardset, dloader, _ = prepare_parallel_dataset_and_dataloder( + tmp_path_factory, + parlen=length, + len1=10, + len2=10, + batch_size=1, + num_workers=2, + shuffle=shuffle, + resume=resume, + tmpdir=tmpdir, + ) + assert pardset.is_cycling() or length is None + dloader.load_state_dict(state_dict_after_2) + assert dloader.restore + # we should get same samples as in epoch 3 if resuming + # else we should get samples from epoch 2 + batches_2 = [] + for i, batch in enumerate(dloader): + if not shuffle: + assert all( + torch.equal(x, y) + for x, y in zip(batch, (expected_3 if resume and length is not None else expected_2)[i]) + ) + batches_2.append(batch) + if i == break_at: + break + # for some reason the workers are swapping their samples compared to the previous epoch when not resuming + # so we update expected_2 and batches_2 accordingly + expected_2 = [expected_2[i + 1] if i % 2 == 0 else expected_2[i - 1] for i in range(len(expected_2))] + batches_2 = [batches_2[i + 1] if i % 2 == 0 else batches_2[i - 1] for i in range(len(batches_2))] + expected_4 = [ + [torch.tensor([1]), torch.tensor([1])], + [torch.tensor([6]), torch.tensor([6])], + [torch.tensor([2]), torch.tensor([2])], + [torch.tensor([7]), torch.tensor([7])], + ] + for i, batch in enumerate(dloader): + if not shuffle: + assert all( + torch.equal(x, y) + for x, y in zip(batch, (expected_4 if resume and length is not None else expected_2)[i]) + ) + elif not resume and length is not None: + assert all(torch.equal(x, y) for x, y in zip(batch, batches_2[i])) + if i == break_at: + break + expected_5 = [ + [torch.tensor([3]), torch.tensor([3])], + [torch.tensor([8]), torch.tensor([8])], + [torch.tensor([4]), torch.tensor([4])], + [torch.tensor([9]), torch.tensor([9])], + ] + for i, batch in enumerate(dloader): + if not shuffle: + assert all( + torch.equal(x, y) + for x, y in zip(batch, (expected_5 if resume and length is not None else expected_2)[i]) + ) + elif not resume and length is not None: + assert all(torch.equal(x, y) for x, y in zip(batch, batches_2[i])) + if i == break_at: + break + -@pytest.mark.parametrize("length", [None, 5]) +@pytest.mark.parametrize("length", [None, 4]) @pytest.mark.parametrize("resume", [False, True]) @pytest.mark.parametrize("shuffle", [False, True]) @pytest.mark.skipif(sys.platform in ("win32", "darwin"), reason="too slow in CI") def test_parallel_dataset_complete_iteration_resume(tmp_path_factory, length, resume, shuffle): - _, _, pardset, dloader = prepare_parallel_dataset_and_dataloder( - tmp_path_factory, parlen=length, len1=4, len2=4, batch_size=1, num_workers=2, shuffle=shuffle, resume=resume + _, _, pardset, dloader, tmpdir = prepare_parallel_dataset_and_dataloder( + tmp_path_factory, parlen=length, len1=6, len2=6, batch_size=1, num_workers=2, shuffle=shuffle, resume=resume ) assert pardset.is_cycling() or length is None expected_1 = [ [torch.tensor([0]), torch.tensor([0])], - [torch.tensor([2]), torch.tensor([2])], - [torch.tensor([1]), torch.tensor([1])], [torch.tensor([3]), torch.tensor([3])], - [torch.tensor([0]), torch.tensor([0])], + [torch.tensor([1]), torch.tensor([1])], + [torch.tensor([4]), torch.tensor([4])], + [torch.tensor([2]), torch.tensor([2])], + [torch.tensor([5]), torch.tensor([5])], ] batches_1 = [] for i, batch in enumerate(dloader): @@ -911,11 +977,10 @@ def test_parallel_dataset_complete_iteration_resume(tmp_path_factory, length, re assert all(torch.equal(x, y) for x, y in zip(batch, expected_1[i])) batches_1.append(batch) expected_2 = [ - [torch.tensor([1]), torch.tensor([1])], [torch.tensor([2]), torch.tensor([2])], + [torch.tensor([5]), torch.tensor([5])], [torch.tensor([0]), torch.tensor([0])], [torch.tensor([3]), torch.tensor([3])], - [torch.tensor([1]), torch.tensor([1])], ] for i, batch in enumerate(dloader): if not shuffle: @@ -925,12 +990,12 @@ def test_parallel_dataset_complete_iteration_resume(tmp_path_factory, length, re ) elif not resume and length is not None: assert all(torch.equal(x, y) for x, y in zip(batch, batches_1[i])) + state_dict_after_2 = dloader.state_dict() expected_3 = [ [torch.tensor([1]), torch.tensor([1])], - [torch.tensor([3]), torch.tensor([3])], - [torch.tensor([0]), torch.tensor([0])], + [torch.tensor([4]), torch.tensor([4])], [torch.tensor([2]), torch.tensor([2])], - [torch.tensor([1]), torch.tensor([1])], + [torch.tensor([5]), torch.tensor([5])], ] for i, batch in enumerate(dloader): if not shuffle: @@ -940,6 +1005,59 @@ def test_parallel_dataset_complete_iteration_resume(tmp_path_factory, length, re ) elif not resume and length is not None: assert all(torch.equal(x, y) for x, y in zip(batch, batches_1[i])) + # simulate training crash and manually resume by loading state dict created after epoch 2 + _, _, pardset, dloader, _ = prepare_parallel_dataset_and_dataloder( + tmp_path_factory, + parlen=length, + len1=6, + len2=6, + batch_size=1, + num_workers=2, + shuffle=shuffle, + resume=resume, + tmpdir=tmpdir, + ) + assert pardset.is_cycling() or length is None + dloader.load_state_dict(state_dict_after_2) + assert not dloader.restore # iterations are complete so no restore + # we should get same samples as in epoch 3 if resuming + # else we should get samples from epoch 1 + for i, batch in enumerate(dloader): + if not shuffle: + assert all( + torch.equal(x, y) + for x, y in zip(batch, (expected_3 if resume and length is not None else expected_1)[i]) + ) + elif not resume and length is not None: + assert all(torch.equal(x, y) for x, y in zip(batch, batches_1[i])) + expected_4 = [ + [torch.tensor([0]), torch.tensor([0])], + [torch.tensor([3]), torch.tensor([3])], + [torch.tensor([1]), torch.tensor([1])], + [torch.tensor([4]), torch.tensor([4])], + ] + for i, batch in enumerate(dloader): + if not shuffle: + assert all( + torch.equal(x, y) + for x, y in zip(batch, (expected_4 if resume and length is not None else expected_1)[i]) + ) + elif not resume and length is not None: + assert all(torch.equal(x, y) for x, y in zip(batch, batches_1[i])) + expected_5 = [ + [torch.tensor([2]), torch.tensor([2])], + [torch.tensor([5]), torch.tensor([5])], + [torch.tensor([0]), torch.tensor([0])], + [torch.tensor([3]), torch.tensor([3])], + ] + for i, batch in enumerate(dloader): + if not shuffle: + assert all( + torch.equal(x, y) + for x, y in zip(batch, (expected_5 if resume and length is not None else expected_1)[i]) + ) + elif not resume and length is not None: + assert all(torch.equal(x, y) for x, y in zip(batch, batches_1[i])) @pytest.mark.parametrize("length", [None, 18]) @@ -947,7 +1065,7 @@ def test_parallel_dataset_complete_iteration_resume(tmp_path_factory, length, re @pytest.mark.parametrize("shuffle", [False, True]) @pytest.mark.skipif(sys.platform in ("win32", "darwin"), reason="too slow in CI") def test_parallel_dataset_partial_iteration_resume_without_dataloader(tmp_path_factory, length, resume, shuffle): - _, _, pardset, _ = prepare_parallel_dataset_and_dataloder( + _, _, pardset, _, _ = prepare_parallel_dataset_and_dataloder( tmp_path_factory, parlen=length, len1=10, len2=10, batch_size=1, num_workers=2, shuffle=shuffle, resume=resume ) assert pardset.is_cycling() or length is None @@ -979,7 +1097,7 @@ def test_parallel_dataset_partial_iteration_resume_without_dataloader(tmp_path_f @pytest.mark.parametrize("shuffle", [False, True]) @pytest.mark.skipif(sys.platform in ("win32", "darwin"), reason="too slow in CI") def test_parallel_dataset_complete_iteration_resume_without_dataloader(tmp_path_factory, length, resume, shuffle): - _, _, pardset, _ = prepare_parallel_dataset_and_dataloder( + _, _, pardset, _, _ = prepare_parallel_dataset_and_dataloder( tmp_path_factory, parlen=length, len1=4, len2=4, batch_size=1, num_workers=2, shuffle=shuffle, resume=resume ) assert pardset.is_cycling() or length is None