tensorflowkeras挂在predi上

2024-03-28 22:59:34 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在运行一个脚本,它使多个模型并行工作,但是在使用predict时所有模型都挂起(如print语句所示)。 它确实可以处理线程而不是多处理,但是我不能使用它,因为这是一个CPU限制的问题,而且我没有得到加速。如果我用tensorflow 1.14而不是2.0运行,也会出现多个值错误。在

from tensorflow.keras.layers import Dense, Input
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
import numpy as np
import multiprocessing
from time import time
import threading
import tensorflow as tf
import gym

counter = 0
TOTAL = 10_000
env = gym.make("CartPole-v0")
INPUT_SHAPE = [*env.observation_space.shape]
ACTION_SPACE_SIZE = env.action_space.n
ACTION_SPACE = np.array(np.identity(ACTION_SPACE_SIZE, dtype=int).tolist())
LEARNING_RATE =  0.0005

class Worker():
    def __init__(self, number, environment, input_ = INPUT_SHAPE, action_size_=ACTION_SPACE_SIZE, learning_rate=LEARNING_RATE):
        self.env = environment
        self.lr=learning_rate
        self.input_shape = input_
        self.action_size = action_size_
        self.graph = tf.compat.v1.get_default_graph()
        self.session = tf.compat.v1.Session()
        self.network = self.create_net()
        self.name = 'no_' +str(number)

    def create_net(self, dense_params=[256]):
        with self.graph.as_default():
            with self.session.as_default():
                inputs_ = Input(shape=self.input_shape)
                out = inputs_

                for units in dense_params:
                    out = Dense(units=units, activation='relu')(out)    #logits

                policy_ = Dense(units=self.action_size, activation='softmax')(out)

                model = Model(inputs=inputs_, outputs=policy_)
                opt = Adam(lr=self.lr)

                model.compile(optimizer=opt, loss='categorical_crossentropy')
        return model

    def get_action(self, state):
       with self.graph.as_default():
           with self.session.as_default():
               state = np.expand_dims(state, axis=0)
               print("RIGHT HERE")
               action_prob_dist = self.network.predict(state)[0]
               action_index = np.random.choice(self.action_size,p=action_prob_dist)
               action_vector = ACTION_SPACE[action_index]
       return action_index, action_vector

    def work(self, counter):
        state = self.env.reset()
        step = 0
        print("Into :", self.name)
        action_vector = []

        while counter.value<TOTAL:
            if step % 1000 == 0:
                print(f"work from {self.name}")
                print("Step no: " + str(counter.value))
                print(action_vector)

            action_index, action_vector = self.get_action(state)
            counter.value +=1
            step += 1

num_workers = 2
jobs = []
envs = [gym.make('CartPole-v0') for i in range(num_workers)]
counter = multiprocessing.Value('i',0)

workers = [Worker(number=i, environment=envs[i]) for i in range(num_workers)]
for worker in workers:
    work = lambda :worker.work(counter)
    job = multiprocessing.Process(target=work)     # uncomment for multiprocess execution
    #job = threading.Thread(target=work)   # uncomment for thread execution
    jobs.append(job)
    job.start()

try:
    [t.join() for t in jobs]
except KeyboardInterrupt:
    print("Exiting all threads...")

我希望它经过所有的迭代(print(“RIGHT HERE”注释掉)它将打印:

^{pr2}$

我得到的是:

Into : no_0
work from no_0
Step no: 0
[]
RIGHT HERE
Into : no_1
work from no_1
Step no: 0
[]
RIGHT HERE

它就挂在那里。在


Tags: nofromimportselfenvfortensorflowas