【LangChain】langchain.chains.create_sql_query_chain() 函数:基于自然语言生成 SQL 查询的链(Chain)
create_sql_query_chain 是 LangChain 库中 langchain.chains.sql_database.query 模块的一个函数,用于创建一个 LCEL(LangChain Expression Language)链,将自然语言查询转换为 SQL 语句。它结合语言模型和数据库上下文(如表结构),生成符合数据库方言的 SQL 查询,适合需要从数据库中提取数据的场景。
langchain.chains.create_sql_query_chain 函数是 LangChain 库中的一个函数,用于创建基于自然语言生成 SQL 查询的链(Chain),结合语言模型(LLM)和数据库上下文生成可执行的 SQL 语句。
本文基于 LangChain 0.3.x,详细介绍 create_sql_query_chain 的定义、参数、方法和典型场景,并提供一个独立示例,展示如何使用 create_sql_query_chain 结合 ChatOpenAI 和 SQLDatabase 实现人工智能主题的数据库查询,示例突出该函数在自然语言到 SQL 转换中的作用。( LangChain Chains 文档)
langchain.chains.create_sql_query_chain 简介
create_sql_query_chain 是 LangChain 库中 langchain.chains.sql_database.query 模块的一个函数,用于创建一个 LCEL(LangChain Expression Language)链,将自然语言查询转换为 SQL 语句。它结合语言模型和数据库上下文(如表结构),生成符合数据库方言的 SQL 查询,适合需要从数据库中提取数据的场景。
核心功能:
- 将自然语言问题转换为结构化的 SQL 查询。
- 利用数据库元数据(如表名、列名)生成准确的查询。
- 支持多种数据库方言(如 SQLite、PostgreSQL、MySQL)。
- 与 LCEL 链无缝集成,可进一步处理查询结果。
适用场景:
- 构建自然语言数据库查询接口(如问答系统)。
- 自动化数据分析,允许用户用自然语言提取数据。
- 结合 RAG 或代理,从数据库中获取动态数据。
- 开发交互式数据查询工具,支持非技术用户。
与其他链对比:
create_sql_query_chain: 生成 SQL 查询,专注于数据库交互。RetrievalQA: 用于文档检索的 RAG 链。ConversationChain: 处理对话历史。SQLDatabaseChain: 执行 SQL 查询并返回结果(更高级,但已部分弃用)。
注意:
- 生成的 SQL 查询需验证,避免语法错误或注入风险。
- 依赖
langchain_community.utilities.SQLDatabase提供数据库上下文。
函数定义和参数
以下是 create_sql_query_chain 的定义,基于 LangChain 源码(langchain/chains/sql_database/query.py)和官方文档(create_sql_query_chain)。
函数签名
def create_sql_query_chain(
llm: BaseLanguageModel,
db: SQLDatabase,
prompt: Optional[BasePromptTemplate] = None,
k: int = 5
) -> Runnable[Union[str, Dict[str, Any]], str]
- 参数:
llm(BaseLanguageModel):语言模型,用于生成 SQL 查询(如ChatOpenAI)。db(SQLDatabase):数据库实例,提供表结构和方言信息。prompt(Optional[BasePromptTemplate]):自定义提示模板,默认为内置模板。k(int):生成查询时考虑的上下文行数,默认为 5。
- 返回值:
Runnable[Union[str, Dict[str, Any]], str]:一个 LCEL 链,接受字符串或字典输入,输出 SQL 查询字符串。
- 功能:
- 使用 LLM 和数据库上下文生成 SQL 查询。
- 支持自定义提示模板以优化查询生成。
- 输出为纯 SQL 语句,需进一步执行。
默认提示模板
默认模板(简化为中文):
你是一个 SQL 专家,基于以下数据库信息将用户问题转换为 SQL 查询。
数据库方言: {dialect}
表结构: {table_info}
问题: {input}
输出仅包含 SQL 查询语句。
输入格式
- 字符串:直接输入自然语言问题:
chain.invoke("有多少用户?") - 字典:指定
input键:chain.invoke({"input": "有多少用户?"})
输出格式
- 字符串:生成的 SQL 查询,如:
SELECT COUNT(*) FROM users;
工作原理
create_sql_query_chain 的运行逻辑如下:
- 输入:接受自然语言问题(字符串或字典)。
- 处理:
- 从
db获取数据库元数据(表名、列名、样本数据)。 - 结合
llm和prompt,生成 SQL 查询。 - 使用
k参数控制上下文数据量。
- 从
- 输出:返回 SQL 查询字符串。
- 执行:需手动或通过链执行查询(如
db.run(query))。
在 LCEL 中的作用:
- 作为链的一部分,生成 SQL 查询。
- 可与
SQLDatabase.run或自定义解析器组合,获取查询结果。
示例流程:
from langchain.chains import create_sql_query_chain
from langchain_openai import ChatOpenAI
from langchain_community.utilities import SQLDatabase
llm = ChatOpenAI()
db = SQLDatabase.from_uri("sqlite:///example.db")
chain = create_sql_query_chain(llm, db)
query = chain.invoke("有多少用户?")
# 输出: SELECT COUNT(*) FROM users;
常用方法
create_sql_query_chain 返回一个 Runnable 对象,支持以下方法:
1. invoke
def invoke(self, input: Union[str, Dict[str, Any]], config: Optional[RunnableConfig] = None) -> str
- 功能:同步调用,生成 SQL 查询。
- 输入:
input(str | Dict):问题字符串或包含input键的字典。config(Optional[RunnableConfig]):运行配置(如超时)。
- 输出:SQL 查询字符串。
- 示例:
query = chain.invoke("列出所有用户的姓名") print(query) # 输出: SELECT name FROM users;
2. ainvoke
async def ainvoke(self, input: Union[str, Dict[str, Any]], config: Optional[RunnableConfig] = None) -> str
- 功能:异步调用,生成 SQL 查询。
- 示例:
query = await chain.ainvoke("有多少用户?") print(query) # 输出: SELECT COUNT(*) FROM users;
3. stream / astream
- 功能:支持流式输出,逐块返回查询。
- 示例:
for chunk in chain.stream("列出所有用户"): print(chunk, end="")
使用方式
以下是使用 create_sql_query_chain 的步骤。
1. 安装依赖
pip install --upgrade langchain langchain-openai sqlalchemy
2. 设置 OpenAI API 密钥
export OPENAI_API_KEY="your-api-key"
或在代码中:
import os
os.environ["OPENAI_API_KEY"] = "your-api-key"
3. 初始化数据库和 LLM
from langchain_community.utilities import SQLDatabase
from langchain_openai import ChatOpenAI
db = SQLDatabase.from_uri("sqlite:///example.db")
llm = ChatOpenAI(model="gpt-3.5-turbo")
4. 创建 SQL 查询链
from langchain.chains import create_sql_query_chain
chain = create_sql_query_chain(llm, db)
5. 调用链并执行查询
query = chain.invoke("有多少用户?")
result = db.run(query)
print(result)
使用 create_sql_query_chain 的示例
以下是一个独立示例,展示如何使用 create_sql_query_chain 结合 ChatOpenAI 和 SQLDatabase 实现人工智能主题的数据库查询。链生成 SQL 查询并执行,回答用户关于 AI 相关数据的自然语言问题。
准备环境:
- 获取 OpenAI API 密钥:OpenAI Platform。
- 设置环境变量:
export OPENAI_API_KEY="your-api-key" - 安装依赖:
pip install --upgrade langchain langchain-openai sqlalchemy - 创建 SQLite 数据库
ai_data.db:CREATE TABLE ai_projects ( id INTEGER PRIMARY KEY, name TEXT NOT NULL, category TEXT, year INTEGER ); INSERT INTO ai_projects (name, category, year) VALUES ('AlphaGo', 'Game AI', 2016), ('GPT-3', 'NLP', 2020), ('DALL-E', 'Generative AI', 2021);
代码:
from langchain_openai import ChatOpenAI
from langchain_community.utilities import SQLDatabase
from langchain.chains import create_sql_query_chain
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableLambda
from langchain_core.output_parsers import StrOutputParser
# 初始化 ChatOpenAI
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0.7)
# 初始化 SQLDatabase
db = SQLDatabase.from_uri("sqlite:///ai_data.db")
# 创建 SQL 查询链
sql_chain = create_sql_query_chain(llm, db)
# 定义提示模板,用于格式化查询结果
prompt = ChatPromptTemplate.from_template(
"你是一个数据库专家,基于以下 SQL 查询结果回答问题:\n结果:{result}\n问题:{input}\n回答:"
)
# 定义输出解析器
parser = StrOutputParser()
# 创建完整工作流
chain = (
{
"query": sql_chain,
"input": RunnableLambda(lambda x: x.strip())
}
| RunnableLambda(lambda x: {"result": db.run(x["query"]), "input": x["input"]})
| prompt
| llm
| parser
)
# 测试 SQL 查询工作流
print("测试 create_sql_query_chain 和数据库查询:")
try:
questions = [
"有多少 AI 项目?",
"列出 2020 年后的 AI 项目"
]
for question in questions:
result = chain.invoke(question)
print(f"\n输入问题: {question}")
print(f"回答: {result}")
except Exception as e:
print(f"错误: {e}")
输出示例(实际输出取决于模型和数据库):
测试 create_sql_query_chain 和数据库查询:
输入问题: 有多少 AI 项目?
回答: 数据库中有 3 个 AI 项目。
输入问题: 列出 2020 年后的 AI 项目
回答: 2020 年后的 AI 项目包括:
- GPT-3(类别:NLP,2020 年)
- DALL-E(类别:Generative AI,2021 年)
代码说明
- LLM 初始化:
- 使用
ChatOpenAI调用gpt-3.5-turbo,设置temperature=0.7。
- 使用
- 数据库初始化:
- 连接 SQLite 数据库
ai_data.db,包含ai_projects表。
- 连接 SQLite 数据库
- SQL 查询链:
- 使用
create_sql_query_chain创建 SQL 生成链。
- 使用
- 工作流:
- 使用
sql_chain生成 SQL 查询。 - 使用
RunnableLambda清理输入并执行查询(db.run)。 - 组合
prompt、llm和parser格式化结果。
- 使用
- 测试:
- 测试两个问题:统计项目数和过滤年份。
- 显示问题和回答,展示自然语言到 SQL 的转换。
- 错误处理:
- 使用
try-except捕获 API 或数据库错误。
- 使用
运行要求:
- 有效的 OpenAI API 密钥:
export OPENAI_API_KEY="your-api-key" - 安装依赖:
pip install --upgrade langchain langchain-openai sqlalchemy - SQLite 数据库
ai_data.db已创建。 - 网络连接:访问
https://api.openai.com.
注意事项
- API 密钥:
- 确保
OPENAI_API_KEY已设置:echo $OPENAI_API_KEY - 或在代码中设置:
llm = ChatOpenAI(api_key="your-api-key")
- 确保
- 数据库配置:
- 验证数据库 URI:
db = SQLDatabase.from_uri("sqlite:///ai_data.db") print(db.get_table_info()) - 限制表范围:
db = SQLDatabase.from_uri("sqlite:///ai_data.db", include_tables=["ai_projects"])
- 验证数据库 URI:
- SQL 安全:
- 检查生成的 SQL 查询,避免注入:
query = sql_chain.invoke("无效查询") print(query) - 限制 LLM 生成危险命令:
prompt = ChatPromptTemplate.from_template("仅生成 SELECT 查询: {input}") sql_chain = create_sql_query_chain(llm, db, prompt=prompt)
- 检查生成的 SQL 查询,避免注入:
- 性能优化:
- 异步调用:使用
ainvoke:query = await sql_chain.ainvoke("有多少用户?") - 缓存查询:结合
langchain.cache:from langchain.cache import InMemoryCache langchain.llm_cache = InMemoryCache() - 限制上下文:调整
k:sql_chain = create_sql_query_chain(llm, db, k=3)
- 异步调用:使用
- 错误调试:
- 查询错误:
- 检查 SQL 语法:
query = sql_chain.invoke("无效查询") print(query) - 验证执行:
print(db.run(query))
- 检查 SQL 语法:
- API 错误:
- 检查密钥:
print(os.environ.get("OPENAI_API_KEY")) - 增加超时:
llm = ChatOpenAI(timeout=30)
- 检查密钥:
- 数据库连接:
- 检查 URI:
print(db.get_usable_table_names()) - 测试连接:
print(db.run("SELECT 1"))
- 检查 URI:
- 查询错误:
常见问题
Q1:如何自定义提示模板?
A:创建自定义 ChatPromptTemplate:
from langchain_core.prompts import ChatPromptTemplate
prompt = ChatPromptTemplate.from_template(
"生成 {dialect} 的 SQL 查询,基于表结构 {table_info},回答问题:{input}\n仅输出 SQL 语句。"
)
sql_chain = create_sql_query_chain(llm, db, prompt=prompt)
Q2:如何执行和解析查询结果?
A:结合 db.run 和提示:
chain = sql_chain | RunnableLambda(lambda x: {"query": x, "result": db.run(x)})
result = chain.invoke("有多少用户?")
print(result["result"])
Q3:如何与代理结合?
A:使用 SQLDatabaseToolkit:
from langchain.agents import create_sql_agent
from langchain_community.agent_toolkits import SQLDatabaseToolkit
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
agent = create_sql_agent(llm=llm, toolkit=toolkit, verbose=True)
result = agent.run("有多少 AI 项目?")
Q4:如何支持开源模型?
A:使用 ChatOllama:
from langchain_ollama import ChatOllama
llm = ChatOllama(model="llama3")
sql_chain = create_sql_query_chain(llm, db)
query = sql_chain.invoke("有多少用户?")
总结
langchain.chains.create_sql_query_chain 是 LangChain 中用于自然语言到 SQL 查询转换的强大工具,核心功能包括:
- 定义:生成 SQL 查询,基于 LLM 和数据库上下文。
- 参数:
llm、db、prompt和k。 - 常用方法:
invoke(同步)、ainvoke(异步)、stream(流式)。 - 适用场景:数据库查询、数据分析、自然语言接口。
更多推荐


所有评论(0)