Поиск
ruCLIP Base [vit-base-patch32-384]-image

ruCLIP Base [vit-base-patch32-384]

Russian Contrastive Language–Image Pre-training. Модель-ранжировщик текстов и изображений, 150 млн параметров.

ruCLIP - это мультимодальная модель для ранжирования изображений и подписей к ним, а также получения семантической близости изображений и текстов. Архитектура впервые представлена OpenAI.

Лицензия

Apache 2.0

Размер файлов

0.6 GB

Версия

0.1

ruCLIP (Russian Contrastive Language – Image Pre-training) обучена для русского языка на открытых данных, собранных из Рунета.

240 миллионов уникальных пар картинка-текст в обучающей выборке.

Информация об использовании модели:

ruCLIP — это модель, состоящая из двух частей (или нейронных сетей):

  1. Image Encoder — часть для кодирования изображений и перевода их в общее векторное пространство. В качестве архитектуры в оригинальной работе берутся ResNet разных размеров и Visual Transformer — тоже разных размеров. В ruCLIP Base в качестве image encoder используется ViT-B/32.
  2. Text Encoder — часть для кодирования текстов и перевода их в общее векторное пространство. В качестве архитектуры используется текстовый Transformer.

Similarity

ruclip0

[{"собачка": 0.9324831366539001}, {"кошка": 0.02790665067732334}, {"мышка": 0.029953204095363617}, {"машина": 0.0034359991550445557}, {"стол": 0.0031091528944671154}, {"дом": 0.0018060706788673997}, {"жидкость": 0.0013057001633569598}]

KFServing

Класс KFServingRuClipModel представлен ниже.

Вы можете подавать на вход модели:

  • ссылки на изображения
  • картинки в формате base64

На выходе модель покажет близость между текстами и картинками. Чем ближе значение к 1, тем ближе семантическое сходство картинки и текста.

import os
import json
from collections import OrderedDict
from typing import Dict
import kfserving
import requests
import torch
import numpy as np
import io
from io import BytesIO
import base64
from PIL import Image
import re

from ruclip import CLIP, RuCLIPProcessor

def open_images_base64(img_strs):
    return [Image.open(BytesIO(base64.b64decode(img_str))) for img_str in img_strs]

def open_image_link(links):
    imgs = []
    for img_link in links:
        response = requests.get(img_link)
        imgs.append(Image.open(BytesIO(response.content)))
    return imgs

def create_image(sim_plt):
    my_stringIObytes = io.BytesIO()
    sim_plt.savefig(my_stringIObytes, format="jpg")
    my_stringIObytes.seek(0)
    my_base64 = base64.b64encode(my_stringIObytes.read())
    return my_base64

class KFServingRuClipModel(kfserving.KFModel):
    def __init__(self, title: str, model_path="./ruclip-vit-base-patch32-384"):
        super().__init__(name)
        self.name = name
        self.ready = False
        self.model_path = model_path

    def load(self):
        self.device = "cuda"
        self.clip = CLIP.from_pretrained(self.model_path).eval().to(self.device)
        self.clip_processor = RuCLIPProcessor.from_pretrained(self.model_path)
        self.ready = True
        
    def get_text_latents(self, texts):
        with torch.no_grad():
            inputs = self.clip_processor(text=texts, images=None)
            text_latents = self.clip.encode_text(
                input_ids=inputs["input_ids"].to(self.device), 
            )
            text_latents = text_latents / text_latents.norm(dim=-1, keepdim=True)
        return text_latents
        
    def get_logits(self, text_latents, pil_images):
        with torch.no_grad():
            inputs = self.clip_processor(text=None, images=pil_images)
            image_latents = self.clip.encode_image(
                pixel_values=inputs["pixel_values"].to(self.device)
            )
            image_latents = image_latents / image_latents.norm(dim=-1, keepdim=True)
            logits_per_text = torch.matmul(text_latents, image_latents.t())
            logits_per_image = logits_per_text.t()
        return logits_per_text, logits_per_image
    
    def get_similarity_scores(self, texts, images):
        """
        Find the most similar image to text.
        `texts`: array of texts or one text ["some_desc"]
        `images`: array of images.
        """
        text_latents = self.get_text_latents(texts)
        results = []
        for pil_image in images:
            _, logits_per_image = self.get_logits(text_latents, [pil_image])
            probs_raw = (logits_per_image * self.clip.logit_scale.exp().detach()).softmax(dim=-1)[0]
            label_id = probs_raw.argmax().item()
            confidence = probs_raw.max().item()
            
            probs = []
            for i in range(len(texts)):
                probs.append({texts[i]: probs_raw[i].item()})

            buffered = BytesIO()
            pil_image.save(buffered, format="JPEG")
            img = base64.b64encode(buffered.getvalue()).decode("utf-8")
            results.append({"image": img, "text": texts[label_id], "confidence" : confidence, "all_res": probs})
        
        return results
    
    def predict(self, request: Dict) -> Dict:
        texts = request["instances"][0]["texts"]
        
        img_strs = request["instances"][0].get("images", None)
        if img_strs is not None:
            images = open_images_base64(img_strs)
        
        images_links = request["instances"][0].get("image_links", None)

        if images_links is not None:
            images = open_image_link(images_links)
          
        error_msg = None
        predictions = []
        try:
            predictions = self.get_similarity_scores(texts, images)
        except Exception as ex:
            print(ex)
            error_msg = ex

        if error_msg is not None:
            return {"predictions": predictions, "error_message": str(error_msg)}
        else:
            return {"predictions": predictions}

Функция predict возвращает массив со словарями для каждой картинки вида:

     predictions: [
				{
					"image": "base64 image", 
					"text": "наиболее подходящий текст для картинки" ,
          "confidence": "мера близости для лучшей картинки", 
          "all_res": [{"менее вероятный текст": 0.31}, {"другой текст": 0.33},} ...]
				}
     ]

Пример работы с моделью

!pip install -r requirements.txt
from kfserving_ru_CLIP import KFServingRuClipModel

model = KFServingRuClipModel("kfserving-clip")
model.load()

# zero-shot and links
url_cat = "https://cs11.livemaster.ru/storage/topic/NxN/2c/9b/9cf0a41d13ecb11439e6145dff576315df83op.jpg?h=3KvOPndE06tlraLLSmkHPQ"
url_cat2 = "https://pbs.twimg.com/profile_images/560798448962633728/rDEdUfV_.jpeg"
url_dog = "https://ichef.bbci.co.uk/news/640/cpsprodpb/475B/production/_98776281_gettyimages-521697453.jpg"
result = model.predict({"instances": [{
        "texts": ["собачка", "кошка", "мышка", "машина", "стол", "дом", "жидкость"],
        "image_links": [url_dog, url_cat, url_cat2]
    }]
})

"Res: ", result

`{"predictions": "[{"image": b"/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQ...", "text": "собачка", "confidence": 0.9324831366539001, "all_res": [{"собачка": 0.9324831366539001}, {"кошка": 0.02790665067732334}, {"мышка": 0.029953204095363617}, {"машина": 0.0034359991550445557}, {"стол": 0.0031091528944671154}, {"дом": 0.0018060706788673997}, {"жидкость": 0.0013057001633569598}]}, ... ]", "error_message": None}`

Оценки модели на популярных датасетах

Косинусная близость между текстами и картинками для модели ruCLIP

ruclip3

Предсказания топ 5 классов для изображений с помощью ruCLIP

ruclip4

Сравнение моделей на задаче zero-shot классификации для разных датасетов. Жирным выделена лучшая метрика для каждого из датасетов без учета оригинального CLIP без переводчика.

ruCLIP Small [rugpt3-small]ruCLIP Base [vit-base-patch32-384]CLIP [vit-base-patch16-224] original + MTCLIP [vit-base-patch16-224] original
Food1010.1380.6420.6630.883
CIFAR100.8080.8620.8590.893
CIFAR1000.4400.5290.6030.647
Birdsnap0.03600.1610.1260.396
SUN3970.2580.5100.4470.631
Stanford Cars0.0230.5720.5670.637
DTD0.1690.3900.2430.432
MNIST0.1370.4040.5590.559
STL100.9100.9460.9670.970
PCam0.4840.5060.6030.573
CLEVR0.1040.1880.2400.240
Rendered SST20.4830.5080.4840.484
FGVC Aircraft0.0200.0530.2200.244
Oxford Pets0.4620.5870.5070.874
Caltech1010.590.8340.7910.883
HatefulMemes0.5270.5370.5790.589
ImageNet0.5380.4510.3920.638
Flowers1020.0630.4490.3570.697

Звездочками показана средняя zero-shot оценка моделей на 16 датасетах. Также, как и в статье, на признаках, которые достает CLIP для изображений были обучены логистические регрессии с использованием 1-2-4-8-16 изображений для каждого класса. Поскольку признаки, которые извлекаются у openai и openai_mt одинаковые — для openai_mt нет отдельного графика few-shot классификации. Также мы посчитали усредненный few-shot график для модели ruCLIP Base без учета трех датасетов - PCam, Oxford Pets и FGVC Aircraft, на которых модель проигрывает заметнее остальных можно видеть (пунктирная линия), что среднее качество становится лучше в сравнении с ruCLIP Small.

ruclip_base_1

То же самое, но отдельно по каждому датасету.

ruclip_base_2

Сравнение linear-prob метрики для трех моделей на разных датасетах.

ruCLIP Small [rugpt3-small]ruCLIP Base [vit-base-patch32-384]CLIP [vit-base-patch16-224] original
Food1010.8740.8510.901
CIFAR100.9480.9340.953
CIFAR1000.7940.7450.808
Birdsnap0.5840.4340.664
SUN3970.7530.7210.777
Stanford Cars0.8060.7660.866
DTD0.7380.7030.770
MNIST0.9850.9650.989
STL100.9770.9680.982
PCam0.8330.8350.830
CLEVR0.5240.3080.604
Rendered SST20.5680.6510.606
FGVC Aircraft0.5000.2830.604
Oxford Pets0.8950.7300.931
Caltech1010.9370.9220.956
HatefulMemes0.6380.5810.645

Здесь указаны графики корреляции zero-shot и linear-prob результатов для разных моделей.

ruCLIP Base [vit-base-patch32-384]

ruclip_base_3

ruCLIP Small [rugpt3-small]

ruclip_base_4

CLIP [vit-base-patch16-224] original + MT

ruclip_base_5

CLIP [vit-base-patch16-224] original

ruclip_base_6

Сравнение разных моделей на ImageNet датасетах

resnet101CLIP [vit-base-patch16-224] originalCLIP [vit-base-patch16-224] original + MTruCLIP Base [vit-base-patch32-384]ruCLIP Small [rugpt3-small]
ImageNet0.7390.6380.3920.4510.538
ImageNetV20.6180.5820.3530.3890.458
ImageNet-R0.2720.4900.3530.4730.241
ImageNet-A0.0220.2650.1570.1140.080
ImageNet-Sketch0.2650.4480.2910.3740.251

Zero-shot классификация для разных датасетов на моделях ruCLIP.

ruCLIP Base [vit-base-patch32-384]ruCLIP Large [vit-large-patch14-224]ruCLIP Large [vit-large-patch14-336] exclusiveruCLIP Base [vit-base-patch16-384] exclusive
Food101, acc0.6420.5970.712 💥0.689
CIFAR10, acc0.8620.8780.906 💥0.845
CIFAR100, acc0.5290.5110.591 💥0.569
Birdsnap, acc0.1610.1720.213 💥0.195
SUN397, acc0.5100.4840.523 💥0.521
Stanford Cars, acc0.5720.5590.659 💥0.626
DTD, acc0.3900.3700.4080.421 💥
MNIST, acc0.4040.3370.2420.478 💥
STL10, acc0.9460.9340.9560.964 💥
PCam, acc0.5060.5200.554 💥0.501
CLEVR, acc0.188 💥0.1520.1420.132
Rendered SST2, acc0.5080.5290.539 💥0.525
ImageNet, acc0.4510.4260.488 💥0.482
FGVC Aircraft, mean-per-class0.0530.0460.075 💥0.046
Oxford Pets, mean-per-class0.5870.6040.5460.635 💥
Caltech101, mean-per-class0.8340.7770.835 💥0.835 💥
Flowers102, mean-per-class0.4490.4550.517 💥0.452
Hateful Memes, roc-auc0.5370.5300.5190.543💥

Few-shot классификация для разных датасетов на моделях ruCLIP.

ruCLIP Base [vit-base-patch32-384]ruCLIP Large [vit-large-patch14-224]ruCLIP Large [vit-large-patch14-336] exclusive ruCLIP Base [vit-base-patch16-384] exclusive 
Food1010.8510.8400.896 💥0.890
CIFAR100.9340.9270.943 💥0.942
CIFAR1000.7450.7340.7700.773 💥
Birdsnap0.4340.5670.6090.612 💥
SUN3970.7210.7310.759 💥0.758
Stanford Cars0.7660.7970.8310.840 💥
DTD0.7030.7110.7310.749 💥
MNIST0.9650.9490.9490.971 💥
STL100.9680.9730.981 💥0.974
PCam0.8350.7910.8070.846 💥
CLEVR0.3080.3580.3180.378 💥
Rendered SST20.6510.6510.6370.661 💥
FGVC Aircraft0.2830.2900.3410.362 💥
Oxford Pets0.7300.8190.7530.856 💥
Caltech1010.9220.9140.937 💥0.932
HatefulMemes0.5810.5630.585 💥0.578

Полезные ссылки

Обратная связь

Круглосуточная поддержка по телефону 8 800 444-24-99, почте support@cloud.ru и в Telegram