import sys
import os
import copy
import json
import logging
import argparse
import torch
import numpy as np
from tqdm import tqdm, trange
import torch.nn as nn
from torchcrf import CRF
from torch.utils.data import TensorDataset
from seqeval.metrics import precision_score, recall_score, f1_score
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers import BertConfig, AdamW, get_linear_schedule_with_warmup
from transformers import (
BertModel,
BertTokenizer,
BertPreTrainedModel,
)
sys.argv=['']
del sys
logger = logging.getLogger(__name__)
定义加载分词器tokenizer的函数(tokenizer是bert自己的分词器,它可以把词分开并且变为one hot编码)
def load_tokenizer(args):
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
return tokenizer
定义其他的函数
def get_intent_acc(preds, labels):
acc = (preds == labels).mean()
return {
"intent_acc": acc
}
def get_slot_metrics(preds, labels):
assert len(preds) == len(labels)
return {
"slot_precision": precision_score(labels, preds),
"slot_recall": recall_score(labels, preds),
"slot_f1": f1_score(labels, preds)
}
def get_sentence_frame_acc(intent_preds, intent_labels, slot_preds, slot_labels):
"""For the cases that intent and all the slots are correct (in one sentence)"""
# Get the intent comparison result
intent_result = (intent_preds == intent_labels)
# Get the slot comparision result
slot_result = []
for preds, labels in zip(slot_preds, slot_labels):
assert len(preds) == len(labels)
one_sent_result = True
for p, l in zip(preds, labels):
if p != l:
one_sent_result = False
break
slot_result.append(one_sent_result)
slot_result = np.array(slot_result)
sementic_acc = np.multiply(intent_result, slot_result).mean()
return {
"sementic_frame_acc": sementic_acc
}
def compute_metrics(intent_preds, intent_labels, slot_preds, slot_labels):
assert len(intent_preds) == len(intent_labels) == len(slot_preds) == len(slot_labels)
results = {}
intent_result = get_intent_acc(intent_preds, intent_labels)
slot_result = get_slot_metrics(slot_preds, slot_labels)
sementic_result = get_sentence_frame_acc(intent_preds, intent_labels, slot_preds, slot_labels)
results.update(intent_result)
results.update(slot_result)
results.update(sementic_result)
return results
argparse 模块可以让人轻松编写用户友好的命令行接口。程序定义它需要的参数,然后 argparse 将弄清如何从 sys.argv 解析出那些参数。 argparse 模块还会自动生成帮助和使用手册,并在用户给程序传入无效参数时报出错误信息。
https://docs.python.org/zh-cn/3/library/argparse.html
创建argparse 模块, 设置我们的task名称
parser = argparse.ArgumentParser()
parser.add_argument("--task", default='atis', required=False, type=str, help="The name of the task to train")
_StoreAction(option_strings=['--task'], dest='task', nargs=None, const=None, default='atis', type=<class 'str'>, choices=None, help='The name of the task to train', metavar=None)
添加其他超参数
parser.add_argument("--model_dir", default="./save_model", required=False, type=str, help="Path to save, load model")
parser.add_argument("--data_dir", default="./data", type=str, help="The input data dir")
parser.add_argument("--intent_label_file", default="intent_label.txt", type=str, help="Intent Label file")
parser.add_argument("--slot_label_file", default="slot_label.txt", type=str, help="Slot Label file")
parser.add_argument("--model_type", default="bert", type=str, help=" Bert is the Model")
parser.add_argument('--seed', type=int, default=1234, help="random seed for initialization")
parser.add_argument("--train_batch_size", default=32, type=int, help="Batch size for training.")
parser.add_argument("--eval_batch_size", default=64, type=int, help="Batch size for evaluation.")
parser.add_argument("--max_seq_len", default=50, type=int, help="The maximum total input sequence length after tokenization.")
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
parser.add_argument("--num_train_epochs", default=2.0, type=float, help="Total number of training epochs to perform.")
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.")
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--max_steps", default=-1, type=int, help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
parser.add_argument("--dropout_rate", default=0.1, type=float, help="Dropout for fully-connected layers")
parser.add_argument('--logging_steps', type=int, default=200, help="Log every X updates steps.")
parser.add_argument('--save_steps', type=int, default=200, help="Save checkpoint every X updates steps.")
parser.add_argument("--do_train", default=True, action="store_true", help="Whether to run training.")
parser.add_argument("--do_eval", default=True, action="store_true", help="Whether to run eval on the test set.")
parser.add_argument("--no_cuda", default=True, action="store_true", help="Avoid using CUDA when available")
parser.add_argument("--ignore_index", default=0, type=int,
help='Specifies a target value that is ignored and does not contribute to the input gradient')
parser.add_argument('--slot_loss_coef', type=float, default=1.0, help='Coefficient for the slot loss.')
# CRF option
parser.add_argument("--use_crf", default=True, action="store_true", help="Whether to use CRF")
parser.add_argument("--slot_pad_label", default="PAD", type=str, help="Pad token for slot label pad (to be ignore when calculate loss)")
_StoreAction(option_strings=['--slot_pad_label'], dest='slot_pad_label', nargs=None, const=None, default='PAD', type=<class 'str'>, choices=None, help='Pad token for slot label pad (to be ignore when calculate loss)', metavar=None)
实例化parser
args = parser.parse_args()
args.model_name_or_path = 'bert-base-uncased'
看看arg里面有什么
args
Namespace(adam_epsilon=1e-08, data_dir='./data', do_eval=True, do_train=True, dropout_rate=0.1, eval_batch_size=64, gradient_accumulation_steps=1, ignore_index=0, intent_label_file='intent_label.txt', learning_rate=5e-05, logging_steps=200, max_grad_norm=1.0, max_seq_len=50, max_steps=-1, model_dir='./save_model', model_name_or_path='bert-base-uncased', model_type='bert', no_cuda=True, num_train_epochs=2.0, save_steps=200, seed=1234, slot_label_file='slot_label.txt', slot_loss_coef=1.0, slot_pad_label='PAD', task='atis', train_batch_size=32, use_crf=True, warmup_steps=0, weight_decay=0.0)
数据输入模块
官方文档 https://huggingface.co/transformers/main_classes/processors.html?highlight=inputexample#transformers.data.processors.utils.InputExample
class InputExample(object):
"""
A single training/test example for simple sequence classification.
Args:
guid: Unique id for the example.
words: list. The words of the sequence.
intent_label: (Optional) string. The intent label of the example.
slot_labels: (Optional) list. The slot labels of the example.
"""
def __init__(self, guid, words, intent_label=None, slot_labels=None):
self.guid = guid
self.words = words
self.intent_label = intent_label
self.slot_labels = slot_labels
def __repr__(self):
return str(self.to_json_string())
def to_dict(self):
"""Serializes this instance to a Python dictionary."""
output = copy.deepcopy(self.__dict__)
return output
def to_json_string(self):
"""Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
class InputFeatures(object):
"""A single set of features of data."""
def __init__(self, input_ids, attention_mask, token_type_ids, intent_label_id, slot_labels_ids):
self.input_ids = input_ids
self.attention_mask = attention_mask
self.token_type_ids = token_type_ids
self.intent_label_id = intent_label_id
self.slot_labels_ids = slot_labels_ids
def __repr__(self):
return str(self.to_json_string())
def to_dict(self):
"""Serializes this instance to a Python dictionary."""
output = copy.deepcopy(self.__dict__)
return output
def to_json_string(self):
"""Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
class JointProcessor(object):
"""Processor for the JointBERT data set """
def __init__(self, args):
self.args = args
self.intent_labels = [label.strip() for label in open("data/atis/intent_label.txt", 'r', encoding='utf-8')]
self.slot_labels = [label.strip() for label in open("data/atis/slot_label.txt", 'r', encoding='utf-8')]
self.input_text_file = 'seq.in'
self.intent_label_file = 'label'
self.slot_labels_file = 'seq.out'
@classmethod
def _read_file(cls, input_file, quotechar=None):
"""Reads a tab separated value file."""
with open(input_file, "r", encoding="utf-8") as f:
lines = []
for line in f:
lines.append(line.strip())
return lines
def _create_examples(self, texts, intents, slots, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for i, (text, intent, slot) in enumerate(zip(texts, intents, slots)):
guid = "%s-%s" % (set_type, i)
# 1. input_text
words = text.split() # Some are spaced twice
# 2. intent
intent_label = self.intent_labels.index(intent) if intent in self.intent_labels else self.intent_labels.index("UNK")
# 3. slot
slot_labels = []
for s in slot.split():
slot_labels.append(self.slot_labels.index(s) if s in self.slot_labels else self.slot_labels.index("UNK"))
assert len(words) == len(slot_labels)
examples.append(InputExample(guid=guid, words=words, intent_label=intent_label, slot_labels=slot_labels))
return examples
def get_examples(self, mode):
"""
Args:
mode: train, dev, test
"""
data_path = os.path.join(self.args.data_dir, self.args.task, mode)
logger.info("LOOKING AT {}".format(data_path))
return self._create_examples(texts=self._read_file(os.path.join(data_path, self.input_text_file)),
intents=self._read_file(os.path.join(data_path, self.intent_label_file)),
slots=self._read_file(os.path.join(data_path, self.slot_labels_file)),
set_type=mode)
把数据转为输入bert的格式
def convert_examples_to_features(examples, max_seq_len, tokenizer,
pad_token_label_id=-100,
cls_token_segment_id=0,
pad_token_segment_id=0,
sequence_a_segment_id=0,
mask_padding_with_zero=True):
# Setting based on the current model type
cls_token = tokenizer.cls_token
sep_token = tokenizer.sep_token
unk_token = tokenizer.unk_token
pad_token_id = tokenizer.pad_token_id
features = []
for (ex_index, example) in enumerate(examples):
if ex_index % 5000 == 0:
logger.info("Writing example %d of %d" % (ex_index, len(examples)))
# Tokenize word by word (for NER)
tokens = []
slot_labels_ids = []
for word, slot_label in zip(example.words, example.slot_labels):
word_tokens = tokenizer.tokenize(word)
if not word_tokens:
word_tokens = [unk_token] # For handling the bad-encoded word
tokens.extend(word_tokens)
# Use the real label id for the first token of the word, and padding ids for the remaining tokens
slot_labels_ids.extend([int(slot_label)] + [pad_token_label_id] * (len(word_tokens) - 1))
# Account for [CLS] and [SEP]
special_tokens_count = 2
if len(tokens) > max_seq_len - special_tokens_count:
tokens = tokens[:(max_seq_len - special_tokens_count)]
slot_labels_ids = slot_labels_ids[:(max_seq_len - special_tokens_count)]
# Add [SEP] token
tokens += [sep_token]
slot_labels_ids += [pad_token_label_id]
token_type_ids = [sequence_a_segment_id] * len(tokens)
# Add [CLS] token
tokens = [cls_token] + tokens
slot_labels_ids = [pad_token_label_id] + slot_labels_ids
token_type_ids = [cls_token_segment_id] + token_type_ids
input_ids = tokenizer.convert_tokens_to_ids(tokens)
# The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to.
attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
# Zero-pad up to the sequence length.
padding_length = max_seq_len - len(input_ids)
input_ids = input_ids + ([pad_token_id] * padding_length)
attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length)
slot_labels_ids = slot_labels_ids + ([pad_token_label_id] * padding_length)
assert len(input_ids) == max_seq_len, "Error with input length {} vs {}".format(len(input_ids), max_seq_len)
assert len(attention_mask) == max_seq_len, "Error with attention mask length {} vs {}".format(len(attention_mask), max_seq_len)
assert len(token_type_ids) == max_seq_len, "Error with token type length {} vs {}".format(len(token_type_ids), max_seq_len)
assert len(slot_labels_ids) == max_seq_len, "Error with slot labels length {} vs {}".format(len(slot_labels_ids), max_seq_len)
intent_label_id = int(example.intent_label)
features.append(
InputFeatures(input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
intent_label_id=intent_label_id,
slot_labels_ids=slot_labels_ids
))
return features
把数据转为dataset
def load_and_cache_examples(tokenizer, mode):
processor = JointProcessor
# Load data features from cache or dataset file
cached_features_file = os.path.join(
args.data_dir,
'cached_{}_{}_{}_{}'.format(
mode,
"atis",
list(filter(None, args.model_name_or_path .split("/"))).pop(),
args.max_seq_len
)
)
if os.path.exists(cached_features_file):
logger.info("Loading features from cached file %s", cached_features_file)
features = torch.load(cached_features_file)
else:
# Load data features from dataset file
logger.info("Creating features from dataset file at %s", args.data_dir)
if mode == "train":
examples = processor.get_examples("train")
elif mode == "dev":
examples = processor.get_examples("dev")
elif mode == "test":
examples = processor.get_examples("test")
else:
raise Exception("For mode, Only train, dev, test is available")
# Use cross entropy ignore index as padding label id so that only real label ids contribute to the loss later
pad_token_label_id = 0
features = convert_examples_to_features(examples, args.max_seq_len, tokenizer,
pad_token_label_id=pad_token_label_id)
logger.info("Saving features into cached file %s", cached_features_file)
torch.save(features, cached_features_file)
# Convert to Tensors and build dataset
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
all_intent_label_ids = torch.tensor([f.intent_label_id for f in features], dtype=torch.long)
all_slot_labels_ids = torch.tensor([f.slot_labels_ids for f in features], dtype=torch.long)
dataset = TensorDataset(all_input_ids, all_attention_mask,
all_token_type_ids, all_intent_label_ids, all_slot_labels_ids)
return dataset
制作数据集
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
train_dataset = load_and_cache_examples(tokenizer, mode="train")
dev_dataset = load_and_cache_examples(tokenizer, mode="dev")
test_dataset = load_and_cache_examples(tokenizer, mode="test")
模型
意图识别分类器
class IntentClassifier(nn.Module):
def __init__(self, input_dim, num_intent_labels, dropout_rate=0.):
super(IntentClassifier, self).__init__()
self.dropout = nn.Dropout(dropout_rate)
self.linear = nn.Linear(input_dim, num_intent_labels)
def forward(self, x):
x = self.dropout(x)
return self.linear(x)
词槽识别
class SlotClassifier(nn.Module):
def __init__(self, input_dim, num_slot_labels, dropout_rate=0.):
super(SlotClassifier, self).__init__()
self.dropout = nn.Dropout(dropout_rate)
self.linear = nn.Linear(input_dim, num_slot_labels)
def forward(self, x):
x = self.dropout(x)
return self.linear(x)
JointBERT 模型定义
class JointBERT(BertPreTrainedModel):
def __init__(self, config, args, intent_label_lst, slot_label_lst):
super(JointBERT, self).__init__(config)
self.args = args
self.num_intent_labels = len(intent_label_lst)
self.num_slot_labels = len(slot_label_lst)
self.bert = BertModel(config=config) # Load pretrained bert
self.intent_classifier = IntentClassifier(config.hidden_size, self.num_intent_labels, args.dropout_rate)
self.slot_classifier = SlotClassifier(config.hidden_size, self.num_slot_labels, args.dropout_rate)
if args.use_crf:
self.crf = CRF(num_tags=self.num_slot_labels, batch_first=True)
def forward(self, input_ids, attention_mask, token_type_ids, intent_label_ids, slot_labels_ids):
outputs = self.bert(input_ids, attention_mask=attention_mask,
token_type_ids=token_type_ids) # sequence_output, pooled_output, (hidden_states), (attentions)
sequence_output = outputs[0]
pooled_output = outputs[1] # [CLS]
intent_logits = self.intent_classifier(pooled_output)
slot_logits = self.slot_classifier(sequence_output)
total_loss = 0
# 1. Intent Softmax
if intent_label_ids is not None:
if self.num_intent_labels == 1:
intent_loss_fct = nn.MSELoss()
intent_loss = intent_loss_fct(intent_logits.view(-1), intent_label_ids.view(-1))
else:
intent_loss_fct = nn.CrossEntropyLoss()
intent_loss = intent_loss_fct(intent_logits.view(-1, self.num_intent_labels), intent_label_ids.view(-1))
total_loss += intent_loss
# 2. Slot Softmax
if slot_labels_ids is not None:
if self.args.use_crf:
slot_loss = self.crf(slot_logits, slot_labels_ids, mask=attention_mask.byte(), reduction='mean')
slot_loss = -1 * slot_loss # negative log-likelihood
else:
slot_loss_fct = nn.CrossEntropyLoss(ignore_index=self.args.ignore_index)
# Only keep active parts of the loss
if attention_mask is not None:
active_loss = attention_mask.view(-1) == 1
active_logits = slot_logits.view(-1, self.num_slot_labels)[active_loss]
active_labels = slot_labels_ids.view(-1)[active_loss]
slot_loss = slot_loss_fct(active_logits, active_labels)
else:
slot_loss = slot_loss_fct(slot_logits.view(-1, self.num_slot_labels), slot_labels_ids.view(-1))
total_loss += self.args.slot_loss_coef * slot_loss
outputs = ((intent_logits, slot_logits),) + outputs[2:] # add hidden states and attention if they are here
outputs = (total_loss,) + outputs
return outputs # (loss), logits, (hidden_states), (attentions) # Logits is a tuple of intent and slot logits
训练模型
class Trainer(object):
def __init__(self, args, train_dataset=None, dev_dataset=None, test_dataset=None):
self.args = args
self.train_dataset = train_dataset
self.dev_dataset = dev_dataset
self.test_dataset = test_dataset
self.intent_label_lst = [label.strip() for label in open("data/atis/intent_label.txt", 'r', encoding='utf-8')]
self.slot_label_lst = [label.strip() for label in open("data/atis/slot_label.txt", 'r', encoding='utf-8')]
# Use cross entropy ignore index as padding label id so that only real label ids contribute to the loss later
self.pad_token_label_id = args.ignore_index
self.config = BertConfig.from_pretrained(self.args.model_name_or_path, finetuning_task=self.args.task)
self.model = JointBERT.from_pretrained(self.args.model_name_or_path,
config=self.config,
args=self.args,
intent_label_lst=self.intent_label_lst,
slot_label_lst=self.slot_label_lst)
# GPU or CPU
self.device = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"
self.model.to(self.device)
def train(self):
train_sampler = RandomSampler(self.train_dataset)
train_dataloader = DataLoader(self.train_dataset, sampler=train_sampler, batch_size=self.args.train_batch_size)
if self.args.max_steps > 0:
t_total = self.args.max_steps
self.args.num_train_epochs = self.args.max_steps // (len(train_dataloader) // self.args.gradient_accumulation_steps) + 1
else:
t_total = len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs
# Prepare optimizer and schedule (linear warmup and decay)
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{'params': [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
'weight_decay': self.args.weight_decay},
{'params': [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters, lr=self.args.learning_rate, eps=self.args.adam_epsilon)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=t_total)
global_step = 0
tr_loss = 0.0
self.model.zero_grad()
train_iterator = trange(int(self.args.num_train_epochs), desc="Epoch")
for _ in train_iterator:
epoch_iterator = tqdm(train_dataloader, desc="Iteration")
for step, batch in enumerate(epoch_iterator):
self.model.train()
batch = tuple(t.to(self.device) for t in batch) # GPU or CPU
inputs = {'input_ids': batch[0],
'attention_mask': batch[1],
'intent_label_ids': batch[3],
'slot_labels_ids': batch[4]}
if self.args.model_type != 'distilbert':
inputs['token_type_ids'] = batch[2]
outputs = self.model(**inputs)
loss = outputs[0]
if self.args.gradient_accumulation_steps > 1:
loss = loss / self.args.gradient_accumulation_steps
loss.backward()
tr_loss += loss.item()
if (step + 1) % self.args.gradient_accumulation_steps == 0:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm)
optimizer.step()
scheduler.step() # Update learning rate schedule
self.model.zero_grad()
global_step += 1
if self.args.logging_steps > 0 and global_step % self.args.logging_steps == 0:
self.evaluate("dev")
if self.args.save_steps > 0 and global_step % self.args.save_steps == 0:
self.save_model()
if 0 < self.args.max_steps < global_step:
epoch_iterator.close()
break
if 0 < self.args.max_steps < global_step:
train_iterator.close()
break
return global_step, tr_loss / global_step
def evaluate(self, mode):
if mode == 'test':
dataset = self.test_dataset
elif mode == 'dev':
dataset = self.dev_dataset
else:
raise Exception("Only dev and test dataset available")
eval_sampler = SequentialSampler(dataset)
eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=self.args.eval_batch_size)
eval_loss = 0.0
nb_eval_steps = 0
intent_preds = None
slot_preds = None
out_intent_label_ids = None
out_slot_labels_ids = None
self.model.eval()
for batch in tqdm(eval_dataloader, desc="Evaluating"):
batch = tuple(t.to(self.device) for t in batch)
with torch.no_grad():
inputs = {'input_ids': batch[0],
'attention_mask': batch[1],
'intent_label_ids': batch[3],
'slot_labels_ids': batch[4]}
if self.args.model_type != 'distilbert':
inputs['token_type_ids'] = batch[2]
outputs = self.model(**inputs)
tmp_eval_loss, (intent_logits, slot_logits) = outputs[:2]
eval_loss += tmp_eval_loss.mean().item()
nb_eval_steps += 1
# Intent prediction
if intent_preds is None:
intent_preds = intent_logits.detach().cpu().numpy()
out_intent_label_ids = inputs['intent_label_ids'].detach().cpu().numpy()
else:
intent_preds = np.append(intent_preds, intent_logits.detach().cpu().numpy(), axis=0)
out_intent_label_ids = np.append(
out_intent_label_ids, inputs['intent_label_ids'].detach().cpu().numpy(), axis=0)
# Slot prediction
if slot_preds is None:
if self.args.use_crf:
# decode() in `torchcrf` returns list with best index directly
slot_preds = np.array(self.model.crf.decode(slot_logits))
else:
slot_preds = slot_logits.detach().cpu().numpy()
out_slot_labels_ids = inputs["slot_labels_ids"].detach().cpu().numpy()
else:
if self.args.use_crf:
slot_preds = np.append(slot_preds, np.array(self.model.crf.decode(slot_logits)), axis=0)
else:
slot_preds = np.append(slot_preds, slot_logits.detach().cpu().numpy(), axis=0)
out_slot_labels_ids = np.append(out_slot_labels_ids, inputs["slot_labels_ids"].detach().cpu().numpy(), axis=0)
eval_loss = eval_loss / nb_eval_steps
results = {
"loss": eval_loss
}
# Intent result
intent_preds = np.argmax(intent_preds, axis=1)
# Slot result
if not self.args.use_crf:
slot_preds = np.argmax(slot_preds, axis=2)
slot_label_map = {i: label for i, label in enumerate(self.slot_label_lst)}
out_slot_label_list = [[] for _ in range(out_slot_labels_ids.shape[0])]
slot_preds_list = [[] for _ in range(out_slot_labels_ids.shape[0])]
for i in range(out_slot_labels_ids.shape[0]):
for j in range(out_slot_labels_ids.shape[1]):
if out_slot_labels_ids[i, j] != self.pad_token_label_id:
out_slot_label_list[i].append(slot_label_map[out_slot_labels_ids[i][j]])
slot_preds_list[i].append(slot_label_map[slot_preds[i][j]])
total_result = compute_metrics(intent_preds, out_intent_label_ids, slot_preds_list, out_slot_label_list)
results.update(total_result)
logger.info("***** Eval results *****")
for key in sorted(results.keys()):
# logger.info(" %s = %s", key, str(results[key]))
print(" %s = %s" %(key, str(results[key])))
return results
def save_model(self):
# Save model checkpoint (Overwrite)
if not os.path.exists(self.args.model_dir):
os.makedirs(self.args.model_dir)
model_to_save = self.model.module if hasattr(self.model, 'module') else self.model
model_to_save.save_pretrained(self.args.model_dir)
# Save training arguments together with the trained model
torch.save(self.args, os.path.join(self.args.model_dir, 'training_args.bin'))
print("Saving model checkpoint to %s" % self.args.model_dir)
def load_model(self):
# Check whether model exists
if not os.path.exists(self.args.model_dir):
raise Exception("Model doesn't exists! Train first!")
try:
self.model = self.model_class.from_pretrained(self.args.model_dir,
args=self.args,
intent_label_lst=self.intent_label_lst,
slot_label_lst=self.slot_label_lst)
self.model.to(self.device)
print("***** Model Loaded *****")
except:
raise Exception("Some model files might be missing...")
开始训练模型
trainer = Trainer(args, train_dataset, dev_dataset, test_dataset)
trainer.train()
Epoch: 0%| | 0/2 [00:00<?, ?it/s]
Iteration: 0%| | 0/140 [00:00<?, ?it/s][A
Iteration: 1%| | 1/140 [00:01<03:36, 1.56s/it][A
Iteration: 1%|▏ | 2/140 [00:03<03:35, 1.56s/it][A
Iteration: 2%|▏ | 3/140 [00:04<03:33, 1.55s/it][A
Iteration: 3%|▎ | 4/140 [00:06<03:30, 1.55s/it][A
Iteration: 4%|▎ | 5/140 [00:07<03:30, 1.56s/it][A
Iteration: 4%|▍ | 6/140 [00:09<03:28, 1.55s/it][A
Iteration: 5%|▌ | 7/140 [00:10<03:24, 1.54s/it][A
Iteration: 6%|▌ | 8/140 [00:12<03:23, 1.54s/it][A
Iteration: 6%|▋ | 9/140 [00:13<03:23, 1.55s/it][A
Iteration: 7%|▋ | 10/140 [00:15<03:23, 1.56s/it][A
Iteration: 8%|▊ | 11/140 [00:17<03:21, 1.56s/it][A
Iteration: 9%|▊ | 12/140 [00:18<03:19, 1.56s/it][A
Iteration: 9%|▉ | 13/140 [00:20<03:17, 1.56s/it][A
Iteration: 10%|█ | 14/140 [00:21<03:18, 1.57s/it][A
Iteration: 11%|█ | 15/140 [00:23<03:14, 1.56s/it][A
Iteration: 11%|█▏ | 16/140 [00:24<03:13, 1.56s/it][A
Iteration: 12%|█▏ | 17/140 [00:26<03:10, 1.55s/it][A
Iteration: 13%|█▎ | 18/140 [00:27<03:08, 1.54s/it][A
Iteration: 14%|█▎ | 19/140 [00:29<03:04, 1.53s/it][A
Iteration: 14%|█▍ | 20/140 [00:30<03:01, 1.51s/it][A
Iteration: 15%|█▌ | 21/140 [00:32<02:59, 1.51s/it][A
Iteration: 16%|█▌ | 22/140 [00:33<02:57, 1.51s/it][A
Iteration: 16%|█▋ | 23/140 [00:35<02:55, 1.50s/it][A
Iteration: 17%|█▋ | 24/140 [00:36<02:54, 1.50s/it][A
Iteration: 18%|█▊ | 25/140 [00:38<02:52, 1.50s/it][A
Iteration: 19%|█▊ | 26/140 [00:39<02:50, 1.50s/it][A
Iteration: 19%|█▉ | 27/140 [00:41<02:48, 1.49s/it][A
Iteration: 20%|██ | 28/140 [00:42<02:46, 1.49s/it][A
Iteration: 21%|██ | 29/140 [00:44<02:46, 1.50s/it][A
Iteration: 21%|██▏ | 30/140 [00:45<02:45, 1.51s/it][A
Iteration: 22%|██▏ | 31/140 [00:47<02:44, 1.51s/it][A
Iteration: 23%|██▎ | 32/140 [00:48<02:42, 1.50s/it][A
Iteration: 24%|██▎ | 33/140 [00:50<02:41, 1.51s/it][A
Iteration: 24%|██▍ | 34/140 [00:51<02:39, 1.50s/it][A
Iteration: 25%|██▌ | 35/140 [00:53<02:37, 1.50s/it][A
Iteration: 26%|██▌ | 36/140 [00:54<02:35, 1.49s/it][A
Iteration: 26%|██▋ | 37/140 [00:56<02:33, 1.49s/it][A
Iteration: 27%|██▋ | 38/140 [00:57<02:31, 1.49s/it][A
Iteration: 28%|██▊ | 39/140 [00:59<02:30, 1.49s/it][A
Iteration: 29%|██▊ | 40/140 [01:00<02:28, 1.49s/it][A
Iteration: 29%|██▉ | 41/140 [01:02<02:27, 1.49s/it][A
Iteration: 30%|███ | 42/140 [01:03<02:26, 1.49s/it][A
Iteration: 31%|███ | 43/140 [01:05<02:24, 1.49s/it][A
Iteration: 31%|███▏ | 44/140 [01:06<02:23, 1.49s/it][A
Iteration: 32%|███▏ | 45/140 [01:08<02:21, 1.49s/it][A
Iteration: 33%|███▎ | 46/140 [01:09<02:20, 1.49s/it][A
Iteration: 34%|███▎ | 47/140 [01:11<02:18, 1.49s/it][A
Iteration: 34%|███▍ | 48/140 [01:12<02:17, 1.49s/it][A
Iteration: 35%|███▌ | 49/140 [01:14<02:15, 1.49s/it][A
Iteration: 36%|███▌ | 50/140 [01:15<02:14, 1.50s/it][A
Iteration: 36%|███▋ | 51/140 [01:17<02:13, 1.50s/it][A
Iteration: 37%|███▋ | 52/140 [01:18<02:11, 1.50s/it][A
Iteration: 38%|███▊ | 53/140 [01:20<02:10, 1.50s/it][A
Iteration: 39%|███▊ | 54/140 [01:21<02:08, 1.50s/it][A
Iteration: 39%|███▉ | 55/140 [01:23<02:07, 1.50s/it][A
Iteration: 40%|████ | 56/140 [01:24<02:05, 1.50s/it][A
Iteration: 41%|████ | 57/140 [01:26<02:04, 1.50s/it][A
Iteration: 41%|████▏ | 58/140 [01:27<02:03, 1.50s/it][A
Iteration: 42%|████▏ | 59/140 [01:29<02:02, 1.51s/it][A
Iteration: 43%|████▎ | 60/140 [01:30<02:00, 1.51s/it][A
Iteration: 44%|████▎ | 61/140 [01:32<01:58, 1.50s/it][A
Iteration: 44%|████▍ | 62/140 [01:33<01:56, 1.50s/it][A
Iteration: 45%|████▌ | 63/140 [01:35<01:55, 1.50s/it][A
Iteration: 46%|████▌ | 64/140 [01:36<01:54, 1.51s/it][A
Iteration: 46%|████▋ | 65/140 [01:38<01:53, 1.51s/it][A
Iteration: 47%|████▋ | 66/140 [01:39<01:51, 1.50s/it][A
Iteration: 48%|████▊ | 67/140 [01:41<01:49, 1.50s/it][A
Iteration: 49%|████▊ | 68/140 [01:42<01:47, 1.50s/it][A
Iteration: 49%|████▉ | 69/140 [01:44<01:46, 1.50s/it][A
Iteration: 50%|█████ | 70/140 [01:45<01:44, 1.50s/it][A
Iteration: 51%|█████ | 71/140 [01:47<01:43, 1.50s/it][A
Iteration: 51%|█████▏ | 72/140 [01:48<01:42, 1.51s/it][A
Iteration: 52%|█████▏ | 73/140 [01:50<01:41, 1.52s/it][A
Iteration: 53%|█████▎ | 74/140 [01:51<01:40, 1.52s/it][A
Iteration: 54%|█████▎ | 75/140 [01:53<01:38, 1.52s/it][A
Iteration: 54%|█████▍ | 76/140 [01:54<01:37, 1.52s/it][A
Iteration: 55%|█████▌ | 77/140 [01:56<01:35, 1.51s/it][A
Iteration: 56%|█████▌ | 78/140 [01:57<01:33, 1.51s/it][A
Iteration: 56%|█████▋ | 79/140 [01:59<01:32, 1.51s/it][A
Iteration: 57%|█████▋ | 80/140 [02:00<01:30, 1.51s/it][A
Iteration: 58%|█████▊ | 81/140 [02:02<01:28, 1.50s/it][A
Iteration: 59%|█████▊ | 82/140 [02:03<01:27, 1.51s/it][A
Iteration: 59%|█████▉ | 83/140 [02:05<01:25, 1.50s/it][A
Iteration: 60%|██████ | 84/140 [02:07<01:25, 1.52s/it][A
Iteration: 61%|██████ | 85/140 [02:08<01:24, 1.53s/it][A
Iteration: 61%|██████▏ | 86/140 [02:10<01:22, 1.53s/it][A
Iteration: 62%|██████▏ | 87/140 [02:11<01:21, 1.54s/it][A
Iteration: 63%|██████▎ | 88/140 [02:13<01:20, 1.55s/it][A
Iteration: 64%|██████▎ | 89/140 [02:14<01:19, 1.56s/it][A
Iteration: 64%|██████▍ | 90/140 [02:16<01:17, 1.55s/it][A
Iteration: 65%|██████▌ | 91/140 [02:17<01:15, 1.55s/it][A
Iteration: 66%|██████▌ | 92/140 [02:19<01:13, 1.53s/it][A
Iteration: 66%|██████▋ | 93/140 [02:20<01:11, 1.52s/it][A
Iteration: 67%|██████▋ | 94/140 [02:22<01:09, 1.52s/it][A
Iteration: 68%|██████▊ | 95/140 [02:23<01:07, 1.51s/it][A
Iteration: 69%|██████▊ | 96/140 [02:25<01:06, 1.51s/it][A
Iteration: 69%|██████▉ | 97/140 [02:26<01:04, 1.51s/it][A
Iteration: 70%|███████ | 98/140 [02:28<01:03, 1.50s/it][A
Iteration: 71%|███████ | 99/140 [02:29<01:02, 1.51s/it][A
Iteration: 71%|███████▏ | 100/140 [02:31<01:00, 1.51s/it][A
Iteration: 72%|███████▏ | 101/140 [02:32<00:58, 1.51s/it][A
Iteration: 73%|███████▎ | 102/140 [02:34<00:57, 1.50s/it][A
Iteration: 74%|███████▎ | 103/140 [02:35<00:55, 1.50s/it][A
Iteration: 74%|███████▍ | 104/140 [02:37<00:53, 1.50s/it][A
Iteration: 75%|███████▌ | 105/140 [02:38<00:52, 1.50s/it][A
Iteration: 76%|███████▌ | 106/140 [02:40<00:50, 1.50s/it][A
Iteration: 76%|███████▋ | 107/140 [02:41<00:49, 1.50s/it][A
Iteration: 77%|███████▋ | 108/140 [02:43<00:47, 1.50s/it][A
Iteration: 78%|███████▊ | 109/140 [02:44<00:46, 1.50s/it][A
Iteration: 79%|███████▊ | 110/140 [02:46<00:45, 1.50s/it][A
Iteration: 79%|███████▉ | 111/140 [02:47<00:43, 1.50s/it][A
Iteration: 80%|████████ | 112/140 [02:49<00:41, 1.50s/it][A
Iteration: 81%|████████ | 113/140 [02:50<00:40, 1.50s/it][A
Iteration: 81%|████████▏ | 114/140 [02:52<00:38, 1.49s/it][A
Iteration: 82%|████████▏ | 115/140 [02:53<00:37, 1.50s/it][A
Iteration: 83%|████████▎ | 116/140 [02:55<00:35, 1.50s/it][A
Iteration: 84%|████████▎ | 117/140 [02:56<00:34, 1.50s/it][A
Iteration: 84%|████████▍ | 118/140 [02:58<00:32, 1.50s/it][A
Iteration: 85%|████████▌ | 119/140 [02:59<00:31, 1.50s/it][A
Iteration: 86%|████████▌ | 120/140 [03:01<00:29, 1.50s/it][A
Iteration: 86%|████████▋ | 121/140 [03:02<00:28, 1.49s/it][A
Iteration: 87%|████████▋ | 122/140 [03:04<00:26, 1.50s/it][A
Iteration: 88%|████████▊ | 123/140 [03:05<00:25, 1.50s/it][A
Iteration: 89%|████████▊ | 124/140 [03:07<00:23, 1.50s/it][A
Iteration: 89%|████████▉ | 125/140 [03:08<00:22, 1.50s/it][A
Iteration: 90%|█████████ | 126/140 [03:10<00:20, 1.50s/it][A
Iteration: 91%|█████████ | 127/140 [03:11<00:19, 1.50s/it][A
Iteration: 91%|█████████▏| 128/140 [03:13<00:17, 1.50s/it][A
Iteration: 92%|█████████▏| 129/140 [03:14<00:16, 1.50s/it][A
Iteration: 93%|█████████▎| 130/140 [03:16<00:14, 1.49s/it][A
Iteration: 94%|█████████▎| 131/140 [03:17<00:13, 1.49s/it][A
Iteration: 94%|█████████▍| 132/140 [03:19<00:11, 1.49s/it][A
Iteration: 95%|█████████▌| 133/140 [03:20<00:10, 1.49s/it][A
Iteration: 96%|█████████▌| 134/140 [03:22<00:08, 1.49s/it][A
Iteration: 96%|█████████▋| 135/140 [03:23<00:07, 1.49s/it][A
Iteration: 97%|█████████▋| 136/140 [03:25<00:05, 1.49s/it][A
Iteration: 98%|█████████▊| 137/140 [03:26<00:04, 1.49s/it][A
Iteration: 99%|█████████▊| 138/140 [03:28<00:02, 1.49s/it][A
Iteration: 99%|█████████▉| 139/140 [03:29<00:01, 1.51s/it][A
Iteration: 100%|██████████| 140/140 [03:31<00:00, 1.51s/it][A
Epoch: 50%|█████ | 1/2 [03:31<03:31, 211.30s/it]
Iteration: 0%| | 0/140 [00:00<?, ?it/s][A
Iteration: 1%| | 1/140 [00:01<03:37, 1.57s/it][A
Iteration: 1%|▏ | 2/140 [00:03<03:35, 1.56s/it][A
Iteration: 2%|▏ | 3/140 [00:04<03:34, 1.57s/it][A
Iteration: 3%|▎ | 4/140 [00:06<03:31, 1.56s/it][A
Iteration: 4%|▎ | 5/140 [00:07<03:30, 1.56s/it][A
Iteration: 4%|▍ | 6/140 [00:09<03:28, 1.56s/it][A
Iteration: 5%|▌ | 7/140 [00:10<03:26, 1.56s/it][A
Iteration: 6%|▌ | 8/140 [00:12<03:23, 1.54s/it][A
Iteration: 6%|▋ | 9/140 [00:13<03:21, 1.54s/it][A
Iteration: 7%|▋ | 10/140 [00:15<03:18, 1.53s/it][A
Iteration: 8%|▊ | 11/140 [00:17<03:19, 1.54s/it][A
Iteration: 9%|▊ | 12/140 [00:18<03:16, 1.53s/it][A
Iteration: 9%|▉ | 13/140 [00:20<03:13, 1.52s/it][A
Iteration: 10%|█ | 14/140 [00:21<03:11, 1.52s/it][A
Iteration: 11%|█ | 15/140 [00:23<03:09, 1.52s/it][A
Iteration: 11%|█▏ | 16/140 [00:24<03:07, 1.51s/it][A
Iteration: 12%|█▏ | 17/140 [00:26<03:05, 1.51s/it][A
Iteration: 13%|█▎ | 18/140 [00:27<03:03, 1.50s/it][A
Iteration: 14%|█▎ | 19/140 [00:29<03:01, 1.50s/it][A
Iteration: 14%|█▍ | 20/140 [00:30<02:59, 1.50s/it][A
Iteration: 15%|█▌ | 21/140 [00:32<02:58, 1.50s/it][A
Iteration: 16%|█▌ | 22/140 [00:33<02:57, 1.51s/it][A
Iteration: 16%|█▋ | 23/140 [00:35<02:56, 1.51s/it][A
Iteration: 17%|█▋ | 24/140 [00:36<02:55, 1.51s/it][A
Iteration: 18%|█▊ | 25/140 [00:38<02:54, 1.52s/it][A
Iteration: 19%|█▊ | 26/140 [00:39<02:54, 1.53s/it][A
Iteration: 19%|█▉ | 27/140 [00:41<02:54, 1.55s/it][A
Iteration: 20%|██ | 28/140 [00:42<02:52, 1.54s/it][A
Iteration: 21%|██ | 29/140 [00:44<02:51, 1.54s/it][A
Iteration: 21%|██▏ | 30/140 [00:45<02:49, 1.54s/it][A
Iteration: 22%|██▏ | 31/140 [00:47<02:48, 1.55s/it][A
Iteration: 23%|██▎ | 32/140 [00:48<02:45, 1.53s/it][A
Iteration: 24%|██▎ | 33/140 [00:50<02:44, 1.53s/it][A
Iteration: 24%|██▍ | 34/140 [00:51<02:41, 1.53s/it][A
Iteration: 25%|██▌ | 35/140 [00:53<02:40, 1.53s/it][A
Iteration: 26%|██▌ | 36/140 [00:55<02:38, 1.53s/it][A
Iteration: 26%|██▋ | 37/140 [00:56<02:37, 1.53s/it][A
Iteration: 27%|██▋ | 38/140 [00:58<02:35, 1.52s/it][A
Iteration: 28%|██▊ | 39/140 [00:59<02:33, 1.52s/it][A
Iteration: 29%|██▊ | 40/140 [01:01<02:31, 1.52s/it][A
Iteration: 29%|██▉ | 41/140 [01:02<02:29, 1.51s/it][A
Iteration: 30%|███ | 42/140 [01:04<02:28, 1.52s/it][A
Iteration: 31%|███ | 43/140 [01:05<02:27, 1.52s/it][A
Iteration: 31%|███▏ | 44/140 [01:07<02:25, 1.52s/it][A
Iteration: 32%|███▏ | 45/140 [01:08<02:24, 1.52s/it][A
Iteration: 33%|███▎ | 46/140 [01:10<02:23, 1.53s/it][A
Iteration: 34%|███▎ | 47/140 [01:11<02:22, 1.53s/it][A
Iteration: 34%|███▍ | 48/140 [01:13<02:20, 1.53s/it][A
Iteration: 35%|███▌ | 49/140 [01:14<02:18, 1.52s/it][A
Iteration: 36%|███▌ | 50/140 [01:16<02:16, 1.52s/it][A
Iteration: 36%|███▋ | 51/140 [01:17<02:15, 1.52s/it][A
Iteration: 37%|███▋ | 52/140 [01:19<02:13, 1.52s/it][A
Iteration: 38%|███▊ | 53/140 [01:20<02:11, 1.51s/it][A
Iteration: 39%|███▊ | 54/140 [01:22<02:11, 1.53s/it][A
Iteration: 39%|███▉ | 55/140 [01:23<02:09, 1.52s/it][A
Iteration: 40%|████ | 56/140 [01:25<02:07, 1.52s/it][A
Iteration: 41%|████ | 57/140 [01:27<02:06, 1.53s/it][A
Iteration: 41%|████▏ | 58/140 [01:28<02:04, 1.52s/it][A
Iteration: 42%|████▏ | 59/140 [01:30<02:02, 1.51s/it][A
Evaluating: 0%| | 0/8 [00:00<?, ?it/s][A[A
Evaluating: 12%|█▎ | 1/8 [00:00<00:06, 1.15it/s][A[A
Evaluating: 25%|██▌ | 2/8 [00:01<00:05, 1.14it/s][A[A
Evaluating: 38%|███▊ | 3/8 [00:02<00:04, 1.14it/s][A[A
Evaluating: 50%|█████ | 4/8 [00:03<00:03, 1.14it/s][A[A
Evaluating: 62%|██████▎ | 5/8 [00:04<00:02, 1.14it/s][A[A
Evaluating: 75%|███████▌ | 6/8 [00:05<00:01, 1.14it/s][A[A
Evaluating: 88%|████████▊ | 7/8 [00:06<00:00, 1.14it/s][A[A
Evaluating: 100%|██████████| 8/8 [00:06<00:00, 1.17it/s][A[A
/home/frank/miniconda3/envs/bio-bert-bilstm-crf/lib/python3.7/site-packages/seqeval/metrics/sequence_labeling.py:171: UserWarning: UNK seems not to be NE tag.
warnings.warn('{} seems not to be NE tag.'.format(chunk))
intent_acc = 0.85
loss = 2.148504465818405
sementic_frame_acc = 0.708
slot_f1 = 0.9299883313885647
slot_precision = 0.9278230500582072
slot_recall = 0.9321637426900585
Iteration: 43%|████▎ | 60/140 [01:38<04:57, 3.71s/it][A
Saving model checkpoint to ./save_model
Iteration: 44%|████▎ | 61/140 [01:40<04:00, 3.05s/it][A
Iteration: 44%|████▍ | 62/140 [01:41<03:21, 2.58s/it][A
Iteration: 45%|████▌ | 63/140 [01:43<02:53, 2.26s/it][A
Iteration: 46%|████▌ | 64/140 [01:44<02:34, 2.03s/it][A
Iteration: 46%|████▋ | 65/140 [01:46<02:20, 1.88s/it][A
Iteration: 47%|████▋ | 66/140 [01:47<02:10, 1.77s/it][A
Iteration: 48%|████▊ | 67/140 [01:49<02:03, 1.70s/it][A
Iteration: 49%|████▊ | 68/140 [01:50<01:58, 1.64s/it][A
Iteration: 49%|████▉ | 69/140 [01:52<01:53, 1.60s/it][A
Iteration: 50%|█████ | 70/140 [01:53<01:50, 1.58s/it][A
Iteration: 51%|█████ | 71/140 [01:55<01:47, 1.56s/it][A
Iteration: 51%|█████▏ | 72/140 [01:57<01:47, 1.58s/it][A
Iteration: 52%|█████▏ | 73/140 [01:58<01:45, 1.57s/it][A
Iteration: 53%|█████▎ | 74/140 [02:00<01:42, 1.56s/it][A
Iteration: 54%|█████▎ | 75/140 [02:01<01:40, 1.54s/it][A
Iteration: 54%|█████▍ | 76/140 [02:03<01:38, 1.54s/it][A
Iteration: 55%|█████▌ | 77/140 [02:04<01:37, 1.54s/it][A
Iteration: 56%|█████▌ | 78/140 [02:06<01:35, 1.54s/it][A
Iteration: 56%|█████▋ | 79/140 [02:07<01:34, 1.54s/it][A
Iteration: 57%|█████▋ | 80/140 [02:09<01:32, 1.53s/it][A
Iteration: 58%|█████▊ | 81/140 [02:10<01:30, 1.53s/it][A
Iteration: 59%|█████▊ | 82/140 [02:12<01:28, 1.52s/it][A
Iteration: 59%|█████▉ | 83/140 [02:13<01:26, 1.52s/it][A
Iteration: 60%|██████ | 84/140 [02:15<01:25, 1.52s/it][A
Iteration: 61%|██████ | 85/140 [02:16<01:23, 1.52s/it][A
Iteration: 61%|██████▏ | 86/140 [02:18<01:22, 1.52s/it][A
Iteration: 62%|██████▏ | 87/140 [02:19<01:20, 1.53s/it][A
Iteration: 63%|██████▎ | 88/140 [02:21<01:19, 1.53s/it][A
Iteration: 64%|██████▎ | 89/140 [02:23<01:17, 1.52s/it][A
Iteration: 64%|██████▍ | 90/140 [02:24<01:15, 1.51s/it][A
Iteration: 65%|██████▌ | 91/140 [02:26<01:13, 1.50s/it][A
Iteration: 66%|██████▌ | 92/140 [02:27<01:12, 1.50s/it][A
Iteration: 66%|██████▋ | 93/140 [02:29<01:10, 1.50s/it][A
Iteration: 67%|██████▋ | 94/140 [02:30<01:09, 1.50s/it][A
Iteration: 68%|██████▊ | 95/140 [02:32<01:07, 1.50s/it][A
Iteration: 69%|██████▊ | 96/140 [02:33<01:06, 1.50s/it][A
Iteration: 69%|██████▉ | 97/140 [02:35<01:04, 1.50s/it][A
Iteration: 70%|███████ | 98/140 [02:36<01:02, 1.50s/it][A
Iteration: 71%|███████ | 99/140 [02:38<01:01, 1.50s/it][A
Iteration: 71%|███████▏ | 100/140 [02:39<01:00, 1.50s/it][A
Iteration: 72%|███████▏ | 101/140 [02:41<00:58, 1.50s/it][A
Iteration: 73%|███████▎ | 102/140 [02:42<00:57, 1.50s/it][A
Iteration: 74%|███████▎ | 103/140 [02:44<00:55, 1.50s/it][A
Iteration: 74%|███████▍ | 104/140 [02:45<00:54, 1.51s/it][A
Iteration: 75%|███████▌ | 105/140 [02:47<00:52, 1.51s/it][A
Iteration: 76%|███████▌ | 106/140 [02:48<00:51, 1.51s/it][A
Iteration: 76%|███████▋ | 107/140 [02:50<00:49, 1.51s/it][A
Iteration: 77%|███████▋ | 108/140 [02:51<00:48, 1.52s/it][A
Iteration: 78%|███████▊ | 109/140 [02:53<00:47, 1.53s/it][A
Iteration: 79%|███████▊ | 110/140 [02:54<00:45, 1.53s/it][A
Iteration: 79%|███████▉ | 111/140 [02:56<00:44, 1.53s/it][A
Iteration: 80%|████████ | 112/140 [02:57<00:42, 1.52s/it][A
Iteration: 81%|████████ | 113/140 [02:59<00:41, 1.52s/it][A
Iteration: 81%|████████▏ | 114/140 [03:00<00:39, 1.52s/it][A
Iteration: 82%|████████▏ | 115/140 [03:02<00:37, 1.52s/it][A
Iteration: 83%|████████▎ | 116/140 [03:03<00:36, 1.51s/it][A
Iteration: 84%|████████▎ | 117/140 [03:05<00:34, 1.52s/it][A
Iteration: 84%|████████▍ | 118/140 [03:06<00:33, 1.54s/it][A
Iteration: 85%|████████▌ | 119/140 [03:08<00:32, 1.54s/it][A
Iteration: 86%|████████▌ | 120/140 [03:09<00:30, 1.54s/it][A
Iteration: 86%|████████▋ | 121/140 [03:11<00:29, 1.54s/it][A
Iteration: 87%|████████▋ | 122/140 [03:13<00:27, 1.53s/it][A
Iteration: 88%|████████▊ | 123/140 [03:14<00:25, 1.52s/it][A
Iteration: 89%|████████▊ | 124/140 [03:16<00:24, 1.52s/it][A
Iteration: 89%|████████▉ | 125/140 [03:17<00:22, 1.51s/it][A
Iteration: 90%|█████████ | 126/140 [03:19<00:21, 1.51s/it][A
Iteration: 91%|█████████ | 127/140 [03:20<00:19, 1.51s/it][A
Iteration: 91%|█████████▏| 128/140 [03:22<00:18, 1.51s/it][A
Iteration: 92%|█████████▏| 129/140 [03:23<00:16, 1.50s/it][A
Iteration: 93%|█████████▎| 130/140 [03:25<00:15, 1.50s/it][A
Iteration: 94%|█████████▎| 131/140 [03:26<00:13, 1.50s/it][A
Iteration: 94%|█████████▍| 132/140 [03:28<00:11, 1.50s/it][A
Iteration: 95%|█████████▌| 133/140 [03:29<00:10, 1.50s/it][A
Iteration: 96%|█████████▌| 134/140 [03:31<00:08, 1.50s/it][A
Iteration: 96%|█████████▋| 135/140 [03:32<00:07, 1.51s/it][A
Iteration: 97%|█████████▋| 136/140 [03:34<00:06, 1.50s/it][A
Iteration: 98%|█████████▊| 137/140 [03:35<00:04, 1.50s/it][A
Iteration: 99%|█████████▊| 138/140 [03:37<00:02, 1.50s/it][A
Iteration: 99%|█████████▉| 139/140 [03:38<00:01, 1.51s/it][A
Iteration: 100%|██████████| 140/140 [03:39<00:00, 1.57s/it][A
Epoch: 100%|██████████| 2/2 [07:11<00:00, 215.63s/it]
(280, 5.539662108251027)
用训练好的模型预测数据
trainer.evaluate("test")
Evaluating: 100%|██████████| 14/14 [00:11<00:00, 1.17it/s]
intent_acc = 0.8443449048152296
loss = 2.3734985419682095
sementic_frame_acc = 0.7077267637178052
slot_f1 = 0.916652105539053
slot_precision = 0.909500693481276
slot_recall = 0.9239168721380768
{'loss': 2.3734985419682095,
'intent_acc': 0.8443449048152296,
'slot_precision': 0.909500693481276,
'slot_recall': 0.9239168721380768,
'slot_f1': 0.916652105539053,
'sementic_frame_acc': 0.7077267637178052}