AI客服系统 - RAG知识库检索系统实现
一、概述
本文档深入分析 AI客服系统中 RAG(Retrieval-Augmented Generation)检索系统的实现,基于 /backend/service/rag/ 目录源码展开讲解。
二、核心架构
2.1 系统组件
┌────────────────────┐
│ 用户输入 Query │
└─────────┬──────────┘
│
┌─────────▼──────────┐
│ Embedding Provider│ 向量化
└─────────┬──────────┘
│
┌─────────▼──────────┐
│ Vector Store │ 相似度检索
│ (Milvus/本地) │
└─────────┬──────────┘
│
┌─────────▼──────────┐
│ Filter & Rerank │ 结果优化
└─────────┬──────────┘
│
┌─────────▼──────────┐
│ Return SearchResult│ 返回结果
└────────────────────┘
2.2 核心数据结构
检索服务结构体: (retrieval.go:13-21)
type RetrievalService struct {
vectorStoreService *VectorStoreService
embeddingProvider embedding.EmbeddingProvider
docRepo *repository.DocumentRepository
kbRepo *repository.KnowledgeBaseRepository
cache *Cache
reranker *SimpleReranker
metrics *Metrics
}
搜索结果结构体: (types.go)
type SearchResult struct {
ID string `json:"id"`
DocumentID string `json:"document_id"`
Content string `json:"content"`
Score float64 `json:"score"`
Metadata map[string]interface{} `json:"metadata"`
}
三、核心实现详解
3.1 检索主流程
Retrieve 方法实现: (retrieval.go:42-112)
func (s *RetrievalService) Retrieve(ctx context.Context, query string, topK int, knowledgeBaseID *uint) ([]SearchResult, error) {
startTime := time.Now()
// 1. 向量库可用性检查
if !s.vectorStoreService.IsAvailable() {
s.metrics.RecordQuery(true, time.Since(startTime), false)
return []SearchResult{}, nil
}
// 2. 缓存检查
var results []SearchResult
cacheHit := false
if s.cache != nil {
if cached, ok := s.cache.Get(query, topK, knowledgeBaseID); ok {
results = cached
cacheHit = true
}
}
// 3. 缓存未命中,执行检索
if !cacheHit {
// 3.1 获取向量模型
svc, err := s.embeddingProvider.Get(ctx)
if err != nil {
return nil, fmt.Errorf("获取嵌入服务失败: %w", err)
}
// 3.2 查询向量化
queryVectors, err := svc.EmbedTexts(ctx, []string{query})
if err != nil {
return nil, fmt.Errorf("查询向量化失败: %w", err)
}
// 3.3 向量检索(多取一些用于后续过滤)
searchLimit := topK * 3
if searchLimit < 10 {
searchLimit = 10
}
results, err = s.vectorStoreService.SearchVectors(ctx, queryVectors[0], searchLimit, knowledgeBaseID)
if err != nil {
return nil, fmt.Errorf("向量检索失败: %w", err)
}
// 3.4 过滤未发布文档
results = s.filterByPublished(ctx, results, topK)
// 3.5 更新缓存
if s.cache != nil {
s.cache.Set(query, topK, knowledgeBaseID, results)
}
}
s.metrics.RecordQuery(err == nil, time.Since(startTime), cacheHit)
return results, err
}
关键点说明:
-
1️⃣ 向量数据库可用性检查: 避免调用失败导致整个链路崩溃
-
2️⃣ 缓存策略: 相同查询直接返回缓存结果
-
3️⃣ 多取结果: 搜索 limit 设置为 topK * 3,给过滤留余地
3.2 结果过滤逻辑
文档发布状态过滤: (retrieval.go:135-209)
func (s *RetrievalService) filterByPublished(ctx context.Context, results []SearchResult, topK int) []SearchResult {
// 1. 提取所有文档ID
docIDs := make([]uint, 0, len(results))
seen := make(map[uint]struct{})
for _, r := range results {
id, err := strconv.ParseUint(r.DocumentID, 10, 32)
if err != nil {
continue
}
uid := uint(id)
if _, ok := seen[uid]; !ok {
seen[uid] = struct{}{}
docIDs = append(docIDs, uid)
}
}
// 2. 查询文档状态
docs, err := s.docRepo.GetByIDs(docIDs)
if err != nil {
return results
}
// 3. 标记未发布文档
unpublished := make(map[uint]struct{})
docIDToKBID := make(map[uint]uint)
for _, d := range docs {
if d.Status != "published" {
unpublished[d.ID] = struct{}{}
}
docIDToKBID[d.ID] = d.KnowledgeBaseID
}
// 4. 标记 RAG 未启用的知识库
disabledKBIDs := make(map[uint]struct{})
if s.kbRepo != nil && len(docIDToKBID) > 0 {
// 查询知识库状态
kbIDs := make([]uint, 0, len(docIDToKBID))
for _, kbID := range docIDToKBID {
kbIDs = append(kbIDs, kbID)
}
if kbs, err := s.kbRepo.GetByIDs(kbIDs); err == nil {
for _, kb := range kbs {
if !kb.RAGEnabled {
disabledKBIDs[kb.ID] = struct{}{}
}
}
}
}
// 5. 过滤结果
filtered := make([]SearchResult, 0, len(results))
for _, r := range results {
id, err := strconv.ParseUint(r.DocumentID, 10, 32)
if err != nil {
// 非文档类型直接保留
filtered = append(filtered, r)
continue
}
uid := uint(id)
// 过滤未发布文档
if _, ok := unpublished[uid]; ok {
continue
}
// 过滤 RAG 未启用的知识库文档
if kbID, inDoc := docIDToKBID[uid]; inDoc {
if _, disabled := disabledKBIDs[kbID]; disabled {
continue
}
}
filtered = append(filtered, r)
if len(filtered) >= topK {
break
}
}
return filtered
}
设计亮点:
-
去重处理: 同一文档多个片段只保留一次
-
两级过滤: 文档发布状态 + 知识库 RAG 启用状态
-
非文档兼容: FAQ 等没有文档 ID 的结果直接保留
3.3 重排序增强
带重排序的检索: (retrieval.go:114-133)
func (s *RetrievalService) RetrieveWithRerank(ctx context.Context, query string, topK int, knowledgeBaseID *uint) ([]SearchResult, error) {
// 1. 先执行基础检索
results, err := s.Retrieve(ctx, query, topK, knowledgeBaseID)
if err != nil {
return nil, err
}
// 2. 重排序
if s.reranker != nil {
reranked, err := s.reranker.Rerank(ctx, query, results)
if err != nil {
// 重排序失败不影响主流程
return results, nil
}
return reranked, nil
}
return results, nil
}
3.4 缓存实现
缓存结构与操作: (cache.go)
type Cache struct {
data map[string]cacheItem
mu sync.RWMutex
ttl time.Duration
}
type cacheItem struct {
results []SearchResult
expires time.Time
}
func (c *Cache) Get(query string, topK int, kbID *uint) ([]SearchResult, bool) {
key := c.buildKey(query, topK, kbID)
c.mu.RLock()
defer c.mu.RUnlock()
item, ok := c.data[key]
if !ok {
return nil, false
}
if time.Now().After(item.expires) {
delete(c.data, key)
return nil, false
}
return item.results, true
}
func (c *Cache) Set(query string, topK int, kbID *uint, results []SearchResult) {
key := c.buildKey(query, topK, kbID)
c.mu.Lock()
defer c.mu.Unlock()
c.data[key] = cacheItem{
results: results,
expires: time.Now().Add(c.ttl),
}
}
四、向量模型适配
4.1 嵌入提供者接口
EmbeddingProvider 接口: (embedding/interface.go)
type EmbeddingProvider interface {
Get(ctx context.Context) (TextEmbedder, error)
}
type TextEmbedder interface {
EmbedTexts(ctx context.Context, texts []string) ([][]float32, error)
}
4.2 OpenAI 嵌入实现
OpenAI 适配: (embedding/openai.go)
type OpenAIEmbedder struct {
client *openai.Client
model string
}
func (e *OpenAIEmbedder) EmbedTexts(ctx context.Context, texts []string) ([][]float32, error) {
req := openai.EmbeddingRequest{
Input: texts,
Model: e.model,
}
resp, err := e.client.CreateEmbeddings(ctx, req)
if err != nil {
return nil, err
}
embeddings := make([][]float32, len(resp.Data))
for i, d := range resp.Data {
embeddings[i] = d.Embedding
}
return embeddings, nil
}
4.3 BGE 本地嵌入实现
本地 BGE 模型适配: (embedding/bge.go)
type BGEEmbedder struct {
modelPath string
// 本地模型加载相关
}
func (e *BGEEmbedder) EmbedTexts(ctx context.Context, texts []string) ([][]float32, error) {
// 调用本地推理服务
// ...
}
五、向量存储服务
5.1 Milvus 集成
向量搜索实现: (infra/milvus.go)
func (ms *MilvusStore) SearchVectors(ctx context.Context, queryVector []float32, topK int, knowledgeBaseID *string) ([]SearchResult, error) {
// 1. 构建搜索参数
searchParams, err := milvus.NewSearchParamBuilder().
WithVectorField("embedding").
WithTopK(topK).
WithMetricType("IP").
WithParams(...)
Build()
// 2. 执行搜索
results, err := ms.client.Search(ctx, ms.collectionName, nil, "", []string{"content", "document_id"}, []milvus.Vector{queryVector}, searchParams)
// 3. 解析结果
return parseSearchResults(results)
}
5.2 健康检查
服务可用性检测: (rag/health.go)
func (vs *VectorStoreService) IsAvailable() bool {
if vs.milvusClient == nil {
return false
}
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
// 简单检查:列出集合
collections, err := vs.milvusClient.ListCollections(ctx)
return err == nil && collections != nil
}
六、性能指标监控
指标记录: (rag/metrics.go)
type Metrics struct {
totalQueries int64
successQueries int64
cacheHits int64
latencySum time.Duration
mu sync.Mutex
}
func (m *Metrics) RecordQuery(success bool, latency time.Duration, cacheHit bool) {
m.mu.Lock()
defer m.mu.Unlock()
m.totalQueries++
if success {
m.successQueries++
}
if cacheHit {
m.cacheHits++
}
m.latencySum += latency
}
func (m *Metrics) GetStats() map[string]interface{} {
m.mu.Lock()
defer m.mu.Unlock()
avgLatency := float64(0)
if m.totalQueries > 0 {
avgLatency = float64(m.latencySum.Milliseconds()) / float64(m.totalQueries)
}
return map[string]interface{}{
"total_queries": m.totalQueries,
"success_rate": float64(m.successQueries) / float64(m.totalQueries),
"cache_hit_rate": float64(m.cacheHits) / float64(m.totalQueries),
"avg_latency_ms": avgLatency,
}
}
七、实战应用
7.1 在 AI 服务中集成
RAG 上下文获取: (ai_service.go:577-610)
func (s *AIService) retrieveRAGContext(ctx context.Context, query string, conversation *models.Conversation) (string, error) {
var knowledgeBaseID *uint
// 执行检索(使用重排序)
topK := 5
results, err := s.retrievalService.RetrieveWithRerank(ctx, query, topK, knowledgeBaseID)
if err != nil {
return "", fmt.Errorf("RAG 检索失败: %w", err)
}
if len(results) == 0 {
return "", nil
}
// 格式化为字符串上下文
var contextParts []string
for i, result := range results {
contextParts = append(contextParts, fmt.Sprintf("文档片段 %d:\n%s", i+1, result.Content))
}
return strings.Join(contextParts, "\n\n"), nil
}
7.2 降级策略
向量库不可用时的备选方案: (ai_service.go:612-629)
func (s *AIService) loadKBDocumentsDirect() (string, error) {
if s.retrievalService == nil {
return "", nil
}
// 直接从数据库加载已发布文档
docs, err := s.retrievalService.LoadAllPublishedDocuments(20)
if err != nil {
log.Printf("⚠️ 直接加载知识库文档失败: %v", err)
return "", err
}
if len(docs) == 0 {
return "", nil
}
var parts []string
for _, d := range docs {
parts = append(parts, fmt.Sprintf("【%s】\n%s", d.Title, d.Content))
}
return strings.Join(parts, "\n\n---\n\n"), nil
}
八、总结
本系统 RAG 实现具有以下特点:
-
高可用: 向量库健康检查 + 降级方案
-
高性能: 多层缓存 + 指标监控 + 结果过滤优化
-
可扩展: 多种嵌入模型支持 + 向量存储抽象
-
精准性: 结果重排序 + 文档发布状态过滤
-
可观测: 完整的 metrics 指标体系
更多推荐

所有评论(0)