lternating Direction Method of Multiplier(ADMM) Algorithm

ADMM算法详解
本文深入介绍了交替方向乘子法(ADMM),一种用于解决凸优化问题的高效算法。ADMM通过将复杂问题分解为一系列简单子问题来加速计算,特别适用于大规模数据集。文中详细解释了ADMM的基本原理,包括其迭代计算过程,并提供了针对LASSO问题的简化版算法公式。此外,还给出了ADMM算法的Python实现示例,展示了如何通过软阈值运算和最小化步骤求解稀疏信号恢复问题。

Alternating Direction Method of Multipliers (ADMM) 是一种通过将凸优化问题分解为一系列的易解子问题进行求解的算法,目前它在很多领域得到了广泛的应用。 [2].

This is simplified version, specifically for the LASSO:

给定一个稀疏向量x∈Rnx\in R^nxRn和矩阵A∈Rm×nA\in R^{m\times n}ARm×n
y=Ax+ey=Ax+ey=Ax+e
其中eee是加性高斯白噪声。为了恢复信号xxx,我们求解如下最小化问题
x^=min⁡x∣∣y−Ax∣∣22+λ∣∣x∣∣1 \hat{x} = \min_x ||y-Ax||_2^2 + \lambda||x||_1 x^=xminyAx22+λx1


在求解过程中,迭代地计算如下两个式子,直到满足收敛条件。
xk+1=(ATA+ρI)−1(ATy+ρ(z−u)) x^{k+1} = (A^TA + \rho I )^{-1}(A^Ty + \rho (z - u))xk+1=(ATA+ρI)1(ATy+ρ(zu))
zk+1=sign(x^)max(0,∣x∣−λρ) z^{k+1} = \mathrm{sign}(\hat{x})\mathrm{max}\left(0, |x| - \frac{\lambda}{\rho}\right) zk+1=sign(x^)max(0,xρλ)

下面是ADMM算法的PYTHON实现方式。 (http://stanford.edu/~boyd/admm.html).

import numpy as np
import matplotlib.pyplot as plt
from math import sqrt, log

def Sthresh(x, gamma):
    return np.sign(x)*np.maximum(0, np.absolute(x)-gamma/2.0)

def ADMM(A, y):

    m, n = A.shape
    w, v = np.linalg.eig(A.T.dot(A))
    MAX_ITER = 10000

    # Function to caluculate min 1/2(y - Ax) + l||x||
    # via alternating direction methods
    xhat = np.zeros([n, 1])
    zhat = np.zeros([n, 1])
    u = np.zeros([n, 1])

    # Calculate regression co-efficient and stepsize
    lamb = sqrt(2*log(n, 10))
    rho = 1/(np.amax(np.absolute(w)))

    # Pre-compute to save some multiplications
    AtA = A.T.dot(A)
    Aty = A.T.dot(y)
    Q = AtA + rho*np.identity(n)
    Q = np.linalg.inv(Q)

    for i in np.arange(1, MAX_ITER + 1):

        # x minimisation step via posterier OLS
        xhat = Q.dot(Aty + rho*(zhat - u))

        # z minimisation via soft-thresholding
        zhat = Sthresh(xhat + u, lamb/rho)

        # mulitplier update
        u = u + xhat - zhat

    return zhat, rho, lamb

def test(m=50, n=200):
    """Test the ADMM method with randomly generated matrices and vectors"""
    A = np.random.randn(m, n)

    num_non_zeros = 10
    positions = np.random.randint(0, n, num_non_zeros)
    amplitudes = 100*np.random.randn(num_non_zeros, 1)
    x = np.zeros((n, 1))
    x[positions] = amplitudes

    y = A.dot(x) + np.random.randn(m, 1)

    xhat, rho, lamb = ADMM(A, y)

    plt.plot(x, label='Original')
    plt.plot(xhat, label = 'Estimate')

    plt.legend(loc = 'upper right')

    plt.show()


if __name__ == "__main__":
    test()

参考文献:
[1] https://codereview.stackexchange.com/questions/108263/alternating-direction-method-of-multipliers
[2] http://stanford.edu/~boyd/admm.html

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值