Theano函数抛出ValueError与'givens'属性有关

2 投票
1 回答
1085 浏览
提问于 2025-04-18 06:51

我在使用 theano 函数,并想用 givens 来遍历所有的输入样本。下面是我的代码:

index = T.scalar('index')
train_set = np.array([[0.2, 0.5, 0.01], [0.3, 0.91, 0.4], [0.1, 0.7, 0.22], 
                      [0.7, 0.54, 0.2], [0.1, 0.12, 0.3], [0.2, 0.52, 0.1], 
                      [0.12, 0.08, 0.4], [0.02, 0.7, 0.22], [0.71, 0.5, 0.2], 
                      [0.1, 0.42, 0.63]])
train = function(inputs=[index], outputs=cost, updates=updates, 
                 givens={x: train_set[index]})

结果出现了一个错误:

ValueError: setting an array element with a sequence.

你能告诉我为什么会这样吗?还有怎么解决这个问题?

1 个回答

4

问题是这样的:train_set[index]

这里的 train_set 是一个 numpy 的数组,而 index 是一个 Theano 的变量。NumPy 不知道怎么处理 Theano 的变量。所以你需要把 train_set 转换成 Theano 的变量,比如一个共享变量:

train_set = theano.shared(train_set)

你还需要修改 index 的声明,因为 Theano 不支持用真实的值来作为索引:

index = T.iscalar()

撰写回答