-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[Relax][Op] Fixed incorrect output shape of Pool op when ceil_mode = true #18641
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…true - Skip the last window as it would start in the bottom padded region
Summary of ChangesHello @locnd182644, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request resolves a critical bug affecting the output shape calculation of pooling operations when Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request fixes an issue with incorrect output shape calculation for pooling operations when ceil_mode=True. The change correctly implements the logic to handle edge cases for the last pooling window, aligning the behavior with frameworks like PyTorch. The fix is applied consistently across topi and relax pooling implementations. The logic appears sound and correctly addresses the reported bug. I have one suggestion to refactor the newly added logic in src/relax/op/nn/pooling.cc to reduce code duplication and improve maintainability.
| if (attrs->ceil_mode) { | ||
| PrimExpr invalid_last_w = (raw_out_w - 1) * attrs->strides[0] >= input_w + attrs->padding[0]; | ||
| out_NCW_shape[2] = analyzer->Simplify(if_then_else(invalid_last_w, raw_out_w - 1, raw_out_w)); | ||
| } else { | ||
| out_NCW_shape[2] = analyzer->Simplify(raw_out_w); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic to handle ceil_mode is duplicated across InferStructInfoPool1D, InferStructInfoPool2D, and InferStructInfoPool3D. To improve code maintainability and reduce redundancy, consider extracting this logic into a common helper function. For example:
// Placed within the relax namespace
inline PrimExpr InferPoolOutputSize(arith::Analyzer* analyzer, PrimExpr input_size,
PrimExpr raw_out_size, PrimExpr stride,
PrimExpr padding_before, bool ceil_mode) {
if (ceil_mode) {
PrimExpr invalid_last = (raw_out_size - 1) * stride >= input_size + padding_before;
return analyzer->Simplify(if_then_else(invalid_last, raw_out_size - 1, raw_out_size));
} else {
return analyzer->Simplify(raw_out_size);
}
}Then, this block and similar blocks in InferStructInfoPool2D and InferStructInfoPool3D can be simplified to a single call, for instance:
out_NCW_shape[2] = InferPoolOutputSize(analyzer, input_w, raw_out_w, attrs->strides[0], attrs->padding[0], attrs->ceil_mode);
tlopex
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks!
Summary
Fixed incorrect output shape of Pool op when ceil_mode = true
Steps to Reproduce
Example: Create Pool Operator from PyTorch
Expected
Resolve