字典学习 (Dictionary Learning) —— K-SVD 算法

论文

M. Aharon, M. Elad and A. Bruckstein, “K-SVD: An algorithm for designing overcomplete dictionaries for sparse representation,” in IEEE Transactions on Signal Processing, vol. 54, no. 11, pp. 4311-4322, Nov. 2006.

问题描述

min ⁡ D , X ∣ ∣ Y − D X ∣ ∣ F s . t . ∣ ∣ x i ∣ ∣ 0 < T 0 , ∀ i \begin{array}{ll} \min_{D,X} & ||Y-DX||_F \\ s.t.& ||x_i||_0 < T_0, \forall i \end{array} minD,Xs.t.YDXFxi0<T0,i
其中 Y ∈ R M × L Y\in R^{M\times L} YRM×L为原始数据, D ∈ R M × N D\in R^{M\times N} DRM×N为字典, X ∈ R N × L X\in R^{N\times L} XRN×L为编码。

M M M 表示数据特征维度, L L L表示样本数, N N N 表示字典大小。

优化的目标是找到原始数据的稀疏表示,要求 X X X的每一列 x i x_i xi的非零元数目小于 T 0 T_0 T0
在这里插入图片描述

求解原理

交替优化:

  • 固定 D D D,优化 X X X,主要用到正交匹配跟踪 (OMP)
  • 固定 X X X,优化 D D D,主要用到奇异值分解 (SVD)

在这里插入图片描述

python 实现

KSVD 算法

from sklearn import linear_model

def KSVD(Y, dict_size,
         max_iter = 10,
         sparse_rate = 0.2,
         tolerance = 1e-6):
    
    assert(dict_size <= Y.shape[1])

    def dict_update(y, d, x):
        assert(d.shape[1] == x.shape[0])

        for i in range(x.shape[0]):
            index = np.where(np.abs(x[i, :]) > 1e-7)[0]

            if len(index) == 0:
                continue

            d[:, i] = 0
            r = (y - np.dot(d, x))[:, index]
            u, s, v = np.linalg.svd(r, full_matrices=False)
            d[:, i] = u[:, 0]
            for j,k in enumerate(index):
                x[i, k] = s[0] * v[0, j]
        return d, x


    # initialize dictionary
    if dict_size > Y.shape[0]:
        dic = Y[:, np.random.choice(Y.shape[1], dict_size, replace=False)]
    else:
        u, s, v = np.linalg.svd(Y)
        dic = u[:, :dict_size]
        
    print('dict shape:', dic.shape)
    
    n_nonzero_coefs_each_code = int(sparse_rate * dict_size) if int(sparse_rate * dict_size) > 0 else 1
    for i in range(max_iter):
        x = linear_model.orthogonal_mp(dic, Y, n_nonzero_coefs = n_nonzero_coefs_each_code)
        e = np.linalg.norm(Y - dic @ x)
        if e < tolerance:
            break
        dict_update(Y, dic, x)

    sparse_code = linear_model.orthogonal_mp(dic, Y, n_nonzero_coefs = n_nonzero_coefs_each_code)
    
    return dic, sparse_code

测试

Y = D X Y = D X Y=DX

import numpy as np
import scipy.sparse as ss

# 生成随机稀疏矩阵 X
num_col_X = 30
num_row_X = 10
num_ele_X = 40
a = [np.random.randint(0,num_row_X) for _ in range(num_ele_X)]
b = [np.random.randint(0,num_col_X) for _ in range(num_ele_X - num_col_X)] + [i for i in range(num_col_X)]
c = [np.random.rand()*10 for _ in range(num_ele_X)]
rows, cols, v = np.array(a), np.array(b), np.array(c)
sparseX = ss.coo_matrix((v,(rows,cols)))
X = sparseX.todense()

# 随机生成字典 D
num_row_D = 10
num_col_D = num_row_X
D = np.random.random((num_row_D,num_col_D))

# 生成 Y
Y = D @ X

原始数据
在这里插入图片描述
完备字典

dic, code = KSVD(Y, 10)
Y_reconstruct = dic @ code

在这里插入图片描述
欠完备字典

dic, code = KSVD(Y, 5)
Y_reconstruct = dic @ code

在这里插入图片描述
超完备字典

dic, code = KSVD(Y, 15)
Y_reconstruct = dic @ code

在这里插入图片描述

结果可视化函数

def showmat(X, cmap='Oranges'):
    fig = plt.figure(figsize=(10,5))
    ax = fig.add_subplot(111)
    X_abs = np.abs(X)
    ax.matshow(X_abs, vmin=np.min(X_abs), vmax=np.max(X_abs), cmap=cmap)
    ax.set_xticks([])
    ax.set_yticks([])

showmat(Y_reconstruct), showmat(Y)
showmat(code,'Greens'), showmat(X,'Greens')
showmat(dic,'Reds'), showmat(D, 'Reds')
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

颹蕭蕭

白嫖?

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值