这篇文章主要是看看别人的代码怎么写,自己目前对于分类问题还是不太会(菜哭了)
数据集
kaggle
数据地址:
https://www.kaggle.com/competitions/ml2021spring-hw2/data
代码地址:
https://www.kaggle.com/code/lizupeng/notebook4bf3cf8e90
代码
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory
import os
for dirname, _, filenames in os.walk('/kaggle/input'):
for filename in filenames:
print(os.path.join(dirname, filename))
# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All"
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session
“”“kaggle经典开头代码,没什么好说的”“”
———————————————分割线——————————————
print('Loading data ...')
data_root = '/kaggle/input/ml2021spring-hw2/timit_11/timit_11/'
# 加载训练数据集
train = np.load(data_root + 'train_11.npy')
train_label = np.load(data_root + 'train_label_11.npy')
# 测试集
test = np.load(data_root + 'test_11.npy')
"""
这里采用的np.load,不是pd
我们的文件形式.npy
ChatGPT:
np.load() 是 NumPy 库中用于加载 NumPy 保存的数组数据的函数。它允许你从磁盘上的 .npy 文件中加载数据并将其转换为 NumPy 数组,以便在 Python 中使用。
numpy.load(file, mmap_mode=None, allow_pickle=True, fix_imports=True, encoding='ASCII')
参数说明:
file: 要加载的 .npy 文件的路径。
mmap_mode:指定文件的内存映射模式。默认为 None,表示不使用内存映射。其他选项包括 'r'(只读模式),'r+'(读写模式)等。
allow_pickle:布尔值,指示是否允许加载包含 Python 对象的 .npy 文件。默认为 True。
fix_imports:布尔值,指示是否自动修复 Python 2 和 Python 3 之间的 pickle 文件中的 import 语句。默认为 True。
encoding:指定文件的编码。默认为 'ASCII'。
"""
print('Size of training data: {}'.format(train.shape))
print('Size of training data: {}'.format(test.shape))
Loading data ...
Size of training data: (1229932, 429)
Size of training data: (451552, 429)
# 定义数据集,经典的三个函数
import torch
from torch.utils.data import Dataset
# 定义数据集
class TIMITDataset(Dataset):
# 测试节和训练集都需要用这个类
def __init__(self, X, y=None):
self.data = torch.from_numpy(X).float()
# y=None 对应的是测试集
if y is not None:
y = y.astype(np.int32)
# 多分类任务,label是一个long
self.label = torch.LongTensor(y)
else:
self.