孤独症项目(2)

猛的一看今天已经2月5号了,寒假马上就要结束了,可自己的科研进展几乎没有。我瞬间感觉很恐慌,我似乎看到了我毕业后去送外卖的样子,我不能接受这个结果。但是我在经历了这么多事情之后,我认识到我也是肉体凡胎,我今天很恐慌,可能我明天就摆烂了,人就是这样。所以我不想管明天的事情,我只管今天,我今天要努力学习三个小时。
1

def main(args):
    with open(args.config_filename) as f:
        config = yaml.load(f, Loader=yaml.Loader)

        dataloaders, node_size, node_feature_size, timeseries_size = \
            init_dataloader(config['data'])

        config['train']["seq_len"] = timeseries_size
        config['train']["node_size"] = node_size



        model = PLSNet(config['model'], node_size,
                         node_feature_size, timeseries_size)
        use_train = BasicTrain




        optimizer = torch.optim.Adam(
            model.parameters(), lr=config['train']['lr'],
            weight_decay=config['train']['weight_decay'])
        opts = (optimizer,)

        loss_name = 'loss'
        if config['train']["group_loss"]:
            loss_name = f"{loss_name}_group_loss"
        if config['train']["sparsity_loss"]:
            loss_name = f"{loss_name}_sparsity_loss"


        save_folder_name = Path(config['train']['log_folder'])/Path(config['model']['type'])/Path(
            # date_time +
            f"{config['data']['dataset']}_{config['data']['atlas']}")

        train_process = use_train(
            config['train'], model, opts, dataloaders, save_folder_name)

        train_process.train()

2

    with open(args.config_filename) as f:
        config = yaml.load(f, Loader=yaml.Loader)

这里的args.config_filename很显然是setting文件夹下的abide_PLSNet.yaml,这个yaml配置文件好像也是我粘贴复制来的。
我看了一遍,应该是我记混了,这个是原来文件夹下面就有的。
f这个东西能叫句柄吗,应该不能叫吧,它是一个变量,代表这个配置文件的内容。
这里就是读取这个yaml配置文件的内容,然后保存在config里面,通过config来调用内容
3
abide_PLSNet.yaml包括三部分的内容,分别是data,model和train

data:
  dataset: ABIDE
  atlas: else
  batch_size: 16
  test_batch_size: 16
  val_batch_size: 16
  train_set: 0.7
  val_set: 0.1
  fold: 0
#  time_seires: /data/CodeGoat24/FBNETGEN/ABIDE_pcp/abide.npy
#  time_seires: /data/CodeGoat24/FBNETGEN_ho/ABIDE_pcp/abide.npy
  time_seires: ./data/CodeGoat24/FBNETGEN_AAL/ABIDE_pcp/abide.npy
#  time_seires: /data/CodeGoat24/FBNETGEN_cc400/ABIDE_pcp/abide.npy




model:
  type: PLSNet
  extractor_type: attention
  embedding_size: 8
  window_size: 4

  cnn_pool_size: 16

  num_gru_layers: 4

  dropout: 0.5



train:
  lr: 1.0e-4
  weight_decay: 1.0e-4
  epochs: 500
  pool_ratio: 0.7
  optimizer: adam
  stepsize: 200
  group_loss: true
  sparsity_loss: true
  sparsity_loss_weight: 0.5e-4
  log_folder: result
  # uniform or pearson
  pure_gnn_graph: pearson
        dataloaders, node_size, node_feature_size, timeseries_size = \
            init_dataloader(config['data'])

我懂了,原来config是变成字典了,data相当于一个子字典,然后里面有键值对
4

import numpy as np
import torch
import torch.utils.data as utils
import csv

from nilearn.connectome import ConnectivityMeasure
from sklearn import preprocessing
import pandas as pd
import matplotlib.pyplot as plt
from scipy.io import loadmat
from nilearn import plotting, datasets

class StandardScaler:
    """
    Standard the input
    """

    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def transform(self, data):
        return (data - self.mean) / self.std

    def inverse_transform(self, data):
        return (data * self.std) + self.mean



        
def init_dataloader(dataset_config):



    data = np.load(dataset_config["time_seires"], allow_pickle=True).item()
    final_fc = data["timeseires"]
    final_pearson = data["corr"]
    labels = data["label"]


    _, _, timeseries = final_fc.shape

    _, node_size, node_feature_size = final_pearson.shape

    scaler = StandardScaler(mean=np.mean(
        final_fc), std=np.std(final_fc))
    
    final_fc = scaler.transform(final_fc)


    pseudo = []
    for i in range(len(final_fc)):
        pseudo.append(np.diag(np.ones(final_pearson.shape[1])))

    if 'cc200' in  dataset_config['atlas']:
        pseudo_arr = np.concatenate(pseudo, axis=0).reshape((-1, 200, 200))
    elif 'aal' in dataset_config['atlas']:
        pseudo_arr = np.concatenate(pseudo, axis=0).reshape((-1, 116, 116))
    elif 'cc400' in dataset_config['atlas']:
        pseudo_arr = np.concatenate(pseudo, axis=0).reshape((-1, 392, 392))
    else:
        pseudo_arr = np.concatenate(pseudo, axis=0).reshape((-1, 111, 111))



    final_fc, final_pearson, labels, pseudo_arr = [torch.from_numpy(
        data).float() for data in (final_fc, final_pearson, labels, pseudo_arr)]

    length = final_fc.shape[0]
    train_length = int(length*dataset_config["train_set"])
    val_length = int(length*dataset_config["val_set"])


    dataset = utils.TensorDataset(
        final_fc,
        final_pearson,
        labels,
        pseudo_arr
    )

    train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
        dataset, [train_length, val_length, length-train_length-val_length])

    train_dataloader = utils.DataLoader(
        train_dataset, batch_size=dataset_config["batch_size"], shuffle=True, drop_last=False)

    val_dataloader = utils.DataLoader(
        val_dataset, batch_size=dataset_config["batch_size"], shuffle=True, drop_last=False)

    test_dataloader = utils.DataLoader(
        test_dataset, batch_size=dataset_config["batch_size"], shuffle=True, drop_last=False)

    return (train_dataloader, val_dataloader, test_dataloader), node_size, node_feature_size, timeseries

接下来就是dataloader.py的内容了
我说这个time_series是啥啊,原来是 time_seires: ./data/CodeGoat24/FBNETGEN_AAL/ABIDE_pcp/abide.npy,一个点npy文件。原来这个npy文件又是一个处理格式,真是一层套一层啊

{'timeseires': array([[[ -4.671351, -12.931799, -20.238678, ...,  -4.6999  ,
          -2.881232,  -5.373659],
        [ -6.610331, -10.146871, -11.706947, ...,  -8.545325,
          -4.824956,  -3.549801],
        [  6.539966,  -3.924838, -22.075667, ..., -10.810907,
          -9.507657, -11.782674],
        ...,
        [ 23.815056,  14.159346,  -4.235137, ...,   2.318333,
          -0.419391,  -6.08656 ],
        [ 29.358294,  29.569822,  12.97839 , ..., -35.478521,
         -32.452544, -38.755226],
        [ 12.700474,  10.523634,  -1.103658, ...,  16.980012,
          18.994251,   7.554179]],

       [[ 15.055096,  14.30167 ,   8.35954 , ...,  -2.320052,
           5.738047,  18.929523],
        [ -2.650435,  -5.070064,  -6.582098, ...,  -6.500024,
          -6.520322,  -2.209336],
        [  7.958244,   7.113133,   3.146248, ...,  -7.816352,
           0.160806,  12.985252],
        ...,
        [  0.216612,  -7.946295, -14.338741, ...,  -2.661149,
          -1.293657,   4.052661],
        [  0.595963,  -1.462514,  -3.352045, ...,  -7.601621,
          -1.38041 ,   7.066522],
        [ -1.754612,  -2.342497,  -2.100473, ...,  -4.888682,
           0.817308,   6.089231]],

       [[ 10.699614,  12.640328,   9.107096, ...,  25.303575,
          14.631509,  -2.431361],
        [  8.946028,  11.992252,  10.768465, ...,  17.12695 ,
           4.901006,  -9.559137],
        [ 12.502747,  17.729441,  17.117685, ...,  13.455708,
         -13.292004, -35.746319],
        ...,
        [-13.968082, -18.558227, -20.739389, ...,  30.461445,
          48.840466,  47.846905],
        [  7.775516,   7.945796,   4.77154 , ...,  35.356228,
          42.096015,  35.434606],
        [  3.752828,   5.027397,   4.62332 , ...,  22.339197,
          34.736423,  33.95411 ]],

       ...,

       [[ -9.594532, -18.276965, -12.977462, ...,  22.482481,
           9.076281,  -3.665357],
        [ -1.753394,  -3.048763,  -0.66247 , ...,   3.479453,
           6.685637,   5.947072],
        [ -7.053816,  -4.420122,   6.336983, ...,  23.162804,
          14.313564,   2.746504],
        ...,
        [ 14.337799,  18.14827 ,  -0.895313, ...,  37.246694,
          58.276117,  34.074897],
        [-15.100426, -22.201254, -24.41196 , ...,  60.251682,
          66.325714,  32.992202],
        [-13.52109 , -18.166892, -16.496382, ...,  38.908534,
          41.846779,  15.183713]],

       [[ -4.960212, -10.467057,  -2.350238, ...,   8.175621,
          -3.594708, -13.050947],
        [ -2.188337,  -5.102677,  -1.05976 , ...,   3.500581,
           5.308049,  -0.665579],
        [ -7.689325,  -4.299846,   4.479113, ...,   7.190843,
          -6.158385, -10.379174],
        ...,
        [  4.838299,  -1.461336,   3.449214, ...,  22.214853,
          10.277305, -20.902231],
        [ -9.16024 ,  -6.62222 ,   5.643113, ...,   1.50843 ,
          -7.206676, -21.635724],
        [ -7.149388,  -5.05326 ,   5.05269 , ...,  12.605102,
           7.582936, -10.857408]],

       [[ -8.999973,  -8.265457,   0.224464, ...,   2.243862,
          -2.492217, -13.048096],
        [ -1.115053,   0.311383,  -0.91307 , ...,   3.98221 ,
           1.695672,  -5.227182],
        [  2.39671 ,   2.200741,   5.486765, ...,  13.607921,
           7.737062, -13.571547],
        ...,
        [-24.200021, -15.751269,   3.40317 , ...,   3.005727,
           3.171528,  -1.583004],
        [-12.074557,  -7.229352,  -1.982399, ...,   4.310723,
          -9.359018,  -3.962187],
        [ -7.750716,  -6.153427,  -2.874191, ...,   4.213236,
          -2.779235,  -1.473658]]], shape=(1009, 111, 100)), 'label': array([0., 0., 0., ..., 0., 0., 0.], shape=(1009,)), 'corr': array([[[ 0.        ,  0.97663709,  0.82530072, ...,  0.13740499,
          0.35407144,  0.08548723],
        [ 0.97663709,  0.        ,  0.81736156, ...,  0.12545135,
          0.31653123,  0.08523937],
        [ 0.82530072,  0.81736156,  0.        , ...,  0.32199278,
          0.48073326,  0.23490678],
        ...,
        [ 0.13740499,  0.12545135,  0.32199278, ...,  0.        ,
          0.91433765,  0.76759212],
        [ 0.35407144,  0.31653123,  0.48073326, ...,  0.91433765,
          0.        ,  1.1123392 ],
        [ 0.08548723,  0.08523937,  0.23490678, ...,  0.76759212,
          1.1123392 ,  0.        ]],

       [[ 0.        ,  0.62647921,  0.68386934, ..., -0.05088815,
          0.60446506,  0.41812769],
        [ 0.62647921,  0.        ,  0.54538175, ...,  0.09402108,
          0.39961469,  0.28481799],
        [ 0.68386934,  0.54538175,  0.        , ..., -0.06276508,
          0.30605398,  0.22979597],
        ...,
        [-0.05088815,  0.09402108, -0.06276508, ...,  0.        ,
         -0.01265819,  0.08046417],
        [ 0.60446506,  0.39961469,  0.30605398, ..., -0.01265819,
          0.        ,  0.85968089],
        [ 0.41812769,  0.28481799,  0.22979597, ...,  0.08046417,
          0.85968089,  0.        ]],

       [[ 0.        ,  0.52387629,  0.39586813, ...,  0.49770721,
          0.77686881,  0.67761376],
        [ 0.52387629,  0.        ,  0.43492339, ...,  0.08623291,
          0.30371602,  0.28660143],
        [ 0.39586813,  0.43492339,  0.        , ..., -0.10309543,
         -0.00903498, -0.06115111],
        ...,
        [ 0.49770721,  0.08623291, -0.10309543, ...,  0.        ,
          0.69001642,  0.72559831],
        [ 0.77686881,  0.30371602, -0.00903498, ...,  0.69001642,
          0.        ,  1.37230362],
        [ 0.67761376,  0.28660143, -0.06115111, ...,  0.72559831,
          1.37230362,  0.        ]],

       ...,

       [[ 0.        ,  0.28196335,  0.62367273, ...,  0.31654042,
          0.30184585,  0.24758647],
        [ 0.28196335,  0.        ,  0.52368336, ..., -0.05335835,
         -0.01349599, -0.08051841],
        [ 0.62367273,  0.52368336,  0.        , ...,  0.07688247,
          0.15630406,  0.09915298],
        ...,
        [ 0.31654042, -0.05335835,  0.07688247, ...,  0.        ,
          0.33613925,  0.35266004],
        [ 0.30184585, -0.01349599,  0.15630406, ...,  0.33613925,
          0.        ,  1.26734374],
        [ 0.24758647, -0.08051841,  0.09915298, ...,  0.35266004,
          1.26734374,  0.        ]],

       [[ 0.        ,  0.3318541 ,  0.3479959 , ...,  0.17348234,
          0.33035729,  0.28096653],
        [ 0.3318541 ,  0.        ,  0.13110274, ...,  0.16131239,
          0.18152441,  0.0901034 ],
        [ 0.3479959 ,  0.13110274,  0.        , ...,  0.03690945,
          0.19735146,  0.23887414],
        ...,
        [ 0.17348234,  0.16131239,  0.03690945, ...,  0.        ,
          0.21702285,  0.29280098],
        [ 0.33035729,  0.18152441,  0.19735146, ...,  0.21702285,
          0.        ,  0.63961068],
        [ 0.28096653,  0.0901034 ,  0.23887414, ...,  0.29280098,
          0.63961068,  0.        ]],

       [[ 0.        ,  0.61490946,  0.86683146, ...,  0.28751356,
          0.35820952,  0.36030001],
        [ 0.61490946,  0.        ,  0.44664105, ...,  0.39767059,
          0.28708429,  0.39182893],
        [ 0.86683146,  0.44664105,  0.        , ...,  0.18048986,
          0.38597263,  0.43650839],
        ...,
        [ 0.28751356,  0.39767059,  0.18048986, ...,  0.        ,
          0.22311195,  0.35032753],
        [ 0.35820952,  0.28708429,  0.38597263, ...,  0.22311195,
          0.        ,  1.03944121],
        [ 0.36030001,  0.39182893,  0.43650839, ...,  0.35032753,
          1.03944121,  0.        ]]], shape=(1009, 111, 111)), 'pcorr': array([[[ 0.        ,  0.0845738 ,  0.05283968, ..., -0.02852326,
         -0.05030299, -0.00900934],
        [ 0.0845738 ,  0.        ,  0.05794923, ..., -0.06134557,
          0.02918709, -0.04988144],
        [ 0.05283968,  0.05794923,  0.        , ...,  0.08585561,
          0.0933945 , -0.01767182],
        ...,
        [-0.02852326, -0.06134557,  0.08585561, ...,  0.        ,
          0.0742375 ,  0.11123407],
        [-0.05030299,  0.02918709,  0.0933945 , ...,  0.0742375 ,
          0.        ,  0.26601118],
        [-0.00900934, -0.04988144, -0.01767182, ...,  0.11123407,
          0.26601118,  0.        ]],

       [[ 0.        ,  0.02515851, -0.02687615, ..., -0.07865261,
          0.03765896,  0.03533854],
        [ 0.02515851,  0.        ,  0.03731194, ..., -0.04081472,
         -0.02018316,  0.02902786],
        [-0.02687615,  0.03731194,  0.        , ...,  0.02341885,
         -0.0237432 ,  0.07630574],
        ...,
        [-0.07865261, -0.04081472,  0.02341885, ...,  0.        ,
          0.05092006,  0.03885538],
        [ 0.03765896, -0.02018316, -0.0237432 , ...,  0.05092006,
          0.        ,  0.08766317],
        [ 0.03533854,  0.02902786,  0.07630574, ...,  0.03885538,
          0.08766317,  0.        ]],

       [[ 0.        ,  0.0345411 ,  0.02792029, ..., -0.01667542,
          0.01530761, -0.00780056],
        [ 0.0345411 ,  0.        ,  0.05455556, ...,  0.0643729 ,
          0.01796368,  0.03800115],
        [ 0.02792029,  0.05455556,  0.        , ..., -0.01042823,
         -0.03121966,  0.01426741],
        ...,
        [-0.01667542,  0.0643729 , -0.01042823, ...,  0.        ,
         -0.07791113, -0.04875181],
        [ 0.01530761,  0.01796368, -0.03121966, ..., -0.07791113,
          0.        ,  0.13051308],
        [-0.00780056,  0.03800115,  0.01426741, ..., -0.04875181,
          0.13051308,  0.        ]],

       ...,

       [[ 0.        ,  0.0855516 ,  0.13597267, ...,  0.16457852,
          0.04648458,  0.03031205],
        [ 0.0855516 ,  0.        , -0.02231855, ..., -0.01696728,
          0.06332429, -0.05170068],
        [ 0.13597267, -0.02231855,  0.        , ..., -0.14505349,
         -0.05634801,  0.06301302],
        ...,
        [ 0.16457852, -0.01696728, -0.14505349, ...,  0.        ,
         -0.09211864,  0.03218952],
        [ 0.04648458,  0.06332429, -0.05634801, ..., -0.09211864,
          0.        ,  0.56874401],
        [ 0.03031205, -0.05170068,  0.06301302, ...,  0.03218952,
          0.56874401,  0.        ]],

       [[ 0.        ,  0.00369714,  0.01803972, ...,  0.02636135,
          0.01498399,  0.00107738],
        [ 0.00369714,  0.        , -0.04215888, ...,  0.08585435,
          0.01525622, -0.0044203 ],
        [ 0.01803972, -0.04215888,  0.        , ...,  0.00435903,
         -0.04385825,  0.01943377],
        ...,
        [ 0.02636135,  0.08585435,  0.00435903, ...,  0.        ,
          0.08473985,  0.14461459],
        [ 0.01498399,  0.01525622, -0.04385825, ...,  0.08473985,
          0.        ,  0.1207058 ],
        [ 0.00107738, -0.0044203 ,  0.01943377, ...,  0.14461459,
          0.1207058 ,  0.        ]],

       [[ 0.        ,  0.02612483,  0.03413983, ...,  0.06152281,
         -0.02745578, -0.01738546],
        [ 0.02612483,  0.        , -0.04195744, ...,  0.01392757,
         -0.07636841,  0.00512999],
        [ 0.03413983, -0.04195744,  0.        , ..., -0.00615086,
          0.00395361,  0.02567149],
        ...,
        [ 0.06152281,  0.01392757, -0.00615086, ...,  0.        ,
          0.01919661,  0.07781511],
        [-0.02745578, -0.07636841,  0.00395361, ...,  0.01919661,
          0.        ,  0.16214421],
        [-0.01738546,  0.00512999,  0.02567149, ...,  0.07781511,
          0.16214421,  0.        ]]], shape=(1009, 111, 111)), 'site': array(['PITT', 'PITT', 'PITT', ..., 'SBL', 'MAX_MUN', 'MAX_MUN'],
      shape=(1009,), dtype='<U8')}

我用笔记本还打不开这个文件,我又写了个python文件来看一下。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值