Python与Matlab中FastICA的性能对比

2024-05-16 18:45:08 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在尝试用Python从Matlab移植following ICA implementation。根据我的理解,它使用了以双曲余弦作为对比函数的通货紧缩正交化

使用FastICA via sklearn可以获得令人满意的结果,但是与Matlab相比,执行时间非常慢

作为比较,以下示例数据的执行时间如下:

  • Python(deflation算法):4.97秒
  • Python(parallel算法):0.04秒
  • Matlab:0.04秒

奇怪的是,Python中FastICA的deflation算法比Matlab实现或Python中FastICA的paralell算法慢100多倍

为什么会有如此巨大的差异,特别是在Matlab和Python版本之间? 我不是ICA方面的专家,因此可能缺少一个配置

这是用于生成示例数据和分析执行时间的Python代码:

import timeit
import numpy as np
from sklearn.decomposition import FastICA
from scipy.misc import electrocardiogram

# prepare signal
ecg = electrocardiogram().reshape(-1, 1)
np.random.seed(0)
ecg_noisy = ecg + np.random.randn(ecg.shape[0], 1)
x = np.hstack((ecg, ecg_noisy))

n = 10  # number of runs for profiling

# profile parallel FastICA algoritm
transformerParallel = FastICA(n_components=x.shape[1],
                              algorithm='parallel',  
                              random_state=0,
                              whiten=True,
                              max_iter=200,
                              fun='logcosh',
                              tol=1E-4)                            
tp = timeit.timeit(lambda: transformerParallel.fit_transform(x), number=n)
print('ICA Parallel takes {:.3f} seconds'.format(tp/n));

# profile deflational FastICA algorithm
transformerDeflation = FastICA(n_components=x.shape[1],
                               algorithm='deflation',  
                               random_state=0,
                               whiten=True,
                               max_iter=200,
                               fun='logcosh',
                               tol=1E-4)                               
td = timeit.timeit(lambda: transformerDeflation.fit_transform(x), number=n)
print('ICA Deflation takes {:.3f} seconds'.format(td/n));

# export data for profiling in Matlab
import pandas as pd
pd.DataFrame(x).to_csv('input.csv', header=False, index=False)

这是用于在Matlab中分析的代码(使用coshFpDeIca.m):

x = load('input.csv');
tm = timeit(@()coshFpDeIca(x'), 3);
fprintf('ICA Deflation (in Matlab) takes %.3f seconds\n', tm);

Tags: import算法numberparallelnp时间randomalgorithm