71. Simplify Path(M)

本文介绍了一个C++实现的算法,用于简化Unix风格的文件绝对路径。该算法通过双端队列处理路径字符串,并能有效应对多种特殊情况,如多个连续斜杠、上一级目录等。

71. Simplify Path

 1 Given an absolute path for a file (Unix-style), simplify it.
 2 
 3 For example,
 4 path = "/home/", => "/home"
 5 path = "/a/./b/../../c/", => "/c"
 6 click to show corner cases.
 7 
 8 Corner Cases:
 9 Did you consider the case where path = "/../"?
10 In this case, you should return "/".
11 Another corner case is the path might contain multiple slashes '/' together, such as "/home//foo/".
12 In this case, you should ignore redundant slashes and return "/home/foo".

 

 1 class Solution {
 2 public:
 3      string simplifyPath(string path) {
 4         deque<string> qs;
 5         string result;
 6         int plen = path.size();
 7         string::size_type curindex = 0, lastindex = 0;
 8 
 9         while (lastindex < plen && (curindex = path.find("/", lastindex)) != string::npos)
10         {
11             if(path.find("//", lastindex))
12             {
13                 qs.push_back(path.substr(lastindex, curindex+1-lastindex));
14                 lastindex = curindex+2;
15             }else if (path.find("./", lastindex)) {
16                 lastindex = curindex+1;
17             }else if (path.find(".//", lastindex)) {
18                 lastindex = curindex+2;
19             }else if (path.find("../", lastindex)) {
20                 qs.pop_back(); // go back one step
21                 lastindex = curindex+1;
22             }else if (path.find("..//", lastindex)) {
23                 qs.pop_back();
24                 lastindex = curindex+2;
25             }else {
26                 qs.push_back(path.substr(lastindex, curindex+1-lastindex));
27                 lastindex = curindex+1;
28             }
29         }
30 
31         while (!qs.empty()) {
32             string tmp = qs.front();
33             qs.pop_front();
34             result.append(tmp);
35         }
36         if(result.size() != 1){
37             result.resize(result.size()-1);
38         }
39         return result;
40     }
41 };
View Code

 

转载于:https://www.cnblogs.com/guxuanqing/p/7503144.html

import torch import torch.onnx import onnxruntime as ort import numpy as np from onnxruntime_extensions import ( onnx_op, PyCustomOpDef as PyOp, make_onnx_model, PyOrtFunction, get_library_path as _get_library_path) torch.ops.load_library("/home/pliu3/AI-preprocess/Picasso-onnx-debug/onnx/build/libmesh_decimation.so") def test_Inversion(): # 定义自定义算子 class SSimplify(torch.autograd.Function): @staticmethod def forward(ctx, vertexIn, faceIn, geometryIn, nvIn_cumsum, mfIn_cumsum, nv2Remove): vertexOut, faceOut, isDegenerate, repOut, mapOut, nvOut, mfOut = torch.ops.mesh_decimate.simplify(vertexIn, faceIn, geometryIn, nvIn_cumsum, mfIn_cumsum, nv2Remove) return vertexOut, faceOut, isDegenerate, repOut, mapOut, nvOut, mfOut @staticmethod def symbolic(g, vertexIn, faceIn, geometryIn, nvIn_cumsum, mfIn_cumsum, nv2Remove): # 关键:域名必须与运行时注册一致("custom") node = g.op("ai.onnx.contrib::Simplify", vertexIn, faceIn, geometryIn, nvIn_cumsum, mfIn_cumsum, nv2Remove, outputs=7) node[0].setType(vertexIn.type()) node[1].setType(faceIn.type()) node[2].setType(faceIn.type()) node[3].setType(faceIn.type()) node[4].setType(faceIn.type()) node[5].setType(faceIn.type()) node[6].setType(faceIn.type()) return node # 封装为模型 class CustomModel(torch.nn.Module): def forward(self, vertexIn, faceIn, geometryIn, nvIn_cumsum, mfIn_cumsum, nv2Remove): return SSimplify.apply(vertexIn, faceIn, geometryIn, nvIn_cumsum, mfIn_cumsum, nv2Remove) # 导出 ONNX 模型 model = CustomModel() vertexIn = torch.load("../vertexIn.pt") faceIn = torch.load("../faceIn.pt") geometryIn = torch.load("../geometry.pt") nvIn_cumsum = torch.load("../nvIn_cumsum.pt") mfIn_cumsum = torch.load("../mfIn_cumsum.pt") nv2Remove = torch.load("../nv2Remove.pt") #vertexOut, faceOut, isDegenerate, repOut, mapOut, nvOut, mfOut torch.onnx.export( model, (vertexIn , faceIn, geometryIn, nvIn_cumsum, mfIn_cumsum, nv2Remove), "simplify.onnx", input_names=["vertexIn", "faceIn", "geometryIn", "nvIn_cumsum", "mfIn_cumsum", "nv2Remove"], output_names=["vertexOut", "faceOut", "isDegenerate", "repOut", "mapOut", "nvOut", "mfOut"], opset_version=17 ) @onnx_op(op_type="Simplify", domain="ai.onnx.contrib", inputs=[PyOp.dt_float, PyOp.dt_int32, PyOp.dt_float, PyOp.dt_int32, PyOp.dt_int32, PyOp.dt_int32], outputs=[PyOp.dt_float, PyOp.dt_int32, PyOp.dt_int32, PyOp.dt_int32, PyOp.dt_int32, PyOp.dt_int32, PyOp.dt_int32]) def fimplify_forward(vertexIn , faceIn, geometryIn, nvIn_cumsum, mfIn_cumsum, nv2Remove): # onnxruntime的自定义算子接受的是numpy的输入 因此需要将numpy转为torch vertexIn_tensor = torch.tensor(vertexIn) faceIn_tensor = torch.tensor(faceIn, dtype=torch.int32) geometryIn_tensor = torch.tensor(geometryIn) nvIn_cumsum_tensor = torch.tensor(nvIn_cumsum, dtype=torch.int32) mfIn_cumsum_tensor = torch.tensor(mfIn_cumsum, dtype=torch.int32) nv2Remove_tensor = torch.tensor(nv2Remove, dtype=torch.int32) vertexOut, faceOut, isDegenerate, repOut, mapOut, nvOut, mfOut = torch.ops.mesh_decimate.simplify(vertexIn_tensor, faceIn_tensor, geometryIn_tensor, nvIn_cumsum_tensor, mfIn_cumsum_tensor, nv2Remove_tensor) return vertexOut.numpy(), faceOut.numpy(), isDegenerate.numpy(), repOut.numpy(), mapOut.numpy(), nvOut.numpy(), mfOut.numpy() # 加载模型时传递 SessionOptions so = ort.SessionOptions() so.register_custom_ops_library(_get_library_path()) session = ort.InferenceSession("./simplify.onnx", so, providers=['CPUExecutionProvider']) vertexIn_in = vertexIn.numpy() faceIn_in = faceIn.numpy() geometryIn_in = geometryIn.numpy() nvIn_cumsum_in = nvIn_cumsum.numpy() mfIn_cumsum_in = mfIn_cumsum.numpy() nv2Remove_in = nv2Remove.numpy() # 运行推理 ort_output = session.run( output_names=["vertexOut", "faceOut", "isDegenerate", "repOut", "mapOut", "nvOut", "mfOut"], input_feed={"vertexIn": vertexIn_in, "faceIn": faceIn_in, "geometryIn": geometryIn_in, "nvIn_cumsum": nvIn_cumsum_in, "mfIn_cumsum": mfIn_cumsum_in, "nv2Remove": nv2Remove_in, } )[0] # pytorch_output = model(vertexIn , faceIn, geometryIn, nvIn_cumsum, mfIn_cumsum, nv2Remove) print("自定义算子输出: ", ort_output[0].shape) vertexOut, faceOut, isDegenerate, repOut, mapOut, nvOut, mfOut = model(vertexIn , faceIn, geometryIn, nvIn_cumsum, mfIn_cumsum, nv2Remove) print(vertexOut.shape) 我这段代码写的对吗 为什么运行后 2025-06-14 18:50:50.964957476 [W:onnxruntime:, execution_frame.cc:870 VerifyOutputSizes] Expected shape from model of {-1,-1} does not match actual shape of {3738} for output isDegenerate 2025-06-14 18:50:50.965330776 [W:onnxruntime:, execution_frame.cc:870 VerifyOutputSizes] Expected shape from model of {-1,-1} does not match actual shape of {1865} for output repOut 2025-06-14 18:50:50.965528976 [W:onnxruntime:, execution_frame.cc:870 VerifyOutputSizes] Expected shape from model of {-1,-1} does not match actual shape of {1865} for output mapOut 2025-06-14 18:50:50.965745576 [W:onnxruntime:, execution_frame.cc:870 VerifyOutputSizes] Expected shape from model of {-1,-1} does not match actual shape of {1} for output nvOut 2025-06-14 18:50:50.965772376 [W:onnxruntime:, execution_frame.cc:870 VerifyOutputSizes] Expected shape from model of {-1,-1} does not match actual shape of {1} for output mfOut 自定义算子输出: (3,) torch.Size([1865, 3])
06-15
我的这个代码里面已经正确定义了如何读取STL文件并合并,以及如何读取一个个文件夹的图像数据(这部分你直接用我这个代码的。)我需要你改进的是calibration算法,也就是如何正确匹配3D和2D点,并计算intrinsic matrix。然后需要将每一张图像匹配的结果像这个代码一样可视化,左图是stl文件渲染图,右图是读取图像,然后特征点连线一下 #!/usr/bin/env python3 """ Intrinsic calibration & paired-image visualisation for the da Vinci LND tool Author: Wenzheng Cheng | last update 2025-06-12 左图完全复用 lnd.py 渲染逻辑(1000×800, elev 0, azim 0, roll 120)。 右图 = mask 等比 resize → 1000×800。 输出: *_pair.jpg (左渲右 mask + 彩线匹配) """ import os, argparse, math, xml.etree.ElementTree as ET import cv2, trimesh, numpy as np import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt from matplotlib.backends.backend_agg import FigureCanvasAgg from mpl_toolkits.mplot3d.art3d import Poly3DCollection from natsort import natsorted from tqdm import tqdm # ------------------------------------------------- # # ------------ 常量与 LND 白名单 ------------------- # # ------------------------------------------------- # LND_XML_PATH = "/home/iulian/chole_ws/src/drrobot/mujoco_menagerie/lnd/lnd.xml" LND_ASSET_DIR = "/home/iulian/chole_ws/src/drrobot/mujoco_menagerie/lnd/assets" WL = {"jaw_1","jaw_2","jaw_pad_1","jaw_pad_2", "pitch_mech","pitch_screw", "pitch_wheel_1","pitch_wheel_2","pitch_wheel_3","pitch_wheel_4", "wheel_1","wheel_2","wheel_3","wheel_4","yaw_screw"} # 左图渲染固定参数 —— 与 lnd.py 完全一致 RENDER_W, RENDER_H = 1000, 800 CAM_ELEV, CAM_AZIM, CAM_ROLL = 0.0, 0.0, 120.0 # ------------------------------------------------- # # ---------------- LND 载入 + 渲染 ------------------ # # ------------------------------------------------- # def parse(xml, asset_dir=LND_ASSET_DIR): root = ET.parse(xml).getroot() return [os.path.join(asset_dir, m.get("file")) for m in root.findall(".//asset/mesh") if m.get("name") in WL] def load_merge(paths): meshes = [] for p in paths: m = trimesh.load_mesh(p, process=False) if isinstance(m, trimesh.Scene): m = trimesh.util.concatenate(tuple(m.geometry.values())) m.apply_scale(1000.0) # m → mm meshes.append(m) return trimesh.util.concatenate(meshes) # === lnd.py 原汁渲染 === def _plot_trimesh(ax, mesh): try: tgt = max(10_000, int(len(mesh.faces)*0.3)) mesh_sub = mesh.simplify_quadratic_decimation(tgt) except Exception: mesh_sub = mesh v, f = mesh_sub.vertices, mesh_sub.faces ax.add_collection3d(Poly3DCollection( v[f], facecolor=[.8,.8,.8], edgecolor=[.4,.4,.4], linewidth=0.15)) span = v.max(0) - v.min(0); cen = v.mean(0); R = span.max()*0.6 for setter,c in zip([ax.set_xlim,ax.set_ylim,ax.set_zlim], cen): setter(c-R, c+R) def render_lnd(mesh): fig = plt.figure(figsize=(RENDER_W/100, RENDER_H/100), dpi=100, facecolor="black") ax = fig.add_subplot(111, projection='3d', facecolor="black") ax.view_init(elev=CAM_ELEV, azim=CAM_AZIM, roll=CAM_ROLL) ax.axis('off') _plot_trimesh(ax, mesh) plt.tight_layout(pad=0) canvas = FigureCanvasAgg(fig); canvas.draw() # Matplotlib ≥3.8:改用 buffer_rgba(),再丢掉 alpha 通道 buf = np.asarray(canvas.buffer_rgba()) # shape (H,W,4) img = buf[...,:3].copy() # → uint8 RGB plt.close(fig) return img # ------------------------------------------------- # # ----------- 数学 / 采样 / PnP 工具函数 ----------- # # ------------------------------------------------- # def view_to_rvec(elev, azim, roll): def Rz(t): return np.array([[ math.cos(t),-math.sin(t),0], [ math.sin(t), math.cos(t),0], [0,0,1]]) def Rx(t): return np.array([[1,0,0], [0, math.cos(t),-math.sin(t)], [0, math.sin(t), math.cos(t)]]) R = Rz(np.radians(azim)) @ Rx(np.radians(elev)) @ Rz(np.radians(roll)) return cv2.Rodrigues(R)[0].astype(np.float32) def sample_surface(mesh, n): pts,_ = trimesh.sample.sample_surface(mesh, n) return pts.astype(np.float32) def uniform_mask_points(mask, max_n): ys,xs = np.where(mask>0) if len(xs)==0: return np.empty((0,2),np.float32) if len(xs)>max_n: sel = np.random.choice(len(xs), max_n, False) xs,ys = xs[sel], ys[sel] pts = np.stack([xs,ys],1).astype(np.float32) pts += np.random.rand(*pts.shape)-0.5 return pts def pnp(obj,img,K): ok,r,t,_ = cv2.solvePnPRansac(obj,img,K,None, flags=cv2.SOLVEPNP_EPNP,iterationsCount=800,reprojectionError=3) if not ok: raise RuntimeError("PnP fail") return r,t def mask_consistent(mask,r,t,K,pts3,max_out=1200): proj,_ = cv2.projectPoints(pts3,r,t,K,None) proj = proj.reshape(-1,2).astype(int) h,w = mask.shape good = (proj[:,0]>=0)&(proj[:,0]<w)&(proj[:,1]>=0)&(proj[:,1]<h) proj,obj = proj[good], pts3[good] keep = mask[proj[:,1],proj[:,0]]>0 obj,proj = obj[keep], proj[keep].astype(np.float32) if len(obj)>max_out: sel = np.random.choice(len(obj), max_out, False) obj,proj = obj[sel], proj[sel] return obj,proj # ------------------------------------------------- # # ---------------- 可视化 (单图) ------------------ # # ------------------------------------------------- # def scale_pts(pts, sx, sy): return (pts * np.array([[sx,sy]])).astype(int) def draw_pair(mask, proj, img_pts, dense_proj, save_path, lnd_img): # 1) 把 mask resize → 左图同分辨率 mask_resized = cv2.resize(mask, (RENDER_W, RENDER_H), interpolation=cv2.INTER_NEAREST) right_img = cv2.cvtColor(mask_resized, cv2.COLOR_GRAY2BGR) canvas = np.concatenate([lnd_img, right_img], axis=1) # 2) 坐标缩放系数 h0,w0 = mask.shape sx, sy = RENDER_W / w0, RENDER_H / h0 img_pts_s = scale_pts(img_pts, sx, sy) + np.array([RENDER_W,0]) proj_s = scale_pts(proj, sx, sy) dense_scaled= scale_pts(dense_proj, sx, sy) for p in dense_scaled: cv2.circle(canvas, tuple(p), 1, (80,80,80), -1) rng = np.random.RandomState(0) for (x1,y1),(x2,y2),c in zip(proj_s, img_pts_s, rng.randint(0,255,(len(img_pts_s),3)).tolist()): cv2.circle(canvas,(x1,y1),3,c,-1) cv2.circle(canvas,(x2,y2),3,c,-1) cv2.line(canvas,(x1,y1),(x2,y2),c,1) cv2.imwrite(save_path, canvas) # ------------------------------------------------- # # ----------------- 数据集遍历工具 ----------------- # # ------------------------------------------------- # # ---------- 路径同时兼容 seg_masks 与 left_img_dir ---------- def collect_imgs(video): seg = os.path.join(video, "seg_masks") if os.path.isdir(os.path.join(video, "seg_masks")) else os.path.join(video, "left_img_dir") return [os.path.join(seg, f) for f in natsorted(os.listdir(seg)) if f.lower().endswith((".png", ".jpg", ".jpeg", ".bmp"))] def iterate(root): return [os.path.join(root,d) for d in natsorted(os.listdir(root)) if os.path.isdir(os.path.join(root,d))] # ---------- 灰度角点替代 uniform 像素 ---------- def detect_corners(gray, n): c = cv2.goodFeaturesToTrack(gray, maxCorners=n, qualityLevel=0.01, minDistance=3) return np.empty((0, 2), np.float32) if c is None else c.reshape(-1, 2).astype(np.float32) def scale_to_canvas(pts, w, h): if len(pts) == 0: return pts lo, hi = pts.min(0), pts.max(0) c = (lo + hi) / 2; span = np.clip(hi - lo, 1e-3, None) s = min(0.85 * w / span[0], 0.85 * h / span[1]) return (pts - c) * s + np.array([w / 2, h / 2]) # ------------------------------------------------- # # --------------------------- main ---------------- # # ------------------------------------------------- # def main(): ag = argparse.ArgumentParser() ag.add_argument("--path", required=True) ag.add_argument("--vis_dir", default="") ag.add_argument("--samples", type=int, default=10_000) ag.add_argument("--max_pts", type=int, default=800) args = ag.parse_args() if args.vis_dir: os.makedirs(args.vis_dir, exist_ok=True) mesh = load_merge(parse(LND_XML_PATH)) lnd_img = render_lnd(mesh) dense_pts = sample_surface(mesh, args.samples) rvec_fixed = view_to_rvec(CAM_ELEV,CAM_AZIM,CAM_ROLL) tvec_zero = np.zeros((3,1),np.float32) for vid in iterate(args.path): paths = collect_imgs(vid) if not paths: continue first_rgb = cv2.imread(paths[0]); h0, w0 = first_rgb.shape[:2] K0 = np.array([[0.8 * w0, 0, w0 / 2], [0, 0.8 * w0, h0 / 2], [0, 0, 1]], float) dense_proj0, _ = cv2.projectPoints(dense_pts, rvec_fixed, tvec_zero, K0, None) dense_proj0 = dense_proj0.reshape(-1, 2) obj_list, img_list = [], [] for p in tqdm(paths, desc=os.path.basename(vid)): rgb = cv2.imread(p); gray = cv2.cvtColor(rgb, cv2.COLOR_BGR2GRAY) img_pts = detect_corners(gray, args.max_pts) if len(img_pts) < 6: continue obj_guess = dense_pts[np.random.choice(len(dense_pts), len(img_pts), False)] try: r, t = pnp(obj_guess, img_pts, K0) except RuntimeError: continue full_mask = np.ones_like(gray, np.uint8) * 255 # 让 mask_consistent 不过滤 obj, img_pts_f = mask_consistent(full_mask, r, t, K0, dense_pts, 1200) if len(obj) < 6: continue obj_list.append(obj); img_list.append(img_pts_f) if args.vis_dir: proj, _ = cv2.projectPoints(obj, r, t, K0, None) proj = proj.reshape(-1, 2) fname = os.path.splitext(os.path.basename(p))[0] right_vis = cv2.resize(rgb, (RENDER_W, RENDER_H)) canvas = np.concatenate([lnd_img, right_vis], 1) proj_s = scale_to_canvas(proj, RENDER_W, RENDER_H).astype(int) img_s = scale_to_canvas(img_pts_f, RENDER_W, RENDER_H).astype(int) + np.array([RENDER_W, 0]) for (x1, y1), (x2, y2) in zip(proj_s, img_s): cv2.circle(canvas, (x1, y1), 3, (0, 255, 0), -1) cv2.circle(canvas, (x2, y2), 3, (0, 255, 0), -1) cv2.line(canvas, (x1, y1), (x2, y2), (0, 255, 0), 1) cv2.imwrite(os.path.join(args.vis_dir, f"{fname}_pair.jpg"), canvas) if len(obj_list)<3: continue flag = getattr(cv2,"CALIB_FIX_SKEW",0) flags = cv2.CALIB_USE_INTRINSIC_GUESS|flag|cv2.CALIB_ZERO_TANGENT_DIST|\ cv2.CALIB_FIX_K3|cv2.CALIB_FIX_K4|cv2.CALIB_FIX_K5|cv2.CALIB_FIX_K6 rms,K,dist,*_ = cv2.calibrateCamera( obj_list,img_list,(w0,h0),K0,None,flags=flags, criteria=(cv2.TERM_CRITERIA_EPS+cv2.TERM_CRITERIA_COUNT,100,1e-6)) print(f"\n[VIDEO] {os.path.basename(vid)} RMS={rms:.3f}px") print("K=\n",K,"\n(k1,k2)=",dist.ravel()[:2]) if __name__ == "__main__": #python intrinsic_matrix.py --path /home/iulian/chole_ws/data/tissue_lift/tissue_1/tissue_lift/ #python intrinsic_matrix.py --path /home/iulian/chole_ws/data/lift/tissue_1/lift/ --vis_dir /home/iulian/chole_ws/src/drrobot/K_vis --max_pts 800 #python intrinsic_matrix.py --path /home/iulian/chole_ws/data/needle_pickup/tissue_1/needle_pickup/ #python intrinsic_matrix.py --path /home/iulian/chole_ws/data/check --vis_dir /home/iulian/chole_ws/src/drrobot/K_vis --max_pts 800 main()
06-18
from ultralytics import YOLO import os import yaml import torch import matplotlib.pyplot as plt from IPython.display import Image class YOLOv8CustomTrainer: def __init__(self, data_config_path, model_size='n'): """ 初始化YOLOv8训练器 Args: data_config_path (str): 数据集配置文件路径 model_size (str): 模型大小 ('n', 's', 'm', 'l', 'x') """ self.data_config_path = data_config_path self.model_size = model_size self.model = None self.device = 'cuda' if torch.cuda.is_available() else 'cpu' print(f"使用设备: {self.device}") # 验证数据集配置 self._validate_data_config() def _validate_data_config(self): """验证数据集配置文件""" if not os.path.exists(self.data_config_path): raise FileNotFoundError(f"数据集配置文件不存在: {self.data_config_path}") with open(self.data_config_path, 'r') as f: data_config = yaml.safe_load(f) # 检查关键字段 required_keys = ['train', 'val', 'nc', 'names'] for key in required_keys: if key not in data_config: raise ValueError(f"数据集配置缺少必要字段: {key}") # 检查路径是否存在 for path_key in ['train', 'val']: path = data_config[path_key] if not os.path.exists(path): raise FileNotFoundError(f"数据集路径不存在: {path}") print("✅ 数据集配置验证通过") def load_model(self, pretrained=True): """ 加载YOLOv8模型 Args: pretrained (bool): 是否加载预训练权重 """ model_name = f'yolov8{self.model_size}.pt' print(f"加载模型: {model_name}") if pretrained: # 从官方预训练权重加载 self.model = YOLO(model_name) else: # 从头开始训练 self.model = YOLO(f'yolov8{self.model_size}.yaml').load(model_name) return self.model def train(self, epochs=100, imgsz=640, batch=16, **kwargs): """ 训练YOLOv8模型 Args: epochs (int): 训练轮数 imgsz (int): 输入图像尺寸 batch (int): 批次大小 **kwargs: 其他训练参数 Returns: dict: 训练结果 """ if self.model is None: self.load_model() # 优化训练参数(参考引用[1]) train_params = { 'data': self.data_config_path, 'epochs': epochs, 'imgsz': imgsz, 'batch': batch, 'device': self.device, 'optimizer': 'AdamW', # 推荐使用AdamW优化器 'lr0': 0.01, # 初始学习率 'lrf': 0.01, # 最终学习率 = lr0 * lrf 'momentum': 0.937, 'weight_decay': 0.0005, 'warmup_epochs': 3.0, # 预热轮数 'warmup_momentum': 0.8, 'warmup_bias_lr': 0.1, 'box': 7.5, # 框损失权重 'cls': 0.5, # 分类损失权重 'dfl': 1.5, # 分布焦点损失权重 'close_mosaic': 10, # 最后10轮关闭Mosaic增强 'amp': True, # 自动混合精度训练 'patience': 50, # 早停轮数 'save': True, 'save_period': 10, # 每10个epoch保存一次 'cache': True, # 使用RAM缓存加速训练 'single_cls': False, # 多类别训练 'cos_lr': True, # 余弦学习率调度 'overlap_mask': True, 'mask_ratio': 4, 'dropout': 0.0, # 分类器dropout概率 'name': 'custom_train', 'project': 'runs/detect', **kwargs # 允许覆盖默认参数 } print("开始训练模型...") results = self.model.train(**train_params) # 保存训练曲线 self._plot_training_curves(results) print("✅ 训练完成!") return results def _plot_training_curves(self, results): """绘制训练曲线""" try: # 损失曲线 plt.figure(figsize=(12, 6)) plt.subplot(1, 2, 1) plt.plot(results.results['train/box_loss'], label='Box Loss') plt.plot(results.results['train/cls_loss'], label='Cls Loss') plt.plot(results.results['train/dfl_loss'], label='DFL Loss') plt.title('Training Loss') plt.xlabel('Epoch') plt.ylabel('Loss') plt.legend() # 验证指标 plt.subplot(1, 2, 2) plt.plot(results.results['metrics/precision'], label='Precision') plt.plot(results.results['metrics/recall'], label='Recall') plt.plot(results.results['metrics/mAP50'], label='mAP50') plt.plot(results.results['metrics/mAP50-95'], label='mAP50-95') plt.title('Validation Metrics') plt.xlabel('Epoch') plt.ylabel('Value') plt.legend() plt.tight_layout() plt.savefig('training_curves.png') print("训练曲线已保存为 training_curves.png") except Exception as e: print(f"绘制训练曲线失败: {e}") def validate(self, model_path=None, imgsz=640): """ 验证模型性能 Args: model_path (str): 模型路径 imgsz (int): 验证图像尺寸 Returns: dict: 验证结果 """ if model_path: model = YOLO(model_path) else: model = self.model if model is None: raise ValueError("没有可用的模型进行验证") print("开始模型验证...") metrics = model.val( data=self.data_config_path, imgsz=imgsz, batch=16, conf=0.001, # 置信度阈值 iou=0.6, # IoU阈值 device=self.device, split='val', # 使用验证集 plots=True, # 生成混淆矩阵等 save_json=True, # 保存JSON格式结果 save_hybrid=False, half=False, # 使用FP32精度 rect=True, # 矩形验证 verbose=True ) # 打印关键指标 print(f"验证结果:") print(f" mAP@0.5: {metrics.box.map50:.4f}") print(f" mAP@0.5:0.95: {metrics.box.map:.4f}") print(f" 精确率: {metrics.box.precision:.4f}") print(f" 召回率: {metrics.box.recall:.4f}") # 显示混淆矩阵 try: Image(filename=f'{metrics.save_dir}/confusion_matrix.png') except: pass print("✅ 验证完成!") return metrics def predict(self, source, model_path=None, conf=0.25, imgsz=640): """ 使用训练好的模型进行预测 Args: source (str): 预测源(图片/视频/目录) model_path (str): 模型路径 conf (float): 置信度阈值 imgsz (int): 图像尺寸 """ if model_path: model = YOLO(model_path) else: model = self.model if model is None: raise ValueError("没有可用的模型进行预测") print(f"开始预测: {source}") results = model.predict( source=source, conf=conf, imgsz=imgsz, save=True, # 保存带检测结果的图像 save_txt=False, # 保存检测结果文本 save_conf=True, # 保存置信度 save_crop=False, # 保存裁剪的检测结果 show_labels=True, # 显示标签 show_conf=True, # 显示置信度 show_boxes=True, # 显示边界框 line_width=2, # 边界框线宽 visualize=False, # 可视化模型特征 augment=False, # 测试时数据增强 agnostic_nms=False, # 类别无关NMS retina_masks=False, boxes=True, device=self.device ) print(f"✅ 预测完成! 结果保存在 {results[0].save_dir}") return results def export(self, model_path, format='onnx', imgsz=640): """ 导出模型 Args: model_path (str): 模型路径 format (str): 导出格式 ('onnx', 'torchscript', 'tflite', 'tfjs') imgsz (int): 导出图像尺寸 """ model = YOLO(model_path) print(f"导出模型为 {format.upper()} 格式...") export_params = { 'format': format, 'imgsz': imgsz, 'keras': False, 'optimize': True, # ONNX优化 'half': False, # FP32精度 'int8': False, 'dynamic': False, 'simplify': True, # ONNX简化 'opset': 12, # ONNX opset版本 'workspace': 4, # TensorRT工作空间大小(GB) 'nms': False, 'batch': 1 } model.export(**export_params) # 检查导出文件 export_file = model_path.replace('.pt', f'.{format}') if os.path.exists(export_file): print(f"✅ 模型已导出: {export_file}") return export_file else: raise RuntimeError(f"模型导出失败: {export_file} 不存在") def main(): # === 配置参数 === DATA_CONFIG_PATH = "path/to/your/data.yaml" # 数据集配置文件 MODEL_SIZE = 's' # 模型大小: n(nano), s(small), m(medium), l(large), x(xlarge) EPOCHS = 100 # 训练轮数 IMGSZ = 640 # 输入图像尺寸 BATCH_SIZE = 16 # 批次大小 # 创建训练器 trainer = YOLOv8CustomTrainer( data_config_path=DATA_CONFIG_PATH, model_size=MODEL_SIZE ) # === 训练模型 === trainer.load_model(pretrained=True) # 使用预训练权重 trainer.train( epochs=EPOCHS, imgsz=IMGSZ, batch=BATCH_SIZE, # 高级参数调整 lr0=0.01, # 初始学习率 weight_decay=0.05, # 权重衰减 dropout=0.2, # 防止过拟合 mosaic=1.0, # Mosaic数据增强概率 mixup=0.1, # MixUp数据增强概率 copy_paste=0.1, # Copy-Paste数据增强概率 name='custom_train_experiment' ) # === 验证模型 === best_model_path = "runs/detect/custom_train_experiment/weights/best.pt" trainer.validate(model_path=best_model_path, imgsz=IMGSZ) # === 使用模型预测 === trainer.predict( source="path/to/test/images", model_path=best_model_path, conf=0.25, imgsz=IMGSZ ) # === 导出模型 === trainer.export( model_path=best_model_path, format='onnx', imgsz=IMGSZ ) if __name__ == "__main__": main() 这段代码我需要修改的地方,以及添加的路径进行红色标注
07-19
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值