Skip to content

Support Neo++ Model#965

Merged
llmc-reviewer merged 32 commits intomainfrom
neo
Mar 30, 2026
Merged

Support Neo++ Model#965
llmc-reviewer merged 32 commits intomainfrom
neo

Conversation

@helloyongyang
Copy link
Copy Markdown
Contributor

No description provided.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements the NeoPP model architecture, featuring both dense and Mixture-of-Experts (MoE) versions. Key additions include optimized Triton kernels for fused dual-RMSNorm and 3D Neox-RoPE, a new NeoPP runner and scheduler, and updated weight modules to support these operations. The review feedback highlights several areas for improvement, such as removing dead code and commented-out blocks, replacing hardcoded file paths with configurable parameters, fixing typos (e.g., "sle" for "self"), and ensuring better device-agnosticism by avoiding direct ".cuda()" calls. Additionally, suggestions were made to use "_get_actual_weight()" for consistent LoRA support and to refactor duplicated logic within the Triton kernels to improve maintainability.

Comment on lines +8 to +28
model_path="/data/nvme1/yongyang/FL/neo9b/neo9b",
model_cls="neopp",
task="t2i",
)

pipe.create_generator(config_json="../../configs/neopp/neopp_dense_t2i.json")
pipe.modify_config({"load_kv_cache_in_pipeline_for_debug": False})


# -------------------------------------------------
# Load KV cache and generate
# -------------------------------------------------

pipe.runner.load_kvcache(
"/data/nvme1/yongyang/FL/neo_test9b/vlm_tensor/to_x2v_cond_kv.pt",
"/data/nvme1/yongyang/FL/neo_test9b/vlm_tensor/to_x2v_uncond_kv.pt",
)

pipe.generate(
seed=200,
save_result_path="/data/nvme1/yongyang/FL/LightX2V/save_results/output_lightx2v_neopp_dense_t2i_1k.png",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The file paths for the model, KV cache, and output are hardcoded. This makes the script less portable and harder to use in different environments. It's a good practice to use command-line arguments to pass these paths, which makes the script more flexible and reusable.

For example, you could use Python's argparse module to define arguments for these paths.

Comment on lines +8 to +28
model_path="/data/nvme1/yongyang/FL/neo_gen_30b_moe/neo_gen_30b_moe",
model_cls="neopp",
task="t2i",
)

pipe.create_generator(config_json="../../configs/neopp/neopp_moe_t2i.json")
pipe.modify_config({"load_kv_cache_in_pipeline_for_debug": False})


# -------------------------------------------------
# Load KV cache and generate
# -------------------------------------------------

pipe.runner.load_kvcache(
"/data/nvme1/yongyang/FL/neo_test/vlm_tensor/to_x2v_cond_kv.pt",
"/data/nvme1/yongyang/FL/neo_test/vlm_tensor/to_x2v_uncond_kv.pt",
)

pipe.generate(
seed=200,
save_result_path="/data/nvme1/yongyang/FL/LightX2V/save_results/output_lightx2v_neopp_moe_t2i_512.png",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The file paths for the model, KV cache, and output are hardcoded. This makes the script less portable. Consider using command-line arguments (e.g., with argparse) to make the script more flexible.

Also, the filename neopp_meo_512.py seems to contain a typo and should likely be neopp_moe_512.py.

hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
return self.weight * hidden_states.to(input_dtype)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To ensure support for LoRA, _get_actual_weight() should be used here instead of accessing self.weight directly. This is consistent with how other weight modules handle LoRA.

Suggested change
return self.weight * hidden_states.to(input_dtype)
return self._get_actual_weight() * hidden_states.to(input_dtype)

Comment on lines +434 to +441
create_cuda_buffer=False,
create_cpu_buffer=False,
lazy_load=False,
lazy_load_file=None,
is_post_adapter=False,
eps=1e-6,
lora_prefix="diffusion_model.blocks",
lora_path="",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The __init__ method of RMSWeightFusedQKNorm3DRope includes several parameters (create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter, lora_prefix, lora_path) that are not used within the class. This appears to be a copy-paste artifact and adds unnecessary complexity. These unused parameters should be removed to clean up the code.

A similar issue exists in RMSWeightDualNorm3DRope.

        eps=1e-6,
    ):

Comment on lines +1076 to +1170
if pid < q_total:
s = pid // q_num_heads
h = pid % q_num_heads
base = (s * q_num_heads + h) * HEAD_DIM

xt1 = tl.load(q_ptr + base + offs_q).to(tl.float32)
xt2 = tl.load(q_ptr + base + QUARTER + offs_q).to(tl.float32)
var_t = (tl.sum(xt1 * xt1) + tl.sum(xt2 * xt2)) * (1.0 / HALF) + eps
irms_t = tl.rsqrt(var_t)
wt1 = tl.load(w_qt_ptr + offs_q).to(tl.float32)
wt2 = tl.load(w_qt_ptr + QUARTER + offs_q).to(tl.float32)
xt1 = xt1 * irms_t * wt1
xt2 = xt2 * irms_t * wt2
c_t = tl.load(cos_t_ptr + s * HALF + offs_q).to(tl.float32)
s_t = tl.load(sin_t_ptr + s * HALF + offs_q).to(tl.float32)
new_xt1 = xt1 * c_t - xt2 * s_t
new_xt2 = xt2 * c_t + xt1 * s_t

xh1 = tl.load(q_ptr + base + HALF + offs_e).to(tl.float32)
xh2 = tl.load(q_ptr + base + HALF + EIGHTH + offs_e).to(tl.float32)
xw1 = tl.load(q_ptr + base + HALF + QUARTER + offs_e).to(tl.float32)
xw2 = tl.load(q_ptr + base + HALF + QUARTER + EIGHTH + offs_e).to(tl.float32)
var_hw = (tl.sum(xh1 * xh1) + tl.sum(xh2 * xh2) + tl.sum(xw1 * xw1) + tl.sum(xw2 * xw2)) * (1.0 / HALF) + eps
irms_hw = tl.rsqrt(var_hw)
wh1 = tl.load(w_qhw_ptr + offs_e).to(tl.float32)
wh2 = tl.load(w_qhw_ptr + EIGHTH + offs_e).to(tl.float32)
ww1 = tl.load(w_qhw_ptr + QUARTER + offs_e).to(tl.float32)
ww2 = tl.load(w_qhw_ptr + QUARTER + EIGHTH + offs_e).to(tl.float32)
xh1 = xh1 * irms_hw * wh1
xh2 = xh2 * irms_hw * wh2
xw1 = xw1 * irms_hw * ww1
xw2 = xw2 * irms_hw * ww2
c_h = tl.load(cos_h_ptr + s * QUARTER + offs_e).to(tl.float32)
s_h = tl.load(sin_h_ptr + s * QUARTER + offs_e).to(tl.float32)
new_xh1 = xh1 * c_h - xh2 * s_h
new_xh2 = xh2 * c_h + xh1 * s_h
c_w = tl.load(cos_w_ptr + s * QUARTER + offs_e).to(tl.float32)
s_w = tl.load(sin_w_ptr + s * QUARTER + offs_e).to(tl.float32)
new_xw1 = xw1 * c_w - xw2 * s_w
new_xw2 = xw2 * c_w + xw1 * s_w

tl.store(q_ptr + base + offs_q, new_xt1.to(tl.bfloat16))
tl.store(q_ptr + base + QUARTER + offs_q, new_xt2.to(tl.bfloat16))
tl.store(q_ptr + base + HALF + offs_e, new_xh1.to(tl.bfloat16))
tl.store(q_ptr + base + HALF + EIGHTH + offs_e, new_xh2.to(tl.bfloat16))
tl.store(q_ptr + base + HALF + QUARTER + offs_e, new_xw1.to(tl.bfloat16))
tl.store(q_ptr + base + HALF + QUARTER + EIGHTH + offs_e, new_xw2.to(tl.bfloat16))
else:
k_pid = pid - q_total
s = k_pid // k_num_heads
h = k_pid % k_num_heads
base = (s * k_num_heads + h) * HEAD_DIM

xt1 = tl.load(k_ptr + base + offs_q).to(tl.float32)
xt2 = tl.load(k_ptr + base + QUARTER + offs_q).to(tl.float32)
var_t = (tl.sum(xt1 * xt1) + tl.sum(xt2 * xt2)) * (1.0 / HALF) + eps
irms_t = tl.rsqrt(var_t)
wt1 = tl.load(w_kt_ptr + offs_q).to(tl.float32)
wt2 = tl.load(w_kt_ptr + QUARTER + offs_q).to(tl.float32)
xt1 = xt1 * irms_t * wt1
xt2 = xt2 * irms_t * wt2
c_t = tl.load(cos_t_ptr + s * HALF + offs_q).to(tl.float32)
s_t = tl.load(sin_t_ptr + s * HALF + offs_q).to(tl.float32)
new_xt1 = xt1 * c_t - xt2 * s_t
new_xt2 = xt2 * c_t + xt1 * s_t

xh1 = tl.load(k_ptr + base + HALF + offs_e).to(tl.float32)
xh2 = tl.load(k_ptr + base + HALF + EIGHTH + offs_e).to(tl.float32)
xw1 = tl.load(k_ptr + base + HALF + QUARTER + offs_e).to(tl.float32)
xw2 = tl.load(k_ptr + base + HALF + QUARTER + EIGHTH + offs_e).to(tl.float32)
var_hw = (tl.sum(xh1 * xh1) + tl.sum(xh2 * xh2) + tl.sum(xw1 * xw1) + tl.sum(xw2 * xw2)) * (1.0 / HALF) + eps
irms_hw = tl.rsqrt(var_hw)
wh1 = tl.load(w_khw_ptr + offs_e).to(tl.float32)
wh2 = tl.load(w_khw_ptr + EIGHTH + offs_e).to(tl.float32)
ww1 = tl.load(w_khw_ptr + QUARTER + offs_e).to(tl.float32)
ww2 = tl.load(w_khw_ptr + QUARTER + EIGHTH + offs_e).to(tl.float32)
xh1 = xh1 * irms_hw * wh1
xh2 = xh2 * irms_hw * wh2
xw1 = xw1 * irms_hw * ww1
xw2 = xw2 * irms_hw * ww2
c_h = tl.load(cos_h_ptr + s * QUARTER + offs_e).to(tl.float32)
s_h = tl.load(sin_h_ptr + s * QUARTER + offs_e).to(tl.float32)
new_xh1 = xh1 * c_h - xh2 * s_h
new_xh2 = xh2 * c_h + xh1 * s_h
c_w = tl.load(cos_w_ptr + s * QUARTER + offs_e).to(tl.float32)
s_w = tl.load(sin_w_ptr + s * QUARTER + offs_e).to(tl.float32)
new_xw1 = xw1 * c_w - xw2 * s_w
new_xw2 = xw2 * c_w + xw1 * s_w

tl.store(k_ptr + base + offs_q, new_xt1.to(tl.bfloat16))
tl.store(k_ptr + base + QUARTER + offs_q, new_xt2.to(tl.bfloat16))
tl.store(k_ptr + base + HALF + offs_e, new_xh1.to(tl.bfloat16))
tl.store(k_ptr + base + HALF + EIGHTH + offs_e, new_xh2.to(tl.bfloat16))
tl.store(k_ptr + base + HALF + QUARTER + offs_e, new_xw1.to(tl.bfloat16))
tl.store(k_ptr + base + HALF + QUARTER + EIGHTH + offs_e, new_xw2.to(tl.bfloat16))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is significant code duplication between the if pid < q_total: block (for Q) and the else: block (for K). The logic for RMSNorm, RoPE application, and storing results is nearly identical.

To improve maintainability and readability, consider refactoring the common logic into a separate tl.device_func that can be called for both Q and K with the appropriate pointers and parameters.

Comment on lines +102 to +112
if self.config.get("load_kv_cache_in_pipeline_for_debug", False):
if self.config.get("version", "moe") == "moe":
self.load_kvcache(
"/data/nvme1/yongyang/FL/neo_test/vlm_tensor/to_x2v_cond_kv.pt",
"/data/nvme1/yongyang/FL/neo_test/vlm_tensor/to_x2v_uncond_kv.pt",
)
else:
self.load_kvcache(
"/data/nvme1/yongyang/FL/neo_test9b/vlm_tensor/to_x2v_cond_kv.pt",
"/data/nvme1/yongyang/FL/neo_test9b/vlm_tensor/to_x2v_uncond_kv.pt",
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The file paths for the KV cache are hardcoded within a debug block. While this is for debugging, hardcoding user-specific paths (/data/nvme1/yongyang/...) is not ideal. It would be better to pass these paths through the configuration, even for debugging purposes, to make the code more portable.

return inv_freq_full[::2]

def _compute_rope(self, position_ids, inv_freq):
inv_freq = inv_freq.cuda()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using .cuda() hardcodes the device. It's better practice to use .to(device) to make the code more device-agnostic. In this case, using the device of position_ids would be appropriate.

Suggested change
inv_freq = inv_freq.cuda()
inv_freq = inv_freq.to(position_ids.device)

logger.info(f"✅ Image saved successfully to: {self.input_info.save_result_path} ✅")
return grid_image

def _denorm(self, x: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The _denorm method does not use any instance attributes (self). It can be defined as a static method using the @staticmethod decorator. This makes it clear that the method's behavior does not depend on the state of the object.

    @staticmethod
    def _denorm(x: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):

Comment on lines +69 to +81
# elif self.time_schedule == "dynamic":
# """
# not support yet
# """
# mu = self._calculate_dynamic_mu(image_seq_len)
# mu_t = t.new_tensor(mu)
# if self.time_shift_type == "exponential":
# shift = torch.exp(mu_t)
# sigma = shift * sigma / (1 + (shift - 1) * sigma)
# elif self.time_shift_type == "linear":
# sigma = mu_t / (mu_t + (1 / sigma - 1))
# else:
# raise ValueError(f"Unsupported time_shift_type: {self.time_shift_type}")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This elif block for self.time_schedule == "dynamic" is commented out and marked as 'not support yet'. This dead code should be removed to clean up the implementation. If this feature is planned for the future, it should be tracked in an issue rather than left as commented-out code.

Comment on lines +4 to +5
lightx2v_path=/data/nvme1/yongyang/FL/LightX2V
model_path=/data/nvme1/yongyang/FL/neo9b/neo9b
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The script contains hardcoded paths specific to a user's environment (e.g., /data/nvme1/yongyang/...). This makes the script not portable. It's better to use environment variables that can be set externally or passed as arguments to the script.

Suggested change
lightx2v_path=/data/nvme1/yongyang/FL/LightX2V
model_path=/data/nvme1/yongyang/FL/neo9b/neo9b
lightx2v_path=${LIGHTX2V_PATH:-/path/to/LightX2V}
model_path=${MODEL_PATH:-/path/to/model}

@llmc-reviewer llmc-reviewer merged commit f6e7696 into main Mar 30, 2026
2 checks passed
@llmc-reviewer llmc-reviewer deleted the neo branch March 30, 2026 08:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants