Tensorflow:此处无法访问:它是在另一个函数或代码块中定义的

2024-05-13 19:17:49 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在实现简单的RNN。在它中,我想返回列表中每个时间步的输出,稍后我可以将其提供给优化器。我在没有@tf.function的情况下构建了工作rnn。但在添加@tf.function之后,它会产生问题

def basic_rnn_cell(self,x,s):#Note :These function are defined in class
    s=self.U*x+self.W*s+self.b #U,W,b and all are tf.Variable
    y=self.V*s+self.c
    return y,s

@tf.function
def rnn(self,X):
    outputs=[]
    state=self.state
    for x in X:
        output,state=self.basic_rnn_cell(x,state)
        outputs.append(output)
    return outputs

这就是我所说的:

x=np.array([0.01,0.02,0.03],dtype=np.float32)
o.rnn(x)

我得到的错误是:

raise errors.InaccessibleTensorError(
tensorflow.python.framework.errors_impl.InaccessibleTensorError: The tensor 'Tensor("while/add_2:0", shape=(), dtype=float32)' cannot be accessed here: it is defined in another function or code block. 
Use return values, explicit Python locals or TensorFlow collections to access it. Defined in: FuncGraph(name=while_body_44, id=2538759416224); accessed from: FuncGraph(name=rnn, id=2538758824096). 

Tags: inselfoutputreturnbasictfdefnp
1条回答
网友
1楼 · 发布于 2024-05-13 19:17:49

这是因为使用python列表临时保存tensor对象。内存回收机制将删除跟踪此函数后保存的内容,因此无法实现。 如果要保存这些临时张量,必须使用tf.TensorArray作为替换。您可以参考以下内容:https://www.tensorflow.org/guide/function#loops

相关问题 更多 >