class edge {
public:
edge(node* prev, const shape3d& shape, vector_type vtype)
: shape_(shape),
vtype_(vtype),
data_({vec_t(shape.size())}),
grad_({ vec_t(shape.size()) }),
prev_(prev)
{
}
void merge_grads(vec_t *dst) {
dst->resize(grad_[0].size());
std::fill(dst->begin(), dst->end(), static_cast<float_t>(0));
for (cnn_size_t sample = 0, sample_count = grad_.size(); sample < sample_count; ++sample) {
vectorize::reduce<float_t>(&grad_[sample][0], dst->size(), &(*dst)[0]);
}
}
void clear_grads() {
for (cnn_size_t sample = 0, sample_count = grad_.size(); sample < sample_count; ++sample) {
std::fill(grad_[sample].begin(), grad_[sample].end(), (float_t)0);
}
}
tensor_t* get_data() {
return &data_;
}
const tensor_t* get_data() const {
return &data_;
}
tensor_t* get_gradient() {
return &grad_;
}
const tensor_t* get_gradient() const {
return &grad_;
}
const std::vector<node*>& next() const { return next_; }
node* prev() { return prev_; }
const node* prev() const { return prev_; }
const shape3d& shape() const { return shape_; }
vector_type vtype() const { return vtype_; }
void add_next_node(node* next) { next_.push_back(next); }
private:
shape3d shape_;
vector_type vtype_;
tensor_t data_;
tensor_t grad_;
node* prev_;
std::vector<node*> next_;
};
class node : public std::enable_shared_from_this<node> {
public:
node(cnn_size_t in_size, cnn_size_t out_size)
: prev_(in_size), next_(out_size) {}
virtual ~node() {}
const std::vector<edgeptr_t>& prev() const { return prev_; }
const std::vector<edgeptr_t>& next() const { return next_; }
cnn_size_t prev_port(const edge& e) const {
auto it = std::find_if(prev_.begin(), prev_.end(),
[&](edgeptr_t ep) { return ep.get() == &e; });
return (cnn_size_t)std::distance(prev_.begin(), it);
}
cnn_size_t next_port(const edge& e) const {
auto it = std::find_if(next_.begin(), next_.end(),
[&](edgeptr_t ep) { return ep.get() == &e; });
return (cnn_size_t)std::distance(next_.begin(), it);
}
std::vector<node*> prev_nodes() const;
std::vector<node*> next_nodes() const;
protected:
node() = delete;
friend void connect(layerptr_t head, layerptr_t tail,
cnn_size_t head_index, cnn_size_t tail_index);
mutable std::vector<edgeptr_t> prev_;
mutable std::vector<edgeptr_t> next_;
};