一文讲明白大模型显存占用(只考虑单卡)

转载自 | 知乎  作者 | 然荻

1.告诉你一个模型的参数量,你要怎么估算出训练和推理时的显存占用?

2.Lora相比于全参训练节省的显存是哪一部分?Qlora相比Lora呢?

3.混合精度训练的具体流程是怎么样的?

这是我曾在面试中被问到的问题,为了巩固相关的知识,打算系统的写一篇文章,帮助自己复习备战秋招的同时,希望也能帮到各位小伙伴。这篇文章将围绕大模型在单卡训练或推理时的显存占用进行系统学习分析,其中有的知识点可能不会涉及太过深入点到为止(因为我也不会),但尽量保证整个读下来逻辑通畅,通俗易懂(只有小白最懂小白!)。

1.数据精度

想要计算显存,从“原子”层面来看,就需要知道我们的使用数据的精度,因为精度代表了数据存储的方式,决定了一个数据占多少bit。我们都知道:

1 byte = 8 bits

1 KB = 1,024 bytes

1 MB = 1,024 KB

1 GB = 1,024 MB

由此可以明白,一个含有1G参数的模型,如果每一个参数都是32bit(4byte),那么直接加载模型就会占用4x1G的显存。

1.1常见的几种精度类型

个人认为只需掌握下图几个常见的数据类型就好,对于更多的精度类型都是可以做到触类旁通发,图源英伟达安培架构白皮书:

图片

各种精度的数据结构

可以非常直观地看到,浮点数主要是由符号位(sign)、指数位(exponent)和小数位(mantissa)三部分组成。符号位都是1位(0表示正,1表示负),指数位影响浮点数范围,小数位影响精度。其中TF32并不是有32bit,只有19bit不要记错了。BF16指的是Brain Float 16,由Google Brain团队提出。

1.2 具体计算例子

我硕士话,讲太多不如一个形象的图片或者例子来得直接,下面我们将通过一个例子来深入理解如何通过这三个部分来得到我们最终的数据:我以BF16,如今业界用的最广泛的精度类型来举个栗子,下面的数完全是我用克劳德大哥随机画的:

  • 题目:

图片

随机生成的BF16精度数据

- 先给出具体计算公式:

图片

- 然后step by step地分析(不是,怎么还对自己使用上Cot了)

符号位Sign = 1,代表是负数

指数位Exponent = 17,中间一坨是

图片

 

小数位Mantissa = 3,后面那一坨是 

图片

  • 最终结果

三个部分乘起来就是最终结果 -8.004646331359449e-34

  • 注意事项

中间唯一需要注意的地方就是指数位是的全0和全1状态是特殊情况,不能用公式,如果想要深入了解可以看这个博客: 彻底搞懂float16与float32的计算方式-优快云博客 如果感兴趣想更加深入了解如何从FP32转换为BF

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值