1,下载代码
git clone -b pytorch_bindings https://github.com/SeanNaren/warp-ctc.git
2,编译
cd warp-ctc
mkdir build
cd build
cmake .. # 这里可以自定义pytorch cpp的编译环境
make
-DCMAKE_PREFIX_PATH=/path/to/libtorch/9.0/10.0
3,设置环境变量
vi ~/.bashrc
export WARP_CTC_PATH=/home/rose/software/warp-ctc/build
source ~/.bashrc # 使环境变量生效
4,将warp-ctc添加到conda
./conda env list
xxx
source activate xxx
cd /home/rose/software/warp-ctc/pytorch_binding
python setup.py install
5,调用
conda install pytest
import torch
from warpctc_pytorch import CTCLoss
ctc_loss = CTCLoss()
# expected shape of seqLength x batchSize x alphabet_size
probs = torch.FloatTensor([[[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]]]).transpose(0, 1).contiguous()
labels = torch.IntTensor([1, 2])
label_sizes = torch.IntTensor([2])
probs_sizes = torch.IntTensor([2])
probs.requires_grad_(True) # tells autograd to compute gradients for probs
cost = ctc_loss(probs, labels, probs_sizes, label_sizes)
cost.backward()

本文详细介绍了如何从下载源代码开始,通过编译、设置环境变量、添加到conda环境,最终在PyTorch中使用Warp-CTC库的过程。包括了关键步骤如CMake配置、环境变量导出、激活conda环境及调用CTCLoss函数的示例。
1717





