【NLP251】命名实体识别常用模块(基于Transformer分类)

1. 从JSON格式的数据中加载并预处理样本供Ner任务训练和推理使用

class JsonNerDataset(Dataset):
    """
    定义一个加载json格式原始命名实体识别格式数据的Dataset
    一行一条样本(json字符串),包含: originalText、entities
    """

    def __init__(self,
                 data_path, tokenizer, categories, target_padding_idx=-100,
                 add_special_token=False, first_special_token='[CLS]', last_special_token='[SEP]'
                 ):
        super(JsonNerDataset, self).__init__()
        self.tokenizer: BertTokenizer = tokenizer
        # self.sentence_max_len = self.tokenizer.max_len_single_sentence  # 510
        self.sentence_max_len = 126
        self.categories = categories
        self.token_padding_idx = self.tokenizer.convert_tokens_to_ids('[PAD]')
        self.target_padding_idx = target_padding_idx
        self.add_special_token = add_special_token
        self.first_token_id = self.tokenizer.convert_tokens_to_ids(first_special_token)
        self.last_token_id = self.tokenizer.convert_tokens_to_ids(last_special_token)

        self.records = self.load_records(data_path)

    def load_records(self, data_path) -> list:
        records = []
        with open(data_path, "r", encoding="utf-8") as reader:
            for line in reader:
                # 1. 获取原始数据
                record = json.loads(line.strip())

                # 2. 原始的文本数据转换
                entities = record['entities']
                text = special_token_processor(record['originalText'])
                chars = list(text)  # 分字,就是每个字对应一个token

                # 3. 标签数据转换
                labels = ['O'] * len(chars)
                for entity in entities:
                    label_type = entity['label_type']
                    start_pos = entity['start_pos']  # 包含开始
                    end_pos = entity['end_pos']  # 不包含结尾
                    if end_pos - start_pos == 1:
                        # 只有一个字形成实体
                        labels[start_pos] = f'S-{label_type}'
                    elif end_pos - start_pos > 1:
                        # 多个字形成实体
                        labels[start_pos] = f'B-{label_type}'
                        labels[end_pos - 1] = f'E-{label_type}'
                        for i in range(start_pos + 1, end_pos - 1):
                            labels[i] = f'M-{label_type}'
                    else:
                        raise ValueError(f"数据异常:{record}")

                if self.add_special_token:
                    # 需要对chars、labels进行split分割, 单个样本的长度不能超过510
                    for sub_chars, sub_labels in split_chars(chars, labels, self.sentence_max_len):
                        x = self.tokenizer.convert_tokens_to_ids(sub_chars)  # 针对每个字获取对应的id
                        assert len(sub_chars) == len(x), "bert进行token id转换后,更改了列表长度."
                        y = [self.categories[c] for c in sub_labels]

                        x.insert(0, self.first_token_id)
                        x.append(self.last_token_id)
                        y.insert(0, self.categories['O'])
                        y.append(self.categories['O'])

                        assert len(x) == len(y), f"输入token的长度必须和标签的长度一致,当前长度为:{len(x)} - {len(y)} - {record}"
                        records.append((x, y, len(x)))
                else:
                    x = self.tokenizer.convert_tokens_to_ids(chars)  # 针对每个字获取对应的id
                    assert len(chars) == len(x), "bert进行token id转换后,更改了列表长度."
                    y = [self.categories[c] for c in labels]

                    assert len(x) == len(y), f"输入token的长度必须和标签的长度一致,当前长度为:{len(x)} - {len(y)} - {record}"
                    records.append((x, y, len(x)))
        return records

    def __getitem__(self, index):
        """
        获取index对应的样本信息,包括x和y
        :param index: 样本的索引id
        :return: x,y
        """
        x, y, num_tokens = self.records[index]
        return copy.deepcopy(x), copy.deepcopy(y), num_tokens

    def __len__(self):
        return len(self.records)

    def collate(self, batch):
        max_length = max([t[2] for t in batch])
        x, y, mask = [], [], []
        for i in range(len(batch))
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值