文章目录
- 背景
- Amazon Bedrock
- 支持多模型选择
- 实验demo
- 列出Amazon Bedrock服务支持的模型
- 从读取用户评论、调用Amazon Bedrock模型进行分类
- 如何利用AWS的嵌入模型进行文本处理和分析
背景
2023年,生成式人工智能、大模型、ChatGPT等概念无处不在,但是到底什么是生成式人工智能?和之前的人工智能有什么区别?和大模型、Chatgpt的关系是什么?
生成式人工智能(gen AI)是一种人工智能(AI),它可以创建新的内容和想法(例如图像和视频),也可以重复使用已知知识来解决新问题。
生成式AI主要用于创造性的工作,包括但不限于以下领域:
- 文章生成:如ChatGPT,可以生成连贯的文本内容,用于写作、对话、代码生成等。
- 影像生成:如Midjourney和Stable Diffusion,可以根据文本描述生成高质量的图像。
- 音乐生成:AI可以创作旋律、和声甚至完整的音乐作品。
- 其他应用:如视频生成、3D模型生成、设计创作等。
Amazon Bedrock
官方文档:https://docs.aws.amazon.com/bedrock/latest/userguide/what-is-bedrock.html
Amazon Bedrock 是一项完全托管的服务,通过单个 API 提供来自 AI21 Labs、Anthropic、Cohere、Luma、Meta、Mistral AI、poolside(即将推出)、Stability AI 和 Amazon 等领先人工智能公司的高性能基础模型(FM)。
支持多模型选择
各种领先的基础模型(FM)可供选择,访问 Amazon Bedrock 中的 100 多个模型,Amazon Bedrock Marketplace
Amazon Bedrock Marketplace 是 Amazon Bedrock 的一项新功能,开发人员可以利用该功能发现、测试和使用 100 多种常用的新兴专业基础模型(FM),以及 Amazon Bedrock 中目前精选的行业领先模型。您可以轻松地在单个目录中发现模型、订阅模型,并在托管端点上部署模型。然后,您可以通过 Bedrock 的统一 API 访问这些模型,并将其与 Bedrock 代理、Bedrock 知识库和 Bedrock 防护机制等 Bedrock 工具一起在本地使用。
amazon Bedrock的模型目录界面,用户可以在此界面中发现和选择适合其使用案例的模型。以下是界面的主要组成部分和功能描述:
模型目录:
顶部显示了当前可用模型的总数(140个),并提供了筛选和搜索功能
用户可以根据“模型集合”(无服务器模型和Bedrock Marketplace模型)、提供者(如Anthropic、Hugging Face等)、模态(音频、文本等)进行筛选。
例如,用户可以选择只查看“无服务器”模型或特定提供者的模型。
因此您无需管理任何基础设施,并且可以使用已经熟悉的 AWS 服务将生成式人工智能功能安全地集成和部署到您的应用程序中。
实验demo
以下,是本人参加aws user group官方活动,实验demo代码,免费为aws打个广告,如有侵权,可联系删除~
官方实验地址:https://dev.amazoncloud.cn/experience/cloudlab?id=6760dc6cac1c0261e65af73f
代码实现了一个基于 AWS Bedrock 和 LangChain 的客户反馈分类系统,能够自动对客户评论进行分类。
关键技术组件
- AWS Bedrock:用于生成嵌入和对话模型
- LangChain:构建AI工作流
- Chroma:向量存储
- Pandas:数据处理
列出Amazon Bedrock服务支持的模型
使用Boto3库与Amazon Bedrock服务进行交互
import boto3
import json
import copy
import pandas as pd
from termcolor import colored
# create clients of bedrock
bedrock = boto3.client(service_name='bedrock', region_name = 'us-east-1')
bedrock_runtime = boto3.client(service_name='bedrock-runtime', region_name = 'us-east-1')
pd.set_option('display.max_rows', None)
results = []
available_models = bedrock.list_foundation_models()
for model in available_models['modelSummaries']:
if 'Amazon' in model['providerName'] and 'TEXT' in model['outputModalities']:
results.append({
'Model Name': model['modelName'],
'Model ID': model['modelId'], # Add Model ID column
'Provider': model['providerName'],
'Input Modalities': ', '.join(model['inputModalities']),
'Output Modalities': ', '.join(model['outputModalities']),
'Streaming': model.get('responseStreamingSupported', 'N/A'),
'Status': model['modelLifecycle']['status']
})
df = pd.DataFrame(results)
pd.reset_option('display.max_rows')
print(df)
使用Boto3库与Amazon Bedrock服务进行交互,并将获取的模型信息存储到Pandas数据框中。
- 使用boto3.client创建两个客户端:一个用于与Bedrock服务交互,另一个用于与Bedrock运行时服务交互,均指定区域为us-east-1。
- pd.set_option(‘display.max_rows’, None):设置Pandas显示所有行,而不限制行数,以便在输出时查看完整数据。
- df = pd.DataFrame(results):将结果列表转换为Pandas数据框,以便于后续的数据处理和分析
- pd.reset_option(‘display.max_rows’):重置Pandas的行显示选项,恢复默认设置。
- print(df):输出数据框的内容到控制台
从读取用户评论、调用Amazon Bedrock模型进行分类
代码实现了一个完整的流程,从读取用户评论、调用Amazon Bedrock模型进行分类,到计算分类准确率并保存结果。它展示了如何使用AWS的语言模型进行自然语言处理任务,并利用Pandas进行数据处理和分析。
import boto3
import json
from botocore.exceptions import ClientError
import dotenv
import os
dotenv.load_dotenv()
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import BaseChatModel, SimpleChatModel
from langchain_core.messages import AIMessageChunk, BaseMessage, HumanMessage,AIMessage,SystemMessage
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.runnables import run_in_executor
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain_core.output_parsers import StrOutputParser,XMLOutputParser
from langchain_core.prompts import ChatPromptTemplate,MessagesPlaceholder,HumanMessagePromptTemplate
class ChatModelNova(BaseChatModel):
model_name: str
br_runtime : Any = None
ak: str = None
sk: str = None
region:str = "us-east-1"
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if not self.br_runtime:
if self.ak and self.sk:
self.br_runtime = boto3.client(service_name = 'bedrock-runtime',
region_name = self.region,
aws_access_key_id = self.ak,
aws_secret_access_key = self.sk
)
else:
self.br_runtime = boto3.client(region_name = self.region, service_name = 'bedrock-runtime')
new_messages = []
system_message = ''
for msg in messages:
if isinstance(msg,SystemMessage):
system_message = msg.content
elif isinstance(msg,HumanMessage):
new_messages.append({
"role": "user",
"content": [ {"text": msg.content}]
})
elif isinstance(msg,AIMessage):
new_messages.append({
"role": "assistant",
"content": [ {"text": msg.content}]
})
temperature = kwargs.get('temperature',0.1)
maxTokens = kwargs.get('max_tokens',3000)
#Base inference parameters to use.
inference_config = {"temperature": temperature,"maxTokens":maxTokens}
# Send the message.
response = self.br_runtime.converse(
modelId=self.model_name,
messages=new_messages,
system=[{"text" : system_message}] if system_message else [],
inferenceConfig=inference_config
)
output_message = response['output']['message']
message = AIMessage(
content=output_message['content'][0]['text'],
additional_kwargs={}, # Used to add additional payload (e.g., function calling request)
response_metadata={ # Use for response metadata
**response['usage']
},
)
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
if not self.br_runtime:
if self.ak and self.sk:
self.br_runtime = boto3.client(service_name = 'bedrock-runtime',
region_name = self.region,
aws_access_key_id = self.ak,
aws_secret_access_key = self.sk
)
else:
self.br_runtime = boto3.client(service_name = 'bedrock-runtime', region_name = self.region)
new_messages = []
system_message = ''
for msg in messages:
if isinstance(msg,SystemMessage):
system_message = msg.content
elif isinstance(msg,HumanMessage):
new_messages.append({
"role": "user",
"content": [ {"text": msg.content}]
})
elif isinstance(msg,AIMessage):
new_messages.append({
"role": "assistant",
"content": [ {"text": msg.content}]
})
temperature = kwargs.get('temperature',0.1)
maxTokens = kwargs.get('max_tokens',3000)
#Base inference parameters to use.
inference_config = {"temperature": temperature,"maxTokens":maxTokens}
# Send the message.
streaming_response = self.br_runtime.converse_stream(
modelId=self.model_name,
messages=new_messages,
system=[{"text" : system_message}] if system_message else [],
inferenceConfig=inference_config
)
# Extract and print the streamed response text in real-time.
for event in streaming_response["stream"]:
if "contentBlockDelta" in event:
text = event["contentBlockDelta"]["delta"]["text"]
# print(text, end="")
chunk = ChatGenerationChunk(message=AIMessageChunk(content=[{"type":"text","text":text}]))
if run_manager:
# This is optional in newer versions of LangChain
# The on_llm_new_token will be called automatically
run_manager.on_llm_new_token(text, chunk=chunk)
yield chunk
if 'metadata' in event:
metadata = event['metadata']
# Let's add some other information (e.g., response metadata)
chunk = ChatGenerationChunk(
message=AIMessageChunk(content="", response_metadata={**metadata})
)
if run_manager:
run_manager.on_llm_new_token('', chunk=chunk)
yield chunk
@property
def _llm_type(self) -> str:
"""Get the type of language model used by this chat model."""
return "echoing-chat-model-advanced"
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Return a dictionary of identifying parameters.
This information is used by the LangChain callback system, which
is used for tracing purposes make it possible to monitor LLMs.
"""
return {
"model_name": self.model_name,
}
llm = ChatModelNova(region_name="us-east-1", model_name="amazon.nova-pro-v1:0")
messages = [
(
"system",
"You are a helpful assistant that translates English to French. Translate the user sentence.",
),
("human", "I love programming."),
]
llm.invoke(messages)
print(llm.invoke(messages))
import pandas as pd
comments_filepath = "data/comments.csv"
comments = pd.read_csv(comments_filepath)
category_definition = "data/categories.csv"
categories = pd.read_csv(category_definition)
print(categories)
system = """You are a professional customer feedback analyst. Your daily task is to categorize user feedback.
You will be given an input in the form of a JSON array. Each object in the array contains a comment ID and a 'c' field representing the user's comment content.
Your role is to analyze these comments and categorize them appropriately.
Please note:
1. Only output valid XML format data.
2. Do not include any explanations or additional text outside the XML structure.
3. Ensure your categorization is accurate and consistent.
4. If you encounter any ambiguous cases, use your best judgment based on the context provided.
5. Maintain a professional and neutral tone in your categorizations.
"""
user = """
Please categorize user comments according to the following category tags library:
<categories>
{tags}
</categories>
Please follow these instructions for categorization:
<instruction>
1. Categorize each comment using the tags above. If no tags apply, output "Others".
2. Summarize the comment content in no more than 50 words. Replace any double quotation marks with single quotation marks.
</instruction>
Below are the customer comments records to be categorized. The input is an array, where each element has an 'id' field representing the complaint ID and a 'c' field summarizing the complaint content.
<comments>
{input}
</comments>
For each record, summarize the comment, categorize according to the category explainations, and return the ID, summary , reasons for tag matches, and category.
Output format example:
<output>
<item>
<id>xxx</id>
<summary>xxx</summary>
<reason>xxx</reason>
<category>xxx</category>
</item>
</output>
Skip the preamble and output only valid XML format data. Remember:
- Avoid double quotation marks within quotation marks. Use single quotation marks instead.
- Replace any double quotation marks in the content with single quotation marks.
"""
prompt = ChatPromptTemplate([
('system',system),
('user',user),
],
partial_variables={'tags':categories['mappings'].values}
)
chain = prompt | llm | XMLOutputParser()
sample_data = [str({"id":'s_'+str(i),"comment":x[0]}) for i,x in enumerate(comments.values)]
print("\n".join(sample_data[:3]))
import math,json,time
from termcolor import colored
batch_size = 20
batch = math.ceil(comments.shape[0]/batch_size)
i = 0
resps = []
for i in range(batch):
print(colored(f"****[{i}]*****\n","blue"))
data = sample_data[i*batch_size:(i+1)*batch_size]
resp = chain.invoke(data)
print(colored(f"****response*****\n{resp}","green"))
for item in resp['output']:
row={}
for it in item['item']:
row[list(it.keys())[0]]=list(it.values())[0]
resps.append(row)
time.sleep(10)
prediction_df = pd.DataFrame(resps).rename(columns={"category":"predict_label"}).drop_duplicates(['id']).reset_index(drop='index')
# convert the label value to lowercase
prediction_df['predict_label'] = prediction_df['predict_label'].apply(lambda x: x.strip().lower().replace("'",""))
ground_truth = comments.copy()
# convert the label value to lowercase
ground_truth['groundtruth'] = ground_truth['groundtruth'].apply(lambda x: x.strip().lower())
merge_df=pd.concat([ground_truth,prediction_df],axis=1)
# 计算准确率
def check_contains(row):
return str(row['groundtruth']) in str(row['predict_label'])
matches = merge_df.apply(check_contains, axis=1)
count = matches.sum()
print(colored(f"accuracy: {count/len(merge_df)*100:.2f}%","green"))
# 列出所有错误分类的记录
def check_not_contains(row):
return str(row['groundtruth']) not in str(row['predict_label'])
merge_df[merge_df.apply(check_not_contains, axis=1)]
# 保存结果
merge_df.to_csv('result_lab_1.csv',index=False)
- 定义ChatModelNova类
该类继承自BaseChatModel,用于与Amazon Bedrock的聊天模型进行交互
model_name:模型名称。
br_runtime:Bedrock运行时客户端。
ak和sk:AWS访问密钥和秘密密钥。
region:AWS区域,默认为us-east-1 - _generate方法:用于发送消息并获取响应。
将输入消息转换为Bedrock所需的格式。
发送请求并处理响应,返回AI生成的消息。 - _stream方法:用于流式发送消息并实时接收响应。
类似于_generate,但支持实时输出。 - 使用Pandas读取用户评论和类别定义的CSV文件。
- 定义系统和用户提示词
- 定义系统角色的任务,即分析用户反馈并进行分类。
用户提示器包含分类的具体要求和输入格式 - 创建聊天提示模板
使用ChatPromptTemplate创建聊天提示模板,将系统和用户消息结合。 - 批量处理评论
将评论数据分批处理,每批20条。
对每批评论调用模型进行分类,并将结果存储在resps列表中。 - 结果处理
将模型预测的结果与真实标签进行合并,计算准确率。
定义check_contains函数检查预测标签是否与真实标签匹配,并计算准确率。
定义check_not_contains函数列出所有错误分类的记录。 - 保存结果
将合并后的结果保存为CSV文件result_lab_1.csv。
如何利用AWS的嵌入模型进行文本处理和分析
import boto3
import json
import copy
import pandas as pd
from termcolor import colored
from langchain_aws.embeddings.bedrock import BedrockEmbeddings
bedrock_embedding = BedrockEmbeddings(model_id='amazon.titan-embed-text-v2:0',region_name = "us-east-1")
test_embedding = bedrock_embedding.embed_documents(['I love programing'])
print(f"The embedding dimension is {len(test_embedding[0])}, first 10 elements are: {test_embedding[0][:10]}")
import pandas as pd
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
category_definition = "data/categories.csv"
categories = pd.read_csv(category_definition)
# 加载客户评论数据
comments_filepath = "data/comments.csv"
comments = pd.read_csv(comments_filepath)
print(comments)
# 加载示例数据
examples_filepath = "data/examples_with_label.csv"
examples_df = pd.read_csv(examples_filepath)
from langchain_chroma import Chroma
# You can uncomment below code to reset the vector db, if you want to retry more times
# vector_store.reset_collection()
vector_store = Chroma(
collection_name="example_collection",
embedding_function=bedrock_embedding,
persist_directory="./chroma_langchain_db", # Where to save data locally, remove if not neccesary
)
from uuid import uuid4
import hashlib
from langchain_core.documents import Document
# 构建 langchain 文档
documents = []
for comment,groundtruth in examples_df.values:
documents.append(
Document(
page_content=comment,
metadata={"groundtruth":groundtruth}
)
)
# 将文档添加到矢量存储
hash_ids = [hashlib.md5(doc.page_content.encode()).hexdigest() for doc in documents]
vector_store.add_documents(documents=documents, ids=hash_ids)
query = comments['comment'].sample(1).values[0]
print(colored(f"******query*****:\n{query}","blue"))
results = vector_store.similarity_search_with_relevance_scores(query, k=4)
print(colored("\n\n******results*****","green"))
for res, score in results:
print(colored(f"* [SIM={score:3f}] \n{res.page_content}\n{res.metadata}","green"))
import boto3
import json
from botocore.exceptions import ClientError
import dotenv
import os
dotenv.load_dotenv()
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import BaseChatModel, SimpleChatModel
from langchain_core.messages import AIMessageChunk, BaseMessage, HumanMessage,AIMessage,SystemMessage
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.runnables import run_in_executor
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain_core.output_parsers import StrOutputParser,XMLOutputParser
from langchain_core.prompts import ChatPromptTemplate,MessagesPlaceholder,HumanMessagePromptTemplate
class ChatModelNova(BaseChatModel):
model_name: str
br_runtime : Any = None
ak: str = None
sk: str = None
region:str = "us-east-1"
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if not self.br_runtime:
if self.ak and self.sk:
self.br_runtime = boto3.client(service_name = 'bedrock-runtime',
region_name = self.region,
aws_access_key_id = self.ak,
aws_secret_access_key = self.sk
)
else:
self.br_runtime = boto3.client(region_name = self.region,service_name = 'bedrock-runtime')
new_messages = []
system_message = ''
for msg in messages:
if isinstance(msg,SystemMessage):
system_message = msg.content
elif isinstance(msg,HumanMessage):
new_messages.append({
"role": "user",
"content": [ {"text": msg.content}]
})
elif isinstance(msg,AIMessage):
new_messages.append({
"role": "assistant",
"content": [ {"text": msg.content}]
})
temperature = kwargs.get('temperature',0.5)
maxTokens = kwargs.get('max_tokens',3000)
#Base inference parameters to use.
inference_config = {"temperature": temperature,"maxTokens":maxTokens}
# Send the message.
response = self.br_runtime.converse(
modelId=self.model_name,
messages=new_messages,
system=[{"text" : system_message}] if system_message else [],
inferenceConfig=inference_config
)
output_message = response['output']['message']
message = AIMessage(
content=output_message['content'][0]['text'],
additional_kwargs={}, # Used to add additional payload (e.g., function calling request)
response_metadata={ # Use for response metadata
**response['usage']
},
)
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
if not self.br_runtime:
if self.ak and self.sk:
self.br_runtime = boto3.client(service_name = 'bedrock-runtime',
region_name = self.region,
aws_access_key_id = self.ak,
aws_secret_access_key = self.sk
)
else:
self.br_runtime = boto3.client(service_name = 'bedrock-runtime')
new_messages = []
system_message = ''
for msg in messages:
if isinstance(msg,SystemMessage):
system_message = msg.content
elif isinstance(msg,HumanMessage):
new_messages.append({
"role": "user",
"content": [ {"text": msg.content}]
})
elif isinstance(msg,AIMessage):
new_messages.append({
"role": "assistant",
"content": [ {"text": msg.content}]
})
temperature = kwargs.get('temperature',0.5)
maxTokens = kwargs.get('max_tokens',3000)
#Base inference parameters to use.
inference_config = {"temperature": temperature,"maxTokens":maxTokens}
# Send the message.
streaming_response = self.br_runtime.converse_stream(
modelId=self.model_name,
messages=new_messages,
system=[{"text" : system_message}] if system_message else [],
inferenceConfig=inference_config
)
# Extract and print the streamed response text in real-time.
for event in streaming_response["stream"]:
if "contentBlockDelta" in event:
text = event["contentBlockDelta"]["delta"]["text"]
# print(text, end="")
chunk = ChatGenerationChunk(message=AIMessageChunk(content=[{"type":"text","text":text}]))
if run_manager:
# This is optional in newer versions of LangChain
# The on_llm_new_token will be called automatically
run_manager.on_llm_new_token(text, chunk=chunk)
yield chunk
if 'metadata' in event:
metadata = event['metadata']
# Let's add some other information (e.g., response metadata)
chunk = ChatGenerationChunk(
message=AIMessageChunk(content="", response_metadata={**metadata})
)
if run_manager:
run_manager.on_llm_new_token('', chunk=chunk)
yield chunk
@property
def _llm_type(self) -> str:
"""Get the type of language model used by this chat model."""
return "echoing-chat-model-advanced"
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Return a dictionary of identifying parameters.
This information is used by the LangChain callback system, which
is used for tracing purposes make it possible to monitor LLMs.
"""
return {
"model_name": self.model_name,
}
llm = ChatModelNova(region_name="us-east-1", model_name="amazon.nova-pro-v1:0")
# 定义提示模板
system = """You are a professional customer feedback analyst. Your daily task is to categorize user feedback.
You will be given an input in the form of a JSON array. Each object in the array contains a comment ID and a 'c' field representing the user's comment content.
Your role is to analyze these comments and categorize them appropriately.
Please note:
1. Only output valid XML format data.
2. Do not include any explanations or additional text outside the XML structure.
3. Ensure your categorization is accurate and consistent.
4. If you encounter any ambiguous cases, use your best judgment based on the context provided.
5. Maintain a professional and neutral tone in your categorizations.
"""
user = """
Please categorize user comments according to the following category tags library:
<categories>
{tags}
</categories>
Here are examples for your to categorize:
<examples>
{examples}
<examples>
Please follow these instructions for categorization:
<instruction>
1. Categorize each comment using the tags above. If no tags apply, output "Invalid Data".
2. Summarize the comment content in no more than 50 words. Replace any double quotation marks with single quotation marks.
</instruction>
Below are the customer comments records to be categorized. The input is an array, where each element has an 'id' field representing the complaint ID and a 'c' field summarizing the complaint content.
<comments>
{input}
</comments>
For each record, summarize the comment, categorize according to the category explainations, and return the ID, summary , reasons for tag matches, and category.
Output format example:
<output>
<item>
<id>xxx</id>
<summary>xxx</summary>
<reason>xxx</reason>
<category>xxx</category>
</item>
</output>
Skip the preamble and output only valid XML format data. Remember:
- Avoid double quotation marks within quotation marks. Use single quotation marks instead.
- Replace any double quotation marks in the content with single quotation marks.
"""
from typing import List
from langchain_core.documents import Document
from langchain_core.runnables import RunnablePassthrough
from langchain_core.runnables import RunnableParallel
retriever = vector_store.as_retriever(search_type="similarity", search_kwargs={"k": 5})
def format_docs(docs):
formatted = "\n".join([f"<content>{doc.page_content}</content>\n<cateogry>{doc.metadata['groundtruth']}</cateogry>" for doc in docs])
# print(colored(f"*****retrived examples******:\n{formatted}","yellow"))
return formatted
retrieved_docs = (retriever|format_docs).invoke("my phone always freeze")
prompt = ChatPromptTemplate([
('system',system),
('user',user)],
partial_variables={'tags':categories['mappings'].values}
)
chain = (
{"examples":retriever|format_docs,"input":RunnablePassthrough()}
| prompt
| llm
| XMLOutputParser()
)
sample_data = [str({"id":'s_'+str(i),"comment":x[0]}) for i,x in enumerate(comments.values)]
print(sample_data[:3])
import math,json
from json import JSONDecodeError
resps = []
def retry_call(chain,args: Dict[str,Any],times:int=3):
"""
Retry mechanism to ensure the success rate of final structure output
"""
try:
content = chain.invoke(args)
if 'output' not in content:
raise (JSONDecodeError('KeyError: output'))
return content
except Exception as e:
if times:
print(f'Exception, return [{times}]')
return retry_call(chain,args,times=times-1)
else:
raise(JSONDecodeError(e))
for i in range(len(sample_data)):
data = sample_data[i]
# print(colored(f"*****input[{i}]*****:\n{data}","blue"))
# resp = chain.invoke(data)
resp = retry_call(chain,data)
print(colored(f"*****response*****\n{resp}","green"))
# resps += json.loads(resp)
for item in resp['output']:
row={}
for it in item['item']:
row[list(it.keys())[0]]=list(it.values())[0]
resps.append(row)
# 检查所有数据是否已处理
assert len(resps) == len(sample_data), "Due to the uncertainly of LLM generation, the JSON output might fail occasionally, if you happen to experience error, please re-run above cell again"
# 转换为pandas数据框
prediction_df = pd.DataFrame(resps).rename(columns={"category":"predict_label"}).drop_duplicates(['id']).reset_index(drop='index')
# convert the label value to lowercase
prediction_df['predict_label'] = prediction_df['predict_label'].apply(lambda x: x.strip().lower().replace("'",""))
# 合并数据
ground_truth = comments.copy()
# convert the label value to lowercase
ground_truth['groundtruth'] = ground_truth['groundtruth'].apply(lambda x: x.strip().lower())
merge_df=pd.concat([ground_truth,prediction_df],axis=1)
print(merge_df)
# 计算准确率
def check_contains(row):
return row['groundtruth'] in row['predict_label']
matches = merge_df.apply(check_contains, axis=1)
count = matches.sum()
print(f"correct: {count}")
print(f"accuracy: {count/len(merge_df)*100:.2f}%")
# 保存结果
# 保存的csv文件将作为lab_4的实验文件
merge_df.to_csv('result_lab_3.csv',index=False)
- 实现自定义的 ChatModelNova 类,用于与 AWS Bedrock 的对话模型(amazon.nova-pro-v1:0)交互。
支持同步(_generate)和流式(_stream)响应。
嵌入准备
# 使用 Amazon Titan 嵌入模型
bedrock_embedding = BedrockEmbeddings(
model_id='amazon.titan-embed-text-v2:0',
region_name="us-east-1"
)
数据加载
# 加载三类数据
categories = pd.read_csv("data/categories.csv") # 类别定义
comments = pd.read_csv("data/comments.csv") # 客户评论
examples_df = pd.read_csv("data/examples_with_label.csv") # 标注示例
向量存储
# 创建 Chroma 向量存储
vector_store = Chroma(
collection_name="example_collection",
embedding_function=bedrock_embedding,
persist_directory="./chroma_langchain_db"
)