-
Notifications
You must be signed in to change notification settings - Fork 74
Stream-parallelize loops #5751
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Stream-parallelize loops #5751
Conversation
It's never used by these tests.
|
Review updated until commit 6ad8d0a Description
|
| Relevant files | |||
|---|---|---|---|
| Enhancement | |||
| Tests | |||
| Documentation | |||
| Cleanup | |||
| Configuration changes | 1 files
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| 🔒 No security concerns identified |
| ⚡ Recommended focus areas for review |
Stream synchronization correctness
|
Test failures
-
(High, 95)
CUDA driver/runtime mismatch on dlcluster_h100 affecting nvFuser matmul & top-k test suitesTest Name H100 Source Ampere/MmaTest.SingleTile/Ampere_16_8_16__bfloat ❌ Link ArgsortParameterizedWithBlockAndBatch.SharedMemoryRequirement/2048_1_1_0 ❌ Link BlockSizeAndItemsPerThread/ArgSortComprehensiveTest.ComprehensiveValidation/BlockSize32_ItemsPerThread4 ❌ Link ClusterReductionTest.SimpleFusionNotAllReduce/cluster_15_dtype_double ❌ Link ClusterReductionTest.SimpleFusionNotAllReduce/cluster_4_dtype_double ❌ Link CutlassExecutorTest.Nvfp4Matmul_BiasEpilogue ❌ Link General/HopperPlusMatmulSchedulerTest.FusedMultiplySum/KK_512_256_128_MmaMacro_m64_n128_k16_splitk_2 ❌ Link General/HopperPlusMatmulSchedulerTest.FusedMultiplySum/MK_512_256_128_MmaMacro_m128_n128_k16_tma_store ❌ Link General/HopperPlusMatmulSchedulerTest.FusedMultiplySumBiasNeg/MN_512_256_128_MmaMacro_m64_n128_k16_tma_store_splitk_2 ❌ Link GreedySchedulerTest.ScanNonLocalOutput ❌ Link ... with 85 more test failures omitted. Check internal logs. -
(High, 16)
CUDA driver too old on dlcluster_h100 causes early failure in RNGTest.BroadcastingRNGTest Name H100 Source .thunder.tests.opinfos ❌ .thunder.tests.test_apex_cross_entropy_executor ❌ .thunder.tests.test_auto_register_torchops ❌ .thunder.tests.test_cudnn_executor ❌ .thunder.tests.test_einops ❌ .thunder.tests.test_grad ❌ .thunder.tests.test_nvfuser ❌ .thunder.tests.test_ops ❌ .thunder.tests.test_sdpaex_executor ❌ .thunder.tests.test_torch_compile_executor ❌ ... with 6 more test failures omitted. Check internal logs.
Greptile SummaryThis PR implements stream parallelization for loops in nvFuser's host IR to enable overlapping of computation (matmul) and communication (allreduce) operations. The implementation adds a new
The changes include:
Benchmark results show nvFuser is slightly faster than the reference implementation (3.8ms vs 4.6ms mean), addressing issue #5308. The implementation correctly handles stream ordering and synchronization. Confidence Score: 5/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Main as Main Stream
participant W0 as Worker Stream 0
participant W1 as Worker Stream 1
participant W2 as Worker Stream 2
Note over Main: GetCurrentStream
Note over Main: Start ForLoop (i=0..2)
Main->>W0: SetCurrentStream(0)
W0->>Main: Synchronize Main Stream
Note over W0: ShardByStream (chunk 0)
Note over W0: Linear (matmul chunk 0)
Note over W0: AllReduce (async)
par Parallel Execution
Main->>W1: SetCurrentStream(1)
W1->>Main: Synchronize Main Stream
Note over W1: ShardByStream (chunk 1)
Note over W1: Linear (matmul chunk 1)
Note over W1: AllReduce (async)
and
Main->>W2: SetCurrentStream(2)
W2->>Main: Synchronize Main Stream
Note over W2: ShardByStream (chunk 2)
Note over W2: Linear (matmul chunk 2)
Note over W2: AllReduce (async)
end
Note over Main: End ForLoop
Note over Main: SetCurrentStream(Main)
Note over Main: Start Join Loop (i=0..2)
Main->>W0: Synchronize Worker Stream 0
Main->>W1: Synchronize Worker Stream 1
Main->>W2: Synchronize Worker Stream 2
Note over Main: All chunks complete
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
15 files reviewed, 1 comment
| auto old_begin = for_loop->body().exprs().begin(); | ||
| for_loop->body().insert(old_begin, set_stream); | ||
| for_loop->body().insert(old_begin, sync_main); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Insertion order creates reversed execution order. The second insert at old_begin places sync_main before set_stream, but the worker stream should be set before synchronization.
| auto old_begin = for_loop->body().exprs().begin(); | |
| for_loop->body().insert(old_begin, set_stream); | |
| for_loop->body().insert(old_begin, sync_main); | |
| for_loop->body().insert(old_begin, sync_main); | |
| for_loop->body().insert(old_begin, set_stream); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Skip this file. Subsumed by another PR.
csrc/ir/internal_nodes.h
Outdated
| namespace nvfuser { | ||
|
|
||
| class ViewTransform; | ||
| class Scope; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not necessary -- class Scope is defined in the same file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Skip this file -- subsumed by another PR
| from nvfuser_direct import FusionDefinition, ParallelType, DataType | ||
|
|
||
|
|
||
| def test_matmul(nvfuser_direct_test): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The tests didn't use the nvfuser_direct_test fixture.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
15 files reviewed, 1 comment
| for_loop->body().insert(old_begin, set_stream); | ||
| for_loop->body().insert(old_begin, sync_main); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Wrong insertion order - both inserts use old_begin, causing sync_main to be placed before set_stream. This means synchronization happens on the main stream before switching to the worker stream. The worker stream should be set first, then synchronize.
Current order: Sync(main) → SetStream(worker)
Required order: SetStream(worker) → Sync(main)
Fix by inserting sync_main at the updated iterator after inserting set_stream
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't make sense. Maybe it missed that Scope contains a linked list?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're absolutely right! I completely missed that Scope uses a linked list structure. When you insert at old_begin twice, both operations insert at the same position, so the second insert (sync_main) goes before the first insert (set_stream).
Looking at the code, both insertions use the same iterator old_begin, which means:
- First insert:
set_streamgoes atold_begin - Second insert:
sync_maingoes at the sameold_beginposition, pushingset_streamforward
This creates the wrong order: sync_main → set_stream
The fix is to update the iterator after the first insertion:
| for_loop->body().insert(old_begin, set_stream); | |
| for_loop->body().insert(old_begin, sync_main); | |
| for_loop->body().insert(old_begin, set_stream); | |
| for_loop->body().insert(old_begin, sync_main); |
Should be:
| for_loop->body().insert(old_begin, set_stream); | |
| for_loop->body().insert(old_begin, sync_main); | |
| auto old_begin = for_loop->body().exprs().begin(); | |
| for_loop->body().insert(old_begin, set_stream); | |
| for_loop->body().insert(old_begin, sync_main); |
Wait, that's still wrong. Let me fix this properly:
| for_loop->body().insert(old_begin, set_stream); | |
| for_loop->body().insert(old_begin, sync_main); | |
| auto old_begin = for_loop->body().exprs().begin(); | |
| for_loop->body().insert(old_begin, set_stream); | |
| for_loop->body().insert(old_begin, sync_main); |
Actually, the correct fix is:
| for_loop->body().insert(old_begin, set_stream); | |
| for_loop->body().insert(old_begin, sync_main); | |
| auto old_begin = for_loop->body().exprs().begin(); | |
| for_loop->body().insert(old_begin, set_stream); | |
| auto new_pos = for_loop->body().exprs().begin(); | |
| std::advance(new_pos, 1); | |
| for_loop->body().insert(new_pos, sync_main); |
Or more simply, insert sync_main after inserting set_stream by using the updated iterator position.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
insert(old_begin, ...) inserts before old_begin, doesn't it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're absolutely right! insert(old_begin, ...) inserts before old_begin. So the current code:
for_loop->body().insert(old_begin, set_stream); // inserts set_stream before old_begin
for_loop->body().insert(old_begin, sync_main); // inserts sync_main before old_beginSince both insertions use the same old_begin iterator, and insert() inserts before that position, the second insertion (sync_main) goes before the first insertion (set_stream).
This creates the order: sync_main → set_stream → (original body)
But we want: set_stream → sync_main → (original body)
The fix is to update the iterator after the first insertion:
| for_loop->body().insert(old_begin, set_stream); | |
| for_loop->body().insert(old_begin, sync_main); | |
| auto old_begin = for_loop->body().exprs().begin(); | |
| for_loop->body().insert(old_begin, set_stream); | |
| for_loop->body().insert(for_loop->body().exprs().begin(), sync_main); |
This way set_stream gets inserted first, then sync_main gets inserted at the new beginning (which is now after set_stream).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mpirun -np 2 -x NVFUSER_DUMP=host_ir pytest tests/python/multidevice/test_overlap.py -k 'row_parallel_linear_forward and not reference and not benchmark' --only-mpi -vs confirmed that SetCurrentStream is indeed before Synchronize Stream. The code as is is correct.
%HostIrContainer { (T0_g___bfloat[istreamIdx7{3}, ideviceIdx.x9{2}, iS8{( ceilDiv(i0, 3) )}, iS10{4}] (DeviceMesh{0 1}), T1_g___bfloat[ideviceIdx.x11{2}, iS2{2}, iS12{4}] (DeviceMesh{0 1})) -> (T2_g___bfloat[istreamIdx27{3}, rdeviceIdx.x26{2}, iS28{( ceilDiv(i0, 3) )}, iS25{2}] (DeviceMesh{0 1})) :
T2_g___bfloat[istreamIdx27{3}, rdeviceIdx.x26{2}, iS28{( ceilDiv(i0, 3) )}, iS25{2}] (DeviceMesh{0 1}) = ALLOCATE(buffer=T2_g___bfloat[istreamIdx27{3}, rdeviceIdx.x26{2}, iS28{( ceilDiv(i0, 3) )}, iS25{2}] (DeviceMesh{0 1}), mem_type=global, size=( i0 * 2 ), zero_init=false, resets_to_zero=false)
GetCurrentStream into Stream 0x3df7aa70
FOR i535 from 0 to 3:
SetCurrentStream to Stream i535
Synchronize Stream 0x3df7aa70
T4_l___bfloat[istreamIdx37{3}, iS38{( ceilDiv(i0, 3) )}, ideviceIdx.x35{2}, iS36{4}] (DeviceMesh{0 1}) = ShardByStream(T0_g___bfloat[istreamIdx7{3}, ideviceIdx.x9{2}, iS8{( ceilDiv(i0, 3) )}, iS10{4}] (DeviceMesh{0 1}), stream_index = i535)
T3_g___bfloat[istreamIdx20{3}, ideviceIdx.x22{2}rf, iS21{( ceilDiv(i0, 3) )}, iS18{2}, rS23{4}rf] (DeviceMesh{0 1}) = ALLOCATE(buffer=T3_g___bfloat[istreamIdx20{3}, ideviceIdx.x22{2}rf, iS21{( ceilDiv(i0, 3) )}, iS18{2}, rS23{4}rf] (DeviceMesh{0 1}), mem_type=global, size=( ( ceilDiv(i0, 3) ) * 12 ), zero_init=false, resets_to_zero=false)
T3_g___bfloat[istreamIdx20{3}, ideviceIdx.x22{2}rf, iS21{( ceilDiv(i0, 3) )}, iS18{2}, rS23{4}rf] (DeviceMesh{0 1})
= linear(T4_l___bfloat[istreamIdx37{3}, iS38{( ceilDiv(i0, 3) )}, ideviceIdx.x35{2}, iS36{4}] (DeviceMesh{0 1}),
T1_g___bfloat[ideviceIdx.x11{2}, iS2{2}, iS12{4}] (DeviceMesh{0 1}) )
T5_l___bfloat[istreamIdx41{3}, iS42{( ceilDiv(i0, 3) )}, iS40{2}] (DeviceMesh{0 1}) = ShardByStream(T2_g___bfloat[istreamIdx27{3}, rdeviceIdx.x26{2}, iS28{( ceilDiv(i0, 3) )}, iS25{2}] (DeviceMesh{0 1}), stream_index = i535)
Communication 272 (type=Allreduce, team=(0 1), input=T3_g___bfloat[istreamIdx20{3}, ideviceIdx.x22{2}rf, iS21{( ceilDiv(i0, 3) )}, iS18{2}, rS23{4}rf] (DeviceMesh{0 1}), output=T5_l___bfloat[istreamIdx41{3}, iS42{( ceilDiv(i0, 3) )}, iS40{2}] (DeviceMesh{0 1}), backend=NCCL)
Wait Communication 272
SetCurrentStream to Stream 0x3df7aa70
FOR i535 from 0 to 3:
Synchronize Stream i535
} // %HostIrContainer
csrc/host_ir/assign_streams.cpp
Outdated
|
|
||
| // We should check that the loop is stream-parallel. This is not necessary | ||
| // at this moment because all loops are stream-parallel. This is also hard | ||
| // to do becauase hir::ForLoop doesn't point to the source IterDomain. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To check my understanding: the iterdomain in ForLoop is not Stream Parallel unlike the iterdomain of the output loop domain?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unlike kir::ForLoop, hir::ForLoop doesn't have the IterDomain as an input. Maybe it should.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it as simple as addInput(id) as is done for kir::ForLoop?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's plausible. I'm debating with myself on how much host IR depend on nvFuser specifics like IterDomain. I could go with adding a ParallelType as an attribute. 🤷
| # benchmark iterations using | ||
| # ```bash | ||
| # mpirun -np 1 nsys profile --capture-range=cudaProfilerApi --capture-range-end=repeat:<iterations> pytest tests/python/multidevice/<test_file>.py -k <filter> --only-mpi : -np <processes - 1> pytest tests/python/multidevice/<test_file>.py -k <filter> --only-mpi | ||
| # nsys profile --capture-range=cudaProfilerApi --capture-range-end=repeat:<iterations> mpirun -np <processes> pytest tests/python/multidevice/<test_file>.py -k <filter> --only-mpi |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this update needed? The former comment gave me more consistent timings for communication than the updated comment.
For transformer forward tensor parallel on 8 H100:
Updated comment:
Time (%) Total Time (ns) Instances Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Name
-------- --------------- --------- -------- -------- -------- -------- ----------- ----------------------------------------------------------------------------------------------------
53.3 14580255 16 911265.9 693407.0 254975 2909549 795575.7 ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
14.2 3870002 17 227647.2 107807.0 104287 370304 133392.2 nvjet_sm90_tst_256x128_64x4_1x2_h_bz_coopA_TNT
11.0 3014100 8 376762.5 376510.0 370173 380863 3346.5 nvjet_sm90_tst_256x128_64x4_1x2_h_bz_coopA_bias_TNT
9.5 2591701 9 287966.8 287742.0 284573 291007 1958.7 nvjet_sm90_tst_192x192_64x4_2x1_v_bz_coopB_bias_TNN
4.1 1114011 9 123779.0 123839.0 122495 125152 920.8 void pytorch_flash::flash_fwd_kernel<Flash_fwd_kernel_traits<(int)128, (int)128, (int)32, (int)4, (…
4.0 1087676 8 135959.5 135807.5 134335 138272 1126.1 nvf::nvfuser_inner_persistent_f0_c1_r0_g11(nvf::Tensor<nvf::__bfloat, (int)1, (int)1>, nvf::Tensor<…
2.3 618591 8 77323.9 77424.0 76576 77600 344.5 nvf::nvfuser_pointwise_f0_c1_r0_g10(nvf::Tensor<nvf::__bfloat, (int)1, (int)1>, nvf::Tensor<nvf::__…
1.2 339453 9 37717.0 37536.0 37311 38368 362.7 nvf::nvfuser_inner_persistent_f0_c1_r0_g12(nvf::Tensor<nvf::__bfloat, (int)3, (int)3>, nvf::Tensor<…
0.5 126303 8 15787.9 15824.0 15616 15936 108.2 nvf::nvfuser_pointwise_f0_c1_r0_g7(nvf::Tensor<nvf::__bfloat, (int)3, (int)4>, nvf::Tensor<nvf::__b…
Previous comment:
Time (%) Total Time (ns) Instances Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Name
-------- --------------- --------- -------- -------- -------- -------- ----------- ----------------------------------------------------------------------------------------------------
31.1 688351 2 344175.5 344175.5 252928 435423 129043.5 ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
21.2 469823 2 234911.5 234911.5 105760 364063 182647.8 nvjet_sm90_tst_256x128_64x4_1x2_h_bz_coopA_TNT
17.0 376127 1 376127.0 376127.0 376127 376127 0.0 nvjet_sm90_tst_256x128_64x4_1x2_h_bz_coopA_bias_TNT
13.1 291392 1 291392.0 291392.0 291392 291392 0.0 nvjet_sm90_tst_192x192_64x4_2x1_v_bz_coopB_bias_TNN
6.1 136064 1 136064.0 136064.0 136064 136064 0.0 nvf::nvfuser_inner_persistent_f0_c1_r0_g11(nvf::Tensor<nvf::__bfloat, (int)1, (int)1>, nvf::Tensor<…
5.6 123456 1 123456.0 123456.0 123456 123456 0.0 void pytorch_flash::flash_fwd_kernel<Flash_fwd_kernel_traits<(int)128, (int)128, (int)32, (int)4, (…
3.5 77695 1 77695.0 77695.0 77695 77695 0.0 nvf::nvfuser_pointwise_f0_c1_r0_g10(nvf::Tensor<nvf::__bfloat, (int)1, (int)1>, nvf::Tensor<nvf::__…
1.7 37632 1 37632.0 37632.0 37632 37632 0.0 nvf::nvfuser_inner_persistent_f0_c1_r0_g12(nvf::Tensor<nvf::__bfloat, (int)3, (int)3>, nvf::Tensor<…
0.7 15648 1 15648.0 15648.0 15648 15648 0.0 nvf::nvfuser_pointwise_f0_c1_r0_g7(nvf::Tensor<nvf::__bfloat, (int)3, (int)4>, nvf::Tensor<nvf::__b…
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had this issue here: #4844 (comment) and the former comment gave me the nsys profile which matched my wall-clock measurements most closely.
Have you noticed this discrepancy?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://docs.nvidia.com/nsight-systems/UserGuide/index.html?utm_source=chatgpt.com#handling-application-launchers-mpirun-deepspeed-etc doesn't have a strong opinion between the two for single-node. For me, the updated command line is more convenient -- it's shorter and gives me the timing of all GPUs in one file so it's easier to process. But let me double check the reported timing is trustworthy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My bad, let me clarify.
I was comparing the two comments in the diff:
-
Updated command (same as 2 in the above image):
# nsys profile --capture-range=cudaProfilerApi --capture-range-end=repeat:<iterations> mpirun -np <processes> pytest tests/python/multidevice/<test_file>.py -k <filter> --only-mpi. Collects data across all ranks. -
Previous command:
# mpirun -np 1 nsys profile --capture-range=cudaProfilerApi --capture-range-end=repeat:<iterations> pytest tests/python/multidevice/<test_file>.py -k <filter> --only-mpi : -np <processes - 1> pytest tests/python/multidevice/<test_file>.py -k <filter> --only-mpi. Only one rank collects data.
The nsys profiles I pasted above are corresponding to these 2 commands.
While (2) in the above image was closer than (1) in the same image, the numbers for communication were still unstable as compared to the previous command where only 1 rank uses nsys profile.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://docs.nvidia.com/nsight-systems/UserGuide/index.html?utm_source=chatgpt.com#handling-application-launchers-mpirun-deepspeed-etc doesn't have a strong opinion between the two for single-node. For me, the updated command line is more convenient -- it's shorter and gives me the timing of all GPUs in one file so it's easier to process. But let me double check the reported timing is trustworthy.
I only got very high numbers for communication (higher than wall-time), everything else looked okay.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried three configs:
nsys profile ... mpirun -np 8 ...mpirun -np 8 nsys profile ...mpirun -np 1 nsys profile ... : -np 7 ...
I prefer config 1 overall because only config 1 tells me:
- The actual run time of either1 allreduce operation is around 260us. Recall that kernels for the same allreduce ends at roughly the same time so the kernel starts early looks slower. Therefore, the actual run time of an allreduce operation, by my definition, is the run time of the fastest kernel corresponding to that allreduce operation.
- Despite of the starting time, allreduce kernels across GPUs end about the same time. The first allreduce ends at around 20,310us and the second around 37,012us.
Config 2 leads to a large variance and there's no way to align different GPUs because each runs a separate nsys profile. However, like config 1, it does show the 260us correctly.
Config 3 has lowest variance likely because the process being profiled tends to run slowest and launches allreduce kernels the last. However, even with that, the first allreduce is shown to run for 441us, which is inaccurate.
Config 1
$ nsys profile --capture-range=cudaProfilerApi --capture-range-end=repeat:1 mpirun -np 8 pytest tests/python/multidevice/test_transformer_engine.py -k 'test_transformer_layer[nonoverlap-tp-forward]' -vs --only-mpi
$ nsys stats report3.nsys-rep --report cuda_gpu_trace | grep nccl
13077591 7232222 1847 16 1 1 640 1 1 96 0.037 0.082 NVIDIA H100 80GB HBM3 (1) 1 7 ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
13364015 6946483 1847 16 1 1 640 1 1 96 0.037 0.082 NVIDIA H100 80GB HBM3 (2) 1 7 ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
13957463 6353057 1847 16 1 1 640 1 1 96 0.037 0.082 NVIDIA H100 80GB HBM3 (7) 1 7 ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
15182508 5127666 1847 16 1 1 640 1 1 96 0.037 0.082 NVIDIA H100 80GB HBM3 (6) 1 7 ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
15584437 4725800 1847 16 1 1 640 1 1 96 0.037 0.082 NVIDIA H100 80GB HBM3 (0) 1 7 ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
15749866 4560473 1847 16 1 1 640 1 1 96 0.037 0.082 NVIDIA H100 80GB HBM3 (4) 1 7 ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
16052423 4258616 1847 16 1 1 640 1 1 96 0.037 0.082 NVIDIA H100 80GB HBM3 (3) 1 7 ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
20044166 266689 1847 16 1 1 640 1 1 96 0.037 0.082 NVIDIA H100 80GB HBM3 (5) 1 7 ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
21276737 15735928 2066 16 1 1 640 1 1 96 0.037 0.082 NVIDIA H100 80GB HBM3 (6) 1 7 ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
21392976 15620182 2066 16 1 1 640 1 1 96 0.037 0.082 NVIDIA H100 80GB HBM3 (1) 1 7 ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
21410685 15601858 2068 16 1 1 640 1 1 96 0.037 0.082 NVIDIA H100 80GB HBM3 (3) 1 7 ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
22583001 14430372 2066 16 1 1 640 1 1 96 0.037 0.082 NVIDIA H100 80GB HBM3 (7) 1 7 ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
22691330 14321563 2066 16 1 1 640 1 1 96 0.037 0.082 NVIDIA H100 80GB HBM3 (0) 1 7 ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
24994300 12019087 2066 16 1 1 640 1 1 96 0.037 0.082 NVIDIA H100 80GB HBM3 (4) 1 7 ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
31899021 5114358 2066 16 1 1 640 1 1 96 0.037 0.082 NVIDIA H100 80GB HBM3 (2) 1 7 ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
36751238 261377 2072 16 1 1 640 1 1 96 0.037 0.082 NVIDIA H100 80GB HBM3 (5) 1 7 ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
Config 2
$ mpirun -np 8 nsys profile --capture-range=cudaProfilerApi --capture-range-end=repeat:1 pytest tests/python/multidevice/test_transformer_engine.py -k 'test_transformer_layer[nonoverlap-tp-forward]' -vs --only-mpi
$ for i in {4..11}; do nsys stats report$i.nsys-rep --report cuda_gpu_trace; done | grep nccl
10424681 40915927 1847 16 1 1 640 1 1 96 0.037 0.082 NVIDIA H100 80GB HBM3 (2) 1 7 ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
52313758 1030942 2066 16 1 1 640 1 1 96 0.037 0.082 NVIDIA H100 80GB HBM3 (2) 1 7 ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
10484621 264926 1847 16 1 1 640 1 1 96 0.037 0.082 NVIDIA H100 80GB HBM3 (1) 1 7 ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
12491523 261631 2066 16 1 1 640 1 1 96 0.037 0.082 NVIDIA H100 80GB HBM3 (1) 1 7 ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
10588716 14568360 1847 16 1 1 640 1 1 96 0.037 0.082 NVIDIA H100 80GB HBM3 (3) 1 7 ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
26119026 1042526 2066 16 1 1 640 1 1 96 0.037 0.082 NVIDIA H100 80GB HBM3 (3) 1 7 ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
13385933 167946450 1847 16 1 1 640 1 1 96 0.037 0.082 NVIDIA H100 80GB HBM3 (6) 1 7 ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
182302274 1033604 2072 16 1 1 640 1 1 96 0.037 0.082 NVIDIA H100 80GB HBM3 (6) 1 7 ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
34146687 756926958 1847 16 1 1 640 1 1 96 0.037 0.082 NVIDIA H100 80GB HBM3 (4) 1 7 ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
792038764 1038974 2066 16 1 1 640 1 1 96 0.037 0.082 NVIDIA H100 80GB HBM3 (4) 1 7 ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
21311074 121867114 1847 16 1 1 640 1 1 96 0.037 0.082 NVIDIA H100 80GB HBM3 (7) 1 7 ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
144132973 1048896 2066 16 1 1 640 1 1 96 0.037 0.082 NVIDIA H100 80GB HBM3 (7) 1 7 ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
12634286 470272575 1847 16 1 1 640 1 1 96 0.037 0.082 NVIDIA H100 80GB HBM3 (5) 1 7 ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
483859345 1050852 2066 16 1 1 640 1 1 96 0.037 0.082 NVIDIA H100 80GB HBM3 (5) 1 7 ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
10716413 52947804 1847 16 1 1 640 1 1 96 0.037 0.082 NVIDIA H100 80GB HBM3 (0) 1 7 ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
64613978 1054210 2066 16 1 1 640 1 1 96 0.037 0.082 NVIDIA H100 80GB HBM3 (0) 1 7 ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
Config 3
$ mpirun -np 1 nsys profile --capture-range=cudaProfilerApi --capture-range-end=repeat:1 pytest tests/python/multidevice/test_transformer_engine.py -k 'test_transformer_layer[nonoverlap-tp-forward]' -vs --only-mpi : -np 7 pytest tests/python/multidevice/test_transformer_engine.py -k 'test_transformer_layer[nonoverlap-tp-forward]' -vs --only-mpi
$ nsys stats report12.nsys-rep --report cuda_gpu_trace | grep nccl
10208430 441441 1847 16 1 1 640 1 1 96 0.037 0.082 NVIDIA H100 80GB HBM3 (0) 1 7 ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
12226226 270945 2066 16 1 1 640 1 1 96 0.037 0.082 NVIDIA H100 80GB HBM3 (0) 1 7 ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>)
(Note: you'll need to change report indices according to your generated files.)
Footnotes
-
a TP transformer layer forward runs two allreduces, one in MHA and the other in MLP. They transfer data of the same size,
b * s * h. ↩
Co-authored-by: Priya Mishra <[email protected]>
|
!test |
| // We should check that the loop is stream-parallel. This is not necessary | ||
| // at this moment because all loops are stream-parallel. This is also hard | ||
| // to do because hir::ForLoop doesn't point to the source IterDomain. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do ALL hir::ForLoops stream-parallelize? Is there no case where we want to sequentially loop in hir? or is this pass triggered by some other condition I'm not seeing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do ALL hir::ForLoops stream-parallelize?
Yes at this moment.
I'm considering separating ParallelType::Stream and ParallelType::HostSerial. The latter doesn't exist today. That's when we'll have to look at the parallel type of the loop index.
| @pytest.mark.benchmark | ||
| def test_row_parallel_linear_forward_benchmark(multidevice_test, benchmark): | ||
| # This is a port of CollectiveBasedOverlapTest.RowParallelLinear_Forward. | ||
| h, s, t = 8192, 2, 8192 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider making s a parameter to benchmark for s=2/4. For reference implementation, s=4 had better overlap.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you also add nsys profiles for s=4 in the PR description?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
s=4 is overall slower than s=2 according to my updated results, but it seems to show better overlap as you said. Recall that there's overhead in decomposing matmuls.
Priya2698
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Let me know if you plan to test merging the for-loops in this PR.
| } | ||
|
|
||
| // We should check that the loop is stream-parallel. This is not necessary | ||
| // at this moment because all loops are stream-parallel. This is also hard |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not strictly for this PR, but similar to kir::ForLoop, hir::ForLoop can hold the source iterdomain for this check
| for_loop->body().insert(old_begin, set_stream); | ||
| for_loop->body().insert(old_begin, sync_main); | ||
|
|
||
| // After the loop: create a joining loop to synchronize all worker streams |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you plan on merging this with the above for-loop?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't convinced myself that will work: http://nv/e-d
Co-authored-by: Priya Mishra <[email protected]>
Co-authored-by: Priya Mishra <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
Implements stream parallelization for loops in nvFuser's host IR to enable overlapping computation and communication operations. Adds new AssignStreams optimization pass that transforms stream-parallel loops by capturing the main stream, assigning worker streams at loop iteration start with synchronization, and creating a joining loop afterward to synchronize all worker streams back. Includes comprehensive test coverage with benchmarks showing nvFuser slightly outperforms the PyTorch reference implementation (3.6ms vs 4.6ms mean).
Confidence Score: 4/5
- Safe to merge with minor improvement opportunities
- The implementation is well-structured with proper synchronization patterns, comprehensive tests, and correct integration into the host IR pipeline. The main concern is the acknowledged missing validation for loop stream-parallelization (line 31-33 in assign_streams.cpp), which could theoretically transform non-stream-parallel loops incorrectly, though the comment indicates all current loops are stream-parallel. The code follows established patterns, includes thorough test coverage, and demonstrates correct behavior via benchmarks and profiling.
csrc/host_ir/assign_streams.cpp- consider adding validation or assertion for stream-parallel loop check
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| csrc/host_ir/assign_streams.cpp | 4/5 | Implements AssignStreams pass to transform stream-parallel loops with proper synchronization; well-structured but lacks validation for loop stream-parallelization |
| csrc/host_ir/passes.cpp | 5/5 | Integrates AssignStreams pass into host IR pipeline after AllocateAndDeallocate; correct ordering |
| tests/python/direct/test_stream.py | 5/5 | Comprehensive tests for stream parallelization with matmul operations; validates correct kernel count and shapes |
Sequence Diagram
sequenceDiagram
participant Main as Main Stream
participant W0 as Worker Stream 0
participant W1 as Worker Stream 1
participant W2 as Worker Stream 2
Note over Main: GetCurrentStream (capture main)
Note over Main: FOR i=0 to 3
Main->>W0: SetCurrentStream(0)
W0->>Main: Synchronize(main_stream)
Note over W0: Execute loop body (matmul/allreduce)
Main->>W1: SetCurrentStream(1)
W1->>Main: Synchronize(main_stream)
Note over W1: Execute loop body (matmul/allreduce)
Main->>W2: SetCurrentStream(2)
W2->>Main: Synchronize(main_stream)
Note over W2: Execute loop body (matmul/allreduce)
Note over Main: SetCurrentStream(main_stream)
Note over Main: FOR i=0 to 3 (joining loop)
Main->>W0: Synchronize(worker_stream_0)
Main->>W1: Synchronize(worker_stream_1)
Main->>W2: Synchronize(worker_stream_2)
Note over Main: All workers synchronized back
| // We should check that the loop is stream-parallel. This is not necessary | ||
| // at this moment because all loops are stream-parallel. This is also hard | ||
| // to do because hir::ForLoop doesn't point to the source IterDomain. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment acknowledges this validation is skipped, but consider adding a TODO or assertion to track this technical debt. Without validation, non-stream-parallel loops could be incorrectly transformed, potentially leading to incorrect synchronization patterns. At minimum, add a NVF_CHECK that verifies the loop meets basic requirements (e.g., has a valid index, start, and stop).
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR implements stream parallelization for loops in nvFuser's host IR to enable overlapping of computation and communication operations. The core implementation adds a new AssignStreams optimization pass that transforms stream-parallel loops by assigning worker streams to each iteration and adding proper synchronization.
Key Changes:
- New
AssignStreamspass incsrc/host_ir/assign_streams.{cpp,h}that transforms loops to use worker streams - Integration into host IR pipeline via
csrc/host_ir/passes.cpp - Comprehensive test coverage in
tests/python/multidevice/test_overlap.pywith benchmarks - Code cleanup: removed unnecessary includes in
allocate_and_deallocate.h,ir.h, andinternal_nodes.h
Transformation Pattern:
For each loop, the pass:
- Captures the main stream before the loop
- At the start of each iteration: switches to a worker stream and synchronizes with the main stream
- After the loop: creates a joining loop that synchronizes all worker streams back to main
Issues Found:
- Copyright year is 2026 in both new files (should be 2025)
- Missing validation that loops are actually stream-parallel (acknowledged in code comment but not implemented)
The implementation correctly follows the stream synchronization pattern demonstrated in the PyTorch reference implementation. Benchmark results show nvFuser achieves slight performance improvements over the baseline.
Confidence Score: 4/5
- This PR is safe to merge with minor corrections needed for copyright years
- The implementation is technically sound with correct synchronization logic matching the reference implementation. The only actual errors are copyright year mistakes (2026 instead of 2025). The missing stream-parallel validation is acknowledged in comments and appears to be a known limitation rather than an oversight. Comprehensive tests provide good coverage.
- Pay attention to the copyright years in
csrc/host_ir/assign_streams.cppandcsrc/host_ir/assign_streams.hwhich need correction from 2026 to 2025
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| csrc/host_ir/assign_streams.cpp | 4/5 | New file implementing stream assignment pass for stream-parallel loops. Transforms loops to use worker streams with proper synchronization. Copyright year is incorrect (2026 instead of 2025). Logic is sound but lacks validation that loops are actually stream-parallel. |
| csrc/host_ir/assign_streams.h | 4/5 | New header file declaring AssignStreams optimization pass. Copyright year is incorrect (2026 instead of 2025). Clean interface design. |
| csrc/host_ir/passes.cpp | 5/5 | Added AssignStreams pass to the host IR pipeline after AllocateAndDeallocate. Simple integration, no issues found. |
| tests/python/multidevice/test_overlap.py | 5/5 | Added comprehensive tests for stream-parallelized row-parallel linear forward pass with benchmarks. Includes reference implementation matching the transformation pattern. Well-structured tests with proper validation. |
Sequence Diagram
sequenceDiagram
participant Main as Main Stream
participant Worker as Worker Stream i
participant Ops as Operations
Note over Main: Save main stream
Main->>Main: GetCurrentStream
Note over Main,Worker: Loop iteration starts
Main->>Worker: SetCurrentStream worker_stream i
Worker->>Main: Synchronize main_stream
Note over Worker: Wait for main work to complete
Worker->>Ops: Execute matmul
Worker->>Ops: Execute allreduce
Note over Worker: Work continues in parallel
Note over Main: After loop completes
Worker->>Main: SetCurrentStream main_stream
Note over Main: Join loop synchronizes workers
Main->>Worker: Synchronize worker_stream 0
Main->>Worker: Synchronize worker_stream 1
Main->>Worker: Synchronize worker_stream N
Note over Main: All workers synchronized
Note over Main: Continue on main stream
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
Overview
This PR implements stream parallelization for loops in nvFuser's host IR, enabling overlapping of computation (matmul) and communication (allreduce) operations to improve performance in distributed workloads.
Key Changes
New Stream Assignment Pass (csrc/host_ir/assign_streams.cpp):
- Transforms stream-parallel loops to use multiple CUDA streams
- Captures the main stream before the loop
- Sets worker streams at the beginning of each iteration with synchronization
- Creates a joining loop after the main loop to synchronize all worker streams back to the main stream
Integration:
- Added
AssignStreamspass to the host IR optimization pipeline (runs afterAllocateAndDeallocate) - Updated CMakeLists.txt to include the new source file
Tests & Benchmarks:
- Refactored
test_row_parallel_linear_forwardto extract reusable fusion definition - Added parametrized benchmark
test_row_parallel_linear_forward_benchmarkwith varying chunk counts (s=1,2,4) - Includes reference PyTorch implementation for validation
Code Cleanup:
- Removed unnecessary includes from header files
- Improved documentation for
Scope::insertmethod - Fixed test function signatures
Architecture
The stream assignment creates a pattern where:
- Main stream is captured before the loop
- Each loop iteration executes on a dedicated worker stream
- Worker streams synchronize with the main stream before starting work
- After the main loop, a joining loop makes the main stream wait for all workers
This enables concurrent execution of loop iterations on different streams, particularly beneficial for overlapping matmul operations with collective communication.
Performance Impact
Benchmark results show slight improvement with overlapping (3.6ms vs 3.7-3.9ms for s=1), with the implementation verified via nsys profiling to correctly assign different streams to different operations.
Confidence Score: 2/5
- This PR contains a critical logic error that will cause incorrect stream synchronization behavior at runtime.
- Score reflects a critical bug in csrc/host_ir/assign_streams.cpp (lines 49-53) where the insertion order of SetCurrentStream and join_loop is inverted, causing the join loop to execute before setting the current stream to main_stream. This means synchronizations will happen on the wrong stream (last worker stream instead of main stream), breaking the intended stream parallelization semantics. Additionally, the copyright year is incorrect (2026 instead of 2025).
- csrc/host_ir/assign_streams.cpp requires immediate attention to fix the insertion order bug on lines 49-53. The expected Host IR comment in tests/python/multidevice/test_overlap.py (line 74-75) should also be updated to include the missing SetCurrentStream statement.
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| csrc/host_ir/assign_streams.cpp | 2/5 | New file implementing stream assignment pass. Contains critical logic error in insertion order (lines 49-53) causing incorrect stream synchronization sequence. Also has incorrect copyright year (2026). |
| csrc/host_ir/assign_streams.h | 5/5 | New header file declaring AssignStreams optimization pass. Clean interface, no issues detected. |
| csrc/host_ir/passes.cpp | 5/5 | Integrates new AssignStreams pass into pipeline. Correct ordering after AllocateAndDeallocate pass. |
| tests/python/multidevice/test_overlap.py | 4/5 | Refactors test into reusable function, adds benchmark test, updates to use outer_split instead of split. Expected Host IR comment may be missing SetCurrentStream statement between loops. |
Sequence Diagram
sequenceDiagram
participant Main as Main Stream
participant W0 as Worker Stream 0
participant W1 as Worker Stream 1
participant W2 as Worker Stream 2
Note over Main: GetCurrentStream
rect rgb(200, 220, 255)
Note over Main,W2: Main Loop (stream-parallel iterations)
Main->>W0: SetCurrentStream(worker_0)
W0->>Main: Synchronize(main_stream)
Note over W0: Compute chunk 0<br/>(matmul + allreduce)
Main->>W1: SetCurrentStream(worker_1)
W1->>Main: Synchronize(main_stream)
Note over W1: Compute chunk 1<br/>(matmul + allreduce)
Main->>W2: SetCurrentStream(worker_2)
W2->>Main: Synchronize(main_stream)
Note over W2: Compute chunk 2<br/>(matmul + allreduce)
end
Note over Main: SetCurrentStream(main_stream)
rect rgb(255, 220, 200)
Note over Main,W2: Join Loop (synchronize workers)
W0->>Main: Synchronize(worker_0)
W1->>Main: Synchronize(worker_1)
W2->>Main: Synchronize(worker_2)
end
Note over Main: Continue on main stream
| @@ -0,0 +1,64 @@ | |||
| // clang-format off | |||
| /* | |||
| * SPDX-FileCopyrightText: Copyright (c) 2026-present NVIDIA CORPORATION & AFFILIATES. | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The copyright year is set to 2026, which appears to be incorrect. This should likely be 2025 or earlier.
| * SPDX-FileCopyrightText: Copyright (c) 2026-present NVIDIA CORPORATION & AFFILIATES. | |
| * SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. |
| hic->topLevel().insert( | ||
| next_it, IrBuilder::create<SetCurrentStream>(main_stream)); | ||
| auto* join_loop = IrBuilder::create<ForLoop>( | ||
| for_loop->index(), for_loop->start(), for_loop->stop()); | ||
| hic->topLevel().insert(next_it, join_loop); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical logic error: The insertion order will result in incorrect code generation. Both statements use insert(next_it, ...) which inserts BEFORE the iterator position. Since both use the same next_it, the join_loop (inserted second) will be placed BEFORE the SetCurrentStream (inserted first), resulting in:
FOR loop (original)
join_loop ← wrong position
SetCurrentStream ← wrong position
The correct order should be SetCurrentStream THEN join_loop, because:
- After the first FOR loop exits, the current stream is the last worker stream (not the main stream)
- The join_loop synchronizes worker streams on the current stream
- Therefore, we must set the current stream to main_stream BEFORE the join_loop
This matches the reference implementation row_parallel_linear_forward_reference in test_overlap.py (lines 182-194) where main_stream.wait_stream(worker_stream) happens on the main_stream.
The fix is to save the iterator returned from the first insert and use it for the second insert:
| hic->topLevel().insert( | |
| next_it, IrBuilder::create<SetCurrentStream>(main_stream)); | |
| auto* join_loop = IrBuilder::create<ForLoop>( | |
| for_loop->index(), for_loop->start(), for_loop->stop()); | |
| hic->topLevel().insert(next_it, join_loop); | |
| // After the loop: create a joining loop to synchronize all worker streams | |
| auto set_main_it = hic->topLevel().insert( | |
| next_it, IrBuilder::create<SetCurrentStream>(main_stream)); | |
| auto* join_loop = IrBuilder::create<ForLoop>( | |
| for_loop->index(), for_loop->start(), for_loop->stop()); | |
| hic->topLevel().insert(std::next(set_main_it), join_loop); |
| # FOR i535 from 0 to 3: | ||
| # Synchronize Stream i535 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The expected Host IR output in this comment appears to be missing a SetCurrentStream statement between the two FOR loops. Based on the implementation in assign_streams.cpp (lines 49-50), there should be a SetCurrentStream to Stream 0x174e5c80 before the second FOR loop.
The correct expected output should be:
# FOR i535 from 0 to 3:
# SetCurrentStream to Stream i535
# Synchronize Stream 0x174e5c80
# ... (computations)
# SetCurrentStream to Stream 0x174e5c80
# FOR i535 from 0 to 3:
# Synchronize Stream i535
This matches the reference implementation row_parallel_linear_forward_reference where the main stream waits for all worker streams after they complete.
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR implements stream parallelization for loops in nvFuser's host IR to enable overlapping of computation (matmul) and communication (allreduce) operations. The core implementation adds an AssignStreams optimization pass that transforms stream-parallel loops by capturing the main stream, assigning worker streams to each iteration with proper synchronization, and creating a joining loop to synchronize all worker streams back to the main stream.
Key Changes:
- New
csrc/host_ir/assign_streams.{cpp,h}implementing the stream assignment pass - Integration of the pass into the host IR optimization pipeline
- Refactored test structure with new benchmark tests
- Code cleanup removing unnecessary includes
Issues Found:
- The joining loop reuses the main loop's index variable, which could cause bugs if code after the loops depends on the index value
- Copyright year typo (2026 instead of 2025) in the new header file
Confidence Score: 3/5
- This PR has one moderate logic issue that should be addressed before merging
- The implementation is generally sound with proper synchronization primitives and good test coverage. However, there's a potential bug where the joining loop reuses the main loop's index variable (line 51 in assign_streams.cpp), which could cause issues if any code depends on the loop index value after the loop completes. This needs verification or fixing. The copyright year typo is a minor style issue.
- Pay close attention to
csrc/host_ir/assign_streams.cpp- specifically the index variable reuse in the joining loop
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| csrc/host_ir/assign_streams.cpp | 3/5 | New file implementing stream parallelization pass. Contains potential bug with loop index reuse in joining loop. |
| csrc/host_ir/assign_streams.h | 4/5 | Header file with copyright year typo (2026 instead of 2025). |
| tests/python/multidevice/test_overlap.py | 4/5 | Refactored test to extract fusion definition, added benchmark test. Changed from torch.randint to torch.testing.make_tensor and torch.randn for test data generation. |
Sequence Diagram
sequenceDiagram
participant Main as Main Stream
participant Pass as AssignStreams Pass
participant ForLoop as Stream Parallel ForLoop
participant W0 as Worker Stream 0
participant W1 as Worker Stream 1
participant WN as Worker Stream N
Pass->>Main: GetCurrentStream
Note over Pass: Capture main stream before loop
Pass->>ForLoop: Transform loop structure
Note over ForLoop,WN: Main Loop Execution
ForLoop->>W0: SetCurrentStream worker 0
W0->>Main: Synchronize with main stream
W0->>W0: Execute iteration 0 body
ForLoop->>W1: SetCurrentStream worker 1
W1->>Main: Synchronize with main stream
W1->>W1: Execute iteration 1 body
ForLoop->>WN: SetCurrentStream worker N
WN->>Main: Synchronize with main stream
WN->>WN: Execute iteration N body
Pass->>Main: SetCurrentStream back to main
Note over ForLoop,WN: Joining Loop Execution
Main->>W0: Synchronize worker stream 0
Main->>W1: Synchronize worker stream 1
Main->>WN: Synchronize worker stream N
Note over Main: All workers joined to main stream
| @@ -0,0 +1,26 @@ | |||
| // clang-format off | |||
| /* | |||
| * SPDX-FileCopyrightText: Copyright (c) 2026-present NVIDIA CORPORATION & AFFILIATES. | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
copyright year is 2026 (future year) - should be 2025
| auto* join_loop = IrBuilder::create<ForLoop>( | ||
| for_loop->index(), for_loop->start(), for_loop->stop()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reusing for_loop->index() for the joining loop could cause issues. The index variable from the main loop will be overwritten during the join loop execution. If any code after the loops depends on the loop index value, this could lead to bugs. Consider creating a new index variable for the joining loop.
|
!test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
Implements stream parallelization for loops in nvFuser's host IR to enable overlapping of computation and communication operations. Adds an AssignStreams optimization pass that transforms loops by capturing the main stream, assigning worker streams to each iteration with proper synchronization, and creating a joining loop to synchronize all worker streams back to the main stream. Includes comprehensive test coverage with benchmarks showing slight performance improvements from overlapping.
Confidence Score: 4/5
- Safe to merge with minor style improvement suggested for iterator handling
- The implementation correctly achieves stream parallelization with proper synchronization semantics as verified by tests. The core logic is sound: capturing the main stream, assigning worker streams per iteration, and synchronizing back via a joining loop. One style issue noted regarding iterator invalidation (lines 44-46 in assign_streams.cpp) that works in practice but could be clearer. No functional bugs or security concerns identified.
- csrc/host_ir/assign_streams.cpp for iterator handling style improvement
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| csrc/host_ir/assign_streams.cpp | 4/5 | Adds stream assignment pass to parallelize loops across streams for overlapping computation and communication |
| csrc/host_ir/passes.cpp | 5/5 | Integrates AssignStreams pass into the host IR optimization pipeline |
| tests/python/multidevice/test_overlap.py | 4/5 | Refactors row parallel linear forward test, adds benchmark tests and updates expected host IR comments |
Sequence Diagram
sequenceDiagram
participant Main as Main Stream
participant W0 as Worker Stream 0
participant W1 as Worker Stream 1
participant W2 as Worker Stream 2
Note over Main: GetCurrentStream
rect rgb(200, 220, 255)
Note over Main,W2: Main Loop (3 iterations)
Main->>W0: SetCurrentStream(0)
W0->>Main: Synchronize(main)
Note over W0: Compute + AllReduce
par Parallel Execution
Main->>W1: SetCurrentStream(1)
W1->>Main: Synchronize(main)
Note over W1: Compute + AllReduce
and
Note over W0: Continue async work
end
par Parallel Execution
Main->>W2: SetCurrentStream(2)
W2->>Main: Synchronize(main)
Note over W2: Compute + AllReduce
and
Note over W0,W1: Continue async work
end
end
rect rgb(255, 220, 200)
Note over Main,W2: Joining Loop (synchronize workers)
Main->>W0: Synchronize(worker 0)
Main->>W1: Synchronize(worker 1)
Main->>W2: Synchronize(worker 2)
end
Main->>Main: SetCurrentStream(main)
| auto old_begin = for_loop->body().exprs().begin(); | ||
| for_loop->body().insert(old_begin, set_stream); | ||
| for_loop->body().insert(old_begin, sync_main); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The insertion order places sync_main before set_stream due to how insert works with old_begin. After the first insert at old_begin, set_stream is inserted. When the second insert happens at the same old_begin (which still points to the original beginning), sync_main gets inserted before set_stream.
This creates the sequence: sync_main → set_stream → rest of body. However, the correct order should be set_stream → sync_main because you need to set the worker stream first before synchronizing with the main stream on that worker stream.
| auto old_begin = for_loop->body().exprs().begin(); | |
| for_loop->body().insert(old_begin, set_stream); | |
| for_loop->body().insert(old_begin, sync_main); | |
| auto* worker_stream = IrBuilder::create<Stream>(for_loop->index()); | |
| auto* set_stream = IrBuilder::create<SetCurrentStream>(worker_stream); | |
| auto* sync_main = IrBuilder::create<Synchronize>(main_stream); | |
| auto old_begin = for_loop->body().exprs().begin(); | |
| for_loop->body().insert(old_begin, set_stream); | |
| auto new_begin = for_loop->body().exprs().begin(); | |
| for_loop->body().insert(std::next(new_begin), sync_main); |
| auto old_begin = for_loop->body().exprs().begin(); | ||
| for_loop->body().insert(old_begin, set_stream); | ||
| for_loop->body().insert(old_begin, sync_main); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using old_begin after the first insert relies on iterator invalidation rules that may lead to undefined behavior. While this works in practice (confirmed by tests), consider capturing the return value for clarity:
| auto old_begin = for_loop->body().exprs().begin(); | |
| for_loop->body().insert(old_begin, set_stream); | |
| for_loop->body().insert(old_begin, sync_main); | |
| auto old_begin = for_loop->body().exprs().begin(); | |
| auto it = for_loop->body().insert(old_begin, set_stream); | |
| for_loop->body().insert(std::next(it), sync_main); |
This makes the intent clearer and avoids potential issues with iterator invalidation.

Fixes #5308
Overlapping improves the wall time slightly.
Stream assignment and overlapping are verified by the following:
The performance is suboptimal for two reasons: