自定义了一个pytorch的输入10张特征图的二分类网络,将10张特征图输入网络的时候踩了很多坑,以此记录一下。
坑一:
pytorch使用Image读取的图片,图片格式为RGB,c++使用opencv读取图片,图片格式为BGR,所以图片要先转换为RGB格式,使用了opencv的函数转换的:
cvtColor(image, img, cv::COLOR_BGR2RGB);
坑二:
读进来的图片进行归一化,使用训练时的均值和方差:
array<float, num> input;//num=channel*w*h
std::size_t counter = 0;
for (unsigned k = 0; k < 3; k++)
for (unsigned ii = 0; ii < image.rows; ii++)
for (unsigned j = 0; j < image.cols; j++)
if (k == 0)
input[counter++] = static_cast<float>((image.at<cv::Vec3b>(ii, j)[k]) / 255.0 - 0.485)/0.229;
else if (k == 1)
input[counter++] = static_cast<float>((image.at<cv::Vec3b>(ii, j)[k]) / 255.0 -0.456) / 0.224;
else if (k == 2)
input[counter++] = static_cast<float>((image.at<cv::Vec3b>(ii, j)[k]) / 255.0 - 0.406) / 0.225;
坑三:获取输入节点名字,最开始最获取到了一个输入,发现是获取的函数不对,正确的应为:
vector<const char*> input_node_names(num_input_nodes);
vector<int64_t> input_node_dims;
for (int i = 0; i < num_input_nodes; i++) {
char* input_name = session.GetInputName(i, allocator);
input_node_names[i] = input_name;
Ort::TypeInfo type_info = session.GetInputTypeInfo(i);
auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
ONNXTensorElementDataType type = tensor_info.GetElementType();
input_node_dims = tensor_info.GetShape();
}
坑四:多输入的array、Value的变量名要不一样,不然存的是最后一张特征图的值,目前还没找到原因,就只能改了不一样的名字
const array<int64_t, 4> inputShape = { 1, numChannles, height, width };
auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeCPU);
std::vector<Ort::Value> ort_inputs;
array<float, numInputElements> input0;
auto inputTensor0 = Ort::Value::CreateTensor<float>(memory_info, input0.data(), input0.size(), inputShape.data(), inputShape.size());
input0 = normalizeImg(img1);//normalizeImg为上述实现的图像归一化
ort_inputs.push_back(std::move(inputTensor0));
array<float, numInputElements> input1;
auto inputTensor1 = Ort::Value::CreateTensor<float>(memory_info, input1.data(), input1.size(), inputShape.data(), inputShape.size());
input1 = normalizeImg(img1);//normalizeImg为上述实现的图像归一化
ort_inputs.push_back(std::move(inputTensor1));
//后8张特征图处理一致……
更新:因为用array造成了内存溢出,所以替换为vector,就需要先进行归一化
vector<float> input0;
input0 = normalizeImg_new(all_img[0]);
auto inputTensor0 = Ort::Value::CreateTensor<float>(memory_info, input0.data(), input0.size(), inputShape.data(), inputShape.size());
ort_inputs.push_back(move(inputTensor0));
其中normalizeImg_new为:
vector<float> normalizeImg_new(cv::Mat& image) {
vector<float> input;
std::size_t counter = 0;
for (unsigned k = 0; k < 3; k++) {
for (unsigned ii = 0; ii < image.rows; ii++) {
for (unsigned j = 0; j < image.cols; j++) {
if (k == 0)
input.push_back(static_cast<float>((image.at<cv::Vec3b>(ii, j)[k]) / 255.0 - 0.485) / 0.229);
else if (k == 1)
input.push_back(static_cast<float>((image.at<cv::Vec3b>(ii, j)[k]) / 255.0 - 0.456) / 0.224);
else if (k == 2)
input.push_back(static_cast<float>((image.at<cv::Vec3b>(ii, j)[k]) / 255.0 - 0.406) / 0.225);
}
}
}
return input;
}
坑五:为了验证输入网络的特征图与pytorch中输入的一致,所以又取出vector<Ort::Value> ort_inputs与Ort::Value inputTensor0 中的值验证
//获取inputTensor的值
const float* data = inputTensor.GetTensorMutableData<float>();
cout << " data:" << data[m] << endl;//m为第几个值
//获取vector<Ort::Value> ort_inputs的值
cout << ort_inputs[5].At<float>({ 0, 0, 0, 0 }) << endl;//四个维度分别为:batchsize,channel,h,w
坑六:结果输出:
std::vector<const char*> output_node_names = { "output" };
auto output_tensors = session.Run(Ort::RunOptions{ nullptr }, input_node_names.data(), ort_inputs.data(), ort_inputs.size(), output_node_names.data(), 1);
float* floatarr = output_tensors.front().GetTensorMutableData<float>();
cout << "fin1:" << floatarr[0] << "fin2:" << floatarr[1] << endl;