VTM调用libtorch

这段代码展示了如何加载并执行一个预先训练的PyTorch模型。首先,它导入必要的库,然后加载模型权重。接着,初始化输入张量并填充数据。之后,模型被转换到CPU上运行,并对输入数据进行前向传播以获取输出。最后,将输出结果应用于图像处理任务,调整范围并保存结果。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

头文件:

#include <torch/script.h>
#include <torch/torch.h>
#include <iostream>
#include <typeinfo>

    torch::NoGradGuard no_grad;
    torch::jit::script::Module module;
    torch::Tensor inputC0, inputC1,inRefP0,inRefP1,inRefPutemp, input;

    module = torch::jit::load("D:/projects/PycharmProjects/BCW/Train_pth/68.pt");

    inputC0 = torch::empty({ swidth, swidth });
    inputC1 = torch::empty({ swidth, swidth });
    inRefP0 = torch::empty({ swidth, swidth });
    inRefP1 = torch::empty({ swidth, swidth });
    inRefPutemp = torch::empty({ swidth, swidth });
    torch::Tensor inOrg = torch::empty({ swidth, swidth });
    torch::fill_(inputC0, DisC0);
    torch::fill_(inputC1, DisC1);
    Picture* Orgpic = pu.cu->slice->getPic();
    CodingStructure& Orgcs = *Orgpic->cs;
    PelBuf orgBuf = Orgcs.getOrgBuf(Orgcs.area).get(COMPONENT_Y);
    //PelBuf orgBuf = pu.cs->getOrgBuf(COMPONENT_Y);
    for (int j = curry - 4, m = 0; j < curry + width; j++, m++)
    {
      for (int i = currx - 4, n = 0; i < currx + width; i++, n++)
      {
        inOrg[m][n] = orgBuf.at((i<0)?0:i, (j < 0) ? 0 : j)/4.0;
      }
    }
    module.to(at::kCPU);
    int t = 0;
    for (int i = 0; i < swidth; i++)
    {
      for (int j = 0; j < swidth; j++)
      {
        inRefP0[i][j] = RefP0[t];
        inRefP1[i][j] = RefP1[t];
        inRefPutemp[i][j] = RefPu[t];
        t++;
      }
    }
    t = 0;
    torch::Tensor inRefCen = inRefPutemp.index({             
    torch::indexing::Slice(4,swidth),torch::indexing::Slice(4,swidth)});
    inOrg.index({ torch::indexing::Slice(4,swidth),torch::indexing::Slice(4,swidth) }) = 
inRefCen;
    torch::Tensor inRefPu = inOrg;
    //cout << inRefPu << endl;
    input = torch::stack({ inputC0,inputC1,inRefP0,inRefP1,inRefPu }, 0);
    input = torch::unsqueeze(input, 0);
    //Execute the model and turn its output into a tensor.
    torch::Tensor output = module.forward({ input }).toTensor();
    //cout << output << endl;
    output = output.reshape({ 1,(int64_t(swidth) - 4) * (int64_t(swidth) - 4) });
    t = 0;
    Pel* BCW_dst = pcYuvDst.Y().buf;
    Pel* BCW_src = pcYuvDst.Y().buf;
    for (unsigned y = 0; y < height; y++)
    {
      for (unsigned x = 0; x < width; x++)
      {
        if (BCW_dst != nullptr && BCW_dst != nullptr)
        {
          float out = output[0][t].item().toFloat();
          if (out < 0)
            out = 0;
          if (out > 255)
            out = 255;
          //BCW_dst[x] = Pel(round(out));
          BCW_dst[x] = Pel(round(out) * 4);
          //if (BCW_dst[x] < 0)
          //  cout << BCW_dst[x] << endl;
          //BCW_dst[x] = m_BCWfwdLUT[BCW_src[x]];
          t++;
        }
      }
      //cout << endl;
      BCW_dst += pcYuvDst.Y().stride;
      BCW_src += pcYuvDst.Y().stride;
    }

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值