AI Agent 的工具调用优化策略

概述

工具调用是 Agent 与外部世界交互的基础能力。但在生产环境中,朴素的工具调用面临多个问题:延迟高(每次调用都是网络请求)、成本高(每次调用消耗 token)、不稳定(API 可能失败)、安全风险(Agent 可能调用不该调用的工具)。

本文从并行调用、缓存策略、错误恢复、成本控制和工具选择优化五个维度,给出工程级的解决方案。

一、并行工具调用

问题

默认的顺序调用模式:Tool A (200ms) -> Tool B (300ms) -> Tool C (150ms) = 650ms 总延迟。 如果三个工具无依赖关系,并行执行仅需 300ms。

实现

import asyncio
from dataclasses import dataclass
from typing import Any

@dataclass
class ToolCallRequest:
    id: str
    name: str
    args: dict

@dataclass
class ToolCallResult:
    id: str
    name: str
    result: Any
    duration_ms: float
    success: bool
    error: str = ""

class ParallelToolExecutor:
    """并行工具执行器"""

    def __init__(self, tool_registry: dict, max_concurrency: int = 10):
        self.tools = tool_registry
        self.semaphore = asyncio.Semaphore(max_concurrency)

    async def execute_batch(
        self, requests: list[ToolCallRequest]
    ) -> list[ToolCallResult]:
        """并行执行一批工具调用"""
        # 分析依赖关系
        independent, dependent = self._analyze_dependencies(requests)

        # 第一阶段:并行执行无依赖的调用
        tasks = [self._execute_single(req) for req in independent]
        results = await asyncio.gather(*tasks, return_exceptions=True)

        # 处理结果
        completed = []
        for req, result in zip(independent, results):
            if isinstance(result, Exception):
                completed.append(ToolCallResult(
                    id=req.id, name=req.name, result=None,
                    duration_ms=0, success=False, error=str(result),
                ))
            else:
                completed.append(result)

        # 第二阶段:执行依赖前一阶段结果的调用
        for req in dependent:
            result = await self._execute_single(req)
            completed.append(result)

        return completed

    async def _execute_single(self, request: ToolCallRequest) -> ToolCallResult:
        async with self.semaphore:
            start = asyncio.get_event_loop().time()

            try:
                func = self.tools[request.name]
                result = await func(**request.args)
                duration = (asyncio.get_event_loop().time() - start) * 1000

                return ToolCallResult(
                    id=request.id,
                    name=request.name,
                    result=result,
                    duration_ms=duration,
                    success=True,
                )
            except Exception as e:
                duration = (asyncio.get_event_loop().time() - start) * 1000
                return ToolCallResult(
                    id=request.id,
                    name=request.name,
                    result=None,
                    duration_ms=duration,
                    success=False,
                    error=str(e),
                )

    def _analyze_dependencies(self, requests):
        """简单的依赖分析:同名工具视为可能有依赖"""
        seen_names = set()
        independent = []
        dependent = []

        for req in requests:
            if req.name in seen_names:
                dependent.append(req)
            else:
                independent.append(req)
                seen_names.add(req.name)

        return independent, dependent

OpenAI 原生并行调用

# OpenAI gpt-4o 原生支持并行工具调用
# 模型可以在一次响应中返回多个 tool_calls
response = client.chat.completions.create(
    model="gpt-4o",
    messages=messages,
    tools=tools,
    parallel_tool_calls=True,  # 默认开启
)

# response.choices[0].message.tool_calls 可能包含多个调用
# [
#     {"id": "call_1", "function": {"name": "get_weather", "arguments": '{"city": "北京"}'}},
#     {"id": "call_2", "function": {"name": "get_weather", "arguments": '{"city": "上海"}'}},
#     {"id": "call_3", "function": {"name": "get_stock", "arguments": '{"symbol": "TSLA"}'}},
# ]

# 并行执行所有调用
tool_calls = response.choices[0].message.tool_calls
executor = ParallelToolExecutor(tool_registry)
requests = [
    ToolCallRequest(id=tc.id, name=tc.function.name,
                    args=json.loads(tc.function.arguments))
    for tc in tool_calls
]
results = await executor.execute_batch(requests)

二、工具结果缓存

分层缓存策略

import hashlib
import json
import time
from enum import Enum

class CacheTier(Enum):
    MEMORY = "memory"    # 内存缓存(ms 级)
    REDIS = "redis"      # Redis 缓存(ms 级)
    DISK = "disk"        # 磁盘缓存(10ms 级)

class ToolCache:
    """多层工具结果缓存"""

    # 定义每个工具的缓存策略
    CACHE_POLICIES = {
        "get_weather": {"tier": CacheTier.REDIS, "ttl": 1800},      # 30 分钟
        "search_web": {"tier": CacheTier.REDIS, "ttl": 3600},       # 1 小时
        "get_exchange_rate": {"tier": CacheTier.REDIS, "ttl": 300},  # 5 分钟
        "query_database": {"tier": CacheTier.MEMORY, "ttl": 60},    # 1 分钟
        "calculate": {"tier": CacheTier.MEMORY, "ttl": 86400},      # 24 小时(确定性计算)
        # 不可缓存的工具不在列表中
        # "send_email" -> 不缓存
        # "create_order" -> 不缓存(有副作用)
    }

    def __init__(self, redis_client=None):
        self.memory_cache = {}
        self.redis = redis_client

    def _cache_key(self, tool_name: str, args: dict) -> str:
        args_str = json.dumps(args, sort_keys=True, ensure_ascii=False)
        content = f"{tool_name}:{args_str}"
        return hashlib.sha256(content.encode()).hexdigest()

    async def get(self, tool_name: str, args: dict) -> tuple[bool, Any]:
        """尝试从缓存获取结果"""
        policy = self.CACHE_POLICIES.get(tool_name)
        if not policy:
            return False, None  # 不可缓存的工具

        key = self._cache_key(tool_name, args)

        if policy["tier"] == CacheTier.MEMORY:
            entry = self.memory_cache.get(key)
            if entry and (time.time() - entry["time"]) < policy["ttl"]:
                return True, entry["result"]

        elif policy["tier"] == CacheTier.REDIS and self.redis:
            cached = await self.redis.get(f"tool_cache:{key}")
            if cached:
                return True, json.loads(cached)

        return False, None

    async def set(self, tool_name: str, args: dict, result: Any):
        """缓存工具结果"""
        policy = self.CACHE_POLICIES.get(tool_name)
        if not policy:
            return

        key = self._cache_key(tool_name, args)

        if policy["tier"] == CacheTier.MEMORY:
            self.memory_cache[key] = {"result": result, "time": time.time()}

        elif policy["tier"] == CacheTier.REDIS and self.redis:
            await self.redis.setex(
                f"tool_cache:{key}",
                policy["ttl"],
                json.dumps(result, ensure_ascii=False),
            )

三、错误恢复

重试策略

import asyncio
import random

class RetryConfig:
    def __init__(self, max_retries=3, base_delay=1.0,
                 max_delay=30.0, exponential_base=2):
        self.max_retries = max_retries
        self.base_delay = base_delay
        self.max_delay = max_delay
        self.exponential_base = exponential_base

class ResilientToolExecutor:
    """带重试和降级的工具执行器"""

    # 每个工具的重试配置
    RETRY_CONFIGS = {
        "search_web": RetryConfig(max_retries=3, base_delay=1.0),
        "query_database": RetryConfig(max_retries=2, base_delay=0.5),
        "call_api": RetryConfig(max_retries=3, base_delay=2.0),
    }

    # 降级工具映射
    FALLBACK_TOOLS = {
        "search_web_google": "search_web_bing",
        "translate_deepl": "translate_google",
    }

    def __init__(self, tool_registry: dict):
        self.tools = tool_registry

    async def execute_with_retry(
        self, tool_name: str, args: dict
    ) -> dict:
        config = self.RETRY_CONFIGS.get(tool_name, RetryConfig())
        last_error = None

        for attempt in range(config.max_retries + 1):
            try:
                result = await asyncio.wait_for(
                    self.tools[tool_name](**args),
                    timeout=30.0,
                )
                return {"status": "success", "result": result, "attempts": attempt + 1}

            except asyncio.TimeoutError:
                last_error = "Timeout"
            except Exception as e:
                last_error = str(e)

                # 不可重试的错误直接返回
                if self._is_non_retryable(e):
                    break

            # 指数退避 + 抖动
            if attempt < config.max_retries:
                delay = min(
                    config.base_delay * (config.exponential_base ** attempt),
                    config.max_delay,
                )
                jitter = random.uniform(0, delay * 0.1)
                await asyncio.sleep(delay + jitter)

        # 重试耗尽,尝试降级
        fallback_name = self.FALLBACK_TOOLS.get(tool_name)
        if fallback_name and fallback_name in self.tools:
            try:
                result = await self.tools[fallback_name](**args)
                return {
                    "status": "fallback",
                    "result": result,
                    "original_tool": tool_name,
                    "fallback_tool": fallback_name,
                }
            except Exception:
                pass

        return {
            "status": "failed",
            "error": last_error,
            "attempts": config.max_retries + 1,
            "hint": f"{tool_name} failed after {config.max_retries + 1} attempts. "
                    "Consider alternative approaches.",
        }

    def _is_non_retryable(self, error: Exception) -> bool:
        """判断是否为不可重试的错误"""
        non_retryable = (
            ValueError, TypeError, KeyError,
            PermissionError, FileNotFoundError,
        )
        return isinstance(error, non_retryable)

四、成本控制

Token 预算管理

class CostController:
    """工具调用成本控制器"""

    # 工具调用的 token 成本估算
    TOOL_TOKEN_COSTS = {
        "search_web": {"input": 100, "output": 500},
        "query_database": {"input": 200, "output": 1000},
        "analyze_image": {"input": 1000, "output": 300},
    }

    def __init__(self, budget_usd: float = 1.0, token_budget: int = 100000):
        self.budget_usd = budget_usd
        self.token_budget = token_budget
        self.spent_usd = 0.0
        self.tokens_used = 0

    def can_afford(self, tool_name: str) -> tuple[bool, str]:
        """检查是否有足够预算调用工具"""
        cost_estimate = self.TOOL_TOKEN_COSTS.get(tool_name, {"input": 200, "output": 500})
        estimated_tokens = cost_estimate["input"] + cost_estimate["output"]

        if self.tokens_used + estimated_tokens > self.token_budget:
            return False, f"Token budget exhausted ({self.tokens_used}/{self.token_budget})"

        estimated_cost = estimated_tokens * 0.00001  # 粗略成本估算
        if self.spent_usd + estimated_cost > self.budget_usd:
            return False, f"USD budget exhausted (${self.spent_usd:.4f}/${self.budget_usd})"

        return True, "ok"

    def record_usage(self, tool_name: str, tokens: int, cost_usd: float):
        self.tokens_used += tokens
        self.spent_usd += cost_usd

    def get_budget_status(self) -> dict:
        return {
            "tokens_used": self.tokens_used,
            "tokens_remaining": self.token_budget - self.tokens_used,
            "usd_spent": round(self.spent_usd, 6),
            "usd_remaining": round(self.budget_usd - self.spent_usd, 6),
            "utilization": round(self.tokens_used / self.token_budget * 100, 1),
        }

工具调用频率限制

class ToolRateLimiter:
    """防止 Agent 过度调用工具"""

    def __init__(self):
        self.call_history: dict[str, list[float]] = {}
        self.limits = {
            "search_web": {"max_per_minute": 10, "max_per_session": 50},
            "send_email": {"max_per_minute": 2, "max_per_session": 10},
            "query_database": {"max_per_minute": 30, "max_per_session": 200},
        }

    def check(self, tool_name: str) -> tuple[bool, str]:
        now = time.time()
        history = self.call_history.get(tool_name, [])

        limits = self.limits.get(tool_name, {
            "max_per_minute": 20,
            "max_per_session": 100,
        })

        # 检查每分钟限制
        recent = [t for t in history if now - t < 60]
        if len(recent) >= limits["max_per_minute"]:
            return False, f"{tool_name}: rate limit ({limits['max_per_minute']}/min)"

        # 检查会话总量限制
        if len(history) >= limits["max_per_session"]:
            return False, f"{tool_name}: session limit ({limits['max_per_session']})"

        return True, "ok"

    def record(self, tool_name: str):
        if tool_name not in self.call_history:
            self.call_history[tool_name] = []
        self.call_history[tool_name].append(time.time())

五、工具选择优化

动态工具集精简

class ToolSelector:
    """根据上下文动态选择工具子集,减少 token 开销"""

    def __init__(self, all_tools: list[dict]):
        self.all_tools = all_tools
        self.tool_descriptions = {t["function"]["name"]: t for t in all_tools}

    async def select_tools(
        self, user_message: str, context: dict, max_tools: int = 8
    ) -> list[dict]:
        """根据用户消息选择最相关的工具子集"""

        # 规则一:关键词匹配
        keyword_scores = {}
        for tool in self.all_tools:
            name = tool["function"]["name"]
            desc = tool["function"]["description"].lower()
            score = 0

            message_lower = user_message.lower()
            keywords = message_lower.split()
            for keyword in keywords:
                if keyword in desc or keyword in name:
                    score += 1

            keyword_scores[name] = score

        # 规则二:上下文相关
        if "code" in user_message.lower() or "编程" in user_message:
            for name in ["execute_code", "read_file", "write_file"]:
                keyword_scores[name] = keyword_scores.get(name, 0) + 5

        if "数据" in user_message or "分析" in user_message:
            for name in ["query_database", "create_chart", "calculate"]:
                keyword_scores[name] = keyword_scores.get(name, 0) + 5

        # 排序并选择 top-N
        sorted_tools = sorted(
            keyword_scores.items(),
            key=lambda x: x[1],
            reverse=True,
        )

        selected_names = [name for name, _ in sorted_tools[:max_tools]]

        # 确保包含基础工具
        essential_tools = ["search_web", "calculate"]
        for tool in essential_tools:
            if tool not in selected_names and tool in self.tool_descriptions:
                selected_names.append(tool)

        return [
            self.tool_descriptions[name]
            for name in selected_names
            if name in self.tool_descriptions
        ]

整合:生产级工具调用管道

class ProductionToolPipeline:
    """整合所有优化的生产级工具调用管道"""

    def __init__(self, tool_registry, redis_client=None):
        self.executor = ParallelToolExecutor(tool_registry)
        self.cache = ToolCache(redis_client)
        self.resilient = ResilientToolExecutor(tool_registry)
        self.cost = CostController(budget_usd=5.0)
        self.rate_limiter = ToolRateLimiter()
        self.selector = ToolSelector(build_tool_schemas(tool_registry))

    async def execute(self, tool_calls: list[dict]) -> list[dict]:
        results = []

        for tc in tool_calls:
            name = tc["name"]
            args = tc["args"]

            # 1. 预算检查
            can_afford, reason = self.cost.can_afford(name)
            if not can_afford:
                results.append({"id": tc["id"], "error": reason})
                continue

            # 2. 频率检查
            allowed, reason = self.rate_limiter.check(name)
            if not allowed:
                results.append({"id": tc["id"], "error": reason})
                continue

            # 3. 缓存查找
            hit, cached_result = await self.cache.get(name, args)
            if hit:
                results.append({"id": tc["id"], "result": cached_result, "cached": True})
                continue

            # 4. 带重试的执行
            result = await self.resilient.execute_with_retry(name, args)

            # 5. 记录用量
            self.rate_limiter.record(name)
            self.cost.record_usage(name, tokens=500, cost_usd=0.005)

            # 6. 缓存结果
            if result["status"] == "success":
                await self.cache.set(name, args, result["result"])

            results.append({"id": tc["id"], **result})

        return results

总结

工具调用优化的核心策略:

  1. 并行化:无依赖的工具调用并行执行,延迟取决于最慢的一个
  2. 缓存:确定性结果缓存(按工具 + 参数哈希),有副作用的工具不缓存
  3. 弹性:指数退避重试 + 降级备选 + 超时保护
  4. 成本控制:token 预算 + 调用频率限制 + 按需选择工具子集
  5. 工具精简:根据上下文动态选择相关工具,减少 prompt token 开销

Maurice | maurice_wen@proton.me