创建矩阵运算器

2024-04-26 11:25:33 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图在TensorFlow中实现一种非线性滤波器,但是我在一步实现上遇到了困难。步骤基本上类似于:

x_update = x.assign(tf.matmul(A, x))

问题是矩阵A的结构类似于:

^{pr2}$

其中每个fn(x)都是我状态的一个非线性函数;类似于tf.sin(x[4])甚至{}。在

我不知道如何创建我的A矩阵来嵌入这些操作。我首先用一些值初始化它:

A_mat = np.eye(5)
A_mat[0, 1] = 0.1
A = tf.Variable(A_mat, dtype=tf.float32, trainable=False, name='A')

然后我试图用tf.scatter_update进行一些切片更新,类似于:

# Define my nonlinear operations.
f1 = tf.cos(...)
f2 = tf.sin(...)
# ...

# Define the part that I want to substitute.
new_part = tf.constant(tf.convert_to_tensor([[f1, f2, f3],
                                             [f4, f5, f6]]))

# Define slice indices and update the matrix.
inds = [vals for vals in zip(np.arange(1, 3), np.arange(2, 5))]
A_update = tf.scatter_update(A, tf.constant(inds), new_part, name='A_update')

这给了我一个错误说明:

ValueError: Shapes must be equal rank, but are 1 and 0

From merging shape 1 with other shapes. for 'packed/0' (op: 'Pack') with input shapes: [1], [1], [], [], [], [].

我也尝试过将我的矩阵new_part分配回numpy定义的A_mat,但是我得到了一个不同的错误,我认为这是由于意外的数据类型造成的,当一个数值数组突然被分配了张量元素。在

那么,有人知道如何定义一个操作矩阵,当矩阵被这样使用时,它会更新吗?在

理想情况下,我希望定义矩阵A,以便A内更新的所有操作都是对A调用的一部分,并自动发生。这样我就可以完全避免切片赋值,而且只会感觉到更张量的flow-y

谢谢你!在


更新:

我通过将ops包装在tf.reshape(op_name, [])中并将更新更改为:

new_part = tf.convert_to_tensor([[0, 0, f1, f2, f3],
                                 [0, 0, f4, f5, f6]]))
rows = np.arange(start_row, end_row)
A_update = tf.scatter_update(A, rows, new_part, name='A_update')

结果是^{}只能对变量的第一个维度进行操作,所以我必须向它提供完整的行,并在我想放置它们的位置提供行索引。这有帮助,但我还是要问:


我的问题:

定义这个A矩阵的最好、最张量的方法是什么,使那些常量元素保持不变,而那些作为我的图上其他张量运算的元素也同样嵌入到A中?我希望在我的图上调用A,这样就可以在不需要手动执行tf.scatter_update的情况下完成并运行这些更新。或者这是正确的方法吗?在


Tags: tonamenew定义tfnpupdate矩阵
1条回答
网友
1楼 · 发布于 2024-04-26 11:25:33

更新子矩阵最简单的方法是使用tensorflow的python切片操作。在

import numpy as np
import tensorflow as tf
A = tf.Variable(np.zeros((5, 5), dtype=np.float32), trainable=False)
new_part = tf.ones((2,3))

update_A = A[2:4,2:5].assign(new_part)

sess = tf.InteractiveSession()
tf.global_variables_initializer().run()
print(update_A.eval())
# array([[ 0.,  0.,  0.,  0.,  0.],
#        [ 0.,  0.,  0.,  0.,  0.],
#        [ 0.,  0.,  1.,  1.,  1.],
#        [ 0.,  0.,  1.,  1.,  1.],
#        [ 0.,  0.,  0.,  0.,  0.]], dtype=float32)

相关问题 更多 >