Quantization is a core method for deploying large neural networks such as Llama 2 efficiently on constrained hardware, especially embedded systems and edge devices. The aim is to reduce computational and memory costs by converting high-precision floating-point representations (like float32) into lower-precision integer types (such as int8). This process significantly reduces inference time and energy usage, often with negligible impact on model accuracy, making it possible to deploy even billion-parameter models on devices where floating point is not natively supported.
Why Quantization?
- Model Size Reduction: Quantization compresses neural network weights/activations from 32-bit floats to 8-bit integers, reducing storage and memory requirements by up to 4x.
- Speed: Int8 matrix multiplications are much faster on most hardware (especially CPUs and embedded accelerators), accelerating inference significantly.
- Hardware Support: Embedded systems (e.g., microcontrollers, NPUs) often lack native float math support, necessitating integer-only arithmetic.
- Energy Efficiency: Less computation and lower memory bandwidth culminate in lower energy draw critical for mobile and IoT devices.
How Quantization Works
1. Representation Mapping
- Original: Weights
W and biasesb are stored as 32-bit floats. - Quantization: These values are mapped into lower-precision integer representations (usually int8 for weights, int32 for biases).
- Dequantization: Before feeding results to the next layer (often still expecting float values), results are mapped back to floating-point via the scale and zero-point parameters.
Example (as in Llama 2 7B or similar large models):
y = x W + b
Where:
W : Quantized to 8-bit integer (int8)b : Quantized to 32-bit integer (int32, for accumulator width)- Computation is performed in lower precision, then dequantized for subsequent operations.
2. Formal Quantization Equation
Forward Quantization
q_x = \mathrm{round}\left(\frac{x}{\text{scale}}\right) + \text{zero\_point}
Dequantization
x \approx \text{scale} \times (q_x - \text{zero\_point})
- scale: Determines the step size between integer values and their float counterparts.
- zero_point: Aligns the integer representation with the network’s value distribution (e.g., maps float zero to nonzero integer).
Types of Quantization
Method | Description |
|---|---|
Dynamic | Quantizes weights post-training; activations quantized during inference. Fast, minimal code changes. |
Static (PTQ) | Requires calibration data; quantizes both weights and activations ahead of inference for best efficiency. |
QAT | Quantization-Aware Training. Simulates quantization noise during training for highest post-quantization accuracy. |
Implementation: PyTorch Workflow (Post-Training Quantization)
Step 1 : Data Preparation
- Loads the MNIST handwritten digits dataset.
- Converts images to PyTorch tensors.
- Prepares DataLoaders for batching and iterating through data during training and testing.
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
transform = transforms.Compose([
transforms.ToTensor()
])
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)
test_loader = DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)
Step 2 :Define the CNN Model (with Quantization Stubs)
import torch
import torch.nn as nn
import torch.quantization
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
# For marking where quantization/dequantization happens
self.quant = torch.quantization.QuantStub()
self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
self.relu1 = nn.ReLU()
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
self.relu2 = nn.ReLU()
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(32 * 14 * 14, 128)
self.relu3 = nn.ReLU()
self.fc2 = nn.Linear(128, 10)
self.dequant = torch.quantization.DeQuantStub()
def forward(self, x):
x = self.quant(x) # Quantize input to int8
x = self.relu1(self.conv1(x))
x = self.pool(self.relu2(self.conv2(x)))
x = x.reshape(x.size(0), -1)
x = self.relu3(self.fc1(x))
x = self.fc2(x)
x = self.dequant(x) # Dequantize output back to float32
return x
model_fp32 = SimpleCNN()
Why QuantStub/DeQuantStub : Mark input/output boundaries for quantization and dequantization in the network so PyTorch knows where to apply quantized ops.
Step 3 : (Optional) Quick Training
- Model trains for a couple of epochs. Even for quantization demos, decent weights are needed.
- The code will work even if you skip training (the quantization part is independent), but accuracy will be poor.
import torch.optim as optim
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_fp32.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model_fp32.parameters(), lr=0.001)
print("Training (just a few epochs for demo)...")
for epoch in range(2): # Just 2 epochs for time; increase for better accuracy!
model_fp32.train()
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model_fp32(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
Step 4 : Fuse Layers
def fuse_model(model):
torch.quantization.fuse_modules(model,
[['conv1', 'relu1'], ['conv2', 'relu2'], ['fc1', 'relu3']], inplace=True)
fuse_model(model_fp32)
Rationale:
- Fusing Conv + ReLU (or Conv+BN+ReLU) is important as it combines them into one operation, improving both the accuracy and speed of quantized models.
- Done before quantization.
Step 5. Set Quantization Configuration
model_fp32.qconfig = torch.quantization.get_default_qconfig('fbgemm')
What is qconfig?
- It tells PyTorch how to observe and quantize the model.
- 'fbgemm' is preferred for x86 CPUs. For ARM CPUs, use 'qnnpack'.
Step 6. Prepare for Quantization
Prepare and inserts observer modules that record the ranges of activations/weights during calibration.
model_fp32.cpu() # Quantization is CPU-only in PyTorch
torch.quantization.prepare(model_fp32, inplace=True)
Step 7 : Calibration
- This step feeds real data through the model.
- Observers collect min/max values to determine how to map floats to int8 (scale/zero-point).
print("Calibrating...")
model_fp32.eval()
with torch.no_grad():
for images, _ in train_loader:
model_fp32(images)
break # In practical cases, use more calibration data; here just a few for demo
Step 8 : Convert to Quantized Model
- All eligible layers (Conv, Linear, etc.) are replaced with quantized (int8) modules.
- Model is now ready for fast, int8 inference on CPU.
quantized_model = torch.quantization.convert(model_fp32, inplace=False)
Step 9 : Evaluate Accuracy
- Runs the test data through each model.
- Compare float32 and quantized int8 accuracy; usually, there is little loss (<1%).
def evaluate(model, test_loader):
model.eval()
correct, total = 0, 0
with torch.no_grad():
for images, labels in test_loader:
outputs = model(images)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy:', 100.0 * correct / total, '%')
return correct / total
print("\nEvaluating original (float32) model:")
evaluate(model_fp32, test_loader) # Note: after quantization, model_fp32 is already quantized if inplace=True
print("\nEvaluating quantized model:")
evaluate(quantized_model, test_loader)
Step 10: Check Model File Sizes
- Quantized model file is about 1/4 the size.
- File size matches your expected storage/compression benefits from quantization.
import os
torch.save(model_fp32.state_dict(), "float_model.pth")
torch.save(quantized_model.state_dict(), "quantized_model.pth")
float_size = os.path.getsize("float_model.pth") / 1024
quant_size = os.path.getsize("quantized_model.pth") / 1024
print(f"\nModel size (float32): {float_size:.1f} KB")
print(f"Model size (quantized): {quant_size:.1f} KB")
Output:

Google Colab link : Pytorch Quantisation
Key Technical Points
- Bias Quantization: Biases typically use int32 to prevent overflow during accumulation, as they sum many int8 products.
- Operations on Embedded Devices: Most embedded AI accelerators strictly support integer arithmetic, making quantization the de facto deployment path.
- Dequantization: At each layer’s output, results are dequantized (float recovered) if further float computation is needed, e.g., for softmax or loss calculation.
- Precision and Loss: Well-calibrated quantization (especially QAT or with asymmetric scaling and dynamic range selection) can keep final accuracy very close to original float network performance.
Real-World Example: Llama 2 Quantization
- Llama 2 7B, originally a multi-billion parameter float model, is quantized down (often to 8-bit int for weights and 32-bit int for biases and accumulators), enabling deployment on hardware-constrained servers, mobile or edge devices with minimal accuracy drop.
- Quantization reduces model size drastically and speeds up inference, since
y= xW+ b (with W, b quantized) becomes a pure int8/int32 computation, replaced by a single scale factor per layer for seamless dequantization.