擅长:python、mysql、java
<p>非常感谢Nikos Karampatziakis。在</p>
<p>如果您希望有一个随机采样解码器,它生成与目标序列长度相同的序列,那么下面的代码可以工作。在</p>
<pre><code>@C.Function
def sampling(x):
noisy_x = x + C.random.gumbel_like(x)
return C.hardmax(noisy_x)
def create_model_sampling(s2smodel):
@C.Function
@C.layers.Signature(input=InputSequence[C.layers.Tensor[input_vocab_dim]],
labels=LabelSequence[C.layers.Tensor[label_vocab_dim]])
def model_sampling(input, labels): # (input*) > (word_sequence*)
unfold = C.layers.UnfoldFrom(lambda history: s2smodel(history, input) >> sampling,
length_increase=1)
return unfold(initial_state=sentence_start, dynamic_axes_like=labels)
return model_sampling
</code></pre>