多人共用服务器抢gpu脚本

18 阅读1分钟
import time
import pynvml
import torch
from threading import Thread

# 全局变量,记录已被占用的GPU
occupied_gpus = set()


def occupy_gpu(gpu_id, target_percent=95):
    """持续占用指定GPU的显存到目标百分比"""
    try:
        # 设置当前设备
        torch.cuda.set_device(gpu_id)
        device = torch.device(f"cuda:{gpu_id}")

        # 获取显存总量
        pynvml.nvmlInit()
        handle = pynvml.nvmlDeviceGetHandleByIndex(gpu_id)
        mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
        total_mem = mem_info.total

        # 计算需要占用的显存量(保留5%余量)
        allocate_mem = int(total_mem * (target_percent / 100) * 0.95)

        # 创建占位张量
        block = torch.ones((allocate_mem // 4,), dtype=torch.float32, device=device)

        # 标记该GPU已被占用
        occupied_gpus.add(gpu_id)
        print(f"Successfully occupied GPU {gpu_id}")

        # 保持显存占用
        while True:
            time.sleep(3600)  # 每小时续约一次
    except Exception as e:
        print(f"Error in GPU {gpu_id}: {str(e)}")
    finally:
        # 如果线程退出,释放该GPU
        occupied_gpus.discard(gpu_id)
        print(f"Released GPU {gpu_id}")


def monitor_gpus():
    pynvml.nvmlInit()
    device_count = pynvml.nvmlDeviceGetCount()

    if device_count != 8:
        print(f"Warning: Detected {device_count} GPUs, expected 8")

    while True:
        for gpu_id in range(device_count):
            try:
                # 如果GPU已经被占用,跳过
                if gpu_id in occupied_gpus:
                    continue

                # 检查显存使用率
                handle = pynvml.nvmlDeviceGetHandleByIndex(gpu_id)
                mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
                used_percent = (mem_info.used / mem_info.total) * 100

                # 如果显存使用率低于50%,尝试占用
                if used_percent < 50:
                    print(f"GPU {gpu_id} usage {used_percent:.1f}% < 50%, attempting to occupy...")
                    # 在新线程中启动显存占用
                    t = Thread(target=occupy_gpu, args=(gpu_id,))
                    t.daemon = True
                    t.start()

            except pynvml.NVMLError as e:
                print(f"Error accessing GPU {gpu_id}: {str(e)}")

        # 如果所有GPU都被占用,退出监控
        if len(occupied_gpus) == device_count:
            print("All GPUs are occupied. Exiting monitor.")
            break

        time.sleep(60)  # 每60秒检查一次


if __name__ == "__main__":
    try:
        monitor_gpus()
    except KeyboardInterrupt:
        print("Stopping monitoring...")
    finally:
        pynvml.nvmlShutdown()