cuda-warp递归神经校准器的pytorch绑定
warp-rna的Python项目详细描述
CUDA曲速递归神经校准器的Pythorch绑定
defrna_loss(log_probs,# type: torch.FloatTensorlabels,# type: torch.IntTensorframes_lengths,# type: torch.IntTensorlabels_lengths,# type: torch.IntTensoraverage_frames=False,# type: boolreduction=None,# type: Optional[AnyStr]blank=0,# type: int):"""The CUDA-Warp Recurrent Neural Aligner loss. Args: log_probs (torch.Tensor): Input tensor (float) with shape (T, N, U, V) where T is the maximum number of input frames, N is the minibatch size, U is the maximum number of output labels and V is the vocabulary of labels (including the blank). labels (torch.IntTensor): Tensor with shape (N, U-1) representing the reference labels for all samples in the minibatch. frames_lengths (torch.IntTensor): Tensor with shape (N,) representing the number of frames for each sample in the minibatch. labels_lengths (torch.IntTensor): Tensor with shape (N,) representing the length of the transcription for each sample in the minibatch. average_frames (bool, optional): Specifies whether the loss of each sample should be divided by its number of frames. Default: ``False''. reduction (string, optional): Specifies the type of reduction. Default: None. blank (int, optional): label used to represent the blank symbol. Default: 0. """# type: (...) -> torch.Tensor
要求
- python:3.5、3.6、3.7(用3.6版测试)。
- PyTorch>;=1.0.0(使用版本1.1.0测试)。
- CUDA Toolkit(使用10.0版测试)。
安装
目前,没有该包的编译版本。以下安装说明从本地源代码编译包。
来自Pypi
pip install warp_rna
来自github
git clone https://github.com/1ytic/warp-rna
cd warp-rna/pytorch_binding
python setup.py install