5min看懂torch.einsum()计算方法-torch.einsum()手动推导详解

本文详细分析了torch.einsum函数在处理复杂案例中的计算过程,尤其强调了爱因斯坦求和约定的应用,通过实例展示了如何理解和应用张量乘法和广播机制。

引言: torch.einsum()的分析和介绍已经有很多博客介绍过了, 但大多数的落脚点都是爱因斯坦求和约定,许多篇幅是用于介绍爱因斯坦求和约定到的各项法则,而实际案例分析方面只是草草给出一笔带过,涉及到的案例也较为简单。而实际我们要用到或者看到torch.einsum()的时候往往是在计算非常复杂的情况下。
因此本文将从实际复杂案例的角度对torch.einsum()的计算过程进行分析,一步一步的推导最终输出的每个元素和输入元素之间的关系。

爱因斯坦求和约定

 首先,torch.einsum()的基础原理是爱因斯坦求和约定,此处为了行文的整体性将对其进行简要的介绍,如果只关注计算本身,可以跳到下一节。爱因斯坦求和约定是为了简化计算而诞生的一种“记法”,就类似于我们用 × \times ×来标记乘法一样,不同之处在于爱因斯坦求和约定可表示的运算更为复杂、灵活性也更高。爱因斯坦求和约定的典型写法为:
i 1 i 2 . . . i N , j 1 j 2 . . . j M → i k 1 i k 2 . . j l 1 j l 1 , k 1 . . . ∈ N , l 1 . . l ∈ M i_1i_2...i_N,j_1j_2...j_M\rightarrow i_{k_1}i_{k_2}..j_{l_1}j_{l_1},k_1...\in N,l_1..l\in M i1i2...iN,j1j2...jMik1ik2..jl1jl1

""" EEG Conformer Convolutional Transformer for EEG decoding Couple CNN and Transformer in a concise manner with amazing results """ # remember to change paths import argparse import os gpus = [0] os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, gpus)) import numpy as np import math import glob import random import itertools import datetime import time import datetime import sys import scipy.io import torchvision.transforms as transforms from torchvision.utils import save_image, make_grid from torch.utils.data import DataLoader from torch.autograd import Variable from torchsummary import summary import torch.autograd as autograd from torchvision.models import vgg19 import torch.nn as nn import torch.nn.functional as F import torch import torch.nn.init as init from torch.utils.data import Dataset from PIL import Image import torchvision.transforms as transforms from sklearn.decomposition import PCA import torch import torch.nn.functional as F import matplotlib.pyplot as plt from torch import nn from torch import Tensor from PIL import Image from torchvision.transforms import Compose, Resize, ToTensor from einops import rearrange, reduce, repeat from einops.layers.torch import Rearrange, Reduce # from common_spatial_pattern import csp import matplotlib.pyplot as plt # from torch.utils.tensorboard import SummaryWriter from torch.backends import cudnn cudnn.benchmark = False cudnn.deterministic = True # writer = SummaryWriter('./TensorBoardX/') # Convolution module # use conv to capture local features, instead of postion embedding. class PatchEmbedding(nn.Module): def __init__(self, emb_size=40): # self.patch_size = patch_size super().__init__() self.shallownet = nn.Sequential( nn.Conv2d(1, 40, (1, 25), (1, 1)), nn.Conv2d(40, 40, (22, 1), (1, 1)), nn.BatchNorm2d(40), nn.ELU(), nn.AvgPool2d((1, 75), (1, 15)), # pooling acts as slicing to obtain 'patch' along the time dimension as in ViT nn.Dropout(0.5), ) self.projection = nn.Sequential( nn.Conv2d(40, emb_size, (1, 1), stride=(1, 1)), # transpose, conv could enhance fiting ability slightly Rearrange('b e (h) (w) -> b (h w) e'), ) def forward(self, x: Tensor) -> Tensor: b, _, _, _ = x.shape x = self.shallownet(x) x = self.projection(x) return x class MultiHeadAttention(nn.Module): def __init__(self, emb_size, num_heads, dropout): super().__init__() self.emb_size = emb_size self.num_heads = num_heads self.keys = nn.Linear(emb_size, emb_size) self.queries = nn.Linear(emb_size, emb_size) self.values = nn.Linear(emb_size, emb_size) self.att_drop = nn.Dropout(dropout) self.projection = nn.Linear(emb_size, emb_size) def forward(self, x: Tensor, mask: Tensor = None) -> Tensor: queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads) keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.num_heads) values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads) energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) if mask is not None: fill_value = torch.finfo(torch.float32).min energy.mask_fill(~mask, fill_value) scaling = self.emb_size ** (1 / 2) att = F.softmax(energy / scaling, dim=-1) att = self.att_drop(att) out = torch.einsum('bhal, bhlv -> bhav ', att, values) out = rearrange(out, "b h n d -> b n (h d)") out = self.projection(out) return out class ResidualAdd(nn.Module): def __init__(self, fn): super().__init__() self.fn = fn def forward(self, x, **kwargs): res = x x = self.fn(x, **kwargs) x += res return x class FeedForwardBlock(nn.Sequential): def __init__(self, emb_size, expansion, drop_p): super().__init__( nn.Linear(emb_size, expansion * emb_size), nn.GELU(), nn.Dropout(drop_p), nn.Linear(expansion * emb_size, emb_size), ) class GELU(nn.Module): def forward(self, input: Tensor) -> Tensor: return input*0.5*(1.0+torch.erf(input/math.sqrt(2.0))) class TransformerEncoderBlock(nn.Sequential): def __init__(self, emb_size, num_heads=10, drop_p=0.5, forward_expansion=4, forward_drop_p=0.5): super().__init__( ResidualAdd(nn.Sequential( nn.LayerNorm(emb_size), MultiHeadAttention(emb_size, num_heads, drop_p), nn.Dropout(drop_p) )), ResidualAdd(nn.Sequential( nn.LayerNorm(emb_size), FeedForwardBlock( emb_size, expansion=forward_expansion, drop_p=forward_drop_p), nn.Dropout(drop_p) ) )) class TransformerEncoder(nn.Sequential): def __init__(self, depth, emb_size): super().__init__(*[TransformerEncoderBlock(emb_size) for _ in range(depth)]) class ClassificationHead(nn.Sequential): def __init__(self, emb_size, n_classes): super().__init__() # global average pooling self.clshead = nn.Sequential( Reduce('b n e -> b e', reduction='mean'), nn.LayerNorm(emb_size), nn.Linear(emb_size, n_classes) ) self.fc = nn.Sequential( nn.Linear(2440, 256), nn.ELU(), nn.Dropout(0.5), nn.Linear(256, 32), nn.ELU(), nn.Dropout(0.3), nn.Linear(32, 4) ) def forward(self, x): x = x.contiguous().view(x.size(0), -1) out = self.fc(x) return x, out class Conformer(nn.Sequential): def __init__(self, emb_size=40, depth=6, n_classes=4, **kwargs): super().__init__( PatchEmbedding(emb_size), TransformerEncoder(depth, emb_size), ClassificationHead(emb_size, n_classes) ) class ExP(): def __init__(self, nsub): super(ExP, self).__init__() self.batch_size = 72 self.n_epochs = 2000 self.c_dim = 4 self.lr = 0.0002 self.b1 = 0.5 self.b2 = 0.999 self.dimension = (190, 50) self.nSub = nsub self.start_epoch = 0 self.root = '/Data/strict_TE/' self.log_write = open("./results/log_subject%d.txt" % self.nSub, "w") self.Tensor = torch.cuda.FloatTensor self.LongTensor = torch.cuda.LongTensor self.criterion_l1 = torch.nn.L1Loss().cuda() self.criterion_l2 = torch.nn.MSELoss().cuda() self.criterion_cls = torch.nn.CrossEntropyLoss().cuda() self.model = Conformer().cuda() self.model = nn.DataParallel(self.model, device_ids=[i for i in range(len(gpus))]) self.model = self.model.cuda() # summary(self.model, (1, 22, 1000)) # Segmentation and Reconstruction (S&R) data augmentation def interaug(self, timg, label): aug_data = [] aug_label = [] for cls4aug in range(4): cls_idx = np.where(label == cls4aug + 1) tmp_data = timg[cls_idx] tmp_label = label[cls_idx] tmp_aug_data = np.zeros((int(self.batch_size / 4), 1, 22, 1000)) for ri in range(int(self.batch_size / 4)): for rj in range(8): rand_idx = np.random.randint(0, tmp_data.shape[0], 8) tmp_aug_data[ri, :, :, rj * 125:(rj + 1) * 125] = tmp_data[rand_idx[rj], :, :, rj * 125:(rj + 1) * 125] aug_data.append(tmp_aug_data) aug_label.append(tmp_label[:int(self.batch_size / 4)]) aug_data = np.concatenate(aug_data) aug_label = np.concatenate(aug_label) aug_shuffle = np.random.permutation(len(aug_data)) aug_data = aug_data[aug_shuffle, :, :] aug_label = aug_label[aug_shuffle] aug_data = torch.from_numpy(aug_data).cuda() aug_data = aug_data.float() aug_label = torch.from_numpy(aug_label-1).cuda() aug_label = aug_label.long() return aug_data, aug_label def get_source_data(self): # ! please please recheck if you need validation set # ! and the data segement compared methods used # train data self.total_data = scipy.io.loadmat(self.root + 'A0%dT.mat' % self.nSub) self.train_data = self.total_data['data'] self.train_label = self.total_data['label'] self.train_data = np.transpose(self.train_data, (2, 1, 0)) self.train_data = np.expand_dims(self.train_data, axis=1) self.train_label = np.transpose(self.train_label) self.allData = self.train_data self.allLabel = self.train_label[0] shuffle_num = np.random.permutation(len(self.allData)) self.allData = self.allData[shuffle_num, :, :, :] self.allLabel = self.allLabel[shuffle_num] # test data self.test_tmp = scipy.io.loadmat(self.root + 'A0%dE.mat' % self.nSub) self.test_data = self.test_tmp['data'] self.test_label = self.test_tmp['label'] self.test_data = np.transpose(self.test_data, (2, 1, 0)) self.test_data = np.expand_dims(self.test_data, axis=1) self.test_label = np.transpose(self.test_label) self.testData = self.test_data self.testLabel = self.test_label[0] # standardize target_mean = np.mean(self.allData) target_std = np.std(self.allData) self.allData = (self.allData - target_mean) / target_std self.testData = (self.testData - target_mean) / target_std # data shape: (trial, conv channel, electrode channel, time samples) return self.allData, self.allLabel, self.testData, self.testLabel def train(self): img, label, test_data, test_label = self.get_source_data() img = torch.from_numpy(img) label = torch.from_numpy(label - 1) dataset = torch.utils.data.TensorDataset(img, label) self.dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=True) test_data = torch.from_numpy(test_data) test_label = torch.from_numpy(test_label - 1) test_dataset = torch.utils.data.TensorDataset(test_data, test_label) self.test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=self.batch_size, shuffle=True) # Optimizers self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, betas=(self.b1, self.b2)) test_data = Variable(test_data.type(self.Tensor)) test_label = Variable(test_label.type(self.LongTensor)) bestAcc = 0 averAcc = 0 num = 0 Y_true = 0 Y_pred = 0 # Train the cnn model total_step = len(self.dataloader) curr_lr = self.lr for e in range(self.n_epochs): # in_epoch = time.time() self.model.train() for i, (img, label) in enumerate(self.dataloader): img = Variable(img.cuda().type(self.Tensor)) label = Variable(label.cuda().type(self.LongTensor)) # data augmentation aug_data, aug_label = self.interaug(self.allData, self.allLabel) img = torch.cat((img, aug_data)) label = torch.cat((label, aug_label)) tok, outputs = self.model(img) loss = self.criterion_cls(outputs, label) self.optimizer.zero_grad() loss.backward() self.optimizer.step() # out_epoch = time.time() # test process if (e + 1) % 1 == 0: self.model.eval() Tok, Cls = self.model(test_data) loss_test = self.criterion_cls(Cls, test_label) y_pred = torch.max(Cls, 1)[1] acc = float((y_pred == test_label).cpu().numpy().astype(int).sum()) / float(test_label.size(0)) train_pred = torch.max(outputs, 1)[1] train_acc = float((train_pred == label).cpu().numpy().astype(int).sum()) / float(label.size(0)) print('Epoch:', e, ' Train loss: %.6f' % loss.detach().cpu().numpy(), ' Test loss: %.6f' % loss_test.detach().cpu().numpy(), ' Train accuracy %.6f' % train_acc, ' Test accuracy is %.6f' % acc) self.log_write.write(str(e) + " " + str(acc) + "\n") num = num + 1 averAcc = averAcc + acc if acc > bestAcc: bestAcc = acc Y_true = test_label Y_pred = y_pred torch.save(self.model.module.state_dict(), 'model.pth') averAcc = averAcc / num print('The average accuracy is:', averAcc) print('The best accuracy is:', bestAcc) self.log_write.write('The average accuracy is: ' + str(averAcc) + "\n") self.log_write.write('The best accuracy is: ' + str(bestAcc) + "\n") return bestAcc, averAcc, Y_true, Y_pred # writer.close() def main(): best = 0 aver = 0 result_write = open("./results/sub_result.txt", "w") for i in range(9): starttime = datetime.datetime.now() seed_n = np.random.randint(2021) print('seed is ' + str(seed_n)) random.seed(seed_n) np.random.seed(seed_n) torch.manual_seed(seed_n) torch.cuda.manual_seed(seed_n) torch.cuda.manual_seed_all(seed_n) print('Subject %d' % (i+1)) exp = ExP(i + 1) bestAcc, averAcc, Y_true, Y_pred = exp.train() print('THE BEST ACCURACY IS ' + str(bestAcc)) result_write.write('Subject ' + str(i + 1) + ' : ' + 'Seed is: ' + str(seed_n) + "\n") result_write.write('Subject ' + str(i + 1) + ' : ' + 'The best accuracy is: ' + str(bestAcc) + "\n") result_write.write('Subject ' + str(i + 1) + ' : ' + 'The average accuracy is: ' + str(averAcc) + "\n") endtime = datetime.datetime.now() print('subject %d duration: '%(i+1) + str(endtime - starttime)) best = best + bestAcc aver = aver + averAcc if i == 0: yt = Y_true yp = Y_pred else: yt = torch.cat((yt, Y_true)) yp = torch.cat((yp, Y_pred)) best = best / 9 aver = aver / 9 result_write.write('**The average Best accuracy is: ' + str(best) + "\n") result_write.write('The average Aver accuracy is: ' + str(aver) + "\n") result_write.close() if __name__ == "__main__": print(time.asctime(time.localtime(time.time()))) main() print(time.asctime(time.localtime(time.time()))) 解释上述代码
最新发布
10-23
评论 2
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值