LangGraph 多 Agent 协作的"安全漏洞",差点把我们整崩

LangGraph 多 Agent 协作的"安全漏洞",差点把我们整崩

前言

我们团队用 LangGraph 搭了多 Agent 系统,结果上线第一天就发现有用户通过提示词注入,让 Agent 调用了不该调的工具。

还好发现得早,没酿成大祸。后来我们搞了一套完整的安全方案。今天说说 LangGraph 多 Agent 的安全漏洞怎么防。

一、底层原理

1.1 LangGraph 多 Agent 的安全风险

多 Agent 比单 Agent 更危险,因为攻击面更大:

graph TD
    A["用户输入"] --> B["输入 Agent"]
    B --> C["对话 Agent"]
    C --> D["工具 Agent"]
    D --> E["数据库 Agent"]
    F["安全防护"] --> G["输入过滤"]
    F --> H["权限校验"]
    F --> I["审计日志"]
    G --> B
    H --> D
    I --> E

主要攻击途径:

  • 通过用户输入注入恶意指令
  • 跨 Agent 传播攻击
  • 工具调用越权
  • 敏感数据泄露

1.2 常见安全方案对比

方案 防护能力 性能影响 实现难度
输入过滤
权限校验
输出审查
完整审计

二、快速上手

先看不安全的 LangGraph 写法:

from langgraph.graph import StateGraph, END
from typing import TypedDict, List

class AgentState(TypedDict):
    messages: List[str]
    next_agent: str

# 不安全的实现
def route_message(state: AgentState):
    # 直接把用户输入传给下一个 Agent
    return {"next_agent": state["messages"][-1]}

graph = StateGraph(AgentState)
# 没有安全检查

再看安全版:

from typing import TypedDict, List, Optional
import re

class SecureState(TypedDict):
    messages: List[str]
    safe_next: Optional[str]

class SecurityFilter:
    def __init__(self):
        self.dangerous_patterns = [
            r"忽略.*指令",
            r"绕过.*限制",
            r"删除.*数据",
        ]
    
    def filter_input(self, text: str) -> bool:
        for pat in self.dangerous_patterns:
            if re.search(pat, text, re.IGNORECASE):
                return False
        return True

def secure_route(state: SecureState):
    last_msg = state.get("messages", [""])[-1]
    filter = SecurityFilter()
    
    if not filter.filter_input(last_msg):
        return {"safe_next": "guard_agent"}
    
    return {"safe_next": "next_agent"}

三、核心 API / 深水区

3.1 LangGraph 安全配置速查

配置项 说明 建议值
输入验证 过滤恶意输入 开启
工具白名单 限制可调用的工具 最小权限
状态隔离 Agent 间状态不共享 开启
超时控制 防止死循环 30秒
审计日志 记录所有操作 开启

3.2 Agent 间的安全隔离

class SafeAgentWrapper:
    def __init__(self, agent_func, allowed_tools):
        self.agent_func = agent_func
        self.allowed_tools = allowed_tools
    
    def execute(self, state):
        # 检查输入
        if not self._validate_input(state):
            return {"error": "输入不合法"}
        
        # 执行代理
        result = self.agent_func(state)
        
        # 检查输出中的工具调用
        if "tool" in result:
            if result["tool"] not in self.allowed_tools:
                return {"error": f"不允许调用 {result['tool']}"}
        
        return result
    
    def _validate_input(self, state):
        text = str(state.get("messages", []))
        dangerous = ["drop", "delete", "exec", "system"]
        return not any(kw in text.lower() for kw in dangerous)

3.3 审计日志实现

import time
from typing import Dict, Any

class AuditLogger:
    def __init__(self):
        self.logs = []
    
    def log(self, agent: str, action: str, details: Dict[str, Any], risk: str = "low"):
        self.logs.append({
            "time": time.time(),
            "agent": agent,
            "action": action,
            "details": details,
            "risk": risk
        })
    
    def get_recent(self, n=10):
        return self.logs[-n:]
    
    def report_suspicious(self):
        return [log for log in self.logs if log["risk"] in ["high", "critical"]]

四、实战演练

安全加固的 LangGraph 多 Agent 系统:

from typing import TypedDict, List, Optional, Dict, Any
import re
import hashlib
import json

class SafeGraphState(TypedDict):
    messages: List[str]
    current_agent: str
    tool_calls: List[Dict[str, Any]]
    audit_trail: List[Dict[str, Any]]

class InputSanitizer:
    def __init__(self):
        self.blocked = [
            r"ignore\s+previous",
            r"disregard\s+all",
            r"system\s+prompt",
            r"root\s+access",
        ]
        self.compiled = [re.compile(p, re.IGNORECASE) for p in self.blocked]
    
    def check(self, text: str) -> bool:
        for pat in self.compiled:
            if pat.search(text):
                return False
        return True

class ToolGuard:
    def __init__(self):
        self.allowed_tools = {
            "search", "read", "calculate", "translate"
        }
        self.tool_call_count = {}
    
    def can_call(self, tool: str, args: Dict) -> bool:
        if tool not in self.allowed_tools:
            return False
        if "shell" in str(args) or "exec" in str(args):
            return False
        return True
    
    def record_call(self, user: str, tool: str):
        key = f"{user}_{tool}"
        self.tool_call_count[key] = self.tool_call_count.get(key, 0) + 1
        if self.tool_call_count[key] > 100:
            return False
        return True

class SecureLangGraph:
    def __init__(self):
        self.sanitizer = InputSanitizer()
        self.tool_guard = ToolGuard()
        self.audit = []
    
    def process(self, user_input: str) -> Dict:
        state: SafeGraphState = {
            "messages": [user_input],
            "current_agent": "input_check",
            "tool_calls": [],
            "audit_trail": []
        }
        
        # 1. 输入检查
        if not self.sanitizer.check(user_input):
            self._audit(state, "blocked_input", user_input, "high")
            return self._reject("输入不安全")
        
        # 2. 对话处理
        # 检查消息中的工具调用
        for msg in state["messages"]:
            if "call_tool" in msg:
                call_info = self._parse_tool_call(msg)
                if call_info:
                    tool, args = call_info
                    if not self.tool_guard.can_call(tool, args):
                        self._audit(state, "blocked_tool", f"{tool}", "high")
                        return self._reject(f"不允许调用 {tool}")
                    
                    self.tool_guard.record_call("user", tool)
                    self._audit(state, "tool_call", f"{tool}({args})", "low")
        
        result = {"status": "ok", "message": "处理完成"}
        self._audit(state, "completed", str(result), "low")
        return result
    
    def _audit(self, state: SafeGraphState, action: str, details: str, risk: str):
        entry = {
            "action": action,
            "details": details,
            "risk": risk
        }
        state["audit_trail"].append(entry)
        self.audit.append(entry)
    
    def _reject(self, reason: str) -> Dict:
        return {"status": "rejected", "reason": reason}
    
    def _parse_tool_call(self, msg: str):
        if "search" in msg.lower():
            return ("search", {"query": msg})
        return None

system = SecureLangGraph()
result = system.process("这个产品怎么样?")
print(result)

五、避坑指南与最佳实践

💡 **技巧:Agent 之间做好隔离
不共享内部状态,只通过消息协议通信。

⚠️ **警告:不要信任任何 Agent 的输出
每个 Agent 的输出都要做校验。

✅ **推荐:全部操作记录审计日志
出了问题能回溯,谁调了什么工具一目了然。

六、综合实战演示

企业级 LangGraph 安全方案:

import time
from typing import Dict, List, Any, Optional
from dataclasses import dataclass, field

@dataclass
class SecurityEvent:
    event_type: str
    agent_name: str
    input_snapshot: str
    risk_level: str
    timestamp: float

class SecurityMonitor:
    def __init__(self):
        self.events: List[SecurityEvent] = []
        self.rate_limits: Dict[str, int] = {}
    
    def check_rate_limit(self, user: str) -> bool:
        now = time.time()
        window_start = now - 60
        recent = [e for e in self.events 
                  if e.event_type == "tool_call" 
                  and e.timestamp > window_start]
        return len(recent) < 100
    
    def log_event(self, event: SecurityEvent):
        self.events.append(event)
    
    def get_threats(self) -> List[SecurityEvent]:
        return [e for e in self.events if e.risk_level == "critical"]

class LangGraphSecuritySystem:
    def __init__(self):
        self.monitor = SecurityMonitor()
        self.sanitizer = InputSanitizer()
    
    def secure_execute(self, graph, user_input: str) -> Dict[str, Any]:
        if not self.monitor.check_rate_limit("user"):
            return {"error": "请求过于频繁"}
        
        if not self.sanitizer.check(user_input):
            self.monitor.log_event(SecurityEvent(
                event_type="injection_attempt",
                agent_name="input_gate",
                input_snapshot=user_input[:100],
                risk_level="critical",
                timestamp=time.time()
            ))
            return {"error": "检测到恶意输入"}
        
        # 执行时注入安全检查
        secure_graph = self._inject_security(graph)
        result = secure_graph.invoke({"messages": [user_input]})
        
        return result
    
    def _inject_security(self, graph):
        # 在实际项目中,这里会包装图的节点
        return graph

system = LangGraphSecuritySystem()
result = system.secure_execute(None, "帮我查一下订单")
print(result)

七、总结

LangGraph 多 Agent 安全要点:

  • 输入必须过滤
  • 工具调用必须校验
  • Agent 之间做好隔离
  • 所有操作记录日志
  • 加上频率限制

安全无小事,尤其多 Agent 系统,攻击面比单 Agent 大得多。

Logo

Agent 垂直技术社区,欢迎活跃、内容共建。

更多推荐