使用ScipyBoundedMinimize和Optax的JAX优化速度慢 - 寻求加速策略

1 投票
1 回答
69 浏览
提问于 2025-04-14 15:27

我正在用 jax 优化一个模型,这个模型需要处理一个很大的观察数据集(有4800个数据点),而且模型本身也比较复杂,还涉及插值。现在使用 jaxopt.ScipyBoundedMinimize 进行优化的过程大约需要30秒来完成100次迭代,而且大部分时间似乎都花在了第一次迭代开始之前或进行时。下面是相关的代码片段,你可以在以下链接找到所需的数据。

必要的数据 (idc, sg 和 cpcs)

import jax.numpy as jnp
import time as ela_time
from jaxopt import ScipyBoundedMinimize
import optax
import jax
import pickle


file1 = open('idc.pkl', 'rb')
idc = pickle.load(file1)
file1.close()

file2 = open('sg.pkl', 'rb')
sg = pickle.load(file2)
file2.close()

file3 = open('cpcs.pkl', 'rb')
cpcs = pickle.load(file3)
file3.close()


def model(fssc, fssh, time, rv, amp):

    fssp = 1.0 - (fssc + fssh)

    ivis = cpcs['common'][time]['ivis']
    areas = cpcs['common'][time]['areas']
    mus = cpcs['common'][time]['mus']

    vels = idc['vels'].copy()

    ldfs_phot = cpcs['line'][time]['ldfs_phot']
    ldfs_cool = cpcs['line'][time]['ldfs_cool']
    ldfs_hot = cpcs['line'][time]['ldfs_hot']

    lps_phot = cpcs['line'][time]['lps_phot']
    lps_cool = cpcs['line'][time]['lps_cool']
    lps_hot = cpcs['line'][time]['lps_hot']

    lis_phot = cpcs['line'][time]['lis_phot']
    lis_cool = cpcs['line'][time]['lis_cool']
    lis_hot = cpcs['line'][time]['lis_hot']

    coeffs_phot = lis_phot * ldfs_phot * areas * mus
    wgt_phot = coeffs_phot * fssp[ivis]
    wgtn_phot = jnp.sum(wgt_phot)

    coeffs_cool = lis_cool * ldfs_cool * areas * mus
    wgt_cool = coeffs_cool * fssc[ivis]
    wgtn_cool = jnp.sum(wgt_cool)

    coeffs_hot = lis_hot * ldfs_hot * areas * mus
    wgt_hot = coeffs_hot * fssh[ivis]
    wgtn_hot = jnp.sum(wgt_hot)

    prf = jnp.sum(wgt_phot[:, None] * lps_phot + wgt_cool[:, None] * lps_cool + wgt_hot[:, None] * lps_hot, axis=0)
    prf /= wgtn_phot + wgtn_cool + wgtn_hot

    prf = jnp.interp(vels, vels + rv, prf)

    prf = prf + amp

    avg = jnp.mean(prf)

    prf = prf / avg

    return prf


def loss(x0s, lmbd):

    noes = sg['noes']

    noo = len(idc['times'])

    fssc = x0s[:noes]
    fssh = x0s[noes: 2 * noes]
    fssp = 1.0 - (fssc + fssh)
    rv = x0s[2 * noes: 2 * noes + noo]
    amp = x0s[2 * noes + noo: 2 * noes + 2 * noo]

    chisq = 0
    for i, itime in enumerate(idc['times']):
        oprf = idc['data'][itime]['prf']
        oprf_errs = idc['data'][itime]['errs']

        nop = len(oprf)

        sprf = model(fssc=fssc, fssh=fssh, time=itime, rv=rv[i], amp=amp[i])

        chisq += jnp.sum(((oprf - sprf) / oprf_errs) ** 2) / (noo * nop)

    wp = sg['grid_areas'] / jnp.max(sg['grid_areas'])

    mem = jnp.sum(wp * (fssc * jnp.log(fssc / 1e-5) + fssh * jnp.log(fssh / 1e-5) +
                    (1.0 - fssp) * jnp.log((1.0 - fssp) / (1.0 - 1e-5)))) / sg['noes']

    ftot = chisq + lmbd * mem

    return ftot


if __name__ == '__main__':

    # idc: a dictionary containing observational data (150 x 32)
    # sg and cpcs: dictionaries with related coefficients

    noes = sg['noes']
    lmbd = 1.0
    maxiter = 1000
    tol = 1e-5

    fss = jnp.ones(2 * noes) * 1e-5
    x0s = jnp.hstack((fss, jnp.zeros(len(idc['times']) * 2)))

    minx0s = [1e-5] * (2 * noes) + [-jnp.inf] * len(idc['times']) * 2
    maxx0s = [1.0 - 1e-5] * (2 * noes) + [jnp.inf] * len(idc['times']) * 2

    bounds = (minx0s, maxx0s)

    start = ela_time.time()

    optimizer = ScipyBoundedMinimize(fun=loss, maxiter=maxiter, tol=tol, method='L-BFGS-B',
                                 options={'disp': True})
    x0s, info = optimizer.run(x0s, bounds,  lmbd)

    # optimizer = optax.adam(learning_rate=0.1)
    # optimizer_state = optimizer.init(x0s)
    #
    # for i in range(1, maxiter + 1):
    #
    #     print('ITERATION -->', i)
    #
    #     gradients = jax.grad(loss)(x0s, lmbd)
    #     updates, optimizer_state = optimizer.update(gradients, optimizer_state, x0s)
    #     x0s = optax.apply_updates(x0s, updates)
    #     x0s = jnp.clip(x0s, jnp.array(minx0s), jnp.array(maxx0s))
    #     print('Objective function: {:.3E}'.format(loss(x0s, lmbd)))

    end = ela_time.time()

    print(end - start)   # total elapsed time: ~30 seconds

以下是相关方面的详细信息:

  • 自由参数的数量(x0s): 5263
  • 数据: 存储在 idc 字典中的观察数据(4800个数据点)
  • 模型:model 函数中定义,也使用了插值
  • 尝试过的优化方法:
    • jaxopt.ScipyBoundedMinimize 使用 L-BFGS-B 方法(速度慢,大约30秒,大部分时间花在第一次迭代之前或进行时)
    • optax.adam(太慢,大约200秒)
  • 尝试并行化: 我尝试对 optax.adam 进行并行化,但由于建模的固有特性,我没有成功,因为 x0s 不能被分割。(假设我对并行化的理解是正确的)

问题:

  1. ScipyBoundedMinimize 中,第一次迭代之前或进行时速度慢的潜在原因是什么?
  2. jax 中有没有其他可能更快的优化算法,适合我的情况(自由参数和数据点数量多,模型复杂且涉及插值)?
  3. 我对 optax.adam 的并行化理解错了吗?在这种情况下有没有可能的并行化策略?
  4. 在提供的代码片段中,有没有可以提高性能的代码优化(例如,向量化)?

附加信息:

  • 硬件: Intel® Core™ i7-9750H CPU @ 2.60GHz × 12, 16 GiB RAM(笔记本电脑)
  • 软件: 操作系统 Ubuntu 22.04,Python 3.10.12,JAX 0.4.25,optax 0.2.1

我非常感谢任何能够改善优化性能的见解或建议。

1 个回答

0

JAX代码是即时编译的(JIT),这意味着第一次运行时花费的时间比较长,主要是因为编译的开销。你的代码越长,编译所需的时间就越多。

一个常见导致编译时间长的问题是使用了Python的控制流,比如for循环。JAX的追踪机制会把这些循环“压平”(详细信息可以查看JAX常见问题:控制流)。在你的例子中,你在数据结构中循环了4800个条目,这样就造成了一个非常长且效率低下的程序。

在这种情况下,通常的解决办法是使用jax.vmap来重写你的程序。像大多数JAX的构造一样,这种方法在处理结构体数组模式时效果最好,而不是你数据中使用的数组结构模式。所以,使用vmap的第一步是将你的数据重新组织成JAX可以使用的形式;可能看起来像这样:

itimes = jnp.arange(len(idc['times']))
prf = jnp.array([idc['data'][i]['prf'] for i in itimes])
errs = jnp.array([idc['data'][i]['errs'] for i in itimes])

sprf = jax.vmap(model, in_axes=[None, None, 0, 0, 0])(fssc, fssh, itimes, rv, amp)
chi2 = jnp.sum((oprf - sprf) / oprf_errs) ** 2) / len(times) / sprf.shape[1]

这段代码不能直接使用:你还需要将model函数使用的数据重构为结构体数组的样式,但希望这能给你一个大概念。

另外要注意的是,这假设idc['data'][i]['prf']idc['data'][i]['errs']的每个条目都有相同的形状。如果不是这样的话,那么你的问题可能不太适合JAX的SPMD编程模型,而且没有简单的方法来解决长时间编译的问题。

撰写回答