算数编码是一种用浮点数来编码信息的方式,之前读过的论文有提到用算数编码来达到数据压缩的目的,尤其是在参数交换的时候,可以达到<1bpp(bit per parameter)的通信效率。
例如编码一串字符信息101(二进制),预估的0和1的出现率分别为0.3和0.7。
possibilities_character = [0.3,0.7]
那么我们按如下方法生成编码的浮点数:
初始化区间[0,1];把区间按概率比例划分为[0,0,3]和[0.3,1]
编码第一个字符1,那么总的区间变成了[0.3,1],再按概率比例划分为[0.3,0.3+区间长度*0.3]以及[0.3+区间长度*0.3,1],这里区间长度是0.7,所以最终的区间划分为[0.3,0.51],[0.51,1]
编码第二个字符0,那么总的区间变成了[0.3,0.51],再按概率比例划分为[0.3,0.363]和[0.363,0.51]
编码第三个字符1,那么总的区间变成了[0.363,0.51]。至此结束编码,不再划分。
101就对应了区间[0.363,0.51],那么在这个区间任选一个浮点数,都可以作为编码结果,例如0.4。(代码实现中实际取了区间的中点)
解码过程如下;
必须要给定需要解码字符的长度,此时为3。
初始化区间[0,1];把区间按概率比例划分为[0,0,3]和[0.3,1]。
0.4落在第二个区间,所以解码出的第一位为1。并且更新总的区间[0.3,0.1],划分为[0.3,0.51]和[0.51,1]。
0.4又落在第一个区间,所以解码出的第二位为0。并且更新总的区间[0.3,0.51],划分为[0.3,0.363]和[0.363,0.51]。
0.4又落在第二个区间,所以解码出的第三为1。达到了需要解码的字符长度,解码结束。
可以发现编码和解码的区间计算几乎完全相同。
字符长度可以拓展,但是实际上取决于浮点数的有效精度。
字符类别可以拓展,例如26个英文字符,那么对于总的区间也被按照概率划分为26段。
代码实现中添加了伯努利采样,对于概率序列[0.2,0.3,0.5],会生成一个三位的01序列,分别以0.2,0.3,0.5的概率生成1,否则生成0。这部分知识与本文无关,只是为了生成初始的01序列。
代码以编码和解码二进制掩码为例子。
import random
#伯努利采样,输入概率向量,返回伯努利采样的掩码(0和1构成的列表)
def bernoulli_sampling(probabilities):
binary_mask = [1 if random.random() < p else 0 for p in probabilities]
return binary_mask
probabilities = [0.5, 0.3, 0.2, 0.85]
binary_mask = bernoulli_sampling(probabilities)
print("原始概率序列:", probabilities)
print("伯努利采样得到的二进制掩码:", binary_mask)
def arithmetic_encoding(binary_mask):
# 定义初始区间 [0, 1)
low, high = 0.0, 1.0
result = 0.0
for bit in binary_mask:
# 计算新的区间范围
range_width = high - low
high = low + range_width * (bit / 2 + 0.5)
low = low + range_width * (bit / 2)
# 取区间中任意一个值作为结果(这里取中点)
result = (low + high) / 2
return result
# 使用之前得到的二进制掩码进行算术编码
arithmetic_code = arithmetic_encoding(binary_mask)
print("二进制掩码:", binary_mask)
print("算术编码结果:", arithmetic_code)
#
probabilities_character = [0.3,0.7]
# 算数解码,传入编码的浮点数,
def arithmetic_decoding(encoded_value, probabilities_character, length):
# 初始化区间 [0, 1)
low, high = 0.0, 1.0
decoded_mask = []
for _ in range(length):
# 计算当前区间范围
range_width = high - low
# 根据区间和概率确定二进制位
if encoded_value < (low + range_width * probabilities_character[0]):
bit = 0
#更新区间,解码为0,就取左边
high = low + range_width * probabilities_character[0]
else:
bit = 1
#更新区间,解码为1,就取右边
low = low + range_width * probabilities_character[0]
# 添加解码的二进制位
decoded_mask.append(bit)
return decoded_mask
# 使用之前的算术编码结果进行解码
decoded_binary_mask = arithmetic_decoding(arithmetic_code, probabilities, len(probabilities))
print("算术编码结果:", arithmetic_code)
print("解码得到的二进制掩码:", decoded_binary_mask)