Conversation
There was a problem hiding this comment.
Code Review
This pull request implements the NeoPP model architecture, featuring both dense and Mixture-of-Experts (MoE) versions. Key additions include optimized Triton kernels for fused dual-RMSNorm and 3D Neox-RoPE, a new NeoPP runner and scheduler, and updated weight modules to support these operations. The review feedback highlights several areas for improvement, such as removing dead code and commented-out blocks, replacing hardcoded file paths with configurable parameters, fixing typos (e.g., "sle" for "self"), and ensuring better device-agnosticism by avoiding direct ".cuda()" calls. Additionally, suggestions were made to use "_get_actual_weight()" for consistent LoRA support and to refactor duplicated logic within the Triton kernels to improve maintainability.
| model_path="/data/nvme1/yongyang/FL/neo9b/neo9b", | ||
| model_cls="neopp", | ||
| task="t2i", | ||
| ) | ||
|
|
||
| pipe.create_generator(config_json="../../configs/neopp/neopp_dense_t2i.json") | ||
| pipe.modify_config({"load_kv_cache_in_pipeline_for_debug": False}) | ||
|
|
||
|
|
||
| # ------------------------------------------------- | ||
| # Load KV cache and generate | ||
| # ------------------------------------------------- | ||
|
|
||
| pipe.runner.load_kvcache( | ||
| "/data/nvme1/yongyang/FL/neo_test9b/vlm_tensor/to_x2v_cond_kv.pt", | ||
| "/data/nvme1/yongyang/FL/neo_test9b/vlm_tensor/to_x2v_uncond_kv.pt", | ||
| ) | ||
|
|
||
| pipe.generate( | ||
| seed=200, | ||
| save_result_path="/data/nvme1/yongyang/FL/LightX2V/save_results/output_lightx2v_neopp_dense_t2i_1k.png", |
There was a problem hiding this comment.
The file paths for the model, KV cache, and output are hardcoded. This makes the script less portable and harder to use in different environments. It's a good practice to use command-line arguments to pass these paths, which makes the script more flexible and reusable.
For example, you could use Python's argparse module to define arguments for these paths.
| model_path="/data/nvme1/yongyang/FL/neo_gen_30b_moe/neo_gen_30b_moe", | ||
| model_cls="neopp", | ||
| task="t2i", | ||
| ) | ||
|
|
||
| pipe.create_generator(config_json="../../configs/neopp/neopp_moe_t2i.json") | ||
| pipe.modify_config({"load_kv_cache_in_pipeline_for_debug": False}) | ||
|
|
||
|
|
||
| # ------------------------------------------------- | ||
| # Load KV cache and generate | ||
| # ------------------------------------------------- | ||
|
|
||
| pipe.runner.load_kvcache( | ||
| "/data/nvme1/yongyang/FL/neo_test/vlm_tensor/to_x2v_cond_kv.pt", | ||
| "/data/nvme1/yongyang/FL/neo_test/vlm_tensor/to_x2v_uncond_kv.pt", | ||
| ) | ||
|
|
||
| pipe.generate( | ||
| seed=200, | ||
| save_result_path="/data/nvme1/yongyang/FL/LightX2V/save_results/output_lightx2v_neopp_moe_t2i_512.png", |
There was a problem hiding this comment.
| hidden_states = hidden_states.to(torch.float32) | ||
| variance = hidden_states.pow(2).mean(-1, keepdim=True) | ||
| hidden_states = hidden_states * torch.rsqrt(variance + self.eps) | ||
| return self.weight * hidden_states.to(input_dtype) |
There was a problem hiding this comment.
To ensure support for LoRA, _get_actual_weight() should be used here instead of accessing self.weight directly. This is consistent with how other weight modules handle LoRA.
| return self.weight * hidden_states.to(input_dtype) | |
| return self._get_actual_weight() * hidden_states.to(input_dtype) |
| create_cuda_buffer=False, | ||
| create_cpu_buffer=False, | ||
| lazy_load=False, | ||
| lazy_load_file=None, | ||
| is_post_adapter=False, | ||
| eps=1e-6, | ||
| lora_prefix="diffusion_model.blocks", | ||
| lora_path="", |
There was a problem hiding this comment.
The __init__ method of RMSWeightFusedQKNorm3DRope includes several parameters (create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter, lora_prefix, lora_path) that are not used within the class. This appears to be a copy-paste artifact and adds unnecessary complexity. These unused parameters should be removed to clean up the code.
A similar issue exists in RMSWeightDualNorm3DRope.
eps=1e-6,
):| if pid < q_total: | ||
| s = pid // q_num_heads | ||
| h = pid % q_num_heads | ||
| base = (s * q_num_heads + h) * HEAD_DIM | ||
|
|
||
| xt1 = tl.load(q_ptr + base + offs_q).to(tl.float32) | ||
| xt2 = tl.load(q_ptr + base + QUARTER + offs_q).to(tl.float32) | ||
| var_t = (tl.sum(xt1 * xt1) + tl.sum(xt2 * xt2)) * (1.0 / HALF) + eps | ||
| irms_t = tl.rsqrt(var_t) | ||
| wt1 = tl.load(w_qt_ptr + offs_q).to(tl.float32) | ||
| wt2 = tl.load(w_qt_ptr + QUARTER + offs_q).to(tl.float32) | ||
| xt1 = xt1 * irms_t * wt1 | ||
| xt2 = xt2 * irms_t * wt2 | ||
| c_t = tl.load(cos_t_ptr + s * HALF + offs_q).to(tl.float32) | ||
| s_t = tl.load(sin_t_ptr + s * HALF + offs_q).to(tl.float32) | ||
| new_xt1 = xt1 * c_t - xt2 * s_t | ||
| new_xt2 = xt2 * c_t + xt1 * s_t | ||
|
|
||
| xh1 = tl.load(q_ptr + base + HALF + offs_e).to(tl.float32) | ||
| xh2 = tl.load(q_ptr + base + HALF + EIGHTH + offs_e).to(tl.float32) | ||
| xw1 = tl.load(q_ptr + base + HALF + QUARTER + offs_e).to(tl.float32) | ||
| xw2 = tl.load(q_ptr + base + HALF + QUARTER + EIGHTH + offs_e).to(tl.float32) | ||
| var_hw = (tl.sum(xh1 * xh1) + tl.sum(xh2 * xh2) + tl.sum(xw1 * xw1) + tl.sum(xw2 * xw2)) * (1.0 / HALF) + eps | ||
| irms_hw = tl.rsqrt(var_hw) | ||
| wh1 = tl.load(w_qhw_ptr + offs_e).to(tl.float32) | ||
| wh2 = tl.load(w_qhw_ptr + EIGHTH + offs_e).to(tl.float32) | ||
| ww1 = tl.load(w_qhw_ptr + QUARTER + offs_e).to(tl.float32) | ||
| ww2 = tl.load(w_qhw_ptr + QUARTER + EIGHTH + offs_e).to(tl.float32) | ||
| xh1 = xh1 * irms_hw * wh1 | ||
| xh2 = xh2 * irms_hw * wh2 | ||
| xw1 = xw1 * irms_hw * ww1 | ||
| xw2 = xw2 * irms_hw * ww2 | ||
| c_h = tl.load(cos_h_ptr + s * QUARTER + offs_e).to(tl.float32) | ||
| s_h = tl.load(sin_h_ptr + s * QUARTER + offs_e).to(tl.float32) | ||
| new_xh1 = xh1 * c_h - xh2 * s_h | ||
| new_xh2 = xh2 * c_h + xh1 * s_h | ||
| c_w = tl.load(cos_w_ptr + s * QUARTER + offs_e).to(tl.float32) | ||
| s_w = tl.load(sin_w_ptr + s * QUARTER + offs_e).to(tl.float32) | ||
| new_xw1 = xw1 * c_w - xw2 * s_w | ||
| new_xw2 = xw2 * c_w + xw1 * s_w | ||
|
|
||
| tl.store(q_ptr + base + offs_q, new_xt1.to(tl.bfloat16)) | ||
| tl.store(q_ptr + base + QUARTER + offs_q, new_xt2.to(tl.bfloat16)) | ||
| tl.store(q_ptr + base + HALF + offs_e, new_xh1.to(tl.bfloat16)) | ||
| tl.store(q_ptr + base + HALF + EIGHTH + offs_e, new_xh2.to(tl.bfloat16)) | ||
| tl.store(q_ptr + base + HALF + QUARTER + offs_e, new_xw1.to(tl.bfloat16)) | ||
| tl.store(q_ptr + base + HALF + QUARTER + EIGHTH + offs_e, new_xw2.to(tl.bfloat16)) | ||
| else: | ||
| k_pid = pid - q_total | ||
| s = k_pid // k_num_heads | ||
| h = k_pid % k_num_heads | ||
| base = (s * k_num_heads + h) * HEAD_DIM | ||
|
|
||
| xt1 = tl.load(k_ptr + base + offs_q).to(tl.float32) | ||
| xt2 = tl.load(k_ptr + base + QUARTER + offs_q).to(tl.float32) | ||
| var_t = (tl.sum(xt1 * xt1) + tl.sum(xt2 * xt2)) * (1.0 / HALF) + eps | ||
| irms_t = tl.rsqrt(var_t) | ||
| wt1 = tl.load(w_kt_ptr + offs_q).to(tl.float32) | ||
| wt2 = tl.load(w_kt_ptr + QUARTER + offs_q).to(tl.float32) | ||
| xt1 = xt1 * irms_t * wt1 | ||
| xt2 = xt2 * irms_t * wt2 | ||
| c_t = tl.load(cos_t_ptr + s * HALF + offs_q).to(tl.float32) | ||
| s_t = tl.load(sin_t_ptr + s * HALF + offs_q).to(tl.float32) | ||
| new_xt1 = xt1 * c_t - xt2 * s_t | ||
| new_xt2 = xt2 * c_t + xt1 * s_t | ||
|
|
||
| xh1 = tl.load(k_ptr + base + HALF + offs_e).to(tl.float32) | ||
| xh2 = tl.load(k_ptr + base + HALF + EIGHTH + offs_e).to(tl.float32) | ||
| xw1 = tl.load(k_ptr + base + HALF + QUARTER + offs_e).to(tl.float32) | ||
| xw2 = tl.load(k_ptr + base + HALF + QUARTER + EIGHTH + offs_e).to(tl.float32) | ||
| var_hw = (tl.sum(xh1 * xh1) + tl.sum(xh2 * xh2) + tl.sum(xw1 * xw1) + tl.sum(xw2 * xw2)) * (1.0 / HALF) + eps | ||
| irms_hw = tl.rsqrt(var_hw) | ||
| wh1 = tl.load(w_khw_ptr + offs_e).to(tl.float32) | ||
| wh2 = tl.load(w_khw_ptr + EIGHTH + offs_e).to(tl.float32) | ||
| ww1 = tl.load(w_khw_ptr + QUARTER + offs_e).to(tl.float32) | ||
| ww2 = tl.load(w_khw_ptr + QUARTER + EIGHTH + offs_e).to(tl.float32) | ||
| xh1 = xh1 * irms_hw * wh1 | ||
| xh2 = xh2 * irms_hw * wh2 | ||
| xw1 = xw1 * irms_hw * ww1 | ||
| xw2 = xw2 * irms_hw * ww2 | ||
| c_h = tl.load(cos_h_ptr + s * QUARTER + offs_e).to(tl.float32) | ||
| s_h = tl.load(sin_h_ptr + s * QUARTER + offs_e).to(tl.float32) | ||
| new_xh1 = xh1 * c_h - xh2 * s_h | ||
| new_xh2 = xh2 * c_h + xh1 * s_h | ||
| c_w = tl.load(cos_w_ptr + s * QUARTER + offs_e).to(tl.float32) | ||
| s_w = tl.load(sin_w_ptr + s * QUARTER + offs_e).to(tl.float32) | ||
| new_xw1 = xw1 * c_w - xw2 * s_w | ||
| new_xw2 = xw2 * c_w + xw1 * s_w | ||
|
|
||
| tl.store(k_ptr + base + offs_q, new_xt1.to(tl.bfloat16)) | ||
| tl.store(k_ptr + base + QUARTER + offs_q, new_xt2.to(tl.bfloat16)) | ||
| tl.store(k_ptr + base + HALF + offs_e, new_xh1.to(tl.bfloat16)) | ||
| tl.store(k_ptr + base + HALF + EIGHTH + offs_e, new_xh2.to(tl.bfloat16)) | ||
| tl.store(k_ptr + base + HALF + QUARTER + offs_e, new_xw1.to(tl.bfloat16)) | ||
| tl.store(k_ptr + base + HALF + QUARTER + EIGHTH + offs_e, new_xw2.to(tl.bfloat16)) |
There was a problem hiding this comment.
There is significant code duplication between the if pid < q_total: block (for Q) and the else: block (for K). The logic for RMSNorm, RoPE application, and storing results is nearly identical.
To improve maintainability and readability, consider refactoring the common logic into a separate tl.device_func that can be called for both Q and K with the appropriate pointers and parameters.
| if self.config.get("load_kv_cache_in_pipeline_for_debug", False): | ||
| if self.config.get("version", "moe") == "moe": | ||
| self.load_kvcache( | ||
| "/data/nvme1/yongyang/FL/neo_test/vlm_tensor/to_x2v_cond_kv.pt", | ||
| "/data/nvme1/yongyang/FL/neo_test/vlm_tensor/to_x2v_uncond_kv.pt", | ||
| ) | ||
| else: | ||
| self.load_kvcache( | ||
| "/data/nvme1/yongyang/FL/neo_test9b/vlm_tensor/to_x2v_cond_kv.pt", | ||
| "/data/nvme1/yongyang/FL/neo_test9b/vlm_tensor/to_x2v_uncond_kv.pt", | ||
| ) |
There was a problem hiding this comment.
| return inv_freq_full[::2] | ||
|
|
||
| def _compute_rope(self, position_ids, inv_freq): | ||
| inv_freq = inv_freq.cuda() |
There was a problem hiding this comment.
| logger.info(f"✅ Image saved successfully to: {self.input_info.save_result_path} ✅") | ||
| return grid_image | ||
|
|
||
| def _denorm(self, x: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]): |
There was a problem hiding this comment.
The _denorm method does not use any instance attributes (self). It can be defined as a static method using the @staticmethod decorator. This makes it clear that the method's behavior does not depend on the state of the object.
@staticmethod
def _denorm(x: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):| # elif self.time_schedule == "dynamic": | ||
| # """ | ||
| # not support yet | ||
| # """ | ||
| # mu = self._calculate_dynamic_mu(image_seq_len) | ||
| # mu_t = t.new_tensor(mu) | ||
| # if self.time_shift_type == "exponential": | ||
| # shift = torch.exp(mu_t) | ||
| # sigma = shift * sigma / (1 + (shift - 1) * sigma) | ||
| # elif self.time_shift_type == "linear": | ||
| # sigma = mu_t / (mu_t + (1 / sigma - 1)) | ||
| # else: | ||
| # raise ValueError(f"Unsupported time_shift_type: {self.time_shift_type}") |
There was a problem hiding this comment.
| lightx2v_path=/data/nvme1/yongyang/FL/LightX2V | ||
| model_path=/data/nvme1/yongyang/FL/neo9b/neo9b |
There was a problem hiding this comment.
The script contains hardcoded paths specific to a user's environment (e.g., /data/nvme1/yongyang/...). This makes the script not portable. It's better to use environment variables that can be set externally or passed as arguments to the script.
| lightx2v_path=/data/nvme1/yongyang/FL/LightX2V | |
| model_path=/data/nvme1/yongyang/FL/neo9b/neo9b | |
| lightx2v_path=${LIGHTX2V_PATH:-/path/to/LightX2V} | |
| model_path=${MODEL_PATH:-/path/to/model} |
No description provided.