5min看懂torch.einsum()计算方法-torch.einsum()手动推导详解

本文详细分析了torch.einsum函数在处理复杂案例中的计算过程,尤其强调了爱因斯坦求和约定的应用,通过实例展示了如何理解和应用张量乘法和广播机制。

引言: torch.einsum()的分析和介绍已经有很多博客介绍过了, 但大多数的落脚点都是爱因斯坦求和约定,许多篇幅是用于介绍爱因斯坦求和约定到的各项法则,而实际案例分析方面只是草草给出一笔带过,涉及到的案例也较为简单。而实际我们要用到或者看到torch.einsum()的时候往往是在计算非常复杂的情况下。
因此本文将从实际复杂案例的角度对torch.einsum()的计算过程进行分析,一步一步的推导最终输出的每个元素和输入元素之间的关系。

爱因斯坦求和约定

 首先,torch.einsum()的基础原理是爱因斯坦求和约定,此处为了行文的整体性将对其进行简要的介绍,如果只关注计算本身,可以跳到下一节。爱因斯坦求和约定是为了简化计算而诞生的一种“记法”,就类似于我们用 × \times ×来标记乘法一样,不同之处在于爱因斯坦求和约定可表示的运算更为复杂、灵活性也更高。爱因斯坦求和约定的典型写法为:
i 1 i 2 . . . i N , j 1 j 2 . . . j M → i k 1 i k 2 . . j l 1 j l 1 , k 1 . . . ∈ N , l 1 . . l ∈ M i_1i_2...i_N,j_1j_2...j_M\rightarrow i_{k_1}i_{k_2}..j_{l_1}j_{l_1},k_1...\in N,l_1..l\in M i1i2...iN,j1j2...jMik1ik2..jl1jl1

""" EEG Conformer Convolutional Transformer for EEG decoding Couple CNN and Transformer in a concise manner with amazing results """ # remember to change paths import argparse import os gpus = [0] os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, gpus)) import numpy as np import math import glob import random import itertools import datetime import time import datetime import sys import scipy.io import torchvision.transforms as transforms from torchvision.utils import save_image, make_grid from torch.utils.data import DataLoader from torch.autograd import Variable from torchsummary import summary import torch.autograd as autograd from torchvision.models import vgg19 import torch.nn as nn import torch.nn.functional as F import torch import torch.nn.init as init from torch.utils.data import Dataset from PIL import Image import torchvision.transforms as transforms from sklearn.decomposition import PCA import torch import torch.nn.functional as F import matplotlib.pyplot as plt from torch import nn from torch import Tensor from PIL import Image from torchvision.transforms import Compose, Resize, ToTensor from einops import rearrange, reduce, repeat from einops.layers.torch import Rearrange, Reduce # from common_spatial_pattern import csp import matplotlib.pyplot as plt # from torch.utils.tensorboard import SummaryWriter from torch.backends import cudnn cudnn.benchmark = False cudnn.deterministic = True # writer = SummaryWriter('./TensorBoardX/') # Convolution module # use conv to capture local features, instead of postion embedding. class PatchEmbedding(nn.Module): def __init__(self, emb_size=40): # self.patch_size = patch_size super().__init__() self.shallownet = nn.Sequential( nn.Conv2d(1, 40, (1, 25), (1, 1)), nn.Conv2d(40, 40, (22, 1), (1, 1)), nn.BatchNorm2d(40), nn.ELU(), nn.AvgPool2d((1, 75), (1, 15)), # pooling acts as slicing to obtain 'patch' along the time dimension as in ViT nn.Dropout(0.5), ) self.projection = nn.Sequential( nn.Conv2d(40, emb_size, (1, 1), stride=(1, 1)), # transpose, conv could enhance fiting ability slightly Rearrange('b e (h) (w) -> b (h w) e'), ) def forward(self, x: Tensor) -> Tensor: b, _, _, _ = x.shape x = self.shallownet(x) x = self.projection(x) return x class MultiHeadAttention(nn.Module): def __init__(self, emb_size, num_heads, dropout): super().__init__() self.emb_size = emb_size self.num_heads = num_heads self.keys = nn.Linear(emb_size, emb_size) self.queries = nn.Linear(emb_size, emb_size) self.values = nn.Linear(emb_size, emb_size) self.att_drop = nn.Dropout(dropout) self.projection = nn.Linear(emb_size, emb_size) def forward(self, x: Tensor, mask: Tensor = None) -> Tensor: queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads) keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.num_heads) values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads) energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) if mask is not None: fill_value = torch.finfo(torch.float32).min energy.mask_fill(~mask, fill_value) scaling = self.emb_size ** (1 / 2) att = F.softmax(energy / scaling, dim=-1) att = self.att_drop(att) out = torch.einsum('bhal, bhlv -> bhav ', att, values) out = rearrange(out, "b h n d -> b n (h d)") out = self.projection(out) return out class ResidualAdd(nn.Module): def __init__(self, fn): super().__init__() self.fn = fn def forward(self, x, **kwargs): res = x x = self.fn(x, **kwargs) x += res return x class FeedForwardBlock(nn.Sequential): def __init__(self, emb_size, expansion, drop_p): super().__init__( nn.Linear(emb_size, expansion * emb_size), nn.GELU(), nn.Dropout(drop_p), nn.Linear(expansion * emb_size, emb_size), ) class GELU(nn.Module): def forward(self, input: Tensor) -> Tensor: return input*0.5*(1.0+torch.erf(input/math.sqrt(2.0))) class TransformerEncoderBlock(nn.Sequential): def __init__(self, emb_size, num_heads=10, drop_p=0.5, forward_expansion=4, forward_drop_p=0.5): super().__init__( ResidualAdd(nn.Sequential( nn.LayerNorm(emb_size), MultiHeadAttention(emb_size, num_heads, drop_p), nn.Dropout(drop_p) )), ResidualAdd(nn.Sequential( nn.LayerNorm(emb_size), FeedForwardBlock( emb_size, expansion=forward_expansion, drop_p=forward_drop_p), nn.Dropout(drop_p) ) )) class TransformerEncoder(nn.Sequential): def __init__(self, depth, emb_size): super().__init__(*[TransformerEncoderBlock(emb_size) for _ in range(depth)]) class ClassificationHead(nn.Sequential): def __init__(self, emb_size, n_classes): super().__init__() # global average pooling self.clshead = nn.Sequential( Reduce('b n e -> b e', reduction='mean'), nn.LayerNorm(emb_size), nn.Linear(emb_size, n_classes) ) self.fc = nn.Sequential( nn.Linear(2440, 256), nn.ELU(), nn.Dropout(0.5), nn.Linear(256, 32), nn.ELU(), nn.Dropout(0.3), nn.Linear(32, 4) ) def forward(self, x): x = x.contiguous().view(x.size(0), -1) out = self.fc(x) return x, out class Conformer(nn.Sequential): def __init__(self, emb_size=40, depth=6, n_classes=4, **kwargs): super().__init__( PatchEmbedding(emb_size), TransformerEncoder(depth, emb_size), ClassificationHead(emb_size, n_classes) ) class ExP(): def __init__(self, nsub): super(ExP, self).__init__() self.batch_size = 72 self.n_epochs = 2000 self.c_dim = 4 self.lr = 0.0002 self.b1 = 0.5 self.b2 = 0.999 self.dimension = (190, 50) self.nSub = nsub self.start_epoch = 0 self.root = '/Data/strict_TE/' self.log_write = open("./results/log_subject%d.txt" % self.nSub, "w") self.Tensor = torch.cuda.FloatTensor self.LongTensor = torch.cuda.LongTensor self.criterion_l1 = torch.nn.L1Loss().cuda() self.criterion_l2 = torch.nn.MSELoss().cuda() self.criterion_cls = torch.nn.CrossEntropyLoss().cuda() self.model = Conformer().cuda() self.model = nn.DataParallel(self.model, device_ids=[i for i in range(len(gpus))]) self.model = self.model.cuda() # summary(self.model, (1, 22, 1000)) # Segmentation and Reconstruction (S&R) data augmentation def interaug(self, timg, label): aug_data = [] aug_label = [] for cls4aug in range(4): cls_idx = np.where(label == cls4aug + 1) tmp_data = timg[cls_idx] tmp_label = label[cls_idx] tmp_aug_data = np.zeros((int(self.batch_size / 4), 1, 22, 1000)) for ri in range(int(self.batch_size / 4)): for rj in range(8): rand_idx = np.random.randint(0, tmp_data.shape[0], 8) tmp_aug_data[ri, :, :, rj * 125:(rj + 1) * 125] = tmp_data[rand_idx[rj], :, :, rj * 125:(rj + 1) * 125] aug_data.append(tmp_aug_data) aug_label.append(tmp_label[:int(self.batch_size / 4)]) aug_data = np.concatenate(aug_data) aug_label = np.concatenate(aug_label) aug_shuffle = np.random.permutation(len(aug_data)) aug_data = aug_data[aug_shuffle, :, :] aug_label = aug_label[aug_shuffle] aug_data = torch.from_numpy(aug_data).cuda() aug_data = aug_data.float() aug_label = torch.from_numpy(aug_label-1).cuda() aug_label = aug_label.long() return aug_data, aug_label def get_source_data(self): # ! please please recheck if you need validation set # ! and the data segement compared methods used # train data self.total_data = scipy.io.loadmat(self.root + 'A0%dT.mat' % self.nSub) self.train_data = self.total_data['data'] self.train_label = self.total_data['label'] self.train_data = np.transpose(self.train_data, (2, 1, 0)) self.train_data = np.expand_dims(self.train_data, axis=1) self.train_label = np.transpose(self.train_label) self.allData = self.train_data self.allLabel = self.train_label[0] shuffle_num = np.random.permutation(len(self.allData)) self.allData = self.allData[shuffle_num, :, :, :] self.allLabel = self.allLabel[shuffle_num] # test data self.test_tmp = scipy.io.loadmat(self.root + 'A0%dE.mat' % self.nSub) self.test_data = self.test_tmp['data'] self.test_label = self.test_tmp['label'] self.test_data = np.transpose(self.test_data, (2, 1, 0)) self.test_data = np.expand_dims(self.test_data, axis=1) self.test_label = np.transpose(self.test_label) self.testData = self.test_data self.testLabel = self.test_label[0] # standardize target_mean = np.mean(self.allData) target_std = np.std(self.allData) self.allData = (self.allData - target_mean) / target_std self.testData = (self.testData - target_mean) / target_std # data shape: (trial, conv channel, electrode channel, time samples) return self.allData, self.allLabel, self.testData, self.testLabel def train(self): img, label, test_data, test_label = self.get_source_data() img = torch.from_numpy(img) label = torch.from_numpy(label - 1) dataset = torch.utils.data.TensorDataset(img, label) self.dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=True) test_data = torch.from_numpy(test_data) test_label = torch.from_numpy(test_label - 1) test_dataset = torch.utils.data.TensorDataset(test_data, test_label) self.test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=self.batch_size, shuffle=True) # Optimizers self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, betas=(self.b1, self.b2)) test_data = Variable(test_data.type(self.Tensor)) test_label = Variable(test_label.type(self.LongTensor)) bestAcc = 0 averAcc = 0 num = 0 Y_true = 0 Y_pred = 0 # Train the cnn model total_step = len(self.dataloader) curr_lr = self.lr for e in range(self.n_epochs): # in_epoch = time.time() self.model.train() for i, (img, label) in enumerate(self.dataloader): img = Variable(img.cuda().type(self.Tensor)) label = Variable(label.cuda().type(self.LongTensor)) # data augmentation aug_data, aug_label = self.interaug(self.allData, self.allLabel) img = torch.cat((img, aug_data)) label = torch.cat((label, aug_label)) tok, outputs = self.model(img) loss = self.criterion_cls(outputs, label) self.optimizer.zero_grad() loss.backward() self.optimizer.step() # out_epoch = time.time() # test process if (e + 1) % 1 == 0: self.model.eval() Tok, Cls = self.model(test_data) loss_test = self.criterion_cls(Cls, test_label) y_pred = torch.max(Cls, 1)[1] acc = float((y_pred == test_label).cpu().numpy().astype(int).sum()) / float(test_label.size(0)) train_pred = torch.max(outputs, 1)[1] train_acc = float((train_pred == label).cpu().numpy().astype(int).sum()) / float(label.size(0)) print('Epoch:', e, ' Train loss: %.6f' % loss.detach().cpu().numpy(), ' Test loss: %.6f' % loss_test.detach().cpu().numpy(), ' Train accuracy %.6f' % train_acc, ' Test accuracy is %.6f' % acc) self.log_write.write(str(e) + " " + str(acc) + "\n") num = num + 1 averAcc = averAcc + acc if acc > bestAcc: bestAcc = acc Y_true = test_label Y_pred = y_pred torch.save(self.model.module.state_dict(), 'model.pth') averAcc = averAcc / num print('The average accuracy is:', averAcc) print('The best accuracy is:', bestAcc) self.log_write.write('The average accuracy is: ' + str(averAcc) + "\n") self.log_write.write('The best accuracy is: ' + str(bestAcc) + "\n") return bestAcc, averAcc, Y_true, Y_pred # writer.close() def main(): best = 0 aver = 0 result_write = open("./results/sub_result.txt", "w") for i in range(9): starttime = datetime.datetime.now() seed_n = np.random.randint(2021) print('seed is ' + str(seed_n)) random.seed(seed_n) np.random.seed(seed_n) torch.manual_seed(seed_n) torch.cuda.manual_seed(seed_n) torch.cuda.manual_seed_all(seed_n) print('Subject %d' % (i+1)) exp = ExP(i + 1) bestAcc, averAcc, Y_true, Y_pred = exp.train() print('THE BEST ACCURACY IS ' + str(bestAcc)) result_write.write('Subject ' + str(i + 1) + ' : ' + 'Seed is: ' + str(seed_n) + "\n") result_write.write('Subject ' + str(i + 1) + ' : ' + 'The best accuracy is: ' + str(bestAcc) + "\n") result_write.write('Subject ' + str(i + 1) + ' : ' + 'The average accuracy is: ' + str(averAcc) + "\n") endtime = datetime.datetime.now() print('subject %d duration: '%(i+1) + str(endtime - starttime)) best = best + bestAcc aver = aver + averAcc if i == 0: yt = Y_true yp = Y_pred else: yt = torch.cat((yt, Y_true)) yp = torch.cat((yp, Y_pred)) best = best / 9 aver = aver / 9 result_write.write('**The average Best accuracy is: ' + str(best) + "\n") result_write.write('The average Aver accuracy is: ' + str(aver) + "\n") result_write.close() if __name__ == "__main__": print(time.asctime(time.localtime(time.time()))) main() print(time.asctime(time.localtime(time.time()))) 解释上述代码
最新发布
10-23
上述代码实现了一个名为 **EEG Conformer** 的深度学习模型,用于脑电图(EEG)信号的解码与分类。该模型结合了 **卷积神经网络 (CNN)** 和 **Transformer 架构**,在 EEG 分类任务中表现出色,尤其适用于运动想象(Motor Imagery, MI)等脑机接口(BCI)任务。 --- ## 🌟 模型概述 ### 名称:EEG Conformer - **目标**:对 EEG 数据进行端到端分类(如 4 类运动想象任务) - **核心思想**: - 使用 CNN 提取局部时空特征(替代传统 Transformer 的位置编码) - 利用 Transformer 编码器捕捉长程时间依赖关系 - 整体结构简洁高效,专为 EEG 设计 --- ## 🔧 主要模块详解 ### 1. `PatchEmbedding`:将原始 EEG 转换为“patch embeddings” ```python class PatchEmbedding(nn.Module): def __init__(self, emb_size=40): super().__init__() self.shallownet = nn.Sequential( nn.Conv2d(1, 40, (1, 25), (1, 1)), nn.Conv2d(40, 40, (22, 1), (1, 1)), nn.BatchNorm2d(40), nn.ELU(), nn.AvgPool2d((1, 75), (1, 15)), nn.Dropout(0.5), ) self.projection = nn.Sequential( nn.Conv2d(40, emb_size, (1, 1)), Rearrange('b e (h) (w) -> b (h w) e'), ) ``` #### 功能解释: - 输入形状:`(batch, 1, 22, 1000)` → 单通道、22 个电极、1000 时间点 - 第一个卷积 `(1,25)`:提取每个通道上的短时频特征(类似滤波器组) - 第二个卷积 `(22,1)`:跨所有电极做空间混合(相当于 CSP 的作用) - `AvgPool2d((1,75), (1,15))`:沿时间轴滑动平均池化,起到“切片”效果,生成多个“patches” - 最终通过 `Rearrange` 把每个 patch 映射成 embedding 向量 → 输出 `(b, n_patches, emb_size)` > ✅ 这是关键创新:用卷积代替 ViT 中的位置嵌入,更适合 EEG 的连续信号特性。 --- ### 2. `MultiHeadAttention`:多头自注意力机制 ```python class MultiHeadAttention(nn.Module): def __init__(self, emb_size, num_heads, dropout): super().__init__() self.keys = nn.Linear(emb_size, emb_size) self.queries = nn.Linear(emb_size, emb_size) self.values = nn.Linear(emb_size, emb_size) self.att_drop = nn.Dropout(dropout) self.projection = nn.Linear(emb_size, emb_size) def forward(self, x: Tensor, mask=None) -> Tensor: queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads) keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.num_heads) values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads) energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) scaling = self.emb_size ** (1/2) att = F.softmax(energy / scaling, dim=-1) att = self.att_drop(att) out = torch.einsum('bhal, bhlv -> bhav', att, values) out = rearrange(out, "b h n d -> b n (h d)") out = self.projection(out) return out ``` #### 解释: - 实现标准的缩放点积注意力。 - 使用 `einops.rearrange` 对张量进行维度重组,便于多头操作。 - 注意力权重计算基于 query-key 相似性。 - 支持 mask(但本项目未使用) --- ### 3. `ResidualAdd`:残差连接包装器 ```python class ResidualAdd(nn.Module): def __init__(self, fn): super().__init__() self.fn = fn def forward(self, x, **kwargs): res = x x = self.fn(x, **kwargs) x += res return x ``` - 实现恒等映射 + 子层输出的残差连接,有助于训练深层网络。 --- ### 4. `FeedForwardBlock`:前馈网络(MLP) ```python class FeedForwardBlock(nn.Sequential): def __init__(self, emb_size, expansion, drop_p): super().__init__( nn.Linear(emb_size, expansion * emb_size), nn.GELU(), nn.Dropout(drop_p), nn.Linear(expansion * emb_size, emb_size), ) ``` - 扩展维度 → GELU 激活 → Dropout → 压缩回原维度 - 是 Transformer 块的标准组件之一 --- ### 5. `TransformerEncoderBlock`:单个 Transformer 编码块 ```python class TransformerEncoderBlock(nn.Sequential): def __init__(...): super().__init__( ResidualAdd(nn.Sequential( nn.LayerNorm(emb_size), MultiHeadAttention(emb_size, num_heads, drop_p), nn.Dropout(drop_p) )), ResidualAdd(nn.Sequential( nn.LayerNorm(emb_size), FeedForwardBlock(...), nn.Dropout(drop_p) )) ) ``` - 包含两个残差分支: 1. 多头注意力 + LayerNorm + Dropout 2. 前馈网络 + LayerNorm + Dropout - 符合标准 Transformer 结构 --- ### 6. `TransformerEncoder`:堆叠多个编码块 ```python class TransformerEncoder(nn.Sequential): def __init__(self, depth, emb_size): super().__init__(*[TransformerEncoderBlock(emb_size) for _ in range(depth)]) ``` - 简单地堆叠 `depth` 层 `TransformerEncoderBlock` - 默认 `depth=6` --- ### 7. `ClassificationHead`:分类头 ```python class ClassificationHead(nn.Sequential): def __init__(self, emb_size, n_classes): super().__init__() self.clshead = nn.Sequential( Reduce('b n e -> b e', reduction='mean'), # 全局平均池化 nn.LayerNorm(emb_size), nn.Linear(emb_size, n_classes) ) self.fc = nn.Sequential( nn.Linear(2440, 256), nn.ELU(), nn.Dropout(0.5), nn.Linear(256, 32), nn.ELU(), nn.Dropout(0.3), nn.Linear(32, 4) ) def forward(self, x): x = x.contiguous().view(x.size(0), -1) out = self.fc(x) return x, out ``` ⚠️ ⚠️ **这里有严重问题!** - `x.view(x.size(0), -1)` 将整个 Transformer 输出 flatten 成 `(b, 2440)` - 但是 `2440` 是怎么来的?我们来推导一下: 假设输入大小为 `(b, 1, 22, 1000)` 经过 PatchEmbedding: - Conv → Pool: `AvgPool2d((1,75),(1,15))` 在时间轴上步长为15 - 时间维度从 1000 变为: $$ \left\lfloor \frac{1000 - 75}{15} \right\rfloor + 1 = \left\lfloor \frac{925}{15} \right\rfloor + 1 = 61 + 1 = 62 $$ - 所以输出 token 数量是 62,embedding size=40 → flatten 后是 `62*40=2480`,不是 2440! 👉 因此 `nn.Linear(2440, 256)` 很可能是错误的硬编码! 而且这个 `fc` 完全绕过了 `clshead`,实际并未使用 Transformer 的全局池化输出! ✅ 正确做法应为: ```python def forward(self, x): x = reduce(x, 'b n e -> b e', 'mean') x = self.clshead(x) return x, x # 或只返回 x ``` 当前 `ClassificationHead` 是冗余且有 bug 的。 --- ### 8. `Conformer`:整体模型架构 ```python class Conformer(nn.Sequential): def __init__(self, emb_size=40, depth=6, n_classes=4, **kwargs): super().__init__( PatchEmbedding(emb_size), TransformerEncoder(depth, emb_size), ClassificationHead(emb_size, n_classes) ) ``` - 按顺序连接三大模块 - 注意:由于 `ClassificationHead` 有问题,最终性能可能受限 --- ### 9. `ExP` 类:实验控制类 负责数据加载、增强、训练和测试流程。 #### 关键方法: ##### `get_source_data()`: - 加载 `.mat` 文件中的 EEG 数据(BCI Competition IV Dataset 2a 格式) - 数据格式转换:`(channel, time, trial)` → `(trial, 1, channel, time)` - 归一化处理(减均值除标准差) - 返回训练集和测试集 ##### `interaug()`:S&R 数据增强(Segmentation & Reconstruction) ```python def interaug(timg, label): for cls in 0..3: 随机选取同一类样本 从8个时间段各取一段125ms拼接成新样本 ``` - 每个新样本由8个不同 trial 的片段拼接而成 - 属于一种强数据增强策略,提升泛化能力 ##### `train()`: - 构建 DataLoader - 使用 Adam 优化器 - 训练过程中加入数据增强样本 - 每轮评估测试准确率 - 保存最佳模型和结果日志 --- ### 10. `main()` 函数 - 循环训练 9 个被试者(Subject 1~9) - 设置随机种子保证可复现 - 记录每个被试的最佳/平均准确率 - 最后输出总体平均表现 --- ## 📈 总结:模型优点与问题 ### ✅ 优点 | 特性 | 说明 | |------|------| | **CNN+Transformer融合** | CNN 提取局部特征,Transformer 捕捉长期依赖 | | **无需位置编码** | 用卷积自然提取时空结构,避免人工设计位置编码 | | **轻量高效** | 参数少,适合小样本 EEG 数据 | | **S&R 数据增强** | 提高数据多样性,缓解过拟合 | ### ❌ 存在的问题 | 问题 | 描述 | |------|------| | **分类头设计错误** | `Linear(2440, ...)` 维度不匹配,且未正确使用全局池化 | | **flatten 操作不合理** | 应使用 `[CLS]` token 或全局平均池化,而非 flatten 全部 tokens | | **缺少验证集** | 直接用测试集调参,可能导致过拟合测试集 | | **学习率固定** | 无学习率调度,影响收敛 | | **Dropout 过多/过大** | 如 `0.5` 在小数据上易导致欠拟合 | --- ## ✅ 改进建议代码示例(修复分类头) ```python class ClassificationHead(nn.Module): def __init__(self, emb_size, n_classes): super().__init__() self.global_pool = Reduce('b n e -> b e', reduction='mean') self.norm = nn.LayerNorm(emb_size) self.classifier = nn.Linear(emb_size, n_classes) def forward(self, x): x = self.global_pool(x) x = self.norm(x) logits = self.classifier(x) return logits # 返回 (b, n_classes) ``` 然后修改 `Conformer`: ```python class Conformer(nn.Module): # 不再继承 Sequential def __init__(self, emb_size=40, depth=6, n_classes=4): super().__init__() self.patch_emb = PatchEmbedding(emb_size) self.transformer = TransformerEncoder(depth, emb_size) self.cls_head = ClassificationHead(emb_size, n_classes) def forward(self, x): x = self.patch_emb(x) x = self.transformer(x) x = self.cls_head(x) return x ``` --- ## 💡 补充建议 1. **可视化注意力图**:观察哪些时间片段更重要 2. **添加学习率衰减**:`StepLR`, `ReduceLROnPlateau` 3. **早停机制(Early Stopping)** 4. **交叉验证**:更可靠的性能评估 5. **对比实验**:与 EEGNet、DeepConvNet、ShallowConvNet 对比 --- ##
评论 2
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值