今日任务:下载并部署模型
新训练的模型数据:
Fold 1: Epoch [1/10], Step [100/395], Loss: 0.5403 Epoch [1/10], Step [200/395], Loss: 0.4706 Epoch [1/10], Step [300/395], Loss: 0.5848 Epoch [1/10], Validation Accuracy: 81.85% Epoch [2/10], Step [100/395], Loss: 0.5793 Epoch [2/10], Step [200/395], Loss: 0.4042 Epoch [2/10], Step [300/395], Loss: 0.4152 Epoch [2/10], Validation Accuracy: 84.50% Epoch [3/10], Step [100/395], Loss: 0.5450 Epoch [3/10], Step [200/395], Loss: 0.4540 Epoch [3/10], Step [300/395], Loss: 0.6119 Epoch [3/10], Validation Accuracy: 83.31% Epoch [4/10], Step [100/395], Loss: 0.5222 Epoch [4/10], Step [200/395], Loss: 0.4575 Epoch [4/10], Step [300/395], Loss: 0.4059 Epoch [4/10], Validation Accuracy: 87.94% Epoch [5/10], Step [100/395], Loss: 0.5911 Epoch [5/10], Step [200/395], Loss: 0.4194 Epoch [5/10], Step [300/395], Loss: 0.5268 Epoch [5/10], Validation Accuracy: 85.29% Epoch [6/10], Step [100/395], Loss: 0.1904 Epoch [6/10], Step [200/395], Loss: 0.4084 Epoch [6/10], Step [300/395], Loss: 0.3588 Epoch [6/10], Validation Accuracy: 86.74% Epoch [7/10], Step [100/395], Loss: 0.1685 Epoch [7/10], Step [200/395], Loss: 0.2216 Epoch [7/10], Step [300/395], Loss: 0.3419 Epoch [7/10], Validation Accuracy: 76.28% Epoch [8/10], Step [100/395], Loss: 0.2877 Epoch [8/10], Step [200/395], Loss: 0.2870 Epoch [8/10], Step [300/395], Loss: 0.2535 Epoch [8/10], Validation Accuracy: 91.81% Epoch [9/10], Step [100/395], Loss: 0.1896 Epoch [9/10], Step [200/395], Loss: 0.2026 Epoch [9/10], Step [300/395], Loss: 0.1269 Epoch [9/10], Validation Accuracy: 92.63% Epoch [10/10], Step [100/395], Loss: 0.2447 Epoch [10/10], Step [200/395], Loss: 0.1344 Epoch [10/10], Step [300/395], Loss: 0.3051 Epoch [10/10], Validation Accuracy: 94.06% -------------------- Fold 2: Epoch [1/10], Step [100/395], Loss: 0.7398 Epoch [1/10], Step [200/395], Loss: 0.5177 Epoch [1/10], Step [300/395], Loss: 0.4803 Epoch [1/10], Validation Accuracy: 69.64% Epoch [2/10], Step [100/395], Loss: 0.6056 Epoch [2/10], Step [200/395], Loss: 0.5051 Epoch [2/10], Step [300/395], Loss: 0.3503 Epoch [2/10], Validation Accuracy: 83.95% Epoch [3/10], Step [100/395], Loss: 0.3278 Epoch [3/10], Step [200/395], Loss: 0.3536 Epoch [3/10], Step [300/395], Loss: 0.3474 Epoch [3/10], Validation Accuracy: 85.29% Epoch [4/10], Step [100/395], Loss: 0.5932 Epoch [4/10], Step [200/395], Loss: 0.4186 Epoch [4/10], Step [300/395], Loss: 0.4706 Epoch [4/10], Validation Accuracy: 77.64% Epoch [5/10], Step [100/395], Loss: 0.5504 Epoch [5/10], Step [200/395], Loss: 0.4621 Epoch [5/10], Step [300/395], Loss: 0.4277 Epoch [5/10], Validation Accuracy: 86.18% Epoch [6/10], Step [100/395], Loss: 0.5386 Epoch [6/10], Step [200/395], Loss: 0.2419 Epoch [6/10], Step [300/395], Loss: 0.2551 Epoch [6/10], Validation Accuracy: 88.35% Epoch [7/10], Step [100/395], Loss: 0.4689 Epoch [7/10], Step [200/395], Loss: 0.2680 Epoch [7/10], Step [300/395], Loss: 0.4049 Epoch [7/10], Validation Accuracy: 93.39% Epoch [8/10], Step [100/395], Loss: 0.1264 Epoch [8/10], Step [200/395], Loss: 0.2044 Epoch [8/10], Step [300/395], Loss: 0.1533 Epoch [8/10], Validation Accuracy: 93.74% Epoch [9/10], Step [100/395], Loss: 0.1288 Epoch [9/10], Step [200/395], Loss: 0.0413 Epoch [9/10], Step [300/395], Loss: 0.0411 Epoch [9/10], Validation Accuracy: 94.31% Epoch [10/10], Step [100/395], Loss: 0.1194 Epoch [10/10], Step [200/395], Loss: 0.2166 Epoch [10/10], Step [300/395], Loss: 0.1666 Epoch [10/10], Validation Accuracy: 95.50% -------------------- Fold 3: Epoch [1/10], Step [100/395], Loss: 0.5017 Epoch [1/10], Step [200/395], Loss: 0.2869 Epoch [1/10], Step [300/395], Loss: 0.4488 Epoch [1/10], Validation Accuracy: 77.27% Epoch [2/10], Step [100/395], Loss: 0.3520 Epoch [2/10], Step [200/395], Loss: 0.2961 Epoch [2/10], Step [300/395], Loss: 0.6696 Epoch [2/10], Validation Accuracy: 56.64% Epoch [3/10], Step [100/395], Loss: 0.5241 Epoch [3/10], Step [200/395], Loss: 0.4306 Epoch [3/10], Step [300/395], Loss: 0.3621 Epoch [3/10], Validation Accuracy: 80.79% Epoch [4/10], Step [100/395], Loss: 0.5825 Epoch [4/10], Step [200/395], Loss: 0.4053 Epoch [4/10], Step [300/395], Loss: 0.4310 Epoch [4/10], Validation Accuracy: 85.85% Epoch [5/10], Step [100/395], Loss: 0.3651 Epoch [5/10], Step [200/395], Loss: 0.3938 Epoch [5/10], Step [300/395], Loss: 0.2980 Epoch [5/10], Validation Accuracy: 85.72% Epoch [6/10], Step [100/395], Loss: 0.3965 Epoch [6/10], Step [200/395], Loss: 0.1742 Epoch [6/10], Step [300/395], Loss: 0.2613 Epoch [6/10], Validation Accuracy: 87.75% Epoch [7/10], Step [100/395], Loss: 0.2838 Epoch [7/10], Step [200/395], Loss: 0.1560 Epoch [7/10], Step [300/395], Loss: 0.4377 Epoch [7/10], Validation Accuracy: 90.02% Epoch [8/10], Step [100/395], Loss: 0.1986 Epoch [8/10], Step [200/395], Loss: 0.2513 Epoch [8/10], Step [300/395], Loss: 0.2369 Epoch [8/10], Validation Accuracy: 90.94% Epoch [9/10], Step [100/395], Loss: 0.2996 Epoch [9/10], Step [200/395], Loss: 0.1340 Epoch [9/10], Step [300/395], Loss: 0.2382 Epoch [9/10], Validation Accuracy: 90.65% Epoch [10/10], Step [100/395], Loss: 0.2572 Epoch [10/10], Step [200/395], Loss: 0.1092 Epoch [10/10], Step [300/395], Loss: 0.2229 Epoch [10/10], Validation Accuracy: 90.30% -------------------- Fold 4: Epoch [1/10], Step [100/395], Loss: 0.4642 Epoch [1/10], Step [200/395], Loss: 0.5792 Epoch [1/10], Step [300/395], Loss: 0.3839 Epoch [1/10], Validation Accuracy: 68.73% Epoch [2/10], Step [100/395], Loss: 0.3857 Epoch [2/10], Step [200/395], Loss: 0.4441 Epoch [2/10], Step [300/395], Loss: 0.3173 Epoch [2/10], Validation Accuracy: 80.41% Epoch [3/10], Step [100/395], Loss: 0.5653 Epoch [3/10], Step [200/395], Loss: 0.4694 Epoch [3/10], Step [300/395], Loss: 0.3736 Epoch [3/10], Validation Accuracy: 85.85% Epoch [4/10], Step [100/395], Loss: 0.5064 Epoch [4/10], Step [200/395], Loss: 0.5156 Epoch [4/10], Step [300/395], Loss: 0.3889 Epoch [4/10], Validation Accuracy: 82.41% Epoch [5/10], Step [100/395], Loss: 0.5918 Epoch [5/10], Step [200/395], Loss: 0.4050 Epoch [5/10], Step [300/395], Loss: 0.2634 Epoch [5/10], Validation Accuracy: 85.69% Epoch [6/10], Step [100/395], Loss: 0.4605 Epoch [6/10], Step [200/395], Loss: 0.5615 Epoch [6/10], Step [300/395], Loss: 0.2193 Epoch [6/10], Validation Accuracy: 90.16% Epoch [7/10], Step [100/395], Loss: 0.2892 Epoch [7/10], Step [200/395], Loss: 0.1804 Epoch [7/10], Step [300/395], Loss: 0.2801 Epoch [7/10], Validation Accuracy: 95.72% Epoch [8/10], Step [100/395], Loss: 0.1843 Epoch [8/10], Step [200/395], Loss: 0.3311 Epoch [8/10], Step [300/395], Loss: 0.0513 Epoch [8/10], Validation Accuracy: 91.38% Epoch [9/10], Step [100/395], Loss: 0.1424 Epoch [9/10], Step [200/395], Loss: 0.3028 Epoch [9/10], Step [300/395], Loss: 0.1823 Epoch [9/10], Validation Accuracy: 95.45% Epoch [10/10], Step [100/395], Loss: 0.0841 Epoch [10/10], Step [200/395], Loss: 0.1645 Epoch [10/10], Step [300/395], Loss: 0.0951 Epoch [10/10], Validation Accuracy: 93.96% -------------------- Fold 5: Epoch [1/10], Step [100/395], Loss: 0.5558 Epoch [1/10], Step [200/395], Loss: 0.5379 Epoch [1/10], Step [300/395], Loss: 0.3384 Epoch [1/10], Validation Accuracy: 81.93% Epoch [2/10], Step [100/395], Loss: 0.3611 Epoch [2/10], Step [200/395], Loss: 0.4635 Epoch [2/10], Step [300/395], Loss: 0.2369 Epoch [2/10], Validation Accuracy: 85.56% Epoch [3/10], Step [100/395], Loss: 0.3294 Epoch [3/10], Step [200/395], Loss: 0.3171 Epoch [3/10], Step [300/395], Loss: 0.2673 Epoch [3/10], Validation Accuracy: 88.51% Epoch [4/10], Step [100/395], Loss: 0.3625 Epoch [4/10], Step [200/395], Loss: 0.2983 Epoch [4/10], Step [300/395], Loss: 0.4526 Epoch [4/10], Validation Accuracy: 92.06% Epoch [5/10], Step [100/395], Loss: 0.1157 Epoch [5/10], Step [200/395], Loss: 0.1774 Epoch [5/10], Step [300/395], Loss: 0.1523 Epoch [5/10], Validation Accuracy: 95.88% Epoch [6/10], Step [100/395], Loss: 0.0898 Epoch [6/10], Step [200/395], Loss: 0.0891 Epoch [6/10], Step [300/395], Loss: 0.1342 Epoch [6/10], Validation Accuracy: 96.39% Epoch [7/10], Step [100/395], Loss: 0.0743 Epoch [7/10], Step [200/395], Loss: 0.0704 Epoch [7/10], Step [300/395], Loss: 0.0073 Epoch [7/10], Validation Accuracy: 98.80% Epoch [8/10], Step [100/395], Loss: 0.0132 Epoch [8/10], Step [200/395], Loss: 0.0029 Epoch [8/10], Step [300/395], Loss: 0.0069 Epoch [8/10], Validation Accuracy: 99.22% Epoch [9/10], Step [100/395], Loss: 0.0098 Epoch [9/10], Step [200/395], Loss: 0.0344 Epoch [9/10], Step [300/395], Loss: 0.0075 Epoch [9/10], Validation Accuracy: 98.83% Epoch [10/10], Step [100/395], Loss: 0.0506 Epoch [10/10], Step [200/395], Loss: 0.0626 Epoch [10/10], Step [300/395], Loss: 0.0172 Epoch [10/10], Validation Accuracy: 99.33% --------------------
保存模型:
torch.save(model.state_dict(), "E:/gxy/gxy/D/teach/footbed/model/model_weights.pth")
出现新的问题:实际测试中的传感器没有经过校准,与数据集中的数据量纲不同、分布也不同。经过了解,不能直接通过归一化处理数据,此处引入分位数处理方法。
好的,我现在需要详细解释分位数变换(Quantile Transform)的原理和操作流程,特别是在处理来自不同传感器的测试数据和实际数据时的应用。用户之前的问题是关于如何处理传感器不同且数据未校准的情况,我的回答中提到了使用分位数变换作为解决方案之一。现在用户希望更深入地了解这个方法的原理和具体操作步骤。
首先,我需要回忆分位数变换的基本概念。分位数变换属于非参数方法,其核心思想是将数据映射到指定的分布,通常是均匀分布或正态分布。这种方法不依赖于数据的原始分布,而是通过匹配分位数来调整数据,从而消除不同传感器带来的量纲差异,同时保留数据的序数关系。
接下来,我需要确定分位数变换的具体步骤。通常,这个过程包括拟合阶段和转换阶段。在训练阶段,使用训练数据来估计每个特征的分位数,然后保存这些分位数参数。在应用阶段,使用这些参数来转换新的数据,包括测试数据或实际应用中的数据。
然后,我需要思考如何解释分位数变换如何消除传感器差异。例如,不同传感器可能有不同的测量范围或灵敏度,导致相同信号在不同传感器上的数值分布不同。分位数变换通过将每个传感器的数据映射到相同的分位数分布,从而对齐它们的统计分布,使模型能够更好地处理这些数据。
还需要考虑分位数变换的优缺点。优点包括消除量纲差异、保留数据序数关系、适用于非正态分布数据等。缺点可能包括计算复杂度较高,特别是对于大数据集,以及需要保存大量的分位数参数。
此外,用户提到实际数据是定性的,只能表现出变化趋势,没有经过校准。分位数变换在这种情况下是否适用?应该说明分位数变换如何处理这种情况,例如通过关注数据的相对顺序而不是绝对数值,从而捕捉趋势变化。
需要举一个具体的例子来说明分位数变换的操作步骤。例如,假设训练数据来自传感器A,实际数据来自传感器B,两者的数据分布不同。通过分位数变换,将传感器B的数据转换为传感器A的分位数分布,从而使得两者的数据在统计上对齐。
另外,要提到在实施分位数变换时需要注意的问题,比如如何处理超出训练数据范围的新数据点,是否需要外推或者进行截断处理,这可能会影响变换后的数据分布。
最后,需要总结分位数变换的适用场景和操作流程,确保用户能够理解并正确应用这一方法。
分位数变换(Quantile Transform)原理与操作流程详解
1. 核心原理
分位数变换是一种非参数方法,通过将数据的分位数映射到目标分布(如均匀分布或正态分布),从而消除不同传感器或数据源的量纲差异。其核心思想是:保留数据的序数关系(相对大小顺序),但强制数据服从相同分布。
-
数学原理:
- 对训练数据的每个特征,计算其经验分位数(如 0%, 25%, 50%, 75%, 100% 分位数)。
- 将实际数据按训练数据的分位数进行插值,映射到目标分布。
-
直观理解:
- 假设传感器 A 的数据分布为
[1, 2, 3, 4, 5]
,传感器 B 的数据分布为[10, 20, 30, 40, 50]
。 - 分位数变换后,两者均被映射到统一分布(如均匀分布
[0.0, 0.25, 0.5, 0.75, 1.0]
)。 - 模型不再关注绝对数值,而是数据的相对位置。
- 假设传感器 A 的数据分布为
2. 操作流程
**(1) 训练阶段:拟合分位数参数**
python
复制
from sklearn.preprocessing import QuantileTransformer
# 示例训练数据(假设形状为 [n_samples, n_features])
X_train = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]])
# 初始化分位数变换器(默认映射到均匀分布)
qt = QuantileTransformer(n_quantiles=100, output_distribution='uniform', random_state=42)
# 拟合训练数据并转换
X_train_transformed = qt.fit_transform(X_train)
# 保存分位数参数(关键!)
np.save("quantile_transformer_params.npy", qt.quantiles_)
**(2) 实际数据应用阶段:使用训练集参数转换**
python
复制
# 加载保存的分位数参数
qt = QuantileTransformer()
qt.quantiles_ = np.load("quantile_transformer_params.npy")
# 实际数据(来自不同传感器,假设单位不同)
X_actual = np.array([[10, 20], [30, 40], [50, 60]])
# 应用变换
X_actual_transformed = qt.transform(X_actual)
3. 关键参数解析
参数 | 说明 |
---|---|
n_quantiles | 计算的分位数数量(默认=1000,建议根据数据量调整,数据量小则降低) |
output_distribution | 目标分布,可选 'uniform' (均匀分布)或 'normal' (正态分布) |
ignore_implicit_zeros | 是否忽略稀疏数据中的隐式零(适用于文本或图像数据) |
subsample | 当数据量大于 1e4 时,随机采样子样本加速计算(例如 subsample=1e4 ) |
4. 优点与适用场景
- 优点:
- 消除量纲差异:不同传感器的绝对数值差异被归一化。
- 保留序数关系:数据的相对大小(如趋势变化)不变。
- 非参数鲁棒性:不假设数据分布,适用于任意分布类型。
- 适用场景:
- 传感器校准不可行(如硬件差异、成本限制)。
- 数据仅能反映趋势(如定性数据、未校准传感器)。
5. 处理数据超出训练范围的情况
当实际数据超出训练数据范围时,分位数变换默认会进行 外推(Extrapolation):
- 外推模式:
- 低于训练集最小值 → 映射到目标分布的最小值以下。
- 高于训练集最大值 → 映射到目标分布的最大值以上。
- 风险:
- 极端值可能导致变换后数据超出目标分布范围(如均匀分布下的 [0,1])。
- 解决方案:
python
复制
# 方法1:截断超出范围的值 X_actual_transformed = np.clip(X_actual_transformed, 0.0, 1.0) # 方法2:调整分位数变换参数以包含更大范围 qt = QuantileTransformer(output_distribution='uniform', n_quantiles=1000)
6. 示例:传感器数据对齐
假设传感器 A(训练数据)和传感器 B(实际数据)测量同一物理量但单位不同:
- 传感器 A(训练数据):温度(单位:°C),范围
[0, 100]
。 - 传感器 B(实际数据):温度(单位:°F),范围
[32, 212]
。
**(1) 原始数据分布**
传感器 | 数据示例 | 分位数(0%, 25%, 50%, 75%, 100%) |
---|---|---|
A | [0, 25, 50, 75, 100] | [0, 25, 50, 75, 100] |
B | [32, 77, 122, 167, 212] | [32, 77, 122, 167, 212] |
**(2) 分位数变换后**
- 训练数据(传感器 A)和目标分布(均匀分布):
[0.0, 0.25, 0.5, 0.75, 1.0]
- 实际数据(传感器 B)按传感器 A 的分位数映射:
传感器 B 值:32 → 对应传感器 A 的 0°C → 变换后 0.0 传感器 B 值:77 → 对应传感器 A 的 25°C → 变换后 0.25 传感器 B 值:122 → 对应传感器 A 的 50°C → 变换后 0.5 ...
7. 总结
- 操作步骤:
- 训练阶段:用训练数据拟合分位数变换器,保存参数。
- 实际应用:加载参数,对所有新数据(包括不同传感器数据)应用相同变换。
- 核心价值:通过强制数据分布一致,使模型无需关心传感器硬件差异,专注于数据内在模式。
实操规划:
使用分位数变换(Quantile Transform)时,原始数据集需要经过预处理,但不需要改变其结构;模型本身通常不需要调整结构,但需要重新训练以适应变换后的数据分布。以下是具体分析:
1. 原始数据集的变化
**(1) 预处理流程变化**
- 需要新增步骤:在数据预处理阶段加入分位数变换。
- 不改变数据结构:样本数量、特征数量、标签等均保持不变,仅特征的数值分布被映射到目标分布(如均匀分布或正态分布)。
**(2) 示例代码**
python
复制
from sklearn.preprocessing import QuantileTransformer
# 训练阶段:拟合分位数参数并转换数据
qt = QuantileTransformer(output_distribution='normal') # 映射到正态分布
X_train_transformed = qt.fit_transform(X_train) # 训练数据需要变换
# 测试/实际数据阶段:加载参数并应用
X_test_transformed = qt.transform(X_test) # 测试数据同样需要变换
**(3) 注意事项**
- 保存参数:必须保存训练集的
qt
对象(尤其是qt.quantiles_
),确保后续数据使用相同的变换规则。 - 数据一致性:测试数据和实际数据必须使用训练集的参数,不能重新拟合分位数变换器!
2. 模型的变化
**(1) 模型是否需要调整结构?**
- 不需要:分位数变换仅改变输入特征的分布,不改变模型的输入维度或结构。
- 适用性:所有模型(如神经网络、决策树、SVM 等)均可直接使用变换后的数据,无需修改模型结构。
**(2) 模型是否需要重新训练?**
- 必须重新训练:变换后的数据分布与原始数据不同,模型需要在新分布上重新学习。
python
复制
# 错误做法:直接在原始模型上预测变换后的数据(模型未见过新分布,效果可能下降) predictions = old_model.predict(X_test_transformed) # 正确做法:用变换后的数据重新训练模型 model.fit(X_train_transformed, y_train) predictions = model.predict(X_test_transformed)
**(3) 不同模型的敏感性**
模型类型 | 分位数变换的收益 | 是否需要重新训练 |
---|---|---|
线性模型(如线性回归、逻辑回归) | 收益显著:线性模型假设特征服从正态分布,分位数变换可提升性能。 | 必须重新训练 |
树模型(如随机森林、XGBoost) | 收益较小:树模型不依赖数据分布,但变换可能改善特征分裂点的选择。 | 可选择性重新训练(推荐重新训练) |
神经网络 | 收益中等:网络可通过学习适应任意分布,但分位数变换可加速收敛。 | 必须重新训练 |
3. 完整流程示例
**(1) 训练阶段**
python
复制
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import QuantileTransformer
# 1. 加载原始数据
X_train, y_train = load_train_data() # 原始数据未变换
# 2. 分位数变换预处理
qt = QuantileTransformer(output_distribution='normal')
X_train_transformed = qt.fit_transform(X_train)
# 3. 训练模型(使用变换后的数据)
model = RandomForestClassifier()
model.fit(X_train_transformed, y_train)
# 4. 保存模型和分位数变换器
import joblib
joblib.dump(model, "model.pkl")
joblib.dump(qt, "quantile_transformer.pkl")
**(2) 推理阶段**
python
复制
# 1. 加载模型和分位数变换器
model = joblib.load("model.pkl")
qt = joblib.load("quantile_transformer.pkl")
# 2. 加载新数据(原始格式)
X_new = load_new_data()
# 3. 应用分位数变换(使用训练集的参数)
X_new_transformed = qt.transform(X_new)
# 4. 预测
predictions = model.predict(X_new_transformed)
4. 常见问题解答
**(1) 分位数变换是否适用于所有数据类型?**
- 适用:数值型数据(连续或离散)。
- 不适用:分类数据(需先编码为数值,如 One-Hot)。
**(2) 分位数变换后数据超出 [0,1] 范围怎么办?**
- 原因:目标分布设置为正态分布(
output_distribution='normal'
)时,数据范围是 (-∞, +∞)。 - 处理:无需额外操作,模型可直接使用。
**(3) 是否需要在交叉验证中集成分位数变换?**
- 需要:将分位数变换作为 Pipeline 的一部分,避免数据泄露。
python
复制
from sklearn.pipeline import Pipeline pipeline = Pipeline([ ('quantile', QuantileTransformer()), ('model', RandomForestClassifier()) ])
总结
- 数据集:需新增分位数变换预处理步骤,但无需改变数据结构。
- 模型:无需调整结构,但必须用变换后的数据重新训练。
- 核心原则:保持训练和推理阶段的数据处理流程一致,避免信息泄露。
好的,用户现在的问题是关于在使用KFold交叉验证的同时如何进行分位数变换。他们提供的代码已经实现了交叉验证,但需要将分位数变换整合进去,同时避免数据泄露。我需要仔细分析他们的代码结构,然后找到合适的地方插入分位数变换的步骤,并确保每个fold的训练集和验证集正确处理。
首先,用户之前已经了解了分位数变换的基本原理和操作流程,现在的问题是在交叉验证的每个fold中正确应用它。关键点在于,每个fold的训练集和验证集划分后,必须只在训练集上拟合分位数变换器,然后用这个变换器去转换验证集,而不是在整个数据集上拟合或者每个fold都重新拟合。这样才能避免数据泄露,确保验证集的独立性。
接下来,我需要考虑如何修改他们的现有代码。用户当前的代码在交叉验证的循环中划分了训练集和验证集,然后将它们转换为PyTorch的DataLoader。分位数变换应该在数据转换为Tensor之前进行,也就是在划分完训练集和验证集之后,但在创建TensorDataset之前。
另外,每个fold的训练集需要独立计算分位数参数,这样每个fold的模型都是在各自训练集的分布上训练的。同时,验证集必须使用对应训练集的参数进行变换,而不是整个数据集或者其他fold的参数。这样能保证每个fold的数据处理都是独立的,避免信息泄露。
还需要注意,分位数变换器在训练集上拟合后,要应用到验证集。在PyTorch中,数据通常是numpy数组的形式,所以可以先用sklearn的QuantileTransformer处理,然后再转换为Tensor。每个fold循环中,需要为训练集创建新的分位数变换器,拟合训练数据,然后转换训练集和验证集。
用户可能容易犯的错误是在交叉验证的外部就进行了分位数变换,这样会导致所有fold都使用整个数据集的统计量,从而引入数据泄露。因此,必须确保每个fold内部处理自己的训练集和验证集,确保验证集的数据不参与任何拟合过程。
最后,需要验证代码修改后的正确性,确保每个fold的训练集和验证集都正确应用了分位数变换,同时模型训练和验证的结果是合理的。可能需要检查变换后的数据分布是否在训练集和验证集上有预期的效果,比如验证集的数据是否基于训练集的参数进行变换,而不是自身的统计量。
总结来说,步骤应该是:在交叉验证的每个fold循环中,划分训练集和验证集后,对训练集进行分位数变换拟合,然后应用该变换到训练集和验证集,再继续后续的数据处理和模型训练。这样每个fold都是独立处理的,避免了数据泄露,保证了交叉验证的有效性。
在 KFold 交叉验证中应用分位数变换(Quantile Transform),需要确保每个 fold 的训练集和验证集独立处理,避免数据泄露。以下是针对你的代码的 分步修改方案:
1. 修改后的交叉验证流程
python
复制
from sklearn.preprocessing import QuantileTransformer
# 4. KFold 交叉验证
kf = KFold(n_splits=k, shuffle=True, random_state=42)
for fold, (train_index, val_index) in enumerate(kf.split(X, y)):
print(f"Fold {fold + 1}:")
# 划分训练集和验证集(原始数据)
X_train_raw, X_val_raw = X[train_index], X[val_index]
y_train, y_val = y[train_index], y[val_index]
# --- 新增分位数变换步骤 ---
# 初始化分位数变换器(每个 fold 独立)
qt = QuantileTransformer(n_quantiles=100, output_distribution='uniform', random_state=42)
# 仅在训练集上拟合分位数变换器
X_train_transformed = qt.fit_transform(X_train_raw)
# 使用训练集的参数变换验证集
X_val_transformed = qt.transform(X_val_raw)
# -------------------------
# 转换为 PyTorch 数据集(使用变换后的数据)
train_dataset = TensorDataset(
torch.from_numpy(X_train_transformed).float(),
torch.from_numpy(y_train).long()
)
val_dataset = TensorDataset(
torch.from_numpy(X_val_transformed).float(),
torch.from_numpy(y_val).long()
)
# 后续代码保持不变...
# 创建数据加载器、初始化模型、训练循环等
2. 关键修改说明
**(1) 分位数变换的独立性**
- 每个 fold 独立拟合:在交叉验证的每个 fold 中,仅用当前训练集 (
X_train_raw
) 拟合分位数变换器 (qt.fit_transform
),确保验证集的数据不参与任何参数计算。 - 验证集使用训练集的参数:验证集 (
X_val_transformed
) 必须使用当前训练集的变换参数 (qt.transform
),而非自身统计量。
**(2) 避免数据泄露**
- 原始数据划分:在未变换的原始数据 (
X_train_raw
,X_val_raw
) 上进行数据划分。 - 变换后数据隔离:变换后的训练集和验证集 (
X_train_transformed
,X_val_transformed
) 完全隔离,确保验证集不泄露信息。
**(3) 分位数变换参数设置**
- **
n_quantiles
**:根据数据量调整,数据量小则减少(如 100)。 - **
output_distribution
**:可选'uniform'
(均匀分布)或'normal'
(正态分布)。 - **
random_state
**:固定随机种子,确保可复现性。
3. 完整代码示例
python
复制
import torch
import numpy as np
from sklearn.model_selection import KFold
from sklearn.preprocessing import QuantileTransformer
from torch.utils.data import TensorDataset, DataLoader
# 假设 X 和 y 是 NumPy 数组,形状为 (样本数, 特征数) 和 (样本数,)
X = np.random.randn(1000, 21) # 示例数据
y = np.random.randint(0, 2, size=1000) # 示例标签
# 超参数
k = 5
batch_size = 64
num_epochs = 10
# KFold 交叉验证
kf = KFold(n_splits=k, shuffle=True, random_state=42)
for fold, (train_index, val_index) in enumerate(kf.split(X, y)):
print(f"Fold {fold + 1}:")
# 划分原始数据
X_train_raw, X_val_raw = X[train_index], X[val_index]
y_train, y_val = y[train_index], y[val_index]
# 分位数变换(仅在训练集拟合,验证集应用)
qt = QuantileTransformer(n_quantiles=100, output_distribution='uniform', random_state=42)
X_train_transformed = qt.fit_transform(X_train_raw)
X_val_transformed = qt.transform(X_val_raw)
# 转换为 PyTorch 数据集
train_dataset = TensorDataset(
torch.from_numpy(X_train_transformed).float(),
torch.from_numpy(y_train).long()
)
val_dataset = TensorDataset(
torch.from_numpy(X_val_transformed).float(),
torch.from_numpy(y_val).long()
)
# 创建 DataLoader
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
# 初始化模型、定义损失函数和优化器
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = LSTMModel(input_size=21, hidden_size=128, num_layers=2, num_classes=2).to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# 训练循环
for epoch in range(num_epochs):
model.train()
for i, (sequences, labels) in enumerate(train_loader):
sequences, labels = sequences.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(sequences)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# 打印训练日志...
# 验证
model.eval()
correct, total = 0, 0
with torch.no_grad():
for sequences, labels in val_loader:
sequences, labels = sequences.to(device), labels.to(device)
outputs = model(sequences)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f"Validation Accuracy: {100 * correct / total:.2f}%")
print("-" * 20)
4. 注意事项
- 数据类型的转换:确保分位数变换后的数据为
float32
(与 PyTorch 默认张量类型一致)。 - 分位数变换器的保存:若需部署模型,需保存每个 fold 的分位数变换器(或选择最佳 fold 的变换器)。
python
复制
import joblib joblib.dump(qt, f"qt_fold_{fold}.pkl")
- 超范围数据处理:若验证集数据超出训练集范围,可截断或记录异常:
python
复制
X_val_transformed = np.clip(X_val_transformed, 0.0, 1.0)
5. 总结
通过在每个 fold 中独立进行分位数变换,确保交叉验证的评估结果反映模型在真实场景中的泛化性能。此方法可有效消除不同传感器或数据源的量纲差异,同时严格避免数据泄露。
4. 部署时的使用示例
假设在交叉验证中选择第 3 个 fold 的模型性能最佳,部署代码如下:
python
复制
# 加载最佳 fold 的模型和变换器
model = torch.load("model_fold_3.pth")
qt = joblib.load("qt_fold_3.pkl")
# 处理新数据
new_data_raw = load_new_data() # 原始数据
new_data_transformed = qt.transform(new_data_raw)
new_data_transformed = np.clip(new_data_transformed, 0.0, 1.0) # 截断超范围值
# 推理
with torch.no_grad():
inputs = torch.from_numpy(new_data_transformed).float().to(device)
outputs = model(inputs)
5. 注意事项
- 训练集是否需要截断?
不需要!训练集的分位数变换是基于自身数据的,理论上应分布在[0, 1]
内(若output_distribution='uniform'
)。 - 何时需要截断?
仅在验证集或实际数据中可能出现超出训练集范围的情况。 - 正态分布场景:
若output_distribution='normal'
,目标分布是标准正态分布(范围(-∞, +∞)
),此时无需截断。
2025.3.30修改模型:
import os
import re
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import QuantileTransformer
# 提取文件名特征的函数
def extract_features(file_name):
match = re.match(r"([A-Za-z]{2})([A-Za-z]{2})(\d{2})_(\d{2})\.txt", file_name)
if match:
data_source = match.group(1) # 数据来源 (Ga, Ju, Si)
gait_type = match.group(4) # 步态编号 (如 01, 10)
# 转换数据来源为完整名称
data_source_mapping = {
"Ga": 0, # 0代表Galit Yogev
"Ju": 1, # 1代表Hausdorff
"Si": 2 # 2代表Silvi Frenkel-Toledo
}
# 转换步态编号为描述
gait_type_mapping = {
"01": 0, # 0代表正常行走
"02": 1, # 1代表双任务行走
"03": 2, # 0代表正常行走
"04": 3, # 1代表双任务行走
"05": 4, # 0代表正常行走
"06": 5, # 1代表双任务行走
"07": 6, # 0代表正常行走
"08": 7, # 1代表双任务行走
"09": 8, # 0代表正常行走
"10": 9 # 1代表双任务行走
}
data_source_code = data_source_mapping.get(data_source, -1) # 默认为-1表示未知
gait_type_code = gait_type_mapping.get(gait_type, -1) # 默认为-1表示未知
return data_source_code, gait_type_code
else:
return None, None
# 使用滑动窗口切片时间序列数据的函数
def sliding_window_slice(data, window_size, step_size):
slices = []
for start in range(0, len(data) - window_size + 1, step_size):
end = start + window_size
slice_data = data[start:end]
slices.append(slice_data)
#print(f"{file_name} 的切片数量: {len(slices)}")#打印一下每个文件的切片数量,确保它符合预期,有可能卡慎用
return slices
def add_features_as_columns(slice_data, data_source, gait_type):
"""
将 data_source 和 gait_type 作为两列数据添加到 slice_data 后面。
Args:
slice_data (numpy.ndarray): 形状为 (600, 19) 的数据。
data_source (int): 数据来源特征。
gait_type (int): 步态类型特征。
Returns:
numpy.ndarray: 添加特征后的数据,形状为 (600, 21)。
"""
# 1. 将 data_source 和 gait_type 扩展为列向量
data_source_column = np.full((slice_data.shape[0], 1), data_source) # 形状为 (600, 1)
gait_type_column = np.full((slice_data.shape[0], 1), gait_type) # 形状为 (600, 1)
# 2. 将扩展后的列向量拼接到 slice_data 后面
new_data = np.concatenate((slice_data, data_source_column, gait_type_column), axis=1)
return new_data
# 处理文件夹中的所有txt文件,生成数据集
def process_files_in_folder(folder_path, window_size=600, step_size=100):
data = []
labels = []
# 获取文件夹中的所有txt文件
files = os.listdir(folder_path)
txt_files = [file for file in files if file.endswith('.txt')]
for file_name in txt_files:
# 提取文件名特征(数据来源、步态类型)
data_source, gait_type = extract_features(file_name)
#print(f"文件名: {file_name} , 数据来源: {data_source}, 步态类型: {gait_type}")#检查文件名特征提取是否正确
if data_source is not None and gait_type is not None:
# 加载传感器数据(假设数据是以空格或逗号分隔的数字)
file_path = os.path.join(folder_path, file_name)
sensor_data = np.loadtxt(file_path) # 根据实际格式调整加载方式
# 使用滑动窗口切片
slices = sliding_window_slice(sensor_data, window_size, step_size)
# 给每个切片添加标签(例如根据文件名中的组别信息)
label = 0 if 'Co' in file_name else 1 # Co -> 0, Pt -> 1
#print(f"文件名: {file_name}, 标签: {label}")#验证标签是否正确
# 将数据来源和步态类型特征附加到每个切片
for slice_data in slices:
# 将数据来源和步态类型特征附加到传感器数据
#feature_data = np.append(slice_data, [data_source, gait_type]) # 拼接特征到数据切片
feature_data=add_features_as_columns(slice_data, data_source, gait_type)
data.append(feature_data)
labels.append(label)
#print(f"文件名: {file_name}, 标签: {label}")
else:
print(f"文件名 {file_name} 格式不正确,跳过")
# 将所有数据堆叠成训练集
data = np.array(data)
labels = np.array(labels)
print(f"切片数据集形状:{data[100].shape}")
print(f"堆叠后数据集形状:{data.shape}")
print(data[1000][:, -2:])
print(f"{data_source}")
print(f"{gait_type}")
print(f"加入特征后数据形状: {feature_data.shape}")
print(f"传感器数据形状: {slice_data.shape}")
return data, labels
# 设置文件夹路径
folder_path = 'gait-in-parkinsons-disease-1.0.0/gait-in-parkinsons-disease-1.0.0' # 请替换为实际路径
# 处理文件并生成训练数据集
X, y = process_files_in_folder(folder_path)
print(f"训练数据集形状: {X.shape}")
print(f"标签形状: {y.shape}")
#print(f"文件名: {file_name}, 标签: {label}")
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from sklearn.model_selection import KFold
from sklearn.preprocessing import QuantileTransformer
# 1. 定义LSTM模型
class LSTMModel(nn.Module):
# ... (与您提供的代码相同)
def __init__(self, input_size, hidden_size, num_layers, num_classes):
super(LSTMModel, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
# 定义LSTM层
self.lstm = nn.LSTM(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True # 输入输出张量格式为(batch, seq_len, feature)
)
# 定义全连接层
self.fc = nn.Linear(hidden_size, num_classes)
def forward(self, x):
# 初始化隐藏状态和细胞状态
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
# 前向传播LSTM
out, _ = self.lstm(x, (h0, c0)) # 输出维度(batch_size, seq_length, hidden_size)
# 只取最后一个时间步的输出
out = out[:, -1, :]
# 全连接层
out = self.fc(out)
return out
# 2. 超参数设置
input_size = 21
hidden_size = 128
num_layers = 2
num_classes = 2
batch_size = 64
learning_rate = 0.001
num_epochs = 10
k = 5 # 折数
# 3. 数据准备 (假设您的数据是numpy数组格式)
# X.shape = (31549, 600, 21)
# y.shape = (31549,)
# 4. KFold 交叉验证
kf = KFold(n_splits=k, shuffle=True, random_state=42)
for fold, (train_index, val_index) in enumerate(kf.split(X, y)):
print(f"Fold {fold + 1}:")
# 划分训练集和验证集
X_train_raw, X_val_raw = X[train_index], X[val_index]
y_train, y_val = y[train_index], y[val_index]
# --- 分位数变换(保持时间序列结构)---
# 1. 将训练集重塑为二维 (n_samples * n_timesteps, n_features)
n_samples_train, n_timesteps, n_features = X_train_raw.shape
X_train_2d = X_train_raw.reshape(-1, n_features)
# --- 新增分位数变换步骤 ---
# 初始化分位数变换器(每个 fold 独立)
qt = QuantileTransformer(n_quantiles=100, output_distribution='uniform', random_state=42)
# 仅在训练集上拟合分位数变换器
X_train_transformed_2d = qt.fit_transform(X_train_2d)
#X_train_transformed = qt.fit_transform(X_train_raw)
# 3. 还原为三维
X_train_transformed = X_train_transformed_2d.reshape(n_samples_train, n_timesteps, n_features)
# 使用训练集的参数变换验证集
#X_val_transformed = qt.transform(X_val_raw)
# -------------------------
# --- 超范围处理(截断到 [0, 1])---
#X_val_transformed = np.clip(X_val_transformed, 0.0, 1.0)
# -------------------------------
# 4. 处理验证集(同样先展平再变换)
n_samples_val = X_val_raw.shape[0]
X_val_2d = X_val_raw.reshape(-1, n_features)
X_val_transformed_2d = qt.transform(X_val_2d)
X_val_transformed = X_val_transformed_2d.reshape(n_samples_val, n_timesteps, n_features)
# -----------------------------------
# 转换为PyTorch数据集
train_dataset = TensorDataset(
torch.from_numpy(X_train_transformed).float(),
torch.from_numpy(y_train).long()
)
val_dataset = TensorDataset(
torch.from_numpy(X_val_transformed).float(),
torch.from_numpy(y_val).long()
)
# 创建数据加载器
train_loader = DataLoader(
dataset=train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=2
)
val_loader = DataLoader(
dataset=val_dataset,
batch_size=batch_size,
shuffle=False, # 验证集不需要打乱
num_workers=2
)
# 5. 初始化模型
#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = LSTMModel(input_size, hidden_size, num_layers, num_classes).to(device)
# 6. 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# 7. 训练循环
for epoch in range(num_epochs):
model.train()
for i, (sequences, labels) in enumerate(train_loader):
sequences = sequences.to(device)
labels = labels.to(device)
#optimizer.zero_grad()
# 前向传播
outputs = model(sequences)
loss = criterion(outputs, labels)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (i+1) % 100 == 0:
print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}')
# 8. 验证
model.eval()
correct = 0
total = 0
with torch.no_grad():
for sequences, labels in val_loader:
sequences = sequences.to(device)
labels = labels.to(device)
outputs = model(sequences)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f"Epoch [{epoch+1}/{num_epochs}], Validation Accuracy: {100 * correct / total:.2f}%")
# 保存当前 fold 的分位数变换器
import joblib
joblib.dump(qt, f"qt_fold_{fold}.pkl") # 文件名示例:qt_fold_0.pkl
print("-" * 20)
Fold 1: Epoch [1/10], Step [100/395], Loss: 0.5710 Epoch [1/10], Step [200/395], Loss: 0.5688 Epoch [1/10], Step [300/395], Loss: 0.4431 Epoch [1/10], Validation Accuracy: 69.08% Epoch [2/10], Step [100/395], Loss: 0.5198 Epoch [2/10], Step [200/395], Loss: 0.5560 Epoch [2/10], Step [300/395], Loss: 0.4902 Epoch [2/10], Validation Accuracy: 75.31% Epoch [3/10], Step [100/395], Loss: 0.5717 Epoch [3/10], Step [200/395], Loss: 0.4435 Epoch [3/10], Step [300/395], Loss: 0.5549 Epoch [3/10], Validation Accuracy: 72.19% Epoch [4/10], Step [100/395], Loss: 0.4831 Epoch [4/10], Step [200/395], Loss: 0.5100 Epoch [4/10], Step [300/395], Loss: 0.4934 Epoch [4/10], Validation Accuracy: 71.36% Epoch [5/10], Step [100/395], Loss: 0.3872 Epoch [5/10], Step [200/395], Loss: 0.5233 Epoch [5/10], Step [300/395], Loss: 0.4830 Epoch [5/10], Validation Accuracy: 67.05% Epoch [6/10], Step [100/395], Loss: 0.5394 Epoch [6/10], Step [200/395], Loss: 0.5739 Epoch [6/10], Step [300/395], Loss: 0.4218 Epoch [6/10], Validation Accuracy: 75.80% Epoch [7/10], Step [100/395], Loss: 0.4925 Epoch [7/10], Step [200/395], Loss: 0.3945 Epoch [7/10], Step [300/395], Loss: 0.3966 Epoch [7/10], Validation Accuracy: 76.04% Epoch [8/10], Step [100/395], Loss: 0.5676 Epoch [8/10], Step [200/395], Loss: 0.4444 Epoch [8/10], Step [300/395], Loss: 0.4245 Epoch [8/10], Validation Accuracy: 82.66% Epoch [9/10], Step [100/395], Loss: 0.3606 Epoch [9/10], Step [200/395], Loss: 0.3105 Epoch [9/10], Step [300/395], Loss: 0.2725 Epoch [9/10], Validation Accuracy: 88.61% Epoch [10/10], Step [100/395], Loss: 0.1843 Epoch [10/10], Step [200/395], Loss: 0.1532 Epoch [10/10], Step [300/395], Loss: 0.3624 Epoch [10/10], Validation Accuracy: 89.33% -------------------- Fold 2: Epoch [1/10], Step [100/395], Loss: 0.5394 Epoch [1/10], Step [200/395], Loss: 0.6227 Epoch [1/10], Step [300/395], Loss: 0.5261 Epoch [1/10], Validation Accuracy: 74.90% Epoch [2/10], Step [100/395], Loss: 0.4530 Epoch [2/10], Step [200/395], Loss: 0.4903 Epoch [2/10], Step [300/395], Loss: 0.4873 Epoch [2/10], Validation Accuracy: 77.29% Epoch [3/10], Step [100/395], Loss: 0.5052 Epoch [3/10], Step [200/395], Loss: 0.4676 Epoch [3/10], Step [300/395], Loss: 0.4754 Epoch [3/10], Validation Accuracy: 80.10% Epoch [4/10], Step [100/395], Loss: 0.4640 Epoch [4/10], Step [200/395], Loss: 0.5620 Epoch [4/10], Step [300/395], Loss: 0.5141 Epoch [4/10], Validation Accuracy: 78.70% Epoch [5/10], Step [100/395], Loss: 0.5870 Epoch [5/10], Step [200/395], Loss: 0.3820 Epoch [5/10], Step [300/395], Loss: 0.4283 Epoch [5/10], Validation Accuracy: 79.22% Epoch [6/10], Step [100/395], Loss: 0.4936 Epoch [6/10], Step [200/395], Loss: 0.5366 Epoch [6/10], Step [300/395], Loss: 0.4548 Epoch [6/10], Validation Accuracy: 71.28% Epoch [7/10], Step [100/395], Loss: 0.4862 Epoch [7/10], Step [200/395], Loss: 0.4579 Epoch [7/10], Step [300/395], Loss: 0.4709 Epoch [7/10], Validation Accuracy: 79.45% Epoch [8/10], Step [100/395], Loss: 0.5021 Epoch [8/10], Step [200/395], Loss: 0.3528 Epoch [8/10], Step [300/395], Loss: 0.4208 Epoch [8/10], Validation Accuracy: 86.23% Epoch [9/10], Step [100/395], Loss: 0.4185 Epoch [9/10], Step [200/395], Loss: 0.3183 Epoch [9/10], Step [300/395], Loss: 0.1801 Epoch [9/10], Validation Accuracy: 90.32% Epoch [10/10], Step [100/395], Loss: 0.2121 Epoch [10/10], Step [200/395], Loss: 0.3103 Epoch [10/10], Step [300/395], Loss: 0.2780 Epoch [10/10], Validation Accuracy: 89.16% -------------------- Fold 3: Epoch [1/10], Step [100/395], Loss: 0.5319 Epoch [1/10], Step [200/395], Loss: 0.5289 Epoch [1/10], Step [300/395], Loss: 0.5112 Epoch [1/10], Validation Accuracy: 77.02% Epoch [2/10], Step [100/395], Loss: 0.5750 Epoch [2/10], Step [200/395], Loss: 0.4664 Epoch [2/10], Step [300/395], Loss: 0.4821 Epoch [2/10], Validation Accuracy: 69.95% Epoch [3/10], Step [100/395], Loss: 0.4684 Epoch [3/10], Step [200/395], Loss: 0.5249 Epoch [3/10], Step [300/395], Loss: 0.4191 Epoch [3/10], Validation Accuracy: 74.91% Epoch [4/10], Step [100/395], Loss: 0.4511 Epoch [4/10], Step [200/395], Loss: 0.4825 Epoch [4/10], Step [300/395], Loss: 0.5972 Epoch [4/10], Validation Accuracy: 71.97% Epoch [5/10], Step [100/395], Loss: 0.4689 Epoch [5/10], Step [200/395], Loss: 0.5675 Epoch [5/10], Step [300/395], Loss: 0.5797 Epoch [5/10], Validation Accuracy: 72.04% Epoch [6/10], Step [100/395], Loss: 0.4590 Epoch [6/10], Step [200/395], Loss: 0.4361 Epoch [6/10], Step [300/395], Loss: 0.4596 Epoch [6/10], Validation Accuracy: 79.02% Epoch [7/10], Step [100/395], Loss: 0.3407 Epoch [7/10], Step [200/395], Loss: 0.3697 Epoch [7/10], Step [300/395], Loss: 0.4824 Epoch [7/10], Validation Accuracy: 80.10% Epoch [8/10], Step [100/395], Loss: 0.3347 Epoch [8/10], Step [200/395], Loss: 0.3416 Epoch [8/10], Step [300/395], Loss: 0.3426 Epoch [8/10], Validation Accuracy: 86.97% Epoch [9/10], Step [100/395], Loss: 0.3905 Epoch [9/10], Step [200/395], Loss: 0.4525 Epoch [9/10], Step [300/395], Loss: 0.1979 Epoch [9/10], Validation Accuracy: 90.43% Epoch [10/10], Step [100/395], Loss: 0.2089 Epoch [10/10], Step [200/395], Loss: 0.2229 Epoch [10/10], Step [300/395], Loss: 0.1659 Epoch [10/10], Validation Accuracy: 93.47% -------------------- Fold 4: Epoch [1/10], Step [100/395], Loss: 0.5331 Epoch [1/10], Step [200/395], Loss: 0.5718 Epoch [1/10], Step [300/395], Loss: 0.5269 Epoch [1/10], Validation Accuracy: 73.66% Epoch [2/10], Step [100/395], Loss: 0.6224 Epoch [2/10], Step [200/395], Loss: 0.4136 Epoch [2/10], Step [300/395], Loss: 0.4466 Epoch [2/10], Validation Accuracy: 81.87% Epoch [3/10], Step [100/395], Loss: 0.5485 Epoch [3/10], Step [200/395], Loss: 0.3710 Epoch [3/10], Step [300/395], Loss: 0.6995 Epoch [3/10], Validation Accuracy: 72.73% Epoch [4/10], Step [100/395], Loss: 0.4283 Epoch [4/10], Step [200/395], Loss: 0.5224 Epoch [4/10], Step [300/395], Loss: 0.4817 Epoch [4/10], Validation Accuracy: 83.90% Epoch [5/10], Step [100/395], Loss: 0.4079 Epoch [5/10], Step [200/395], Loss: 0.4213 Epoch [5/10], Step [300/395], Loss: 0.4585 Epoch [5/10], Validation Accuracy: 57.37% Epoch [6/10], Step [100/395], Loss: 0.3714 Epoch [6/10], Step [200/395], Loss: 0.4363 Epoch [6/10], Step [300/395], Loss: 0.3943 Epoch [6/10], Validation Accuracy: 68.73% Epoch [7/10], Step [100/395], Loss: 0.4605 Epoch [7/10], Step [200/395], Loss: 0.5365 Epoch [7/10], Step [300/395], Loss: 0.4774 Epoch [7/10], Validation Accuracy: 79.00% Epoch [8/10], Step [100/395], Loss: 0.6591 Epoch [8/10], Step [200/395], Loss: 0.4983 Epoch [8/10], Step [300/395], Loss: 0.2841 Epoch [8/10], Validation Accuracy: 79.46% Epoch [9/10], Step [100/395], Loss: 0.4531 Epoch [9/10], Step [200/395], Loss: 0.3549 Epoch [9/10], Step [300/395], Loss: 0.2137 Epoch [9/10], Validation Accuracy: 88.80% Epoch [10/10], Step [100/395], Loss: 0.2804 Epoch [10/10], Step [200/395], Loss: 0.1934 Epoch [10/10], Step [300/395], Loss: 0.2239 Epoch [10/10], Validation Accuracy: 91.05% -------------------- Fold 5: Epoch [1/10], Step [100/395], Loss: 0.9139 Epoch [1/10], Step [200/395], Loss: 0.4847 Epoch [1/10], Step [300/395], Loss: 0.5156 Epoch [1/10], Validation Accuracy: 70.44% Epoch [2/10], Step [100/395], Loss: 0.5444 Epoch [2/10], Step [200/395], Loss: 0.4427 Epoch [2/10], Step [300/395], Loss: 0.6123 Epoch [2/10], Validation Accuracy: 71.12% Epoch [3/10], Step [100/395], Loss: 0.5134 Epoch [3/10], Step [200/395], Loss: 0.5476 Epoch [3/10], Step [300/395], Loss: 0.6857 Epoch [3/10], Validation Accuracy: 54.54% Epoch [4/10], Step [100/395], Loss: 0.6450 Epoch [4/10], Step [200/395], Loss: 0.5678 Epoch [4/10], Step [300/395], Loss: 0.5123 Epoch [4/10], Validation Accuracy: 71.63% Epoch [5/10], Step [100/395], Loss: 0.5825 Epoch [5/10], Step [200/395], Loss: 0.5873 Epoch [5/10], Step [300/395], Loss: 0.5484 Epoch [5/10], Validation Accuracy: 70.45% Epoch [6/10], Step [100/395], Loss: 0.4344 Epoch [6/10], Step [200/395], Loss: 0.4760 Epoch [6/10], Step [300/395], Loss: 0.4506 Epoch [6/10], Validation Accuracy: 73.48% Epoch [7/10], Step [100/395], Loss: 0.5211 Epoch [7/10], Step [200/395], Loss: 0.4966 Epoch [7/10], Step [300/395], Loss: 0.5997 Epoch [7/10], Validation Accuracy: 79.68% Epoch [8/10], Step [100/395], Loss: 0.4426 Epoch [8/10], Step [200/395], Loss: 0.4270 Epoch [8/10], Step [300/395], Loss: 0.3234 Epoch [8/10], Validation Accuracy: 81.07% Epoch [9/10], Step [100/395], Loss: 0.4746 Epoch [9/10], Step [200/395], Loss: 0.3826 Epoch [9/10], Step [300/395], Loss: 0.4755 Epoch [9/10], Validation Accuracy: 85.43% Epoch [10/10], Step [100/395], Loss: 0.3076 Epoch [10/10], Step [200/395], Loss: 0.4571 Epoch [10/10], Step [300/395], Loss: 0.3124 Epoch [10/10], Validation Accuracy: 89.94% --------------------