海洋中尺度涡旋检测模型的训练与评估
1. 数据读取与类别频率分析
首先,从指定路径读取样本数据:
- 从 /home/username/ML_eddies/cds_ssh_1998 - 2018_10day_interval/subset_pet_masks_with_adt_1998 - 2018_lat14N - 46N_lon166W - 134W.npz 读取 987 个样本。
- 从 /home/username/ML_eddies/cds_ssh_2019_10day_interval/subset_pet_masks_with_adt_2019_lat14N - 46N_lon166W - 134W.npz 读取 47 个样本。
接着,分析训练数据集中类别的分布情况,以识别类别不平衡问题。以下是具体代码:
import numpy as np
train_masks = train_loader.dataset.masks.copy()
class_frequency = np.bincount(train_masks.flatten())
total_pixels = sum(class_frequency)
print(
f"Total number of pixels in training set: {total_pixels/1e6:.2f} megapixels"
f" across {len(train_masks)} SSH maps\n"
f"Number of pixels that
超级会员免费看
订阅专栏 解锁全文
445

被折叠的 条评论
为什么被折叠?



