AI Agent 工具调用中间件:Go 实现截断、超时与熔断
·
一、问题:工具调用为什么需要中间件?
先看一个真实场景:
用户: "帮我分析这份 500 页 PDF 的内容"
Agent: 调用 read_pdf 工具
Tool: 返回全部 500 页文本 (2MB+)
LLM: 上下文窗口溢出 → 报错 / 乱答
这正是 LLM Agent 的典型故障模式:工具返回值不受控,撑爆上下文窗口。
再看另一个:
用户: "把数据库里的用户数据导出来"
Agent: 调用 db_query 工具
Tool: SQL 慢查询,耗时 120 秒
Agent: 一直等待,前端超时 → 用户以为系统挂了
这些问题不能靠修改每个工具来解决 —— 每个工具开发者都自己写一遍超时和截断逻辑,重复且易出错。中间件模式正是为此而生。
二、中间件模式设计
2.1 核心接口
类似 HTTP 中间件,我们定义工具调用的 Handler 和 Middleware:
// middleware/middleware.go
package middleware
import (
"context"
"github.com/example/mcp-server-go/internal/protocol"
)
// ToolHandler 工具执行函数(与上篇的 ToolHandler 一致)
type ToolHandler func(ctx context.Context, args map[string]interface{}) (*protocol.CallToolResult, error)
// Middleware 中间件 —— 包装一个 Handler,返回新的 Handler
type Middleware func(next ToolHandler) ToolHandler
// Chain 将多个中间件串联
// 执行顺序:middlewares[0] → middlewares[1] → ... → handler
func Chain(handler ToolHandler, middlewares ...Middleware) ToolHandler {
// 洋葱模型:从外到内
for i := len(middlewares) - 1; i >= 0; i-- {
handler = middlewares[i](handler)
}
return handler
}
2.2 执行模型
请求 → [截断中间件] → [超时中间件] → [熔断中间件] → [实际工具]
↓
记录成功/失败 → 更新熔断状态
每个中间件只做一件事,组合起来形成完整的防御体系。
三、中间件一:输出截断(TruncateMiddleware)
防止工具返回过大的结果撑爆 LLM 上下文窗口。
// middleware/truncate.go
package middleware
import (
"context"
"fmt"
"strings"
"unicode/utf8"
"github.com/example/mcp-server-go/internal/protocol"
)
// TruncateConfig 截断配置
type TruncateConfig struct {
MaxChars int // 最大字符数(0 = 不限制)
MaxTokens int // 最大 token 数(粗略估计:1 token ≈ 4 chars,0 = 不限制)
TruncateMsg string // 截断提示信息
}
func DefaultTruncateConfig() TruncateConfig {
return TruncateConfig{
MaxChars: 32000, // 约 8K tokens
MaxTokens: 0,
TruncateMsg: "\n\n[...输出过长已截断,请用更精确的查询...]",
}
}
// WithTruncate 创建截断中间件
func WithTruncate(cfg TruncateConfig) Middleware {
return func(next ToolHandler) ToolHandler {
return func(ctx context.Context, args map[string]interface{}) (*protocol.CallToolResult, error) {
result, err := next(ctx, args)
if err != nil || result == nil {
return result, err
}
result.Content = truncateContent(result.Content, cfg)
return result, nil
}
}
}
func truncateContent(items []protocol.ContentItem, cfg TruncateConfig) []protocol.ContentItem {
var totalChars int
truncated := false
for i := range items {
if items[i].Type != "text" {
continue
}
itemChars := utf8.RuneCountInString(items[i].Text)
if cfg.MaxChars > 0 && totalChars+itemChars > cfg.MaxChars {
// 截断文本
remaining := cfg.MaxChars - totalChars - utf8.RuneCountInString(cfg.TruncateMsg)
if remaining > 0 {
// 找到安全的截断点(避免切断 UTF-8 多字节字符)
runes := []rune(items[i].Text)
if remaining < len(runes) {
items[i].Text = string(runes[:remaining]) + cfg.TruncateMsg
}
} else {
items[i].Text = cfg.TruncateMsg
}
truncated = true
break
}
totalChars += itemChars
// Token 估算截断
if cfg.MaxTokens > 0 {
estimatedTokens := totalChars / 4
if estimatedTokens > cfg.MaxTokens {
items[i].Text = items[i].Text[:cfg.MaxTokens*4-len(items[i].Text)] + " [TRUNCATED]"
truncated = true
break
}
}
}
if truncated {
// 追加截断标记到最后一个 content item
lastText := &items[len(items)-1]
if lastText.Type == "text" {
lastText.Text += fmt.Sprintf("\n[截断统计: 原输出超过 %d 字符限制]", cfg.MaxChars)
}
}
return items
}
为什么不在工具内部截断?
- 工具开发者不需要关心 LLM 上下文窗口细节
- 截断策略可以统一调整(比如为不同模型设不同的 token 上限)
- 截断日志可以集中收集
四、中间件二:超时控制(TimeoutMiddleware)
// middleware/timeout.go
package middleware
import (
"context"
"fmt"
"time"
"github.com/example/mcp-server-go/internal/protocol"
)
// WithTimeout 创建超时中间件
func WithTimeout(timeout time.Duration) Middleware {
return func(next ToolHandler) ToolHandler {
return func(ctx context.Context, args map[string]interface{}) (*protocol.CallToolResult, error) {
// 创建带超时的子 context
toolCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
// 通过 channel 桥接同步调用
type result struct {
res *protocol.CallToolResult
err error
}
done := make(chan result, 1)
go func() {
defer func() {
if r := recover(); r != nil {
done <- result{err: fmt.Errorf("tool panic: %v", r)}
}
}()
res, err := next(toolCtx, args)
done <- result{res: res, err: err}
}()
select {
case <-toolCtx.Done():
// 超时 或 父 context 取消
return nil, fmt.Errorf("tool call timeout after %v: %w", timeout, toolCtx.Err())
case r := <-done:
return r.res, r.err
}
}
}
}
超时的坑:goroutine 泄漏
上面的实现在超时后,goroutine 中的 next() 仍在执行但结果被丢弃。如果工具内部有长时间运行的操作,会造成 goroutine 泄漏。
解决方案:让工具本身尊重 ctx.Done():
// 工具实现必须检查 ctx.Done()
func slowDBQuery(ctx context.Context, args map[string]interface{}) (*protocol.CallToolResult, error) {
// 正确做法:在每次阻塞操作前检查 context
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
// 使用支持 context 的 API
rows, err := db.QueryContext(ctx, sql, args...)
// ...
}
关键原则:中间件的超时只是最后一层防线。真正可靠的超时需要工具内部配合 ctx.Done()。
五、中间件三:熔断器(CircuitBreakerMiddleware)
当某个工具连续失败时,快速失败避免雪崩。
// middleware/circuit_breaker.go
package middleware
import (
"context"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/example/mcp-server-go/internal/protocol"
)
// CircuitState 熔断器状态
type CircuitState int32
const (
StateClosed CircuitState = iota // 正常
StateOpen // 熔断中,快速失败
StateHalfOpen // 半开,试探性放行
)
// CircuitBreakerConfig 熔断器配置
type CircuitBreakerConfig struct {
FailureThreshold int // 连续失败多少次后熔断
SuccessThreshold int // 半开状态下多少次成功后恢复
Timeout time.Duration // 熔断打开后多久进入半开
HalfOpenMaxCalls int // 半开状态最多放行几个请求
}
func DefaultCircuitBreakerConfig() CircuitBreakerConfig {
return CircuitBreakerConfig{
FailureThreshold: 5,
SuccessThreshold: 3,
Timeout: 30 * time.Second,
HalfOpenMaxCalls: 1,
}
}
// CircuitBreaker 熔断器实现
type CircuitBreaker struct {
cfg CircuitBreakerConfig
state atomic.Int32 // CircuitState
consecutiveFail atomic.Int32 // 连续失败计数
consecutiveSuccess atomic.Int32 // 连续成功计数(半开状态用)
lastFailureTime time.Time
halfOpenCalls atomic.Int32 // 半开状态已放行数
mu sync.Mutex
}
func NewCircuitBreaker(cfg CircuitBreakerConfig) *CircuitBreaker {
cb := &CircuitBreaker{cfg: cfg}
cb.state.Store(int32(StateClosed))
return cb
}
// Allow 检查是否允许调用
func (cb *CircuitBreaker) Allow() error {
state := CircuitState(cb.state.Load())
switch state {
case StateClosed:
return nil
case StateOpen:
cb.mu.Lock()
defer cb.mu.Unlock()
// 检查是否可以进入半开状态
if time.Since(cb.lastFailureTime) > cb.cfg.Timeout {
cb.state.Store(int32(StateHalfOpen))
cb.consecutiveSuccess.Store(0)
cb.halfOpenCalls.Store(0)
return nil
}
return fmt.Errorf("circuit breaker is OPEN: too many failures, retry after %v",
cb.cfg.Timeout-time.Since(cb.lastFailureTime).Round(time.Second))
case StateHalfOpen:
// 限制半开状态的并发试探请求
current := cb.halfOpenCalls.Add(1)
if current > int32(cb.cfg.HalfOpenMaxCalls) {
cb.halfOpenCalls.Add(-1)
return fmt.Errorf("circuit breaker is HALF-OPEN: max probe calls reached")
}
return nil
default:
return nil
}
}
// RecordSuccess 记录成功
func (cb *CircuitBreaker) RecordSuccess() {
state := CircuitState(cb.state.Load())
switch state {
case StateClosed:
cb.consecutiveFail.Store(0)
cb.consecutiveSuccess.Add(1)
case StateHalfOpen:
success := cb.consecutiveSuccess.Add(1)
if int(success) >= cb.cfg.SuccessThreshold {
cb.state.Store(int32(StateClosed))
cb.consecutiveFail.Store(0)
}
cb.halfOpenCalls.Add(-1)
}
}
// RecordFailure 记录失败
func (cb *CircuitBreaker) RecordFailure() {
cb.mu.Lock()
defer cb.mu.Unlock()
state := CircuitState(cb.state.Load())
cb.lastFailureTime = time.Now()
switch state {
case StateClosed:
fails := cb.consecutiveFail.Add(1)
if int(fails) >= cb.cfg.FailureThreshold {
cb.state.Store(int32(StateOpen))
}
case StateHalfOpen:
// 半开状态下的失败 → 立即重新熔断
cb.state.Store(int32(StateOpen))
cb.consecutiveFail.Store(int32(cb.cfg.FailureThreshold))
cb.halfOpenCalls.Add(-1)
}
}
// WithCircuitBreaker 创建熔断中间件
func WithCircuitBreaker(cb *CircuitBreaker) Middleware {
return func(next ToolHandler) ToolHandler {
return func(ctx context.Context, args map[string]interface{}) (*protocol.CallToolResult, error) {
if err := cb.Allow(); err != nil {
return nil, err
}
result, err := next(ctx, args)
if err != nil {
cb.RecordFailure()
} else if result != nil && !result.IsError {
cb.RecordSuccess()
} else {
// isError=true 也算失败
cb.RecordFailure()
}
return result, err
}
}
}
六、组合使用:洋葱模型
// 实际用法
package main
import (
"time"
"github.com/example/mcp-server-go/internal/middleware"
"github.com/example/mcp-server-go/internal/protocol"
)
func main() {
// 原始工具处理器
rawHandler := dbQueryHandler
// 创建熔断器(全局共享,跨请求统计)
cb := middleware.NewCircuitBreaker(middleware.DefaultCircuitBreakerConfig())
// 洋葱模型组合中间件
// 执行顺序(由外到内):截断 → 超时 → 熔断 → 实际工具
protectedHandler := middleware.Chain(rawHandler,
middleware.WithTruncate(middleware.TruncateConfig{
MaxChars: 16000, // 约 4K tokens(针对小模型)
TruncateMsg: "\n\n[输出已截断]",
}),
middleware.WithTimeout(15 * time.Second),
middleware.WithCircuitBreaker(cb),
)
// 将 protectedHandler 注册到 MCP Registry
// registry.Register(tool, protectedHandler)
}
请求执行时序
请求到达
│
▼
[Truncate] ── 记录开始时间
│
▼
[Timeout] ── ctx, cancel := context.WithTimeout(ctx, 15s)
│
▼
[CircuitBreaker] ── cb.Allow() → 状态检查
│ ├─ OPEN → 直接返回 "circuit breaker open"
│ └─ CLOSED/HALF-OPEN → 继续
▼
[实际工具] ── 执行业务逻辑
│
▼
[CircuitBreaker] ◄── 记录成功/失败,更新计数
│
▼
[Timeout] ◄── cancel() 回收资源
│
▼
[Truncate] ◄── 如果输出过长,截断
│
▼
返回结果给 Client
七、进阶:可观测性中间件
除了防御性中间件,还可以加一层可观测性:
// middleware/observability.go
package middleware
import (
"context"
"log"
"time"
"github.com/example/mcp-server-go/internal/protocol"
)
// WithMetrics 指标记录中间件
func WithMetrics(toolName string) Middleware {
return func(next ToolHandler) ToolHandler {
return func(ctx context.Context, args map[string]interface{}) (*protocol.CallToolResult, error) {
start := time.Now()
result, err := next(ctx, args)
elapsed := time.Since(start)
status := "success"
if err != nil {
status = "error"
}
// 结构化日志(可接入 Prometheus / Loki)
log.Printf("[metrics] tool=%s status=%s duration=%v args=%v",
toolName, status, elapsed, truncateArgs(args, 200))
// 生产环境:emit metrics
// toolCallDuration.WithLabelValues(toolName, status).Observe(elapsed.Seconds())
// toolCallTotal.WithLabelValues(toolName, status).Inc()
return result, err
}
}
}
func truncateArgs(args map[string]interface{}, maxLen int) string {
s := fmt.Sprintf("%v", args)
if len(s) > maxLen {
return s[:maxLen] + "..."
}
return s
}
八、中间件对比总结
| 中间件 | 解决的问题 | 适用场景 | 性能开销 |
|---|---|---|---|
| Truncate | 输出过大撑爆上下文 | 文件读取、数据库查询、API 调用 | 低(仅字符串操作) |
| Timeout | 工具卡死不返回 | 网络调用、慢查询、外部 API | 低(一个 goroutine + channel) |
| CircuitBreaker | 连续失败雪崩 | 外部依赖不可靠时 | 极低(原子操作 + 锁) |
| Metrics | 无感知,问题发现滞后 | 所有工具 | 低(日志 I/O 开销) |
更多推荐
所有评论(0)