我有几个类,每个类都有自己的神经网络结构。你知道吗
基于用户输入的标志rx_flag
,我尝试在驱动程序文件中检索特定的体系结构。你知道吗
我有两个问题:
def build_model(rx_flag):
switcher = {
'xss': XSS().get_model(),
'rss': RSS().get_model()
}
return switcher.get(rx_flag)
main()
方法时,我得到了NoneType
。这是我的课。其他类也有类似的模板。我已经注释掉了__hash__()
和__eq__()
的实现,因为没有它,在字典中存储似乎也可以正常工作。你知道吗
from model import Model
from keras.layers import Dense
from keras.models import Sequential
class XSS(Model):
def __init__(self):
self.num_layers = 2
self.input_dim = 3
self.output_dim = 1
self.architecture = [64, 32]
self.model = Sequential()
def get_model( self , arch=[64, 32]):
# add input layer
self.model.add(Dense(arch[0], activation='relu', input_shape=(self.input_dim, )))
# add intermediate layers
for i in range(1, self.num_layers):
self.model.add(Dense(arch[i], activation='relu'))
# add output layer
self.model.add(Dense(self.output_dim, activation='linear'))
return self.model
def get_name( self ):
return 'xss'
def get_value( self ):
return self.__value()
def __value( self ):
return (self.model, self.num_layers, self.input_dim, self.output_dim, self.architecture)
# def __hash__(self):
# return (self.hash(self.__value()))
#
# def __eq__(self, other):
# if isinstance(other, XSS):
# return self.__value() == other.__value()
# return NotImplemented
这是驱动程序代码:
import sys
from model import Model
from xss import XSS
def build_model(rx_flag):
switcher = {}
obX = XSS()
switcher[obX.get_name()] = obX.get_model()
obR = RSS()
switcher[obR.get_name()] = obR.get_model()
print(switcher)
return switcher.get(rx_flag)
if __name__ == '__main__':
rx_flag = sys.argv[0]
# create a model instance based on flag
model = build_model(rx_flag)
model.summary()
这就是我在尝试model.summary()
时遇到的错误。你知道吗
Traceback (most recent call last):
File "C:/Users/path/driver.py", line 19, in <module>
model.summary()
AttributeError: 'NoneType' object has no attribute 'summary'
如何以一种更为python的方式构建字典,并让它返回实际的模型?你知道吗
@juanpa.arrivillaga建议rxèu标志不是我认为的那样。他们是对的。你知道吗
当我初始化
rx_flag
时,即使在PyCharm上,代码也可以正常工作,如下所示:当我在PyCharm的run配置中输入它时,我的印象是它是第一个参数。你知道吗
相关问题 更多 >
编程相关推荐