当用字典映射张量值时,我得到TypeError:张量是不可破坏的。相反,使用tensor.ref()作为键

2024-04-25 23:07:51 发布

您现在位置:Python中文网/ 问答频道 /正文

我尝试基于一个字典创建一个新的张量,该字典将张量中的值映射为1到1(下面的示例只是一个简单的例子),我得到了错误“TypeError:tensor是不可破坏的。相反,使用tensor.ref()作为键。”-尽管我没有在字典中使用张量作为键,但我在之前将它们转换为int:

tensor1 = tf.cast([1,2,3], dtype=tf.int32)
m = {1:101, 2:102, 3:103}
tf.map_fn(lambda  x: m[int(x)], elems=tensor1, fn_output_signature=tf.int32)

在colab中运行上述代码段时出现的错误:

<ipython-input-23-ec0121562aaa> in <module>()
      1 tensor1 = tf.cast([1,2,3], dtype=tf.int32)
      2 m = {1:101, 2:102, 3:103}
----> 3 tf.map_fn(lambda  x: m[int(x)], elems=tensor1, fn_output_signature=tf.int32)

8 frames

/usr/local/lib/python3.7/dist-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
    600                   func.__module__, arg_name, arg_value, 'in a future version'
    601                   if date is None else ('after %s' % date), instructions)
--> 602       return func(*args, **kwargs)
    603 
    604     doc = _add_deprecated_arg_value_notice_to_docstring(

/usr/local/lib/python3.7/dist-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
    533                 'in a future version' if date is None else ('after %s' % date),
    534                 instructions)
--> 535       return func(*args, **kwargs)
    536 
    537     doc = _add_deprecated_arg_notice_to_docstring(

/usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/map_fn.py in map_fn_v2(fn, elems, dtype, parallel_iterations, back_prop, swap_memory, infer_shape, name, fn_output_signature)
    649       swap_memory=swap_memory,
    650       infer_shape=infer_shape,
--> 651       name=name)
    652 
    653 

/usr/local/lib/python3.7/dist-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
    533                 'in a future version' if date is None else ('after %s' % date),
    534                 instructions)
--> 535       return func(*args, **kwargs)
    536 
    537     doc = _add_deprecated_arg_notice_to_docstring(

/usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/map_fn.py in map_fn(fn, elems, dtype, parallel_iterations, back_prop, swap_memory, infer_shape, name, fn_output_signature)
    505         back_prop=back_prop,
    506         swap_memory=swap_memory,
--> 507         maximum_iterations=n)
    508     result_batchable = [r.stack() for r in r_a]
    509 

/usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/control_flow_ops.py in while_loop(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, name, maximum_iterations, return_same_structure)
   2775                                               list(loop_vars))
   2776       while cond(*loop_vars):
-> 2777         loop_vars = body(*loop_vars)
   2778         if try_to_pack and not isinstance(loop_vars, (list, _basetuple)):
   2779           packed = True

/usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/control_flow_ops.py in <lambda>(i, lv)
   2766         cond = lambda i, lv: (  # pylint: disable=g-long-lambda
   2767             math_ops.logical_and(i < maximum_iterations, orig_cond(*lv)))
-> 2768         body = lambda i, lv: (i + 1, orig_body(*lv))
   2769       try_to_pack = False
   2770 

/usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/map_fn.py in compute(i, tas)
    489       ag_ctx = autograph_ctx.control_status_ctx()
    490       autographed_fn = autograph.tf_convert(fn, ag_ctx)
--> 491       result_value = autographed_fn(elems_value)
    492       nest.assert_same_structure(fn_output_signature or elems, result_value)
    493       result_value_flat = nest.flatten(result_value)

/usr/local/lib/python3.7/dist-packages/tensorflow/python/autograph/impl/api.py in wrapper(*args, **kwargs)
    693       except Exception as e:  # pylint:disable=broad-except
    694         if hasattr(e, 'ag_error_metadata'):
--> 695           raise e.ag_error_metadata.to_exception(e)
    696         else:
    697           raise

TypeError: in user code:

    <ipython-input-21-cf7a2c1a01a1>:3 None  *
        lambda  x: m[int(x)], elems=tk1, fn_output_signature=tf.int32, back_prop=False )
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/ops.py:845 __hash__
        raise TypeError("Tensor is unhashable. "

    TypeError: Tensor is unhashable. Instead, use tensor.ref() as the key.

Tags: lambdainpymaplibpackagesusrlocal
1条回答
网友
1楼 · 发布于 2024-04-25 23:07:51

出现错误是因为当您使用int(x)键入casted时,它仍然是张量。它是类型 tensorflow.python.framework.ops.EagerTensor。请使用numpy()

所以代码更改将是

tf.map_fn(lambda  x: m[x.numpy()], elems=tensor1, fn_output_signature=tf.int32)

注意:有时候一行程序不容易调试,所以下面的内容可以简化调试

def my_debug_func(x):
    print(f"type of int(x) : {type(int(x))}")
    print(f"type of x.numpy() {type(x.numpy())}")
    return m[x.numpy()]
new_tensor=tf.map_fn(my_debug_func, elems=tensor1, fn_output_signature=tf.int32)

并将打印

type of int(x) : <class 'tensorflow.python.framework.ops.EagerTensor'>
type of x.numpy() <class 'numpy.int32'>

相关问题 更多 >

    热门问题