猛的一看今天已经2月5号了,寒假马上就要结束了,可自己的科研进展几乎没有。我瞬间感觉很恐慌,我似乎看到了我毕业后去送外卖的样子,我不能接受这个结果。但是我在经历了这么多事情之后,我认识到我也是肉体凡胎,我今天很恐慌,可能我明天就摆烂了,人就是这样。所以我不想管明天的事情,我只管今天,我今天要努力学习三个小时。
1
def main(args):
with open(args.config_filename) as f:
config = yaml.load(f, Loader=yaml.Loader)
dataloaders, node_size, node_feature_size, timeseries_size = \
init_dataloader(config['data'])
config['train']["seq_len"] = timeseries_size
config['train']["node_size"] = node_size
model = PLSNet(config['model'], node_size,
node_feature_size, timeseries_size)
use_train = BasicTrain
optimizer = torch.optim.Adam(
model.parameters(), lr=config['train']['lr'],
weight_decay=config['train']['weight_decay'])
opts = (optimizer,)
loss_name = 'loss'
if config['train']["group_loss"]:
loss_name = f"{loss_name}_group_loss"
if config['train']["sparsity_loss"]:
loss_name = f"{loss_name}_sparsity_loss"
save_folder_name = Path(config['train']['log_folder'])/Path(config['model']['type'])/Path(
# date_time +
f"{config['data']['dataset']}_{config['data']['atlas']}")
train_process = use_train(
config['train'], model, opts, dataloaders, save_folder_name)
train_process.train()
2
with open(args.config_filename) as f:
config = yaml.load(f, Loader=yaml.Loader)
这里的args.config_filename很显然是setting文件夹下的abide_PLSNet.yaml,这个yaml配置文件好像也是我粘贴复制来的。
我看了一遍,应该是我记混了,这个是原来文件夹下面就有的。
f这个东西能叫句柄吗,应该不能叫吧,它是一个变量,代表这个配置文件的内容。
这里就是读取这个yaml配置文件的内容,然后保存在config里面,通过config来调用内容
3
abide_PLSNet.yaml包括三部分的内容,分别是data,model和train
data:
dataset: ABIDE
atlas: else
batch_size: 16
test_batch_size: 16
val_batch_size: 16
train_set: 0.7
val_set: 0.1
fold: 0
# time_seires: /data/CodeGoat24/FBNETGEN/ABIDE_pcp/abide.npy
# time_seires: /data/CodeGoat24/FBNETGEN_ho/ABIDE_pcp/abide.npy
time_seires: ./data/CodeGoat24/FBNETGEN_AAL/ABIDE_pcp/abide.npy
# time_seires: /data/CodeGoat24/FBNETGEN_cc400/ABIDE_pcp/abide.npy
model:
type: PLSNet
extractor_type: attention
embedding_size: 8
window_size: 4
cnn_pool_size: 16
num_gru_layers: 4
dropout: 0.5
train:
lr: 1.0e-4
weight_decay: 1.0e-4
epochs: 500
pool_ratio: 0.7
optimizer: adam
stepsize: 200
group_loss: true
sparsity_loss: true
sparsity_loss_weight: 0.5e-4
log_folder: result
# uniform or pearson
pure_gnn_graph: pearson
dataloaders, node_size, node_feature_size, timeseries_size = \
init_dataloader(config['data'])
我懂了,原来config是变成字典了,data相当于一个子字典,然后里面有键值对
4
import numpy as np
import torch
import torch.utils.data as utils
import csv
from nilearn.connectome import ConnectivityMeasure
from sklearn import preprocessing
import pandas as pd
import matplotlib.pyplot as plt
from scipy.io import loadmat
from nilearn import plotting, datasets
class StandardScaler:
"""
Standard the input
"""
def __init__(self, mean, std):
self.mean = mean
self.std = std
def transform(self, data):
return (data - self.mean) / self.std
def inverse_transform(self, data):
return (data * self.std) + self.mean
def init_dataloader(dataset_config):
data = np.load(dataset_config["time_seires"], allow_pickle=True).item()
final_fc = data["timeseires"]
final_pearson = data["corr"]
labels = data["label"]
_, _, timeseries = final_fc.shape
_, node_size, node_feature_size = final_pearson.shape
scaler = StandardScaler(mean=np.mean(
final_fc), std=np.std(final_fc))
final_fc = scaler.transform(final_fc)
pseudo = []
for i in range(len(final_fc)):
pseudo.append(np.diag(np.ones(final_pearson.shape[1])))
if 'cc200' in dataset_config['atlas']:
pseudo_arr = np.concatenate(pseudo, axis=0).reshape((-1, 200, 200))
elif 'aal' in dataset_config['atlas']:
pseudo_arr = np.concatenate(pseudo, axis=0).reshape((-1, 116, 116))
elif 'cc400' in dataset_config['atlas']:
pseudo_arr = np.concatenate(pseudo, axis=0).reshape((-1, 392, 392))
else:
pseudo_arr = np.concatenate(pseudo, axis=0).reshape((-1, 111, 111))
final_fc, final_pearson, labels, pseudo_arr = [torch.from_numpy(
data).float() for data in (final_fc, final_pearson, labels, pseudo_arr)]
length = final_fc.shape[0]
train_length = int(length*dataset_config["train_set"])
val_length = int(length*dataset_config["val_set"])
dataset = utils.TensorDataset(
final_fc,
final_pearson,
labels,
pseudo_arr
)
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
dataset, [train_length, val_length, length-train_length-val_length])
train_dataloader = utils.DataLoader(
train_dataset, batch_size=dataset_config["batch_size"], shuffle=True, drop_last=False)
val_dataloader = utils.DataLoader(
val_dataset, batch_size=dataset_config["batch_size"], shuffle=True, drop_last=False)
test_dataloader = utils.DataLoader(
test_dataset, batch_size=dataset_config["batch_size"], shuffle=True, drop_last=False)
return (train_dataloader, val_dataloader, test_dataloader), node_size, node_feature_size, timeseries
接下来就是dataloader.py的内容了
我说这个time_series是啥啊,原来是 time_seires: ./data/CodeGoat24/FBNETGEN_AAL/ABIDE_pcp/abide.npy,一个点npy文件。原来这个npy文件又是一个处理格式,真是一层套一层啊
{'timeseires': array([[[ -4.671351, -12.931799, -20.238678, ..., -4.6999 ,
-2.881232, -5.373659],
[ -6.610331, -10.146871, -11.706947, ..., -8.545325,
-4.824956, -3.549801],
[ 6.539966, -3.924838, -22.075667, ..., -10.810907,
-9.507657, -11.782674],
...,
[ 23.815056, 14.159346, -4.235137, ..., 2.318333,
-0.419391, -6.08656 ],
[ 29.358294, 29.569822, 12.97839 , ..., -35.478521,
-32.452544, -38.755226],
[ 12.700474, 10.523634, -1.103658, ..., 16.980012,
18.994251, 7.554179]],
[[ 15.055096, 14.30167 , 8.35954 , ..., -2.320052,
5.738047, 18.929523],
[ -2.650435, -5.070064, -6.582098, ..., -6.500024,
-6.520322, -2.209336],
[ 7.958244, 7.113133, 3.146248, ..., -7.816352,
0.160806, 12.985252],
...,
[ 0.216612, -7.946295, -14.338741, ..., -2.661149,
-1.293657, 4.052661],
[ 0.595963, -1.462514, -3.352045, ..., -7.601621,
-1.38041 , 7.066522],
[ -1.754612, -2.342497, -2.100473, ..., -4.888682,
0.817308, 6.089231]],
[[ 10.699614, 12.640328, 9.107096, ..., 25.303575,
14.631509, -2.431361],
[ 8.946028, 11.992252, 10.768465, ..., 17.12695 ,
4.901006, -9.559137],
[ 12.502747, 17.729441, 17.117685, ..., 13.455708,
-13.292004, -35.746319],
...,
[-13.968082, -18.558227, -20.739389, ..., 30.461445,
48.840466, 47.846905],
[ 7.775516, 7.945796, 4.77154 , ..., 35.356228,
42.096015, 35.434606],
[ 3.752828, 5.027397, 4.62332 , ..., 22.339197,
34.736423, 33.95411 ]],
...,
[[ -9.594532, -18.276965, -12.977462, ..., 22.482481,
9.076281, -3.665357],
[ -1.753394, -3.048763, -0.66247 , ..., 3.479453,
6.685637, 5.947072],
[ -7.053816, -4.420122, 6.336983, ..., 23.162804,
14.313564, 2.746504],
...,
[ 14.337799, 18.14827 , -0.895313, ..., 37.246694,
58.276117, 34.074897],
[-15.100426, -22.201254, -24.41196 , ..., 60.251682,
66.325714, 32.992202],
[-13.52109 , -18.166892, -16.496382, ..., 38.908534,
41.846779, 15.183713]],
[[ -4.960212, -10.467057, -2.350238, ..., 8.175621,
-3.594708, -13.050947],
[ -2.188337, -5.102677, -1.05976 , ..., 3.500581,
5.308049, -0.665579],
[ -7.689325, -4.299846, 4.479113, ..., 7.190843,
-6.158385, -10.379174],
...,
[ 4.838299, -1.461336, 3.449214, ..., 22.214853,
10.277305, -20.902231],
[ -9.16024 , -6.62222 , 5.643113, ..., 1.50843 ,
-7.206676, -21.635724],
[ -7.149388, -5.05326 , 5.05269 , ..., 12.605102,
7.582936, -10.857408]],
[[ -8.999973, -8.265457, 0.224464, ..., 2.243862,
-2.492217, -13.048096],
[ -1.115053, 0.311383, -0.91307 , ..., 3.98221 ,
1.695672, -5.227182],
[ 2.39671 , 2.200741, 5.486765, ..., 13.607921,
7.737062, -13.571547],
...,
[-24.200021, -15.751269, 3.40317 , ..., 3.005727,
3.171528, -1.583004],
[-12.074557, -7.229352, -1.982399, ..., 4.310723,
-9.359018, -3.962187],
[ -7.750716, -6.153427, -2.874191, ..., 4.213236,
-2.779235, -1.473658]]], shape=(1009, 111, 100)), 'label': array([0., 0., 0., ..., 0., 0., 0.], shape=(1009,)), 'corr': array([[[ 0. , 0.97663709, 0.82530072, ..., 0.13740499,
0.35407144, 0.08548723],
[ 0.97663709, 0. , 0.81736156, ..., 0.12545135,
0.31653123, 0.08523937],
[ 0.82530072, 0.81736156, 0. , ..., 0.32199278,
0.48073326, 0.23490678],
...,
[ 0.13740499, 0.12545135, 0.32199278, ..., 0. ,
0.91433765, 0.76759212],
[ 0.35407144, 0.31653123, 0.48073326, ..., 0.91433765,
0. , 1.1123392 ],
[ 0.08548723, 0.08523937, 0.23490678, ..., 0.76759212,
1.1123392 , 0. ]],
[[ 0. , 0.62647921, 0.68386934, ..., -0.05088815,
0.60446506, 0.41812769],
[ 0.62647921, 0. , 0.54538175, ..., 0.09402108,
0.39961469, 0.28481799],
[ 0.68386934, 0.54538175, 0. , ..., -0.06276508,
0.30605398, 0.22979597],
...,
[-0.05088815, 0.09402108, -0.06276508, ..., 0. ,
-0.01265819, 0.08046417],
[ 0.60446506, 0.39961469, 0.30605398, ..., -0.01265819,
0. , 0.85968089],
[ 0.41812769, 0.28481799, 0.22979597, ..., 0.08046417,
0.85968089, 0. ]],
[[ 0. , 0.52387629, 0.39586813, ..., 0.49770721,
0.77686881, 0.67761376],
[ 0.52387629, 0. , 0.43492339, ..., 0.08623291,
0.30371602, 0.28660143],
[ 0.39586813, 0.43492339, 0. , ..., -0.10309543,
-0.00903498, -0.06115111],
...,
[ 0.49770721, 0.08623291, -0.10309543, ..., 0. ,
0.69001642, 0.72559831],
[ 0.77686881, 0.30371602, -0.00903498, ..., 0.69001642,
0. , 1.37230362],
[ 0.67761376, 0.28660143, -0.06115111, ..., 0.72559831,
1.37230362, 0. ]],
...,
[[ 0. , 0.28196335, 0.62367273, ..., 0.31654042,
0.30184585, 0.24758647],
[ 0.28196335, 0. , 0.52368336, ..., -0.05335835,
-0.01349599, -0.08051841],
[ 0.62367273, 0.52368336, 0. , ..., 0.07688247,
0.15630406, 0.09915298],
...,
[ 0.31654042, -0.05335835, 0.07688247, ..., 0. ,
0.33613925, 0.35266004],
[ 0.30184585, -0.01349599, 0.15630406, ..., 0.33613925,
0. , 1.26734374],
[ 0.24758647, -0.08051841, 0.09915298, ..., 0.35266004,
1.26734374, 0. ]],
[[ 0. , 0.3318541 , 0.3479959 , ..., 0.17348234,
0.33035729, 0.28096653],
[ 0.3318541 , 0. , 0.13110274, ..., 0.16131239,
0.18152441, 0.0901034 ],
[ 0.3479959 , 0.13110274, 0. , ..., 0.03690945,
0.19735146, 0.23887414],
...,
[ 0.17348234, 0.16131239, 0.03690945, ..., 0. ,
0.21702285, 0.29280098],
[ 0.33035729, 0.18152441, 0.19735146, ..., 0.21702285,
0. , 0.63961068],
[ 0.28096653, 0.0901034 , 0.23887414, ..., 0.29280098,
0.63961068, 0. ]],
[[ 0. , 0.61490946, 0.86683146, ..., 0.28751356,
0.35820952, 0.36030001],
[ 0.61490946, 0. , 0.44664105, ..., 0.39767059,
0.28708429, 0.39182893],
[ 0.86683146, 0.44664105, 0. , ..., 0.18048986,
0.38597263, 0.43650839],
...,
[ 0.28751356, 0.39767059, 0.18048986, ..., 0. ,
0.22311195, 0.35032753],
[ 0.35820952, 0.28708429, 0.38597263, ..., 0.22311195,
0. , 1.03944121],
[ 0.36030001, 0.39182893, 0.43650839, ..., 0.35032753,
1.03944121, 0. ]]], shape=(1009, 111, 111)), 'pcorr': array([[[ 0. , 0.0845738 , 0.05283968, ..., -0.02852326,
-0.05030299, -0.00900934],
[ 0.0845738 , 0. , 0.05794923, ..., -0.06134557,
0.02918709, -0.04988144],
[ 0.05283968, 0.05794923, 0. , ..., 0.08585561,
0.0933945 , -0.01767182],
...,
[-0.02852326, -0.06134557, 0.08585561, ..., 0. ,
0.0742375 , 0.11123407],
[-0.05030299, 0.02918709, 0.0933945 , ..., 0.0742375 ,
0. , 0.26601118],
[-0.00900934, -0.04988144, -0.01767182, ..., 0.11123407,
0.26601118, 0. ]],
[[ 0. , 0.02515851, -0.02687615, ..., -0.07865261,
0.03765896, 0.03533854],
[ 0.02515851, 0. , 0.03731194, ..., -0.04081472,
-0.02018316, 0.02902786],
[-0.02687615, 0.03731194, 0. , ..., 0.02341885,
-0.0237432 , 0.07630574],
...,
[-0.07865261, -0.04081472, 0.02341885, ..., 0. ,
0.05092006, 0.03885538],
[ 0.03765896, -0.02018316, -0.0237432 , ..., 0.05092006,
0. , 0.08766317],
[ 0.03533854, 0.02902786, 0.07630574, ..., 0.03885538,
0.08766317, 0. ]],
[[ 0. , 0.0345411 , 0.02792029, ..., -0.01667542,
0.01530761, -0.00780056],
[ 0.0345411 , 0. , 0.05455556, ..., 0.0643729 ,
0.01796368, 0.03800115],
[ 0.02792029, 0.05455556, 0. , ..., -0.01042823,
-0.03121966, 0.01426741],
...,
[-0.01667542, 0.0643729 , -0.01042823, ..., 0. ,
-0.07791113, -0.04875181],
[ 0.01530761, 0.01796368, -0.03121966, ..., -0.07791113,
0. , 0.13051308],
[-0.00780056, 0.03800115, 0.01426741, ..., -0.04875181,
0.13051308, 0. ]],
...,
[[ 0. , 0.0855516 , 0.13597267, ..., 0.16457852,
0.04648458, 0.03031205],
[ 0.0855516 , 0. , -0.02231855, ..., -0.01696728,
0.06332429, -0.05170068],
[ 0.13597267, -0.02231855, 0. , ..., -0.14505349,
-0.05634801, 0.06301302],
...,
[ 0.16457852, -0.01696728, -0.14505349, ..., 0. ,
-0.09211864, 0.03218952],
[ 0.04648458, 0.06332429, -0.05634801, ..., -0.09211864,
0. , 0.56874401],
[ 0.03031205, -0.05170068, 0.06301302, ..., 0.03218952,
0.56874401, 0. ]],
[[ 0. , 0.00369714, 0.01803972, ..., 0.02636135,
0.01498399, 0.00107738],
[ 0.00369714, 0. , -0.04215888, ..., 0.08585435,
0.01525622, -0.0044203 ],
[ 0.01803972, -0.04215888, 0. , ..., 0.00435903,
-0.04385825, 0.01943377],
...,
[ 0.02636135, 0.08585435, 0.00435903, ..., 0. ,
0.08473985, 0.14461459],
[ 0.01498399, 0.01525622, -0.04385825, ..., 0.08473985,
0. , 0.1207058 ],
[ 0.00107738, -0.0044203 , 0.01943377, ..., 0.14461459,
0.1207058 , 0. ]],
[[ 0. , 0.02612483, 0.03413983, ..., 0.06152281,
-0.02745578, -0.01738546],
[ 0.02612483, 0. , -0.04195744, ..., 0.01392757,
-0.07636841, 0.00512999],
[ 0.03413983, -0.04195744, 0. , ..., -0.00615086,
0.00395361, 0.02567149],
...,
[ 0.06152281, 0.01392757, -0.00615086, ..., 0. ,
0.01919661, 0.07781511],
[-0.02745578, -0.07636841, 0.00395361, ..., 0.01919661,
0. , 0.16214421],
[-0.01738546, 0.00512999, 0.02567149, ..., 0.07781511,
0.16214421, 0. ]]], shape=(1009, 111, 111)), 'site': array(['PITT', 'PITT', 'PITT', ..., 'SBL', 'MAX_MUN', 'MAX_MUN'],
shape=(1009,), dtype='<U8')}
我用笔记本还打不开这个文件,我又写了个python文件来看一下。