Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion torchrec/distributed/test_utils/test_model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

# pyre-strict

import os
import unittest
from typing import Any, cast, Dict, List, Optional, Tuple, Type

Expand All @@ -28,7 +29,7 @@
)
from torchrec.distributed.types import ModuleSharder, ShardingStrategy, ShardingType
from torchrec.modules.embedding_configs import EmbeddingBagConfig, PoolingType
from torchrec.test_utils import seed_and_log, skip_if_asan_class
from torchrec.test_utils import get_free_port, seed_and_log, skip_if_asan_class
from torchrec.types import DataType


Expand Down Expand Up @@ -163,6 +164,10 @@ def _test_sharding(
lengths_dtype: torch.dtype = torch.int64,
sharding_strategy: Optional[ShardingStrategy] = None,
) -> None:
# Refresh MASTER_PORT for each Hypothesis example to avoid port conflicts
# that cause flaky test failures (socket errno 99: Cannot assign requested address)
os.environ["MASTER_PORT"] = str(get_free_port())

self._build_tables_and_groups(data_type=data_type)
# directly run the test with single process
if world_size == 1:
Expand Down Expand Up @@ -258,6 +263,10 @@ def _test_dynamic_sharding(
Tests the reshard API with dynamic_sharding_test, which creates 2 identical models
one of which is resharded, and then compares the predictions of the 2 models.
"""
# Refresh MASTER_PORT for each Hypothesis example to avoid port conflicts
# that cause flaky test failures (socket errno 99: Cannot assign requested address)
os.environ["MASTER_PORT"] = str(get_free_port())

self._build_tables_and_groups(data_type=data_type)
constraints = {}
if sharding_type is not None:
Expand Down
7 changes: 5 additions & 2 deletions torchrec/test_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,15 @@ def get_free_port() -> int:
raise Exception(
f"Binding failed with address {address} while getting free port {e}"
)
# OSS GHA: TODO remove when enable ipv6 on GHA @omkar
# OSS GHA: Use SO_REUSEADDR to handle TIME_WAIT port conflicts
# This is safe because the port is only used for gloo/nccl rendezvous coordination,
# not for actual data transfer. The distributed backends establish their own connections.
else:
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.bind(("127.0.0.1", 0))
s.listen(0)
s.listen(1)
with closing(s):
return s.getsockname()[1]
except Exception as e:
Expand Down
Loading