Training-Time Quantization
The LinearQuantizer
class implements training-time quantization, also known as quantization-aware training (QAT) as described in the paper Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference. LinearQuantizer
quantizes both weights and activations, whereas Post-Training Quantization quantizes only the weights.
PyTorch quantization APIs
You can use PyTorch's quantization APIs directly, and then convert the model to Core ML. However, the converted model performance may not be optimal. The PyTorch API default settings (symmetric asymmetric quantization modes and which ops are quantized) are not optimal for the Core ML stack and Apple hardware. If you use the Core ML Tools
coremltools.optimize.torch
APIs, as described in this section, the correct default settings are applied automatically.
Use LinearQuantizer
LinearQuantizer
Follow these key steps:
- Define the
LinearQuantizerConfig
config to specify the quantization parameters. - Initialize the
LinearQuantizer
object. - Call the
prepare
API to insert fake quantization layers in the PyTorch model. - Run the usual training loop, with the addition of the
quantizer.step
call. - Once the model has converged, use the
finalize
API to prepare the model for conversion to Core ML.
The following code sample shows how you can use LinearQuantizer
to perform training-time quantization on your PyTorch model.
from collections import OrderedDict
import torch
import torch.nn as nn
import coremltools as ct
from coremltools.optimize.torch.quantization import LinearQuantizer, LinearQuantizerConfig
model = nn.Sequential(
OrderedDict(
{
"conv": nn.Conv2d(1, 20, (3, 3)),
"relu1": nn.ReLU(),
"conv2": nn.Conv2d(20, 20, (3, 3)),
"relu2": nn.ReLU(),
}
)
)
loss_fn = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
data = get_data()
# Initialize the quantizer
config = LinearQuantizerConfig.from_dict(
{
"global_config": {
"quantization_scheme": "symmetric",
"milestones": [0, 100, 400, 200],
}
}
)
quantizer = LinearQuantizer(model, config)
# Prepare the model to insert FakeQuantize layers for QAT
example_input = torch.rand(1, 1, 20, 20)
model = quantizer.prepare(example_inputs=example_input, inplace=True)
# Use quantizer in your PyTorch training loop
for inputs, labels in data:
output = model(inputs)
loss = loss_fn(output, labels)
loss.backward()
optimizer.step()
quantizer.step()
# Convert operations to their quanitzed counterparts using parameters learnt via QAT
model = quantizer.finalize(inplace=True)
# Convert the PyTorch models to CoreML format
traced_model = torch.jit.trace(model, example_input)
coreml_model = ct.convert(
traced_model,
convert_to="mlprogram",
inputs=[ct.TensorType(shape=example_input.shape)],
minimum_deployment_target=ct.target.iOS17,
)
coreml_model.save("~/quantized_model.mlpackage")
The two key parameters in ModuleLinearQuantizerConfig
that need to be set for training-time quantization are quantization_scheme
and milestones
. The allowed values for quantization_scheme
are symmetric
and affine
. In symmetric
mode, zero_point
is always set to zero, whereas affine
mode is able to use any zero point in the quint8
or int8
range, depending on the dtype used.
The milestones
parameter controls the flow of the quantization algorithm, and calling the step
API on the quantizer object steps through these milestones. The milestones
parameter is an array of size 4, and each element is an integer indicating the training step at which the stage corresponding to that element comes into effect. A detailed explanation of these various stages can be found in the API Reference for ModuleLinearQuantizerConfig
.
How It Works
The LinearQuantizer
class simulates the effects of quantization during training by quantizing and de-quantizing the weights and activations during the model’s forward pass. The forward and backward pass computations are conducted in float32
dtype. However, these float32
values follow the constraints imposed by int8
and quint8
dtypes, for weights and activations respectively. This allows the model weights to adjust and reduce the error introduced by quantization. Straight-Through Estimation is used for computing gradients of non-differentiable operations introduced by simulated quantization.
The LinearQuantizer algorithm is implemented as an extension of FX Graph Mode Quantization in PyTorch. It first traces the PyTorch model symbolically to obtain a torch.fx graph capturing all the operations in the model. It then analyzes this graph, and inserts FakeQuantize layers in the graph. FakeQuantize layer insertion locations are chosen such that model inference on hardware is optimized and only weights and activations which benefit from quantization are quantized.
Since the prepare
method uses prepare_qat_fx to insert quantization layers, the model returned from the method is a torch.fx.GraphModule, and as a result custom methods defined on the original model class may not be available on the returned model. Some models, like those with dynamic control flow, may not be traceable into a torch.fx.GraphModule
. We recommend following the instructions in Limitations of Symbolic Tracing and FX Graph Mode Quantization User Guide to update your model first, before using LinearQuantizer algorithm.
Example
Linear Quantization Tutorial: Learn how to train a simple convolutional neural network using LinearQuantizer
. This algorithm simulates the effects of quantization during training, by quantizing and dequantizing the weights and/or activations during the model’s forward pass. You can download a Jupyter Notebook version and the source code from the tutorial.
Updated about 1 month ago