解决自定义batch_norm训练时报错:TypeError: cannot assign ‘torch.cuda.FloatTensor‘ as parameter ‘running_mean‘

在训练自定义的BatchNorm层时,由于运行均值和方差需要保存为Parameter类型,作者修改了开源代码,将running_mean和running_var设为nn.Parameter。遇到的问题是batch_norm函数返回的running_mean和running_var是torch.cuda.FloatTensor类型,不能直接赋值给Parameter。为解决此问题,作者进行了类型转换,从而消除了报错并能正常保存参数。
部署运行你感兴趣的模型镜像

        在尝试量化网络时涉及到了自定义bn层,从网上找了开源代码如下:

        我做了一些改动,将running_mean和running_var设置为了Parameter,如果直接赋值为tensor类型变量的话,是不会保存这两个参数的,会造成很多不便,但同时也不会报题目中的错误。

class BatchNorm(nn.Module):
    def __init__(self,num_features,num_dims,w_bit, in_bit, l_shift, out_bit):
        super(BatchNorm,self).__init__()
        if num_dims == 2:
            shape = (1,num_features)
        else:
            shape = (1,num_features,1,1)
        self.gamma = nn.Parameter(torch.ones(shape))
        self.beta = nn.Parameter(torch.zeros(shape))

        self.running_mean = nn.Parameter(torch.zeros(shape), requires_grad=False)
        self.running_var = nn.Parameter(torch.ones(shape), requires_grad=False)
        self.eps = 1e-5
        self.w_bit = w_bit
        self.in_bit = in_bit
        self.l_shift = l_shift
        self.out_bit = out_bit

    def forward(self,X):
        # print(self.gamma)
        if self.running_mean.device != X.device:
            self.running_mean = self.running_mean.to(X.device)
            self.running_var = self.running_var.to(X.device)
        Y,running_mean_out,running_var_out = batch_norm(self.w_bit, self.in_bit, self.l_shift, self.out_bit, self.training,X,self.gamma,self.beta,self.running_mean,self.running_var,eps=1e-5,momentum=0.9)
        return Y

        在训练的时候会报题目中的错误,报错的根源在于:

Y,running_mean_out,running_var_out = batch_norm(self.w_bit, self.in_bit, self.l_shift, self.out_bit, self.training,X,self.gamma,self.beta,self.running_mean,self.running_var,eps=1e-5,momentum=0.9)

        其原因正如报错中所提到的,从batch_norm()函数中返回的是一个torch.cuda.FloatTensor类型的running_mean,不能直接赋值给Parameter类型,所以我们在此处可以做一下类型转换,即可消除这个报错,我训练下来应该没啥大问题,如果有错请大佬指正!

Y,running_mean_out,running_var_out = batch_norm(self.w_bit, self.in_bit, self.l_shift, self.out_bit, self.training,X,self.gamma,self.beta,self.running_mean,self.running_var,eps=1e-5,momentum=0.9)
self.running_mean = nn.Parameter(running_mean_out,requires_grad=False)
self.running_var = nn.Parameter(running_var_out,requires_grad=False)

        之前网上提到.cuda()解决的办法,但经过实际测试,添加之后无法保存参数,于是放弃了这种操作。

您可能感兴趣的与本文相关的镜像

PyTorch 2.6

PyTorch 2.6

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

# -*- coding: utf-8 -*- # 重新增加了然门控变得更快得方式:1.beta_l0更大;2.log_alpha的学习率变为2.0;3.添加熵正则化。 from __future__ import annotations import math import os import random import time from collections import deque from pathlib import Path from typing import Tuple import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.optim.lr_scheduler import CosineAnnealingLR from torch.utils.data import DataLoader from torchvision import datasets, models, transforms from sklearn.cluster import KMeans import numpy as np import pandas as pd import matplotlib.pyplot as plt from sklearn.metrics import ( silhouette_score, silhouette_samples, calinski_harabasz_score, davies_bouldin_score, ) from sklearn.manifold import TSNE try: import umap # 只有 umap-learn 才带 UMAP 类 HAS_UMAP = hasattr(umap, "UMAP") or hasattr(umap, "umap_") except ImportError: HAS_UMAP = False from datetime import datetime from matplotlib.patches import Rectangle import warnings # -------------------------- Global configuration -------------------------- # class CFG: # Paths data_root: str = r"D:\dataset\TILDA_8class_73" save_root: str = r"D:\SCI_exp\7_29\exp_file" # Dataset & DL batch_size: int = 128 num_workers: int = 0 # tune to your CPU img_size: int = 224 # F2013 images are 48×48; we upscale for ResNet‐18 # Model dimensions (§3.5.1) d_backbone: int = 512 d_proj: int = 128 K_max: int = 3 mem_size: int = 4096 # Optimisation (§3.5.1) lr_warmup: float = 1e-3 lr_joint: float = 3e-4 lr_ft: float = 1e-4 weight_decay: float = 5e-4 n_epochs_warmup: int = 15#5 n_epochs_joint: int = 150 #20 n_epochs_ft: int = 25 #15 # Loss hyper‑params lambda1: float = 0.5 # push&ndash;pull alpha_proto: float = 0.1 scale_ce: float = 30.0 gamma_se: float = 20 # 自表示权重 0.5 # ---------- Hard-Concrete ---------- tau0_hc: float = 1.5 # 初始温度 tau_min_hc: float = 0.15 # 最低温度 anneal_epochs_hc: int = 5 gamma_hc: float = -0.1 # stretch 下界 zeta_hc: float = 1.1 # stretch 上界 beta_l0: float = 5e-2 # L0 正则系数 5e-2 hc_threshold: float = 0.35 # Misc seed: int = 42 device: str = "cuda" if torch.cuda.is_available() else "cpu" # ---------- datetime ---------- # def get_timestamp(): """获取当前间戳,格式:YYYYMMDD_HHMMSS""" return datetime.now().strftime("%Y%m%d_%H%M%S") # ---------- diagnostics ---------- # MAX_SAMPLED = 5_000 # None → 全量 timestamp = get_timestamp() # 获取当前间戳 DIAG_DIR = Path(CFG.save_root) / f"diagnostics_{timestamp}" # 文件夹名包含间戳 DIAG_DIR.mkdir(parents=True, exist_ok=True) # -------------------------- Reproducibility -------------------------- # torch.manual_seed(CFG.seed) random.seed(CFG.seed) # -------------------------- Utility functions -------------------------- # def L2_normalise(t: torch.Tensor, dim: int = 1, eps: float = 1e-12) -> torch.Tensor: return F.normalize(t, p=2, dim=dim, eps=eps) def pairwise_cosine(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """Compute cosine similarity between all pairs in *x* and *y*.""" x = L2_normalise(x) y = L2_normalise(y) return x @ y.T # (N, M) # -------------------------- Memory bank (FIFO queue) -------------------------- # class MemoryBank: """Fixed‑size FIFO queue storing (p, q, y_c). All tensors are detached.""" def __init__(self, dim: int, size: int): self.size = size self.dim = dim self.ptr = 0 self.is_full = False # pre‑allocate self.p_bank = torch.zeros(size, dim, device=CFG.device) self.q_bank = torch.zeros_like(self.p_bank) self.y_bank = torch.zeros(size, dtype=torch.long, device=CFG.device) @torch.no_grad() def enqueue(self, p: torch.Tensor, q: torch.Tensor, y: torch.Tensor): b = p.size(0) if b > self.size: p, q, y = p[-self.size:], q[-self.size:], y[-self.size:] b = self.size idx = (torch.arange(b, device=CFG.device) + self.ptr) % self.size self.p_bank[idx] = p.detach() self.q_bank[idx] = q.detach() self.y_bank[idx] = y.detach() self.ptr = (self.ptr + b) % self.size if self.ptr == 0: self.is_full = True def get(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: valid = self.size if self.is_full else self.ptr return ( self.p_bank[:valid].detach(), self.q_bank[:valid].detach(), self.y_bank[:valid].detach(), ) # -------------------------- Projection heads -------------------------- # class MLPHead(nn.Module): def __init__(self, in_dim: int, out_dim: int): super().__init__() self.mlp = nn.Sequential( nn.Linear(in_dim, out_dim//2, bias=False), nn.BatchNorm1d(out_dim//2), nn.ReLU(inplace=True), nn.Linear(out_dim//2, out_dim, bias=True), ) def forward(self, x: torch.Tensor): return self.mlp(x) # -------------------------- Cosine classifier -------------------------- # class CosineLinear(nn.Module): """Cosine classifier with fixed scale *s* (Eq. CE).""" def __init__(self, in_dim: int, n_classes: int, s: float = CFG.scale_ce): super().__init__() self.s = s self.weight = nn.Parameter(torch.randn(n_classes, in_dim)) nn.init.xavier_uniform_(self.weight) def forward(self, x: torch.Tensor): # x ∈ ℝ^{B×d_p} x = L2_normalise(x) w = L2_normalise(self.weight) # logits = s * cos(θ) return self.s * (x @ w.T) # -------------------------- BaPSTO model -------------------------- # class BaPSTO(nn.Module): """Backbone + DASSER heads + BPGSNet prototypes & gates.""" def __init__(self, n_classes: int): super().__init__() # --- Backbone (ResNet‑18) ------------------------------------------------ resnet = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1) pretrained_path = Path(CFG.save_root) / "resnet18_best_TILDA_8class_73_7446.pth" if pretrained_path.exists(): print(f"Loading pretrained weights from {pretrained_path}") pretrained = torch.load(pretrained_path, map_location=CFG.device, weights_only=True) # 创建临模型来获取预训练权重的正确映射 temp_model = models.resnet18() temp_model.fc = nn.Linear(temp_model.fc.in_features, n_classes) temp_model.load_state_dict(pretrained["state_dict"], strict=False) # 复制预训练权重到我们的模型中(除了fc层) resnet_dict = resnet.state_dict() pretrained_dict = {k: v for k, v in temp_model.state_dict().items() if k in resnet_dict and 'fc' not in k} resnet_dict.update(pretrained_dict) resnet.load_state_dict(resnet_dict) print("✓ Successfully loaded pretrained backbone weights!") else: print(f"⚠️ Pretrained weights not found at {pretrained_path}. Using ImageNet weights.") # --- Backbone ------------------------------------------------ in_feat = resnet.fc.in_features # 512 resnet.fc = nn.Identity() self.backbone = resnet # project to d_backbone (512-64-128) #self.fc_backbone = nn.Linear(in_feat, CFG.d_backbone, bias=False) #nn.init.xavier_uniform_(self.fc_backbone.weight) # 这一句的 # --- Projection heads --------------------------------------------------- self.g_SA = MLPHead(CFG.d_backbone, CFG.d_proj) self.g_FV = MLPHead(CFG.d_backbone, CFG.d_proj) # Cosine classifier (coarse level) self.classifier = CosineLinear(CFG.d_proj, n_classes) # --- BPGSNet prototypes & gate logits ----------------------------------- self.prototypes = nn.Parameter( torch.randn(n_classes, CFG.K_max, CFG.d_proj) ) # (K_C, K_max, d_p) nn.init.xavier_uniform_(self.prototypes) self.log_alpha = nn.Parameter( torch.randn(n_classes, CFG.K_max) * 0.01 # 随机初始化 ) # (K_C, K_max) self.register_buffer("global_step", torch.tensor(0, dtype=torch.long)) # ---------------- Forward pass ---------------- # def forward(self, x: torch.Tensor, y_c: torch.Tensor, mem_bank: MemoryBank, use_bpgs: bool = True ) -> tuple[torch.Tensor, dict[str, float], torch.Tensor, torch.Tensor]: """Return full loss components (Section §3.3 & §3.4).""" B = x.size(0) # --- Backbone & projections ------------------------------------------- z = self.backbone(x) # (B, 512) p = L2_normalise(self.g_SA(z)) # (B, d_p) q = L2_normalise(self.g_FV(z)) # (B, d_p) bank_p, bank_q, bank_y = mem_bank.get() # ---------------- DASSER losses ---------------- # # L_SA, L_ortho, L_ce_dasser = self._dasser_losses( # p, q, y_c, bank_p, bank_q, bank_y # ) # total_loss = L_SA + L_ortho + L_ce_dasser # stats = { # "loss": total_loss.item(), # "L_SA": L_SA.item(), # "L_ortho": L_ortho.item(), # "L_ce_dasser": L_ce_dasser.item(), # } L_SA, L_ortho, L_ce_dasser, L_se = self._dasser_losses( p, q, y_c, bank_p, bank_q, bank_y ) total_loss = ( L_SA + L_ortho + L_ce_dasser + CFG.gamma_se * L_se # NEW ) stats = { "loss": total_loss.item(), "L_SA": L_SA.item(), "L_ortho": L_ortho.item(), "L_ce_dasser": L_ce_dasser.item(), "L_se": L_se.item(), # NEW } # ---------------- BPGSNet (conditional) -------- # if use_bpgs: L_ce_bpgs, L_proto, L_gate, coarse_logits = self._bpgs_losses(q, y_c) total_loss = total_loss + L_ce_bpgs + L_proto + L_gate stats.update({ "L_ce_bpgs": L_ce_bpgs.item(), "L_proto": L_proto.item(), "L_gate": L_gate.item(), }) else: coarse_logits = None return total_loss, stats, p.detach(), q.detach() # ---------------------- Internal helpers ---------------------- # def _dasser_losses( self, p: torch.Tensor, q: torch.Tensor, y_c: torch.Tensor, bank_p: torch.Tensor, bank_q: torch.Tensor, bank_y: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ DASSER 损失: • 语义对齐 L_SA • 正交 L_ortho • 粗粒度 CE L_ce • 自表示 L_se (NEW) """ # ---------- 拼 batch + memory ---------- # p_all = torch.cat([p, bank_p], dim=0) if bank_p.numel() > 0 else p q_all = torch.cat([q, bank_q], dim=0) if bank_q.numel() > 0 else q y_all = torch.cat([y_c, bank_y], dim=0) if bank_y.numel() > 0 else y_c # ---------- 1) 语义对齐 (原有) ---------- # G = pairwise_cosine(p_all, p_all) # (N,N) :contentReference[oaicite:2]{index=2} G.fill_diagonal_(0.0) same = y_all.unsqueeze(0) == y_all.unsqueeze(1) diff = ~same L_SA = ((same * (1 - G)).sum() + CFG.lambda1 * (diff * G.clamp_min(0)).sum()) / (p_all.size(0) ** 2) # ---------- 2) 正交 (原有) --------------- # L_ortho = (1.0 / CFG.d_proj) * (p_all @ q_all.T).pow(2).sum() # ---------- 3) 自表示 (NEW) -------------- # C_logits = pairwise_cosine(p_all, p_all) # 再算一次以免受上一步改动 C_logits.fill_diagonal_(-1e4) # 置 −∞ → softmax&asymp;0 C = F.softmax(C_logits, dim=1) # 行归一化 :contentReference[oaicite:3]{index=3} Q_recon = C @ q_all # 线性重构 L_se = F.mse_loss(Q_recon, q_all) # :contentReference[oaicite:4]{index=4} # ---------- 4) 粗粒度 CE (原有) ---------- # logits_coarse = self.classifier(p) L_ce = F.cross_entropy(logits_coarse, y_c) return L_SA, L_ortho, L_ce, L_se # ---------------------- 放到 BaPSTO 类里,直接替换原函数 ---------------------- # def _bpgs_losses( self, q: torch.Tensor, y_c: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ 计算 BPGSNet 损失(正确的 log-sum-exp 版) """ B = q.size(0) # q是batch*128的矩阵,获得批次大小 K_C, K_M = self.prototypes.size(0), self.prototypes.size(1) # K_C 是类别数,K_M 是每个类别的原型数 # (1) 欧氏距离 d = ((q.unsqueeze(1).unsqueeze(2) - self.prototypes.unsqueeze(0)) ** 2).sum(-1) # (B,K_C,K_M) s = 30.0 # ===== (2) 退火温度 τ ===== # τ 线性退火 epoch = self.global_step.item() / self.steps_per_epoch tau = max(CFG.tau_min_hc, CFG.tau0_hc - (CFG.tau0_hc - CFG.tau_min_hc) * min(1., epoch / CFG.anneal_epochs_hc)) # ----- (3) Hard- ----- log_alpha = self.log_alpha # (C,K) z, _s = self._sample_hardConcrete(log_alpha, tau) # z: (C,K) g = z.unsqueeze(0) # (1,C,K) 广播到 batch # (1,C,K) # ----- (4) coarse logits ----- mask_logits = -d * s + torch.log(g + 1e-12) # (B,C,K) coarse_logits = torch.logsumexp(mask_logits, dim=2) # (B,C) # ----- (5) losses ----- L_ce = F.cross_entropy(coarse_logits, y_c) y_hat = torch.softmax(mask_logits.detach(), dim=2) # stop-grad L_proto = CFG.alpha_proto * (y_hat * d).mean() # ---------- Hard-Concrete 的 L0 正则 ---------- temp = (log_alpha - tau * math.log(-CFG.gamma_hc / CFG.zeta_hc)) # (C,K) p_active = torch.sigmoid(temp) # 激活概率 p_active 是解析期望 pa(z大于0) # 新增加得loss pa = torch.sigmoid(log_alpha) entropy_penalty = 0.05 * (pa * torch.log(pa + 1e-8) + (1-pa) * torch.log(1-pa + 1e-8)).mean() # 新增加得loss,控制全局稀疏率 L_gate = CFG.beta_l0 * p_active.mean() - entropy_penalty # L0 正则 beta_l0 控控制全局稀疏率 return L_ce, L_proto, L_gate, coarse_logits def _sample_hardConcrete(self, log_alpha, tau): """return z ~ HardConcrete, and its stretched unclipped \tilde z""" u = torch.rand_like(log_alpha).clamp_(1e-6, 1-1e-6) s = torch.sigmoid((log_alpha + torch.log(u) - torch.log(1-u)) / tau) s = s * (CFG.zeta_hc - CFG.gamma_hc) + CFG.gamma_hc # stretch z_hard = s.clamp(0.0, 1.0) z = z_hard + (s - s.detach()) # ST estimator,让梯度穿过 return z, s # z用于前向, s用于梯度 # -------------------------- K-means++ initialisation -------------------------- # @torch.no_grad() def kmeans_init(model: BaPSTO, loader: DataLoader): """Use q‑features to initialise prototypes with K‑means++ (§3.4.1).""" print("[Init] Running K‑means++ for prototype initialisation...") model.eval() all_q, all_y = [], [] for x, y in loader: x = x.to(CFG.device) z = L2_normalise(model.g_FV(model.backbone(x))) all_q.append(z.cpu()) all_y.append(y) all_q = torch.cat(all_q) # (N, d_p) all_y = torch.cat(all_y) # (N,) for c in range(model.prototypes.size(0)): feats = all_q[all_y == c] kmeans = KMeans( n_clusters=CFG.K_max, init="k-means++", n_init=10, max_iter=100, random_state=CFG.seed, ).fit(feats.numpy()) centroids = torch.from_numpy(kmeans.cluster_centers_).to(CFG.device) centroids = L2_normalise(centroids) # (K_max, d_p) model.prototypes.data[c] = centroids print("[Init] Prototype initialisation done.") # -------------------------- Training utilities -------------------------- # def accuracy(output: torch.Tensor, target: torch.Tensor) -> float: """Compute top‑1 accuracy (coarse).""" with torch.no_grad(): pred = output.argmax(dim=1) correct = pred.eq(target).sum().item() return correct / target.size(0) @torch.no_grad() def _collect_Q_labels(model: BaPSTO, loader: DataLoader): """遍历 *loader*,返回 (Q features, coarse-ID, proto-ID);采样上限 MAX_SAMPLED.""" model.eval() qs, cls, subs = [], [], [] for x, y in loader: x = x.to(CFG.device) q = L2_normalise(model.g_FV(model.backbone(x))) # (B,d) # &mdash;&mdash; 预测最近原型 idx &mdash;&mdash; # d = ((q.unsqueeze(1).unsqueeze(2) - model.prototypes.unsqueeze(0))**2).sum(-1) # (B,C,K) proto_id = d.view(d.size(0), -1).argmin(dim=1) # flatten idx = C*K + k qs.append(q.cpu()) cls.append(y) subs.append(proto_id.cpu()) if MAX_SAMPLED and (sum(len(t) for t in qs) >= MAX_SAMPLED): break Q = torch.cat(qs)[:MAX_SAMPLED] # (N,d) Yc = torch.cat(cls)[:MAX_SAMPLED] # coarse Ysub = torch.cat(subs)[:MAX_SAMPLED] # pseudo-fine return Q.numpy(), Yc.numpy(), Ysub.numpy() def _plot_heatmap(mat: np.ndarray, title: str, path: Path, boxes: list[tuple[int,int]] | None = None): """ mat : 排好序的相似度矩阵 boxes : [(row_start,row_end), ...];坐标在排序后的索引系中 """ plt.figure(figsize=(6, 5)) ax = plt.gca() im = ax.imshow(mat, cmap="viridis", aspect="auto") plt.colorbar(im) if boxes: # 逐个 coarse-class 画框 for s, e in boxes: w = e - s rect = Rectangle((s - .5, s - .5), w, w, linewidth=1.5, edgecolor="white", facecolor="none") ax.add_patch(rect) plt.title(title) plt.tight_layout() plt.savefig(path, dpi=300) plt.close() def compute_and_save_diagnostics(model: BaPSTO, loader: DataLoader, tag: str): """ • 计算三个内部指标并保存 csv • 绘制五张图 (C heatmap, t-SNE / UMAP, Laplacian spectrum, Silhouette bars, Gate heatmap(opt)) """ print(f"[Diag] computing metrics ({tag}) ...") timestamp = get_timestamp() Q, Yc, Ysub = _collect_Q_labels(model, loader) # ========== 1) 聚类指标 ========== # sil = silhouette_score(Q, Ysub, metric="cosine") ch = calinski_harabasz_score(Q, Ysub) db = davies_bouldin_score(Q, Ysub) pd.DataFrame( {"tag":[tag], "silhouette":[sil], "calinski":[ch], "davies":[db]} ).to_csv(DIAG_DIR / f"cluster_metrics_{tag}_{timestamp}.csv", index=False) # ========== 2) C heatmap & Laplacian ========== # GRAPH_LEVEL = 'coarse' # ← 这里换 'sub' 就看细粒度--------------------------------------------------- # ① &mdash;&mdash; 相似度矩阵(始终基于所有样本,用来画热力图) &mdash;&mdash; # P_all = Q @ Q.T / np.linalg.norm(Q, axis=1, keepdims=True) / np.linalg.norm(Q, axis=1)[:, None] np.fill_diagonal(P_all, -1e4) # 取消自环 C_heat = torch.softmax(torch.tensor(P_all), dim=1).cpu().numpy() # &mdash;&mdash; 画热力图:完全沿用旧逻辑,不受 GRAPH_LEVEL 影响 &mdash;&mdash; # order = np.lexsort((Ysub, Yc)) # 先 coarse 再 sub #order = np.argsort(Yc) # 只按粗类别拍平---------------------- # &mdash;&mdash; 计算每个 coarse-class 的起止行(列) &mdash;&mdash; # coarse_sorted = Yc[order] bounds = [] # [(start,end),...] start = 0 for i in range(1, len(coarse_sorted)): if coarse_sorted[i] != coarse_sorted[i-1]: bounds.append((start, i)) # [start, end) start = i bounds.append((start, len(coarse_sorted))) # &mdash;&mdash; 绘图,并把边界传给 boxes 参数 &mdash;&mdash; # _plot_heatmap(C_heat[order][:, order], f"C heatmap ({tag})", DIAG_DIR / f"C_heatmap_{tag}_{timestamp}.png", boxes=bounds) # ② &mdash;&mdash; 针对 Laplacian 的图,可选按 coarse/sub 屏蔽 &mdash;&mdash; # P_graph = P_all.copy() # 从全局矩阵复制一份 if GRAPH_LEVEL == 'coarse': P_graph[Yc[:, None] != Yc[None, :]] = -1e4 # 只留同 coarse 的边 elif GRAPH_LEVEL == 'sub': P_graph[Ysub[:, None] != Ysub[None, :]] = -1e4 # 只留同子簇的边 C_graph = torch.softmax(torch.tensor(P_graph), dim=1).cpu().numpy() D = np.diag(C_graph.sum(1)) L = D - (C_graph + C_graph.T) / 2 eigs = np.sort(np.linalg.eigvalsh(L))[:30] plt.figure(); plt.plot(eigs, marker='o') plt.title(f"Laplacian spectrum ({GRAPH_LEVEL or 'global'} | {tag})") plt.tight_layout() plt.savefig(DIAG_DIR / f"laplacian_{tag}_{timestamp}.png", dpi=300); plt.close() # ========== 3) t-SNE / UMAP (带图例 & 色彩 ≤20) ========== # warnings.filterwarnings("ignore", message="n_jobs value 1") focus_cls = 1#None # ← 若只看 coarse ID=3,把它改成 3 sel = slice(None) if focus_cls is None else (Yc == focus_cls) Q_sel, Ysub_sel = Q[sel], Ysub[sel] # -- 选 UMAP 或 t-SNE -- if HAS_UMAP: # :contentReference[oaicite:2]{index=2} reducer_cls = umap.UMAP if hasattr(umap, "UMAP") else umap.umap_.UMAP reducer = reducer_cls(n_neighbors=30, min_dist=0.1, random_state=CFG.seed) method = "UMAP" else: reducer = TSNE(perplexity=30, init="pca", random_state=CFG.seed) method = "t-SNE" emb = reducer.fit_transform(Q_sel) # (N,2) # ---------- scatter ---------- # unique_sub = np.unique(Ysub_sel) try: # 新版 Matplotlib (≥3.7) cmap = plt.get_cmap("tab20", min(len(unique_sub), 20)) except TypeError: # 旧版 Matplotlib (<3.7) cmap = plt.cm.get_cmap("tab20", min(len(unique_sub), 20)) plt.figure(figsize=(5, 5)) for i, s_id in enumerate(unique_sub): pts = Ysub_sel == s_id plt.scatter(emb[pts, 0], emb[pts, 1], color=cmap(i % 20), s=6, alpha=0.7, label=str(s_id) if len(unique_sub) <= 20 else None) if len(unique_sub) <= 20: plt.legend(markerscale=2, bbox_to_anchor=(1.02, 1), borderaxespad=0.) title = f"{method} ({tag})" if focus_cls is None else f"{method} cls={focus_cls} ({tag})" plt.title(title) plt.tight_layout() plt.savefig(DIAG_DIR / f"embed_{tag}_{timestamp}.png", dpi=300) plt.close() # ========== 4) Silhouette bars ========== # sil_samples = silhouette_samples(Q, Ysub, metric="cosine") order = np.argsort(Ysub) plt.figure(figsize=(6,4)) plt.barh(np.arange(len(sil_samples)), sil_samples[order], color="steelblue") plt.title(f"Silhouette per sample ({tag})"); plt.xlabel("coefficient") plt.tight_layout(); plt.savefig(DIAG_DIR / f"silhouette_bar_{tag}_{timestamp}.png", dpi=300); plt.close() print(f"[Diag] saved to {DIAG_DIR}") def create_dataloaders() -> Tuple[DataLoader, DataLoader, int]: """Load train/val as ImageFolder and return dataloaders + K_C.""" train_dir = Path(CFG.data_root) / "train" val_dir = Path(CFG.data_root) / "test" classes = sorted([d.name for d in train_dir.iterdir() if d.is_dir()]) K_C = len(classes) transform_train = transforms.Compose( [ transforms.Grayscale(num_output_channels=3), transforms.Resize((CFG.img_size, CFG.img_size)), transforms.RandomHorizontalFlip(), transforms.RandomRotation(10), transforms.RandomResizedCrop(CFG.img_size, scale=(0.8, 1.0)), transforms.ToTensor(), transforms.Normalize(mean=[0.5] * 3, std=[0.5] * 3), ] ) transform_val = transforms.Compose( [ transforms.Grayscale(num_output_channels=3), transforms.Resize((CFG.img_size, CFG.img_size)), transforms.ToTensor(), transforms.Normalize(mean=[0.5] * 3, std=[0.5] * 3), ] ) train_ds = datasets.ImageFolder(str(train_dir), transform=transform_train) val_ds = datasets.ImageFolder(str(val_dir), transform=transform_val) train_loader = DataLoader( train_ds, batch_size=CFG.batch_size, shuffle=True, num_workers=CFG.num_workers, pin_memory=True, drop_last=True, ) val_loader = DataLoader( val_ds, batch_size=CFG.batch_size, shuffle=False, num_workers=CFG.num_workers, pin_memory=True, ) return train_loader, val_loader, K_C # -------------------------- Main training routine -------------------------- # def train(): best_ckpt_path = None # 记录最佳 joint 权重的完整文件名 best_acc = 0.0 best_epoch = -1 train_loader, val_loader, K_C = create_dataloaders() model = BaPSTO(K_C).to(CFG.device) model.steps_per_epoch = len(train_loader) #print(model) mb = MemoryBank(dim=CFG.d_proj, size=CFG.mem_size) warmup_weights_path = Path(CFG.save_root) / "bapsto_warmup_complete.pth" # 检查是否存在预保存的warm-up权重 if warmup_weights_path.exists(): print(f"找到预训练的warm-up权重,正在加载: {warmup_weights_path}") checkpoint = torch.load(warmup_weights_path, map_location=CFG.device,weights_only=True) model.load_state_dict(checkpoint["state_dict"]) print("✓ 成功加载warm-up权重,跳过warm-up阶段!") else: # ---------- Phase 1: DASSER warm‑up (backbone frozen) ---------- # print("\n==== Phase 1 | DASSER warm‑up ====") for p in model.backbone.parameters(): p.requires_grad = False # &mdash;&mdash; 冻结 prototypes 和 gate_logits &mdash;&mdash; # model.prototypes.requires_grad = False model.log_alpha.requires_grad = False # &mdash;&mdash; 冻结 prototypes 和 gate_logits &mdash;&mdash; # optimizer = optim.AdamW( filter(lambda p: p.requires_grad, model.parameters()), lr=CFG.lr_warmup, weight_decay=CFG.weight_decay, betas=(0.9, 0.95), ) scheduler = CosineAnnealingLR(optimizer, T_max=len(train_loader) * CFG.n_epochs_warmup) for epoch in range(CFG.n_epochs_warmup): run_epoch(train_loader, model, mb, optimizer, scheduler, epoch, phase="warmup") # 保存warm-up完成后的权重 torch.save( {"epoch": CFG.n_epochs_warmup, "state_dict": model.state_dict()}, warmup_weights_path ) print(f"✓ Warm-up完成,模型权重已保存至: {warmup_weights_path}") # after warm‑up loop, before Phase 2 header kmeans_init(model, train_loader) # <─ 新增 print("K‑means initialisation done. Prototypes are now ready.") compute_and_save_diagnostics(model, train_loader, tag="after_kmeans") # ---------- Phase 2: Joint optimisation (all params trainable) ---------- # print("\n==== Phase 2 | Joint optimisation ====") for p in model.backbone.parameters(): p.requires_grad = True # &mdash;&mdash; 解冻 prototypes 和 gate logits &mdash;&mdash; # model.prototypes.requires_grad = True model.log_alpha.requires_grad = True # &mdash;&mdash; 解冻 prototypes 和 gate logits &mdash;&mdash; # param_groups = [ {"params": [p for n,p in model.named_parameters() if n!='log_alpha'], "lr": CFG.lr_joint}, {"params": [model.log_alpha], "lr": CFG.lr_joint * 2.0} ] optimizer = optim.AdamW( param_groups, weight_decay=CFG.weight_decay, betas=(0.9, 0.95), ) scheduler = CosineAnnealingLR(optimizer, T_max=len(train_loader) * CFG.n_epochs_joint) best_acc = 0.0 best_epoch = -1 epochs_no_improve = 0 for epoch in range(CFG.n_epochs_joint): stats = run_epoch(train_loader, model, mb, optimizer, scheduler, epoch, phase="joint") # ─────────────────────────────────────────── if (epoch + 1) % 1 == 0: # 每个 epoch 都跑验证 # &mdash;&mdash; 每 5 个 epoch 额外保存 Gate & 聚类诊断 &mdash;&mdash; # if (epoch + 1) % 5 == 0: timestamp = get_timestamp() gate_prob = torch.sigmoid(model.log_alpha.detach().cpu()) _plot_heatmap( gate_prob, f"Gate prob (ep{epoch+1})", DIAG_DIR / f"gate_ep{epoch+1}_{timestamp}.png", ) compute_and_save_diagnostics( model, train_loader, tag=f"joint_ep{epoch+1}" ) # ---------- 统计指标 ---------- val_loss, val_acc, per_cls_acc, auc = metrics_on_loader(val_loader, model) train_acc = metrics_on_loader (train_loader, model)[1] # 只取整体训练准确率 print(f"[Val] ep {epoch+1:02d} | loss {val_loss:.3f} | " f"acc {val_acc:.3f} | train-acc {train_acc:.3f} |\n" f" per-cls-acc {np.round(per_cls_acc, 2)} |\n" f" AUC {np.round(auc, 2)}") # &mdash;&mdash; checkpoint &mdash;&mdash; # if val_acc > best_acc: best_acc = val_acc best_epoch = epoch epochs_no_improve = 0 best_ckpt_path = save_ckpt(model, epoch, tag="best_joint", acc=val_acc, optimizer=optimizer, scheduler=scheduler) # ← 传进去 else: epochs_no_improve += 1 # &mdash;&mdash; gate 修剪 &mdash;&mdash; # if epoch+1 >= 10: # 先训练 10 个 epoch 再剪 prune_gates(model, threshold=0.25, min_keep=1, hc_threshold=CFG.hc_threshold) # &mdash;&mdash; early stopping &mdash;&mdash; # if epochs_no_improve >= 50: print("Early stopping triggered in joint phase.") break # ─────────────────────────────────────────── model.global_step += 1 print(model.prototypes.grad.norm()) # 非零即可证明 L_proto 对原型确实有更新压力 model.global_step.zero_() # Joint训练结束后,重命名最佳模型文件,添加准确率 best_acc_int = round(best_acc * 1e4) # 将0.7068转换为7068 joint_ckpt_path = Path(CFG.save_root) / "bapsto_best_joint.pth" renamed_path = Path(CFG.save_root) / f"bapsto_best_joint_{best_acc_int}.pth" if joint_ckpt_path.exists(): joint_ckpt_path.rename(renamed_path) best_ckpt_path = renamed_path # ★ 同步路径,供 fine-tune 使用 print(f"✓ 最优联合训练模型已重命名: {renamed_path.name} " f"(epoch {best_epoch+1}, ACC: {best_acc:.4f})") # ---------- Phase 3: Fine‑tune (prototypes & gates frozen) ---------- # print("\n==== Phase 3 | Fine‑tuning ====") best_ft_acc = 0.0 best_ft_epoch = -1 # 若有最佳 joint 权重则加载 if best_ckpt_path is not None and Path(best_ckpt_path).exists(): ckpt = torch.load(best_ckpt_path, map_location=CFG.device, weights_only=True) model.load_state_dict(ckpt["state_dict"]) epoch_loaded = ckpt["epoch"] + 1 # 以 1 为起点的人类可读轮次 acc_loaded = ckpt.get("acc", -1) # 若早期代码没存 acc,给个占位 print(f"✓ loaded best joint ckpt (epoch {epoch_loaded}, ACC {acc_loaded:.4f})") else: print("⚠️ best_ckpt_path 未找到,继续沿用上一轮权重。") for param in [model.prototypes, model.log_alpha]: param.requires_grad = False for p in model.parameters(): if p.requires_grad: p.grad = None # clear any stale gradients optimizer = optim.AdamW( filter(lambda p: p.requires_grad, model.parameters()), lr=CFG.lr_ft, weight_decay=CFG.weight_decay, betas=(0.9, 0.95), ) scheduler = CosineAnnealingLR(optimizer, T_max=len(train_loader) * CFG.n_epochs_ft) for epoch in range(CFG.n_epochs_ft): run_epoch(train_loader, model, mb, optimizer, scheduler, epoch, phase="finetune") if (epoch + 1) % 1 == 0: # 每个 epoch 都评估 val_acc = evaluate(val_loader, model) print(f"[FT] ep {epoch+1:02d} | acc {val_acc:.4f}") # ① 按 epoch 保存快照(可选) save_ckpt(model, epoch, tag="ft") # ② 维护 “fine-tune 最佳” if val_acc > best_ft_acc: best_ft_acc = val_acc best_ft_epoch = epoch best_ft_acc_int = round(best_ft_acc * 1e4) # 将0.7068转换为7068 best_ft_ckpt_path = Path(CFG.save_root) / f"bapsto_best_ft_{best_ft_acc_int}.pth" save_ckpt(model, epoch, tag="best_ft", acc=val_acc) # 只保留一个最新 best_ft # 重命名保存文件 if best_ft_ckpt_path.exists(): best_ft_ckpt_path.rename(best_ft_ckpt_path) print(f"✓ Fine-tune最佳模型已重命名: {best_ft_ckpt_path.name} (epoch {best_ft_epoch+1}, ACC: {best_ft_acc:.4f})") print(f"Training completed. Best FT ACC {best_ft_acc:.4f}") # -------------------------- Helper functions -------------------------- # def run_epoch(loader, model, mem_bank: MemoryBank, optimizer, scheduler, epoch, phase:str): model.train() running = {"loss": 0.0} use_bpgs = (phase != "warmup") for step, (x, y) in enumerate(loader): x, y = x.to(CFG.device), y.to(CFG.device) optimizer.zero_grad() loss, stats, p_det, q_det = model(x, y, mem_bank, use_bpgs=use_bpgs) loss.backward() optimizer.step() scheduler.step() mem_bank.enqueue(p_det, q_det, y.detach()) # accumulate for k, v in stats.items(): running[k] = running.get(k, 0.0) + v # ★★★★★ Hard-Concrete 梯度健康检查 ★★★★★ if phase == "joint" and step % 100 == 0: # ─── Hard-Concrete 监控 ─── tau_now = max( CFG.tau_min_hc, CFG.tau0_hc - (CFG.tau0_hc - CFG.tau_min_hc) * min(1.0, model.global_step.item() / (model.steps_per_epoch * CFG.anneal_epochs_hc)) ) pa = torch.sigmoid(model.log_alpha) # (C,K) p_act = pa.mean().item() alive = (pa > 0.4).float().sum().item() # 0.4 与 prune 阈值一致 total = pa.numel() # = C × K grad_nm = (model.log_alpha.grad.detach().norm().item() if model.log_alpha.grad is not None else 0.0) pa = torch.sigmoid(model.log_alpha) print(f"[DBG] τ={tau_now:.3f} p̄={pa.mean():.3f} " f"min={pa.min():.2f} max={pa.max():.2f} " f"alive={(pa>0.25).sum().item()}/{pa.numel()} " f"‖∇α‖={grad_nm:.2e}") # ★★★★★ 监控段结束 ★★★★★ if (step + 1) % 50 == 0: avg_loss = running["loss"] / (step + 1) print( f"Epoch[{phase} {epoch+1}] Step {step+1}/{len(loader)} | " f"loss: {avg_loss:.4f}", end="\r", ) # epoch summary print(f"Epoch [{phase} {epoch+1}]: " + ', '.join(f"{k}: {running[k]:.4f}" for k in running)) return running @torch.no_grad() def evaluate(loader, model): model.eval() total_correct, total_samples = 0, 0 K_C, K_M = model.prototypes.size(0), model.prototypes.size(1) gate_hard = (model.log_alpha > 0).float() # (K_C,K_M) for x, y in loader: x, y = x.to(CFG.device), y.to(CFG.device) b = x.size(0) # --- 特征 & 距离 --- q = L2_normalise(model.g_FV(model.backbone(x))) # (b,d_p) d = ((q.unsqueeze(1).unsqueeze(2) - model.prototypes.unsqueeze(0))**2).sum(-1) # (b,K_C,K_M) s = 30.0 # scale for logits # --- 子簇 logit & 粗 logit --- mask_logits = -d * s + torch.log(gate_hard + 1e-12) # (b,K_C,K_M) # 这里由于是log,所以二者相加 coarse_logits = torch.logsumexp(mask_logits, dim=2) # (b,K_C) # --- 统计准确率 --- total_correct += coarse_logits.argmax(1).eq(y).sum().item() total_samples += b return total_correct / total_samples @torch.no_grad() def metrics_on_loader(loader, model): """ 返回: loss_avg &ndash; 均值交叉熵 acc &ndash; overall top-1 per_cls_acc (C,) &ndash; 每个 coarse 类别准确率 auc (C,) &ndash; 每类 one-vs-rest ROC-AUC """ model.eval() n_cls = model.prototypes.size(0) total_loss, total_correct, total_samples = 0., 0, 0 # &mdash;&mdash; 用来存储全量 logits / labels &mdash;&mdash; # logits_all, labels_all = [], [] ce_fn = nn.CrossEntropyLoss(reduction="sum") # 累加再除 for x, y in loader: x, y = x.to(CFG.device), y.to(CFG.device) # 前向 with torch.no_grad(): q = L2_normalise(model.g_FV(model.backbone(x))) d = ((q.unsqueeze(1).unsqueeze(2) - model.prototypes.unsqueeze(0))**2).sum(-1) logits = torch.logsumexp(-d*30 + torch.log((model.log_alpha>0).float()+1e-12), dim=2) total_loss += ce_fn(logits, y).item() total_correct += logits.argmax(1).eq(y).sum().item() total_samples += y.size(0) logits_all.append(logits.cpu()) labels_all.append(y.cpu()) # &mdash;&mdash; overall &mdash;&mdash; # loss_avg = total_loss / total_samples acc = total_correct / total_samples # &mdash;&mdash; 拼接 & 转 numpy &mdash;&mdash; # logits_all = torch.cat(logits_all).numpy() labels_all = torch.cat(labels_all).numpy() # &mdash;&mdash; per-class ACC &mdash;&mdash; # per_cls_acc = np.zeros(n_cls) for c in range(n_cls): mask = labels_all == c if mask.any(): per_cls_acc[c] = (logits_all[mask].argmax(1) == c).mean() # &mdash;&mdash; per-class AUC &mdash;&mdash; # try: from sklearn.metrics import roc_auc_score prob = torch.softmax(torch.from_numpy(logits_all), dim=1).numpy() auc = roc_auc_score(labels_all, prob, multi_class="ovr", average=None) except Exception: # 组数太少或只有 1 类样本报错 auc = np.full(n_cls, np.nan) return loss_avg, acc, per_cls_acc, auc def save_ckpt(model, epoch:int, tag:str, acc:float|None=None, optimizer=None, scheduler=None): """ 通用保存函数 • 返回 ckpt 文件完整路径,方便上层记录 • 可选把 opt / sched state_dict 一起存进去,便于 resume """ save_dir = Path(CFG.save_root) save_dir.mkdir(parents=True, exist_ok=True) # -------- 路径策略 -------- # if tag == "best_joint": # 只保留一个最新最优 joint ckpt_path = save_dir / "bapsto_best_joint.pth" else: # 其他阶段带间戳 ckpt_path = save_dir / f"bapsto_{tag}_epoch{epoch+1}_{get_timestamp()}.pth" # -------- 组装 payload -------- # # • vars(CFG) 可以拿到用户自己在 CFG 里写的字段 # • 再过滤掉 __ 开头的内部键、防止把 Python meta-data 也 dump 进去 cfg_dict = {k: v for k, v in vars(CFG).items() if not k.startswith("__")} payload = { "epoch": epoch, "state_dict": model.state_dict(), "cfg": cfg_dict, # ← 改在这里 } if acc is not None: payload["acc"] = acc if optimizer is not None: payload["optimizer"] = optimizer.state_dict() if scheduler is not None: payload["scheduler"] = scheduler.state_dict() torch.save(payload, ckpt_path) print(f"✓ checkpoint saved to {ckpt_path}") return ckpt_path @torch.no_grad() def prune_gates(model: BaPSTO, threshold=0.05, min_keep=2, hc_threshold=0.35): """ Disable sub-clusters whose mean gate probability < threshold. After setting them to -10, we do another **row normalization**: Each coarse class row is subtracted by the max logit of that row, ensuring the maximum logit for active clusters is 0 and inactive clusters &asymp; -10 → softmax(-10) &asymp; 0. Also check for Hard-Concrete (HC) weights below a threshold (e.g., 0.35) to disable sub-clusters. """ # softmax probabilities (K_C, K_max) p_active = torch.sigmoid(model.log_alpha) # Activation probability mask = (p_active < threshold) # Check HC thresholds and disable low weight clusters low_weight_mask = (p_active < hc_threshold) # Find sub-clusters with low HC weight mask = mask | low_weight_mask # Combine with existing mask # Ensure at least `min_keep` sub-clusters are kept per coarse class keep_mask = (mask.cumsum(1) >= (CFG.K_max - min_keep)) mask = mask & ~keep_mask pruned = mask.sum().item() if pruned == 0: return model.log_alpha.data[mask] = -10.0 # Set log_alpha of pruned sub-clusters to a very low value print(f"Pruned {pruned} sub-clusters (ḡ<{threshold}, keep≥{min_keep}/class)") # Reassign samples from pruned sub-clusters to active sub-clusters if pruned > 0: # Find the indices of the pruned sub-clusters pruned_clusters = mask.sum(dim=1) > 0 # (K_C,) for c in range(model.prototypes.size(0)): # Loop through each coarse class if pruned_clusters[c]: pruned_indices = mask[c] # Get indices of pruned sub-clusters for class `c` active_indices = ~pruned_indices # Get indices of active sub-clusters active_prototypes = model.prototypes[c][active_indices] # Get active prototypes q = model.q # Get features # Reassign samples from pruned clusters to active clusters d_active = pairwise_cosine(q, active_prototypes) # Compute distance to active prototypes best_active = d_active.argmin(dim=1) # Assign samples to the nearest active sub-cluster # Update the model with reallocated samples (you can implement reallocation logic here) print(f"Reassigning samples from pruned sub-clusters of class {c} to active clusters.") # -------------------------- Entrypoint -------------------------- # if __name__ == "__main__": os.makedirs(CFG.save_root, exist_ok=True) start = time.time() train() print(f"Total runtime: {(time.time() - start) / 3600:.2f} h") 逐行详细解释代码
09-05
import numpy as np import matplotlib.pyplot as plt import matplotlib import matplotlib.font_manager as fm import torch import torch.nn as nn import torch.optim as optim from torch.distributions import Categorical from torch.optim.lr_scheduler import CosineAnnealingLR import random import time from sklearn.cluster import DBSCAN # 字体设置(确保中文显示) def set_safe_font(): try: font_paths = fm.findSystemFonts() system_fonts = set() for path in font_paths: try: font_prop = fm.FontProperties(fname=path) system_fonts.add(font_prop.get_name()) except: continue preferred_fonts = ["SimHei", "Microsoft YaHei", "Heiti TC", "WenQuanYi Zen Hei", "Arial Unicode MS"] available_fonts = [f for f in preferred_fonts if f in system_fonts] if available_fonts: plt.rcParams["font.family"] = [available_fonts[0]] else: plt.rcParams["font.family"] = ["sans-serif"] plt.rcParams["axes.unicode_minus"] = False except Exception as e: print(f"字体设置警告: {e}") set_safe_font() matplotlib.use('TkAgg') # 有GUI环境用这个,无GUI换为'Agg' # 随机种子(确保实验可复现) SEED = 42 torch.manual_seed(SEED) np.random.seed(SEED) random.seed(SEED) start_time = time.time() # 坐标处理函数(矩形分组与折线拟合) def group_coordinates_into_rectangles(raw_coords, group_size=4): if len(raw_coords) % group_size != 0: raise ValueError(f"坐标总数必须是{group_size}的倍数!当前共{len(raw_coords)}个坐标") for i, coord in enumerate(raw_coords): if not (isinstance(coord[0], (int, float)) and isinstance(coord[1], (int, float))): raise TypeError(f"第{i + 1}个坐标格式错误,需为数值型(当前:{coord})") return [np.array(raw_coords[i:i + group_size]) for i in range(0, len(raw_coords), group_size)] def calculate_rectangle_centers(grouped_coords): return np.array([np.mean(rectangle, axis=0) for rectangle in grouped_coords]) def fit_line_to_rectangle_group(rectangle_group): centers = calculate_rectangle_centers(rectangle_group) lon = centers[:, 0] lat = centers[:, 1] lon_var = np.var(lon) lat_var = np.var(lat) if lon_var > lat_var: A = np.vstack([lon, np.ones(len(lon))]).T k, b = np.linalg.lstsq(A, lat, rcond=None)[0] lon_min, lon_max = lon.min() - 0.0001, lon.max() + 0.0001 lat_min = k * lon_min + b lat_max = k * lon_max + b start_point = np.array([lon_min, lat_min]) end_point = np.array([lon_max, lat_max]) else: A = np.vstack([lat, np.ones(len(lat))]).T k, b = np.linalg.lstsq(A, lon, rcond=None)[0] lat_min, lat_max = lat.min() - 0.0001, lat.max() + 0.0001 lon_min = k * lat_min + b lon_max = k * lat_max + b start_point = np.array([lon_min, lat_min]) end_point = np.array([lon_max, lat_max]) line_vec = end_point - start_point line_len = np.linalg.norm(line_vec) unit_vec = line_vec / line_len if line_len > 1e-8 else np.array([0, 0]) return { "start": start_point, "end": end_point, "vec": unit_vec, "length": line_len, "centers": centers } def assign_agent_to_line(agents_centers, fitted_lines): agent_lines = [] for agent_center in agents_centers: line_distances = [] for line in fitted_lines: dist = point_to_line_distance(agent_center, line["start"], line["end"]) line_distances.append((dist, line)) min_dist, assigned_line = min(line_distances, key=lambda x: x[0]) line_vec = assigned_line["vec"] extend_len = assigned_line["length"] * 0.2 agent_start = assigned_line["start"] - line_vec * extend_len agent_end = assigned_line["end"] + line_vec * extend_len agent_line_vec = agent_end - agent_start agent_line_len = np.linalg.norm(agent_line_vec) agent_unit_vec = agent_line_vec / agent_line_len if agent_line_len > 1e-8 else np.array([0, 0]) agent_lines.append({ "start": agent_start, "end": agent_end, "vec": agent_unit_vec, "length": agent_line_len, "center": agent_center, "assigned_line_idx": fitted_lines.index(assigned_line) }) return agent_lines # 折线处理函数(自定义折线解析与智能体分配) def process_polyline(polyline_points, agent_center=None): polyline = np.array(polyline_points) if len(polyline) < 2: raise ValueError(f"折线至少需要2个坐标点(当前:{len(polyline)}个)") if polyline.shape[1] != 2: raise ValueError(f"折线坐标需为(x,y)格式(当前:{polyline.shape})") start_point = polyline[0] end_point = polyline[-1] overall_vec = end_point - start_point overall_len = np.linalg.norm(overall_vec) unit_vec = overall_vec / overall_len if overall_len > 1e-8 else np.array([0, 0]) min_x, min_y = polyline.min(axis=0) max_x, max_y = polyline.max(axis=0) # 动态调整局部轨道阈值(根据折线长度) local_segments = polyline if agent_center is not None: nearby_segments = [] # 动态阈值:折线长度的1/5(上限0.01&asymp;1100米) dynamic_threshold = min(0.01, overall_len / 5) if overall_len > 0 else 0.007 for i in range(len(polyline) - 1): p1 = polyline[i] p2 = polyline[i + 1] dist = point_to_line_distance(agent_center, p1, p2) if dist < dynamic_threshold: nearby_segments.append(p1) nearby_segments.append(p2) # 确保局部轨道至少有2个点 if len(nearby_segments) >= 2: local_segments = np.unique(nearby_segments, axis=0) else: # 无附近线段,取智能体周围的折线片段(前后各1个点) closest_idx = np.argmin([np.linalg.norm(agent_center - p) for p in polyline]) start_idx = max(0, closest_idx - 1) end_idx = min(len(polyline) - 1, closest_idx + 1) local_segments = polyline[start_idx:end_idx + 1] return { "original_points": polyline, "local_points": local_segments, # 局部运动线段 "start": start_point, "end": end_point, "vec": unit_vec, "length": overall_len, "bounds": (min_x, max_x, min_y, max_y), "local_bounds": (local_segments[:, 0].min() - 0.0001, local_segments[:, 0].max() + 0.0001, local_segments[:, 1].min() - 0.0001, local_segments[:, 1].max() + 0.0001) } def assign_agents_to_polylines(agents_centers, custom_polylines, safe_distance=0.0005): processed_lines = [] for i, polyline in enumerate(custom_polylines): try: processed = process_polyline(polyline) # 先不传入agent_center # 计算折线最大承载量(长度//安全距离) processed["max_agents"] = max(1, int(np.ceil(processed["length"] / safe_distance))) processed["assigned_agents_count"] = 0 # 已分配智能体计数 processed_lines.append(processed) except Exception as e: raise ValueError(f"第{i + 1}条折线处理失败:{e}") agent_lines = [] for agent_center in agents_centers: line_distances = [] for idx, line in enumerate(processed_lines): # 计算智能体到该折线的最小距离 min_dist = float('inf') for i in range(len(line["original_points"]) - 1): p1 = np.array(line["original_points"][i]) p2 = np.array(line["original_points"][i + 1]) dist = point_to_line_distance(agent_center, p1, p2) if dist < min_dist: min_dist = dist # 记录:(距离,折线,折线索引,是否有剩余容量) line_distances.append((min_dist, line, idx, line["assigned_agents_count"] < line["max_agents"])) # 优先选择「有剩余容量」的折线中距离最近的 available_lines = [(d, l, idx) for d, l, idx, available in line_distances if available] if available_lines: min_dist, assigned_line, line_idx = min(available_lines, key=lambda x: x[0]) else: # 所有折线都满了,选择距离最近的(降级策略) min_dist, assigned_line, line_idx, _ = min(line_distances, key=lambda x: x[0]) # 更新折线的已分配计数 processed_lines[line_idx]["assigned_agents_count"] += 1 # 生成智能体的局部轨道 local_processed = process_polyline(assigned_line["original_points"], agent_center) agent_lines.append({ "original_points": assigned_line["original_points"], "local_points": local_processed["local_points"], "start": assigned_line["start"], "end": assigned_line["end"], "vec": assigned_line["vec"], "length": assigned_line["length"], "bounds": local_processed["local_bounds"], "center": agent_center, "assigned_line_idx": line_idx, "max_agents": assigned_line["max_agents"] }) print(f"已将{len(agents_centers)}个智能体分配到{len(custom_polylines)}条折线(局部轨道)") return agent_lines # 线要素生成函数(自动聚类或手动指定) def generate_custom_agent_lines(grouped_coords, line_count=None): agents_centers = calculate_rectangle_centers(grouped_coords) num_agents = len(agents_centers) if line_count is None: distances = [] for i in range(num_agents): for j in range(i + 1, num_agents): distances.append(np.linalg.norm(agents_centers[i] - agents_centers[j])) avg_dist = np.mean(distances) if distances else 0.0002 clustering = DBSCAN(eps=avg_dist * 1.2, min_samples=2).fit(agents_centers) labels = clustering.labels_ line_count = len(set(labels)) if len(set(labels)) > 0 else 1 fitted_lines = [] if line_count == 1: fitted_lines.append(fit_line_to_rectangle_group(grouped_coords)) else: agents_per_line = num_agents // line_count for i in range(line_count): start_idx = i * agents_per_line end_idx = start_idx + agents_per_line if i < line_count - 1 else num_agents line_group = grouped_coords[start_idx:end_idx] fitted_lines.append(fit_line_to_rectangle_group(line_group)) agent_lines = assign_agent_to_line(agents_centers, fitted_lines) print(f"生成自定义线要素:{len(fitted_lines)}条主线 → 分配给{len(agent_lines)}个智能体") return agent_lines # 辅助函数:点到线段的距离计算 def point_to_line_distance(point, line_start, line_end): line_vec = line_end - line_start point_vec = point - line_start line_len = np.linalg.norm(line_vec) if line_len < 1e-8: return np.linalg.norm(point_vec) proj_coeff = np.dot(point_vec, line_vec) / (line_len ** 2) proj_coeff_clipped = np.clip(proj_coeff, 0.0, 1.0) proj_point = line_start + proj_coeff_clipped * line_vec return np.linalg.norm(point - proj_point) # 智能体-折线分配可视化 def visualize_agent_line_assignment(agents_centers, agent_lines, custom_polylines, save_path="agent_line_assignment.png"): plt.figure(figsize=(10, 8)) if len(custom_polylines) > 0: colors = plt.cm.tab10(np.linspace(0, 1, len(custom_polylines))) # 每条折线一个颜色 else: colors = [plt.cm.tab10(0)] # 默认颜色 # 绘制所有折线 for i, polyline in enumerate(custom_polylines): polyline_np = np.array(polyline) plt.plot(polyline_np[:, 0], polyline_np[:, 1], '-', color=colors[i], linewidth=2, alpha=0.6, label=f'折线{i + 1}') # 绘制智能体及其分配的局部轨道 for agent_idx, agent_line in enumerate(agent_lines): # 智能体初始中心 plt.scatter(agent_line["center"][0], agent_line["center"][1], color=colors[agent_line["assigned_line_idx"]], s=100, edgecolors='black', zorder=5) # 智能体编号 plt.text(agent_line["center"][0], agent_line["center"][1], str(agent_idx + 1), ha='center', va='center', fontweight='bold', zorder=6) # 局部轨道 local_poly = agent_line["local_points"] plt.plot(local_poly[:, 0], local_poly[:, 1], '--', color=colors[agent_line["assigned_line_idx"]], linewidth=3, alpha=0.8) plt.xlabel("经度") plt.ylabel("纬度") plt.title("智能体-折线分配关系(实线=原始折线,虚线=局部轨道)") plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left') plt.grid(alpha=0.3) plt.tight_layout() plt.savefig(save_path, dpi=300) plt.close() print(f"智能体-折线分配关系图已保存至 {save_path}") # 环境类(核心:重构step方法的奖惩机制) class GeoEnv: def __init__(self, grouped_initial_positions, global_step=4e-9, min_agent_distance=0.0008, # 安全距离,避免压盖 max_agent_distance=0.0007, custom_polylines=None): self.grouped_initial_positions = grouped_initial_positions # 保存原始形状坐标 self.initial_centers = calculate_rectangle_centers(self.grouped_initial_positions) self.num_agents = len(self.initial_centers) print(f"环境初始化:{self.num_agents}个智能体") if custom_polylines is not None: self.agent_lines = assign_agents_to_polylines( self.initial_centers, custom_polylines, safe_distance=min_agent_distance # 传入安全距离用于负载计算 ) else: self.agent_lines = generate_custom_agent_lines(self.grouped_initial_positions) # 动态折线容忍度参数 self.initial_line_tolerance = 0.00037 self.line_tolerance = self.initial_line_tolerance self.tolerance_increase_rate = 0.00001 self.max_line_tolerance = 0.0009 self.slide_step = global_step # 计算环境边界 all_coords = np.vstack(self.grouped_initial_positions) min_lon, min_lat = all_coords.min(axis=0) max_lon, max_lat = all_coords.max(axis=0) self.grid_left, self.grid_right = min_lon - 0.001, max_lon + 0.001 self.grid_bottom, self.grid_top = min_lat - 0.001, max_lat + 0.001 self.grid_width = self.grid_right - self.grid_left self.grid_height = self.grid_top - self.grid_bottom self.min_agent_dist = min_agent_distance self.max_agent_dist = max_agent_distance self.state_dim = (self.num_agents * 2) + 5 * self.num_agents # 状态维度 # 冲突记忆机制 self.recent_conflict_pairs = [] self.conflict_memory_window = 10 self.continuous_conflict_steps = 0 self.reset() def reset(self, episode=0): self.agents_pos = self.initial_centers.copy() self.episode_conflicts = [] self.previous_conflict_count = None self.previous_positions = self.agents_pos.copy() self.boundary_violations = [0] * self.num_agents self.line_violations = [0] * self.num_agents self.continuous_conflict_free_steps = np.zeros(self.num_agents) self.stay_count = [0] * self.num_agents self.recent_conflict_pairs = [] self.continuous_boundary_steps = np.zeros(self.num_agents) # 动态调整折线容忍度 if episode > 200: self.line_tolerance = min( self.initial_line_tolerance + self.tolerance_increase_rate * (episode - 200), self.max_line_tolerance ) else: self.line_tolerance = self.initial_line_tolerance return self.get_state() def normalize_state(self, state): norm_state = state.copy() for i in range(0, self.num_agents * 2, 2): norm_state[i] = (norm_state[i] - self.grid_left) / self.grid_width norm_state[i + 1] = (norm_state[i + 1] - self.grid_bottom) / self.grid_height return norm_state def get_state(self): # 状态组成:归一化位置 + 边界距离 + 最近智能体距离 + 冲突标记 + 边界违规标记 + 折线违规标记 pos_state = self.normalize_state(self.agents_pos.flatten()) # 1. 智能体到折线边界的距离(归一化) boundary_dist = [] for i in range(self.num_agents): min_x, max_x, min_y, max_y = self.agent_lines[i]["bounds"] pos = self.agents_pos[i] min_dist = min(pos[0] - min_x, max_x - pos[0], pos[1] - min_y, max_y - pos[1]) line_len = self.agent_lines[i]["length"] boundary_dist.append(min_dist / (line_len / 2) if line_len > 0 else 0) # 2. 每个智能体到其他智能体的最近距离(归一化) agent_dist = [] for i in range(self.num_agents): min_dist = float('inf') for j in range(self.num_agents): if i == j: continue dist = np.linalg.norm(self.agents_pos[i] - self.agents_pos[j]) if dist < min_dist: min_dist = dist agent_dist.append(min_dist / self.max_agent_dist) # 3. 冲突标记(1=冲突,0=无冲突) conflict_flag = [0.0] * self.num_agents current_conflicts = self.detect_conflicts() if current_conflicts > 0: for i in range(self.num_agents): for j in range(i + 1, self.num_agents): if np.linalg.norm(self.agents_pos[i] - self.agents_pos[j]) < self.min_agent_dist: conflict_flag[i] = 1.0 conflict_flag[j] = 1.0 # 4. 边界违规标记(1=违规,0=合规) violation_flag = [0.0] * self.num_agents for i in range(self.num_agents): if self.check_boundary_violation(i, self.agents_pos[i]) > 0: violation_flag[i] = 1.0 # 5. 折线违规标记(1=违规,0=合规) line_violation_flag = [0.0] * self.num_agents for i in range(self.num_agents): polyline = self.agent_lines[i]["original_points"] min_dist_to_line = float('inf') for j in range(len(polyline) - 1): p1 = polyline[j] p2 = polyline[j + 1] dist = point_to_line_distance(self.agents_pos[i], p1, p2) if dist < min_dist_to_line: min_dist_to_line = dist if min_dist_to_line > self.line_tolerance: line_violation_flag[i] = 1.0 return np.concatenate([pos_state, boundary_dist, agent_dist, conflict_flag, violation_flag, line_violation_flag]) def get_agents_pos(self): return self.agents_pos.copy() def get_initial_positions(self): return self.initial_centers.copy() def get_agent_shapes(self): """获取智能体的原始形状""" return self.grouped_initial_positions.copy() def detect_conflicts(self): conflict_count = 0 rect_diag = 0.0005 # 调整为适合原始形状的对角线长度 for i in range(self.num_agents): for j in range(i + 1, self.num_agents): if np.linalg.norm(self.agents_pos[i] - self.agents_pos[j]) < (self.min_agent_dist + rect_diag): conflict_count += 1 return conflict_count def detect_all_conflicts(self): """检测所有冲突对和涉及的智能体""" conflict_count = 0 conflict_pairs = [] # 存储冲突对 (i,j),i < j 避免重复 conflict_agents = set() # 存储所有涉及冲突的智能体索引 rect_diag = 0.0005 # 调整为适合原始形状的对角线长度 safe_dist = self.min_agent_dist + rect_diag for i in range(self.num_agents): for j in range(i + 1, self.num_agents): if np.linalg.norm(self.agents_pos[i] - self.agents_pos[j]) < safe_dist: conflict_count += 1 conflict_pairs.append((i, j)) conflict_agents.add(i) conflict_agents.add(j) return conflict_count, conflict_pairs, conflict_agents @staticmethod def calculate_agent_distance(agent1_pos, agent2_pos): return np.linalg.norm(agent1_pos - agent2_pos) def check_boundary_violation(self, agent_idx, pos): line = self.agent_lines[agent_idx] min_x, max_x, min_y, max_y = line["bounds"] if pos[0] < min_x or pos[0] > max_x or pos[1] < min_y or pos[1] > max_y: dx = max(min_x - pos[0], 0, pos[0] - max_x) dy = max(min_y - pos[1], 0, pos[1] - max_y) return np.sqrt(dx ** 2 + dy ** 2) return 0.0 def validate_position(self, agent_idx, current_pos, action): line = self.agent_lines[agent_idx] line_vec = line["vec"] # 计算目标位置(沿折线方向) if action == 0: target_pos = current_pos + line_vec * self.slide_step elif action == 1: target_pos = current_pos - line_vec * self.slide_step else: target_pos = current_pos.copy() # 不动保留原始位置 # 用局部边界裁剪 min_x, max_x, min_y, max_y = line["bounds"] clamped_x = np.clip(target_pos[0], min_x, max_x) clamped_y = np.clip(target_pos[1], min_y, max_y) valid_pos = np.array([clamped_x, clamped_y]) # 弱化折线合规性:允许偏离折线,但不超出局部边界 is_line_valid = True # 冲突检测与避障 is_dist_valid = True rect_diag = 0.0005 # 调整为适合原始形状的对角线长度 safe_dist = self.min_agent_dist + rect_diag for other_idx in range(self.num_agents): if other_idx == agent_idx: continue dist = np.linalg.norm(valid_pos - self.agents_pos[other_idx]) if dist < safe_dist: is_dist_valid = False conflict_dir = valid_pos - self.agents_pos[other_idx] # 优先沿折线方向避障 line_dir = line_vec perp_dir1 = np.array([-line_vec[1], line_vec[0]]) perp_dir2 = np.array([line_vec[1], -line_vec[0]]) forward_pos = valid_pos + line_dir * (safe_dist - dist + 1e-8) backward_pos = valid_pos - line_dir * (safe_dist - dist + 1e-8) if np.linalg.norm(forward_pos - self.agents_pos[other_idx]) >= safe_dist: adjust_dir = line_dir elif np.linalg.norm(backward_pos - self.agents_pos[other_idx]) >= safe_dist: adjust_dir = -line_dir else: adjust_dir = perp_dir1 if np.dot(conflict_dir, perp_dir1) > np.dot(conflict_dir, perp_dir2) else perp_dir2 adjust_dist = safe_dist - dist + 1e-8 valid_pos += adjust_dir * adjust_dist # 重新裁剪到局部边界 valid_pos[0] = np.clip(valid_pos[0], min_x, max_x) valid_pos[1] = np.clip(valid_pos[1], min_y, max_y) # 停留计数 if action == 2: self.stay_count[agent_idx] += 1 else: self.stay_count[agent_idx] = 0 return valid_pos, (target_pos[0] >= min_x and target_pos[0] <= max_x and target_pos[1] >= min_y and target_pos[ 1] <= max_y), is_dist_valid, is_line_valid def step(self, actions, step, max_steps_per_episode): rewards = np.zeros(self.num_agents) action_mapping = {0: 0, 1: 1, 2: 2} scale_factor = 0.25 total_global_reward = 0.0 # 1. 统一的冲突检测 current_conflicts, conflict_pairs, conflict_agents = self.detect_all_conflicts() repeat_conflict = sum(1 for pair in conflict_pairs if pair in self.recent_conflict_pairs) # 2. 智能体位置更新与完整状态统计 boundary_violation_count = 0 line_violation_count = 0 boundary_violation_degree = 0.0 line_violation_degree = 0.0 for i in range(self.num_agents): action = action_mapping[actions[i]] valid_pos, is_boundary_valid, is_dist_valid, is_line_valid = self.validate_position( i, self.agents_pos[i], action) # 完整的边界违规统计 boundary_viol_dist = self.check_boundary_violation(i, valid_pos) if boundary_viol_dist > 0: boundary_violation_count += 1 boundary_violation_degree += boundary_viol_dist self.boundary_violations[i] += 1 # 完整的折线违规统计 polyline = self.agent_lines[i]["original_points"] min_dist_to_line = float('inf') for j in range(len(polyline) - 1): dist = point_to_line_distance(valid_pos, polyline[j], polyline[j + 1]) if dist < min_dist_to_line: min_dist_to_line = dist if min_dist_to_line > self.line_tolerance: line_violation_count += 1 line_violation_degree += (min_dist_to_line - self.line_tolerance) self.line_violations[i] += 1 self.agents_pos[i] = valid_pos # 3. 检测新状态下的冲突 new_conflicts, new_conflict_pairs, new_conflict_agents = self.detect_all_conflicts() conflict_diff = new_conflicts - current_conflicts # 4. 平衡的奖惩计算 total_reward = 0.0 # 4.1 冲突奖惩(权重:40%) if new_conflicts == 0: global_conflict_reward = 6.0 * scale_factor elif conflict_diff < 0: global_conflict_reward = 4.0 * scale_factor * (-conflict_diff) else: global_conflict_reward = -5.0 * scale_factor * max(0, conflict_diff) total_global_reward += global_conflict_reward * 0.40 # 2. 全局边界奖励 if boundary_violation_count == 0: global_boundary_reward = 4.0 * scale_factor else: avg_viol_degree = boundary_violation_degree / max(1, boundary_violation_count) global_boundary_reward = -3.0 * scale_factor * boundary_violation_count * (1 + avg_viol_degree) total_global_reward += global_boundary_reward * 0.35 # 3. 全局折线奖励 if line_violation_count == 0: global_line_reward = 2.0 * scale_factor else: avg_line_degree = line_violation_degree / max(1, line_violation_count) global_line_reward = -1.5 * scale_factor * line_violation_count * (1 + avg_line_degree) total_global_reward += global_line_reward * 0.30 # 4. 个体奖励调整 individual_adjustments = np.zeros(self.num_agents) for i in range(self.num_agents): # 冲突个体额外惩罚 if i in new_conflict_agents: individual_adjustments[i] -= 2.0 * scale_factor # 边界合规个体额外奖励 boundary_viol_dist = self.check_boundary_violation(i, self.agents_pos[i]) if boundary_viol_dist == 0: individual_adjustments[i] += 1.0 * scale_factor # 有效移动奖励 move_dist = np.linalg.norm(self.agents_pos[i] - self.previous_positions[i]) if actions[i] != 2 and move_dist > self.slide_step * 0.5: individual_adjustments[i] += 0.5 * scale_factor # 5. 最终奖励分配 base_global_reward = total_global_reward / self.num_agents for i in range(self.num_agents): rewards[i] = base_global_reward + individual_adjustments[i] # 判断是否结束 done = (step + 1) >= max_steps_per_episode # 终局奖励 if done: milestone_bonus = 0.0 if new_conflicts == 0: milestone_bonus += 4.0 * scale_factor if boundary_violation_count == 0: milestone_bonus += 3.0 * scale_factor if line_violation_count == 0: milestone_bonus += 1.0 * scale_factor # 平均分配给所有智能体 milestone_bonus /= self.num_agents rewards += milestone_bonus return self.get_state(), rewards, done, { "conflicts": new_conflicts, "boundary_violations": boundary_violation_count, "line_violations": line_violation_count } # 网络模型 class PolicyNetwork(nn.Module): def __init__(self, state_dim, action_dim, num_agents, hidden_sizes=(128, 64, 32)): super().__init__() self.num_agents = num_agents self.action_dim = action_dim backbone_layers = [] in_features = state_dim for hidden_size in hidden_sizes: backbone_layers.append(nn.Linear(in_features, hidden_size)) backbone_layers.append(nn.LayerNorm(hidden_size)) backbone_layers.append(nn.LeakyReLU(0.1)) in_features = hidden_size self.backbone = nn.Sequential(*backbone_layers) self.agent_head = nn.Linear(in_features, num_agents * action_dim) def forward(self, x): features = self.backbone(x) logits = self.agent_head(features) probs = torch.softmax(logits.view(-1, self.num_agents, self.action_dim), dim=-1) return probs class ValueNetwork(nn.Module): def __init__(self, state_dim, hidden_sizes=(128, 64, 32)): super().__init__() layers = [ nn.Linear(state_dim, hidden_sizes[0]), nn.LayerNorm(hidden_sizes[0]), nn.LeakyReLU(0.1), nn.Dropout(0.1) ] for i in range(1, len(hidden_sizes)): layers.append(nn.Linear(hidden_sizes[i - 1], hidden_sizes[i])) layers.append(nn.LayerNorm(hidden_sizes[i])) layers.append(nn.LeakyReLU(0.1)) layers.append(nn.Linear(hidden_sizes[-1], 1)) self.net = nn.Sequential(*layers) nn.init.uniform_(self.net[-1].weight, -0.001, 0.001) nn.init.constant_(self.net[-1].bias, 0.0) def forward(self, x): return self.net(x) # PPO算法 class PPO: def __init__(self, state_dim, action_dim, num_agents, lr=3.5e-5, gamma=0.99, gae_lambda=0.90, epsilon=0.15, epochs=4, batch_size=64, ent_coef=0.4): self.policy_net = PolicyNetwork(state_dim, action_dim, num_agents) self.value_net = ValueNetwork(state_dim) self.policy_optimizer = optim.Adam(self.policy_net.parameters(), lr=lr, eps=1e-8) self.value_optimizer = optim.Adam(self.value_net.parameters(), lr=lr * 0.5, eps=1e-8) self.num_agents = num_agents self.action_dim = action_dim self.gamma = gamma self.gae_lambda = gae_lambda self.epsilon = epsilon self.epochs = epochs self.batch_size = batch_size self.ent_coef = ent_coef # 熵系数(训练中衰减) self.memory = [] def select_action(self, state): state_tensor = torch.FloatTensor(state).unsqueeze(0) probs = self.policy_net(state_tensor) actions = [] log_probs_list = [] for i in range(self.num_agents): agent_probs = probs[0, i, :] dist = Categorical(agent_probs) action = dist.sample() actions.append(action.item()) log_probs_list.append(dist.log_prob(action)) log_probs_tensor = torch.stack(log_probs_list).detach() return actions, log_probs_tensor def store_transition(self, transition): self.memory.append(transition) def update(self): n_samples = len(self.memory) if n_samples == 0: return 0.0, 0.0 states = torch.FloatTensor(np.array([t[0] for t in self.memory])) actions = torch.LongTensor(np.array([t[1] for t in self.memory])) old_log_probs = torch.stack([t[2] for t in self.memory]) rewards = torch.FloatTensor(np.array([t[3] for t in self.memory])) next_states = torch.FloatTensor(np.array([t[4] for t in self.memory])) dones = torch.FloatTensor(np.array([t[5] for t in self.memory])).view(-1, 1) global_rewards = rewards.mean(dim=1, keepdim=True) with torch.no_grad(): values = self.value_net(states) next_values = self.value_net(next_states) deltas = global_rewards + self.gamma * next_values * (1 - dones) - values advantages = torch.zeros_like(deltas) advantage = 0.0 for t in reversed(range(len(deltas))): advantage = deltas[t] + self.gamma * self.gae_lambda * advantage * (1 - dones[t]) advantages[t] = advantage returns = advantages + values returns = (returns[:-2] + returns[1:-1] + returns[2:]) / 3 # 3步滑动平均 advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) total_policy_loss = 0.0 total_value_loss = 0.0 update_count = 0 for _ in range(self.epochs): permutation = torch.randperm(n_samples) num_batches = max(1, (n_samples + self.batch_size - 1) // self.batch_size) permutation = permutation.repeat(num_batches)[:num_batches * self.batch_size] for start in range(0, len(permutation), self.batch_size): end = start + self.batch_size batch_indices = permutation[start:end] if len(batch_indices) == 0: continue max_return_idx = len(returns) - 1 valid_indices = batch_indices[batch_indices <= max_return_idx] if len(valid_indices) == 0: continue b_states = states[valid_indices] b_actions = actions[valid_indices] b_old_log_probs = old_log_probs[valid_indices] b_advantages = advantages[valid_indices] b_returns = returns[valid_indices] # 策略网络更新 self.policy_optimizer.zero_grad() b_new_probs = self.policy_net(b_states) b_new_log_probs_list = [] b_entropies_list = [] for i in range(self.num_agents): dist_i = Categorical(b_new_probs[:, i, :]) log_prob_i = dist_i.log_prob(b_actions[:, i]) entropy_i = dist_i.entropy() b_new_log_probs_list.append(log_prob_i.unsqueeze(1)) b_entropies_list.append(entropy_i.unsqueeze(1)) b_new_log_probs = torch.cat(b_new_log_probs_list, dim=1) b_total_entropy = torch.cat(b_entropies_list, dim=1).mean(dim=1, keepdim=True) ratio = torch.exp(b_new_log_probs - b_old_log_probs) surr1 = ratio * b_advantages surr2 = torch.clamp(ratio, 1 - self.epsilon, 1 + self.epsilon) * b_advantages ppo_loss = -torch.min(surr1, surr2).mean() entropy_loss = -self.ent_coef * b_total_entropy.mean() policy_loss = ppo_loss + entropy_loss policy_loss.backward() torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), max_norm=0.5) self.policy_optimizer.step() # 价值网络更新 self.value_optimizer.zero_grad() value_preds = self.value_net(b_states) huber_loss = nn.HuberLoss(delta=0.3) value_loss = huber_loss(value_preds, b_returns.detach()) l2_lambda = 2e-5 l2_reg = torch.tensor(0., requires_grad=True) for param in self.value_net.parameters(): l2_reg = l2_reg + torch.norm(param, p=2) ** 2 value_loss = value_loss + l2_lambda * l2_reg value_loss.backward() torch.nn.utils.clip_grad_norm_(self.value_net.parameters(), max_norm=0.5) self.value_optimizer.step() total_policy_loss += policy_loss.item() total_value_loss += value_loss.item() update_count += 1 self.memory.clear() avg_policy_loss = total_policy_loss / update_count if update_count > 0 else 0.0 avg_value_loss = total_value_loss / update_count if update_count > 0 else 0.0 return avg_policy_loss, avg_value_loss # 训练函数 def train_ppo(env, agent, episodes=800, max_steps_per_episode=60, update_every=6): reward_history = [] policy_loss_history = [] value_loss_history = [] conflict_history = [] boundary_violation_history = [] line_violation_history = [] initial_positions = env.get_initial_positions() final_positions = None agent_shapes = env.get_agent_shapes() # 获取智能体原始形状 plt.ion() fig, ax = plt.subplots(figsize=(8, 6)) render_interval = 10 policy_lr_scheduler = CosineAnnealingLR(agent.policy_optimizer, T_max=600, eta_min=8e-7) value_lr_scheduler = CosineAnnealingLR(agent.value_optimizer, T_max=600, eta_min=4e-7) for episode in range(episodes): state = env.reset(episode) done = False total_reward = 0.0 step = 0 episode_conflicts = [] episode_boundary_violations = [] episode_line_violations = [] agent.memory.clear() # 熵系数衰减 max_ent = 0.15 min_ent = 0.01 decay_rate = (max_ent - min_ent) / 400 agent.ent_coef = max(min_ent, max_ent - episode * decay_rate) while not done and step < max_steps_per_episode: actions, log_probs = agent.select_action(state) next_state, rewards, done, info = env.step(actions, step, max_steps_per_episode) total_reward += np.sum(rewards) episode_conflicts.append(info['conflicts']) episode_boundary_violations.append(info['boundary_violations']) episode_line_violations.append(info['line_violations']) agent.store_transition((state, actions, log_probs, rewards, next_state, float(done))) if step % render_interval == 0 or done: ax.clear() agents_pos = np.array(env.get_agents_pos()) agent_lines = env.agent_lines # 绘制折线(原始折线+局部运动线段) for i, line in enumerate(agent_lines): # 绘制原始折线(灰色细线条) polyline = line["original_points"] ax.plot(polyline[:, 0], polyline[:, 1], '-', color="gray", linewidth=1, alpha=0.5, label='原始折线' if i == 0 else "") # 绘制局部运动线段(黑色粗线条) local_poly = line["local_points"] ax.plot(local_poly[:, 0], local_poly[:, 1], '-', color="black", linewidth=2, alpha=0.8, label='局部运动轨道' if i == 0 else "") # 绘制智能体(使用原始形状) for i, (lon, lat) in enumerate(agents_pos): # 获取智能体原始形状的相对坐标 shape_coords = agent_shapes[i] # 计算中心到各顶点的偏移量 offsets = shape_coords - initial_positions[i] # 根据当前位置调整形状坐标 current_shape = np.array([[lon + dx, lat + dy] for dx, dy in offsets]) # 绘制原始形状 shape_patch = plt.Polygon( current_shape, facecolor='none', edgecolor='black', linewidth=2, alpha=0.7 ) ax.add_patch(shape_patch) ax.text(lon, lat, str(i + 1), ha='center', va='center', fontweight='bold') # 标记冲突(红色边框) current_conflict = info['conflicts'] if current_conflict > 0: conflict_pairs = [] rect_diag = 0.0013 # 适合原始形状的对角线长度 for i in range(env.num_agents): for j in range(i + 1, env.num_agents): if np.linalg.norm(agents_pos[i] - agents_pos[j]) < (env.min_agent_dist + rect_diag): conflict_pairs.append(i) conflict_pairs.append(j) for idx in set(conflict_pairs): lon, lat = agents_pos[idx] shape_coords = agent_shapes[idx] offsets = shape_coords - initial_positions[idx] current_shape = np.array([[lon + dx, lat + dy] for dx, dy in offsets]) conflict_patch = plt.Polygon( current_shape, facecolor='none', edgecolor='red', linewidth=2, alpha=0.7 ) ax.add_patch(conflict_patch) # 标记线违规(橙色边框) current_line_viol = info['line_violations'] if current_line_viol > 0: for i in range(env.num_agents): line = agent_lines[i] polyline = line["original_points"] min_dist_to_line = float('inf') for j in range(len(polyline) - 1): p1 = polyline[j] p2 = polyline[j + 1] dist = point_to_line_distance(agents_pos[i], p1, p2) if dist < min_dist_to_line: min_dist_to_line = dist if min_dist_to_line > env.line_tolerance: lon, lat = agents_pos[i] shape_coords = agent_shapes[i] offsets = shape_coords - initial_positions[i] current_shape = np.array([[lon + dx, lat + dy] for dx, dy in offsets]) line_viol_patch = plt.Polygon( current_shape, facecolor='none', edgecolor='orange', linewidth=2, alpha=0.7 ) ax.add_patch(line_viol_patch) ax.set_xlim(env.grid_left, env.grid_right) ax.set_ylim(env.grid_bottom, env.grid_top) ax.set_title( f"回合: {episode + 1}, 步数: {step + 1}, 冲突: {current_conflict}, " f"边界违规: {info['boundary_violations']}, 线违规: {current_line_viol}" ) ax.set_xlabel("经度") ax.set_ylabel("纬度") ax.legend() plt.pause(0.00001) # 定期更新模型 if (step + 1) % update_every == 0 or done: policy_loss, value_loss = agent.update() policy_loss_history.append(policy_loss) value_loss_history.append(value_loss) state = next_state step += 1 policy_lr_scheduler.step() value_lr_scheduler.step() reward_history.append(total_reward) conflict_history.append(np.mean(episode_conflicts) if episode_conflicts else 0) boundary_violation_history.append(np.mean(episode_boundary_violations) if episode_boundary_violations else 0) line_violation_history.append(np.mean(episode_line_violations) if episode_line_violations else 0) if episode % 10 == 0: latest_policy_loss = policy_loss_history[-1] if policy_loss_history else 0.0 latest_value_loss = value_loss_history[-1] if value_loss_history else 0.0 current_lr = agent.policy_optimizer.param_groups[0]['lr'] print(f"回合 {episode:4d}, 总奖励: {total_reward:6.2f}, " f"冲突: {np.mean(episode_conflicts):.2f}, 边界违规: {np.mean(episode_boundary_violations):.2f}, " f"线违规: {np.mean(episode_line_violations):.2f}, " f"策略损失: {latest_policy_loss:.4f}, 价值损失: {latest_value_loss:.4f}, 学习率: {current_lr:.1e}") final_positions = env.get_agents_pos() plt.ioff() plt.close() torch.save(agent.policy_net.state_dict(), "rect_ppo_policy_polyline_optimized.pth") torch.save(agent.value_net.state_dict(), "rect_ppo_value_polyline_optimized.pth") print("\n训练完成!优化后的模型已保存") return (reward_history, policy_loss_history, value_loss_history, conflict_history, boundary_violation_history, line_violation_history, initial_positions, final_positions, env.agent_lines, agent_shapes) # 初始/最终位置可视化(分开保存) def plot_initial_positions(env, initial_pos, agent_lines, agent_shapes): """绘制并保存初始位置图""" rect_diag = 0.0013 # 适合原始形状的对角线长度 plt.figure(figsize=(10, 8)) ax = plt.gca() for i, line in enumerate(agent_lines): polyline = line["original_points"] ax.plot(polyline[:, 0], polyline[:, 1], '-', color="gray", linewidth=1, alpha=0.5) local_poly = line["local_points"] ax.plot(local_poly[:, 0], local_poly[:, 1], '-', color="black", linewidth=2, alpha=0.8, label='局部运动轨道' if i == 0 else "") # 绘制智能体原始形状 for i, (lon, lat) in enumerate(initial_pos): shape_coords = agent_shapes[i] offsets = shape_coords - initial_pos[i] current_shape = np.array([[lon + dx, lat + dy] for dx, dy in offsets]) shape_patch = plt.Polygon( current_shape, facecolor='none', edgecolor='black', linewidth=2, alpha=0.7 ) ax.add_patch(shape_patch) ax.text(lon, lat, str(i + 1), ha='center', va='center', fontweight='bold', fontsize=10) # 标记冲突智能体 initial_conflict_idx = set() for i in range(env.num_agents): for j in range(i + 1, env.num_agents): if np.linalg.norm(initial_pos[i] - initial_pos[j]) < (env.min_agent_dist + rect_diag): initial_conflict_idx.add(i) initial_conflict_idx.add(j) for idx in initial_conflict_idx: lon, lat = initial_pos[idx] shape_coords = agent_shapes[idx] offsets = shape_coords - initial_pos[idx] current_shape = np.array([[lon + dx, lat + dy] for dx, dy in offsets]) conflict_patch = plt.Polygon( current_shape, facecolor='none', edgecolor='red', linewidth=2, alpha=0.7 ) ax.add_patch(conflict_patch) ax.set_xlim(env.grid_left, env.grid_right) ax.set_ylim(env.grid_bottom, env.grid_top) ax.set_title(f"初始位置分布(冲突智能体数:{len(initial_conflict_idx)})") ax.set_xlabel("经度") ax.set_ylabel("纬度") ax.legend() ax.grid(True, linestyle='--', alpha=0.7) plt.tight_layout() plt.savefig("initial_positions_polyline_optimized.png", dpi=300) plt.close() print("初始位置图已保存为 'initial_positions_polyline_optimized.png'") def plot_final_positions(env, final_pos, agent_lines, agent_shapes): """绘制并保存最终位置图""" rect_diag = 0.0013 # 适合原始形状的对角线长度 plt.figure(figsize=(10, 8)) ax = plt.gca() for i, line in enumerate(agent_lines): polyline = line["original_points"] ax.plot(polyline[:, 0], polyline[:, 1], '-', color="gray", linewidth=1, alpha=0.5) local_poly = line["local_points"] ax.plot(local_poly[:, 0], local_poly[:, 1], '-', color="black", linewidth=2, alpha=0.8, label='局部运动轨道' if i == 0 else "") # 绘制智能体原始形状 for i, (lon, lat) in enumerate(final_pos): shape_coords = agent_shapes[i] offsets = shape_coords - env.get_initial_positions()[i] # 基于初始位置计算偏移 current_shape = np.array([[lon + dx, lat + dy] for dx, dy in offsets]) shape_patch = plt.Polygon( current_shape, facecolor='none', edgecolor='black', linewidth=2, alpha=0.7 ) ax.add_patch(shape_patch) ax.text(lon, lat, str(i + 1), ha='center', va='center', fontweight='bold', fontsize=10) # 标记冲突智能体 final_conflict_idx = set() for i in range(env.num_agents): for j in range(i + 1, env.num_agents): if np.linalg.norm(final_pos[i] - final_pos[j]) < (env.min_agent_dist + rect_diag): final_conflict_idx.add(i) final_conflict_idx.add(j) for idx in final_conflict_idx: lon, lat = final_pos[idx] shape_coords = agent_shapes[idx] offsets = shape_coords - env.get_initial_positions()[idx] current_shape = np.array([[lon + dx, lat + dy] for dx, dy in offsets]) conflict_patch = plt.Polygon( current_shape, facecolor='none', edgecolor='red', linewidth=2, alpha=0.7 ) ax.add_patch(conflict_patch) # 标记线违规智能体 line_viol_idx = set() for i in range(env.num_agents): line = agent_lines[i] polyline = line["original_points"] min_dist_to_line = float('inf') for j in range(len(polyline) - 1): p1 = polyline[j] p2 = polyline[j + 1] dist = point_to_line_distance(final_pos[i], p1, p2) if dist < min_dist_to_line: min_dist_to_line = dist if min_dist_to_line > env.line_tolerance: line_viol_idx.add(i) for idx in line_viol_idx: lon, lat = final_pos[idx] shape_coords = agent_shapes[idx] offsets = shape_coords - env.get_initial_positions()[idx] current_shape = np.array([[lon + dx, lat + dy] for dx, dy in offsets]) line_viol_patch = plt.Polygon( current_shape, facecolor='none', edgecolor='orange', linewidth=2, alpha=0.7 ) ax.add_patch(line_viol_patch) ax.set_xlim(env.grid_left, env.grid_right) ax.set_ylim(env.grid_bottom, env.grid_top) ax.set_title(f"最终位置分布(冲突:{len(final_conflict_idx)},线违规:{len(line_viol_idx)})") ax.set_xlabel("经度") ax.set_ylabel("纬度") ax.legend() ax.grid(True, linestyle='--', alpha=0.7) plt.tight_layout() plt.savefig("final_positions_polyline_optimized.png", dpi=300) plt.close() print("最终位置图已保存为 'final_positions_polyline_optimized.png'") # 结果绘图函数 def plot_results(reward_history, policy_loss, value_loss, conflict_history, boundary_violation_history, line_violation_history): plt.figure(figsize=(10, 4)) plt.plot(reward_history, alpha=0.5, label="单回合奖励") if len(reward_history) >= 10: moving_avg = np.convolve(reward_history, np.ones(10) / 10, mode='valid') plt.plot(range(9, len(reward_history)), moving_avg, label="10回合平均") plt.title("奖励曲线(优化后)") plt.xlabel("回合") plt.ylabel("总奖励") plt.legend() plt.grid() plt.savefig("reward_polyline_optimized.png", dpi=300) plt.close() plt.figure(figsize=(10, 4)) plt.plot(conflict_history, label="平均冲突数", color='blue') plt.plot(boundary_violation_history, label="平均边界违规数", color='orange') plt.plot(line_violation_history, label="平均线违规数", color='red') plt.axhline(y=0, color='black', linestyle='--') plt.title("冲突与违规曲线(优化后)") plt.xlabel("回合") plt.ylabel("数量") plt.legend() plt.grid() plt.savefig("conflict_violation_polyline_optimized.png", dpi=300) plt.close() plt.figure(figsize=(10, 4)) if len(policy_loss) >= 10: policy_smoothed = np.convolve(policy_loss, np.ones(10) / 10, mode='valid') plt.plot(range(9, len(policy_loss)), policy_smoothed, label="策略损失(平滑)", color='green') plt.plot(policy_loss, alpha=0.3, color='green') plt.title("策略损失曲线(优化后)") plt.xlabel("更新步数") plt.ylabel("损失值") plt.legend() plt.grid() plt.savefig("policy_loss_polyline_optimized.png", dpi=300) plt.close() plt.figure(figsize=(10, 4)) if len(value_loss) >= 10: value_smoothed = np.convolve(value_loss, np.ones(10) / 10, mode='valid') plt.plot(range(9, len(value_loss)), value_smoothed, label="价值损失(平滑)", color='purple') plt.plot(value_loss, alpha=0.3, color='purple') plt.title("价值损失曲线(优化后)") plt.xlabel("更新步数") plt.ylabel("损失值") plt.legend() plt.grid() plt.savefig("value_loss_polyline_optimized.png", dpi=300) plt.close() plt.figure(figsize=(12, 6)) if len(reward_history) > 0 and len(conflict_history) > 0 and len(value_loss) > 0: norm_reward = (np.array(reward_history) - np.min(reward_history)) / ( np.max(reward_history) - np.min(reward_history) + 1e-8) norm_conflict = 1 - (np.array(conflict_history) - np.min(conflict_history)) / ( np.max(conflict_history) - np.min(conflict_history) + 1e-8) norm_line_viol = 1 - (np.array(line_violation_history) - np.min(line_violation_history)) / ( np.max(line_violation_history) - np.min(line_violation_history) + 1e-8) trunc_value_loss = value_loss[:len(reward_history)] norm_value_loss = (np.array(trunc_value_loss) - np.min(trunc_value_loss)) / ( np.max(trunc_value_loss) - np.min(trunc_value_loss) + 1e-8) plt.plot(norm_reward, label="归一化总奖励", color='blue') plt.plot(norm_conflict, label="归一化冲突值(反向)", color='green') plt.plot(norm_line_viol, label="归一化线违规值(反向)", color='orange') plt.plot(norm_value_loss, label="归一化价值损失", color='red', alpha=0.7) plt.title("总奖励-冲突-线违规-价值损失联动图(优化后)") plt.xlabel("回合") plt.ylabel("归一化值") plt.legend() plt.grid() plt.savefig("correlation_polyline_optimized.png", dpi=300) plt.close() # 主函数 def main(): # 智能体矩形坐标(每组4个顶点,定义智能体原始形状) your_coordinates = [ [121.44042, 31.323465], [121.440783, 31.323465], [121.440783, 31.323964], [121.44042, 31.323964], [121.439167, 31.31262], [121.440305, 31.31262], [121.440305, 31.313174], [121.439167, 31.313174], [121.45059, 31.311141], [121.451727, 31.311141], [121.451727, 31.311694], [121.45059, 31.311694], [121.442881, 31.31078], [121.44389, 31.31078], [121.44389, 31.311334], [121.442881, 31.311334], [121.443881, 31.312954], [121.445019, 31.312954], [121.445019, 31.313508], [121.443881, 31.313508], [121.446896, 31.311852], [121.448033, 31.311852], [121.448033, 31.312406], [121.446896, 31.312406], [121.444236, 31.31119], [121.445245, 31.31119], [121.445245, 31.311744], [121.444236, 31.311744], [121.441675, 31.316022], [121.442684, 31.316022], [121.442684, 31.316576], [121.441675, 31.316576], [121.442911, 31.312575], [121.44392, 31.312575], [121.44392, 31.313129], [121.442911, 31.313129], [121.44394, 31.315784], [121.444949, 31.315784], [121.444949, 31.316338], [121.44394, 31.316338], [121.451557, 31.313325], [121.452666, 31.313325], [121.452666, 31.313868], [121.451557, 31.313868], [121.452448, 31.315506], [121.453935, 31.315506], [121.453935, 31.316004], [121.452448, 31.316004], [121.447553, 31.315458], [121.44904, 31.315458], [121.44904, 31.316001], [121.447553, 31.316001], [121.450557, 31.313825], [121.451666, 31.313825], [121.451666, 31.314368], [121.450557, 31.314368], [121.440305, 31.322271], [121.441057, 31.322271], [121.441057, 31.322825], [121.440305, 31.322825], [121.44275, 31.313706], [121.443759, 31.313706], [121.443759, 31.314259], [121.44275, 31.314259], [121.44653, 31.315487], [121.448056, 31.315487], [121.448056, 31.316041], [121.44653, 31.316041], [121.45081, 31.316222], [121.451948, 31.316222], [121.451948, 31.316776], [121.45081, 31.316776], [121.45282, 31.31282], [121.453958, 31.31282], [121.453958, 31.313374], [121.45282, 31.313374] ] # 多坐标点折线 your_custom_polylines = [ [ [121.444092, 31.310223], [121.443857, 31.31533], [121.443778, 31.316861], [121.443699, 31.317496], [121.443428, 31.31852], [121.443264, 31.319012], [121.442922, 31.319814], [121.441951, 31.322021], [121.437934, 31.33081], [121.436298, 31.334379], [121.434728, 31.337951], [121.43302, 31.342159] ], [ [121.450669, 31.315322], [121.451167, 31.315325], [121.452653, 31.315332], [121.453431, 31.315334], [121.453678, 31.315331], [121.45395, 31.315338], [121.454566, 31.315341], [121.455416, 31.315344] ], [ [121.447862, 31.307364], [121.447802, 31.310141], [121.447802, 31.312613], [121.447862, 31.315308], [121.447814, 31.31773], [121.447802, 31.318823], [121.447905, 31.319138], [121.448575, 31.319629], [121.451242, 31.320819] ], [[121.450669, 31.312621], [121.450666, 31.314346], [121.450669, 31.315322]], [ [121.432436, 31.342711], [121.43314, 31.34107], [121.434084, 31.338915], [121.436022, 31.334623], [121.437832, 31.330747], [121.440288, 31.325478], [121.441313, 31.323083], [121.441547, 31.322548], [121.442606, 31.320254], [121.443153, 31.318998], [121.443348, 31.31841], [121.44357, 31.317523], [121.44365, 31.316862], [121.443714, 31.315821], [121.443842, 31.312865], [121.443951, 31.310217], [121.4439, 31.308523], [121.44387, 31.307221] ], [[121.443661, 31.315258], [121.443701, 31.314217], [121.443712, 31.31393], [121.443765, 31.312525]], [ [121.444031, 31.31526], [121.445081, 31.315273], [121.445704, 31.315281], [121.447862, 31.315308], [121.448595, 31.315312], [121.44911, 31.315314], [121.450669, 31.315322] ], [[121.450669, 31.315322], [121.450918, 31.316447], [121.450993, 31.316816], [121.451116, 31.317412]], [ [121.455355, 31.312634], [121.454421, 31.312631], [121.453857, 31.31263], [121.453396, 31.312628], [121.452636, 31.312626], [121.450669, 31.312621] ], [ [121.443765, 31.312525], [121.442313, 31.312481], [121.441637, 31.312458], [121.440438, 31.312447], [121.439588, 31.312426], [121.439197, 31.312417], [121.438451, 31.3124], [121.43793, 31.312394], [121.437464, 31.312385], [121.436576, 31.312355] ], [[121.450738, 31.310214], [121.450669, 31.312621]], [[121.443765, 31.312525], [121.443853, 31.310818], [121.443904, 31.309829]], [[121.444225, 31.312532], [121.450669, 31.312621]], [ [121.450669, 31.312621], [121.449561, 31.312623], [121.448839, 31.312611], [121.447802, 31.312613], [121.447433, 31.312612], [121.445848, 31.312568], [121.444225, 31.312532] ], [[121.444302, 31.309839], [121.444348, 31.311538], [121.444266, 31.31196], [121.444225, 31.312532]], [[121.443661, 31.315258], [121.444031, 31.31526]], [[121.444225, 31.312532], [121.444142, 31.31276], [121.444051, 31.314857], [121.444031, 31.31526]], [[121.444031, 31.31526], [121.443948, 31.316862]], [[121.443585, 31.316781], [121.443658, 31.315359], [121.443661, 31.315258]] ] # 坐标分组 try: grouped_coords = group_coordinates_into_rectangles(your_coordinates, group_size=4) print(f"坐标分组完成:共{len(grouped_coords)}个矩形 → 对应{len(grouped_coords)}个智能体") except ValueError as e: print(f"坐标错误:{e}") return # 创建环境 env = GeoEnv( grouped_initial_positions=grouped_coords, global_step=2e-5, # 滑动步长 min_agent_distance=0.0005, # 安全距离 max_agent_distance=0.0001, custom_polylines=your_custom_polylines ) # 可视化智能体-折线分配关系 agent_lines = env.agent_lines visualize_agent_line_assignment(env.initial_centers, agent_lines, your_custom_polylines) # 初始化PPO state_dim = env.state_dim action_dim = 3 # 0=向前,1=向后,2=不动 agent = PPO( state_dim, action_dim, num_agents=env.num_agents, lr=3e-6, gamma=0.93, epochs=10, batch_size=512, ent_coef=0.15, epsilon=0.08, gae_lambda=0.80 ) # 开始训练 print("\n开始训练(优化版:保留原始位置,折线旁移动)...") (reward_history, policy_loss, value_loss, conflict_history, boundary_violation_history, line_violation_history, initial_pos, final_pos, agent_lines, agent_shapes) = train_ppo( env, agent, episodes=1200, max_steps_per_episode=100, update_every=30 ) # 绘制结果 plot_results(reward_history, policy_loss, value_loss, conflict_history, boundary_violation_history, line_violation_history) # 分开保存初始和最终位置图 plot_initial_positions(env, initial_pos, agent_lines, agent_shapes) plot_final_positions(env, final_pos, agent_lines, agent_shapes) if __name__ == "__main__": main() end_time = time.time() print(f"总运行间:{end_time - start_time:.2f}秒")
最新发布
11-16
logit_scale: cuda:0 text_model.embeddings.token_embedding.weight: cuda:0 text_model.embeddings.position_embedding.weight: cuda:0 text_model.encoder.layers.0.self_attn.k_proj.weight: cuda:0 text_model.encoder.layers.0.self_attn.k_proj.bias: cuda:0 text_model.encoder.layers.0.self_attn.v_proj.weight: cuda:0 text_model.encoder.layers.0.self_attn.v_proj.bias: cuda:0 text_model.encoder.layers.0.self_attn.q_proj.weight: cuda:0 text_model.encoder.layers.0.self_attn.q_proj.bias: cuda:0 text_model.encoder.layers.0.self_attn.out_proj.weight: cuda:0 text_model.encoder.layers.0.self_attn.out_proj.bias: cuda:0 text_model.encoder.layers.0.layer_norm1.weight: cuda:0 text_model.encoder.layers.0.layer_norm1.bias: cuda:0 text_model.encoder.layers.0.mlp.fc1.weight: cuda:0 text_model.encoder.layers.0.mlp.fc1.bias: cuda:0 text_model.encoder.layers.0.mlp.fc2.weight: cuda:0 text_model.encoder.layers.0.mlp.fc2.bias: cuda:0 text_model.encoder.layers.0.layer_norm2.weight: cuda:0 text_model.encoder.layers.0.layer_norm2.bias: cuda:0 text_model.encoder.layers.1.self_attn.k_proj.weight: cuda:0 text_model.encoder.layers.1.self_attn.k_proj.bias: cuda:0 text_model.encoder.layers.1.self_attn.v_proj.weight: cuda:0 text_model.encoder.layers.1.self_attn.v_proj.bias: cuda:0 text_model.encoder.layers.1.self_attn.q_proj.weight: cuda:0 text_model.encoder.layers.1.self_attn.q_proj.bias: cuda:0 text_model.encoder.layers.1.self_attn.out_proj.weight: cuda:0 text_model.encoder.layers.1.self_attn.out_proj.bias: cuda:0 text_model.encoder.layers.1.layer_norm1.weight: cuda:0 text_model.encoder.layers.1.layer_norm1.bias: cuda:0 text_model.encoder.layers.1.mlp.fc1.weight: cuda:0 text_model.encoder.layers.1.mlp.fc1.bias: cuda:0 text_model.encoder.layers.1.mlp.fc2.weight: cuda:0 text_model.encoder.layers.1.mlp.fc2.bias: cuda:0 text_model.encoder.layers.1.layer_norm2.weight: cuda:0 text_model.encoder.layers.1.layer_norm2.bias: cuda:0 text_model.encoder.layers.2.self_attn.k_proj.weight: cuda:0 text_model.encoder.layers.2.self_attn.k_proj.bias: cuda:0 text_model.encoder.layers.2.self_attn.v_proj.weight: cuda:0 text_model.encoder.layers.2.self_attn.v_proj.bias: cuda:0 text_model.encoder.layers.2.self_attn.q_proj.weight: cuda:0 text_model.encoder.layers.2.self_attn.q_proj.bias: cuda:0 text_model.encoder.layers.2.self_attn.out_proj.weight: cuda:0 text_model.encoder.layers.2.self_attn.out_proj.bias: cuda:0 text_model.encoder.layers.2.layer_norm1.weight: cuda:0 text_model.encoder.layers.2.layer_norm1.bias: cuda:0 text_model.encoder.layers.2.mlp.fc1.weight: cuda:0 text_model.encoder.layers.2.mlp.fc1.bias: cuda:0 text_model.encoder.layers.2.mlp.fc2.weight: cuda:0 text_model.encoder.layers.2.mlp.fc2.bias: cuda:0 text_model.encoder.layers.2.layer_norm2.weight: cuda:0 text_model.encoder.layers.2.layer_norm2.bias: cuda:0 text_model.encoder.layers.3.self_attn.k_proj.weight: cuda:0 text_model.encoder.layers.3.self_attn.k_proj.bias: cuda:0 text_model.encoder.layers.3.self_attn.v_proj.weight: cuda:0 text_model.encoder.layers.3.self_attn.v_proj.bias: cuda:0 text_model.encoder.layers.3.self_attn.q_proj.weight: cuda:0 text_model.encoder.layers.3.self_attn.q_proj.bias: cuda:0 text_model.encoder.layers.3.self_attn.out_proj.weight: cuda:0 text_model.encoder.layers.3.self_attn.out_proj.bias: cuda:0 text_model.encoder.layers.3.layer_norm1.weight: cuda:0 text_model.encoder.layers.3.layer_norm1.bias: cuda:0 text_model.encoder.layers.3.mlp.fc1.weight: cuda:0 text_model.encoder.layers.3.mlp.fc1.bias: cuda:0 text_model.encoder.layers.3.mlp.fc2.weight: cuda:0 text_model.encoder.layers.3.mlp.fc2.bias: cuda:0 text_model.encoder.layers.3.layer_norm2.weight: cuda:0 text_model.encoder.layers.3.layer_norm2.bias: cuda:0 text_model.encoder.layers.4.self_attn.k_proj.weight: cuda:0 text_model.encoder.layers.4.self_attn.k_proj.bias: cuda:0 text_model.encoder.layers.4.self_attn.v_proj.weight: cuda:0 text_model.encoder.layers.4.self_attn.v_proj.bias: cuda:0 text_model.encoder.layers.4.self_attn.q_proj.weight: cuda:0 text_model.encoder.layers.4.self_attn.q_proj.bias: cuda:0 text_model.encoder.layers.4.self_attn.out_proj.weight: cuda:0 text_model.encoder.layers.4.self_attn.out_proj.bias: cuda:0 text_model.encoder.layers.4.layer_norm1.weight: cuda:0 text_model.encoder.layers.4.layer_norm1.bias: cuda:0 text_model.encoder.layers.4.mlp.fc1.weight: cuda:0 text_model.encoder.layers.4.mlp.fc1.bias: cuda:0 text_model.encoder.layers.4.mlp.fc2.weight: cuda:0 text_model.encoder.layers.4.mlp.fc2.bias: cuda:0 text_model.encoder.layers.4.layer_norm2.weight: cuda:0 text_model.encoder.layers.4.layer_norm2.bias: cuda:0 text_model.encoder.layers.5.self_attn.k_proj.weight: cuda:0 text_model.encoder.layers.5.self_attn.k_proj.bias: cuda:0 text_model.encoder.layers.5.self_attn.v_proj.weight: cuda:0 text_model.encoder.layers.5.self_attn.v_proj.bias: cuda:0 text_model.encoder.layers.5.self_attn.q_proj.weight: cuda:0 text_model.encoder.layers.5.self_attn.q_proj.bias: cuda:0 text_model.encoder.layers.5.self_attn.out_proj.weight: cuda:0 text_model.encoder.layers.5.self_attn.out_proj.bias: cuda:0 text_model.encoder.layers.5.layer_norm1.weight: cuda:0 text_model.encoder.layers.5.layer_norm1.bias: cuda:0 text_model.encoder.layers.5.mlp.fc1.weight: cuda:0 text_model.encoder.layers.5.mlp.fc1.bias: cuda:0 text_model.encoder.layers.5.mlp.fc2.weight: cuda:0 text_model.encoder.layers.5.mlp.fc2.bias: cuda:0 text_model.encoder.layers.5.layer_norm2.weight: cuda:0 text_model.encoder.layers.5.layer_norm2.bias: cuda:0 text_model.encoder.layers.6.self_attn.k_proj.weight: cuda:0 text_model.encoder.layers.6.self_attn.k_proj.bias: cuda:0 text_model.encoder.layers.6.self_attn.v_proj.weight: cuda:0 text_model.encoder.layers.6.self_attn.v_proj.bias: cuda:0 text_model.encoder.layers.6.self_attn.q_proj.weight: cuda:0 text_model.encoder.layers.6.self_attn.q_proj.bias: cuda:0 text_model.encoder.layers.6.self_attn.out_proj.weight: cuda:0 text_model.encoder.layers.6.self_attn.out_proj.bias: cuda:0 text_model.encoder.layers.6.layer_norm1.weight: cuda:0 text_model.encoder.layers.6.layer_norm1.bias: cuda:0 text_model.encoder.layers.6.mlp.fc1.weight: cuda:0 text_model.encoder.layers.6.mlp.fc1.bias: cuda:0 text_model.encoder.layers.6.mlp.fc2.weight: cuda:0 text_model.encoder.layers.6.mlp.fc2.bias: cuda:0 text_model.encoder.layers.6.layer_norm2.weight: cuda:0 text_model.encoder.layers.6.layer_norm2.bias: cuda:0 text_model.encoder.layers.7.self_attn.k_proj.weight: cuda:0 text_model.encoder.layers.7.self_attn.k_proj.bias: cuda:0 text_model.encoder.layers.7.self_attn.v_proj.weight: cuda:0 text_model.encoder.layers.7.self_attn.v_proj.bias: cuda:0 text_model.encoder.layers.7.self_attn.q_proj.weight: cuda:0 text_model.encoder.layers.7.self_attn.q_proj.bias: cuda:0 text_model.encoder.layers.7.self_attn.out_proj.weight: cuda:0 text_model.encoder.layers.7.self_attn.out_proj.bias: cuda:0 text_model.encoder.layers.7.layer_norm1.weight: cuda:0 text_model.encoder.layers.7.layer_norm1.bias: cuda:0 text_model.encoder.layers.7.mlp.fc1.weight: cuda:0 text_model.encoder.layers.7.mlp.fc1.bias: cuda:0 text_model.encoder.layers.7.mlp.fc2.weight: cuda:0 text_model.encoder.layers.7.mlp.fc2.bias: cuda:0 text_model.encoder.layers.7.layer_norm2.weight: cuda:0 text_model.encoder.layers.7.layer_norm2.bias: cuda:0 text_model.encoder.layers.8.self_attn.k_proj.weight: cuda:0 text_model.encoder.layers.8.self_attn.k_proj.bias: cuda:0 text_model.encoder.layers.8.self_attn.v_proj.weight: cuda:0 text_model.encoder.layers.8.self_attn.v_proj.bias: cuda:0 text_model.encoder.layers.8.self_attn.q_proj.weight: cuda:0 text_model.encoder.layers.8.self_attn.q_proj.bias: cuda:0 text_model.encoder.layers.8.self_attn.out_proj.weight: cuda:0 text_model.encoder.layers.8.self_attn.out_proj.bias: cuda:0 text_model.encoder.layers.8.layer_norm1.weight: cuda:0 text_model.encoder.layers.8.layer_norm1.bias: cuda:0 text_model.encoder.layers.8.mlp.fc1.weight: cuda:0 text_model.encoder.layers.8.mlp.fc1.bias: cuda:0 text_model.encoder.layers.8.mlp.fc2.weight: cuda:0 text_model.encoder.layers.8.mlp.fc2.bias: cuda:0 text_model.encoder.layers.8.layer_norm2.weight: cuda:0 text_model.encoder.layers.8.layer_norm2.bias: cuda:0 text_model.encoder.layers.9.self_attn.k_proj.weight: cuda:0 text_model.encoder.layers.9.self_attn.k_proj.bias: cuda:0 text_model.encoder.layers.9.self_attn.v_proj.weight: cuda:0 text_model.encoder.layers.9.self_attn.v_proj.bias: cuda:0 text_model.encoder.layers.9.self_attn.q_proj.weight: cuda:0 text_model.encoder.layers.9.self_attn.q_proj.bias: cuda:0 text_model.encoder.layers.9.self_attn.out_proj.weight: cuda:0 text_model.encoder.layers.9.self_attn.out_proj.bias: cuda:0 text_model.encoder.layers.9.layer_norm1.weight: cuda:0 text_model.encoder.layers.9.layer_norm1.bias: cuda:0 text_model.encoder.layers.9.mlp.fc1.weight: cuda:0 text_model.encoder.layers.9.mlp.fc1.bias: cuda:0 text_model.encoder.layers.9.mlp.fc2.weight: cuda:0 text_model.encoder.layers.9.mlp.fc2.bias: cuda:0 text_model.encoder.layers.9.layer_norm2.weight: cuda:0 text_model.encoder.layers.9.layer_norm2.bias: cuda:0 text_model.encoder.layers.10.self_attn.k_proj.weight: cuda:0 text_model.encoder.layers.10.self_attn.k_proj.bias: cuda:0 text_model.encoder.layers.10.self_attn.v_proj.weight: cuda:0 text_model.encoder.layers.10.self_attn.v_proj.bias: cuda:0 text_model.encoder.layers.10.self_attn.q_proj.weight: cuda:0 text_model.encoder.layers.10.self_attn.q_proj.bias: cuda:0 text_model.encoder.layers.10.self_attn.out_proj.weight: cuda:0 text_model.encoder.layers.10.self_attn.out_proj.bias: cuda:0 text_model.encoder.layers.10.layer_norm1.weight: cuda:0 text_model.encoder.layers.10.layer_norm1.bias: cuda:0 text_model.encoder.layers.10.mlp.fc1.weight: cuda:0 text_model.encoder.layers.10.mlp.fc1.bias: cuda:0 text_model.encoder.layers.10.mlp.fc2.weight: cuda:0 text_model.encoder.layers.10.mlp.fc2.bias: cuda:0 text_model.encoder.layers.10.layer_norm2.weight: cuda:0 text_model.encoder.layers.10.layer_norm2.bias: cuda:0 text_model.encoder.layers.11.self_attn.k_proj.weight: cuda:0 text_model.encoder.layers.11.self_attn.k_proj.bias: cuda:0 text_model.encoder.layers.11.self_attn.v_proj.weight: cuda:0 text_model.encoder.layers.11.self_attn.v_proj.bias: cuda:0 text_model.encoder.layers.11.self_attn.q_proj.weight: cuda:0 text_model.encoder.layers.11.self_attn.q_proj.bias: cuda:0 text_model.encoder.layers.11.self_attn.out_proj.weight: cuda:0 text_model.encoder.layers.11.self_attn.out_proj.bias: cuda:0 text_model.encoder.layers.11.layer_norm1.weight: cuda:0 text_model.encoder.layers.11.layer_norm1.bias: cuda:0 text_model.encoder.layers.11.mlp.fc1.weight: cuda:0 text_model.encoder.layers.11.mlp.fc1.bias: cuda:0 text_model.encoder.layers.11.mlp.fc2.weight: cuda:0 text_model.encoder.layers.11.mlp.fc2.bias: cuda:0 text_model.encoder.layers.11.layer_norm2.weight: cuda:0 text_model.encoder.layers.11.layer_norm2.bias: cuda:0 text_model.final_layer_norm.weight: cuda:0 text_model.final_layer_norm.bias: cuda:0 vision_model.embeddings.class_embedding: cuda:0 vision_model.embeddings.patch_embedding.weight: cuda:0 vision_model.embeddings.position_embedding.weight: cuda:0 vision_model.pre_layrnorm.weight: cuda:0 vision_model.pre_layrnorm.bias: cuda:0 vision_model.encoder.layers.0.self_attn.k_proj.weight: cuda:0 vision_model.encoder.layers.0.self_attn.k_proj.bias: cuda:0 vision_model.encoder.layers.0.self_attn.v_proj.weight: cuda:0 vision_model.encoder.layers.0.self_attn.v_proj.bias: cuda:0 vision_model.encoder.layers.0.self_attn.q_proj.weight: cuda:0 vision_model.encoder.layers.0.self_attn.q_proj.bias: cuda:0 vision_model.encoder.layers.0.self_attn.out_proj.weight: cuda:0 vision_model.encoder.layers.0.self_attn.out_proj.bias: cuda:0 vision_model.encoder.layers.0.layer_norm1.weight: cuda:0 vision_model.encoder.layers.0.layer_norm1.bias: cuda:0 vision_model.encoder.layers.0.mlp.fc1.weight: cuda:0 vision_model.encoder.layers.0.mlp.fc1.bias: cuda:0 vision_model.encoder.layers.0.mlp.fc2.weight: cuda:0 vision_model.encoder.layers.0.mlp.fc2.bias: cuda:0 vision_model.encoder.layers.0.layer_norm2.weight: cuda:0 vision_model.encoder.layers.0.layer_norm2.bias: cuda:0 vision_model.encoder.layers.1.self_attn.k_proj.weight: cuda:0 vision_model.encoder.layers.1.self_attn.k_proj.bias: cuda:0 vision_model.encoder.layers.1.self_attn.v_proj.weight: cuda:0 vision_model.encoder.layers.1.self_attn.v_proj.bias: cuda:0 vision_model.encoder.layers.1.self_attn.q_proj.weight: cuda:0 vision_model.encoder.layers.1.self_attn.q_proj.bias: cuda:0 vision_model.encoder.layers.1.self_attn.out_proj.weight: cuda:0 vision_model.encoder.layers.1.self_attn.out_proj.bias: cuda:0 vision_model.encoder.layers.1.layer_norm1.weight: cuda:0 vision_model.encoder.layers.1.layer_norm1.bias: cuda:0 vision_model.encoder.layers.1.mlp.fc1.weight: cuda:0 vision_model.encoder.layers.1.mlp.fc1.bias: cuda:0 vision_model.encoder.layers.1.mlp.fc2.weight: cuda:0 vision_model.encoder.layers.1.mlp.fc2.bias: cuda:0 vision_model.encoder.layers.1.layer_norm2.weight: cuda:0 vision_model.encoder.layers.1.layer_norm2.bias: cuda:0 vision_model.encoder.layers.2.self_attn.k_proj.weight: cuda:0 vision_model.encoder.layers.2.self_attn.k_proj.bias: cuda:0 vision_model.encoder.layers.2.self_attn.v_proj.weight: cuda:0 vision_model.encoder.layers.2.self_attn.v_proj.bias: cuda:0 vision_model.encoder.layers.2.self_attn.q_proj.weight: cuda:0 vision_model.encoder.layers.2.self_attn.q_proj.bias: cuda:0 vision_model.encoder.layers.2.self_attn.out_proj.weight: cuda:0 vision_model.encoder.layers.2.self_attn.out_proj.bias: cuda:0 vision_model.encoder.layers.2.layer_norm1.weight: cuda:0 vision_model.encoder.layers.2.layer_norm1.bias: cuda:0 vision_model.encoder.layers.2.mlp.fc1.weight: cuda:0 vision_model.encoder.layers.2.mlp.fc1.bias: cuda:0 vision_model.encoder.layers.2.mlp.fc2.weight: cuda:0 vision_model.encoder.layers.2.mlp.fc2.bias: cuda:0 vision_model.encoder.layers.2.layer_norm2.weight: cuda:0 vision_model.encoder.layers.2.layer_norm2.bias: cuda:0 vision_model.encoder.layers.3.self_attn.k_proj.weight: cuda:0 vision_model.encoder.layers.3.self_attn.k_proj.bias: cuda:0 vision_model.encoder.layers.3.self_attn.v_proj.weight: cuda:0 vision_model.encoder.layers.3.self_attn.v_proj.bias: cuda:0 vision_model.encoder.layers.3.self_attn.q_proj.weight: cuda:0 vision_model.encoder.layers.3.self_attn.q_proj.bias: cuda:0 vision_model.encoder.layers.3.self_attn.out_proj.weight: cuda:0 vision_model.encoder.layers.3.self_attn.out_proj.bias: cuda:0 vision_model.encoder.layers.3.layer_norm1.weight: cuda:0 vision_model.encoder.layers.3.layer_norm1.bias: cuda:0 vision_model.encoder.layers.3.mlp.fc1.weight: cuda:0 vision_model.encoder.layers.3.mlp.fc1.bias: cuda:0 vision_model.encoder.layers.3.mlp.fc2.weight: cuda:0 vision_model.encoder.layers.3.mlp.fc2.bias: cuda:0 vision_model.encoder.layers.3.layer_norm2.weight: cuda:0 vision_model.encoder.layers.3.layer_norm2.bias: cuda:0 vision_model.encoder.layers.4.self_attn.k_proj.weight: cuda:0 vision_model.encoder.layers.4.self_attn.k_proj.bias: cuda:0 vision_model.encoder.layers.4.self_attn.v_proj.weight: cuda:0 vision_model.encoder.layers.4.self_attn.v_proj.bias: cuda:0 vision_model.encoder.layers.4.self_attn.q_proj.weight: cuda:0 vision_model.encoder.layers.4.self_attn.q_proj.bias: cuda:0 vision_model.encoder.layers.4.self_attn.out_proj.weight: cuda:0 vision_model.encoder.layers.4.self_attn.out_proj.bias: cuda:0 vision_model.encoder.layers.4.layer_norm1.weight: cuda:0 vision_model.encoder.layers.4.layer_norm1.bias: cuda:0 vision_model.encoder.layers.4.mlp.fc1.weight: cuda:0 vision_model.encoder.layers.4.mlp.fc1.bias: cuda:0 vision_model.encoder.layers.4.mlp.fc2.weight: cuda:0 vision_model.encoder.layers.4.mlp.fc2.bias: cuda:0 vision_model.encoder.layers.4.layer_norm2.weight: cuda:0 vision_model.encoder.layers.4.layer_norm2.bias: cuda:0 vision_model.encoder.layers.5.self_attn.k_proj.weight: cuda:0 vision_model.encoder.layers.5.self_attn.k_proj.bias: cuda:0 vision_model.encoder.layers.5.self_attn.v_proj.weight: cuda:0 vision_model.encoder.layers.5.self_attn.v_proj.bias: cuda:0 vision_model.encoder.layers.5.self_attn.q_proj.weight: cuda:0 vision_model.encoder.layers.5.self_attn.q_proj.bias: cuda:0 vision_model.encoder.layers.5.self_attn.out_proj.weight: cuda:0 vision_model.encoder.layers.5.self_attn.out_proj.bias: cuda:0 vision_model.encoder.layers.5.layer_norm1.weight: cuda:0 vision_model.encoder.layers.5.layer_norm1.bias: cuda:0 vision_model.encoder.layers.5.mlp.fc1.weight: cuda:0 vision_model.encoder.layers.5.mlp.fc1.bias: cuda:0 vision_model.encoder.layers.5.mlp.fc2.weight: cuda:0 vision_model.encoder.layers.5.mlp.fc2.bias: cuda:0 vision_model.encoder.layers.5.layer_norm2.weight: cuda:0 vision_model.encoder.layers.5.layer_norm2.bias: cuda:0 vision_model.encoder.layers.6.self_attn.k_proj.weight: cuda:0 vision_model.encoder.layers.6.self_attn.k_proj.bias: cuda:0 vision_model.encoder.layers.6.self_attn.v_proj.weight: cuda:0 vision_model.encoder.layers.6.self_attn.v_proj.bias: cuda:0 vision_model.encoder.layers.6.self_attn.q_proj.weight: cuda:0 vision_model.encoder.layers.6.self_attn.q_proj.bias: cuda:0 vision_model.encoder.layers.6.self_attn.out_proj.weight: cuda:0 vision_model.encoder.layers.6.self_attn.out_proj.bias: cuda:0 vision_model.encoder.layers.6.layer_norm1.weight: cuda:0 vision_model.encoder.layers.6.layer_norm1.bias: cuda:0 vision_model.encoder.layers.6.mlp.fc1.weight: cuda:0 vision_model.encoder.layers.6.mlp.fc1.bias: cuda:0 vision_model.encoder.layers.6.mlp.fc2.weight: cuda:0 vision_model.encoder.layers.6.mlp.fc2.bias: cuda:0 vision_model.encoder.layers.6.layer_norm2.weight: cuda:0 vision_model.encoder.layers.6.layer_norm2.bias: cuda:0 vision_model.encoder.layers.7.self_attn.k_proj.weight: cuda:0 vision_model.encoder.layers.7.self_attn.k_proj.bias: cuda:0 vision_model.encoder.layers.7.self_attn.v_proj.weight: cuda:0 vision_model.encoder.layers.7.self_attn.v_proj.bias: cuda:0 vision_model.encoder.layers.7.self_attn.q_proj.weight: cuda:0 vision_model.encoder.layers.7.self_attn.q_proj.bias: cuda:0 vision_model.encoder.layers.7.self_attn.out_proj.weight: cuda:0 vision_model.encoder.layers.7.self_attn.out_proj.bias: cuda:0 vision_model.encoder.layers.7.layer_norm1.weight: cuda:0 vision_model.encoder.layers.7.layer_norm1.bias: cuda:0 vision_model.encoder.layers.7.mlp.fc1.weight: cuda:0 vision_model.encoder.layers.7.mlp.fc1.bias: cuda:0 vision_model.encoder.layers.7.mlp.fc2.weight: cuda:0 vision_model.encoder.layers.7.mlp.fc2.bias: cuda:0 vision_model.encoder.layers.7.layer_norm2.weight: cuda:0 vision_model.encoder.layers.7.layer_norm2.bias: cuda:0 vision_model.encoder.layers.8.self_attn.k_proj.weight: cuda:0 vision_model.encoder.layers.8.self_attn.k_proj.bias: cuda:0 vision_model.encoder.layers.8.self_attn.v_proj.weight: cuda:0 vision_model.encoder.layers.8.self_attn.v_proj.bias: cuda:0 vision_model.encoder.layers.8.self_attn.q_proj.weight: cuda:0 vision_model.encoder.layers.8.self_attn.q_proj.bias: cuda:0 vision_model.encoder.layers.8.self_attn.out_proj.weight: cuda:0 vision_model.encoder.layers.8.self_attn.out_proj.bias: cuda:0 vision_model.encoder.layers.8.layer_norm1.weight: cuda:0 vision_model.encoder.layers.8.layer_norm1.bias: cuda:0 vision_model.encoder.layers.8.mlp.fc1.weight: cuda:0 vision_model.encoder.layers.8.mlp.fc1.bias: cuda:0 vision_model.encoder.layers.8.mlp.fc2.weight: cuda:0 vision_model.encoder.layers.8.mlp.fc2.bias: cuda:0 vision_model.encoder.layers.8.layer_norm2.weight: cuda:0 vision_model.encoder.layers.8.layer_norm2.bias: cuda:0 vision_model.encoder.layers.9.self_attn.k_proj.weight: cuda:0 vision_model.encoder.layers.9.self_attn.k_proj.bias: cuda:0 vision_model.encoder.layers.9.self_attn.v_proj.weight: cuda:0 vision_model.encoder.layers.9.self_attn.v_proj.bias: cuda:0 vision_model.encoder.layers.9.self_attn.q_proj.weight: cuda:0 vision_model.encoder.layers.9.self_attn.q_proj.bias: cuda:0 vision_model.encoder.layers.9.self_attn.out_proj.weight: cuda:0 vision_model.encoder.layers.9.self_attn.out_proj.bias: cuda:0 vision_model.encoder.layers.9.layer_norm1.weight: cuda:0 vision_model.encoder.layers.9.layer_norm1.bias: cuda:0 vision_model.encoder.layers.9.mlp.fc1.weight: cuda:0 vision_model.encoder.layers.9.mlp.fc1.bias: cuda:0 vision_model.encoder.layers.9.mlp.fc2.weight: cuda:0 vision_model.encoder.layers.9.mlp.fc2.bias: cuda:0 vision_model.encoder.layers.9.layer_norm2.weight: cuda:0 vision_model.encoder.layers.9.layer_norm2.bias: cuda:0 vision_model.encoder.layers.10.self_attn.k_proj.weight: cuda:0 vision_model.encoder.layers.10.self_attn.k_proj.bias: cuda:0 vision_model.encoder.layers.10.self_attn.v_proj.weight: cuda:0 vision_model.encoder.layers.10.self_attn.v_proj.bias: cuda:0 vision_model.encoder.layers.10.self_attn.q_proj.weight: cuda:0 vision_model.encoder.layers.10.self_attn.q_proj.bias: cuda:0 vision_model.encoder.layers.10.self_attn.out_proj.weight: cuda:0 vision_model.encoder.layers.10.self_attn.out_proj.bias: cuda:0 vision_model.encoder.layers.10.layer_norm1.weight: cuda:0 vision_model.encoder.layers.10.layer_norm1.bias: cuda:0 vision_model.encoder.layers.10.mlp.fc1.weight: cuda:0 vision_model.encoder.layers.10.mlp.fc1.bias: cuda:0 vision_model.encoder.layers.10.mlp.fc2.weight: cuda:0 vision_model.encoder.layers.10.mlp.fc2.bias: cuda:0 vision_model.encoder.layers.10.layer_norm2.weight: cuda:0 vision_model.encoder.layers.10.layer_norm2.bias: cuda:0 vision_model.encoder.layers.11.self_attn.k_proj.weight: cuda:0 vision_model.encoder.layers.11.self_attn.k_proj.bias: cuda:0 vision_model.encoder.layers.11.self_attn.v_proj.weight: cuda:0 vision_model.encoder.layers.11.self_attn.v_proj.bias: cuda:0 vision_model.encoder.layers.11.self_attn.q_proj.weight: cuda:0 vision_model.encoder.layers.11.self_attn.q_proj.bias: cuda:0 vision_model.encoder.layers.11.self_attn.out_proj.weight: cuda:0 vision_model.encoder.layers.11.self_attn.out_proj.bias: cuda:0 vision_model.encoder.layers.11.layer_norm1.weight: cuda:0 vision_model.encoder.layers.11.layer_norm1.bias: cuda:0 vision_model.encoder.layers.11.mlp.fc1.weight: cuda:0 vision_model.encoder.layers.11.mlp.fc1.bias: cuda:0 vision_model.encoder.layers.11.mlp.fc2.weight: cuda:0 vision_model.encoder.layers.11.mlp.fc2.bias: cuda:0 vision_model.encoder.layers.11.layer_norm2.weight: cuda:0 vision_model.encoder.layers.11.layer_norm2.bias: cuda:0 vision_model.encoder.layers.12.self_attn.k_proj.weight: cuda:0 vision_model.encoder.layers.12.self_attn.k_proj.bias: cuda:0 vision_model.encoder.layers.12.self_attn.v_proj.weight: cuda:0 vision_model.encoder.layers.12.self_attn.v_proj.bias: cuda:0 vision_model.encoder.layers.12.self_attn.q_proj.weight: cuda:0 vision_model.encoder.layers.12.self_attn.q_proj.bias: cuda:0 vision_model.encoder.layers.12.self_attn.out_proj.weight: cuda:0 vision_model.encoder.layers.12.self_attn.out_proj.bias: cuda:0 vision_model.encoder.layers.12.layer_norm1.weight: cuda:0 vision_model.encoder.layers.12.layer_norm1.bias: cuda:0 vision_model.encoder.layers.12.mlp.fc1.weight: cuda:0 vision_model.encoder.layers.12.mlp.fc1.bias: cuda:0 vision_model.encoder.layers.12.mlp.fc2.weight: cuda:0 vision_model.encoder.layers.12.mlp.fc2.bias: cuda:0 vision_model.encoder.layers.12.layer_norm2.weight: cuda:0 vision_model.encoder.layers.12.layer_norm2.bias: cuda:0 vision_model.encoder.layers.13.self_attn.k_proj.weight: cuda:0 vision_model.encoder.layers.13.self_attn.k_proj.bias: cuda:0 vision_model.encoder.layers.13.self_attn.v_proj.weight: cuda:0 vision_model.encoder.layers.13.self_attn.v_proj.bias: cuda:0 vision_model.encoder.layers.13.self_attn.q_proj.weight: cuda:0 vision_model.encoder.layers.13.self_attn.q_proj.bias: cuda:0 vision_model.encoder.layers.13.self_attn.out_proj.weight: cuda:0 vision_model.encoder.layers.13.self_attn.out_proj.bias: cuda:0 vision_model.encoder.layers.13.layer_norm1.weight: cuda:0 vision_model.encoder.layers.13.layer_norm1.bias: cuda:0 vision_model.encoder.layers.13.mlp.fc1.weight: cuda:0 vision_model.encoder.layers.13.mlp.fc1.bias: cuda:0 vision_model.encoder.layers.13.mlp.fc2.weight: cuda:0 vision_model.encoder.layers.13.mlp.fc2.bias: cuda:0 vision_model.encoder.layers.13.layer_norm2.weight: cuda:0 vision_model.encoder.layers.13.layer_norm2.bias: cuda:0 vision_model.encoder.layers.14.self_attn.k_proj.weight: cuda:0 vision_model.encoder.layers.14.self_attn.k_proj.bias: cuda:0 vision_model.encoder.layers.14.self_attn.v_proj.weight: cuda:0 vision_model.encoder.layers.14.self_attn.v_proj.bias: cuda:0 vision_model.encoder.layers.14.self_attn.q_proj.weight: cuda:0 vision_model.encoder.layers.14.self_attn.q_proj.bias: cuda:0 vision_model.encoder.layers.14.self_attn.out_proj.weight: cuda:0 vision_model.encoder.layers.14.self_attn.out_proj.bias: cuda:0 vision_model.encoder.layers.14.layer_norm1.weight: cuda:0 vision_model.encoder.layers.14.layer_norm1.bias: cuda:0 vision_model.encoder.layers.14.mlp.fc1.weight: cuda:0 vision_model.encoder.layers.14.mlp.fc1.bias: cuda:0 vision_model.encoder.layers.14.mlp.fc2.weight: cuda:0 vision_model.encoder.layers.14.mlp.fc2.bias: cuda:0 vision_model.encoder.layers.14.layer_norm2.weight: cuda:0 vision_model.encoder.layers.14.layer_norm2.bias: cuda:0 vision_model.encoder.layers.15.self_attn.k_proj.weight: cuda:0 vision_model.encoder.layers.15.self_attn.k_proj.bias: cuda:0 vision_model.encoder.layers.15.self_attn.v_proj.weight: cuda:0 vision_model.encoder.layers.15.self_attn.v_proj.bias: cuda:0 vision_model.encoder.layers.15.self_attn.q_proj.weight: cuda:0 vision_model.encoder.layers.15.self_attn.q_proj.bias: cuda:0 vision_model.encoder.layers.15.self_attn.out_proj.weight: cuda:0 vision_model.encoder.layers.15.self_attn.out_proj.bias: cuda:0 vision_model.encoder.layers.15.layer_norm1.weight: cuda:0 vision_model.encoder.layers.15.layer_norm1.bias: cuda:0 vision_model.encoder.layers.15.mlp.fc1.weight: cuda:0 vision_model.encoder.layers.15.mlp.fc1.bias: cuda:0 vision_model.encoder.layers.15.mlp.fc2.weight: cuda:0 vision_model.encoder.layers.15.mlp.fc2.bias: cuda:0 vision_model.encoder.layers.15.layer_norm2.weight: cuda:0 vision_model.encoder.layers.15.layer_norm2.bias: cuda:0 vision_model.encoder.layers.16.self_attn.k_proj.weight: cuda:0 vision_model.encoder.layers.16.self_attn.k_proj.bias: cuda:0 vision_model.encoder.layers.16.self_attn.v_proj.weight: cuda:0 vision_model.encoder.layers.16.self_attn.v_proj.bias: cuda:0 vision_model.encoder.layers.16.self_attn.q_proj.weight: cuda:0 vision_model.encoder.layers.16.self_attn.q_proj.bias: cuda:0 vision_model.encoder.layers.16.self_attn.out_proj.weight: cuda:0 vision_model.encoder.layers.16.self_attn.out_proj.bias: cuda:0 vision_model.encoder.layers.16.layer_norm1.weight: cuda:0 vision_model.encoder.layers.16.layer_norm1.bias: cuda:0 vision_model.encoder.layers.16.mlp.fc1.weight: cuda:0 vision_model.encoder.layers.16.mlp.fc1.bias: cuda:0 vision_model.encoder.layers.16.mlp.fc2.weight: cuda:0 vision_model.encoder.layers.16.mlp.fc2.bias: cuda:0 vision_model.encoder.layers.16.layer_norm2.weight: cuda:0 vision_model.encoder.layers.16.layer_norm2.bias: cuda:0 vision_model.encoder.layers.17.self_attn.k_proj.weight: cuda:0 vision_model.encoder.layers.17.self_attn.k_proj.bias: cuda:0 vision_model.encoder.layers.17.self_attn.v_proj.weight: cuda:0 vision_model.encoder.layers.17.self_attn.v_proj.bias: cuda:0 vision_model.encoder.layers.17.self_attn.q_proj.weight: cuda:0 vision_model.encoder.layers.17.self_attn.q_proj.bias: cuda:0 vision_model.encoder.layers.17.self_attn.out_proj.weight: cuda:0 vision_model.encoder.layers.17.self_attn.out_proj.bias: cuda:0 vision_model.encoder.layers.17.layer_norm1.weight: cuda:0 vision_model.encoder.layers.17.layer_norm1.bias: cuda:0 vision_model.encoder.layers.17.mlp.fc1.weight: cuda:0 vision_model.encoder.layers.17.mlp.fc1.bias: cuda:0 vision_model.encoder.layers.17.mlp.fc2.weight: cuda:0 vision_model.encoder.layers.17.mlp.fc2.bias: cuda:0 vision_model.encoder.layers.17.layer_norm2.weight: cuda:0 vision_model.encoder.layers.17.layer_norm2.bias: cuda:0 vision_model.encoder.layers.18.self_attn.k_proj.weight: cuda:0 vision_model.encoder.layers.18.self_attn.k_proj.bias: cuda:0 vision_model.encoder.layers.18.self_attn.v_proj.weight: cuda:0 vision_model.encoder.layers.18.self_attn.v_proj.bias: cuda:0 vision_model.encoder.layers.18.self_attn.q_proj.weight: cuda:0 vision_model.encoder.layers.18.self_attn.q_proj.bias: cuda:0 vision_model.encoder.layers.18.self_attn.out_proj.weight: cuda:0 vision_model.encoder.layers.18.self_attn.out_proj.bias: cuda:0 vision_model.encoder.layers.18.layer_norm1.weight: cuda:0 vision_model.encoder.layers.18.layer_norm1.bias: cuda:0 vision_model.encoder.layers.18.mlp.fc1.weight: cuda:0 vision_model.encoder.layers.18.mlp.fc1.bias: cuda:0 vision_model.encoder.layers.18.mlp.fc2.weight: cuda:0 vision_model.encoder.layers.18.mlp.fc2.bias: cuda:0 vision_model.encoder.layers.18.layer_norm2.weight: cuda:0 vision_model.encoder.layers.18.layer_norm2.bias: cuda:0 vision_model.encoder.layers.19.self_attn.k_proj.weight: cuda:0 vision_model.encoder.layers.19.self_attn.k_proj.bias: cuda:0 vision_model.encoder.layers.19.self_attn.v_proj.weight: cuda:0 vision_model.encoder.layers.19.self_attn.v_proj.bias: cuda:0 vision_model.encoder.layers.19.self_attn.q_proj.weight: cuda:0 vision_model.encoder.layers.19.self_attn.q_proj.bias: cuda:0 vision_model.encoder.layers.19.self_attn.out_proj.weight: cuda:0 vision_model.encoder.layers.19.self_attn.out_proj.bias: cuda:0 vision_model.encoder.layers.19.layer_norm1.weight: cuda:0 vision_model.encoder.layers.19.layer_norm1.bias: cuda:0 vision_model.encoder.layers.19.mlp.fc1.weight: cuda:0 vision_model.encoder.layers.19.mlp.fc1.bias: cuda:0 vision_model.encoder.layers.19.mlp.fc2.weight: cuda:0 vision_model.encoder.layers.19.mlp.fc2.bias: cuda:0 vision_model.encoder.layers.19.layer_norm2.weight: cuda:0 vision_model.encoder.layers.19.layer_norm2.bias: cuda:0 vision_model.encoder.layers.20.self_attn.k_proj.weight: cuda:0 vision_model.encoder.layers.20.self_attn.k_proj.bias: cuda:0 vision_model.encoder.layers.20.self_attn.v_proj.weight: cuda:0 vision_model.encoder.layers.20.self_attn.v_proj.bias: cuda:0 vision_model.encoder.layers.20.self_attn.q_proj.weight: cuda:0 vision_model.encoder.layers.20.self_attn.q_proj.bias: cuda:0 vision_model.encoder.layers.20.self_attn.out_proj.weight: cuda:0 vision_model.encoder.layers.20.self_attn.out_proj.bias: cuda:0 vision_model.encoder.layers.20.layer_norm1.weight: cuda:0 vision_model.encoder.layers.20.layer_norm1.bias: cuda:0 vision_model.encoder.layers.20.mlp.fc1.weight: cuda:0 vision_model.encoder.layers.20.mlp.fc1.bias: cuda:0 vision_model.encoder.layers.20.mlp.fc2.weight: cuda:0 vision_model.encoder.layers.20.mlp.fc2.bias: cuda:0 vision_model.encoder.layers.20.layer_norm2.weight: cuda:0 vision_model.encoder.layers.20.layer_norm2.bias: cuda:0 vision_model.encoder.layers.21.self_attn.k_proj.weight: cuda:0 vision_model.encoder.layers.21.self_attn.k_proj.bias: cuda:0 vision_model.encoder.layers.21.self_attn.v_proj.weight: cuda:0 vision_model.encoder.layers.21.self_attn.v_proj.bias: cuda:0 vision_model.encoder.layers.21.self_attn.q_proj.weight: cuda:0 vision_model.encoder.layers.21.self_attn.q_proj.bias: cuda:0 vision_model.encoder.layers.21.self_attn.out_proj.weight: cuda:0 vision_model.encoder.layers.21.self_attn.out_proj.bias: cuda:0 vision_model.encoder.layers.21.layer_norm1.weight: cuda:0 vision_model.encoder.layers.21.layer_norm1.bias: cuda:0 vision_model.encoder.layers.21.mlp.fc1.weight: cuda:0 vision_model.encoder.layers.21.mlp.fc1.bias: cuda:0 vision_model.encoder.layers.21.mlp.fc2.weight: cuda:0 vision_model.encoder.layers.21.mlp.fc2.bias: cuda:0 vision_model.encoder.layers.21.layer_norm2.weight: cuda:0 vision_model.encoder.layers.21.layer_norm2.bias: cuda:0 vision_model.encoder.layers.22.self_attn.k_proj.weight: cuda:0 vision_model.encoder.layers.22.self_attn.k_proj.bias: cuda:0 vision_model.encoder.layers.22.self_attn.v_proj.weight: cuda:0 vision_model.encoder.layers.22.self_attn.v_proj.bias: cuda:0 vision_model.encoder.layers.22.self_attn.q_proj.weight: cuda:0 vision_model.encoder.layers.22.self_attn.q_proj.bias: cuda:0 vision_model.encoder.layers.22.self_attn.out_proj.weight: cuda:0 vision_model.encoder.layers.22.self_attn.out_proj.bias: cuda:0 vision_model.encoder.layers.22.layer_norm1.weight: cuda:0 vision_model.encoder.layers.22.layer_norm1.bias: cuda:0 vision_model.encoder.layers.22.mlp.fc1.weight: cuda:0 vision_model.encoder.layers.22.mlp.fc1.bias: cuda:0 vision_model.encoder.layers.22.mlp.fc2.weight: cuda:0 vision_model.encoder.layers.22.mlp.fc2.bias: cuda:0 vision_model.encoder.layers.22.layer_norm2.weight: cuda:0 vision_model.encoder.layers.22.layer_norm2.bias: cuda:0 vision_model.encoder.layers.23.self_attn.k_proj.weight: cuda:0 vision_model.encoder.layers.23.self_attn.k_proj.bias: cuda:0 vision_model.encoder.layers.23.self_attn.v_proj.weight: cuda:0 vision_model.encoder.layers.23.self_attn.v_proj.bias: cuda:0 vision_model.encoder.layers.23.self_attn.q_proj.weight: cuda:0 vision_model.encoder.layers.23.self_attn.q_proj.bias: cuda:0 vision_model.encoder.layers.23.self_attn.out_proj.weight: cuda:0 vision_model.encoder.layers.23.self_attn.out_proj.bias: cuda:0 vision_model.encoder.layers.23.layer_norm1.weight: cuda:0 vision_model.encoder.layers.23.layer_norm1.bias: cuda:0 vision_model.encoder.layers.23.mlp.fc1.weight: cuda:0 vision_model.encoder.layers.23.mlp.fc1.bias: cuda:0 vision_model.encoder.layers.23.mlp.fc2.weight: cuda:0 vision_model.encoder.layers.23.mlp.fc2.bias: cuda:0 vision_model.encoder.layers.23.layer_norm2.weight: cuda:0 vision_model.encoder.layers.23.layer_norm2.bias: cuda:0 vision_model.post_layernorm.weight: cuda:0 vision_model.post_layernorm.bias: cuda:0 visual_projection.weight: cuda:0 text_projection.weight: cuda:0 试分析代码结果: import torch from transformers import AutoModel model_path = r"D:\pycharm\pythonProject1\mms\models\BAAIbge-vl-large" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = AutoModel.from_pretrained(model_path, trust_remote_code=True, local_files_only=True).to(device) for name, param in model.named_parameters(): print(f"{name}: {param.device}")
07-22
这段代码的输出结果显示了模型的所有参数(weights和biases)都被成功加载到了 `cuda:0` 设备上。这表明你的PyTorch环境检测到了可用的CUDA设备(GPU),并将模型的所有参数都转移到了GPU上。以下是对代码和输出的分析: ### 代码分析 1. **模型加载**: - 使用 `AutoModel.from_pretrained` 从本地路径加载预训练模型。 - `trust_remote_code=True` 允许从模型库中动态加载自定义代码(如果模型有自定义架构)。 - `local_files_only=True` 确保只从本地路径加载模型,而不尝试从网络下载。 2. **设备选择**: - `torch.device("cuda" if torch.cuda.is_available() else "cpu")` 会自动选择GPU(如果可用),否则回退到CPU。 - 从输出 `cuda:0` 可以看出,你的环境中有可用的CUDA设备(GPU),因此模型被加载到了GPU上。 3. **参数遍历**: - `model.named_parameters()` 遍历模型的所有参数,并打印每个参数的名称和设备位置。 - 输出显示所有参数都在 `cuda:0` 上,说明模型完全加载到了GPU。 ### 输出分析 输出中列出了模型的每一层及其参数的设备位置,例如: - `logit_scale: cuda:0` - `text_model.embeddings.token_embedding.weight: cuda:0` - `vision_model.encoder.layers.0.self_attn.k_proj.weight: cuda:0` - ...(其他类似参数) 这表明: 1. 模型的所有部分(包括文本编码器、视觉编码器及其投影矩阵)都被正确加载到了GPU。 2. 没有参数留在CPU上,说明设备转移(`.to(device)`)操作成功。 ### 潜在问题与建议 1. **CUDA环境**: - 确保你的PyTorch版本支持CUDA(可通过 `torch.cuda.is_available()` 验证)。 - 如果输出是 `cpu`,可能是CUDA驱动未安装或PyTorch未编译CUDA支持。 2. **显存限制**: - 如果模型很大(如大型视觉-语言模型),可能会占用大量显存。可以通过 `torch.cuda.mem_get_info()` 检查剩余显存。 - 如果显存不足,可以尝试: - 使用更小的batch size。 - 启用梯度检查点(`gradient_checkpointing`)。 - 使用半精度(`model.half()`)。 3. **模型兼容性**: - `trust_remote_code=True` 需要模型库的代码是可信的,因为会动态执行自定义代码。如果模型来自不可信来源,可能存在安全风险。 ### 示例扩展 如果你想进一步验证模型是否能在GPU上运行,可以尝试以下代码: ```python # 生成随机输入并测试前向传播 input_text = torch.randint(0, 1000, (1, 10)).to(device) # 假设文本输入 input_image = torch.randn(1, 3, 224, 224).to(device) # 假设图像输入 # 前向传播(具体输入名称需根据模型调整) output = model(text_input_ids=input_text, pixel_values=input_image) print(output) ``` ### 总结 你的代码成功将模型加载到了GPU上,输出 `cuda:0` 表明一切正常。接下来可以: 1. 验证模型的前向传播是否工作(如上所示)。 2. 检查显存使用情况,确保模型能处理实际数据。 3. 如果遇到问题(如显存不足),可以尝试优化或调整模型配置。
评论 1
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值