AI Agent 的长时任务工程实践:超时、中断与断点续跑

从一次凌晨告警说起,系统梳理生产环境中 AI Agent 长时任务的三大工程难题
0. 引子:一次真实的生产事故
凌晨两点,告警响了。
用户提交了一个"分析竞品并生成完整报告"的 Agent 任务,任务包含:搜索 10 个竞品官网、提取关键信息、调用代码执行工具做数据对比、最后生成一份 3000 字的分析报告。
任务跑了 4 分 23 秒后,用户看到了一个 504 Gateway Timeout。
然后是三个连环问题:
- 任务到底有没有完成? 后端日志显示 LLM 还在生成,但 HTTP 连接已经断了。
- 要不要重试? 如果任务已经完成了一半,重试会从头开始,浪费钱还浪费时间。
- 重试会不会重复执行? 其中一个工具调用会写数据库,重复执行会产生脏数据。
这三个问题,本质上是 AI Agent 长时任务的三大工程难题:超时控制、中断处理、断点续跑。
这篇文章就从这次事故出发,系统梳理这三个问题的工程解法。
1. 为什么长时任务在 LLM Agent 里特别难
在传统 Web 服务里,一个请求超过 30 秒就算慢请求了。但 AI Agent 任务动辄需要几分钟——这不是代码写得烂,而是任务本身的性质决定的。
1.1 LLM API 本身的超时限制
先看各主流 LLM provider 的超时配置:
| Provider | 默认超时 | 最大超时 | 备注 |
|---|---|---|---|
| OpenAI | 600s | 600s | 官方文档明确 |
| Anthropic | 600s | 600s | 流式模式下更长 |
| Google Gemini | 300s | 600s | 依 region 不同 |
| 国内主流(通义/文心/混元) | 120-300s | 300s | 各家不同 |
600 秒看起来很长,但一个复杂 Agent 任务可能需要:
- 10 次 LLM 调用,每次 30-60s
- 5 次工具调用,每次 10-30s
- 总计轻松超过 10 分钟
1.2 HTTP 链路上的多层超时
更麻烦的是,LLM API 超时只是最后一道关卡。在请求到达 LLM 之前,还有一堆超时在等着你:
用户浏览器 (30s)
→ CDN/反向代理 (60s)
→ 负载均衡器 (AWS ALB 默认 60s idle timeout)
→ 应用服务器 (Nginx proxy_read_timeout 默认 60s)
→ 你的 Agent 代码
→ LLM API (600s)
AWS ALB 的 idle timeout 默认是 60 秒——意思是如果连接上 60 秒没有数据传输,连接就会被强制关闭。对于一个需要等待 LLM 生成的长任务,这个默认值几乎必然触发。
很多团队在这里踩坑:明明 LLM 还在生成,但用户已经收到 504 了。
1.3 工具调用的不确定性
Agent 的工具调用时间极难预测:
- Web 搜索:快则 2s,慢则 30s(目标网站响应慢)
- 代码执行:快则 1s,慢则 120s(复杂计算)
- 数据库查询:快则 10ms,慢则 30s(全表扫描)
- 外部 API 调用:完全不可控
每次工具调用,Agent 都在同步等待结果,然后把结果塞进 context 再调用 LLM。10 个工具调用串行下来,时间累加是必然的。
1.4 上下文累积导致的延迟增长
这个问题更隐蔽。Transformer 的注意力机制是 O(n²) 的——context 越长,每次 LLM 调用越慢。
一个任务进行到第 8 步时,context 里可能已经有:
- 原始任务描述(500 tokens)
- 7 次工具调用的结果(每次 200-500 tokens)
- 7 次 LLM 响应(每次 300-800 tokens)
总计 5000-10000 tokens,比任务开始时慢 3-5 倍。
1.5 无状态 API 的本质
最根本的问题:LLM API 是无状态的。每次调用都是独立的,没有"续传"概念。
如果一次 LLM 调用在生成到一半时被中断,你只能重新发起调用——没有办法从第 500 个 token 继续生成。
这意味着所有的"断点续跑"逻辑,都必须在你的应用层实现,而不是依赖 LLM API。
2. 超时控制:建立分层超时体系
解决超时问题的核心思路是:不要用一个超时控制所有事情,而是建立分层超时体系。
2.1 超时层级设计
用户请求超时 (e.g., 300s)
└── Agent 循环超时 (e.g., 240s)
├── 单次 LLM 调用超时 (e.g., 60s,含重试)
└── 单次工具调用超时 (e.g., 30s)
每一层都有独立的超时,且内层超时 < 外层超时,留出足够的缓冲时间处理超时后的清理工作。
2.2 Python 实现:asyncio 分层超时
import asyncio
import logging
from dataclasses import dataclass
from typing import Any, Optional
logger = logging.getLogger(__name__)
@dataclass
class TimeoutConfig:
user_request: float = 300.0 # 用户请求总超时
agent_loop: float = 240.0 # Agent 循环超时
llm_call: float = 60.0 # 单次 LLM 调用超时
tool_call: float = 30.0 # 单次工具调用超时
llm_retry_count: int = 2 # LLM 调用重试次数
class TimeoutError(Exception):
def __init__(self, layer: str, timeout: float):
self.layer = layer
self.timeout = timeout
super().__init__(f"Timeout at layer '{layer}' after {timeout}s")
async def call_llm_with_timeout(
llm_client,
messages: list,
config: TimeoutConfig,
task_id: str
) -> str:
"""带超时和重试的 LLM 调用"""
last_error = None
for attempt in range(config.llm_retry_count + 1):
try:
response = await asyncio.wait_for(
llm_client.chat(messages),
timeout=config.llm_call
)
return response.content
except asyncio.TimeoutError:
last_error = TimeoutError("llm_call", config.llm_call)
logger.warning(
f"[{task_id}] LLM call timeout (attempt {attempt+1}/{config.llm_retry_count+1})"
)
if attempt < config.llm_retry_count:
await asyncio.sleep(2 ** attempt) # 指数退避
except Exception as e:
last_error = e
logger.error(f"[{task_id}] LLM call error: {e}")
if attempt < config.llm_retry_count:
await asyncio.sleep(2 ** attempt)
raise last_error
async def call_tool_with_timeout(
tool_fn,
args: dict,
config: TimeoutConfig,
task_id: str,
tool_name: str
) -> Any:
"""带超时的工具调用"""
try:
return await asyncio.wait_for(
tool_fn(**args),
timeout=config.tool_call
)
except asyncio.TimeoutError:
logger.error(f"[{task_id}] Tool '{tool_name}' timeout after {config.tool_call}s")
raise TimeoutError("tool_call", config.tool_call)
2.3 超时后的正确处理
超时发生后,有两种处理策略:
策略 A:保存检查点,返回部分结果
async def run_agent_with_timeout(task_id: str, task: dict, config: TimeoutConfig):
checkpoint = AgentCheckpoint(redis_client, task_id)
try:
async with asyncio.timeout(config.agent_loop):
return await run_agent_loop(task_id, task, checkpoint, config)
except asyncio.TimeoutError:
# 超时前保存当前状态
current_state = await checkpoint.load()
if current_state:
current_state["status"] = "timeout"
current_state["timeout_at"] = time.time()
await checkpoint.save(current_state)
raise TimeoutError("agent_loop", config.agent_loop)
策略 B:直接报错,让调用方决定是否重试
对于不支持断点续跑的简单任务,直接报错更清晰。关键是要在错误信息里说明任务执行到了哪一步。
2.4 常见错误
错误 1:只设一层超时
# 错误:只有最外层超时,内层 LLM 调用可能永远挂着
async with asyncio.timeout(300):
for step in steps:
result = await llm_client.chat(messages) # 没有超时!
错误 2:超时后不清理资源
# 错误:超时后没有取消正在运行的工具调用
try:
result = await asyncio.wait_for(run_tool(), timeout=30)
except asyncio.TimeoutError:
pass # 工具可能还在后台运行!
正确做法是在超时后显式取消所有子任务:
task = asyncio.create_task(run_tool())
try:
result = await asyncio.wait_for(asyncio.shield(task), timeout=30)
except asyncio.TimeoutError:
task.cancel() # 显式取消
try:
await task
except asyncio.CancelledError:
pass
3. 中断处理:让任务可以被安全停止
中断和超时不同:超时是被动的(时间到了),中断是主动的(用户取消、系统信号、资源耗尽)。
3.1 中断的来源
- 用户取消:用户点了"停止"按钮
- 系统信号:SIGTERM(容器被调度器终止)、SIGINT(Ctrl+C)
- 资源耗尽:OOM Killer、磁盘满
- 依赖服务故障:Redis 断连、数据库不可用
每种中断的处理方式不同,但有一个共同原则:中断后的状态必须是一致的。
3.2 优雅关闭 vs 强制终止
import signal
import asyncio
class AgentRunner:
def __init__(self):
self._shutdown_event = asyncio.Event()
self._current_task: Optional[asyncio.Task] = None
def setup_signal_handlers(self):
"""注册信号处理器,支持优雅关闭"""
loop = asyncio.get_event_loop()
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(
sig,
lambda: asyncio.create_task(self._graceful_shutdown())
)
async def _graceful_shutdown(self):
"""优雅关闭:先保存状态,再停止"""
logger.info("Received shutdown signal, saving checkpoint...")
self._shutdown_event.set()
if self._current_task and not self._current_task.done():
# 给当前任务 10 秒时间完成清理
try:
await asyncio.wait_for(
asyncio.shield(self._current_task),
timeout=10.0
)
except asyncio.TimeoutError:
logger.warning("Graceful shutdown timeout, forcing cancel")
self._current_task.cancel()
async def run_task(self, task_id: str, task: dict):
"""在 Agent 循环中检查关闭信号"""
checkpoint = AgentCheckpoint(redis_client, task_id)
for step_idx, step in enumerate(generate_steps(task)):
# 每步开始前检查是否需要关闭
if self._shutdown_event.is_set():
logger.info(f"[{task_id}] Shutdown requested at step {step_idx}, saving checkpoint")
await checkpoint.save({
"step_idx": step_idx,
"status": "interrupted",
"interrupted_at": time.time()
})
return {"status": "interrupted", "step_idx": step_idx}
result = await self._execute_step(step, task_id)
await checkpoint.save({"step_idx": step_idx + 1, "last_result": result})
3.3 工具执行的幂等性设计
这是中断处理里最容易被忽视的问题:工具调用必须是幂等的。
如果一个工具调用(比如"向数据库写入分析结果")在执行到一半时被中断,续跑时会重新执行这个工具。如果工具不是幂等的,就会产生重复数据。
解决方案:基于内容 hash 的工具结果缓存。
import hashlib
import json
import time
from typing import Any, Optional
class IdempotentToolExecutor:
def __init__(self, redis_client, task_id: str, result_ttl: int = 3600):
self.redis = redis_client
self.task_id = task_id
self.ttl = result_ttl
def _make_cache_key(self, tool_name: str, args: dict) -> str:
"""基于工具名和参数生成确定性缓存 key"""
args_hash = hashlib.sha256(
json.dumps(args, sort_keys=True, ensure_ascii=False).encode()
).hexdigest()[:16]
return f"tool:result:{self.task_id}:{tool_name}:{args_hash}"
async def execute(
self,
tool_name: str,
tool_fn,
args: dict,
force_rerun: bool = False
) -> Any:
cache_key = self._make_cache_key(tool_name, args)
# 检查是否已有缓存结果
if not force_rerun:
cached = await self.redis.get(cache_key)
if cached:
logger.info(f"[{self.task_id}] Tool '{tool_name}' cache hit, skipping execution")
return json.loads(cached)
# 执行工具
logger.info(f"[{self.task_id}] Executing tool '{tool_name}'")
start_time = time.time()
try:
result = await tool_fn(**args)
except Exception as e:
logger.error(f"[{self.task_id}] Tool '{tool_name}' failed: {e}")
raise
elapsed = time.time() - start_time
logger.info(f"[{self.task_id}] Tool '{tool_name}' completed in {elapsed:.2f}s")
# 缓存结果
await self.redis.setex(
cache_key,
self.ttl,
json.dumps(result, ensure_ascii=False)
)
return result
使用示例:
executor = IdempotentToolExecutor(redis_client, task_id="task-123")
# 第一次执行:实际调用工具
result = await executor.execute(
"search_web",
search_web_fn,
{"query": "竞品A 定价策略", "max_results": 5}
)
# 续跑时再次执行:直接返回缓存结果,不重复调用
result = await executor.execute(
"search_web",
search_web_fn,
{"query": "竞品A 定价策略", "max_results": 5}
)
3.4 分布式锁 + 心跳租约防止僵尸任务
在多实例部署场景下,一个任务可能被多个 worker 同时认领。更糟糕的是,worker 崩溃后任务可能永远处于"running"状态——没有人知道它已经死了。
解决方案:心跳租约(Heartbeat Lease)。
import asyncio
import time
from typing import Optional
class TaskLease:
"""
分布式任务锁,基于 Redis SETNX + 心跳续期。
Worker 必须定期续期,否则锁自动过期,其他 worker 可以接管。
"""
def __init__(
self,
redis_client,
task_id: str,
lease_ttl: int = 30, # 锁的 TTL(秒)
heartbeat_interval: Optional[int] = None # 心跳间隔,默认 TTL/2
):
self.redis = redis_client
self.key = f"agent:lease:{task_id}"
self.ttl = lease_ttl
self.heartbeat_interval = heartbeat_interval or (lease_ttl // 2)
self._heartbeat_task: Optional[asyncio.Task] = None
self._worker_id = f"worker-{id(self)}"
async def acquire(self) -> bool:
"""尝试获取锁。返回 True 表示成功,False 表示已被其他 worker 持有"""
acquired = await self.redis.set(
self.key,
self._worker_id,
nx=True, # 只在 key 不存在时设置
ex=self.ttl
)
if acquired:
await self._start_heartbeat()
return bool(acquired)
async def _start_heartbeat(self):
"""启动后台心跳任务,定期续期锁"""
async def _beat():
while True:
await asyncio.sleep(self.heartbeat_interval)
try:
# 只有当前 worker 持有锁时才续期
current_holder = await self.redis.get(self.key)
if current_holder and current_holder.decode() == self._worker_id:
await self.redis.expire(self.key, self.ttl)
else:
logger.warning(f"Lost lease for {self.key}, stopping heartbeat")
break
except Exception as e:
logger.error(f"Heartbeat error: {e}")
self._heartbeat_task = asyncio.create_task(_beat())
async def release(self):
"""释放锁"""
if self._heartbeat_task:
self._heartbeat_task.cancel()
try:
await self._heartbeat_task
except asyncio.CancelledError:
pass
# 只删除自己持有的锁(Lua 脚本保证原子性)
lua_script = """
if redis.call("get", KEYS[1]) == ARGV[1] then
return redis.call("del", KEYS[1])
else
return 0
end
"""
await self.redis.eval(lua_script, 1, self.key, self._worker_id)
async def __aenter__(self):
if not await self.acquire():
raise RuntimeError(f"Failed to acquire lease for {self.key}")
return self
async def __aexit__(self, *args):
await self.release()
使用示例:
async def process_task(task_id: str, task: dict):
lease = TaskLease(redis_client, task_id, lease_ttl=30)
async with lease:
# 在锁保护下执行任务
# 如果 worker 崩溃,30 秒后锁自动过期,其他 worker 可以接管
result = await run_agent_loop(task_id, task)
return result
4. 断点续跑:从哪里跌倒从哪里爬起来
断点续跑的核心是:在任务执行过程中,定期把"足够恢复任务"的状态持久化到外部存储。
4.1 检查点设计原则
原则 1:粒度要合适
检查点太粗(只在任务结束时保存)= 失败后全部重来。
检查点太细(每个 token 都保存)= 存储和性能开销巨大。
实践中,以"一次完整的 LLM 调用 + 工具调用"为一个检查点粒度比较合理。
原则 2:检查点必须包含足够的恢复信息
最小检查点内容:
{
"task_id": "task-123",
"step_idx": 5, # 当前执行到第几步
"messages": [...], # 完整的对话历史(用于恢复 LLM 上下文)
"tool_results": {...}, # 已完成的工具调用结果(用于幂等性)
"partial_output": "...", # 已生成的部分输出
"status": "in_progress",
"created_at": 1748678400,
"updated_at": 1748678500
}
原则 3:TTL 要合理
检查点不应该永久保存。设置一个合理的 TTL(比如任务预期时长的 3-5 倍),过期后自动清理。
4.2 Redis 检查点实现
import json
import time
import redis.asyncio as aioredis
from typing import Optional, Any
class AgentCheckpoint:
"""
基于 Redis 的 Agent 检查点管理器。
支持保存、加载、清除检查点,以及检查点版本管理。
"""
def __init__(
self,
redis_client: aioredis.Redis,
task_id: str,
ttl: int = 3600, # 默认 1 小时
namespace: str = "agent:checkpoint"
):
self.redis = redis_client
self.task_id = task_id
self.ttl = ttl
self.key = f"{namespace}:{task_id}"
self.history_key = f"{namespace}:{task_id}:history"
async def save(self, state: dict) -> None:
"""保存检查点,自动添加时间戳和版本号"""
# 加载当前版本号
current = await self.load()
version = (current.get("_version", 0) + 1) if current else 1
state_with_meta = {
**state,
"_version": version,
"_saved_at": time.time(),
"_task_id": self.task_id
}
serialized = json.dumps(state_with_meta, ensure_ascii=False)
# 原子操作:保存当前状态 + 追加历史记录
pipe = self.redis.pipeline()
pipe.setex(self.key, self.ttl, serialized)
# 保留最近 10 个历史版本(用于调试)
pipe.lpush(self.history_key, serialized)
pipe.ltrim(self.history_key, 0, 9)
pipe.expire(self.history_key, self.ttl)
await pipe.execute()
async def load(self) -> Optional[dict]:
"""加载最新检查点"""
data = await self.redis.get(self.key)
if not data:
return None
return json.loads(data)
async def load_history(self, count: int = 5) -> list[dict]:
"""加载历史检查点(用于调试)"""
items = await self.redis.lrange(self.history_key, 0, count - 1)
return [json.loads(item) for item in items]
async def clear(self) -> None:
"""清除检查点(任务成功完成后调用)"""
pipe = self.redis.pipeline()
pipe.delete(self.key)
pipe.delete(self.history_key)
await pipe.execute()
async def exists(self) -> bool:
"""检查是否存在未完成的检查点"""
return bool(await self.redis.exists(self.key))
4.3 什么应该存入检查点
应该存:
- 完整的对话历史(
messages列表) - 已完成的工具调用结果(用于幂等性判断)
- 当前步骤索引
- 任务的输入参数(用于验证续跑时任务没有变化)
- 部分生成的输出
不应该存:
- 大型二进制数据(图片、文件)→ 存文件系统,检查点里只存路径
- 临时文件路径 → 续跑时路径可能已失效
- 函数对象、类实例 → 无法序列化
- 敏感信息(API Key、用户密码)→ 安全风险
4.4 续跑时的幂等性保证
续跑的关键问题:如何确保从检查点恢复后,不会重复执行已完成的步骤?
async def run_agent_with_resume(
task_id: str,
task: dict,
redis_client: aioredis.Redis,
config: TimeoutConfig
) -> dict:
checkpoint = AgentCheckpoint(redis_client, task_id)
tool_executor = IdempotentToolExecutor(redis_client, task_id)
# 尝试从检查点恢复
saved_state = await checkpoint.load()
if saved_state:
start_step = saved_state.get("step_idx", 0)
messages = saved_state.get("messages", [])
partial_output = saved_state.get("partial_output", "")
logger.info(f"[{task_id}] Resuming from step {start_step}")
else:
start_step = 0
messages = [{"role": "user", "content": task["prompt"]}]
partial_output = ""
logger.info(f"[{task_id}] Starting fresh")
# 执行 Agent 循环
for step_idx in range(start_step, task.get("max_steps", 20)):
# 调用 LLM
response = await call_llm_with_timeout(llm_client, messages, config, task_id)
messages.append({"role": "assistant", "content": response})
# 解析工具调用
tool_calls = parse_tool_calls(response)
for tool_call in tool_calls:
# 幂等工具执行:已执行过的直接返回缓存结果
tool_result = await tool_executor.execute(
tool_call["name"],
get_tool_fn(tool_call["name"]),
tool_call["args"]
)
messages.append({
"role": "tool",
"tool_call_id": tool_call["id"],
"content": json.dumps(tool_result)
})
# 检查是否完成
if is_task_complete(response):
final_output = extract_output(response)
await checkpoint.clear() # 成功完成,清除检查点
return {"status": "completed", "output": final_output}
# 保存检查点
await checkpoint.save({
"step_idx": step_idx + 1,
"messages": messages,
"partial_output": partial_output,
"status": "in_progress"
})
# 达到最大步数
return {"status": "max_steps_reached", "partial_output": partial_output}
5. 异步任务队列:彻底解耦提交与执行
前面三节解决的是"任务在执行中"的问题。但还有一个更根本的问题:用户不应该等着 HTTP 连接挂着。
5.1 同步 vs 异步执行对比
| 维度 | 同步(阻塞 HTTP) | 异步(Job Queue) |
|---|---|---|
| 实现复杂度 | 低 | 高 |
| 超时风险 | 高(受 HTTP 链路限制) | 低(后台执行,不受 HTTP 超时影响) |
| 用户体验 | 等待感强,浏览器可能超时 | 可轮询/推送,用户可以关闭页面 |
| 断点续跑 | 难(连接断了就没了) | 容易(状态在 DB/Redis 里) |
| 适用场景 | < 30s 的简单任务 | > 30s 的复杂任务 |
| 基础设施 | 无额外依赖 | 需要队列 + 持久化存储 |
| 可观测性 | 差(只有请求日志) | 好(任务状态、进度、历史) |
5.2 Job Queue 模式实现
from enum import Enum
from dataclasses import dataclass, field
import uuid
import time
class TaskStatus(Enum):
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
TIMEOUT = "timeout"
CANCELLED = "cancelled"
@dataclass
class AgentTask:
task_id: str = field(default_factory=lambda: str(uuid.uuid4()))
prompt: str = ""
status: TaskStatus = TaskStatus.PENDING
created_at: float = field(default_factory=time.time)
started_at: Optional[float] = None
completed_at: Optional[float] = None
result: Optional[dict] = None
error: Optional[str] = None
progress: dict = field(default_factory=dict)
# API 层:提交任务
async def submit_task(prompt: str) -> dict:
task = AgentTask(prompt=prompt)
# 保存任务到 Redis
await redis_client.setex(
f"task:{task.task_id}",
86400, # 24 小时 TTL
json.dumps(asdict(task), default=str)
)
# 推入任务队列
await redis_client.lpush("agent:task_queue", task.task_id)
return {"task_id": task.task_id, "status": "pending"}
# API 层:查询任务状态
async def get_task_status(task_id: str) -> dict:
data = await redis_client.get(f"task:{task_id}")
if not data:
raise ValueError(f"Task {task_id} not found")
return json.loads(data)
# Worker 层:消费任务
async def worker_loop():
while True:
# 阻塞等待任务(BRPOP 超时 5 秒)
result = await redis_client.brpop("agent:task_queue", timeout=5)
if not result:
continue
_, task_id = result
task_id = task_id.decode()
# 获取任务详情
task_data = await redis_client.get(f"task:{task_id}")
if not task_data:
continue
task = AgentTask(**json.loads(task_data))
# 尝试获取任务锁(防止重复执行)
lease = TaskLease(redis_client, task_id)
if not await lease.acquire():
logger.info(f"Task {task_id} already being processed, skipping")
continue
try:
# 更新状态为 running
task.status = TaskStatus.RUNNING
task.started_at = time.time()
await save_task(task)
# 执行任务
result = await run_agent_with_resume(task_id, {"prompt": task.prompt}, redis_client, TimeoutConfig())
# 更新状态为 completed
task.status = TaskStatus.COMPLETED
task.result = result
task.completed_at = time.time()
await save_task(task)
except TimeoutError as e:
task.status = TaskStatus.TIMEOUT
task.error = str(e)
task.completed_at = time.time()
await save_task(task)
except Exception as e:
task.status = TaskStatus.FAILED
task.error = str(e)
task.completed_at = time.time()
await save_task(task)
finally:
await lease.release()
5.3 与 OpenAI Assistants API 的对比
OpenAI Assistants API 的 Run 对象本质上就是这个模式的托管版本:
| 概念 | 自建 Job Queue | OpenAI Assistants API |
|---|---|---|
| 任务提交 | POST /tasks | POST /threads/{id}/runs |
| 任务 ID | task_id | run_id |
| 状态查询 | GET /tasks/{id} | GET /threads/{id}/runs/{id} |
| 状态值 | pending/running/completed | queued/in_progress/completed |
| 工具执行 | 自己实现 | requires_action 回调 |
| 超时处理 | 自己实现 | 自动(有限制) |
自建的优势:完全控制超时、重试、检查点策略;可以集成任意工具;不受 OpenAI 的速率限制和功能限制。
6. 实战:完整的 AgentTaskRunner
把前面所有模式组合起来,实现一个生产可用的 AgentTaskRunner:
import asyncio
import json
import logging
import time
import uuid
from dataclasses import dataclass, field, asdict
from enum import Enum
from typing import Any, Callable, Optional
import redis.asyncio as aioredis
logger = logging.getLogger(__name__)
# ============================================================
# 配置
# ============================================================
@dataclass
class TaskRunnerConfig:
# 超时配置
agent_loop_timeout: float = 240.0
llm_call_timeout: float = 60.0
tool_call_timeout: float = 30.0
llm_retry_count: int = 2
# 检查点配置
checkpoint_ttl: int = 3600
# 租约配置
lease_ttl: int = 30
# 任务配置
max_steps: int = 20
# ============================================================
# 核心 Runner
# ============================================================
class AgentTaskRunner:
"""
生产级 AI Agent 任务执行器。
集成:分层超时 + 检查点续跑 + 幂等工具 + 心跳租约
"""
def __init__(
self,
redis_client: aioredis.Redis,
llm_client,
tools: dict[str, Callable],
config: Optional[TaskRunnerConfig] = None
):
self.redis = redis_client
self.llm = llm_client
self.tools = tools
self.config = config or TaskRunnerConfig()
async def run(self, task_id: str, prompt: str) -> dict:
"""
执行任务,支持断点续跑。
如果存在检查点,从检查点恢复;否则从头开始。
"""
checkpoint = AgentCheckpoint(self.redis, task_id, self.config.checkpoint_ttl)
tool_executor = IdempotentToolExecutor(self.redis, task_id, self.config.checkpoint_ttl)
lease = TaskLease(self.redis, task_id, self.config.lease_ttl)
# 获取分布式锁
if not await lease.acquire():
raise RuntimeError(f"Task {task_id} is already running on another worker")
try:
return await asyncio.wait_for(
self._run_loop(task_id, prompt, checkpoint, tool_executor),
timeout=self.config.agent_loop_timeout
)
except asyncio.TimeoutError:
# 超时:保存检查点,等待续跑
state = await checkpoint.load() or {}
state["status"] = "timeout"
state["timeout_at"] = time.time()
await checkpoint.save(state)
raise TimeoutError("agent_loop", self.config.agent_loop_timeout)
finally:
await lease.release()
async def _run_loop(
self,
task_id: str,
prompt: str,
checkpoint: "AgentCheckpoint",
tool_executor: "IdempotentToolExecutor"
) -> dict:
# 从检查点恢复或初始化
saved = await checkpoint.load()
if saved and saved.get("status") not in ("completed", "failed"):
start_step = saved.get("step_idx", 0)
messages = saved.get("messages", [])
logger.info(f"[{task_id}] Resuming from step {start_step} (checkpoint version {saved.get('_version')})")
else:
start_step = 0
messages = [{"role": "user", "content": prompt}]
logger.info(f"[{task_id}] Starting fresh")
for step_idx in range(start_step, self.config.max_steps):
logger.info(f"[{task_id}] Step {step_idx + 1}/{self.config.max_steps}")
# LLM 调用(带超时和重试)
response_text = await self._call_llm(messages, task_id)
messages.append({"role": "assistant", "content": response_text})
# 解析并执行工具调用
tool_calls = self._parse_tool_calls(response_text)
for tc in tool_calls:
if tc["name"] not in self.tools:
logger.warning(f"[{task_id}] Unknown tool: {tc['name']}")
continue
tool_result = await self._call_tool(
tc["name"], self.tools[tc["name"]], tc["args"],
tool_executor, task_id
)
messages.append({
"role": "tool",
"tool_call_id": tc.get("id", str(uuid.uuid4())),
"content": json.dumps(tool_result, ensure_ascii=False)
})
# 检查是否完成
if self._is_complete(response_text):
output = self._extract_output(response_text)
await checkpoint.clear()
logger.info(f"[{task_id}] Completed at step {step_idx + 1}")
return {"status": "completed", "output": output, "steps": step_idx + 1}
# 保存检查点
await checkpoint.save({
"step_idx": step_idx + 1,
"messages": messages,
"status": "in_progress"
})
return {"status": "max_steps_reached", "steps": self.config.max_steps}
async def _call_llm(self, messages: list, task_id: str) -> str:
last_error = None
for attempt in range(self.config.llm_retry_count + 1):
try:
response = await asyncio.wait_for(
self.llm.chat(messages),
timeout=self.config.llm_call_timeout
)
return response.content
except asyncio.TimeoutError:
last_error = TimeoutError("llm_call", self.config.llm_call_timeout)
if attempt < self.config.llm_retry_count:
await asyncio.sleep(2 ** attempt)
except Exception as e:
last_error = e
if attempt < self.config.llm_retry_count:
await asyncio.sleep(2 ** attempt)
raise last_error
async def _call_tool(
self, name: str, fn: Callable, args: dict,
executor: "IdempotentToolExecutor", task_id: str
) -> Any:
try:
return await asyncio.wait_for(
executor.execute(name, fn, args),
timeout=self.config.tool_call_timeout
)
except asyncio.TimeoutError:
logger.error(f"[{task_id}] Tool '{name}' timeout")
return {"error": f"Tool '{name}' timed out after {self.config.tool_call_timeout}s"}
def _parse_tool_calls(self, response: str) -> list[dict]:
# 实际实现需要解析 LLM 返回的工具调用格式
# 这里简化处理
return []
def _is_complete(self, response: str) -> bool:
return "<TASK_COMPLETE>" in response
def _extract_output(self, response: str) -> str:
return response.replace("<TASK_COMPLETE>", "").strip()
6.1 测试:模拟中断后续跑
import asyncio
import pytest
@pytest.mark.asyncio
async def test_resume_after_interrupt():
"""测试任务中断后能从检查点续跑"""
redis = aioredis.from_url("redis://localhost:6379")
task_id = "test-task-001"
# 模拟一个会在第 3 步中断的 LLM
call_count = 0
async def mock_llm_chat(messages):
nonlocal call_count
call_count += 1
if call_count == 3:
raise asyncio.TimeoutError() # 模拟第 3 步超时
if call_count >= 5:
return MockResponse("<TASK_COMPLETE>分析完成")
return MockResponse(f"执行步骤 {call_count}")
runner = AgentTaskRunner(
redis_client=redis,
llm_client=MockLLMClient(mock_llm_chat),
tools=,
config=TaskRunnerConfig(llm_retry_count=0)
)
# 第一次运行:在第 3 步超时
with pytest.raises(TimeoutError):
await runner.run(task_id, "分析竞品并生成报告")
# 验证检查点已保存
checkpoint = AgentCheckpoint(redis, task_id)
state = await checkpoint.load()
assert state is not None
assert state["step_idx"] == 2 # 完成了 2 步
# 第二次运行:从检查点续跑
result = await runner.run(task_id, "分析竞品并生成报告")
assert result["status"] == "completed"
assert call_count == 5 # 总共调用了 5 次(2 + 3)
await redis.aclose()
7. 踩坑总结
经历了几次生产事故后,总结出以下 5 个最常见的坑:
坑 1:只设了 LLM 超时,忘了 HTTP 链路超时
症状:LLM 调用成功,但用户收到 504。
原因:AWS ALB idle timeout 默认 60s,Nginx proxy_read_timeout 默认 60s。
修复:调大 LB 和 Nginx 超时,或者改用异步任务队列。
坑 2:检查点里存了不可序列化的对象
症状:json.JSONDecodeError 或 TypeError: Object of type X is not JSON serializable。
原因:把 datetime、Enum、自定义类直接存入检查点。
修复:序列化前转换为基础类型;使用 default=str 作为 fallback。
坑 3:工具调用没有幂等性,续跑时重复写数据库
症状:数据库里出现重复记录,用户收到重复通知。
原因:续跑时重新执行了已完成的工具调用。
修复:基于内容 hash 的工具结果缓存(见第 3.3 节)。
坑 4:心跳间隔 > 租约 TTL,导致锁提前过期
症状:同一任务被两个 worker 同时执行。
原因:心跳间隔设为 30s,但租约 TTL 也是 30s,心跳来不及续期锁就过期了。
修复:心跳间隔 = TTL / 2(见第 3.4 节的实现)。
坑 5:上下文无限增长,越跑越慢
症状:任务前几步很快,后面越来越慢,最终超时。
原因:每步都把完整工具结果追加到 messages,context 越来越大。
修复:对旧的工具结果做摘要压缩,只保留最近 N 步的完整结果。
def compress_old_tool_results(messages: list, keep_recent: int = 5) -> list:
"""压缩旧的工具调用结果,只保留最近 N 个完整结果"""
tool_messages = [(i, m) for i, m in enumerate(messages) if m["role"] == "tool"]
if len(tool_messages) <= keep_recent:
return messages
# 需要压缩的工具消息
to_compress = tool_messages[:-keep_recent]
compress_indices = {i for i, _ in to_compress}
compressed = []
for i, msg in enumerate(messages):
if i in compress_indices:
# 压缩:只保留摘要
content = json.loads(msg["content"])
compressed.append({
**msg,
"content": json.dumps({
"_compressed": True,
"_summary": str(content)[:200] + "..."
})
})
else:
compressed.append(msg)
return compressed
决策树:什么时候用同步,什么时候用异步
任务预期时长 < 30s?
├── 是 → 同步执行(简单,无额外依赖)
└── 否 → 异步任务队列
│
├── 需要断点续跑?
│ ├── 是 → 检查点 + 幂等工具
│ └── 否 → 简单 Job Queue 即可
│
└── 多实例部署?
├── 是 → 心跳租约(防僵尸任务)
└── 否 → 简单锁即可
8. 结语
长时任务是 AI Agent 工程化的必经之路。随着 Agent 能力越来越强,任务越来越复杂,这些问题只会更突出,不会消失。
核心思路其实很简单:把"任务执行"和"HTTP 连接"解耦。一旦解耦,超时、中断、续跑都有了清晰的处理空间。
本文的代码都是可以直接用的生产级实现,不是玩具代码。如果你的 Agent 任务已经开始出现超时问题,从分层超时体系开始改起,然后逐步加入检查点和异步队列。
不需要一次性把所有模式都上,按需引入,够用就好。
如果这篇文章对你有帮助,欢迎点赞收藏。有问题或者踩过其他坑,欢迎在评论区交流。
更多推荐



所有评论(0)