【与模型打交道 · 第 2 篇】PyTorch 显存 OOM 频繁?4 个策略让你的 GPU 稳定运行 AI 大模型

在这里插入图片描述

与模型打交道 · 第 2 篇 | 预估阅读:10 分钟


上一篇换了 4 家 AI 模型,代码只动了 1 行——这个架构设计让老板随便折腾

凌晨 3 点的电话

小禾以为把 LLM 适配层搞定后,可以睡个安稳觉了。

直到凌晨 2:47,手机响了。

02:47 - 告警邮件:服务响应超时
02:48 - 告警短信:健康检查连续失败
02:49 - 电话铃声:老板的来电...

小禾睡眼惺忪地爬起来,打开电脑。

ssh production-server
nvidia-smi

屏幕上的数字让他彻底清醒了:

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 535.154.05   Driver Version: 535.154.05   CUDA Version: 12.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|===============================+======================+======================|
|   0  NVIDIA RTX 4090     On   | 00000000:01:00.0 Off |                  Off |
| 89%   82C    P2   321W / 450W |  24564MiB / 24564MiB |    100%      Default |
+-------------------------------+----------------------+----------------------+

显存:100%。

温度:82°C。

GPU 利用率:100%。

服务日志里全是:

CUDA out of memory. Tried to allocate 512.00 MiB (GPU 0; 24.00 GiB total capacity; 23.12 GiB already allocated)

小禾叹了口气,重启服务后一切恢复正常。

凌晨 3 点 15 分,小禾躺回床上,盯着天花板,开始思考人生。

这种事情,不能再来第二次了。


问题出在哪?

第二天一早,小禾开始分析昨晚的问题。

打开监控数据,他发现了一个有趣的规律:

23:00  显存使用:8GB / 24GB   ✓
23:30  显存使用:12GB / 24GB  ✓
00:00  显存使用:16GB / 24GB  !
00:30  显存使用:20GB / 24GB  !
01:00  显存使用:23GB / 24GB  !
01:30  显存使用:24GB / 24GB  💥 OOM

显存一直在涨,从没降下来过。

小禾检查了代码,发现了问题:

# 问题代码 1:张量没有及时释放
def generate_image(prompt: str):
    # 生成图片
    image_tensor = model(prompt)

    # 转成 PIL 图片
    pil_image = to_pil(image_tensor)

    return pil_image
    # image_tensor 还在 GPU 上!
    # 虽然函数返回了,但 Python 的垃圾回收不一定立即执行
# 问题代码 2:中间结果累积
results = []
for prompt in prompts:
    output = model(prompt)
    results.append(output)  # GPU 张量不断累积
    # 100 个请求后,显存就满了
# 问题代码 3:并发过高
# 10 个用户同时请求
# 每个请求需要 2GB 显存
# 10 × 2GB = 20GB
# 加上模型本身 6GB
# 总计 26GB > 24GB
# 💥

小禾总结了三个杀手:

  1. 张量没释放:生成完的数据还占着显存
  2. 中间变量累积:循环里每次生成都在堆积
  3. 并发无限制:来多少请求接多少,显存不够也硬上

第一道防线:限制并发

最简单的办法,不让太多请求同时跑。

小禾加了一个信号量:

import asyncio
from fastapi import FastAPI, HTTPException

# 最多允许 2 个请求同时使用 GPU
gpu_semaphore = asyncio.Semaphore(2)

@app.post("/generate")
async def generate(request: GenerateRequest):
    """生成图片接口"""

    # 检查当前是否已经满载
    if gpu_semaphore.locked() and gpu_semaphore._value == 0:
        # GPU 忙不过来,直接告诉用户
        raise HTTPException(
            status_code=503,
            detail="服务器忙,请稍后再试",
            headers={"Retry-After": "30"}
        )

    async with gpu_semaphore:
        # 这里最多只有 2 个请求在跑
        result = await model.generate(request.prompt)
        return result

这招很简单,但效果立竿见影。

就像餐厅门口放了个叫号机,座位满了就外面等着,不会把厨房挤爆。


第二道防线:显存门槛检查

光限制并发还不够。

万一一个请求要生成超大图片,一个请求就把显存吃满呢?

小禾加了个显存检查:

import torch

# 至少要有 4GB 空闲才接受新请求
MIN_FREE_MEMORY_GB = 4.0

def get_gpu_memory_info():
    """获取 GPU 显存信息"""
    if not torch.cuda.is_available():
        return None

    device = torch.cuda.current_device()
    total = torch.cuda.get_device_properties(device).total_memory
    reserved = torch.cuda.memory_reserved(device)
    allocated = torch.cuda.memory_allocated(device)

    return {
        "total_gb": total / 1024**3,
        "reserved_gb": reserved / 1024**3,
        "allocated_gb": allocated / 1024**3,
        "free_gb": (total - reserved) / 1024**3,
    }

@app.post("/generate")
async def generate(request: GenerateRequest):
    """生成图片接口"""

    # 检查显存余量
    memory_info = get_gpu_memory_info()
    if memory_info and memory_info["free_gb"] < MIN_FREE_MEMORY_GB:
        raise HTTPException(
            status_code=503,
            detail=f"显存不足,需要 {MIN_FREE_MEMORY_GB}GB,当前只有 {memory_info['free_gb']:.1f}GB",
        )

    async with gpu_semaphore:
        result = await model.generate(request.prompt)
        return result

现在不只是看并发数,还要看实际的显存情况。

就像银行取款,不只是看排队人数,还要看 ATM 机里还有没有钱。


第三道防线:自动清理机制

预防是一方面,及时清理也很重要。

小禾写了个清理函数,每次生成后都调用:

import gc
import torch

def cleanup_gpu():
    """清理 GPU 资源"""

    # 第一步:触发 Python 垃圾回收
    # 这会清理所有没有引用的对象
    gc.collect()

    # 第二步:清理 PyTorch 的显存缓存
    if torch.cuda.is_available():
        # empty_cache() 释放缓存的显存
        # PyTorch 会缓存显存以加速后续分配
        # 但长期运行后,缓存会越来越大
        torch.cuda.empty_cache()

        # synchronize() 等待所有 CUDA 操作完成
        # 确保之前的操作真的结束了
        torch.cuda.synchronize()

@app.post("/generate")
async def generate(request: GenerateRequest):
    """生成图片接口"""

    try:
        result = await model.generate(request.prompt)
        return result
    finally:
        # 无论成功还是失败,都要清理
        cleanup_gpu()

这里有个关键:用 finally 而不是直接写在后面。

因为如果生成过程出错了,没有 finally 的话清理代码不会执行。

就像去完厕所要冲水,不管用得舒不舒服都得冲。


第四道防线:上下文管理器

小禾觉得每个接口都写 try...finally 太啰嗦,于是封装了一个上下文管理器:

from contextlib import contextmanager

@contextmanager
def gpu_memory_guard(min_free_gb: float = 2.0):
    """
    GPU 显存守卫

    使用方式:
    with gpu_memory_guard(min_free_gb=4.0):
        result = model.generate(prompt)
    """

    # 进入前:检查并尝试清理
    memory_info = get_gpu_memory_info()
    if memory_info:
        free = memory_info["total_gb"] - memory_info["reserved_gb"]
        if free < min_free_gb:
            # 先清理一波试试
            cleanup_gpu()
            # 再检查一次
            memory_info = get_gpu_memory_info()
            free = memory_info["total_gb"] - memory_info["reserved_gb"]
            if free < min_free_gb:
                raise RuntimeError(f"显存不足: {free:.1f}GB < {min_free_gb}GB")

    try:
        yield  # 执行业务代码
    finally:
        # 退出时:无论如何都清理
        cleanup_gpu()

# 使用示例
async def generate_image_safe(prompt: str) -> Image:
    """安全的图片生成"""
    with gpu_memory_guard(min_free_gb=4.0):
        tensor = model(prompt)
        # 把结果搬到 CPU 再返回
        # 这样 GPU 上的张量就可以被释放了
        image = tensor.cpu().numpy()
        return Image.fromarray(image)

现在只要 with gpu_memory_guard() 一包,进门检查显存,出门自动清理。

代码简洁,逻辑清晰。


优雅处理 OOM

说了这么多预防措施,但世事无常,OOM 还是可能发生。

小禾设计了一个优雅降级方案:

@app.post("/generate")
async def generate(request: GenerateRequest):
    """生成图片,支持自动降级"""

    # 第一次尝试:原始尺寸
    try:
        return await _do_generate(
            request.prompt,
            width=request.width,
            height=request.height
        )
    except torch.cuda.OutOfMemoryError:
        app_logger.warning(
            f"OOM!尝试降级... 原尺寸: {request.width}x{request.height}"
        )
        cleanup_gpu()

    # 第二次尝试:尺寸减半
    try:
        result = await _do_generate(
            request.prompt,
            width=request.width // 2,
            height=request.height // 2
        )
        return {
            **result,
            "warning": "因显存不足,图片尺寸已自动缩小",
            "original_size": f"{request.width}x{request.height}",
            "actual_size": f"{request.width//2}x{request.height//2}",
        }
    except torch.cuda.OutOfMemoryError:
        app_logger.error("二次降级仍然 OOM")
        cleanup_gpu()

    # 第三次尝试:最小尺寸
    try:
        result = await _do_generate(
            request.prompt,
            width=512,
            height=512
        )
        return {
            **result,
            "warning": "因显存严重不足,已使用最小尺寸",
            "actual_size": "512x512",
        }
    except torch.cuda.OutOfMemoryError:
        # 实在不行了,返回错误
        raise HTTPException(
            status_code=503,
            detail="GPU 显存已耗尽,请稍后再试或减小图片尺寸"
        )

这个设计的好处是:

  1. 用户不会直接看到 500 错误,总能拿到结果
  2. 结果里带了告警信息,用户知道发生了什么
  3. 有条不紊地尝试,而不是一失败就崩溃

就像餐厅没有牛排了,服务员会问:“牛排没了,鸡排可以吗?”

而不是直接让你饿着走。


监控和告警

预防做了,降级也做了,但小禾还是不放心。

他需要一个监控系统,能在问题变严重之前提醒他:

import asyncio
from datetime import datetime

class GPUMonitor:
    """GPU 显存监控器"""

    def __init__(
        self,
        warning_threshold: float = 0.8,   # 80% 警告
        critical_threshold: float = 0.95,  # 95% 严重
        check_interval: int = 30           # 30 秒检查一次
    ):
        self.warning_threshold = warning_threshold
        self.critical_threshold = critical_threshold
        self.check_interval = check_interval
        self._running = False

    async def start(self):
        """启动监控"""
        self._running = True
        app_logger.info("🔍 GPU 监控已启动")

        while self._running:
            await self._check()
            await asyncio.sleep(self.check_interval)

    def stop(self):
        """停止监控"""
        self._running = False

    async def _check(self):
        """执行一次检查"""
        memory_info = get_gpu_memory_info()
        if not memory_info:
            return

        usage_ratio = memory_info["reserved_gb"] / memory_info["total_gb"]
        timestamp = datetime.now().strftime("%H:%M:%S")

        if usage_ratio > self.critical_threshold:
            # 🚨 严重告警
            app_logger.critical(
                f"[{timestamp}] 🚨 GPU 显存严重不足!"
                f"使用率: {usage_ratio:.1%} "
                f"({memory_info['reserved_gb']:.1f}GB / {memory_info['total_gb']:.1f}GB)"
            )
            # 发送告警通知(邮件、钉钉等)
            await self._send_alert("critical", usage_ratio, memory_info)

        elif usage_ratio > self.warning_threshold:
            # ⚠️ 警告
            app_logger.warning(
                f"[{timestamp}] ⚠️ GPU 显存使用较高: {usage_ratio:.1%}"
            )

        else:
            # ✓ 正常
            app_logger.debug(f"[{timestamp}] ✓ GPU 显存正常: {usage_ratio:.1%}")

    async def _send_alert(self, level: str, usage: float, info: dict):
        """发送告警通知"""
        # 这里可以集成钉钉、Slack、邮件等
        pass

# 在应用启动时启动监控
@app.on_event("startup")
async def startup():
    monitor = GPUMonitor()
    asyncio.create_task(monitor.start())

现在每 30 秒检查一次显存,80% 警告,95% 严重告警。

再也不用等到凌晨 3 点才知道出问题了。


健康检查端点

运维同学需要一个接口来检查 GPU 状态,小禾加了个健康检查端点:

@app.get("/health/gpu")
async def gpu_health():
    """GPU 健康检查接口"""

    memory_info = get_gpu_memory_info()

    if not memory_info:
        return {
            "status": "no_gpu",
            "message": "没有检测到 GPU"
        }

    usage = memory_info["reserved_gb"] / memory_info["total_gb"]

    # 判断状态
    if usage < 0.8:
        status = "healthy"
    elif usage < 0.95:
        status = "warning"
    else:
        status = "critical"

    return {
        "status": status,
        "gpu": {
            "total_gb": round(memory_info["total_gb"], 2),
            "used_gb": round(memory_info["reserved_gb"], 2),
            "free_gb": round(memory_info["free_gb"], 2),
            "usage_percent": round(usage * 100, 1),
        },
        "thresholds": {
            "warning": "80%",
            "critical": "95%"
        }
    }

现在运维可以把这个接口配到监控系统里,一目了然。


收益总结

小禾算了笔账:

指标改造前改造后
凌晨被叫醒次数每周 2-3 次0 次
OOM 导致服务宕机经常极少
用户看到 500 错误频繁几乎没有
问题发现时间凌晨 3 点用户投诉提前 10 分钟告警
显存使用效率乱七八糟稳定在 60-70%

最重要的是:小禾终于能睡个安稳觉了。


小禾的感悟

那个凌晨 3 点的电话,
让我明白一个道理:

GPU 不是取款机,
你想取多少就取多少。

它更像一个游泳池,
容量有限,
来的人多了就会挤。

预防 > 治疗,
监控 > 救火。

与其半夜爬起来重启,
不如提前做好限制和告警。

资源管理不是可选项,
是必修课。

现在每次看到 nvidia-smi,
我都会心里默念:
"别爆、别爆、别爆..."

但有了这套方案,
我终于可以安心睡觉了。

小禾关掉监控面板,看了眼 GPU 使用率:62%。

稳得一批。


下一篇预告:一个请求加载模型要 30 秒,用户早跑了

模型预加载,让首次请求不再等待。

敬请期待。

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

程序员义拉冠

你的鼓励是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值