如何实现一个自定义的keras层,它只保留前n个值,其余的都归零?

2024-06-06 04:42:02 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在尝试实现一个自定义Keras层,它将只保留输入的前N个值,并将其余所有值转换为零。我有一个版本,大多数情况下是有效的,但如果有联系,则会留下超过N个值。我想使用一个sort函数总是只留下N个非零值。在

以下是主要工作层,当存在连接时,会留下超过N个值:

def top_n_filter_layer(input_data, n=2, tf_dtype=tf_dtype):

    #### Works, but returns more than 2 values if there are ties:
    values_to_keep = tf.cast(tf.nn.top_k(input_data, k=n, sorted=True).values, tf_dtype)
    min_value_to_keep = tf.cast(tf.math.reduce_min(values_to_keep), tf_dtype)
    mask = tf.math.greater_equal(tf.cast(input_data, tf_dtype), min_value_to_keep)
    zeros = tf.zeros_like(input_data)
    output = tf.where(mask, input_data, zeros)

    return output

这是我正在研究的排序方法,但我还是被tf.scatter_更新函数抱怨等级不匹配:

^{pr2}$

回溯如下:

---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
/opt/conda/lib/python3.6/site-packages/tensorflow/python/framework/ops.py in _create_c_op(graph, node_def, inputs, control_inputs)
   1658   try:
-> 1659     c_op = c_api.TF_FinishOperation(op_desc)
   1660   except errors.InvalidArgumentError as e:

InvalidArgumentError: Shapes must be equal rank, but are 2 and 3 for 'ScatterUpdate' (op: 'ScatterUpdate') with input shapes: [?,10], [?,2], [?,2].

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
<ipython-input-10-598e009077f8> in <module>()
     27 
     28 input_layer = Input(shape=(10,))
---> 29 output_data = top_n_filter_layer(input_layer)
     30 
     31 with tf.Session() as sess:

<ipython-input-10-598e009077f8> in top_n_filter_layer(input_data, n, tf_dtype)
     18     zeros_variable = tf.assign(zeros_variable, zeros, validate_shape=False)
     19 
---> 20     output = tf.scatter_update(zeros_variable, indices_to_keep, values_to_keep)
     21 
     22     return output

/opt/conda/lib/python3.6/site-packages/tensorflow/python/ops/state_ops.py in scatter_update(ref, indices, updates, use_locking, name)
    297   if ref.dtype._is_ref_dtype:
    298     return gen_state_ops.scatter_update(ref, indices, updates,
--> 299                                         use_locking=use_locking, name=name)
    300   return ref._lazy_read(gen_resource_variable_ops.resource_scatter_update(  # pylint: disable=protected-access
    301       ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),

/opt/conda/lib/python3.6/site-packages/tensorflow/python/ops/gen_state_ops.py in scatter_update(ref, indices, updates, use_locking, name)
   1273   _, _, _op = _op_def_lib._apply_op_helper(
   1274         "ScatterUpdate", ref=ref, indices=indices, updates=updates,
-> 1275                          use_locking=use_locking, name=name)
   1276   _result = _op.outputs[:]
   1277   _inputs_flat = _op.inputs

/opt/conda/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py in _apply_op_helper(self, op_type_name, name, **keywords)
    786         op = g.create_op(op_type_name, inputs, output_types, name=scope,
    787                          input_types=input_types, attrs=attr_protos,
--> 788                          op_def=op_def)
    789       return output_structure, op_def.is_stateful, op
    790 

/opt/conda/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
    505                 'in a future version' if date is None else ('after %s' % date),
    506                 instructions)
--> 507       return func(*args, **kwargs)
    508 
    509     doc = _add_deprecated_arg_notice_to_docstring(

/opt/conda/lib/python3.6/site-packages/tensorflow/python/framework/ops.py in create_op(***failed resolving arguments***)
   3298           input_types=input_types,
   3299           original_op=self._default_original_op,
-> 3300           op_def=op_def)
   3301       self._create_op_helper(ret, compute_device=compute_device)
   3302     return ret

/opt/conda/lib/python3.6/site-packages/tensorflow/python/framework/ops.py in __init__(self, node_def, g, inputs, output_types, control_inputs, input_types, original_op, op_def)
   1821           op_def, inputs, node_def.attr)
   1822       self._c_op = _create_c_op(self._graph, node_def, grouped_inputs,
-> 1823                                 control_input_ops)
   1824 
   1825     # Initialize self._outputs.

/opt/conda/lib/python3.6/site-packages/tensorflow/python/framework/ops.py in _create_c_op(graph, node_def, inputs, control_inputs)
   1660   except errors.InvalidArgumentError as e:
   1661     # Convert to ValueError for backwards compatibility.
-> 1662     raise ValueError(str(e))
   1663 
   1664   return c_op

ValueError: Shapes must be equal rank, but are 2 and 3 for 'ScatterUpdate' (op: 'ScatterUpdate') with input shapes: [?,10], [?,2], [?,2].

@Vlad下面的答案显示了一种使用一种热编码的工作方法。下面是一个示例,说明它是有效的:

import tensorflow as tf
import numpy as np

tf.reset_default_graph()

model = tf.keras.models.Sequential()
model.add(tf.keras.layers.InputLayer((10,)))

def top_n_filter_layer(input_data, n=2):

    topk = tf.nn.top_k(input_data, k=n, sorted=False)

    res = tf.reduce_sum(                                 
        tf.one_hot(topk.indices,                         
                   input_data.get_shape().as_list()[-1]), 
        axis=1)                                          

    res *= input_data

    return res

model.add(tf.keras.layers.Lambda(top_n_filter_layer))

x_train = [[1,2,3,4,5,6,7,7,7,7]]

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(model.output.eval({model.inputs[0]:x_train}))

# [[0. 0. 0. 0. 0. 0. 7. 7. 0. 0.]]

Tags: tonameinrefinputoutputdatareturn
1条回答
网友
1楼 · 发布于 2024-06-06 04:42:02

让我们一步一步来:

  1. 首先,我们取网络的软最大输出,找出其最大k值及其指标。在
  2. 我们创建一个热编码的向量,这样每个向量在索引的顶端k处都有一个。然后我们将k这样的向量求和,得到与k个正好相同的原始输出形状。在
  3. 一旦我们有一个张量在顶部k位置,我们就用网络的原始softmax输出进行元素乘法。在

topk=2值的Tensorflow示例:

import tensorflow as tf
import numpy as np

model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Dense(
    units=5, input_shape=(2, ), activation=tf.nn.softmax,
    kernel_initializer=tf.initializers.random_normal))

softmaxed = model.output # <  take the *softmaxed* output
topk = tf.nn.top_k(softmaxed,    # <  find its top k values and their indices
                   k=2,
                   sorted=False)

res = tf.reduce_sum(                                 # <  create a one-hot encoded
    tf.one_hot(topk.indices,                         #     vectors out of top k indices
               softmaxed.get_shape().as_list()[-1]), #     and sum each k of them to
    axis=1)                                          #     create a single binary tensor

res *= softmaxed # <  element-wise multiplication

x_train = [np.random.normal(size=(2, ))] # <  train data

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(res.eval({model.inputs[0]:x_train})) # [[0.2 0.2 0.  0.  0. ]]
    print(softmaxed.eval({model.inputs[0]:x_train})) # [[0.2 0.2 0.2 0.2 0.2]]

相关问题 更多 >