import torch
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
def apply_pca_on_channels(x, n_components=3):
"""
对输入张量的通道维度进行PCA处理。
参数:
x: torch.Tensor
输入张量,形状为 (B, C, H, W)。
n_components: int
要保留的主成分数量。
返回:
torch.Tensor
输出张量,形状为 (B, n_components, H, W)。
"""
# 获取输入的形状
B, C, H, W = x.size()
# 调整形状为 (B, H*W, C)
x_reshaped = x.permute(0, 2, 3, 1).reshape(B, H * W, C)
# 存储每个样本的PCA结果
pca_results = []
# 对每个样本执行PCA
for i in range(B):
# 初始化PCA
pca = PCA(n_components=n_components)
# 取出第i个样本
x_image = x_reshaped[i].cpu().numpy()
# 执行PCA
transformed = pca.fit_transform(x_image)
# 将结果转换为tensor并加入列表
pca_results.append(torch.tensor(transformed))
# 将PCA结果拼接为一个Tensor
pca_results = torch.stack(pca_results)
# 调整为 (B, H, W, n_components) 并 permute 为 (B, n_components, H, W)
pca_results = pca_results.reshape(B, H, W, n_components).numpy()
return pca_results
pytorch对特征图使用PCA降维返回numpy数组
于 2024-10-17 10:23:56 首次发布