我有一个形状为(32,5)
的numpy数组batch
。批处理的每个元素由一个numpy数组batch_elem = [s,_,_,_,_]
组成,其中s = [img,val1,val2]
是一个三维numpy数组,_
只是标量值。
img
是一个维度为(84,84,3)
的图像(numpy数组)
我想创建一个形状为(32,84,84,3)
的numpy数组。基本上,我想提取每个batch
中的图像信息,并将其转换为一个4维数组。在
我尝试了以下方法:
b = np.vstack(batch[:,0]) #this yields a b with shape (32,3), type: <class 'numpy.ndarray'>
现在我想访问图像(第二维度的第一个索引)
^{pr2}$如何才能最好地访问图像数据并获得形状(32,84,84,3)
?在
注:
s = b[0] #first s of the 32 in batch: shape (3,) , type: <class 'numpy.ndarray'>
编辑:
这应该是一个最小的例子:
img = np.zeros([5,5,3])
s = np.array([img,1,1])
batch_elem = np.array([s,1,1,1,1])
batch = np.array([batch_elem for _ in range(32)])
假设我正确地理解了这个问题,你可以在最后一个轴上叠加两次。在
res = np.stack(np.stack(batch[:,0])[...,0])
相关问题 更多 >
编程相关推荐