可视化kitti的pred


import pickle
import random
import cv2
import numpy as np
import matplotlib.pyplot as plt


def threshold_data(data, threshold):
    data2 = {"name": [], "truncated": [], "occluded": [], "alpha": [], "bbox": [], "dimensions": [], "location": [], "rotation_y": [], "score": [], "boxes_lidar": [], "frame_id": data['frame_id']}
    scores = data['score']
    # print(len(scores))
    for i in range(len(scores)):
        if scores[i] > threshold:
            data2['name'].append(data['name'][i])
            data2['truncated'].append(data['truncated'][i])
            data2['occluded'].append(data['occluded'][i])
            data2['alpha'].append(data['alpha'][i])
            data2['bbox'].append(data['bbox'][i])
            data2['dimensions'].append(data['dimensions'][i])
            data2['location'].append(data['location'][i])
            data2['rotation_y'].append(data['rotation_y'][i])
            data2['score'].append(data['score'][i])
            data2['boxes_lidar'].append(data['boxes_lidar'][i])
    print("After thresholding, boxes num is ",len(data2['score']))
    return data2
def get_calib( calib_path):
    with open(calib_path, 'r') as f:
        lines = f.readlines()
    P0 = np.array(        [float(info) for info in lines[0].split(' ')[1:13]]).reshape(            [3, 4])
    P1 = np.array(        [float(info) for info in lines[1].split(' ')[1:13]]).reshape(            [3, 4])
    P2 = np.array(        [float(info) for info in lines[2].split(' ')[1:13]]).reshape(            [3, 4])
    P3 = np.array(        [float(info) for info in lines[3].split(' ')[1:13]]).reshape(            [3, 4])
    R0_rect = np.array([        float(info) for info in lines[4].split(' ')[1:10]    ]).reshape([3, 3])
    Tr_velo_to_cam = np.array([        float(info) for info in lines[5].split(' ')[1:13]    ]).reshape([3, 4])
    Tr_imu_to_velo = np.array([        float(info) for info in lines[6].split(' ')[1:13]    ]).reshape([3, 4])
    return P0, P1, P2, P3, R0_rect, Tr_velo_to_cam, Tr_imu_to_velo
def visual_img_2d(img_path, data2):
    id=data2['frame_id']
    image = cv2.imread(img_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    for i in range(len(data2['bbox'])):
        x_min, y_min, x_max, y_max = map(int, data2['bbox'][i])
        cv2.rectangle(image, (x_min, y_min), (x_max, y_max), color=color_dt[data2['name'][i]], thickness=2)
    plt.imshow(image)
    plt.axis('off')  # Hide the axis
    plt.show()
    cv2.imwrite(f"pred/img_2d/{id}.png", image)
    print("img_2d saved to pred/img_2d/" + id + ".png")
    plt.close()
def calculate_iou(box1, box2):
    # 这里需要实现3D框的IOU计算
    # box格式为 [x, y, z, length, width, height, yaw]
    # 请根据具体的3D框计算方法来实现此函数
    # 返回值应该是计算得到的IOU
    pass

# def filter_boxes(data, threshold=0.3, iou_threshold=0.01):
#     def calculate_iou(box1, box2):
#         x1, y1, z1, l1, w1, h1, yaw1 = box1
#         x2, y2, z2, l2, w2, h2, yaw2 = box2
#         x1_min = x1 - l1 / 2
#         x1_max = x1 + l1 / 2
#         y1_min = y1 - w1 / 2
#         y1_max = y1 + w1 / 2
#         x2_min = x2 - l2 / 2
#         x2_max = x2 + l2 / 2
#         y2_min = y2 - w2 / 2
#         y2_max = y2 + w2 / 2
#         inter_x_min = max(x1_min, x2_min)
#         inter_x_max = min(x1_max, x2_max)
#         inter_y_min = max(y1_min, y2_min)
#         inter_y_max = min(y1_max, y2_max)
#         inter_width = max(0, inter_x_max - inter_x_min)
#         inter_height = max(0, inter_y_max - inter_y_min)
#         intersection_area = inter_width * inter_height
#         area1 = l1 * w1
#         area2 = l2 * w2
#         union_area = area1 + area2 - intersection_area
#         iou = intersection_area / union_area if union_area > 0 else 0
#         return iou
#     data2 = {"name": [], "truncated": [], "occluded": [], "alpha": [], "bbox": [], "dimensions": [], "location": [], "rotation_y": [], "score": [], "boxes_lidar": [], "frame_id": data['frame_id']}
#     keep_indices=list(range(len(data['score'])))
#     while len(keep_indices):
#         scores=[data['score'][i] for i in keep_indices]
#         if  max(scores)< threshold:
#             break
#         i = np.argmax(scores)
#         max_index=keep_indices[i]
#         max_score_box = data['boxes_lidar'][max_index]   
#         iou_values = [calculate_iou(max_score_box, data['boxes_lidar'][index]) for index in keep_indices ]
#         print("Before thresholding, boxes num is ",len(data2['score']))
#         keep_indices = [keep_indices[i] for i in  range(len(iou_values)) if iou_values[i] <= iou_threshold]   
#         data2['name'].append(data['name'][max_index])
#         data2['truncated'].append(data['truncated'][max_index])
#         data2['occluded'].append(data['occluded'][max_index])
#         data2['alpha'].append(data['alpha'][max_index])
#         data2['bbox'].append(data['bbox'][max_index])
#         data2['dimensions'].append(data['dimensions'][max_index])
#         data2['location'].append(data['location'][max_index])
#         data2['rotation_y'].append(data['rotation_y'][max_index])
#         data2['score'].append(data['score'][max_index])
#         data2['boxes_lidar'].append(data['boxes_lidar'][max_index])



        
#     print("After filtering, boxes num is ", len(data2['score']))
#     return data2
def visual_img_3d(img_path, data2, calib_path):
    """
        7 -------- 4
       /|         /|
      6 -------- 5 .
      | |        | |
      . 3 -------- 0
      |/         |/
      2 -------- 1
      目前不对
    """
    id=data2['frame_id']
    image = cv2.imread(img_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    P0, P1, P2, P3, R0_rect, Tr_velo_to_cam, Tr_imu_to_velo = get_calib( calib_path)
    for i in range(len(data2['bbox'])):   
        x_min, y_min, x_max, y_max = map(int, data2['bbox'][i])
        # cv2.rectangle(image, (x_min, y_min), (x_max, y_max), color=color_dt[data2['name'][i]], thickness=2)
        box_3d = data2['boxes_lidar'][i] # [39.689312  20.123287  -0.9484792  3.800696   1.5940567  1.478453  -3.1090808]
        x, y, z, l, w, h, yaw = box_3d        
        corners = np.array([
            [x - l / 2, y - w / 2, z - h / 2],
            [x + l / 2, y - w / 2, z - h / 2],
            [x + l / 2, y + w / 2, z - h / 2],
            [x - l / 2, y + w / 2, z - h / 2],  
            [x - l / 2, y - w / 2, z + h / 2],
            [x + l / 2, y - w / 2, z + h / 2],
            [x + l / 2, y + w / 2, z + h / 2],
            [x - l / 2, y + w / 2, z + h / 2],
        ])
        # 旋转矩阵
        corners_hom = np.hstack((corners, np.ones((8, 1))))  # 将点转换为齐次坐标
        corners_cam = np.dot(Tr_velo_to_cam, corners_hom.T).T
        corners_cam = corners_cam[:, :3]  # 忽略齐次坐标

        corners_hom_cam = np.hstack((corners_cam, np.ones((8, 1))))
        corners_img = np.dot(P2, corners_hom_cam.T).T
        corners_img = corners_img[:, :2] / corners_img[:, 2:3]  # 归一化
        for k1, k2 in zip([0, 0, 0, 1, 1, 2, 2, 3, 4, 4,  5, 6],
                        [1, 3, 4, 2, 5, 3, 6, 7, 5, 7,   6, 7]):
            pt1 = tuple(corners_img[k1].astype(int))
            pt2 = tuple(corners_img[k2].astype(int))
            color = color_dt[data2['name'][i]]
            cv2.line(image, pt1, pt2, color, 1)
    plt.imshow(image)
    plt.axis('off')  # 隐藏坐标轴
    plt.show()
    cv2.imwrite(f"pred/img_3d/{id}.png", image)
    print("img_3d saved to pred/img_3d/" + id + ".png")
    plt.close()

def visual_lidar_2d(lidar_path, data2):
    """
    绘制lidar点云的鸟瞰图BEV
    """
    color_dt={"Car":"green", "Pedestrian":"blue", "Cyclist":"pink"}
    def sample_lidar(lidar,percentage):
        num = int(len(lidar) * percentage)
        n=len(lidar)
        for i in range(num):
            x,y=random.randint(0,n-1),random.randint(0,n-1)
            lidar[x],lidar[y]=lidar[y],lidar[x]
        per=int(percentage*100)
        lidar_=[ lidar[i] for i in range(0,n,per)]
        return np.array(lidar_)
    id=data2['frame_id']
    lidar = np.fromfile(lidar_path, dtype=np.float32).reshape(-1, 4)
    # lidar = sample_lidar(lidar,0.4)
    x = lidar[:, 0]
    y = lidar[:, 1]
    r = np.sqrt(x ** 2 + y ** 2)
    fig,ax = plt.subplots(figsize=(10, 10))
    ax.scatter(x, y, c='gray', s=0.5) 
    ax.set_xlim(0, 35)
    ax.set_ylim(-20,20)
    ax.set_aspect('equal')
    box_3d = data2['boxes_lidar']
    for i, box in enumerate(data2['boxes_lidar']):
        x_center, y_center, z, length, width, height, yaw = box
        corners = np.array([
            [-length / 2, -width / 2],
            [length / 2, -width / 2],
            [length / 2, width / 2],
            [-length / 2, width / 2],
            [-length / 2, -width / 2]  # 回到起始点
        ])
        rotation_matrix = np.array([
            [np.cos(yaw), -np.sin(yaw)],
            [np.sin(yaw), np.cos(yaw)]
        ])
        rotated_corners = corners @ rotation_matrix.T
        
        rotated_corners[:, 0] += x_center
        rotated_corners[:, 1] += y_center
        ax.plot(rotated_corners[:, 0], rotated_corners[:, 1], c=color_dt[data2['name'][i]], linewidth=2)

    plt.show()
    fig.savefig(f"pred/lidar_2d/{id}.png", dpi=300)
    print("lidar_bev saved to pred/lidar_2d/" + id + ".png")
    plt.close()

def visual(datas):
    for data in datas:
        print(data['frame_id'])
        id=data['frame_id']
        img_path="/home/zbh/project/CoIn-main/data/kitti/training/image_2/"+id+".png"
        calib_path="/home/zbh/project/CoIn-main/data/kitti/training/calib/"+id+".txt"
        lidar_path="/home/zbh/project/CoIn-main/data/kitti/training/velodyne/"+id+".bin"
        data2=threshold_data(data, threshold=0.5)
        # data2=filter_boxes(data)
        # visual_img_2d(img_path, data2)
        # visual_img_3d(img_path, data2, calib_path)
        visual_lidar_2d(lidar_path, data2)
        print("\n")
        if int(data['frame_id'])>100:
            break

color_dt={"Car":(86,255,86), "Pedestrian":(255,255,0), "Cyclist":(255,170,212)}

if __name__ == '__main__':
    file_path = "/home/zbh/project/CoIn-main/output/kitti_models/aug_20_20/default/eval/eval_all_default/default/epoch_80/val/result.pkl"
    file_path = "/home/zbh/project/CoIn-main/output/kitti_models/aug_25/default/eval/eval_all_default/default/epoch_100/val/result.pkl"
    # file_path = "/home/zbh/project/CoIn-main/output/kitti_models/aug_25/default/ckpt/checkpoint_epoch_100.pth"
    with open(file_path, 'rb') as file:
        datas = pickle.load(file)
    "datas is list"
    "datas[0] is dict ,and its keys are ['name', 'truncated', 'occluded', 'alpha', 'bbox', 'dimensions', 'location', 'rotation_y', 'score', 'boxes_lidar', 'frame_id']"
    visual(datas)





请添加图片描述

请添加图片描述
在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值