AI Agent 的工具调用优化策略
原创
灵阙教研团队
S 精选 进阶 |
约 9 分钟阅读
更新于 2026-02-28 AI 导读
AI Agent 的工具调用优化策略 概述 工具调用是 Agent 与外部世界交互的基础能力。但在生产环境中,朴素的工具调用面临多个问题:延迟高(每次调用都是网络请求)、成本高(每次调用消耗 token)、不稳定(API 可能失败)、安全风险(Agent 可能调用不该调用的工具)。 本文从并行调用、缓存策略、错误恢复、成本控制和工具选择优化五个维度,给出工程级的解决方案。 一、并行工具调用 问题...
AI Agent 的工具调用优化策略
概述
工具调用是 Agent 与外部世界交互的基础能力。但在生产环境中,朴素的工具调用面临多个问题:延迟高(每次调用都是网络请求)、成本高(每次调用消耗 token)、不稳定(API 可能失败)、安全风险(Agent 可能调用不该调用的工具)。
本文从并行调用、缓存策略、错误恢复、成本控制和工具选择优化五个维度,给出工程级的解决方案。
一、并行工具调用
问题
默认的顺序调用模式:Tool A (200ms) -> Tool B (300ms) -> Tool C (150ms) = 650ms 总延迟。 如果三个工具无依赖关系,并行执行仅需 300ms。
实现
import asyncio
from dataclasses import dataclass
from typing import Any
@dataclass
class ToolCallRequest:
id: str
name: str
args: dict
@dataclass
class ToolCallResult:
id: str
name: str
result: Any
duration_ms: float
success: bool
error: str = ""
class ParallelToolExecutor:
"""并行工具执行器"""
def __init__(self, tool_registry: dict, max_concurrency: int = 10):
self.tools = tool_registry
self.semaphore = asyncio.Semaphore(max_concurrency)
async def execute_batch(
self, requests: list[ToolCallRequest]
) -> list[ToolCallResult]:
"""并行执行一批工具调用"""
# 分析依赖关系
independent, dependent = self._analyze_dependencies(requests)
# 第一阶段:并行执行无依赖的调用
tasks = [self._execute_single(req) for req in independent]
results = await asyncio.gather(*tasks, return_exceptions=True)
# 处理结果
completed = []
for req, result in zip(independent, results):
if isinstance(result, Exception):
completed.append(ToolCallResult(
id=req.id, name=req.name, result=None,
duration_ms=0, success=False, error=str(result),
))
else:
completed.append(result)
# 第二阶段:执行依赖前一阶段结果的调用
for req in dependent:
result = await self._execute_single(req)
completed.append(result)
return completed
async def _execute_single(self, request: ToolCallRequest) -> ToolCallResult:
async with self.semaphore:
start = asyncio.get_event_loop().time()
try:
func = self.tools[request.name]
result = await func(**request.args)
duration = (asyncio.get_event_loop().time() - start) * 1000
return ToolCallResult(
id=request.id,
name=request.name,
result=result,
duration_ms=duration,
success=True,
)
except Exception as e:
duration = (asyncio.get_event_loop().time() - start) * 1000
return ToolCallResult(
id=request.id,
name=request.name,
result=None,
duration_ms=duration,
success=False,
error=str(e),
)
def _analyze_dependencies(self, requests):
"""简单的依赖分析:同名工具视为可能有依赖"""
seen_names = set()
independent = []
dependent = []
for req in requests:
if req.name in seen_names:
dependent.append(req)
else:
independent.append(req)
seen_names.add(req.name)
return independent, dependent
OpenAI 原生并行调用
# OpenAI gpt-4o 原生支持并行工具调用
# 模型可以在一次响应中返回多个 tool_calls
response = client.chat.completions.create(
model="gpt-4o",
messages=messages,
tools=tools,
parallel_tool_calls=True, # 默认开启
)
# response.choices[0].message.tool_calls 可能包含多个调用
# [
# {"id": "call_1", "function": {"name": "get_weather", "arguments": '{"city": "北京"}'}},
# {"id": "call_2", "function": {"name": "get_weather", "arguments": '{"city": "上海"}'}},
# {"id": "call_3", "function": {"name": "get_stock", "arguments": '{"symbol": "TSLA"}'}},
# ]
# 并行执行所有调用
tool_calls = response.choices[0].message.tool_calls
executor = ParallelToolExecutor(tool_registry)
requests = [
ToolCallRequest(id=tc.id, name=tc.function.name,
args=json.loads(tc.function.arguments))
for tc in tool_calls
]
results = await executor.execute_batch(requests)
二、工具结果缓存
分层缓存策略
import hashlib
import json
import time
from enum import Enum
class CacheTier(Enum):
MEMORY = "memory" # 内存缓存(ms 级)
REDIS = "redis" # Redis 缓存(ms 级)
DISK = "disk" # 磁盘缓存(10ms 级)
class ToolCache:
"""多层工具结果缓存"""
# 定义每个工具的缓存策略
CACHE_POLICIES = {
"get_weather": {"tier": CacheTier.REDIS, "ttl": 1800}, # 30 分钟
"search_web": {"tier": CacheTier.REDIS, "ttl": 3600}, # 1 小时
"get_exchange_rate": {"tier": CacheTier.REDIS, "ttl": 300}, # 5 分钟
"query_database": {"tier": CacheTier.MEMORY, "ttl": 60}, # 1 分钟
"calculate": {"tier": CacheTier.MEMORY, "ttl": 86400}, # 24 小时(确定性计算)
# 不可缓存的工具不在列表中
# "send_email" -> 不缓存
# "create_order" -> 不缓存(有副作用)
}
def __init__(self, redis_client=None):
self.memory_cache = {}
self.redis = redis_client
def _cache_key(self, tool_name: str, args: dict) -> str:
args_str = json.dumps(args, sort_keys=True, ensure_ascii=False)
content = f"{tool_name}:{args_str}"
return hashlib.sha256(content.encode()).hexdigest()
async def get(self, tool_name: str, args: dict) -> tuple[bool, Any]:
"""尝试从缓存获取结果"""
policy = self.CACHE_POLICIES.get(tool_name)
if not policy:
return False, None # 不可缓存的工具
key = self._cache_key(tool_name, args)
if policy["tier"] == CacheTier.MEMORY:
entry = self.memory_cache.get(key)
if entry and (time.time() - entry["time"]) < policy["ttl"]:
return True, entry["result"]
elif policy["tier"] == CacheTier.REDIS and self.redis:
cached = await self.redis.get(f"tool_cache:{key}")
if cached:
return True, json.loads(cached)
return False, None
async def set(self, tool_name: str, args: dict, result: Any):
"""缓存工具结果"""
policy = self.CACHE_POLICIES.get(tool_name)
if not policy:
return
key = self._cache_key(tool_name, args)
if policy["tier"] == CacheTier.MEMORY:
self.memory_cache[key] = {"result": result, "time": time.time()}
elif policy["tier"] == CacheTier.REDIS and self.redis:
await self.redis.setex(
f"tool_cache:{key}",
policy["ttl"],
json.dumps(result, ensure_ascii=False),
)
三、错误恢复
重试策略
import asyncio
import random
class RetryConfig:
def __init__(self, max_retries=3, base_delay=1.0,
max_delay=30.0, exponential_base=2):
self.max_retries = max_retries
self.base_delay = base_delay
self.max_delay = max_delay
self.exponential_base = exponential_base
class ResilientToolExecutor:
"""带重试和降级的工具执行器"""
# 每个工具的重试配置
RETRY_CONFIGS = {
"search_web": RetryConfig(max_retries=3, base_delay=1.0),
"query_database": RetryConfig(max_retries=2, base_delay=0.5),
"call_api": RetryConfig(max_retries=3, base_delay=2.0),
}
# 降级工具映射
FALLBACK_TOOLS = {
"search_web_google": "search_web_bing",
"translate_deepl": "translate_google",
}
def __init__(self, tool_registry: dict):
self.tools = tool_registry
async def execute_with_retry(
self, tool_name: str, args: dict
) -> dict:
config = self.RETRY_CONFIGS.get(tool_name, RetryConfig())
last_error = None
for attempt in range(config.max_retries + 1):
try:
result = await asyncio.wait_for(
self.tools[tool_name](**args),
timeout=30.0,
)
return {"status": "success", "result": result, "attempts": attempt + 1}
except asyncio.TimeoutError:
last_error = "Timeout"
except Exception as e:
last_error = str(e)
# 不可重试的错误直接返回
if self._is_non_retryable(e):
break
# 指数退避 + 抖动
if attempt < config.max_retries:
delay = min(
config.base_delay * (config.exponential_base ** attempt),
config.max_delay,
)
jitter = random.uniform(0, delay * 0.1)
await asyncio.sleep(delay + jitter)
# 重试耗尽,尝试降级
fallback_name = self.FALLBACK_TOOLS.get(tool_name)
if fallback_name and fallback_name in self.tools:
try:
result = await self.tools[fallback_name](**args)
return {
"status": "fallback",
"result": result,
"original_tool": tool_name,
"fallback_tool": fallback_name,
}
except Exception:
pass
return {
"status": "failed",
"error": last_error,
"attempts": config.max_retries + 1,
"hint": f"{tool_name} failed after {config.max_retries + 1} attempts. "
"Consider alternative approaches.",
}
def _is_non_retryable(self, error: Exception) -> bool:
"""判断是否为不可重试的错误"""
non_retryable = (
ValueError, TypeError, KeyError,
PermissionError, FileNotFoundError,
)
return isinstance(error, non_retryable)
四、成本控制
Token 预算管理
class CostController:
"""工具调用成本控制器"""
# 工具调用的 token 成本估算
TOOL_TOKEN_COSTS = {
"search_web": {"input": 100, "output": 500},
"query_database": {"input": 200, "output": 1000},
"analyze_image": {"input": 1000, "output": 300},
}
def __init__(self, budget_usd: float = 1.0, token_budget: int = 100000):
self.budget_usd = budget_usd
self.token_budget = token_budget
self.spent_usd = 0.0
self.tokens_used = 0
def can_afford(self, tool_name: str) -> tuple[bool, str]:
"""检查是否有足够预算调用工具"""
cost_estimate = self.TOOL_TOKEN_COSTS.get(tool_name, {"input": 200, "output": 500})
estimated_tokens = cost_estimate["input"] + cost_estimate["output"]
if self.tokens_used + estimated_tokens > self.token_budget:
return False, f"Token budget exhausted ({self.tokens_used}/{self.token_budget})"
estimated_cost = estimated_tokens * 0.00001 # 粗略成本估算
if self.spent_usd + estimated_cost > self.budget_usd:
return False, f"USD budget exhausted (${self.spent_usd:.4f}/${self.budget_usd})"
return True, "ok"
def record_usage(self, tool_name: str, tokens: int, cost_usd: float):
self.tokens_used += tokens
self.spent_usd += cost_usd
def get_budget_status(self) -> dict:
return {
"tokens_used": self.tokens_used,
"tokens_remaining": self.token_budget - self.tokens_used,
"usd_spent": round(self.spent_usd, 6),
"usd_remaining": round(self.budget_usd - self.spent_usd, 6),
"utilization": round(self.tokens_used / self.token_budget * 100, 1),
}
工具调用频率限制
class ToolRateLimiter:
"""防止 Agent 过度调用工具"""
def __init__(self):
self.call_history: dict[str, list[float]] = {}
self.limits = {
"search_web": {"max_per_minute": 10, "max_per_session": 50},
"send_email": {"max_per_minute": 2, "max_per_session": 10},
"query_database": {"max_per_minute": 30, "max_per_session": 200},
}
def check(self, tool_name: str) -> tuple[bool, str]:
now = time.time()
history = self.call_history.get(tool_name, [])
limits = self.limits.get(tool_name, {
"max_per_minute": 20,
"max_per_session": 100,
})
# 检查每分钟限制
recent = [t for t in history if now - t < 60]
if len(recent) >= limits["max_per_minute"]:
return False, f"{tool_name}: rate limit ({limits['max_per_minute']}/min)"
# 检查会话总量限制
if len(history) >= limits["max_per_session"]:
return False, f"{tool_name}: session limit ({limits['max_per_session']})"
return True, "ok"
def record(self, tool_name: str):
if tool_name not in self.call_history:
self.call_history[tool_name] = []
self.call_history[tool_name].append(time.time())
五、工具选择优化
动态工具集精简
class ToolSelector:
"""根据上下文动态选择工具子集,减少 token 开销"""
def __init__(self, all_tools: list[dict]):
self.all_tools = all_tools
self.tool_descriptions = {t["function"]["name"]: t for t in all_tools}
async def select_tools(
self, user_message: str, context: dict, max_tools: int = 8
) -> list[dict]:
"""根据用户消息选择最相关的工具子集"""
# 规则一:关键词匹配
keyword_scores = {}
for tool in self.all_tools:
name = tool["function"]["name"]
desc = tool["function"]["description"].lower()
score = 0
message_lower = user_message.lower()
keywords = message_lower.split()
for keyword in keywords:
if keyword in desc or keyword in name:
score += 1
keyword_scores[name] = score
# 规则二:上下文相关
if "code" in user_message.lower() or "编程" in user_message:
for name in ["execute_code", "read_file", "write_file"]:
keyword_scores[name] = keyword_scores.get(name, 0) + 5
if "数据" in user_message or "分析" in user_message:
for name in ["query_database", "create_chart", "calculate"]:
keyword_scores[name] = keyword_scores.get(name, 0) + 5
# 排序并选择 top-N
sorted_tools = sorted(
keyword_scores.items(),
key=lambda x: x[1],
reverse=True,
)
selected_names = [name for name, _ in sorted_tools[:max_tools]]
# 确保包含基础工具
essential_tools = ["search_web", "calculate"]
for tool in essential_tools:
if tool not in selected_names and tool in self.tool_descriptions:
selected_names.append(tool)
return [
self.tool_descriptions[name]
for name in selected_names
if name in self.tool_descriptions
]
整合:生产级工具调用管道
class ProductionToolPipeline:
"""整合所有优化的生产级工具调用管道"""
def __init__(self, tool_registry, redis_client=None):
self.executor = ParallelToolExecutor(tool_registry)
self.cache = ToolCache(redis_client)
self.resilient = ResilientToolExecutor(tool_registry)
self.cost = CostController(budget_usd=5.0)
self.rate_limiter = ToolRateLimiter()
self.selector = ToolSelector(build_tool_schemas(tool_registry))
async def execute(self, tool_calls: list[dict]) -> list[dict]:
results = []
for tc in tool_calls:
name = tc["name"]
args = tc["args"]
# 1. 预算检查
can_afford, reason = self.cost.can_afford(name)
if not can_afford:
results.append({"id": tc["id"], "error": reason})
continue
# 2. 频率检查
allowed, reason = self.rate_limiter.check(name)
if not allowed:
results.append({"id": tc["id"], "error": reason})
continue
# 3. 缓存查找
hit, cached_result = await self.cache.get(name, args)
if hit:
results.append({"id": tc["id"], "result": cached_result, "cached": True})
continue
# 4. 带重试的执行
result = await self.resilient.execute_with_retry(name, args)
# 5. 记录用量
self.rate_limiter.record(name)
self.cost.record_usage(name, tokens=500, cost_usd=0.005)
# 6. 缓存结果
if result["status"] == "success":
await self.cache.set(name, args, result["result"])
results.append({"id": tc["id"], **result})
return results
总结
工具调用优化的核心策略:
- 并行化:无依赖的工具调用并行执行,延迟取决于最慢的一个
- 缓存:确定性结果缓存(按工具 + 参数哈希),有副作用的工具不缓存
- 弹性:指数退避重试 + 降级备选 + 超时保护
- 成本控制:token 预算 + 调用频率限制 + 按需选择工具子集
- 工具精简:根据上下文动态选择相关工具,减少 prompt token 开销
Maurice | maurice_wen@proton.me