Soup's Blog

Back

Vision-Language Models(VLM)学习(三)复现OpenAI的CLIP模型Blur image

前面已经学习了CLIP模型的原理,本节就基于MNIST手写数字数据集实现一个CLIP模型。

首先实现Images Encoder

Images Encoder是由一个ResNet网络实现的。经过卷积操作后,通过x=self.wi(x.view(x.size(0),-1))将特征展平,然后用一个Linear层,将维度映射到8

接下来是Text Encoder

Text Encoder是由几个Linear构成的,首先输入的就是一个长度为10one-hot向量,经过编码得到特征,也是一个长度为8的特征向量。

接下来就是CLIP代码:

class CLIP(nn.Module):
    def __init__(self,):
        super().__init__()
        self.img_enc=ImgEncoder()
        self.text_enc=TextEncoder()

    def forward(self,img_x,text_x):
        img_emb=self.img_enc(img_x)
        text_emb=self.text_enc(text_x)
        return img_emb@text_emb.T
bash

这里就直接调用Encoder进行编码,然后进行矩阵乘法操作。

接下来就是训练代码:

值的注意的是,在训练的过程中,每一个批次的样本都要包括10个不同手写数字的样本,这是因为我们在做CLIP训练时,如果同一个批次出现两张同类别的手写数字图,就会出现下面的情况: 在这里插入图片描述 假设T1T_1TNT_N是同一个标签,假设是9,而I1I_1是数字9对应的图片,此时在计算Loss时,会强制认为T1T_1I1I_1是匹配的,从而打压模型对TNT_NI1I_1的判断(事实上TNT_NI1I_1也是匹配的,但是这时模型强制认为T1T_1I1I_1是匹配的),所以会导致模型“脑裂”。

所以训练代码中的:

是来保证每轮训练的 batch 必须包含所有10个类别。

最后就是推理代码:

这里是实现了两种任务,第一种就是分类任务,第二种就是通过图片相似度实现以图搜图的任务。

Vision-Language Models(VLM)学习(三)复现OpenAI的CLIP模型
http://www.soupcola.top/blog/vlm_blogs/vlm_blog-3
Author Soup Cola
Published at 2026年1月31日