如何使用node\ u def在Tensorflow中复制图形操作?

2024-04-27 03:06:00 发布

您现在位置:Python中文网/ 问答频道 /正文

假设我有一个操作my_op,定义如下:

name: "Const"
op: "Const"
attr {
  key: "dtype"
  value {
    type: DT_INT32
  }
}
attr {
  key: "value"
  value {
    tensor {
      dtype: DT_INT32
      tensor_shape {
        dim {
          size: 2
        }
      }
      tensor_content: "\001\000\000\000\001\000\000\000"
    }
  }
}

我可以访问图形中的操作,但不能访问构造代码。我想复制操作并更改其某些属性:

name: "PrettyConst"
op: "Const"
attr {
  key: "dtype"
  value {
    type: DT_INT32
  }
}
attr {
  key: "value"
  value {
    tensor {
      dtype: DT_FLOAT32
      tensor_shape {
        dim {
          size: 8
        }
      }
      tensor_content: "\001\000\000\000\001\000\000\000\001\000\000\000\001\000\000\000\001\000\000\000\001\000\000\000\001\000\000\000\001\000\000\000"
    }
  }
}

虽然我可以通过将图形另存为txt、修改文件内容并将其还原回来来轻松地完成,但在python中找不到一种简单的方法。我认为应该有可能做到以下几点:

op_def_copy = op.node_def.copy()
op_def_copy.name = "PrettyConst"
op_def_copy.attr["dtype"] = 0
# and also do something with the content, whatever
graph.append(tf.Operation(op_def_copy))

不过,tf.contrig.graph_editor似乎能够做到这一点。你知道吗


Tags: keynamevaluedeftypedtcontentattr
2条回答

@jdehesa很好地回答了这个问题。我有更多的工具:

import tensorflow
import copy
import tensorflow.contrib.graph_editor as ge
from copy import deepcopy

a = tf.constant(1)
b = tf.constant(2)
c = a+b

def modify(t): 
    # illustrate operation copy&modification
    new_t = deepcopy(t.op.node_def)
    new_t.name = new_t.name+"_but_awesome"
    new_t = tf.Operation(new_t, tf.get_default_graph())
    # we got a tensor, let's return a tensor
    return new_t.outputs[0]

def update_existing(target, updated):
    # illustrate how to use new op
    related_ops = ge.get_backward_walk_ops(target, stop_at_ts=updated.keys(), inclusive=True)
    new_ops, mapping = ge.copy_with_input_replacements(related_ops, updated)
    new_op = mapping._transformed_ops[target.op]
    return new_op.outputs[0]

new_a = modify(a)
new_b = modify(b)
injection = new_a+39 # illustrate how to add another op to the graph
new_c = update_existing(c, {a:injection, b:new_b})

with tf.Session():
    print(c.eval()) # -> 3
    print(new_c.eval()) # -> 42

必须从文本表示中解析^{}消息,然后才能从中构建^{}。你可以这样做:

import tensorflow as tf
import google.protobuf

node_def_message = r"""name: "Const"
op: "Const"
attr {
  key: "dtype"
  value {
    type: DT_INT32
  }
}
attr {
  key: "value"
  value {
    tensor {
      dtype: DT_INT32
      tensor_shape {
        dim {
          size: 2
        }
      }
      tensor_content: "\001\000\000\000\001\000\000\000"
    }
  }
}"""

# Build NodeDef message
node_def = tf.NodeDef()
# Parse from the string
google.protobuf.text_format.Parse(node_def_message, node_def)
# Build the operation
op = tf.Operation(node_def, tf.get_default_graph())
# Get the output from the operation
c = op.outputs[0]
# Check value
with tf.Session() as sess:
    print(sess.run(c))
    # [1 1]

注意,如果您正在构建的操作需要一些输入,那么您可能需要向^{}传递额外的参数。你知道吗

相关问题 更多 >