Skip to content

Commit 1f9bfd7

Browse files
authored
Make FqnToConfig handle module swap configs (#3492)
**Summary:** `FqnToConfig` does not handle module swap configs like `QATConfig` today because it does not reassign the replaced modules, but instead assumes there is no need to do so because the underlying tensors are swapped. This breaks when the user tries to use `FqnToConfig` with `QATConfig`, which does not rely on tensor subclasses. This commit fixes this by making the replacement logic in `FqnToConfig` more general. Fixes #3490. **Test Plan:** ``` python test/quantization/test_quant_api.py -k test_fqn_config_quantized_nested_module_module_swap ```
1 parent ac535b2 commit 1f9bfd7

File tree

2 files changed

+41
-8
lines changed

2 files changed

+41
-8
lines changed

test/quantization/test_quant_api.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@
4040
LinearActivationQuantizedTensor,
4141
PerGroup,
4242
)
43+
from torchao.quantization.qat import (
44+
FakeQuantizedLinear,
45+
QATConfig,
46+
)
4347
from torchao.quantization.quant_api import (
4448
Float8DynamicActivationFloat8WeightConfig,
4549
Float8StaticActivationFloat8WeightConfig,
@@ -1197,6 +1201,32 @@ def __init__(self):
11971201
assert isinstance(m.nested.linear.weight, AffineQuantizedTensor)
11981202
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
11991203

1204+
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
1205+
def test_fqn_config_quantized_nested_module_module_swap(self):
1206+
class NestedModule(torch.nn.Module):
1207+
def __init__(self):
1208+
super().__init__()
1209+
self.linear = torch.nn.Linear(16, 16)
1210+
1211+
class TopLevelModule(torch.nn.Module):
1212+
def __init__(self):
1213+
super().__init__()
1214+
self.nested = NestedModule()
1215+
self.linear1 = torch.nn.Linear(16, 16)
1216+
1217+
m = TopLevelModule()
1218+
config = QATConfig(Int4WeightOnlyConfig(), step="prepare")
1219+
quant_config = FqnToConfig(
1220+
{
1221+
"nested.linear": config,
1222+
"linear1": config,
1223+
}
1224+
)
1225+
quantize_(m, quant_config, filter_fn=None)
1226+
1227+
assert isinstance(m.nested.linear, FakeQuantizedLinear)
1228+
assert isinstance(m.linear1, FakeQuantizedLinear)
1229+
12001230
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
12011231
def test_fqn_config_quantized_nested_module_param(self):
12021232
class NestedModule(torch.nn.Module):

torchao/quantization/quant_api.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -468,19 +468,23 @@ def quantize_(
468468
raise ValueError(
469469
"Custom filter_fn and FqnToConfig were both specified. Only filter_fn=None is supported when FqnToConfig is specified."
470470
)
471-
472-
for module_fqn, module in model.named_modules():
471+
named_modules = dict(model.named_modules())
472+
for module_fqn, module in named_modules.items():
473473
if (
474474
fqn_matches_fqn_config(module_fqn, config)
475475
or _module_param_matches_fqn_config(module, module_fqn, config)
476476
or ("_default" in config.fqn_to_config and _is_linear(module))
477477
):
478-
# this replaces inplace, so no need to reassign
479-
_fqn_to_config_handler(module, module_fqn, config)
478+
replacement = _fqn_to_config_handler(module, module_fqn, config)
480479
if device is not None:
481-
module.to(device=device)
482-
return
483-
if isinstance(config, AOBaseConfig):
480+
replacement = replacement.to(device=device)
481+
# handle module swap
482+
if replacement is not module and module_fqn != "":
483+
child_name = module_fqn.split(".")[-1]
484+
parent_fqn = module_fqn.removesuffix(child_name).removesuffix(".")
485+
parent_module = named_modules[parent_fqn]
486+
setattr(parent_module, child_name, replacement)
487+
elif isinstance(config, AOBaseConfig):
484488
filter_fn = _is_linear if filter_fn is None else filter_fn
485489
handler = _QUANTIZE_CONFIG_HANDLER[type(config)]
486490
# for each linear in the model, apply the transform if filtering passes
@@ -491,7 +495,6 @@ def quantize_(
491495
device=device,
492496
extra_args=(config,),
493497
)
494-
495498
else:
496499
raise AssertionError(
497500
"""Passing a generic Callable to `quantize_` is no longer recommended and will be deprecated at a later release. Please see https://github.com/pytorch/ao/issues/1690 for instructions on how to pass in workflow configuration instead."""

0 commit comments

Comments
 (0)