As of 12372f4, device assignment in the PyTorch dataloader does not work correctly with multiple GPUs.
import os
import pandas as pd
from merlin.dataloader.torch import Loader
from merlin.io.dataset import Dataset
dataset = Dataset(pd.DataFrame({"a": list(range(10))}))
dataset = dataset.repartition(npartitions=2)
rank = int(os.environ["LOCAL_RANK"])
with Loader(
dataset,
batch_size=1,
global_rank=rank,
global_size=2,
device=rank,
) as loader:
for idx, batch in enumerate(loader):
x, y = batch
device = x["a"].device
print(f"rank: {rank}, device: {device}")
When I run the above, I get:
root@ba87ba84045a:/dataloader# torchrun --nproc_per_node=2 test_torch_multi_gpu.py
WARNING:torch.distributed.run:
*****************************************
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
*****************************************
/usr/local/lib/python3.8/dist-packages/merlin/dtypes/mappings/tf.py:52: UserWarning: Tensorflow dtype mappings did not load successfully due to an error: No module named 'tensorflow'
warn(f"Tensorflow dtype mappings did not load successfully due to an error: {exc.msg}")
/usr/local/lib/python3.8/dist-packages/merlin/dtypes/mappings/tf.py:52: UserWarning: Tensorflow dtype mappings did not load successfully due to an error: No module named 'tensorflow'
warn(f"Tensorflow dtype mappings did not load successfully due to an error: {exc.msg}")
rank: 0, device: cuda:0
rank: 0, device: cuda:0
rank: 0, device: cuda:0
rank: 0, device: cuda:0
rank: 0, device: cuda:0
rank: 1, device: cuda:0
rank: 1, device: cuda:0
rank: 1, device: cuda:0
rank: 1, device: cuda:0
rank: 1, device: cuda:0
But for rank 1, tensors are expected to be be placed on cuda:1 not cuda:0.
As of 12372f4, device assignment in the PyTorch dataloader does not work correctly with multiple GPUs.
When I run the above, I get:
But for rank 1, tensors are expected to be be placed on
cuda:1notcuda:0.