头文件:
#include <torch/script.h>
#include <torch/torch.h>
#include <iostream>
#include <typeinfo>
torch::NoGradGuard no_grad;
torch::jit::script::Module module;
torch::Tensor inputC0, inputC1,inRefP0,inRefP1,inRefPutemp, input;
module = torch::jit::load("D:/projects/PycharmProjects/BCW/Train_pth/68.pt");
inputC0 = torch::empty({ swidth, swidth });
inputC1 = torch::empty({ swidth, swidth });
inRefP0 = torch::empty({ swidth, swidth });
inRefP1 = torch::empty({ swidth, swidth });
inRefPutemp = torch::empty({ swidth, swidth });
torch::Tensor inOrg = torch::empty({ swidth, swidth });
torch::fill_(inputC0, DisC0);
torch::fill_(inputC1, DisC1);
Picture* Orgpic = pu.cu->slice->getPic();
CodingStructure& Orgcs = *Orgpic->cs;
PelBuf orgBuf = Orgcs.getOrgBuf(Orgcs.area).get(COMPONENT_Y);
//PelBuf orgBuf = pu.cs->getOrgBuf(COMPONENT_Y);
for (int j = curry - 4, m = 0; j < curry + width; j++, m++)
{
for (int i = currx - 4, n = 0; i < currx + width; i++, n++)
{
inOrg[m][n] = orgBuf.at((i<0)?0:i, (j < 0) ? 0 : j)/4.0;
}
}
module.to(at::kCPU);
int t = 0;
for (int i = 0; i < swidth; i++)
{
for (int j = 0; j < swidth; j++)
{
inRefP0[i][j] = RefP0[t];
inRefP1[i][j] = RefP1[t];
inRefPutemp[i][j] = RefPu[t];
t++;
}
}
t = 0;
torch::Tensor inRefCen = inRefPutemp.index({
torch::indexing::Slice(4,swidth),torch::indexing::Slice(4,swidth)});
inOrg.index({ torch::indexing::Slice(4,swidth),torch::indexing::Slice(4,swidth) }) =
inRefCen;
torch::Tensor inRefPu = inOrg;
//cout << inRefPu << endl;
input = torch::stack({ inputC0,inputC1,inRefP0,inRefP1,inRefPu }, 0);
input = torch::unsqueeze(input, 0);
//Execute the model and turn its output into a tensor.
torch::Tensor output = module.forward({ input }).toTensor();
//cout << output << endl;
output = output.reshape({ 1,(int64_t(swidth) - 4) * (int64_t(swidth) - 4) });
t = 0;
Pel* BCW_dst = pcYuvDst.Y().buf;
Pel* BCW_src = pcYuvDst.Y().buf;
for (unsigned y = 0; y < height; y++)
{
for (unsigned x = 0; x < width; x++)
{
if (BCW_dst != nullptr && BCW_dst != nullptr)
{
float out = output[0][t].item().toFloat();
if (out < 0)
out = 0;
if (out > 255)
out = 255;
//BCW_dst[x] = Pel(round(out));
BCW_dst[x] = Pel(round(out) * 4);
//if (BCW_dst[x] < 0)
// cout << BCW_dst[x] << endl;
//BCW_dst[x] = m_BCWfwdLUT[BCW_src[x]];
t++;
}
}
//cout << endl;
BCW_dst += pcYuvDst.Y().stride;
BCW_src += pcYuvDst.Y().stride;
}