Python实现的近端交替线性化最小化算法

2024-04-23 06:19:45 发布

您现在位置:Python中文网/ 问答频道 /正文

enter image description here 梯度的更新有点错误。你知道吗

我已经实现了下面给出的算法。我做错了什么

enter image description here

'''
Implementation of PALM- proximal alternating linearisation method
'''
def palm(X,m,niter,lamda):
    X = X.T
    l = X.shape[0]
    n = X.shape[1]
    W = np.random.rand(m,n)
    D = np.random.rand(l,m)
    for i in range(niter):
        '''
        Update dictionary D
        '''
        tau_d = np.linalg.norm(W,2)**-2
        D = D - tau_d * np.matmul((np.matmul(D,W)-X),W.T)

        for j in range(1,m):
            D[:,j] = D[:,j] - (np.ones((l,1)).T*D[:,j])/l    

        for j in range(m):
            D[:,j] = D[:,j]/max(1,np.linalg.norm(D[:,j],2))

        '''
        Update coefficients W
        '''
        tau_w = np.linalg.norm(D,2)**-2
        W = W - tau_w * np.matmul(D.T,(np.matmul(D,W)-X))        
        for j in range(m):
            W[j,:] = np.multiply(np.maximum(np.zeros(W[j,:].shape[0]),np.absolute(W[j,:])-lamda),np.sign(W[j,:]))  
    return D,W

我相信,第二行的W和第二列的D和W的更新是错误的


Tags: innormfor错误npupdaterangerandom
1条回答
网友
1楼 · 发布于 2024-04-23 06:19:45
import numpy as np
def palm(X,m,niter,lamda):
    X = X.T
    l = X.shape[0]
    n = X.shape[1]
    W = np.random.rand(m,n)
    D = np.random.rand(l,m)
    for i in range(niter):
        '''
        Update dictionary D
        '''
        tau_d = np.linalg.norm(W,2)**-2
        D = D - tau_d * np.matmul((np.matmul(D,W)-X),W.T)

        for j in range(1,m):
            D[:,j] = D[:,j] - (np.ones((l,1)).T*D[:,j])/l    

        for j in range(1,m):
            D[:,j] = D[:,j] - D[:,j]/max(1,np.linalg.norm(D[:,j],2))

        '''
        Update coefficients W
        '''
        tau_w = np.linalg.norm(D,2)**-2
        W = W - tau_w * np.matmul(D.T,(np.matmul(D,W)-X))        
        for j in range(1,m):
            W[j,:] = W[j,:] - np.multiply(np.maximum(np.zeros(W[j,:].shape[0]),np.absolute(W[j,:])-lamda),np.sign(W[j,:]))  
    return D,W

相关问题 更多 >