关于n<<p数据集中的批大小与Keras中的批大小>n的混淆

2024-03-28 16:23:01 发布

您现在位置:Python中文网/ 问答频道 /正文

我对运行良好的代码有点困惑,但我不明白为什么。你知道吗

我在一个n<;<;p场景中,功能比示例多2倍。我不小心将批大小设置为一个值>;n

那么keras的默认行为是什么呢?只是退回到纯梯度下降,在一个时代结束时平均权重?你知道吗

我在有监督的二进制分类器设置以及基于LSTM/自动编码器的无监督异常检测器中使用这个设置

这增加了一个额外的混淆,因为我认为在LSTM-case中n%batch\u size应该是零。你知道吗


Tags: 代码ltgt功能示例分类器场景二进制
1条回答
网友
1楼 · 发布于 2024-03-28 16:23:01

我已经深入研究了源代码,我想我已经找到了这些问题的答案。你知道吗

  1. 如果批量大小>n,Keras是否“退化”为普通梯度下降?你知道吗

答案是肯定的。从第334行开始的方法batch_shuffle可以看出(注意:我链接到V2.2以保留行号),如果batch\u size>;n则返回整个批。以下是相关代码和输出:

import numpy as np
index_array = np.array([0,1,2,3,4,5])
batch_size = 72
batch_count = int(len(index_array) / batch_size)
#batch_count = 0

last_batch = index_array[batch_count * batch_size:]
# last_batch = array([0, 1, 2, 3, 4, 5])

index_array = index_array[:batch_count * batch_size]
#index_array = array([], dtype=int64)

index_array = index_array.reshape((batch_count, batch_size))
#index_array =array([], shape=(0, 72), dtype=int64)

np.random.shuffle(index_array)
index_array = index_array.flatten()
return np.append(index_array, last_batch)
# np.append(index_array, last_batch) = array([0, 1, 2, 3, 4, 5])
  1. 为什么即使n%batch\u size<;gt;0,LSTM也能工作?你知道吗

n%batch\u size=0的要求仅适用于有状态的LSTM,在下面的code,第817行中再次找到了证据,只有在有状态==True和n%batch\u size<;>;0时才会出现错误

相关问题 更多 >