NCNN-Option_Paramdict

NCNN库中的Option类用于设定轻量级模式、线程数等基本配置,而ParamDict类管理网络层的参数,如层类型、数值等。ParamDict通过键值对形式存储参数,支持int、float及数组类型的配置,并提供了加载与解析参数的方法。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

NCNN-Option_Paramdict

路径:src/Option.h 和 src/Option.cpp, src/paramdict.h 和 src/paramdict.cpp

Option是一个简单的配置类,Paramdic是网络层的配置参数,不同层有不同配置参数

源码

仅包含一些bool或int类型的变量,构造函数配置了默认选项
class NCNN_EXPORT Option {
public:
    Option();  // 默认的配置项
public:
    bool lightmode;	  // light mode
    int num_threads;  // thread count

    Allocator* blob_allocator;
    Allocator* workspace_allocator;
    
    int openmp_blocktime;
    bool use_winograd_convolution;
    bool use_sgemm_convolution;
    bool use_int8_inference;   enable quantized int8 inference
	...
};
class NCNN_EXPORT ParamDict {
public:

    ParamDict();
    virtual ~ParamDict();
    ParamDict(const ParamDict&);
    ParamDict& operator=(const ParamDict&);

    // get type
    int type(int id) const;
    // get int
    int get(int id, int def) const;
    // get float
    float get(int id, float def) const;
    // get array
    Mat get(int id, const Mat& def) const;

    // set int
    void set(int id, int i);
    // set float
    void set(int id, float f);
    // set array
    void set(int id, const Mat& v);

protected:
    friend class Net;
    void clear();
    int load_param(const DataReader& dr);
    int load_param_bin(const DataReader& dr);

private:
    ParamDictPrivate* const d;
};


class ParamDictPrivate {
public:
    // 0 = null
    // 1 = int/float
    // 2 = int
    // 3 = float
    // 4 = array of int/float
    // 5 = array of int
    // 6 = array of float
    struct {
        int type;
        union {
            int i;
            float f;
        };
        Mat v;
    } params[NCNN_MAX_PARAM_COUNT];  // 每个对象包括一个长度为NCNN_MAX_PARAM_COUNT的结构体数组
};

int ParamDict::type(int id) const {
	// 获取对象中第i个变量的类型
    return d->params[id].type;
}

get / set 函数也只是对对象中的第i个变量获取对应数据或设置


void ParamDict::clear() {
    for (int i = 0; i < NCNN_MAX_PARAM_COUNT; i++) {
    	// 遍历,模式设为null
        d->params[i].type = 0;
        d->params[i].v = Mat();
    }
}

// 判断字符串是不是浮点数
static bool vstr_is_float(const char vstr[16]) {
    // look ahead for determine isfloat
    for (int j = 0; j < 16; j++) {
        if (vstr[j] == '\0')
            break;

        if (vstr[j] == '.' || tolower(vstr[j]) == 'e')
        	// 判断字符数组中是否有.(浮点数)和e(科学计数法)
            return true;
    }

    return false;
}

// 字符串转浮点数
static float vstr_to_float(const char vstr[16]);

// 加载层参数
int ParamDict::load_param(const DataReader& dr) {
	clear();
	// 配置参数有两种
	// 一种是k=v的类型;另一种是k=len,v1,v2,v3….(数组类型)
	// key即id,结构体数组中索引
    int id = 0;
    while (dr.scan("%d=", &id) == 1) {
        bool is_array = id <= -23300;  // 以-23300为标志位
        if (is_array) {
            id = -id - 23300;
        }

        if (id >= NCNN_MAX_PARAM_COUNT) {
        	// 结构体数组的长度为NCNN_MAX_PARAM_COUNT,不能超过这个长度
            NCNN_LOGE("id < NCNN_MAX_PARAM_COUNT failed (id=%d, NCNN_MAX_PARAM_COUNT=%d)", id, NCNN_MAX_PARAM_COUNT);
            return -1;
        }

        if (is_array) {
        	// key=len, xx, xx, ...
            int len = 0;
            int nscan = dr.scan("%d", &len);
            if (nscan != 1) {
                NCNN_LOGE("ParamDict read array length failed");
                return -1;
            }

            d->params[id].v.create(len);

            for (int j = 0; j < len; j++) {
                char vstr[16];
                nscan = dr.scan(",%15[^,\n ]", vstr);
                if (nscan != 1) {
                    NCNN_LOGE("ParamDict read array element failed");
                    return -1;
                }

                bool is_float = vstr_is_float(vstr);

                if (is_float) {  // 是不是浮点数
                    float* ptr = d->params[id].v;
                    ptr[j] = vstr_to_float(vstr);
                }else {  // int
                    int* ptr = d->params[id].v;
                    nscan = sscanf(vstr, "%d", &ptr[j]);
                    if (nscan != 1) {
                        NCNN_LOGE("ParamDict parse array element failed");
                        return -1;
                    }
                }

                d->params[id].type = is_float ? 6 : 5;
            }
        }else {  
        	// key=value
            char vstr[16];
            int nscan = dr.scan("%15s", vstr);
            if (nscan != 1) {
                NCNN_LOGE("ParamDict read value failed");
                return -1;
            }

            bool is_float = vstr_is_float(vstr);

            if (is_float) {
                d->params[id].f = vstr_to_float(vstr);
            }else {
                nscan = sscanf(vstr, "%d", &d->params[id].i);
                if (nscan != 1) {
                    NCNN_LOGE("ParamDict parse value failed");
                    return -1;
                }
            }

            d->params[id].type = is_float ? 3 : 2;
        }		
    }
}

int ParamDict::load_param_bin(const DataReader& dr) {
    clear();

    //     binary 0
    //     binary 100
    //     binary 1
    //     binary 1.250000
    //     binary 3 | array_bit
    //     binary 5
    //     binary 0.1
    //     binary 0.2
    //     binary 0.4
    //     binary 0.8
    //     binary 1.0
    //     binary -233(EOP)

    int id = 0;
    size_t nread;
    nread = dr.read(&id, sizeof(int));
    if (nread != sizeof(int)) {
        NCNN_LOGE("ParamDict read id failed %zd", nread);
        return -1;
    }

    while (id != -233) {
    	// 默认都为float
        bool is_array = id <= -23300;
        if (is_array) {
            id = -id - 23300;
        }

        if (id >= NCNN_MAX_PARAM_COUNT) {
            NCNN_LOGE("id < NCNN_MAX_PARAM_COUNT failed (id=%d, NCNN_MAX_PARAM_COUNT=%d)", id, NCNN_MAX_PARAM_COUNT);
            return -1;
        }

        if (is_array) {
            int len = 0;
            nread = dr.read(&len, sizeof(int));
            if (nread != sizeof(int)) {
                NCNN_LOGE("ParamDict read array length failed %zd", nread);
                return -1;
            }

            d->params[id].v.create(len);

            float* ptr = d->params[id].v;
            nread = dr.read(ptr, sizeof(float) * len);
            if (nread != sizeof(float) * len) {
                NCNN_LOGE("ParamDict read array element failed %zd", nread);
                return -1;
            }

            d->params[id].type = 4;
        }else {
            nread = dr.read(&d->params[id].f, sizeof(float));
            if (nread != sizeof(float)) {
                NCNN_LOGE("ParamDict read value failed %zd", nread);
                return -1;
            }

            d->params[id].type = 1;
        }

        nread = dr.read(&id, sizeof(int));
        if (nread != sizeof(int)) {
            NCNN_LOGE("ParamDict read EOP failed %zd", nread);
            return -1;
        }
    }

    return 0;
}
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值