

本节源码基于官方 ↗。
由上一章节可以知道Vanna主要由两个函数起作用,第一个是train函数,第二个是ask函数,这两个函数都封装在VannaBase类中。
- VannaBase 是Vanna框架的核心基类,它提供了一个完整的文本到SQL生成系统的架构。
- train 函数用于训练Vanna模型,使其能够更好地将自然语言问题转换为SQL查询。
- ask 函数是Vanna的主要交互接口,用于回答用户提出的自然语言问题。
train函数#
源码如下:
def train(
self,
question: str = None,
sql: str = None,
ddl: str = None,
documentation: str = None,
plan: TrainingPlan = None,
) -> str:
"""
**Example:**
```python
vn.train()
```
Train Vanna.AI on a question and its corresponding SQL query.
If you call it with no arguments, it will check if you connected to a database and it will attempt to train on the metadata of that database.
If you call it with the sql argument, it's equivalent to [`vn.add_question_sql()`][vanna.base.base.VannaBase.add_question_sql].
If you call it with the ddl argument, it's equivalent to [`vn.add_ddl()`][vanna.base.base.VannaBase.add_ddl].
If you call it with the documentation argument, it's equivalent to [`vn.add_documentation()`][vanna.base.base.VannaBase.add_documentation].
Additionally, you can pass a [`TrainingPlan`][vanna.types.TrainingPlan] object. Get a training plan with [`vn.get_training_plan_generic()`][vanna.base.base.VannaBase.get_training_plan_generic].
Args:
question (str): The question to train on.
sql (str): The SQL query to train on.
ddl (str): The DDL statement.
documentation (str): The documentation to train on.
plan (TrainingPlan): The training plan to train on.
"""
if question and not sql:
raise ValidationError("Please also provide a SQL query")
if documentation:
print("Adding documentation....")
return self.add_documentation(documentation)
if sql:
if question is None:
question = self.generate_question(sql)
print("Question generated with sql:", question, "\nAdding SQL...")
return self.add_question_sql(question=question, sql=sql)
if ddl:
print("Adding ddl:", ddl)
return self.add_ddl(ddl)
if plan:
for item in plan._plan:
if item.item_type == TrainingPlanItem.ITEM_TYPE_DDL:
self.add_ddl(item.item_value)
elif item.item_type == TrainingPlanItem.ITEM_TYPE_IS:
self.add_documentation(item.item_value)
elif item.item_type == TrainingPlanItem.ITEM_TYPE_SQL:
self.add_question_sql(question=item.item_name, sql=item.item_value)bash首先是方法定义部分:
def train(
self,
question: str = None,
sql: str = None,
ddl: str = None,
documentation: str = None,
plan: TrainingPlan = None,
) -> str:bash表示方法接收question、sql、ddl、documentation和plan,并返回一个字符串结果。如果不带参数调用,它会检查是否连接到数据库,并尝试在该数据库的元数据上进行训练。 如果使用sql参数调用,它等同于[vn.add_question_sql()][vanna.base.base.VannaBase.add_question_sql]。 如果使用ddl参数调用,它等同于[vn.add_ddl()][vanna.base.base.VannaBase.add_ddl]。 如果使用documentation参数调用,它等同于[vn.add_documentation()][vanna.base.base.VannaBase.add_documentation]。 此外,您可以传递一个[TrainingPlan][vanna.types.TrainingPlan]对象。使用[vn.get_training_plan_generic()][vanna.base.base.VannaBase.get_training_plan_generic]获取训练计划。
结合源码可以知道:
if ddl:
print("Adding ddl:", ddl)
return self.add_ddl(ddl)
if question and not sql:
raise ValidationError("Please also provide a SQL query")
if documentation:
print("Adding documentation....")
return self.add_documentation(documentation)bash这些都比较简单就是简单的方法调用,不过这些都是抽象方法需要具体实现:
@abstractmethod
def add_ddl(self, ddl: str, **kwargs) -> str:
"""
This method is used to add a DDL statement to the training data.
Args:
ddl (str): The DDL statement to add.
Returns:
str: The ID of the training data that was added.
"""
pass
@abstractmethod
def add_documentation(self, documentation: str, **kwargs) -> str:
"""
This method is used to add documentation to the training data.
Args:
documentation (str): The documentation to add.
Returns:
str: The ID of the training data that was added.
"""
passbash下面来看看:
if sql:
if question is None:
question = self.generate_question(sql)
print("Question generated with sql:", question, "\nAdding SQL...")
return self.add_question_sql(question=question, sql=sql)bash当传入SQL语句,没有传入Question时,Vanna会自动根据SQL生成问题,效果如下:
深入看一下generate_question方法:
def generate_question(self, sql: str, **kwargs) -> str:
response = self.submit_prompt(
[
self.system_message(
"The user will give you SQL and you will try to guess what the business question this query is answering. Return just the question without any additional explanation. Do not reference the table name in the question."
),
self.user_message(sql),
],
**kwargs,
)
return responsebash方法很简单,就是构造一个系统提示词和用户提供的SQL语句,然后传给抽象方法submit_prompt,这个方法是由我们自己实现的,在QwenLLM类中:
def submit_prompt(self,prompt,**kwargs):
resp=Generation.call(
model=self.model,
messages=prompt,
seed=random.randint(1, 10000),
result_format='message',
api_key=self.api_key)
answer=resp.output.choices[0].message.content
global DEBUG_INFO
DEBUG_INFO=(prompt,answer)
return answerbashsubmit_prompt方法收到了prompt,然后调用模型进行推理解析并返回answer。此时,我们已经通过SQL生成了Question,然后调用add_question_sql(question=question, sql=sql)方法,这个方法是将Question-SQL对添加到训练数据中,返回唯一的ID,这个抽象方法需要我们具体实现,保存的训练数据格式是一个json列表,每个json是一个训练样本,例如:
{
"question":"what are 5 most grossing movies in IMDB top 1000 ",
"answer":"SELECT series_title,\n gross\nFROM imdb.public.movies\nORDER BY gross desc limit 5;"
}bash接下来看train方法中比较复杂的一部分:
if plan:
for item in plan._plan:
if item.item_type == TrainingPlanItem.ITEM_TYPE_DDL:
self.add_ddl(item.item_value)
elif item.item_type == TrainingPlanItem.ITEM_TYPE_IS:
self.add_documentation(item.item_value)
elif item.item_type == TrainingPlanItem.ITEM_TYPE_SQL:
self.add_question_sql(question=item.item_name, sql=item.item_value)bash传入的plan是一个TrainingPlan类的实例,TrainingPlan类是Vanna框架中用于管理和组织训练数据的核心类,它代表了一个结构化的训练计划。其中_plan是一个TrainingPlanItem的列表,TrainingPlanItem包括以下属性:
class TrainingPlanItem:
item_type: str
item_group: str
item_name: str
item_value: str
def __str__(self):
if self.item_type == self.ITEM_TYPE_SQL:
return f"Train on SQL: {self.item_group} {self.item_name}"
elif self.item_type == self.ITEM_TYPE_DDL:
return f"Train on DDL: {self.item_group} {self.item_name}"
elif self.item_type == self.ITEM_TYPE_IS:
return f"Train on Information Schema: {self.item_group} {self.item_name}"
ITEM_TYPE_SQL = "sql"
ITEM_TYPE_DDL = "ddl"
ITEM_TYPE_IS = "is"bash- item_type: 训练项类型(SQL查询、DDL语句、信息模式)
- item_group: 训练项分组(如数据库名.模式名)
- item_name: 训练项名称(如表名)
- item_value: 训练项具体内容
以item.item_type==sql为例,将它的值执行抽象方法add_ddl,并返回唯一的ID。执行完train方法的这一段代码后,Vanna AI模型将会:
- 学习到数据库的表结构(通过DDL)
- 获得额外的上下文信息(通过文档)
- 掌握更多问题与SQL查询的对应关系(通过问答对)
- 提升将自然语言转换为SQL查询的准确率
这实际上是批量训练模型的过程,将训练计划中所有类型的训练数据都添加到模型的检索层中,以增强模型的性能。add_ddl、add_documentation和add_question_sql需要在子类中具体实现,比如使用ChromaDB作为向量存储时,会在vanna.chromadb_vector.ChromaDB_VectorStore类中实现。
训练好的数据在生成SQL时会被检索和使用:
get_similar_question_sql():检索相似的问题-SQL对
get_related_ddl():检索相关的DDL语句
get_related_documentation():检索相关的文档bash这些检索到的信息会被组合成提示词,提供给大语言模型生成最终的SQL查询。
非常好,Vanna提供了该子类的实现代码,接下来我们需要在ChromaDB_VectorStore学习。
ChromaDB_VectorStore类#
class ChromaDB_VectorStore(VannaBase):
def __init__(self, config=None):
VannaBase.__init__(self, config=config)
if config is None:
config = {}
path = config.get("path", ".")
self.embedding_function = config.get("embedding_function", default_ef)
curr_client = config.get("client", "persistent")
collection_metadata = config.get("collection_metadata", None)
self.n_results_sql = config.get("n_results_sql", config.get("n_results", 10))
self.n_results_documentation = config.get("n_results_documentation", config.get("n_results", 10))
self.n_results_ddl = config.get("n_results_ddl", config.get("n_results", 10))
if curr_client == "persistent":
self.chroma_client = chromadb.PersistentClient(
path=path, settings=Settings(anonymized_telemetry=False)
)
elif curr_client == "in-memory":
self.chroma_client = chromadb.EphemeralClient(
settings=Settings(anonymized_telemetry=False)
)
elif isinstance(curr_client, chromadb.api.client.Client):
## allow providing client directly
self.chroma_client = curr_client
else:
raise ValueError(f"Unsupported client was set in config: {curr_client}")
self.documentation_collection = self.chroma_client.get_or_create_collection(
name="documentation",
embedding_function=self.embedding_function,
metadata=collection_metadata,
)
self.ddl_collection = self.chroma_client.get_or_create_collection(
name="ddl",
embedding_function=self.embedding_function,
metadata=collection_metadata,
)
self.sql_collection = self.chroma_client.get_or_create_collection(
name="sql",
embedding_function=self.embedding_function,
metadata=collection_metadata,
)bashChromaDB_VectorStore继承VannaBase类,实现VannaBase类提供的一些抽象方法:
path = config.get("path", ".")bashpath: ChromaDB 数据持久化存储的路径。默认值为当前目录(”.”),即在当前目录下创建和存储向量数据库。
self.embedding_function = config.get("embedding_function", default_ef)bashembedding_function: 用于将文本转换为向量的嵌入函数。默认使用 DefaultEmbeddingFunction。
curr_client = config.get("client", "persistent")bashclient: ChromaDB 客户端类型。“persistent”: 持久化客户端,数据存储在磁盘上(默认)。“in-memory”: 内存客户端,数据仅在内存中,程序退出后丢失。chromadb.api.client.Client: 直接提供已配置的客户端实例。
collection_metadata = config.get("collection_metadata", None)bashcollection_metadata: 集合的元数据信息。用于存储集合的额外信息,如版本号、创建时间等。
self.n_results_sql = config.get("n_results_sql", config.get("n_results", 10))
self.n_results_documentation = config.get("n_results_documentation", config.get("n_results", 10))
self.n_results_ddl = config.get("n_results_ddl", config.get("n_results", 10))bash查询结果数量配置:
- n_results_sql: 查询相似 SQL 问题时返回的结果数量,默认10个
- n_results_documentation: 查询相关文档时返回的结果数量,默认10个
- n_results_ddl: 查询相关 DDL 语句时返回的结果数量,默认10个
if curr_client == "persistent":
self.chroma_client = chromadb.PersistentClient(
path=path, settings=Settings(anonymized_telemetry=False)
)
elif curr_client == "in-memory":
self.chroma_client = chromadb.EphemeralClient(
settings=Settings(anonymized_telemetry=False)
)bash根据配置创建相应的 ChromaDB 客户端。
self.documentation_collection = self.chroma_client.get_or_create_collection(
name="documentation",
embedding_function=self.embedding_function,
metadata=collection_metadata,
)
## 类似地创建 ddl_collection 和 sql_collectionbash创建三个向量集合(Collections)用于存储不同类型的数据:
- documentation collection: 存储文档信息
- ddl collection: 存储数据定义语言(表结构等)
- sql collection: 存储问题和 SQL 查询对
下面还是以add_ddl方法为例,由train函数中的self.add_ddl(item.item_value)将值传入进来,然后经过:
def add_ddl(self, ddl: str, **kwargs) -> str:
id = deterministic_uuid(ddl) + "-ddl"
self.ddl_collection.add(
documents=ddl,
embeddings=self.generate_embedding(ddl),
ids=id,
)
return idbash首先生成唯一的ID,然后收集documents、对应的embedding结果和ID。同样的收集question_sql和documentation。
执行get_training_data方法的结果如下:
id \
0 04a88b26-6984-5521-b897-73798ce0001f-sql
1 e5102160-2dbf-5300-98f5-24d762a12b59-sql
2 0189b3e3-c135-5bfe-a9f8-7faabd751813-sql
3 eb6bbff7-a89c-51bc-a58d-ebf6dc181ae3-sql
4 54db6ffd-201b-59a2-8568-cd05d82db461-sql
5 9ccf7bcd-5091-5b97-bf72-af9d41e526a5-sql
6 270bfd96-c340-5b21-afe9-0d14c23fd8bd-sql
0 ab0ac208-2f5e-50b0-9177-423427220940-ddl
0 8fc54ebe-8bb3-5fb7-88a6-5c98d817ed07-doc
1 82e9153e-0b3b-5aca-ac66-31e65eb61d36-doc
question \
0 Who are the users aged between 10 and 20?
1 小鱼儿的年龄
2 小猪猪的年龄
3 用户的平均年龄
4 打算给一批员工送福报,把他们的名字过滤出来
5 What are the names of users whose age is betwe...
6 各个年龄段的人数都是多少?
0 None
0 None
1 None
content training_data_type
0 select name from user where age between 10 and 20 sql
1 select age from user where name="小鱼儿" sql
2 select age from user where name="小猪猪" sql
3 select avg(age) from user sql
4 select name from user where age >= 35 sql
5 select name from user where age between 10 and 20 sql
6 SELECT \n CASE \n WHEN age BETWEEN 0... sql
0 CREATE TABLE IF NOT EXISTS user (\n id ... ddl
0 "福报"是指age>=35岁,也就是可以向社会输送的人才 documentation
1 用户年龄段划分逻辑:0-10,10-20,20-30,30-40,40-50,50-60,6... documentation bash返回的内容是一个表格,可以看到包括四个属性id 、question、content 和training_data_type ,值得注意的是,对于ddl和documentation类型的训练数据,是没有question值的。
在ChromaDB_VectorStore类的方法里还有remove_training_data和remove_collection方法,可以删除数据。
具体的训练流程就是,当在VannaBase基类的train方法中设置plan为True时,它会将接收到新的ddl、documentation和question-sql对,以及它们对应的embedding值存储到向量库中。如果提供json格式的训练集,就只需要对该训练集进行遍历存储到向量库中即可。
ask函数#
接下来分析另一个重要的函数ask函数,首先看它的初始化:
def ask(
self,
question: Union[str, None] = None,
print_results: bool = True,
auto_train: bool = True,
visualize: bool = True, ## if False, will not generate plotly code
allow_llm_to_see_data: bool = False,
) -> Union[
Tuple[
Union[str, None],
Union[pd.DataFrame, None],
Union[plotly.graph_objs.Figure, None],
],
None,
]:bash- question (Union[str, None]) - 用户要询问的问题字符串,如果为 None 则会提示用户输入
- print_results (bool) - 是否打印结果,默认为 True
- auto_train (bool) - 是否自动训练,默认为 True,会将问题和 SQL 查询对添加到训练数据中
- visualize (bool) - 是否生成图表,默认为 True,会根据数据生成 Plotly 图表
- allow_llm_to_see_data (bool) - 是否允许 LLM 查看数据,默认为 False
具体返回一个三元组 (Tuple) 或 None:
- SQL 查询字符串 (str) - 生成的 SQL 查询语句
- 数据结果 (pd.DataFrame) - SQL 查询执行后的结果数据
- 图表对象 (plotly.graph_objs.Figure) - 根据数据生成的可视化图表
try:
sql = self.generate_sql(question=question, allow_llm_to_see_data=allow_llm_to_see_data)
except Exception as e:
print(e)
return None, None, Nonebash根据用户的提问生成sql语句,并确定是否让LLM看到数据。查看generate_sql函数:
if self.config is not None:
initial_prompt = self.config.get("initial_prompt", None)
else:
initial_prompt = None
question_sql_list = self.get_similar_question_sql(question, **kwargs)
ddl_list = self.get_related_ddl(question, **kwargs)
doc_list = self.get_related_documentation(question, **kwargs)
prompt = self.get_sql_prompt(
initial_prompt=initial_prompt,
question=question,
question_sql_list=question_sql_list,
ddl_list=ddl_list,
doc_list=doc_list,
**kwargs,
)
self.log(title="SQL Prompt", message=prompt)
llm_response = self.submit_prompt(prompt, **kwargs)
self.log(title="LLM Response", message=llm_response)
if 'intermediate_sql' in llm_response:
if not allow_llm_to_see_data:
return "The LLM is not allowed to see the data in your database. Your question requires database introspection to generate the necessary SQL. Please set allow_llm_to_see_data=True to enable this."
if allow_llm_to_see_data:
intermediate_sql = self.extract_sql(llm_response)
try:
self.log(title="Running Intermediate SQL", message=intermediate_sql)
df = self.run_sql(intermediate_sql)
prompt = self.get_sql_prompt(
initial_prompt=initial_prompt,
question=question,
question_sql_list=question_sql_list,
ddl_list=ddl_list,
doc_list=doc_list+[f"The following is a pandas DataFrame with the results of the intermediate SQL query {intermediate_sql}: \n" + df.to_markdown()],
**kwargs,
)
self.log(title="Final SQL Prompt", message=prompt)
llm_response = self.submit_prompt(prompt, **kwargs)
self.log(title="LLM Response", message=llm_response)
except Exception as e:
return f"Error running intermediate SQL: {e}"
return self.extract_sql(llm_response)bash其中:
question_sql_list = self.get_similar_question_sql(question, **kwargs)
ddl_list = self.get_related_ddl(question, **kwargs)
doc_list = self.get_related_documentation(question, **kwargs)bash调用三个关键方法获取生成 SQL 所需的上下文:
- get_similar_question_sql():获取相似的问题-SQL 对
- get_related_ddl():获取相关的数据定义语言(表结构)
- get_related_documentation():获取相关的文档说明
这三个方法在子类ChromaDB_VectorStore中实现,不同的向量数据库有不同的实现,这里简单看get_similar_question_sql方法的具体实现:
def get_similar_question_sql(self, question: str, **kwargs) -> list:
return ChromaDB_VectorStore._extract_documents(
self.sql_collection.query(
query_texts=[question],
n_results=self.n_results_sql,
)
)bash它是将sql_collection.query的结果传入_extract_documents方法。sql_collection.query方法是如何实现的?这里挖个坑。
_extract_documents方法将 ChromaDB 查询返回的原始数据结构转换为可以直接使用的文档列表。主要用于处理以下三种查询的返回结果:
## 1. 获取相似问题-SQL对
self.get_similar_question_sql(question)
## 返回: [{"question": "...", "sql": "..."}, {...}]
## 2. 获取相关DDL语句
self.get_related_ddl(question)
## 返回: ["CREATE TABLE ...", "CREATE TABLE ..."]
## 3. 获取相关文档
self.get_related_documentation(question)
## 返回: ["文档内容1", "文档内容2"]bash保存为question_sql_list,同样的还有ddl_list和doc_list,然后一起组装成prompt:
prompt = self.get_sql_prompt(
initial_prompt=initial_prompt,
question=question,
question_sql_list=question_sql_list,
ddl_list=ddl_list,
doc_list=doc_list,
**kwargs,
)bash具体实现如下:
def get_sql_prompt(
self,
initial_prompt : str,
question: str,
question_sql_list: list,
ddl_list: list,
doc_list: list,
**kwargs,
):
"""
Example:
```python
vn.get_sql_prompt(
question="What are the top 10 customers by sales?",
question_sql_list=[{"question": "What are the top 10 customers by sales?", "sql": "SELECT * FROM customers ORDER BY sales DESC LIMIT 10"}],
ddl_list=["CREATE TABLE customers (id INT, name TEXT, sales DECIMAL)"],
doc_list=["The customers table contains information about customers and their sales."],
)
```
This method is used to generate a prompt for the LLM to generate SQL.
Args:
question (str): The question to generate SQL for.
question_sql_list (list): A list of questions and their corresponding SQL statements.
ddl_list (list): A list of DDL statements.
doc_list (list): A list of documentation.
Returns:
any: The prompt for the LLM to generate SQL.
"""
if initial_prompt is None:
initial_prompt = f"You are a {self.dialect} expert. " + \
"Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. "
initial_prompt = self.add_ddl_to_prompt(
initial_prompt, ddl_list, max_tokens=self.max_tokens
)
if self.static_documentation != "":
doc_list.append(self.static_documentation)
initial_prompt = self.add_documentation_to_prompt(
initial_prompt, doc_list, max_tokens=self.max_tokens
)
initial_prompt += (
"===Response Guidelines \n"
"1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. \n"
"2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment saying intermediate_sql \n"
"3. If the provided context is insufficient, please explain why it can't be generated. \n"
"4. Please use the most relevant table(s). \n"
"5. If the question has been asked and answered before, please repeat the answer exactly as it was given before. \n"
f"6. Ensure that the output SQL is {self.dialect}-compliant and executable, and free of syntax errors. \n"
)
message_log = [self.system_message(initial_prompt)]
for example in question_sql_list:
if example is None:
print("example is None")
else:
if example is not None and "question" in example and "sql" in example:
message_log.append(self.user_message(example["question"]))
message_log.append(self.assistant_message(example["sql"]))
message_log.append(self.user_message(question))
return message_logpython然后将生成的prompt提交llm_response = self.submit_prompt(prompt, **kwargs)到大模型得到返回结果,继续执行代码:
if 'intermediate_sql' in llm_response:
if not allow_llm_to_see_data:
return "The LLM is not allowed to see the data in your database. Your question requires database introspection to generate the necessary SQL. Please set allow_llm_to_see_data=True to enable this."
if allow_llm_to_see_data:
intermediate_sql = self.extract_sql(llm_response)
try:
self.log(title="Running Intermediate SQL", message=intermediate_sql)
df = self.run_sql(intermediate_sql)
prompt = self.get_sql_prompt(
initial_prompt=initial_prompt,
question=question,
question_sql_list=question_sql_list,
ddl_list=ddl_list,
doc_list=doc_list+[f"The following is a pandas DataFrame with the results of the intermediate SQL query {intermediate_sql}: \n" + df.to_markdown()],
**kwargs,
)
self.log(title="Final SQL Prompt", message=prompt)
llm_response = self.submit_prompt(prompt, **kwargs)
self.log(title="LLM Response", message=llm_response)
except Exception as e:
return f"Error running intermediate SQL: {e}"bash由前面生成的prompt的Response Guidelines的第2条可以看到,当提交上下文的内容不是很充分时,需要执行一个中间的查询操作,也就是提示词里会出现intermediate,然后执行extract_sql方法,这个方法就是当 LLM 生成响应时,通常不仅包含 SQL 查询,还可能包含解释、注释或其他文本。这个方法的作用就是从复杂的响应中准确提取出真正的 SQL 查询语句:
def extract_sql(self, llm_response: str) -> str:
import re
## Match CREATE TABLE ... AS SELECT
sqls = re.findall(r"\bCREATE\s+TABLE\b.*?\bAS\b.*?;", llm_response, re.DOTALL | re.IGNORECASE)
if sqls:
sql = sqls[-1]
self.log(title="Extracted SQL", message=f"{sql}")
return sql
## Match WITH clause (CTEs)
sqls = re.findall(r"\bWITH\b .*?;", llm_response, re.DOTALL | re.IGNORECASE)
if sqls:
sql = sqls[-1]
self.log(title="Extracted SQL", message=f"{sql}")
return sql
## Match SELECT ... ;
sqls = re.findall(r"\bSELECT\b .*?;", llm_response, re.DOTALL | re.IGNORECASE)
if sqls:
sql = sqls[-1]
self.log(title="Extracted SQL", message=f"{sql}")
return sql
## Match ```sql ... ```blocks
sqls = re.findall(r"```sql\s*\n(.*?)```", llm_response, re.DOTALL | re.IGNORECASE)
if sqls:
sql = sqls[-1].strip()
self.log(title="Extracted SQL", message=f"{sql}")
return sql
## Match any ```... ```code blocks
sqls = re.findall(r"```(.*?)```", llm_response, re.DOTALL | re.IGNORECASE)
if sqls:
sql = sqls[-1].strip()
self.log(title="Extracted SQL", message=f"{sql}")
return sql
return llm_responsebash他是使用正则化找到LLM返回内容中的SQL语句,例如大模型返回内容:
根据您的问题,我建议使用以下 SQL 查询:
```sql
SELECT customer_name, SUM(sales) as total_sales
FROM customers
GROUP BY customer_name
ORDER BY total_sales DESC
LIMIT 10;bash经过extract_sql处理后:
SELECT customer_name, SUM(sales) as total_sales FROM customers GROUP BY customer_name ORDER BY total_sales DESC LIMIT 10;bash再执行中间SQLdf = self.run_sql(intermediate_sql)得到数据表然后执行:
prompt = self.get_sql_prompt(
initial_prompt=initial_prompt,
question=question,
question_sql_list=question_sql_list,
ddl_list=ddl_list,
doc_list=doc_list+[f"The following is a pandas DataFrame with the results of the intermediate SQL query {intermediate_sql}: \n" + df.to_markdown()],
**kwargs,
)bash这个操作就是将表的信息放到doc_list中,生成prompt,这样就弥补了之前信息不充分的缺点。最后再提交一次给大模型,提取sql返回即可:
llm_response = self.submit_prompt(prompt, **kwargs)
return self.extract_sql(llm_response)bash上面就是ask方法中的generate_sql的逻辑,总结就是,根据用户额度提问,从DDL、Document和question_sql中分别检索返回top_k个相关的内容,然后组成初始的prompt,如果不需要生成中间SQL,则直接提取回答中的SQL并返回,若需要执行,则经过查询一些表格的内容信息并重新组成prompt传给大模型进行处理,最后经过提取回答中的SQL并返回。
接下来就是打印结果:
if print_results:
try:
Code = __import__("IPython.display", fromList=["Code"]).Code
display(Code(sql))
except Exception as e:
print(sql)bash这段代码以语法高亮的方式显示生成的 SQL 查询,以美观的格式显示生成的 SQL 查询,而不是简单地打印纯文本。
if self.run_sql_is_set is False:
print(
"If you want to run the SQL query, connect to a database first."
)
if print_results:
return None
else:
return sql, None, Nonebash最后就是执行SQL语句、绘制表格和数据图:
if self.run_sql_is_set is False:
print(
"If you want to run the SQL query, connect to a database first."
)
if print_results:
return None
else:
return sql, None, None
try:
df = self.run_sql(sql)
if print_results:
try:
display = __import__(
"IPython.display", fromList=["display"]
).display
display(df)
except Exception as e:
print(df)
if len(df) > 0 and auto_train:
self.add_question_sql(question=question, sql=sql)
## Only generate plotly code if visualize is True
if visualize:
try:
plotly_code = self.generate_plotly_code(
question=question,
sql=sql,
df_metadata=f"Running df.dtypes gives:\n {df.dtypes}",
)
fig = self.get_plotly_figure(plotly_code=plotly_code, df=df)
if print_results:
try:
display = __import__(
"IPython.display", fromlist=["display"]
).display
Image = __import__(
"IPython.display", fromlist=["Image"]
).Image
img_bytes = fig.to_image(format="png", scale=2)
display(Image(img_bytes))
except Exception as e:
fig.show()
except Exception as e:
## Print stack trace
traceback.print_exc()
print("Couldn't run plotly code: ", e)
if print_results:
return None
else:
return sql, df, None
else:
return sql, df, None
except Exception as e:
print("Couldn't run sql: ", e)
if print_results:
return None
else:
return sql, None, None
return sql, df, figbash至此,Vanna的两大关键方法已经解读地差不多了。