import matplotlib
try:
import seaborn as sns
seaborn_style = sns.axes_style("whitegrid")
matplotlib.style.library['seaborn-whitegrid'] = seaborn_style
except Exception as e:
print("[WARN] Failed to add seaborn-whitegrid style to matplotlib: ", e)
try:
matplotlib.style.library['seaborn-whitegrid'] = {}
except Exception:
pass
_orig_style_use = matplotlib.style.use
def _safe_style_use(style_name):
try:
if isinstance(style_name, str) and style_name.lower() == 'seaborn-whitegrid':
# 确保库中有这个样式
if 'seaborn-whitegrid' not in matplotlib.style.library:
matplotlib.style.library['seaborn-whitegrid'] = {}
return _orig_style_use('seaborn-whitegrid')
return _orig_style_use(style_name)
except Exception as e:
print(f"[WARN] Error setting style {style_name}: {e}")
if style_name.lower() == 'seaborn-whitegrid':
return _orig_style_use('default')
else:
raise
matplotlib.style.use = _safe_style_use
# ---------------- seaborn style 补丁 END ----------------
import os
import sys
import json
import logging
import traceback
import re
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
from pyquaternion import Quaternion
# try to import NuScenes + NuScenesMap + PredictHelper
MAP_API_AVAILABLE = True
ARCLINE_AVAILABLE = True
map_api_import_exception = None
arcline_import_exception = None
try:
from nuscenes import NuScenes
from nuscenes.prediction import PredictHelper
try:
from nuscenes.map_expansion.map_api import NuScenesMap
except Exception as e_map:
NuScenesMap = None
MAP_API_AVAILABLE = False
map_api_import_exception = e_map
try:
from nuscenes.map_expansion import arcline_path_utils
except Exception as e_arc:
arcline_path_utils = None
ARCLINE_AVAILABLE = False
arcline_import_exception = e_arc
except Exception as e_all:
# If entire import failed, keep going but mark unavailable
NuScenes = None
PredictHelper = None
NuScenesMap = None
arcline_path_utils = None
MAP_API_AVAILABLE = False
ARCLINE_AVAILABLE = False
map_api_import_exception = e_all
arcline_import_exception = e_all
# ---------------- Config ----------------
NUSCENES_ROOT = os.getenv('NUSCENES_ROOT', '/home/ljc/nuscenes')
OUTPUT_DIR = os.getenv('OUTPUT_DIR', '/home/ljc/RISG/RISG-nuscenes/datasets')
LOGS_DIR = os.path.join(OUTPUT_DIR, 'logs')
os.makedirs(LOGS_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)
HISTORY_SECONDS = 2.0
FUTURE_SECONDS = 6.0
SAMPLE_INTERVAL = 0.5
MIN_HISTORY_STEPS = int(HISTORY_SECONDS / SAMPLE_INTERVAL)
MIN_FUTURE_STEPS = int(FUTURE_SECONDS / SAMPLE_INTERVAL)
DEFAULT_MAP_QUERY_RADIUS = 50.0 # meters
logfile = os.path.join(LOGS_DIR, f'preprocess_map_fallback_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log')
logging.basicConfig(
level=logging.DEBUG,
format='%(asctime)s [%(levelname)s] %(message)s',
handlers=[logging.FileHandler(logfile), logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger('preprocess_map_fallback')
logger.info("Starting preprocess_nuscenes_map_fallback.py (debug mode)")
logger.info("NUSCENES_ROOT=%s OUTPUT_DIR=%s", NUSCENES_ROOT, OUTPUT_DIR)
logger.info("MAP_API_AVAILABLE=%s ARCLINE_AVAILABLE=%s", MAP_API_AVAILABLE, ARCLINE_AVAILABLE)
if not MAP_API_AVAILABLE:
logger.warning("NuScenesMap import failed: %s", repr(map_api_import_exception))
if not ARCLINE_AVAILABLE:
logger.warning("arcline_path_utils import failed: %s", repr(arcline_import_exception))
# ---------------- helpers ----------------
def sanitize_filename(name: str) -> str:
name = name.strip().replace(' ', '_')
return re.sub(r'[^A-Za-z0-9_\\-\\.]', '_', name)
def global_to_ego_xy(global_xy: Tuple[float, float], ego_xy: Tuple[float, float], ego_yaw: float) -> List[float]:
x, y = float(global_xy[0]), float(global_xy[1])
x0, y0 = float(ego_xy[0]), float(ego_xy[1])
dx, dy = x - x0, y - y0
cos_t, sin_t = np.cos(ego_yaw), np.sin(ego_yaw)
x_ego = cos_t * dx + sin_t * dy
y_ego = -sin_t * dx + cos_t * dy
return [float(x_ego), float(y_ego)]
class NumpyEncoder(json.JSONEncoder):
def default(self, obj: Any):
if isinstance(obj, (np.integer,)):
return int(obj)
if isinstance(obj, (np.floating,)):
return float(obj)
if isinstance(obj, (np.ndarray,)):
return obj.tolist()
if isinstance(obj, (np.bool_, bool)):
return bool(obj)
return super().default(obj)
# ---------------- map-json parsing caches ----------------
_map_json_cache: Dict[str, Dict] = {}
_node_map_cache: Dict[str, Dict[str, Tuple[float, float]]] = {}
_line_map_cache: Dict[str, Dict[str, np.ndarray]] = {}
def find_map_json_path(dataroot: str, map_name: str) -> Optional[str]:
if not map_name:
return None
maps_dir = os.path.join(dataroot, 'maps')
cand = os.path.join(maps_dir, f"{map_name}.json")
if os.path.exists(cand):
return cand
if os.path.isdir(maps_dir):
for f in os.listdir(maps_dir):
if f.lower().endswith('.json') and f.lower().startswith(map_name.lower()):
return os.path.join(maps_dir, f)
return None
def load_map_json_cached(path: Optional[str]) -> Optional[Dict]:
if not path:
return None
if path in _map_json_cache:
return _map_json_cache[path]
try:
with open(path, 'r', encoding='utf-8') as fh:
mj = json.load(fh)
_map_json_cache[path] = mj
logger.debug("Loaded map json: %s (keys=%d)", path, len(mj.keys()))
return mj
except Exception:
logger.exception("Failed to load map json: %s", path)
return None
def try_extract_coords_from_node(rec: Any) -> Optional[Tuple[float, float]]:
# node records often have x,y or xyz list
try:
if rec is None:
return None
if isinstance(rec, dict):
# common fields
if 'x' in rec and 'y' in rec:
return (float(rec['x']), float(rec['y']))
if 'xyz' in rec and isinstance(rec['xyz'], (list,tuple)) and len(rec['xyz']) >= 2:
return (float(rec['xyz'][0]), float(rec['xyz'][1]))
# sometimes 'translation' or 'center' etc
if 'translation' in rec and isinstance(rec['translation'], (list,tuple)) and len(rec['translation']) >= 2:
return (float(rec['translation'][0]), float(rec['translation'][1]))
# sometimes node stores 'node_token' or 'token' with coords in nested field
if 'node' in rec and isinstance(rec['node'], dict):
return try_extract_coords_from_node(rec['node'])
# if it's a list/tuple of coords
if isinstance(rec, (list,tuple)) and len(rec) >= 2:
return (float(rec[0]), float(rec[1]))
except Exception:
logger.debug("try_extract_coords_from_node failed for rec type %s", type(rec))
return None
return None
def build_node_map(map_json: Dict, map_json_path: str) -> Dict[str, Tuple[float, float]]:
# build token -> (x,y)
if map_json_path in _node_map_cache:
return _node_map_cache[map_json_path]
nodes = {}
if not map_json:
_node_map_cache[map_json_path] = nodes
return nodes
node_obj = None
if 'node' in map_json:
node_obj = map_json['node']
else:
# find candidate
for k in map_json.keys():
if 'node' in k.lower():
node_obj = map_json[k]; break
if node_obj is None:
logger.warning("No 'node' section in map JSON")
_node_map_cache[map_json_path] = nodes
return nodes
if isinstance(node_obj, dict):
for token, rec in node_obj.items():
coord = try_extract_coords_from_node(rec)
if coord:
nodes[str(token)] = coord
elif isinstance(node_obj, list):
for rec in node_obj:
token = None
coord = None
if isinstance(rec, dict):
token = rec.get('token') or rec.get('id') or rec.get('uid') or rec.get('node_token')
coord = try_extract_coords_from_node(rec)
elif isinstance(rec, (list,tuple)):
coord = try_extract_coords_from_node(rec)
if token and coord:
nodes[str(token)] = coord
logger.info("Built node map with %d entries from %s", len(nodes), map_json_path)
_node_map_cache[map_json_path] = nodes
return nodes
def try_extract_coords(val: Any) -> Optional[np.ndarray]:
"""
Try to coerce val into an Nx2 numpy array of XY coordinates.
Handles lists, strings that encode arrays, dicts with 'xyz' etc.
"""
if val is None:
return None
try:
if isinstance(val, (list,tuple, np.ndarray)):
arr = np.asarray(val)
if arr.ndim >= 2 and arr.shape[1] >= 2:
return arr[:, :2].astype(float)
# list of dicts with x,y
if len(val) > 0 and isinstance(val[0], dict):
pts=[]
for p in val:
if 'x' in p and 'y' in p:
pts.append([float(p['x']), float(p['y'])])
elif 'xyz' in p and isinstance(p['xyz'], (list,tuple)) and len(p['xyz'])>=2:
pts.append([float(p['xyz'][0]), float(p['xyz'][1])])
else:
return None
if pts:
return np.asarray(pts)[:, :2]
if isinstance(val, dict):
for key in ['xyz','polyline','coords','points','geometry','points_2d','line','translation','polyline_coords']:
if key in val and val[key]:
return try_extract_coords(val[key])
# nested search
for k,v in val.items():
c = try_extract_coords(v)
if c is not None:
return c
if isinstance(val, str):
# try JSON parse
try:
parsed = json.loads(val)
return try_extract_coords(parsed)
except Exception:
# regex float extraction
floats = re.findall(r'[-+]?[0-9]*\\.?[0-9]+(?:[eE][-+]?[0-9]+)?', val)
if len(floats) >= 2:
arr = np.array([float(x) for x in floats], dtype=float)
if arr.size % 2 == 0:
pts = arr.reshape(-1,2)
return pts
if arr.size % 3 == 0:
pts = arr.reshape(-1,3)[:, :2]; return pts
except Exception:
logger.exception("try_extract_coords failed for val type %s", type(val))
return None
return None
def build_line_map_from_map_json(map_json: Dict, map_json_path: str) -> Dict[str, np.ndarray]:
"""
Build mapping: line_token -> Nx2 coords.
For nuScenes map format where 'line' entries reference 'node_tokens', we reconstruct
coordinates using the 'node' section.
"""
if map_json_path in _line_map_cache:
return _line_map_cache[map_json_path]
line_map = {}
if not map_json:
_line_map_cache[map_json_path] = line_map
return line_map
# Build node map first
node_map = build_node_map(map_json, map_json_path)
lines_obj = None
if 'line' in map_json:
lines_obj = map_json['line']
else:
for k in map_json.keys():
if 'line' in k.lower():
lines_obj = map_json[k]; break
if lines_obj is None:
logger.warning("No 'line' key in map json")
_line_map_cache[map_json_path] = line_map
return line_map
# lines_obj usually a list of dicts
if isinstance(lines_obj, dict):
# token -> rec
for lid, rec in lines_obj.items():
# try several ways to get coords
coord = None
if isinstance(rec, dict):
# if rec has 'node_tokens' (nuScenes common)
if 'node_tokens' in rec and rec['node_tokens']:
pts=[]
for nt in rec['node_tokens']:
if nt in node_map:
pts.append(list(node_map[nt]))
else:
# not found, skip
pts=[]
break
if pts:
coord = np.asarray(pts)[:, :2].astype(float)
if coord is None:
# try embedded coords
for key in ['xyz','polyline','points','coords','geometry','line','polyline_coords']:
if key in rec and rec[key]:
coord = try_extract_coords(rec[key])
if coord is not None:
break
else:
coord = try_extract_coords(rec)
if coord is not None and coord.shape[0] >= 2:
line_map[str(lid)] = coord
elif isinstance(lines_obj, list):
for rec in lines_obj:
lid = None
coord = None
if isinstance(rec, dict):
lid = rec.get('token') or rec.get('id') or rec.get('uid')
# try node_tokens
if 'node_tokens' in rec and rec['node_tokens']:
pts=[]
for nt in rec['node_tokens']:
if nt in node_map:
pts.append(list(node_map[nt]))
else:
pts=[]
break
if pts:
coord = np.asarray(pts)[:, :2].astype(float)
if coord is None:
for key in ['xyz','polyline','points','coords','geometry','line','polyline_coords']:
if key in rec and rec[key]:
coord = try_extract_coords(rec[key]); break
if coord is None:
# fallback try entire rec
coord = try_extract_coords(rec)
else:
coord = try_extract_coords(rec)
if lid is None:
lid = f"idx_{len(line_map)}"
if coord is not None and coord.shape[0] >= 2:
line_map[str(lid)] = coord
logger.info("Built line_map entries: %d", len(line_map))
# log a few examples
for i, (k,v) in enumerate(list(line_map.items())[:3]):
logger.debug("line_map example %d: token=%s pts=%d first=%s", i, k, v.shape[0], v[0].tolist())
_line_map_cache[map_json_path] = line_map
return line_map
# ---------------- geometry helpers ----------------
def compute_centerline_segments(cl: np.ndarray):
if cl is None or cl.shape[0] < 2:
return np.zeros((0,2)), np.zeros((0,))
seg_mid = (cl[1:] + cl[:-1]) / 2.0
seg_vec = cl[1:] - cl[:-1]
seg_len = np.linalg.norm(seg_vec[:, :2], axis=1)
return seg_mid, seg_len
def lane_distance_to_points(cl: np.ndarray, points: np.ndarray) -> float:
if cl is None or cl.shape[0] < 2 or points.size == 0:
return float('inf')
min_d = float('inf')
for i in range(cl.shape[0] - 1):
a = cl[i]; b = cl[i+1]
ab = b - a
ab2 = ab.dot(ab)
for pt in points:
ap = pt - a
if ab2 == 0:
d = np.linalg.norm(ap)
else:
t = np.clip(ap.dot(ab) / ab2, 0.0, 1.0)
proj = a + t * ab
d = np.linalg.norm(pt - proj)
if d < min_d:
min_d = d
if min_d == 0.0:
return 0.0
return min_d
# ---------------- map extraction pipeline ----------------
def extract_map_summary(map_api_obj: Optional[Any], map_name: str, positions_xy: np.ndarray, radius: float = DEFAULT_MAP_QUERY_RADIUS) -> Dict:
"""
Return summary dict matching previous format:
lane_positions, lane_lengths, centerline_positions, centerline_lengths, centerline_to_lane,
lane_adjacent_edges, lane_predecessor_edges, lane_successor_edges
Strategy:
1) If NuScenesMap exists and provides useful methods, try those (preferred).
2) Else, fall back to parsing maps/<map_name>.json using node & line tables and select line segments near positions.
In fallback, also attempt to parse lane connectivity from connector-like sections (names vary across map versions).
"""
empty = {'lane_positions': [], 'lane_lengths': [], 'centerline_positions': [], 'centerline_lengths': [], 'centerline_to_lane': [[], []], 'lane_adjacent_edges': [], 'lane_predecessor_edges': [], 'lane_successor_edges': []}
if positions_xy is None or positions_xy.size == 0:
logger.debug("positions empty -> return empty map summary")
return empty
cx = float(np.mean(positions_xy[:,0])); cy = float(np.mean(positions_xy[:,1]))
logger.info("Query center for map extraction: (%.3f, %.3f) radius=%.1f", cx, cy, radius)
# 1) try official API if available
if map_api_obj is not None:
logger.info("Trying NuScenesMap API methods (if available)")
try:
has_get_ids = hasattr(map_api_obj, 'get_lane_ids_in_xy_bbox')
has_get_center = hasattr(map_api_obj, 'get_lane_segment_centerline')
logger.debug("NuScenesMap has get_lane_ids_in_xy_bbox=%s get_lane_segment_centerline=%s", has_get_ids, has_get_center)
lane_ids = []
if has_get_ids:
try:
lane_ids = map_api_obj.get_lane_ids_in_xy_bbox(cx, cy, radius)
logger.info("get_lane_ids_in_xy_bbox returned %d ids", len(lane_ids) if lane_ids is not None else 0)
except Exception:
logger.exception("get_lane_ids_in_xy_bbox failed")
lane_ids = []
# try closest lane if none found
if (not lane_ids) and hasattr(map_api_obj, 'get_closest_lane'):
try:
lid = map_api_obj.get_closest_lane(cx, cy, radius)
if lid:
lane_ids = [lid]
logger.info("get_closest_lane returned %s", lid)
except Exception:
logger.exception("get_closest_lane failed")
if lane_ids:
lp, ll, cp, cll, src, dst = [], [], [], [], [], []
lane_id_list = list(lane_ids)
for li, lid in enumerate(lane_id_list):
try:
lane_record = None
if hasattr(map_api_obj, 'get_lane'):
lane_record = map_api_obj.get_lane(lid)
cl = None
# try arcline discretize if available
if ARCLINE_AVAILABLE and lane_record is not None:
try:
poses = arcline_path_utils.discretize_lane(lane_record, resolution_meters=1)
if poses and len(poses) >= 2:
cl = np.array([[p[0], p[1]] for p in poses], dtype=float)
except Exception:
logger.debug("arcline discretize failed for %s", lid, exc_info=True)
# fallback to centerline API
if cl is None and hasattr(map_api_obj, 'get_lane_segment_centerline'):
try:
raw_cl = map_api_obj.get_lane_segment_centerline(lid)
arr = np.asarray(raw_cl)
if arr.ndim >= 2 and arr.shape[1] >= 2:
cl = arr[:, :2].astype(float)
logger.debug("get_lane_segment_centerline for %s: %s pts", lid, None if cl is None else cl.shape[0])
except Exception:
logger.debug("get_lane_segment_centerline exception", exc_info=True)
# fallback parse lane_record dict
if cl is None and isinstance(lane_record, dict):
for key in ['centerline','polyline','coords','points','xyz','geometry','line']:
if key in lane_record and lane_record[key]:
cl = try_extract_coords(lane_record[key])
if cl is not None:
break
if cl is None:
logger.warning("No centerline for lane %s, skipping", lid)
continue
seg_mid, seg_len = compute_centerline_segments(cl)
total_len = float(np.sum(seg_len)) if seg_len.size > 0 else 0.0
if total_len <= 0 or seg_mid.size == 0:
continue
rep = max(0, len(cl)//2)
lp.append([float(cl[rep,0]), float(cl[rep,1])])
ll.append(total_len)
for j in range(seg_mid.shape[0]):
cp.append([float(seg_mid[j,0]), float(seg_mid[j,1])])
cll.append(float(seg_len[j]))
src.append(len(cp)-1); dst.append(li)
except Exception:
logger.exception("Error processing lane id %s", lid)
# connectivity via API (if available)
lane_adj_src, lane_adj_dst = [], []
lane_pred_src, lane_pred_dst = [], []
lane_succ_src, lane_succ_dst = [], []
try:
# get_incoming/get_outgoing available in map_api_obj?
for li, lid in enumerate(lane_id_list):
try:
incoming = []
outgoing = []
if hasattr(map_api_obj, 'get_incoming_lane_ids'):
incoming = map_api_obj.get_incoming_lane_ids(lid) or []
if hasattr(map_api_obj, 'get_outgoing_lane_ids'):
outgoing = map_api_obj.get_outgoing_lane_ids(lid) or []
# predecessors
for pred in incoming:
if pred in lane_id_list:
pidx = lane_id_list.index(pred)
lane_pred_src.append(pidx); lane_pred_dst.append(li)
# successors
for succ in outgoing:
if succ in lane_id_list:
sidx = lane_id_list.index(succ)
lane_succ_src.append(li); lane_succ_dst.append(sidx)
except Exception:
logger.debug("Failed to get incoming/outgoing for %s", lid, exc_info=True)
except Exception:
logger.debug("Connectivity extraction via NuScenesMap failed", exc_info=True)
if lp:
logger.info("NuScenesMap API extraction success: lanes=%d centerline_pts=%d", len(lp), len(cp))
return {'lane_positions': lp, 'lane_lengths': ll, 'centerline_positions': cp, 'centerline_lengths': cll, 'centerline_to_lane': [src, dst], 'lane_adjacent_edges': [lane_adj_src, lane_adj_dst], 'lane_predecessor_edges': [lane_pred_src, lane_pred_dst], 'lane_successor_edges': [lane_succ_src, lane_succ_dst]}
else:
logger.info("NuScenesMap returned no lanes for this area; will fallback to JSON parsing")
except Exception:
logger.exception("NuScenesMap API path threw exception; falling back to JSON parsing")
# 2) fallback: parse maps/<map_name>.json directly
map_json_path = find_map_json_path(NUSCENES_ROOT, map_name) if map_name else None
logger.debug("Fallback map_json_path=%s", map_json_path)
mj = load_map_json_cached(map_json_path) if map_json_path else None
if mj is None:
logger.warning("No map JSON available; returning empty map summary")
return empty
# Build line_map then select lines near positions
line_map = build_line_map_from_map_json(mj, map_json_path)
logger.debug("Total lines available in map json: %d", len(line_map))
if not line_map:
logger.warning("line_map empty after parse; returning empty")
return empty
# select lines whose distance to positions <= radius
selected = []
for lid, cl in line_map.items():
try:
d = lane_distance_to_points(cl, positions_xy)
if d <= radius:
selected.append((lid, cl))
except Exception:
logger.exception("Error computing distance for line %s", lid)
logger.info("Selected %d line(s) near the query center", len(selected))
if not selected:
return empty
# Build lane-like output: treat each selected line as a lane representative
lane_positions, lane_lengths, centerline_positions, centerline_lengths, src, dst = [], [], [], [], [], []
selected_lane_tokens = []
for li, (lid, cl) in enumerate(selected):
seg_mid, seg_len = compute_centerline_segments(cl)
total_len = float(np.sum(seg_len)) if seg_len.size>0 else 0.0
if total_len <= 0 or seg_mid.size==0:
continue
rep = max(0, len(cl)//2)
lane_positions.append([float(cl[rep,0]), float(cl[rep,1])])
lane_lengths.append(total_len)
selected_lane_tokens.append(lid)
for j in range(seg_mid.shape[0]):
centerline_positions.append([float(seg_mid[j,0]), float(seg_mid[j,1])])
centerline_lengths.append(float(seg_len[j]))
src.append(len(centerline_positions)-1); dst.append(li)
logger.info("Built lane_recs map entries: %d", len(selected_lane_tokens))
# Try to extract connectivity from the map JSON.
# Map JSON field names vary across versions. We attempt several likely names:
connectors_candidates = []
for key in ['lane_connector', 'lane_connectors', 'lane_link', 'lane_links', 'connectors', 'connector', 'lane_connectivity', 'laneconnector']:
if key in mj and mj[key]:
logger.info("Found connector-like key in map json: %s", key)
connectors_candidates.append(mj[key])
# also some maps store connectors under other keys - broad scan
for k,v in mj.items():
if isinstance(v, list) and len(v) > 0 and isinstance(v[0], dict):
sample_keys = set(v[0].keys())
if any(x in sample_keys for x in ['from_lane_token','to_lane_token','in_lane_token','out_lane_token','start_lane_token','end_lane_token','from_lane','to_lane','in_lane','out_lane']):
connectors_candidates.append(v)
# flatten connectors candidates
connectors = []
for c in connectors_candidates:
if isinstance(c, list):
connectors.extend(c)
elif isinstance(c, dict):
# dict of token->rec
for tt,rec in c.items():
connectors.append(rec)
# Additional: try to harvest incoming/outgoing stored inside lane records themselves
lane_records_obj = None
if 'lane' in mj:
lane_records_obj = mj['lane']
else:
for k in mj.keys():
if 'lane' in k.lower():
lane_records_obj = mj[k]; break
logger.info("Found %d connector-like records in map json", len(connectors))
# Build token->index map for selected lanes
token_to_idx = {tok: idx for idx, tok in enumerate(selected_lane_tokens)}
lane_adj_src, lane_adj_dst = [], []
lane_pred_src, lane_pred_dst = [], []
lane_succ_src, lane_succ_dst = [], []
# parse connectors
if connectors:
logger.debug("Parsing %d connector records for connectivity", len(connectors))
for rec in connectors:
try:
if not isinstance(rec, dict):
continue
# try many possible field names to extract from/to tokens
in_tok = (rec.get('in_lane_token') or rec.get('from_lane_token') or rec.get('from_lane') or
rec.get('start_lane_token') or rec.get('from') or rec.get('in_lane') or rec.get('lane_from') or rec.get('lane_start_token'))
out_tok = (rec.get('out_lane_token') or rec.get('to_lane_token') or rec.get('to_lane') or
rec.get('end_lane_token') or rec.get('to') or rec.get('out_lane') or rec.get('lane_to') or rec.get('lane_end_token'))
# sometimes fields are lists
if isinstance(in_tok, list) and in_tok:
in_tok = in_tok[0]
if isinstance(out_tok, list) and out_tok:
out_tok = out_tok[0]
# try nested string detections
if (not in_tok or not out_tok):
# inspect all string fields and pick first two tokens that are in token_to_idx
found = []
for v in rec.values():
if isinstance(v, str) and v in token_to_idx:
found.append(v)
if isinstance(v, list):
for it in v:
if isinstance(it, str) and it in token_to_idx:
found.append(it)
if len(found) >= 2:
break
if len(found) >= 2:
in_tok, out_tok = found[0], found[1]
if in_tok and out_tok:
if (in_tok in token_to_idx) and (out_tok in token_to_idx):
src_idx = token_to_idx[in_tok]
dst_idx = token_to_idx[out_tok]
lane_succ_src.append(src_idx); lane_succ_dst.append(dst_idx)
lane_pred_src.append(dst_idx); lane_pred_dst.append(src_idx)
except Exception:
logger.debug("Connector parse failed for rec", exc_info=True)
# parse incoming/outgoing fields inside lane records if present
if lane_records_obj:
logger.info("Inspecting lane records for incoming/outgoing/left/right adjacency fields")
if isinstance(lane_records_obj, dict):
for tok, rec in lane_records_obj.items():
try:
if tok not in token_to_idx:
continue
idx = token_to_idx[tok]
# incoming/outgoing lists (common names)
incs = rec.get('incoming') or rec.get('incoming_lane_tokens') or rec.get('incoming_lane') or rec.get('incoming_lanes') or rec.get('incomings')
outs = rec.get('outgoing') or rec.get('outgoing_lane_tokens') or rec.get('outgoing_lane') or rec.get('outgoing_lanes') or rec.get('outgoings')
if isinstance(incs, list):
for p in incs:
if p in token_to_idx:
lane_pred_src.append(token_to_idx[p]); lane_pred_dst.append(idx)
if isinstance(outs, list):
for s in outs:
if s in token_to_idx:
lane_succ_src.append(idx); lane_succ_dst.append(token_to_idx[s])
# left/right adjacency
left = rec.get('left_lane') or rec.get('adjacent_left') or rec.get('left_neighbor') or rec.get('left')
right = rec.get('right_lane') or rec.get('adjacent_right') or rec.get('right_neighbor') or rec.get('right')
if left and left in token_to_idx:
lane_adj_src.append(idx); lane_adj_dst.append(token_to_idx[left])
if right and right in token_to_idx:
lane_adj_src.append(idx); lane_adj_dst.append(token_to_idx[right])
except Exception:
logger.debug("lane record parse issue", exc_info=True)
elif isinstance(lane_records_obj, list):
for rec in lane_records_obj:
try:
tok = rec.get('token') or rec.get('id') or rec.get('uid')
if not tok or tok not in token_to_idx:
continue
idx = token_to_idx[tok]
incs = rec.get('incoming') or rec.get('incoming_lane_tokens') or rec.get('incoming_lanes')
outs = rec.get('outgoing') or rec.get('outgoing_lane_tokens') or rec.get('outgoing_lanes')
if isinstance(incs, list):
for p in incs:
if p in token_to_idx:
lane_pred_src.append(token_to_idx[p]); lane_pred_dst.append(idx)
if isinstance(outs, list):
for s in outs:
if s in token_to_idx:
lane_succ_src.append(idx); lane_succ_dst.append(token_to_idx[s])
left = rec.get('left_lane') or rec.get('adjacent_left') or rec.get('left_neighbor')
right = rec.get('right_lane') or rec.get('adjacent_right') or rec.get('right_neighbor')
if left and left in token_to_idx:
lane_adj_src.append(idx); lane_adj_dst.append(token_to_idx[left])
if right and right in token_to_idx:
lane_adj_src.append(idx); lane_adj_dst.append(token_to_idx[right])
except Exception:
logger.debug("lane record parse issue", exc_info=True)
logger.debug("Parsed connectivity: succ=%d pred=%d adj=%d", len(lane_succ_src), len(lane_pred_src), len(lane_adj_src))
logger.info("Built fallback map summary: lanes=%d centerline_pts=%d", len(lane_positions), len(centerline_positions))
return {'lane_positions': lane_positions, 'lane_lengths': lane_lengths, 'centerline_positions': centerline_positions, 'centerline_lengths': centerline_lengths, 'centerline_to_lane': [src, dst], 'lane_adjacent_edges': [lane_adj_src, lane_adj_dst], 'lane_predecessor_edges': [lane_pred_src, lane_pred_dst], 'lane_successor_edges': [lane_succ_src, lane_succ_dst]}
# ---------------- Preprocessor (annotation processing) ----------------
class Preprocessor:
def __init__(self, dataroot: str, output_dir: str):
self.dataroot = dataroot
self.output_dir = output_dir
logger.info("Loading NuScenes...")
self.nusc = NuScenes(version='v1.0-trainval', dataroot=dataroot, verbose=False)
self.helper = PredictHelper(self.nusc)
self.scenes = list(self.nusc.scene)
self.inst2cat = {inst['token']: inst.get('category_token') for inst in self.nusc.instance if 'token' in inst}
def category_label(self, instance_token: str) -> int:
cat_token = self.inst2cat.get(instance_token)
if not cat_token:
return 5
try:
cat = self.nusc.get('category', cat_token)
name = cat.get('name', '').lower()
except Exception:
return 5
if any(k in name for k in ['vehicle','car','truck','bus','motor_vehicle']): return 0
if any(k in name for k in ['pedestrian','person']): return 1
if any(k in name for k in ['bicycle','motorcycle','bike','motorbike']): return 2
if 'animal' in name: return 3
if any(k in name for k in ['barrier','obstacle','static','cone','debris','construction']): return 4
return 5
def process_annotation(self, ann_token: str) -> Dict:
try:
ann = self.nusc.get('sample_annotation', ann_token)
sample = self.nusc.get('sample', ann['sample_token'])
ts_current = int(sample['timestamp'])
inst = ann['instance_token']
size = ann.get('size', [0.0,0.0,0.0])
area = float(size[0] * size[1]) if len(size) >= 2 else 0.0
trans = ann.get('translation', [0.0,0.0,0.0])
global_pos = [float(trans[0]), float(trans[1])]
# ego pose: assume first ann in sample is ego
s_anns = sample.get('anns', [])
if len(s_anns) == 0:
ego_xy = [0.0, 0.0]; ego_yaw = 0.0
else:
ego_ann = self.nusc.get('sample_annotation', s_anns[0])
ego_t = ego_ann.get('translation', [0.0,0.0,0.0])
ego_xy = [float(ego_t[0]), float(ego_t[1])]
ego_rot = ego_ann.get('rotation', None)
ego_yaw = Quaternion(ego_rot).yaw_pitch_roll[0] if ego_rot is not None else 0.0
# get past and future full records
past = self.helper.get_past_for_agent(inst, ann['sample_token'], seconds=HISTORY_SECONDS, in_agent_frame=False, just_xy=False)
fut = self.helper.get_future_for_agent(inst, ann['sample_token'], seconds=FUTURE_SECONDS, in_agent_frame=False, just_xy=False)
past_list = past if isinstance(past, list) else (list(past) if past is not None else [])
fut_list = fut if isinstance(fut, list) else (list(fut) if fut is not None else [])
history_points = []
if len(past_list) > 0:
ordered = past_list[::-1]
for r in ordered:
try:
samp_r = self.nusc.get('sample', r['sample_token'])
t = int(samp_r['timestamp'])
xy = r.get('translation', [0.0,0.0,0.0])[:2]
ego_p = global_to_ego_xy((xy[0], xy[1]), tuple(ego_xy), ego_yaw)
history_points.append({'global':[float(xy[0]), float(xy[1])], 'ego':[float(ego_p[0]), float(ego_p[1])], 'timestamp': t})
except Exception:
logger.exception("history point parse failed")
# append current
history_points.append({'global':[float(global_pos[0]), float(global_pos[1])], 'ego': global_to_ego_xy((global_pos[0], global_pos[1]), tuple(ego_xy), ego_yaw), 'timestamp': ts_current})
future_points = []
if len(fut_list) > 0:
for r in fut_list:
try:
samp_r = self.nusc.get('sample', r['sample_token'])
t = int(samp_r['timestamp'])
xy = r.get('translation', [0.0,0.0,0.0])[:2]
ego_p = global_to_ego_xy((xy[0], xy[1]), tuple(ego_xy), ego_yaw)
future_points.append({'global':[float(xy[0]), float(xy[1])], 'ego':[float(ego_p[0]), float(ego_p[1])], 'timestamp': t})
except Exception:
logger.exception("future point parse failed")
complete = (len(history_points) >= MIN_HISTORY_STEPS) and (len(future_points) >= MIN_FUTURE_STEPS)
cat_label = self.category_label(inst)
positions_arr = np.asarray([p['global'] for p in history_points]) if len(history_points) > 0 else np.zeros((0,2))
# derive map name from scene->log
scene_rec = self.nusc.get('scene', sample.get('scene_token')) if sample else None
map_name = None
map_api_obj = None
try:
if scene_rec:
log = self.nusc.get('log', scene_rec['log_token'])
map_name = log.get('location', '').replace(' ', '-').lower()
if NuScenesMap is not None and map_name:
try:
# try instantiate NuScenesMap (may fail due to matplotlib/style issues)
map_api_obj = NuScenesMap(dataroot=self.nusc.dataroot, map_name=map_name)
logger.debug("NuScenesMap instantiated for %s", map_name)
except Exception:
logger.exception("NuScenesMap init failed for %s", map_name)
map_api_obj = None
except Exception:
logger.exception("map loading exception")
map_api_obj = None
map_summary = extract_map_summary(map_api_obj, map_name if map_name else '', positions_arr, radius=DEFAULT_MAP_QUERY_RADIUS)
result = {
'sample_annotation_token': ann_token,
'scene_token': sample.get('scene_token'),
'timestamp': int(ts_current),
'category': int(cat_label),
'area': float(area),
'act': int(0) if any('moving' in self.nusc.get('attribute', at)['name'].lower() for at in ann.get('attribute_tokens', [])) else 1,
'global_position': [float(global_pos[0]), float(global_pos[1])],
'ego_position': [float(global_to_ego_xy((global_pos[0], global_pos[1]), tuple(ego_xy), ego_yaw)[0]), float(global_to_ego_xy((global_pos[0], global_pos[1]), tuple(ego_xy), ego_yaw)[1])],
'heading': float(Quaternion(ann.get('rotation', [1,0,0,0])).yaw_pitch_roll[0] if ann.get('rotation', None) else 0.0),
'trajectory': {'history': history_points, 'future': future_points},
'timestamps': {'history_start': int(history_points[0]['timestamp']) if len(history_points) > 0 else 0, 'future_end': int(future_points[-1]['timestamp']) if len(future_points) > 0 else int(ts_current)},
'complete': bool(complete),
'map': {
'lane_positions': map_summary.get('lane_positions', []),
'lane_lengths': map_summary.get('lane_lengths', []),
'centerline_positions': map_summary.get('centerline_positions', []),
'centerline_lengths': map_summary.get('centerline_lengths', []),
'centerline_to_lane': map_summary.get('centerline_to_lane', [[], []]),
'lane_adjacent_edges': map_summary.get('lane_adjacent_edges', []),
'lane_predecessor_edges': map_summary.get('lane_predecessor_edges', []),
'lane_successor_edges': map_summary.get('lane_successor_edges', [])
}
}
return result
except Exception:
logger.exception("process_annotation failed for %s", ann_token)
return {}
def process_scene(self, scene_idx: int, scene_rec: Dict):
scene_name = scene_rec.get('name', f"scene_{scene_idx+1:03d}")
safe = sanitize_filename(scene_name)
out_path = os.path.join(self.output_dir, f"{safe}.jsonl")
logger.info("Processing scene %d/%d name=%s -> %s", scene_idx+1, len(self.scenes), scene_name, out_path)
count = 0
with open(out_path, 'w', encoding='utf-8') as fout:
s_token = scene_rec.get('first_sample_token')
while s_token:
sample = self.nusc.get('sample', s_token)
for ann_token in sample.get('anns', []):
rec = self.process_annotation(ann_token)
if rec:
fout.write(json.dumps(rec, cls=NumpyEncoder, ensure_ascii=False) + '\n')
count += 1
s_token = sample.get('next', '')
logger.info("Wrote %d annotations to %s", count, out_path)
return out_path
# ---------------- main ----------------
def main():
try:
pre = Preprocessor(NUSCENES_ROOT, OUTPUT_DIR)
total = len(pre.scenes)
logger.info("Total scenes to process: %d", total)
for i, s in enumerate(pre.scenes):
pre.process_scene(i, s)
except KeyboardInterrupt:
logger.info("Interrupted by user")
except Exception:
logger.exception("Fatal error")
traceback.print_exc()
if __name__ == '__main__':
main()
代码是这样的,应该如何修改
最新发布