一、准备数据集
- Dataset
其使用需要自定义一个类继承Dataset—》class ImdbDataset(Dataset):
其主要作用是从文件中获取数据,主要有三个方法:
__init__该方法是初始化方法,用来获取所有文件的路径
__getitems__该方法是通过index索引获取到每个文件的内容(包括content,label),此处利用tokenlize对获取到的文本了预处理(处理掉不必要的字符)
__len__该方法是获取整个数据集的长度
注意: 数据集分为训练集和测试集,所以此处在处理数据集的时候,需要标注如 def init(self, train=True): 默认是训练集
class ImdbDataset(Dataset):
def __init__(self, train=True):
self.train_data_path = r"aclImdb\train"
self.test_data_path = r"aclImdb\test"
data_path = self.train_data_path if train else self.test_data_path
temp_data_path = [os.path.join(data_path, "pos"), os.path.join(data_path, "neg")]
self.total_file_path = []
for path in temp_data_path:
file_name_list = os.listdir(path)
file_path_list = [os.path.join(path, i) for i in file_name_list if i.endswith(".txt")]
self.total_file_path.extend(file_path_list)
def __getitem__(self, index):
file_path = self.total_file_path[index]
label_str = file_path.split("\\")[-2]
label = 0 if label_str == "neg" else 1
content = open(file_path,encoding="utf-8")