利用MLPack插件在DuckDB中机器学习

先安装mlpack插件

D load httpfs;
D INSTALL mlpack FROM community;
100% ▕██████████████████████████████████████▏ (00:00:06.43 elapsed)  

鸢尾花数据集(Iris Dataset)是机器学习中最经典的入门数据集之一。

鸢尾花数据集包含了三种鸢尾花(Setosa、Versicolor、Virginica)每种花的 4 个特征:花萼长度、花萼宽度、花瓣长度和花瓣宽度。

接下来我们的任务是基于这些特征来预测鸢尾花的种类。
示例脚本中有一处错误,mlpack_adaboost_train函数误写作mlpack_adaboost,已改正。

load httpfs;
load mlpack;
.timer on


-- Perform adaBoost (using weak learner 'Perceptron' by default)
-- Read 'features' into 'X', 'labels' into 'Y', use optional parameters
-- from 'Z', and prepare model storage in 'M'
CREATE TABLE X AS SELECT * FROM read_csv("https://eddelbuettel.github.io/duckdb-mlpack/data/iris.csv");
CREATE TABLE Y AS SELECT * FROM read_csv("https://eddelbuettel.github.io/duckdb-mlpack/data/iris_labels.csv");
CREATE TABLE Z (name VARCHAR, value VARCHAR);
INSERT INTO Z VALUES ('iterations', '50'), ('tolerance', '1e-7');
CREATE TABLE M (key VARCHAR, json VARCHAR);

-- Train model for 'Y' on 'X' using parameters 'Z', store in 'M'
CREATE TEMP TABLE A AS SELECT * FROM mlpack_adaboost_train("X", "Y", "Z", "M");

-- Count by predicted group
SELECT COUNT(*) as n, predicted FROM A GROUP BY predicted;

-- Model 'M' can be used to predict
CREATE TABLE N (x1 DOUBLE, x2 DOUBLE, x3 DOUBLE, x4 DOUBLE);
-- inserting approximate column mean values
INSERT INTO N VALUES (5.843, 3.054, 3.759, 1.199);
-- inserting approximate column mean values, min values, max values
INSERT INTO N VALUES (5.843, 3.054, 3.759, 1.199), (4.3, 2.0, 1.0, 0.1), (7.9, 4.4, 6.9, 2.5);
-- and this predict one element each
SELECT * FROM mlpack_adaboost_pred("N", "M");

执行结果如下:

root@66d4e20ec1d7:/par# ./duckdb141 mlpack
DuckDB v1.4.1 (Andium) b390a7c376
Enter ".help" for usage hints.
D .read ml.txt
Run Time (s): real 1.646 user 0.012000 sys 0.004000
Run Time (s): real 2.675 user 0.008000 sys 0.004000
Run Time (s): real 0.042 user 0.000000 sys 0.000000
Run Time (s): real 0.042 user 0.000000 sys 0.000000
Run Time (s): real 0.041 user 0.000000 sys 0.000000
Misclassified: 1
Run Time (s): real 0.118 user 0.192000 sys 0.000000
┌───────┬───────────┐
│   n   │ predicted │
│ int64 │   int32   │
├───────┼───────────┤
│    500 │
│    491 │
│    512 │
└───────┴───────────┘
Run Time (s): real 0.001 user 0.000000 sys 0.000000
Run Time (s): real 0.040 user 0.000000 sys 0.000000
Run Time (s): real 0.042 user 0.004000 sys 0.000000
Run Time (s): real 0.041 user 0.000000 sys 0.000000
┌───────────┐
│ predicted │
│   int32   │
├───────────┤
│         1 │
│         1 │
│         0 │
│         2 │
└───────────┘
Run Time (s): real 0.003 user 0.004000 sys 0.000000

查看表中数据

D from x;
┌─────────┬─────────┬─────────┬─────────┐
│ column0 │ column1 │ column2 │ column3 │
│ doubledoubledoubledouble  │
├─────────┼─────────┼─────────┼─────────┤
│     5.13.51.40.2 │
│      ·  │      ·  │      ·  │      ·  │
│     5.93.05.11.8 │
├─────────┴─────────┴─────────┴─────────┤
│ 150 rows (40 shown)         4 columns │
└───────────────────────────────────────┘
Run Time (s): real 0.146 user 0.016000 sys 0.000000
D from y;
┌────────────┐
│  column0   │
│   int64    │
├────────────┤
│          0 │
│          · │
│          2 │
├────────────┤
│  150 rows  │
│ (40 shown) │
└────────────┘
Run Time (s): real 0.001 user 0.000000 sys 0.000000
D from z;
┌────────────┬─────────┐
│    name    │  value  │
│  varcharvarchar │
├────────────┼─────────┤
│ iterations │ 50      │
│ tolerance  │ 1e-7    │
└────────────┴─────────┘
Run Time (s): real 0.001 user 0.000000 sys 0.000000
D from m;
┌─────────┬──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┐
│   key   │                                                                                       json                                                                                       │
│ varcharvarchar                                                                                      │
├─────────┼──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ model   │ {\n    "x": {\n        "cereal_class_version": 1,\n        "numClasses": 3,\n        "tolerance": 1e-7,\n        "maxIterations": 50,\n        "alpha": [\n            1.68364…  │
└─────────┴──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┘
Run Time (s): real 0.001 user 0.000000 sys 0.000000
D from n;
┌────────┬────────┬────────┬────────┐
│   x1   │   x2   │   x3   │   x4   │
│ doubledoubledoubledouble │
├────────┼────────┼────────┼────────┤
│  5.8433.0543.7591.199 │
│  5.8433.0543.7591.199 │
│    4.32.01.00.1 │
│    7.94.46.92.5 │
└────────┴────────┴────────┴────────┘
Run Time (s): real 0.001 user 0.004000 sys 0.000000
D from a;
┌────────────┐
│ predicted  │
│   int32    │
├────────────┤
│          0 │
│          · │
│          2 │
├────────────┤
│  150 rows  │
│ (40 shown) │
└────────────┘
Run Time (s): real 0.001 user 0.000000 sys 0.000000

因为数据集很小,才150行,虽然迭代50次,训练模型和预测都非常快,模型的精度也还可以,只有1个分类错误。

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值