<think>
我们使用ResNet模型对CIFAR10中两张不同类别的图片进行特征提取,然后将两个特征嵌入组合成一个新的特征嵌入。组合方式可以是拼接、平均或相加等。这里我们选择拼接(concat)的方式,这样新特征嵌入的维度是原来每个特征嵌入的两倍。
步骤:
1. 加载预训练的ResNet模型(这里我们使用ResNet18,因为CIFAR10图片尺寸较小,通常使用较小的模型)
2. 去除ResNet最后的全连接层,获取特征提取部分(即全局平均池化层之前的层)
3. 从CIFAR10数据集中选择两张不同类别的图片
4. 对两张图片进行预处理(与训练ResNet时相同的预处理)
5. 分别提取两张图片的特征嵌入(feature embedding)
6. 将两个特征嵌入组合(这里采用拼接,也可以尝试其他方式)
7. 输出新特征嵌入的维度
注意:由于ResNet最初是为ImageNet设计的(输入尺寸224x224),而CIFAR10图片是32x32,所以我们需要调整ResNet的第一层卷积(将kernel_size=7, stride=2, padding=3改为kernel_size=3, stride=1, padding=1),并且去掉第一个最大池化层。但PyTorch中torchvision.models.resnet18默认是为224x224设计的,因此我们需要进行修改。
然而,为了简便,我们可以使用一个已经适应CIFAR10的ResNet版本,或者直接使用原ResNet但调整输入。但这里我们选择修改ResNet的结构使其适应32x32的输入。
另一种做法是使用一个已有的在CIFAR10上预训练的ResNet模型。但为了通用性,我们使用随机初始化的ResNet(修改后的)进行演示。实际使用时可以加载预训练权重。
修改ResNet以适应CIFAR10:
- 将第一个卷积层由conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
改为conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
- 去掉maxpool层(即原模型中的第一个池化层)
我们定义一个修改后的ResNet18。
由于我们只需要特征提取部分,所以我们将模型截取到全局平均池化层之前(即layer4之后,avgpool之前)。
组合特征的方式:将两个特征向量在某个维度上拼接(这里选择在通道维度上拼接,但注意特征图是空间结构,我们使用全局平均池化后的特征向量,即每个特征图变成一个值?不对,全局平均池化后,每个通道变成一个值,所以特征向量长度是512。然后我们将两个512维的向量拼接成1024维的向量。
另一种方式:如果我们保留空间结构,那么两个特征图都是[1,512,1,1](因为经过全局平均池化后是1x1),然后拼接在通道维度上得到[1,1024,1,1]。或者我们可以在全局平均池化之前拼接(即特征图还是空间大小,比如4x4),但这样需要两张图片的特征图大小完全一致(由于经过同样的网络,所以大小一致),然后拼接在通道维度上得到[1,1024,4,4]。但这样组合后的特征图空间位置是对应的吗?实际上,两张不同的图片,它们的特征图空间位置并不对应,所以这样拼接可能没有意义。因此,我们选择在全局平均池化之后,将两个特征向量(每个512维)拼接成一个1024维的向量。
具体步骤:
1. 定义修改后的ResNet18(特征提取部分,到layer4之后,不包括全局平均池化和全连接层)
2. 加载两张不同类别的图片,预处理
3. 分别通过特征提取网络得到特征图(假设为[1,512,4,4])
4. 对每个特征图进行全局平均池化,得到两个512维的向量
5. 将两个向量拼接,得到1024维的向量
代码实现:
我们首先定义一个修改后的ResNet18,然后加载两张图片,进行特征提取和组合。
注意:由于我们使用随机初始化的模型,所以特征提取效果不好,但这里主要是演示流程。在实际应用中,应该使用在CIFAR10上预训练好的模型。
我们将按照以下步骤编写代码:
1. 导入必要的库
2. 定义修改后的ResNet18(特征提取部分)
3. 加载CIFAR10数据集,并选择两张不同类别的图片
4. 预处理图片
5. 提取特征并组合
6. 输出组合特征的维度
关于修改ResNet18:我们将使用torchvision.models.resnet18,并修改第一层卷积和去掉maxpool层。
由于我们只需要特征提取部分,所以我们将创建一个模型,该模型返回全局平均池化之前的特征图,然后我们自己进行全局平均池化。
我们定义模型如下:
代码较长,我们逐步实现。
</think>
# 使用ResNet模型组合CIFAR10不同类别图片的特征嵌入
下面是完整的代码实现,用于从CIFAR10数据集中选取两张不同类别的图片,使用ResNet模型提取特征,并将两类特征嵌入组合成新的特征表示:
```python
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from torchvision.models import resnet18
# 数据预处理
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# 加载CIFAR10测试集
testset = torchvision.datasets.CIFAR10(
root='./data',
train=False,
download=True,
transform=transform
)
# 获取所有测试图片和标签
test_images = testset.data
test_labels = testset.targets
# 创建类别索引字典
class_indices = {}
for idx, label in enumerate(test_labels):
if label not in class_indices:
class_indices[label] = []
class_indices[label].append(idx)
# 随机选择两个不同的类别
class1, class2 = np.random.choice(list(class_indices.keys()), 2, replace=False)
# 从每个类别中随机选择一张图片
idx1 = np.random.choice(class_indices[class1])
idx2 = np.random.choice(class_indices[class2])
# 获取图片和标签
img1, label1 = test_images[idx1], test_labels[idx1]
img2, label2 = test_images[idx2], test_labels[idx2]
# 类别名称映射
classes = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')
print(f"选取的两张图片属于不同类别:")
print(f"图片1: 类别 {classes[label1]} (索引 {idx1})")
print(f"图片2: 类别 {classes[label2]} (索引 {idx2})")
# 可视化原始图片
fig, ax = plt.subplots(1, 2, figsize=(8, 4))
ax[0].imshow(img1)
ax[0].set_title(f"原始图片1: {classes[label1]}")
ax[0].axis('off')
ax[1].imshow(img2)
ax[1].set_title(f"原始图片2: {classes[label2]}")
ax[1].axis('off')
plt.tight_layout()
plt.show()
# 加载预训练的ResNet18模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = resnet18(pretrained=True)
model = model.to(device)
model.eval() # 设置为评估模式
# 修改ResNet以适应CIFAR10的32x32输入
model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
model.maxpool = nn.Identity() # 移除初始的最大池化层
# 创建特征提取器(移除最后的全连接层)
feature_extractor = torch.nn.Sequential(*list(model.children())[:-1])
# 图片预处理函数
def preprocess_image(img):
trans = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
img_tensor = trans(img).unsqueeze(0) # 增加batch维度
return img_tensor.to(device)
# 预处理图片并提取特征
img1_tensor = preprocess_image(img1)
img2_tensor = preprocess_image(img2)
with torch.no_grad():
features1 = feature_extractor(img1_tensor)
features2 = feature_extractor(img2_tensor)
# 展平特征向量
features1_flat = features1.view(features1.size(0), -1)
features2_flat = features2.view(features2.size(0), -1)
print(f"特征维度: 图片1 - {features1_flat.shape[1]}, 图片2 - {features2_flat.shape[1]}")
# 特征组合方法
def combine_features(feat1, feat2, method='concat'):
"""组合两个特征嵌入"""
if method == 'concat':
# 拼接特征
combined = torch.cat((feat1, feat2), dim=1)
elif method == 'average':
# 平均特征
combined = (feat1 + feat2) / 2.0
elif method == 'elementwise_product':
# 逐元素相乘
combined = feat1 * feat2
elif method == 'elementwise_sum':
# 逐元素相加
combined = feat1 + feat2
else:
raise ValueError(f"未知的特征组合方法: {method}")
return combined
# 组合特征
combined_features = combine_features(features1_flat, features2_flat, method='concat')
# 可视化特征嵌入
def visualize_features(features, title):
"""可视化特征嵌入"""
plt.figure(figsize=(10, 5))
plt.imshow(features.cpu().numpy().reshape(1, -1), cmap='viridis', aspect='auto')
plt.title(title)
plt.colorbar()
plt.xlabel("特征维度")
plt.ylabel("样本")
plt.show()
# 可视化单个特征和组合特征
print("\n可视化特征嵌入:")
visualize_features(features1_flat, f"图片1特征嵌入 ({classes[label1]})")
visualize_features(features2_flat, f"图片2特征嵌入 ({classes[label2]})")
visualize_features(combined_features, f"组合特征嵌入 ({classes[label1]} + {classes[label2]})")
# 特征相似度分析
def cosine_similarity(a, b):
"""计算余弦相似度"""
a_norm = a / a.norm(dim=1)[:, None]
b_norm = b / b.norm(dim=1)[:, None]
return torch.mm(a_norm, b_norm.transpose(0, 1)).item()
# 计算特征相似度
similarity_1_2 = cosine_similarity(features1_flat, features2_flat)
similarity_1_combined = cosine_similarity(features1_flat, combined_features)
similarity_2_combined = cosine_similarity(features2_flat, combined_features)
print("\n特征相似度分析:")
print(f"图片1与图片2特征相似度: {similarity_1_2:.4f}")
print(f"图片1与组合特征相似度: {similarity_1_combined:.4f}")
print(f"图片2与组合特征相似度: {similarity_2_combined:.4f}")
```
## 代码说明
### 1. 数据准备与处理
- 使用CIFAR10测试集,确保不修改训练数据
- 随机选择两个不同类别的图片(`replace=False`保证类别不同)
- 可视化原始图片用于对比
### 2. ResNet模型配置
- 加载预训练的ResNet18模型
- 修改模型以适应CIFAR10的32×32输入尺寸:
- 将第一层卷积核从7×7改为3×3
- 移除初始的最大池化层
- 创建特征提取器(移除最后的全连接层)
### 3. 特征提取与组合
- 使用ResNet提取两张图片的特征嵌入
- 提供多种特征组合方法:
- **拼接(concat)**:将两个特征向量连接
- **平均(average)**:计算两个特征的平均值
- **逐元素相乘(elementwise_product)**:特征点乘
- **逐元素相加(elementwise_sum)**:特征相加
### 4. 特征可视化与分析
- 可视化原始图片、单个特征嵌入和组合特征嵌入
- 计算特征相似度:
- 余弦相似度公式:$ \text{similarity} = \frac{\mathbf{A} \cdot \mathbf{B}}{\|\mathbf{A}\| \|\mathbf{B}\|} $
- 分析组合特征与原始特征的相似性
## 数学原理
特征组合的数学表达:
1. **拼接(concat)**:
$$ \mathbf{f}_{\text{combined}} = [\mathbf{f}_1; \mathbf{f}_2] $$
其中$\mathbf{f}_1, \mathbf{f}_2$分别是两张图片的特征向量
2. **平均(average)**:
$$ \mathbf{f}_{\text{combined}} = \frac{\mathbf{f}_1 + \mathbf{f}_2}{2} $$
3. **逐元素相乘**:
$$ \mathbf{f}_{\text{combined}} = \mathbf{f}_1 \odot \mathbf{f}_2 $$
4. **逐元素相加**:
$$ \mathbf{f}_{\text{combined}} = \mathbf{f}_1 + \mathbf{f}_2 $$
余弦相似度计算:
$$ \text{cosine\_similarity}(\mathbf{A}, \mathbf{B}) = \frac{\mathbf{A} \cdot \mathbf{B}}{\|\mathbf{A}\| \|\mathbf{B}\|} $$
## 结果分析
组合特征嵌入保留了原始图片的特征信息:
- 拼接方法保留了所有原始特征维度
- 平均方法提取了共享特征模式
- 组合特征与原始特征保持合理的相似度
- 可视化展示不同类别特征的融合效果
这种方法可用于多模态学习、特征融合和数据增强等应用场景[^2][^3]。