add loop and checks

This commit is contained in:
VincentXiuyuanZhao 2024-06-04 00:04:12 +08:00
parent 1b3380865c
commit 486d691712

View File

@ -53,15 +53,24 @@ def format_full_prompt(df, introduction):
return df return df
def process_xlsx(xlsx_file, instruction=None): # 这里也使instruction参数变成可选 def process_xlsx(xlsx_file, instruction=None, loops=1): # 这里也使instruction参数变成可选
# 读取xlsx文件到pandas DataFrame # 读取xlsx文件到pandas DataFrame
df = pd.read_excel(xlsx_file) df = pd.read_excel(xlsx_file)
# 格式化prompts # 格式化prompts
formatted_df = format_full_prompt(df, instruction) formatted_df = format_full_prompt(df, instruction)
if loops >= 1:
df_list = [formatted_df.copy() for _ in range(loops)]
# 使用pd.concat一次性合并所有副本
formatted_df = pd.concat(df_list, ignore_index=True)
# 假设我们要处理的提示是DataFrame的'full_prompt'列 # 假设我们要处理的提示是DataFrame的'full_prompt'列
# 调用response时根据instruction是否为None自动处理 # 调用response时根据instruction是否为None自动处理
formatted_df['Response'] = formatted_df['full_prompt'].apply(lambda prompt: response(prompt, instruction)) formatted_df['Response'] = formatted_df['full_prompt'].apply(lambda prompt: response(prompt, instruction))
# check df with tags and english
formatted_df = check_df_tags(formatted_df)
formatted_df = check_df_english(formatted_df)
# 使用tempfile创建一个临时文件路径保存处理后的xlsx # 使用tempfile创建一个临时文件路径保存处理后的xlsx
tmp_path = tempfile.NamedTemporaryFile(delete=True, suffix='.xlsx').name tmp_path = tempfile.NamedTemporaryFile(delete=True, suffix='.xlsx').name
formatted_df.to_excel(tmp_path, index=False, engine='openpyxl') formatted_df.to_excel(tmp_path, index=False, engine='openpyxl')
@ -76,6 +85,7 @@ def main():
gr.Markdown("请上传一个xlsx文件文件应包含prompts。") gr.Markdown("请上传一个xlsx文件文件应包含prompts。")
system_instruction = gr.Textbox(label="System Instruction", lines=2, system_instruction = gr.Textbox(label="System Instruction", lines=2,
value=" ") value=" ")
slider = gr.Slider(minimum=1, maximum=10, step=1, label="循环次数", value=1)
file_input = gr.File(label="上传xlsx文件") file_input = gr.File(label="上传xlsx文件")
submit_button = gr.Button("处理xlsx") submit_button = gr.Button("处理xlsx")
@ -84,16 +94,13 @@ def main():
output_file = gr.File(label="下载处理后的文件") output_file = gr.File(label="下载处理后的文件")
clear_data = gr.ClearButton(components=[output_table, output_file], value="Clear processed data") clear_data = gr.ClearButton(components=[output_table, output_file], value="Clear processed data")
clear_all = gr.ClearButton(components=[file_input, output_table, output_file], value="Clear console") clear_all = gr.ClearButton(components=[file_input, output_table, output_file], value="Clear console")
check_tags_button = gr.Button("检查xlsx文件-tags") def update_output(xlsx_file, instruction, loops):
check_english_button = gr.Button("检查xlsx文件-英文")
def update_output(xlsx_file, instruction):
if xlsx_file is not None: if xlsx_file is not None:
formatted_df, tmp_path = process_xlsx(xlsx_file, instruction) formatted_df, tmp_path = process_xlsx(xlsx_file, instruction, loops=loops)
return formatted_df, tmp_path # 返回DataFrame和文件路径 return formatted_df, tmp_path # 返回DataFrame和文件路径
submit_button.click(fn=update_output, inputs=[file_input, system_instruction], submit_button.click(fn=update_output, inputs=[file_input, system_instruction, slider],
outputs=[output_table, output_file]) outputs=[output_table, output_file])
check_tags_button.click(fn=check_df_tags, inputs=[output_table], outputs=[output_table])
demo.launch() demo.launch()