googlejax一维卷积神经网络

2024-04-16 06:32:42 发布

您现在位置:Python中文网/ 问答频道 /正文

我正试图用stax.GeneralConv()https://jax.readthedocs.io/en/latest/_modules/jax/experimental/stax.html#GeneralConv)在googlejax中实现一个1D卷积神经网络。 我有一个一维输入数组,有18个条目,输出数组有6个条目。我想实现内核宽度为3的CNN,如下所示:

init_random_params, conv_net = stax.serial(
    GeneralConv(('NC','IO','NC'),1,(3,),padding='SAME'), # dimension_numbers = ('NC','IO','NC')
    LogSoftmax,
    Dense(6),
)

使用初始网络参数:

rng = jax.random.PRNGKey(0)
_, init_params = init_random_params(rng, (18,))

但我得到了以下错误:

stax.py", line 75, in <listcomp>
    next(filter_shape_iter) for c in rhs_spec]

IndexError: tuple index out of range

stax要求维度编号rhs_spec至少有2个字符长,但我使用的是一维过滤器。有人知道如何解决这个问题吗


Tags: inioinit条目random数组paramsnc
1条回答
网友
1楼 · 发布于 2024-04-16 06:32:42

我自己没有试过,但我认为一维卷积仍然需要一个方向来卷积,例如

Conv2d = functools.partial(GeneralConv, ('NHWC', 'HWIO', 'NHWC'))
Conv1d = functools.partial(GeneralConv, ('NHC', 'HIO', 'NHC'))

换言之,将W轴从2d旋转到1d卷积

NHC对应的输入形状是(batch_size, sequence_length, num_channels)

注意,即使通道的数量可能是1,您仍然需要包括该轴,因为GeneralConv沿着num_channels = input_shape['NHC'.index('C')]行进行索引查找

相关问题 更多 >