采用可变输入张量列表的自定义Tensorflow操作

2024-06-16 14:14:03 发布

您现在位置:Python中文网/ 问答频道 /正文

<>我试图在C++中编程一个自定义的TysFooFrm操作。此操作应将张量列表作为输入并修改其内容。我想使用Assign操作的示例,它在Tensorflow代码中注册如下:

REGISTER_OP("Assign")
    .Input("ref: Ref(T)")
    .Input("value: T")
    .Output("output_ref: Ref(T)")
    .Attr("T: type")
    ...

作为参考,赋值操作(ref)的input(0)是要赋值的张量,input(1)value)是它的新值。输出张量(output_ref)只是对传播的input(0)的引用。你知道吗

在其定义中,Assign操作还有以下代码来检查第一个输入是否是可变张量:

OP_REQUIRES(context, IsRefType(context->input_type(0)),
errors::InvalidArgument("lhs input needs to be a ref type"));

与赋值操作相反,我的自定义操作应该采用可变张量列表(而不是单个张量),这些张量的内容将由该操作修改。你知道吗

我尝试通过以下方式注册我的操作:

REGISTER_OP("MyCustomOperation")
    .Input("refs: list(Ref(T))")
    .Attr("T: type")
    ...

但在加载库时,Tensorflow会给我以下错误:

tensorflow.python.framework.errors_impl.InvalidArgumentError: Reference to unknown attr 'list' from Input("refs: list(Ref(T))") for Op MyCustomOperation

我还尝试了refs: list(T)属性T: Ref(type),但这也不起作用(Tensorflow打印错误Trouble parsing type string at 'Ref(type)' from Attr("T: Ref(type)"))。你知道吗

所以我换成了以下的注册方式:

REGISTER_OP("MyCustomOperation")
    .Input("refs: list(Ref(T))")
    .Attr("T: type")
    ...

但是,有了这个定义,IsRefType断言就失败了。注意,我在Python级别传递了一个tf.RefVariable列表,我假设它是可变的。你知道吗

如何使我的操作正确地期望可变张量的列表?你知道吗


Tags: registerref列表inputtensorflowtypelistattr
1条回答
网友
1楼 · 发布于 2024-06-16 14:14:03

经过一些调查,我发现了一个这样做的例子。下面是传递可变张量列表的解决方案:

REGISTER_OP("MyCustomOperation")
    .Input("refs: Ref(N * T)")
    .Attr("T: type")
    .Attr("N: int")
    ...

相关问题 更多 >