AI 网关设计:模型路由与负载均衡

LLM API 网关架构模式、智能路由策略、Fallback 链设计、成本追踪与 Token 计量

引言

当一个组织同时使用多个 LLM 提供商(OpenAI、Anthropic、Google、开源模型)时,直接在业务代码中硬编码 API 调用会导致紧耦合、切换困难和成本失控。AI 网关作为统一接入层,解决了模型路由、负载均衡、成本控制、可观测性和容错等横切关注点。

本文从架构设计到工程实现,系统讲解如何构建生产级 AI 网关。

架构设计

整体拓扑

┌─────────────────────────────────────────────────────────────┐
│                     业务应用层                                │
│  ┌─────────┐  ┌──────────┐  ┌──────────┐  ┌─────────────┐  │
│  │ ChatBot │  │ Code Gen │  │ RAG App  │  │ Agent System│  │
│  └────┬────┘  └─────┬────┘  └─────┬────┘  └──────┬──────┘  │
└───────┼─────────────┼─────────────┼──────────────┼──────────┘
        │             │             │              │
        ▼             ▼             ▼              ▼
┌─────────────────────────────────────────────────────────────┐
│                      AI 网关层                                │
│                                                             │
│  ┌──────────┐  ┌──────────┐  ┌──────────┐  ┌────────────┐  │
│  │  路由器   │  │ 限流器   │  │ 缓存层   │  │ 可观测性   │  │
│  │ (Router) │  │ (Limiter)│  │ (Cache)  │  │ (Telemetry)│  │
│  └────┬─────┘  └──────────┘  └──────────┘  └────────────┘  │
│       │                                                     │
│  ┌────▼─────────────────────────────────────────────────┐   │
│  │               Fallback Chain Manager                  │   │
│  └────┬──────────┬──────────┬──────────┬────────────────┘   │
└───────┼──────────┼──────────┼──────────┼────────────────────┘
        │          │          │          │
   ┌────▼───┐ ┌───▼────┐ ┌───▼────┐ ┌───▼────┐
   │OpenAI  │ │Anthropic│ │Google  │ │vLLM    │
   │GPT-4o  │ │Claude   │ │Gemini  │ │(自部署) │
   └────────┘ └────────┘ └────────┘ └────────┘

核心模块职责

模块 职责 关键指标
路由器 根据场景/模型/成本/延迟选择 Provider 路由决策延迟 <1ms
限流器 RPM/TPM 限制,防止超出 Provider 配额 令牌桶/滑动窗口精度
缓存层 语义缓存 + 精确匹配缓存 缓存命中率 >30%
Fallback 管理 失败自动切换,熔断恢复 故障切换延迟 <100ms
可观测性 Token 计量、延迟追踪、质量评分 数据完整性 >99.9%

智能路由策略

场景路由

不同任务场景对模型能力的要求不同,路由器根据场景类型选择最合适的模型:

// src/router/scene-router.ts
interface RouteConfig {
  scene: string;
  tiers: {
    premium: ProviderModel[];
    balanced: ProviderModel[];
    fast: ProviderModel[];
  };
}

const routeTable: RouteConfig[] = [
  {
    scene: "chat",
    tiers: {
      premium:  [{ provider: "anthropic", model: "claude-sonnet-4-20250514" }],
      balanced: [{ provider: "openai",    model: "gpt-4o" }],
      fast:     [{ provider: "google",    model: "gemini-2.0-flash" }],
    },
  },
  {
    scene: "code",
    tiers: {
      premium:  [{ provider: "anthropic", model: "claude-sonnet-4-20250514" }],
      balanced: [{ provider: "openai",    model: "gpt-4o" }],
      fast:     [{ provider: "google",    model: "gemini-2.0-flash" }],
    },
  },
  {
    scene: "vision",
    tiers: {
      premium:  [{ provider: "google",    model: "gemini-2.5-pro" }],
      balanced: [{ provider: "openai",    model: "gpt-4o" }],
      fast:     [{ provider: "google",    model: "gemini-2.0-flash" }],
    },
  },
  {
    scene: "embedding",
    tiers: {
      premium:  [{ provider: "openai",    model: "text-embedding-3-large" }],
      balanced: [{ provider: "openai",    model: "text-embedding-3-small" }],
      fast:     [{ provider: "local",     model: "bge-m3" }],
    },
  },
];

function routeRequest(scene: string, tier: string = "balanced"): ProviderModel {
  const config = routeTable.find(r => r.scene === scene);
  if (!config) throw new Error(`Unknown scene: ${scene}`);

  const candidates = config.tiers[tier];
  if (!candidates?.length) throw new Error(`No models for ${scene}/${tier}`);

  // Check availability before returning
  for (const candidate of candidates) {
    if (circuitBreaker.isAvailable(candidate.provider)) {
      return candidate;
    }
  }

  // All providers down, try fallback tier
  return routeRequest(scene, tier === "premium" ? "balanced" : "fast");
}

成本感知路由

// src/router/cost-router.ts
interface ModelPricing {
  provider: string;
  model: string;
  inputPer1M: number;   // USD per 1M input tokens
  outputPer1M: number;  // USD per 1M output tokens
  cachedPer1M?: number; // USD per 1M cached input tokens
}

const pricing: ModelPricing[] = [
  { provider: "openai",    model: "gpt-4o",              inputPer1M: 2.50, outputPer1M: 10.00 },
  { provider: "openai",    model: "gpt-4o-mini",         inputPer1M: 0.15, outputPer1M: 0.60 },
  { provider: "anthropic", model: "claude-sonnet-4-20250514", inputPer1M: 3.00, outputPer1M: 15.00 },
  { provider: "anthropic", model: "claude-haiku-3.5",    inputPer1M: 0.80, outputPer1M: 4.00 },
  { provider: "google",    model: "gemini-2.0-flash",    inputPer1M: 0.10, outputPer1M: 0.40 },
  { provider: "google",    model: "gemini-2.5-pro",      inputPer1M: 1.25, outputPer1M: 10.00 },
];

function estimateCost(
  model: ModelPricing,
  estimatedInputTokens: number,
  estimatedOutputTokens: number,
): number {
  return (
    (estimatedInputTokens / 1_000_000) * model.inputPer1M +
    (estimatedOutputTokens / 1_000_000) * model.outputPer1M
  );
}

function routeByCost(
  scene: string,
  maxCostUsd: number,
  estimatedInputTokens: number,
  estimatedOutputTokens: number,
): ProviderModel {
  const candidates = getSceneCandidates(scene);

  // Filter by budget, sort by quality (premium first)
  const affordable = candidates
    .map(c => ({
      ...c,
      cost: estimateCost(
        pricing.find(p => p.model === c.model)!,
        estimatedInputTokens,
        estimatedOutputTokens,
      ),
    }))
    .filter(c => c.cost <= maxCostUsd)
    .sort((a, b) => b.cost - a.cost); // Higher cost = higher quality (heuristic)

  if (!affordable.length) {
    throw new Error(`No model within budget $${maxCostUsd} for ${scene}`);
  }

  return affordable[0];
}

Fallback 链与容错

熔断器实现

// src/resilience/circuit-breaker.ts
enum CircuitState {
  CLOSED = "closed",       // Normal operation
  OPEN = "open",           // Failing, reject requests
  HALF_OPEN = "half_open", // Testing recovery
}

interface CircuitConfig {
  failureThreshold: number;    // Failures before opening
  recoveryTimeout: number;     // ms before trying half-open
  successThreshold: number;    // Successes in half-open to close
  monitorWindow: number;       // ms window for failure counting
}

class CircuitBreaker {
  private state: CircuitState = CircuitState.CLOSED;
  private failures: number[] = [];
  private successes = 0;
  private lastStateChange = Date.now();

  constructor(
    private provider: string,
    private config: CircuitConfig = {
      failureThreshold: 5,
      recoveryTimeout: 30_000,
      successThreshold: 3,
      monitorWindow: 60_000,
    },
  ) {}

  isAvailable(): boolean {
    if (this.state === CircuitState.CLOSED) return true;
    if (this.state === CircuitState.OPEN) {
      // Check if recovery timeout has passed
      if (Date.now() - this.lastStateChange > this.config.recoveryTimeout) {
        this.transition(CircuitState.HALF_OPEN);
        return true;
      }
      return false;
    }
    // HALF_OPEN: allow limited traffic
    return true;
  }

  recordSuccess(): void {
    if (this.state === CircuitState.HALF_OPEN) {
      this.successes++;
      if (this.successes >= this.config.successThreshold) {
        this.transition(CircuitState.CLOSED);
      }
    }
    // Reset failure window
    this.failures = [];
  }

  recordFailure(error: Error): void {
    const now = Date.now();
    this.failures.push(now);

    // Clean old failures outside monitoring window
    this.failures = this.failures.filter(
      t => now - t < this.config.monitorWindow
    );

    if (this.state === CircuitState.HALF_OPEN) {
      this.transition(CircuitState.OPEN);
      return;
    }

    if (this.failures.length >= this.config.failureThreshold) {
      this.transition(CircuitState.OPEN);
    }
  }

  private transition(newState: CircuitState): void {
    console.log(
      `CircuitBreaker [${this.provider}]: ${this.state} -> ${newState}`
    );
    this.state = newState;
    this.lastStateChange = Date.now();
    this.successes = 0;
    if (newState === CircuitState.CLOSED) {
      this.failures = [];
    }
  }
}

Fallback 链执行器

// src/resilience/fallback-chain.ts
interface FallbackResult {
  response: LLMResponse;
  provider: string;
  model: string;
  attempts: AttemptRecord[];
}

interface AttemptRecord {
  provider: string;
  model: string;
  latencyMs: number;
  error?: string;
}

async function executeWithFallback(
  request: LLMRequest,
  chain: ProviderModel[],
): Promise<FallbackResult> {
  const attempts: AttemptRecord[] = [];

  for (const { provider, model } of chain) {
    const breaker = getCircuitBreaker(provider);

    if (!breaker.isAvailable()) {
      attempts.push({
        provider, model, latencyMs: 0,
        error: "Circuit breaker OPEN",
      });
      continue;
    }

    const start = Date.now();
    try {
      const response = await callProvider(provider, model, request, {
        timeout: 30_000,
        retries: 1,
      });

      breaker.recordSuccess();
      attempts.push({ provider, model, latencyMs: Date.now() - start });

      return { response, provider, model, attempts };
    } catch (error) {
      const latencyMs = Date.now() - start;
      breaker.recordFailure(error);
      attempts.push({
        provider, model, latencyMs,
        error: error.message,
      });

      // Log for observability
      console.error(
        `Fallback: ${provider}/${model} failed (${latencyMs}ms): ${error.message}`
      );
    }
  }

  throw new FallbackExhaustedError(
    `All providers failed for request`,
    attempts,
  );
}

Token 计量与成本追踪

精确 Token 计数

// src/metering/token-counter.ts
import { encode } from "gpt-tokenizer";    // tiktoken compatible
import Anthropic from "@anthropic-ai/sdk";

interface TokenUsage {
  inputTokens: number;
  outputTokens: number;
  cachedTokens?: number;
  totalTokens: number;
  estimatedCostUsd: number;
}

function countTokens(
  provider: string,
  model: string,
  messages: Message[],
  response: LLMResponse,
): TokenUsage {
  // Most providers return token counts in response
  const usage = response.usage;

  if (usage) {
    const modelPricing = getPricing(provider, model);
    return {
      inputTokens: usage.input_tokens ?? usage.prompt_tokens,
      outputTokens: usage.output_tokens ?? usage.completion_tokens,
      cachedTokens: usage.cache_read_input_tokens ?? 0,
      totalTokens: (usage.input_tokens ?? usage.prompt_tokens) +
                   (usage.output_tokens ?? usage.completion_tokens),
      estimatedCostUsd: calculateCost(modelPricing, usage),
    };
  }

  // Fallback: client-side estimation
  const inputText = messages.map(m => m.content).join(" ");
  const inputTokens = encode(inputText).length;
  const outputTokens = encode(response.content).length;

  return {
    inputTokens,
    outputTokens,
    totalTokens: inputTokens + outputTokens,
    estimatedCostUsd: 0, // Mark as estimated
  };
}

成本看板数据模型

-- Cost tracking schema
CREATE TABLE llm_usage_log (
    id              BIGSERIAL PRIMARY KEY,
    request_id      UUID NOT NULL,
    timestamp       TIMESTAMPTZ DEFAULT NOW(),

    -- Routing info
    scene           TEXT NOT NULL,
    tier            TEXT NOT NULL,
    provider        TEXT NOT NULL,
    model           TEXT NOT NULL,

    -- Token usage
    input_tokens    INTEGER NOT NULL,
    output_tokens   INTEGER NOT NULL,
    cached_tokens   INTEGER DEFAULT 0,
    total_tokens    INTEGER GENERATED ALWAYS AS (input_tokens + output_tokens) STORED,

    -- Cost
    cost_usd        NUMERIC(10, 6) NOT NULL,

    -- Performance
    latency_ms      INTEGER NOT NULL,
    ttft_ms         INTEGER,           -- Time to First Token

    -- Context
    user_id         TEXT,
    team_id         TEXT,
    app_id          TEXT NOT NULL,
    was_fallback    BOOLEAN DEFAULT FALSE,
    fallback_chain  JSONB,             -- Full attempt history
    cache_hit       BOOLEAN DEFAULT FALSE
);

-- Daily cost aggregation view
CREATE MATERIALIZED VIEW daily_cost_by_team AS
SELECT
    DATE_TRUNC('day', timestamp) AS day,
    team_id,
    app_id,
    provider,
    model,
    COUNT(*) AS request_count,
    SUM(input_tokens) AS total_input_tokens,
    SUM(output_tokens) AS total_output_tokens,
    SUM(cost_usd) AS total_cost_usd,
    AVG(latency_ms) AS avg_latency_ms,
    PERCENTILE_CONT(0.99) WITHIN GROUP (ORDER BY latency_ms) AS p99_latency_ms,
    SUM(CASE WHEN was_fallback THEN 1 ELSE 0 END) AS fallback_count,
    SUM(CASE WHEN cache_hit THEN 1 ELSE 0 END) AS cache_hit_count
FROM llm_usage_log
GROUP BY 1, 2, 3, 4, 5;

-- Query: monthly cost by team with trend
SELECT
    team_id,
    SUM(total_cost_usd) AS monthly_cost,
    SUM(request_count) AS monthly_requests,
    ROUND(SUM(total_cost_usd) / SUM(request_count) * 1000, 2) AS cost_per_1k_requests,
    ROUND(SUM(cache_hit_count)::NUMERIC / SUM(request_count) * 100, 1) AS cache_hit_pct
FROM daily_cost_by_team
WHERE day >= DATE_TRUNC('month', NOW())
GROUP BY team_id
ORDER BY monthly_cost DESC;

速率限制

令牌桶限流器

// src/rate-limiter/token-bucket.ts
class TokenBucketLimiter {
  private tokens: number;
  private lastRefill: number;

  constructor(
    private maxTokens: number,      // Bucket capacity
    private refillRate: number,     // Tokens added per second
  ) {
    this.tokens = maxTokens;
    this.lastRefill = Date.now();
  }

  tryConsume(cost: number = 1): boolean {
    this.refill();

    if (this.tokens >= cost) {
      this.tokens -= cost;
      return true;
    }

    return false;
  }

  private refill(): void {
    const now = Date.now();
    const elapsed = (now - this.lastRefill) / 1000;
    this.tokens = Math.min(
      this.maxTokens,
      this.tokens + elapsed * this.refillRate,
    );
    this.lastRefill = now;
  }

  getWaitTime(cost: number = 1): number {
    this.refill();
    if (this.tokens >= cost) return 0;
    return Math.ceil((cost - this.tokens) / this.refillRate * 1000);
  }
}

// Per-provider rate limits (matching their documented limits)
const providerLimits = {
  openai: {
    rpm: new TokenBucketLimiter(500, 500 / 60),          // 500 RPM
    tpm: new TokenBucketLimiter(200_000, 200_000 / 60),  // 200K TPM
  },
  anthropic: {
    rpm: new TokenBucketLimiter(1000, 1000 / 60),
    tpm: new TokenBucketLimiter(400_000, 400_000 / 60),
  },
  google: {
    rpm: new TokenBucketLimiter(1000, 1000 / 60),
    tpm: new TokenBucketLimiter(4_000_000, 4_000_000 / 60),
  },
};

生产部署参考架构

                    ┌─────────────┐
                    │ CloudFlare  │
                    │   (CDN)     │
                    └──────┬──────┘
                           │
                    ┌──────▼──────┐
                    │   Nginx     │
                    │ (TLS终止)   │
                    └──────┬──────┘
                           │
              ┌────────────┼────────────┐
              │            │            │
        ┌─────▼────┐ ┌────▼─────┐ ┌────▼─────┐
        │ Gateway  │ │ Gateway  │ │ Gateway  │
        │ Pod #1   │ │ Pod #2   │ │ Pod #3   │
        └─────┬────┘ └────┬─────┘ └────┬─────┘
              │            │            │
        ┌─────▼────────────▼────────────▼─────┐
        │              Redis                   │
        │  (Rate Limits + Semantic Cache)      │
        └─────┬────────────┬──────────────────┘
              │            │
        ┌─────▼────┐ ┌────▼──────┐
        │PostgreSQL│ │ClickHouse │
        │(Config)  │ │(Analytics)│
        └──────────┘ └───────────┘

总结

  1. 统一接口,多后端:业务代码只需感知场景和质量档位,不需要关心具体 Provider。
  2. Fallback 是必需品:任何单一 Provider 都会宕机,Fallback 链确保服务可用性。
  3. 熔断器防止雪崩:快速识别故障 Provider,避免超时等待浪费用户时间。
  4. 成本可见才能可控:逐请求记录 Token 用量和成本,建立团队/应用维度的成本看板。
  5. 缓存是最便宜的推理:语义缓存可以将重复查询的成本降为零,命中率 30% 就能节省可观成本。

Maurice | maurice_wen@proton.me