技术背景
笔者在执行一个Jax的任务中,又发现了一个奇怪的问题,就是明明只分配了很小的矩阵空间,但是在多次的任务执行之后,显存突然就爆了。而且此时已经按照Jax的官方说明配置了 XLA_PYTHON_CLIENT_PREALLOCATE 这个参数为 false ,也就是不进行显存的预分配(默认会分配90%的显存空间以供使用)。然后在网上找到了一些类似的问题,比如参考链接中的1、2、3、4,都是在一些操作后发现未释放显存,这里提供一个实例问题和处理的思路,如果有更好的方案欢迎大家在评论区留言。
问题复现
在未执行任何GPU的任务时,我们可以看到此时 nvidia-smi 的输出如下:
Tue Dec 14 16:14:32 2021
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.42.01 Driver Version: 470.42.01 CUDA Version: 11.4 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 Quadro RTX 4000 On | 00000000:03:00.0 On | N/A |
| 30% 43C P8 20W / 125W | 1260MiB / 7979MiB | 10% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
| 1 Quadro RTX 4000 On | 00000000:A6:00.0 Off | N/A |
| 30% 34C P8 7W / 125W | 10MiB / 7982MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================

在执行Jax任务时,作者遇到显存未被释放的问题,即使设置了不预分配显存。通过问题复现和解决思路,作者发现即使删除对象,显存仍被占用。为了解决这个问题,采用了进程管理的方式,通过multiprocessing创建子进程执行任务,确保显存及时释放,避免显存占用堆积。此方法在不影响返回值的情况下,有效控制了GPU显存的占用。
最低0.47元/天 解锁文章
2509

被折叠的 条评论
为什么被折叠?



