Soup's Blog

Back

Text2SQL(二)Vanna源码解读之train和askBlur image

本节源码基于官方

由上一章节可以知道Vanna主要由两个函数起作用,第一个是train函数,第二个是ask函数,这两个函数都封装在VannaBase类中。

  • VannaBase 是Vanna框架的核心基类,它提供了一个完整的文本到SQL生成系统的架构。
  • train 函数用于训练Vanna模型,使其能够更好地将自然语言问题转换为SQL查询。
  • ask 函数是Vanna的主要交互接口,用于回答用户提出的自然语言问题。

train函数#

源码如下:

首先是方法定义部分:

def train(
        self,
        question: str = None,
        sql: str = None,
        ddl: str = None,
        documentation: str = None,
        plan: TrainingPlan = None,
    ) -> str:
bash

表示方法接收questionsqlddldocumentationplan,并返回一个字符串结果。如果不带参数调用,它会检查是否连接到数据库,并尝试在该数据库的元数据上进行训练。 如果使用sql参数调用,它等同于[vn.add_question_sql()][vanna.base.base.VannaBase.add_question_sql]。 如果使用ddl参数调用,它等同于[vn.add_ddl()][vanna.base.base.VannaBase.add_ddl]。 如果使用documentation参数调用,它等同于[vn.add_documentation()][vanna.base.base.VannaBase.add_documentation]。 此外,您可以传递一个[TrainingPlan][vanna.types.TrainingPlan]对象。使用[vn.get_training_plan_generic()][vanna.base.base.VannaBase.get_training_plan_generic]获取训练计划。

结合源码可以知道:

if ddl:
   print("Adding ddl:", ddl)
   return self.add_ddl(ddl)
if question and not sql:
   raise ValidationError("Please also provide a SQL query")
if documentation:
     print("Adding documentation....")
     return self.add_documentation(documentation)
bash

这些都比较简单就是简单的方法调用,不过这些都是抽象方法需要具体实现:

下面来看看:

if sql:
   if question is None:
       question = self.generate_question(sql)
       print("Question generated with sql:", question, "\nAdding SQL...")
   return self.add_question_sql(question=question, sql=sql)
bash

当传入SQL语句,没有传入Question时,Vanna会自动根据SQL生成问题,效果如下: 在这里插入图片描述 深入看一下generate_question方法:

def generate_question(self, sql: str, **kwargs) -> str:
     response = self.submit_prompt(
         [
             self.system_message(
                 "The user will give you SQL and you will try to guess what the business question this query is answering. Return just the question without any additional explanation. Do not reference the table name in the question."
             ),
             self.user_message(sql),
         ],
         **kwargs,
     )

     return response
bash

方法很简单,就是构造一个系统提示词和用户提供的SQL语句,然后传给抽象方法submit_prompt,这个方法是由我们自己实现的,在QwenLLM类中:

def submit_prompt(self,prompt,**kwargs):
    resp=Generation.call(
      model=self.model,
      messages=prompt,
      seed=random.randint(1, 10000),
      result_format='message',
      api_key=self.api_key)
    answer=resp.output.choices[0].message.content
    global DEBUG_INFO
    DEBUG_INFO=(prompt,answer)
    return answer
bash

submit_prompt方法收到了prompt,然后调用模型进行推理解析并返回answer。此时,我们已经通过SQL生成了Question,然后调用add_question_sql(question=question, sql=sql)方法,这个方法是将Question-SQL对添加到训练数据中,返回唯一的ID,这个抽象方法需要我们具体实现,保存的训练数据格式是一个json列表,每个json是一个训练样本,例如:

{
      "question":"what are 5 most grossing movies in IMDB top 1000 ",
      "answer":"SELECT series_title,\n       gross\nFROM   imdb.public.movies\nORDER BY gross desc limit 5;"
    }
bash

接下来看train方法中比较复杂的一部分:

if plan:
   for item in plan._plan:
       if item.item_type == TrainingPlanItem.ITEM_TYPE_DDL:
           self.add_ddl(item.item_value)
       elif item.item_type == TrainingPlanItem.ITEM_TYPE_IS:
           self.add_documentation(item.item_value)
       elif item.item_type == TrainingPlanItem.ITEM_TYPE_SQL:
           self.add_question_sql(question=item.item_name, sql=item.item_value)
bash

传入的plan是一个TrainingPlan类的实例,TrainingPlan类是Vanna框架中用于管理和组织训练数据的核心类,它代表了一个结构化的训练计划。其中_plan是一个TrainingPlanItem的列表,TrainingPlanItem包括以下属性:

  • item_type: 训练项类型(SQL查询、DDL语句、信息模式)
  • item_group: 训练项分组(如数据库名.模式名)
  • item_name: 训练项名称(如表名)
  • item_value: 训练项具体内容

item.item_type==sql为例,将它的值执行抽象方法add_ddl,并返回唯一的ID。执行完train方法的这一段代码后,Vanna AI模型将会:

  • 学习到数据库的表结构(通过DDL)
  • 获得额外的上下文信息(通过文档)
  • 掌握更多问题与SQL查询的对应关系(通过问答对)
  • 提升将自然语言转换为SQL查询的准确率

这实际上是批量训练模型的过程,将训练计划中所有类型的训练数据都添加到模型的检索层中,以增强模型的性能。add_ddladd_documentationadd_question_sql需要在子类中具体实现,比如使用ChromaDB作为向量存储时,会在vanna.chromadb_vector.ChromaDB_VectorStore类中实现。

训练好的数据在生成SQL时会被检索和使用:

get_similar_question_sql():检索相似的问题-SQL对
get_related_ddl():检索相关的DDL语句
get_related_documentation():检索相关的文档
bash

这些检索到的信息会被组合成提示词,提供给大语言模型生成最终的SQL查询。

非常好,Vanna提供了该子类的实现代码,接下来我们需要在ChromaDB_VectorStore学习。

ChromaDB_VectorStore类#

ChromaDB_VectorStore继承VannaBase类,实现VannaBase类提供的一些抽象方法:

path = config.get("path", ".")
bash

path: ChromaDB 数据持久化存储的路径。默认值为当前目录(”.”),即在当前目录下创建和存储向量数据库。

self.embedding_function = config.get("embedding_function", default_ef)
bash

embedding_function: 用于将文本转换为向量的嵌入函数。默认使用 DefaultEmbeddingFunction。

curr_client = config.get("client", "persistent")
bash

client: ChromaDB 客户端类型。“persistent”: 持久化客户端,数据存储在磁盘上(默认)。“in-memory”: 内存客户端,数据仅在内存中,程序退出后丢失。chromadb.api.client.Client: 直接提供已配置的客户端实例。

collection_metadata = config.get("collection_metadata", None)
bash

collection_metadata: 集合的元数据信息。用于存储集合的额外信息,如版本号、创建时间等。

self.n_results_sql = config.get("n_results_sql", config.get("n_results", 10))
self.n_results_documentation = config.get("n_results_documentation", config.get("n_results", 10))
self.n_results_ddl = config.get("n_results_ddl", config.get("n_results", 10))
bash

查询结果数量配置:

  • n_results_sql: 查询相似 SQL 问题时返回的结果数量,默认10个
  • n_results_documentation: 查询相关文档时返回的结果数量,默认10个
  • n_results_ddl: 查询相关 DDL 语句时返回的结果数量,默认10个
if curr_client == "persistent":
    self.chroma_client = chromadb.PersistentClient(
        path=path, settings=Settings(anonymized_telemetry=False)
    )
elif curr_client == "in-memory":
    self.chroma_client = chromadb.EphemeralClient(
        settings=Settings(anonymized_telemetry=False)
    )
bash

根据配置创建相应的 ChromaDB 客户端。

self.documentation_collection = self.chroma_client.get_or_create_collection(
    name="documentation",
    embedding_function=self.embedding_function,
    metadata=collection_metadata,
)
## 类似地创建 ddl_collection 和 sql_collection
bash

创建三个向量集合(Collections)用于存储不同类型的数据:

  • documentation collection: 存储文档信息
  • ddl collection: 存储数据定义语言(表结构等)
  • sql collection: 存储问题和 SQL 查询对

下面还是以add_ddl方法为例,由train函数中的self.add_ddl(item.item_value)将值传入进来,然后经过:

def add_ddl(self, ddl: str, **kwargs) -> str:
        id = deterministic_uuid(ddl) + "-ddl"
        self.ddl_collection.add(
            documents=ddl,
            embeddings=self.generate_embedding(ddl),
            ids=id,
        )
        return id
bash

首先生成唯一的ID,然后收集documents、对应的embedding结果和ID。同样的收集question_sqldocumentation

执行get_training_data方法的结果如下:

返回的内容是一个表格,可以看到包括四个属性idquestioncontent training_data_type ,值得注意的是,对于ddldocumentation类型的训练数据,是没有question值的。

ChromaDB_VectorStore类的方法里还有remove_training_dataremove_collection方法,可以删除数据。

具体的训练流程就是,当在VannaBase基类的train方法中设置plan为True时,它会将接收到新的ddldocumentationquestion-sql对,以及它们对应的embedding值存储到向量库中。如果提供json格式的训练集,就只需要对该训练集进行遍历存储到向量库中即可。

ask函数#

接下来分析另一个重要的函数ask函数,首先看它的初始化:

def ask(
        self,
        question: Union[str, None] = None,
        print_results: bool = True,
        auto_train: bool = True,
        visualize: bool = True,  ## if False, will not generate plotly code
        allow_llm_to_see_data: bool = False,
    ) -> Union[
        Tuple[
            Union[str, None],
            Union[pd.DataFrame, None],
            Union[plotly.graph_objs.Figure, None],
        ],
        None,
    ]:
bash
  • question (Union[str, None]) - 用户要询问的问题字符串,如果为 None 则会提示用户输入
  • print_results (bool) - 是否打印结果,默认为 True
  • auto_train (bool) - 是否自动训练,默认为 True,会将问题和 SQL 查询对添加到训练数据中
  • visualize (bool) - 是否生成图表,默认为 True,会根据数据生成 Plotly 图表
  • allow_llm_to_see_data (bool) - 是否允许 LLM 查看数据,默认为 False

具体返回一个三元组 (Tuple) 或 None:

  • SQL 查询字符串 (str) - 生成的 SQL 查询语句
  • 数据结果 (pd.DataFrame) - SQL 查询执行后的结果数据
  • 图表对象 (plotly.graph_objs.Figure) - 根据数据生成的可视化图表
try:
    sql = self.generate_sql(question=question, allow_llm_to_see_data=allow_llm_to_see_data)
except Exception as e:
    print(e)
    return None, None, None
bash

根据用户的提问生成sql语句,并确定是否让LLM看到数据。查看generate_sql函数:

其中:

question_sql_list = self.get_similar_question_sql(question, **kwargs)
ddl_list = self.get_related_ddl(question, **kwargs)
doc_list = self.get_related_documentation(question, **kwargs)
bash

调用三个关键方法获取生成 SQL 所需的上下文:

  • get_similar_question_sql():获取相似的问题-SQL 对
  • get_related_ddl():获取相关的数据定义语言(表结构)
  • get_related_documentation():获取相关的文档说明

这三个方法在子类ChromaDB_VectorStore中实现,不同的向量数据库有不同的实现,这里简单看get_similar_question_sql方法的具体实现:

def get_similar_question_sql(self, question: str, **kwargs) -> list:
        return ChromaDB_VectorStore._extract_documents(
            self.sql_collection.query(
                query_texts=[question],
                n_results=self.n_results_sql,
            )
        )
bash

它是将sql_collection.query的结果传入_extract_documents方法。sql_collection.query方法是如何实现的?这里挖个坑。 _extract_documents方法将 ChromaDB 查询返回的原始数据结构转换为可以直接使用的文档列表。主要用于处理以下三种查询的返回结果:

## 1. 获取相似问题-SQL对
self.get_similar_question_sql(question) 
## 返回: [{"question": "...", "sql": "..."}, {...}]

## 2. 获取相关DDL语句
self.get_related_ddl(question)
## 返回: ["CREATE TABLE ...", "CREATE TABLE ..."]

## 3. 获取相关文档
self.get_related_documentation(question)
## 返回: ["文档内容1", "文档内容2"]
bash

保存为question_sql_list,同样的还有ddl_listdoc_list,然后一起组装成prompt:

prompt = self.get_sql_prompt(
         initial_prompt=initial_prompt,
         question=question,
         question_sql_list=question_sql_list,
         ddl_list=ddl_list,
         doc_list=doc_list,
         **kwargs,
     )
bash

具体实现如下:

然后将生成的prompt提交llm_response = self.submit_prompt(prompt, **kwargs)到大模型得到返回结果,继续执行代码:

由前面生成的promptResponse Guidelines的第2条可以看到,当提交上下文的内容不是很充分时,需要执行一个中间的查询操作,也就是提示词里会出现intermediate,然后执行extract_sql方法,这个方法就是当 LLM 生成响应时,通常不仅包含 SQL 查询,还可能包含解释、注释或其他文本。这个方法的作用就是从复杂的响应中准确提取出真正的 SQL 查询语句:

他是使用正则化找到LLM返回内容中的SQL语句,例如大模型返回内容:

根据您的问题,我建议使用以下 SQL 查询:

```sql
SELECT customer_name, SUM(sales) as total_sales 
FROM customers 
GROUP BY customer_name 
ORDER BY total_sales DESC 
LIMIT 10;
bash

经过extract_sql处理后:

SELECT customer_name, SUM(sales) as total_sales FROM customers GROUP BY customer_name ORDER BY total_sales DESC LIMIT 10;
bash

再执行中间SQLdf = self.run_sql(intermediate_sql)得到数据表然后执行:

prompt = self.get_sql_prompt(
        initial_prompt=initial_prompt,
        question=question,
        question_sql_list=question_sql_list,
        ddl_list=ddl_list,
        doc_list=doc_list+[f"The following is a pandas DataFrame with the results of the intermediate SQL query {intermediate_sql}: \n" + df.to_markdown()],
        **kwargs,
    )
bash

这个操作就是将表的信息放到doc_list中,生成prompt,这样就弥补了之前信息不充分的缺点。最后再提交一次给大模型,提取sql返回即可:

llm_response = self.submit_prompt(prompt, **kwargs)
return self.extract_sql(llm_response)
bash

上面就是ask方法中的generate_sql的逻辑,总结就是,根据用户额度提问,从DDLDocumentquestion_sql中分别检索返回top_k个相关的内容,然后组成初始的prompt,如果不需要生成中间SQL,则直接提取回答中的SQL并返回,若需要执行,则经过查询一些表格的内容信息并重新组成prompt传给大模型进行处理,最后经过提取回答中的SQL并返回。

接下来就是打印结果:

if print_results:
   try:
       Code = __import__("IPython.display", fromList=["Code"]).Code
       display(Code(sql))
   except Exception as e:
       print(sql)
bash

这段代码以语法高亮的方式显示生成的 SQL 查询,以美观的格式显示生成的 SQL 查询,而不是简单地打印纯文本。

if self.run_sql_is_set is False:
   print(
       "If you want to run the SQL query, connect to a database first."
   )

   if print_results:
       return None
   else:
       return sql, None, None
bash

最后就是执行SQL语句、绘制表格和数据图:

至此,Vanna的两大关键方法已经解读地差不多了。

Text2SQL(二)Vanna源码解读之train和ask
http://www.soupcola.top/blog/text2sql/text2sql_blogs-2
Author Soup Cola
Published at 2026年1月31日