我以前用过发电机,但现在用起来有些困难,我不知道我做错了什么
这是我的发电机
def batch_gen():
# each sample has shape (2, 28, 28), so each batch_x has (batch_size, 2, 28, 28) and batch_y has (batch_size)
while True:
x1, x2, y = [],[],[]
if np.random.random() < .5:
# match
num = (int)(np.random.random() * 10)
index_1,index_2 = np.random.randint(0, high=len(organized[num])),np.random.randint(0, high=len(organized[num]))
el_1,el_2 = organized[num][index_1],organized[num][index_2]
x1.append(el_1)
x2.append(el_2)
y.append(1)
else:
num1 = (int)(np.random.random() * 10)
num2 = (int)(np.random.random() * 10)
while num2==num1:
num2 = (int)(np.random.random() * 10)
index_1,index_2 = np.random.randint(0, high=len(organized[num1])),np.random.randint(0, high=len(organized[num2]))
el_1, el_2 = organized[num1][index_1],organized[num2][index_2]
x1.append(el_1)
x2.append(el_2)
y.append(0)
yield [np.array(x1), np.array(x2)], np.array(y)
然后,当我调用model.fit(batch_gen, batch_size=64, epochs=20, callbacks=[tf.keras.callbacks.EarlyStopping()])
时,我得到以下错误消息:ValueError: Failed to find data adapter that can handle input: <class 'function'>, <class 'NoneType'>
我摆弄了一下yield语句,我试着把2个输入和1个输出放在一个元组中,我试着用numpy数组和不用numpy数组,我只是不知道我做错了什么,希望能得到任何帮助
目前没有回答
相关问题 更多 >
编程相关推荐