Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,18 @@ class IPGBucket:
elements: int = 0
index: int = 0
has_moe_params: bool = False
# Streams that issued copies into buffer[index] for the current bucket fill.
# average_tensor must wait on all of them before reducing the bucket, since the
# copies can be produced on multiple streams (e.g. under torch.compile gradient
# hooks run on different autograd streams), not just the current one (#8061).
copy_streams: set = field(default_factory=set)

def clear(self):
self.params.clear()
self.grads.clear()
self.elements = 0
self.has_moe_params = False
self.copy_streams.clear()


class DeepSpeedZeroOptimizer(ZeROOptimizer):
Expand Down Expand Up @@ -1120,6 +1126,11 @@ def reduce_independent_p_g_buckets_and_remove_grads(self, param, i):

if self.contiguous_gradients:
if param.numel() > self.reduce_bucket_size:
# Scope note (#8061): extra-large params are reduced directly and
# never copied into the contiguous IPG bucket, so no producer stream
# is recorded for them. average_tensor falls back to waiting on the
# current stream for this path; the producer-stream tracking only
# covers the bucketed path below.
self.extra_large_param_to_reduce[comm_dtype] = param
else:
# keeping the gradients contiguous to prevent memory fragmentation, and avoid flattening
Expand All @@ -1130,6 +1141,10 @@ def reduce_independent_p_g_buckets_and_remove_grads(self, param, i):
grad_reduc.data = new_grad_tensor.data.view_as(grad_reduc) if (
not self.zenflow or grad_reduc.dim() == 1) else new_grad_tensor.data.view_as(
grad_reduc.transpose(0, 1))
# Record the stream this copy ran on so average_tensor can wait on
# every producer of the bucket, not just the current stream (#8061).
if self.overlap_comm and not get_accelerator().resolves_data_dependency():
bucket.copy_streams.add(get_accelerator().current_stream())

bucket.elements += param.numel()

Expand Down Expand Up @@ -1238,7 +1253,15 @@ def average_tensor(self, tensor: torch.Tensor, communication_data_type: torch.dt
if self.overlap_comm:
stream = self.reduction_stream
if not get_accelerator().resolves_data_dependency():
stream.wait_stream(get_accelerator().current_stream())
# The contiguous IPG bucket may have been filled by copies issued on
# several streams (e.g. under torch.compile, gradient hooks run on
# different autograd streams). Waiting only on the current stream lets
# the reduction read the bucket before the other producers finish
# (#8061), so wait on every stream that produced a copy into it.
bucket = self.ipg_buckets[communication_data_type]
producer_streams = bucket.copy_streams or {get_accelerator().current_stream()}
for producer_stream in producer_streams:
stream.wait_stream(producer_stream)
Comment on lines +1261 to +1264
get_accelerator().current_stream().wait_stream(stream)
else:
stream = get_accelerator().current_stream()
Expand Down
72 changes: 72 additions & 0 deletions tests/unit/v1/zero/test_overlap_comm_record_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,75 @@ def test_allreduce_and_copy_with_multiple_ranks_records_only_local_buffers(monke
assert bucket[0].recorded_streams == [optimizer.reduction_stream]
assert bucket[1].copied_from is None
assert bucket[1].recorded_streams == []


class _FakeWaitStream:
"""A stream stand-in that records which streams it was told to wait on."""

def __init__(self):
self.waited_on = []

def wait_stream(self, other):
self.waited_on.append(other)


class _FakeAcceleratorWithCurrentStream(_FakeAccelerator):

def __init__(self, resolves_data_dependency, current_stream):
super().__init__(resolves_data_dependency)
self._current_stream = current_stream

def current_stream(self):
return self._current_stream


def _build_average_tensor_optimizer(monkeypatch, *, copy_streams):
optimizer = DeepSpeedZeroOptimizer.__new__(DeepSpeedZeroOptimizer)
optimizer.overlap_comm = True
optimizer.reduce_scatter = False # take the early-return reduce path, isolating the wait logic
optimizer.reduction_stream = _FakeWaitStream()
comm_dtype = torch.float16
bucket = zero_stage12.IPGBucket()
bucket.copy_streams = set(copy_streams)
optimizer.ipg_buckets = {comm_dtype: bucket}
reduced = []
optimizer.gradient_reduction_w_predivide = lambda tensor, dt: reduced.append(dt)
current = _FakeWaitStream()
monkeypatch.setattr(
zero_stage12,
"get_accelerator",
lambda: _FakeAcceleratorWithCurrentStream(False, current),
)
return optimizer, comm_dtype, current, reduced


def test_average_tensor_waits_on_all_ipg_bucket_producer_streams(monkeypatch):
# #8061: the reduction stream must wait on every stream that produced a copy into
# the contiguous IPG bucket, not just the current stream, because under
# torch.compile those copies can be issued on multiple autograd streams.
s1, s2 = object(), object()
optimizer, comm_dtype, _, reduced = _build_average_tensor_optimizer(monkeypatch, copy_streams=[s1, s2])

optimizer.average_tensor(torch.zeros(4), comm_dtype)

assert set(optimizer.reduction_stream.waited_on) == {s1, s2}
assert reduced == [comm_dtype]
Comment on lines +144 to +150


def test_average_tensor_falls_back_to_current_stream_without_producers(monkeypatch):
# The extra-large-param path reduces without copying into the bucket, so
# copy_streams is empty: preserve the original behavior of waiting on the
# current stream.
optimizer, comm_dtype, current, _ = _build_average_tensor_optimizer(monkeypatch, copy_streams=[])

optimizer.average_tensor(torch.zeros(4), comm_dtype)

assert optimizer.reduction_stream.waited_on == [current]


def test_ipg_bucket_clear_resets_copy_streams():
bucket = zero_stage12.IPGBucket()
assert bucket.copy_streams == set()
bucket.copy_streams.add(object())
bucket.clear()
assert bucket.copy_streams == set()
Loading