diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index a8624075ef05..11bfa6a5ec24 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -116,12 +116,18 @@ class IPGBucket: elements: int = 0 index: int = 0 has_moe_params: bool = False + # Streams that issued copies into buffer[index] for the current bucket fill. + # average_tensor must wait on all of them before reducing the bucket, since the + # copies can be produced on multiple streams (e.g. under torch.compile gradient + # hooks run on different autograd streams), not just the current one (#8061). + copy_streams: set = field(default_factory=set) def clear(self): self.params.clear() self.grads.clear() self.elements = 0 self.has_moe_params = False + self.copy_streams.clear() class DeepSpeedZeroOptimizer(ZeROOptimizer): @@ -1120,6 +1126,11 @@ def reduce_independent_p_g_buckets_and_remove_grads(self, param, i): if self.contiguous_gradients: if param.numel() > self.reduce_bucket_size: + # Scope note (#8061): extra-large params are reduced directly and + # never copied into the contiguous IPG bucket, so no producer stream + # is recorded for them. average_tensor falls back to waiting on the + # current stream for this path; the producer-stream tracking only + # covers the bucketed path below. self.extra_large_param_to_reduce[comm_dtype] = param else: # keeping the gradients contiguous to prevent memory fragmentation, and avoid flattening @@ -1130,6 +1141,10 @@ def reduce_independent_p_g_buckets_and_remove_grads(self, param, i): grad_reduc.data = new_grad_tensor.data.view_as(grad_reduc) if ( not self.zenflow or grad_reduc.dim() == 1) else new_grad_tensor.data.view_as( grad_reduc.transpose(0, 1)) + # Record the stream this copy ran on so average_tensor can wait on + # every producer of the bucket, not just the current stream (#8061). + if self.overlap_comm and not get_accelerator().resolves_data_dependency(): + bucket.copy_streams.add(get_accelerator().current_stream()) bucket.elements += param.numel() @@ -1238,7 +1253,15 @@ def average_tensor(self, tensor: torch.Tensor, communication_data_type: torch.dt if self.overlap_comm: stream = self.reduction_stream if not get_accelerator().resolves_data_dependency(): - stream.wait_stream(get_accelerator().current_stream()) + # The contiguous IPG bucket may have been filled by copies issued on + # several streams (e.g. under torch.compile, gradient hooks run on + # different autograd streams). Waiting only on the current stream lets + # the reduction read the bucket before the other producers finish + # (#8061), so wait on every stream that produced a copy into it. + bucket = self.ipg_buckets[communication_data_type] + producer_streams = bucket.copy_streams or {get_accelerator().current_stream()} + for producer_stream in producer_streams: + stream.wait_stream(producer_stream) get_accelerator().current_stream().wait_stream(stream) else: stream = get_accelerator().current_stream() diff --git a/tests/unit/v1/zero/test_overlap_comm_record_stream.py b/tests/unit/v1/zero/test_overlap_comm_record_stream.py index 431461703063..5dc85671830d 100644 --- a/tests/unit/v1/zero/test_overlap_comm_record_stream.py +++ b/tests/unit/v1/zero/test_overlap_comm_record_stream.py @@ -95,3 +95,75 @@ def test_allreduce_and_copy_with_multiple_ranks_records_only_local_buffers(monke assert bucket[0].recorded_streams == [optimizer.reduction_stream] assert bucket[1].copied_from is None assert bucket[1].recorded_streams == [] + + +class _FakeWaitStream: + """A stream stand-in that records which streams it was told to wait on.""" + + def __init__(self): + self.waited_on = [] + + def wait_stream(self, other): + self.waited_on.append(other) + + +class _FakeAcceleratorWithCurrentStream(_FakeAccelerator): + + def __init__(self, resolves_data_dependency, current_stream): + super().__init__(resolves_data_dependency) + self._current_stream = current_stream + + def current_stream(self): + return self._current_stream + + +def _build_average_tensor_optimizer(monkeypatch, *, copy_streams): + optimizer = DeepSpeedZeroOptimizer.__new__(DeepSpeedZeroOptimizer) + optimizer.overlap_comm = True + optimizer.reduce_scatter = False # take the early-return reduce path, isolating the wait logic + optimizer.reduction_stream = _FakeWaitStream() + comm_dtype = torch.float16 + bucket = zero_stage12.IPGBucket() + bucket.copy_streams = set(copy_streams) + optimizer.ipg_buckets = {comm_dtype: bucket} + reduced = [] + optimizer.gradient_reduction_w_predivide = lambda tensor, dt: reduced.append(dt) + current = _FakeWaitStream() + monkeypatch.setattr( + zero_stage12, + "get_accelerator", + lambda: _FakeAcceleratorWithCurrentStream(False, current), + ) + return optimizer, comm_dtype, current, reduced + + +def test_average_tensor_waits_on_all_ipg_bucket_producer_streams(monkeypatch): + # #8061: the reduction stream must wait on every stream that produced a copy into + # the contiguous IPG bucket, not just the current stream, because under + # torch.compile those copies can be issued on multiple autograd streams. + s1, s2 = object(), object() + optimizer, comm_dtype, _, reduced = _build_average_tensor_optimizer(monkeypatch, copy_streams=[s1, s2]) + + optimizer.average_tensor(torch.zeros(4), comm_dtype) + + assert set(optimizer.reduction_stream.waited_on) == {s1, s2} + assert reduced == [comm_dtype] + + +def test_average_tensor_falls_back_to_current_stream_without_producers(monkeypatch): + # The extra-large-param path reduces without copying into the bucket, so + # copy_streams is empty: preserve the original behavior of waiting on the + # current stream. + optimizer, comm_dtype, current, _ = _build_average_tensor_optimizer(monkeypatch, copy_streams=[]) + + optimizer.average_tensor(torch.zeros(4), comm_dtype) + + assert optimizer.reduction_stream.waited_on == [current] + + +def test_ipg_bucket_clear_resets_copy_streams(): + bucket = zero_stage12.IPGBucket() + assert bucket.copy_streams == set() + bucket.copy_streams.add(object()) + bucket.clear() + assert bucket.copy_streams == set()