

前面已经学习了CLIP模型的原理,本节就基于MNIST手写数字数据集实现一个CLIP模型。
首先实现Images Encoder:
class ResidualBlock(nn.Module):
def __init__(self,in_channels,out_channels,stride):
super().__init__()
self.conv1=nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=3,padding=1,stride=stride)
self.bn1=nn.BatchNorm2d(out_channels)
self.conv2=nn.Conv2d(in_channels=out_channels,out_channels=out_channels,kernel_size=3,padding=1,stride=1)
self.bn2=nn.BatchNorm2d(out_channels)
self.conv3=nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=1,padding=0,stride=stride)
def forward(self,x):
y=F.relu(self.bn1(self.conv1(x)))
y=self.bn2(self.conv2(y))
z=self.conv3(x)
return F.relu(y+z)
class ImgEncoder(nn.Module):
def __init__(self):
super().__init__()
self.res_block1=ResidualBlock(in_channels=1,out_channels=16,stride=2) # (batch,16,14,14)
self.res_block2=ResidualBlock(in_channels=16,out_channels=4,stride=2) # (batch,4,7,7)
self.res_block3=ResidualBlock(in_channels=4,out_channels=1,stride=2) # (batch,1,4,4)
self.wi=nn.Linear(in_features=16,out_features=8)
self.ln=nn.LayerNorm(8)
def forward(self,x):
x=self.res_block1(x)
x=self.res_block2(x)
x=self.res_block3(x)
x=self.wi(x.view(x.size(0),-1))
x=self.ln(x)
return xbashImages Encoder是由一个ResNet网络实现的。经过卷积操作后,通过x=self.wi(x.view(x.size(0),-1))将特征展平,然后用一个Linear层,将维度映射到8。
接下来是Text Encoder:
class TextEncoder(nn.Module):
def __init__(self):
super().__init__()
self.emb=nn.Embedding(num_embeddings=10,embedding_dim=16)
self.dense1=nn.Linear(in_features=16,out_features=64)
self.dense2=nn.Linear(in_features=64,out_features=16)
self.wt=nn.Linear(in_features=16,out_features=8)
self.ln=nn.LayerNorm(8)
def forward(self,x):
x=self.emb(x)
x=F.relu(self.dense1(x))
x=F.relu(self.dense2(x))
x=self.wt(x)
x=self.ln(x)
return xbashText Encoder是由几个Linear构成的,首先输入的就是一个长度为10的one-hot向量,经过编码得到特征,也是一个长度为8的特征向量。
接下来就是CLIP代码:
class CLIP(nn.Module):
def __init__(self,):
super().__init__()
self.img_enc=ImgEncoder()
self.text_enc=TextEncoder()
def forward(self,img_x,text_x):
img_emb=self.img_enc(img_x)
text_emb=self.text_enc(text_x)
return img_emb@text_emb.Tbash这里就直接调用Encoder进行编码,然后进行矩阵乘法操作。
接下来就是训练代码:
DEVICE='cuda' if torch.cuda.is_available() else 'cpu' # 设备
dataset=MNIST() # 数据集
model=CLIP().to(DEVICE) # 模型
try: # 加载模型
model.load_state_dict(torch.load('model.pth'))
except:
pass
optimzer=torch.optim.Adam(model.parameters(),lr=1e-3) # 优化器
'''
训练模型
'''
ITER_BATCH_COUNT=100000 # 迭代次数
BATCH_SIZE=64 # 从batch内选出10个不一样的数字
TARGET_COUNT=10 # 共10种数字
dataloader=DataLoader(dataset,batch_size=BATCH_SIZE,shuffle=True,num_workers=10,persistent_workers=True) # 数据加载器
for i in range(ITER_BATCH_COUNT):
while True:
imgs,labels=next(iter(dataloader))
if torch.unique(labels).shape[0]<TARGET_COUNT: # 未覆盖10种数字
continue
# 挑选出10个数字
target=set()
indexes=[]
for j in range(BATCH_SIZE):
if labels[j].item() in target:
continue
target.add(labels[j].item())
indexes.append(j)
if len(target)==TARGET_COUNT:
break
imgs=imgs[indexes]
labels=labels[indexes]
break
logits=model(imgs.to(DEVICE),labels.to(DEVICE))
targets=torch.arange(0,TARGET_COUNT).to(DEVICE)
loss_i=F.cross_entropy(logits,targets)
loss_t=F.cross_entropy(logits.permute(1,0),targets)
loss=(loss_i+loss_t)/2
optimzer.zero_grad()
loss.backward()
optimzer.step()
if i%1000==0:
print('iter:{},loss:{}'.format(i,loss))
torch.save(model.state_dict(),'.model.pth')
os.replace('.model.pth','model.pth')bash值的注意的是,在训练的过程中,每一个批次的样本都要包括10个不同手写数字的样本,这是因为我们在做CLIP训练时,如果同一个批次出现两张同类别的手写数字图,就会出现下面的情况:
假设和是同一个标签,假设是9,而是数字9对应的图片,此时在计算Loss时,会强制认为和是匹配的,从而打压模型对和的判断(事实上和也是匹配的,但是这时模型强制认为和是匹配的),所以会导致模型“脑裂”。
所以训练代码中的:
while True:
imgs,labels=next(iter(dataloader))
if torch.unique(labels).shape[0]<TARGET_COUNT: # 未覆盖10种数字
continue
# 挑选出10个数字
target=set()
indexes=[]
for j in range(BATCH_SIZE):
if labels[j].item() in target:
continue
target.add(labels[j].item())
indexes.append(j)
if len(target)==TARGET_COUNT:
break
imgs=imgs[indexes]
labels=labels[indexes]
breakbash是来保证每轮训练的 batch 必须包含所有10个类别。
最后就是推理代码:
DEVICE='cuda' if torch.cuda.is_available() else 'cpu' # 设备
dataset=MNIST() # 数据集
model=CLIP().to(DEVICE) # 模型
model.load_state_dict(torch.load('model.pth'))
model.eval() # 预测模式
'''
1、对图片分类
'''
image,label=dataset[0]
print('正确分类:',label)
plt.imshow(image.permute(1,2,0))
plt.show()
targets=torch.arange(0,10) #10种分类
logits=model(image.unsqueeze(0).to(DEVICE),targets.to(DEVICE)) # 1张图片 vs 10种分类
print(logits)
print('CLIP分类:',logits.argmax(-1).item())
'''
2、图像相似度
'''
other_images=[]
other_labels=[]
for i in range(1,101):
other_image,other_label=dataset[i]
other_images.append(other_image)
other_labels.append(other_label)
# 其他100张图片的向量
other_img_embs=model.img_enc(torch.stack(other_images,dim=0).to(DEVICE))
# 当前图片的向量
img_emb=model.img_enc(image.unsqueeze(0).to(DEVICE))
# 计算当前图片和100张其他图片的相似度
logtis=img_emb@other_img_embs.T
values,indexs=logtis[0].topk(5) # 5个最相似的
plt.figure(figsize=(15,15))
for i,img_idx in enumerate(indexs):
plt.subplot(1,5,i+1)
plt.imshow(other_images[img_idx].permute(1,2,0))
plt.title(other_labels[img_idx])
plt.axis('off')
plt.show()bash这里是实现了两种任务,第一种就是分类任务,第二种就是通过图片相似度实现以图搜图的任务。