import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from PIL import Image
import torchvision.transforms as transforms
# 加载RGB图像
image_path = '2.jpg'
image = Image.open(image_path)
# 将图像转换为PyTorch张量并进行必要的预处理
transform = transforms.Compose([transforms.Resize((224, 224)),
transforms.ToTensor()])
input_image = transform(image).unsqueeze(0) # 添加一个批次维度
class Mish(nn.Module):
def forward(self, x):
return x * torch.tanh(F.softplus(x))
class VGG19(nn.Module):
def __init__(self, num_classes=1000):
super(VGG19, self).__init__()
# 卷积层部分
self.features1 = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
# nn.PReLU(),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2), )
# 前向传播函数
def forward(self, x):
x1 = self.features1(x)
return x1
# output_feature_map = conv_layer(input_image)
vgg16 = VGG19()
print(vgg16)
output_feature_map = vgg16(input_image)
print(output_feature_map.shape)
# 获取特征图数据
feature_map_data = output_feature_map[0, 0].detach().numpy()
print(feature_map_data.shape)
# 使用Matplotlib绘制特征图的热力图
plt.imshow(feature_map_data, cmap="viridis")
plt.title("Feature Map Heatmap")
plt.colorbar()
plt.show()