NCNN-Option_Paramdict
路径:src/Option.h 和 src/Option.cpp, src/paramdict.h 和 src/paramdict.cpp
Option是一个简单的配置类,Paramdic是网络层的配置参数,不同层有不同配置参数
源码
仅包含一些bool或int类型的变量,构造函数配置了默认选项
class NCNN_EXPORT Option {
public:
Option(); // 默认的配置项
public:
bool lightmode; // light mode
int num_threads; // thread count
Allocator* blob_allocator;
Allocator* workspace_allocator;
int openmp_blocktime;
bool use_winograd_convolution;
bool use_sgemm_convolution;
bool use_int8_inference; enable quantized int8 inference
...
};
class NCNN_EXPORT ParamDict {
public:
ParamDict();
virtual ~ParamDict();
ParamDict(const ParamDict&);
ParamDict& operator=(const ParamDict&);
// get type
int type(int id) const;
// get int
int get(int id, int def) const;
// get float
float get(int id, float def) const;
// get array
Mat get(int id, const Mat& def) const;
// set int
void set(int id, int i);
// set float
void set(int id, float f);
// set array
void set(int id, const Mat& v);
protected:
friend class Net;
void clear();
int load_param(const DataReader& dr);
int load_param_bin(const DataReader& dr);
private:
ParamDictPrivate* const d;
};
class ParamDictPrivate {
public:
// 0 = null
// 1 = int/float
// 2 = int
// 3 = float
// 4 = array of int/float
// 5 = array of int
// 6 = array of float
struct {
int type;
union {
int i;
float f;
};
Mat v;
} params[NCNN_MAX_PARAM_COUNT]; // 每个对象包括一个长度为NCNN_MAX_PARAM_COUNT的结构体数组
};
int ParamDict::type(int id) const {
// 获取对象中第i个变量的类型
return d->params[id].type;
}
get / set 函数也只是对对象中的第i个变量获取对应数据或设置
void ParamDict::clear() {
for (int i = 0; i < NCNN_MAX_PARAM_COUNT; i++) {
// 遍历,模式设为null
d->params[i].type = 0;
d->params[i].v = Mat();
}
}
// 判断字符串是不是浮点数
static bool vstr_is_float(const char vstr[16]) {
// look ahead for determine isfloat
for (int j = 0; j < 16; j++) {
if (vstr[j] == '\0')
break;
if (vstr[j] == '.' || tolower(vstr[j]) == 'e')
// 判断字符数组中是否有.(浮点数)和e(科学计数法)
return true;
}
return false;
}
// 字符串转浮点数
static float vstr_to_float(const char vstr[16]);
// 加载层参数
int ParamDict::load_param(const DataReader& dr) {
clear();
// 配置参数有两种
// 一种是k=v的类型;另一种是k=len,v1,v2,v3….(数组类型)
// key即id,结构体数组中索引
int id = 0;
while (dr.scan("%d=", &id) == 1) {
bool is_array = id <= -23300; // 以-23300为标志位
if (is_array) {
id = -id - 23300;
}
if (id >= NCNN_MAX_PARAM_COUNT) {
// 结构体数组的长度为NCNN_MAX_PARAM_COUNT,不能超过这个长度
NCNN_LOGE("id < NCNN_MAX_PARAM_COUNT failed (id=%d, NCNN_MAX_PARAM_COUNT=%d)", id, NCNN_MAX_PARAM_COUNT);
return -1;
}
if (is_array) {
// key=len, xx, xx, ...
int len = 0;
int nscan = dr.scan("%d", &len);
if (nscan != 1) {
NCNN_LOGE("ParamDict read array length failed");
return -1;
}
d->params[id].v.create(len);
for (int j = 0; j < len; j++) {
char vstr[16];
nscan = dr.scan(",%15[^,\n ]", vstr);
if (nscan != 1) {
NCNN_LOGE("ParamDict read array element failed");
return -1;
}
bool is_float = vstr_is_float(vstr);
if (is_float) { // 是不是浮点数
float* ptr = d->params[id].v;
ptr[j] = vstr_to_float(vstr);
}else { // int
int* ptr = d->params[id].v;
nscan = sscanf(vstr, "%d", &ptr[j]);
if (nscan != 1) {
NCNN_LOGE("ParamDict parse array element failed");
return -1;
}
}
d->params[id].type = is_float ? 6 : 5;
}
}else {
// key=value
char vstr[16];
int nscan = dr.scan("%15s", vstr);
if (nscan != 1) {
NCNN_LOGE("ParamDict read value failed");
return -1;
}
bool is_float = vstr_is_float(vstr);
if (is_float) {
d->params[id].f = vstr_to_float(vstr);
}else {
nscan = sscanf(vstr, "%d", &d->params[id].i);
if (nscan != 1) {
NCNN_LOGE("ParamDict parse value failed");
return -1;
}
}
d->params[id].type = is_float ? 3 : 2;
}
}
}
int ParamDict::load_param_bin(const DataReader& dr) {
clear();
// binary 0
// binary 100
// binary 1
// binary 1.250000
// binary 3 | array_bit
// binary 5
// binary 0.1
// binary 0.2
// binary 0.4
// binary 0.8
// binary 1.0
// binary -233(EOP)
int id = 0;
size_t nread;
nread = dr.read(&id, sizeof(int));
if (nread != sizeof(int)) {
NCNN_LOGE("ParamDict read id failed %zd", nread);
return -1;
}
while (id != -233) {
// 默认都为float
bool is_array = id <= -23300;
if (is_array) {
id = -id - 23300;
}
if (id >= NCNN_MAX_PARAM_COUNT) {
NCNN_LOGE("id < NCNN_MAX_PARAM_COUNT failed (id=%d, NCNN_MAX_PARAM_COUNT=%d)", id, NCNN_MAX_PARAM_COUNT);
return -1;
}
if (is_array) {
int len = 0;
nread = dr.read(&len, sizeof(int));
if (nread != sizeof(int)) {
NCNN_LOGE("ParamDict read array length failed %zd", nread);
return -1;
}
d->params[id].v.create(len);
float* ptr = d->params[id].v;
nread = dr.read(ptr, sizeof(float) * len);
if (nread != sizeof(float) * len) {
NCNN_LOGE("ParamDict read array element failed %zd", nread);
return -1;
}
d->params[id].type = 4;
}else {
nread = dr.read(&d->params[id].f, sizeof(float));
if (nread != sizeof(float)) {
NCNN_LOGE("ParamDict read value failed %zd", nread);
return -1;
}
d->params[id].type = 1;
}
nread = dr.read(&id, sizeof(int));
if (nread != sizeof(int)) {
NCNN_LOGE("ParamDict read EOP failed %zd", nread);
return -1;
}
}
return 0;
}