# 导入必要的库
from benchpots.datasets import preprocess_physionet2012 # 导入physionet2012数据集预处理函数
import numpy as np
import os
import torch
from pypots.data.saving import pickle_dump
# 设置随机种子以保证实验可复现
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
torch.cuda.manual_seed(SEED)
# 加载并预处理physionet2012数据集
physionet2012_dataset = preprocess_physionet2012(
subset="set-a",
pattern="point",
rate=0.1,
)
# 打印数据集的键
print(physionet2012_dataset.keys())
# 创建测试集的缺失值指示掩码
test_X_nan = np.isnan(physionet2012_dataset["test_X"])
test_X_ori_nan = np.isnan(physionet2012_dataset["test_X_ori"])
physionet2012_dataset["test_X_indicating_mask"] = test_X_nan ^ test_X_ori_nan
# 将原始测试数据中的NaN值替换为0
physionet2012_dataset["test_X_ori"] = np.nan_to_num(physionet2012_dataset["test_X_ori"])
# 构建数据集字典
train_set = {"X": physionet2012_dataset["train_X"]}
val_set = {
"X": physionet2012_dataset["val_X"],
"X_ori": physionet2012_dataset["val_X_ori"]
}
test_set = {
"X": physionet2012_dataset["test_X"],
"X_ori": physionet2012_dataset["test_X_ori"]
}
# 导入SAITS模型及相关组件
from pypots.imputation import SAITS
from pypots.optim import Adam
from pypots.nn.functional import calc_mse
# 自动选择设备
device = "cuda" if torch.cuda.is_available() else "cpu"
# 初始化SAITS模型
saits = SAITS(
n_steps=physionet2012_dataset['n_steps'],
n_features=physionet2012_dataset['n_features'],
n_layers=3,
d_model=64,
n_heads=4,
d_k=16,
d_v=16,
d_ffn=128,
dropout=0.1,
ORT_weight=1,
MIT_weight=1,
batch_size=32,
epochs=10,
patience=3,
optimizer=Adam(lr=1e-3),
num_workers=0,
device=device,
saving_path="result_saving/imputation/saits",
model_saving_strategy="best",
)
# 确保保存路径存在
os.makedirs("result_saving/imputation/saits", exist_ok=True)
os.makedirs("result_saving", exist_ok=True)
# 训练模型
try:
saits.fit(train_set, val_set)
except Exception as e:
SAITS模型在PhysioNet2012数据集上的缺失值填补实践笔记
最新推荐文章于 2025-05-19 09:47:58 发布

最低0.47元/天 解锁文章
1041

被折叠的 条评论
为什么被折叠?



