可支持批量调用LLaVA,并将生成结果进行保存。
import os
import subprocess
import re
# 设置文件夹路径
folder_path = ""
# 设置模型路径
model_path = ""
# 要输入的文本
input_text = ""
# 输出结果保存路径
output_folder = ""
# 确保输出文件夹存在
os.makedirs(output_folder, exist_ok=True)
# 遍历文件夹中的所有文件
for image_file in os.listdir(folder_path):
# 构建完整的文件路径
image_path = os.path.join(folder_path, image_file)
# 检查是否是图像文件
if image_file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
# 构建命令
command = [
"python", "-m", "llava.serve.cli",
"--model-path", model_path,
"--image-file", image_path,
"--load-8bit"
]
# 打印命令
print(f"Running command: {' '.join(command)}")
try:
# 启动进程
process = subprocess.Popen(command, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
text=True)
# 输入文本并按回车
stdout, stderr = process.communicate(input=input_text + '\n', timeout=60)
# 提取模型回答部分
match = re.search(r'ASSISTANT: (.*?)(?:\nUSER:|\Z)', stdout, re.S)
if match:
response = match.group(1).strip()
else:
response = "No valid response found."
# 输出结果保存路径
output_file = os.path.join(output_folder, f"{os.path.splitext(image_file)[0]}.txt")
# 保存输出到文件
with open(output_file, 'w') as file:
file.write(response)
print(f"Saved output to {output_file}")
except subprocess.TimeoutExpired:
print(f"Timeout expired for {image_file}")
except Exception as e:
print(f"An error occurred for {image_file}: {e}")