动态编辑Tensorflow对象检测的管道配置

2024-04-28 06:53:51 发布

您现在位置:Python中文网/ 问答频道 /正文

我使用的是tensorflow对象检测API,我希望能够在python中动态编辑配置文件,如下所示。我曾想过在python中使用protocol buffers库,但我不知道该怎么做。在

model {
ssd {
num_classes: 1
image_resizer {
  fixed_shape_resizer {
    height: 300
    width: 300
  }
}
feature_extractor {
  type: "ssd_inception_v2"
  depth_multiplier: 1.0
  min_depth: 16
  conv_hyperparams {
    regularizer {
      l2_regularizer {
        weight: 3.99999989895e-05
      }
    }
    initializer {
      truncated_normal_initializer {
        mean: 0.0
        stddev: 0.0299999993294
      }
    }
    activation: RELU_6
    batch_norm {
      decay: 0.999700009823
      center: true
      scale: true
      epsilon: 0.0010000000475
      train: true
    }
  }
 ...
 ...

}

有没有一种简单/简单的方法可以将“图像大小调整器”中的“高度”等字段的特定值从300更改为500?用修改过的值写回文件而不改变其他任何东西?在

编辑: 虽然@DmytroPrylipko提供的答案适用于配置中的大多数参数,但我在“复合字段”方面遇到了一些问题。。在

也就是说,如果我们有如下配置:

^{pr2}$

我添加这一行来编辑输入路径:

 pipeline_config.train_input_reader.tf_record_input_reader.input_path = "/tensorflow/models/data/train100.record"

它抛出错误:

TypeError: Can't set composite field

Tags: 对象apitrue编辑inputtensorflow动态train
2条回答

是的,使用Protobuf Python API非常简单:

编辑_管道.py

import argparse

import tensorflow as tf
from google.protobuf import text_format
from object_detection.protos import pipeline_pb2


def parse_arguments():                                                                                                                                                                                                                                                
    parser = argparse.ArgumentParser(description='')                                                                                                                                                                                                                  
    parser.add_argument('pipeline')                                                                                                                                                                                                                                   
    parser.add_argument('output')                                                                                                                                                                                                                                     
    return parser.parse_args()                                                                                                                                                                                                                                        


def main():                                                                                                                                                                                                                                                           
    args = parse_arguments()                                                                                                                                                                                                                                          
    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()                                                                                                                                                                                                          

    with tf.gfile.GFile(args.pipeline, "r") as f:                                                                                                                                                                                                                     
        proto_str = f.read()                                                                                                                                                                                                                                          
        text_format.Merge(proto_str, pipeline_config)                                                                                                                                                                                                                 

    pipeline_config.model.ssd.image_resizer.fixed_shape_resizer.height = 300                                                                                                                                                                                          
    pipeline_config.model.ssd.image_resizer.fixed_shape_resizer.width = 300                                                                                                                                                                                           

    config_text = text_format.MessageToString(pipeline_config)                                                                                                                                                                                                        
    with tf.gfile.Open(args.output, "wb") as f:                                                                                                                                                                                                                       
        f.write(config_text)                                                                                                                                                                                                                                          


if __name__ == '__main__':                                                                                                                                                                                                                                            
    main() 

我对剧本的称呼是:

^{pr2}$

复合字段

对于重复字段,必须将其视为数组(例如使用extend()append()方法):

pipeline_config.train_input_reader.tf_record_input_reader.input_path[0] = '/tensorflow/models/data/train100.record'

评估输入读取器错误

这是试图编辑复合字段的常见错误。(“找不到属性tf_record_input_reader”在eval_input_reader的情况下)

下面是@latida的答案中提到的。 通过将其设置为数组字段来解决这个问题。在

pipeline_config.eval_input_reader[0].label_map_path  = label_map_full_path
pipeline_config.eval_input_reader[0].tf_record_input_reader.input_path[0] = val_record_path
pipeline_config.eval_input_reader[0].label_map_path  = label_map_full_path
pipeline_config.eval_input_reader[0].tf_record_input_reader.input_path[0] = val_record_path

相关问题 更多 >