下面的代码(修改的tensorflow示例)产生错误“所有输入张量必须具有相同的秩”。tf.linalg.LINEAROPERATORTRIIAG的多个运算也给出了类似的错误。我需要在Keras层中用一个三对角矩阵乘以一个输入,由于该层输入中的额外批次维度,张量的秩是不同的。有任何已知的实际解决方案吗
import tensorflow as tf
superdiag = tf.constant([-1, -1, 0], dtype=tf.float64)
maindiag = tf.constant([2, 2, 2], dtype=tf.float64)
subdiag = tf.constant([0, -1, -1], dtype=tf.float64)
diagonals = [superdiag, maindiag, subdiag]
rhs = tf.constant([[[1, 1], [1, 1], [1, 1]]], dtype=tf.float64)
x = tf.linalg.tridiagonal_matmul(diagonals, rhs, diagonals_format='sequence')
你必须展开第一个维度
相关问题 更多 >
编程相关推荐