边缘 AI 部署:从云端到端侧的模型优化

概述

边缘 AI(Edge AI)将 AI 推理从云端移到设备端(手机、IoT 设备、浏览器、嵌入式系统),核心目标是降低延迟、保护隐私、减少带宽依赖。

关键挑战:端侧设备算力和内存有限,必须对模型进行大幅压缩和优化。本文系统覆盖模型量化、知识蒸馏、结构优化和推理引擎四个维度。

模型压缩技术全景

模型压缩技术栈
    |
    ├── 量化 (Quantization)
    │   ├── 训练后量化 (PTQ)
    │   ├── 量化感知训练 (QAT)
    │   └── 动态量化 / 混合精度
    |
    ├── 剪枝 (Pruning)
    │   ├── 非结构化剪枝
    │   └── 结构化剪枝
    |
    ├── 知识蒸馏 (Distillation)
    │   ├── 输出蒸馏
    │   ├── 特征蒸馏
    │   └── 注意力蒸馏
    |
    └── 结构优化
        ├── 算子融合
        ├── 图优化
        └── 架构搜索 (NAS)

量化(Quantization)

基本原理

将模型权重和激活值从高精度(FP32/FP16)映射到低精度(INT8/INT4)。

FP32 权重范围: [-3.14, 2.78]
        |
        v  线性量化
INT8 映射: [-128, 127]

公式: q = round(x / scale + zero_point)
      x = (q - zero_point) * scale

训练后量化(PTQ)

不需要重新训练,直接对已有模型量化。适合快速部署。

import torch
from torch.quantization import quantize_dynamic

# 方式一:动态量化(推理时量化激活值)
model_fp32 = load_model()
model_int8 = quantize_dynamic(
    model_fp32,
    {torch.nn.Linear, torch.nn.LSTM},  # 量化的层类型
    dtype=torch.qint8,
)

# 模型大小对比
import os
torch.save(model_fp32.state_dict(), "model_fp32.pt")
torch.save(model_int8.state_dict(), "model_int8.pt")
print(f"FP32: {os.path.getsize('model_fp32.pt') / 1e6:.1f} MB")
print(f"INT8: {os.path.getsize('model_int8.pt') / 1e6:.1f} MB")
# 通常 INT8 约为 FP32 的 1/4

# 方式二:静态量化(需要校准数据集)
from torch.quantization import prepare, convert

model_fp32.eval()
model_fp32.qconfig = torch.quantization.get_default_qconfig("x86")

# 插入观测器
model_prepared = prepare(model_fp32)

# 用代表性数据做校准
with torch.no_grad():
    for batch in calibration_dataloader:
        model_prepared(batch)

# 转换为量化模型
model_quantized = convert(model_prepared)

使用 ONNX Runtime 量化

from onnxruntime.quantization import quantize_static, CalibrationDataReader
import onnxruntime as ort

class MyCalibrationReader(CalibrationDataReader):
    def __init__(self, dataset, batch_size=32):
        self.dataset = dataset
        self.batch_size = batch_size
        self.index = 0

    def get_next(self):
        if self.index >= len(self.dataset):
            return None
        batch = self.dataset[self.index:self.index + self.batch_size]
        self.index += self.batch_size
        return {"input": batch}

# INT8 量化
quantize_static(
    model_input="model.onnx",
    model_output="model_int8.onnx",
    calibration_data_reader=MyCalibrationReader(calibration_data),
    quant_format=ort.quantization.QuantFormat.QDQ,  # QDQ 格式
    per_channel=True,          # 按通道量化(精度更高)
    reduce_range=False,
)

# 验证量化模型
session = ort.InferenceSession("model_int8.onnx")
result = session.run(None, {"input": test_data})

量化感知训练(QAT)

在训练过程中模拟量化误差,让模型学会在低精度下保持性能。精度通常优于 PTQ。

import torch
from torch.quantization import prepare_qat, convert

model = load_model()
model.train()

# 设置 QAT 配置
model.qconfig = torch.quantization.get_default_qat_qconfig("x86")
model_qat = prepare_qat(model)

# 正常训练(伪量化节点自动插入)
optimizer = torch.optim.Adam(model_qat.parameters(), lr=1e-5)

for epoch in range(3):
    for batch in train_dataloader:
        loss = model_qat(batch)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

# 转换为量化模型
model_qat.eval()
model_quantized = convert(model_qat)

知识蒸馏(Knowledge Distillation)

用大模型(Teacher)的输出分布指导小模型(Student)的训练。

import torch
import torch.nn.functional as F

class DistillationTrainer:
    def __init__(self, teacher, student, temperature=4.0, alpha=0.7):
        self.teacher = teacher
        self.student = student
        self.temperature = temperature
        self.alpha = alpha  # 蒸馏损失权重

        self.teacher.eval()  # Teacher 不更新

    def distillation_loss(self, student_logits, teacher_logits, labels):
        """组合蒸馏损失和硬标签损失"""
        T = self.temperature

        # 软标签损失(KL 散度)
        soft_loss = F.kl_div(
            F.log_softmax(student_logits / T, dim=-1),
            F.softmax(teacher_logits / T, dim=-1),
            reduction="batchmean",
        ) * (T * T)

        # 硬标签损失(交叉熵)
        hard_loss = F.cross_entropy(student_logits, labels)

        # 加权组合
        return self.alpha * soft_loss + (1 - self.alpha) * hard_loss

    def train_step(self, batch):
        inputs, labels = batch

        # Teacher 推理(不计算梯度)
        with torch.no_grad():
            teacher_logits = self.teacher(inputs)

        # Student 推理
        student_logits = self.student(inputs)

        # 计算损失
        loss = self.distillation_loss(student_logits, teacher_logits, labels)

        return loss

# 使用示例
teacher = load_model("bert-large")       # 340M 参数
student = load_model("bert-tiny")         # 14M 参数
trainer = DistillationTrainer(teacher, student)

LLM 蒸馏实践

对于大语言模型,蒸馏通常采用数据蒸馏方式:

# 用大模型生成高质量训练数据,训练小模型
from openai import OpenAI

client = OpenAI()

def generate_training_data(prompts: list[str], teacher_model="gpt-4o"):
    """用 Teacher 模型生成训练数据"""
    training_pairs = []

    for prompt in prompts:
        response = client.chat.completions.create(
            model=teacher_model,
            messages=[
                {"role": "system", "content": "你是一个专业的客服助手..."},
                {"role": "user", "content": prompt},
            ],
            temperature=0.3,
        )

        training_pairs.append({
            "instruction": prompt,
            "output": response.choices[0].message.content,
        })

    return training_pairs

# 生成数据后用来微调小模型(如 Llama-3.2-1B)

推理引擎对比

ONNX Runtime

import onnxruntime as ort
import numpy as np

# 创建推理会话
session_options = ort.SessionOptions()
session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
session_options.intra_op_num_threads = 4

# CPU 推理
session = ort.InferenceSession(
    "model.onnx",
    sess_options=session_options,
    providers=["CPUExecutionProvider"],
)

# GPU 推理(CUDA)
session_gpu = ort.InferenceSession(
    "model.onnx",
    providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
)

# 推理
input_data = np.random.randn(1, 768).astype(np.float32)
outputs = session.run(None, {"input": input_data})

TensorRT(NVIDIA GPU 极致优化)

# 将 ONNX 模型转换为 TensorRT engine
import tensorrt as trt

def build_engine(onnx_path, engine_path, fp16=True):
    logger = trt.Logger(trt.Logger.WARNING)
    builder = trt.Builder(logger)
    network = builder.create_network(
        1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
    )
    parser = trt.OnnxParser(network, logger)

    with open(onnx_path, "rb") as f:
        parser.parse(f.read())

    config = builder.create_builder_config()
    config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30)  # 1GB

    if fp16:
        config.set_flag(trt.BuilderFlag.FP16)

    engine = builder.build_serialized_network(network, config)

    with open(engine_path, "wb") as f:
        f.write(engine)

    return engine

# 构建 FP16 engine(比 FP32 快 2-3 倍)
build_engine("model.onnx", "model.engine", fp16=True)

llama.cpp(CPU 上的 LLM 推理)

# 编译 llama.cpp
git clone https://github.com/ggerganov/llama.cpp
cd llama.cpp && make -j

# 量化模型(GGUF 格式)
python convert_hf_to_gguf.py \
  --model meta-llama/Llama-3.2-1B-Instruct \
  --outfile llama-3.2-1b.gguf

# Q4_K_M 量化(推荐,精度和速度的最佳平衡)
./llama-quantize llama-3.2-1b.gguf llama-3.2-1b-q4_k_m.gguf Q4_K_M

# 推理
./llama-cli \
  -m llama-3.2-1b-q4_k_m.gguf \
  -p "What is machine learning?" \
  -n 256 \
  --threads 8

推理引擎对比表

引擎 目标平台 优势 适用模型
ONNX Runtime 全平台 跨平台、多后端 通用
TensorRT NVIDIA GPU 极致 GPU 性能 分类/检测/NLP
llama.cpp CPU/Metal 纯 CPU 运行 LLM LLM
Core ML Apple 设备 ANE 加速 iOS/macOS
TFLite Android/嵌入式 移动端优化 分类/检测
MediaPipe 移动端 端到端解决方案 视觉/NLP

移动端部署

iOS(Core ML)

# PyTorch -> Core ML 转换
import coremltools as ct
import torch

model = load_model()
model.eval()

# 示例输入
example_input = torch.randn(1, 3, 224, 224)

# 转换为 Core ML
traced_model = torch.jit.trace(model, example_input)
coreml_model = ct.convert(
    traced_model,
    inputs=[ct.TensorType(name="image", shape=(1, 3, 224, 224))],
    compute_precision=ct.precision.FLOAT16,  # FP16
    compute_units=ct.ComputeUnit.ALL,        # CPU + GPU + ANE
)

coreml_model.save("model.mlpackage")
// Swift 推理代码
import CoreML
import Vision

let model = try! MyModel(configuration: .init())

let request = VNCoreMLRequest(model: try! VNCoreMLModel(for: model.model)) {
    request, error in
    guard let results = request.results as? [VNClassificationObservation] else { return }
    let topResult = results.first!
    print("\(topResult.identifier): \(topResult.confidence)")
}

let handler = VNImageRequestHandler(cgImage: image)
try! handler.perform([request])

Android(TFLite)

# PyTorch -> TFLite 转换
import torch
import ai_edge_torch

model = load_model()
model.eval()

sample_input = (torch.randn(1, 3, 224, 224),)

# 转换
edge_model = ai_edge_torch.convert(model, sample_input)
edge_model.export("model.tflite")
// Kotlin 推理代码
import org.tensorflow.lite.Interpreter

val interpreter = Interpreter(loadModelFile("model.tflite"))

val inputBuffer = ByteBuffer.allocateDirect(4 * 3 * 224 * 224)
val outputBuffer = ByteBuffer.allocateDirect(4 * 1000) // 1000 类

interpreter.run(inputBuffer, outputBuffer)

浏览器端推理(WebAssembly + WebGPU)

// 使用 Transformers.js 在浏览器中运行模型
import { pipeline } from "@xenova/transformers";

// 自动下载 ONNX 模型并在浏览器中推理
const classifier = await pipeline(
  "sentiment-analysis",
  "Xenova/distilbert-base-uncased-finetuned-sst-2-english",
  { device: "webgpu" },  // 使用 WebGPU 加速
);

const result = await classifier("I love this product!");
console.log(result);
// [{ label: "POSITIVE", score: 0.9998 }]

// 使用 WebLLM 在浏览器中运行 LLM
import { CreateMLCEngine } from "@mlc-ai/web-llm";

const engine = await CreateMLCEngine("Llama-3.2-1B-Instruct-q4f16_1-MLC");

const reply = await engine.chat.completions.create({
  messages: [{ role: "user", content: "Hello!" }],
  stream: true,
});

for await (const chunk of reply) {
  process.stdout.write(chunk.choices[0]?.delta?.content || "");
}

性能基准参考

以 ResNet-50 图像分类为例:

平台 精度 延迟 (ms) 模型大小
A100 GPU (FP32) 76.1% 2 98 MB
A100 GPU (FP16) 76.1% 1.2 49 MB
A100 TensorRT (INT8) 75.8% 0.6 25 MB
CPU (FP32) 76.1% 45 98 MB
CPU ONNX RT (INT8) 75.6% 12 25 MB
iPhone 15 ANE (FP16) 76.0% 3 49 MB
Pixel 8 NPU (INT8) 75.5% 8 25 MB

选型决策

你的推理在哪里运行?
    |
    ├── 云端 GPU
    │   └── TensorRT > ONNX Runtime > vLLM (LLM)
    |
    ├── 云端 CPU
    │   └── ONNX Runtime > llama.cpp (LLM)
    |
    ├── iOS / macOS
    │   └── Core ML (ANE 加速)
    |
    ├── Android
    │   └── TFLite > ONNX Runtime Mobile
    |
    ├── 浏览器
    │   └── Transformers.js / WebLLM (WebGPU)
    |
    └── 嵌入式 / IoT
        └── TFLite Micro > ONNX Runtime Micro

Maurice | maurice_wen@proton.me