基于预训练模型 ERNIE-Gram 实现语义匹配-模型预测

尝试直接使用训练好的参数,进行预测

使用 Lcqmc 数据集的测试集作为我们的预测数据

添加链接描述

在这里插入图片描述

加载预测数据

test_ds = load_dataset("lcqmc", splits=["test"])

在这里插入图片描述

生成预测数据

predict_data_loader =paddle.io.DataLoader(
        dataset=test_ds.map(trans_func),
        batch_sampler=batch_sampler,
        collate_fn=batchify_fn,
        return_list=True)

定义预测模型

pretrained_model = paddlenlp.transformers.ErnieGramModel.from_pretrained('ernie-gram-zh')

model = PointwiseMatching(pretrained_model)

加载已训练好的模型参数

state_dict = paddle.load("test/ernie_gram_zh_pointwise_matching_model/model_20000/model_state.pdparams")
model.set_dict(state_dict)

在这里插入图片描述

定义预测函数

def predict(model, data_loader):
    batch_probs = []

    # 预测阶段打开 eval 模式,模型中的 dropout 等操作会关掉
    model.eval()

    with paddle.no_grad():
        for batch_data in data_loader:
            input_ids, token_type_ids = batch_data
            input_ids = paddle.to_tensor(input_ids)
            token_type_ids = paddle.to_tensor(token_type_ids)

            # 获取每个样本的预测概率: [batch_size, 2] 的矩阵
            batch_prob = model(
                input_ids=input_ids, token_type_ids=token_type_ids).numpy()

            batch_probs.append(batch_prob)
            if(batch_probs.__len__()==10):
                batch_probs = np.concatenate(batch_probs, axis=0)
                return batch_probs
   
# 执行预测函数
y_probs = predict(model, predict_data_loader)

# 根据预测概率获取预测 label
y_preds = np.argmax(y_probs, axis=1)

在这里插入图片描述

预测函数里的batch_probs存放前向计算的结果,为了方便观看处理后的数据,我设定了batch_probs的大小为10就返回结果

在这里插入图片描述

# 按照千言文本相似度竞赛的提交格式将预测结果存储在 lcqmc.tsv 中,用来后续提交
with open("lcqmc.tsv", 'w', encoding="utf-8") as f:
    f.write("index\tprediction\n")
    for idx, y_pred in enumerate(y_preds):
        f.write("{}\t{}\n".format(idx, y_pred))
        text_pair = test_ds[idx]
        text_pair["label"] = y_pred
        print(text_pair)

在这里插入图片描述

可见,在不需要微调的情况下,预训练这种模式使得任务完成地特别简单。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值