使用ScipyBoundedMinimize和Optax的JAX优化速度慢 - 寻求加速策略
我正在用 jax
优化一个模型,这个模型需要处理一个很大的观察数据集(有4800个数据点),而且模型本身也比较复杂,还涉及插值。现在使用 jaxopt.ScipyBoundedMinimize
进行优化的过程大约需要30秒来完成100次迭代,而且大部分时间似乎都花在了第一次迭代开始之前或进行时。下面是相关的代码片段,你可以在以下链接找到所需的数据。
import jax.numpy as jnp
import time as ela_time
from jaxopt import ScipyBoundedMinimize
import optax
import jax
import pickle
file1 = open('idc.pkl', 'rb')
idc = pickle.load(file1)
file1.close()
file2 = open('sg.pkl', 'rb')
sg = pickle.load(file2)
file2.close()
file3 = open('cpcs.pkl', 'rb')
cpcs = pickle.load(file3)
file3.close()
def model(fssc, fssh, time, rv, amp):
fssp = 1.0 - (fssc + fssh)
ivis = cpcs['common'][time]['ivis']
areas = cpcs['common'][time]['areas']
mus = cpcs['common'][time]['mus']
vels = idc['vels'].copy()
ldfs_phot = cpcs['line'][time]['ldfs_phot']
ldfs_cool = cpcs['line'][time]['ldfs_cool']
ldfs_hot = cpcs['line'][time]['ldfs_hot']
lps_phot = cpcs['line'][time]['lps_phot']
lps_cool = cpcs['line'][time]['lps_cool']
lps_hot = cpcs['line'][time]['lps_hot']
lis_phot = cpcs['line'][time]['lis_phot']
lis_cool = cpcs['line'][time]['lis_cool']
lis_hot = cpcs['line'][time]['lis_hot']
coeffs_phot = lis_phot * ldfs_phot * areas * mus
wgt_phot = coeffs_phot * fssp[ivis]
wgtn_phot = jnp.sum(wgt_phot)
coeffs_cool = lis_cool * ldfs_cool * areas * mus
wgt_cool = coeffs_cool * fssc[ivis]
wgtn_cool = jnp.sum(wgt_cool)
coeffs_hot = lis_hot * ldfs_hot * areas * mus
wgt_hot = coeffs_hot * fssh[ivis]
wgtn_hot = jnp.sum(wgt_hot)
prf = jnp.sum(wgt_phot[:, None] * lps_phot + wgt_cool[:, None] * lps_cool + wgt_hot[:, None] * lps_hot, axis=0)
prf /= wgtn_phot + wgtn_cool + wgtn_hot
prf = jnp.interp(vels, vels + rv, prf)
prf = prf + amp
avg = jnp.mean(prf)
prf = prf / avg
return prf
def loss(x0s, lmbd):
noes = sg['noes']
noo = len(idc['times'])
fssc = x0s[:noes]
fssh = x0s[noes: 2 * noes]
fssp = 1.0 - (fssc + fssh)
rv = x0s[2 * noes: 2 * noes + noo]
amp = x0s[2 * noes + noo: 2 * noes + 2 * noo]
chisq = 0
for i, itime in enumerate(idc['times']):
oprf = idc['data'][itime]['prf']
oprf_errs = idc['data'][itime]['errs']
nop = len(oprf)
sprf = model(fssc=fssc, fssh=fssh, time=itime, rv=rv[i], amp=amp[i])
chisq += jnp.sum(((oprf - sprf) / oprf_errs) ** 2) / (noo * nop)
wp = sg['grid_areas'] / jnp.max(sg['grid_areas'])
mem = jnp.sum(wp * (fssc * jnp.log(fssc / 1e-5) + fssh * jnp.log(fssh / 1e-5) +
(1.0 - fssp) * jnp.log((1.0 - fssp) / (1.0 - 1e-5)))) / sg['noes']
ftot = chisq + lmbd * mem
return ftot
if __name__ == '__main__':
# idc: a dictionary containing observational data (150 x 32)
# sg and cpcs: dictionaries with related coefficients
noes = sg['noes']
lmbd = 1.0
maxiter = 1000
tol = 1e-5
fss = jnp.ones(2 * noes) * 1e-5
x0s = jnp.hstack((fss, jnp.zeros(len(idc['times']) * 2)))
minx0s = [1e-5] * (2 * noes) + [-jnp.inf] * len(idc['times']) * 2
maxx0s = [1.0 - 1e-5] * (2 * noes) + [jnp.inf] * len(idc['times']) * 2
bounds = (minx0s, maxx0s)
start = ela_time.time()
optimizer = ScipyBoundedMinimize(fun=loss, maxiter=maxiter, tol=tol, method='L-BFGS-B',
options={'disp': True})
x0s, info = optimizer.run(x0s, bounds, lmbd)
# optimizer = optax.adam(learning_rate=0.1)
# optimizer_state = optimizer.init(x0s)
#
# for i in range(1, maxiter + 1):
#
# print('ITERATION -->', i)
#
# gradients = jax.grad(loss)(x0s, lmbd)
# updates, optimizer_state = optimizer.update(gradients, optimizer_state, x0s)
# x0s = optax.apply_updates(x0s, updates)
# x0s = jnp.clip(x0s, jnp.array(minx0s), jnp.array(maxx0s))
# print('Objective function: {:.3E}'.format(loss(x0s, lmbd)))
end = ela_time.time()
print(end - start) # total elapsed time: ~30 seconds
以下是相关方面的详细信息:
- 自由参数的数量(
x0s
): 5263 - 数据: 存储在
idc
字典中的观察数据(4800个数据点) - 模型: 在
model
函数中定义,也使用了插值 - 尝试过的优化方法:
jaxopt.ScipyBoundedMinimize
使用L-BFGS-B
方法(速度慢,大约30秒,大部分时间花在第一次迭代之前或进行时)- optax.adam(太慢,大约200秒)
- 尝试并行化: 我尝试对
optax.adam
进行并行化,但由于建模的固有特性,我没有成功,因为x0s
不能被分割。(假设我对并行化的理解是正确的)
问题:
- 在
ScipyBoundedMinimize
中,第一次迭代之前或进行时速度慢的潜在原因是什么? - 在
jax
中有没有其他可能更快的优化算法,适合我的情况(自由参数和数据点数量多,模型复杂且涉及插值)? - 我对
optax.adam
的并行化理解错了吗?在这种情况下有没有可能的并行化策略? - 在提供的代码片段中,有没有可以提高性能的代码优化(例如,向量化)?
附加信息:
- 硬件: Intel® Core™ i7-9750H CPU @ 2.60GHz × 12, 16 GiB RAM(笔记本电脑)
- 软件: 操作系统 Ubuntu 22.04,Python 3.10.12,JAX 0.4.25,optax 0.2.1
我非常感谢任何能够改善优化性能的见解或建议。
1 个回答
0
JAX代码是即时编译的(JIT),这意味着第一次运行时花费的时间比较长,主要是因为编译的开销。你的代码越长,编译所需的时间就越多。
一个常见导致编译时间长的问题是使用了Python的控制流,比如for
循环。JAX的追踪机制会把这些循环“压平”(详细信息可以查看JAX常见问题:控制流)。在你的例子中,你在数据结构中循环了4800个条目,这样就造成了一个非常长且效率低下的程序。
在这种情况下,通常的解决办法是使用jax.vmap
来重写你的程序。像大多数JAX的构造一样,这种方法在处理结构体数组模式时效果最好,而不是你数据中使用的数组结构模式。所以,使用vmap
的第一步是将你的数据重新组织成JAX可以使用的形式;可能看起来像这样:
itimes = jnp.arange(len(idc['times']))
prf = jnp.array([idc['data'][i]['prf'] for i in itimes])
errs = jnp.array([idc['data'][i]['errs'] for i in itimes])
sprf = jax.vmap(model, in_axes=[None, None, 0, 0, 0])(fssc, fssh, itimes, rv, amp)
chi2 = jnp.sum((oprf - sprf) / oprf_errs) ** 2) / len(times) / sprf.shape[1]
这段代码不能直接使用:你还需要将model
函数使用的数据重构为结构体数组的样式,但希望这能给你一个大概念。
另外要注意的是,这假设idc['data'][i]['prf']
和idc['data'][i]['errs']
的每个条目都有相同的形状。如果不是这样的话,那么你的问题可能不太适合JAX的SPMD编程模型,而且没有简单的方法来解决长时间编译的问题。