一、问题:工具调用为什么需要中间件?

先看一个真实场景:

用户: "帮我分析这份 500 页 PDF 的内容"
Agent: 调用 read_pdf 工具
Tool:  返回全部 500 页文本 (2MB+)
LLM:  上下文窗口溢出 → 报错 / 乱答

这正是 LLM Agent 的典型故障模式:工具返回值不受控,撑爆上下文窗口

再看另一个:

用户: "把数据库里的用户数据导出来"
Agent: 调用 db_query 工具
Tool:  SQL 慢查询,耗时 120 秒
Agent: 一直等待,前端超时 → 用户以为系统挂了

这些问题不能靠修改每个工具来解决 —— 每个工具开发者都自己写一遍超时和截断逻辑,重复且易出错。中间件模式正是为此而生。


二、中间件模式设计

2.1 核心接口

类似 HTTP 中间件,我们定义工具调用的 Handler 和 Middleware:

// middleware/middleware.go
package middleware

import (
    "context"

    "github.com/example/mcp-server-go/internal/protocol"
)

// ToolHandler 工具执行函数(与上篇的 ToolHandler 一致)
type ToolHandler func(ctx context.Context, args map[string]interface{}) (*protocol.CallToolResult, error)

// Middleware 中间件 —— 包装一个 Handler,返回新的 Handler
type Middleware func(next ToolHandler) ToolHandler

// Chain 将多个中间件串联
// 执行顺序:middlewares[0] → middlewares[1] → ... → handler
func Chain(handler ToolHandler, middlewares ...Middleware) ToolHandler {
    // 洋葱模型:从外到内
    for i := len(middlewares) - 1; i >= 0; i-- {
        handler = middlewares[i](handler)
    }
    return handler
}

2.2 执行模型

请求 → [截断中间件] → [超时中间件] → [熔断中间件] → [实际工具]
                                        ↓
                            记录成功/失败 → 更新熔断状态

每个中间件只做一件事,组合起来形成完整的防御体系。


三、中间件一:输出截断(TruncateMiddleware)

防止工具返回过大的结果撑爆 LLM 上下文窗口。

// middleware/truncate.go
package middleware

import (
    "context"
    "fmt"
    "strings"
    "unicode/utf8"

    "github.com/example/mcp-server-go/internal/protocol"
)

// TruncateConfig 截断配置
type TruncateConfig struct {
    MaxChars     int    // 最大字符数(0 = 不限制)
    MaxTokens    int    // 最大 token 数(粗略估计:1 token ≈ 4 chars,0 = 不限制)
    TruncateMsg  string // 截断提示信息
}

func DefaultTruncateConfig() TruncateConfig {
    return TruncateConfig{
        MaxChars:    32000, // 约 8K tokens
        MaxTokens:    0,
        TruncateMsg: "\n\n[...输出过长已截断,请用更精确的查询...]",
    }
}

// WithTruncate 创建截断中间件
func WithTruncate(cfg TruncateConfig) Middleware {
    return func(next ToolHandler) ToolHandler {
        return func(ctx context.Context, args map[string]interface{}) (*protocol.CallToolResult, error) {
            result, err := next(ctx, args)
            if err != nil || result == nil {
                return result, err
            }

            result.Content = truncateContent(result.Content, cfg)
            return result, nil
        }
    }
}

func truncateContent(items []protocol.ContentItem, cfg TruncateConfig) []protocol.ContentItem {
    var totalChars int
    truncated := false

    for i := range items {
        if items[i].Type != "text" {
            continue
        }

        itemChars := utf8.RuneCountInString(items[i].Text)

        if cfg.MaxChars > 0 && totalChars+itemChars > cfg.MaxChars {
            // 截断文本
            remaining := cfg.MaxChars - totalChars - utf8.RuneCountInString(cfg.TruncateMsg)
            if remaining > 0 {
                // 找到安全的截断点(避免切断 UTF-8 多字节字符)
                runes := []rune(items[i].Text)
                if remaining < len(runes) {
                    items[i].Text = string(runes[:remaining]) + cfg.TruncateMsg
                }
            } else {
                items[i].Text = cfg.TruncateMsg
            }
            truncated = true
            break
        }

        totalChars += itemChars

        // Token 估算截断
        if cfg.MaxTokens > 0 {
            estimatedTokens := totalChars / 4
            if estimatedTokens > cfg.MaxTokens {
                items[i].Text = items[i].Text[:cfg.MaxTokens*4-len(items[i].Text)] + " [TRUNCATED]"
                truncated = true
                break
            }
        }
    }

    if truncated {
        // 追加截断标记到最后一个 content item
        lastText := &items[len(items)-1]
        if lastText.Type == "text" {
            lastText.Text += fmt.Sprintf("\n[截断统计: 原输出超过 %d 字符限制]", cfg.MaxChars)
        }
    }

    return items
}

为什么不在工具内部截断?

  • 工具开发者不需要关心 LLM 上下文窗口细节
  • 截断策略可以统一调整(比如为不同模型设不同的 token 上限)
  • 截断日志可以集中收集

四、中间件二:超时控制(TimeoutMiddleware)

// middleware/timeout.go
package middleware

import (
    "context"
    "fmt"
    "time"

    "github.com/example/mcp-server-go/internal/protocol"
)

// WithTimeout 创建超时中间件
func WithTimeout(timeout time.Duration) Middleware {
    return func(next ToolHandler) ToolHandler {
        return func(ctx context.Context, args map[string]interface{}) (*protocol.CallToolResult, error) {
            // 创建带超时的子 context
            toolCtx, cancel := context.WithTimeout(ctx, timeout)
            defer cancel()

            // 通过 channel 桥接同步调用
            type result struct {
                res *protocol.CallToolResult
                err error
            }
            done := make(chan result, 1)

            go func() {
                defer func() {
                    if r := recover(); r != nil {
                        done <- result{err: fmt.Errorf("tool panic: %v", r)}
                    }
                }()
                res, err := next(toolCtx, args)
                done <- result{res: res, err: err}
            }()

            select {
            case <-toolCtx.Done():
                // 超时 或 父 context 取消
                return nil, fmt.Errorf("tool call timeout after %v: %w", timeout, toolCtx.Err())
            case r := <-done:
                return r.res, r.err
            }
        }
    }
}

超时的坑:goroutine 泄漏

上面的实现在超时后,goroutine 中的 next() 仍在执行但结果被丢弃。如果工具内部有长时间运行的操作,会造成 goroutine 泄漏。

解决方案:让工具本身尊重 ctx.Done()

// 工具实现必须检查 ctx.Done()
func slowDBQuery(ctx context.Context, args map[string]interface{}) (*protocol.CallToolResult, error) {
    // 正确做法:在每次阻塞操作前检查 context
    select {
    case <-ctx.Done():
        return nil, ctx.Err()
    default:
    }

    // 使用支持 context 的 API
    rows, err := db.QueryContext(ctx, sql, args...)
    // ...
}

关键原则:中间件的超时只是最后一层防线。真正可靠的超时需要工具内部配合 ctx.Done()


五、中间件三:熔断器(CircuitBreakerMiddleware)

当某个工具连续失败时,快速失败避免雪崩。

// middleware/circuit_breaker.go
package middleware

import (
    "context"
    "fmt"
    "sync"
    "sync/atomic"
    "time"

    "github.com/example/mcp-server-go/internal/protocol"
)

// CircuitState 熔断器状态
type CircuitState int32

const (
    StateClosed   CircuitState = iota // 正常
    StateOpen                          // 熔断中,快速失败
    StateHalfOpen                      // 半开,试探性放行
)

// CircuitBreakerConfig 熔断器配置
type CircuitBreakerConfig struct {
    FailureThreshold  int           // 连续失败多少次后熔断
    SuccessThreshold  int           // 半开状态下多少次成功后恢复
    Timeout           time.Duration // 熔断打开后多久进入半开
    HalfOpenMaxCalls  int           // 半开状态最多放行几个请求
}

func DefaultCircuitBreakerConfig() CircuitBreakerConfig {
    return CircuitBreakerConfig{
        FailureThreshold: 5,
        SuccessThreshold: 3,
        Timeout:          30 * time.Second,
        HalfOpenMaxCalls: 1,
    }
}

// CircuitBreaker 熔断器实现
type CircuitBreaker struct {
    cfg CircuitBreakerConfig

    state          atomic.Int32       // CircuitState
    consecutiveFail atomic.Int32      // 连续失败计数
    consecutiveSuccess atomic.Int32   // 连续成功计数(半开状态用)
    lastFailureTime  time.Time
    halfOpenCalls    atomic.Int32     // 半开状态已放行数

    mu sync.Mutex
}

func NewCircuitBreaker(cfg CircuitBreakerConfig) *CircuitBreaker {
    cb := &CircuitBreaker{cfg: cfg}
    cb.state.Store(int32(StateClosed))
    return cb
}

// Allow 检查是否允许调用
func (cb *CircuitBreaker) Allow() error {
    state := CircuitState(cb.state.Load())

    switch state {
    case StateClosed:
        return nil

    case StateOpen:
        cb.mu.Lock()
        defer cb.mu.Unlock()

        // 检查是否可以进入半开状态
        if time.Since(cb.lastFailureTime) > cb.cfg.Timeout {
            cb.state.Store(int32(StateHalfOpen))
            cb.consecutiveSuccess.Store(0)
            cb.halfOpenCalls.Store(0)
            return nil
        }

        return fmt.Errorf("circuit breaker is OPEN: too many failures, retry after %v",
            cb.cfg.Timeout-time.Since(cb.lastFailureTime).Round(time.Second))

    case StateHalfOpen:
        // 限制半开状态的并发试探请求
        current := cb.halfOpenCalls.Add(1)
        if current > int32(cb.cfg.HalfOpenMaxCalls) {
            cb.halfOpenCalls.Add(-1)
            return fmt.Errorf("circuit breaker is HALF-OPEN: max probe calls reached")
        }
        return nil

    default:
        return nil
    }
}

// RecordSuccess 记录成功
func (cb *CircuitBreaker) RecordSuccess() {
    state := CircuitState(cb.state.Load())

    switch state {
    case StateClosed:
        cb.consecutiveFail.Store(0)
        cb.consecutiveSuccess.Add(1)

    case StateHalfOpen:
        success := cb.consecutiveSuccess.Add(1)
        if int(success) >= cb.cfg.SuccessThreshold {
            cb.state.Store(int32(StateClosed))
            cb.consecutiveFail.Store(0)
        }
        cb.halfOpenCalls.Add(-1)
    }
}

// RecordFailure 记录失败
func (cb *CircuitBreaker) RecordFailure() {
    cb.mu.Lock()
    defer cb.mu.Unlock()

    state := CircuitState(cb.state.Load())
    cb.lastFailureTime = time.Now()

    switch state {
    case StateClosed:
        fails := cb.consecutiveFail.Add(1)
        if int(fails) >= cb.cfg.FailureThreshold {
            cb.state.Store(int32(StateOpen))
        }

    case StateHalfOpen:
        // 半开状态下的失败 → 立即重新熔断
        cb.state.Store(int32(StateOpen))
        cb.consecutiveFail.Store(int32(cb.cfg.FailureThreshold))
        cb.halfOpenCalls.Add(-1)
    }
}

// WithCircuitBreaker 创建熔断中间件
func WithCircuitBreaker(cb *CircuitBreaker) Middleware {
    return func(next ToolHandler) ToolHandler {
        return func(ctx context.Context, args map[string]interface{}) (*protocol.CallToolResult, error) {
            if err := cb.Allow(); err != nil {
                return nil, err
            }

            result, err := next(ctx, args)

            if err != nil {
                cb.RecordFailure()
            } else if result != nil && !result.IsError {
                cb.RecordSuccess()
            } else {
                // isError=true 也算失败
                cb.RecordFailure()
            }

            return result, err
        }
    }
}

六、组合使用:洋葱模型

// 实际用法
package main

import (
    "time"

    "github.com/example/mcp-server-go/internal/middleware"
    "github.com/example/mcp-server-go/internal/protocol"
)

func main() {
    // 原始工具处理器
    rawHandler := dbQueryHandler

    // 创建熔断器(全局共享,跨请求统计)
    cb := middleware.NewCircuitBreaker(middleware.DefaultCircuitBreakerConfig())

    // 洋葱模型组合中间件
    // 执行顺序(由外到内):截断 → 超时 → 熔断 → 实际工具
    protectedHandler := middleware.Chain(rawHandler,
        middleware.WithTruncate(middleware.TruncateConfig{
            MaxChars:    16000, // 约 4K tokens(针对小模型)
            TruncateMsg: "\n\n[输出已截断]",
        }),
        middleware.WithTimeout(15 * time.Second),
        middleware.WithCircuitBreaker(cb),
    )

    // 将 protectedHandler 注册到 MCP Registry
    // registry.Register(tool, protectedHandler)
}

请求执行时序

请求到达
  │
  ▼
[Truncate] ── 记录开始时间
  │
  ▼
[Timeout] ── ctx, cancel := context.WithTimeout(ctx, 15s)
  │
  ▼
[CircuitBreaker] ── cb.Allow() → 状态检查
  │                   ├─ OPEN → 直接返回 "circuit breaker open"
  │                   └─ CLOSED/HALF-OPEN → 继续
  ▼
[实际工具] ── 执行业务逻辑
  │
  ▼
[CircuitBreaker] ◄── 记录成功/失败,更新计数
  │
  ▼
[Timeout] ◄── cancel() 回收资源
  │
  ▼
[Truncate] ◄── 如果输出过长,截断
  │
  ▼
返回结果给 Client

七、进阶:可观测性中间件

除了防御性中间件,还可以加一层可观测性:

// middleware/observability.go
package middleware

import (
    "context"
    "log"
    "time"

    "github.com/example/mcp-server-go/internal/protocol"
)

// WithMetrics 指标记录中间件
func WithMetrics(toolName string) Middleware {
    return func(next ToolHandler) ToolHandler {
        return func(ctx context.Context, args map[string]interface{}) (*protocol.CallToolResult, error) {
            start := time.Now()

            result, err := next(ctx, args)

            elapsed := time.Since(start)
            status := "success"
            if err != nil {
                status = "error"
            }

            // 结构化日志(可接入 Prometheus / Loki)
            log.Printf("[metrics] tool=%s status=%s duration=%v args=%v",
                toolName, status, elapsed, truncateArgs(args, 200))

            // 生产环境:emit metrics
            // toolCallDuration.WithLabelValues(toolName, status).Observe(elapsed.Seconds())
            // toolCallTotal.WithLabelValues(toolName, status).Inc()

            return result, err
        }
    }
}

func truncateArgs(args map[string]interface{}, maxLen int) string {
    s := fmt.Sprintf("%v", args)
    if len(s) > maxLen {
        return s[:maxLen] + "..."
    }
    return s
}

八、中间件对比总结

中间件 解决的问题 适用场景 性能开销
Truncate 输出过大撑爆上下文 文件读取、数据库查询、API 调用 低(仅字符串操作)
Timeout 工具卡死不返回 网络调用、慢查询、外部 API 低(一个 goroutine + channel)
CircuitBreaker 连续失败雪崩 外部依赖不可靠时 极低(原子操作 + 锁)
Metrics 无感知,问题发现滞后 所有工具 低(日志 I/O 开销)
Logo

Agent 垂直技术社区,欢迎活跃、内容共建。

更多推荐