在类中处理tensorflow会话

2024-04-26 21:29:15 发布

您现在位置:Python中文网/ 问答频道 /正文

我用张量流来预测神经网络的输出。我有一节课,我描述了神经网络,我有一个主文件,其中预测正在进行,并根据结果,权重被更新。然而,这些预测似乎非常缓慢。我的代码是这样的:

class NNPredictor():
    def __init__(self):
        self.input = tf.placeholder(...)
        ...
        self.output = (...) #Neural network output
    def predict_output(self, sess, input):
        return sess.run(tf.squeeze(self.output), feed_dict = {self.input: input})

以下是主文件的外观:

^{pr2}$

但是,如果我在类中使用以下函数定义:

    def predict_output(self):
        return self.output

主文件如下:

sess = tf.Session()
predictor = NNPredictor()

input = #some initial value 
output_op = predictor.predict_value()
for i in range(iter):
    output = np.squeeze(sess.run(output_op, feed_dict = {predictor.input: input}))
    input = #some function of output

代码运行速度快了20-30倍。我似乎不明白这里的情况如何,我想知道最好的做法是什么。在


Tags: 文件runselfinputoutputreturntfdef
1条回答
网友
1楼 · 发布于 2024-04-26 21:29:15

这与Python屏蔽的底层内存访问有关。下面是一些示例代码来说明这个想法:

import time

runs = 10000000

class A:
    def __init__(self):
    self.val = 1

    def get_val(self):
    return self.val

# Using method to then call object attribute
obj = A()
start = time.time()
total = 0
for i in xrange(runs):
    total += obj.get_val()
end = time.time()
print end - start

# Using object attribute directly
start = time.time()
total = 0
for i in xrange(runs):
    total += obj.val
end = time.time()
print end - start

# Assign to local_var first
start = time.time()
total = 0
local_var = obj.get_val()
for i in xrange(runs):
    total += local_var
end = time.time()
print end - start

在我的电脑上,它按以下时间运行:

^{pr2}$

具体到您的情况,您在第一种情况下调用object方法,但在第二种情况下不调用它。如果以这种方式多次调用代码,那么性能会有很大的差异。在

相关问题 更多 >