百度Warp-CTC的pytorch绑定
torch-baidu-ctc的Python项目详细描述
火炬百度CTC
百度Warp CTC的Pythorch绑定。这些装束的灵感来自 SeanNaren's但是这些包括一些错误修复, 并提供一些附加功能。
importtorchfromtorch_baidu_ctcimportctc_loss,CTCLoss# Activations. Shape T x N x D.# T -> max number of frames/timesteps# N -> minibatch size# D -> number of output labels (including the CTC blank)x=torch.rand(10,3,6)# Target labelsy=torch.tensor([# 1st sample1,1,2,5,2,# 2nd1,5,2,# 3rd4,4,2,3,],dtype=torch.int,)# Activations lengthsxs=torch.tensor([10,6,9],dtype=torch.int)# Target lengthsys=torch.tensor([5,3,4],dtype=torch.int)# By default, the costs (negative log-likelihood) of all samples are summed.# This is equivalent to:# ctc_loss(x, y, xs, ys, average_frames=False, reduction="sum")loss1=ctc_loss(x,y,xs,ys)# You can also average the cost of each sample among the number of frames.# The averaged costs are then summed.loss2=ctc_loss(x,y,xs,ys,average_frames=True)# Instead of summing the costs of each sample, you can perform# other `reductions`: "none", "sum", or "mean"## Return an array with the loss of each individual samplelosses=ctc_loss(x,y,xs,ys,reduction="none")## Compute the mean of the individual lossesloss3=ctc_loss(x,y,xs,ys,reduction="mean")## First, normalize loss by number of frames, later average lossesloss4=ctc_loss(x,y,xs,ys,average_frames=True,reduction="mean")# Finally, there's also a nn.Module to use this loss.ctc=CTCLoss(average_frames=True,reduction="mean",blank=0)loss4_2=ctc(x,y,xs,ys)# Note: the `blank` option is also available for `ctc_loss`.# By default it is 0.
要求
- < C++ > 11编译器(用GCC 4.9测试)。
- python:2.7、3.5、3.6、3.7(使用版本2.7、3.5和3.6进行测试)。
- PyTorch>;=1.1.0(使用版本1.1.0测试)。
- 对于GPU支持:CUDA Toolkit。
安装
假设您 已正确安装所需的库和工具。
安装过程从源代码处编译包,并使用 如果Pythorch有CUDA支持。
来自PYPI(推荐)
pip install torch-baidu-ctc
来自github
git clone --recursive https://github.com/jpuigcerver/pytorch-baidu-ctc.git
cd pytorch-baidu-ctc
python setup.py build
python setup.py install
AVX512相关问题
使用CUDA和更新的主机编译器时可能会出现一些编译问题
使用AVX512指令。请安装gcc 4.9并将其用作主机
NVCC编译器。您可以简单地设置CC
和CXX
环境变量
在生成/安装命令之前:
CC=gcc-4.9 CXX=g++-4.9 pip install torch-baidu-ctc
或者(如果您使用的是github源代码):
CC=gcc-4.9 CXX=g++-4.9 python setup.py build
测试
安装库后,可以使用unittest
进行测试。特别地,
运行以下命令:
python -m unittest torch_baidu_ctc.test
所有测试都应通过(只有在支持时才执行CUDA测试)。