系列目录:本文是「AI 应用开发进阶实战」系列的第 4 篇。前面我们构建了 RAG、MCP 工具链和知识图谱,本篇进入 Agent 的核心——如何设计一个可靠的、可观测的工作流引擎。


一、为什么 Agent 需要工作流引擎?

1.1 从简单链到复杂工作流

最简单的 Agent:
  user_input → LLM → output

加了 RAG 后:
  user_input → retrieve → LLM → output     (2步)

加了工具调用后:
  user_input → LLM → tool_call → tool_result → LLM → output   (循环)

加了多 Agent 后:
  user_input → planner → [worker1, worker2, worker3] → aggregator → output

随着复杂度增长,直接写 if/else + while 循环的代码会迅速失控。工作流引擎提供:

能力 无引擎(裸写) 有引擎
可视化 代码即文档,难以理解 DAG 图一目了然
错误处理 到处 try/except,容易遗漏 声明式重试、降级、回退
并行执行 asyncio.gather 自己管理 引擎自动拓扑排序并行
状态持久化 手动存 Redis/DB 引擎内置检查点
可观测性 自己加日志 每个节点自动追踪

二、工作流引擎核心设计

2.1 DAG 工作流的数据结构

# workflow_engine.py
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Set
from enum import Enum
import asyncio
import time
import json


class NodeStatus(Enum):
    PENDING = "pending"
    RUNNING = "running"
    SUCCESS = "success"
    FAILED = "failed"
    SKIPPED = "skipped"


@dataclass
class WorkflowNode:
    """工作流中的一个节点"""
    name: str
    handler: Callable  # async function(context) -> Any
    inputs: List[str] = field(default_factory=list)  # 依赖节点名
    retry_count: int = 0        # 失败重试次数
    retry_delay: float = 1.0    # 重试间隔(秒)
    timeout: float = 60.0       # 超时时间(秒)
    condition: Optional[Callable] = None  # 条件执行:fn(context)->bool
    on_failure: str = "fail"    # "fail" | "skip" | "continue"


@dataclass
class Workflow:
    """DAG 工作流定义"""
    name: str
    nodes: Dict[str, WorkflowNode]
    edges: List[tuple]  # [(from_node, to_node), ...]
    
    def validate(self) -> bool:
        """校验 DAG 合法性:无环、所有依赖存在"""
        # 1. 检查所有边引用的节点都存在
        all_nodes = set(self.nodes.keys())
        for src, dst in self.edges:
            if src not in all_nodes:
                raise ValueError(f"Edge source '{src}' not found in nodes")
            if dst not in all_nodes:
                raise ValueError(f"Edge target '{dst}' not found in nodes")
        
        # 2. 检查无环(拓扑排序)
        in_degree = {name: 0 for name in all_nodes}
        adj = {name: [] for name in all_nodes}
        
        for src, dst in self.edges:
            adj[src].append(dst)
            in_degree[dst] += 1
        
        queue = [n for n, d in in_degree.items() if d == 0]
        visited = 0
        
        while queue:
            node = queue.pop(0)
            visited += 1
            for neighbor in adj[node]:
                in_degree[neighbor] -= 1
                if in_degree[neighbor] == 0:
                    queue.append(neighbor)
        
        if visited != len(all_nodes):
            raise ValueError("Workflow contains a cycle!")
        
        return True


@dataclass
class WorkflowContext:
    """工作流执行上下文——在节点间传递数据"""
    inputs: Dict[str, Any] = field(default_factory=dict)   # 用户输入
    node_outputs: Dict[str, Any] = field(default_factory=dict)  # 各节点输出
    metadata: Dict[str, Any] = field(default_factory=dict)
    
    def get(self, key: str, default=None):
        """优先从节点输出获取,其次从输入获取"""
        return self.node_outputs.get(key) or self.inputs.get(key, default)

2.2 工作流执行引擎

class WorkflowEngine:
    """DAG 工作流执行引擎"""
    
    def __init__(self):
        self.executions: Dict[str, dict] = {}  # 存储执行记录
    
    async def execute(self, workflow: Workflow, context: WorkflowContext) -> dict:
        """执行工作流"""
        workflow.validate()
        
        exec_id = f"{workflow.name}_{int(time.time())}"
        self.executions[exec_id] = {
            "workflow": workflow.name,
            "start_time": time.time(),
            "node_statuses": {},
            "results": {}
        }
        
        # 构建依赖图
        in_degree = {name: 0 for name in workflow.nodes}
        dependents = {name: [] for name in workflow.nodes}  # 谁依赖我
        
        for src, dst in workflow.edges:
            in_degree[dst] += 1
            dependents[src].append(dst)
        
        print(f"\n{'='*60}")
        print(f"Workflow: {workflow.name}")
        print(f"Nodes: {len(workflow.nodes)}, Edges: {len(workflow.edges)}")
        print(f"{'='*60}\n")
        
        # 并发执行:当节点的所有依赖完成,即可执行
        ready_queue = asyncio.Queue()
        completed_count = 0
        total_nodes = len(workflow.nodes)
        
        # 入度为 0 的节点先入队
        for name, degree in in_degree.items():
            if degree == 0:
                await ready_queue.put(name)
        
        # 并发 worker
        async def worker():
            nonlocal completed_count
            while completed_count < total_nodes:
                try:
                    node_name = await asyncio.wait_for(
                        ready_queue.get(), timeout=1.0
                    )
                except asyncio.TimeoutError:
                    continue
                
                node = workflow.nodes[node_name]
                
                # 检查条件执行
                if node.condition and not node.condition(context):
                    print(f"  [{node_name}] SKIPPED (condition=False)")
                    self.executions[exec_id]["node_statuses"][node_name] = NodeStatus.SKIPPED
                else:
                    # 执行节点
                    result = await self._execute_node(
                        node, context, exec_id
                    )
                    context.node_outputs[node_name] = result
                
                completed_count += 1
                
                # 解锁依赖此节点的下游节点
                for dependent in dependents[node_name]:
                    in_degree[dependent] -= 1
                    if in_degree[dependent] == 0:
                        await ready_queue.put(dependent)
        
        # 启动并发 workers(最多 10 个并发)
        workers = [
            asyncio.create_task(worker())
            for _ in range(min(10, total_nodes))
        ]
        await asyncio.gather(*workers)
        
        exec_record = self.executions[exec_id]
        exec_record["end_time"] = time.time()
        exec_record["duration"] = exec_record["end_time"] - exec_record["start_time"]
        
        print(f"\nWorkflow completed in {exec_record['duration']:.1f}s")
        
        return context.node_outputs
    
    async def _execute_node(
        self, 
        node: WorkflowNode, 
        context: WorkflowContext,
        exec_id: str
    ) -> Any:
        """执行单个节点,带重试逻辑"""
        last_error = None
        
        for attempt in range(node.retry_count + 1):
            try:
                print(f"  [{node.name}] Running... (attempt {attempt+1}/{node.retry_count+1})")
                
                self.executions[exec_id]["node_statuses"][node.name] = NodeStatus.RUNNING
                
                # 执行节点处理器(带超时)
                result = await asyncio.wait_for(
                    node.handler(context),
                    timeout=node.timeout
                )
                
                self.executions[exec_id]["node_statuses"][node.name] = NodeStatus.SUCCESS
                self.executions[exec_id]["results"][node.name] = {
                    "status": "success",
                    "attempts": attempt + 1
                }
                
                print(f"  [{node.name}] SUCCESS")
                return result
                
            except asyncio.TimeoutError:
                last_error = f"Timeout after {node.timeout}s"
                print(f"  [{node.name}] TIMEOUT")
                
            except Exception as e:
                last_error = str(e)
                print(f"  [{node.name}] ERROR: {e}")
            
            if attempt < node.retry_count:
                await asyncio.sleep(node.retry_delay)
        
        # 所有重试失败后的处理
        self.executions[exec_id]["node_statuses"][node.name] = NodeStatus.FAILED
        self.executions[exec_id]["results"][node.name] = {
            "status": "failed",
            "error": last_error,
            "attempts": node.retry_count + 1
        }
        
        if node.on_failure == "skip":
            print(f"  [{node.name}] FAILED → SKIPPING (on_failure=skip)")
            return None
        
        elif node.on_failure == "continue":
            print(f"  [{node.name}] FAILED → CONTINUING (on_failure=continue)")
            return {"error": last_error}
        
        else:  # "fail"
            raise RuntimeError(f"Node '{node.name}' failed: {last_error}")

2.3 构建示例:文档处理工作流

# example_workflow.py
from workflow_engine import Workflow, WorkflowNode, WorkflowContext, WorkflowEngine


# === 定义节点处理器 ===

async def load_documents(ctx: WorkflowContext) -> dict:
    """从输入路径加载文档"""
    path = ctx.inputs.get("doc_path", "./docs")
    
    import os, glob
    files = glob.glob(os.path.join(path, "**/*.md"), recursive=True)
    
    documents = []
    for f in files[:20]:  # 限制数量
        with open(f, "r", encoding="utf-8") as fp:
            documents.append({
                "path": f,
                "content": fp.read(),
                "size": os.path.getsize(f)
            })
    
    print(f"    Loaded {len(documents)} documents")
    return {"documents": documents, "count": len(documents)}


async def chunk_documents(ctx: WorkflowContext) -> dict:
    """分块"""
    docs = ctx.node_outputs["load_docs"]["documents"]
    
    chunks = []
    for doc in docs:
        content = doc["content"]
        # 简单按段落分块
        paragraphs = content.split("\n\n")
        for i, para in enumerate(paragraphs):
            if len(para.strip()) > 50:
                chunks.append({
                    "source": doc["path"],
                    "chunk_id": f"{doc['path']}#{i}",
                    "content": para.strip()[:1000]
                })
    
    print(f"    Split into {len(chunks)} chunks")
    return {"chunks": chunks, "count": len(chunks)}


async def generate_embeddings(ctx: WorkflowContext) -> dict:
    """生成向量嵌入"""
    chunks = ctx.node_outputs["chunk_docs"]["chunks"]
    
    # 模拟 embedding(实际中用 OpenAI API)
    embeddings = []
    for chunk in chunks:
        embeddings.append({
            "chunk_id": chunk["chunk_id"],
            "embedding": [0.1] * 128,  # 模拟向量
            "content": chunk["content"][:100]
        })
    
    print(f"    Generated {len(embeddings)} embeddings")
    return {"embeddings": embeddings, "count": len(embeddings)}


async def extract_entities(ctx: WorkflowContext) -> dict:
    """提取实体(与 chunk 并行)"""
    docs = ctx.node_outputs["load_docs"]["documents"]
    
    # 模拟实体提取(实际中用 LLM)
    entities = {}
    for doc in docs:
        # 简单关键词提取
        for word in ["AI", "LLM", "Agent", "RAG"]:
            count = doc["content"].count(word)
            if count > 0:
                entities[word] = entities.get(word, 0) + count
    
    print(f"    Extracted {len(entities)} entity types")
    return {"entities": entities}


async def merge_and_index(ctx: WorkflowContext) -> dict:
    """合并 chunk 和 entity 结果,构建索引"""
    embeddings = ctx.node_outputs["gen_embeddings"]["embeddings"]
    entities = ctx.node_outputs["extract_entities"]["entities"]
    
    # 合并为统一索引
    index = {
        "total_chunks": len(embeddings),
        "total_entities": len(entities),
        "top_entities": sorted(entities.items(), key=lambda x: x[1], reverse=True)[:5]
    }
    
    print(f"    Built index: {index}")
    return index


async def quality_check(ctx: WorkflowContext) -> dict:
    """质量检查——条件执行"""
    index = ctx.node_outputs["merge_index"]
    
    passed = index["total_chunks"] >= 5  # 至少 5 个块
    result = {
        "passed": passed,
        "reason": "OK" if passed else f"Only {index['total_chunks']} chunks (< 5)",
        "stats": index
    }
    
    status = "PASS" if passed else "FAIL"
    print(f"    Quality Check: {status}")
    return result


async def notify_result(ctx: WorkflowContext) -> dict:
    """通知结果"""
    qc = ctx.node_outputs["quality_check"]
    
    if qc["passed"]:
        msg = f"Workflow succeeded. Index built with {qc['stats']['total_chunks']} chunks."
    else:
        msg = f"Workflow completed but quality check failed: {qc['reason']}"
    
    print(f"    Notification: {msg}")
    return {"message": msg, "sent": True}


# === 定义工作流 ===

def build_document_workflow():
    """构建文档处理 DAG 工作流"""
    
    # define_node 辅助函数
    def node(name, handler, deps=None, **kwargs):
        return name, WorkflowNode(
            name=name, 
            handler=handler,
            inputs=deps or [],
            retry_count=kwargs.get("retry", 1),
            timeout=kwargs.get("timeout", 60),
            condition=kwargs.get("condition"),
            on_failure=kwargs.get("on_failure", "fail")
        )
    
    nodes = dict([
        node("load_docs", load_documents),
        node("chunk_docs", chunk_documents, deps=["load_docs"]),
        node("gen_embeddings", generate_embeddings, deps=["chunk_docs"], timeout=120),
        node("extract_entities", extract_entities, deps=["load_docs"]),
        node("merge_index", merge_and_index, deps=["gen_embeddings", "extract_entities"]),
        node("quality_check", quality_check, deps=["merge_index"], on_failure="skip"),
        node("notify", notify_result, deps=["quality_check"], retry=2),
    ])
    
    edges = [
        ("load_docs", "chunk_docs"),
        ("load_docs", "extract_entities"),
        ("chunk_docs", "gen_embeddings"),
        ("gen_embeddings", "merge_index"),
        ("extract_entities", "merge_index"),
        ("merge_index", "quality_check"),
        ("quality_check", "notify"),
    ]
    
    return Workflow(name="document_processing", nodes=nodes, edges=edges)


# === 运行 ===
async def main():
    engine = WorkflowEngine()
    workflow = build_document_workflow()
    
    context = WorkflowContext(inputs={
        "doc_path": "./sample_docs",
        "user_id": "user_123"
    })
    
    try:
        results = await engine.execute(workflow, context)
        
        print("\nFinal Results:")
        for node_name, output in results.items():
            print(f"  {node_name}: {json.dumps(output, indent=2, ensure_ascii=False)[:200]}")
    
    except Exception as e:
        print(f"\nWorkflow failed: {e}")

if __name__ == "__main__":
    asyncio.run(main())

执行流程可视化:

                load_docs
                /        \
               /          \
        chunk_docs    extract_entities    ← 并行执行!
             |              |
        gen_embeddings       |
             \              /
              \            /
              merge_index
                   |
              quality_check
                   |
                notify

三、动态路由:条件分支

3.1 条件节点

# 条件节点示例:根据文档类型分流
async def classify_document(ctx: WorkflowContext) -> str:
    """分类文档类型,返回路由标签"""
    doc = ctx.inputs.get("document", {})
    content = doc.get("content", "")
    
    if "合同" in content or "协议" in content:
        return "contract"
    elif "代码" in content or "function" in content:
        return "code"
    else:
        return "general"


# 条件函数:只有分类为 "contract" 时才执行
def is_contract(ctx: WorkflowContext) -> bool:
    return ctx.node_outputs.get("classify_doc") == "contract"

def is_code(ctx: WorkflowContext) -> bool:
    return ctx.node_outputs.get("classify_doc") == "code"


# 工作流中的条件节点
contract_parser = WorkflowNode(
    name="parse_contract",
    handler=parse_contract_terms,
    condition=is_contract  # 仅当 classify_doc 返回 "contract"
)

code_analyzer = WorkflowNode(
    name="analyze_code",
    handler=analyze_code_structure,
    condition=is_code
)

3.2 LLM 驱动的动态路由

from openai import OpenAI

async def llm_router(ctx: WorkflowContext) -> str:
    """让 LLM 决定下一步走哪个分支"""
    client = OpenAI(api_key=ctx.inputs.get("api_key"))
    user_query = ctx.inputs.get("query")
    
    response = client.chat.completions.create(
        model="gpt-4o-mini",
        messages=[{
            "role": "system",
            "content": """分析用户意图,返回以下路由标签之一:
- SEARCH: 需要检索知识库
- CODE: 需要生成/分析代码
- CHART: 需要数据可视化
- CHAT: 简单对话即可回答
- COMPLEX: 需要多步推理

只回复一个标签。"""
        }, {
            "role": "user",
            "content": user_query
        }],
        temperature=0
    )
    
    route = response.choices[0].message.content.strip()
    print(f"    LLM Router: {user_query[:50]}... → {route}")
    return route


# 动态路由版本
class DynamicWorkflow(Workflow):
    """支持动态路由的工作流"""
    
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.dynamic_edges: Dict[str, Callable] = {}  # node → routing_fn
    
    def add_dynamic_route(self, from_node: str, routing_fn: Callable):
        """添加动态路由:fn(ctx) → target_node_name"""
        self.dynamic_edges[from_node] = routing_fn
    
    async def execute(self, context: WorkflowContext) -> dict:
        """执行时动态解析路由"""
        results = {}
        
        current = "start"
        while current:
            node = self.nodes.get(current)
            if not node:
                break
            
            result = await self._execute_node(node, context, "")
            results[current] = result
            context.node_outputs[current] = result
            
            # 动态路由
            if current in self.dynamic_edges:
                route_fn = self.dynamic_edges[current]
                current = route_fn(context)
            elif current in self.edges:
                current = self.edges[current]  # 固定边
            else:
                current = None
        
        return results

四、状态机模式:复杂交互流程

对于需要多轮交互、状态转换的工作流(如审批流程),DAG 不适用——用有限状态机。

from enum import Enum, auto

class ApprovalState(Enum):
    DRAFT = auto()
    SUBMITTED = auto()
    REVIEWING = auto()
    APPROVED = auto()
    REJECTED = auto()
    REVISION_NEEDED = auto()

class StateMachineWorkflow:
    """基于状态机的审批工作流"""
    
    def __init__(self):
        self.transitions = {
            ApprovalState.DRAFT: {
                "submit": ApprovalState.SUBMITTED,
            },
            ApprovalState.SUBMITTED: {
                "start_review": ApprovalState.REVIEWING,
                "auto_approve": ApprovalState.APPROVED,  # 自动通过
            },
            ApprovalState.REVIEWING: {
                "approve": ApprovalState.APPROVED,
                "reject": ApprovalState.REJECTED,
                "request_revision": ApprovalState.REVISION_NEEDED,
            },
            ApprovalState.REVISION_NEEDED: {
                "resubmit": ApprovalState.SUBMITTED,
            },
        }
        
        self.handlers: Dict[tuple, Callable] = {}  # (from, to) → handler
    
    def on_transition(self, from_state, to_state, handler):
        """注册状态转换处理器"""
        self.handlers[(from_state, to_state)] = handler
    
    async def run(self, context: WorkflowContext) -> dict:
        """执行状态机"""
        state = ApprovalState.DRAFT
        history = []
        
        while state not in (ApprovalState.APPROVED, ApprovalState.REJECTED):
            # LLM 决定下一步操作
            action = await self._decide_action(state, context)
            
            if action not in self.transitions.get(state, {}):
                raise ValueError(f"Invalid transition: {state}{action}")
            
            next_state = self.transitions[state][action]
            
            # 执行转换处理器
            handler = self.handlers.get((state, next_state))
            if handler:
                result = await handler(context)
                context.node_outputs[f"{state.name}_{next_state.name}"] = result
            
            history.append({"from": state.name, "to": next_state.name, "action": action})
            state = next_state
        
        return {
            "final_state": state.name,
            "history": history
        }
    
    async def _decide_action(self, state: ApprovalState, ctx: WorkflowContext) -> str:
        """让 LLM 决定当前状态下的操作"""
        # 实现略
        pass

五、可观测性

class ObservableWorkflowEngine(WorkflowEngine):
    """带完整可观测性的工作流引擎"""
    
    def __init__(self):
        super().__init__()
        self.traces = []
    
    async def execute(self, workflow, context):
        """执行并记录完整追踪"""
        trace_id = f"trace_{int(time.time()*1000)}"
        
        trace = {
            "trace_id": trace_id,
            "workflow": workflow.name,
            "start_time": time.time(),
            "spans": []
        }
        
        # Wrap 每个节点记录 span
        original_execute = self._execute_node
        
        async def traced_execute(node, ctx, exec_id):
            span_start = time.time()
            try:
                result = await original_execute(node, ctx, exec_id)
                span = {
                    "node": node.name,
                    "status": "success",
                    "duration": time.time() - span_start,
                }
            except Exception as e:
                span = {
                    "node": node.name,
                    "status": "failed",
                    "error": str(e),
                    "duration": time.time() - span_start,
                }
                raise
            finally:
                trace["spans"].append(span)
            
            return result
        
        self._execute_node = traced_execute
        
        try:
            result = await super().execute(workflow, context)
            trace["status"] = "success"
            return result
        except Exception:
            trace["status"] = "failed"
            raise
        finally:
            trace["end_time"] = time.time()
            trace["total_duration"] = trace["end_time"] - trace["start_time"]
            self.traces.append(trace)
            
            # 打印追踪报告
            self._print_trace_report(trace)
    
    def _print_trace_report(self, trace: dict):
        """打印漂亮的可视化追踪"""
        print(f"\n{'='*60}")
        print(f"Trace: {trace['trace_id']}")
        print(f"Workflow: {trace['workflow']} | Status: {trace['status']}")
        print(f"Duration: {trace['total_duration']:.2f}s")
        print(f"{'='*60}")
        
        # 按耗时排序
        spans = sorted(trace["spans"], key=lambda s: s["duration"], reverse=True)
        for span in spans:
            bar_len = int(span["duration"] / trace["total_duration"] * 30)
            bar = "█" * bar_len
            status_icon = "✓" if span["status"] == "success" else "✗"
            print(f"  {status_icon} {span['node']:<20s} {span['duration']:.1f}s {bar}")

六、总结

工作流引擎是 Agent 从"能跑"到"可靠"的关键:

DAG 编排    → 声明式定义任务依赖,自动并行
动态路由    → LLM 决策流程分支,灵活应对变化
重试机制    → 自动处理瞬时故障
超时控制    → 防止节点无限等待
条件执行    → 跳过不需要的分支
状态持久化  → 从失败节点恢复,不重做已完成工作
可观测性    → 追踪每个节点的耗时和状态

下一篇——系列最终篇:多 Agent 协作——任务分解、通信协议与并行编排。将一个复杂任务自动拆解给多个 Agent 并行执行。


本文完整代码已开源。下一篇:多 Agent 协作(最终篇,即将发布)

Logo

Agent 垂直技术社区,欢迎活跃、内容共建。

更多推荐