利用SVD函数只保存前k个主成分
from scipy.sparse.linalg import svds
import numpy as np
def svd_image_reconstruction(image, k, return_USV=False):
C, W, H = image.shape
# 将图像平铺为 Cx(WH)
image_reshaped = image.reshape(C, W * H)
# SVD分解
# U, S, Vt = svd(image_reshaped, full_matrices=False)
U, S, Vt = svds(image_reshaped, k=k)
# 保留前k个主成分
U_k = U[:, :k]
S_k = np.diag(S[:k])
Vt_k = Vt[:k, :]
U_k, S_k, Vt_k = U_k.astype(np.float16), S_k.astype(np.float16), Vt_k.astype(np.float16)
reconstructed_image = np.dot(U_k, np.dot(S_k, Vt_k)).reshape(C, W, H)
if return_USV:
return reconstructed_image, (U_k, S_k, Vt_k)
else:
return reconstructed_image