我想从模型中删除输入变量,以使其自动回归(即在其上应用展开)
下面有一个简单的例子。你知道吗
def auto_regressive(model):
@C.Function
def unfold(seq_start, dyn_axis: Sequence[Tensor[121]]):
unfoldfrom = UnfoldFrom(lambda x: model(x))
return unfoldfrom(initial_state=seq_start, dynamic_axes_like=dyn_axis)
return unfold
a = C.sequence.input_variable(121, name='input_tensor') # line 1
model = Recurrence(LSTM(121, name='LSTM'), name="Recur")(a) # line 2
# Do something to remove input_variable from model here
output_tensor = auto_regressive(model)(C.Constant(0, shape=121), a) # line 3
最后一行将引发异常,因为在第2行中,我将输入变量a输入到模型中。如果我把它排除在外,那么它就会执行得很好。你知道吗
我需要删除input\u变量的原因是我有一个经过预训练的模型,我想删除input\u变量。你知道吗
目前没有回答
相关问题 更多 >
编程相关推荐