Skip to content
Learn Agentic AI12 min read0 views

Optimizing Model Size for Edge Deployment: Pruning, Distillation, and Quantization

Master the three core techniques for reducing AI model size for edge deployment — pruning, knowledge distillation, and quantization — with practical code examples and quality preservation strategies.

The Edge Deployment Challenge

A typical transformer model for an AI agent — say, a 110 million parameter BERT — weighs 440 MB in FP32. That is too large for many edge devices: too slow to load, too much memory to run, and too big to ship in a mobile app bundle.

The goal of model optimization is to make that model smaller and faster while preserving as much accuracy as possible. Three techniques are your primary tools: quantization (reducing numerical precision), pruning (removing unnecessary weights), and knowledge distillation (training a small model to mimic a large one).

Quantization: Reducing Numerical Precision

Quantization converts model weights from 32-bit floating point to lower precision formats. It is the single most impactful optimization for edge deployment.

Post-Training Quantization

Apply quantization after training, without any retraining:

import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer

model_name = "distilbert-base-uncased"
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=5)

# Dynamic quantization — quantizes weights to INT8
quantized_model = torch.quantization.quantize_dynamic(
    model,
    {torch.nn.Linear},  # Which layers to quantize
    dtype=torch.qint8,
)

# Compare sizes
def get_model_size_mb(model):
    param_size = sum(p.nelement() * p.element_size() for p in model.parameters())
    buffer_size = sum(b.nelement() * b.element_size() for b in model.buffers())
    return (param_size + buffer_size) / 1024 / 1024

print(f"Original size:  {get_model_size_mb(model):.1f} MB")
print(f"Quantized size: {get_model_size_mb(quantized_model):.1f} MB")
# Original:  256.4 MB
# Quantized:  64.2 MB  (4x reduction)

Quantization-Aware Training (QAT)

For minimal accuracy loss, simulate quantization during training so the model adapts:

import torch
from torch.quantization import prepare_qat, convert

model.train()

# Specify quantization config
model.qconfig = torch.quantization.get_default_qat_qconfig("fbgemm")

# Insert fake quantization modules
model_prepared = prepare_qat(model)

# Fine-tune with quantization simulation
optimizer = torch.optim.AdamW(model_prepared.parameters(), lr=1e-5)

for epoch in range(3):
    for batch in train_dataloader:
        outputs = model_prepared(**batch)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

# Convert fake quantized model to actual quantized model
model_prepared.eval()
quantized_model = convert(model_prepared)

4-Bit Quantization with GPTQ

For aggressive size reduction of generative models:

from transformers import AutoModelForCausalLM, AutoTokenizer, GPTQConfig

# Configure 4-bit quantization
quantization_config = GPTQConfig(
    bits=4,
    dataset="c4",
    tokenizer=AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B"),
    group_size=128,
)

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.2-1B",
    quantization_config=quantization_config,
    device_map="auto",
)

# Original: ~4 GB, 4-bit: ~700 MB
model.save_pretrained("llama-1b-gptq-4bit")

Pruning: Removing Unnecessary Weights

Pruning sets small weights to zero, creating a sparse model. Structured pruning removes entire neurons or attention heads; unstructured pruning zeros out individual weights.

See AI Voice Agents Handle Real Calls

Book a free demo or calculate how much you can save with AI voice automation.

Magnitude-Based Unstructured Pruning

import torch
import torch.nn.utils.prune as prune

model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=5)

# Prune 30% of weights in all Linear layers
parameters_to_prune = [
    (module, "weight")
    for module in model.modules()
    if isinstance(module, torch.nn.Linear)
]

for module, param_name in parameters_to_prune:
    prune.l1_unstructured(module, name=param_name, amount=0.3)

# Count remaining nonzero weights
total = sum(p.nelement() for p in model.parameters())
nonzero = sum(torch.count_nonzero(p).item() for p in model.parameters())
print(f"Sparsity: {(1 - nonzero / total) * 100:.1f}%")
# Sparsity: 30.0%

# Make pruning permanent
for module, param_name in parameters_to_prune:
    prune.remove(module, param_name)

Structured Pruning of Attention Heads

Remove entire attention heads that contribute least to model output:

import torch
import numpy as np

def compute_head_importance(model, eval_dataloader) -> dict:
    """Calculate importance score for each attention head."""
    head_importance = {}
    model.eval()

    for batch in eval_dataloader:
        outputs = model(**batch, output_attentions=True)
        for layer_idx, attention in enumerate(outputs.attentions):
            # Shape: (batch, num_heads, seq_len, seq_len)
            importance = attention.abs().mean(dim=(0, 2, 3))  # Average over batch and positions
            key = f"layer_{layer_idx}"
            if key not in head_importance:
                head_importance[key] = []
            head_importance[key].append(importance.detach().cpu())

    # Average across batches
    return {
        k: torch.stack(v).mean(dim=0) for k, v in head_importance.items()
    }

def prune_least_important_heads(model, head_importance, prune_ratio=0.25):
    """Zero out the least important attention heads."""
    all_scores = []
    for layer_name, scores in head_importance.items():
        for head_idx, score in enumerate(scores):
            all_scores.append((layer_name, head_idx, score.item()))

    all_scores.sort(key=lambda x: x[2])
    num_to_prune = int(len(all_scores) * prune_ratio)
    heads_to_prune = all_scores[:num_to_prune]

    print(f"Pruning {num_to_prune} attention heads out of {len(all_scores)}")
    for layer_name, head_idx, score in heads_to_prune:
        print(f"  {layer_name}, head {head_idx} (importance: {score:.4f})")
    return heads_to_prune

Knowledge Distillation: Training a Smaller Model

Distillation trains a small "student" model to replicate the behavior of a large "teacher" model. The student learns from the teacher's soft probability distributions, which contain more information than hard labels:

import torch
import torch.nn.functional as F

class DistillationTrainer:
    """Train a student model to mimic a teacher model."""

    def __init__(
        self,
        teacher,
        student,
        temperature: float = 4.0,
        alpha: float = 0.7,
    ):
        self.teacher = teacher.eval()
        self.student = student.train()
        self.temperature = temperature
        self.alpha = alpha  # Weight for distillation vs task loss

    def distillation_loss(self, student_logits, teacher_logits, labels):
        """Combine soft target loss with hard target loss."""
        # Soft targets — KL divergence between student and teacher distributions
        soft_student = F.log_softmax(student_logits / self.temperature, dim=-1)
        soft_teacher = F.softmax(teacher_logits / self.temperature, dim=-1)
        distill_loss = F.kl_div(
            soft_student, soft_teacher, reduction="batchmean"
        ) * (self.temperature ** 2)

        # Hard targets — standard cross-entropy with true labels
        hard_loss = F.cross_entropy(student_logits, labels)

        # Weighted combination
        return self.alpha * distill_loss + (1 - self.alpha) * hard_loss

    def train_step(self, batch, optimizer):
        optimizer.zero_grad()

        # Teacher forward pass (no gradients)
        with torch.no_grad():
            teacher_outputs = self.teacher(**batch)

        # Student forward pass
        student_outputs = self.student(**batch)

        loss = self.distillation_loss(
            student_outputs.logits,
            teacher_outputs.logits,
            batch["labels"],
        )
        loss.backward()
        optimizer.step()
        return loss.item()

Combining Techniques

The most effective edge optimization combines all three approaches in sequence:

def optimize_for_edge(teacher_model, train_data, eval_data):
    """Full optimization pipeline: distill, prune, quantize."""

    # Step 1: Distill to smaller architecture
    student = create_smaller_model(num_layers=3, hidden_size=256)
    trainer = DistillationTrainer(teacher_model, student)
    train_student(trainer, train_data, epochs=5)

    # Step 2: Prune redundant weights
    head_importance = compute_head_importance(student, eval_data)
    prune_least_important_heads(student, head_importance, prune_ratio=0.25)
    fine_tune(student, train_data, epochs=2)  # Recover accuracy

    # Step 3: Quantize to INT8
    quantized = torch.quantization.quantize_dynamic(
        student, {torch.nn.Linear}, dtype=torch.qint8
    )

    return quantized

# Result: 440 MB teacher -> ~8 MB optimized edge model

Quality Preservation Strategies

After aggressive optimization, validate that quality remains acceptable:

def evaluate_optimization(original, optimized, test_data) -> dict:
    """Compare original and optimized model quality."""
    original_preds = run_predictions(original, test_data)
    optimized_preds = run_predictions(optimized, test_data)

    return {
        "original_accuracy": compute_accuracy(original_preds, test_data.labels),
        "optimized_accuracy": compute_accuracy(optimized_preds, test_data.labels),
        "agreement_rate": compute_agreement(original_preds, optimized_preds),
        "original_size_mb": get_model_size_mb(original),
        "optimized_size_mb": get_model_size_mb(optimized),
        "compression_ratio": get_model_size_mb(original) / get_model_size_mb(optimized),
    }

FAQ

Which optimization technique should I apply first?

Start with quantization — it is the easiest to apply and gives the largest size reduction (typically 4x) with the least accuracy impact. If the model is still too large, add pruning. Use knowledge distillation when you need to move to a fundamentally smaller architecture (e.g., from BERT-large to a 3-layer student). The order for maximum compression is: distill first, then prune the student, then quantize the pruned student.

How much accuracy loss should I expect from aggressive optimization?

Dynamic INT8 quantization typically loses less than 0.5 percent accuracy. Adding 30 percent unstructured pruning adds another 0.5 to 1 percent loss. Knowledge distillation to a model one-quarter the size of the teacher typically retains 95 to 98 percent of accuracy. Combining all three, expect 2 to 5 percent total accuracy loss while achieving 20 to 50x size reduction. Always validate on your specific task — some tasks are more sensitive to compression than others.

Can I quantize a model to less than 4 bits?

Research on 2-bit and even 1-bit quantization exists, but practical deployment below 4 bits remains challenging. At 4-bit quantization, most models retain 95 percent or more of their accuracy. At 2 bits, accuracy drops significantly for general tasks, though specialized models trained with quantization awareness can still perform well on narrow tasks. For most edge agent deployments, 4-bit quantization (GPTQ or AWQ) is the practical floor for quality preservation.


#ModelOptimization #Quantization #KnowledgeDistillation #Pruning #EdgeDeployment #MLOps #AgenticAI #LearnAI #AIEngineering

Share this article
C

CallSphere Team

Expert insights on AI voice agents and customer communication automation.

Try CallSphere AI Voice Agents

See how AI voice agents work for your industry. Live demo available -- no signup required.