Skip to content

4609: Add AUC-Margin Loss for AUROC optimization#8719

Open
shubham-61969 wants to merge 7 commits intoProject-MONAI:devfrom
shubham-61969:4609-aucm-loss
Open

4609: Add AUC-Margin Loss for AUROC optimization#8719
shubham-61969 wants to merge 7 commits intoProject-MONAI:devfrom
shubham-61969:4609-aucm-loss

Conversation

@shubham-61969
Copy link
Copy Markdown
Contributor

Fixes #4609.

Description

This PR adds an implementation of AUC-Margin Loss (AUCM) for direct AUROC optimization in MONAI, based on:

Yuan et al., “Large-scale Robust Deep AUC Maximization: A New Surrogate Loss and Empirical Studies on Medical Image Classification”, ICCV 2021.

Implementation based on:
https://github.com/Optimization-AI/LibAUC/blob/1.4.0/libauc/losses/auc.py

The loss is designed for imbalanced classification problems, which are common in medical imaging, where AUROC is often the primary evaluation metric. The implementation follows MONAI’s loss conventions, is fully PyTorch-native, and does not introduce any new dependencies.

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • New tests added to cover the changes.
  • Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • Quick tests passed locally by running ./runtests.sh --quick --unittests --disttests.
  • In-line docstrings updated.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Jan 25, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds a new AUCMLoss class implementing a margin-based squared-hinge surrogate for AUROC optimization (monai/losses/aucm_loss.py) and exports it from monai/losses/init.py. AUCMLoss introduces learnable scalar parameters a, b, and alpha; supports two versions ("v1" with optional class-prior imratio, "v2" without); validates inputs (minimum 2 dims, single-channel binary targets, matching shapes, binary values); flattens tensors and builds positive/negative masks; computes global or class-conditional masked means via helpers _global_mean and _class_mean; and combines terms into the version-specific loss. Adds unit tests (tests/losses/test_aucm_loss.py) covering both versions, invalid constructor and runtime inputs, backward pass, and TorchScript serialization.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Title check ✅ Passed Title clearly and specifically describes the main change: adding AUC-Margin Loss for AUROC optimization, matching the pull request's core contribution.
Description check ✅ Passed Description includes all essential elements: issue reference (#4609), clear explanation of the feature, research basis, implementation source, and checked boxes for testing and documentation updates.
Linked Issues check ✅ Passed PR fully addresses issue #4609 by implementing AUCMLoss for AUROC optimization in medical image classification, with PyTorch-native implementation, proper docstrings, and comprehensive tests.
Out of Scope Changes check ✅ Passed All changes are scoped to the AUCMLoss feature: new loss module, public API export, and corresponding unit tests. No extraneous modifications present.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@monai/losses/aucm_loss.py`:
- Around line 109-135: The masked-mean is wrong because _safe_mean uses
tensor.count_nonzero() which drops valid zero residuals; change _safe_mean to
accept an optional mask parameter (e.g., def _safe_mean(self, tensor:
torch.Tensor, mask: Optional[torch.Tensor]=None)) and use mask.sum() as the
denominator when a mask is provided (falling back to tensor.numel() or
tensor.sum() logic when mask is None), returning a zero tensor on denom == 0;
update all calls (the three places calling self._safe_mean((input - self.a) ** 2
* pos_mask), self._safe_mean((input - self.b) ** 2 * neg_mask),
self._safe_mean(p * input * neg_mask - (1 - p) * input * pos_mask), and the
other input*pos/neg calls) to pass the corresponding pos_mask or neg_mask as the
mask argument so the mean divides by mask.sum() instead of count_nonzero().
🧹 Nitpick comments (2)
tests/losses/test_aucm_loss.py (1)

22-74: Add Google‑style docstrings for the test class/methods.
Docstrings are missing on TestAUCMLoss and its test methods; add brief Google‑style docstrings (Args/Returns as needed). As per coding guidelines, ...

monai/losses/aucm_loss.py (1)

90-98: Add a Returns section to the forward docstring.
Document the scalar loss output in Google style. As per coding guidelines, ...

✍️ Suggested docstring tweak
@@
         Args:
             input: the shape should be B1HW[D], where the channel dimension is 1 for binary classification.
             target: the shape should be B1HW[D], with values 0 or 1.
 
+        Returns:
+            torch.Tensor: scalar AUCM loss.
+
         Raises:
             ValueError: When input or target have incorrect shapes.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In `@monai/losses/aucm_loss.py`:
- Around line 101-112: Validate that target contains only binary values before
creating pos_mask/neg_mask: after the existing shape checks and after flattening
target (the variable target), assert all elements are either 0 or 1 (e.g., check
torch.logical_or(target == 0, target == 1).all()) and raise a ValueError with a
clear message if not; this prevents creation of incorrect pos_mask/neg_mask from
non-binary targets and keeps the subsequent code that builds pos_mask and
neg_mask unchanged.
- Around line 51-66: The constructor accepts a reduction parameter that is never
used and forward always returns a scalar; remove the unused reduction parameter
from the __init__ signature (and any stored attribute), update the docstring to
stop documenting reduction, and adjust any callers/tests that pass reduction;
alternatively, if you want to preserve the API, explicitly store reduction in
__init__ and add a clear docstring note (or raise if it's not the supported
scalar-only behavior) so users know forward returns a batch-level scalar —
locate references to reduction in the class constructor and the forward method
to implement one of these two fixes.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Fix all issues with AI agents
In `@monai/losses/aucm_loss.py`:
- Around line 77-85: The constructor currently assigns imratio without
validation; add a bounds check in the __init__ of the AUCMLoss (or the class
that contains self.imratio) to ensure imratio is within a valid range (e.g., 0 <
imratio < 1) and raise a ValueError with a clear message if out of range,
placing the check near the existing version validation before setting
self.imratio so invalid values are rejected early.
- Around line 100-105: The guard currently accesses input.shape[1] and
target.shape[1] without verifying tensor rank, which will raise IndexError for
1D tensors; update the checks in the AUCM loss (where variables input and target
are validated) to first assert or check the rank/ndim (e.g., len(input.shape) or
input.ndim) before indexing shape[1], raising a clear ValueError if tensors are
1D or have unexpected rank, then validate that the channel dimension equals 1
and that input.shape matches target.shape; this ensures safe access to shape[1]
and clearer error messages.

In `@tests/losses/test_aucm_loss.py`:
- Around line 72-78: The test_non_binary_target currently fails due to a shape
mismatch before the non-binary check; update the test so the target tensor has
the same number of rows as the input (32) and includes non-binary values to
trigger AUCMLoss's non-binary branch. Concretely, in test_non_binary_target
ensure input = torch.randn(32, 1) remains and construct target with 32 elements
(e.g., torch.tensor([0.5, 1.0, 2.0] * 10 + [0.0]).view(32,1) or equivalent) so
the shape validation passes and the ValueError for non-binary values in AUCMLoss
is exercised.
🧹 Nitpick comments (1)
monai/losses/aucm_loss.py (1)

139-140: Expand _safe_mean docstring to Google-style.
Include Args/Returns to match the project’s docstring rules. As per coding guidelines.

Proposed docstring
-    def _safe_mean(self, tensor: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
-        """Compute mean safely over masked elements."""
+    def _safe_mean(self, tensor: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
+        """
+        Compute mean safely over masked elements.
+
+        Args:
+            tensor: values to average.
+            mask: binary mask selecting elements to include.
+
+        Returns:
+            torch.Tensor: scalar masked mean (0 if mask is empty).
+        """

@KumoLiu
Copy link
Copy Markdown
Contributor

KumoLiu commented Jan 27, 2026

/build

@github-actions
Copy link
Copy Markdown

👎 Promotion blocked, new vulnerability found

Vulnerability report

Component Vulnerability Description Severity
monai CVE-2026-21851 MONAI (Medical Open Network for AI) is an AI toolkit for health care imaging. In versions up to and including 1.5.1, a Path Traversal (Zip Slip) vulnerability exists in MONAI's _download_from_ngc_private() function. The function uses zipfile.ZipFile.extractall() without path validation, while other similar download functions in the same codebase properly use the existing safe_extract_member() function. Commit 4014c84 fixes this issue. MEDIUM
Project-MONAI/MONAI CVE-2026-21851 MONAI (Medical Open Network for AI) is an AI toolkit for health care imaging. In versions up to and including 1.5.1, a Path Traversal (Zip Slip) vulnerability exists in MONAI's _download_from_ngc_private() function. The function uses zipfile.ZipFile.extractall() without path validation, while other similar download functions in the same codebase properly use the existing safe_extract_member() function. Commit 4014c84 fixes this issue. MEDIUM

@shubham-61969
Copy link
Copy Markdown
Contributor Author

Hi @KumoLiu , thanks for triggering the build.
I see the vulnerability report refers to an existing issue in MONAI unrelated to this PR.

Are there any changes needed from my side for this PR?

@KumoLiu
Copy link
Copy Markdown
Contributor

KumoLiu commented Jan 29, 2026

Are there any changes needed from my side for this PR?

Hi @shubham-61969, no need, we are releasing a minor version to fix the issue.

Copy link
Copy Markdown
Member

@ericspod ericspod left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @shubham-61969 thanks for this submission. I had some minor comments on some things, specifically I had refactored the loss calculation to reduce space and hopefully make it more understandable. I haven't tested it so please double check what I've done, if you prefer your explicit version we can leave that in. Any more commentary to help explain the loss wouldn't hurt either in the code. I had minor comments on the tests as well.

Comment on lines +122 to +141
if self.version == "v1":
p = float(self.imratio) if self.imratio is not None else float(pos_mask.mean().item())
loss = (
(1 - p) * self._safe_mean((input - self.a) ** 2, pos_mask)
+ p * self._safe_mean((input - self.b) ** 2, neg_mask)
+ 2
* self.alpha
* (
p * (1 - p) * self.margin
+ self._safe_mean(p * input * neg_mask - (1 - p) * input * pos_mask, pos_mask + neg_mask)
)
- p * (1 - p) * self.alpha**2
)
else:
loss = (
self._safe_mean((input - self.a) ** 2, pos_mask)
+ self._safe_mean((input - self.b) ** 2, neg_mask)
+ 2 * self.alpha * (self.margin + self._safe_mean(input, neg_mask) - self._safe_mean(input, pos_mask))
- self.alpha**2
)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if self.version == "v1":
p = float(self.imratio) if self.imratio is not None else float(pos_mask.mean().item())
loss = (
(1 - p) * self._safe_mean((input - self.a) ** 2, pos_mask)
+ p * self._safe_mean((input - self.b) ** 2, neg_mask)
+ 2
* self.alpha
* (
p * (1 - p) * self.margin
+ self._safe_mean(p * input * neg_mask - (1 - p) * input * pos_mask, pos_mask + neg_mask)
)
- p * (1 - p) * self.alpha**2
)
else:
loss = (
self._safe_mean((input - self.a) ** 2, pos_mask)
+ self._safe_mean((input - self.b) ** 2, neg_mask)
+ 2 * self.alpha * (self.margin + self._safe_mean(input, neg_mask) - self._safe_mean(input, pos_mask))
- self.alpha**2
)
if self.version == "v1":
p = float(self.imratio if self.imratio is not None else pos_mask.mean().item())
p1 = 1 - p
safe_mean_pos_neg = self._safe_mean(p * input * neg_mask - p1 * input * pos_mask, pos_mask + neg_mask)
else:
p = p1 = 1.0 # positive sample ratio discounted in this version by setting coefficients to 1
safe_mean_pos_neg = self._safe_mean(input, neg_mask) - self._safe_mean(input, pos_mask)
mean_a = p1 * self._safe_mean((input - self.a) ** 2, pos_mask)
mean_b = p * self._safe_mean((input - self.b) ** 2, neg_mask)
loss = mean_a + mean_b + 2 * self.alpha * (p * p1 * self.margin + safe_mean_pos_neg) - p * p1 * self.alpha**2

If I haven't messed up the refactor of the loss equations, this would help simplify the code and reduce redundant calculations. With a bit more commentary this may be easier to understand.

Comment on lines +25 to +41
def test_v1(self):
"""Test AUCMLoss with version 'v1'."""
loss_fn = AUCMLoss(version="v1")
input = torch.randn(32, 1, requires_grad=True)
target = torch.randint(0, 2, (32, 1)).float()
loss = loss_fn(input, target)
self.assertIsInstance(loss, torch.Tensor)
self.assertEqual(loss.ndim, 0)

def test_v2(self):
"""Test AUCMLoss with version 'v2'."""
loss_fn = AUCMLoss(version="v2")
input = torch.randn(32, 1, requires_grad=True)
target = torch.randint(0, 2, (32, 1)).float()
loss = loss_fn(input, target)
self.assertIsInstance(loss, torch.Tensor)
self.assertEqual(loss.ndim, 0)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests are a good start but I think we need a few which tests the output values themselves to ensure the calculation is what it should be. You can precompute some values and store them as globals at the top of this file, look at other test files to see how this is done with parameterized.

self.assertIsInstance(loss, torch.Tensor)
self.assertEqual(loss.ndim, 0)

def test_invalid_version(self):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The following tests are good too but could be condensed using parameterized instead if you wanted to do that.

shubham-61969 and others added 6 commits March 29, 2026 02:56
Signed-off-by: Shubham Chandravanshi <shubham.chandravanshi378@gmail.com>
Signed-off-by: Shubham Chandravanshi <shubham.chandravanshi378@gmail.com>
Signed-off-by: Shubham Chandravanshi <shubham.chandravanshi378@gmail.com>
…-binary target test

Signed-off-by: Shubham Chandravanshi <shubham.chandravanshi378@gmail.com>
Signed-off-by: Shubham Chandravanshi <shubham.chandravanshi378@gmail.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (1)
tests/losses/test_aucm_loss.py (1)

60-113: Add a regression test for zero-element tensors.

Given the new loss math, include one explicit empty-input test to lock down expected behavior (raise vs finite scalar).

As per coding guidelines "**/*.py: Ensure new or modified definitions will be covered by existing or new unit tests."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/losses/test_aucm_loss.py` around lines 60 - 113, Add a regression unit
test in tests/losses/test_aucm_loss.py that verifies AUCMLoss behavior on empty
tensors: call AUCMLoss() with input and target tensors that have zero batch size
but correct shape (e.g., torch.empty(0,1) and torch.empty(0,1)), then assert the
expected outcome (either raise ValueError or return a finite scalar) to lock
down the behavior; name the test function test_zero_element_tensors and place it
alongside test_invalid_input_shape/test_insufficient_dimensions so CI covers the
new edge case and clarifies intended behavior for AUCMLoss.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@monai/losses/aucm_loss.py`:
- Around line 132-138: Guard against empty tensors before calling .mean():
replace uses of pos_mask.mean().item() and other .mean() calls that can operate
on empty tensors by checking tensor.numel()>0 (or input.numel()>0) and returning
a safe default (e.g., 0.0) when empty; specifically update the p computation
(currently p = float(self.imratio) if self.imratio is not None else
float(pos_mask.mean().item())) to use pos_mask.mean().item() only when
pos_mask.numel()>0, and similarly protect the other .mean() uses and the calls
to _global_mean that depend on mean values (mean_pos, mean_neg, interaction) so
they receive 0.0 or an appropriate zero-tensor/scalar on empty input to avoid
NaNs during loss computation.
- Line 90: The forward method parameter named input in the class/method forward
should be renamed (e.g., to pred or logits) to avoid shadowing the builtin
input(); update the parameter name in the def forward signature and replace all
uses and the local reassignment (previously input = ...) inside forward (and any
internal variable references) to the new name, and update any tests that call
forward or construct tensors/kwargs using the old parameter name so all
references are consistent.

In `@tests/losses/test_aucm_loss.py`:
- Around line 36-38: Rename the local variable named `input` in
tests/losses/test_aucm_loss.py to avoid shadowing the builtin; update all
occurrences (e.g., in the tensor creation and any subsequent uses) to a
non-built-in name like `x`, `model_input`, or `data` (lines referenced by the
reviewer: 36, 63, 71, 79, 87, 95, 103); ensure you preserve attributes such as
requires_grad=True and the calls to loss_fn(target, ...) or loss_fn(...) by
replacing `input` with the new name everywhere in that test file so tests
continue to reference the same tensor variable.

---

Nitpick comments:
In `@tests/losses/test_aucm_loss.py`:
- Around line 60-113: Add a regression unit test in
tests/losses/test_aucm_loss.py that verifies AUCMLoss behavior on empty tensors:
call AUCMLoss() with input and target tensors that have zero batch size but
correct shape (e.g., torch.empty(0,1) and torch.empty(0,1)), then assert the
expected outcome (either raise ValueError or return a finite scalar) to lock
down the behavior; name the test function test_zero_element_tensors and place it
alongside test_invalid_input_shape/test_insufficient_dimensions so CI covers the
new edge case and clarifies intended behavior for AUCMLoss.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 17aec3c2-58d9-48fe-83aa-9b7c699b82a5

📥 Commits

Reviewing files that changed from the base of the PR and between 448c5df and 25c2702.

📒 Files selected for processing (3)
  • monai/losses/__init__.py
  • monai/losses/aucm_loss.py
  • tests/losses/test_aucm_loss.py
✅ Files skipped from review due to trivial changes (1)
  • monai/losses/init.py

self.b = nn.Parameter(torch.tensor(0.0))
self.alpha = nn.Parameter(torch.tensor(0.0))

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "Checking Ruff config for A001/A002 handling..."
fd -HI 'pyproject.toml|ruff.toml|setup.cfg' | while read -r f; do
  echo "--- $f"
  rg -n '(^\[tool\.ruff|select|extend-select|ignore|A001|A002)' "$f" || true
done

echo
echo "Checking builtin-shadow occurrences in changed files..."
rg -nP --type=py '\binput\b' monai/losses/aucm_loss.py tests/losses/test_aucm_loss.py

Repository: Project-MONAI/MONAI

Length of output: 3826


🏁 Script executed:

# Get the full Ruff configuration content
echo "=== pyproject.toml ruff config ==="
rg -A 100 '\[tool\.ruff\.lint\.per-file-ignores\]' pyproject.toml

echo -e "\n=== Checking if A001/A002 is globally selected ==="
rg -B 5 -A 20 'select = \[' pyproject.toml | head -40

echo -e "\n=== ruff.toml full content ==="
cat ruff.toml

Repository: Project-MONAI/MONAI

Length of output: 3820


Rename parameter input to avoid shadowing Python builtin.

Parameter input at line 90 and reassignment at line 113 shadow the built-in input() function, violating PEP8 naming conventions. While not flagged by the default Ruff config, this would trigger A001/A002 in stricter configs (ruff.toml). Rename to pred or logits and update all references throughout the method and tests.

🧰 Tools
🪛 Ruff (0.15.7)

[error] 90-90: Function argument input is shadowing a Python builtin

(A002)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@monai/losses/aucm_loss.py` at line 90, The forward method parameter named
input in the class/method forward should be renamed (e.g., to pred or logits) to
avoid shadowing the builtin input(); update the parameter name in the def
forward signature and replace all uses and the local reassignment (previously
input = ...) inside forward (and any internal variable references) to the new
name, and update any tests that call forward or construct tensors/kwargs using
the old parameter name so all references are consistent.

Comment on lines +132 to +138
p = float(self.imratio) if self.imratio is not None else float(pos_mask.mean().item())
p1 = 1.0 - p

mean_pos = self._global_mean(mean_pos_sq, pos_mask)
mean_neg = self._global_mean(mean_neg_sq, neg_mask)

interaction = self._global_mean(p * input * neg_mask - p1 * input * pos_mask, pos_mask + neg_mask)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Guard empty tensors to prevent NaN loss values.

Line 132 (pos_mask.mean()) and Line 178 (.mean()) return NaN when input.numel() == 0 (e.g., shape (0, 1, ...)). This silently corrupts training.

🐛 Proposed fix
@@
         input = input.flatten()
         target = target.flatten()
+        if input.numel() == 0:
+            raise ValueError("Input and target must contain at least one element.")
@@
     def _global_mean(self, tensor: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
@@
-        return (tensor * mask).mean()
+        masked = tensor * mask
+        if masked.numel() == 0:
+            return torch.zeros((), dtype=tensor.dtype, device=tensor.device)
+        return masked.mean()

Also applies to: 178-178

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@monai/losses/aucm_loss.py` around lines 132 - 138, Guard against empty
tensors before calling .mean(): replace uses of pos_mask.mean().item() and other
.mean() calls that can operate on empty tensors by checking tensor.numel()>0 (or
input.numel()>0) and returning a safe default (e.g., 0.0) when empty;
specifically update the p computation (currently p = float(self.imratio) if
self.imratio is not None else float(pos_mask.mean().item())) to use
pos_mask.mean().item() only when pos_mask.numel()>0, and similarly protect the
other .mean() uses and the calls to _global_mean that depend on mean values
(mean_pos, mean_neg, interaction) so they receive 0.0 or an appropriate
zero-tensor/scalar on empty input to avoid NaNs during loss computation.

Comment on lines +36 to +38
input = torch.randn(32, 1, requires_grad=True)
target = torch.randint(0, 2, (32, 1)).float()
loss = loss_fn(input, target)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail
rg -nP --type=py '^\s*input\s*=' tests/losses/test_aucm_loss.py

Repository: Project-MONAI/MONAI

Length of output: 396


🏁 Script executed:

#!/bin/bash
# Check the file exists and get context around the flagged lines
wc -l tests/losses/test_aucm_loss.py

Repository: Project-MONAI/MONAI

Length of output: 97


🏁 Script executed:

#!/bin/bash
# Show context around the first flagged instance and a few others
sed -n '34,40p' tests/losses/test_aucm_loss.py
echo "---"
sed -n '61,68p' tests/losses/test_aucm_loss.py
echo "---"
sed -n '101,108p' tests/losses/test_aucm_loss.py

Repository: Project-MONAI/MONAI

Length of output: 1015


🏁 Script executed:

#!/bin/bash
# Verify Ruff can detect this issue
rg -nP --type=py '^\s*input\s*=' tests/losses/test_aucm_loss.py | wc -l

Repository: Project-MONAI/MONAI

Length of output: 64


Rename local input variables in tests to avoid builtin shadowing.

These lines shadow Python's builtin input, triggering Ruff A001. Rename to something like x, model_input, or data in lines 36, 63, 71, 79, 87, 95, 103 to comply with PEP8 style guidelines.

🧰 Tools
🪛 Ruff (0.15.7)

[error] 36-36: Variable input is shadowing a Python builtin

(A001)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/losses/test_aucm_loss.py` around lines 36 - 38, Rename the local
variable named `input` in tests/losses/test_aucm_loss.py to avoid shadowing the
builtin; update all occurrences (e.g., in the tensor creation and any subsequent
uses) to a non-built-in name like `x`, `model_input`, or `data` (lines
referenced by the reviewer: 36, 63, 71, 79, 87, 95, 103); ensure you preserve
attributes such as requires_grad=True and the calls to loss_fn(target, ...) or
loss_fn(...) by replacing `input` with the new name everywhere in that test file
so tests continue to reference the same tensor variable.

Signed-off-by: Shubham Chandravanshi <shubham.chandravanshi378@gmail.com>
@shubham-61969
Copy link
Copy Markdown
Contributor Author

Hi @ericspod , thanks for the detailed feedback and the suggested refactor! Also, apologies for the delayed response on this PR.

I carefully review the formulation against both the paper and the original LibAUC implementation. One key point I found is that v1 and v2 use different types of expectations:

  • v1 uses global expectations (normalized by total number of samples)
  • v2 uses class-conditional expectations (normalized by the number of samples in each class)

This distinction is present both in the paper formulation and in the LibAUC reference implementation. Because of this, fully unifying the expressions for v1 and v2 would change the intended behavior of the loss. I’ve therefore kept the two paths semantically distinct while refactoring the code to improve readability and reduce repetition.

I also added comments in the implementation to clarify this difference so future readers don’t accidentally simplify it in a way that alters the behavior.

On the testing side:

  • I refactored the version tests using parameterized to reduce duplication.
  • I added deterministic test cases with manually verified expected values to ensure correctness.

At the moment, I’ve included two deterministic cases (one for v1 and one for v2) using small inputs where the expected values can be derived reliably. For more complex or higher-dimensional inputs, computing exact expected values becomes less straightforward.

Would you recommend adding more deterministic cases similar to other losses, or is this level of coverage sufficient given the complexity of the formulation?

@shubham-61969 shubham-61969 requested a review from ericspod March 29, 2026 19:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Feature request: AUCMLoss

3 participants