张量流,在多处理中更新权重

2024-03-29 06:00:51 发布

您现在位置:Python中文网/ 问答频道 /正文

我定义了一个网络,每个范围包含每个进程的权重,每个进程分配其相应的权重,下面是我的演示代码

from multiprocessing import Process

import tensorflow as tf


def init_network(name):
    with tf.name_scope(name):
        x = tf.Variable(int(name))
        return x


def f(name, sess):
    print('step into f()')
    vars = tf.trainable_variables(name)
    print(sess.run(vars[0]))
    sess.run(vars[0].assign(int(name)+10))


if __name__ == '__main__':
    sess = tf.Session()
    x1 = init_network('1')
    x2 = init_network('2')
    sess.run(tf.global_variables_initializer())
    p1 = Process(target=f, args=('1', sess))
    p2 = Process(target=f, args=('2', sess))

    p1.start()
    p2.start()

    p1.join()
    p2.join()
    print(sess.run([x1, x2]))

演示代码卡住了,sess似乎不能在不同的进程中共享,如何在多处理设置中更新权重?在


Tags: run代码nameimport进程inittfnetwork
1条回答
网友
1楼 · 发布于 2024-03-29 06:00:51

在google上搜索了一段时间后,我发现multiprocessing不适用于TensorFlow,因此,我改为使用threading。在

from threading import Thread

import tensorflow as tf

def init_network(name):
    with tf.name_scope(name):
        x = tf.Variable(int(name))
        return x

def f(name, sess):
    with sess.as_default(), sess.graph.as_default():
        print('step into f()')
        vars = tf.trainable_variables(name)
        print(vars)
        sess.run(vars[0].assign(int(name)+10))
        print(sess.run(vars[0]))


if __name__ == '__main__':
    sess = tf.Session()
    coord = tf.train.Coordinator()

    x1 = init_network('1')
    x2 = init_network('2')
    sess.run(tf.global_variables_initializer())
    print(sess.run([x1, x2]))

    p1 = Thread(target=f, args=('1', sess))
    p2 = Thread(target=f, args=('2', sess))
    p1.start()
    p2.start()
    coord.join([p1, p2])
    print(sess.run([x1, x2]))

它现在起作用了,默认会话是当前线程的属性。{3>如果您希望在会话中显式地添加一个线程,则必须在该线程中显式地添加一个线程。并且必须显式地输入一个with sess.graph.as_default():块,使sess.graph成为默认图形。在

tf.train.Coordinator连接线程非常方便。也可以使用thread.join()方法连接线程。在

相关问题 更多 >