Skip to content

Commit 3233761

Browse files
martinlsmMartin Lindström
andauthored
Arm backend: Add example of partial quantization (#16298)
Update the minimal VGF example to mention partial quantization of models. Signed-off-by: Martin Lindstroem <[email protected]> Co-authored-by: Martin Lindström <[email protected]>
1 parent 2350cde commit 3233761

File tree

1 file changed

+20
-8
lines changed

1 file changed

+20
-8
lines changed

examples/arm/vgf_minimal_example.ipynb

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,17 @@
4848
"source": [
4949
"import torch\n",
5050
"\n",
51-
"class Add(torch.nn.Module):\n",
51+
"class AddSigmoid(torch.nn.Module):\n",
52+
" def __init__(self):\n",
53+
" super().__init__()\n",
54+
" self.sigmoid = torch.nn.Sigmoid()\n",
55+
"\n",
5256
" def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:\n",
53-
" return x + y\n",
57+
" return self.sigmoid(x + y)\n",
5458
"\n",
5559
"example_inputs = (torch.ones(1,1,1,1),torch.ones(1,1,1,1))\n",
5660
"\n",
57-
"model = Add()\n",
61+
"model = AddSigmoid()\n",
5862
"model = model.eval()\n",
5963
"exported_program = torch.export.export(model, example_inputs)\n",
6064
"graph_module = exported_program.graph_module\n",
@@ -84,8 +88,8 @@
8488
"source": [
8589
"from executorch.backends.arm.vgf import VgfCompileSpec\n",
8690
"\n",
87-
"# Create a compilation spec describing the floating point target.\n",
88-
"compile_spec = VgfCompileSpec(\"TOSA-1.0+FP\")\n",
91+
"# Create a compilation spec describing the target\n",
92+
"compile_spec = VgfCompileSpec()\n",
8993
"\n",
9094
"_ = graph_module.print_readable()\n",
9195
"\n",
@@ -99,7 +103,7 @@
99103
"source": [
100104
"To lower the graph_module for INT targets using the VGF backend, we apply the arm_quantizer. \n",
101105
"\n",
102-
"Quantization can be performed in various ways and tailored to different subgraphs; the sequence shown here represents the recommended workflow for VGF. \n",
106+
"Quantization can be performed in various ways and tailored to different subgraphs; it is even possible to opt out of quantization for selected layers and have them run in floating-point.\n",
103107
"\n",
104108
"This step also requires calibrating the module with representative inputs. \n",
105109
"\n",
@@ -120,13 +124,21 @@
120124
"from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e\n",
121125
"\n",
122126
"# Create a compilation spec describing the target for configuring the quantizer\n",
123-
"compile_spec = VgfCompileSpec(\"TOSA-1.0+INT\")\n",
127+
"compile_spec = VgfCompileSpec()\n",
124128
"\n",
125129
"# Create and configure quantizer to use a symmetric quantization config globally on all nodes\n",
126130
"quantizer = VgfQuantizer(compile_spec)\n",
127131
"operator_config = get_symmetric_quantization_config(is_per_channel=False)\n",
132+
"\n",
133+
"# Set global (default) quantization config for the layers in the models.\n",
134+
"# Can also be set to `None` to let layers run in FP as default.\n",
128135
"quantizer.set_global(operator_config)\n",
129136
"\n",
137+
"# Skip quantizing all sigmoid ops (only one for this model); let it run in FP.\n",
138+
"# This step is optional; selecting which layers to include/exclude for\n",
139+
"# quantization is part of optimizing the model's performance.\n",
140+
"quantizer.set_module_type(torch.nn.Sigmoid, None)\n",
141+
"\n",
130142
"# Post training quantization\n",
131143
"quantized_graph_module = prepare_pt2e(graph_module, quantizer)\n",
132144
"quantized_graph_module(*example_inputs) # Calibrate the graph module with the example input\n",
@@ -142,7 +154,7 @@
142154
"cell_type": "markdown",
143155
"metadata": {},
144156
"source": [
145-
"# In the example below, we will make use of the quantized graph module.\n",
157+
"# In the example below, we will make use of the (partially) quantized graph module.\n",
146158
"\n",
147159
"The lowering in the VGFBackend happens in five steps:\n",
148160
"\n",

0 commit comments

Comments
 (0)