diff --git a/configs/disagg/wan_t2v_disagg_controller.json b/configs/disagg/wan_t2v_disagg_controller.json new file mode 100644 index 000000000..d55badf24 --- /dev/null +++ b/configs/disagg/wan_t2v_disagg_controller.json @@ -0,0 +1,26 @@ +{ + "infer_steps": 50, + "target_video_length": 81, + "text_len": 512, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": 5, + "sample_shift": 5, + "enable_cfg": true, + "cpu_offload": false, + "fps": 16, + "disagg_mode": "controller", + "disagg_config": { + "bootstrap_addr": "127.0.0.1", + "bootstrap_room": 0, + "encoder_engine_rank": 0, + "transformer_engine_rank": 1, + "decoder_engine_rank": 2, + "protocol": "rdma", + "local_hostname": "localhost", + "metadata_server": "P2PHANDSHAKE" + } +} diff --git a/configs/disagg/wan_t2v_disagg_decoder.json b/configs/disagg/wan_t2v_disagg_decoder.json new file mode 100644 index 000000000..a778bcfaa --- /dev/null +++ b/configs/disagg/wan_t2v_disagg_decoder.json @@ -0,0 +1,26 @@ +{ + "infer_steps": 50, + "target_video_length": 81, + "text_len": 512, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": 5, + "sample_shift": 5, + "enable_cfg": true, + "cpu_offload": false, + "fps": 16, + "disagg_mode": "decoder", + "disagg_config": { + "bootstrap_addr": "127.0.0.1", + "bootstrap_room": 0, + "encoder_engine_rank": 0, + "transformer_engine_rank": 1, + "decoder_engine_rank": 2, + "protocol": "rdma", + "local_hostname": "localhost", + "metadata_server": "P2PHANDSHAKE" + } +} diff --git a/configs/disagg/wan_t2v_disagg_encoder.json b/configs/disagg/wan_t2v_disagg_encoder.json new file mode 100644 index 000000000..57adbbd41 --- /dev/null +++ b/configs/disagg/wan_t2v_disagg_encoder.json @@ -0,0 +1,26 @@ +{ + "infer_steps": 50, + "target_video_length": 81, + "text_len": 512, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": 5, + "sample_shift": 5, + "enable_cfg": true, + "cpu_offload": false, + "fps": 16, + "disagg_mode": "encoder", + "disagg_config": { + "bootstrap_addr": "127.0.0.1", + "bootstrap_room": 0, + "encoder_engine_rank": 0, + "transformer_engine_rank": 1, + "decoder_engine_rank": 2, + "protocol": "rdma", + "local_hostname": "localhost", + "metadata_server": "P2PHANDSHAKE" + } +} diff --git a/configs/disagg/wan_t2v_disagg_transformer.json b/configs/disagg/wan_t2v_disagg_transformer.json new file mode 100644 index 000000000..af6bf4c02 --- /dev/null +++ b/configs/disagg/wan_t2v_disagg_transformer.json @@ -0,0 +1,26 @@ +{ + "infer_steps": 50, + "target_video_length": 81, + "text_len": 512, + "target_height": 480, + "target_width": 832, + "self_attn_1_type": "sage_attn2", + "cross_attn_1_type": "sage_attn2", + "cross_attn_2_type": "sage_attn2", + "sample_guide_scale": 5, + "sample_shift": 5, + "enable_cfg": true, + "cpu_offload": false, + "fps": 16, + "disagg_mode": "transformer", + "disagg_config": { + "bootstrap_addr": "127.0.0.1", + "bootstrap_room": 0, + "encoder_engine_rank": 0, + "transformer_engine_rank": 1, + "decoder_engine_rank": 2, + "protocol": "rdma", + "local_hostname": "localhost", + "metadata_server": "P2PHANDSHAKE" + } +} diff --git a/lightx2v/disagg/conn.py b/lightx2v/disagg/conn.py index 2ae6ba863..cd592de7d 100644 --- a/lightx2v/disagg/conn.py +++ b/lightx2v/disagg/conn.py @@ -75,6 +75,7 @@ class DataPoll: RequestPoolType = Dict[int, List[int]] WaitingPoolType = Dict[int, Tuple[str, list[int]]] +MONITOR_POLLING_PORT = 7788 REQUEST_POLLING_PORT = 12788 DATASENDER_POLLING_PORT = 17788 DATARECEIVER_POLLING_PORT = 27788 diff --git a/lightx2v/disagg/examples/run_service.py b/lightx2v/disagg/examples/run_service.py new file mode 100644 index 000000000..3c1659284 --- /dev/null +++ b/lightx2v/disagg/examples/run_service.py @@ -0,0 +1,129 @@ +import argparse +import json +import logging + +from loguru import logger + +from lightx2v.disagg.services.controller import ControllerService +from lightx2v.disagg.services.decoder import DecoderService +from lightx2v.disagg.services.encoder import EncoderService +from lightx2v.disagg.services.transformer import TransformerService +from lightx2v.disagg.utils import set_config +from lightx2v.utils.utils import seed_all + +logging.basicConfig(level=logging.INFO) + + +def _build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Run a disaggregated LightX2V service process") + parser.add_argument("--model_cls", type=str, default="wan2.1") + parser.add_argument("--task", type=str, default="t2v") + parser.add_argument("--model_path", type=str, required=True) + parser.add_argument("--config_json", type=str, required=True) + + parser.add_argument("--seed", type=int, default=42) + parser.add_argument( + "--prompt", + type=str, + default="Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", + ) + parser.add_argument( + "--negative_prompt", + type=str, + default=( + "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰," + "最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部," + "畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" + ), + ) + parser.add_argument( + "--save_result_path", + type=str, + default="/root/zht/LightX2V/save_results/test_disagg.mp4", + ) + + parser.add_argument( + "--service", + type=str, + choices=["encoder", "transformer", "decoder", "controller", "auto"], + default="auto", + help="Service role. auto = infer from config_json.disagg_mode", + ) + return parser + + +def _normalize_disagg_config(config: dict) -> dict: + disagg_cfg = config.get("disagg_config") + if isinstance(disagg_cfg, dict): + mapping = { + "bootstrap_addr": "data_bootstrap_addr", + "bootstrap_room": "data_bootstrap_room", + "encoder_engine_rank": "encoder_engine_rank", + "transformer_engine_rank": "transformer_engine_rank", + "decoder_engine_rank": "decoder_engine_rank", + "protocol": "protocol", + "local_hostname": "local_hostname", + "metadata_server": "metadata_server", + } + for src_key, dst_key in mapping.items(): + if src_key in disagg_cfg: + config[dst_key] = disagg_cfg[src_key] + return config + + +def _load_raw_json(path: str) -> dict: + with open(path, "r", encoding="utf-8") as f: + return json.load(f) + + +def _resolve_service_mode(args: argparse.Namespace, raw_cfg: dict) -> str: + if args.service != "auto": + return args.service + mode = raw_cfg.get("disagg_mode") + if mode in {"encoder", "transformer", "decoder", "controller"}: + return mode + raise ValueError("Cannot resolve service mode: use --service or set disagg_mode in config_json") + + +def _build_runtime_config(args: argparse.Namespace) -> tuple[dict, dict]: + raw_cfg = _load_raw_json(args.config_json) + + config = set_config( + model_path=args.model_path, + task=args.task, + model_cls=args.model_cls, + config_path=args.config_json, + ) + + config = _normalize_disagg_config(config) + raw_cfg = _normalize_disagg_config(raw_cfg) + + config["seed"] = args.seed + config["prompt"] = args.prompt + config["negative_prompt"] = args.negative_prompt + config["save_path"] = args.save_result_path + return config, raw_cfg + + +def main(): + args = _build_parser().parse_args() + config, raw_cfg = _build_runtime_config(args) + service_mode = _resolve_service_mode(args, raw_cfg) + + seed_all(args.seed) + logger.info("Starting disagg service mode={}", service_mode) + + if service_mode == "encoder": + EncoderService(config).run() + elif service_mode == "transformer": + TransformerService(config).run() + elif service_mode == "decoder": + DecoderService(config).run() + elif service_mode == "controller": + ControllerService().run(config) + else: + raise ValueError(f"Unsupported service mode: {service_mode}") + + +if __name__ == "__main__": + main() diff --git a/lightx2v/disagg/examples/wan_i2v_service.py b/lightx2v/disagg/examples/wan_i2v_service.py index 5e47c2407..c5030d669 100644 --- a/lightx2v/disagg/examples/wan_i2v_service.py +++ b/lightx2v/disagg/examples/wan_i2v_service.py @@ -84,21 +84,21 @@ def main(): # 3. Define service threads def run_encoder(): logger.info("Initializing Encoder Service...") - encoder_service = EncoderService() + encoder_service = EncoderService(config) logger.info("Running Encoder Service...") encoder_service.run(stop_event=encoder_stop_event) logger.info("Encoder Service completed.") def run_transformer(): logger.info("Initializing Transformer Service...") - transformer_service = TransformerService() + transformer_service = TransformerService(config) logger.info("Running Transformer Service...") transformer_service.run(stop_event=transformer_stop_event) logger.info("Transformer Service completed.") def run_decoder(): logger.info("Initializing Decoder Service...") - decoder_service = DecoderService() + decoder_service = DecoderService(config) logger.info("Running Decoder Service...") decoder_service.run(stop_event=decoder_stop_event) logger.info("Video generation completed.") diff --git a/lightx2v/disagg/examples/wan_t2v_service.py b/lightx2v/disagg/examples/wan_t2v_service.py index c0c5dd833..80833dd6e 100644 --- a/lightx2v/disagg/examples/wan_t2v_service.py +++ b/lightx2v/disagg/examples/wan_t2v_service.py @@ -67,21 +67,21 @@ def main(): # 2. Define service threads def run_encoder(): logger.info("Initializing Encoder Service...") - encoder_service = EncoderService() + encoder_service = EncoderService(config) logger.info("Running Encoder Service...") encoder_service.run(stop_event=encoder_stop_event) logger.info("Encoder Service completed.") def run_transformer(): logger.info("Initializing Transformer Service...") - transformer_service = TransformerService() + transformer_service = TransformerService(config) logger.info("Running Transformer Service...") transformer_service.run(stop_event=transformer_stop_event) logger.info("Transformer Service completed.") def run_decoder(): logger.info("Initializing Decoder Service...") - decoder_service = DecoderService() + decoder_service = DecoderService(config) logger.info("Running Decoder Service...") decoder_service.run(stop_event=decoder_stop_event) logger.info("Video generation completed.") diff --git a/lightx2v/disagg/monitor.py b/lightx2v/disagg/monitor.py new file mode 100644 index 000000000..d29c7136b --- /dev/null +++ b/lightx2v/disagg/monitor.py @@ -0,0 +1,139 @@ +import logging +import subprocess +import threading +import time +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional + +import zmq + +logger = logging.getLogger(__name__) + + +@dataclass +class ReporterConfig: + service_type: str + gpu_id: int + bind_address: str + + +class Reporter: + def __init__(self, service_type: str, gpu_id: int, bind_address: str): + self.config = ReporterConfig( + service_type=service_type, + gpu_id=gpu_id, + bind_address=bind_address, + ) + self._context = zmq.Context.instance() + self._stop_event = threading.Event() + + def _query_gpu_metrics(self) -> Dict[str, Any]: + cmd = [ + "nvidia-smi", + "--query-gpu=utilization.gpu,memory.used,memory.total", + "--format=csv,noheader,nounits", + "-i", + str(self.config.gpu_id), + ] + out = subprocess.check_output(cmd, stderr=subprocess.STDOUT, text=True).strip() + if not out: + raise RuntimeError("nvidia-smi returned empty output") + + util_str, mem_used_str, mem_total_str = [x.strip() for x in out.split(",")] + return { + "gpu_utilization": float(util_str), + "gpu_memory_used_mb": float(mem_used_str), + "gpu_memory_total_mb": float(mem_total_str), + } + + def get_metrics(self) -> Dict[str, Any]: + metrics: Dict[str, Any] = { + "service_type": self.config.service_type, + "gpu_id": self.config.gpu_id, + "timestamp": time.time(), + } + try: + metrics.update(self._query_gpu_metrics()) + metrics["status"] = "ok" + except Exception as exc: + metrics["status"] = "error" + metrics["error"] = str(exc) + return metrics + + def serve_forever(self): + socket = self._context.socket(zmq.REP) + socket.linger = 0 + socket.bind(self.config.bind_address) + logger.info("Reporter started: service=%s gpu=%s bind=%s", self.config.service_type, self.config.gpu_id, self.config.bind_address) + + try: + while not self._stop_event.is_set(): + if socket.poll(timeout=500) == 0: + continue + try: + req = socket.recv_json() + except Exception: + socket.send_json({"status": "error", "error": "invalid request"}) + continue + + cmd = req.get("cmd", "metrics") if isinstance(req, dict) else "metrics" + if cmd == "metrics": + socket.send_json(self.get_metrics()) + else: + socket.send_json({"status": "error", "error": f"unsupported cmd: {cmd}"}) + finally: + socket.close() + + def stop(self): + self._stop_event.set() + + +class Monitor: + def __init__(self, nodes: List[str], request_timeout_ms: int = 1000): + self.nodes = nodes + self.request_timeout_ms = request_timeout_ms + self._context = zmq.Context.instance() + + def _poll_one(self, address: str) -> Dict[str, Any]: + socket = self._context.socket(zmq.REQ) + socket.linger = 0 + socket.rcvtimeo = self.request_timeout_ms + socket.sndtimeo = self.request_timeout_ms + socket.connect(address) + try: + socket.send_json({"cmd": "metrics"}) + resp = socket.recv_json() + if not isinstance(resp, dict): + return { + "status": "error", + "address": address, + "error": "invalid response type", + } + result = dict(resp) + result["address"] = address + return result + except Exception as exc: + return { + "status": "error", + "address": address, + "error": str(exc), + } + finally: + socket.close() + + def poll_once(self) -> List[Dict[str, Any]]: + return [self._poll_one(address) for address in self.nodes] + + def run_forever( + self, + interval_seconds: float = 5.0, + callback: Optional[Callable[[List[Dict[str, Any]]], None]] = None, + stop_event: Optional[threading.Event] = None, + ): + while True: + if stop_event is not None and stop_event.is_set(): + break + results = self.poll_once() + if callback is not None: + callback(results) + time.sleep(interval_seconds) diff --git a/lightx2v/disagg/rdma_buffer.py b/lightx2v/disagg/rdma_buffer.py new file mode 100644 index 000000000..2a26a69e0 --- /dev/null +++ b/lightx2v/disagg/rdma_buffer.py @@ -0,0 +1,282 @@ +from __future__ import annotations + +import ctypes +import json +import logging +import threading +from dataclasses import dataclass +from typing import Any, Dict, Optional + +from lightx2v.disagg.rdma_client import RDMAClient +from lightx2v.disagg.rdma_server import RDMAServer + +logger = logging.getLogger(__name__) + + +@dataclass +class RDMABufferDescriptor: + slot_addr: int + slot_bytes: int + slot_size: int + buffer_size: int + head_addr: int + tail_addr: int + rkey: int = 0 + head_bytes: int = 8 + tail_bytes: int = 8 + + +class RDMABuffer: + """Ring buffer backed by RDMA-accessible memory. + + Role model: + - server: producer side, owns and registers memory regions. + - client: consumer side, reads slots remotely and updates head by rdma_faa. + + The ring stores serialized JSON configs in fixed-size slots. + """ + + def __init__( + self, + role: str, + buffer_size: int = 128, + slot_size: int = 4096, + *, + rdma_server: Optional[RDMAServer] = None, + rdma_client: Optional[RDMAClient] = None, + remote: Optional[RDMABufferDescriptor] = None, + ): + if role not in {"server", "client"}: + raise ValueError("role must be 'server' or 'client'") + if buffer_size <= 0: + raise ValueError("buffer_size must be positive") + if slot_size <= 8: + raise ValueError("slot_size must be greater than 8") + + self.role = role + self.buffer_size = int(buffer_size) + self.slot_size = int(slot_size) + + self.rdma_server: Optional[RDMAServer] = rdma_server + self.rdma_client: Optional[RDMAClient] = rdma_client + + self._lock = threading.Lock() + + # Local backing store (server side). Client can also allocate local scratch. + self._slot_mem = bytearray(self.buffer_size * self.slot_size) + self._head_mem = bytearray(8) + self._tail_mem = bytearray(8) + + self._slot_addr = ctypes.addressof(ctypes.c_char.from_buffer(self._slot_mem)) + self._head_addr = ctypes.addressof(ctypes.c_char.from_buffer(self._head_mem)) + self._tail_addr = ctypes.addressof(ctypes.c_char.from_buffer(self._tail_mem)) + + # Initialize head/tail to 0. + self._write_local_u64(self._head_mem, 0) + self._write_local_u64(self._tail_mem, 0) + + self._descriptor: Optional[RDMABufferDescriptor] = None + if self.role == "server": + if self.rdma_server is not None: + info = self.rdma_server.get_local_info() + base_addr = int(info["addr"]) + need_bytes = 16 + self.buffer_size * self.slot_size + self.rdma_server.register_memory(base_addr, need_bytes) + self.rdma_server.write_memory(base_addr, (0).to_bytes(8, byteorder="little", signed=False)) + self.rdma_server.write_memory(base_addr + 8, (0).to_bytes(8, byteorder="little", signed=False)) + self._descriptor = RDMABufferDescriptor( + slot_addr=base_addr + 16, + slot_bytes=self.buffer_size * self.slot_size, + slot_size=self.slot_size, + buffer_size=self.buffer_size, + head_addr=base_addr, + tail_addr=base_addr + 8, + rkey=int(info.get("rkey", 0)), + ) + else: + self._descriptor = RDMABufferDescriptor( + slot_addr=self._slot_addr, + slot_bytes=len(self._slot_mem), + slot_size=self.slot_size, + buffer_size=self.buffer_size, + head_addr=self._head_addr, + tail_addr=self._tail_addr, + ) + else: + if remote is None: + raise ValueError("client role requires remote descriptor") + self._descriptor = remote + self.buffer_size = int(remote.buffer_size) + self.slot_size = int(remote.slot_size) + + @property + def descriptor(self) -> RDMABufferDescriptor: + if self._descriptor is None: + raise RuntimeError("descriptor is not initialized") + return self._descriptor + + def _write_local_u64(self, buf: bytearray, value: int): + buf[:8] = int(value).to_bytes(8, byteorder="little", signed=False) + + def _read_local_u64(self, buf: bytearray) -> int: + return int.from_bytes(bytes(buf[:8]), byteorder="little", signed=False) + + def _rdma_faa(self, ptr_addr: int, add_value: int) -> int: + if self.rdma_client is not None: + return self.rdma_client.rdma_faa(ptr_addr, int(add_value), rkey=self.descriptor.rkey) + + if self.rdma_server is not None: + with self._lock: + old = self._read_remote_u64(ptr_addr) + new = (old + int(add_value)) & ((1 << 64) - 1) + self._rdma_write_bytes(ptr_addr, new.to_bytes(8, byteorder="little", signed=False)) + return old + + # Fallback: local atomic emulation (useful for single-process validation). + with self._lock: + if ptr_addr == self.descriptor.head_addr: + old = self._read_local_u64(self._head_mem) + self._write_local_u64(self._head_mem, old + int(add_value)) + return old + if ptr_addr == self.descriptor.tail_addr: + old = self._read_local_u64(self._tail_mem) + self._write_local_u64(self._tail_mem, old + int(add_value)) + return old + raise RuntimeError("rdma_faa failed and no local fallback for ptr") + + def _rdma_read_bytes(self, remote_addr: int, length: int) -> bytes: + if self.rdma_server is not None and self._descriptor is not None: + base = self._descriptor.head_addr + end = base + 16 + self.buffer_size * self.slot_size + if base <= remote_addr < end: + return self.rdma_server.read_memory(int(remote_addr), int(length)) + + if self.rdma_client is not None: + data = self.rdma_client.rdma_read_from(int(remote_addr), int(length), rkey=self.descriptor.rkey) + if isinstance(data, (bytes, bytearray)): + return bytes(data) + raise RuntimeError("rdma_read_from returned non-bytes payload") + + # Local fallback for single-process testing. + if remote_addr == self.descriptor.head_addr: + return bytes(self._head_mem[:length]) + if remote_addr == self.descriptor.tail_addr: + return bytes(self._tail_mem[:length]) + base = self.descriptor.slot_addr + end = base + self.descriptor.slot_bytes + if base <= remote_addr < end: + off = remote_addr - base + return bytes(self._slot_mem[off : off + length]) + raise RuntimeError("rdma_read failed and no local fallback for addr") + + def _rdma_write_bytes(self, remote_addr: int, payload: bytes): + if self.rdma_server is not None and self._descriptor is not None: + base = self._descriptor.head_addr + end = base + 16 + self.buffer_size * self.slot_size + if base <= remote_addr < end: + self.rdma_server.write_memory(int(remote_addr), payload) + return + + if self.rdma_client is not None: + self.rdma_client.rdma_write_to(int(remote_addr), payload, rkey=self.descriptor.rkey) + return + + # Local fallback for single-process testing. + if remote_addr == self.descriptor.head_addr: + self._head_mem[: len(payload)] = payload + return + if remote_addr == self.descriptor.tail_addr: + self._tail_mem[: len(payload)] = payload + return + base = self.descriptor.slot_addr + end = base + self.descriptor.slot_bytes + if base <= remote_addr < end: + off = remote_addr - base + self._slot_mem[off : off + len(payload)] = payload + return + raise RuntimeError("rdma_write failed and no local fallback for addr") + + def _read_remote_u64(self, remote_addr: int) -> int: + raw = self._rdma_read_bytes(remote_addr, 8) + return int.from_bytes(raw, byteorder="little", signed=False) + + def _slot_offset(self, index: int) -> int: + return (index % self.buffer_size) * self.slot_size + + def _serialize_config(self, config: Dict[str, Any]) -> bytes: + payload = json.dumps(config, ensure_ascii=True, separators=(",", ":")).encode("utf-8") + if len(payload) > self.slot_size - 4: + raise ValueError(f"config payload too large: {len(payload)} > {self.slot_size - 4}") + return len(payload).to_bytes(4, byteorder="little", signed=False) + payload + + def _deserialize_config(self, raw_slot: bytes) -> Dict[str, Any]: + if len(raw_slot) < 4: + raise ValueError("invalid slot payload") + plen = int.from_bytes(raw_slot[:4], byteorder="little", signed=False) + if plen == 0: + return {} + data = raw_slot[4 : 4 + plen] + return json.loads(data.decode("utf-8")) + + def produce(self, config: Dict[str, Any]) -> int: + """Produce one config into ring buffer and advance tail by rdma_faa.""" + if self.rdma_server is None and self.rdma_client is None: + raise RuntimeError("produce requires rdma_server or rdma_client") + + # Reserve one slot by atomically incrementing tail. + old_tail = self._rdma_faa(self.descriptor.tail_addr, 1) + cur_head = self._read_remote_u64(self.descriptor.head_addr) + if (old_tail + 1) - cur_head > self.buffer_size: + # Ring full, rollback reservation. + self._rdma_faa(self.descriptor.tail_addr, -1) + raise BufferError("ring buffer is full") + + slot_idx = old_tail % self.buffer_size + offset = self._slot_offset(slot_idx) + payload = self._serialize_config(config) + + # Write payload to the selected slot (works for both server-local and client-remote paths). + slot_addr = self.descriptor.slot_addr + offset + self._rdma_write_bytes(slot_addr, b"\x00" * self.slot_size) + self._rdma_write_bytes(slot_addr, payload) + logger.info("Produced config to RDMA buffer slot %d", slot_idx) + return slot_idx + + def consume(self) -> Optional[Dict[str, Any]]: + """Consume one config from ring buffer and advance head by rdma_faa.""" + if self.role != "client": + raise RuntimeError("consume is only allowed in client role") + + try: + cur_head = self._read_remote_u64(self.descriptor.head_addr) + cur_tail = self._read_remote_u64(self.descriptor.tail_addr) + except Exception as exc: + return None + + # Fast path: empty queue, do not touch head. + if cur_head >= cur_tail: + return None + + # Try to reserve one slot by advancing head atomically. + try: + old_head = self._rdma_faa(self.descriptor.head_addr, 1) + except Exception as exc: + return None + + if old_head >= cur_tail: + # Lost the race: rollback reservation. + try: + self._rdma_faa(self.descriptor.head_addr, -1) + except Exception as exc: + logger.warning("RDMA buffer rollback failed on empty consume: %s", exc) + return None + + slot_idx = old_head % self.buffer_size + slot_addr = self.descriptor.slot_addr + self._slot_offset(slot_idx) + try: + raw = self._rdma_read_bytes(slot_addr, self.slot_size) + except Exception as exc: + logger.warning("RDMA buffer slot read failed for slot %d: %s", slot_idx, exc) + return None + logger.info("Consumed config from RDMA buffer slot %d", slot_idx) + return self._deserialize_config(raw) diff --git a/lightx2v/disagg/rdma_client.py b/lightx2v/disagg/rdma_client.py new file mode 100644 index 000000000..0c615eb49 --- /dev/null +++ b/lightx2v/disagg/rdma_client.py @@ -0,0 +1,254 @@ +import json +import socket +import threading +import time + +import pyverbs.enums as e +from pyverbs.addr import GID, AHAttr, GlobalRoute +from pyverbs.cq import CQ +from pyverbs.device import Context, get_device_list +from pyverbs.mr import MR +from pyverbs.pd import PD +from pyverbs.qp import QP, QPAttr, QPCap, QPInitAttr +from pyverbs.wr import SGE +from pyverbs.wr import SendWR as WR + + +class IBDevice: + def __init__(self, name: str): + self.name = name + + def open(self): + return Context(name=self.name) + + +class QPType: + RC = e.IBV_QPT_RC + + +class WROpcode: + RDMA_WRITE = e.IBV_WR_RDMA_WRITE + RDMA_READ = e.IBV_WR_RDMA_READ + + +class AccessFlag: + LOCAL_WRITE = e.IBV_ACCESS_LOCAL_WRITE + REMOTE_WRITE = e.IBV_ACCESS_REMOTE_WRITE + REMOTE_READ = e.IBV_ACCESS_REMOTE_READ + + +class RDMAClient: + def __init__(self, iface_name=None, local_buffer_size=4096): + self.local_psn = 654321 + self.port_num = 1 + self.gid_index = 1 + if iface_name is None: + devices = get_device_list() + if not devices: + raise RuntimeError("No RDMA device found") + raw_name = devices[0].name + iface_name = raw_name.decode() if isinstance(raw_name, bytes) else raw_name + + self.ctx = IBDevice(iface_name).open() + self.pd = PD(self.ctx) + self.cq = CQ(self.ctx, 10) + + qp_init_attr = QPCap(max_send_wr=10, max_recv_wr=10, max_send_sge=1, max_recv_sge=1) + qia = QPInitAttr(qp_type=QPType.RC, scq=self.cq, rcq=self.cq, cap=qp_init_attr) + qa = QPAttr(port_num=self.port_num) + self.qp = QP(self.pd, qia, qa) + + # 客户端也需要注册内存,用于发送数据的源 (Write) 或接收数据的目标 (Read) + self.buffer_size = int(local_buffer_size) + if self.buffer_size <= 0: + raise ValueError("local_buffer_size must be positive") + self.local_mr = MR(self.pd, self.buffer_size, AccessFlag.LOCAL_WRITE) + self._io_lock = threading.RLock() + + def _ensure_local_mr_capacity(self, required_size: int): + required = int(required_size) + if required <= self.buffer_size: + return + self.buffer_size = required + self.local_mr = MR(self.pd, self.buffer_size, AccessFlag.LOCAL_WRITE) + + def connect_to_server(self, server_ip="127.0.0.1", port=5566): + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.connect((server_ip, port)) + + # 1. 接收 Server 信息 (包含 rkey 和 addr) + data = sock.recv(4096) + self.remote_info = json.loads(data.decode()) + print(f"[Client] Got Server Info: Addr={hex(self.remote_info['addr'])}, RKey={self.remote_info['rkey']}") + + # 2. 发送我的信息给 Server + gid = self.ctx.query_gid(port_num=self.port_num, index=self.gid_index) + my_info = { + "lid": self.ctx.query_port(port_num=self.port_num).lid, + "qpn": self.qp.qp_num, + "psn": self.local_psn, + "gid": str(gid), + "gid_index": self.gid_index, + } + sock.sendall(json.dumps(my_info).encode()) + + # 3. 修改 QP 状态 + self._modify_qp_to_rts() + self.sock = sock + print("[Client] Connection established (RTS)") + + def _modify_qp_to_rts(self): + # Follow the standard RC flow: INIT -> RTR -> RTS. + init_attr = QPAttr(port_num=self.port_num) + init_attr.qp_access_flags = AccessFlag.LOCAL_WRITE | AccessFlag.REMOTE_WRITE | AccessFlag.REMOTE_READ + self.qp.to_init(init_attr) + + rtr_attr = QPAttr(port_num=self.port_num) + rtr_attr.path_mtu = e.IBV_MTU_1024 + rtr_attr.max_dest_rd_atomic = 1 + rtr_attr.min_rnr_timer = 12 + rtr_attr.dest_qp_num = int(self.remote_info["qpn"]) + rtr_attr.rq_psn = int(self.remote_info["psn"]) + + remote_lid = int(self.remote_info.get("lid", 0)) + remote_gid_index = int(self.remote_info.get("gid_index", self.gid_index)) + gr = GlobalRoute(dgid=GID(self.remote_info["gid"]), sgid_index=remote_gid_index) + rtr_attr.ah_attr = AHAttr(port_num=self.port_num, is_global=1, gr=gr, dlid=remote_lid) + self.qp.to_rtr(rtr_attr) + + rts_attr = QPAttr(port_num=self.port_num) + rts_attr.timeout = 14 + rts_attr.retry_cnt = 7 + rts_attr.rnr_retry = 7 + rts_attr.sq_psn = self.local_psn + rts_attr.max_rd_atomic = 1 + self.qp.to_rts(rts_attr) + + def rdma_write(self, data_bytes, notify_server: bool = False): + """执行单边写:将本地数据直接写入远程内存""" + self._ensure_local_mr_capacity(len(data_bytes)) + + # 1. 准备本地数据 + padded = data_bytes.ljust(self.buffer_size, b"\x00") + self.local_mr.write(padded, len(padded), 0) + + # 2. 构造 WR (Work Request) + sge = SGE(self.local_mr.buf, len(data_bytes), self.local_mr.lkey) + wr = WR( + wr_id=123, + opcode=WROpcode.RDMA_WRITE, + num_sge=1, + sg=[sge], + send_flags=e.IBV_SEND_SIGNALED, + ) + wr.set_wr_rdma(int(self.remote_info["rkey"]), int(self.remote_info["addr"])) + + # 3. 提交请求 + self.qp.post_send(wr) + + # 4. 轮询完成队列 (如果之前设置了 SIGNALED) + # 对于纯单边写,如果不要求确认,可以不用轮询,这就是"零拷贝零中断"的精髓 + # 但为了演示成功,我们这里简单轮询一下 + self._poll_cq() + # Optional demo-path notification channel; rdma_buffer path does not rely on it. + if notify_server and hasattr(self, "sock") and self.sock is not None: + try: + self.sock.sendall(b"WRITE_DONE") + except (BrokenPipeError, OSError): + self.sock = None + + def rdma_read(self, length): + """执行单边读:直接从远程内存读取数据到本地""" + self._ensure_local_mr_capacity(length) + sge = SGE(self.local_mr.buf, length, self.local_mr.lkey) + wr = WR( + wr_id=124, + opcode=WROpcode.RDMA_READ, + num_sge=1, + sg=[sge], + send_flags=e.IBV_SEND_SIGNALED, + ) + wr.set_wr_rdma(int(self.remote_info["rkey"]), int(self.remote_info["addr"])) + + self.qp.post_send(wr) + + self._poll_cq() + return self.local_mr.read(length, 0) + + def rdma_write_to(self, remote_addr, data_bytes, rkey=None): + """Write bytes to an explicit remote address. + + Keeps compatibility with existing rdma_write implementation by temporarily + overriding remote_info addr/rkey for this operation. + """ + with self._io_lock: + old_addr = self.remote_info["addr"] + old_rkey = self.remote_info["rkey"] + self.remote_info["addr"] = int(remote_addr) + if rkey is not None: + self.remote_info["rkey"] = int(rkey) + try: + self.rdma_write(data_bytes, notify_server=False) + finally: + self.remote_info["addr"] = old_addr + self.remote_info["rkey"] = old_rkey + + def rdma_read_from(self, remote_addr, length, rkey=None): + """Read bytes from an explicit remote address.""" + with self._io_lock: + old_addr = self.remote_info["addr"] + old_rkey = self.remote_info["rkey"] + self.remote_info["addr"] = int(remote_addr) + if rkey is not None: + self.remote_info["rkey"] = int(rkey) + try: + return self.rdma_read(int(length)) + finally: + self.remote_info["addr"] = old_addr + self.remote_info["rkey"] = old_rkey + + def rdma_faa(self, remote_addr, add_value, rkey=None): + """Best-effort FAA semantics via read-modify-write. + + NOTE: This is not a true remote atomic verb; it is a compatibility shim + until atomic WR support is implemented. + """ + with self._io_lock: + old = self.rdma_read_from(int(remote_addr), 8, rkey=rkey) + old_v = int.from_bytes(old, byteorder="little", signed=False) + new_v = (old_v + int(add_value)) & ((1 << 64) - 1) + self.rdma_write_to(int(remote_addr), new_v.to_bytes(8, byteorder="little", signed=False), rkey=rkey) + return old_v + + def _poll_cq(self): + """简单的轮询""" + while True: + poll_ret = self.cq.poll(1) + if not isinstance(poll_ret, tuple) or len(poll_ret) != 2: + raise RuntimeError(f"Unexpected CQ poll return: {poll_ret}") + num_wc, wc_list = poll_ret + if num_wc > 0 and wc_list: + wc = wc_list[0] + status = getattr(wc, "status", None) + if status is None: + raise RuntimeError(f"Unexpected WC object: {wc}") + if status != e.IBV_WC_SUCCESS: + vendor_err = getattr(wc, "vendor_err", None) + raise Exception(f"WC Error: {status}, vendor_err: {vendor_err}") + break + time.sleep(0.0001) + + +# 使用示例 +# if __name__ == "__main__": +# cli = RDMAClient() +# cli.connect_to_server('127.0.0.1') # 替换为服务器 IP + +# # 执行单边写 +# msg = b"Hello RDMA!" +# cli.rdma_write(msg) +# print("Write done.") + +# # 执行单边读 +# data = cli.rdma_read(len(msg)) +# print("Read data:", data) diff --git a/lightx2v/disagg/rdma_server.py b/lightx2v/disagg/rdma_server.py new file mode 100644 index 000000000..0111a4447 --- /dev/null +++ b/lightx2v/disagg/rdma_server.py @@ -0,0 +1,244 @@ +import json +import socket +import threading + +import pyverbs.enums as e +from pyverbs.addr import GID, AHAttr, GlobalRoute +from pyverbs.cq import CQ +from pyverbs.device import Context, get_device_list +from pyverbs.mr import MR +from pyverbs.pd import PD +from pyverbs.qp import QP, QPAttr, QPCap, QPInitAttr + + +class IBDevice: + def __init__(self, name: str): + self.name = name + + def open(self): + return Context(name=self.name) + + +class QPType: + RC = e.IBV_QPT_RC + + +class WROpcode: + RDMA_WRITE = e.IBV_WR_RDMA_WRITE + + +class AccessFlag: + LOCAL_WRITE = e.IBV_ACCESS_LOCAL_WRITE + REMOTE_WRITE = e.IBV_ACCESS_REMOTE_WRITE + REMOTE_READ = e.IBV_ACCESS_REMOTE_READ + + +class RDMAServer: + def __init__(self, iface_name=None, port_num=1, buffer_size=4096): + self.local_psn = 123456 + self._next_psn = int(self.local_psn) + self.port_num = port_num + self.gid_index = 1 + self.buffer_size = int(buffer_size) + if self.buffer_size <= 0: + raise ValueError("buffer_size must be positive") + if iface_name is None: + devices = get_device_list() + if not devices: + raise RuntimeError("No RDMA device found") + raw_name = devices[0].name + iface_name = raw_name.decode() if isinstance(raw_name, bytes) else raw_name + + self.ctx = IBDevice(iface_name).open() + if self.ctx is None: + available = [] + for dev in get_device_list(): + dev_name = dev.name.decode() if isinstance(dev.name, bytes) else dev.name + available.append(dev_name) + raise RuntimeError(f"Failed to open RDMA device '{iface_name}'. Available devices: {available}") + + self.pd = PD(self.ctx) + self.cq = CQ(self.ctx, 10) + + # 创建 QP (Queue Pair) + qp_init_attr = QPCap(max_send_wr=10, max_recv_wr=10, max_send_sge=1, max_recv_sge=1) + qia = QPInitAttr(qp_type=QPType.RC, scq=self.cq, rcq=self.cq, cap=qp_init_attr) + qa = QPAttr(port_num=self.port_num) + self.qp = QP(self.pd, qia, qa) # RC: Reliable Connected + self._qp_init_attr = qia + self._qp_attr = qa + self._conn_lock = threading.Lock() + self._active_qps = [self.qp] + self._active_conns = [] + self._listener_socket = None + + # 关键:注册一块内存用于被远程访问 + # buffer_size 可配置,允许远程写入 (REMOTE_WRITE) 和远程读取 (REMOTE_READ) + self.mr = MR(self.pd, self.buffer_size, AccessFlag.LOCAL_WRITE | AccessFlag.REMOTE_WRITE | AccessFlag.REMOTE_READ) + + # 初始化缓冲区数据 (例如全为 0) + zeros = b"\x00" * self.buffer_size + self.mr.write(zeros, len(zeros), 0) + + mr_addr = getattr(self.mr, "addr", None) + if mr_addr is None: + mr_addr = self.mr.buf + self._mr_addr = int(mr_addr) + print(f"[Server] MR Registered. Addr: {mr_addr}, RKey: {self.mr.rkey}") + + def register_memory(self, addr: int, length: int): + """Validate a requested sub-region against server MR and return registration metadata. + + This server uses one pre-registered MR, so sub-regions are slices of that MR. + """ + if length <= 0: + raise ValueError("length must be positive") + addr = int(addr) + length = int(length) + if addr < self._mr_addr: + raise ValueError("addr is below MR base") + off = addr - self._mr_addr + if off + length > self.buffer_size: + raise ValueError(f"region out of MR range: off={off}, length={length}, buffer_size={self.buffer_size}") + return { + "addr": addr, + "length": length, + "rkey": int(self.mr.rkey), + } + + def read_memory(self, addr: int, length: int) -> bytes: + """Read bytes from a validated sub-region within server MR.""" + region = self.register_memory(addr, length) + off = int(region["addr"]) - self._mr_addr + return self.mr.read(int(length), int(off)) + + def write_memory(self, addr: int, payload: bytes): + """Write bytes to a validated sub-region within server MR.""" + self.register_memory(addr, len(payload)) + off = int(addr) - self._mr_addr + self.mr.write(payload, len(payload), int(off)) + + def get_local_info(self, qp=None, psn=None): + """获取本机 QP 信息,用于交换""" + mr_addr = getattr(self.mr, "addr", None) + if mr_addr is None: + mr_addr = self.mr.buf + qp = self.qp if qp is None else qp + psn = self.local_psn if psn is None else int(psn) + gid = self.ctx.query_gid(port_num=self.port_num, index=self.gid_index) + return {"lid": self.ctx.query_port(port_num=self.port_num).lid, "qpn": qp.qp_num, "psn": psn, "gid": str(gid), "gid_index": self.gid_index, "rkey": self.mr.rkey, "addr": mr_addr} + + def _alloc_qp_with_psn(self): + with self._conn_lock: + self._next_psn = (self._next_psn + 1) & 0xFFFFFF + if self._next_psn == 0: + self._next_psn = 1 + psn = self._next_psn + qp = QP(self.pd, self._qp_init_attr, self._qp_attr) + return qp, psn + + def _accept_one_client(self, listen_sock): + conn, addr = listen_sock.accept() + print(f"[Server] Connected to {addr}") + + qp, local_psn = self._alloc_qp_with_psn() + + # 1. 发送我的信息给 Client + my_info = self.get_local_info(qp=qp, psn=local_psn) + conn.sendall(json.dumps(my_info).encode()) + + # 2. 接收 Client 的信息 + data = conn.recv(4096) + remote_info = json.loads(data.decode()) + print(f"[Server] Received remote info: QPN={remote_info['qpn']}") + + # 3. 修改 QP 状态到 RTS + self._modify_qp_to_rts(qp, remote_info, local_psn) + + with self._conn_lock: + self._active_qps.append(qp) + self._active_conns.append(conn) + return conn + + def handshake(self, host="0.0.0.0", port=5566, serve_forever=True): + """TCP handshake to exchange QP information. + + When serve_forever=True (default), accepts multiple clients on one port. + Each client gets its own QP so multiple services can connect concurrently. + """ + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind((host, port)) + sock.listen(16) + self._listener_socket = sock + print(f"[Server] Waiting for connection on {host}:{port}...") + + if not serve_forever: + return self._accept_one_client(sock) + + while True: + try: + self._accept_one_client(sock) + except Exception as exc: + print(f"[Server] Handshake accept failed: {exc}") + + def _modify_qp_to_rts(self, qp, remote_info, local_psn): + # Follow the standard RC flow: INIT -> RTR -> RTS. + init_attr = QPAttr(port_num=self.port_num) + init_attr.qp_access_flags = AccessFlag.LOCAL_WRITE | AccessFlag.REMOTE_WRITE | AccessFlag.REMOTE_READ + qp.to_init(init_attr) + + rtr_attr = QPAttr(port_num=self.port_num) + rtr_attr.path_mtu = e.IBV_MTU_1024 + rtr_attr.max_dest_rd_atomic = 1 + rtr_attr.min_rnr_timer = 12 + rtr_attr.dest_qp_num = int(remote_info["qpn"]) + rtr_attr.rq_psn = int(remote_info["psn"]) + + remote_lid = int(remote_info.get("lid", 0)) + remote_gid_index = int(remote_info.get("gid_index", self.gid_index)) + gr = GlobalRoute(dgid=GID(remote_info["gid"]), sgid_index=remote_gid_index) + rtr_attr.ah_attr = AHAttr(port_num=self.port_num, is_global=1, gr=gr, dlid=remote_lid) + qp.to_rtr(rtr_attr) + + rts_attr = QPAttr(port_num=self.port_num) + rts_attr.timeout = 14 + rts_attr.retry_cnt = 7 + rts_attr.rnr_retry = 7 + rts_attr.sq_psn = int(local_psn) + rts_attr.max_rd_atomic = 1 + qp.to_rts(rts_attr) + print("[Server] QP State changed to RTS") + + def wait_for_completion(self, timeout_ms=5000): + """轮询 CQ 等待操作完成(如果是带响应的操作)""" + # 单边写通常不需要接收方做额外操作,除非使用了带立即数或原子操作需要确认 + # 这里仅作演示,实际单边写完成后,服务端内存已变化 + pass + + def read_local_memory(self): + """读取本地内存查看变化""" + return self.mr.read(self.buffer_size, 0) + + +# 使用示例 (需在单独进程运行) +# if __name__ == "__main__": +# parser = argparse.ArgumentParser(description="RDMA server demo") +# parser.add_argument("--iface", default=None, help="RDMA device name, auto-detect when omitted") +# parser.add_argument("--port-num", type=int, default=1, help="RDMA port number") +# parser.add_argument("--buffer-size", type=int, default=4096, help="registered memory size in bytes") +# parser.add_argument("--listen-host", default="0.0.0.0", help="TCP handshake listen host") +# parser.add_argument("--listen-port", type=int, default=5566, help="TCP handshake listen port") +# args = parser.parse_args() + +# srv = RDMAServer(iface_name=args.iface, port_num=args.port_num, buffer_size=args.buffer_size) +# conn = srv.handshake(host=args.listen_host, port=args.listen_port) +# conn.settimeout(10.0) +# try: +# marker = conn.recv(64) +# print(f"[Server] Completion marker: {marker!r}") +# except socket.timeout: +# print("[Server] No completion marker received before timeout") +# finally: +# conn.close() +# print("Data after operation:", srv.read_local_memory()) diff --git a/lightx2v/disagg/services/controller.py b/lightx2v/disagg/services/controller.py index 8d69f711c..8fdfc1696 100644 --- a/lightx2v/disagg/services/controller.py +++ b/lightx2v/disagg/services/controller.py @@ -1,10 +1,10 @@ -import time -from collections import deque from pathlib import Path -from threading import Lock -from typing import Any, Deque +from threading import Event, Lock, Thread -from lightx2v.disagg.conn import REQUEST_POLLING_PORT, ReqManager +from lightx2v.disagg.conn import MONITOR_POLLING_PORT, REQUEST_POLLING_PORT, ReqManager +from lightx2v.disagg.monitor import Monitor +from lightx2v.disagg.rdma_buffer import RDMABuffer +from lightx2v.disagg.rdma_server import RDMAServer from lightx2v.disagg.scheduler.round_robin import RoundRobinPolicy from lightx2v.disagg.services.base import BaseService @@ -12,12 +12,103 @@ class ControllerService(BaseService): def __init__(self): super().__init__() - self.request_queue: Deque[Any] = deque() + self.rdma_buffer_request: RDMABuffer | None = None + self.rdma_buffer_phase1: RDMABuffer | None = None + self.rdma_buffer_phase2: RDMABuffer | None = None self.encoder_policy = RoundRobinPolicy() self.transformer_policy = RoundRobinPolicy() self.decoder_policy = RoundRobinPolicy() self._lock = Lock() self.req_mgr = ReqManager() + self.monitor = Monitor(nodes=[]) + self._rdma_server_request: RDMAServer | None = None + self._rdma_server_phase1: RDMAServer | None = None + self._rdma_server_phase2: RDMAServer | None = None + self._rdma_handshake_thread_request: Thread | None = None + self._rdma_handshake_thread_phase1: Thread | None = None + self._rdma_handshake_thread_phase2: Thread | None = None + + def _init_request_rdma_buffer(self, bootstrap_addr: str, config: dict): + slots = int(config.get("rdma_buffer_slots", "128")) + slot_size = int(config.get("rdma_buffer_slot_size", "4096")) + handshake_port = int(config.get("rdma_request_handshake_port", "5566")) + phase1_slots = slots + phase1_slot_size = slot_size + phase1_handshake_port = int(config.get("rdma_phase1_handshake_port", "5567")) + phase2_slots = slots + phase2_slot_size = slot_size + phase2_handshake_port = int(config.get("rdma_phase2_handshake_port", "5568")) + + # Normalize RDMA request-buffer parameters so downstream services consume the same values. + config["rdma_request_host"] = bootstrap_addr + config["rdma_buffer_slots"] = slots + config["rdma_buffer_slot_size"] = slot_size + config["rdma_request_handshake_port"] = handshake_port + config["rdma_phase1_host"] = bootstrap_addr + config["rdma_phase1_handshake_port"] = phase1_handshake_port + config["rdma_phase2_host"] = bootstrap_addr + config["rdma_phase2_handshake_port"] = phase2_handshake_port + + need_bytes = 16 + slots * slot_size + self._rdma_server_request = RDMAServer(buffer_size=need_bytes) + self.rdma_buffer_request = RDMABuffer( + role="server", + buffer_size=slots, + slot_size=slot_size, + rdma_server=self._rdma_server_request, + ) + + self._rdma_handshake_thread_request = Thread( + target=self._rdma_server_request.handshake, + kwargs={"host": bootstrap_addr, "port": handshake_port}, + name="controller-rdma-handshake", + daemon=True, + ) + self._rdma_handshake_thread_request.start() + + need_bytes_phase1 = 16 + phase1_slots * phase1_slot_size + self._rdma_server_phase1 = RDMAServer(buffer_size=need_bytes_phase1) + self.rdma_buffer_phase1 = RDMABuffer( + role="server", + buffer_size=phase1_slots, + slot_size=phase1_slot_size, + rdma_server=self._rdma_server_phase1, + ) + self._rdma_handshake_thread_phase1 = Thread( + target=self._rdma_server_phase1.handshake, + kwargs={"host": bootstrap_addr, "port": phase1_handshake_port}, + name="controller-rdma-handshake-phase1", + daemon=True, + ) + self._rdma_handshake_thread_phase1.start() + + need_bytes_phase2 = 16 + phase2_slots * phase2_slot_size + self._rdma_server_phase2 = RDMAServer(buffer_size=need_bytes_phase2) + self.rdma_buffer_phase2 = RDMABuffer( + role="server", + buffer_size=phase2_slots, + slot_size=phase2_slot_size, + rdma_server=self._rdma_server_phase2, + ) + self._rdma_handshake_thread_phase2 = Thread( + target=self._rdma_server_phase2.handshake, + kwargs={"host": bootstrap_addr, "port": phase2_handshake_port}, + name="controller-rdma-handshake-phase2", + daemon=True, + ) + self._rdma_handshake_thread_phase2.start() + self.logger.info( + "Initialized RDMA buffers: request=(%s,%s,%s) phase1=(%s,%s,%s) phase2=(%s,%s,%s)", + slots, + slot_size, + need_bytes, + phase1_slots, + phase1_slot_size, + need_bytes_phase1, + phase2_slots, + phase2_slot_size, + need_bytes_phase2, + ) def add_instance(self, instance_type: str, instance_address: str): """Add instance address to the matching scheduling policy by type.""" @@ -52,21 +143,13 @@ def send_request(self, config): if config is None: raise ValueError("config cannot be None") - encoder_addr = self.encoder_policy.schedule() - transformer_addr = self.transformer_policy.schedule() - decoder_addr = self.decoder_policy.schedule() - - encoder_ip, encoder_port_str = encoder_addr.rsplit(":", 1) - transformer_ip, transformer_port_str = transformer_addr.rsplit(":", 1) - decoder_ip, decoder_port_str = decoder_addr.rsplit(":", 1) - - self.req_mgr.send(encoder_ip, int(encoder_port_str), config) - self.req_mgr.send(transformer_ip, int(transformer_port_str), config) - self.req_mgr.send(decoder_ip, int(decoder_port_str), config) - self.logger.info("Request added to controller queue and dispatched to services") + if self.rdma_buffer_request is None: + raise RuntimeError("RDMA request buffer is not initialized") + self.rdma_buffer_request.produce(config) + self.logger.info("Request enqueued to encoder request RDMA buffer") def run(self, config): - """Initialize instances from config and submit request multiple times.""" + """Initialize instances, send requests, wait for decoder save_path callbacks, then exit.""" if config is None: raise ValueError("config cannot be None") @@ -74,11 +157,15 @@ def run(self, config): encoder_engine_rank = config.get("encoder_engine_rank", 0) transformer_engine_rank = config.get("transformer_engine_rank", 1) decoder_engine_rank = config.get("decoder_engine_rank", 2) + request_count = int(config.get("request_count", 2)) + result_port = int(config.get("controller_result_port", REQUEST_POLLING_PORT - 1)) self.encoder_policy = RoundRobinPolicy() self.transformer_policy = RoundRobinPolicy() self.decoder_policy = RoundRobinPolicy() + self._init_request_rdma_buffer(bootstrap_addr, config) + self.add_instance("encoder", f"{bootstrap_addr}:{REQUEST_POLLING_PORT + encoder_engine_rank}") self.add_instance( "transformer", @@ -86,18 +173,95 @@ def run(self, config): ) self.add_instance("decoder", f"{bootstrap_addr}:{REQUEST_POLLING_PORT + decoder_engine_rank}") + monitor_nodes = [ + f"tcp://{bootstrap_addr}:{MONITOR_POLLING_PORT + encoder_engine_rank}", + f"tcp://{bootstrap_addr}:{MONITOR_POLLING_PORT + transformer_engine_rank}", + f"tcp://{bootstrap_addr}:{MONITOR_POLLING_PORT + decoder_engine_rank}", + ] + self.monitor.nodes = monitor_nodes + + monitor_stop_event = Event() + + def _monitor_callback(results): + for item in results: + self.logger.info("monitor: %s", item) + + monitor_thread = Thread( + target=self.monitor.run_forever, + kwargs={ + "interval_seconds": 5.0, + "callback": _monitor_callback, + "stop_event": monitor_stop_event, + }, + name="controller-monitor", + daemon=True, + ) + # monitor_thread.start() + base_save_path = config.get("save_path") + expected_rooms: set[int] = set() + received_rooms: set[int] = set() + received_results: list[dict] = [] + try: + for i in range(request_count): + request_config = dict(config) + request_config["data_bootstrap_room"] = i + request_config["controller_result_host"] = bootstrap_addr + request_config["controller_result_port"] = result_port + if base_save_path: + save_path = Path(base_save_path) + request_config["save_path"] = str(save_path.with_name(f"{save_path.stem}{i + 1}{save_path.suffix}")) + # TODO: use queue to receive request from client and dispatch, currently we just send the same request multiple times for testing + with self._lock: + current_request = request_config + self.send_request(current_request) + + expected_rooms.add(i) + + self.logger.info( + "Waiting for decoder results: expected=%s on port=%s", + sorted(expected_rooms), + result_port, + ) + while len(received_rooms) < len(expected_rooms): + result = self.req_mgr.receive(result_port) + if not isinstance(result, dict): + self.logger.warning("Ignored non-dict decoder result: %s", result) + continue + room = result.get("data_bootstrap_room") + if room is None: + self.logger.warning("Ignored decoder result without data_bootstrap_room: %s", result) + continue + room = int(room) + if room not in expected_rooms: + self.logger.warning("Ignored decoder result for unexpected room=%s: %s", room, result) + continue + if room in received_rooms: + self.logger.info("Duplicate decoder result for room=%s ignored", room) + continue + + received_rooms.add(room) + received_results.append(result) + + if result.get("ok", False): + self.logger.info( + "Decoder result received room=%s save_path=%s (%s/%s)", + room, + result.get("save_path"), + len(received_rooms), + len(expected_rooms), + ) + else: + self.logger.error( + "Decoder result failed room=%s error=%s (%s/%s)", + room, + result.get("error"), + len(received_rooms), + len(expected_rooms), + ) - for i in range(2): - request_config = dict(config) - request_config["data_bootstrap_room"] = i - if base_save_path: - save_path = Path(base_save_path) - request_config["save_path"] = str(save_path.with_name(f"{save_path.stem}{i + 1}{save_path.suffix}")) - # TODO: use queue to receive request from client and dispatch, currently we just send the same request multiple times for testing - with self._lock: - self.request_queue.append(request_config) - current_request = self.request_queue.popleft() - self.send_request(current_request) - - time.sleep(2) # Sleep briefly to allow services to process the request + self.logger.info("All decoder results received. Controller exiting.") + finally: + pass + # monitor_stop_event.set() + # monitor_thread.join(timeout=1.0) diff --git a/lightx2v/disagg/services/decoder.py b/lightx2v/disagg/services/decoder.py index 72f33e836..7513f1528 100644 --- a/lightx2v/disagg/services/decoder.py +++ b/lightx2v/disagg/services/decoder.py @@ -1,13 +1,17 @@ import hashlib import json +import threading import time from collections import deque from typing import Dict, List, Optional import torch -from lightx2v.disagg.conn import REQUEST_POLLING_PORT, DataArgs, DataManager, DataPoll, DataReceiver, DisaggregationMode, DisaggregationPhase, ReqManager +from lightx2v.disagg.conn import MONITOR_POLLING_PORT, DataArgs, DataManager, DataPoll, DataReceiver, DisaggregationMode, DisaggregationPhase, ReqManager +from lightx2v.disagg.monitor import Reporter from lightx2v.disagg.protocol import AllocationRequest, MemoryHandle, RemoteBuffer +from lightx2v.disagg.rdma_buffer import RDMABuffer, RDMABufferDescriptor +from lightx2v.disagg.rdma_client import RDMAClient from lightx2v.disagg.services.base import BaseService from lightx2v.disagg.utils import estimate_transformer_buffer_sizes, load_wan_vae_decoder from lightx2v.utils.envs import GET_DTYPE @@ -16,10 +20,21 @@ class DecoderService(BaseService): - def __init__(self): + def __init__(self, config: dict): super().__init__() - self.request_port = REQUEST_POLLING_PORT + 2 - self.req_mgr = ReqManager() + self.config = config + self.encoder_engine_rank = int(self.config.get("encoder_engine_rank", 0)) + self.transformer_engine_rank = int(self.config.get("transformer_engine_rank", 1)) + self.decoder_engine_rank = int(self.config.get("decoder_engine_rank", 2)) + self._phase2_rdma_client: Optional[RDMAClient] = None + self._phase2_rdma_buffer: Optional[RDMABuffer] = None + shared_slots = int(self.config.get("rdma_buffer_slots", "128")) + shared_slot_size = int(self.config.get("rdma_buffer_slot_size", "4096")) + self._phase2_server_ip = str(self.config.get("rdma_phase2_host", "127.0.0.1")) + self._phase2_handshake_port = int(self.config.get("rdma_phase2_handshake_port", "5568")) + self._phase2_slots = shared_slots + self._phase2_slot_size = shared_slot_size + self._last_phase2_connect_retry_ts = 0.0 self.vae_decoder = None self._rdma_buffers: Dict[int, List[torch.Tensor]] = {} self.data_mgr = DataManager( @@ -27,25 +42,75 @@ def __init__(self): DisaggregationMode.DECODE, ) self.data_receiver: Dict[int, DataReceiver] = {} + self.req_mgr = ReqManager() + self.reporter = Reporter( + service_type="decoder", + gpu_id=self.decoder_engine_rank, + bind_address=f"tcp://{self.config.get('data_bootstrap_addr', '127.0.0.1')}:{MONITOR_POLLING_PORT + self.decoder_engine_rank}", + ) + self._reporter_thread: Optional[threading.Thread] = threading.Thread( + target=self.reporter.serve_forever, + name="decoder-reporter", + daemon=True, + ) + self._reporter_thread.start() + self.load_models() + + def _ensure_phase2_request_buffer(self) -> bool: + if self._phase2_rdma_buffer is not None: + return True + now = time.time() + if now - self._last_phase2_connect_retry_ts < 1.0: + return False + self._last_phase2_connect_retry_ts = now + + if self._phase2_rdma_client is None: + self._phase2_rdma_client = RDMAClient(local_buffer_size=self._phase2_slot_size) + self._phase2_rdma_client.connect_to_server(self._phase2_server_ip, self._phase2_handshake_port) + remote_info = self._phase2_rdma_client.remote_info + base_addr = int(remote_info["addr"]) + self._phase2_rdma_buffer = RDMABuffer( + role="client", + rdma_client=self._phase2_rdma_client, + remote=RDMABufferDescriptor( + slot_addr=base_addr + 16, + slot_bytes=self._phase2_slots * self._phase2_slot_size, + slot_size=self._phase2_slot_size, + buffer_size=self._phase2_slots, + head_addr=base_addr, + tail_addr=base_addr + 8, + rkey=int(remote_info.get("rkey", 0)), + ), + ) + return True def init(self, config): self.config = config - self.vae_decoder = None + shared_slots = int(self.config.get("rdma_buffer_slots", self._phase2_slots)) + shared_slot_size = int(self.config.get("rdma_buffer_slot_size", 4096)) + self._phase2_server_ip = str(self.config.get("rdma_phase2_host", self._phase2_server_ip)) + self._phase2_handshake_port = int(self.config.get("rdma_phase2_handshake_port", self._phase2_handshake_port)) + self._phase2_slots = shared_slots + self._phase2_slot_size = shared_slot_size self.encoder_engine_rank = int(self.config.get("encoder_engine_rank", 0)) self.transformer_engine_rank = int(self.config.get("transformer_engine_rank", 1)) self.decoder_engine_rank = int(self.config.get("decoder_engine_rank", 2)) - self.load_models() - if "seed" in self.config: seed_all(self.config["seed"]) data_bootstrap_addr = self.config.get("data_bootstrap_addr", "127.0.0.1") data_bootstrap_room = self.config.get("data_bootstrap_room", 0) + if data_bootstrap_addr is None or data_bootstrap_room is None: return + try: + self._ensure_phase2_request_buffer() + except Exception: + self.logger.exception("Failed to connect phase2 RDMA buffer, will retry") + buffer_sizes = estimate_transformer_buffer_sizes(self.config) request = AllocationRequest( bootstrap_room=data_bootstrap_room, @@ -63,7 +128,8 @@ def init(self, config): ib_device=None, ) self.data_mgr.init(data_args, data_bootstrap_room) - self.data_receiver[data_bootstrap_room] = DataReceiver(self.data_mgr, data_bootstrap_addr, data_bootstrap_room) + phase2_bootstrap_addr = str(self.config.get("transformer_node_address", data_bootstrap_addr)) + self.data_receiver[data_bootstrap_room] = DataReceiver(self.data_mgr, phase2_bootstrap_addr, data_bootstrap_room) self.data_receiver[data_bootstrap_room].init() def load_models(self): @@ -176,6 +242,10 @@ def remove(self, room: int): def release(self): for room in list(self._rdma_buffers.keys()): self.remove(room) + self.reporter.stop() + if self._reporter_thread is not None and self._reporter_thread.is_alive(): + self._reporter_thread.join(timeout=1.0) + self._reporter_thread = None if self.data_mgr is not None: self.data_mgr.release() self.data_receiver.clear() @@ -187,11 +257,21 @@ def run(self, stop_event=None): exec_queue = deque() while True: - while True: - config = self.req_mgr.receive_non_block(self.request_port) - if config is None: - break - req_queue.append(config) + if self._phase2_rdma_buffer is None: + try: + self._ensure_phase2_request_buffer() + except Exception: + self.logger.exception("Failed to connect phase2 request RDMA buffer, will retry") + + if self._phase2_rdma_buffer is not None: + packet = self._phase2_rdma_buffer.consume() + if packet is not None: + if isinstance(packet, dict) and "request_config" in packet: + config = dict(packet.get("request_config") or {}) + config["transformer_node_address"] = packet.get("transformer_node_address", config.get("transformer_node_address", "127.0.0.1")) + else: + config = packet + req_queue.append(config) if req_queue: config = req_queue.popleft() @@ -229,9 +309,34 @@ def run(self, stop_event=None): if exec_queue: room, config = exec_queue.popleft() try: - self.process(config) + save_path = self.process(config) + callback_host = str(config.get("controller_result_host", config.get("data_bootstrap_addr", "127.0.0.1"))) + callback_port = int(config.get("controller_result_port")) if config.get("controller_result_port") is not None else None + if callback_port is not None: + self.req_mgr.send( + callback_host, + callback_port, + { + "ok": True, + "data_bootstrap_room": int(room), + "save_path": save_path, + }, + ) except Exception: self.logger.exception("Failed to process request for room=%s", room) + callback_host = str(config.get("controller_result_host", config.get("data_bootstrap_addr", "127.0.0.1"))) + callback_port = int(config.get("controller_result_port")) if config.get("controller_result_port") is not None else None + if callback_port is not None: + self.req_mgr.send( + callback_host, + callback_port, + { + "ok": False, + "data_bootstrap_room": int(room), + "save_path": None, + "error": "decoder process failed", + }, + ) finally: self.remove(room) diff --git a/lightx2v/disagg/services/encoder.py b/lightx2v/disagg/services/encoder.py index 31314cc08..396cfcab0 100644 --- a/lightx2v/disagg/services/encoder.py +++ b/lightx2v/disagg/services/encoder.py @@ -1,5 +1,6 @@ import hashlib import json +import threading import time from collections import deque from typing import Dict, List, Optional @@ -7,8 +8,11 @@ import numpy as np import torch -from lightx2v.disagg.conn import REQUEST_POLLING_PORT, DataArgs, DataManager, DataPoll, DataSender, DisaggregationMode, DisaggregationPhase, ReqManager +from lightx2v.disagg.conn import MONITOR_POLLING_PORT, DataArgs, DataManager, DataPoll, DataSender, DisaggregationMode, DisaggregationPhase +from lightx2v.disagg.monitor import Reporter from lightx2v.disagg.protocol import AllocationRequest, MemoryHandle, RemoteBuffer +from lightx2v.disagg.rdma_buffer import RDMABuffer, RDMABufferDescriptor +from lightx2v.disagg.rdma_client import RDMAClient from lightx2v.disagg.services.base import BaseService from lightx2v.disagg.utils import ( estimate_encoder_buffer_sizes, @@ -23,10 +27,28 @@ class EncoderService(BaseService): - def __init__(self): + def __init__(self, config: dict): super().__init__() - self.request_port = REQUEST_POLLING_PORT + 0 - self.req_mgr = ReqManager() + self.config = config + self.encoder_engine_rank = int(self.config.get("encoder_engine_rank", "0")) + self.transformer_engine_rank = int(self.config.get("transformer_engine_rank", "1")) + self.decoder_engine_rank = int(self.config.get("decoder_engine_rank", "2")) + self._request_rdma_client: Optional[RDMAClient] = None + self._request_rdma_buffer: Optional[RDMABuffer] = None + self._phase1_rdma_client: Optional[RDMAClient] = None + self._phase1_rdma_buffer: Optional[RDMABuffer] = None + shared_slots = int(self.config.get("rdma_buffer_slots", "128")) + shared_slot_size = int(self.config.get("rdma_buffer_slot_size", "4096")) + self._request_server_ip = str(self.config.get("rdma_request_host", "127.0.0.1")) + self._request_handshake_port = int(self.config.get("rdma_request_handshake_port", "5566")) + self._request_slots = shared_slots + self._request_slot_size = shared_slot_size + self._phase1_server_ip = str(self.config.get("rdma_phase1_host", "127.0.0.1")) + self._phase1_handshake_port = int(self.config.get("rdma_phase1_handshake_port", "5567")) + self._phase1_slots = shared_slots + self._phase1_slot_size = shared_slot_size + self._last_request_connect_retry_ts = 0.0 + self._last_phase1_connect_retry_ts = 0.0 self.text_encoder = None self.image_encoder = None self.vae_encoder = None @@ -36,19 +58,119 @@ def __init__(self): ) self.data_sender: Dict[int, DataSender] = {} self._rdma_buffers: Dict[int, List[torch.Tensor]] = {} + self.reporter = Reporter( + service_type="encoder", + gpu_id=self.encoder_engine_rank, + bind_address=f"tcp://{self.config.get('data_bootstrap_addr', '127.0.0.1')}:{MONITOR_POLLING_PORT + self.encoder_engine_rank}", + ) + self._reporter_thread: Optional[threading.Thread] = threading.Thread( + target=self.reporter.serve_forever, + name="encoder-reporter", + daemon=True, + ) + self._reporter_thread.start() + self.load_models() + + def _ensure_request_buffer(self) -> bool: + if self._request_rdma_buffer is not None: + return True + + now = time.time() + if now - self._last_request_connect_retry_ts < 1.0: + return False + self._last_request_connect_retry_ts = now + + if self._request_rdma_client is None: + self._request_rdma_client = RDMAClient(local_buffer_size=self._request_slot_size) + + self._request_rdma_client.connect_to_server( + server_ip=self._request_server_ip, + port=self._request_handshake_port, + ) + + remote_info = self._request_rdma_client.remote_info + base_addr = int(remote_info["addr"]) + descriptor = RDMABufferDescriptor( + slot_addr=base_addr + 16, + slot_bytes=self._request_slots * self._request_slot_size, + slot_size=self._request_slot_size, + buffer_size=self._request_slots, + head_addr=base_addr, + tail_addr=base_addr + 8, + rkey=int(remote_info.get("rkey", 0)), + ) + self._request_rdma_buffer = RDMABuffer( + role="client", + rdma_client=self._request_rdma_client, + remote=descriptor, + ) + self.logger.info( + "Connected request RDMA buffer: host=%s port=%s slots=%s slot_size=%s", + self._request_server_ip, + self._request_handshake_port, + self._request_slots, + self._request_slot_size, + ) + return True + + def _ensure_phase1_meta_buffer(self) -> bool: + if self._phase1_rdma_buffer is not None: + return True + + now = time.time() + if now - self._last_phase1_connect_retry_ts < 1.0: + return False + self._last_phase1_connect_retry_ts = now + + if self._phase1_rdma_client is None: + self._phase1_rdma_client = RDMAClient(local_buffer_size=self._phase1_slot_size) + + self._phase1_rdma_client.connect_to_server( + server_ip=self._phase1_server_ip, + port=self._phase1_handshake_port, + ) + + remote_info = self._phase1_rdma_client.remote_info + base_addr = int(remote_info["addr"]) + descriptor = RDMABufferDescriptor( + slot_addr=base_addr + 16, + slot_bytes=self._phase1_slots * self._phase1_slot_size, + slot_size=self._phase1_slot_size, + buffer_size=self._phase1_slots, + head_addr=base_addr, + tail_addr=base_addr + 8, + rkey=int(remote_info.get("rkey", 0)), + ) + self._phase1_rdma_buffer = RDMABuffer( + role="client", + rdma_client=self._phase1_rdma_client, + remote=descriptor, + ) + self.logger.info( + "Connected phase1 RDMA buffer: host=%s port=%s slots=%s slot_size=%s", + self._phase1_server_ip, + self._phase1_handshake_port, + self._phase1_slots, + self._phase1_slot_size, + ) + return True def init(self, config): self.config = config - self.text_encoder = None - self.image_encoder = None - self.vae_encoder = None + shared_slots = int(self.config.get("rdma_buffer_slots", self._request_slots)) + shared_slot_size = int(self.config.get("rdma_buffer_slot_size", 4096)) + self._request_server_ip = str(self.config.get("rdma_request_host", self._request_server_ip)) + self._request_handshake_port = int(self.config.get("rdma_request_handshake_port", self._request_handshake_port)) + self._request_slots = shared_slots + self._request_slot_size = shared_slot_size + self._phase1_server_ip = str(self.config.get("rdma_phase1_host", self._phase1_server_ip)) + self._phase1_handshake_port = int(self.config.get("rdma_phase1_handshake_port", self._phase1_handshake_port)) + self._phase1_slots = shared_slots + self._phase1_slot_size = shared_slot_size self.encoder_engine_rank = int(self.config.get("encoder_engine_rank", 0)) self.transformer_engine_rank = int(self.config.get("transformer_engine_rank", 1)) self.decoder_engine_rank = int(self.config.get("decoder_engine_rank", 2)) - # Load models based on config - self.load_models() - # Seed everything if seed is in config if "seed" in self.config: seed_all(self.config["seed"]) @@ -56,6 +178,18 @@ def init(self, config): data_bootstrap_addr = self.config.get("data_bootstrap_addr", "127.0.0.1") data_bootstrap_room = self.config.get("data_bootstrap_room", 0) + phase1_deadline = time.time() + 30.0 + while self._phase1_rdma_buffer is None and time.time() < phase1_deadline: + try: + self._ensure_phase1_meta_buffer() + except Exception: + self.logger.exception("Failed to connect phase1 RDMA buffer, will retry") + if self._phase1_rdma_buffer is None: + time.sleep(0.1) + + if self._phase1_rdma_buffer is None: + raise RuntimeError("phase1 RDMA buffer is not ready") + if data_bootstrap_addr is None or data_bootstrap_room is None: return @@ -78,6 +212,13 @@ def init(self, config): self.data_mgr.init(data_args, data_bootstrap_room) self.data_sender[data_bootstrap_room] = DataSender(self.data_mgr, data_bootstrap_addr, data_bootstrap_room) + phase1_meta = { + "request_config": dict(self.config), + "encoder_node_address": self.data_mgr.get_localhost(), + "encoder_session_id": self.data_mgr.get_session_id(), + } + self._phase1_rdma_buffer.produce(phase1_meta) + def load_models(self): self.logger.info("Loading Encoder Models...") @@ -355,6 +496,10 @@ def remove(self, room: int): def release(self): for room in list(self._rdma_buffers.keys()): self.remove(room) + self.reporter.stop() + if self._reporter_thread is not None and self._reporter_thread.is_alive(): + self._reporter_thread.join(timeout=1.0) + self._reporter_thread = None if self.data_mgr is not None: self.data_mgr.release() self.data_sender.clear() @@ -368,14 +513,17 @@ def run(self, stop_event=None): complete_queue: Dict[int, dict] = {} while True: - # config = self.req_mgr.receive(self.request_port) - # req_queue.append(config) - while True: - config = self.req_mgr.receive_non_block(self.request_port) - if config is None: - break - self.logger.info("Received request config: %s", {k: v for k, v in config.items() if not k.endswith("_path")}) - req_queue.append(config) + if self._request_rdma_buffer is None: + try: + self._ensure_request_buffer() + except Exception: + self.logger.exception("Failed to connect request RDMA buffer, will retry") + + if self._request_rdma_buffer is not None: + config = self._request_rdma_buffer.consume() + if config is not None: + self.logger.info("Received request config from RDMA buffer: %s", {k: v for k, v in config.items() if not k.endswith("_path")}) + req_queue.append(config) if req_queue: config = req_queue.popleft() diff --git a/lightx2v/disagg/services/transformer.py b/lightx2v/disagg/services/transformer.py index 15363b589..0446199b0 100644 --- a/lightx2v/disagg/services/transformer.py +++ b/lightx2v/disagg/services/transformer.py @@ -1,5 +1,6 @@ import hashlib import json +import threading import time from collections import deque from typing import List, Optional @@ -7,8 +8,11 @@ import numpy as np import torch -from lightx2v.disagg.conn import REQUEST_POLLING_PORT, DataArgs, DataManager, DataPoll, DataReceiver, DataSender, DisaggregationMode, DisaggregationPhase, ReqManager +from lightx2v.disagg.conn import MONITOR_POLLING_PORT, DataArgs, DataManager, DataPoll, DataReceiver, DataSender, DisaggregationMode, DisaggregationPhase +from lightx2v.disagg.monitor import Reporter from lightx2v.disagg.protocol import AllocationRequest, MemoryHandle, RemoteBuffer +from lightx2v.disagg.rdma_buffer import RDMABuffer, RDMABufferDescriptor +from lightx2v.disagg.rdma_client import RDMAClient from lightx2v.disagg.services.base import BaseService from lightx2v.disagg.utils import ( estimate_encoder_buffer_sizes, @@ -22,10 +26,28 @@ class TransformerService(BaseService): - def __init__(self): + def __init__(self, config: dict): super().__init__() - self.request_port = REQUEST_POLLING_PORT + 1 - self.req_mgr = ReqManager() + self.config = config + self.encoder_engine_rank = int(self.config.get("encoder_engine_rank", 0)) + self.transformer_engine_rank = int(self.config.get("transformer_engine_rank", 1)) + self.decoder_engine_rank = int(self.config.get("decoder_engine_rank", 2)) + self._phase1_rdma_client: Optional[RDMAClient] = None + self._phase1_rdma_buffer: Optional[RDMABuffer] = None + self._phase2_rdma_client: Optional[RDMAClient] = None + self._phase2_rdma_buffer: Optional[RDMABuffer] = None + shared_slots = int(self.config.get("rdma_buffer_slots", "128")) + shared_slot_size = int(self.config.get("rdma_buffer_slot_size", "4096")) + self._phase1_server_ip = str(self.config.get("rdma_phase1_host", "127.0.0.1")) + self._phase1_handshake_port = int(self.config.get("rdma_phase1_handshake_port", "5567")) + self._phase1_slots = shared_slots + self._phase1_slot_size = shared_slot_size + self._phase2_server_ip = str(self.config.get("rdma_phase2_host", "127.0.0.1")) + self._phase2_handshake_port = int(self.config.get("rdma_phase2_handshake_port", "5568")) + self._phase2_slots = shared_slots + self._phase2_slot_size = shared_slot_size + self._last_phase1_connect_retry_ts = 0.0 + self._last_phase2_connect_retry_ts = 0.0 self.transformer = None self.scheduler = None self.rdma_buffer1: dict[int, List[torch.Tensor]] = {} @@ -34,17 +56,91 @@ def __init__(self): self.data_mgr2 = DataManager(DisaggregationPhase.PHASE2, DisaggregationMode.TRANSFORMER) self.data_receiver: dict[int, DataReceiver] = {} self.data_sender: dict[int, DataSender] = {} + self.reporter = Reporter( + service_type="transformer", + gpu_id=self.transformer_engine_rank, + bind_address=f"tcp://{self.config.get('data_bootstrap_addr', '127.0.0.1')}:{MONITOR_POLLING_PORT + self.transformer_engine_rank}", + ) + self._reporter_thread: Optional[threading.Thread] = threading.Thread( + target=self.reporter.serve_forever, + name="transformer-reporter", + daemon=True, + ) + self._reporter_thread.start() + self.load_models() + + def _ensure_phase1_request_buffer(self) -> bool: + if self._phase1_rdma_buffer is not None: + return True + now = time.time() + if now - self._last_phase1_connect_retry_ts < 1.0: + return False + self._last_phase1_connect_retry_ts = now + + if self._phase1_rdma_client is None: + self._phase1_rdma_client = RDMAClient(local_buffer_size=self._phase1_slot_size) + self._phase1_rdma_client.connect_to_server(self._phase1_server_ip, self._phase1_handshake_port) + remote_info = self._phase1_rdma_client.remote_info + base_addr = int(remote_info["addr"]) + self._phase1_rdma_buffer = RDMABuffer( + role="client", + rdma_client=self._phase1_rdma_client, + remote=RDMABufferDescriptor( + slot_addr=base_addr + 16, + slot_bytes=self._phase1_slots * self._phase1_slot_size, + slot_size=self._phase1_slot_size, + buffer_size=self._phase1_slots, + head_addr=base_addr, + tail_addr=base_addr + 8, + rkey=int(remote_info.get("rkey", 0)), + ), + ) + return True + + def _ensure_phase2_meta_buffer(self) -> bool: + if self._phase2_rdma_buffer is not None: + return True + now = time.time() + if now - self._last_phase2_connect_retry_ts < 1.0: + return False + self._last_phase2_connect_retry_ts = now + + if self._phase2_rdma_client is None: + self._phase2_rdma_client = RDMAClient(local_buffer_size=self._phase2_slot_size) + self._phase2_rdma_client.connect_to_server(self._phase2_server_ip, self._phase2_handshake_port) + remote_info = self._phase2_rdma_client.remote_info + base_addr = int(remote_info["addr"]) + self._phase2_rdma_buffer = RDMABuffer( + role="client", + rdma_client=self._phase2_rdma_client, + remote=RDMABufferDescriptor( + slot_addr=base_addr + 16, + slot_bytes=self._phase2_slots * self._phase2_slot_size, + slot_size=self._phase2_slot_size, + buffer_size=self._phase2_slots, + head_addr=base_addr, + tail_addr=base_addr + 8, + rkey=int(remote_info.get("rkey", 0)), + ), + ) + return True def init(self, config): self.config = config - self.transformer = None - self.scheduler = None + shared_slots = int(self.config.get("rdma_buffer_slots", self._phase1_slots)) + shared_slot_size = int(self.config.get("rdma_buffer_slot_size", 4096)) + self._phase1_server_ip = str(self.config.get("rdma_phase1_host", self._phase1_server_ip)) + self._phase1_handshake_port = int(self.config.get("rdma_phase1_handshake_port", self._phase1_handshake_port)) + self._phase1_slots = shared_slots + self._phase1_slot_size = shared_slot_size + self._phase2_server_ip = str(self.config.get("rdma_phase2_host", self._phase2_server_ip)) + self._phase2_handshake_port = int(self.config.get("rdma_phase2_handshake_port", self._phase2_handshake_port)) + self._phase2_slots = shared_slots + self._phase2_slot_size = shared_slot_size self.encoder_engine_rank = int(self.config.get("encoder_engine_rank", 0)) self.transformer_engine_rank = int(self.config.get("transformer_engine_rank", 1)) self.decoder_engine_rank = int(self.config.get("decoder_engine_rank", 2)) - self.load_models() - # Set global seed if present in config, though specific process calls might reuse it if "seed" in self.config: seed_all(self.config["seed"]) @@ -55,6 +151,20 @@ def init(self, config): if data_bootstrap_addr is None or data_bootstrap_room is None: return + phase_deadline = time.time() + 30.0 + while time.time() < phase_deadline: + try: + self._ensure_phase1_request_buffer() + self._ensure_phase2_meta_buffer() + except Exception: + self.logger.exception("Failed to connect phase RDMA buffers, will retry") + if self._phase1_rdma_buffer is not None and self._phase2_rdma_buffer is not None: + break + time.sleep(0.1) + + if self._phase1_rdma_buffer is None or self._phase2_rdma_buffer is None: + raise RuntimeError("phase RDMA buffers are not ready") + buffer_sizes = estimate_encoder_buffer_sizes(self.config) request = AllocationRequest( bootstrap_room=data_bootstrap_room, @@ -72,7 +182,8 @@ def init(self, config): ib_device=None, ) self.data_mgr1.init(data_args, data_bootstrap_room) - self.data_receiver[data_bootstrap_room] = DataReceiver(self.data_mgr1, data_bootstrap_addr, data_bootstrap_room) + phase1_bootstrap_addr = str(self.config.get("encoder_node_address", data_bootstrap_addr)) + self.data_receiver[data_bootstrap_room] = DataReceiver(self.data_mgr1, phase1_bootstrap_addr, data_bootstrap_room) self.data_receiver[data_bootstrap_room].init() buffer_sizes = estimate_transformer_buffer_sizes(self.config) @@ -94,6 +205,14 @@ def init(self, config): self.data_mgr2.init(data_args, data_bootstrap_room) self.data_sender[data_bootstrap_room] = DataSender(self.data_mgr2, data_bootstrap_addr, data_bootstrap_room) + self._phase2_rdma_buffer.produce( + { + "request_config": dict(self.config), + "transformer_node_address": self.data_mgr2.get_localhost(), + "transformer_session_id": self.data_mgr2.get_session_id(), + } + ) + def load_models(self): self.logger.info("Loading Transformer Models...") @@ -377,6 +496,10 @@ def release(self): room_ids = set(self.rdma_buffer1.keys()) | set(self.rdma_buffer2.keys()) for room in list(room_ids): self.remove(room) + self.reporter.stop() + if self._reporter_thread is not None and self._reporter_thread.is_alive(): + self._reporter_thread.join(timeout=1.0) + self._reporter_thread = None if self.data_mgr1 is not None: self.data_mgr1.release() if self.data_mgr2 is not None: @@ -393,11 +516,21 @@ def run(self, stop_event=None): complete_queue: dict[int, dict] = {} while True: - while True: - config = self.req_mgr.receive_non_block(self.request_port) - if config is None: - break - req_queue.append(config) + if self._phase1_rdma_buffer is None: + try: + self._ensure_phase1_request_buffer() + except Exception: + self.logger.exception("Failed to connect phase1 request RDMA buffer, will retry") + + if self._phase1_rdma_buffer is not None: + packet = self._phase1_rdma_buffer.consume() + if packet is not None: + if isinstance(packet, dict) and "request_config" in packet: + config = dict(packet.get("request_config") or {}) + config["encoder_node_address"] = packet.get("encoder_node_address", config.get("encoder_node_address", "127.0.0.1")) + else: + config = packet + req_queue.append(config) if req_queue: config = req_queue.popleft() diff --git a/scripts/disagg/kill_service.sh b/scripts/disagg/kill_service.sh new file mode 100755 index 000000000..63e3601fc --- /dev/null +++ b/scripts/disagg/kill_service.sh @@ -0,0 +1,84 @@ +#!/bin/bash + +set -euo pipefail + +SCRIPT_NAME="run_wan_t2v_service.sh" +PORTS=(7788 7789 7790 12788 12789 12790 17788 17789 17790 27788 27789 27790) + +kill_pid_gracefully() { + local pid="$1" + if [[ -z "$pid" ]]; then + return + fi + if kill -0 "$pid" 2>/dev/null; then + kill "$pid" 2>/dev/null || true + sleep 1 + if kill -0 "$pid" 2>/dev/null; then + kill -9 "$pid" 2>/dev/null || true + fi + fi +} + +find_listen_pids_by_port() { + local port="$1" + + if command -v lsof >/dev/null 2>&1; then + lsof -nP -t -iTCP:"$port" -sTCP:LISTEN 2>/dev/null | sort -u || true + return + fi + + if command -v ss >/dev/null 2>&1; then + ss -ltnp 2>/dev/null | awk -v p=":$port" ' + index($4, p) > 0 { + while (match($0, /pid=[0-9]+/)) { + print substr($0, RSTART + 4, RLENGTH - 4) + $0 = substr($0, RSTART + RLENGTH) + } + } + ' | sort -u || true + return + fi + + if command -v fuser >/dev/null 2>&1; then + fuser -n tcp "$port" 2>/dev/null | tr ' ' '\n' | sed '/^$/d' | sort -u || true + return + fi + + echo "No supported tool found to query listening ports (need one of: lsof, ss, fuser)." >&2 +} + +echo "Stopping script process: ${SCRIPT_NAME}" +script_pids=$(pgrep -f "$SCRIPT_NAME" || true) +if [[ -n "${script_pids}" ]]; then + while read -r pid; do + [[ -z "$pid" ]] && continue + echo "Killing script pid=$pid" + kill_pid_gracefully "$pid" + done <<< "$script_pids" +else + echo "No running process found for ${SCRIPT_NAME}" +fi + +for port in "${PORTS[@]}"; do + echo "Stopping listeners on port ${port}" + port_pids=$(find_listen_pids_by_port "$port") + if [[ -z "${port_pids}" ]]; then + echo "No listener found on port ${port}" + continue + fi + + while read -r pid; do + [[ -z "$pid" ]] && continue + echo "Killing pid=$pid on port ${port}" + kill_pid_gracefully "$pid" + done <<< "$port_pids" + + remaining=$(find_listen_pids_by_port "$port") + if [[ -n "${remaining}" ]]; then + echo "Warning: port ${port} still has listeners: ${remaining}" + else + echo "Port ${port} is clear" + fi +done + +echo "kill_service.sh done." diff --git a/scripts/disagg/run_wan_t2v_service.sh b/scripts/disagg/run_wan_t2v_service.sh new file mode 100755 index 000000000..f7dd1ce8e --- /dev/null +++ b/scripts/disagg/run_wan_t2v_service.sh @@ -0,0 +1,147 @@ +#!/bin/bash + +# set path firstly +lightx2v_path=/root/zht/LightX2V +model_path=/root/zht/LightX2V/models/Wan-AI/Wan2.1-T2V-1.3B + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +# Keep flashinfer enabled while ensuring nvcc uses a supported host compiler. +export CC=/usr/bin/gcc-13 +export CXX=/usr/bin/g++-13 +export CUDAHOSTCXX=/usr/bin/g++-13 +if [[ -n "${NVCC_PREPEND_FLAGS:-}" ]]; then + export NVCC_PREPEND_FLAGS="${NVCC_PREPEND_FLAGS} -allow-unsupported-compiler" +else + export NVCC_PREPEND_FLAGS="-allow-unsupported-compiler" +fi + +controller_cfg=${lightx2v_path}/configs/disagg/wan_t2v_disagg_controller.json +encoder_cfg=${lightx2v_path}/configs/disagg/wan_t2v_disagg_encoder.json +transformer_cfg=${lightx2v_path}/configs/disagg/wan_t2v_disagg_transformer.json +decoder_cfg=${lightx2v_path}/configs/disagg/wan_t2v_disagg_decoder.json + +seed=42 +prompt="Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." +negative_prompt="镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" +save_result_path=${lightx2v_path}/save_results/test_disagg.mp4 +save_result_path_1=${save_result_path%.mp4}1.mp4 +save_result_path_2=${save_result_path%.mp4}2.mp4 + +# Remove old outputs so wait loop reflects current run status. +rm -f "${save_result_path_1}" "${save_result_path_2}" + +cleanup() { + local pids=("${encoder_pid:-}" "${transformer_pid:-}" "${decoder_pid:-}" "${controller_pid:-}") + for pid in "${pids[@]}"; do + if [[ -n "${pid}" ]] && kill -0 "${pid}" 2>/dev/null; then + kill "${pid}" 2>/dev/null || true + fi + done +} + +trap cleanup EXIT INT TERM + +wait_for_port() { + local host="$1" + local port="$2" + local timeout_secs="${3:-30}" + local waited=0 + + while true; do + if (echo > /dev/tcp/${host}/${port}) >/dev/null 2>&1; then + echo "Port ready: ${host}:${port}" + return 0 + fi + + if (( waited >= timeout_secs )); then + echo "Timeout waiting for port ${host}:${port} after ${timeout_secs}s" + return 1 + fi + + sleep 1 + waited=$((waited + 1)) + done +} + +rdma_request_port=5566 +rdma_phase1_port=5567 +rdma_phase2_port=5568 + +python -m lightx2v.disagg.examples.run_service \ + --service controller \ + --model_cls wan2.1 \ + --task t2v \ + --model_path ${model_path} \ + --config_json ${controller_cfg} \ + --seed ${seed} \ + --prompt "${prompt}" \ + --negative_prompt "${negative_prompt}" \ + --save_result_path ${save_result_path} \ + > ${lightx2v_path}/save_results/disagg_controller.log 2>&1 & +controller_pid=$! + +wait_for_port 127.0.0.1 ${rdma_request_port} 60 +wait_for_port 127.0.0.1 ${rdma_phase1_port} 60 +wait_for_port 127.0.0.1 ${rdma_phase2_port} 60 + +CUDA_VISIBLE_DEVICES=0 python -m lightx2v.disagg.examples.run_service \ + --service encoder \ + --model_cls wan2.1 \ + --task t2v \ + --model_path ${model_path} \ + --config_json ${encoder_cfg} \ + --seed ${seed} \ + --prompt "${prompt}" \ + --negative_prompt "${negative_prompt}" \ + --save_result_path ${save_result_path} \ + > ${lightx2v_path}/save_results/disagg_encoder.log 2>&1 & +encoder_pid=$! + +CUDA_VISIBLE_DEVICES=1 python -m lightx2v.disagg.examples.run_service \ + --service transformer \ + --model_cls wan2.1 \ + --task t2v \ + --model_path ${model_path} \ + --config_json ${transformer_cfg} \ + --seed ${seed} \ + --prompt "${prompt}" \ + --negative_prompt "${negative_prompt}" \ + --save_result_path ${save_result_path} \ + > ${lightx2v_path}/save_results/disagg_transformer.log 2>&1 & +transformer_pid=$! + +CUDA_VISIBLE_DEVICES=2 python -m lightx2v.disagg.examples.run_service \ + --service decoder \ + --model_cls wan2.1 \ + --task t2v \ + --model_path ${model_path} \ + --config_json ${decoder_cfg} \ + --seed ${seed} \ + --prompt "${prompt}" \ + --negative_prompt "${negative_prompt}" \ + --save_result_path ${save_result_path} \ + > ${lightx2v_path}/save_results/disagg_decoder.log 2>&1 & +decoder_pid=$! + +# Give background services time to flush and finish queued requests. + +echo "Waiting for output videos: ${save_result_path_1}, ${save_result_path_2}" +wait_seconds=0 +max_wait_seconds=1200 + +while true; do + if [[ -f "${save_result_path_1}" && -f "${save_result_path_2}" ]]; then + echo "Both output videos are generated." + break + fi + + if (( wait_seconds >= max_wait_seconds )); then + echo "Timeout waiting for output videos after ${max_wait_seconds}s" + exit 1 + fi + + sleep 5 + wait_seconds=$((wait_seconds + 5)) +done