tiny-dnn源码Edge类

本文介绍了一个深度学习框架中的边缘类实现细节,包括构造函数、梯度合并及清除等功能,并展示了节点类的设计,用于连接不同边缘。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

class edge {
 public:
    edge(node* prev, const shape3d& shape, vector_type vtype)
        : shape_(shape),
          vtype_(vtype),
          data_({vec_t(shape.size())}),
          grad_({ vec_t(shape.size()) }),
          prev_(prev)
    {
    }

    void merge_grads(vec_t *dst) {
        dst->resize(grad_[0].size());
        std::fill(dst->begin(), dst->end(), static_cast<float_t>(0));

        // @todo consider adding parallelism
        for (cnn_size_t sample = 0, sample_count = grad_.size(); sample < sample_count; ++sample) {
            vectorize::reduce<float_t>(&grad_[sample][0], dst->size(), &(*dst)[0]);
        }
    }

    void clear_grads() {
        for (cnn_size_t sample = 0, sample_count = grad_.size(); sample < sample_count; ++sample) {
            std::fill(grad_[sample].begin(), grad_[sample].end(), (float_t)0);
        }
    }

    tensor_t* get_data() {
        return &data_;
    }

    const tensor_t* get_data() const {
        return &data_;
    }

    tensor_t* get_gradient() {
        return &grad_;
    }

    const tensor_t* get_gradient() const {
        return &grad_;
    }

    const std::vector<node*>& next() const { return next_; }
    node* prev() { return prev_; }
    const node* prev() const { return prev_; }

    const shape3d& shape() const { return shape_; }
    vector_type vtype() const { return vtype_; }
    void add_next_node(node* next) { next_.push_back(next); }

 private:
    shape3d shape_;
    vector_type vtype_;
    tensor_t data_;
    tensor_t grad_;
    node* prev_;               // previous node, "producer" of this tensor
    std::vector<node*> next_;  // next nodes, "consumers" of this tensor
};
class node : public std::enable_shared_from_this<node> {
public:
    node(cnn_size_t in_size, cnn_size_t out_size)
        : prev_(in_size), next_(out_size) {}
    virtual ~node() {}

    const std::vector<edgeptr_t>& prev() const { return prev_; }
    const std::vector<edgeptr_t>& next() const { return next_; }

    cnn_size_t prev_port(const edge& e) const {
        auto it = std::find_if(prev_.begin(), prev_.end(),
                               [&](edgeptr_t ep) { return ep.get() == &e; });
        return (cnn_size_t)std::distance(prev_.begin(), it);
    }

    cnn_size_t next_port(const edge& e) const {
        auto it = std::find_if(next_.begin(), next_.end(),
                               [&](edgeptr_t ep) { return ep.get() == &e; });
        return (cnn_size_t)std::distance(next_.begin(), it);
    }

    std::vector<node*> prev_nodes() const; // @todo refactor and remove this method
    std::vector<node*> next_nodes() const; // @todo refactor and remove this method
 protected:
    node() = delete;

    friend void connect(layerptr_t head, layerptr_t tail,
                        cnn_size_t head_index, cnn_size_t tail_index);

    mutable std::vector<edgeptr_t> prev_;
    mutable std::vector<edgeptr_t> next_;
};
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值