如何从字典值返回对象?

2024-05-23 23:03:05 发布

您现在位置:Python中文网/ 问答频道 /正文

我有几个类,每个类都有自己的神经网络结构。你知道吗

基于用户输入的标志rx_flag,我尝试在驱动程序文件中检索特定的体系结构。你知道吗

我有两个问题:

  1. 我不能用通常的方法编词典。以下格式无效:
    def build_model(rx_flag):
        switcher = {
            'xss': XSS().get_model(),
            'rss': RSS().get_model()
        }
        return switcher.get(rx_flag)
    
  2. 经过一些研究,我终于能够至少构造一个字典,其中这些类中的模型作为值存储,但是当我将它们返回到main()方法时,我得到了NoneType

这是我的课。其他类也有类似的模板。我已经注释掉了__hash__()__eq__()的实现,因为没有它,在字典中存储似乎也可以正常工作。你知道吗

from model import Model
from keras.layers import Dense
from keras.models import Sequential

class XSS(Model):

    def __init__(self):
        self.num_layers = 2
        self.input_dim = 3
        self.output_dim = 1
        self.architecture = [64, 32]
        self.model = Sequential()

    def get_model( self , arch=[64, 32]):
        # add input layer
        self.model.add(Dense(arch[0], activation='relu', input_shape=(self.input_dim, )))

        # add intermediate layers
        for i in range(1, self.num_layers):
            self.model.add(Dense(arch[i], activation='relu'))

        # add output layer
        self.model.add(Dense(self.output_dim, activation='linear'))
        return self.model

    def get_name( self ):
        return 'xss'

    def get_value( self ):
        return self.__value()

    def __value( self ):
        return (self.model, self.num_layers, self.input_dim, self.output_dim, self.architecture)

    # def __hash__(self):
    #     return (self.hash(self.__value()))
    #
    # def __eq__(self, other):
    #     if isinstance(other, XSS):
    #         return self.__value() == other.__value()
    #     return NotImplemented

这是驱动程序代码:

import sys

from model import Model
from xss import XSS

def build_model(rx_flag):
    switcher = {}
    obX = XSS()
    switcher[obX.get_name()] = obX.get_model()
    obR = RSS()
    switcher[obR.get_name()] = obR.get_model()
    print(switcher)
    return switcher.get(rx_flag)

if __name__ == '__main__':
    rx_flag = sys.argv[0]
    # create a model instance based on flag
    model = build_model(rx_flag)
    model.summary()

这就是我在尝试model.summary()时遇到的错误。你知道吗

Traceback (most recent call last):
File "C:/Users/path/driver.py", line 19, in <module>
    model.summary()
AttributeError: 'NoneType' object has no attribute 'summary'

如何以一种更为python的方式构建字典,并让它返回实际的模型?你知道吗


Tags: fromimportselfaddgetmodelreturnvalue
1条回答
网友
1楼 · 发布于 2024-05-23 23:03:05

@juanpa.arrivillaga建议rxèu标志不是我认为的那样。他们是对的。你知道吗

当我初始化rx_flag时,即使在PyCharm上,代码也可以正常工作,如下所示:

rx_flag = sys.argv[1]

当我在PyCharm的run配置中输入它时,我的印象是它是第一个参数。你知道吗

相关问题 更多 >