我正在编写一个程序来构造一个GAN,其中generator类接受一个tf模块,需要进行所有初始化,但是当初始化并最终调用这个类时,我面临一个传递额外参数的错误(我在下面发布的完整错误)
class Generator(tf.Module):
def __init__(self, noise_size, condition_size, generator_latent_size, cell_type, mean=0, std=1):
super().__init__()
self.noise_size = noise_size
self.condition_size = condition_size
self.generator_latent_size = generator_latent_size
self.mean = mean
self.std = std
if cell_type == "lstm":
self.cond_to_latent = tf.keras.layers.LSTM(generator_latent_size)
else:
self.cond_to_latent = tf.keras.layers.GRU(generator_latent_size)
self.model = tf.keras.Sequential(
tf.keras.layers.InputLayer(input_shape=((generator_latent_size + self.noise_size),)),
tf.keras.layers.Dense(generator_latent_size + self.noise_size),
tf.keras.layers.ReLU(),
tf.keras.layers.Dense(1)
)
def forward(self, noise, condition):
condition = (condition - self.mean) / self.std
condition = condition.view(-1, self.condition_size, 1)
condition = condition.transpose(0, 1)
condition_latent, _ = self.cond_to_latent(condition)
condition_latent = condition_latent[-1]
g_input = tf.concat((condition_latent, noise), dim=1)
output = self.model(g_input)
output = output * self.std + self.mean
return output
def get_noise_size(self):
return self.noise_size
当调用is生成器对象时,我在内部方法包装器中得到一个错误
" Traceback (most recent call last):
File "forgan.py", line 185, in <module>
forgan = ForGAN(opt)
File "forgan.py", line 35, in __init__
std=opt.data_std)
File "C:\Users\mura_ab\PycharmProjects\ForGAN\components.py", line 23, in __init__
tf.keras.layers.Dense(1)
File "C:\Users\mura_ab\Anaconda3\envs\Plygrnd\lib\site-packages\tensorflow_core\python\training\tracking\base.py", line 457, in _method_wrapper
result = method(self, *args, **kwargs)
TypeError: __init__() takes from 1 to 3 positional arguments but 5 were given"
这是正在初始化的生成器对象:
class ForGAN:
def __init__(self, opt):
self.opt = opt
self.device = tf.device("cuda:0") if tf.test.is_gpu_available() else tf.device("cpu")
print("***** Hyper-parameters *****")
for k, v in vars(opt).items():
print("{}:\t{}".format(k, v))
print("************************")
# Making required directories for logging, plots and models' checkpoints
os.makedirs("./{}/".format(self.opt.dataset), exist_ok=True)
# Defining GAN components
self.generator = Generator(noise_size=opt.noise_size,
condition_size=opt.condition_size,
generator_latent_size=opt.generator_latent_size,
cell_type=opt.cell_type,
mean=opt.data_mean,
std=opt.data_std)
终于有人打电话来了
forgan = ForGAN(opt)
有人能告诉我是否有解决这个问题的方法吗
一般来说,您应该密切注意错误消息的内容。我不会马上回答解决你的错误,我会告诉你我是如何发现的,希望你能在下一个错误中自己解决
逐一读取错误的所有信息:
错误中的文本清楚地表明,您使用了5个参数来初始化对象。看到步骤3,很明显您没有使用5个参数初始化密集层。但是,您似乎正在用4个参数初始化序列模型。。。如果您添加Python对象初始化中始终存在的隐藏
self
参数,那么它总共会添加5个参数!在初始化顺序模型时,可能您做错了什么为了证实这一点,您应该看看官方API或其他一些官方指南。通过谷歌搜索,你可以很容易地找到this API和this guide。在API中,您可以看到初始化不需要额外的参数和2个额外的参数(1到3个,包括
self
)。这正是你的错误所说的!在指南中,您可以看到如何正确使用它的示例。看起来应该在容器(列表或元组)中传递层因此,这应该可以解决这个问题(注意额外的方括号将所有5个参数转换为一个列表):
相关问题 更多 >
编程相关推荐