TNN 的 resize 虽然分通道提供了多个接口,但底层是一起的。整个实现对于灰度图优化非常有限,而3通道或4通道的图像会有加速。缩放的映射关系较为简单,主要分为三步:
- 一维位置索引和插值系数计算;
- 行内像素插值;
- 相邻行结果插值。
ResizeBilinearC1
return ResizeBilinearC1Impl(src, batch, src_w, src_h, src_w, dst, w, h, w);
ResizeBilinearC1Impl
ResizeBilinearPreparation 在行列方向上独立计算对应源图上的位置和系数。
根据线程数开辟rows0
和rows1
缓存,rows0_t
和rows1_t
为所属各线程的指针。
OMP_PARALLEL_FOR_
OMP_TID_ 调用 omp_get_thread_num
prev_sy
数组记录每个线程上次处理的源图上的行。而 OpenMP 下每个目的行都是独立执行的,复用靠运气?
ResizeBilinearOneRow 传入一个比较复杂的 ResizeBilinearKernelParm 结构体。
ResizeBilinearPreparation(1);
ResizeBilinearKernelParm param(xofs, yofs, ialpha, ibeta, src, dst, src_plane, src_stride, schannel);
// loop body
int max_num_threads = OMP_MAX_THREADS_NUM_;
short* rows0 = new short[w * max_num_threads];
short* rows1 = new short[w * max_num_threads];
short* rows0_t[max_num_threads];
short* rows1_t[max_num_threads];
int prev_sy[max_num_threads];
for (int b = 0; b < batch; ++b) {
for (int t = 0; t < max_num_threads; ++t) {
prev_sy[t] = -2;
rows0_t[t] = rows0 + t * w;
rows1_t[t] = rows1 + t * w;
}
OMP_PARALLEL_FOR_
for (int dy = 0; dy < h; dy++) {
int thread_id = OMP_TID_;
ResizeBilinearOneRow(param, thread_id, rows0_t, rows1_t, prev_sy, b, w, h, stride, dy);
}
}
delete[] rows0;
delete[] rows1;
delete[] buf;
ResizeBilinearPreparation
GetResizeBuf 开辟一块公共连续内存并计算输入图上的对应位置(xofs
和yofs
)和一维插值系数(ialpha
和ibeta
)。
内部计算好后外部通过指针访问,不如 OpenCV 直观。
#define ResizeBilinearPreparation(channel) \
int schannel = channel; \
int* buf = nullptr; \
GetResizeBuf(src_w, src_h, w, h, schannel, &buf); \
int* xofs = buf; \
int* yofs = buf + w; \
short* ialpha = (short*)(buf + w + h); \
short* ibeta = (short*)(buf + w + h + w); \
int src_plane = src_h * src_stride;
GetResizeBuf
GetResizeBufPreparation 开辟一段连续内存并定义4个参数指针。
CalculatePositionAndRatio 计算对应位置和一维插值系数。
横向和纵向操作是相同的。
GetResizeBufPreparation(short);
CalculatePositionAndRatio(w, scale_x, src_w, c, xofs, ialpha);
CalculatePositionAndRatio(h, scale_y, src_h, 1, yofs, ibeta);
GetResizeBufPreparation
源除以目的。
#define GetResizeBufPreparation(type) \
double scale_x = (double)src_w / w; \
double scale_y = (double)src_h / h; \
*buf = new int[w + h + w + h]; \
int* xofs = *buf; \
int* yofs = *buf + w; \
type* ialpha = (type*)(*buf + w + h); \
type* ibeta = (type*)(*buf + w + h + w);
CalculatePositionAndRatio
CalculatePosition 根据索引和缩放系数计算源图上对应的位置并修剪到边界内,返回位置比例系数rat_f
。
将一组比例系数放大以整数形式保存。
const int INTER_RESIZE_COEF_BITS = 11;
const int INTER_RESIZE_COEF_SCALE = 1 << INTER_RESIZE_COEF_BITS;
for (int i = 0; i < length; i++) {
float rat_f = CalculatePosition(position, i, scale, border, channel);
float a0 = (1.f - rat_f) * INTER_RESIZE_COEF_SCALE;
float a1 = rat_f * INTER_RESIZE_COEF_SCALE;
ratio[i * 2] = SATURATE_CAST_SHORT(a0);
ratio[i * 2 + 1] = SATURATE_CAST_SHORT(a1);
}
ResizeBilinearOneRow
借助yofs
得到对应源图上的行。
ResizeGetAdjacentRows 对于单通道无优化,将同行两元素加权。
ResizeCalculateOneRow 由两个相邻的源行得到目的行。
int sy = param.yofs[dy];
ResizeGetAdjacentRows(sy, prev_sy[thread_id], &rows0_t[thread_id], &rows1_t[thread_id], param.xofs,
param.src + b * param.src_plane, param.src_stride, param.schannel, w, param.ialpha);
prev_sy[thread_id] = sy;
// vresize
short b0 = param.ibeta[dy * 2];
short b1 = param.ibeta[dy * 2 + 1];
uint8_t* Dp = param.dst + stride * (b * h + dy);
ResizeCalculateOneRow(rows0_t[thread_id], rows1_t[thread_id], b0, b1, w, param.schannel, Dp);
ResizeGetAdjacentRows
如果prev_sy
是sy
的前一行,那么意味着sy
行不需要处理了,只处理sy + 1
行。
S1
指向了sy + 1
行。
S1p
为当前行中位置。
x 2 − x x 2 − x 1 f ( Q 21 ) + x − x 1 x 2 − x 1 f ( Q 22 ) = α 0 f ( Q 21 ) + α 1 f ( Q 22 ) \frac{x_2 -x}{x_2 -x_1}f(Q_{21}) + \frac{x-x_1}{x_2 -x_1}f(Q_{22})= \alpha_0 f(Q_{21}) + \alpha_1 f(Q_{22}) x2−x1x2−xf(Q21)+x2−x1x−x1f(Q22)=α0f(Q21)+α1f(Q22)
不论通道数量如何,每次循环处理一个目的像素。
加权结果右移4位。
if (sy == prev_sy) {
// reuse all rows
} else if (sy == prev_sy + 1) {
// hresize one row
short* rows0_old = *rows0;
*rows0 = *rows1;
*rows1 = rows0_old;
const uint8_t* S1 = src + src_stride * (sy + 1);
short* rows1p = *rows1;
for (int dx = 0; dx < w; dx++) {
int sx = xofs[dx];
short a0 = ialphap[0];
short a1 = ialphap[1];
const uint8_t* S1p = S1 + sx;
#ifndef TNN_USE_NEON
for (int dc = 0; dc < c; ++dc) {
rows1p[dc] = (S1p[dc] * a0 + S1p[dc + c] * a1) >> 4;
}
#else
if (c == 2) {
int16x4_t _a0a1XX = vld1_s16(ialphap);
int16x4_t _a0a0a1a1 = vzip_s16(_a0a1XX, _a0a1XX).val[0];
uint8x8_t _S1 = uint8x8_t();
_S1 = vld1_lane_u8(S1p, _S1, 0);
_S1 = vld1_lane_u8(S1p + 1, _S1, 1);
_S1 = vld1_lane_u8(S1p + 2, _S1, 2);
_S1 = vld1_lane_u8(S1p + 3, _S1, 3);
int16x8_t _S116 = vreinterpretq_s16_u16(vmovl_u8(_S1));
int16x4_t _S1lowhigh = vget_low_s16(_S116);
int32x4_t _S1ma0a1 = vmull_s16(_S1lowhigh, _a0a0a1a1);
int32x2_t _rows1low = vadd_s32(vget_low_s32(_S1ma0a1), vget_high_s32(_S1ma0a1));
int32x4_t _rows1 = vcombine_s32(_rows1low, vget_high_s32(_S1ma0a1));
int16x4_t _rows1_sr4 = vshrn_n_s32(_rows1, 4);
vst1_s16(rows1p, _rows1_sr4);
} else if (c == 3) {
int16x4_t _a0 = vdup_n_s16(a0);
int16x4_t _a1 = vdup_n_s16(a1);
uint8x8_t _S1 = uint8x8_t();
_S1 = vld1_lane_u8(S1p, _S1, 0);
_S1 = vld1_lane_u8(S1p + 1, _S1, 1);
_S1 = vld1_lane_u8(S1p + 2, _S1, 2);
_S1 = vld1_lane_u8(S1p + 3, _S1, 3);
_S1 = vld1_lane_u8(S1p + 4, _S1, 4);
_S1 = vld1_lane_u8(S1p + 5, _S1, 5);
int16x8_t _S116 = vreinterpretq_s16_u16(vmovl_u8(_S1));
int16x4_t _S1low = vget_low_s16(_S116);
int16x4_t _S1high = vext_s16(_S1low, vget_high_s16(_S116), 3);
int32x4_t _rows1 = vmull_s16(_S1low, _a0);
_rows1 = vmlal_s16(_rows1, _S1high, _a1);
int16x4_t _rows1_sr4 = vshrn_n_s32(_rows1, 4);
vst1_s16(rows1p, _rows1_sr4);
} else if (c == 4) {
int16x4_t _a0 = vdup_n_s16(a0);
int16x4_t _a1 = vdup_n_s16(a1);
uint8x8_t _S1 = vld1_u8(S1p);
int16x8_t _S116 = vreinterpretq_s16_u16(vmovl_u8(_S1));
int16x4_t _S1low = vget_low_s16(_S116);
int16x4_t _S1high = vget_high_s16(_S116);
int32x4_t _rows1 = vmull_s16(_S1low, _a0);
_rows1 = vmlal_s16(_rows1, _S1high, _a1);
int16x4_t _rows1_sr4 = vshrn_n_s32(_rows1, 4);
vst1_s16(rows1p, _rows1_sr4);
} else {
for (int dc = 0; dc < c; ++dc) {
rows1p[dc] = (S1p[dc] * a0 + S1p[dc + c] * a1) >> 4;
}
}
#endif
ialphap += 2;
rows1p += c;
}
否则分别缩放sy
和sy + 1
两行。
α
0
f
(
Q
11
)
+
α
1
f
(
Q
12
)
\alpha_0 f(Q_{11}) + \alpha_1 f(Q_{12})
α0f(Q11)+α1f(Q12) 和
α
0
f
(
Q
21
)
+
α
1
f
(
Q
22
)
\alpha_0 f(Q_{21}) + \alpha_1 f(Q_{22})
α0f(Q21)+α1f(Q22)
vdup_n_s16 广播参数。
vld1_lane_u8 逐个加载源像素。
对于两通道的图,_S0
和_S1
使用了一半;三通道使用了3/4。
_S016
和_S116
为临时变量,_S0lowhigh
和_S1lowhigh
取其前4个元素。_S0S1low_S0S1high
分开存储两个目的点的值。
vmovl_u8 将_S0
和_S1
中的元素值左移,vreinterpretq_s16_u16 转为有符号数。
vget_low_s16 取出_S0
和_S1
中有效的一半元素。
vtrn_s32 转置元素。在转置前使用 vreinterpret_s32_s16,确保同一像素的两个通道在一起,而两行的左源像素在一起,右源像素在一起。即_S0S1low_S0S1high.val[0]
为
[
f
(
Q
11
)
,
f
(
Q
21
)
]
[f(Q_{11}), f(Q_{21})]
[f(Q11),f(Q21)],_S0S1low_S0S1high.val[1]
为
[
f
(
Q
12
)
,
f
(
Q
22
)
]
[f(Q_{12}), f(Q_{22})]
[f(Q12),f(Q22)]。
_rows01
向量中分别为sy
行和sy+1
行的加权和。
vext_s16 从第二个操作数向量的低端提取 n 个元素,从第一个操作数的高端提取其余元素,并将它们组合以形成结果向量。 第二个操作数的元素放置在结果向量的最高有效部分中,第一个操作数的元素放置在结果向量的最低有效部分中。这里两个输入向量相同,则会将其循环移位。
vst1_s16 存储4个元素。这里有越写覆盖问题,所以申请内存的时候尾部多一点。两次调用将_rows01_sr4
中的两部分分别保存下来。
} else {
// hresize two rows
const uint8_t* S0 = src + src_stride * (sy);
const uint8_t* S1 = src + src_stride * (sy + 1);
short* rows0p = *rows0;
short* rows1p = *rows1;
for (int dx = 0; dx < w; dx++) {
int sx = xofs[dx];
short a0 = ialphap[0];
short a1 = ialphap[1];
const uint8_t* S0p = S0 + sx;
const uint8_t* S1p = S1 + sx;
#ifndef TNN_USE_NEON
for (int dc = 0; dc < c; ++dc) {
rows0p[dc] = (S0p[dc] * a0 + S0p[dc + c] * a1) >> 4;
rows1p[dc] = (S1p[dc] * a0 + S1p[dc + c] * a1) >> 4;
}
#else
if (c == 2) {
int16x4_t _a0 = vdup_n_s16(a0);
int16x4_t _a1 = vdup_n_s16(a1);
uint8x8_t _S0 = uint8x8_t();
uint8x8_t _S1 = uint8x8_t();
_S0 = vld1_lane_u8(S0p, _S0, 0);
_S0 = vld1_lane_u8(S0p + 1, _S0, 1);
_S0 = vld1_lane_u8(S0p + 2, _S0, 2);
_S0 = vld1_lane_u8(S0p + 3, _S0, 3);
_S1 = vld1_lane_u8(S1p, _S1, 0);
_S1 = vld1_lane_u8(S1p + 1, _S1, 1);
_S1 = vld1_lane_u8(S1p + 2, _S1, 2);
_S1 = vld1_lane_u8(S1p + 3, _S1, 3);
int16x8_t _S016 = vreinterpretq_s16_u16(vmovl_u8(_S0));
int16x8_t _S116 = vreinterpretq_s16_u16(vmovl_u8(_S1));
int16x4_t _S0lowhigh = vget_low_s16(_S016);
int16x4_t _S1lowhigh = vget_low_s16(_S116);
int32x2x2_t _S0S1low_S0S1high = vtrn_s32(vreinterpret_s32_s16(_S0lowhigh), vreinterpret_s32_s16(_S1lowhigh));
int32x4_t _rows01 = vmull_s16(vreinterpret_s16_s32(_S0S1low_S0S1high.val[0]), _a0);
_rows01 = vmlal_s16(_rows01, vreinterpret_s16_s32(_S0S1low_S0S1high.val[1]), _a1);
int16x4_t _rows01_sr4 = vshrn_n_s32(_rows01, 4);
int16x4_t _rows1_sr4 = vext_s16(_rows01_sr4, _rows01_sr4, 2);
vst1_s16(rows0p, _rows01_sr4);
vst1_s16(rows1p, _rows1_sr4);
三通道和四通道的加载稍有不同,计算方式相同。
} else if (c == 3) {
int16x4_t _a0 = vdup_n_s16(a0);
int16x4_t _a1 = vdup_n_s16(a1);
uint8x8_t _S0 = uint8x8_t();
uint8x8_t _S1 = uint8x8_t();
_S0 = vld1_lane_u8(S0p, _S0, 0);
_S0 = vld1_lane_u8(S0p + 1, _S0, 1);
_S0 = vld1_lane_u8(S0p + 2, _S0, 2);
_S0 = vld1_lane_u8(S0p + 3, _S0, 3);
_S0 = vld1_lane_u8(S0p + 4, _S0, 4);
_S0 = vld1_lane_u8(S0p + 5, _S0, 5);
_S1 = vld1_lane_u8(S1p, _S1, 0);
_S1 = vld1_lane_u8(S1p + 1, _S1, 1);
_S1 = vld1_lane_u8(S1p + 2, _S1, 2);
_S1 = vld1_lane_u8(S1p + 3, _S1, 3);
_S1 = vld1_lane_u8(S1p + 4, _S1, 4);
_S1 = vld1_lane_u8(S1p + 5, _S1, 5);
int16x8_t _S016 = vreinterpretq_s16_u16(vmovl_u8(_S0));
int16x8_t _S116 = vreinterpretq_s16_u16(vmovl_u8(_S1));
int16x4_t _S0low = vget_low_s16(_S016);
int16x4_t _S1low = vget_low_s16(_S116);
int16x4_t _S0high = vext_s16(_S0low, vget_high_s16(_S016), 3);
int16x4_t _S1high = vext_s16(_S1low, vget_high_s16(_S116), 3);
int32x4_t _rows0 = vmull_s16(_S0low, _a0);
int32x4_t _rows1 = vmull_s16(_S1low, _a0);
_rows0 = vmlal_s16(_rows0, _S0high, _a1);
_rows1 = vmlal_s16(_rows1, _S1high, _a1);
int16x4_t _rows0_sr4 = vshrn_n_s32(_rows0, 4);
int16x4_t _rows1_sr4 = vshrn_n_s32(_rows1, 4);
vst1_s16(rows0p, _rows0_sr4);
vst1_s16(rows1p, _rows1_sr4);
} else if (c == 4) {
int16x4_t _a0 = vdup_n_s16(a0);
int16x4_t _a1 = vdup_n_s16(a1);
uint8x8_t _S0 = vld1_u8(S0p);
uint8x8_t _S1 = vld1_u8(S1p);
int16x8_t _S016 = vreinterpretq_s16_u16(vmovl_u8(_S0));
int16x8_t _S116 = vreinterpretq_s16_u16(vmovl_u8(_S1));
int16x4_t _S0low = vget_low_s16(_S016);
int16x4_t _S1low = vget_low_s16(_S116);
int16x4_t _S0high = vget_high_s16(_S016);
int16x4_t _S1high = vget_high_s16(_S116);
int32x4_t _rows0 = vmull_s16(_S0low, _a0);
int32x4_t _rows1 = vmull_s16(_S1low, _a0);
_rows0 = vmlal_s16(_rows0, _S0high, _a1);
_rows1 = vmlal_s16(_rows1, _S1high, _a1);
int16x4_t _rows0_sr4 = vshrn_n_s32(_rows0, 4);
int16x4_t _rows1_sr4 = vshrn_n_s32(_rows1, 4);
vst1_s16(rows0p, _rows0_sr4);
vst1_s16(rows1p, _rows1_sr4);
} else {
for (int dc = 0; dc < c; ++dc) {
rows0p[dc] = (S0p[dc] * a0 + S0p[dc + c] * a1) >> 4;
rows1p[dc] = (S1p[dc] * a0 + S1p[dc + c] * a1) >> 4;
}
}
#endif
ialphap += 2;
rows0p += c;
rows1p += c;
}
}
ResizeCalculateOneRow
不区分通道,每次处理8个元素。
_v2
对最终结果实现四舍五入。
#ifndef TNN_USE_NEON
int remain = w * c;
#else
int nn = (w * c) >> 3;
int remain = (w * c) - (nn << 3);
int16x4_t _b0 = vdup_n_s16(b0);
int16x4_t _b1 = vdup_n_s16(b1);
int32x4_t _v2 = vdupq_n_s32(2);
_rows0p_sr4_mb0
和_rows0p_1_sr4_mb0
为
β
0
(
α
0
f
(
Q
11
)
+
α
1
f
(
Q
12
)
)
\beta_0(\alpha_0 f(Q_{11}) + \alpha_1 f(Q_{12}))
β0(α0f(Q11)+α1f(Q12)),
_rows1p_sr4_mb1
和_rows1p_1_sr4_mb1
为
β
1
(
α
0
f
(
Q
21
)
+
α
1
f
(
Q
22
)
)
\beta_1(\alpha_0 f(Q_{21}) + \alpha_1 f(Q_{22}))
β1(α0f(Q21)+α1f(Q22))。
vsraq_n_s32 向量右移常数位并累加。
中间结果右移16位,然后右移两位。
_acc16
和_acc16_1
运算结果。vcombine_s16 将二者拼接起来。vqmovun_s16 为有符号整数到无符号的向量饱和窄指令。
for (; nn > 0; nn--) {
int16x4_t _rows0p_sr4 = vld1_s16(rows0p);
int16x4_t _rows1p_sr4 = vld1_s16(rows1p);
int16x4_t _rows0p_1_sr4 = vld1_s16(rows0p + 4);
int16x4_t _rows1p_1_sr4 = vld1_s16(rows1p + 4);
int32x4_t _rows0p_sr4_mb0 = vmull_s16(_rows0p_sr4, _b0);
int32x4_t _rows1p_sr4_mb1 = vmull_s16(_rows1p_sr4, _b1);
int32x4_t _rows0p_1_sr4_mb0 = vmull_s16(_rows0p_1_sr4, _b0);
int32x4_t _rows1p_1_sr4_mb1 = vmull_s16(_rows1p_1_sr4, _b1);
int32x4_t _acc = _v2;
_acc = vsraq_n_s32(_acc, _rows0p_sr4_mb0, 16);
_acc = vsraq_n_s32(_acc, _rows1p_sr4_mb1, 16);
int32x4_t _acc_1 = _v2;
_acc_1 = vsraq_n_s32(_acc_1, _rows0p_1_sr4_mb0, 16);
_acc_1 = vsraq_n_s32(_acc_1, _rows1p_1_sr4_mb1, 16);
int16x4_t _acc16 = vshrn_n_s32(_acc, 2);
int16x4_t _acc16_1 = vshrn_n_s32(_acc_1, 2);
uint8x8_t _D = vqmovun_s16(vcombine_s16(_acc16, _acc16_1));
vst1_u8(Dp, _D);
Dp += 8;
rows0p += 8;
rows1p += 8;
}
#endif
for (; remain; --remain) {
*Dp++ = (uint8_t)(
((short)((b0 * (short)(*rows0p++)) >> 16) + (short)((b1 * (short)(*rows1p++)) >> 16) + 2) >> 2);
}