机器学习_LightGBM callback示例

lightgbm在train的时候有callback的接口,我们需要将训练过程的损失下降情况进行记录就需要这个接口。本文笔者就是以记录训练迭代过程的损失为出发点,写一个简单的lightgbm中callback的使用方法。

一、callbacks接口

callbacks参数输入要求

(function) train: (params:
… callbacks: List[(…) -> Any] | None = None) -> Booster
callbacks : list of callable, or None, optional (default=None)
入参是一个list,list中的对象都是callback方法。callbacks在官方文档中主要是四种方法

  • early_stopping
    • 停止迭代
    • lightgbm.early_stopping(stopping_rounds, first_metric_only=False, verbose=True, min_delta=0.0)
  • log_evaluation
    • 记录迭代过程的指标, 可以在日志中输出
    • lightgbm.log_evaluation(period=1, show_stdv=True)
  • record_evaluation(eval_result)
    • 把迭代过程指标记录到输入的空字典中
    • lightgbm.record_evaluation(eval_result) ;eval_result 可以为一个空字典eval_result = {}
  • reset_parameter(**kwargs)
    • 每次迭代更新数据
    • List of parameters for each boosting round or a callable that calculates the parameter in terms of current number of round

示例

eval_result = {}
lgb_model = lgb.train(lgb_param, train_set=tr_lgb_dt , valid_sets=[tr_lgb_dt, te_lgb_dt], 
          verbose_eval=20,
          callbacks=[lgb.log_evaluation, lgb.early_stopping(50, first_metric_only=True), lgb.record_evaluation(eval_result)]
          )

二、完整iris案例


from sklearn.datasets import load_iris
import lightgbm as lgb
import pandas as pd
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
plt.style.use('ggplot')
import warnings
warnings.filterwarnings('ignore')


iris = load_iris()
df = pd.DataFrame(data=iris.data, columns=[i[:-5].replace(' ','_') for i in iris.feature_names])
df['target'] = iris.target

tr_x, te_x, tr_y, te_y = train_test_split(df.drop(columns='target'), df['target'], test_size=0.2)
tr_lgb_dt = lgb.Dataset(tr_x, label=tr_y.values)
te_lgb_dt = lgb.Dataset(te_x, label=te_y.values)


lgb_param = {
    'objective': 'multiclass',
    'metric': ['multi_logloss', 'multi_error'],
    'num_class': 3,
    'n_jobs': 4,
    'num_iterations': 300,
    'learning_rate': 0.02,
    'max_depth': 4,
    'lambda_l2': 0.8,
    'verbose': -1
}
eval_result={}
lgb_model = lgb.train(lgb_param, train_set=tr_lgb_dt , valid_sets=[tr_lgb_dt, te_lgb_dt], 
          verbose_eval=20,
          callbacks=[lgb.log_evaluation, lgb.early_stopping(50, first_metric_only=True), lgb.record_evaluation(eval_result)]
          )


# plot loss
plt.title('train_loss')
for data_name, metric_res in eval_result.items():
    for metric_name, log_ in metric_res.items():
        plt.plot(log_, label = f'{data_name}-{metric_name}', 
                color='steelblue' if 'train' in data_name else 'darkred', 
                linestyle=None if 'train' in data_name else '-.',
                alpha=0.7)

plt.legend()
plt.show()

在这里插入图片描述

评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Scc_hy

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值