python scikit learn 封装_Google's BERT模型的一个sklearn封装

本文介绍了一个scikit-learn包装器,用于微调Google的BERT模型,适用于文本和令牌序列任务。包括预训练的SciBERT和BioBERT模型,可用于科学和生物医学领域。提供分类、回归和标记序列分类功能,并支持超参数调整。

scikit-learn wrapper to finetune BERT

A scikit-learn wrapper to finetune Google's BERT model for text and token sequence tasks based on the huggingface pytorch port.

Includes configurable MLP as final classifier/regressor for text and text pair tasks

Includes token sequence classifier for NER, PoS, and chunking tasks

Includes SciBERT and BioBERT pretrained models for scientific and biomedical domains.

installation

requires python >= 3.5 and pytorch >= 0.4.1

git clone -b master https://github.com/charles9n/bert-sklearn

cd bert-sklearn

pip install .

basic operation

model.fit(X,y) i.e finetune BERT

X: list, pandas dataframe, or numpy array of text, text pairs, or token lists

y : list, pandas dataframe, or numpy array of labels/targets

from bert_sklearn import BertClassifier

from bert_sklearn import BertRegressor

from bert_sklearn import load_model

# define model

model = BertClassifier() # text/text pair classification

# model = BertRegressor() # text/text pair regression

# model = BertTokenClassifier() # token sequence classification

# finetune model

model.fit(X_train, y_train)

# make predictions

y_pred = model.predict(X_test)

# make probabilty predictions

y_pred = model.predict_proba(X_test)

# score model on test data

model.score(X_test, y_test)

# save model to disk

savefile='/data/mymodel.bin'

model.save(savefile)

# load model from disk

new_model = load_model(savefile)

# do stuff with new model

new_model.score(X_test, y_test)

See demo notebook.

model options

# try different options...

model.bert_model = 'bert-large-uncased'

model.num_mlp_layers = 3

model.max_seq_length = 196

model.epochs = 4

model.learning_rate = 4e-5

model.gradient_accumulation_steps = 4

# finetune

model.fit(X_train, y_train)

# do stuff...

model.score(X_test, y_test)

hyperparameter tuning

from sklearn.model_selection import GridSearchCV

params = {'epochs':[3, 4], 'learning_rate':[2e-5, 3e-5, 5e-5]}

# wrap classifier in GridSearchCV

clf = GridSearchCV(BertClassifier(validation_fraction=0),

params,

scoring='accuracy',

verbose=True)

# fit gridsearch

clf.fit(X_train ,y_train)

GLUE datasets

The train and dev data sets from the GLUE(Generalized Language Understanding Evaluation) benchmarks were used with bert-base-uncased model and compared againt the reported results in the Google paper and GLUE leaderboard.

MNLI(m/mm)

QQP

QNLI

SST-2

CoLA

STS-B

MRPC

RTE

BERT base(leaderboard)

84.6/83.4

89.2

90.1

93.5

52.1

87.1

84.8

66.4

bert-sklearn

83.7/83.9

90.2

88.6

92.32

58.1

89.7

86.8

64.6

Individual runs can be found can be found here.

CoNLL-2003 Named Entity Recognition(NER)

NER results for CoNLL-2003 shared task

dev f1

test f1

BERT paper

96.4

92.4

bert-sklearn

96.04

91.97

Span level stats on test:

processed 46666 tokens with 5648 phrases; found: 5740 phrases; correct: 5173.

accuracy: 98.15%; precision: 90.12%; recall: 91.59%; FB1: 90.85

LOC: precision: 92.24%; recall: 92.69%; FB1: 92.46 1676

MISC: precision: 78.07%; recall: 81.62%; FB1: 79.81 734

ORG: precision: 87.64%; recall: 90.07%; FB1: 88.84 1707

PER: precision: 96.00%; recall: 96.35%; FB1: 96.17 1623

See ner_english notebook for a demo using 'bert-base-cased' model.

NCBI Biomedical NER

NER results using bert-sklearn with SciBERT and BioBERT on the the NCBI disease Corpus name recognition task.

Previous SOTA for this task is 87.34 for f1 on the test set.

test f1 (bert-sklearn)

test f1 (from papers)

BERT base cased

85.09

85.49

SciBERT basevocab cased

88.29

86.91

SciBERT scivocab cased

87.73

86.45

BioBERT pubmed_v1.0

87.86

87.38

BioBERT pubmed_pmc_v1.0

88.26

89.36

BioBERT pubmed_v1.1

87.26

NA

See ner_NCBI_disease_BioBERT_SciBERT notebook for a demo using SciBERT and BioBERT models.

See SciBERT paper and BioBERT paper for more info on the respective models.

Other examples

See IMDb notebook for a text classification demo on the Internet Movie Database review sentiment task.

See chunking_english notebook for a demo on syntactic chunking using the CoNLL-2000 chunking task data.

See ner_chinese notebook for a demo using 'bert-base-chinese' for Chinese NER.

tests

Run tests with pytest :

python -m pytest -sv tests/

references

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值