Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -224,11 +224,12 @@ void sm100_generic_mixed_gemm_kernelLauncher(ActivationType const* A, WeightType
StrideD stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(n, m, 1));
StrideC stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(0, m, 1));

ElementAccumulator output_op_beta = (biases == nullptr) ? ElementAccumulator(0.f) : ElementAccumulator(1.f);
typename Gemm::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm, {n, m, k, 1},
{reinterpret_cast<CutlassWeightType const*>(B), stride_B, reinterpret_cast<CutlassActivationType const*>(A),
stride_A, reinterpret_cast<ElementScale const*>(weight_scales), layout_S, group_size,
reinterpret_cast<ElementZero const*>(weight_zero_points)},
{{alpha}, reinterpret_cast<CutlassBiasType const*>(biases), stride_C,
{{alpha, output_op_beta}, reinterpret_cast<CutlassBiasType const*>(biases), stride_C,
reinterpret_cast<CutlassOutputType*>(C), stride_D}};

Gemm gemm;
Expand Down
102 changes: 70 additions & 32 deletions examples/disaggregated/slurm/benchmark/submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,22 +84,26 @@ def assign_server(server_allocation: Dict[str, Any], world_size: int,
server_allocation["nodes"][hostname].append(gpu_id)
global_gpu_cursor += 1

port = base_port

def assign_servers(
server_allocations: Dict[str, Any],
server_type: str,
num_servers: int,
world_size: int,
gpus_per_node: int,
):
nonlocal port
if server_type not in server_allocations:
server_allocations[server_type] = {}
for i in range(num_servers):
server_allocation = {
"port": base_port + i,
"port": port,
"nodes": {},
}
assign_server(server_allocation, world_size, gpus_per_node)
server_allocations[server_type][i] = server_allocation
port += 1

assign_servers(allocations, "GEN", num_gen_servers, gen_world_size,
gpus_per_node)
Expand Down Expand Up @@ -142,6 +146,18 @@ def convert_allocations_to_server_config(allocations, server_port=8333):
return server_config


def upsert_env_config(env_config, config_key, key_name, value_str):
"""Upsert env var into env_config key.

Replaces existing entry for the same key name, or prepends if not present.
"""
parts = [
part for part in env_config.get(config_key, '').split()
if not part.startswith(f"{key_name}=")
]
env_config[config_key] = " ".join([value_str, *parts]).strip()


def convert_envs_to_str(env_vars: Dict[str, str]) -> str:
return ','.join([f"{key}='{value}'" for key, value in env_vars.items()])

Expand Down Expand Up @@ -195,18 +211,52 @@ def build_worker_environment(worker_config, env_config, role, benchmark_mode,
"""
env = {}

# 1. Use gpu_ids to set CUDA_VISIBLE_DEVICES
# 1. Add mode-based env vars to env_config
if benchmark_mode == "gen_only_no_context":
upsert_env_config(env_config, 'worker_env_var',
'TRTLLM_DISAGG_BENCHMARK_GEN_ONLY',
'TRTLLM_DISAGG_BENCHMARK_GEN_ONLY=1')
if benchmark_mode == "gen_only":
upsert_env_config(env_config, 'worker_env_var',
'TRTLLM_DISABLE_KV_CACHE_TRANSFER_OVERLAP',
'TRTLLM_DISABLE_KV_CACHE_TRANSFER_OVERLAP=1')
if role == "GEN":
upsert_env_config(env_config, 'gen_worker_env_var',
'TLLM_BENCHMARK_REQ_QUEUES_SIZE',
f'TLLM_BENCHMARK_REQ_QUEUES_SIZE={concurrency}')

# 2. Add profiling env vars to env_config (conditional)
if nsys_on:
upsert_env_config(env_config, 'worker_env_var',
'TLLM_PROFILE_RECORD_GC', 'TLLM_PROFILE_RECORD_GC=1')
upsert_env_config(env_config, 'worker_env_var', 'TLLM_NVTX_DEBUG',
'TLLM_NVTX_DEBUG=1')
upsert_env_config(env_config, 'worker_env_var',
'NSYS_MPI_STORE_TEAMS_PER_RANK',
'NSYS_MPI_STORE_TEAMS_PER_RANK=1')
if role == "CTX":
upsert_env_config(env_config, 'ctx_worker_env_var',
'TLLM_PROFILE_START_STOP',
f'TLLM_PROFILE_START_STOP={profile_range}')
elif role == "GEN":
upsert_env_config(env_config, 'gen_worker_env_var',
'TLLM_PROFILE_START_STOP',
f'TLLM_PROFILE_START_STOP={profile_range}')

# 3. Set CUDA_VISIBLE_DEVICES from gpu_ids
cuda_devices = ','.join(map(str, gpu_ids))
env["CUDA_VISIBLE_DEVICES"] = cuda_devices

# 2. Parse user-defined worker env vars from config
# 4. Parse user-defined worker env vars from config
# (now includes mode-based and profiling vars from steps 1-2)
worker_env_var = env_config.get('worker_env_var', '')
for var_string in worker_env_var.split():
if '=' in var_string:
key, val = var_string.split('=', 1)
env[key] = val

# 3. Add role-specific env vars (CTX or GEN)
# 5. Add role-specific env vars (CTX or GEN)
# (now includes role-specific mode/profiling vars from steps 1-2)
role_env_vars = {
"CTX": env_config.get('ctx_worker_env_var', ''),
"GEN": env_config.get('gen_worker_env_var', '')
Expand All @@ -217,21 +267,6 @@ def build_worker_environment(worker_config, env_config, role, benchmark_mode,
key, val = var_string.split('=', 1)
env[key] = val

# 4. Add mode-based env vars
if benchmark_mode == "gen_only_no_context":
env["TRTLLM_DISAGG_BENCHMARK_GEN_ONLY"] = "1"
if benchmark_mode == "gen_only":
env["TRTLLM_DISABLE_KV_CACHE_TRANSFER_OVERLAP"] = "1"
if role == "GEN":
env["TLLM_BENCHMARK_REQ_QUEUES_SIZE"] = str(concurrency)

# 5. Add profiling env vars (conditional)
if nsys_on:
env["TLLM_PROFILE_RECORD_GC"] = "1"
env["TLLM_NVTX_DEBUG"] = "1"
env["NSYS_MPI_STORE_TEAMS_PER_RANK"] = "1"
env["TLLM_PROFILE_START_STOP"] = profile_range

return env


Expand All @@ -247,17 +282,19 @@ def build_server_environment(env_config, benchmark_mode):
"""
env = {}

# Parse user-defined server env vars
# Add mode-based env vars to env_config
if benchmark_mode == "gen_only_no_context":
upsert_env_config(env_config, 'server_env_var',
'TRTLLM_DISAGG_BENCHMARK_GEN_ONLY',
'TRTLLM_DISAGG_BENCHMARK_GEN_ONLY=1')

# Parse user-defined server env vars (now includes mode-based vars)
server_env_var = env_config.get('server_env_var', '')
for var_string in server_env_var.split():
if '=' in var_string:
key, val = var_string.split('=', 1)
env[key] = val

# Add mode-based env vars
if benchmark_mode == "gen_only_no_context":
env["TRTLLM_DISAGG_BENCHMARK_GEN_ONLY"] = "1"

return env


Expand Down Expand Up @@ -429,14 +466,6 @@ def submit_job(config, log_dir, dry_run):
os.makedirs(log_dir, exist_ok=True)
print(f"Log will be saved to: {log_dir}")

# Save environment variables (for record-keeping only)
worker_env_var = env_config.get('worker_env_var', '')
ctx_worker_env_var = env_config.get('ctx_worker_env_var', '')
gen_worker_env_var = env_config.get('gen_worker_env_var', '')
server_env_var = env_config.get('server_env_var', '')
save_env_file(os.path.join(log_dir, "env_vars.json"), server_env_var,
worker_env_var, ctx_worker_env_var, gen_worker_env_var)

# Setup config file paths and save worker configs
ctx_config_path = os.path.join(log_dir, 'ctx_config.yaml')
gen_config_path = os.path.join(log_dir, 'gen_config.yaml')
Expand Down Expand Up @@ -546,6 +575,15 @@ def submit_job(config, log_dir, dry_run):
]
start_server_cmds.append(" ".join(cmd))

# Read env_config after worker/server env build so env_vars.json includes runtime-added vars
save_env_file(
os.path.join(log_dir, "env_vars.json"),
env_config.get('server_env_var', ''),
env_config.get('worker_env_var', ''),
env_config.get('ctx_worker_env_var', ''),
env_config.get('gen_worker_env_var', ''),
)

# Generate wait server command (use script_dir for wait_server.sh)
cmd = [
"srun -l",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,14 @@ def __call__(
active_slots = [[]]
generation_steps = []
logits_vec = [[]]
for i, r in enumerate(scheduled_requests.context_requests):
if r.is_last_context_chunk:
active_slots[0].append(r.py_seq_slot)
generation_steps.append(r.decoding_iter)
logits_vec[0].append(
logits[num_context_logits_prefix_sum[i]:
num_context_logits_prefix_sum[i + 1]].unsqueeze(0))
for i, r in enumerate(
scheduled_requests.context_requests_last_chunk,
start=len(scheduled_requests.context_requests_chunking)):
active_slots[0].append(r.py_seq_slot)
generation_steps.append(r.decoding_iter)
logits_vec[0].append(
logits[num_context_logits_prefix_sum[i]:
num_context_logits_prefix_sum[i + 1]].unsqueeze(0))

logits_index = num_context_logits_prefix_sum[-1]
for i, r in enumerate(scheduled_requests.generation_requests):
Expand Down
11 changes: 4 additions & 7 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1535,7 +1535,7 @@ def _ring_broadcast_sample_state(
tag = PPCommTag.SAMPLE_STATE
microbatch_id = executed_batch.microbatch_id
sample_state = executed_batch.sample_state
requests = sample_state.scheduled_requests.all_requests()
requests = sample_state.requests

if not self.dist.is_last_pp_rank:
# Receive tokens from previous pp rank (w.r.t model forward direction)
Expand Down Expand Up @@ -1570,13 +1570,9 @@ def _handle_executed_batch(self, executed_batch: Optional[BatchStatePP]):
finished_requests = []
if executed_batch is not None:
with torch.cuda.nvtx.range("_handle_executed_batch_pp"):
scheduled_requests = executed_batch.scheduled_requests
sampling_requests = ScheduledRequests()
sampling_requests.context_requests_last_chunk = scheduled_requests.context_requests_last_chunk
sampling_requests.generation_requests = scheduled_requests.generation_requests
executed_batch.sample_state.scheduled_requests = sampling_requests
self._update_requests(executed_batch.sample_state)

scheduled_requests = executed_batch.scheduled_requests
if self.kv_cache_transceiver:
finished_ctx_reqs = scheduled_requests.context_requests_last_chunk
self._send_kv_async(finished_ctx_reqs)
Expand Down Expand Up @@ -2369,8 +2365,9 @@ def _forward_step_inter_pp(self, scheduled_batch) -> SampleState:
sampler_event = torch.cuda.Event()
sampler_event.record()
self._update_request_states(scheduled_batch)
sampling_requests = scheduled_batch.context_requests_last_chunk + scheduled_batch.generation_requests
return self.sampler.SampleState(
scheduled_requests=scheduled_batch,
requests=sampling_requests,
sampler_event=SamplerEvent(cuda_event=sampler_event),
runtime_draft_len=self.model_engine.runtime_draft_len,
)
Expand Down
Loading
Loading