【Python系列】之画BD-RATE及码率波动图示例

本文介绍如何使用Python进行SSIM/PSNR等波动图绘制及BDRATE图处理,涉及字符处理、文件读写、数据计算与图表生成。通过示例代码,展示了如何从解码输出文件中提取关键数据,计算滑动窗口平均值,并将结果绘制成图表。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

  本文主要想根据此示例介绍怎么使用python的画图函数,以及各种字符的处理;

1、说明

  此示例可以画SSIM/PSNR等波动图,也可以画BDRATE图,处理的文本格式待补充:

待补充
2、功能代码
#_*-coding:utf-8_*_

import os
import re
import sys
import glob
import subprocess
import numpy as np
import matplotlib.pyplot as plt

##定义输出文件
psnr_summary_file = 'psnr_summary_result.txt'
ssim_summary_file = 'ssim_summary_result.txt'
bits_summary_file = 'bits_summary_result.txt'
psnr_summary_table = 'psnr_summary_table.txt'

##自动获取系统路径宏
space = ' '
delimiter = os.sep

##计算一个列表里所有值得总和
def sum_list(list):
    if (list==[]):
        return 0
    return list[0] + sum_list(list[1:])
    
##以滑动窗口跟滑动步长计算一个列表(每个滑动窗求取窗口内的平均值)
def cacl_avg_in_windowsize(list, winsize, winstep):
    len_list = len(list)
    list_avg = []
    for i in range(0, len_list, winstep):
        if i + winsize > len_list:
            break
        windata_list = list[i:i+winsize]
        windata_avg = float(sum_list(windata_list)) / winsize
        list_avg.append(windata_avg)
    list_avg.append(float(sum_list(list_avg)) / len(list_avg))
    list_avg_str = map(lambda x:str(x), list_avg)
    return list_avg_str
    
##从解码输出文件中提取psnr\ssim
def extract_psnr_ssim(outdir, outtxt, keys, winsize, winstep):
    file_out = open(outtxt, 'r+')
    lines = file_out.readlines()
    line_begin = 0
    line_end = 0
    y_ssim = []
    y_psnr = []
    for i in range(len(lines)):
        line = lines[i]
        line_word = line.strip().split(',')
        if lines[i].find("frame_num") != -1:
            line_begin = 1
            continue
        if lines[i].find("SUMMARY") != 1:
            line_end = 1
        if lines[i].find("Y_PSNR") != -1:
            continue
        if lines[i].find("Y_SSIM") != -1:
            continue
        if line_begin == 1 and line_end == 0:
            y_ssim.append(float(line_word[6]))
            y_psnr.append(float(line_word[9]))
    y_ssim_win_avg = cacl_avg_in_windowsize(y_ssim, winsize, winstep)
    y_psnr_win_avg = cacl_avg_in_windowsize(y_psnr, winsize, winstep)
    
    def write_summary_to_file(outdir, summary_file, data):
        file = open(outdir + delimiter + summary_file, 'a+')
        keys_tmp = list(keys)
        keys_tmp.extend(data)
        one_line = space.join(keys_tmp)
        file.write(one_line + '\n')
        file.close
    
    write_summary_to_file(outdir, ssim_summary_file, y_ssim_win_avg)
    write_summary_to_file(outdir, psnr_summary_file, y_psnr_win_avg)
    file_out.close()

##从解码输出文件中提取码率    
def extract_bits( outdir, outtxt, keys, winsize, winstep):
    file_out = open(outdir + delimiter + outtxt, 'r+')
    lines = file_out.readlines()
    line_begin = 0
    line_end = 0
    y_bit1 = []
    y_bit2 = []
    for i in range(len(lines)):
        line = lines[i]
        line_word = line.strip().split(',')
        if lines[i].find('frame_num') != -1:
            line_begin = 1
            continue
        if lines[i].find('SUMMARY') != -1:
            line_end = 1
        if line_begin == 1 and line_end == 0:
            y_bit1.append(float(line_word[1]))
            y_bit2.append(float(line_word[2]))
    y_bit1_win_avg = cacl_avg_in_windowsize(y_bit1, winsize, winstep)
    y_bit2_win_avg = cacl_avg_in_windowsize(y_bit2, winsize, winstep)
    
    def write_summary_to_file(outdir, summary_file, data, check):
        file = open(outdir + delimiter + summary_file, 'a+')
        lines = file.readlines()
        if check == 1:
            anchor_keys = list(keys)
            anchor_keys[0] = 'Anchor'
            anchor_keys.extend(data)
            one_line = space.join(anchor_keys)
            
            def check_data(one_line, lines):
                ret = 0
                lines_len = len(lines)
                for i in range(0, lines_len):
                    test = lines[i].strip()
                    if one_line == lines[i].strip():
                        ret = 1
                return ret
            ret = check_data(one_line, lines)
            if ret == 0:
                file.write(one_line + '\n')
        else:
            file.write(one_line + '\n')
        file.close()
    
    write_summary_to_file(outdir, bits_summary_file, y_bit1_win_avg, 1)
    write_summary_to_file(outdir, bits_summary_file, y_bit2_win_avg, 0)
    file_out.close()
    
##解码
def exec_decode_process(outdir, decoder, anchor_streams, ref_streams, winsize, winstep):
    for ref_str_file in ref_streams:
        ref_str_path = os.path.split(ref_str_file.restrip(delimiter)) ##dir/xxxx.h264
        ref_str_name = ref_str_path[1]
        ref_str_suffix = os.path.splitext(ref_str_name)[1]
        ref_str_solution = []
        
        def get_file_name(file_path, suffix):
            name = os.path.basename(file_path)
            return name[0:(0-len(suffix))]
            
        ref_str_name_no_suffix = get_file_name(ref_str_name, ref_str_suffix)
        
        def extract_info_from_name(filename):
            keys = filename.replace('_',' ').split()
            key = [decoder]
            target_bit = []
            stream_name = []
            stream_resolution = ''
            ##filename=xulie_832x480_26k.264 or xulie_832x480.yuv
            for i in range(len(keys)):
                ret = re.match('[0-9]+x[0-9]+', keys[i])
                if ret:
                    stream_resolution = keys[i]
                if not ret:
                    stream_name.append(keys[i])
                else:
                    target_bit = keys[i+1].lower()
                    break
            if target_bit.endswith('m'):
                target_bit = float(target_bit[0:-1])*1024
            else:
                target_bit = float(target_bit[0:-1])
            stream_name.append(stream_resolution)
            key.append('_'.join(stream_name))
            key.append(str(target_bit))
            ref_str_solution.append(stream_resolution)
            return key
        
        ref_key_word = extract_info_from_name(ref_str_name_no_suffix)
        ref_bit_rate = ref_key_word[2]
        
        anch_str_file = ''
        anch_str_name_no_suffix = ref_str_name_no_suffix
        anch_str_file_idx = [i for i, x in enumerate(anchor_streams) if x.find(anch_str_name_no_suffix) != -1]
        if len(anch_str_file_idx) == 0:
            tmp = re.match("[a-zA-Z0-9_-]+[0-9]+x[0-9]+", anch_str_name_no_suffix).group()
            anch_yuv_name = tmp + '.'
            anch_str_file_idx = [i for i, x in enumerate(anchor_streams) if x.find(anch_yuv_name) != -1]
            if len(anch_str_file_idx) == 0:
                print "No" + anch_str_name_no_suffix + "in specified dir"
                os.sys.exit(1)
        anch_str_file = anchor_streams[anch_str_file_idx[0]]
        anch_str_path = os.path.split(anch_str_file.restrip(delimiter))
        anch_str_name = anch_str_path[1]
        anch_str_suffix = os.path.splitext(anch_str_name)[1]
        
        ref_out_txt = os.path.join(outdir, (decoder + "_" + ref_str_name + ".txt"))
        
        limit_frames = '10'
        anch_format = ' '
        if anch_str_suffix == ".yuv":
            anch_format = ' -y ' + anch_str_name
            anch_format = anch_format + ' -r ' + ref_str_solution[0]
            anch_format = anch_format + ' -ys ' + ' 0 '
            anch_format = anch_format + ' -ye ' + limit_frames
        else:
            anch_format = ' -s ' + anch_str_name
            anch_format = anch_format + ' -ss ' + ' 0 '
            anch_format = anch_format + ' -se ' + limit_frames
        ref_format = ' '
        if ref_str_suffix == ".yuv":
            ref_format = ' -y ' + ref_str_name
            ref_format = ref_format + ' -r ' + ref_str_solution[0]
            ref_format = ref_format + ' -ys ' + ' 0 '
            ref_format = ref_format + ' -ye ' + limit_frames
        else:
            ref_format = ' -s ' + ref_str_name
            ref_format = ref_format + ' -ss ' + ' 0 '
            ref_format = ref_format + ' -se ' + limit_frames
        
        cmd = space.join(['xxxx.exe',
                         anch_format,
                         ref_format,
                         '>', ref_out_txt])  ##命令行参数
                         
        ret = subprocess.call(cmd, shell=True)
        if ret != 1:
            print "Decode fail"
            os.sys.exit(1)
        ref_key_temp = list(ref_key_word)
        extract_psnr_ssim(outdir, ref_out_txr, ref_key_temp, winsize, winstep)
        
        ref_key_temp = list(ref_key_word)
        extract_bits(outdir, ref_out_txt, ref_key_temp, winsize, winstep)
        
##准备解码码流
def decode_process(outdir, decoder, src_anchor, src_ref, winsize, winstep):
    stream_suffix = " "
    anchor_streams = glob.glob(src_anchor + delimiter + "*" + stream_suffix)
    ref_streams = glob.glob(src_ref + delimiter + "*" + stream_suffix)
    
    if(len(anchor_streams) == 0) or (len(ref_streams) == 0):
        print "No file in specified dir"
        os.sys.exit(1)
    if ' ' in src_anchor or ' ' in src_ref or ' ' in outdir:
        print "Path contains space"
        os.sys.exit(1)
        
    exec_decode_process(outdir, decoder, anchor_streams, ref_streams, winsize, winstep)
    
##从汇总文件中提取编码器的名字、序列名、码率信息
def get_encoder_seqs_bits(summary_file):
        file = open(summary_file,"r")
        lines = file.readlines()
        encodes_list = []
        seqs_list = []
        bits_list = []
        def extract_enc_seq_bit(content):
            items = content.split()
            if not items[0] in encodes_list:
                encodes_list.append(items[0])
            if not items[1] in seqs_list:
                seqs_list.append(items[1])
            if not items[2] in bits_list:
                bits_list.append(items[2])
        map(extract_enc_seq_bit, lines)
        file.close()
        return (encodes_list, seqs_list, bits_list)


##排除某编码器之外的数据
def filter_by_enc(enc):
    return (lambda x:True if cmp(x[0], enc) == 0 else False)
    
##排除某序列之外的数据
def filter_by_seq(seq):
    return (lambda x:True if cmp(x[1], seq) == 0 else False)    
    
##排除某码率之外的数据
def filter_by_bit(bit):
    return (lambda x:True if cmp(x[2], bit) == 0 else False)
            
##画BD-RATE图
def plot_bdrate_chart_process(outdir, summary_file, summary_table):
    (encs_list, seqs_list, bits_list) = get_encoder_seqs_bits(outdir+delimiter+summary_file)
    table_file = open(outdir+delimiter+summary_table, 'w+')
    summa_file = open(outdir+delimiter+summary_file, 'r')
    summa_lines = summa_file.readlines()
    
    def collect_idx(idx):
        def connect(x, y):
            x.append(y[idx])
            return x
        return connect
        
    ##画BD-RATE的函数
    def plot_bdrate_func(x, y, enc, seq):
        plt.plot()
        plt.title(seq)
        plt.xlabel('Bit Rate')
        plt.ylabel('Y-PSNR')
        plt.grid(True)
        plt.plot(x, y, "-o", label="-".join([enc,seq]))
        plt.rcParams["legend fontsize"] = "x-small"
        plt.legend(shadow=True, loc=0)
        
    ##将画BD-RATE的几个点写入文件中
    def colect_bdrate_dot(bits, val, enc, seq, y_table):
        def fix_len_align(x):
            return x.ljust(16, ' ')
        if len(y_table) == 0:
            y_title = [fix_len_align(seq)]
            bits = map(lambda x:x.replace(" 0", "kb"), bits)
            y_title.extend(map(fix_len_align, bits))
            y_table = [y_title]
        y_data = [fix_len_align(enc)]
        y_data.extend(map(fix_len_align, val))
        y_table.append(y_data)
        return (y_table)
            
    for seq in seqs:
        title = " "
        psnr_table = []
        for enc in encs:
            datas = filter(filter_by_seq(seq), filter((filter_by_enc(enc), map(str.split, summa_lines))))
            datas.sort(key=(lambda x:float(x[2])))
            bits = reduce(collect_idx(2), datas, list())
            psnr_avg = reduce(collect_idx(-1), datas, list())
            
            plot_bdrate_func(bits, psnr_avg, enc, seq)
            psnr_table = colect_bdrate_dot(bits, psnr_avg, enc, seq, psnr_table)
        plt.savefig(outdir+delimiter+seq+".png")
        plt.close()
        
        def write_line(x):
            table_file.write(' '.join(x) + '\n')
        table_file.write("\n")
        map(write_line, psnr_table)
        
    table_file.close()
    summa_file.close()
    

###执行画波动图
def exec_plot_wave_chart_process(outdir, bit, enclist, seq, plot_table, y_label):
        x = []
        item_count = 0
        for i in range(0, len(enclist)):
            enc = enclist[i]
            data_list = plot_table[i]
            data_list = map(lambda x:float(x), data_list) #数字转换为float型
            data_each = data_list[0:-1]
            data_avg =  data_list[-1]
            
            plt.plot()
            plt.title('_'.join([seq, bit]))
            plt.xlabel("Frame Idx")
            plt.ylabel(y_label)
            plt.grid(True)
            
            x = range(0, len(data_each))
            plt.plot(x, data_each, "-o", label = '_'.join([enc, seq, bit]))
            plt.rcParams["legend fontsize"] = "x-small"
            plt.legend(shadow=True, loc=0)
            
            item_count = len(data_each)
            data_avg_list = np.linspace(data_avg, data_avg, item_count);
            plt.plot(x, data_avg_list, "-x", label='_'.join([enc, seq, bit, "avg"]))
            plt.legend(shadow=True, loc=0)
        plt.savefig('_'.join([outdir+delimiter, seq, bit, y_label]) + ".png")
        plt.close()
        

###画波动图前预处理
def plot_wave_chart_process(outdir, summary_file, y_label):
    (encs, seqs, bits) = get_encoder_seqs_bits(outdir+delimiter+summary_file)
    file = open(outdir+delimiter+summary_file, "r")
    lines = file.readlines()
    def collect_enc(idx):
        def collect(x, y):
            x.append(y[idx])
            return x
        return collect
    for seq in seqs:
        plot_tabel= []
        for bit in bits:
            datas = filter(filter_by_seq(seq), (filter(filter_by_bit(bit), map(str.split, lines))))
            encs = reduce(collect_enc(0), datas, list())
            plot_tabel = map(lambda x:(x[3]), datas)
            exec_plot_wave_chart_process(outdir, bit, encs, seq, plot_tabel, y_label)
            
###创建目录
def creat_check_dir(path):
    path = path.strip()
    isExist = os.path.exists(path)
    if(not isExist):
        os.makedirs(path)    #路径不存在,则创建此路径
        print (path + ' Creat Successful')
        return True
    else:
        return True

### main函数入口
if __name__ == '__main__':
    if len(sys.argv) < 3:
        print "usage: outdir windowsize windowstep anchordir ref1dir ref2dir"
        os.sys.exit(1)
    
    argc = len(sys.argv)
    outdir = sys.argv[1]
    winsize = int(sys.argv[2])
    winstep = int(sys.argv[3])
    anchodir = sys.argv[4]
    ref1dir = sys.argv[5]
    ref2dir = sys.argv[6]
    
    def clear_file(filename):
        file = open(filename, "w+")
        file.truncate()
        file.close()
        
    ###清空文件
    clear_file(outdir+delimiter+psnr_summary_file)
    clear_file(outdir+delimiter+ssim_summary_file)
    clear_file(outdir+delimiter+bits_summary_file)
    clear_file(outdir+delimiter+psnr_summary_table)
    
    ###构建输出目录
    creat_check_dir(outdir)
    for i in range(5, argc):
        os_path = os.path.split(sys.argv[i].rstrip(delimiter))
        os_path_suffix = os_path[1]
        encoder = os_path_suffix[0:os_path_suffix.find('_')]
        decode_process(outdir, encoder, anchodir, sys.argv[i], winsize, winstep)
    
    plot_wave_chart_process(outdir, bits_summary_file, "BITS")
    plot_wave_chart_process(outdir, psnr_summary_file, "PSNR")
    plot_wave_chart_process(outdir, ssim_summary_file, "SSIM")
    
    plot_bdrate_chart_process(outdir, psnr_summary_file, psnr_summary_table)      
            
           
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值