Multi-scale Transformer Network with Edge-aware Pre-training for Cross-Modality MRImage Synthesis 代码复现和讲解
复现
解压后打开,放入服务器中
2、 安装环境
pip install -r requirements.txt
注意torch版本
3、下载数据集
-
Download BraTS2020 dataset from kaggle. The file name should be
./data/archive.zip. Unzip the file in./data/.
data/MICCAI_BraTS2020_TrainingData
4、运行预处理
python utils/preprocessing.py

5、预训练
python pretrain.py
更改/options/pretrain_options.py文件中的默认设置。
权重将保存在。/weight/EdgeMAE/中。

6、微调
python Finetune.py
可以更改./options/finetune_options.py中的默认设置,特别是data_rate选项,以调整配对数据的数量以进行微调。此外,您可以增加num_workers来加速微调。
重量将保存在。./weight/finetuned/。注意,对于MT-Net,输入大小必须为256×256。
7、使用
test.py 进行测试
合成图片将保存在./snapshot/test/
代码讲解
preprocessing
读取t1\t2\flair\t1ce 四个模态的数据和gt 数据
for i,(t1_path, t2_path, t1ce_path, flair_path, gt_path) in enumerate(zip(t1_list,t2_list,t1ce_list,flair_list,gt_list)):
print('preprocessing the',i+1,'th subject')
t1_img = nib.load(t1_path) # (240,140,155)
t2_img = nib.load(t2_path)
flair_img = nib.load(flair_path)
t1ce_img = nib.load(t1ce_path)
gt_img = nib.load(gt_path)
转化为np然后标准化
#to numpy
t1_data = t1_img.get_fdata()
t2_data = t2_img.get_fdata()
flair_data = flair_img.get_fdata()
t1ce_data = t1ce_img.get_fdata()
gt_data = gt_img.get_fdata()
gt_data = gt_data.astype(np.uint8)
gt_data[gt_data == 4] = 3 #label 3 is missing in BraTS 2020
t1_data = normalize(t1_data) # normalize to [0,1]
t2_data = normalize(t2_data)
t1ce_data = normalize(t1ce_data)
flair_data = normalize(flair_data)
压缩为一个5通道数据
tensor = np.stack([t1_data, t2_data, t1ce_data, flair_data, gt_data]) # (4, 240, 240, 155)
截取一部分保存
if i < train_len:
for j in range(60):
Tensor = tensor[:, 10:210, 25:225, 50 + j]
np.save(train_path + str(60 * i + j + 1) + '.npy', Tensor)
else:
for j in range(60):
Tensor = tensor[:, 10:210, 25:225, 50 + j]
np.save(test_path + str(60 * (i - train_len) + j + 1) + '.npy', Tensor)
每个np为(5,200,200,60)
pertain
mae网络
mae = EdgeMAE(img_size=opt.img_size,patch_size=opt.patch_size, embed_dim=opt.dim_encoder, depth=opt.depth, num_heads=opt.num_heads, in_chans=1,
decoder_embed_dim=opt.dim_decoder, decoder_depth=opt.decoder_depth, decoder_num_heads=opt.decoder_num_heads,
mlp_ratio=opt.mlp_ratio,norm_pix_loss=False,patchwise_loss=opt.use_patchwise_loss)
定义了编码器 ,包括PatchEmbed,然后定义好了block块
# MAE encoder specifics
self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
num_patches

最低0.47元/天 解锁文章
1710

被折叠的 条评论
为什么被折叠?



