import json
import os
from math import trunc
import jsonlines
from collections import defaultdict
from typing import Dict, List, Any, Tuple
import logging
import time
from dataclasses import dataclass, field
配置日志
logging.basicConfig(
level=logging.INFO,
format=‘%(asctime)s - %(name)s - %(levelname)s - %(message)s’,
handlers=[
logging.FileHandler(“collision_processing.log”),
logging.StreamHandler()
]
)
logger = logging.getLogger(“CollisionProcessor”)
数据类定义
@dataclass
class FrameData:
timestamp: float
road_points: List[Tuple[float, float]]
record: Dict[str, Any]
group_key: str = “”
scene_id: str = “”
processed: bool = False
@dataclass
class CameraStats:
cam_tag: str
total_frames: int = 0
yes_count: int = 0
no_count: int = 0
@dataclass
class SceneStats:
scene_id: str
total_cam_tags: int = 0
total_frames: int = 0
yes_count: int = 0
no_count: int = 0
cam_tags: Dict[str, CameraStats] = field(default_factory=dict)
@dataclass
class OverallStats:
total_scenes: int = 0
total_frames: int = 0
yes_count: int = 0
no_count: int = 0
def extract_group_key(image_path: str) -> str:
“”“解析文件路径中的相机标签”“”
try:
parts = image_path.replace(“\”, “/”).strip(“/”).split(“/”)
return parts[-2] if len(parts) >= 2 else “default_group”
except Exception as e:
logger.error(f"提取分组键错误: {e}, 路径: {image_path}")
return “default_group”
def round_coordinate(value: float, decimals: int = 2) -> float:
“”“四舍五入坐标值到指定小数位数”“”
if value is None:
return 0.0
factor = 10 ** decimals
return round(value * factor) / factor
def process_road_points(road_points: List[Dict[str, float]]) -> List[Tuple[float, float]]:
“”“处理道路点数据,提取并四舍五入坐标”“”
if not road_points:
return []
return [
(round_coordinate(point.get(‘x’, 0.0)),
round_coordinate(point.get(‘y’, 0.0)))
for point in road_points
]
def calculate_overlap_ratio(
points1: List[Tuple[float, float]],
points2: List[Tuple[float, float]],
tolerance: float = 0.01
) -> float:
“”“计算两组点之间的重叠率”“”
if not points1 or not points2:
return 0.0
# 使用集合计算重叠率 set1 = {(trunc(x), trunc(y)) for x, y in points1} set2 = {(trunc(x), trunc(y)) for x, y in points2} # set1 = {(round(x, 2), round(y, 2)) for x, y in points1} # set2 = {(round(x, 2), round(y, 2)) for x, y in points2} intersection = set1 & set2 union = set1 | set2 return len(intersection) / len(union) if union else 0.0
def load_data(input_file: str) -> Dict[str, Dict[str, List[FrameData]]]:
“”“从JSONL文件加载数据并分组”“”
grouped_data = defaultdict(lambda: defaultdict(list))
total_records = 0
logger.info(f"开始读取输入文件: {input_file}") try: with open(input_file, 'r', encoding='utf-8') as f: for line_num, line in enumerate(f, 1): try: record = json.loads(line.strip()) scene_id = record.get('scene_id', 'unknown_scene') image_path = record.get('image_path', '') timestamp = record.get('timestamp', 0.0) # 提取道路点信息 mop_info = record.get('mop_pnc_info', {}) if not isinstance(mop_info, dict): mop_info = {} road_points = mop_info.get('roadPoints', []) # 处理道路点 processed_points = process_road_points(road_points) # 提取分组键 group_key = extract_group_key(image_path) # 创建FrameData对象 frame_data = FrameData( timestamp=timestamp, road_points=processed_points, record=record, group_key=group_key, scene_id=scene_id ) # 添加到分组数据 grouped_data[scene_id][group_key].append(frame_data) total_records += 1 except json.JSONDecodeError as e: logger.error(f"JSON解析错误 (行 {line_num}): {e}") except Exception as e: logger.error(f"处理记录错误 (行 {line_num}): {e}") except IOError as e: logger.error(f"文件读取错误: {e}") return {} logger.info(f"成功读取 {total_records} 条记录,分组到 {len(grouped_data)} 个场景") return grouped_data
def process_frame_group(
frames: List[FrameData],
threshold: float = 0.7,
debug: bool = False
) -> CameraStats:
“”“处理一组帧数据,使用新的打标逻辑”“”
if not frames:
return CameraStats(cam_tag=“”)
# 获取相机标签 cam_tag = frames[0].group_key cam_stats = CameraStats(cam_tag=cam_tag) cam_stats.total_frames = len(frames) # 按时间戳降序排序(最新时间在前) frames.sort(key=lambda x: x.timestamp, reverse=True) # 标记是否已出现第一个no no_occurred = False # 处理所有帧 for i, frame in enumerate(frames): if i == 0: # 第一帧(最新帧) frame.record['mask'] = 'yes' cam_stats.yes_count += 1 if debug: logger.debug(f"相机 {cam_tag}: 帧0 (最新) 标签设置为 'yes'") elif i == 1: # 第二帧 # 与第一帧比较 overlap = calculate_overlap_ratio( frames[0].road_points, frame.road_points ) if overlap >= threshold: frame.record['mask'] = 'yes' cam_stats.yes_count += 1 if debug: logger.debug(f"相机 {cam_tag}: 帧1 与帧0重叠度 {overlap:.2f} >= {threshold} -> 标签设置为 'yes'") else: frame.record['mask'] = 'no' cam_stats.no_count += 1 no_occurred = True # 标记已出现第一个no if debug: logger.debug( f"相机 {cam_tag}: 帧1 与帧0重叠度 {overlap:.2f} < {threshold} -> 标签设置为 'no' (首次出现)") else: # 第三帧及以后 if no_occurred: # 一旦出现第一个no,后续所有帧都是no frame.record['mask'] = 'no' cam_stats.no_count += 1 if debug: logger.debug(f"相机 {cam_tag}: 帧{i} 已出现no -> 标签设置为 'no'") else: # 与前一帧比较 overlap = calculate_overlap_ratio( frames[i - 1].road_points, frame.road_points ) if overlap >= threshold: frame.record['mask'] = frames[i - 1].record['mask'] if frame.record['mask'] == 'yes': cam_stats.yes_count += 1 else: cam_stats.no_count += 1 if debug: logger.debug( f"相机 {cam_tag}: 帧{i} 与帧{i - 1}重叠度 {overlap:.2f} >= {threshold} -> 继承标签 '{frame.record['mask']}'") else: frame.record['mask'] = 'no' cam_stats.no_count += 1 no_occurred = True # 标记已出现第一个no if debug: logger.debug( f"相机 {cam_tag}: 帧{i} 与帧{i - 1}重叠度 {overlap:.2f} < {threshold} -> 标签设置为 'no' (首次出现)") frame.processed = True return cam_stats
def process_collision_data(
input_file: str,
output_file: str,
stats_file: str,
threshold: float = 0.7,
debug: bool = False
) -> Dict[str, Any]:
“”“处理碰撞数据的核心函数”“”
start_time = time.time()
logger.info(f"开始处理碰撞数据: {input_file}")
# 加载数据 grouped_data = load_data(input_file) if not grouped_data: logger.error("未加载到有效数据,处理终止") return {} # 初始化统计数据结构 overall_stats = OverallStats(total_scenes=len(grouped_data)) scene_stats_dict = {} labeled_data = [] # 处理每个场景 for scene_id, cam_groups in grouped_data.items(): scene_stats = SceneStats( scene_id=scene_id, total_cam_tags=len(cam_groups) ) # 处理每个相机组 for group_key, frames in cam_groups.items(): cam_stats = process_frame_group(frames, threshold, debug) # 更新场景统计 scene_stats.cam_tags[group_key] = cam_stats scene_stats.total_frames += cam_stats.total_frames scene_stats.yes_count += cam_stats.yes_count scene_stats.no_count += cam_stats.no_count # 收集处理后的数据 labeled_data.extend(frame.record for frame in frames if frame.processed) # 更新总体统计 scene_stats_dict[scene_id] = scene_stats overall_stats.total_frames += scene_stats.total_frames overall_stats.yes_count += scene_stats.yes_count overall_stats.no_count += scene_stats.no_count # 按scene_id、cam_tag和timestamp降序排序 labeled_data.sort(key=lambda x: ( x.get('scene_id', ''), x.get('image_path', '').split('/')[-2] if 'image_path' in x else '', -x.get('timestamp', 0) )) # 准备输出统计信息 stats = { "stats_file": os.path.basename(stats_file), "total_scenes": overall_stats.total_scenes, "scenes": {}, "overall": { "total_frames": overall_stats.total_frames, "yes_count": overall_stats.yes_count, "no_count": overall_stats.no_count } } # 添加场景统计 for scene_id, scene_stats in scene_stats_dict.items(): stats["scenes"][scene_id] = { "scene_id": scene_stats.scene_id, "total_cam_tags": scene_stats.total_cam_tags, "total_frames": scene_stats.total_frames, "yes_count": scene_stats.yes_count, "no_count": scene_stats.no_count, "cam_tags": {} } # 添加相机统计 for cam_tag, cam_stats in scene_stats.cam_tags.items(): stats["scenes"][scene_id]["cam_tags"][cam_tag] = { "cam_tag": cam_stats.cam_tag, "total_frames": cam_stats.total_frames, "yes_count": cam_stats.yes_count, "no_count": cam_stats.no_count } # 确保输出目录存在 os.makedirs(os.path.dirname(output_file) or '.', exist_ok=True) os.makedirs(os.path.dirname(stats_file) or '.', exist_ok=True) # 写入输出文件 logger.info(f"写入处理后的数据到: {output_file}") with jsonlines.open(output_file, 'w') as writer: for record in labeled_data: writer.write(record) # 写入统计文件 logger.info(f"写入统计信息到: {stats_file}") with open(stats_file, 'w', encoding='utf-8') as f: json.dump(stats, f, ensure_ascii=False, indent=4) # 计算处理时间 processing_time = time.time() - start_time logger.info(f"处理完成! 总耗时: {processing_time:.2f}秒") return stats
if name == “main”:
# 使用指定的文件路径调用处理函数
stats = process_collision_data(
#input_file=“./basicData_historyDepth_v2_img106k_250723__huanshi_6ver_4fish_large_model_900_792874_10frames.jsonl”,
input_file=“./basicData_realCollision_v2_img42k_250726__real_collision_dataset_pt1_to_8.jsonl”,
#input_file=“./basicData_vggtHDepth_v2_img42k_250725__huanshi_6ver_6pinhole__large_model_900_792874_10frames__pt1__250722.jsonl”,
output_file=“processed_42k_0726_tag03.jsonl”,
stats_file=“collision_stats_0726_tag03.json”,
threshold=0.1,
debug=False # 设置为True可查看详细处理日志
)
修改数据处理逻辑:按照scene_id分组,按照image_path按照/分割后取倒数第二个值分组,分组内按照timestamp降序排列,然后进行处理:默认分组内第一条数据的mask = "yes",后续的处理逻辑为取出前一帧数据的roadpoint的坐标值列表和当前数据的roadpoint值进行取交集和并集,允许数据值误差0.01,求交集/丙级>0.7的继承前一条数据的mask,否则对前一条数据的mask取否,当第一次出现no后,该分组的后续数据都标记为no。