对每个输入通道应用不同的conv1d过滤器

2024-04-19 11:57:43 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在研究张量流模型,其中,每个输入通道都应该应用单独的1d卷积。我已经玩过各种convXd函数。到目前为止,我已经有了一些工作,每个滤波器都应用到每个通道,产生N x N输出,从中我可以选择一个对角线。但这似乎很低效。关于如何仅卷积输入通道i的滤波器i有什么想法吗?谢谢你的建议!

说明我最佳工作示例的代码:

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
print(tf.__version__)

# [batch, in_height, in_width, in_channels]
X_size = [5, 109, 2, 1]

# [filter_height, filter_width, in_channels, out_channels]
W_size = [10, 1, 1, 2]

mX = np.zeros(X_size)
mX[0,10,0,0]=1
mX[0,40,1,0]=2

mW = np.zeros(W_size)
mW[1:3,0,0,0]=1
mW[3:6,0,0,1]=-1

X = tf.Variable(mX, dtype=tf.float32)
W = tf.Variable(mW, dtype=tf.float32)

# convolve everything
Y = tf.nn.conv2d(X, W, strides=[1, 1, 1, 1], padding='VALID')

# now only preserve the outputs for filter i + input i
Y_desired = tf.matrix_diag_part(Y)    

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(Y.shape)
    Yout = sess.run(fetches=Y)

# Yes=desired output, No=extraneous output
plt.figure()
plt.subplot(2,2,1)
plt.plot(Yout[0,:,0,0])
plt.title('Yes: W filter 0 * X channel 0')
plt.subplot(2,2,2)
plt.plot(Yout[0,:,1,0])
plt.title('No: W filter 0 * X channel 1')
plt.subplot(2,2,3)
plt.plot(Yout[0,:,0,1])
plt.title('No: W filter 1 * X channel 0')
plt.subplot(2,2,4)
plt.plot(Yout[0,:,1,1])
plt.title('Yes: W filter 1 * X channel 1')
plt.tight_layout()

以下是一个修订版,其中包含使用depthwise\u conv2d的建议:

^{pr2}$

Tags: inimportsizeplottitletfasnp
1条回答
网友
1楼 · 发布于 2024-04-19 11:57:43

听起来你在找depthwise convolution。这将为每个输入通道构建单独的过滤器。不幸的是,似乎没有内置的1D版本,但是大多数1D卷积实现只是在引擎盖下使用2D。你可以这样做:

inp = ...  # assume this is your input, shape batch x time (or width or whatever) x channels
inp_fake2d = inp[:, tf.newaxis, :, :]  # add a fake second spatial dimension
filters = tf.random_normal([1, w, channels, 1])
out_fake2d = tf.nn.depthwise_conv2d(inp_fake2d, filters, [1,1,1,1], "valid")
out = out_fake2d[:, 0, :, :]

这将添加一个大小为1的“假”第二个空间维度,然后卷积一个过滤器(在伪维度中也是大小1,没有任何东西在该方向卷积),最后再次移除伪维度。注意,滤波器张量中的第四个维度(也是尺寸1)是每个输入通道的滤波器数量。因为每个通道只需要一个单独的过滤器,所以这个数字是1。在

我希望我能正确地理解这个问题,因为我有点困惑于你的输入X一开始是4D(通常你会使用1D卷积来进行3D输入)。不过,你也许可以根据自己的需要来调整它。在

相关问题 更多 >