scikit-learn wrapper to finetune BERT
A scikit-learn wrapper to finetune Google's BERT model for text and token sequence tasks based on the huggingface pytorch port.
Includes configurable MLP as final classifier/regressor for text and text pair tasks
Includes token sequence classifier for NER, PoS, and chunking tasks
Includes SciBERT and BioBERT pretrained models for scientific and biomedical domains.
installation
requires python >= 3.5 and pytorch >= 0.4.1
git clone -b master https://github.com/charles9n/bert-sklearn
cd bert-sklearn
pip install .
basic operation
model.fit(X,y) i.e finetune BERT
X: list, pandas dataframe, or numpy array of text, text pairs, or token lists
y : list, pandas dataframe, or numpy array of labels/targets
from bert_sklearn import BertClassifier
from bert_sklearn import BertRegressor
from bert_sklearn import load_model
# define model
model = BertClassifier() # text/text pair classification
# model = BertRegressor() # text/text pair regression
# model = BertTokenClassifier() # token sequence classification
# finetune model
model.fit(X_train, y_train)
# make predictions
y_pred = model.predict(X_test)
# make probabilty predictions
y_pred = model.predict_proba(X_test)
# score model on test data
model.score(X_test, y_test)
# save model to disk
savefile='/data/mymodel.bin'
model.save(savefile)
# load model from disk
new_model = load_model(savefile)
# do stuff with new model
new_model.score(X_test, y_test)
See demo notebook.
model options
# try different options...
model.bert_model = 'bert-large-uncased'
model.num_mlp_layers = 3
model.max_seq_length = 196
model.epochs = 4
model.learning_rate = 4e-5
model.gradient_accumulation_steps = 4
# finetune
model.fit(X_train, y_train)
# do stuff...
model.score(X_test, y_test)
hyperparameter tuning
from sklearn.model_selection import GridSearchCV
params = {'epochs':[3, 4], 'learning_rate':[2e-5, 3e-5, 5e-5]}
# wrap classifier in GridSearchCV
clf = GridSearchCV(BertClassifier(validation_fraction=0),
params,
scoring='accuracy',
verbose=True)
# fit gridsearch
clf.fit(X_train ,y_train)
GLUE datasets
The train and dev data sets from the GLUE(Generalized Language Understanding Evaluation) benchmarks were used with bert-base-uncased model and compared againt the reported results in the Google paper and GLUE leaderboard.
MNLI(m/mm)
QQP
QNLI
SST-2
CoLA
STS-B
MRPC
RTE
BERT base(leaderboard)
84.6/83.4
89.2
90.1
93.5
52.1
87.1
84.8
66.4
bert-sklearn
83.7/83.9
90.2
88.6
92.32
58.1
89.7
86.8
64.6
Individual runs can be found can be found here.
CoNLL-2003 Named Entity Recognition(NER)
NER results for CoNLL-2003 shared task
dev f1
test f1
BERT paper
96.4
92.4
bert-sklearn
96.04
91.97
Span level stats on test:
processed 46666 tokens with 5648 phrases; found: 5740 phrases; correct: 5173.
accuracy: 98.15%; precision: 90.12%; recall: 91.59%; FB1: 90.85
LOC: precision: 92.24%; recall: 92.69%; FB1: 92.46 1676
MISC: precision: 78.07%; recall: 81.62%; FB1: 79.81 734
ORG: precision: 87.64%; recall: 90.07%; FB1: 88.84 1707
PER: precision: 96.00%; recall: 96.35%; FB1: 96.17 1623
See ner_english notebook for a demo using 'bert-base-cased' model.
NCBI Biomedical NER
NER results using bert-sklearn with SciBERT and BioBERT on the the NCBI disease Corpus name recognition task.
Previous SOTA for this task is 87.34 for f1 on the test set.
test f1 (bert-sklearn)
test f1 (from papers)
BERT base cased
85.09
85.49
SciBERT basevocab cased
88.29
86.91
SciBERT scivocab cased
87.73
86.45
BioBERT pubmed_v1.0
87.86
87.38
BioBERT pubmed_pmc_v1.0
88.26
89.36
BioBERT pubmed_v1.1
87.26
NA
See ner_NCBI_disease_BioBERT_SciBERT notebook for a demo using SciBERT and BioBERT models.
See SciBERT paper and BioBERT paper for more info on the respective models.
Other examples
See IMDb notebook for a text classification demo on the Internet Movie Database review sentiment task.
See chunking_english notebook for a demo on syntactic chunking using the CoNLL-2000 chunking task data.
See ner_chinese notebook for a demo using 'bert-base-chinese' for Chinese NER.
tests
Run tests with pytest :
python -m pytest -sv tests/
references
本文介绍了一个scikit-learn包装器,用于微调Google的BERT模型,适用于文本和令牌序列任务。包括预训练的SciBERT和BioBERT模型,可用于科学和生物医学领域。提供分类、回归和标记序列分类功能,并支持超参数调整。
2342

被折叠的 条评论
为什么被折叠?



