如何在Keras上格式化培训输入和输出数据

2024-05-23 23:44:53 发布

您现在位置:Python中文网/ 问答频道 /正文

我是一个新的深度学习和我的一些数据格式的Keras。我的CNN是基于A.Newell等人的Stacked Hourglass Networks for Human Pose Estimation。在

在这个网络上,输入是一个256x256 RGB图像,输出应该是一个64x64热图,突出显示身体关节(肩部、膝盖等)。我设法建立网络,我有所有的数据(图像)及其注释(身体关节的像素标签)。我想知道如何格式化训练集的输入和输出数据来训练我的模型。目前我使用一个numpy数组(256256,3)作为图像,但我不知道如何格式化输出。我应该创建一个表[n,64,64,7]?(n是训练集的大小,7是我用来获得7个关节热图的过滤器数量)

谢谢你抽出时间。在


Tags: 数据图像网络forcnnkerashourglass热图
2条回答

输出也可以是numpy数组。 考虑这个例子: 训练集:50张256x256x3的图像。这可以组合成一个单一的numpy数组(50256,256,3)。 类似的方法来格式化输出数据。 示例代码如下:

    #a, b and c are arrays of size 256x256x3
    import numpy as np

    temp = []
    temp.append(a)
    temp.append(b)
    temp.append(c)
    output_labels = []
    output_labels = np.stack(temp)

输出标签数组的形状为(3x256x256x3)。在

Keras建议创建数据生成器,以向网络提供训练数据和地面实况。 具体到堆叠沙漏网络案例,您可以参考我的实现来了解详细信息https://github.com/yuanyuanli85/Stacked_Hourglass_Network_Keras/tree/master/src/data_gen

相关问题 更多 >