如何在tensorflow中实现三对角矩阵与不同秩和外维张量的乘法

2024-05-29 11:27:41 发布

您现在位置:Python中文网/ 问答频道 /正文

下面的代码(修改的tensorflow示例)产生错误“所有输入张量必须具有相同的秩”。tf.linalg.LINEAROPERATORTRIIAG的多个运算也给出了类似的错误。我需要在Keras层中用一个三对角矩阵乘以一个输入,由于该层输入中的额外批次维度,张量的秩是不同的。有任何已知的实际解决方案吗

import tensorflow as tf

superdiag = tf.constant([-1, -1, 0], dtype=tf.float64)
maindiag = tf.constant([2, 2, 2], dtype=tf.float64)
subdiag = tf.constant([0, -1, -1], dtype=tf.float64)
diagonals = [superdiag, maindiag, subdiag]
rhs = tf.constant([[[1, 1], [1, 1], [1, 1]]], dtype=tf.float64)
x = tf.linalg.tridiagonal_matmul(diagonals, rhs, diagonals_format='sequence')

Tags: 代码示例tftensorflow错误dtypeconstantfloat64
1条回答
网友
1楼 · 发布于 2024-05-29 11:27:41

你必须展开第一个维度

superdiag = tf.constant([-1, -1, 0], dtype=tf.float64)
maindiag = tf.constant([2, 2, 2], dtype=tf.float64)
subdiag = tf.constant([0, -1, -1], dtype=tf.float64)
diagonals = [tf.expand_dims(superdiag,0), tf.expand_dims(maindiag,0), tf.expand_dims(subdiag,0)]
rhs = tf.constant([[[1, 1], [1, 1], [1, 1]]], dtype=tf.float64)
x = tf.linalg.tridiagonal_matmul(diagonals, rhs, diagonals_format='sequence')

相关问题 更多 >

    热门问题