add attention method

(1) Densely Connected CNN with Multi-scale Feature Attention for Text Classification

Add attention on the last layer before classification.

(2) Attention-Mechanism-Containing Neural Networks for High-Resolution Remote Sensing Image Classification

(3) Learning Multi-Attention Convolutional Neural Network for Fine-Grained Image Recognition

(4) LEARN TO PAY ATTENTION

(5) Weighted Channel Dropout for Regularization of Deep Convolutional Neural Network

class PointnetSAModuleMSG(_PointnetSAModuleBase): """ Pointnet set abstraction layer with multiscale grouping and attention mechanism """ def init(self, *, npoint: int, radii: List[float], nsamples: List[int], mlps: List[List[int]], bn: bool = True, use_xyz: bool = True, pool_method='max_pool', instance_norm=False): """ :param npoint: int :param radii: list of float, list of radii to group with :param nsamples: list of int, number of samples in each ball query :param mlps: list of list of int, spec of the pointnet before the global pooling for each scale :param bn: whether to use batchnorm :param use_xyz: :param pool_method: max_pool / avg_pool :param instance_norm: whether to use instance_norm """ super().init() assert len(radii) == len(nsamples) == len(mlps) self.npoint = npoint self.groupers = nn.ModuleList() self.mlps = nn.ModuleList() # Add attention module self.attentions = nn.ModuleList() for i in range(len(radii)): radius = radii[i] nsample = nsamples[i] self.groupers.append( pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz) if npoint is not None else pointnet2_utils.GroupAll(use_xyz) ) mlp_spec = mlps[i] if use_xyz: mlp_spec[0] += 3 # Add attention module for each scale self.attentions.append(Attention(mlp_spec[-1])) self.mlps.append(pt_utils.SharedMLP(mlp_spec, bn=bn, instance_norm=instance_norm)) self.pool_method = pool_method def forward(self, xyz, features): """ :param xyz: (B, N, 3) xyz coordinates of the points :param features: (B, N, C) input features :return: (B, npoint, mlp[-1]) tensor """ new_features_list = [] for i in range(len(self.groupers)): grouper = self.groupers[i] mlp = self.mlps[i] attention = self.attentions[i] # Group points and features grouped_xyz, grouped_features = grouper(xyz, features) # Apply MLP to each group grouped_features = mlp(grouped_features) # Apply attention mechanism to the features of each group grouped_features = attention(grouped_features) # Perform pooling over each group if self.pool_method == 'max_pool': pooled_features = torch.max(grouped_features, dim=2)[0] else: pooled_features = torch.mean(grouped_features, dim=2) new_features_list.append(pooled_features) # Concatenate features from different scales new_features = torch.cat(new_features_list, dim=1) return new_features在该类中使用的QueryAndGroup类会主动将该类所继承的父类的返回值传入QueryAndGroup类中的forward函数吗
05-22
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值