Tensorflow v2生成器值错误:找不到可以处理输入的数据适配器:<class'function'>,<class'NoneType'>

2024-06-07 19:10:49 发布

您现在位置:Python中文网/ 问答频道 /正文

我以前用过发电机,但现在用起来有些困难,我不知道我做错了什么

这是我的发电机

def batch_gen():
    # each sample has shape (2, 28, 28), so each batch_x has (batch_size, 2, 28, 28) and batch_y has (batch_size)
    while True:
        x1, x2, y = [],[],[]
        if np.random.random() < .5:
            # match
            num = (int)(np.random.random() * 10)
            index_1,index_2 = np.random.randint(0, high=len(organized[num])),np.random.randint(0, high=len(organized[num]))
            el_1,el_2 = organized[num][index_1],organized[num][index_2]
            x1.append(el_1)
            x2.append(el_2)
            y.append(1)
        else:
            num1 = (int)(np.random.random() * 10)
            num2 = (int)(np.random.random() * 10)
            while num2==num1:
                num2 = (int)(np.random.random() * 10)
            index_1,index_2 = np.random.randint(0, high=len(organized[num1])),np.random.randint(0, high=len(organized[num2]))
            el_1, el_2 = organized[num1][index_1],organized[num2][index_2]
            x1.append(el_1)
            x2.append(el_2)
            y.append(0)
        yield [np.array(x1), np.array(x2)], np.array(y)

然后,当我调用model.fit(batch_gen, batch_size=64, epochs=20, callbacks=[tf.keras.callbacks.EarlyStopping()])时,我得到以下错误消息:ValueError: Failed to find data adapter that can handle input: <class 'function'>, <class 'NoneType'>

我摆弄了一下yield语句,我试着把2个输入和1个输出放在一个元组中,我试着用numpy数组和不用numpy数组,我只是不知道我做错了什么,希望能得到任何帮助


Tags: indexlennpbatchrandomelnumint

热门问题