-
Notifications
You must be signed in to change notification settings - Fork 81
Fix ParallelStreamingDataset with resume=True not resuming after loading a state dict when breaking early
#771
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -423,9 +423,9 @@ def test_dataloader_shuffle(tmp_path, shuffle): | |||||||||||||||
|
|
||||||||||||||||
|
|
||||||||||||||||
| def prepare_parallel_dataset_and_dataloder( | ||||||||||||||||
| tmp_path_factory, parlen, len1=48, len2=56, num_workers=0, batch_size=4, shuffle=True, resume=True | ||||||||||||||||
| tmp_path_factory, parlen, len1=48, len2=56, num_workers=0, batch_size=4, shuffle=True, resume=True, tmpdir=None | ||||||||||||||||
| ): | ||||||||||||||||
| tmpdir = tmp_path_factory.mktemp("data") | ||||||||||||||||
| tmpdir = tmp_path_factory.mktemp("data") if tmpdir is None else tmpdir | ||||||||||||||||
| datasets = [str(tmpdir / f"dataset_{i}") for i in range(2)] | ||||||||||||||||
| for dataset, num_items in zip(datasets, [len1, len2]): | ||||||||||||||||
| cache = Cache(input_dir=dataset, chunk_size=10) | ||||||||||||||||
|
|
@@ -437,12 +437,12 @@ def prepare_parallel_dataset_and_dataloder( | |||||||||||||||
| dset2 = StreamingDataset(datasets[1], shuffle=shuffle) | ||||||||||||||||
| pardset = ParallelStreamingDataset(datasets=[dset1, dset2], length=parlen, resume=resume) | ||||||||||||||||
| dloader = StreamingDataLoader(pardset, num_workers=num_workers, batch_size=batch_size) | ||||||||||||||||
| return dset1, dset2, pardset, dloader | ||||||||||||||||
| return dset1, dset2, pardset, dloader, tmpdir | ||||||||||||||||
|
|
||||||||||||||||
|
|
||||||||||||||||
| @pytest.mark.parametrize("length", [None, 3, float("inf")]) | ||||||||||||||||
| def test_parallel_dataset_dataloader_states_without_any_iterations(tmp_path_factory, length): | ||||||||||||||||
| _, _, _, dataloader = prepare_parallel_dataset_and_dataloder(tmp_path_factory, length) | ||||||||||||||||
| _, _, _, dataloader, _ = prepare_parallel_dataset_and_dataloder(tmp_path_factory, length) | ||||||||||||||||
| assert not dataloader.restore | ||||||||||||||||
| dataloader.load_state_dict(dataloader.state_dict()) | ||||||||||||||||
| assert not dataloader.restore | ||||||||||||||||
|
|
@@ -458,7 +458,7 @@ def test_parallel_dataset_dataloader_states_without_any_iterations(tmp_path_fact | |||||||||||||||
| def test_parallel_dataset_dataloader_states_complete_iterations(tmp_path_factory, length, num_workers, batch_size): | ||||||||||||||||
| print(f"Testing with num_workers={num_workers}") | ||||||||||||||||
|
|
||||||||||||||||
| _, _, parallel_dataset, dataloader = prepare_parallel_dataset_and_dataloder( | ||||||||||||||||
| _, _, parallel_dataset, dataloader, _ = prepare_parallel_dataset_and_dataloder( | ||||||||||||||||
| tmp_path_factory, | ||||||||||||||||
| length, | ||||||||||||||||
| batch_size=batch_size, | ||||||||||||||||
|
|
@@ -524,7 +524,7 @@ def test_parallel_dataset_dataloader_states_partial_iterations( | |||||||||||||||
| ): | ||||||||||||||||
| print(f"Testing with num_workers={num_workers}, break_at={break_at}") | ||||||||||||||||
|
|
||||||||||||||||
| _, _, parallel_dataset, dataloader = prepare_parallel_dataset_and_dataloder( | ||||||||||||||||
| _, _, parallel_dataset, dataloader, _ = prepare_parallel_dataset_and_dataloder( | ||||||||||||||||
| tmp_path_factory, length, batch_size=batch_size, num_workers=num_workers, shuffle=True | ||||||||||||||||
| ) | ||||||||||||||||
|
|
||||||||||||||||
|
|
@@ -701,7 +701,7 @@ def test_parallel_dataset_with_dataloader_2_epochs( | |||||||||||||||
| expected_states_1: ExpectedStates, | ||||||||||||||||
| expected_states_2: ExpectedStates, | ||||||||||||||||
| ): | ||||||||||||||||
| dataset1, dataset2, _, dataloader = prepare_parallel_dataset_and_dataloder( | ||||||||||||||||
| dataset1, dataset2, _, dataloader, _ = prepare_parallel_dataset_and_dataloder( | ||||||||||||||||
| tmp_path_factory, | ||||||||||||||||
| length, | ||||||||||||||||
| len1, | ||||||||||||||||
|
|
@@ -837,7 +837,7 @@ def test_parallel_dataset_with_dataloader_2_epochs( | |||||||||||||||
| @pytest.mark.parametrize("shuffle", [False, True]) | ||||||||||||||||
| @pytest.mark.skipif(sys.platform in ("win32", "darwin"), reason="too slow in CI") | ||||||||||||||||
| def test_parallel_dataset_partial_iteration_resume(tmp_path_factory, length, resume, shuffle): | ||||||||||||||||
| _, _, pardset, dloader = prepare_parallel_dataset_and_dataloder( | ||||||||||||||||
| _, _, pardset, dloader, tmpdir = prepare_parallel_dataset_and_dataloder( | ||||||||||||||||
| tmp_path_factory, parlen=length, len1=10, len2=10, batch_size=1, num_workers=2, shuffle=shuffle, resume=resume | ||||||||||||||||
| ) | ||||||||||||||||
| assert pardset.is_cycling() or length is None | ||||||||||||||||
|
|
@@ -871,6 +871,7 @@ def test_parallel_dataset_partial_iteration_resume(tmp_path_factory, length, res | |||||||||||||||
| assert all(torch.equal(x, y) for x, y in zip(batch, batches_1[i])) | ||||||||||||||||
| if i == break_at: | ||||||||||||||||
| break | ||||||||||||||||
| state_dict_after_2 = dloader.state_dict() | ||||||||||||||||
| expected_3 = [ | ||||||||||||||||
| [torch.tensor([4]), torch.tensor([4])], | ||||||||||||||||
| [torch.tensor([9]), torch.tensor([9])], | ||||||||||||||||
|
|
@@ -888,34 +889,98 @@ def test_parallel_dataset_partial_iteration_resume(tmp_path_factory, length, res | |||||||||||||||
| if i == break_at: | ||||||||||||||||
| break | ||||||||||||||||
|
|
||||||||||||||||
| # simulate training crash and manually resume by loading state dict created after epoch 2 | ||||||||||||||||
| _, _, pardset, dloader, _ = prepare_parallel_dataset_and_dataloder( | ||||||||||||||||
| tmp_path_factory, | ||||||||||||||||
| parlen=length, | ||||||||||||||||
| len1=10, | ||||||||||||||||
| len2=10, | ||||||||||||||||
| batch_size=1, | ||||||||||||||||
| num_workers=2, | ||||||||||||||||
| shuffle=shuffle, | ||||||||||||||||
| resume=resume, | ||||||||||||||||
| tmpdir=tmpdir, | ||||||||||||||||
| ) | ||||||||||||||||
| assert pardset.is_cycling() or length is None | ||||||||||||||||
| dloader.load_state_dict(state_dict_after_2) | ||||||||||||||||
| assert dloader.restore | ||||||||||||||||
| # we should get same samples as in epoch 3 if resuming | ||||||||||||||||
| # else we should get samples from epoch 2 | ||||||||||||||||
| batches_2 = [] | ||||||||||||||||
| for i, batch in enumerate(dloader): | ||||||||||||||||
| if not shuffle: | ||||||||||||||||
| assert all( | ||||||||||||||||
| torch.equal(x, y) | ||||||||||||||||
| for x, y in zip(batch, (expected_3 if resume and length is not None else expected_2)[i]) | ||||||||||||||||
| ) | ||||||||||||||||
| batches_2.append(batch) | ||||||||||||||||
| if i == break_at: | ||||||||||||||||
| break | ||||||||||||||||
| # for some reason the workers are swapping their samples compared to the previous epoch when not resuming | ||||||||||||||||
| # so we update expected_2 and batches_2 accordingly | ||||||||||||||||
| expected_2 = [expected_2[i + 1] if i % 2 == 0 else expected_2[i - 1] for i in range(len(expected_2))] | ||||||||||||||||
| batches_2 = [batches_2[i + 1] if i % 2 == 0 else batches_2[i - 1] for i in range(len(batches_2))] | ||||||||||||||||
|
Comment on lines
+919
to
+922
|
||||||||||||||||
| # for some reason the workers are swapping their samples compared to the previous epoch when not resuming | |
| # so we update expected_2 and batches_2 accordingly | |
| expected_2 = [expected_2[i + 1] if i % 2 == 0 else expected_2[i - 1] for i in range(len(expected_2))] | |
| batches_2 = [batches_2[i + 1] if i % 2 == 0 else batches_2[i - 1] for i in range(len(batches_2))] | |
| # The order of samples delivered by the workers should be deterministic when shuffle=False. | |
| # If this test fails, investigate the worker sample assignment logic in the DataLoader. | |
| # Remove the workaround that swaps expected_2 and batches_2; the test should reflect the true, documented order. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment states "we do not want to restart at index 0 at every epoch" but could be more explicit about what this fix addresses. Consider expanding the comment to mention that this handles both automatic cycling (between epochs in the same session) and manual resume (after loading a state dict from a previous session/crash), as this is the key bug being fixed.