Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces the NeoPP model (Mixture-of-Transformer-Experts) with dense and MoE variants, featuring optimized Triton kernels for fused RMSNorm and 3D RoPE, a KV cache manager, and a new runner. The pipeline and input info utilities were refactored to support dynamic task handling. Feedback highlights several issues, including a typo in unpatchify, potential TypeError in the runner's cleanup logic, usage of mutable default arguments, and hardcoded device calls. Additionally, the removal of global task validation should be addressed to maintain safety checks across models.
I am having trouble creating individual review comments. Click here to see my feedback.
lightx2v/models/networks/neopp/model.py (96)
Typo in method parameter name: sle should be self.
def unpatchify(self, x, patch_size, h=None, w=None):
lightx2v/models/runners/default_runner.py (543)
The _empty_cache variable might be None if torch_device_module does not have an empty_cache attribute. Calling it directly will cause a TypeError.
if _empty_cache is not None:
_empty_cache()
lightx2v/utils/input_info.py (322-323)
Avoid using mutable default arguments like support_tasks=[]. This can lead to unexpected behavior as the list is shared across all calls to the function.
def init_empty_input_info(task, support_tasks=None):
if support_tasks is None or len(support_tasks) == 0:
lightx2v/models/runners/neopp/neopp_runner.py (55)
Hardcoding .cuda() prevents the model from running on other supported devices (e.g., CPU, NPU, XPU). Use the device of the input tensors instead.
inv_freq = inv_freq.to(position_ids.device)
lightx2v/models/runners/neopp/neopp_runner.py (76)
This line contains a 'Need Check' comment with multiple exclamation marks, suggesting that the logic for N might be incomplete or requires verification. Please resolve this before merging.
lightx2v/models/runners/neopp/neopp_runner.py (208)
Using np.uint8 requires numpy to be imported as np. Alternatively, you can use the string "uint8" to avoid the dependency on the alias in this line.
image = (image.clamp(0, 1).permute(0, 2, 3, 1).cpu().numpy() * 255.0).round().astype("uint8")
lightx2v/models/runners/neopp/neopp_runner.py (214-219)
Avoid using mutable default arguments for mean and std. Use None as the default value and initialize them inside the method.
def _denorm(self, x: torch.Tensor, mean=None, std=None):
"""
x: [B,3,H,W] normalized ((img-mean)/std). returns [0,1] clamped.
"""
mean = torch.tensor(mean if mean is not None else [0.5, 0.5, 0.5], device=x.device, dtype=x.dtype).view(1, 3, 1, 1)
std = torch.tensor(std if std is not None else [0.5, 0.5, 0.5], device=x.device, dtype=x.dtype).view(1, 3, 1, 1)
lightx2v/models/networks/neopp/infer/pre_infer.py (181-184)
Moving RoPE cache tensors to the device in every extract_feature call is inefficient. This should be done once during initialization or lazily.
if self.cos_cached_x.device != patch_embeds.device:
self.cos_cached_x = self.cos_cached_x.to(patch_embeds.device)
self.sin_cached_x = self.sin_cached_x.to(patch_embeds.device)
self.cos_cached_y = self.cos_cached_y.to(patch_embeds.device)
self.sin_cached_y = self.sin_cached_y.to(patch_embeds.device)
lightx2v/infer.py (173)
The task argument validation has been commented out and the underlying function validate_task_arguments was removed from utils.py. This removes important safety checks for all models. If the validation logic needs to change for Neo++, it should be updated rather than disabled globally.
No description provided.