解决softmax后列和不为1的bug记录 :问题原因为 s为1维的,来除torch.exp(x)(64x10)时候,维数不对应,需要将s也要转换为2维的即维数为(64x1),才可以广播按行对应相除

本文详细介绍了在PyTorch框架中如何正确实现Softmax函数,通过解决维度不匹配的问题,确保了函数能够准确地应用于神经网络的输出层,实现了概率分布的计算。

def softmax(x):
    ## TODO: Implement the softmax function here
    #print("torch.exp(x)=",torch.exp(x))
    s = torch.sum(torch.exp(x),dim=1)
    print("s.size=",s.size())
    print("s.view(-1,1)=",s.view(-1,1).size())

   #有问题的return (torch.exp(x)/s).view(-1,1)
    return torch.exp(x)/(s.view(-1,1))
    """
def softmax(x):
    print("x.shape=",x.shape)
    return torch.exp(x)/torch.sum(torch.exp(x), dim=1).view(-1, 1)
"""
# Here, out should be the output of the network in the previous excercise with shape (64,10)
#print("out=",out)
print("out.shape=",out.shape)
probabilities = softmax(out)
print("probabilities=",probabilities)
# Does it have the right shape? Should be (64, 10)
#print(probabilities.shape)
# Does it sum to 1?
#print(probabilities)
print(probabilities.sum(dim=1))

#result

out.shape= torch.Size([64, 10])
s.size= torch.Size([64])
s.view(-1,1)= torch.Size([64, 1])
probabilities= tensor([[ 3.0988e-16,  1.0000e+00,  1.8007e-11,  1.4650e-11,  8.8916e-09,
          5.3950e-17,  4.7416e-09,  8.0687e-19,  3.1054e-09,  2.2082e-10],
        [ 7.2884e-18,  9.9896e-01,  4.5583e-08,  1.2075e-05,  2.1988e-09,
          1.5377e-13,  1.0279e-03,  2.0932e-10,  2.9357e-07,  4.3231e-11],
        [ 5.2138e-16,  4.6368e-02,  8.9821e-14,  3.0514e-09,  3.1598e-09,
          1.3639e-17,  9.5363e-01,  6.5663e-15,  1.2519e-09,  1.5230e-12],
        [ 1.5445e-11,  9.9793e-01,  2.8143e-08,  1.2875e-04,  1.2082e-04,
          9.7656e-14,  3.9710e-05,  7.6168e-09,  1.7803e-03,  1.4417e-08],
        [ 1.0958e-13,  9.6759e-01,  2.2564e-12,  1.9013e-11,  1.9647e-11,
          1.8859e-12,  3.2378e-02,  5.4169e-08,  2.7937e-05,  3.2464e-13],
        [ 8.1973e-11,  6.9379e-01,  5.9974e-09,  8.5656e-06,  5.1688e-03,
          2.4149e-17,  2.3638e-02,  4.3271e-12,  4.1039e-05,  2.7735e-01],
        [ 3.7563e-10,  9.9727e-01,  2.4940e-06,  8.8124e-04,  3.5414e-05,
          1.8617e-11,  2.0491e-04,  1.7631e-11,  1.6052e-03,  1.5836e-10],
        [ 9.4069e-17,  9.9962e-01,  5.0893e-13,  2.1499e-11,  2.3957e-06,
          3.5678e-20,  3.8035e-04,  1.4381e-14,  7.2088e-07,  2.4388e-08],
        [ 3.5989e-12,  9.9864e-01,  1.6160e-06,  2.3764e-04,  1.2289e-06,
          6.1952e-12,  2.6893e-07,  1.1190e-06,  1.1140e-03,  1.5123e-07],
        [ 7.6218e-18,  9.5789e-01,  2.4285e-11,  1.9727e-09,  4.4734e-13,
          1.1504e-13,  4.2112e-02,  5.4719e-12,  1.1318e-07,  4.2207e-14],
        [ 6.9200e-18,  9.9999e-01,  3.3765e-11,  3.0676e-09,  6.6521e-10,
          1.9209e-17,  9.4535e-06,  1.0995e-13,  1.4980e-07,  1.1143e-08],
        [ 1.1968e-10,  9.9062e-01,  8.3010e-10,  1.9934e-03,  2.6987e-08,
          6.7082e-12,  5.9127e-05,  5.9065e-06,  7.3257e-03,  3.9619e-08],
        [ 2.6697e-20,  1.0000e+00,  6.1491e-14,  5.5949e-13,  3.8020e-12,
          4.6593e-19,  1.3082e-07,  1.8616e-14,  8.4002e-11,  8.4461e-15],
        [ 2.4753e-17,  1.0000e+00,  5.8076e-12,  1.6527e-07,  1.4076e-10,
          3.4725e-13,  8.1565e-07,  2.1852e-11,  2.3215e-07,  4.4056e-12],
        [ 4.7344e-17,  1.0000e+00,  3.3414e-11,  3.1781e-10,  8.6068e-12,
          2.3071e-15,  2.8367e-10,  2.3826e-12,  1.7079e-07,  3.1149e-11],
        [ 4.4352e-11,  9.9746e-01,  2.8773e-09,  7.2385e-11,  3.6884e-08,
          4.4588e-15,  9.4779e-08,  1.2648e-13,  2.5424e-03,  7.2419e-13],
        [ 1.1315e-14,  4.7640e-01,  4.4058e-08,  5.7084e-13,  6.0529e-09,
          7.8344e-15,  5.2360e-01,  6.2813e-09,  5.5826e-14,  3.6299e-14],
        [ 7.4028e-14,  9.7872e-01,  3.0862e-10,  1.8337e-07,  1.6125e-06,
          4.8351e-14,  2.1259e-02,  7.2829e-11,  1.4334e-05,  1.3491e-09],
        [ 1.3201e-13,  9.9998e-01,  1.2696e-10,  1.9295e-08,  4.1318e-11,
          7.8015e-12,  1.6925e-05,  3.8119e-15,  6.8226e-07,  1.3916e-11],
        [ 1.4418e-18,  8.1829e-01,  1.2666e-11,  5.9709e-13,  1.2443e-06,
          2.1665e-16,  1.8171e-01,  6.2544e-13,  3.6078e-13,  9.4016e-15],
        [ 2.1567e-10,  9.6855e-01,  3.3647e-10,  5.3065e-07,  2.0691e-06,
          2.1863e-20,  2.8312e-02,  1.0761e-07,  2.8217e-03,  3.0837e-04],
        [ 7.6143e-12,  9.8924e-01,  2.0398e-06,  8.3411e-06,  1.1166e-03,
          6.4356e-08,  9.6115e-03,  3.7792e-08,  1.4629e-05,  2.7286e-06],
        [ 1.0694e-19,  1.0000e+00,  1.6696e-11,  9.9358e-11,  3.0807e-15,
          1.4382e-16,  4.7081e-09,  6.2134e-14,  8.9735e-11,  4.8432e-13],
        [ 1.6541e-16,  9.0249e-01,  9.2298e-12,  8.2350e-06,  1.0797e-10,
          2.7196e-13,  9.7506e-02,  2.3810e-09,  1.0474e-07,  1.2539e-13],
        [ 3.7895e-15,  9.9998e-01,  3.1908e-08,  4.6845e-15,  6.6953e-07,
          1.8835e-17,  2.2061e-05,  3.7345e-11,  5.3391e-10,  9.8172e-13],
        [ 2.2465e-13,  9.9520e-01,  3.3508e-10,  2.0350e-07,  1.6494e-05,
          1.1619e-11,  1.7151e-05,  2.1420e-09,  4.7628e-03,  4.6928e-13],
        [ 6.8133e-13,  9.9846e-01,  1.0881e-06,  6.1678e-10,  7.9566e-07,
          1.2640e-13,  1.0096e-05,  2.6806e-10,  5.2880e-09,  1.5275e-03],
        [ 4.0346e-06,  9.9922e-01,  5.6760e-14,  5.1681e-04,  2.7113e-10,
          1.9770e-15,  3.4958e-08,  4.6642e-09,  2.5622e-04,  3.5541e-06],
        [ 2.3558e-15,  9.9996e-01,  1.5141e-10,  9.6498e-08,  3.9098e-05,
          1.0603e-15,  1.3403e-08,  4.2760e-15,  1.1000e-09,  4.5716e-11],
        [ 7.2170e-17,  9.9850e-01,  1.0450e-11,  4.0275e-07,  1.7123e-09,
          5.7820e-12,  1.5023e-03,  3.4995e-12,  9.2998e-09,  1.2087e-11],
        [ 7.0592e-17,  9.9999e-01,  3.2517e-09,  2.0088e-13,  2.8803e-12,
          7.7315e-18,  3.9520e-06,  3.0273e-11,  1.4446e-06,  7.1642e-14],
        [ 1.9605e-14,  9.9196e-01,  2.2681e-07,  1.8784e-05,  4.0640e-07,
          2.9341e-11,  1.3631e-03,  1.0106e-06,  6.6461e-03,  1.3264e-05],
        [ 2.0659e-11,  9.9777e-01,  1.1538e-11,  9.0129e-06,  3.1849e-05,
          2.0640e-11,  2.0553e-03,  1.6242e-07,  1.3015e-04,  4.3597e-10],
        [ 4.0858e-13,  9.9590e-01,  7.0598e-13,  4.0908e-03,  1.4616e-06,
          2.7791e-13,  1.0953e-05,  2.4548e-09,  2.6244e-07,  8.3998e-11],
        [ 2.4086e-11,  9.9991e-01,  1.3949e-10,  5.7709e-05,  3.4177e-07,
          9.4874e-16,  7.1716e-06,  7.3639e-13,  2.9284e-05,  6.0737e-13],
        [ 6.1061e-14,  9.3933e-01,  2.6500e-10,  3.2538e-07,  4.3397e-07,
          9.8440e-16,  4.2918e-05,  6.4921e-13,  6.0628e-02,  8.5904e-13],
        [ 4.5025e-17,  1.0000e+00,  1.1380e-15,  1.2371e-17,  3.1707e-15,
          1.0971e-22,  5.5156e-11,  3.0394e-16,  2.0227e-13,  2.7100e-13],
        [ 8.6031e-14,  9.9975e-01,  1.5011e-06,  5.5593e-05,  2.7639e-08,
          4.2199e-15,  3.8121e-07,  1.6479e-04,  2.6434e-05,  7.3974e-10],
        [ 7.8388e-10,  9.9906e-01,  3.6027e-09,  1.8897e-04,  5.1720e-10,
          5.0279e-11,  7.4521e-04,  1.6138e-06,  5.0832e-12,  1.4910e-07],
        [ 8.1989e-15,  9.9968e-01,  3.4793e-14,  2.9727e-04,  9.1857e-11,
          5.3726e-14,  2.2063e-05,  3.3956e-08,  5.2318e-08,  1.4547e-08],
        [ 2.9552e-11,  9.9901e-01,  1.8481e-04,  5.5410e-05,  2.9475e-07,
          3.4597e-12,  7.4382e-04,  4.2870e-09,  7.1338e-06,  1.2287e-08],
        [ 1.4819e-11,  2.0494e-01,  2.6659e-12,  4.1191e-09,  1.4334e-06,
          1.6648e-16,  7.9506e-01,  4.5270e-11,  1.4956e-07,  1.4043e-08],
        [ 4.1190e-11,  9.7229e-01,  3.6178e-04,  1.7179e-02,  8.1722e-08,
          5.6448e-15,  2.3135e-05,  1.6037e-06,  1.0142e-02,  1.6100e-08],
        [ 7.3725e-11,  7.0134e-01,  1.7235e-07,  1.5593e-04,  2.1137e-07,
          4.9391e-15,  2.9849e-01,  1.3844e-06,  8.7998e-06,  3.4835e-06],
        [ 1.4538e-16,  9.9962e-01,  1.0699e-14,  2.9652e-06,  3.3357e-08,
          5.2820e-13,  3.7365e-04,  1.6186e-13,  1.0933e-07,  2.0106e-10],
        [ 2.5675e-18,  1.0000e+00,  6.6527e-11,  3.2654e-10,  6.1968e-11,
          1.0783e-16,  4.2726e-06,  5.4559e-14,  5.1654e-08,  4.5689e-09],
        [ 3.8212e-13,  9.9232e-01,  2.4962e-07,  1.1517e-08,  1.8197e-05,
          5.6938e-10,  6.1286e-08,  8.9154e-11,  7.6590e-03,  8.1098e-07],
        [ 1.0655e-15,  9.9801e-01,  8.3258e-07,  2.4878e-09,  7.7644e-07,
          1.6196e-12,  1.9839e-03,  3.1300e-10,  7.4616e-09,  5.4257e-14],
        [ 4.8779e-16,  9.9999e-01,  8.2279e-12,  7.2976e-11,  4.4960e-13,
          6.2256e-13,  1.3462e-05,  8.8224e-12,  3.6317e-11,  1.3092e-11],
        [ 7.8841e-19,  9.9997e-01,  5.5566e-11,  4.0056e-10,  9.2204e-11,
          1.1939e-11,  2.6435e-05,  2.1659e-09,  9.7906e-08,  2.7225e-17],
        [ 2.9075e-11,  9.9067e-01,  4.3553e-09,  3.9526e-09,  5.5092e-08,
          1.6082e-14,  9.3325e-03,  1.3082e-08,  3.2009e-07,  1.7811e-09],
        [ 6.3138e-09,  9.9999e-01,  4.1571e-07,  2.6991e-07,  3.8994e-06,
          4.2260e-14,  3.8578e-06,  7.5499e-12,  1.2526e-07,  1.3982e-10],
        [ 5.6722e-16,  8.2703e-01,  4.3730e-14,  2.1261e-07,  1.2086e-10,
          9.9402e-16,  1.7297e-01,  3.5612e-12,  2.2576e-08,  4.4894e-11],
        [ 6.6159e-14,  1.0000e+00,  2.1189e-07,  2.3230e-06,  7.5754e-09,
          8.0700e-11,  1.8903e-06,  4.0510e-10,  3.2011e-09,  6.0312e-08],
        [ 5.6496e-13,  9.9997e-01,  7.0953e-10,  7.8982e-06,  7.9838e-07,
          5.1995e-17,  1.3731e-05,  1.6974e-13,  3.9789e-06,  3.2110e-07],
        [ 3.3208e-07,  9.5613e-01,  2.8434e-09,  4.1989e-07,  7.8602e-07,
          6.8962e-14,  2.7200e-04,  1.1510e-09,  4.3544e-02,  5.3730e-05],
        [ 1.8417e-13,  9.9728e-01,  3.9895e-07,  3.4604e-05,  6.7489e-09,
          2.1293e-11,  2.1081e-03,  8.0648e-09,  3.0386e-07,  5.8137e-04],
        [ 3.7823e-12,  1.0000e+00,  3.2145e-10,  2.3663e-10,  4.1574e-07,
          1.8841e-12,  5.5345e-07,  6.2556e-13,  1.5828e-10,  3.6763e-14],
        [ 2.4093e-17,  1.0000e+00,  7.4625e-15,  2.2570e-13,  5.3833e-09,
          7.0629e-20,  3.6674e-10,  2.6192e-16,  1.0119e-09,  1.5468e-21],
        [ 8.6024e-15,  2.1488e-01,  3.2251e-10,  1.0990e-06,  1.4385e-06,
          6.2122e-13,  7.8512e-01,  1.5255e-08,  1.1232e-06,  1.2307e-08],
        [ 3.4641e-09,  9.9518e-01,  5.8583e-10,  1.6316e-04,  1.2406e-03,
          4.2527e-15,  3.3174e-03,  1.1678e-06,  9.8092e-05,  8.7550e-07],
        [ 9.6422e-15,  1.0000e+00,  2.1577e-12,  4.3397e-13,  2.8023e-09,
          1.4402e-18,  2.1583e-07,  1.4629e-14,  1.1131e-10,  5.5922e-12],
        [ 1.4889e-13,  9.9998e-01,  1.3683e-08,  2.6725e-11,  7.6360e-12,
          2.5601e-14,  1.5277e-05,  1.4099e-10,  2.0855e-09,  1.5405e-15],
        [ 2.0768e-16,  9.9753e-01,  5.2618e-07,  2.0046e-06,  6.0347e-07,
          2.9689e-13,  2.4702e-03,  3.1447e-09,  5.1434e-09,  4.9742e-10]])
tensor([ 1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,
         1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,
         1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,
         1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,
         1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,
         1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,
         1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,
         1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,
         1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,
         1.0000])

#问题原因为 s为1维的,来除torch.exp(x)(64x10)时候,维数不对应,需要将s也要转换为2维的即维数为(64x1),才可以广播按行对应相除

*coding:utf-8 * import torch import torch.nn as nn import torch.nn.functional as F from utils import PointNetSetAbstraction, PointNetFeaturePropagation class TransformerBlock(nn.Module): “”“点云Transformer模块”“” def init(self, d_model, nhead, dim_feedforward=512, dropout=0.1): super().init() self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) self.linear1 = nn.Linear(d_model, dim_feedforward) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) self.activation = F.gelu def forward(self, x): # x: (B, C, N) -> (N, B, C) x = x.permute(2, 0, 1) attn_output, _ = self.self_attn(x, x, x) x = x + self.dropout1(attn_output) x = self.norm1(x) ff_output = self.linear2(self.dropout(self.activation(self.linear1(x)))) x = x + self.dropout2(ff_output) x = self.norm2(x) return x.permute(1, 2, 0) class ChannelAttention(nn.Module): “”“改进的通道注意力机制”“” def init(self, in_channels, reduction_ratio=8): super().init() self.avg_pool = nn.AdaptiveAvgPool1d(1) self.max_pool = nn.AdaptiveMaxPool1d(1) # 确保中间层通道至少为1 mid_channels = max(1, in_channels // reduction_ratio) # 关键修复 self.mlp = nn.Sequential( nn.Conv1d(in_channels, mid_channels, 1), # 使用mid_channels nn.ReLU(), nn.Conv1d(mid_channels, in_channels, 1) # 使用mid_channels ) self.sigmoid = nn.Sigmoid() def forward(self, x): avg_out = self.mlp(self.avg_pool(x)) max_out = self.mlp(self.max_pool(x)) out = avg_out + max_out return self.sigmoid(out) class MBConvPoint(nn.Module): “”“改进的MBConv模块”“” def init(self, in_channels, out_channels, expand_ratio, se_ratio, stride): super().init() self.in_channels = in_channels self.out_channels = out_channels self.expand_ratio = expand_ratio self.stride = stride self.has_se = se_ratio is not None and 0 < se_ratio <= 1 # 扩张层 expanded_channels = in_channels * expand_ratio if expand_ratio != 1: self.expand_conv = nn.Conv1d(in_channels, expanded_channels, 1, bias=False) self.bn0 = nn.BatchNorm1d(expanded_channels) self.swish0 = nn.SiLU() # 深度可分离卷积 self.depthwise_conv = nn.Conv1d( expanded_channels, expanded_channels, 1, groups=expanded_channels, bias=False ) self.bn1 = nn.BatchNorm1d(expanded_channels) self.swish1 = nn.SiLU() # 通道注意力 if self.has_se: self.se = ChannelAttention(expanded_channels, reduction_ratio=8) # 投影层 self.project_conv = nn.Conv1d(expanded_channels, out_channels, 1, bias=False) self.bn2 = nn.BatchNorm1d(out_channels) # 残差连接 self.use_skip = in_channels == out_channels and stride == 1 if self.use_skip: self.shortcut = nn.Identity() else: self.shortcut = nn.Sequential( nn.Conv1d(in_channels, out_channels, 1, bias=False), nn.BatchNorm1d(out_channels) ) def forward(self, x): identity = x # 扩展阶段 if self.expand_ratio != 1: x = self.expand_conv(x) x = self.bn0(x) x = self.swish0(x) # 深度卷积 x = self.depthwise_conv(x) x = self.bn1(x) x = self.swish1(x) # 通道注意力 if self.has_se: se_weights = self.se(x) x = x * se_weights # 投影 x = self.project_conv(x) x = self.bn2(x) # 残差连接 if self.use_skip: x = x + identity else: x = x + self.shortcut(identity) return x class DiceLoss(nn.Module): “”“Dice Loss for segmentation tasks”“” def init(self, num_classes, epsilon=1e-5): super().init() self.num_classes = num_classes self.epsilon = epsilon def forward(self, pred, target): probs = F.softmax(pred, dim=1) target_onehot = F.one_hot(target, num_classes=self.num_classes).permute(0, 2, 1).float() intersection = torch.sum(probs * target_onehot, dim=(0, 2)) union = torch.sum(probs, dim=(0, 2)) + torch.sum(target_onehot, dim=(0, 2)) dice = (2. * intersection + self.epsilon) / (union + self.epsilon) return 1. - torch.mean(dice) class FocalLoss(nn.Module): “”“Focal Loss for class imbalance”“” def init(self, alpha=0.5, gamma=2.0, reduction=‘mean’): super().init() self.alpha = alpha self.gamma = gamma self.reduction = reduction self.ce_loss = nn.CrossEntropyLoss(reduction=‘none’) def forward(self, inputs, targets): ce_loss = self.ce_loss(inputs, targets) pt = torch.exp(-ce_loss) focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss if self.reduction == ‘mean’: return torch.mean(focal_loss) elif self.reduction == ‘sum’: return torch.sum(focal_loss) return focal_loss class EnhancedPointEfficientNet(nn.Module): “”“增强版Point-EfficientNet”“” def init(self, num_classes, num_parts=50, normal_channel=False): super().init() if normal_channel: additional_channel = 3 else: additional_channel = 0 self.num_parts = num_parts self.normal_channel = normal_channel # 编码器 self.enc1 = self._make_encoder_stage(3 + additional_channel, 32, 2) self.sa1 = PointNetSetAbstraction(npoint=512, radius=0.2, nsample=64, in_channel=32, mlp=[32, 32, 64], group_all=False) self.transformer1 = TransformerBlock(64, nhead=4) self.enc2 = self._make_encoder_stage(64, 64, 3) self.sa2 = PointNetSetAbstraction(npoint=128, radius=0.4, nsample=64, in_channel=64, mlp=[64, 64, 128], group_all=False) self.transformer2 = TransformerBlock(128, nhead=4) self.enc3 = self._make_encoder_stage(128, 128, 4) self.sa3 = PointNetSetAbstraction(npoint=None, radius=None, nsample=None, in_channel=128, mlp=[128, 256, 512], group_all=True) self.transformer3 = TransformerBlock(512, nhead=8) # 多尺度融合 self.cross_scale_fusion = nn.Sequential( nn.Conv1d(512 + 128, 512, 1), nn.BatchNorm1d(512), nn.LeakyReLU(0.2) ) # 增强注意力机制 self.channel_att = ChannelAttention(512, reduction_ratio=8) self.spatial_att = nn.Sequential( nn.Conv1d(512, 256, 1), nn.BatchNorm1d(256), nn.ReLU(), nn.Conv1d(256, 1, 1), nn.Sigmoid() ) # 类别嵌入 self.cls_emb = nn.Sequential( nn.Linear(16, 256), nn.ReLU(), nn.Linear(256, 512), nn.LayerNorm(512) ) # 解码器 self.fp3 = PointNetFeaturePropagation(in_channel=640, mlp=[256, 256]) self.dec3 = self._make_decoder_stage(256, 128) self.transformer_fp3 = TransformerBlock(128, nhead=4) self.fp2 = PointNetFeaturePropagation(in_channel=192, mlp=[256, 128]) self.dec2 = self._make_decoder_stage(128, 64) self.transformer_fp2 = TransformerBlock(64, nhead=4) in_channel_fp1 = 3 + 32 + 64 self.fp1 = PointNetFeaturePropagation(in_channel=in_channel_fp1, mlp=[128, 128]) self.dec1 = self.make_decoder_stage(128, 64) self.transformer_fp1 = TransformerBlock(64, nhead=4) # 特征适配 self.feat_adapter = nn.Sequential( nn.Conv1d(64, 256, 1), nn.BatchNorm1d(256), nn.SiLU() ) # 分割头 self.conv1 = nn.Conv1d(256 + 512, 512, 1) self.bn1 = nn.BatchNorm1d(512) self.drop1 = nn.Dropout(0.4) self.conv2 = nn.Conv1d(512, num_parts, 1) def make_encoder_stage(self, in_channels, out_channels, num_blocks): blocks = [] blocks.append(MBConvPoint(in_channels, out_channels, expand_ratio=1, se_ratio=0.25, stride=1)) for _ in range(1, num_blocks): blocks.append(MBConvPoint(out_channels, out_channels, expand_ratio=6, se_ratio=0.25, stride=1)) return nn.Sequential(*blocks) def make_decoder_stage(self, in_channels, out_channels): return nn.Sequential( MBConvPoint(in_channels, in_channels, expand_ratio=6, se_ratio=0.25, stride=1), MBConvPoint(in_channels, out_channels, expand_ratio=6, se_ratio=0.25, stride=1) ) def forward(self, xyz, cls_label): B, , N = xyz.shape # 初始化 if self.normal_channel: l0_points = xyz l0_xyz = xyz[:, :3, :] else: l0_points = xyz l0_xyz = xyz # 编码器 l0_points = self.enc1(l0_points) l1_xyz, l1_points = self.sa1(l0_xyz, l0_points) l1_points = self.enc2(l1_points) l1_points = self.transformer1(l1_points) # Transformer Block 1 l2_xyz, l2_points = self.sa2(l1_xyz, l1_points) l2_points = self.enc3(l2_points) l2_points = self.transformer2(l2_points) # Transformer Block 2 l3_xyz, l3_points = self.sa3(l2_xyz, l2_points) l3_points = self.transformer3(l3_points) # Transformer Block 3 # 多尺度特征融合 l2_global = torch.mean(l2_points, dim=2, keepdim=True) fused_features = torch.cat([l3_points, l2_global], dim=1) fused_features = self.cross_scale_fusion(fused_features) # 增强注意力机制 channel_att_weights = self.channel_att(fused_features) spatial_att_weights = self.spatial_att(fused_features) fused_features = fused_features * channel_att_weights * spatial_att_weights # 类别嵌入 cls_embed = self.cls_emb(cls_label).unsqueeze(2) fused_features = fused_features + cls_embed # 全局特征 global_features = fused_features.repeat(1, 1, N) # 解码器 l2_points = self.fp3(l2_xyz, l3_xyz, l2_points, l3_points) l2_points = self.dec3(l2_points) l2_points = self.transformer_fp3(l2_points) # FP Transformer 1 l1_points = self.fp2(l1_xyz, l2_xyz, l1_points, l2_points) l1_points = self.dec2(l1_points) l1_points = self.transformer_fp2(l1_points) # FP Transformer 2 l0_points = self.fp1(l0_xyz, l1_xyz, torch.cat([l0_xyz, l0_points], 1), l1_points) l0_points = self.dec1(l0_points) l0_points = self.transformer_fp1(l0_points) # FP Transformer 3 # 特征适配 l0_points = self.feat_adapter(l0_points) # 分割头 combined = torch.cat([l0_points, global_features], dim=1) feat = F.silu(self.bn1(self.conv1(combined))) x = self.drop1(feat) x = self.conv2(x) x = F.log_softmax(x, dim=1) return x, l3_points def init_weights(self): for m in self.modules(): if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Linear)): nn.init.kaiming_normal(m.weight, mode=‘fan_out’, nonlinearity=‘relu’) if m.bias is not None: nn.init.constant(m.bias, 0) elif isinstance(m, (nn.BatchNorm1d, nn.LayerNorm)): nn.init.constant(m.weight, 1.0) nn.init.constant(m.bias, 0) class HybridLoss(nn.Module): “”“改进的混合损失函(Focal Loss + Dice Loss)”“” def init(self, num_classes, class_weights=None, focal_weight=0.7, dice_weight=0.3, focal_alpha=0.5, focal_gamma=2.0): super().init() self.focal = FocalLoss(alpha=focal_alpha, gamma=focal_gamma) self.dice = DiceLoss(num_classes) self.focal_weight = focal_weight self.dice_weight = dice_weight def forward(self, pred, target): focal_loss = self.focal(pred, target) dice_loss = self.dice(pred, target) return self.focal_weight * focal_loss + self.dice_weight * dice_loss *coding:utf-8 * import torch import torch.nn as nn import torch.nn.functional as F from utils import PointNetSetAbstraction class TransformerBlock(nn.Module): “”“点云Transformer模块”“” def init(self, d_model, nhead, dim_feedforward=512, dropout=0.1): super().init() self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) self.linear1 = nn.Linear(d_model, dim_feedforward) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) self.activation = F.gelu def forward(self, x): x = x.permute(2, 0, 1) attn_output, _ = self.self_attn(x, x, x) x = x + self.dropout1(attn_output) x = self.norm1(x) ff_output = self.linear2(self.dropout(self.activation(self.linear1(x)))) x = x + self.dropout2(ff_output) x = self.norm2(x) return x.permute(1, 2, 0) class ChannelAttention(nn.Module): “”“改进的通道注意力机制”“” def init(self, in_channels, reduction_ratio=8): super().init() self.avg_pool = nn.AdaptiveAvgPool1d(1) self.max_pool = nn.AdaptiveMaxPool1d(1) mid_channels = max(1, in_channels // reduction_ratio) self.mlp = nn.Sequential( nn.Conv1d(in_channels, mid_channels, 1), nn.ReLU(), nn.Conv1d(mid_channels, in_channels, 1) ) self.sigmoid = nn.Sigmoid() def forward(self, x): avg_out = self.mlp(self.avg_pool(x)) max_out = self.mlp(self.max_pool(x)) out = avg_out + max_out return self.sigmoid(out) class MBConvPoint(nn.Module): “”“改进的MBConv模块”“” def init(self, in_channels, out_channels, expand_ratio, se_ratio, stride): super().init() self.in_channels = in_channels self.out_channels = out_channels self.expand_ratio = expand_ratio self.stride = stride self.has_se = se_ratio is not None and 0 < se_ratio <= 1 expanded_channels = in_channels * expand_ratio if expand_ratio != 1: self.expand_conv = nn.Conv1d(in_channels, expanded_channels, 1, bias=False) self.bn0 = nn.BatchNorm1d(expanded_channels) self.swish0 = nn.SiLU() self.depthwise_conv = nn.Conv1d( expanded_channels, expanded_channels, 1, groups=expanded_channels, bias=False ) self.bn1 = nn.BatchNorm1d(expanded_channels) self.swish1 = nn.SiLU() if self.has_se: self.se = ChannelAttention(expanded_channels, reduction_ratio=8) self.project_conv = nn.Conv1d(expanded_channels, out_channels, 1, bias=False) self.bn2 = nn.BatchNorm1d(out_channels) self.use_skip = in_channels == out_channels and stride == 1 if self.use_skip: self.shortcut = nn.Identity() else: self.shortcut = nn.Sequential( nn.Conv1d(in_channels, out_channels, 1, bias=False), nn.BatchNorm1d(out_channels) ) def forward(self, x): identity = x if self.expand_ratio != 1: x = self.expand_conv(x) x = self.bn0(x) x = self.swish0(x) x = self.depthwise_conv(x) x = self.bn1(x) x = self.swish1(x) if self.has_se: se_weights = self.se(x) x = x * se_weights x = self.project_conv(x) x = self.bn2(x) if self.use_skip: return x + identity return x + self.shortcut(identity) class PointEfficientNetForClassification(nn.Module): “”“用于ScanObjectNN分类任务的增强版Point-EfficientNet”“” def init(self, num_classes=15, normal_channel=False): super().init() self.normal_channel = normal_channel additional_channel = 3 if normal_channel else 0 # 编码器 self.enc1 = self._make_encoder_stage(3 + additional_channel, 32, 2) self.sa1 = PointNetSetAbstraction(npoint=512, radius=0.2, nsample=64, in_channel=32, mlp=[32, 32, 64], group_all=False) self.transformer1 = TransformerBlock(64, nhead=4) self.enc2 = self.make_encoder_stage(64, 64, 3) self.sa2 = PointNetSetAbstraction(npoint=128, radius=0.4, nsample=64, in_channel=64, mlp=[64, 64, 128], group_all=False) self.transformer2 = TransformerBlock(128, nhead=4) self.enc3 = self.make_encoder_stage(128, 128, 4) self.sa3 = PointNetSetAbstraction(npoint=None, radius=None, nsample=None, in_channel=128, mlp=[128, 256, 512], group_all=True) self.transformer3 = TransformerBlock(512, nhead=8) # 多尺度融合 self.cross_scale_fusion = nn.Sequential( nn.Conv1d(512 + 128, 512, 1), nn.BatchNorm1d(512), nn.LeakyReLU(0.2) ) # 增强注意力机制 self.channel_att = ChannelAttention(512, reduction_ratio=8) self.spatial_att = nn.Sequential( nn.Conv1d(512, 256, 1), nn.BatchNorm1d(256), nn.ReLU(), nn.Conv1d(256, 1, 1), nn.Sigmoid() ) # 分类头 self.classifier = nn.Sequential( nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(inplace=True), nn.Dropout(0.5), nn.Linear(256, num_classes) ) def make_encoder_stage(self, in_channels, out_channels, num_blocks): blocks = [] blocks.append(MBConvPoint(in_channels, out_channels, expand_ratio=1, se_ratio=0.25, stride=1)) for _ in range(1, num_blocks): blocks.append(MBConvPoint(out_channels, out_channels, expand_ratio=6, se_ratio=0.25, stride=1)) return nn.Sequential(*blocks) def forward(self, xyz): B, , N = xyz.shape # 初始化 if self.normal_channel: l0_points = xyz l0_xyz = xyz[:, :3, :] else: l0_points = xyz l0_xyz = xyz # 编码器 l0_points = self.enc1(l0_points) l1_xyz, l1_points = self.sa1(l0_xyz, l0_points) l1_points = self.enc2(l1_points) l1_points = self.transformer1(l1_points) l2_xyz, l2_points = self.sa2(l1_xyz, l1_points) l2_points = self.enc3(l2_points) l2_points = self.transformer2(l2_points) l3_xyz, l3_points = self.sa3(l2_xyz, l2_points) l3_points = self.transformer3(l3_points) # 多尺度特征融合 l2_global = torch.mean(l2_points, dim=2, keepdim=True) fused_features = torch.cat([l3_points, l2_global], dim=1) fused_features = self.cross_scale_fusion(fused_features) # 增强注意力机制 channel_att_weights = self.channel_att(fused_features) spatial_att_weights = self.spatial_att(fused_features) fused_features = fused_features * channel_att_weights * spatial_att_weights # 全局池化 global_features = F.adaptive_avg_pool1d(fused_features, 1).squeeze(-1) # 分类 x = self.classifier(global_features) # # 打印输出度 # print(f"global_features shape: {global_features.shape}“) # print(f"Output shape: {x.shape}”) return x def init_weights(self): for m in self.modules(): if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Linear)): nn.init.kaiming_normal(m.weight, mode=‘fan_out’, nonlinearity=‘relu’) if m.bias is not None: nn.init.constant(m.bias, 0) elif isinstance(m, (nn.BatchNorm1d, nn.LayerNorm)): nn.init.constant(m.weight, 1.0) nn.init.constant(m.bias, 0) 我把我的模型从上面的分割模型变成了下面的分类模型。也就是之前用于训练ShapeNetPart数据集,现在用于训练ScanObjectNN数据集。可是为什么修改之前在shapenetpart数据集上能达到80多的miou,但是在分类数据集上准确率一直都是0.0几怎么变化?这明显有问题不能这么差。请帮我仔细阅读代码,分析原因
07-08
# -*- coding: utf-8 -*- """ 📌 完整版:基于 DDC 的迁移诊断(支持极少 Normal 样本) 🔧 功能亮点: 1. 自适应 SMOTE / RandomOverSampler 解决小样本问题 2. 修复 t-SNE 绘图匹配 Bug 3. 输出目标域预测结果表 4. 迁移前后 t-SNE 对比图 5. 预测分布柱状图 6. 新增:伪标签置信度分布图(按类别分组 + 整体直方图) """ import numpy as np import pandas as pd import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, TensorDataset from sklearn.preprocessing import StandardScaler from sklearn.metrics import accuracy_score from sklearn.manifold import TSNE import matplotlib.pyplot as plt import seaborn as sns from imblearn.over_sampling import SMOTE, RandomOverSampler # ====================== 1. 加载数据 ====================== print("🚀 开始加载特征数据...") df = pd.read_csv('extracted_features_with_domain.csv') source_data = df[df['domain'] == 'source'].copy() target_data = df[df['domain'] == 'target'].copy() X_s = source_data.drop(columns=['filename', 'label', 'domain']).values y_s = source_data['label'].values filenames_t = target_data['filename'].values X_t = target_data.drop(columns=['filename', 'label', 'domain']).values class_names = ['Normal', 'Outer Race', 'Inner Race', 'Ball'] print(f"✅ 源域原始样本: {len(X_s)}") print(f"✅ 各类样本统计:\n{pd.Series(y_s).value_counts().sort_index()}") # ------------------------ 过采样 Normal 类(智能选择方法)------------------------ n_normal = (y_s == 0).sum() sampling_strategy = {0: 40} # 目标:将 Normal 补到 40 个 if n_normal >= 6: k_neighbors = 5 oversampler = SMOTE(sampling_strategy=sampling_strategy, k_neighbors=k_neighbors, random_state=42) print(f"🔍 使用 SMOTE(k={k_neighbors}) 对 Normal 类进过采样...") elif n_normal >= 2: k_neighbors = min(3, n_normal - 1) oversampler = SMOTE(sampling_strategy=sampling_strategy, k_neighbors=k_neighbors, random_state=42) print(f"🔍 使用 SMOTE(k={k_neighbors})(样本少,降低邻居...") else: oversampler = RandomOverSampler(sampling_strategy=sampling_strategy, random_state=42) print(f"⚠️ Normal 类仅 {n_normal} 个样本,改用 RandomOverSampler(简单复制)...") try: X_s_resampled, y_s_resampled = oversampler.fit_resample(X_s, y_s) except Exception as e: raise RuntimeError(f"过采样失败:{e}") print(f"✅ 过采样完成!新源域样本: {len(X_s_resampled)}") print(f"✅ 新各类统计:\n{pd.Series(y_s_resampled).value_counts().sort_index()}") # ====================== 2. 数据标准化 ====================== scaler = StandardScaler() X_all_scaled = scaler.fit_transform(np.vstack((X_s_resampled, X_t))) X_s_scaled = X_all_scaled[:len(X_s_resampled)] X_t_scaled = X_all_scaled[len(X_s_resampled):] # 转为 Tensor X_s_tensor = torch.FloatTensor(X_s_scaled) y_s_tensor = torch.LongTensor(y_s_resampled) X_t_tensor = torch.FloatTensor(X_t_scaled) # 创建 DataLoader batch_size = 32 dataset_s = TensorDataset(X_s_tensor, y_s_tensor) loader_s = DataLoader(dataset_s, batch_size=batch_size, shuffle=True) loader_t = DataLoader(TensorDataset(X_t_tensor), batch_size=batch_size, shuffle=False) # ====================== 3. MMD 损失函 ====================== def compute_mmd(x_src, x_tar, kernel_type='rbf'): if kernel_type == 'linear': xx = torch.mm(x_src, x_src.t()) yy = torch.mm(x_tar, x_tar.t()) xy = torch.mm(x_src, x_tar.t()) return xx.mean() + yy.mean() - 2 * xy.mean() elif kernel_type == 'rbf': gamma = 1.0 / x_src.size(1) XX = torch.cdist(x_src.unsqueeze(0), x_src.unsqueeze(1)) ** 2 YY = torch.cdist(x_tar.unsqueeze(0), x_tar.unsqueeze(1)) ** 2 XY = torch.cdist(x_src.unsqueeze(0), x_tar.unsqueeze(1)) ** 2 K_XX = torch.exp(-gamma * XX.squeeze(0)) K_YY = torch.exp(-gamma * YY.squeeze(0)) K_XY = torch.exp(-gamma * XY.squeeze(0)) return K_XX.mean() + K_YY.mean() - 2 * K_XY.mean() else: raise ValueError("Unsupported kernel type") # ====================== 4. 模型定义 ====================== class FeatureExtractor(nn.Module): def __init__(self, input_dim=41, hidden_dim=64): super().__init__() self.net = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Dropout(0.3), nn.Linear(hidden_dim, 32), nn.ReLU() ) def forward(self, x): return self.net(x) class LabelClassifier(nn.Module): def __init__(self, input_dim=32, num_classes=4): super().__init__() self.classifier = nn.Linear(input_dim, num_classes) def forward(self, x): return self.classifier(x) # 初始化模型 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") feature_extractor = FeatureExtractor(input_dim=X_s.shape[1]).to(device) label_classifier = LabelClassifier().to(device) # 损失函(类别加权) class_weights = torch.tensor([1.5, 1.0, 1.0, 1.0]).to(device) # 提高 Normal 权重 criterion_cls = nn.CrossEntropyLoss(weight=class_weights) optimizer = optim.Adam(list(feature_extractor.parameters()) + list(label_classifier.parameters()), lr=1e-3) # ====================== 5. 训练循环 ====================== num_epochs = 200 lambda_mmd = 1.0 log_interval = 50 losses_cls = [] losses_mmd = [] feature_extractor.train() label_classifier.train() for epoch in range(num_epochs): total_cls_loss = 0 total_mmd_loss = 0 n_batch = 0 src_iter = iter(loader_s) tgt_iter = iter(loader_t) for (x_s, y_s_batch) in src_iter: try: (x_t,) = next(tgt_iter) except StopIteration: tgt_iter = iter(loader_t) (x_t,) = next(tgt_iter) x_s, y_s_batch = x_s.to(device), y_s_batch.to(device) x_t = x_t.to(device) optimizer.zero_grad() feat_s = feature_extractor(x_s) feat_t = feature_extractor(x_t) logits = label_classifier(feat_s) cls_loss = criterion_cls(logits, y_s_batch) mmd_loss = compute_mmd(feat_s, feat_t, kernel_type='rbf') loss = cls_loss + lambda_mmd * mmd_loss loss.backward() optimizer.step() total_cls_loss += cls_loss.item() total_mmd_loss += mmd_loss.item() n_batch += 1 avg_cls_loss = total_cls_loss / n_batch avg_mmd_loss = total_mmd_loss / n_batch losses_cls.append(avg_cls_loss) losses_mmd.append(avg_mmd_loss) if (epoch+1) % log_interval == 0: print(f"Epoch [{epoch+1}/{num_epochs}], Cls Loss: {avg_cls_loss:.4f}, MMD Loss: {avg_mmd_loss:.4f}") # ====================== 6. 预测目标域 ====================== feature_extractor.eval() label_classifier.eval() preds = [] confidences = [] with torch.no_grad(): for (x_t,) in loader_t: x_t = x_t.to(device) feat = feature_extractor(x_t) logits = label_classifier(feat) probs = torch.softmax(logits, dim=1) pred = torch.argmax(probs, dim=1).cpu().numpy() conf = probs.max(dim=1)[0].cpu().numpy() preds.extend(pred) confidences.extend(conf) predicted_labels = [class_names[p] for p in preds] result_df = pd.DataFrame({ 'Filename': filenames_t, 'Predicted_Class': predicted_labels, 'Confidence': np.round(confidences, 3) }) print("\n📋 目标域最终预测结果:") print(result_df) result_df.to_csv('target_domain_predictions_final.csv', index=False) # ====================== 7. t-SNE:迁移前后对比图(已修复)====================== def plot_tsne_before_after(X_s_orig, y_s_orig, X_t_orig, X_s_trans, X_t_trans, y_s_trans): fig, axes = plt.subplots(1, 2, figsize=(16, 7)) class_names = ['Normal', 'Outer Race', 'Inner Race', 'Ball'] colors = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red'] # --- 左图:原始特征空间 --- X_before = np.vstack((X_s_orig, X_t_orig)) # (160 + 15, 41) tsne = TSNE(n_components=2, perplexity=min(30, len(X_before)-1), n_iter=1000, random_state=42) X_tsne_before = tsne.fit_transform(X_before) n_src = len(X_s_orig) n_tar = len(X_t_orig) for i, name in enumerate(class_names): idx_src = (y_s_orig == i) axes[0].scatter( X_tsne_before[:n_src][idx_src, 0], X_tsne_before[:n_src][idx_src, 1], c=colors[i], label=name, alpha=0.7, s=50 ) axes[0].scatter( X_tsne_before[n_src:, 0], X_tsne_before[n_src:, 1], c='gray', marker='x', s=100, label='Target', zorder=5 ) axes[0].set_title('t-SNE Before Transfer\n(Raw Features)') axes[0].set_xlabel('t-SNE Component 1') axes[0].set_ylabel('t-SNE Component 2') axes[0].legend() axes[0].grid(True, linestyle='--', alpha=0.5) # --- 右图:迁移后特征空间 --- X_after = np.vstack((X_s_trans, X_t_trans)) X_tsne_after = tsne.fit_transform(X_after) for i, name in enumerate(class_names): idx_src = (y_s_trans == i) axes[1].scatter( X_tsne_after[:n_src][idx_src, 0], X_tsne_after[:n_src][idx_src, 1], c=colors[i], label=name, alpha=0.7, s=50 ) axes[1].scatter( X_tsne_after[n_src:, 0], X_tsne_after[n_src:, 1], c='gray', marker='x', s=100, label='Target (Predicted)', zorder=5 ) axes[1].set_title('t-SNE After Transfer\n(Aligned by DDC)') axes[1].set_xlabel('t-SNE Component 1') axes[1].set_ylabel('t-SNE Component 2') axes[1].legend() axes[1].grid(True, linestyle='--', alpha=0.5) plt.tight_layout() plt.savefig('tsne_before_after_transfer.png', dpi=150) plt.show() # 获取迁移后特征表示 with torch.no_grad(): X_s_trans = feature_extractor(X_s_tensor.to(device)).cpu().numpy() X_t_trans = feature_extractor(X_t_tensor.to(device)).cpu().numpy() plot_tsne_before_after(X_s_resampled, y_s_resampled, X_t, X_s_trans, X_t_trans, y_s_resampled) # ====================== 8. 预测分布柱状图 ====================== plt.figure(figsize=(10, 6)) sns.countplot(data=result_df, x='Predicted_Class', palette='Set2', order=class_names) plt.title('Distribution of Predicted Fault Types in Target Domain') plt.ylabel('Count') for i, v in enumerate(result_df['Predicted_Class'].value_counts()[class_names]): plt.text(i, v + 0.05, str(v), ha='center', va='bottom') plt.grid(True, axis='y', linestyle='--', alpha=0.5) plt.savefig('prediction_distribution_final.png', dpi=150) plt.show() # ====================== 9. 伪标签置信度分布图 ====================== plt.figure(figsize=(12, 7)) # 分箱处理 result_df['Confidence_Bin'] = pd.cut( result_df['Confidence'], bins=[0.0, 0.5, 0.7, 0.8, 0.9, 1.0], labels=['[0.0-0.5)', '[0.5-0.7)', '[0.7-0.8)', '[0.8-0.9)', '[0.9-1.0]'] ) # 创建交叉表 confusion_table = pd.crosstab( result_df['Predicted_Class'], result_df['Confidence_Bin'], margins=True, margins_name='Total' ) print("\n📋 伪标签置信度交叉统计表(按类别置信区间):") print(confusion_table) # 绘制堆叠柱状图 confidence_bins = ['[0.0-0.5)', '[0.5-0.7)', '[0.7-0.8)', '[0.8-0.9)', '[0.9-1.0]'] colors_conf = ['#d9534f', '#f0ad4e', '#5bc0de', '#5cb85c', '#337ab7'] plot_data = result_df.copy() plot_data['Confidence_Bin'] = pd.Categorical(plot_data['Confidence_Bin'], categories=confidence_bins, ordered=True) pivot_table = plot_data.groupby(['Predicted_Class', 'Confidence_Bin']).size().unstack(fill_value=0) pivot_table = pivot_table.reindex(class_names) pivot_table.plot(kind='bar', stacked=True, color=colors_conf, ax=plt.gca()) plt.title('Pseudo-Label Confidence Distribution by Predicted Class', fontsize=14) plt.xlabel('Predicted Fault Type') plt.ylabel('Number of Samples') plt.xticks(rotation=0) plt.legend(title='Confidence Interval', bbox_to_anchor=(1.05, 1), loc='upper left') plt.grid(True, axis='y', linestyle='--', alpha=0.5) plt.tight_layout() plt.savefig('pseudo_label_confidence_distribution.png', dpi=150) plt.show() # --- 整体置信度直方图 --- plt.figure(figsize=(10, 6)) plt.hist(result_df['Confidence'], bins=20, color='skyblue', edgecolor='black', alpha=0.8) plt.axvline(x=0.5, color='red', linestyle='--', label='Low Confidence Threshold (0.5)') plt.axvline(x=0.9, color='green', linestyle='--', label='High Confidence Threshold (0.9)') plt.title('Overall Confidence Distribution of Pseudo-Labels in Target Domain') plt.xlabel('Prediction Confidence') plt.ylabel('Frequency') plt.legend() plt.grid(True, axis='y', linestyle='--', alpha=0.5) plt.tight_layout() plt.savefig('overall_confidence_histogram.png', dpi=150) plt.show() print("🎉 所有任务完成!") print("📄 输出文件:") print(" - target_domain_predictions_final.csv") print(" - tsne_before_after_transfer.png") print(" - prediction_distribution_final.png") print(" - pseudo_label_confidence_distribution.png") print(" - overall_confidence_histogram.png") 以上是任务三的代码,需要对其进改进,使用任务二中的诊断模型,并且需要进特征对齐,然后运用迁移学习技术
最新发布
09-24
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值