算一算大模型显存占用
大模型在部署的时候,肯定离不开一个概念,叫显存占用。部署大模型,很多时候难点就在于它的大,之因为大,所以要么就根本跑不起来 (显存溢出),要么就推理很慢 (模型太大了)。所以对于大模型的推理优化,它跟 “小模型”的推理优化存在明显的不同之处,当然这里重要不是说这个,这里主要来看怎么计算大模型的显存占用。
比如对于目前比较流行的大模型 LLama2 来说,它就有7B、13B、70B三个版本。B这个单位是十亿的意思,而 M 这个单位是百万的意思。所以像LLama2这种大模型,就可以称之为十亿、百亿级的大模型了。
然后还要清楚一个概念是模型精度,对于深度学习模型来说,一般有的精度就是float32、float16、int8、int4这些,后面的int8、int4这些低精度基本就是专门给部署的时候推理加速用的。如一个float32会占用4个字节32个比特,往后就减半减半,如int8是1字节占用8比特,int4 的占用空间会更加小。
这样有了参数量和模型精度,就可以计算出模型的显存占用了。
比如以 LLama2-13B 为例。
对于float32精度:
$$ \frac{13 \times 10^9 \times 4}{1024^3} \approx 48.42 \text{ G} $$
对于float16精度:
$$ \frac{13 \times 10^9 \times 2}{1024^3} \approx 24.21 \text{ G} $$
再来算一下 LLama2-7B 模型的显存占用。
对于float32精度:
$$ \frac{7 \times 10^9 \times 4}{1024^3} \approx 26.08 \text{ G} $$
对于float16精度显存降一半:
$$ \frac{7 \times 10^9 \times 2}{1024^3} \approx 13 \text{ G} $$
对于int8精度显存再降一半:
$$ \frac{7 \times 10^9 \times 1}{1024^3} \approx 6.5 \text{ G} $$
对于int4精度显存再再降一半:
$$ \frac{7 \times 10^9 \times 0.5}{1024^3} \approx 3.2 \text{ G} $$
上面的显存计算只是适用于模型前向推理,不适用于模型训练,因为训练过程中还会受梯度、优化器参数、bs等影响,一般经验来说,训练时的显存占用会是推理时的好多倍,甚至十几倍。
当然,上述的推理显存占用是理论值,实际肯定会更加多一些的,所以需要预留一些余量。打个比方,比如实测LLama2-13B,上面计算出来差不多是要48.21G,实测就需要52G的样子。当然,这种计算方式也适用于CNN模型前向推理的显存占用。