pytorch入门实战(李宏毅hw2_仿_读代码)

这篇文章主要是看看别人的代码怎么写,自己目前对于分类问题还是不太会(菜哭了)

数据集

kaggle

数据地址:
https://www.kaggle.com/competitions/ml2021spring-hw2/data
代码地址:
https://www.kaggle.com/code/lizupeng/notebook4bf3cf8e90

代码

# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

“”“kaggle经典开头代码,没什么好说的”“”

———————————————分割线——————————————

print('Loading data ...')

data_root = '/kaggle/input/ml2021spring-hw2/timit_11/timit_11/'
# 加载训练数据集
train = np.load(data_root + 'train_11.npy')
train_label = np.load(data_root + 'train_label_11.npy')
# 测试集
test = np.load(data_root + 'test_11.npy')
"""
这里采用的np.load,不是pd
我们的文件形式.npy
ChatGPT:
np.load() 是 NumPy 库中用于加载 NumPy 保存的数组数据的函数。它允许你从磁盘上的 .npy 文件中加载数据并将其转换为 NumPy 数组,以便在 Python 中使用。
numpy.load(file, mmap_mode=None, allow_pickle=True, fix_imports=True, encoding='ASCII')
参数说明:

file: 要加载的 .npy 文件的路径。
mmap_mode:指定文件的内存映射模式。默认为 None,表示不使用内存映射。其他选项包括 'r'(只读模式),'r+'(读写模式)等。
allow_pickle:布尔值,指示是否允许加载包含 Python 对象的 .npy 文件。默认为 True。
fix_imports:布尔值,指示是否自动修复 Python 2 和 Python 3 之间的 pickle 文件中的 import 语句。默认为 True。
encoding:指定文件的编码。默认为 'ASCII'。
"""
print('Size of training data: {}'.format(train.shape))
print('Size of training data: {}'.format(test.shape))

Loading data ...
Size of training data: (1229932, 429)
Size of training data: (451552, 429)

# 定义数据集,经典的三个函数
import torch
from torch.utils.data import Dataset

# 定义数据集
class TIMITDataset(Dataset):
    # 测试节和训练集都需要用这个类
    def __init__(self, X, y=None):
        self.data = torch.from_numpy(X).float()
        # y=None 对应的是测试集
        if y is not None:
            y = y.astype(np.int32)
            # 多分类任务,label是一个long
            self.label = torch.LongTensor(y)
        else:
            self.
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值