如何利用量化训练完成神经网络的4位量化

2024-06-16 09:32:12 发布

您现在位置:Python中文网/ 问答频道 /正文

我用量子化感知在4位量子化LeNet-训练。但是量化感知训练仅对8位的神经网络进行量化。 我不知道如何更改代码。在

我尝试更改quantize_graph.py和{}中的代码tensorflow/contrib/quantize/python。但我失败了。在

def Quantize(graph,
             is_training,
             weight_bits=8,
             activation_bits=8,
             ema_decay=0.999,
             quant_delay=None,
             vars_collection=ops.GraphKeys.GLOBAL_VARIABLES,
             scope=None):
  """Updates graph with quantization operations.

  Currently we quantize the following tensors:
  * Conv/MatMul: Quantize the weights if it matches.
  * Activation: Quantize the output if it matches.
  * Bypass/Post-activation Bypass: Quantize both input and output
    if it matches.

  Args:
    graph: Graph to modify.
    is_training: Whether quantizing training graph or eval graph.
    weight_bits: Number of bits to use for quantizing weights.
    activation_bits: Number of bits to use for quantizing activations.
    ema_decay: (Optional) Float, EMA decay parameter.  EMA is used to update
      quantization intervals for quantizing activations (see here about EMA:
      https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average).
    quant_delay: (Optional, default None) Int, count of global steps for which
      to delay quantization.  This helps weights stabilize at the start of
      training.
    vars_collection: (Optional) Collection where to store the variables for
      quantization interval ends.
    scope: The scope to be transformed. If it's not None, only the ops which
      are in this scope will be transformed.
  Raises:
    ValueError: When quantization fails.
  """
  if scope and not scope.endswith('/'):
    scope += '/'

  input_to_ops_map = input_to_ops.InputToOps(graph)        #InputToOps:"""Holds a mapping from tensor's name to ops that take it as input."""
  for layer_match in _FindLayersToQuantize(graph):      #_FindLayersToQuantize:""Matches layers in graph to quantize
    # Quantize the weights.
    context = _GetContextFromOp(layer_match.layer_op)  #_GetContextFromOp:"""Gets the root context name from the op name."""

    # If `scope` is given, only quantize it if the consumer of weights
    # (the layer op) is in the right scope.
    _InsertQuantOp(         #_InsertQuantOp:Inserts a quant op between a producer op and (multiple) consumer ops.
        context,
        'weights_quant',
        layer_match.weight_tensor.op, [layer_match.layer_op],
        is_training,
        moving_avg=False,
        ema_decay=ema_decay,
        quant_delay=quant_delay,
        narrow_range=True,
        vars_collection=vars_collection,
        bits=weight_bits,
        consumer_scope=scope)

    # Quantize the activations.
    #ConsumerOperations:""Looks through outputs of producer_op, finds ops that take them as input
    consumer_ops = input_to_ops_map.ConsumerOperations(
        layer_match.activation_op)
    add_context = context
    if layer_match.bypass_op:
      pattern_match_result = re.search(r'^(.*)/([^/]+)', context)
      if pattern_match_result is not None:
        add_context = pattern_match_result.group(1)
      else:
        add_context = ''
    # If `scope` is given, only quantize it if the producer of weights
    # (usually it's the layer op) is in the right scope.
    _InsertQuantOp(
        add_context,
        'act_quant',
        layer_match.activation_op,
        consumer_ops,
        is_training,
        moving_avg=True,
        ema_decay=ema_decay,
        quant_delay=quant_delay,
        vars_collection=vars_collection,
        bits=activation_bits,
        init_min=0.0,
        producer_scope=scope)

    # Quantize the inputs and output to the bypass (if it exists). The input to
    # the bypass is the bias add, and the output is the activation.
    #一下是对卷积的量化
    if layer_match.bypass_op is not None:
      # If `scope` is given, only quantize it if the both the producer and the
      # consumer are in the right scope.
      _InsertQuantOp(    
          context,
          'conv_quant',
          layer_match.bias_add_op, [layer_match.bypass_op],
          is_training,
          moving_avg=True,
          ema_decay=ema_decay,
          quant_delay=quant_delay,
          vars_collection=vars_collection,
          bits=activation_bits,
          producer_scope=scope,
          consumer_scope=scope)
      # Make sure the op following this isn't an activation. In which case, we
      # shouldn't quantize it, since the activation will be Fused into the
      # Add at inference time.
      consumers = input_to_ops_map.ConsumerOperations(layer_match.bypass_op)
      if any([consumer.type in _ACTIVATION_TYPES for consumer in consumers]):
        logging.info('Skipping %s, because its followed by an activation.',
                     layer_match.bypass_op.name)
      else:
        _InsertQuantOp(
            add_context,
            'add_quant',
            layer_match.bypass_op,
            input_to_ops_map.ConsumerOperations(layer_match.bypass_op),
            is_training,
            moving_avg=True,
            ema_decay=ema_decay,
            quant_delay=quant_delay,
            vars_collection=vars_collection,
            bits=activation_bits,
            producer_scope=scope,
            consumer_scope=scope)

    # Quantize bypass ops that occur after the activation.
    if layer_match.post_activation_bypass_op is not None:
      pattern_match_result = re.search(
          r'^(.*)/([^/]+)', layer_match.post_activation_bypass_op.name)
      if pattern_match_result is not None:
        post_activation_bypass_context = pattern_match_result.group(1)
      else:
        post_activation_bypass_context = ''
      # If `scope` is given, only quantize it if the producer is in the right
      # scope.
      # Make sure the op following this isn't an activation. In which case, we
      # shouldn't quantize it, since the activation will be Fused into the
      # Add at inference time.
      consumers = input_to_ops_map.ConsumerOperations(
          layer_match.post_activation_bypass_op)
      if any([consumer.type in _ACTIVATION_TYPES for consumer in consumers]):
        logging.info('Skipping %s, because its followed by an activation.',
                     layer_match.post_activation_bypass_op.name)
      else:
        _InsertQuantOp(
            post_activation_bypass_context,
            'post_activation_bypass_quant',
            layer_match.post_activation_bypass_op,
            consumers,
            is_training,
            moving_avg=True,
            ema_decay=ema_decay,
            quant_delay=quant_delay,
            vars_collection=vars_collection,
            bits=activation_bits,
            producer_scope=scope)

Tags: thetolayerifismatchcontextactivation