Tensorflow输入形状不兼容

2024-04-27 01:05:52 发布

您现在位置:Python中文网/ 问答频道 /正文

试图建立一个Tensorflow模型,其中我的数据有70个特征。以下是我的第一层的设置:

tf.keras.layers.Dense(units=50, activation='relu', input_shape=(None,70)),

将输入形状设置为(None,70)对我来说似乎是最好的,因为我使用的是前馈神经网络,其中每个“行”的数据都是唯一的。我使用的批量大小(目前)为10。我的输入形状是否应更改为(10,70)

我尝试使用原始(None, 70)并得到错误:

WARNING:tensorflow:Model was constructed with shape (None, None, 70) for input Tensor("dense_33_input:0", shape=(None, None, 70), dtype=float32), but it was called on an input with incompatible shape (10, 70).

TypeError: Input 'y' of 'Mul' Op has type float64 that does not match type float32 of argument 'x'.

input_shape到底出了什么问题感到困惑,因为(None, 70)似乎最合适。非常感谢您的帮助

编辑:希望添加一个可复制的示例以获得更多上下文。对不起,时间太长了。这是[本例][1]的复制,以更好地适应我的当前数据(非图像)

可变自动编码器型号

class VAE(tf.keras.Model):
    
    def __init__(self, latent_dim):
        super(VAE, self).__init__()
        self.latent_dim = latent_dim
        self.encoder = tf.keras.Sequential(
        [
            tf.keras.layers.Dense(units=50, activation='relu', input_shape=(70,)),
            tf.keras.layers.Dense(latent_dim + latent_dim), #No activation
        ])
        
        self.decoder = tf.keras.Sequential(
        [
            tf.keras.layers.Dense(units=50, activation='relu', input_shape=(latent_dim,)),
            tf.keras.layers.Dense(units=70),
        ])
        
    @tf.function
    def sample(self, eps=None):
        if eps is None:
            eps = tf.random.normal(shape=(100, self.latent_dim))
        return self.decode(eps, apply_sigmoid=True)

    def encode(self, x):
        mean, logvar = tf.split(self.encoder(x), num_or_size_splits=2, axis=1)
        return mean, logvar

    def reparameterize(self, mean, logvar):
        eps = tf.random.normal(shape=mean.shape)
        return eps * tf.exp(logvar * .5) + mean

    def decode(self, z, apply_sigmoid=False):
        logits = self.decoder(z)
        if apply_sigmoid:
            probs = tf.sigmoid(logits)
            return probs
        return logits


  [1]: https://www.tensorflow.org/tutorials/generative/cvae

优化器&;丧失信心

optimizer = tf.keras.optimizers.Adam(1e-4)

def log_normal_pdf(sample, mean, logvar, raxis=1):
    log2pi = tf.math.log(2. * np.pi)
    return tf.reduce_sum(
        -.5 * ((sample - mean) ** 2. * tf.exp(-logvar) + logvar + log2pi), axis=raxis)

def compute_loss(model, x):
    mean, logvar = model.encode(x)
    z = model.reparameterize(mean, logvar)
    x_logit = model.decode(z)
    cross_ent = tf.nn.sigmoid_cross_entropy_with_logits(logits=x_logit, labels=x)
    logpx_z = -tf.reduce_sum(cross_ent, axis=[1])
    logpz = log_normal_pdf(z, 0, 0)
    logqz_x = log_normal_pdf(z, mean, logvar)
    return -tf.reduce_mean(logpx_z + logpz - logqz_x)

@tf.function
def train_step(model, x, optimizer):
    with tf.GradientTape() as tape:
        loss = compute_loss(model, x)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

列车

X = tf.random.uniform((100,70))
y = tf.random.uniform((100,))

ds_train = tf.data.Dataset.from_tensor_slices((X, y))

tf.random.set_seed(1)

train = ds_train.shuffle(buffer_size=len(X))
train = train.batch(batch_size=10, drop_remainder=False)

epochs = 5
latent_dim = 2

model = VAE(2)

for epoch in range(1, epochs+1):
    start_time = time.time()
    for i, (train_x, train_y) in enumerate(train):
        train_step(model, train_x, optimizer)
    end_time = time.time()
    
    loss = tf.keras.metrics.Mean()
    for i, (test_x, test_y) in enumerate(ds_test):
        loss(compute_loss(model, test_x))
    elbo = -loss.result()
    display.clear_output(wait=False)
    print('Epoch: {}, Test set ELBO: {}, time elapse for current epoch: {}'
         .format(epoch, elbo, end_time - start_time))

Tags: selfnoneinputmodeltimetfdeftrain
1条回答
网友
1楼 · 发布于 2024-04-27 01:05:52

input_shape不应包括批处理维度。使用input_shape=(70,)

tf.keras.layers.Dense(units=50, activation='relu', input_shape=(70,))

您可以在调用model.fit(..., batch_size=10)时设置批大小。请参阅有关^{}的文档

由于将int32值传递给tf.math.exp,原始帖子中出现了另一个错误。那一行应该是

logpz = log_normal_pdf(z, 0., 0.)

解决这个错误。请注意0.值,它们的计算结果是浮点数而不是整数

相关问题 更多 >