# Python CLIP多模态模型完整教程(从入门到实战,代码可复现)

86 阅读15分钟

CLIP(Contrastive Language-Image Pretraining)是OpenAI推出的跨文本-图像的多模态预训练模型,核心价值在于打破了图像与文本的模态壁垒,实现了“以文找图”“以图判文”“零样本图像分类”等核心功能,无需针对特定任务微调即可落地,解决了传统计算机视觉模型“依赖大量标注数据、泛化能力弱”的痛点。

一、前置准备:环境搭建与核心库安装

1. 核心依赖库说明

CLIP的运行依赖以下核心库,其中torch为深度学习框架,clip为OpenAI官方CLIP实现,Pillow用于图像处理,matplotlib用于结果可视化:

  • torch & torchvision:深度学习核心,支持GPU加速(推荐);
  • clip:OpenAI官方CLIP库,提供预训练模型与核心接口;
  • Pillow:图像读取、预处理工具;
  • matplotlib:结果可视化,展示图像与匹配结果;
  • numpy:数值计算辅助,处理嵌入向量与相似度。

2. 环境搭建步骤(推荐conda虚拟环境)

# 1. 创建conda虚拟环境(Python 3.9-3.10兼容性最佳,避免版本冲突)
conda create -n clip_env python=3.9
conda activate clip_env

# 2. 安装PyTorch(优先匹配本地CUDA版本,支持GPU加速;无GPU则安装CPU版本)
# GPU版本(推荐,需提前安装CUDA 11.7+/12.0+,参考PyTorch官网:https://pytorch.org/)
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

# CPU版本(无GPU设备,运行速度较慢,仅用于入门测试)
pip3 install torch torchvision torchaudio

# 3. 安装CLIP官方库与其他依赖(国内清华镜像加速,避免下载超时)
pip install git+https://github.com/openai/CLIP.git -i https://pypi.tuna.tsinghua.edu.cn/simple
pip install Pillow matplotlib numpy -i https://pypi.tuna.tsinghua.edu.cn/simple

3. 环境验证与预训练模型说明

运行以下代码,验证环境是否配置成功,并了解CLIP的预训练模型(不同规模平衡速度与精度):

import clip
import torch
from PIL import Image

# 验证库导入是否成功
print("CLIP库导入成功")
print(f"PyTorch版本:{torch.__version__}")
print(f"CUDA是否可用:{torch.cuda.is_available()}")

# 查看CLIP支持的预训练模型(核心模型列表)
available_models = clip.available_models()
print("CLIP支持的预训练模型:", available_models)

核心预训练模型选择指南(落地优先)

模型名称模型规模推理速度精度表现适用场景
ViT-B/32小(约120M参数)快(GPU≈100fps,CPU≈5fps)平衡入门测试、实时落地、边缘设备
ViT-B/16中(约120M参数)中(GPU≈50fps,CPU≈2fps)较高精度要求中等,无严格实时性需求
ViT-L/14大(约300M参数)慢(GPU≈20fps,CPU≈1fps)高精度需求,离线任务、服务器端
ViT-L/14@336px超大极慢极高科研场景、超高精度图像分析

落地首选ViT-B/32,平衡速度与精度,无需高性能GPU,入门门槛低。

二、CLIP核心基础:文本-图像多模态匹配(入门第一战)

CLIP的核心能力是将图像和文本转换到同一高维嵌入空间,通过计算向量相似度判断两者的匹配程度,核心流程为:「加载模型与预处理→图像/文本预处理→提取嵌入向量→计算余弦相似度→匹配结果可视化」。

完整实战代码(单图像+多文本匹配)

import clip
import torch
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

# 1. 配置设备(优先使用GPU,无GPU则使用CPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"使用设备:{device}")

# 2. 加载CLIP预训练模型与图像/文本预处理工具
# 选择ViT-B/32模型,落地首选
model, preprocess = clip.load("ViT-B/32", device=device)

# 3. 加载并预处理测试图像
# 本地图像路径(替换为你的测试图像,如cat.jpg、dog.jpg)
img_path = "test_cat.jpg"
try:
    # 读取图像
    raw_image = Image.open(img_path).convert("RGB")
    # 预处理图像(CLIP要求的标准化、尺寸调整等)
    processed_image = preprocess(raw_image).unsqueeze(0).to(device)
except Exception as e:
    raise Exception(f"图像加载失败:{e},请检查文件路径与格式")

# 4. 定义待匹配的文本提示词(自定义类别,判断图像与哪个文本最匹配)
text_prompts = [
    "a photo of a cat",
    "a photo of a dog",
    "a photo of a bird",
    "a photo of a flower",
    "a photo of a car"
]
# 预处理文本(CLIP要求的文本编码)
text_tokens = clip.tokenize(text_prompts).to(device)

# 5. 提取图像与文本的嵌入向量(关闭梯度计算,提升推理速度)
with torch.no_grad():
    # 提取图像嵌入向量
    image_embedding = model.encode_image(processed_image)
    # 提取文本嵌入向量
    text_embedding = model.encode_text(text_tokens)

# 6. 计算余弦相似度(判断图像与每个文本的匹配程度)
# 归一化嵌入向量(提升相似度计算准确性)
image_embedding = image_embedding / image_embedding.norm(dim=-1, keepdim=True)
text_embedding = text_embedding / text_embedding.norm(dim=-1, keepdim=True)

# 计算相似度(batch维度匹配)
similarity = (100.0 * image_embedding @ text_embedding.T).softmax(dim=-1)
# 转换为numpy数组,便于后续处理
similarity_np = similarity.cpu().numpy()[0]

# 7. 可视化结果(图像+相似度排名)
plt.figure(figsize=(12, 6))

# 显示原始图像
plt.subplot(1, 2, 1)
plt.imshow(raw_image)
plt.title("Test Image")
plt.axis("off")

# 显示文本相似度排名
plt.subplot(1, 2, 2)
sorted_indices = np.argsort(similarity_np)[::-1]
sorted_prompts = [text_prompts[i] for i in sorted_indices]
sorted_scores = [similarity_np[i] for i in sorted_indices]

# 绘制柱状图
colors = ["#ff6b6b" if i == 0 else "#4ecdc4" for i in range(len(sorted_prompts))]
plt.barh(range(len(sorted_prompts)), sorted_scores, color=colors)
plt.yticks(range(len(sorted_prompts)), sorted_prompts)
plt.xlabel("Similarity Score (Probability)")
plt.title("Text-Image Similarity Ranking")
plt.xlim(0, 1)

# 标注最高相似度得分
plt.text(sorted_scores[0]+0.01, 0, f"{sorted_scores[0]:.4f}", va="center")

plt.tight_layout()
plt.savefig("clip_text_image_matching.png", dpi=300, bbox_inches="tight")
plt.show()

# 8. 输出匹配结果
best_match_idx = sorted_indices[0]
print(f"\n最佳匹配文本:{text_prompts[best_match_idx]}")
print(f"匹配相似度:{similarity_np[best_match_idx]:.4f}")

核心知识点解析

  1. 预处理的重要性:CLIP的preprocessclip.tokenize()分别对图像和文本进行标准化处理,确保输入格式符合模型预训练要求,否则会导致结果失效;
  2. 嵌入向量归一化:通过norm(dim=-1, keepdim=True)归一化后,余弦相似度可直接反映匹配程度,避免向量幅值对结果的干扰;
  3. 余弦相似度计算:使用矩阵乘法@快速计算批量文本与图像的相似度,效率远高于循环计算;
  4. 设备适配:优先使用GPU加速,unsqueeze(0)将单张图像/文本转换为批量格式(CLIP模型要求批量输入)。

三、核心实战1:CLIP零样本图像分类(无需标注,快速落地)

传统图像分类模型需要大量标注数据训练,而CLIP的零样本分类能力可直接使用自定义文本类别完成分类,无需训练,是落地的核心场景,适合快速实现图像分类需求(如商品分类、场景分类、动物分类)。

完整实战代码(自定义类别分类)

import clip
import torch
from PIL import Image
import os

# 1. 配置设备与加载模型
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

# 2. 定义零样本分类核心函数
def clip_zero_shot_classification(img_path, class_names, prompt_template="a photo of a {}"):
    """
    CLIP零样本图像分类
    :param img_path: 图像路径
    :param class_names: 分类类别列表(如["cat", "dog", "bird"])
    :param prompt_template: 文本提示词模板,提升分类精度
    :return: 分类结果(最佳类别、相似度得分)
    """
    # 步骤1:加载并预处理图像
    try:
        raw_image = Image.open(img_path).convert("RGB")
        processed_image = preprocess(raw_image).unsqueeze(0).to(device)
    except Exception as e:
        raise Exception(f"图像加载失败:{e}")
    
    # 步骤2:构建文本提示词(使用模板优化,提升精度)
    text_prompts = [prompt_template.format(cls) for cls in class_names]
    text_tokens = clip.tokenize(text_prompts).to(device)
    
    # 步骤3:提取嵌入向量并计算相似度
    with torch.no_grad():
        image_embedding = model.encode_image(processed_image)
        text_embedding = model.encode_text(text_tokens)
        
        # 归一化与相似度计算
        image_embedding = image_embedding / image_embedding.norm(dim=-1, keepdim=True)
        text_embedding = text_embedding / text_embedding.norm(dim=-1, keepdim=True)
        similarity = (100.0 * image_embedding @ text_embedding.T).softmax(dim=-1)
    
    # 步骤4:解析分类结果
    similarity_np = similarity.cpu().numpy()[0]
    best_class_idx = similarity_np.argmax()
    best_class = class_names[best_class_idx]
    best_score = similarity_np[best_class_idx]
    
    return best_class, best_score, similarity_np

# 3. 定义分类任务(自定义类别,可根据需求修改)
# 示例1:动物分类
# class_names = ["cat", "dog", "goldfish", "parrot", "rabbit"]
# 示例2:场景分类
class_names = ["mountain", "beach", "city", "forest", "desert"]
# 示例3:商品分类
# class_names = ["phone", "laptop", "book", "cup", "watch"]

# 4. 执行零样本分类(替换为你的测试图像路径)
img_path = "test_beach.jpg"
best_class, best_score, all_scores = clip_zero_shot_classification(
    img_path=img_path,
    class_names=class_names,
    prompt_template="a photo of a {} scene"  # 针对场景分类优化模板
)

# 5. 输出分类结果
print("="*50)
print("CLIP零样本图像分类结果")
print("="*50)
print(f"测试图像:{img_path}")
print(f"最佳分类结果:{best_class}")
print(f"匹配相似度:{best_score:.4f}")
print("\n所有类别相似度得分:")
for cls, score in zip(class_names, all_scores):
    print(f"{cls}: {score:.4f}")

零样本分类优化技巧(关键提升精度)

  1. 提示词模板工程:使用a photo of a {}a picture of a {} scene等模板,比直接使用类别名称精度提升10%-30%,核心是贴近CLIP预训练的文本格式;
  2. 类别细化描述:区分相似类别(如"golden retriever" vs "labrador"),避免类别过于宽泛导致分类模糊;
  3. 多模板投票融合:使用多个不同模板生成提示词,对分类结果投票,减少单模板的随机性(如["a photo of a {}", "a picture of a {}", "an image of a {}"]);
  4. 模型升级:若精度不足,可升级为ViT-B/16ViT-L/14模型,牺牲部分速度换取更高精度。

四、核心实战2:CLIP图像检索(文本找图像/图像找图像)

CLIP的另一核心落地场景是图像检索,支持两种模式:「文本到图像检索」(输入文本,找到最匹配的图像)、「图像到图像检索」(输入查询图像,找到相似图像),适合构建图像搜索引擎、商品图库检索等系统。

完整实战代码(构建小型图像检索库)

import clip
import torch
from PIL import Image
import os
import numpy as np
import matplotlib.pyplot as plt

# 1. 配置设备与加载模型
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

# 2. 构建图像检索库(预处理图像并缓存嵌入向量)
def build_image_retrieval_database(img_dir):
    """
    构建图像检索库,缓存所有图像的嵌入向量
    :param img_dir: 图像文件夹路径
    :return: 图像路径列表、图像嵌入向量矩阵
    """
    img_paths = []
    img_embeddings = []
    
    # 遍历图像文件夹
    for filename in os.listdir(img_dir):
        if filename.lower().endswith((".jpg", ".png", ".jpeg")):
            img_path = os.path.join(img_dir, filename)
            try:
                # 预处理图像
                raw_image = Image.open(img_path).convert("RGB")
                processed_image = preprocess(raw_image).unsqueeze(0).to(device)
                
                # 提取嵌入向量
                with torch.no_grad():
                    img_embedding = model.encode_image(processed_image)
                    # 归一化并转换为numpy数组(便于后续检索)
                    img_embedding = img_embedding / img_embedding.norm(dim=-1, keepdim=True)
                    img_embedding_np = img_embedding.cpu().numpy()[0]
                
                # 缓存结果
                img_paths.append(img_path)
                img_embeddings.append(img_embedding_np)
            except Exception as e:
                print(f"跳过无效图像 {img_path}{e}")
    
    # 转换为矩阵(n_images × embedding_dim)
    img_embeddings_mat = np.array(img_embeddings)
    
    print(f"图像检索库构建完成,共加载 {len(img_paths)} 张有效图像")
    return img_paths, img_embeddings_mat

# 3. 文本到图像检索(核心功能)
def text_to_image_retrieval(query_text, img_paths, img_embeddings_mat, top_k=5):
    """
    文本到图像检索,返回Top-K最匹配的图像
    :param query_text: 检索文本
    :param img_paths: 图像路径列表
    :param img_embeddings_mat: 图像嵌入向量矩阵
    :param top_k: 返回最匹配的前K张图像
    :return: Top-K图像路径、Top-K相似度得分
    """
    # 预处理文本并提取嵌入向量
    text_tokens = clip.tokenize([query_text]).to(device)
    with torch.no_grad():
        text_embedding = model.encode_text(text_tokens)
        text_embedding = text_embedding / text_embedding.norm(dim=-1, keepdim=True)
        text_embedding_np = text_embedding.cpu().numpy()[0]
    
    # 计算文本与所有图像的余弦相似度
    similarities = np.dot(img_embeddings_mat, text_embedding_np.T)
    
    # 排序并获取Top-K结果
    sorted_indices = np.argsort(similarities)[::-1][:top_k]
    top_k_img_paths = [img_paths[i] for i in sorted_indices]
    top_k_scores = [similarities[i] for i in sorted_indices]
    
    return top_k_img_paths, top_k_scores

# 4. 图像到图像检索(核心功能)
def image_to_image_retrieval(query_img_path, img_paths, img_embeddings_mat, top_k=5):
    """
    图像到图像检索,返回Top-K最相似的图像
    """
    # 预处理查询图像并提取嵌入向量
    try:
        raw_image = Image.open(query_img_path).convert("RGB")
        processed_image = preprocess(raw_image).unsqueeze(0).to(device)
    except Exception as e:
        raise Exception(f"查询图像加载失败:{e}")
    
    with torch.no_grad():
        query_embedding = model.encode_image(processed_image)
        query_embedding = query_embedding / query_embedding.norm(dim=-1, keepdim=True)
        query_embedding_np = query_embedding.cpu().numpy()[0]
    
    # 计算查询图像与所有图像的余弦相似度
    similarities = np.dot(img_embeddings_mat, query_embedding_np.T)
    
    # 排序并获取Top-K结果(排除自身,若查询图像在检索库中)
    sorted_indices = np.argsort(similarities)[::-1]
    # 过滤自身(若存在)
    if query_img_path in img_paths:
        self_idx = img_paths.index(query_img_path)
        sorted_indices = [idx for idx in sorted_indices if idx != self_idx]
    
    # 取Top-K
    top_k_indices = sorted_indices[:top_k]
    top_k_img_paths = [img_paths[i] for i in top_k_indices]
    top_k_scores = [similarities[i] for i in top_k_indices]
    
    return top_k_img_paths, top_k_scores

# 5. 可视化检索结果
def visualize_retrieval_results(query, top_k_img_paths, top_k_scores, retrieval_type="text"):
    """
    可视化Top-K检索结果
    """
    plt.figure(figsize=(15, 10))
    
    # 显示查询信息
    plt.subplot(2, 3, 1)
    if retrieval_type == "text":
        plt.text(0.5, 0.5, query, ha="center", va="center", fontsize=14)
        plt.title("Query Text")
    else:
        query_img = Image.open(query)
        plt.imshow(query_img)
        plt.title("Query Image")
    plt.axis("off")
    
    # 显示Top-K图像
    for i, (img_path, score) in enumerate(zip(top_k_img_paths, top_k_scores)):
        plt.subplot(2, 3, i+2)
        img = Image.open(img_path)
        plt.imshow(img)
        plt.title(f"Top {i+1}\nScore: {score:.4f}")
        plt.axis("off")
    
    plt.tight_layout()
    plt.savefig(f"clip_{retrieval_type}_retrieval_result.png", dpi=300, bbox_inches="tight")
    plt.show()

# 6. 执行图像检索(实战流程)
if __name__ == "__main__":
    # 步骤1:构建图像检索库(替换为你的图像文件夹路径)
    img_dir = "image_database"
    if not os.path.exists(img_dir):
        os.makedirs(img_dir)
        print(f"请在 {img_dir} 文件夹中放入测试图像后重新运行")
        exit(1)
    img_paths, img_embeddings_mat = build_image_retrieval_database(img_dir)
    
    # 步骤2:文本到图像检索(自定义查询文本)
    query_text = "a photo of a beach with blue water"
    top_k = 5
    top_k_imgs_text, top_k_scores_text = text_to_image_retrieval(
        query_text=query_text,
        img_paths=img_paths,
        img_embeddings_mat=img_embeddings_mat,
        top_k=top_k
    )
    print(f"\n文本检索完成,返回前 {top_k} 张匹配图像")
    visualize_retrieval_results(query_text, top_k_imgs_text, top_k_scores_text, retrieval_type="text")
    
    # 步骤3:图像到图像检索(替换为你的查询图像路径)
    query_img_path = "test_beach.jpg"
    top_k_imgs_img, top_k_scores_img = image_to_image_retrieval(
        query_img_path=query_img_path,
        img_paths=img_paths,
        img_embeddings_mat=img_embeddings_mat,
        top_k=top_k
    )
    print(f"\n图像检索完成,返回前 {top_k} 张相似图像")
    visualize_retrieval_results(query_img_path, top_k_imgs_img, top_k_scores_img, retrieval_type="image")

图像检索落地优化技巧

  1. 嵌入向量缓存:一次性预处理所有图像并缓存嵌入向量,避免每次检索重复提取,大幅提升检索速度;
  2. 向量数据库集成:当图像数量超过1万张时,使用FAISS、Chroma、Pinecone等向量数据库,提升批量检索效率(CLIP嵌入维度为512,适合向量数据库存储);
  3. Top-K取值优化:根据需求调整Top-K,通常取5-10,兼顾检索效果与展示效率;
  4. 图像预处理优化:过滤模糊、低分辨率图像,提升检索库的整体质量,避免无效匹配。

五、CLIP提示词工程(关键优化,提升效果30%+)

CLIP的性能高度依赖文本提示词的质量,相同模型下,优秀的提示词可使精度提升30%以上,核心是贴近CLIP预训练的文本分布(以英文为主,描述简洁、符合自然语言习惯)。

核心提示词技巧与示例

  1. 基础模板优先:使用固定模板构建提示词,避免直接使用类别名称,常用模板:
    • a photo of a {class}(通用物体分类)
    • a picture of a {class} scene(场景分类)
    • an image of a {class} with {attribute}(带属性的分类,如"a photo of a cat with black fur"
  2. 类别细化描述:区分相似类别,提升分类精度,示例:
    • 差:"car"
    • 好:"a photo of a sports car""a photo of a family sedan"
  3. 场景与风格补充:添加场景、光照、风格等信息,贴近图像实际内容,示例:
    • 差:"dog"
    • 好:"a photo of a dog running outdoors in the sun"
  4. 多提示词投票融合:使用多个模板生成提示词,对结果取平均或投票,减少随机性,示例:
    prompt_templates = [
        "a photo of a {}",
        "a picture of a {}",
        "an image of a {}",
        "a photo of a {} in natural light"
    ]
    
  5. 避免过度冗余:提示词简洁明了,避免添加无关信息(如"a very beautiful photo of a cat that is cute and fluffy"),冗余信息会降低匹配精度。

实战:提示词效果对比

import clip
import torch
from PIL import Image

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

# 加载测试图像
img_path = "test_golden_retriever.jpg"
raw_image = Image.open(img_path).convert("RGB")
processed_image = preprocess(raw_image).unsqueeze(0).to(device)

# 定义不同质量的提示词
bad_prompts = ["dog", "golden", "retriever"]
good_prompts = [
    "a photo of a golden retriever",
    "a picture of a golden retriever dog",
    "an image of a golden retriever running outdoors"
]

# 计算相似度
def calculate_similarity(image, prompts):
    text_tokens = clip.tokenize(prompts).to(device)
    with torch.no_grad():
        img_emb = model.encode_image(image)
        text_emb = model.encode_text(text_tokens)
        img_emb = img_emb / img_emb.norm(dim=-1, keepdim=True)
        text_emb = text_emb / text_emb.norm(dim=-1, keepdim=True)
        similarity = (100.0 * img_emb @ text_emb.T).softmax(dim=-1)
    return similarity.cpu().numpy()[0]

# 对比结果
bad_similarities = calculate_similarity(processed_image, bad_prompts)
good_similarities = calculate_similarity(processed_image, good_prompts)

print("低质量提示词相似度:", bad_similarities)
print("高质量提示词相似度:", good_similarities)
print("平均相似度提升:", (good_similarities.mean() - bad_similarities.mean()):.4f)

六、常见问题与避坑指南(CLIP落地专属)

  1. 模型加载缓慢/下载失败
    • 解决方案:手动下载CLIP预训练权重(参考OpenAI CLIP仓库),放置到~/.cache/clip/目录下;使用国内镜像加速,或更换网络环境;
  2. 零样本分类精度低
    • 解决方案:优化提示词模板、使用多提示词融合、升级更大模型(如ViT-B/16)、补充少量样本做少样本分类;
  3. 显存不足(OOM错误)
    • 解决方案:使用更小模型(ViT-B/32)、降低批量大小、使用CPU版本、启用半精度推理(model.half());
  4. 图像检索速度慢
    • 解决方案:缓存嵌入向量、使用向量数据库(FAISS)、优化图像预处理、使用轻量化模型;
  5. 中文提示词效果差
    • 解决方案:CLIP预训练以英文为主,中文效果有限,优先使用英文提示词;若需中文,可使用中文微调版CLIP(如Chinese-CLIP、AltCLIP);
  6. 预处理格式错误
    • 解决方案:确保图像转换为RGB格式(Image.open().convert("RGB")),文本输入为字符串列表,避免特殊字符。