import argparse
import cv2
import numpy as np
from ais_bench.infer.interface import InferSession
CLASSES = {0: 'class_0', 1: 'class_1', 2: 'class_2', ...}
def preprocess_image(image_path, target_size=(224, 224)):
"""
预处理图像至模型所需的尺寸。
"""
image = cv2.imread(image_path)
image = cv2.resize(image, target_size, interpolation=cv2.INTER_LINEAR)
image = image / 255.0
image = image.astype(np.float32)
image = np.expand_dims(image, axis=0)
return image
def classify_image(session, image_path):
"""
执行图像分类推理并打印分类结果。
"""
image_data = preprocess_image(image_path)
begin_time = time.time()
outputs = session.infer(feeds=image_data, mode="static")
end_time = time.time()
print("OM infer time:", end_time - begin_time)
prediction = outputs[0]
predicted_class_id = np.argmax(prediction)
predicted_class = CLASSES[predicted_class_id]
confidence = prediction[predicted_class_id]
print(f"Predicted Class: {predicted_class} with confidence {confidence:.2f}")
def main(om_model, input_image):
session = InferSession(device_id=0, model_path=om_model)
classify_image(session, input_image)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", default="classification_model.om", help="Input your OM model for classification.")
parser.add_argument("--img", default="path_to_your_image.jpg", help="Path to input image.")
args = parser.parse_args()
main(args.model, args.img)