这两天看到一个很好玩的图像裁剪算法,是十年前的提出的一个很经典的算法。原来的图像裁剪算法大多会使得图像失真,而这个算法提出一种基于能量的原理来相适应的去裁剪。简单来说就是给每个像素值赋予一个能量值,然后根据这个像素值得8连通域进行动态规划求取最小值,然后逐行或逐列地应用这一算法会获得一条能量线,其实就是相邻两行(列)的最小像素值的像素所在连线,然后将其从原始图像删除,具体删除多少条这样的能连线,根据需要裁剪的scale而定。上面说的能量其实和像素梯度没什么区别,在图像里面很重要的信息其边缘轮廓纹理等会变化很大,梯度也会很明显,那么就可以认为这样的重要内容的能量很大。
paper:下载
论文给出的基本能量方程:
这里其实就是根据sobel算子提取x和y轴的梯度,然后绝对值相加
算法的基本步骤:
1.为每个像素分配一个能量值
2.找到能量值最小的像素的八连通路径
3.删除路径中的所有像素
- 4.重复前面1-3步,直到删除的行/列数量达到理想状态
论文中给出的动态规划算法求最小能量值的方法:、
然后是代码测试:
import sys
import numpy as np
from imageio import imread, imwrite
from scipy.ndimage.filters import convolve
import cv2
import numba
import matplotlib.pyplot as plt
# 图像处理时候查看处理进度条
from tqdm import trange
def calc_energy(img):
'''
这里计算原始图像的能量图,其实也就是计算梯度
'''
# 这里其实就是X和Y轴的sobel算子
filter_du = np.array([
[1.0, 2.0, 1.0],
[0.0, 0.0, 0.0],
[-1.0, -2.0, -1.0]])
# RGB三颜色通道,为每个通道都复制一份相同的滤波器
filter_du = np.stack([filter_du] * 3, axis=2)
filter_dv = np.array([
[1.0, 0.0, -1.0],
[2.0, 0.0, -2.0],
[1.0, 0.0, -1.0]])
filter_dv = np.stack([filter_dv] * 3, axis=2)
print(type(img))
img = img.astype('float32')
# 然后根据论文中的能量最基本的公式计算原始图像的能量图
convolved = np.absolute(convolve(img, filter_du)) + np.absolute(convolve(img, filter_dv))
# 将RGB三颜色通道的能量进行相加得到原始图的能量图
energy_map = convolved.sum(axis=2)
return energy_map
# @numba.jit
def minimum_seam(img):
r, c, _ = img.shape
energy_map = calc_energy(img)
imwrite(energy_map, 'energy.jpg')
M = energy_map.copy()
backtrack = np.zeros_like(M, dtype=np.int)
# 从第二行开始
for i in range(1, r):
for j in range(0, c):
# 这里处理左侧边缘部分,确保数组不会越界
# 使用动态规划算法求解最小值
if j == 0:
idx = np.argmin(M[i-1, j:j+2])
backtrack[i, j] = idx + j
min_energy = M[i-1, idx + j]
else:
idx = np.argmin(M[i-1, j-1:j+2])
backtrack[i, j] = idx + j - 1
min_energy = M[i-1, idx + j - 1]
M[i, j] += min_energy
return M, backtrack
# @numba.jit
def carve_column(img):
r, c, _ = img.shape
M, backtrack = minimum_seam(img)
# 创建一个和原始图一样大小的矩阵,初始化为True
mask = np.ones((r, c), dtype=np.bool)
# 逐行找到最小元素的位置,采用的是动态规划算法
j = np.argmin(M[-1])
for i in reversed(range(r)):
# 把需要删除的像素位置标记
mask[i, j] = False
j = backtrack[i, j]
# RGB三颜色通道同时标记
mask = np.stack([mask]*3, axis=2)
# reshape成和原始图像大小一致的维度
img = img[mask].reshape((r, c - 1, 3))
return img, mask, M
def crop_c(img, scale_c):
r, c, _ = img.shape
new_c = int(scale_c * c)
mask, M = None
# 删除每次的最小能量线, 由scale_c控制做几次
for i in trange(c - new_c):
img, _, _ = carve_column(img)
return img
def crop_r(img, scale_r):
# 对原始图像矩阵进行旋转90°
img = np.rot90(img, 1, (0, 1))
# 本来进行列的删除,现在旋转后进行行的删除,然后在做矩阵的三次旋转旋转回来
img = crop_c(img, scale_r)
img = np.rot90(img, 3, (0, 1))
return img
def main():
if len(sys.argv) != 6:
print('usage:carver.py<r/c> <scale> <image_in> <image_out>', file=sys.stderr)
sys.exit(1)
which_axis = sys.argv[1]
scale = float(sys.argv[2])
in_filename = sys.argv[3]
out_filename = sys.argv[4]
img = imread(in_filename)
r, c, _ = img.shape
print(r)
print(c)
if which_axis == 'r':
out = crop_r(img, scale)
elif which_axis == 'c':
out = crop_c(img, scale)
else:
print('usage:carver.py<r/c/a> <scale> <image_in> <image_out>', file=sys.stderr)
sys.exit(1)
imwrite(out_filename, out)
if __name__ == '__main__':
# main()
img = imread('screen.jpg')
energy = calc_energy(img)
plt.imshow(energy)
plt.show()



