我是Python新手,我正在使用tensorflow在数据集中加载图像和标签
data_test = pathlib.Path(test_path)
all_test_paths = list(data_test.glob('*/*'))
all_test_paths = [str(path) for path in all_test_paths]
random.shuffle(all_test_paths)
label_names = sorted(item.name for item in data_test.glob('*/') if item.is_dir())
label_to_index = dict((name, index) for index, name in enumerate(label_names))
all_test_labels = [label_to_index[pathlib.Path(path).parent.name]
for path in all_test_paths]
all_test_paths=tf.convert_to_tensor(all_test_paths)
test_path_ds = tf.data.Dataset.from_tensor_slices(all_test_paths)
print(test_path_ds.output_shapes)
当我运行这部分代码时,它返回:()
目前没有回答
相关问题 更多 >
编程相关推荐