Paper Reading:Spatial Transformer Networks(with code explanation)

转自:http://blog.youkuaiyun.com/l691899397/article/details/53641485


原文:Spatial Transformer Networks

前言:卷积神经网络(CNN)已经可以构建出一个强大的分类模型,但它仍然缺乏能力来应对输入数据的空间变换,比如:平移、缩放、旋转,尽管convolutional层和pooling层在一定程度上解决了平移和缩放的问题,但很多时候面对旋转、大尺度的缩放等情况时仍然无能为力。这篇文章提出了一种叫做空间变换网络(Spatial Transform Networks, STN)的模型,它能自动学习变换参数,对上一层的图像进行处理,在一定程度上自适应实现这些变换,从而实现对空间变换的不变性。当输入数据的空间变换差异较大时,STN可以将输入的数据进行“矫正”,进而提高分类的准确性。

前导知识:

1:图像的二维仿射变换

图像的二维仿射变换包括图像的平移(Translation)、缩放(Scale)、旋转(Rotation)等变换,实现这些变换只需要一个2*3维的变换矩阵。

(1)平移变换(Translation)

  平移变换完成的操作是: 
    x=x+Tx  
    y=y+Ty

  写成矩阵形式为: 
这里写图片描述

(2)缩放(Scale)

   x=xSx  
   y=ySy

  写成矩阵形式为: 
这里写图片描述

(3)旋转

  旋转的矩阵形式为:

这里写图片描述

  由于图像的坐标系的原点在左上角,所以这里会做一个Normalization,把坐标归一化到[-1,1],这样就是绕图像的中心进行旋转了。不然的话会绕图片的左上角进行旋转。

2、双线性插值

  如果把输入图像的一部分映射的输出图像,那么输出图像V中的每一个点,在输入图像U中不一定对应整数点,比如输入图像中的5*5的区域对应输出图像10*10的区域,那么输出图像中很多点对应的输入图像的坐标并不是整数点。这个时候要使用插值法来填充这些像素。 
  论文中提到了2种方法来填充这些像素,在代码实现中使用了双线性插值的方法。 
  如下图,中间点P是待插值的点,其中f(x)在我们这里是位于点x处的像素值。 
这里写图片描述

STN网络详解

1、正向计算过程:

网络结构: 
这里写图片描述 
  这幅图是论文中出现的网络结构,可以看出,STN网络为主网络的一个分支,整个STN分支分为三个部分,分别为:

1、Localisation Network 
  这部分主要是通过一个子网络生成变换参数θ,即2*3维的变换矩阵。子网络主要是由全连接层或者卷积层组成。

2、 Parameterised Sampling Grid 
  该部分将输入坐标转换为输出坐标。假设输入U每个像素的坐标为( xsi , ysi ),输出V的每个像素坐标为( xti , yti ),空间变换函数Tθ为仿射变换函数,那么 ( xsi , ysi ) 和 ( xti , yti ) 的对应关系可以写为: 
这里写图片描述

  这里有一点需要注意,论文中用的是输入图像的坐标( xsi , ysi ) =θ* ( xti , yti )(输出图像的坐标),即根据输出图像V中一点的坐标找该点在输入图像U中的对应坐标,这正好是我们前边仿射变换的逆过程。

3、Differentiable Image Sampling

  最后,就可以采用不同的插值方法将输入图像映射到输出图像。入下式 k为不同的sampling kernel。 
这里写图片描述 
  若采用双线性插值的方法,则插值公式为 
这里写图片描述

2、反向传播过程:

  该层的输入梯度为 lossV ,我们需要计算的包括loss对输入U的导数,以及loss对仿射矩阵θ的导数。 
  前边的双线性插值的公式如下,在后面的计算过程中会用到:

这里写图片描述

1、loss对输入U的导数

  根据双线性插值的公式,对U求导得: 
这里写图片描述

  该导数的值就是双线性插值的系数。然后根据链式求导法则即可得出loss对U的导数。

2、loss对仿射矩阵θ的导数 
  根据链式求导法则这里写图片描述 ,所以我们需要求 Vx xθ 。 
  文中已经给出: 
这里写图片描述

  所以双线性插值只需对周围4个点进行计算即可,其他的点均为0。 Vy 的求法也一样。

  而对于 xθ yθ ,可以根据下面公式: 
这里写图片描述 
  很容易求到: 
这里写图片描述 
  根据链式法则即可求出 Vθ

  至此,整个stn的推导便结束了。

代码分析

LayerSetUp:

void SpatialTransformerLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
      const vector<Blob<Dtype>*>& top) {

    string prefix = "\t\tSpatial Transformer Layer:: LayerSetUp: \t";

    if(this->layer_param_.st_param().transform_type() == "affine") {
        transform_type_ = "affine";
    } else {
        CHECK(false) << prefix << "Transformation type only supports affine now!" << std::endl;
    }

    if(this->layer_param_.st_param().sampler_type() == "bilinear") {
        sampler_type_ = "bilinear";
    } else {
        CHECK(false) << prefix << "Sampler type only supports bilinear now!" << std::endl;
    }

    if(this->layer_param_.st_param().to_compute_du()) {
        to_compute_dU_ = true;
    }

    std::cout<<prefix<<"Getting output_H_ and output_W_"<<std::endl;

    output_H_ = bottom[0]->shape(2);
    if(this->layer_param_.st_param().has_output_h()) {
        output_H_ = this->layer_param_.st_param().output_h();
    }
    output_W_ = bottom[0]->shape(3);
    if(this->layer_param_.st_param().has_output_w()) {
        output_W_ = this->layer_param_.st_param().output_w();
    }

    std::cout<<prefix<<"output_H_ = "<<output_H_<<", output_W_ = "<<output_W_<<std::endl;

    std::cout<<prefix<<"Getting pre-defined parameters"<<std::endl;

    is_pre_defined_theta[0] = false;
    if(this->layer_param_.st_param().has_theta_1_1()) {
        is_pre_defined_theta[0] = true;
        ++ pre_defined_count;
        pre_defined_theta[0] = this->layer_param_.st_param().theta_1_1();
        std::cout<<prefix<<"Getting pre-defined theta[1][1] = "<<pre_defined_theta[0]<<std::endl;
    }

    is_pre_defined_theta[1] = false;
    if(this->layer_param_.st_param().has_theta_1_2()) {
        is_pre_defined_theta[1] = true;
        ++ pre_defined_count;
        pre_defined_theta[1] = this->layer_param_.st_param().theta_1_2();
        std::cout<<prefix<<"Getting pre-defined theta[1][2] = "<<pre_defined_theta[1]<<std::endl;
    }

    is_pre_defined_theta[2] = false;
    if(this->layer_param_.st_param().has_theta_1_3()) {
        is_pre_defined_theta[2] = true;
        ++ pre_defined_count;
        pre_defined_theta[2] = this->layer_param_.st_param().theta_1_3();
        std::cout<<prefix<<"Getting pre-defined theta[1][3] = "<<pre_defined_theta[2]<<std::endl;
    }

    is_pre_defined_theta[3] = false;
    if(this->layer_param_.st_param().has_theta_2_1()) {
        is_pre_defined_theta[3] = true;
        ++ pre_defined_count;
        pre_defined_theta[3] = this->layer_param_.st_param().theta_2_1();
        std::cout<<prefix<<"Getting pre-defined theta[2][1] = "<<pre_defined_theta[3]<<std::endl;
    }

    is_pre_defined_theta[4] = false;
    if(this->layer_param_.st_param().has_theta_2_2()) {
        is_pre_defined_theta[4] = true;
        ++ pre_defined_count;
        pre_defined_theta[4] = this->layer_param_.st_param().theta_2_2();
        std::cout<<prefix<<"Getting pre-defined theta[2][2] = "<<pre_defined_theta[4]<<std::endl;
    }

    is_pre_defined_theta[5] = false;
    if(this->layer_param_.st_param().has_theta_2_3()) {
        is_pre_defined_theta[5] = true;
        ++ pre_defined_count;
        pre_defined_theta[5] = this->layer_param_.st_param().theta_2_3();
        std::cout<<prefix<<"Getting pre-defined theta[2][3] = "<<pre_defined_theta[5]<<std::endl;
    }

    // check the validation for the parameter theta
    CHECK(bottom[1]->count(1) + pre_defined_count == 6) << "The dimension of theta is not six!"
            << " Only " << bottom[1]->count(1) << " + " << pre_defined_count << std::endl;
    CHECK(bottom[1]->shape(0) == bottom[0]->shape(0)) << "The first dimension of theta and " <<
            "U should be the same" << std::endl;

    // initialize the matrix for output grid
    std::cout<<prefix<<"Initializing the matrix for output grid"<<std::endl;

    vector<int> shape_output(2);
    shape_output[0] = output_H_ * output_W_; shape_output[1] = 3;
    output_grid.Reshape(shape_output);

    Dtype* data = output_grid.mutable_cpu_data();
    //这里初始化了保存输出V的坐标的矩阵,并做了Normalization,将坐标归一化到了[-1,1],每个点的坐标为[x,y,1]
    for(int i=0; i<output_H_ * output_W_; ++i) {
        data[3 * i] = (i / output_W_) * 1.0 / output_H_ * 2 - 1;
        data[3 * i + 1] = (i % output_W_) * 1.0 / output_W_ * 2 - 1;
        data[3 * i + 2] = 1;
    }

    // initialize the matrix for input grid
    std::cout<<prefix<<"Initializing the matrix for input grid"<<std::endl;

    vector<int> shape_input(3);
    shape_input[0] = bottom[1]->shape(0); shape_input[1] = output_H_ * output_W_; shape_input[2] = 2;
    input_grid.Reshape(shape_input);

    std::cout<<prefix<<"Initialization finished."<<std::endl;
}
 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114

Forward:

void SpatialTransformerLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
    const vector<Blob<Dtype>*>& top) {

    string prefix = "\t\tSpatial Transformer Layer:: Forward_cpu: \t";

    // CHECK(false) << "Don't use the CPU implementation! If you really want to, delete the" <<
            // " CHECK in st_layer.cpp file. Line number: 240-241." << std::endl;

    if(global_debug) std::cout<<prefix<<"Starting!"<<std::endl;
    //U为输入的图像(可以为整个网络的输入,也可以为某一层的feature map)
    //theta为前边计算得到的仿射矩阵
    const Dtype* U = bottom[0]->cpu_data();
    const Dtype* theta = bottom[1]->cpu_data();
    const Dtype* output_grid_data = output_grid.cpu_data();

    Dtype* input_grid_data = input_grid.mutable_cpu_data();
    Dtype* V = top[0]->mutable_cpu_data();

    caffe_set(input_grid.count(), (Dtype)0, input_grid_data);
    caffe_set(top[0]->count(), (Dtype)0, V);

    // for each input
    for(int i = 0; i < N; ++i) {

        Dtype* coordinates = input_grid_data + (output_H_ * output_W_ * 2) * i;
        //计算输出V中的每一个点的坐标在输入U中对应的坐标
        caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasTrans, output_H_ * output_W_, 2, 3, (Dtype)1.,
              output_grid_data, theta + 6 * i, (Dtype)0., coordinates);

        int row_idx; Dtype px, py;

        for(int j = 0; j < C; ++j)
            for(int s = 0; s < output_H_; ++s)
                for(int t = 0; t < output_W_; ++t) {

                    row_idx = output_W_ * s + t;

                    px = coordinates[row_idx * 2];
                    py = coordinates[row_idx * 2 + 1];
                    //该函数通过双线性插值得到V中每个点的像素值,具体实现见下一段代码
                    V[top[0]->offset(i, j, s, t)] = transform_forward_cpu(
                            U + bottom[0]->offset(i, j, 0, 0), px, py);
                }
    }

    if(global_debug) std::cout<<prefix<<"Finished."<<std::endl;
}
 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47

transform_forward_cpu:

Dtype SpatialTransformerLayer<Dtype>::transform_forward_cpu(const Dtype* pic, Dtype px, Dtype py) {

    bool debug = false;

    string prefix = "\t\tSpatial Transformer Layer:: transform_forward_cpu: \t";

    if(debug) std::cout<<prefix<<"Starting!\t"<<std::endl;
    if(debug) std::cout<<prefix<<"(px, py) = ("<<px<<", "<<py<<")"<<std::endl;

    Dtype res = (Dtype)0.;
    //Normalization到[-1,1]区间的坐标值还原到原来区间
    Dtype x = (px + 1) / 2 * H; 
    Dtype y = (py + 1) / 2 * W;

    if(debug) std::cout<<prefix<<"(x, y) = ("<<x<<", "<<y<<")"<<std::endl;

    int m, n; Dtype w;
    //下面4部分分别为双线性插值对应的4个点,res为插值得到的像素值
    m = floor(x); n = floor(y); w = 0;
    if(debug) std::cout<<prefix<<"1: (m, n) = ("<<m<<", "<<n<<")"<<std::endl;

    if(m >= 0 && m < H && n >= 0 && n < W) {
        w = max(0, 1 - abs(x - m)) * max(0, 1 - abs(y - n));
        res += w * pic[m * W + n];
        if(debug) std::cout<<prefix<<"w = "<<w<<", pic[m, n] = "<<pic[m * W + n]<<std::endl;
    }

    m = floor(x) + 1; n = floor(y); w = 0;
    if(debug) std::cout<<prefix<<"2: (m, n) = ("<<m<<", "<<n<<")"<<std::endl;

    if(m >= 0 && m < H && n >= 0 && n < W) {
        w = max(0, 1 - abs(x - m)) * max(0, 1 - abs(y - n));
        res += w * pic[m * W + n];
        if(debug) std::cout<<prefix<<"w = "<<w<<", pic[m, n] = "<<pic[m * W + n]<<std::endl;
    }

    m = floor(x); n = floor(y) + 1; w = 0;
    if(debug) std::cout<<prefix<<"3: (m, n) = ("<<m<<", "<<n<<")"<<std::endl;

    if(m >= 0 && m < H && n >= 0 && n < W) {
        w = max(0, 1 - abs(x - m)) * max(0, 1 - abs(y - n));
        res += w * pic[m * W + n];
        if(debug) std::cout<<prefix<<"w = "<<w<<", pic[m, n] = "<<pic[m * W + n]<<std::endl;
    }

    m = floor(x) + 1; n = floor(y) + 1; w = 0;
    if(debug) std::cout<<prefix<<"4: (m, n) = ("<<m<<", "<<n<<")"<<std::endl;

    if(m >= 0 && m < H && n >= 0 && n < W) {
        w = max(0, 1 - abs(x - m)) * max(0, 1 - abs(y - n));
        res += w * pic[m * W + n];
        if(debug) std::cout<<prefix<<"w = "<<w<<", pic[m, n] = "<<pic[m * W + n]<<std::endl;
    }

    if(debug) std::cout<<prefix<<"Finished. \tres = "<<res<<std::endl;

    return res;
}
 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58

backward:

void SpatialTransformerLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
    const vector<bool>& propagate_down,
    const vector<Blob<Dtype>*>& bottom) {

        string prefix = "\t\tSpatial Transformer Layer:: Backward_cpu: \t";

        CHECK(false) << "Don't use the CPU implementation! If you really want to, delete the" <<
                " CHECK in st_layer.cpp file. Line number: 420-421." << std::endl;

        if(global_debug) std::cout<<prefix<<"Starting!"<<std::endl;

        const Dtype* dV = top[0]->cpu_diff();
        const Dtype* input_grid_data = input_grid.cpu_data();
        const Dtype* U = bottom[0]->cpu_data();

        Dtype* dU = bottom[0]->mutable_cpu_diff();
        Dtype* dTheta = bottom[1]->mutable_cpu_diff();
        Dtype* input_grid_diff = input_grid.mutable_cpu_diff();

        caffe_set(bottom[0]->count(), (Dtype)0, dU);
        caffe_set(bottom[1]->count(), (Dtype)0, dTheta);
        caffe_set(input_grid.count(), (Dtype)0, input_grid_diff);

        for(int i = 0; i < N; ++i) {

            const Dtype* coordinates = input_grid_data + (output_H_ * output_W_ * 2) * i;
            Dtype* coordinates_diff = input_grid_diff + (output_H_ * output_W_ * 2) * i;

            int row_idx; Dtype px, py, dpx, dpy, delta_dpx, delta_dpy;

            for(int s = 0; s < output_H_; ++s)
                for(int t = 0; t < output_W_; ++t) {

                    row_idx = output_W_ * s + t;

                    px = coordinates[row_idx * 2];
                    py = coordinates[row_idx * 2 + 1];

                    for(int j = 0; j < C; ++j) {

                        delta_dpx = delta_dpy = (Dtype)0.;
                        //计算dx和dU,具体实现见下一段代码
                        transform_backward_cpu(dV[top[0]->offset(i, j, s, t)], U + bottom[0]->offset(i, j, 0, 0),
                                px, py, dU + bottom[0]->offset(i, j, 0, 0), delta_dpx, delta_dpy);

                        coordinates_diff[row_idx * 2] += delta_dpx;
                        coordinates_diff[row_idx * 2 + 1] += delta_dpy;
                    }

                    dpx = coordinates_diff[row_idx * 2];
                    dpy = coordinates_diff[row_idx * 2 + 1];

                    //这里对应了上面求dx/dTheta和dx/dTheta的部分
                    dTheta[6 * i] += dpx * (s * 1.0 / output_H_ * 2 - 1);
                    dTheta[6 * i + 1] += dpx * (t * 1.0 / output_W_ * 2 - 1);
                    dTheta[6 * i + 2] += dpx;
                    dTheta[6 * i + 3] += dpy * (s * 1.0 / output_H_ * 2 - 1);
                    dTheta[6 * i + 4] += dpy * (t * 1.0 / output_W_ * 2 - 1);
                    dTheta[6 * i + 5] += dpy;
                }
        }

        if(global_debug) std::cout<<prefix<<"Finished."<<std::endl;
}
 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64

transform_backward_cpu:

void SpatialTransformerLayer<Dtype>::transform_backward_cpu(Dtype dV, const Dtype* U, const Dtype px,
        const Dtype py, Dtype* dU, Dtype& dpx, Dtype& dpy) {

    bool debug = false;

    string prefix = "\t\tSpatial Transformer Layer:: transform_backward_cpu: \t";

    if(debug) std::cout<<prefix<<"Starting!"<<std::endl;

    //Normalization到[-1,1]区间的坐标值还原到原来区间
    Dtype x = (px + 1) / 2 * H; 
    Dtype y = (py + 1) / 2 * W;
    if(debug) std::cout<<prefix<<"(x, y) = ("<<x<<", "<<y<<")"<<std::endl;

    int m, n; Dtype w;

    //下面4部分也对应了双线性插值的那4个点
    m = floor(x); n = floor(y); w = 0;
    if(debug) std::cout<<prefix<<"(m, n) = ("<<m<<", "<<n<<")"<<std::endl;

    if(m >= 0 && m < H && n >= 0 && n < W) {
        w = max(0, 1 - abs(x - m)) * max(0, 1 - abs(y - n));

        //U中的(m,n)是V中(s,t)对应的点的坐标,则梯度向前传播
        dU[m * W + n] += w * dV;

        //对应上面的dV/dx和dV/dy
        if(abs(x - m) < 1) {
            if(m >= x) {
                dpx += max(0, 1 - abs(y - n)) * U[m * W + n] * dV * H / 2;
                if(debug) std::cout<<prefix<<"dpx += "<<max(0, 1 - abs(y - n))<<" * "<<U[m * W + n]<<" * "<<dV<<" * "<<H / 2<<std::endl;
            } else {
                dpx -= max(0, 1 - abs(y - n)) * U[m * W + n] * dV * H / 2;
                if(debug) std::cout<<prefix<<"dpx -= "<<max(0, 1 - abs(y - n))<<" * "<<U[m * W + n]<<" * "<<dV<<" * "<<H / 2<<std::endl;
            }
        }

        if(abs(y - n) < 1) {
            if(n >= y) {
                dpy += max(0, 1 - abs(x - m)) * U[m * W + n] * dV * W / 2;
                if(debug) std::cout<<prefix<<"dpy += "<<max(0, 1 - abs(x - m))<<" * "<<U[m * W + n]<<" * "<<dV<<" * "<<W / 2<<std::endl;
            } else {
                dpy -= max(0, 1 - abs(x - m)) * U[m * W + n] * dV * W / 2;
                if(debug) std::cout<<prefix<<"dpy -= "<<max(0, 1 - abs(x - m))<<" * "<<U[m * W + n]<<" * "<<dV<<" * "<<W / 2<<std::endl;
            }
        }
    }

    m = floor(x) + 1; n = floor(y); w = 0;
    if(debug) std::cout<<prefix<<"(m, n) = ("<<m<<", "<<n<<")"<<std::endl;

    if(m >= 0 && m < H && n >= 0 && n < W) {
        w = max(0, 1 - abs(x - m)) * max(0, 1 - abs(y - n));

        dU[m * W + n] += w * dV;

        if(abs(x - m) < 1) {
            if(m >= x) {
                dpx += max(0, 1 - abs(y - n)) * U[m * W + n] * dV * H / 2;
                if(debug) std::cout<<prefix<<"dpx += "<<max(0, 1 - abs(y - n))<<" * "<<U[m * W + n]<<" * "<<dV<<" * "<<H / 2<<std::endl;
            } else {
                dpx -= max(0, 1 - abs(y - n)) * U[m * W + n] * dV * H / 2;
                if(debug) std::cout<<prefix<<"dpx -= "<<max(0, 1 - abs(y - n))<<" * "<<U[m * W + n]<<" * "<<dV<<" * "<<H / 2<<std::endl;
            }
        }

        if(abs(y - n) < 1) {
            if(n >= y) {
                dpy += max(0, 1 - abs(x - m)) * U[m * W + n] * dV * W / 2;
                if(debug) std::cout<<prefix<<"dpy += "<<max(0, 1 - abs(x - m))<<" * "<<U[m * W + n]<<" * "<<dV<<" * "<<W / 2<<std::endl;
            } else {
                dpy -= max(0, 1 - abs(x - m)) * U[m * W + n] * dV * W / 2;
                if(debug) std::cout<<prefix<<"dpy -= "<<max(0, 1 - abs(x - m))<<" * "<<U[m * W + n]<<" * "<<dV<<" * "<<W / 2<<std::endl;
            }
        }
    }

    m = floor(x); n = floor(y) + 1; w = 0;
    if(debug) std::cout<<prefix<<"(m, n) = ("<<m<<", "<<n<<")"<<std::endl;

    if(m >= 0 && m < H && n >= 0 && n < W) {
        w = max(0, 1 - abs(x - m)) * max(0, 1 - abs(y - n));

        dU[m * W + n] += w * dV;

        if(abs(x - m) < 1) {
            if(m >= x) {
                dpx += max(0, 1 - abs(y - n)) * U[m * W + n] * dV * H / 2;
                if(debug) std::cout<<prefix<<"dpx += "<<max(0, 1 - abs(y - n))<<" * "<<U[m * W + n]<<" * "<<dV<<" * "<<H / 2<<std::endl;
            } else {
                dpx -= max(0, 1 - abs(y - n)) * U[m * W + n] * dV * H / 2;
                if(debug) std::cout<<prefix<<"dpx -= "<<max(0, 1 - abs(y - n))<<" * "<<U[m * W + n]<<" * "<<dV<<" * "<<H / 2<<std::endl;
            }
        }

        if(abs(y - n) < 1) {
            if(n >= y) {
                dpy += max(0, 1 - abs(x - m)) * U[m * W + n] * dV * W / 2;
                if(debug) std::cout<<prefix<<"dpy += "<<max(0, 1 - abs(x - m))<<" * "<<U[m * W + n]<<" * "<<dV<<" * "<<W / 2<<std::endl;
            } else {
                dpy -= max(0, 1 - abs(x - m)) * U[m * W + n] * dV * W / 2;
                if(debug) std::cout<<prefix<<"dpy -= "<<max(0, 1 - abs(x - m))<<" * "<<U[m * W + n]<<" * "<<dV<<" * "<<W / 2<<std::endl;
            }
        }
    }

    m = floor(x) + 1; n = floor(y) + 1; w = 0;
    if(debug) std::cout<<prefix<<"(m, n) = ("<<m<<", "<<n<<")"<<std::endl;

    if(m >= 0 && m < H && n >= 0 && n < W) {
        w = max(0, 1 - abs(x - m)) * max(0, 1 - abs(y - n));

        dU[m * W + n] += w * dV;

        if(abs(x - m) < 1) {
            if(m >= x) {
                dpx += max(0, 1 - abs(y - n)) * U[m * W + n] * dV * H / 2;
                if(debug) std::cout<<prefix<<"dpx += "<<max(0, 1 - abs(y - n))<<" * "<<U[m * W + n]<<" * "<<dV<<" * "<<H / 2<<std::endl;
            } else {
                dpx -= max(0, 1 - abs(y - n)) * U[m * W + n] * dV * H / 2;
                if(debug) std::cout<<prefix<<"dpx -= "<<max(0, 1 - abs(y - n))<<" * "<<U[m * W + n]<<" * "<<dV<<" * "<<H / 2<<std::endl;
            }
        }

        if(abs(y - n) < 1) {
            if(n >= y) {
                dpy += max(0, 1 - abs(x - m)) * U[m * W + n] * dV * W / 2;
                if(debug) std::cout<<prefix<<"dpy += "<<max(0, 1 - abs(x - m))<<" * "<<U[m * W + n]<<" * "<<dV<<" * "<<W / 2<<std::endl;
            } else {
                dpy -= max(0, 1 - abs(x - m)) * U[m * W + n] * dV * W / 2;
                if(debug) std::cout<<prefix<<"dpy -= "<<max(0, 1 - abs(x - m))<<" * "<<U[m * W + n]<<" * "<<dV<<" * "<<W / 2<<std::endl;
            }
        }
    }

    if(debug) std::cout<<prefix<<"Finished."<<std::endl;
}

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值