模块出处
[link] [code] [ACM MM 23] Frequency Perception Network for Camouflaged Object Detection
模块名称
Frequency-Perception Module (FPM)
模块作用
获取频域信息,更好识别伪装对象
模块结构
模块代码
import torch
import torch.nn as nn
import torch.nn.functional as F
class FirstOctaveConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, alpha=0.5, stride=1, padding=1, dilation=1,
groups=1, bias=False):
super(FirstOctaveConv, self).__init__()
self.stride = stride
kernel_size = kernel_size[0]
self.h2g_pool = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
self.h2l = torch.nn.Conv2d(in_channels, int(alpha * in_channels),
kernel_size, 1, padding, dilation, groups, bias)
self.h2h = torch.nn.Conv2d(in_channels, in_channels - int(alpha * in_channels),
kernel_size, 1, padding, dilation, groups, bias)
def forward(self, x):
if self.stride ==2:
x = self.h2g_pool(x)
X_h2l = self.h2g_pool(x)
X_h