SAITS模型在PhysioNet2012数据集上的缺失值填补实践笔记

# 导入必要的库
from benchpots.datasets import preprocess_physionet2012  # 导入physionet2012数据集预处理函数
import numpy as np
import os
import torch
from pypots.data.saving import pickle_dump

# 设置随机种子以保证实验可复现
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)

# 加载并预处理physionet2012数据集
physionet2012_dataset = preprocess_physionet2012(
    subset="set-a",
    pattern="point",
    rate=0.1,
)

# 打印数据集的键
print(physionet2012_dataset.keys())

# 创建测试集的缺失值指示掩码
test_X_nan = np.isnan(physionet2012_dataset["test_X"])
test_X_ori_nan = np.isnan(physionet2012_dataset["test_X_ori"])
physionet2012_dataset["test_X_indicating_mask"] = test_X_nan ^ test_X_ori_nan

# 将原始测试数据中的NaN值替换为0
physionet2012_dataset["test_X_ori"] = np.nan_to_num(physionet2012_dataset["test_X_ori"])

# 构建数据集字典
train_set = {"X": physionet2012_dataset["train_X"]}
val_set = {
    "X": physionet2012_dataset["val_X"],
    "X_ori": physionet2012_dataset["val_X_ori"]
}
test_set = {
    "X": physionet2012_dataset["test_X"],
    "X_ori": physionet2012_dataset["test_X_ori"]
}

# 导入SAITS模型及相关组件
from pypots.imputation import SAITS
from pypots.optim import Adam
from pypots.nn.functional import calc_mse

# 自动选择设备
device = "cuda" if torch.cuda.is_available() else "cpu"

# 初始化SAITS模型
saits = SAITS(
    n_steps=physionet2012_dataset['n_steps'],
    n_features=physionet2012_dataset['n_features'],
    n_layers=3,
    d_model=64,
    n_heads=4,
    d_k=16,
    d_v=16,
    d_ffn=128,
    dropout=0.1,
    ORT_weight=1,
    MIT_weight=1,
    batch_size=32,
    epochs=10,
    patience=3,
    optimizer=Adam(lr=1e-3),
    num_workers=0,
    device=device,
    saving_path="result_saving/imputation/saits",
    model_saving_strategy="best",
)

# 确保保存路径存在
os.makedirs("result_saving/imputation/saits", exist_ok=True)
os.makedirs("result_saving", exist_ok=True)

# 训练模型
try:
    saits.fit(train_set, val_set)
except Exception as e:
   
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值