在RetrievalQA链中实现过滤功能
我一直在按照这个教程,使用Langchain中的RetrievalQA和Azure OpenAI API中的大型语言模型(LLM)进行实现。我的实现进展不错,下面是我正在处理的代码片段:
import os
# env variables
os.environ["OPENAI_API_TYPE"] = "azure"
os.environ["OPENAI_API_VERSION"] = "<YOUR_API_VERSION>"
os.environ["OPENAI_API_KEY"] = "<YOUR_API_KEY>"
os.environ["AZURE_OPENAI_ENDPOINT"] = "https://<SPACE_NAME>.openai.azure.com/"
# libary imports
import pandas as pd
from langchain.prompts import PromptTemplate
from langchain.chains.router.llm_router import LLMRouterChain,RouterOutputParser
from langchain.embeddings import GPT4AllEmbeddings
from langchain.llms import AzureOpenAI
from langchain.chat_models import AzureChatOpenAI
from langchain.chains import RetrievalQA
from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain
from langchain.document_loaders import DataFrameLoader
from langchain.text_splitter import (RecursiveCharacterTextSplitter,
CharacterTextSplitter)
from langchain.vectorstores import Chroma
from langchain.vectorstores import utils as chromautils
from langchain.embeddings import (HuggingFaceEmbeddings, OpenAIEmbeddings,
SentenceTransformerEmbeddings)
from langchain.callbacks import get_openai_callback
#
# toy = 'Search in the documents and find a toy that teaches about color to kids'
toy = 'Search in the documents and find a toy with cards that has monsters'
all_docs = pd.read_csv(data) # data is the dataset from the tutorial (see above)
print('Model init \u2713')
print('----> Azure OpenAI \u2713')
llm_open = AzureChatOpenAI(
model="GPT3",
max_tokens = 100
)
print('Create docs \u2713')
loader = DataFrameLoader(all_docs,
page_content_column='description' # column description in data
)
my_docs = loader.load()
print'Create splits \u2713')
text_splitter = CharacterTextSplitter(chunk_size=512,
chunk_overlap=0
)
all_splits = text_splitter.split_documents(my_docs)
print('Init embeddings \u2713')
chroma_docs = chromautils.filter_complex_metadata(all_splits)
# embeddings = HuggingFaceEmbeddings()
my_model_name = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
embeddings = SentenceTransformerEmbeddings(model_name=my_model_name)
print('Create Chromadb \u2713')
vectorstore = Chroma.from_documents(all_splits,
embeddings,
# metadatas=[{"source": f"{i}-pl"} for i in \
# range(len(all_splits))]
)
print('Create QA chain \u2713')
qa_chain = RetrievalQA.from_chain_type(
llm=llm_open,
chain_type="stuff",
retriever=vectorstore.as_retriever(search_kwargs={"k": 10}),
verbose=True,)
print('*** YOUR ANSWER: ***')
with get_openai_callback() as cb:
llm_res = qa_chain.run(toy)
plpy.notice(f'{llm_res}')
plpy.notice(f'Total Tokens: {cb.total_tokens}')
plpy.notice(f'Prompt Tokens: {cb.prompt_tokens}')
plpy.notice(f'Completion Tokens: {cb.completion_tokens}')
plpy.notice(f'Total Cost (USD): ${cb.total_cost}')**strong text**
在这个教程中,有一部分是通过SQL查询来根据最低和最高价格过滤产品。不过,我不太确定如何在Langchain中使用RetrievalQA实现类似的功能,同时还能够获取相关的来源信息。教程中我提到的具体部分是:
results = await conn.fetch("""
WITH vector_matches AS (
SELECT product_id,
1 - (embedding <=> $1) AS similarity
FROM product_embeddings
WHERE 1 - (embedding <=> $1) > $2
ORDER BY similarity DESC
LIMIT $3
)
SELECT product_name,
list_price,
description
FROM products
WHERE product_id IN (SELECT product_id FROM vector_matches)
AND list_price >= $4 AND list_price <= $5
""",
qe, similarity_threshold, num_matches, min_price, max_price)
如何在Langchain中使用RetrievalQA链来实现这个过滤功能,并且获取与过滤后的产品相关的来源信息呢?
1 个回答
暂无回答