一、概述

本文档深入分析 AI客服系统中 RAG(Retrieval-Augmented Generation)检索系统的实现,基于 /backend/service/rag/ 目录源码展开讲解。

二、核心架构

2.1 系统组件

┌────────────────────┐
│   用户输入 Query   │
└─────────┬──────────┘
          │
┌─────────▼──────────┐
│  Embedding Provider│ 向量化
└─────────┬──────────┘
          │
┌─────────▼──────────┐
│ Vector Store       │ 相似度检索
│ (Milvus/本地)      │
└─────────┬──────────┘
          │
┌─────────▼──────────┐
│ Filter & Rerank    │ 结果优化
└─────────┬──────────┘
          │
┌─────────▼──────────┐
│ Return SearchResult│ 返回结果
└────────────────────┘

2.2 核心数据结构

检索服务结构体: (retrieval.go:13-21)

type RetrievalService struct {
    vectorStoreService *VectorStoreService
    embeddingProvider  embedding.EmbeddingProvider
    docRepo            *repository.DocumentRepository
    kbRepo             *repository.KnowledgeBaseRepository
    cache              *Cache
    reranker           *SimpleReranker
    metrics            *Metrics
}

搜索结果结构体: (types.go)

type SearchResult struct {
    ID           string  `json:"id"`
    DocumentID   string  `json:"document_id"`
    Content      string  `json:"content"`
    Score        float64 `json:"score"`
    Metadata     map[string]interface{} `json:"metadata"`
}

三、核心实现详解

3.1 检索主流程

Retrieve 方法实现: (retrieval.go:42-112)

func (s *RetrievalService) Retrieve(ctx context.Context, query string, topK int, knowledgeBaseID *uint) ([]SearchResult, error) {
    startTime := time.Now()
    
    // 1. 向量库可用性检查
    if !s.vectorStoreService.IsAvailable() {
        s.metrics.RecordQuery(true, time.Since(startTime), false)
        return []SearchResult{}, nil
    }
    
    // 2. 缓存检查
    var results []SearchResult
    cacheHit := false
    if s.cache != nil {
        if cached, ok := s.cache.Get(query, topK, knowledgeBaseID); ok {
            results = cached
            cacheHit = true
        }
    }
    
    // 3. 缓存未命中,执行检索
    if !cacheHit {
        // 3.1 获取向量模型
        svc, err := s.embeddingProvider.Get(ctx)
        if err != nil {
            return nil, fmt.Errorf("获取嵌入服务失败: %w", err)
        }
        
        // 3.2 查询向量化
        queryVectors, err := svc.EmbedTexts(ctx, []string{query})
        if err != nil {
            return nil, fmt.Errorf("查询向量化失败: %w", err)
        }
        
        // 3.3 向量检索(多取一些用于后续过滤)
        searchLimit := topK * 3
        if searchLimit < 10 {
            searchLimit = 10
        }
        results, err = s.vectorStoreService.SearchVectors(ctx, queryVectors[0], searchLimit, knowledgeBaseID)
        if err != nil {
            return nil, fmt.Errorf("向量检索失败: %w", err)
        }
        
        // 3.4 过滤未发布文档
        results = s.filterByPublished(ctx, results, topK)
        
        // 3.5 更新缓存
        if s.cache != nil {
            s.cache.Set(query, topK, knowledgeBaseID, results)
        }
    }
    
    s.metrics.RecordQuery(err == nil, time.Since(startTime), cacheHit)
    return results, err
}

关键点说明:

  • 1️⃣ 向量数据库可用性检查: 避免调用失败导致整个链路崩溃

  • 2️⃣ 缓存策略: 相同查询直接返回缓存结果

  • 3️⃣ 多取结果: 搜索 limit 设置为 topK * 3,给过滤留余地

3.2 结果过滤逻辑

文档发布状态过滤: (retrieval.go:135-209)

func (s *RetrievalService) filterByPublished(ctx context.Context, results []SearchResult, topK int) []SearchResult {
    // 1. 提取所有文档ID
    docIDs := make([]uint, 0, len(results))
    seen := make(map[uint]struct{})
    for _, r := range results {
        id, err := strconv.ParseUint(r.DocumentID, 10, 32)
        if err != nil {
            continue
        }
        uid := uint(id)
        if _, ok := seen[uid]; !ok {
            seen[uid] = struct{}{}
            docIDs = append(docIDs, uid)
        }
    }
    
    // 2. 查询文档状态
    docs, err := s.docRepo.GetByIDs(docIDs)
    if err != nil {
        return results
    }
    
    // 3. 标记未发布文档
    unpublished := make(map[uint]struct{})
    docIDToKBID := make(map[uint]uint)
    for _, d := range docs {
        if d.Status != "published" {
            unpublished[d.ID] = struct{}{}
        }
        docIDToKBID[d.ID] = d.KnowledgeBaseID
    }
    
    // 4. 标记 RAG 未启用的知识库
    disabledKBIDs := make(map[uint]struct{})
    if s.kbRepo != nil && len(docIDToKBID) > 0 {
        // 查询知识库状态
        kbIDs := make([]uint, 0, len(docIDToKBID))
        for _, kbID := range docIDToKBID {
            kbIDs = append(kbIDs, kbID)
        }
        if kbs, err := s.kbRepo.GetByIDs(kbIDs); err == nil {
            for _, kb := range kbs {
                if !kb.RAGEnabled {
                    disabledKBIDs[kb.ID] = struct{}{}
                }
            }
        }
    }
    
    // 5. 过滤结果
    filtered := make([]SearchResult, 0, len(results))
    for _, r := range results {
        id, err := strconv.ParseUint(r.DocumentID, 10, 32)
        if err != nil {
            // 非文档类型直接保留
            filtered = append(filtered, r)
            continue
        }
        uid := uint(id)
        
        // 过滤未发布文档
        if _, ok := unpublished[uid]; ok {
            continue
        }
        
        // 过滤 RAG 未启用的知识库文档
        if kbID, inDoc := docIDToKBID[uid]; inDoc {
            if _, disabled := disabledKBIDs[kbID]; disabled {
                continue
            }
        }
        
        filtered = append(filtered, r)
        if len(filtered) >= topK {
            break
        }
    }
    
    return filtered
}

设计亮点:

  1. 去重处理: 同一文档多个片段只保留一次

  2. 两级过滤: 文档发布状态 + 知识库 RAG 启用状态

  3. 非文档兼容: FAQ 等没有文档 ID 的结果直接保留

3.3 重排序增强

带重排序的检索: (retrieval.go:114-133)

func (s *RetrievalService) RetrieveWithRerank(ctx context.Context, query string, topK int, knowledgeBaseID *uint) ([]SearchResult, error) {
    // 1. 先执行基础检索
    results, err := s.Retrieve(ctx, query, topK, knowledgeBaseID)
    if err != nil {
        return nil, err
    }
    
    // 2. 重排序
    if s.reranker != nil {
        reranked, err := s.reranker.Rerank(ctx, query, results)
        if err != nil {
            // 重排序失败不影响主流程
            return results, nil
        }
        return reranked, nil
    }
    
    return results, nil
}

3.4 缓存实现

缓存结构与操作: (cache.go)

type Cache struct {
    data map[string]cacheItem
    mu   sync.RWMutex
    ttl  time.Duration
}

type cacheItem struct {
    results []SearchResult
    expires time.Time
}

func (c *Cache) Get(query string, topK int, kbID *uint) ([]SearchResult, bool) {
    key := c.buildKey(query, topK, kbID)
    
    c.mu.RLock()
    defer c.mu.RUnlock()
    
    item, ok := c.data[key]
    if !ok {
        return nil, false
    }
    
    if time.Now().After(item.expires) {
        delete(c.data, key)
        return nil, false
    }
    
    return item.results, true
}

func (c *Cache) Set(query string, topK int, kbID *uint, results []SearchResult) {
    key := c.buildKey(query, topK, kbID)
    
    c.mu.Lock()
    defer c.mu.Unlock()
    
    c.data[key] = cacheItem{
        results: results,
        expires: time.Now().Add(c.ttl),
    }
}

四、向量模型适配

4.1 嵌入提供者接口

EmbeddingProvider 接口: (embedding/interface.go)

type EmbeddingProvider interface {
    Get(ctx context.Context) (TextEmbedder, error)
}

type TextEmbedder interface {
    EmbedTexts(ctx context.Context, texts []string) ([][]float32, error)
}

4.2 OpenAI 嵌入实现

OpenAI 适配: (embedding/openai.go)

type OpenAIEmbedder struct {
    client *openai.Client
    model  string
}

func (e *OpenAIEmbedder) EmbedTexts(ctx context.Context, texts []string) ([][]float32, error) {
    req := openai.EmbeddingRequest{
        Input: texts,
        Model: e.model,
    }
    
    resp, err := e.client.CreateEmbeddings(ctx, req)
    if err != nil {
        return nil, err
    }
    
    embeddings := make([][]float32, len(resp.Data))
    for i, d := range resp.Data {
        embeddings[i] = d.Embedding
    }
    
    return embeddings, nil
}

4.3 BGE 本地嵌入实现

本地 BGE 模型适配: (embedding/bge.go)

type BGEEmbedder struct {
    modelPath string
    // 本地模型加载相关
}

func (e *BGEEmbedder) EmbedTexts(ctx context.Context, texts []string) ([][]float32, error) {
    // 调用本地推理服务
    // ...
}

五、向量存储服务

5.1 Milvus 集成

向量搜索实现: (infra/milvus.go)

func (ms *MilvusStore) SearchVectors(ctx context.Context, queryVector []float32, topK int, knowledgeBaseID *string) ([]SearchResult, error) {
    // 1. 构建搜索参数
    searchParams, err := milvus.NewSearchParamBuilder().
        WithVectorField("embedding").
        WithTopK(topK).
        WithMetricType("IP").
        WithParams(...)
        Build()
    
    // 2. 执行搜索
    results, err := ms.client.Search(ctx, ms.collectionName, nil, "", []string{"content", "document_id"}, []milvus.Vector{queryVector}, searchParams)
    
    // 3. 解析结果
    return parseSearchResults(results)
}

5.2 健康检查

服务可用性检测: (rag/health.go)

func (vs *VectorStoreService) IsAvailable() bool {
    if vs.milvusClient == nil {
        return false
    }
    
    ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
    defer cancel()
    
    // 简单检查:列出集合
    collections, err := vs.milvusClient.ListCollections(ctx)
    return err == nil && collections != nil
}

六、性能指标监控

指标记录: (rag/metrics.go)

type Metrics struct {
    totalQueries   int64
    successQueries int64
    cacheHits      int64
    latencySum     time.Duration
    mu             sync.Mutex
}

func (m *Metrics) RecordQuery(success bool, latency time.Duration, cacheHit bool) {
    m.mu.Lock()
    defer m.mu.Unlock()
    
    m.totalQueries++
    if success {
        m.successQueries++
    }
    if cacheHit {
        m.cacheHits++
    }
    m.latencySum += latency
}

func (m *Metrics) GetStats() map[string]interface{} {
    m.mu.Lock()
    defer m.mu.Unlock()
    
    avgLatency := float64(0)
    if m.totalQueries > 0 {
        avgLatency = float64(m.latencySum.Milliseconds()) / float64(m.totalQueries)
    }
    
    return map[string]interface{}{
        "total_queries":   m.totalQueries,
        "success_rate":    float64(m.successQueries) / float64(m.totalQueries),
        "cache_hit_rate":  float64(m.cacheHits) / float64(m.totalQueries),
        "avg_latency_ms":  avgLatency,
    }
}

七、实战应用

7.1 在 AI 服务中集成

RAG 上下文获取: (ai_service.go:577-610)

func (s *AIService) retrieveRAGContext(ctx context.Context, query string, conversation *models.Conversation) (string, error) {
    var knowledgeBaseID *uint
    
    // 执行检索(使用重排序)
    topK := 5
    results, err := s.retrievalService.RetrieveWithRerank(ctx, query, topK, knowledgeBaseID)
    if err != nil {
        return "", fmt.Errorf("RAG 检索失败: %w", err)
    }
    
    if len(results) == 0 {
        return "", nil
    }
    
    // 格式化为字符串上下文
    var contextParts []string
    for i, result := range results {
        contextParts = append(contextParts, fmt.Sprintf("文档片段 %d:\n%s", i+1, result.Content))
    }
    
    return strings.Join(contextParts, "\n\n"), nil
}

7.2 降级策略

向量库不可用时的备选方案: (ai_service.go:612-629)

func (s *AIService) loadKBDocumentsDirect() (string, error) {
    if s.retrievalService == nil {
        return "", nil
    }
    
    // 直接从数据库加载已发布文档
    docs, err := s.retrievalService.LoadAllPublishedDocuments(20)
    if err != nil {
        log.Printf("⚠️ 直接加载知识库文档失败: %v", err)
        return "", err
    }
    
    if len(docs) == 0 {
        return "", nil
    }
    
    var parts []string
    for _, d := range docs {
        parts = append(parts, fmt.Sprintf("【%s】\n%s", d.Title, d.Content))
    }
    
    return strings.Join(parts, "\n\n---\n\n"), nil
}

八、总结

本系统 RAG 实现具有以下特点:

  1. 高可用: 向量库健康检查 + 降级方案

  2. 高性能: 多层缓存 + 指标监控 + 结果过滤优化

  3. 可扩展: 多种嵌入模型支持 + 向量存储抽象

  4. 精准性: 结果重排序 + 文档发布状态过滤

  5. 可观测: 完整的 metrics 指标体系

Logo

Agent 垂直技术社区,欢迎活跃、内容共建。

更多推荐