深度学习实验9-前馈神经网络(5)鸢尾花分类(代码)

数据集可视化

import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
plt.rcParams['font.sans-serif'] = ['Microsoft YaHei']
# 加载鸢尾花数据集
iris = datasets.load_iris()
X = iris.data[:, :2]  # 只取前两个特征:萼片长度和萼片宽度
y = iris.target

# 特征名称
feature_names = iris.feature_names[:2]

# 品种名称
target_names = iris.target_names

# 创建一个颜色映射
colors = ['red', 'green', 'blue']

# 绘制散点图
plt.figure(figsize=(10, 6))
for target, color in zip(range(3), colors):
    # 获取属于当前类别的样本
    mask = (y == target)
    plt.scatter(X[mask, 0], X[mask, 1], label=target_names[target], c=color, edgecolor='k', s=50)

# 添加标签和标题
plt.xlabel(feature_names[0])
plt.ylabel(feature_names[1])
plt.title('鸢尾花数据集散点图')
plt.legend()

代码

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from torch import nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
import math
plt.rcParams['font.sans-serif'] = ['Microsoft YaHei']  # 设置中文字体

# 加载并打乱鸢尾花数据集
def load_data(shuffle=True):
    iris = load_iris()
    X = iris.data
    y = iris.target
    if shuffle:
        indices = np.arange(X.shape[0])
        np.random.shuffle(indices)
        X = X[indices]
        y = y[indices]
    return X, y

# 定义鸢尾花数据集类
class IrisDataset(Dataset):
    def __init__(self, mode='train', num_train=120, num_dev=15):
        super(IrisDataset, self).__init__()
        X, y = load_data(shuffle=True)
        if mode == 'train':
            self.X, self.y = X[:num_train], y[:num_train]
        elif mode == 'dev':
            self.X, self.y = X[num_train:num_train + num_dev], y[num_train:num_train + num_dev]
        else:
            self.X, self.y = X[num_train + num_dev:], y[num_train + num_dev:]

    def __getitem__(self, idx):
        return self.X[idx].astype(np.float32), torch.tensor(self.y[idx], dtype=torch.long)  # 转换为浮点型和长整型

    def __len__(self):
        return len(self.y)

# 设置随机种子
torch.manual_seed(12)

# 创建数据集
train_dataset = IrisDataset(mode='train')
print("length of train set: ", len(train_dataset))  # 打印训练集长度

# 批量大小
batch_size = 16

# 加载数据
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
dev_dataset = IrisD
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值