9、抽象基类num_sequence的创建及使用

本文详细探讨了如何创建和应用抽象基类num_sequence,通过实例代码num_sequence.h和main.cc,阐述了其在程序设计中的关键角色和使用方法。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

num_sequence.h

#include <iostream>
#include <vector>

using namespace std;

class num_sequence
{
    public:
        virtual ~num_sequence(){};                          // 后面加分号

        virtual int         elem( int pos ) const = 0;
        virtual const char *what_am_i() const = 0;
        static int          max_elems() { return _max_elems; }       // 括号后面不加分号
        virtual ostream&    print( ostream &os = cout ) const = 0;      ///////////
     
        //virtual const char*         test(){ return "num_seq test"; }
    protected:
        virtual void        gen_elems( int pos ) const = 0;
        bool                check_integrity( int pos, int size) const;

        const static int    _max_elems = 1024;                 // 前面加 const
};

class Fibonacci : public num_sequence
{
    public:
        Fibonacci( int len = 1, int beg = 1 ) 
            : _len( len ), _beg_pos( beg ){} 

        int                 elem( int pos) const;
        virtual const char  *what_am_i() const { return "Fibonacci"; }
        ostream&            print( ostream &os = cout ) const;

        int                 length(){ return _len; }
        int                 beg_pos(){ return _beg_pos; }

        //const char*         test(){ return "fib_seq test"; }
    protected:
        void                gen_elems( int pos ) const;

    private:
        int                 _len;
        int                 _beg_pos;
        static vector<int>  _elems;

};      // 类后面要加分号

//vector<int> Fibonacci::_elems;  // 在此处定义_elems
//也可以,或者在main.cc,不然会报_elems 未定义错误

// operator<<
ostream& operator<<( ostream &os, const num_sequence &ns)
{
    return ns.print(os);
}

// check_integrity //此处的check_integrity 只是大致判断了范围,不具体不确切
#if 0
inline bool num_sequence::check_integrity ( int pos ) const
{
    if (pos <= 0 || pos > _max_elems)
    {
        cerr << "invailed position" << pos <<endl;
        return false;
    }
    return true;
}
#endif
#if 0
// 更具体的 检查完整性方式,但不完美,check_integrity在基类中并不是virtual
// 因此通过指向基类的指针或引用时会出现只调用基类的check_integrity
// 若采取在基类中定义该函数为virtual,同样也是不完美的,因为根据不完整信息所实现的内容,
// 可能也是不完整的,这和型别相依相悖
inline bool Fibonacci::
check_integrity ( int pos ) const
{
    if (! num_sequence::check_integrity(pos))  //必须加上类运算符,不然就是Fib
        return false;

    if (pos > _elems.size())
        Fibonacci::gen_elems(pos);

    return true;
}
#else
// 因此重写 check_integrity,属于 num_sequence类,
// 采取两个参数
inline bool num_sequence::
check_integrity(int pos, int size) const
{
    if (pos <= 0 || pos > _max_elems)
    {
        cerr << "invailed position" << pos <<endl;
        return false;
    }
    if (pos > size)
        gen_elems(pos);  // 通过虚拟机制调用,结合了gen,因此print 和 elem函数就可以去掉 gen_elems

    return true;
}
#endif

#if 1
//ostream& Fibonacci::print( ostream &os = cout ) const   //此处错误
ostream& Fibonacci::
print( ostream &os) const
{
    unsigned int elem_pos = _beg_pos - 1;  //_beg_pos = 1 initialize
    unsigned int end_pos = elem_pos + _len;

    if (end_pos > _elems.size())
        Fibonacci::gen_elems(end_pos);  // 此处加::作用域 编译期而不用等到执行期
    while (elem_pos < end_pos)
        os << _elems[elem_pos++] << ' ';

    return os;
}
#else
ostream& Fibonacci::
print( ostream &os) const
{
    unsigned int elem_pos = _beg_pos - 1;  //_beg_pos = 1 initialize
    unsigned int end_pos = elem_pos + _len;
    
    if (!check_integrity(end_pos, _elems.size()))
        // 此处不知道返回什么
    while (elem_pos < end_pos)
        os << _elems[elem_pos++] << ' ';

    return os;
}
#endif

#if 0
int Fibonacci::elem( int pos ) const
{
    if ( !check_integrity(pos) )
        return 0;
    if ( (unsigned int)pos > _elems.size())  // 放到了 check_integrity中,相当于每个派生类都不用再写这段添加代码
    {
        Fibonacci::gen_elems(pos);
    }
    return _elems[pos - 1];
}
#else
int Fibonacci::elem( int pos ) const
{
    if ( !check_integrity(pos, _elems.size()) )
        return 0;

    return _elems[pos - 1];
}
#endif

void Fibonacci::gen_elems( int pos ) const
{
    if ( _elems.empty() )
    {
        _elems.push_back(1);
        _elems.push_back(1);
    }
    if ( _elems.size() <= (unsigned int)pos )
    {
        int ix = _elems.size();
        int n_2 = _elems[ix - 2];
        int n_1 = _elems[ix - 1];

        for ( ; ix <= pos; ++ix )
        {
            int elem = n_2 + n_1;
            _elems.push_back(elem);
            n_2 = n_1; n_1 = elem;
        }
    }
}

main.cc

#include "num_sequence.h"



vector<int> Fibonacci::_elems;



int main()
{
    //num_sequence *ps = new Fibonacci;  // 派生类定义了该虚拟函数,才可以定义基类对象,通过派生类定义
    //num_sequence *ps2 = new Fibonacci( 2, 2);
    
    cout << "init main.\n";
    
    Fibonacci fib;
    cout << "fib: beginning at element 1 for 1 element: "
         << fib << endl;    // fib 由于重载了 operator<< 运算发,并且调用print函数

    Fibonacci fib2(10);
    cout << "fib: beginning at element 1 for 10 element: "
         << fib2 << endl;    // fib 由于重载了 operator<< 运算发,并且调用print函数

    Fibonacci fib3(3, 8);  // len = 3, _beg_pos = 7
    cout << "fib: beginning at element 7 for 3 element: "
         << fib3 << endl;    // fib 由于重载了 operator<< 运算发,并且调用print函数
    return 0;
}

# architectures/base_model.py import torch import inspect import torch.nn as nn from abc import ABC, abstractmethod from components import * import os from datetime import datetime class BaseMultiArchModel(ABC, nn.Module): """多架构模型基类,定义通用组件和接口""" # 组件白名单(核心定型后仅创建者可修改) # 注意:必须注册具体实现类,而非抽象基类! ALLOWED_COMPONENTS = { "embedding_Standard": StandardEmbedding, # 具体实现类 "embedding_Position": PositionalEmbedding, # 具体实现类 "attention_Self": SelfAttentionModule, # 具体实现类 "attention_Dynamic": DynamicAttention, # 具体实现类 "norm_Layer": LayerNormModule, # 具体实现类 "norm_Dynamic": DynamicNormModule, # 具体实现类 "cnn_Conv1d": Conv1dModule, # 具体实现类 "cnn_MultiK": MultiKernelCNN, # 具体实现类 "rnn_GRU": GRUModule, # 具体实现类 "rnn_LSTM": LSTMModule, # 具体实现类 "transformer_Transformer": TransformerModule, # 具体实现类 "transformer_Lightweight": LightweightTransformer, # 具体实现类 "fusion_Conca": ConcatFusion, # 具体实现类 "fusion_Attention": AttentionFusion # 具体实现类 } # 基因快照配置 GENE_SNAPSHOT_DIR = "gene_snapshots/" def __init__(self, model_config): super().__init__() self.model_config = model_config self._init_components() self._ensure_snapshot_dir() def _init_components(self): """初始化通用组件(增加白名单校验)""" # 创建嵌入层(任务特定,需子类实现) self.embedding = self._create_embedding() # 创建通用组件(使用白名单校验) for comp_type in [ "embedding_Standard", "embedding_Position", "attention_Self", "attention_Dynamic", "norm_Layer", "norm_Dynamic", "cnn_Conv1d", "cnn_MultiK", "rnn_GRU", "rnn_LSTM", "transformer_Transformer", "transformer_Lightweight", "fusion_Conca", "fusion_Attention" ]: config = self.model_config["model_components"].get(comp_type, {}) comp_class = self.ALLOWED_COMPONENTS[comp_type] # 校验组件是否被篡改 assert issubclass(comp_class, BaseComponent), f"非法组件:{comp_type}" # 确保是具体实现类 assert not inspect.isabstract(comp_class), f"非法组件:{comp_type} 是抽象类" # 创建组件实例 setattr(self, comp_type, comp_class(**config)) # 【修改】根据配置设置默认组件的快捷引用 # 不再通过命名规则自动推断,而是通过配置文件显式指定 default_components = self.model_config.get("default_components", {}) # 默认组件映射关系 default_mapping = { "attention": "attention_Self", "norm": "norm_Layer", "cnn": "cnn_Conv1d", "rnn": "rnn_GRU", "transformer": "transformer_Transformer", "fusion": "fusion_Conca" } # 应用默认组件配置 for attr_name, default_key in default_mapping.items(): # 如果配置中指定了其他组件,则使用配置中的组件 config_key = default_components.get(attr_name, default_key) # 检查组件是否存在 if not hasattr(self, config_key): raise KeyError(f"默认组件 {config_key} 未在组件初始化中创建,请检查配置!") # 设置快捷引用,例如 self.attention = self.attention_Self setattr(self, attr_name, getattr(self, config_key)) def _ensure_snapshot_dir(self): """确保基因快照目录存在""" if not os.path.exists(self.GENE_SNAPSHOT_DIR): os.makedirs(self.GENE_SNAPSHOT_DIR) @abstractmethod def _create_embedding(self): """创建任务特定的嵌入层""" pass @abstractmethod def _create_fusion(self): """创建特征融合层""" pass # 通用组件创建方法(保留,但修改为使用配置中指定的组件) def _create_attention(self): # 从配置中获取使用哪个注意力组件 attention_type = self.model_config["model_components"].get("attention_type", "attention_Self") config = self.model_config["model_components"].get(attention_type, {}) return self.ALLOWED_COMPONENTS[attention_type]( embed_dim=self.model_config["embedding_dim"], **config ) def _create_normalization(self): # 从配置中获取使用哪个归一化组件 norm_type = self.model_config["model_components"].get("norm_type", "norm_Layer") config = self.model_config["model_components"].get(norm_type, {}) return self.ALLOWED_COMPONENTS[norm_type]( num_features=self.model_config["embedding_dim"], **config ) def _create_cnn(self): # 从配置中获取使用哪个CNN组件 cnn_type = self.model_config["model_components"].get("cnn_type", "cnn_Conv1d") config = self.model_config["model_components"].get(cnn_type, {}) return self.ALLOWED_COMPONENTS[cnn_type]( in_channels=self.model_config["embedding_dim"], out_channels=self.model_config["hidden_dim"], **config ) def _create_rnn(self): # 从配置中获取使用哪个RNN组件 rnn_type = self.model_config["model_components"].get("rnn_type", "rnn_GRU") config = self.model_config["model_components"].get(rnn_type, {}) return self.ALLOWED_COMPONENTS[rnn_type]( input_size=self.model_config["embedding_dim"], hidden_size=self.model_config["hidden_dim"], **config ) def _create_transformer(self): # 从配置中获取使用哪个Transformer组件 transformer_type = self.model_config["model_components"].get("transformer_type", "transformer_Transformer") config = self.model_config["model_components"].get(transformer_type, {}) return self.ALLOWED_COMPONENTS[transformer_type]( d_model=self.model_config["embedding_dim"], **config ) def forward_features(self, x): """通用特征提取流程""" # 嵌入层处理 x_embed = self.embedding(x) # 使用默认的归一化和注意力组件 x_embed = self.norm(x_embed) if hasattr(self, &#39;norm&#39;) else x_embed x_embed = self.attention(x_embed) if hasattr(self, &#39;attention&#39;) else x_embed # 多架构特征提取 cnn_out = self.cnn(x_embed.transpose(1, 2)).transpose(1, 2) if hasattr(self, &#39;cnn&#39;) else None rnn_out = self.rnn(x_embed) if hasattr(self, &#39;rnn&#39;) else None trans_out = self.transformer(x_embed) if hasattr(self, &#39;transformer&#39;) else None # 特征融合 fusion_inputs = {} if cnn_out is not None: fusion_inputs["cnn"] = cnn_out if rnn_out is not None: fusion_inputs["rnn"] = rnn_out if trans_out is not None: fusion_inputs["transformer"] = trans_out fused_features = self.fusion(fusion_inputs) return fused_features @abstractmethod def forward(self, x): """主前向传播方法""" pass def save_gene_snapshot(self, snapshot_name=None, evolution_note=""): """保存核心组件权重(基因快照)""" snapshot_name = snapshot_name or datetime.now().strftime("%Y%m%d_%H%M%S") snapshot_path = os.path.join(self.GENE_SNAPSHOT_DIR, f"{snapshot_name}.pth") # 保存所有组件的权重 state_dict = { "model_config": self.model_config } # 保存主要组件的权重 for attr in ["embedding", "attention", "norm", "cnn", "rnn", "transformer", "fusion"]: if hasattr(self, attr): state_dict[attr] = getattr(self, attr).state_dict() # 保存所有注册的组件权重 for comp_type in self.ALLOWED_COMPONENTS.keys(): if hasattr(self, comp_type): state_dict[comp_type] = getattr(self, comp_type).state_dict() torch.save(state_dict, snapshot_path) # 写入进化日志 with open(os.path.join(self.GENE_SNAPSHOT_DIR, "evolution_log.md"), "a") as f: f.write(f"## {snapshot_name}\n") f.write(f"变更原因:{evolution_note}\n") f.write(f"时间:{datetime.now().strftime(&#39;%Y-%m-%d %H:%M:%S&#39;)}\n\n") def load_gene_snapshot(self, snapshot_path): """加载基因快照(复活/回退用)""" state_dict = torch.load(snapshot_path) # 更新配置 self.model_config = state_dict["model_config"] # 加载主要组件权重 for attr in ["embedding", "attention", "norm", "cnn", "rnn", "transformer", "fusion"]: if attr in state_dict and hasattr(self, attr): getattr(self, attr).load_state_dict(state_dict[attr]) # 加载所有注册的组件权重 for comp_type in self.ALLOWED_COMPONENTS.keys(): if comp_type in state_dict and hasattr(self, comp_type): getattr(self, comp_type).load_state_dict(state_dict[comp_type]) print(f"[萧默芯] 成功加载基因快照: {snapshot_path}") 帮我分析一下这个代码的可靠性和能力!
最新发布
06-12
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值