streamlit-testGenius/newGrWithFile.py
VincentXiuyuanZhao 6a4d7d858d format code
2024-06-07 16:28:54 +08:00

129 lines
5.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import gradio as gr
import pandas as pd
# import tempfile
from http import HTTPStatus
import dashscope
from dashscope import Generation
import os
# import logging
from testAny import check_df_english, check_df_tags
import concurrent.futures
import datetime
# 设置日志记录
# logging.basicConfig(level=logging.INFO)
dashscope.api_key = os.getenv("DASHSCOPE_API_KEY") # Vincent's API key
def response(prompt, instruction=None):
messages = [{'role': 'user', 'content': prompt}]
if instruction is not None: # 如果提供了指令则添加到messages中
messages.insert(0, {'role': 'system', 'content': instruction})
try:
response = Generation.call(model='qwen-plus',
messages=messages,
seed=1234,
result_format='message',
stream=False,
incremental_output=False,
temperature=0.85,
top_p=0.9,
top_k=999
)
if response.status_code == HTTPStatus.OK:
message = response.output.choices[0]['message']['content']
return message
else:
print('Request id: %s, Status code: %s, error code: %s, error message: %s' % (
response.request_id, response.status_code,
response.code, response.message
))
return f"Error: Could not generate response with Status code: {response.status_code}, error code: {response.code}"
except Exception as e:
print(f"Failed to generate response: {e}")
return f"Error: Failed to generate response due to an error."
def format_full_prompt(df, introduction):
# 为每个 row 创建 context拼接RAG1和2
df['context'] = df.apply(lambda row: f"{row['RAG1']}-{row['RAG2']}", axis=1)
# 准备用于 format 的字典
column_list = df.drop('full_prompt', axis=1).columns.tolist() # 去除full_prompt列其他的都为参数
format_dict = df[column_list].apply(lambda x: dict(zip(x.index, x)), axis=1)
if len(introduction) >= 200:
df['full_prompt'] = introduction
# 使用 apply() 和 lambda 函数格式化 full_prompt 列
df['full_prompt'] = df.apply(lambda row: row['full_prompt'].format(**format_dict[row.name]), axis=1)
# 删除临时创建的 context 列
df.drop(columns=['context'], inplace=True)
return df
def process_xlsx(xlsx_file, instruction=None, loops=1):
try:
# 读取xlsx文件到pandas DataFrame
df = pd.read_excel(xlsx_file)
# 格式化prompts
formatted_df = format_full_prompt(df, instruction)
if loops > 1:
df_list = [formatted_df.copy() for _ in range(loops - 1)]
# 使用pd.concat一次性合并所有副本
formatted_df = pd.concat([formatted_df] + df_list, ignore_index=True)
# 调用response时根据instruction是否为None自动处理
formatted_df['Response'] = formatted_df['full_prompt'].apply(lambda prompt: response(prompt, instruction))
# check df with tags and english
formatted_df = check_df_tags(formatted_df)
formatted_df = check_df_english(formatted_df)
# 使用一个文件路径保存处理后的xlsx
date_str = datetime.datetime.now().strftime("%m%d%H")
times_str = str(20 * loops)
output_path = 'output'
if not os.path.exists(output_path):
os.makedirs(output_path)
file_name = f"{date_str}_{times_str}times_output.xlsx"
file_path = os.path.join(output_path, file_name)
formatted_df.to_excel(file_path, index=False, engine='openpyxl')
return formatted_df, file_path
except Exception as e:
print(f"Failed to process xlsx: {e}")
return None, None
def main():
with gr.Blocks() as demo:
gr.Markdown("### 大模型xlsx处理工具")
with gr.Accordion("输入说明"):
gr.Markdown("请上传一个xlsx文件文件应包含prompts。")
system_instruction = gr.Textbox(label="System Instruction", lines=2,
value=" ")
slider = gr.Slider(minimum=1, maximum=10, step=1, label="循环次数", value=1)
file_input = gr.File(label="上传xlsx文件")
submit_button = gr.Button("处理xlsx")
output_table = gr.Dataframe(label="处理后的数据")
output_file = gr.File(label="下载处理后的文件")
clear_data = gr.ClearButton(components=[output_table, output_file], value="Clear processed data")
clear_all = gr.ClearButton(components=[file_input, output_table, output_file], value="Clear console")
def update_output(xlsx_file, instruction, loops):
if xlsx_file is not None:
formatted_df, tmp_path = process_xlsx(xlsx_file, instruction, loops)
# 使用线程池并行处理每一行
with concurrent.futures.ThreadPoolExecutor() as executor:
formatted_df['Response'] = list(executor.map(response, formatted_df['full_prompt'], [instruction] * len(formatted_df)))
if formatted_df is not None:
return formatted_df, tmp_path # 返回DataFrame和文件路径
submit_button.click(fn=update_output, inputs=[file_input, system_instruction, slider],
outputs=[output_table, output_file])
demo.launch()
if __name__ == "__main__":
main()