利用Transformer、DPR、FAISS和BART对检索增强生成(RAG)进行深入技术探索
|文末点击阅读原文查看网页版|
更多专栏文章点击查看:
LLM 架构专栏
大模型架构专栏文章阅读指南
Agent系列
强化学习系列
欢迎加入大模型交流群:加群链接 https://docs.qq.com/doc/DS3VGS0NFVHNRR0Ru#
公众号【柏企阅文】
个人网站:https://www.chenbaiqi.com
简介
RAG代表检索增强生成(Retrieval-Augmented Generation)。这是一种巧妙的设置,在这种设置中,Transformer模型(所有GPT背后的核心)并非只是凭空编造内容,它实际上会去搜索真实的信息,并在回答问题之前将其检索回来。
在这篇文章中,将逐步为你介绍它的工作原理。密集段落检索(Dense Passage Retrieval,DPR)起着关键作用,它使用在问答数据集上训练的模型进行智能编码。DPR使用基于BERT的编码器,该编码器从标记化开始处理文本,然后应用嵌入、注意力机制和多个Transformer层来生成最终的向量表示(嵌入)。我们将这种编码应用于用户的问题以及内部文档或段落。这将产生两组嵌入。为了找到最相关的段落,我们使用由Facebook开发的FAISS,它通过相似性度量来比较这些嵌入。检索到的相关上下文随后会被传递给一个生成器模型,该模型会生成一个精确且有依据的回复。
用例:仓库运营助手
问题
有人问你的人工智能助手:“在仓库中应如何存储易碎物品?”
挑战
答案不在公共博客或教科书中,它深藏在你的内部仓库手册和处理程序中,而这些是人工智能模型在训练期间从未见过的内容。
解决方案
以下是我们如何使用简单的RAG(检索增强生成)来解决这个问题:
1. 索引步骤
- 分块:将长文档分解为更小的、易于管理的文本块,以便于处理。
- 编码(DPR)-编码器:使用密集通道检索编码器将每个块转换为密集向量。
- 矢量数据库(FAISS):矢量表示存储在矢量数据库中
2. 检索步骤
当用户提交查询时,RAG使用索引阶段的相同编码模型将查询转换为向量表示。然后它计算查询向量和之前创建的文档块向量之间的相似度分数。基于这些分数,系统检索与查询最相似的顶部K个块。
3. 答案生成(BART或GPT)步骤
接下来,使用GPT-2或BART等生成器模型。该模型将原始问题与顶部检索到的上下文块一起生成相关且准确的自然语言答案。
这篇文章为你提供了RAG工作原理的完整技术剖析,没有冗长的废话,只有扎实的解释。并且,是的,我们将使用简单的英语、实际数字和“基于单词的图表”来讲解每个矩阵、架构细节和真实世界的示例,这样即使你没有人工智能博士学位,也能跟上每一步。
详细的组件逐个流程
1. 索引步骤
我们将长文档分成可管理的块(段落或段落),然后使用DPR上下文编码器对每个块进行编码。这些密集矢量嵌入存储在FAISS索引中,以便以后高效检索。
第1步:标记化(用于上下文块)
- 输入:类似“易碎物品应单独存储”的段落。
- 输出:令牌化为:input_ids形状(1,256)-(批量大小为1,填充/截断为256个令牌)
- 此步骤1使用DPRContextEncoderTokenizer来完成此部分:
# 模型名称
DPR_CTX_ENCODER = "facebook/dpr-ctx_encoder-single-nq-base"
ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained(DPR_CTX_ENCODER)
从第2步开始由DPRContextEncoder执行。现在让我们探索DPRContextEncoder内部发生了什么。
第2步:嵌入层
- 每个令牌都映射到一个768维向量(用于基于BERT的DPR)。
- 输出形状:(1, 256, 768)(批次大小、序列长度、嵌入大小)
第3步:位置编码
- 位置编码被添加到令牌嵌入中。
- 变压器编码器的最终输入:(1, 256, 768)-仍然是相同的形状,但现在包含内容和位置信息。
第4步:变压器编码器内部(DPR / BERT)
- DPRContextEncoder将标记化文本转换为密集嵌入以进行语义检索。
- 它使用基于BERT的架构,包括嵌入、注意力和转换器层。
- 与一般BERT不同,DPR通过对比学习进行训练,以获得更好的通道检索。
- 关键区别:
- BERT = 一般NLP任务
- DPR = 针对基于相似度的检索进行了优化
现在让我们进一步探索第4步内部
Self-Attention(12头)-步骤4的一部分
对于每个令牌:
每个头使用单独的权重矩阵计算自己的Query(Q)、Key(K)和Value(V)投影。这就是为什么每个的形状是(1,256,64)(1, 256, 64)-64来自将768除以12个头。
并且每个头学会专注于输入的不同方面。有些头可能会学习语法,有些头专注于关系,有些头只是看起来在工作。但总的来说,它们捕获了不同的模式,然后将其组合起来,使模型有更丰富的理解。
缩放点积注意力(每个头)-步骤4的一部分:
其中:
- Q,K,V → (1,256,64)
- dk → 64是每个头的关键向量的维度
- 结果形状:(1,256,64)
多头注意力(合并所有头部)-步骤4的一部分:
如果你有12个头,就变成→(1,256,768)
每个头部独立应用注意力机制,输出沿最后一个维度连接。
前馈网络(FFN)-步骤4的一部分:
Linear(768 → 3072) → ReLU → Linear(3072 → 768)
对12个变压器层重复此过程,保持输出形状:(1,256,768)
池化-步骤4的一部分
- 通过转换器传递一个文本块后,我们得到一个形状
(1, 256, 768)
的输出-表示256个标记,每个标记都有一个768维向量。但是为了检索,我们不需要所有256个向量。相反,我们希望一个固定大小的向量使用池化来表示整个块。 - 我们使用**[CLS]令牌输出**来表示整个输入块:
output[:, 0, :] → (1, 768)
- 这会选择第一个标记的向量(位置0),它充当块级嵌入。
- 如果你有76个块,请为每个块重复编码步骤。最终结果是一个上下文矩阵:
context_matrix = (76, 768)
这意味着我们现在有76个向量嵌入,每个大小为768,可以添加到FAISS索引中进行检索。
我们不需要担心所有这些数学细节,所有这个索引步骤都可以使用以下代码(这只是一个快照,最终代码在最后):
import faiss
def build_faiss_index(embeddings):
"""从嵌入中创建并返回FAISS索引。"""
dim = embeddings.shape[1]
index = faiss.IndexFlatL2(dim)
index.add(embeddings.astype('float32'))
return index
DPR_CTX_ENCODER = "facebook/dpr-ctx_encoder-single-nq-base"
# 加载预训练的密集段落检索(DPR)上下文编码器和分词器
ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained(DPR_CTX_ENCODER)
ctx_encoder = DPRContextEncoder.from_pretrained(DPR_CTX_ENCODER)
# 对输入文本进行分词
inputs = ctx_tokenizer(text, return_tensors='pt', truncation=True, padding=True)
# 对输入进行编码并收集[CLS]令牌输出
with torch.no_grad():
outputs = ctx_encoder(**inputs)
# 追加池化输出(形状:[1, 768])
embeddings.append(outputs.pooler_output)
index = build_faiss_index(embeddings)
2. 检索步骤
第一步:问题编码(同上)
- 输入:“我应该如何存储易碎物品?”
- 标记化、嵌入、编码→输出:
(1, 768)
DPR_Q_ENCODER = "facebook/dpr-question_encoder-single-nq-base"
# 加载预训练的DPR问题编码器和分词器
q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(DPR_Q_ENCODER)
q_encoder = DPRQuestionEncoder.from_pretrained(DPR_Q_ENCODER)
# 对问题进行分词
inputs = q_tokenizer(question, return_tensors='pt')
# 推理时禁用梯度
with torch.no_grad():
outputs = q_encoder(**inputs)
# 返回[CLS]令牌嵌入
q_embedding = outputs.pooler_output.cpu().numpy()
第2步:相似度和排名前K
import faiss
def search(index, query_embedding, top_k=3):
"""返回前k个相似嵌入的距离和索引。"""
distances, indices = index.search(query_embedding, top_k)
return distances[0], indices[0]
# 在FAISS索引中执行相似度搜索以获取最匹配的段落
top_k = 3
distances, indices = index.search(q_embedding, top_k)
top_context = paragraphs[top_indices[0]]
3. 答案生成步骤
在进入步骤之前,只是关于解码器的一点细节:
- 一次生成一个令牌。
- 屏蔽自我注意力:解码器只关注以前生成的标记(而不是未来的标记),保持自回归行为。
- 交叉注意:解码器的查询(Q)涉及编码器的输出(K, V)-这是输入上下文的编码表示。
- 前馈网络→Softmax→Next Token:解码器通过前馈层处理参与输出,应用softmax获取概率,并选择下一个令牌。
第1步:从查询和返回的上下文构造BART Input或GPT2 Input
索引和检索步骤中的查询和返回上下文示例:
问题:我应该如何存放易碎物品?
上下文(从索引和检索步骤返回):
- “使用气泡膜。”
- “仅限顶部货架。”
- “永远不要堆叠标有易碎的物品”。
prompt = f"Context: {context}\n\nQuestion: {question}\nAnswer:"
tokenizer = AutoTokenizer.from_pretrained(GPT2)
model = AutoModelForCausalLM.from_pretrained(GPT2)
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
第2步:生成输出/答案
outputs = model.generate(**inputs, max_new_tokens=50)
矩阵形状摘要
每个组件的作用
这是检索增强生成的完整堆栈,变得简单而完整
代码逐步实现
我们将构建一个简单的RAG应用程序,执行以下操作:
- 从文本文件加载您的知识库(制造政策)。
- 使用DPR构建或加载段落嵌入的FAISS索引。(段落的编码步骤)
- 接受用户问题,将其编码成密集矢量。(问题编码)
- 在FAISS索引中搜索最相关的段落。(使用Faiss进行相似度匹配)
- 使用GPT-2或BART从该段落生成自然语言答案。(使用与问题一起生成的顶部上下文来生成答案)
- 在界面中显示检索到的上下文和AI生成的答案。
- 配置:config.py
# 文件路径
KB_PATH = "data/knowledge_base.txt"
FAISS_INDEX_PATH = "vector_store/faiss.index"
# 模型名称
DPR_CTX_ENCODER = "facebook/dpr-ctx_encoder-single-nq-base"
DPR_Q_ENCODER = "facebook/dpr-question_encoder-single-nq-base"
GPT2 = "gpt2"
BART = "facebook/bart-large"
- app.py
# 主应用程序
# app.py
import streamlit as st
from utils import load_paragraphs
from encoder import encode_paragraphs, encode_question
from retriever import build_faiss_index, search, load_faiss_index, save_faiss_index
from generator import generate_answer
from config import KB_PATH, FAISS_INDEX_PATH
import os
import numpy as np
# 加载知识库
paragraphs = load_paragraphs(KB_PATH)
if not paragraphs:
st.error("❗ knowledge_base.txt为空或缺失。请添加内容并重新启动。")
st.stop()
# 如果可用则加载现有的FAISS索引;否则使用DPR对段落进行编码,
# 从它们的嵌入中构建一个新的FAISS索引,并将其保存以供将来使用。
if os.path.exists(FAISS_INDEX_PATH):
index = load_faiss_index()
st.sidebar.success("已从文件加载FAISS索引。")
else:
st.sidebar.info("正在编码段落并构建FAISS索引...")
ctx_embeddings = encode_paragraphs(paragraphs)
index = build_faiss_index(ctx_embeddings)
save_faiss_index(index)
st.sidebar.success("FAISS索引已构建并保存。")
# Streamlit用户界面
st.title("RAG: 制造助手")
st.markdown("提出一个问题,并从您公司的政策知识库中获取答案。")
question = st.text_input("❓ 您的问题")
model_choice = st.selectbox("选择回答模型", ["gpt2", "bart"])
if question:
# 使用DPR问题编码器对用户的问题进行编码
q_embedding = encode_question(question)
# 在FAISS索引中执行相似度搜索以获取最匹配的段落
_, top_indices = search(index, q_embedding)
# 根据相似度检索最匹配的段落(上下文)
top_context = paragraphs[top_indices[0]]
st.subheader("检索到的上下文")
st.write(top_context)
# 使用选定的模型(GPT-2或BART)生成自然语言答案
st.subheader("答案")
answer = generate_answer(question, top_context, model_type=model_choice)
st.write(answer)
encoder.py
# DPR编码逻辑
# encoder.py
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer
import torch
from config import DPR_CTX_ENCODER, DPR_Q_ENCODER
# 加载预训练的密集段落检索(DPR)上下文编码器和分词器
ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained(DPR_CTX_ENCODER)
ctx_encoder = DPRContextEncoder.from_pretrained(DPR_CTX_ENCODER)
# 加载预训练的DPR问题编码器和分词器
q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(DPR_Q_ENCODER)
q_encoder = DPRQuestionEncoder.from_pretrained(DPR_Q_ENCODER)
def encode_paragraphs(paragraphs):
"""
使用DPR上下文编码器将段落列表编码为密集向量嵌入。
参数:
paragraphs (list of str): 要编码的文本块。
返回:
numpy.ndarray: 形状为 (num_paragraphs, 768) 的数组,包含密集嵌入。
"""
embeddings = []
for text in paragraphs:
# 对每个段落进行分词(带填充/截断)
inputs = ctx_tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=256)
# 推理时禁用梯度跟踪
with torch.no_grad():
outputs = ctx_encoder(**inputs)
# 从池化器中收集[CLS]令牌输出
embeddings.append(outputs.pooler_output)
# 将所有段落嵌入连接成一个单一的numpy数组
return torch.cat(embeddings).cpu().numpy()
def encode_question(question):
"""
使用DPR问题编码器将单个用户问题编码为密集向量。
参数:
question (str): 用户的自然语言查询。
返回:
numpy.ndarray: 形状为 (1, 768) 的数组,表示问题嵌入。
"""
# 对问题进行分词
inputs = q_tokenizer(question, return_tensors='pt')
# 推理时禁用梯度
with torch.no_grad():
outputs = q_encoder(**inputs)
# 返回[CLS]令牌嵌入
return outputs.pooler_output.cpu().numpy()
generator.py
# GPT-2 / BART生成
# generator.py
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
BartTokenizer,
BartForConditionalGeneration
)
from config import GPT2, BART
import torch
def generate_answer(question, context, model_type="gpt2"):
"""使用指定的模型从上下文和问题生成答案。"""
if model_type == "gpt2":
prompt = f"Context: {context}\n\nQuestion: {question}\nAnswer:"
tokenizer = AutoTokenizer.from_pretrained(GPT2)
model = AutoModelForCausalLM.from_pretrained(GPT2)
elif model_type == "bart":
prompt = f"question: {question} context: {context}"
tokenizer = BartTokenizer.from_pretrained(BART)
model = BartForConditionalGeneration.from_pretrained(BART)
else:
raise ValueError("Model must be 'gpt2' or 'bart'.")
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
outputs = model.generate(**inputs, max_new_tokens=50)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
retriever.py
# FAISS检索逻辑
# retriever.py
import faiss
import numpy as np
import os
from config import FAISS_INDEX_PATH
def build_faiss_index(embeddings):
"""从嵌入中创建并返回FAISS索引。"""
dim = embeddings.shape[1]
index = faiss.IndexFlatL2(dim)
index.add(embeddings.astype('float32'))
return index
def save_faiss_index(index, path=FAISS_INDEX_PATH):
faiss.write_index(index, path)
def load_faiss_index(path=FAISS_INDEX_PATH):
if os.path.exists(path):
return faiss.read_index(path)
else:
return None
def search(index, query_embedding, top_k=3):
"""返回前k个相似嵌入的距离和索引。"""
distances, indices = index.search(query_embedding, top_k)
return distances[0], indices[0]
参考文献
[1] 《检索增强生成在大语言模型中的应用:综述》https://arxiv.org/pdf/2312.10997
[2] 《检索增强生成(RAG)的全面综述:演变、当前格局和未来方向》https://arxiv.org/pdf/2410.12837
[3] 《什么是检索增强生成,即RAG?》https://blogs.nvidia.com/blog/what-is-retrieval-augmented-generation/
[4] 《Faiss》https://python.langchain.com/docs/integrations/vectorstores/faiss/
[5] 《DPR》https://huggingface.co/docs/transformers/en/model_doc/dpr
评论