理解python中的嵌套函数

2024-05-28 23:08:33 发布

您现在位置:Python中文网/ 问答频道 /正文

我有以下嵌套函数代码:

 def function1(req):
    def inner_func(username):
        if username == 'Admin':
            return "'{0}' can access to {1}.".format(username, req)
        else:
            return "'{0}' cannot access to {1}.".format(username, req)
    return inner_func

current_user = function1('Admin Area')
print(current_user('Admin'))

random_user = function1('Admin Area')
print(random_user('Not Admin'))

输出如下:

'Admin' can access to Admin Area.

'Not Admin' cannot access to Admin Area.

我明白这一点,但我有一个片段或代码在一个预先训练的模型称为伯特

def model_fn_builder(bert_config, init_checkpoint, learning_rate,
                     num_train_steps, num_warmup_steps, use_tpu,
                     use_one_hot_embeddings):
  """Returns `model_fn` closure for TPUEstimator."""

  def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
    """The `model_fn` for TPUEstimator."""

    tf.logging.info("*** Features ***")
    for name in sorted(features.keys()):
      tf.logging.info("  name = %s, shape = %s" % (name, features[name].shape))

    unique_ids = features["unique_ids"]
    input_ids = features["input_ids"]
    input_mask = features["input_mask"]
    segment_ids = features["segment_ids"]

    is_training = (mode == tf.estimator.ModeKeys.TRAIN)

    (start_logits, end_logits) = create_model(
        bert_config=bert_config,
        is_training=is_training,
        input_ids=input_ids,
        input_mask=input_mask,
        segment_ids=segment_ids,
        use_one_hot_embeddings=use_one_hot_embeddings)

    tvars = tf.trainable_variables()

    initialized_variable_names = {}
    scaffold_fn = None
    if init_checkpoint:
      (assignment_map, initialized_variable_names
      ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
      if use_tpu:

        def tpu_scaffold():
          tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
          return tf.train.Scaffold()

        scaffold_fn = tpu_scaffold
      else:
        tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

    tf.logging.info("**** Trainable Variables ****")
    for var in tvars:
      init_string = ""
      if var.name in initialized_variable_names:
        init_string = ", *INIT_FROM_CKPT*"
      tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                      init_string)

    output_spec = None
    if mode == tf.estimator.ModeKeys.TRAIN:
      seq_length = modeling.get_shape_list(input_ids)[1]

      def compute_loss(logits, positions):
        one_hot_positions = tf.one_hot(
            positions, depth=seq_length, dtype=tf.float32)
        log_probs = tf.nn.log_softmax(logits, axis=-1)
        loss = -tf.reduce_mean(
            tf.reduce_sum(one_hot_positions * log_probs, axis=-1))
        return loss

      start_positions = features["start_positions"]
      end_positions = features["end_positions"]

      start_loss = compute_loss(start_logits, start_positions)
      end_loss = compute_loss(end_logits, end_positions)

      total_loss = (start_loss + end_loss) / 2.0

      train_op = optimization.create_optimizer(
          total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu)

      output_spec = tf.contrib.tpu.TPUEstimatorSpec(
          mode=mode,
          loss=total_loss,
          train_op=train_op,
          scaffold_fn=scaffold_fn)
    elif mode == tf.estimator.ModeKeys.PREDICT:
      predictions = {
          "unique_ids": unique_ids,
          "start_logits": start_logits,
          "end_logits": end_logits,
      }
      output_spec = tf.contrib.tpu.TPUEstimatorSpec(
          mode=mode, predictions=predictions, scaffold_fn=scaffold_fn)
    else:
      raise ValueError(
          "Only TRAIN and PREDICT modes are supported: %s" % (mode))

    return output_spec

  return model_fn

算法调用此函数model_fn_builder,如下所示:

model_fn = model_fn_builder(
      bert_config=bert_config,
      init_checkpoint=FLAGS.init_checkpoint,
      learning_rate=FLAGS.learning_rate,
      num_train_steps=num_train_steps,
      num_warmup_steps=num_warmup_steps,
      use_tpu=FLAGS.use_tpu,
      use_one_hot_embeddings=FLAGS.use_tpu)

这里我不知道函数featureslabelsmodeparams的参数是如何传递的。你知道吗

有人能帮我理解吗??你知道吗


Tags: idsmodeladmininitusemodetftrain
2条回答

调用model_fn时,您将传递这些参数。你知道吗

以您的例子:

current_user = function1('Admin Area')

相当于:

model_fn = model_fn_builder(
      bert_config=bert_config,
      init_checkpoint=FLAGS.init_checkpoint,
      learning_rate=FLAGS.learning_rate,
      num_train_steps=num_train_steps,
      num_warmup_steps=num_warmup_steps,
      use_tpu=FLAGS.use_tpu,
      use_one_hot_embeddings=FLAGS.use_tpu)

然后,正如您为用户名提供:

print(current_user('Admin'))

调用model_fn时需要传递features, labels, mode, params。你知道吗

正如@MrFuppets所说,model\u fn是一个函数,您通常会用必要的参数来调用它。你知道吗

另一个与OP类似的例子是:

from types import FunctionType

def function_builder(builder_param: str) -> FunctionType:
    def inner_function(inner_function_param: int) -> int:
        print(f"I'm inner_function, called with {inner_function_param}")
        print(f"I'm built from function_builder, built with {builder_param}")

        return inner_function_param

    return inner_function

# Call function builder
# build a function
built_function = function_builder('Hello world!')
print(f'Type of built_function is {type(built_function)}')
# Type of built_function is <class 'function'>
print('')

# Call built_function as a normal function, passing args as usually
result = built_function(42)
print(f'Result of calling built_function is {result}')
# I'm inner_function, called with 42
# I'm built from function_builder, built with Hello world!
# Result of calling built_function is 42

print('')

# Calling it again
result = built_function(27)
print(f'Result of calling built_function is {result}')
# I'm inner_function, called with 27
# I'm built from function_builder, built with Hello world!
# Result of calling built_function is 27

相关问题 更多 >

    热门问题