Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 251a644

Browse files
committed
test ci
1 parent e36c9f0 commit 251a644

File tree

4 files changed

+156
-18
lines changed

4 files changed

+156
-18
lines changed

python/mxnet/_ctypes/cached_op.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def __call__(self, *args, **kwargs):
7777
if not default_device:
7878
default_device = kwargs.pop('default_ctx', None)
7979
out = kwargs.pop('out', None)
80+
nleaf_vars = [container.data() for container in kwargs.pop('_nleaf_vars', [])]
8081
if kwargs:
8182
raise TypeError(
8283
"CachedOp.__call__ got unexpected keyword argument(s): " + \
@@ -93,7 +94,10 @@ def __call__(self, *args, **kwargs):
9394
*args,
9495
type_id,
9596
device_id,
96-
*out_arg
97+
len(out_arg),
98+
*out_arg,
99+
len(nleaf_vars),
100+
*nleaf_vars
97101
)
98102
if out is not None:
99103
return out

python/mxnet/gluon/block.py

Lines changed: 91 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,14 @@
3333
import json
3434
import numpy as np
3535

36-
from ..base import mx_real_t, MXNetError, NDArrayHandle, SymbolHandle, py_str, check_call, _LIB
36+
from ..base import mx_real_t, MXNetError, NDArrayHandle, SymbolHandle, py_str, check_call, _LIB, \
37+
_as_list
3738
from .. import symbol, ndarray, initializer, autograd, _deferred_compute as dc, name as _name, \
3839
profiler as _profiler, device as _device
3940
from ..symbol.numpy import _symbol as np_symbol
4041
from ..symbol import Symbol, fromjson
4142
from ..ndarray import NDArray, get_dtype_name
42-
from .parameter import Parameter, DeferredInitializationError
43+
from .parameter import Parameter, DeferredInitializationError, Intermediate
4344
from .utils import _indent, _brief_print_list, HookHandle, shape_is_known
4445
from .utils import _check_same_symbol_type, _check_all_np_ndarrays, _check_block_input_np_ndarrays
4546
from .. import numpy_extension as _mx_npx
@@ -1091,6 +1092,7 @@ def __init__(self):
10911092
self._backend_opts = {}
10921093
self._partition_if_dynamic = True
10931094
self._first_forward = True
1095+
self._nleaf_vars = OrderedDict()
10941096

10951097
def __setattr__(self, name, value):
10961098
"""Registers parameters."""
@@ -1302,7 +1304,7 @@ def _call_cached_op(self, *args):
13021304
args_without_none = [ele for ele in args if ele is not None]
13031305
cargs = [args_without_none[i] if is_arg else i.data()
13041306
for is_arg, name, i in self._cached_op_args]
1305-
out = self._cached_op(*cargs)
1307+
out = self._cached_op(*cargs, _nleaf_vars=self._nleaf_vars.values())
13061308
if isinstance(out, NDArray):
13071309
out = [out]
13081310
return _regroup(out, self._out_format)
@@ -1678,6 +1680,92 @@ def reset_ctx(self, ctx):
16781680
self.reset_device(ctx)
16791681

16801682

1683+
def intermediate(self, names, var_arrays_inp, grad_req='write'):
1684+
"""Mark the intermediate variables.
1685+
1686+
Parameters
1687+
----------
1688+
name : str or tuple[str], name of the registered intermediate variable
1689+
var_arrays_inp : ndarray or tuple[ndarray], the output of the expression
1690+
grad_req : str, gradient request
1691+
"""
1692+
if not self._active:
1693+
var_arrays = _as_list(var_arrays_inp)
1694+
names = _as_list(names)
1695+
self._nleaf_vars.update(
1696+
{name : Intermediate(name, array, grad_req) for name, array in zip(names, var_arrays)})
1697+
else:
1698+
prev_val = dc.set_deferred_compute(False)
1699+
var_arrays = _as_list(var_arrays_inp)
1700+
names = _as_list(names)
1701+
# Prepare ctypes array types
1702+
import ctypes
1703+
var_handles_type = ctypes.c_void_p * len(var_arrays)
1704+
# Convert handles
1705+
var_handles = var_handles_type(*[arr.handle for arr in var_arrays])
1706+
check_call(_LIB.MXNDArrayMarkDCVariables(var_handles, len(var_arrays), len(self._nleaf_vars)))
1707+
self._nleaf_vars.update(
1708+
{name : Intermediate(name, array, grad_req) for name, array in zip(names, var_arrays)})
1709+
dc.set_deferred_compute(prev_val)
1710+
return var_arrays_inp
1711+
1712+
def attach_grad_intermediate(self):
1713+
"""Attach gradient to all the intermediate variables.
1714+
"""
1715+
for val in self._nleaf_vars.values():
1716+
val.data().attach_grad(grad_req=val.grad_req)
1717+
1718+
def get_intermediate(self, names):
1719+
"""Get the intermediate variables by names
1720+
"""
1721+
if isinstance(names, list):
1722+
return [self._nleaf_vars[n] for n in names]
1723+
else:
1724+
return self._nleaf_vars[names]
1725+
1726+
def intermediate(self, names, var_arrays_inp, grad_req='write'):
1727+
"""Mark the intermediate variables.
1728+
1729+
Parameters
1730+
----------
1731+
name : str or tuple[str], name of the registered intermediate variable
1732+
var_arrays_inp : ndarray or tuple[ndarray], the output of the expression
1733+
grad_req : str, gradient request
1734+
"""
1735+
if not self._active:
1736+
var_arrays = _as_list(var_arrays_inp)
1737+
names = _as_list(names)
1738+
self._nleaf_vars.update(
1739+
{name : Intermediate(name, array, grad_req) for name, array in zip(names, var_arrays)})
1740+
else:
1741+
prev_val = dc.set_deferred_compute(False)
1742+
var_arrays = _as_list(var_arrays_inp)
1743+
names = _as_list(names)
1744+
# Prepare ctypes array types
1745+
import ctypes
1746+
var_handles_type = ctypes.c_void_p * len(var_arrays)
1747+
# Convert handles
1748+
var_handles = var_handles_type(*[arr.handle for arr in var_arrays])
1749+
check_call(_LIB.MXNDArrayMarkDCVariables(var_handles, len(var_arrays), len(self._nleaf_vars)))
1750+
self._nleaf_vars.update(
1751+
{name : Intermediate(name, array, grad_req) for name, array in zip(names, var_arrays)})
1752+
dc.set_deferred_compute(prev_val)
1753+
return var_arrays_inp
1754+
1755+
def attach_grad_intermediate(self):
1756+
"""Attach gradient to all the intermediate variables.
1757+
"""
1758+
for val in self._nleaf_vars.values():
1759+
val.data().attach_grad(grad_req=val.grad_req)
1760+
1761+
def get_intermediate(self, names):
1762+
"""Get the intermediate variables by names
1763+
"""
1764+
if isinstance(names, list):
1765+
return [self._nleaf_vars[n] for n in names]
1766+
else:
1767+
return self._nleaf_vars[names]
1768+
16811769
class SymbolBlock(HybridBlock):
16821770
"""Construct block from symbol. This is useful for using pre-trained models
16831771
as feature extractors. For example, you may want to extract the output

python/mxnet/gluon/parameter.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -773,3 +773,40 @@ def grad_req(self, req):
773773
warnings.warn('Constant parameter "{}" does not support '
774774
'grad_req other than "null", and new value "{}" '
775775
'is ignored.'.format(self.name, req))
776+
777+
class Intermediate:
778+
"""A Container holding marked intermediate variables of Blocks.
779+
780+
Parameters
781+
----------
782+
name : str.
783+
Name of this parameter. It be used to retrieve the marked variables.
784+
grad_req : {'write', 'add', 'null'}, default 'write'
785+
Specifies how to update gradient to grad arrays.
786+
787+
- ``'write'`` means everytime gradient is written to grad :py:class:`NDArray`.
788+
- ``'add'`` means everytime gradient is added to the grad :py:class:`NDArray`. You need
789+
to manually call ``zero_grad()`` to clear the gradient buffer before each
790+
iteration when using this option.
791+
- 'null' means gradient is not requested for this parameter. gradient arrays
792+
will not be allocated.
793+
"""
794+
def __init__(self, name, data=None, grad_req='write'):
795+
self._name = name
796+
self._data = data
797+
self._grad_req = grad_req
798+
799+
def __repr__(self):
800+
s = 'Intermediate name={name}'
801+
return s.format(name=self._name)
802+
803+
def data(self):
804+
return self._data
805+
806+
@property
807+
def name(self):
808+
return self._name
809+
810+
@property
811+
def grad_req(self):
812+
return self._grad_req

tests/python/unittest/test_autograd.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -533,7 +533,7 @@ def test_retain_grad_drop_grad():
533533
z.attach_grad()
534534
out_grad = nd.array([10, 10, 10, 10])
535535
z.backward(out_grad, retain_graph=True)
536-
536+
537537
assert (u.grad == out_grad * x).asnumpy().all()
538538
assert (z.grad == out_grad).asnumpy().all()
539539
assert (x.grad == out_grad * 2 * x * y).asnumpy().all()
@@ -548,39 +548,48 @@ def test_retain_grad_drop_grad():
548548
assert u.grad is None and z.grad is None and y.grad is None
549549
assert (x.grad == out_grad * 2 * x * y).asnumpy().all()
550550

551-
def test_retain_grad_drop_grad_gluon():
552-
class CompBlock(mx.gluon.HybridBlock):
551+
@pytest.fixture(scope="function", params=[True, False])
552+
def test_retain_grad_drop_grad_gluon(request):
553+
class CompBlock(mx.HybridBlock):
553554
def __init__(self):
554555
super().__init__()
555-
self.marked_var = None
556-
def forward(self, a, b):
557-
out1 = a*b
558-
out2 = out1 * a
559-
self.marked_var = out1
556+
557+
def forward(self, a, b, c):
558+
out1 = self.intermediate(('out1_0', 'out1_1'), ((a+b)*c, a*b), grad_req='write')
559+
out2 = self.intermediate('out2', out1[1] * a)
560560
return out2
561+
561562
x = mx.np.array([1,2,3,4])
562563
y = mx.np.array([5,6,7,8])
564+
w = mx.np.array([0.1, 0.1, 0.1, 0.1])
563565
x.attach_grad()
564566
y.attach_grad()
567+
w.attach_grad()
565568
block2 = CompBlock()
566569
block2.initialize()
567-
# block2.hybridize()
570+
param = request.param
571+
if param:
572+
block2.hybridize()
568573
with mx.autograd.record():
569-
z = block2(x, y)
570-
u = block2.marked_var
571-
u.attach_grad()
572-
z.attach_grad()
574+
z = block2(x, y, w)
575+
576+
block2.attach_grad_intermediate()
577+
u0 = block2.get_intermediate('out1_0').data()
578+
u = block2.get_intermediate('out1_1').data()
579+
z = block2.get_intermediate('out2').data()
573580
z.backward(retain_graph=True)
574581

575582
assert (u.grad == x).all()
583+
assert (u0.grad == mx.np.array([0, 0, 0, 0])).all()
576584
assert (z.grad == mx.np.array([1,1,1,1])).all()
577585
assert (x.grad == 2 * x * y).all()
578586
assert (y.grad == x*x).all()
579587

580588
u.drop_grad()
589+
u0.drop_grad()
581590
z.drop_grad()
582591
y.drop_grad()
583592
z.backward()
584593

585-
assert u.grad is None and z.grad is None and y.grad is None
594+
assert u.grad is None and u0.grad is None and y.grad is None and z.grad is None
586595
assert (x.grad == 2 * x * y).all()

0 commit comments

Comments
 (0)