如何将mnist数据转换为RGB格式?

2024-04-20 00:28:53 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在尝试将MNIST数据集转换为RGB格式,每个图像的实际形状是(28,28),但我需要(28,28,3)。在

import numpy as np
import tensorflow as tf

mnist = tf.keras.datasets.mnist
(x_train, _), (x_test, _) = mnist.load_data()

X = np.concatenate([x_train, x_test])
X = X / 127.5 - 1

X.reshape((70000, 28, 28, 1))

tf.image.grayscale_to_rgb(
    X,
    name=None
)

但我得到了以下错误:

^{pr2}$

Tags: 数据test图像importnumpytftensorflowas
3条回答

如果你之前打印X的形状tf.image.grayscale_到\u rgb您将看到输出维度是(70000,28,28)。输入到tf.image.灰度尺寸必须为1作为最终尺寸。在

展开X的最后一个维度,使其与函数兼容

tf.image.grayscale_to_rgb(tf.expand_dims(X, axis=3))

应将重塑后的3D[28x28x1]图像存储在数组中:

X = X.reshape((70000, 28, 28, 1))

转换时,将另一个数组设置为tf.image.grayscale_to_rgb()函数的返回值:

^{pr2}$

最后,用matplotlibtf.session()从得到的张量图像中画出一个例子:

import matplotlib.pyplot as plt

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    image_to_plot = sess.run(image)
    plt.figure()
    plt.imshow(image_to_plot)
    plt.grid(False)

完整代码:


import numpy as np
import tensorflow as tf

mnist = tf.keras.datasets.mnist
(x_train, _), (x_test, _) = mnist.load_data()

X = np.concatenate([x_train, x_test])
X = X / 127.5 - 1

# Set reshaped array to X 
X = X.reshape((70000, 28, 28, 1))

# Convert images and store them in X3
X3 = tf.image.grayscale_to_rgb(
    X,
    name=None
)

# Get one image from the 3D image array to var. image
image = X3[0,:,:,:]

# Plot it out with matplotlib.pyplot
import matplotlib.pyplot as plt

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    image_to_plot = sess.run(image)
    plt.figure()
    plt.imshow(image_to_plot)
    plt.grid(False)

除了@DMolony和@Aqwis01答案外,另一个简单的解决方案是使用numpy.repeat方法将张量的最后一个维度复制几次:

X = X.reshape((70000, 28, 28, 1))
X = X.repeat(3, -1)  # repeat the last (-1) dimension three times
X_t = tf.convert_to_tensor(X)
assert X_t.shape == (70000, 28, 28, 3)

相关问题 更多 >