LLaVA批量推理

 可支持批量调用LLaVA,并将生成结果进行保存。

import os
import subprocess
import re

# 设置文件夹路径
folder_path = ""
# 设置模型路径
model_path = ""
# 要输入的文本
input_text = ""
# 输出结果保存路径
output_folder = ""

# 确保输出文件夹存在
os.makedirs(output_folder, exist_ok=True)

# 遍历文件夹中的所有文件
for image_file in os.listdir(folder_path):
    # 构建完整的文件路径
    image_path = os.path.join(folder_path, image_file)

    # 检查是否是图像文件
    if image_file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
        # 构建命令
        command = [
            "python", "-m", "llava.serve.cli",
            "--model-path", model_path,
            "--image-file", image_path,
            "--load-8bit"
        ]

        # 打印命令
        print(f"Running command: {' '.join(command)}")

        try:
            # 启动进程
            process = subprocess.Popen(command, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
                                       text=True)

            # 输入文本并按回车
            stdout, stderr = process.communicate(input=input_text + '\n', timeout=60)

            # 提取模型回答部分
            match = re.search(r'ASSISTANT: (.*?)(?:\nUSER:|\Z)', stdout, re.S)
            if match:
                response = match.group(1).strip()
            else:
                response = "No valid response found."

            # 输出结果保存路径
            output_file = os.path.join(output_folder, f"{os.path.splitext(image_file)[0]}.txt")

            # 保存输出到文件
            with open(output_file, 'w') as file:
                file.write(response)

            print(f"Saved output to {output_file}")

        except subprocess.TimeoutExpired:
            print(f"Timeout expired for {image_file}")

        except Exception as e:
            print(f"An error occurred for {image_file}: {e}")

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值