Tensorflow 2.1具有自定义通道操作的图像分类

2024-06-02 07:10:28 发布

您现在位置:Python中文网/ 问答频道 /正文

我一直在尝试创建一系列手工制作的通道,这些通道是图形的一部分,我希望在输入图像/张量穿过网络的其余部分之前将其堆叠到输入图像/张量上

input_tensor = KL.Input(shape=input_image, name="input")
handcrafted_channels = handcrafted(input_tensor)
x = KL.concatenate([input_tensor, handcrafted_channels], axis=-1)
x = KL.ZeroPadding2D((3, 3))(x)
x = KL.MaxPooling2D(pool_size=(1, 1), strides=(1,1), padding="same")(x)
x = KL.Conv2D(64, (7, 7), strides=(2, 2), name='conv1', use_bias=True, input_shape=x.shape, data_format="channels_last")(x)
... continue with normal resnet

def handcrafted(self, input_tensor):
    _red, _green, _blue = tf.split(input_tensor, 3, axis = 3)
    # This could be any sort of equation, but for example a really simple set
    handcrafted_channel_a = KL.add([_red, _green])
    handcrafted_channel_b = KL.subtract([_green, _blue])
    handcrafted_channels = KL.concatenate([handcrafted_channel_a, handcrafted_channel_b], axis=-1)
    return handcrafted_channels

当我用稀疏分类交叉熵的损失函数和SGD的优化器(学习率=0.01,动量=0.9,clipnorm=5.0)运行这个函数时,我得到了损失的nan。为了确保网络的其余部分正常,如果我删除了手工制作的_通道,并且没有第一个连接,我可以成功运行培训

Training run with handcrafted enabled

Training run without handcrafted enabled

以及在第一个历元结束时转储的错误消息:

---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-6-77a19ea1ade6> in <module>
     48           reduce_lr,
---> 49           early_stopping,
     50         ])

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq, max_queue_size, workers, use_multiprocessing, **kwargs)
    817         max_queue_size=max_queue_size,
    818         workers=workers,
--> 819         use_multiprocessing=use_multiprocessing)
    820 
    821   def evaluate(self,

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training_v2.py in fit(self, model, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq, max_queue_size, workers, use_multiprocessing, **kwargs)
    395                       total_epochs=1)
    396                   cbks.make_logs(model, epoch_logs, eval_result, ModeKeys.TEST,
--> 397                                  prefix='val_')
    398 
    399     return model.history

/usr/lib/python3.6/contextlib.py in __exit__(self, type, value, traceback)
     86         if type is None:
     87             try:
---> 88                 next(self.gen)
     89             except StopIteration:
     90                 return False

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training_v2.py in on_epoch(self, epoch, mode)
    769       if mode == ModeKeys.TRAIN:
    770         # Epochs only apply to `fit`.
--> 771         self.callbacks.on_epoch_end(epoch, epoch_logs)
    772       self.progbar.on_epoch_end(epoch, epoch_logs)
    773 

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/callbacks.py in on_epoch_end(self, epoch, logs)
    300     logs = logs or {}
    301     for callback in self.callbacks:
--> 302       callback.on_epoch_end(epoch, logs)
    303 
    304   def on_train_batch_begin(self, batch, logs=None):

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/callbacks.py in on_epoch_end(self, epoch, logs)
   1711 
   1712     if self.histogram_freq and epoch % self.histogram_freq == 0:
-> 1713       self._log_weights(epoch)
   1714 
   1715     if self.embeddings_freq and epoch % self.embeddings_freq == 0:

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/callbacks.py in _log_weights(self, epoch)
   1802           with ops.init_scope():
   1803             weight = K.get_value(weight)
-> 1804           summary_ops_v2.histogram(weight_name, weight, step=epoch)
   1805           if self.write_images:
   1806             self._log_weight_as_image(weight, weight_name, epoch)

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/ops/summary_ops_v2.py in histogram(name, tensor, family, step)
    821         name=scope)
    822 
--> 823   return summary_writer_function(name, tensor, function, family=family)
    824 
    825 

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/ops/summary_ops_v2.py in summary_writer_function(name, tensor, function, family)
    750   with ops.device("cpu:0"):
    751     op = smart_cond.smart_cond(
--> 752         should_record_summaries(), record, _nothing, name="")
    753     if not context.executing_eagerly():
    754       ops.add_to_collection(ops.GraphKeys._SUMMARY_COLLECTION, op)  # pylint: disable=protected-access

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/smart_cond.py in smart_cond(pred, true_fn, false_fn, name)
     52   if pred_value is not None:
     53     if pred_value:
---> 54       return true_fn()
     55     else:
     56       return false_fn()

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/ops/summary_ops_v2.py in record()
    743     with ops.name_scope(name_scope), summary_op_util.summary_scope(
    744         name, family, values=[tensor]) as (tag, scope):
--> 745       with ops.control_dependencies([function(tag, scope)]):
    746         return constant_op.constant(True)
    747 

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/ops/summary_ops_v2.py in function(tag, scope)
    819         tag,
    820         array_ops.identity(tensor),
--> 821         name=scope)
    822 
    823   return summary_writer_function(name, tensor, function, family=family)

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/ops/gen_summary_ops.py in write_histogram_summary(writer, step, tag, values, name)
    467       try:
    468         return write_histogram_summary_eager_fallback(
--> 469             writer, step, tag, values, name=name, ctx=_ctx)
    470       except _core._SymbolicException:
    471         pass  # Add nodes to the TensorFlow graph.

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/ops/gen_summary_ops.py in write_histogram_summary_eager_fallback(writer, step, tag, values, name, ctx)
    488   _attrs = ("T", _attr_T)
    489   _result = _execute.execute(b"WriteHistogramSummary", 0, inputs=_inputs_flat,
--> 490                              attrs=_attrs, ctx=ctx, name=name)
    491   _result = None
    492   return _result

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     65     else:
     66       message = e.message
---> 67     six.raise_from(core._status_to_exception(e.code, message), None)
     68   except TypeError as e:
     69     keras_symbolic_tensors = [

/usr/local/lib/python3.6/dist-packages/six.py in raise_from(value, from_value)

如果有人对在哪里尝试和排除故障有任何想法,我们将不胜感激


1条回答
网友
1楼 · 发布于 2024-06-02 07:10:28

我建议将handcrafted抽象为一个单独的模型:

def __init__(self, ...):
    self.handcrafted = create_handcrafted()
    ....

def create_handcrafted(self):
    inp = Input((None, None, 3))
    _red = inp[:, :, :, 0]
    _green = inp[:, :, :, 1]
    _blue = inp[:, :, :, 2]
    handcrafted_channel_a = _red + _green
    handcrafted_channel_b = _green - _blue
    handcrafted_channels = KL.concatenate([handcrafted_channel_a, handcrafted_channel_b], axis=-1)
    return Model(inputs=[inp], outputs=[handcrafted_channels])

....

然后将其作为主模型的一部分:

input_tensor = KL.Input(shape=input_image, name="input")
handcrafted_channels = self.handcrafted(input_tensor)
# Proceed as before

相关问题 更多 >