# -*- coding: utf-8 -*-
"""
Created on Tue Nov 26 08:57:10 2019
@author: lee
"""
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import sys
sys.setrecursionlimit(1000000) #调整递归深度
#定义树的结点类
class Tree_Nodes(object):
def __init__(self,value = None, left = None, right = None, father = None):
self.value = value
self.left = left
self.right = right
self.father = father
def build_father(left,right):
n = Tree_Nodes(value = left.value + right.value, left = left, right = right)
left.father = right.father = n
return n
def encode(n):#沿二叉树迭代进行编码
if n.father == None:
return b''
if n.father.left == n:#若该结点的父节点的左儿子和该结点相等
return Tree_Nodes.encode(n.father) + b'0'#那么该结点为左结点,编码为0
else:
return Tree_Nodes.encode(n.father) + b'1'#否则为右结点,编码为1
#取rank个奇异值,获得近似的图像矩阵
def get_approx_matrix(U,sigma,V,rank):
m = len(U)
n = len(V)
A = np.zeros((m,n))
k = 0
#rank = np.linalg.matrix_rank(A)#rank为矩阵A的秩
while k < rank: #截断式SVD的保留的奇异值个数为rank个,即删去值为0的奇异值
uk = U[:, k].reshape(m, 1)
vk = V[k].reshape(1, n)
A += sigma[k] * np.dot(uk,vk)
k += 1
A[A < 0] = 0
A[A > 255] = 255
return A.astype("uint8")
#对图像进行SVD压缩
def get_svd_image(file_path):
name, suffix = file_path.split(".")
img = Image.open(file_path,'r')
ax = plt.subplot(1,1,1)
ax.imshow(mpimg.imread(file_path))
ax.set_title("Original Image")
ax.set_axis_off()
plt.show()
print("\norigin_image——size:", img.size)
A = np.array(img)
u0, sigma0, v0 = np.linalg.svd(A[:, :, 0])
u1, sigma1, v1 = np.linalg.svd(A[:, :, 1])
u2, sigma2, v2 = np.linalg.svd(A[:, :, 2])
print("\norigin_image——red_rank:", len(sigma0),
"green_rank:", len(sigma1),
"blue_rank:", len(sigma2))
# rank = np.linalg.matrix_rank(A)#rank为矩阵A的秩
# print("rank_A(the number of singular values):", rank)
rank = 100
red_matrix = get_approx_matrix(u0, sigma0, v0, rank)
green_matrix = get_approx_matrix(u1, sigma1, v1, rank)
blue_matrix = get_approx_matrix(u2, sigma2, v2, rank)
I = np.stack((red_matrix, green_matrix, blue_matrix), 2)
Image.fromarray(I).save(name + str(rank) + "." + suffix)
print("\nSVD compression OK!")
#构建哈夫曼树
def build_huffman_tree(l):
if len(l) == 1:
return l
sorts = sorted(l, key = lambda x:x.value, reverse = False)
n = Tree_Nodes.build_father(sorts[0], sorts[1])
sorts.pop(0)
sorts.pop(0)
sorts.append(n)
return build_huffman_tree(sorts)
#利用构造好的哈夫曼树进行编码
def huffman_encode(echo):
for x in node_dict.keys():
ec_dict[x] = Tree_Nodes.encode(node_dict[x])
if echo == True:
print(x)
print(ec_dict[x])#输出键值对
#继续对图片进行Huffman编码
def encodeImage(inputfile):
print("\nStarting huffman_encode... \n")
img = open(inputfile, "rb")
bytes_width = 1
i = 0
img.seek(0, 2)
count = img.tell() / bytes_width
print("Bite Size of SVD_image:", count)
#统计字符出现的频率
nodes = []
buff = [b''] * int(count)
img.seek(0)
while i < count:
buff[i] = img.read(bytes_width)
if count_dict.get(buff[i], -1) == -1:
count_dict[buff[i]] = 0
count_dict[buff[i]] = count_dict[buff[i]] + 1
i = i + 1
print("\nRead OK! \n")
print("The count_dict:", count_dict) #输出权值字典,可注释掉
#根据权值字典都建Huffman树,进行Huffman编码,构造编码表
for x in count_dict.keys():
node_dict[x] = Tree_Nodes(count_dict[x])
nodes.append(node_dict[x])
img.close()
build_huffman_tree(nodes)
huffman_encode(False)
print("\nEncode OK!")
#处理文件头
head = sorted(count_dict.items(), key = lambda x:x[1], reverse = True)
bit_width = 1
print("\nhead:", head[0][1])
if head[0][1] > 255:
bit_width = 2
if head[0][1] > 65535:
bit_width = 3
if head[0][1] > 16777215:
bit_width = 4
print("\nbit_width:", bit_width)
i = 0
raw = 0b1
last = 0
name = inputfile.split('.')
o = open(name[0] + ".ys", 'wb')
name = inputfile.split('/')
o.write((name[len(name)-1] + '\n').encode(encoding="utf-8"))
o.write(int.to_bytes(len(ec_dict) ,2 ,byteorder = 'big'))
o.write(int.to_bytes(bit_width ,1 ,byteorder = 'big'))
for x in ec_dict.keys():
o.write(x)
o.write(int.to_bytes(count_dict[x], bit_width, byteorder = 'big'))
print('\nhead OK!\n')
while i < count:
for x in ec_dict[buff[i]]:
raw = raw << 1
if x == 49:
raw = raw | 1
if raw.bit_length() == 9:
raw = raw & (~(1 << 8))
o.write(int.to_bytes(raw, 1, byteorder = 'big'))
o.flush()
raw = 0b1
tem = int(i / len(buff) * 100)
if tem > last:
print("encode:", tem, '%')
last = tem
i = i + 1
if raw.bit_length() > 1:
raw = raw << (8 - (raw.bit_length() - 1))
raw = raw & (~(1 << raw.bit_length() - 1))
o.write(int.to_bytes(raw, 1, byteorder = 'big'))
o.close()
print("\nFile encode successful!\nYou can find the compressed file in the folder where the original image is stored.")
#文件解压
def decodeImage(inputfile):
print("Starting decode...\n")
count = 0
raw = 0
last = 0
img = open(inputfile ,'rb')
img.seek(0,2)
eof = img.tell()
img.seek(0)
name = inputfile.split('/')
outputfile = inputfile.replace(name[len(name)-1], img.readline().decode(encoding="utf-8"))
o = open(outputfile.replace('\n','') ,'wb')
count = int.from_bytes(img.read(2), byteorder = 'big') #取出结点数量
bit_width = int.from_bytes(img.read(1), byteorder = 'big') #取出编码表字宽
i = 0
de_dict = {}
while i < count: #解析文件头
key = img.read(1)
value = int.from_bytes(img.read(bit_width), byteorder = 'big')
de_dict[key] = value
i = i + 1
for x in de_dict.keys():
node_dict[x] = Tree_Nodes(de_dict[x])
nodes.append(node_dict[x])
build_huffman_tree(nodes) #重建哈夫曼树
huffman_encode(False) #建立编码表
for x in ec_dict.keys(): #反向字典构建
inverse_dict[ec_dict[x]] = x
i = img.tell()
data = b''
while i < eof: #eof = img.tell() 当前文件所在位置 #开始解压数据
raw = int.from_bytes(img.read(1), byteorder = 'big')
print("raw:",raw)
i = i + 1
j = 8
while j > 0:
if (raw >> (j - 1)) & 1 == 1:
data = data + b'1'
raw = raw & (~(1 << (j - 1)))
else:
data = data + b'0'
raw = raw & (~(1 << (j - 1)))
if inverse_dict.get(data, 0) != 0:
o.write(inverse_dict[data])
o.flush()
print("\ndecode",data,":",inverse_dict[data])
data = b''
j = j - 1
tem = int(i / eof * 100)
if tem > last:
print("decode:", tem,'%')#输出解压进度
last = tem
raw = 0
img.close()
o.close()
print("\nFile decode successful!\nYou can find the uncompressed file in the folder where the compressed file is stored. ")
if __name__ == '__main__':
#数据初始化
node_dict = {}#结点字典 #建立原始数据与编码节点的映射,便于稍后输出数据的编码
count_dict = {}#权值字典
ec_dict = {}
nodes = []#哈夫曼数的结点列表
inverse_dict = {}
if input("1:压缩图片\t2:解压图片\n\n请输入你要执行的操作:") == '1':
img_path = input("请输入要压缩的图片:")
get_svd_image(img_path)
name, suffix = img_path.split(".")
ax = plt.subplot(1,1,1)
img = mpimg.imread(name + '100.' + suffix)
ax.imshow(img)
ax.set_title("SVD Image")
ax.set_axis_off()
plt.savefig(name + '_svd.' + suffix)
plt.show()
encodeImage(name + "100." + suffix)
else:
decodeImage(input("请输入要解压的图片:"))