Skip to content

Commit 56539b9

Browse files
committed
Add test registering >10 custom gradients and verifying registry presence
This adds a regression test for PR #632. The test dynamically discovers op types with no registered gradient (using TF_GetAllOpList + TensorFlow.hasGradient), registers 11 custom gradients, and verifies that all are present in the native gradient registry. This directly validates that registering more than 10 gradients works and that all entries are correctly stored in the native registry, without relying on Graph.addGradients() execution. Addresses reviewer comment about missing test for >10 gradients.
1 parent cdd943c commit 56539b9

File tree

1 file changed

+102
-0
lines changed

1 file changed

+102
-0
lines changed
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
/*
2+
Copyright 2026 The TensorFlow Authors. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
==============================================================================
16+
*/
17+
package org.tensorflow;
18+
19+
import static org.junit.jupiter.api.Assertions.assertFalse;
20+
import static org.junit.jupiter.api.Assertions.assertTrue;
21+
22+
import java.util.ArrayList;
23+
import java.util.Collections;
24+
import java.util.List;
25+
import org.junit.jupiter.api.Test;
26+
import org.tensorflow.internal.c_api.TF_Buffer;
27+
import org.tensorflow.internal.c_api.global.tensorflow;
28+
import org.tensorflow.proto.OpDef;
29+
import org.tensorflow.proto.OpList;
30+
31+
/**
32+
* Regression test for PR #632: Registering >10 custom gradients must work; all registered opTypes
33+
* must be present in the native gradient registry (TensorFlow.hasGradient(opType) becomes true).
34+
*/
35+
public class CustomGradientRegistryCapacityTest {
36+
37+
private static List<String> listAllOpTypes() {
38+
TF_Buffer buf = tensorflow.TF_GetAllOpList();
39+
try {
40+
OpList opList = OpList.parseFrom(buf.dataAsByteBuffer());
41+
List<String> names = new ArrayList<>(opList.getOpCount());
42+
for (OpDef op : opList.getOpList()) {
43+
names.add(op.getName());
44+
}
45+
Collections.sort(names);
46+
return names;
47+
} catch (Exception e) {
48+
throw new RuntimeException("Failed to parse TF_GetAllOpList()", e);
49+
}
50+
}
51+
52+
@Test
53+
public void registerMoreThanTenGradients_thenHasGradientIsTrue() {
54+
55+
// 1) Discover op types that currently have NO gradient registered.
56+
List<String> ops = listAllOpTypes();
57+
58+
List<String> noGrad = new ArrayList<>();
59+
for (String opType : ops) {
60+
if (!TensorFlow.hasGradient(opType)) {
61+
// Avoid internal/private ops to reduce risk of weirdness (optional)
62+
if (!opType.startsWith("_")) {
63+
noGrad.add(opType);
64+
}
65+
}
66+
}
67+
68+
// 2) Pick 11 opTypes (stable order: alphabetical).
69+
// We intentionally pick from "noGrad" so the test is meaningful:
70+
// before: hasGradient=false, after register: true.
71+
assertTrue(noGrad.size() >= 11, "Need at least 11 ops with no gradient in this runtime.");
72+
73+
List<String> selected = noGrad.subList(0, 11);
74+
75+
// 3) Before: ensure hasGradient is false
76+
for (String opType : selected) {
77+
assertFalse(
78+
TensorFlow.hasGradient(opType),
79+
"Precondition failed: expected no gradient for " + opType);
80+
}
81+
82+
// 4) Register 11 custom gradients (simple zerosLike per input)
83+
for (String opType : selected) {
84+
TensorFlow.registerCustomGradient(
85+
opType,
86+
(tf, op, gradInputs) -> {
87+
int n = op.numInputs();
88+
ArrayList<org.tensorflow.Operand<?>> grads = new ArrayList<>(n);
89+
for (int i = 0; i < n; i++) {
90+
grads.add(tf.zerosLike(op.input(i)));
91+
}
92+
return grads;
93+
});
94+
}
95+
96+
// 5) After: ensure hasGradient is true for all
97+
for (String opType : selected) {
98+
assertTrue(
99+
TensorFlow.hasGradient(opType), "Expected gradient to be registered for " + opType);
100+
}
101+
}
102+
}

0 commit comments

Comments
 (0)