F. 实验7_21_编码、解码数组

设有一个整型数组a中存放n个互不相同的整数,a中元素取值范围为0到n-1。a数组的编码定义如下:a[0]的编码为0,放入b[0];ai的编码为在a[0]、a[1]、…、a[i-1]中比a[i]的值小的数的个数,放入b[i]。例如:a[6]={4,3,0,5,1,2}时,则b[6]={0,0,0,3,1,2}。你有两个任务,任务一是编码,即已知n与数组a,求数组b;任务二是解码,即已知n与数组b,求数组a。
输入与输出要求:首先输入两个整数n和flag。n代表数组元素的个数(1<=n<=100),flag代表任务类型(flag=1代表编码任务,flag=2代表解码任务)。然后是n个整数,当flag=1时,这n个数即代表数组a的内容;当flag=2时,这n个数即代表数组b的内容。
输出n个整数,当flag=1时即为编码后数组b的内容,当flag=2时,即为解码后数组a的内容。每个整数用空格分开,最后一个整数后是换行符。
程序运行效果:
Sample 1:
5↙
1↙
2 1 3 0 4↙
0 0 2 0 4
Sample 2:
8↙
2↙
0 1 2 0 1 1 4 5↙
3 6 7 0 2 1 4 5

#include<stdio.h>
int main()
{
	int flag,n;
	scanf("%d%d",&n,&flag);
	int a[n],b[n];
	int cnt=0;
	if(flag==1)
	{
		for(int i=0;i<n;i++)
		{
			scanf("%d",&a[i]);
		}
		for(int i=0;i<n;i++)
		{
			cnt=0;
			for(int j=0;j<i;j++) 
			{
				if(a[j]<a[i])
				{
					cnt++;
				}
			}
			b[i]=cnt;
		}
		for(int i=0;i<n;i++)
		{
			if(i<n-1)
			{
				printf("%d ",b[i]);
			}
			else
			{
				printf("%d\n",b[i]);
			}
		}
	}
	cnt=0;
	if(flag==2)
	{
		for(int i=0;i<n;i++)
		{
			a[i]=-1;
		}
		for(int i=0;i<n;i++)
		{
			scanf("%d",&b[i]);
		}
		//a[n-1]=b[n-1];
		int number[n];
		for(int i=0;i<n;i++)
		{
			number[i]=i;
		}
		//number[b[n-1]]=-1;
		for(int i=n-1;i>=0;i--)
		{
			if(number[b[i]]>=0&&number[b[i]]<=n-1)
			{
				a[i]=number[b[i]];
				for(int k=b[i];k<n-1-cnt;k++)
				{
					number[k]=number[k+1];
				}
				cnt++;
			}
		}
		for(int i=0;i<n;i++)
		{
			if(i<n-1)
			{
				printf("%d ",a[i]);
			}
			else if(i==n-1)
			{
				printf("%d\n",a[i]);
			} 
		}
	}
	return 0;
}

简化版

#include<stdio.h>
int main()
{
	int flag,n;
	scanf("%d%d",&n,&flag);
	int a[n];
	int b[n];
	if(flag==1)
	{
		for(int i=0;i<n;i++)
			scanf("%d",&a[i]);
		for(int i=0;i<n;i++)
		{
			int cnt=0;
			for(int j=0;j<i;j++)
			{
				if(a[i]>a[j])
				{
					cnt++;
				}
			}
			b[i]=cnt;
		}
		for(int i=0;i<n;i++)
		{
			if(i<n-1)
				printf("%d ",b[i]);
			else
				printf("%d\n",b[i]);
		}
	}
	if(flag==2)
	{
		int number[n];
		for(int i=0;i<n;i++) 
			number[i]=i;
		for(int i=0;i<n;i++)
			scanf("%d",&b[i]);
		for(int i=n-1;i>=0;i--)
		{
			a[i]=number[b[i]];
			for(int j=b[i]+1;j<n;j++)
			{
				number[j-1]=number[j];
			}
		}	
		for(int i=0;i<n;i++)
		{
			if(i<n-1)
				printf("%d ",a[i]);
			else
				printf("%d\n",a[i]);
		}
	}
	return 0;
}
1.简要介绍 图像分割是对图像中的每个像素加标签的一个过程,这一过程使得具有相同标签的像素具有某种共同视觉特性。 本示例简要介绍如何通过飞桨开源框架,实现图像分割。这里我们是采用了一个在图像分割领域比较熟知的U-Net网络结构,是一个基于FCN做改进后的一个深度学习网络,包含下采样(编码器,特征提取)和上采样(解码器,分辨率还原)两个阶段,因模型结构比较像U型而命名为U-Net,U-Net模型结构如图所示: 2.环境设置 导入一些比较基础常用的模块,确认自己的飞桨版本。 In [2] import os import io import numpy as np import matplotlib.pyplot as plt from PIL import Image as PilImage import paddle from paddle.nn import functional as F paddle.set_device('gpu') paddle.__version__ ---------------------------------------------------------------------------ValueError Traceback (most recent call last)/tmp/ipykernel_95/3499843494.py in <module> 8 from paddle.nn import functional as F 9 ---> 10 paddle.set_device('gpu') 11 paddle.__version__ /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/device.py in set_device(device) 130 if not core.is_compiled_with_cuda(): 131 raise ValueError( --> 132 "The device should not be 'gpu', " \ 133 "since PaddlePaddle is not compiled with CUDA") 134 place = core.CUDAPlace(ParallelEnv().dev_id) ValueError: The device should not be 'gpu', since PaddlePaddle is not compiled with CUDA 3.数据集 3.1 数据集下载 本案例使用Oxford-IIIT Pet数据集,官网:https://www.robots.ox.ac.uk/~vgg/data/pets 。 In [ ] !tar -xf data/data50154/images.tar.gz !tar -xf data/data50154/annotations.tar.gz 3.2 数据集概览 首先我们先看看下载到磁盘上的文件结构是什么样,来了解一下我们的数据集。 首先看一下images.tar.gz这个压缩包,该文件解压后得到一个images目录,这个目录比较简单,里面直接放的是用类名和序号命名好的图片文件,每个图片是对应的宠物照片。 . ├── samoyed_7.jpg ├── ...... └── samoyed_81.jpg 然后我们在看下annotations.tar.gz,文件解压后的目录里面包含以下内容,目录中的README文件将每个目录和文件做了比较详细的介绍,我们可以通过README来查看每个目录文件的说明。 . ├── README ├── list.txt ├── test.txt ├── trainval.txt ├── trimaps │ ├── Abyssinian_1.png │ ├── Abyssinian_10.png │ ├── ...... │ └── yorkshire_terrier_99.png └── xmls ├── Abyssinian_1.xml ├── Abyssinian_10.xml ├── ...... └── yorkshire_terrier_190.xml 本次我们主要使用到images和annotations/trimaps两个目录,即原图和三元图像文件,前者作为训练的输入数据,后者是对应的标签数据。 我们来看看这个数据集给我们提供了多少个训练样本。 In [ ] IMAGE_SIZE = (160, 160) train_images_path = "images/" label_images_path = "annotations/trimaps/" image_count = len([os.path.join(train_images_path, image_name) for image_name in os.listdir(train_images_path) if image_name.endswith('.jpg')]) print("用于训练的图片样本数量:", image_count) # 对数据集进行处理,划分训练集、测试集 def _sort_images(image_dir, image_type): """ 对文件夹内的图像进行按照文件名排序 """ files = [] for image_name in os.listdir(image_dir): if image_name.endswith('.{}'.format(image_type)) \ and not image_name.startswith('.'): files.append(os.path.join(image_dir, image_name)) return sorted(files) def write_file(mode, images, labels): with open('./{}.txt'.format(mode), 'w') as f: for i in range(len(images)): f.write('{}\t{}\n'.format(images[i], labels[i])) """ 由于所有文件都是散落在文件夹中,在训练时我们需要使用的是数据集和标签对应的数据关系, 所以我们第一步是对原始的数据集进行整理,得到数据集和标签两个数组,分别一一对应。 这样可以在使用的时候能够很方便的找到原始数据和标签的对应关系,否则对于原有的文件夹图片数据无法直接应用。 在这里是用了一个非常简单的方法,按照文件名称进行排序。 因为刚好数据和标签的文件名是按照这个逻辑制作的,名字都一样,只有扩展名不一样。 """ images = _sort_images(train_images_path, 'jpg') labels = _sort_images(label_images_path, 'png') eval_num = int(image_count * 0.15) write_file('train', images[:-eval_num], labels[:-eval_num]) write_file('test', images[-eval_num:], labels[-eval_num:]) write_file('predict', images[-eval_num:], labels[-eval_num:]) 3.3 PetDataSet数据集抽样展示 划分好数据集之后,我们来查验一下数据集是否符合预期,我们通过划分的配置文件读取图片路径后再加载图片数据来用matplotlib进行展示,这里要注意的是对于分割的标签文件因为是1通道的灰度图片,需要在使用imshow接口时注意下传参cmap='gray'。 In [ ] with open('./train.txt', 'r') as f: i = 0 for line in f.readlines(): image_path, label_path = line.strip().split('\t') image = np.array(PilImage.open(image_path)) label = np.array(PilImage.open(label_path)) if i > 2: break # 进行图片的展示 plt.figure() plt.subplot(1,2,1), plt.title('Train Image') plt.imshow(image.astype('uint8')) plt.axis('off') plt.subplot(1,2,2), plt.title('Label') plt.imshow(label.astype('uint8'), cmap='gray') plt.axis('off') plt.show() i = i + 1 3.4 数据集类定义 飞桨(PaddlePaddle)数据集加载方案是统一使用Dataset(数据集定义) + DataLoader(多进程数据集加载)。 首先我们先进行数据集的定义,数据集定义主要是实现一个新的Dataset类,继承父类paddle.io.Dataset,并实现父类中以下两个抽象方法,__getitem__和__len__: class MyDataset(Dataset): def __init__(self): ... # 每次迭代时返回数据和对应的标签 def __getitem__(self, idx): return x, y # 返回整个数据集的总数 def __len__(self): return count(samples) 在数据集内部可以结合图像数据预处理相关API进行图像的预处理(改变大小、反转、调整格式等)。 由于加载进来的图像不一定都符合自己的需求,举个例子,已下载的这些图片里面就会有RGBA格式的图片,这个时候图片就不符合我们所需3通道的需求,我们需要进行图片的格式转换,那么这里我们直接实现了一个通用的图片读取接口,确保读取出来的图片都是满足我们的需求。 另外图片加载出来的默认shape是HWC,这个时候要看看是否满足后面训练的需要,如果Layer的默认格式和这个不是符合的情况下,需要看下Layer有没有参数可以进行格式调整。不过如果layer较多的话,还是直接调整原数据Shape比较好,否则每个layer都要做参数设置,如果有遗漏就会导致训练出错,那么在本案例中是直接对数据源的shape做了统一调整,从HWC转换成了CHW,因为飞桨的卷积等API的默认输入格式为CHW,这样处理方便后续模型训练。 In [ ] import random from paddle.io import Dataset from paddle.vision.transforms import transforms as T class PetDataset(Dataset): """ 数据集定义 """ def __init__(self, mode='train'): """ 构造函数 """ self.image_size = IMAGE_SIZE self.mode = mode.lower() assert self.mode in ['train', 'test', 'predict'], \ "mode should be 'train' or 'test' or 'predict', but got {}".format(self.mode) self.train_images = [] self.label_images = [] with open('./{}.txt'.format(self.mode), 'r') as f: for line in f.readlines(): image, label = line.strip().split('\t') self.train_images.append(image) self.label_images.append(label) def _load_img(self, path, color_mode='rgb', transforms=[]): """ 统一的图像处理接口封装,用于规整图像大小和通道 """ with open(path, 'rb') as f: img = PilImage.open(io.BytesIO(f.read())) if color_mode == 'grayscale': # if image is not already an 8-bit, 16-bit or 32-bit grayscale image # convert it to an 8-bit grayscale image. if img.mode not in ('L', 'I;16', 'I'): img = img.convert('L') elif color_mode == 'rgba': if img.mode != 'RGBA': img = img.convert('RGBA') elif color_mode == 'rgb': if img.mode != 'RGB': img = img.convert('RGB') else: raise ValueError('color_mode must be "grayscale", "rgb", or "rgba"') return T.Compose([ T.Resize(self.image_size) ] + transforms)(img) def __getitem__(self, idx): """ 返回 image, label """ train_image = self._load_img(self.train_images[idx], transforms=[ T.Transpose(), T.Normalize(mean=127.5, std=127.5) ]) # 加载原始图像 label_image = self._load_img(self.label_images[idx], color_mode='grayscale', transforms=[T.Grayscale()]) # 加载Label图像 # 返回image, label train_image = np.array(train_image, dtype='float32') label_image = np.array(label_image, dtype='int64') return train_image, label_image def __len__(self): """ 返回数据集总数 """ return len(self.train_images) 4.模型组网 U-Net是一个U型网络结构,可以看做两个大的阶段,图像先经过Encoder编码器进行下采样得到高级语义特征图,再经过Decoder解码器上采样将特征图恢复到原图片的分辨率。 4.1 定义SeparableConv2D接口 我们为了减少卷积操作中的训练参数来提升性能,是继承paddle.nn.Layer自定义了一个SeparableConv2D Layer类,整个过程是把filter_size * filter_size * num_filters的Conv2D操作拆解为两个子Conv2D,先对输入数据的每个通道使用filter_size * filter_size * 1的卷积核进行计算,输入输出通道数目相同,之后在使用1 * 1 * num_filters的卷积核计算。 In [ ] from paddle.nn import functional as F class SeparableConv2D(paddle.nn.Layer): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=None, weight_attr=None, bias_attr=None, data_format="NCHW"): super(SeparableConv2D, self).__init__() self._padding = padding self._stride = stride self._dilation = dilation self._in_channels = in_channels self._data_format = data_format # 第一次卷积参数,没有偏置参数 filter_shape = [in_channels, 1] + self.convert_to_list(kernel_size, 2, 'kernel_size') self.weight_conv = self.create_parameter(shape=filter_shape, attr=weight_attr) # 第二次卷积参数 filter_shape = [out_channels, in_channels] + self.convert_to_list(1, 2, 'kernel_size') self.weight_pointwise = self.create_parameter(shape=filter_shape, attr=weight_attr) self.bias_pointwise = self.create_parameter(shape=[out_channels], attr=bias_attr, is_bias=True) def convert_to_list(self, value, n, name, dtype=np.int): if isinstance(value, dtype): return [value, ] * n else: try: value_list = list(value) except TypeError: raise ValueError("The " + name + "'s type must be list or tuple. Received: " + str( value)) if len(value_list) != n: raise ValueError("The " + name + "'s length must be " + str(n) + ". Received: " + str(value)) for single_value in value_list: try: dtype(single_value) except (ValueError, TypeError): raise ValueError( "The " + name + "'s type must be a list or tuple of " + str( n) + " " + str(dtype) + " . Received: " + str( value) + " " "including element " + str(single_value) + " of type" + " " + str(type(single_value))) return value_list def forward(self, inputs): conv_out = F.conv2d(inputs, self.weight_conv, padding=self._padding, stride=self._stride, dilation=self._dilation, groups=self._in_channels, data_format=self._data_format) out = F.conv2d(conv_out, self.weight_pointwise, bias=self.bias_pointwise, padding=0, stride=1, dilation=1, groups=1, data_format=self._data_format) return out 4.2 定义Encoder编码器 我们将网络结构中的Encoder下采样过程进行了一个Layer封装,方便后续调用,减少代码编写,下采样是有一个模型逐渐向下画曲线的一个过程,这个过程中是不断的重复一个单元结构将通道数不断增加,形状不断缩小,并且引入残差网络结构,我们将这些都抽象出来进行统一封装。 In [ ] class Encoder(paddle.nn.Layer): def __init__(self, in_channels, out_channels): super(Encoder, self).__init__() self.relus = paddle.nn.LayerList( [paddle.nn.ReLU() for i in range(2)]) self.separable_conv_01 = SeparableConv2D(in_channels, out_channels, kernel_size=3, padding='same') self.bns = paddle.nn.LayerList( [paddle.nn.BatchNorm2D(out_channels) for i in range(2)]) self.separable_conv_02 = SeparableConv2D(out_channels, out_channels, kernel_size=3, padding='same') self.pool = paddle.nn.MaxPool2D(kernel_size=3, stride=2, padding=1) self.residual_conv = paddle.nn.Conv2D(in_channels, out_channels, kernel_size=1, stride=2, padding='same') def forward(self, inputs): previous_block_activation = inputs y = self.relus[0](inputs) y = self.separable_conv_01(y) y = self.bns[0](y) y = self.relus[1](y) y = self.separable_conv_02(y) y = self.bns[1](y) y = self.pool(y) residual = self.residual_conv(previous_block_activation) y = paddle.add(y, residual) return y 4.3 定义Decoder解码器 在通道数达到最大得到高级语义特征图后,网络结构会开始进行decode操作,进行上采样,通道数逐渐减小,对应图片尺寸逐步增加,直至恢复到原图像大小,那么这个过程里面也是通过不断的重复相同结构的残差网络完成,我们也是为了减少代码编写,将这个过程定义一个Layer来放到模型组网中使用。 In [ ] class Decoder(paddle.nn.Layer): def __init__(self, in_channels, out_channels): super(Decoder, self).__init__() self.relus = paddle.nn.LayerList( [paddle.nn.ReLU() for i in range(2)]) self.conv_transpose_01 = paddle.nn.Conv2DTranspose(in_channels, out_channels, kernel_size=3, padding=1) self.conv_transpose_02 = paddle.nn.Conv2DTranspose(out_channels, out_channels, kernel_size=3, padding=1) self.bns = paddle.nn.LayerList( [paddle.nn.BatchNorm2D(out_channels) for i in range(2)] ) self.upsamples = paddle.nn.LayerList( [paddle.nn.Upsample(scale_factor=2.0) for i in range(2)] ) self.residual_conv = paddle.nn.Conv2D(in_channels, out_channels, kernel_size=1, padding='same') def forward(self, inputs): previous_block_activation = inputs y = self.relus[0](inputs) y = self.conv_transpose_01(y) y = self.bns[0](y) y = self.relus[1](y) y = self.conv_transpose_02(y) y = self.bns[1](y) y = self.upsamples[0](y) residual = self.upsamples[1](previous_block_activation) residual = self.residual_conv(residual) y = paddle.add(y, residual) return y 4.4 训练模型组网 按照U型网络结构格式进行整体的网络结构搭建,三次下采样,四次上采样。 In [ ] class PetNet(paddle.nn.Layer): def __init__(self, num_classes): super(PetNet, self).__init__() self.conv_1 = paddle.nn.Conv2D(3, 32, kernel_size=3, stride=2, padding='same') self.bn = paddle.nn.BatchNorm2D(32) self.relu = paddle.nn.ReLU() in_channels = 32 self.encoders = [] self.encoder_list = [64, 128, 256] self.decoder_list = [256, 128, 64, 32] # 根据下采样个数和配置循环定义子Layer,避免重复写一样的程序 for out_channels in self.encoder_list: block = self.add_sublayer('encoder_{}'.format(out_channels), Encoder(in_channels, out_channels)) self.encoders.append(block) in_channels = out_channels self.decoders = [] # 根据上采样个数和配置循环定义子Layer,避免重复写一样的程序 for out_channels in self.decoder_list: block = self.add_sublayer('decoder_{}'.format(out_channels), Decoder(in_channels, out_channels)) self.decoders.append(block) in_channels = out_channels self.output_conv = paddle.nn.Conv2D(in_channels, num_classes, kernel_size=3, padding='same') def forward(self, inputs): y = self.conv_1(inputs) y = self.bn(y) y = self.relu(y) for encoder in self.encoders: y = encoder(y) for decoder in self.decoders: y = decoder(y) y = self.output_conv(y) return y 4.5 模型可视化 调用飞桨提供的summary接口对组建好的模型进行可视化,方便进行模型结构和参数信息的查看和确认。 In [ ] num_classes = 4 network = PetNet(num_classes) model = paddle.Model(network) model.summary((-1, 3,) + IMAGE_SIZE) 5.模型训练 5.1 启动模型训练 使用模型代码进行Model实例生成,使用prepare接口定义优化器、损失函数和评价指标等信息,用于后续训练使用。在所有初步配置完成后,调用fit接口开启训练执行过程,调用fit时只需要将前面定义好的训练数据集、测试数据集、训练轮次(Epoch)和批次大小(batch_size)配置好即可。 In [ ] train_dataset = PetDataset(mode='train') # 训练数据集 val_dataset = PetDataset(mode='test') # 验证数据集 optim = paddle.optimizer.RMSProp(learning_rate=0.001, rho=0.9, momentum=0.0, epsilon=1e-07, centered=False, parameters=model.parameters()) model.prepare(optim, paddle.nn.CrossEntropyLoss(axis=1)) model.fit(train_dataset, val_dataset, epochs=15, batch_size=32, save_dir='model', verbose=1) 6.模型预测 6.1 预测数据集准备和预测 继续使用PetDataset来实例化待预测使用的数据集。这里我们为了方便没有在另外准备预测数据,复用了评估数据。 我们可以直接使用model.predict接口来对数据集进行预测操作,只需要将预测数据集传递到接口内即可。 In [ ] path = r"model/final.pdparams" state_dict = paddle.load(path) network.set_state_dict(state_dict) model = paddle.Model(network) predict_dataset = PetDataset(mode='predict') model.prepare(paddle.nn.CrossEntropyLoss(axis=1)) predict_results = model.predict(predict_dataset) 6.2 预测结果可视化 从我们的预测数据集中抽3个动物来看看预测的效果,展示一下原图、标签图和预测结果。 In [ ] plt.figure(figsize=(10, 10)) i = 0 mask_idx = 0 with open('./predict.txt', 'r') as f: for line in f.readlines(): image_path, label_path = line.strip().split('\t') resize_t = T.Compose([ T.Resize(IMAGE_SIZE) ]) image = resize_t(PilImage.open(image_path)) label = resize_t(PilImage.open(label_path)) image = np.array(image).astype('uint8') label = np.array(label).astype('uint8') if i > 8: break plt.subplot(3, 3, i + 1) plt.imshow(image) plt.title('Input Image') plt.axis("off") plt.subplot(3, 3, i + 2) plt.imshow(label, cmap='gray') plt.title('Label') plt.axis("off") # 模型只有一个输出,所以我们通过predict_results[0]来取出1000个预测的结果 # 映射原始图片的index来取出预测结果,提取mask进行展示 data = predict_results[0][mask_idx][0].transpose((1, 2, 0)) mask = np.argmax(data, axis=-1) plt.subplot(3, 3, i + 3) plt.imshow(mask.astype('uint8'), cmap='gray') plt.title('Predict') plt.axis("off") i += 3 mask_idx += 1 plt.show() 改进网络重新训练,以提高分割准确度,考虑增加BN和残差网络,并画出损失函数和准确率趋势的折线图
最新发布
11-08
接着解读:import math import numpy as np import pymongo import torch.optim as optim from tqdm import tqdm import copy from Autoencoder.data.datasets import get_loader from Autoencoder.loss.distortion import * from Autoencoder.net import channel, network from Diffusion import ChannelDiffusionTrainer, ChannelDiffusionSampler # from Diffusion.Autoencoder import AE from Diffusion.Model import UNet from Scheduler import GradualWarmupScheduler from torchvision.utils import save_image import torch.nn as nn def train_CHDDIM(config, CHDDIM_config): encoder = network.JSCC_encoder(config, config.C).cuda() encoder_path = config.encoder_path pass_channel = channel.Channel(config) encoder.eval() CDDM_config=copy.deepcopy(config) CDDM_config.batch_size=config.CDDM_batch trainLoader, _ = get_loader(CDDM_config) CHDDIM = UNet(T=CHDDIM_config.T, ch=CHDDIM_config.channel, ch_mult=CHDDIM_config.channel_mult, attn=CHDDIM_config.attn, num_res_blocks=CHDDIM_config.num_res_blocks, dropout=CHDDIM_config.dropout, input_channel=CHDDIM_config.C).cuda() # encoder = torch.nn.DataParallel(encoder, device_ids=CHDDIM_config.device_ids) # # CHDDIM = torch.nn.DataParallel(CHDDIM, device_ids=CHDDIM_config.device_ids) encoder.load_state_dict(torch.load(encoder_path)) # if CHDDIM_config.training_load_weight is not None: # ckpt = torch.load(CHDDIM_config.save_weight_dir + CHDDIM_config.training_load_weight) # CHDDIM.load_state_dict(ckpt) # encoder = encoder.cuda(device=CHDDIM_config.device_ids[0]) # # CHDDIM = CHDDIM.cuda(device=CHDDIM_config.device_ids[0]) optimizer = torch.optim.AdamW( CHDDIM.parameters(), lr=CHDDIM_config.lr, weight_decay=1e-4) # print(CHDDIM_config.lr) cosineScheduler = optim.lr_scheduler.CosineAnnealingLR( optimizer=optimizer, T_max=CHDDIM_config.epoch, eta_min=0, last_epoch=-1) warmUpScheduler = GradualWarmupScheduler( optimizer=optimizer, multiplier=CHDDIM_config.multiplier, warm_epoch=0.1, # CHDDIM_config.epoch // 10, after_scheduler=cosineScheduler) trainer = ChannelDiffusionTrainer(model=CHDDIM,noise_schedule=CHDDIM_config.noise_schedule, re_weight=CHDDIM_config.re_weight,beta_1=CHDDIM_config.snr_max, beta_T=CHDDIM_config.snr_min, T=CHDDIM_config.T).cuda() # start training all_loss=[] for e in range(CHDDIM_config.epoch): ave_loss=0 with tqdm(trainLoader, dynamic_ncols=True) as tqdmDataLoader: for images, labels in tqdmDataLoader: # train optimizer.zero_grad() x_0 = images.cuda() # 在这里加上一个干扰图片 # idx_perm = torch.randperm(images.size(0)).cuda() # I_images = images.index_select(0, idx_perm) # 干扰与原图片存在一个batch里面 # feature_I, _ = encoder(I_images) feature, _ = encoder(x_0) # y = torch.cat([feature, feature_I], dim = 1) y = feature # print(y.shape) # print("y:",y) # print("mean:",feature.mode()) y, pwr = pass_channel.complex_normalize(y, power=1) # normalize if config.channel_type == "rayleigh": _, h = pass_channel.reyleigh_layer(y) elif config.channel_type == 'awgn': h = torch.ones(y.shape).cuda() else: raise ValueError loss = trainer(y, h, config.SNRs, channel_type=config.channel_type) loss.backward() ave_loss+=loss torch.nn.utils.clip_grad_norm_( CHDDIM.parameters(), CHDDIM_config.grad_clip) optimizer.step() tqdmDataLoader.set_postfix(ordered_dict={ "epoch": e, "state": 'train_CDDM'+'re_weight'+str(CHDDIM_config.re_weight), "loss: ": loss.item(), "noise_schedule":CHDDIM_config.noise_schedule, "input shape: ": x_0.shape, "CBR": feature.numel() / 2 / x_0.numel(), "LR": optimizer.state_dict()['param_groups'][0]["lr"] }) warmUpScheduler.step() if (e + 1) % CHDDIM_config.save_model_freq == 0: torch.save(CHDDIM.state_dict(), CHDDIM_config.save_path) all_loss.append(ave_loss.item()/50) #print(all_loss) def eval_JSCC_with_CDDM(config, CHDDIM_config): encoder = network.JSCC_encoder(config, config.C).cuda() decoder = network.JSCC_decoder(config, config.C).cuda() encoder_path = config.encoder_path decoder_path = config.re_decoder_path pass_channel = channel.Channel(config) encoder.eval() decoder.eval() _, testLoader = get_loader(config) CHDDIM = UNet(T=CHDDIM_config.T, ch=CHDDIM_config.channel, ch_mult=CHDDIM_config.channel_mult, attn=CHDDIM_config.attn, num_res_blocks=CHDDIM_config.num_res_blocks, dropout=CHDDIM_config.dropout, input_channel=CHDDIM_config.C).cuda() # encoder = torch.nn.DataParallel(encoder, device_ids=CHDDIM_config.device_ids) # decoder = torch.nn.DataParallel(decoder, device_ids=CHDDIM_config.device_ids) # CHDDIM = torch.nn.DataParallel(CHDDIM, device_ids=CHDDIM_config.device_ids) # # encoder = encoder.cuda(device=CHDDIM_config.device_ids[0]) # decoder = decoder.cuda(device=CHDDIM_config.device_ids[0]) # CHDDIM = CHDDIM.cuda(device=CHDDIM_config.device_ids[0]) encoder.load_state_dict(torch.load(encoder_path)) ckpt = torch.load(CHDDIM_config.save_path) CHDDIM.load_state_dict(ckpt) CHDDIM.eval() decoder.load_state_dict(torch.load(decoder_path)) sampler = ChannelDiffusionSampler(model=CHDDIM, noise_schedule=CHDDIM_config.noise_schedule,t_max=CHDDIM_config.t_max,beta_1=CHDDIM_config.snr_max, beta_T=CHDDIM_config.snr_min, T=CHDDIM_config.T).cuda() if config.dataset == "CIFAR10": CalcuSSIM = MS_SSIM(window_size=3, data_range=1., levels=4, channel=3).cuda() else: CalcuSSIM = MS_SSIM(data_range=1., levels=4, channel=3).cuda() # start training snr_in = config.SNRs - CHDDIM_config.large_snr matric_aver = 0 mse1_aver = 0 mse2_aver = 0 # sigma_eps_aver=torch.zeros() with tqdm(testLoader, dynamic_ncols=True) as tqdmtestLoader: for i, (images, labels) in enumerate(tqdmtestLoader): # train x_0 = images.cuda() feature, _ = encoder(x_0) y = feature y_0 = y y, pwr, h = pass_channel.forward(y, snr_in) # normalize sigma_square = 1.0 / (2 * 10 ** (snr_in / 10)) if config.channel_type == "awgn": y_awgn = torch.cat((torch.real(y), torch.imag(y)), dim=2) mse1 = torch.nn.MSELoss()(y_awgn * math.sqrt(2), y_0 * math.sqrt(2) / torch.sqrt(pwr)) elif config.channel_type == 'rayleigh': y_mmse = y * torch.conj(h) / (torch.abs(h) ** 2 + sigma_square * 2) y_mmse = torch.cat((torch.real(y_mmse), torch.imag(y_mmse)), dim=2) mse1 = torch.nn.MSELoss()(y_mmse * math.sqrt(2), y_0 * math.sqrt(2) / torch.sqrt(pwr)) else: raise ValueError y = y / math.sqrt(1 + sigma_square) # 这里可能改一下 feature_hat = sampler(y, snr_in, snr_in + CHDDIM_config.large_snr, h, config.channel_type) mse2 = torch.nn.MSELoss()(feature_hat * math.sqrt(2), y_0 * math.sqrt(2) / torch.sqrt(pwr)) feature_hat = feature_hat * torch.sqrt(pwr) x_0_hat = decoder(feature_hat) # optimizer1.step() # optimizer2.step() if config.loss_function == "MSE": mse = torch.nn.MSELoss()(x_0 * 255., x_0_hat.clamp(0., 1.) * 255) psnr = 10 * math.log10(255. * 255. / mse.item()) matric = psnr #save_image(x_0_hat,"/home/wutong/semdif_revise/DIV2K_JSCCCDDM_rayleigh_PSNR_10dB/{}.png".format(i)) elif config.loss_function == "MSSSIM": msssim = 1 - CalcuSSIM(x_0, x_0_hat.clamp(0., 1.)).mean().item() matric = msssim #save_image(x_0_hat,"/home/wutong/semdif_revise/DIV2K_JSCCCDDM_rayleigh_MSSSIM_10dB/{}.png".format(i)) mse1_aver += mse1.item() mse2_aver += mse2.item() matric_aver += matric CBR = feature.numel() / 2 / x_0.numel() tqdmtestLoader.set_postfix(ordered_dict={ "dataset": config.dataset, "re_weight":str(CHDDIM_config.re_weight), "state": 'eval JSCC with CDDM' + config.loss_function, "channel": config.channel_type, "noise_schedule":CHDDIM_config.noise_schedule, "CBR": CBR, "SNR": snr_in, "matric ": matric, "MSE_channel": mse1.item(), "MSE_channel+CDDM": mse2.item(), "T_max":CHDDIM_config.t_max }) mse1_aver = (mse1_aver / (i + 1)) mse2_aver = (mse2_aver / (i + 1)) matric_aver = (matric_aver / (i + 1)) if config.loss_function == "MSE": name = 'PSNR' elif config.loss_function == "MSSSIM": name = "MSSSIM" else: raise ValueError #print("matric:{}",matric_aver) myclient = pymongo.MongoClient(config.database_address) mydb = myclient[config.dataset] if 'SNRs' in config.encoder_path: mycol = mydb[name + '_' + config.channel_type + '_SNRs_' + 'JSCC+CDDM' + '_CBR_' + str(CBR)] mydic = {'SNR': snr_in, name: matric_aver} mycol.insert_one(mydic) print('writing successfully', mydic) mycol = mydb['MSE' + name + '_' + config.channel_type + '_SNRs_' + 'JSCC' + '_CBR_' + str(CBR)] mydic = {'SNR': snr_in, 'MSE': mse1_aver} mycol.insert_one(mydic) print('writing successfully', mydic) mycol = mydb["MSE" + name + '_' + config.channel_type + '_SNRs_' + 'JSCC+CDDM' + '_CBR_' + str(CBR)] mydic = {'SNR': snr_in, 'MSE': mse2_aver} mycol.insert_one(mydic) print('writing successfully', mydic) elif 'CBRs' in config.encoder_path: mycol = mydb[name + '_' + config.channel_type + '_CBRs_' + 'JSCC+CDDM' + '_SNR_' + str(snr_in)] mydic = {'CBR': CBR, name: matric_aver} mycol.insert_one(mydic) print('writing successfully', mydic) mycol = mydb['MSE' + name + '_' + config.channel_type + '_CBRs_' + 'JSCC' + '_SNR_' + str(snr_in)] mydic = {'CBR': CBR, 'MSE': mse1_aver} mycol.insert_one(mydic) print('writing successfully', mydic) mycol = mydb["MSE" + name + '_' + config.channel_type + '_CBRs_' + 'JSCC+CDDM' + '_SNR_' + str(snr_in)] mydic = {'CBR': CBR, 'MSE': mse2_aver} mycol.insert_one(mydic) print('writing successfully', mydic) else: raise ValueError # print(psnr_aver/100,mse1_aver/100,mse2_aver/100) # eval_psnr = np.array(sigma_eps_aver / 100) # print(PSNR_all) # # print(eval_psnr.shape) # file = ('./CDESC_sigma_eps_rayleigh_decoderSNR5.csv'.format(CHDDIM_config.train_snr)) # data = pd.DataFrame(eval_psnr) # data.to_csv(file, index=False) # eval_psnr = np.array(PSNR_all) # print(eval_psnr) # # print(eval_psnr.shape) # file = ('./CDESC_PSNR_rayleigh_decoderSNR5.csv'.format(CHDDIM_config.train_snr)) # data = pd.DataFrame(eval_psnr) # data.to_csv(file, index=False) def eval_JSCC_with_CDDM_SNRs(config, CHDDIM_config): encoder = network.JSCC_encoder(config, config.C).cuda() decoder = network.JSCC_decoder(config, config.C).cuda() encoder_path = config.encoder_path decoder_path = config.re_decoder_path pass_channel = channel.Channel(config) encoder.eval() decoder.eval() _, testLoader = get_loader(config) CHDDIM = UNet(T=CHDDIM_config.T, ch=CHDDIM_config.channel, ch_mult=CHDDIM_config.channel_mult, attn=CHDDIM_config.attn, num_res_blocks=CHDDIM_config.num_res_blocks, dropout=CHDDIM_config.dropout, input_channel=CHDDIM_config.C).cuda() # encoder = torch.nn.DataParallel(encoder, device_ids=CHDDIM_config.device_ids) # decoder = torch.nn.DataParallel(decoder, device_ids=CHDDIM_config.device_ids) # CHDDIM = torch.nn.DataParallel(CHDDIM, device_ids=CHDDIM_config.device_ids) # # encoder = encoder.cuda(device=CHDDIM_config.device_ids[0]) # decoder = decoder.cuda(device=CHDDIM_config.device_ids[0]) # CHDDIM = CHDDIM.cuda(device=CHDDIM_config.device_ids[0]) encoder.load_state_dict(torch.load(encoder_path)) ckpt = torch.load(CHDDIM_config.save_path) CHDDIM.load_state_dict(ckpt) CHDDIM.eval() decoder.load_state_dict(torch.load(decoder_path)) sampler = ChannelDiffusionSampler(model=CHDDIM, beta_1=CHDDIM_config.snr_max, beta_T=CHDDIM_config.snr_min, T=CHDDIM_config.T).cuda() if config.dataset == "CIFAR10": CalcuSSIM = MS_SSIM(window_size=3, data_range=1., levels=4, channel=3).cuda() else: CalcuSSIM = MS_SSIM(data_range=1., levels=4, channel=3).cuda() # start training #snr_in = config.SNRs - 3 matric_all = [] # sigma_eps_aver=torch.zeros() for snr_in in config.all_SNRs: matric_aver = 0 with tqdm(testLoader, dynamic_ncols=True) as tqdmtestLoader: for i, (images, labels) in enumerate(tqdmtestLoader): # train x_0 = images.cuda() feature, _ = encoder(x_0) y = feature y, pwr, h = pass_channel.forward(y, snr_in) # normalize sigma_square = 1.0 / (2 * 10 ** (snr_in / 10)) y = y / math.sqrt(1 + sigma_square) # 这里可能改一下 feature_hat = sampler(y, snr_in, config.SNRs, h, config.channel_type) feature_hat = feature_hat * torch.sqrt(pwr) x_0_hat = decoder(feature_hat) # optimizer1.step() # optimizer2.step() if config.loss_function == "MSE": mse = torch.nn.MSELoss()(x_0 * 255., x_0_hat.clamp(0., 1.) * 255) psnr = 10 * math.log10(255. * 255. / mse.item()) matric = psnr elif config.loss_function == "MSSSIM": msssim = 1 - CalcuSSIM(x_0, x_0_hat.clamp(0., 1.)).mean().item() matric = msssim matric_aver += matric CBR = feature.numel() / 2 / x_0.numel() tqdmtestLoader.set_postfix(ordered_dict={ "dataset": config.dataset, "state": 'eval JSCC with CDDM for all SNRs' + config.loss_function, "channel": config.channel_type, "CBR": CBR, "SNR": snr_in, "matric ": matric, }) matric_aver = (matric_aver / (i + 1)) matric_all.append(matric_aver) if config.loss_function == "MSE": name = 'PSNR' elif config.loss_function == "MSSSIM": name = "MSSSIM" else: raise ValueError myclient = pymongo.MongoClient(config.database_address) mydb = myclient[config.dataset] mycol = mydb[name + '_' + config.channel_type + '_allSNRs_' + 'JSCC+CDDM' +'_CBR_' + str(CBR)] mydic = {'SNR': config.all_SNRs, name: matric_all} mycol.insert_one(mydic) print('writing successfully', mydic) def eval_JSCC_SNRs(config): encoder = network.JSCC_encoder(config, config.C).cuda() decoder = network.JSCC_decoder(config, config.C).cuda() encoder_path = config.encoder_path decoder_path = config.decoder_path pass_channel = channel.Channel(config) encoder.eval() decoder.eval() _, testLoader = get_loader(config) encoder.load_state_dict(torch.load(encoder_path)) decoder.load_state_dict(torch.load(decoder_path)) if config.dataset == "CIFAR10": CalcuSSIM = MS_SSIM(window_size=3, data_range=1., levels=4, channel=3).cuda() else: CalcuSSIM = MS_SSIM(data_range=1., levels=4, channel=3).cuda() # start training #snr_in = config.SNRs - 3 matric_all = [] # sigma_eps_aver=torch.zeros() for snr_in in config.all_SNRs: matric_aver = 0 with tqdm(testLoader, dynamic_ncols=True) as tqdmtestLoader: for i, (images, labels) in enumerate(tqdmtestLoader): # train x_0 = images.cuda() feature, _ = encoder(x_0) y = feature y, pwr, h = pass_channel.forward(y, snr_in) # normalize sigma_square_fix = 1.0 / (10 ** (snr_in / 10)) if config.channel_type=='rayleigh': y = y * (torch.conj(h)) / (torch.abs(h) ** 2 + sigma_square_fix) elif config.channel_type=='awgn': y=y else: raise ValueError y = torch.cat((torch.real(y), torch.imag(y)), dim=2) feature_hat = y * torch.sqrt(pwr) x_0_hat = decoder(feature_hat) if config.loss_function == "MSE": mse = torch.nn.MSELoss()(x_0 * 255., x_0_hat.clamp(0., 1.) * 255) psnr = 10 * math.log10(255. * 255. / mse.item()) matric = psnr elif config.loss_function == "MSSSIM": msssim = 1 - CalcuSSIM(x_0, x_0_hat.clamp(0., 1.)).mean().item() matric = msssim matric_aver += matric CBR = feature.numel() / 2 / x_0.numel() tqdmtestLoader.set_postfix(ordered_dict={ "dataset": config.dataset, "state": 'eval JSCC for all SNRs' + config.loss_function, "channel": config.channel_type, "CBR": CBR, "SNR": snr_in, "matric ": matric, }) matric_aver = (matric_aver / (i + 1)) matric_all.append(matric_aver) if config.loss_function == "MSE": name = 'PSNR' elif config.loss_function == "MSSSIM": name = "MSSSIM" else: raise ValueError myclient = pymongo.MongoClient(config.database_address) mydb = myclient[config.dataset] mycol = mydb[name + '_' + config.channel_type + '_allSNRs_' + 'JSCC' +'_CBR_' + str(CBR)] mydic = {'SNR': config.all_SNRs, name: matric_all} mycol.insert_one(mydic) print('writing successfully', mydic) def eval_JSCC_with_CDDM_delte_h(config, CHDDIM_config): encoder = network.JSCC_encoder(config, config.C).cuda() decoder = network.JSCC_decoder(config, config.C).cuda() encoder_path = config.encoder_path decoder_path = config.re_decoder_path pass_channel = channel.Channel(config) encoder.eval() decoder.eval() _, testLoader = get_loader(config) CHDDIM = UNet(T=CHDDIM_config.T, ch=CHDDIM_config.channel, ch_mult=CHDDIM_config.channel_mult, attn=CHDDIM_config.attn, num_res_blocks=CHDDIM_config.num_res_blocks, dropout=CHDDIM_config.dropout, input_channel=CHDDIM_config.C).cuda() # encoder = torch.nn.DataParallel(encoder, device_ids=CHDDIM_config.device_ids) # decoder = torch.nn.DataParallel(decoder, device_ids=CHDDIM_config.device_ids) # CHDDIM = torch.nn.DataParallel(CHDDIM, device_ids=CHDDIM_config.device_ids) # # encoder = encoder.cuda(device=CHDDIM_config.device_ids[0]) # decoder = decoder.cuda(device=CHDDIM_config.device_ids[0]) # CHDDIM = CHDDIM.cuda(device=CHDDIM_config.device_ids[0]) encoder.load_state_dict(torch.load(encoder_path)) ckpt = torch.load(CHDDIM_config.save_path) CHDDIM.load_state_dict(ckpt) CHDDIM.eval() decoder.load_state_dict(torch.load(decoder_path)) sampler = ChannelDiffusionSampler(model=CHDDIM, beta_1=CHDDIM_config.snr_max, beta_T=CHDDIM_config.snr_min, T=CHDDIM_config.T).cuda() if config.dataset == "CIFAR10": CalcuSSIM = MS_SSIM(window_size=3, data_range=1., levels=4, channel=3).cuda() else: CalcuSSIM = MS_SSIM(data_range=1., levels=4, channel=3).cuda() # start training snr_in = config.SNRs - 3 for h_sigma in config.h_sigma: matric_aver = 0 mse1_aver = 0 mse2_aver = 0 # sigma_eps_aver=torch.zeros() with tqdm(testLoader, dynamic_ncols=True) as tqdmtestLoader: for i, (images, labels) in enumerate(tqdmtestLoader): # train x_0 = images.cuda() feature, _ = encoder(x_0) y = feature y_0 = y y, pwr, h = pass_channel.forward(y, snr_in) # normalize delte_h = h_sigma * ( torch.normal(mean=0.0, std=1, size=np.shape(h)) + 1j * torch.normal(mean=0.0, std=1, size=np.shape( h))) / np.sqrt(2) h = h + delte_h.cuda() sigma_square = 1.0 / (2 * 10 ** (snr_in / 10)) if config.channel_type == "awgn": y_awgn = torch.cat((torch.real(y), torch.imag(y)), dim=2) mse1 = torch.nn.MSELoss()(y_awgn * math.sqrt(2), y_0 * math.sqrt(2) / torch.sqrt(pwr)) elif config.channel_type == 'rayleigh': y_mmse = y * torch.conj(h) / (torch.abs(h) ** 2 + sigma_square * 2) y_mmse = torch.cat((torch.real(y_mmse), torch.imag(y_mmse)), dim=2) mse1 = torch.nn.MSELoss()(y_mmse * math.sqrt(2), y_0 * math.sqrt(2) / torch.sqrt(pwr)) else: raise ValueError y = y / math.sqrt(1 + sigma_square) # 这里可能改一下 feature_hat = sampler(y, snr_in, snr_in + 3, h, config.channel_type) mse2 = torch.nn.MSELoss()(feature_hat * math.sqrt(2), y_0 * math.sqrt(2) / torch.sqrt(pwr)) feature_hat = feature_hat * torch.sqrt(pwr) x_0_hat = decoder(feature_hat) # optimizer1.step() # optimizer2.step() if config.loss_function == "MSE": mse = torch.nn.MSELoss()(x_0 * 255., x_0_hat.clamp(0., 1.) * 255) psnr = 10 * math.log10(255. * 255. / mse.item()) matric = psnr elif config.loss_function == "MSSSIM": msssim = 1 - CalcuSSIM(x_0, x_0_hat.clamp(0., 1.)).mean().item() matric = msssim mse1_aver += mse1.item() mse2_aver += mse2.item() matric_aver += matric CBR = feature.numel() / 2 / x_0.numel() tqdmtestLoader.set_postfix(ordered_dict={ "dataset": config.dataset, "state": 'eval delte h JSCC with CDDM_' + config.loss_function, "h sigma": h_sigma, "channel": config.channel_type, "CBR": CBR, "SNR": snr_in, "matric ": matric, "MSE_channel": mse1.item(), "MSE_channel+CDDM": mse2.item() }) mse1_aver = (mse1_aver / (i + 1)) mse2_aver = (mse2_aver / (i + 1)) matric_aver = (matric_aver / (i + 1)) if config.loss_function == "MSE": name = 'PSNR' elif config.loss_function == "MSSSIM": name = "MSSSIM" else: raise ValueError myclient = pymongo.MongoClient(config.database_address) mydb = myclient[config.dataset] if 'SNRs' in config.encoder_path: mycol = mydb[name + '_' + config.channel_type + '_SNRs_' + 'JSCC+CDDM' + '_h_sigma_' + str( h_sigma) + '_CBR_' + str(CBR)] mydic = {'SNR': snr_in, name: matric_aver} mycol.insert_one(mydic) print('writing successfully', mydic) mycol = mydb['MSE' + name + '_' + config.channel_type + '_SNRs_' + 'JSCC' + '_h_sigma_' + str( h_sigma) + '_CBR_' + str(CBR)] mydic = {'SNR': snr_in, 'MSE': mse1_aver} mycol.insert_one(mydic) print('writing successfully', mydic) mycol = mydb["MSE" + name + '_' + config.channel_type + '_SNRs_' + 'JSCC+CDDM' + '_h_sigma_' + str( h_sigma) + '_CBR_' + str(CBR)] mydic = {'SNR': snr_in, 'MSE': mse2_aver} mycol.insert_one(mydic) print('writing successfully', mydic) elif 'CBRs' in config.encoder_path: mycol = mydb[name + '_' + config.channel_type + '_CBRs_' + 'JSCC+CDDM' + '_h_sigma_' + str( h_sigma) + '_SNR_' + str(snr_in)] mydic = {'CBR': CBR, name: matric_aver} mycol.insert_one(mydic) print('writing successfully', mydic) mycol = mydb['MSE' + name + '_' + config.channel_type + '_CBRs_' + 'JSCC' + '_h_sigma_' + str( h_sigma) + '_SNR_' + str(snr_in)] mydic = {'CBR': CBR, 'MSE': mse1_aver} mycol.insert_one(mydic) print('writing successfully', mydic) mycol = mydb["MSE" + name + '_' + config.channel_type + '_CBRs_' + 'JSCC+CDDM' + '_h_sigma_' + str( h_sigma) + '_SNR_' + str(snr_in)] mydic = {'CBR': CBR, 'MSE': mse2_aver} mycol.insert_one(mydic) print('writing successfully', mydic) else: raise ValueError def train_JSCC_with_CDDM(config, CHDDIM_config): encoder = network.JSCC_encoder(config, config.C).cuda() decoder = network.JSCC_decoder(config, config.C).cuda() encoder_path = config.encoder_path decoder_path = config.decoder_path pass_channel = channel.Channel(config) trainLoader, _ = get_loader(config) encoder.eval() CHDDIM = UNet(T=CHDDIM_config.T, ch=CHDDIM_config.channel, ch_mult=CHDDIM_config.channel_mult, attn=CHDDIM_config.attn, num_res_blocks=CHDDIM_config.num_res_blocks, dropout=CHDDIM_config.dropout, input_channel=CHDDIM_config.C).cuda() # encoder = torch.nn.DataParallel(encoder, device_ids=CHDDIM_config.device_ids) # decoder = torch.nn.DataParallel(decoder, device_ids=CHDDIM_config.device_ids) # CHDDIM = torch.nn.DataParallel(CHDDIM, device_ids=CHDDIM_config.device_ids) # # encoder = encoder.cuda(device=CHDDIM_config.device_ids[0]) # decoder = decoder.cuda(device=CHDDIM_config.device_ids[0]) # CHDDIM = CHDDIM.cuda(device=CHDDIM_config.device_ids[0]) encoder.load_state_dict(torch.load(encoder_path)) decoder.load_state_dict(torch.load(decoder_path)) ckpt = torch.load(CHDDIM_config.save_path) CHDDIM.load_state_dict(ckpt) CHDDIM.eval() sampler = ChannelDiffusionSampler(model=CHDDIM, noise_schedule=CHDDIM_config.noise_schedule,t_max=CHDDIM_config.t_max, beta_1=CHDDIM_config.snr_max, beta_T=CHDDIM_config.snr_min, T=CHDDIM_config.T).cuda() # optimizer_encoder = torch.optim.AdamW( # encoder.parameters(), lr=CHDDIM_config.lr, weight_decay=1e-4) optimizer_decoder = torch.optim.AdamW( decoder.parameters(), lr=CHDDIM_config.lr, weight_decay=1e-4) # start training if config.dataset == "CIFAR10": CalcuSSIM = MS_SSIM(window_size=3, data_range=1., levels=4, channel=3).cuda() else: CalcuSSIM = MS_SSIM(data_range=1., levels=4, channel=3).cuda() for e in range(config.retrain_epoch): with tqdm(trainLoader, dynamic_ncols=True) as tqdmtrainLoader: for i, (images, labels) in enumerate(tqdmtrainLoader): # train snr = config.SNRs - CHDDIM_config.large_snr x_0 = images.cuda() feature, _ = encoder(x_0) y = feature y, pwr, h = pass_channel.forward(y, snr) # normalize sigma_square = 1.0 / (2 * 10 ** (snr / 10)) y = y / math.sqrt(1 + sigma_square) # 这里可能改一下 feature_hat = sampler(y, snr, snr + CHDDIM_config.large_snr, h, config.channel_type) feature_hat = feature_hat * torch.sqrt(pwr) x_0_hat = decoder(feature_hat) # mse1=torch.nn.MSEloss()() if config.loss_function == "MSE": loss = torch.nn.MSELoss()(x_0, x_0_hat) elif config.loss_function == "MSSSIM": loss = CalcuSSIM(x_0, x_0_hat.clamp(0., 1.)).mean() else: raise ValueError optimizer_decoder.zero_grad() loss.backward() optimizer_decoder.step() # optimizer_encoder.step() if config.loss_function == "MSE": mse = torch.nn.MSELoss()(x_0 * 255., x_0_hat.clamp(0., 1.) * 255) psnr = 10 * math.log10(255. * 255. / mse.item()) matric = psnr elif config.loss_function == "MSSSIM": msssim = 1 - CalcuSSIM(x_0, x_0_hat.clamp(0., 1.)).mean().item() matric = msssim tqdmtrainLoader.set_postfix(ordered_dict={ "dataset": config.dataset, "state": "train_decoder" + config.loss_function, "noise_schedule":CHDDIM_config.noise_schedule, "channel": config.channel_type, "CBR:": feature.numel() / 2 / x_0.numel(), "SNR": snr, "matric": matric, "T_max":CHDDIM_config.t_max }) if (e + 1) % config.retrain_save_model_freq == 0: torch.save(decoder.state_dict(), config.re_decoder_path) def train_DnCNN(config,CHDDIM_config): from DnCNN.models import DnCNN import torch.nn as nn encoder = network.JSCC_encoder(config, config.C).cuda() encoder_path = config.encoder_path pass_channel = channel.Channel(config) encoder.eval() CNN_config=copy.deepcopy(config) CNN_config.batch_size=config.CDDM_batch trainLoader, _ = get_loader(CNN_config) DeCNN=DnCNN(config.C).cuda() DeCNN.train() # encoder = torch.nn.DataParallel(encoder, device_ids=CHDDIM_config.device_ids) # # CHDDIM = torch.nn.DataParallel(CHDDIM, device_ids=CHDDIM_config.device_ids) encoder.load_state_dict(torch.load(encoder_path)) # if CHDDIM_config.training_load_weight is not None: # ckpt = torch.load(CHDDIM_config.save_weight_dir + CHDDIM_config.training_load_weight) # CHDDIM.load_state_dict(ckpt) # encoder = encoder.cuda(device=CHDDIM_config.device_ids[0]) # # CHDDIM = CHDDIM.cuda(device=CHDDIM_config.device_ids[0]) optimizer = torch.optim.Adam( DeCNN.parameters(), lr=CNN_config.learning_rate) # print(CHDDIM_config.lr) # start training for e in range(CHDDIM_config.epoch): with tqdm(trainLoader, dynamic_ncols=True) as tqdmDataLoader: for images, labels in tqdmDataLoader: # train optimizer.zero_grad() x_0 = images.cuda() feature, _ = encoder(x_0) y = feature # print(y.shape) # print("y:",y) # print("mean:",feature.mode()) y, pwr, h = pass_channel.forward(y, config.SNRs) sigma_square = 1.0 / (2 * 10 ** (config.SNRs / 10)) if config.channel_type == "awgn": y_awgn = torch.cat((torch.real(y), torch.imag(y)), dim=2) noise=y_awgn-feature #mse1 = torch.nn.MSELoss()(y_awgn * math.sqrt(2), y * math.sqrt(2) / torch.sqrt(pwr)) receive=y_awgn elif config.channel_type == 'rayleigh': y_mmse = y * torch.conj(h) / (torch.abs(h) ** 2 + sigma_square * 2) y_mmse = torch.cat((torch.real(y_mmse), torch.imag(y_mmse)), dim=2) noise=y_mmse-feature #mse1 = torch.nn.MSELoss()(y_mmse * math.sqrt(2), y* math.sqrt(2) / torch.sqrt(pwr)) receive=y_mmse else: raise ValueError output=DeCNN(receive) loss=nn.MSELoss()(output, noise) loss.backward() optimizer.step() tqdmDataLoader.set_postfix(ordered_dict={ "epoch": e, "channel_type":config.channel_type, "state": 'train_DeCNN', "loss: ": loss.item(), "input shape: ": x_0.shape, "CBR": feature.numel() / 2 / x_0.numel(), "LR": optimizer.state_dict()['param_groups'][0]["lr"] }) if (e + 1) % CHDDIM_config.save_model_freq == 0: torch.save(DeCNN.state_dict(), CHDDIM_config.save_path)
08-14
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值