import numpy as np
from openai import OpenAI
import torch
import torch.nn as nn
import torchtext

torchtext.disable_torchtext_deprecation_warning()
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import GloVe


# 填写密钥即可
models = {
    "moonshot-v1-32k": {
        "api_key": "",
        "base_url": "https://api.moonshot.cn/v1",
    },
    "yi-large-turbo": {
        "api_key": "",
        "base_url": "https://api.lingyiwanwu.com/v1",
    },
    "generalv3.5": {
        "api_key": "",
        "base_url": "https://spark-api-open.xf-yun.com/v1",
    },
    "glm-4-flash": {
        "api_key": "",
        "base_url": "https://open.bigmodel.cn/api/paas/v4/",
    },
    "glm-4-plus": {
        "api_key": "",
        "base_url": "https://open.bigmodel.cn/api/paas/v4/",
    },
    "Baichuan2-Turbo": {
        "api_key": "",
        "base_url": "https://api.baichuan-ai.com/v1/",
    },
    "qwen-max": {
        "api_key": "",
        "base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
    },
}


def get_client(model_name):
    assert model_name in models, "model name not found in models.py"
    model = models[model_name]
    client = OpenAI(
        api_key=model["api_key"],
        base_url=model["base_url"],
    )
    return client


# 预加载GloVe模型
glove = GloVe(name="6B", dim=300)
embedding = nn.Embedding.from_pretrained(glove.vectors)

# 不依赖于kb类的实例,不必作为类的方法
def get_word_vector(word):
    """获取单个词的向量表示"""
    if word in glove.stoi:
        index = glove.stoi[word]
        vector = embedding(torch.tensor([[index]])).detach().numpy()
    else:
        vector = np.zeros((1, 300))
    return vector


class KnowledgeBase:
    def __init__(self, file_path):
        self.docs = self.load_docs(file_path)
        self.embeds = self.encode_docs(self.docs)

    @staticmethod
    def load_docs(file_path):
        """加载文档并分割为段落"""
        with open(file_path, "r", encoding="utf-8") as file:
            content = file.read()
        return [content[i : i + 150] for i in range(0, len(content), 150)]

    @staticmethod
    def tokenize(text):
        """使用基本英文分词器进行分词"""
        tokenizer = get_tokenizer("basic_english")
        return tokenizer(text)

    def encode_docs(self, texts):
        """将文档列表编码为向量"""
        return np.concatenate([self.encode_text(text) for text in texts], axis=0)

    def encode_text(self, text):
        """将单个文档编码为向量"""
        words = self.tokenize(text)
        word_vectors = [get_word_vector(word) for word in words if word in glove.stoi]
        if word_vectors:
            return np.mean(np.vstack(word_vectors), axis=0).reshape(1, -1)
        else:
            return np.zeros((1, 300))

    @staticmethod
    def cosine_similarity(e1, e2):
        """计算两个向量的余弦相似度"""
        dot_product = np.dot(e1, e2)
        norm_e1 = np.linalg.norm(e1)
        norm_e2 = np.linalg.norm(e2)
        return dot_product / (norm_e1 * norm_e2)

    def search(self, text):
        """搜索最相似的文档"""
        query_vector = self.encode_text(text)
        similarities = [self.cosine_similarity(query_vector, te) for te in self.embeds]
        max_similarity_index = np.argmax(similarities)
        return self.docs[max_similarity_index]


class RagModel:
    def __init__(self, model_name, kb: KnowledgeBase):
        self.model_name = model_name
        self.kb = kb
        self.prompt_template = "基于:%s\n回答:%s"

    def chat(self, message):
        """与模型进行对话"""
        context = self.kb.search(message)
        query = self.prompt_template % (context, message)
        print("query:", query)
        client = get_client(self.model_name)
        completion = client.chat.completions.create(
            model=self.model_name,
            messages=[{"role": "user", "content": query}],
            temperature=1,
        )
        return completion.choices[0].message.content


# 实例化知识库
kb = KnowledgeBase("knowledge.txt")

# 实例化RAG模型
rag_model = RagModel("yi-large-turbo", kb)

# 与模型进行对话
while True:
    user_input = input("human >>> ")
    if user_input.lower() == "exit":
        break
    bot_response = rag_model.chat(user_input)
    print("bot >>>", bot_response)

Logo

Agent 垂直技术社区,欢迎活跃、内容共建。

更多推荐