【人工智能】langchain + qwen2.5 使用示例
实现本地知识库向量化后保存,然后langchain进行qa数据流。
·
1. qwen2.5 + langchain + text2vec-large-chinese
实现本地知识库向量化后保存,然后langchain进行qa数据流
from transformers import AutoModelForCausalLM, AutoTokenizer
from abc import ABC
from langchain.llms.base import LLM
from typing import Any, List, Mapping, Optional
from langchain.callbacks.manager import CallbackManagerForLLMRun
import os
import re
import torch
import numpy as np
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.document_loaders import TextLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain.docstore.document import Document
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
model_name = "/home/sky/model_data/Qwen/Qwen2.5-7B-Instruct"
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
class Qwen(LLM, ABC):
max_token: int = 10000
temperature: float = 0.01
top_p = 0.9
history_len: int = 3
def __init__(self):
super().__init__()
@property
def _llm_type(self) -> str:
return "Qwen"
@property
def _history_len(self) -> int:
return self.history_len
def set_history_len(self, history_len: int = 10) -> None:
self.history_len = history_len
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
messages = [
{"role": "system", "content": "你是小智,由智子引擎创建的AI智能助手."},
{"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
generated_ids = model.generate(
**model_inputs,
max_new_tokens=512
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
return response
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {"max_token": self.max_token,
"temperature": self.temperature,
"top_p": self.top_p,
"history_len": self.history_len}
class ChineseTextSplitter(CharacterTextSplitter):
def __init__(self, pdf: bool = False, **kwargs):
super().__init__(**kwargs)
self.pdf = pdf
def split_text(self, text: str) -> List[str]:
if self.pdf:
text = re.sub(r"\n{3,}", "\n", text)
text = re.sub('\s', ' ', text)
text = text.replace("\n\n", "")
sent_sep_pattern = re.compile(
'([﹒﹔﹖﹗.。!?]["’”」』]{0,2}|(?=["‘“「『]{1,2}|$))')
sent_list = []
for ele in sent_sep_pattern.split(text):
if sent_sep_pattern.match(ele) and sent_list:
sent_list[-1] += ele
elif ele:
sent_list.append(ele)
return sent_list
def load_file(filepath):
loader = TextLoader(filepath, autodetect_encoding=True)
textsplitter = ChineseTextSplitter(pdf=False)
docs = loader.load_and_split(textsplitter)
write_check_file(filepath, docs)
return docs
def write_check_file(filepath, docs):
folder_path = os.path.join(os.path.dirname(filepath), "tmp_files")
if not os.path.exists(folder_path):
os.makedirs(folder_path)
fp = os.path.join(folder_path, 'load_file.txt')
with open(fp, 'a+', encoding='utf-8') as fout:
fout.write("filepath=%s,len=%s" % (filepath, len(docs)))
fout.write('\n')
for i in docs:
fout.write(str(i))
fout.write('\n')
fout.close()
def save_faiss_db(docsearch, save_path):
docsearch.save_local(save_path)
if __name__ == '__main__':
# Load documents (pdf file or txt file)
filepath = '/home/sky/model_data/code/pdf_dir/1.txt'
# Embedding model name
EMBEDDING_MODEL = 'text2vec'
PROMPT_TEMPLATE = """已知信息:
{context_str}
基于以上已知信息,请简洁专业地回答用户的问题。如果无法从已有信息中得出答案,请说“根据提供的信息无法回答该问题”或者“提供的相关信息不足”,并且不要在答案中添加虚构的细节。请用中文回答。问题是:{question}"""
# Embedding running device
EMBEDDING_DEVICE = "cuda"
# Return top-k text chunks from vector store
VECTOR_SEARCH_TOP_K = 3
CHAIN_TYPE = 'stuff'
embedding_model_dict = {
"text2vec": "/home/sky/model_data/text2vec-large-chinese",
}
llm = Qwen()
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[EMBEDDING_MODEL], model_kwargs={'device': EMBEDDING_DEVICE})
# Load and process the file
docs = load_file(filepath)
# Create FAISS wrapper and save it
faiss_save_path = "/home/sky/model_data/faiss_index"
docsearch = FAISS.from_documents(docs, embeddings)
save_faiss_db(docsearch, faiss_save_path)
# Set up the retrieval QA chain
prompt = PromptTemplate(
template=PROMPT_TEMPLATE, input_variables=["context_str", "question"]
)
chain_type_kwargs = {"prompt": prompt, "document_variable_name": "context_str"}
qa = RetrievalQA.from_chain_type(
llm=llm,
chain_type=CHAIN_TYPE,
retriever=docsearch.as_retriever(search_kwargs={"k": VECTOR_SEARCH_TOP_K}),
chain_type_kwargs=chain_type_kwargs
)
# Example query
query = "先说下你是谁?然后再简述一下三体的故事."
print(qa.run(query))
更多推荐



所有评论(0)