
☞ ░ 前往老猿Python博客 ░ https://blog.youkuaiyun.com/LaoYuanPython
Joblib 是 scikit-learn 的依赖库之一,在机器学习领域应用广泛,主要用于高效序列化(保存/加载)和并行计算,特别适合处理大数据(如 NumPy 数组)和机器学习模型,在处理大型数值数据时比 Python 内置的 pickle 更高效。
一、安装
安装库和检测安装是否成功的指令如下:
pip install joblib
python -c "import joblib; print(joblib.__version__)"
执行检测指令后输出joblib库的版本号。如图:
joblib 通常与 numpy 和 scipy 一起使用,安装完成后,就可以使用 joblib 进行高效的并行计算和缓存功能了。
二. 主要功能
1、高效序列化(保存 & 加载数据/模型)
- 替代 pickle:joblib 在存储 NumPy 数组、SciPy 稀疏矩阵和机器学习模型时,比 pickle 更快、更节省内存
- 适用于 Scikit-learn 模型:Scikit-learn 官方推荐使用 joblib 保存和加载训练好的模型。
2、并行计算(Parallel Computing)
- 多线程/多进程支持:可以轻松并行化 for 循环,适用于 CPU 密集型任务(如网格搜索、交叉验证)
- 替代 multiprocessing:相比 Python 自带的 multiprocessing,joblib 的 API 更简单,且能更好地处理大数据。
三、核心API
1、使用dump和load保存和加载数据/模型
示例:
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris
# 训练一个模型
data = load_iris()
model = RandomForestClassifier().fit(data.data, data.target)
# 保存模型
joblib.dump(model, "iris_model.pkl")
# 加载模型
loaded_model = joblib.load("iris_model.pkl")
print(loaded_model.predict([[5.1, 3.5, 1.4, 0.2]])) # 输出: [0]
2、并行计算(Parallel & delayed)
joblib.Parallel 用于并行执行任务,delayed 用于定义要并行化的函数。
Parallel适用于机器学习超参数调优、大数据批处理(如特征工程)、模拟实验(如蒙特卡洛模拟)等场景。
Parallel参数说明:
- n_jobs:并行任务数(-1 表示使用所有 CPU 核心)
- backend:并行后端(“threading” 或 “multiprocessing”,默认 “loky”),CPU密集型任务使用默认的 multiprocessing 后端,IO密集型任务使用 threading 后端
- verbose:是否显示进度(verbose=10 会打印进度)。
装饰器delayed使函数其延迟执行,delayed 不会立即执行函数,而是 返回一个延迟计算对象,记录要调用的函数及其参数。当配合 Parallel 使用时,这些延迟的任务会被分配到多个进程或线程中并行执行。
注意:
- 并行处理时,被调用的函数需要是可序列化的
- 避免在并行函数中使用全局变量
示例:
from joblib import Parallel, delayed
import math
# 基本并行计算示例
def square(x):
return x ** 2
results = Parallel(n_jobs=2)(delayed(square)(i) for i in range(10))
print(results) # 输出: [0, 1, 4, 9, 16, 25, 36, 49, 64, 81]
# 复杂函数并行
def process_item(x, power):
return math.pow(x, power), x * 2
results = Parallel(n_jobs=4)(
delayed(process_item)(i, 3) for i in range(5)
)
print(results) # 输出: [(0.0, 0), (1.0, 2), (8.0, 4), (27.0, 6), (64.0, 8)]
3、后端并行parallel_backend
parallel_backend 是一个上下文管理器(通过 joblib 提供),用于临时修改 scikit-learn 中并行任务的默认后端(如多进程、多线程或分布式计算)。它影响所有内部使用 joblib.Parallel 的代码(如 GridSearchCV、RandomizedSearchCV 或任何设置 n_jobs 的模型)。
案例:
from joblib import parallel_backend,Parallel,delayed
import time
def slow_function(x):
time.sleep(1)
return x * 2
# 使用上下文管理器控制并行后端
with parallel_backend('threading', n_jobs=2):
results = Parallel()(delayed(slow_function)(i) for i in range(4))
# 原本需要4秒,并行后约2秒完成
print(results) # 输出: [0, 2, 4, 6]
# 批处理模式减少通信开销
results = Parallel(batch_size=2)(delayed(slow_function)(i) for i in range(4))
上面代码中,with语句中的parallel_backend(‘threading’, n_jobs=2)设置并行计算的后端(backend)为 threading,并限制最多 2 个线程,这个语句影响with语句下面的Parallel,Parallel没有显式指定 n_jobs,它会使用 parallel_backend 设置的 n_jobs=2,因此,slow_function 会在 2 个线程上并行执行,4 个任务总共耗时约 2 秒。
最后一行的Parallel(batch_size=2)没有使用 parallel_backend,所以默认使用 joblib 的默认后端(通常是 loky,即多进程),batch_size=2 表示任务分批处理,每批次提交 2 个任务给工作进程,可以减少进程间通信(IPC)的开销,但不会改变并行度(n_jobs 默认是 1,除非显式设置)。
4、内存映射(mmap_mode)
适用于超大数组,允许部分加载数据而不全部读入内存:
import joblib,numpy
large_array = numpy.random.rand(100000, 100)
joblib.dump(large_array, "large_array.pkl", compress=3)
# 以内存映射方式加载(不全部读入内存)
loaded_array = joblib.load("large_array.pkl", mmap_mode="r")
print(loaded_array[:10]) # 只读取前10行
5、缓存计算结果(Memory)
避免重复计算,适合耗时的预处理:
from joblib import Memory
import time
def timing_decorator(func):
def wrapper(*args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
print(f"通过装饰器获取的{func.__name__} 执行耗时: {end_time - start_time:.6f}秒")
return result
return wrapper
# 设置缓存目录
memory = Memory("cachedir", verbose=0)
@timing_decorator
@memory.cache
def Squart(x):
start_time = time.time() # 记录开始时间
time.sleep(2) # 模拟耗时计算
end_time = time.time() # 记录结束时间
print(f"Squart函数执行耗时: {end_time - start_time:.6f}秒")
return x ** 2
# 第一次运行会计算并缓存
print("函数执行结果:",Squart(6)) # 耗时 ~2秒
# 第二次运行直接读取缓存
print("函数执行结果:",Squart(6)) # 耗时 ~0秒
执行结果如下:
Squart函数执行耗时: 2.000085秒
通过装饰器获取的Squart 执行耗时: 2.002621秒
函数执行结果: 36
通过装饰器获取的Squart 执行耗时: 0.000834秒
函数执行结果: 36
如果再重复执行一次,输出结果如下:
Squart函数执行耗时: 2.000085秒
通过装饰器获取的Squart 执行耗时: 2.002621秒
函数执行结果: 36
通过装饰器获取的Squart 执行耗时: 0.000834秒
函数执行结果: 36
可以看到,函数重复执行,只有第一次是执行了函数体的所有代码,后面重复执行都是从缓存中获取的数据,根本没有执行函数体的内容。
注意:
- 内存缓存对函数的输入参数进行哈希处理,确保参数是可哈希的
- 在 Python 中,可哈希(hashable) 的数据类型是指那些不可变(immutable) 并且能正确实现 hash() 和 eq() 方法的对象。可哈希的对象可以用作字典的键或集合的元素,而 不可哈希(unhashable) 的类型则不能
- 可哈希(Hashable)的数据类型:整数(int) 、浮点数(float)、布尔值(bool)、字符串(str)、元组(tuple)(要求其内所有元素都可哈希)、冻结集合(frozenset)、字节(bytes),元组可以哈希是因为元组是不可变的
- 不可哈希(Unhashable)的数据类型:列表(list)、字典(dict)、集合(set)、字节数组(bytearray)、自定义类(默认) 、函数(包括 lambda)
四、实战案例:机器学习管道
机器学习管道(Pipeline)是指 将数据预处理、模型训练、评估等步骤标准化为可复用的流程。
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from joblib import Parallel, Memory, dump, load,delayed
# 1. 准备数据
digits = load_digits()
X, Y = digits.data, digits.target
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2)
# 2. 设置缓存
memory = Memory('./cachedir', verbose=0)
# 3. 缓存预处理函数
@memory.cache
def preprocess_data(X):
return X / 16.0 # 简单归一化
X_train_processed = preprocess_data(X_train)
X_test_processed = preprocess_data(X_test)
# 4. 并行训练多个模型
def train_model(n_estimators):
model = RandomForestClassifier(n_estimators=n_estimators)
model.fit(X_train_processed, Y_train)
return model.score(X_test_processed, Y_test)
n_estimators_list = [10, 50, 100, 200]
scores = Parallel(n_jobs=4)(
delayed(train_model)(n) for n in n_estimators_list
)
#选择最佳模型
best_n = n_estimators_list[scores.index(max(scores))] # 动态选择
print("模型得分:", scores,"最佳模型的n_estimators:",best_n)
best_model = RandomForestClassifier(n_estimators=best_n)
best_model.fit(X_train_processed, Y_train)
# 5. 保存最佳模型
dump(best_model, 'best_model.joblib')
# 6. 后续加载使用
loaded_model = load('best_model.joblib')
print("测试准确率:", loaded_model.score(X_test_processed, Y_test))
执行时输出如下:
模型得分: [0.9527777777777777, 0.9722222222222222, 0.9722222222222222, 0.975] 最佳模型的n_estimators: 200
测试准确率: 0.975
上段代码中:
- sklearn.datasets.load_digits() 是 Scikit-learn 提供的一个内置函数,用于加载经典的手写数字数据集(Digits Dataset)
- RandomForestClassifier 是 Scikit-learn 中基于随机森林(Random Forest)算法的分类器,属于集成学习(Ensemble Learning)方法。它通过构建多棵决策树并综合它们的预测结果来提高分类准确率和鲁棒性,其中参数n_estimators为决策树的数量
- train_test_split 是 Scikit-learn 中用于将数据集随机划分为训练集和测试集的函数,是机器学习流程中的关键步骤
- train_model:接收一个参数 n_estimators(决策树的数量),用该参数初始化随机森林模型,并在训练集(X_train_processed, Y_train)上训练,返回模型在测试集(X_test_processed, Y_test)上的准确率(score),方便后续评估不同参数的准确率
上段代码符合机器学习管道的核心特征:
-
模块化步骤
数据准备 → 预处理 → 模型训练 → 评估 → 持久化
每个步骤职责明确,通过函数封装(如 preprocess_data、train_model)。 -
自动化缓存
预处理结果自动缓存(@memory.cache),避免重复计算,提升管道效率。 -
并行化处理
超参数搜索通过 Parallel 并行化,加速模型选择。 -
端到端流程
从原始数据 (load_digits) 到最终模型部署 (load),形成完整闭环。
在这段代码中,如果只是运行这个示例程序,先 dump 后 load 没什么作用,但作为机器学习管道的示例,这样的作用是为了实现 模型的持久化(Serialization)与复用,这是机器学习工作流中的关键步骤。
四、常见问题
1、joblib 和 pickle 怎么选?
如果是 Scikit-learn 模型、NumPy 数组,优先用 joblib。
如果是 普通 Python 对象(如字典、列表),可以用 pickle。
2、并行计算时 n_jobs 怎么设置?
- n_jobs=-1:使用所有 CPU 核心(推荐)
- n_jobs=1:单线程(调试时用)
- n_jobs=4:手动指定 4 个线程。
3、为什么 joblib.load 报错?
可能原因:
- 文件路径错误
- 存储的 Python 版本和加载的版本不一致(如 Python 3.7 保存,Python 3.10 加载)
- 文件损坏(可尝试重新保存)。
五、小结
joblib 是一个用于 Python 的轻量级流水线工具库,是机器学习工程师和数据科学家的必备工具,特别适合模型持久化、并行加速、大数据处理、快速磁盘缓存。本文介绍了joblib的功能、安装,并结合案例介绍了主要的API能力。
更多人工智能知识学习过程中可能遇到的疑难问题及解决办法请关注专栏《零基础机器学习入门》及付费专栏《机器学习疑难问题集》后续的文章。
写博不易,敬请支持:
如果阅读本文于您有所获,敬请点赞、评论、收藏,谢谢大家的支持!
关于老猿的付费专栏
- 付费专栏《https://blog.youkuaiyun.com/laoyuanpython/category_9607725.html 使用PyQt开发图形界面Python应用》专门介绍基于Python的PyQt图形界面开发基础教程,对应文章目录为《 https://blog.youkuaiyun.com/LaoYuanPython/article/details/107580932 使用PyQt开发图形界面Python应用专栏目录》;
- 付费专栏《https://blog.youkuaiyun.com/laoyuanpython/category_10232926.html moviepy音视频开发专栏 )详细介绍moviepy音视频剪辑合成处理的类相关方法及使用相关方法进行相关剪辑合成场景的处理,对应文章目录为《https://blog.youkuaiyun.com/LaoYuanPython/article/details/107574583 moviepy音视频开发专栏文章目录》;
- 付费专栏《https://blog.youkuaiyun.com/laoyuanpython/category_10581071.html OpenCV-Python初学者疑难问题集》为《https://blog.youkuaiyun.com/laoyuanpython/category_9979286.html OpenCV-Python图形图像处理 》的伴生专栏,是笔者对OpenCV-Python图形图像处理学习中遇到的一些问题个人感悟的整合,相关资料基本上都是老猿反复研究的成果,有助于OpenCV-Python初学者比较深入地理解OpenCV,对应文章目录为《https://blog.youkuaiyun.com/LaoYuanPython/article/details/109713407 OpenCV-Python初学者疑难问题集专栏目录 》
- 付费专栏《https://blog.youkuaiyun.com/laoyuanpython/category_10762553.html Python爬虫入门 》站在一个互联网前端开发小白的角度介绍爬虫开发应知应会内容,包括爬虫入门的基础知识,以及爬取优快云文章信息、博主信息、给文章点赞、评论等实战内容。
前两个专栏都适合有一定Python基础但无相关知识的小白读者学习,第三个专栏请大家结合《https://blog.youkuaiyun.com/laoyuanpython/category_9979286.html OpenCV-Python图形图像处理 》的学习使用。
对于缺乏Python基础的同仁,可以通过老猿的免费专栏《https://blog.youkuaiyun.com/laoyuanpython/category_9831699.html 专栏:Python基础教程目录)从零开始学习Python。
如果有兴趣也愿意支持老猿的读者,欢迎购买付费专栏。