数据集可视化
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
plt.rcParams['font.sans-serif'] = ['Microsoft YaHei']
# 加载鸢尾花数据集
iris = datasets.load_iris()
X = iris.data[:, :2] # 只取前两个特征:萼片长度和萼片宽度
y = iris.target
# 特征名称
feature_names = iris.feature_names[:2]
# 品种名称
target_names = iris.target_names
# 创建一个颜色映射
colors = ['red', 'green', 'blue']
# 绘制散点图
plt.figure(figsize=(10, 6))
for target, color in zip(range(3), colors):
# 获取属于当前类别的样本
mask = (y == target)
plt.scatter(X[mask, 0], X[mask, 1], label=target_names[target], c=color, edgecolor='k', s=50)
# 添加标签和标题
plt.xlabel(feature_names[0])
plt.ylabel(feature_names[1])
plt.title('鸢尾花数据集散点图')
plt.legend()
代码
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from torch import nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
import math
plt.rcParams['font.sans-serif'] = ['Microsoft YaHei'] # 设置中文字体
# 加载并打乱鸢尾花数据集
def load_data(shuffle=True):
iris = load_iris()
X = iris.data
y = iris.target
if shuffle:
indices = np.arange(X.shape[0])
np.random.shuffle(indices)
X = X[indices]
y = y[indices]
return X, y
# 定义鸢尾花数据集类
class IrisDataset(Dataset):
def __init__(self, mode='train', num_train=120, num_dev=15):
super(IrisDataset, self).__init__()
X, y = load_data(shuffle=True)
if mode == 'train':
self.X, self.y = X[:num_train], y[:num_train]
elif mode == 'dev':
self.X, self.y = X[num_train:num_train + num_dev], y[num_train:num_train + num_dev]
else:
self.X, self.y = X[num_train + num_dev:], y[num_train + num_dev:]
def __getitem__(self, idx):
return self.X[idx].astype(np.float32), torch.tensor(self.y[idx], dtype=torch.long) # 转换为浮点型和长整型
def __len__(self):
return len(self.y)
# 设置随机种子
torch.manual_seed(12)
# 创建数据集
train_dataset = IrisDataset(mode='train')
print("length of train set: ", len(train_dataset)) # 打印训练集长度
# 批量大小
batch_size = 16
# 加载数据
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
dev_dataset = IrisD