昇思MindSpore学习笔记5-02生成式--RNN实现情感分类

摘要:

        记录MindSpore AI框架使用RNN网络对自然语言进行情感分类的过程、步骤和方法。

        包括环境准备、下载数据集、数据集加载和预处理、构建模型、模型训练、模型测试等。

一、

情感分类。

RNN网络模型

实现效果:

        输入: This film is terrible

        正确标签: Negative

        预测标签: Negative

        输入: This film is great

        正确标签: Positive

        预测标签: Positive

二、环境准备

%%capture captured_output
# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
# 查看当前 mindspore 版本
!pip show mindspore

输出:

Name: mindspore
Version: 2.2.14
Summary: MindSpore is a new open source deep learning training/inference framework that could be used for mobile, edge and cloud scenarios.
Home-page: https://www.mindspore.cn
Author: The MindSpore Authors
Author-email: contact@mindspore.cn
License: Apache 2.0
Location: /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages
Requires: asttokens, astunparse, numpy, packaging, pillow, protobuf, psutil, scipy
Required-by: 

手动安装tqdm和requests库

!pip install tqdm requests

三、加载数据

IMDB影评数据集

        Positive

        Negative

预训练词向量

        编码自然语言单词

        获取文本语义特征

选取Glove词向量作为Embedding。

1.下载数据集和预训练词向量

requests库

        http请求

tqdm库

        下载百分比进度

IO方式下载临时文件

保存至指定路径并返回

import os
import shutil
import requests
import tempfile
from tqdm import tqdm
from typing import IO
from pathlib import Path
​
# 指定保存路径为 `home_path/.mindspore_examples`
cache_dir = Path.home() / '.mindspore_examples'
​
def http_get(url: str, temp_file: IO):
    """使用requests库下载数据,并使用tqdm库进行流程可视化"""
    req = requests.get(url, stream=True)
    content_length = req.headers.get('Content-Length')
    total = int(content_length) if content_length is not None else None
    progress = tqdm(unit='B', total=total)
    for chunk in req.iter_content(chunk_size=1024):
        if chunk:
            progress.update(len(chunk))
            temp_file.write(chunk)
    progress.close()
​
def download(file_name: str, url: str):
    """下载数据并存为指定名称"""
    if not os.path.exists(cache_dir):
        os.makedirs(cache_dir)
    cache_path = os.path.join(cache_dir, file_name)
    cache_exist = os.path.exists(cache_path)
    if not cache_exist:
        with tempfile.NamedTemporaryFile() as temp_file:
            http_get(url, temp_file)
            temp_file.flush()
            temp_file.seek(0)
            with open(cache_path, 'wb') as cache_file:
                shutil.copyfileobj(temp_file, cache_file)
    return cache_path

下载IMDB数据集(使用华为云镜像):

imdb_path = download('aclImdb_v1.tar.gz', 'https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/aclImdb_v1.tar.gz')
imdb_path

输出:

100%|██████████| 84125825/84125825 [00:02<00:00, 38051349.02B/s]
'/home/nginx/.mindspore_examples/aclImdb_v1.tar.gz'

2.IMDB加载模块

Python的tarfile库

        读取IMDB数据集tar.gz文件

        所有数据和标签分别进行存放

IMDB数据集解压目录如下:

    ├── aclImdb
    │   ├── imdbEr.txt
    │   ├── imdb.vocab
    │   ├── README
    │   ├── test
    │   └── train
    │         ├── neg
    │         ├── pos
    ...

数据集分两部分加载

        train

                neg

                pos

        Test

                neg

                pos

import re
import six
import string
import tarfile
​
class IMDBData():
    """IMDB数据集加载器
​
    加载IMDB数据集并处理为一个Python迭代对象。
​
    """
    label_map = {
        "pos": 1,
        "neg": 0
    }
    def __init__(self, path, mode="train"):
        self.mode = mode
        self.path = path
        self.docs, self.labels = [], []
​
        self._load("pos")
        self._load("neg")
​
    def _load(self, label):
        pattern = re.compile(r"aclImdb/{}/{}/.*\.txt$".format(self.mode, label))
        # 将数据加载至内存
        with tarfile.open(self.path) as tarf:
            tf = tarf.next()
            while tf is not None:
                if bool(pattern.match(tf.name)):
                    # 对文本进行分词、去除标点和特殊字符、小写处理
                    self.docs.append(str(tarf.extractfile(tf).read().rstrip(six.b("\n\r"))
                                         .translate(None, six.b(string.punctuation)).lower()).split())
                    self.labels.append([self.label_map[label]])
                tf = tarf.next()
​
    def __getitem__(self, idx):
        return self.docs[idx], self.labels[idx]
​
    def __len__(self):
        return len(self.docs)

3.加载训练数据集

imdb_train = IMDBData(imdb_path, 'train')
len(imdb_train)

输出:

25000

mindspore.dataset.

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

muren

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值