我正在尝试将此方法矢量化,我正在使用此方法进行ML图像增强:
def random_erase_from_image(images, random_erasing, image_size):
#could probably be vectorized to speed up
to_return = images
for t in range(images.shape[0]):
if np.random.randint(0, 2) == 0:#do random erasing
x_erase_size = np.random.randint(0, random_erasing)
y_erase_size = np.random.randint(0, random_erasing)
x_erase_start = np.random.randint(0, image_size-x_erase_size)
y_erase_start = np.random.randint(0, image_size-y_erase_size)
shape = to_return[t, y_erase_start:y_erase_start+y_erase_size, x_erase_start:x_erase_start+x_erase_size, :].shape
print(shape)
to_return[t, y_erase_start:y_erase_start+y_erase_size, x_erase_start:x_erase_start+x_erase_size, :] = (np.random.random(shape) * 255).astype('uint8')
return images
这是我所能做到的,但我不知道如何正确地切割
def random_erase_vec(images, random_erasing, image_size):
#could probably be vectorized to speed up
to_return = images
mask = np.random.choice(a=[False, True], size=images.shape[0], p=[.5, .5])
x_erase_size = np.random.randint(0, random_erasing, size=images.shape[0])
y_erase_size = np.random.randint(0, random_erasing, size=images.shape[0])
x_erase_start = np.random.randint(0, image_size-x_erase_size, size=images.shape[0])
y_erase_start = np.random.randint(0, image_size-y_erase_size, size=images.shape[0])
random_values = (np.random.random((images.shape))* 255).astype('uint8')
to_return[:, [y_erase_start[:]]:[y_erase_start[:]+y_erase_size[:]], [x_erase_start[:]]:[x_erase_start[:]+x_erase_size[:]], :] = random_values[:, [y_erase_start[:]]:[y_erase_start[:]+y_erase_size[:]], [x_erase_start[:]]:[x_erase_start[:]+x_erase_size[:]], :]
return images
我试图避免重塑,但如果这是需要的,我想它会做的。让我知道你能想到的任何加速原始方法的方法
我在切片线上遇到以下错误: “切片索引必须为整数或无,或具有索引方法”
我还想遮罩,所以不是所有的图像都是随机擦除的,但我想在切片部分完成后这样做
谢谢你的帮助
编辑:示例输入:
图像:大小为[#的图像、高度(32)、宽度(32)、通道(3)的numpy阵列
random_Erasting(随机擦除):名称不好,但要擦除的任意维图像的最大大小。当前设置为20
image_size:现在我想可能是从images数组中得到的,但是清理还不是优先事项
我稍微整理了一下你的函数,并尝试对它进行部分矢量化,但由于你想改变随机补丁的大小,这有点复杂
速度并不惊人(大约20%),因为这是在ML中进行预处理,所以最好的办法是使用更多的工作人员来准备数据,以便充分利用GPU
编辑: 是的,我使用.copy()来确保参数不会在例程之外发生变异。如果你愿意,你可以忽略这个
我在[tensorflow文档]中使用术语worker:(https://www.tensorflow.org/api_docs/python/tf/keras/Model)
相关问题 更多 >
编程相关推荐