Модель
ruCLIP Base [vit-base-patch32-384]

Russian Contrastive Language–Image Pre-training. Модель-ранжировщик текстов и изображений, 150 млн параметров.

ruCLIP - это мультимодальная модель для ранжирования изображений и подписей к ним, а также получения семантической близости изображений и текстов. Архитектура впервые представлена OpenAI.

Лицензия
Apache 2.0
Размер файлов
0.6 GB
Версия
0.1
Бесплатно
Подключить
Категории
cvNLPclipgpt3rucliprugpt3pytorchcomputer visionnatural language generation
Разработчик
Sber AI, SberDevices
Контакты
gpt3support@sber.ru
Описание

ruCLIP (Russian Contrastive Language – Image Pre-training) обучена для русского языка на открытых данных, собранных из Рунета.

240 миллионов уникальных пар картинка-текст в обучающей выборке.

Информация об использовании модели:

ruCLIP — это модель, состоящая из двух частей (или нейронных сетей):

  1. Image Encoder — часть для кодирования изображений и перевода их в общее векторное пространство. В качестве архитектуры в оригинальной работе берутся ResNet разных размеров и Visual Transformer — тоже разных размеров. В ruCLIP Base в качестве image encoder используется ViT-B/32.
  2. Text Encoder — часть для кодирования текстов и перевода их в общее векторное пространство. В качестве архитектуры используется текстовый Transformer.

Similarity

ruclip0

[{"собачка": 0.9324831366539001}, {"кошка": 0.02790665067732334}, {"мышка": 0.029953204095363617}, {"машина": 0.0034359991550445557}, {"стол": 0.0031091528944671154}, {"дом": 0.0018060706788673997}, {"жидкость": 0.0013057001633569598}]

KFServing

Класс KFServingRuClipModel представлен ниже.

Вы можете подавать на вход модели:

  • ссылки на изображения
  • картинки в формате base64

На выходе модель покажет близость между текстами и картинками. Чем ближе значение к 1, тем ближе семантическое сходство картинки и текста.

import os
import json
from collections import OrderedDict
from typing import Dict
import kfserving
import requests
import torch
import numpy as np
import io
from io import BytesIO
import base64
from PIL import Image
import re

from ruclip import CLIP, RuCLIPProcessor

def open_images_base64(img_strs):
    return [Image.open(BytesIO(base64.b64decode(img_str))) for img_str in img_strs]

def open_image_link(links):
    imgs = []
    for img_link in links:
        response = requests.get(img_link)
        imgs.append(Image.open(BytesIO(response.content)))
    return imgs

def create_image(sim_plt):
    my_stringIObytes = io.BytesIO()
    sim_plt.savefig(my_stringIObytes, format="jpg")
    my_stringIObytes.seek(0)
    my_base64 = base64.b64encode(my_stringIObytes.read())
    return my_base64

class KFServingRuClipModel(kfserving.KFModel):
    def __init__(self, name: str, model_path="./ruclip-vit-base-patch32-384"):
        super().__init__(name)
        self.name = name
        self.ready = False
        self.model_path = model_path

    def load(self):
        self.device = "cuda"
        self.clip = CLIP.from_pretrained(self.model_path).eval().to(self.device)
        self.clip_processor = RuCLIPProcessor.from_pretrained(self.model_path)
        self.ready = True
        
    def get_text_latents(self, texts):
        with torch.no_grad():
            inputs = self.clip_processor(text=texts, images=None)
            text_latents = self.clip.encode_text(
                input_ids=inputs["input_ids"].to(self.device), 
            )
            text_latents = text_latents / text_latents.norm(dim=-1, keepdim=True)
        return text_latents
        
    def get_logits(self, text_latents, pil_images):
        with torch.no_grad():
            inputs = self.clip_processor(text=None, images=pil_images)
            image_latents = self.clip.encode_image(
                pixel_values=inputs["pixel_values"].to(self.device)
            )
            image_latents = image_latents / image_latents.norm(dim=-1, keepdim=True)
            logits_per_text = torch.matmul(text_latents, image_latents.t())
            logits_per_image = logits_per_text.t()
        return logits_per_text, logits_per_image
    
    def get_similarity_scores(self, texts, images):
        """
        Find the most similar image to text.
        `texts`: array of texts or one text ["some_desc"]
        `images`: array of images.
        """
        text_latents = self.get_text_latents(texts)
        results = []
        for pil_image in images:
            _, logits_per_image = self.get_logits(text_latents, [pil_image])
            probs_raw = (logits_per_image * self.clip.logit_scale.exp().detach()).softmax(dim=-1)[0]
            label_id = probs_raw.argmax().item()
            confidence = probs_raw.max().item()
            
            probs = []
            for i in range(len(texts)):
                probs.append({texts[i]: probs_raw[i].item()})

            buffered = BytesIO()
            pil_image.save(buffered, format="JPEG")
            img = base64.b64encode(buffered.getvalue()).decode("utf-8")
            results.append({"image": img, "text": texts[label_id], "confidence" : confidence, "all_res": probs})
        
        return results
    
    def predict(self, request: Dict) -> Dict:
        texts = request["instances"][0]["texts"]
        
        img_strs = request["instances"][0].get("images", None)
        if img_strs is not None:
            images = open_images_base64(img_strs)
        
        images_links = request["instances"][0].get("image_links", None)

        if images_links is not None:
            images = open_image_link(images_links)
          
        error_msg = None
        predictions = []
        try:
            predictions = self.get_similarity_scores(texts, images)
        except Exception as ex:
            print(ex)
            error_msg = ex

        if error_msg is not None:
            return {"predictions": predictions, "error_message": str(error_msg)}
        else:
            return {"predictions": predictions}

Функция predict возвращает массив со словарями для каждой картинки вида:

     predictions: [
				{
					"image": "base64 image", 
					"text": "наиболее подходящий текст для картинки" ,
          "confidence": "мера близости для лучшей картинки", 
          "all_res": [{"менее вероятный текст": 0.31}, {"другой текст": 0.33},} ...]
				}
     ]

Пример работы с моделью

!pip install -r requirements.txt
from kfserving_ru_CLIP import KFServingRuClipModel

model = KFServingRuClipModel("kfserving-clip")
model.load()

# zero-shot and links
url_cat = "https://cs11.livemaster.ru/storage/topic/NxN/2c/9b/9cf0a41d13ecb11439e6145dff576315df83op.jpg?h=3KvOPndE06tlraLLSmkHPQ"
url_cat2 = "https://pbs.twimg.com/profile_images/560798448962633728/rDEdUfV_.jpeg"
url_dog = "https://ichef.bbci.co.uk/news/640/cpsprodpb/475B/production/_98776281_gettyimages-521697453.jpg"
result = model.predict({"instances": [{
        "texts": ["собачка", "кошка", "мышка", "машина", "стол", "дом", "жидкость"],
        "image_links": [url_dog, url_cat, url_cat2]
    }]
})

"Res: ", result

{"predictions": "[{"image": b"/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQ...", "text": "собачка", "confidence": 0.9324831366539001, "all_res": [{"собачка": 0.9324831366539001}, {"кошка": 0.02790665067732334}, {"мышка": 0.029953204095363617}, {"машина": 0.0034359991550445557}, {"стол": 0.0031091528944671154}, {"дом": 0.0018060706788673997}, {"жидкость": 0.0013057001633569598}]}, ... ]", "error_message": None}

Оценки модели на популярных датасетах

Косинусная близость между текстами и картинками для модели ruCLIP

ruclip3

Предсказания топ 5 классов для изображений с помощью ruCLIP

ruclip4

Сравнение моделей на задаче zero-shot классификации для разных датасетов. Жирным выделена лучшая метрика для каждого из датасетов без учета оригинального CLIP без переводчика.

| | ruCLIP Small [rugpt3-small] | ruCLIP Base [vit-base-patch32-384] | CLIP [vit-base-patch16-224] original + MT | CLIP [vit-base-patch16-224] original | | --- | --- | --- | --- | --- | | Food101 | 0.138 | 0.642 | 0.663 | 0.883 | | CIFAR10 | 0.808 | 0.862 | 0.859 | 0.893 | | CIFAR100 | 0.440 | 0.529 | 0.603 | 0.647 | | Birdsnap | 0.0360 | 0.161 | 0.126 | 0.396 | | SUN397 | 0.258 | 0.510 | 0.447 | 0.631 | | Stanford Cars | 0.023 | 0.572 | 0.567 | 0.637 | | DTD | 0.169 | 0.390 | 0.243 | 0.432 | | MNIST | 0.137 | 0.404 | 0.559 | 0.559 | | STL10 | 0.910 | 0.946 | 0.967 | 0.970 | | PCam | 0.484 | 0.506 | 0.603 | 0.573 | | CLEVR | 0.104 | 0.188 | 0.240 | 0.240 | | Rendered SST2 | 0.483 | 0.508 | 0.484 | 0.484 | | FGVC Aircraft | 0.020 | 0.053 | 0.220 | 0.244 | | Oxford Pets | 0.462 | 0.587 | 0.507 | 0.874 | | Caltech101 | 0.59 | 0.834 | 0.791 | 0.883 | | HatefulMemes | 0.527 | 0.537 | 0.579 | 0.589 | | ImageNet | 0.538 | 0.451 | 0.392 | 0.638 | | Flowers102 | 0.063 | 0.449 | 0.357 | 0.697 |

Звездочками показана средняя zero-shot оценка моделей на 16 датасетах. Также, как и в статье, на признаках, которые достает CLIP для изображений были обучены логистические регрессии с использованием 1-2-4-8-16 изображений для каждого класса. Поскольку признаки, которые извлекаются у openai и openai_mt одинаковые — для openai_mt нет отдельного графика few-shot классификации. Также мы посчитали усредненный few-shot график для модели ruCLIP Base без учета трех датасетов - PCam, Oxford Pets и FGVC Aircraft, на которых модель проигрывает заметнее остальных можно видеть (пунктирная линия), что среднее качество становится лучше в сравнении с ruCLIP Small.

ruclip_base_1

То же самое, но отдельно по каждому датасету.

ruclip_base_2

Сравнение linear-prob метрики для трех моделей на разных датасетах.

| | ruCLIP Small [rugpt3-small] | ruCLIP Base [vit-base-patch32-384] | CLIP [vit-base-patch16-224] original | | --- | --- | --- | --- | | Food101 | 0.874 | 0.851 | 0.901 | | CIFAR10 | 0.948 | 0.934 | 0.953 | | CIFAR100 | 0.794 | 0.745 | 0.808 | | Birdsnap | 0.584 | 0.434 | 0.664 | | SUN397 | 0.753 | 0.721 | 0.777 | | Stanford Cars | 0.806 | 0.766 | 0.866 | | DTD | 0.738 | 0.703 | 0.770 | | MNIST | 0.985 | 0.965 | 0.989 | | STL10 | 0.977 | 0.968 | 0.982 | | PCam | 0.833 | 0.835 | 0.830 | | CLEVR | 0.524 | 0.308 | 0.604 | | Rendered SST2 | 0.568 | 0.651 | 0.606 | | FGVC Aircraft | 0.500 | 0.283 | 0.604 | | Oxford Pets | 0.895 | 0.730 | 0.931 | | Caltech101 | 0.937 | 0.922 | 0.956 | | HatefulMemes | 0.638 | 0.581 | 0.645 |

Здесь указаны графики корреляции zero-shot и linear-prob результатов для разных моделей.

ruCLIP Base [vit-base-patch32-384]

ruclip_base_3

ruCLIP Small [rugpt3-small]

ruclip_base_4

CLIP [vit-base-patch16-224] original + MT

ruclip_base_5

CLIP [vit-base-patch16-224] original

ruclip_base_6

Сравнение разных моделей на ImageNet датасетах

| | resnet101 | CLIP [vit-base-patch16-224] original | CLIP [vit-base-patch16-224] original + MT | ruCLIP Base [vit-base-patch32-384] | ruCLIP Small [rugpt3-small] | | --- | --- | --- | --- | --- | --- | | ImageNet | 0.739 | 0.638 | 0.392 | 0.451 | 0.538 | | ImageNetV2 | 0.618 | 0.582 | 0.353 | 0.389 | 0.458 | | ImageNet-R | 0.272 | 0.490 | 0.353 | 0.473 | 0.241 | | ImageNet-A | 0.022 | 0.265 | 0.157 | 0.114 | 0.080 | | ImageNet-Sketch | 0.265 | 0.448 | 0.291 | 0.374 | 0.251 |

Zero-shot классификация для разных датасетов на моделях ruCLIP.

| | ruCLIP Base [vit-base-patch32-384] | ruCLIP Large [vit-large-patch14-224] | ruCLIP Large [vit-large-patch14-336] exclusive | ruCLIP Base [vit-base-patch16-384] exclusive | | --- | --- | --- | --- | --- | | Food101, acc | 0.642 | 0.597 | 0.712 💥 | 0.689 | | CIFAR10, acc | 0.862 | 0.878 | 0.906 💥 | 0.845 | | CIFAR100, acc | 0.529 | 0.511 | 0.591 💥 | 0.569 | | Birdsnap, acc | 0.161 | 0.172 | 0.213 💥 | 0.195 | | SUN397, acc | 0.510 | 0.484 | 0.523 💥 | 0.521 | | Stanford Cars, acc | 0.572 | 0.559 | 0.659 💥 | 0.626 | | DTD, acc | 0.390 | 0.370 | 0.408 | 0.421 💥 | | MNIST, acc | 0.404 | 0.337 | 0.242 | 0.478 💥 | | STL10, acc | 0.946 | 0.934 | 0.956 | 0.964 💥 | | PCam, acc | 0.506 | 0.520 | 0.554 💥 | 0.501 | | CLEVR, acc | 0.188 💥 | 0.152 | 0.142 | 0.132 | | Rendered SST2, acc | 0.508 | 0.529 | 0.539 💥 | 0.525 | | ImageNet, acc | 0.451 | 0.426 | 0.488 💥 | 0.482 | | FGVC Aircraft, mean-per-class | 0.053 | 0.046 | 0.075 💥 | 0.046 | | Oxford Pets, mean-per-class | 0.587 | 0.604 | 0.546 | 0.635 💥 | | Caltech101, mean-per-class | 0.834 | 0.777 | 0.835 💥 | 0.835 💥 | | Flowers102, mean-per-class | 0.449 | 0.455 | 0.517 💥 | 0.452 | | Hateful Memes, roc-auc | 0.537 | 0.530 | 0.519 | 0.543💥 |

Few-shot классификация для разных датасетов на моделях ruCLIP.

| | ruCLIP Base [vit-base-patch32-384] | ruCLIP Large [vit-large-patch14-224] | ruCLIP Large [vit-large-patch14-336] exclusive  | ruCLIP Base [vit-base-patch16-384] exclusive  | | --- | --- | --- | --- | --- | | Food101 | 0.851 | 0.840 | 0.896 💥 | 0.890 | | CIFAR10 | 0.934 | 0.927 | 0.943 💥 | 0.942 | | CIFAR100 | 0.745 | 0.734 | 0.770 | 0.773 💥 | | Birdsnap | 0.434 | 0.567 | 0.609 | 0.612 💥 | | SUN397 | 0.721 | 0.731 | 0.759 💥 | 0.758 | | Stanford Cars | 0.766 | 0.797 | 0.831 | 0.840 💥 | | DTD | 0.703 | 0.711 | 0.731 | 0.749 💥 | | MNIST | 0.965 | 0.949 | 0.949 | 0.971 💥 | | STL10 | 0.968 | 0.973 | 0.981 💥 | 0.974 | | PCam | 0.835 | 0.791 | 0.807 | 0.846 💥 | | CLEVR | 0.308 | 0.358 | 0.318 | 0.378 💥 | | Rendered SST2 | 0.651 | 0.651 | 0.637 | 0.661 💥 | | FGVC Aircraft | 0.283 | 0.290 | 0.341 | 0.362 💥 | | Oxford Pets | 0.730 | 0.819 | 0.753 | 0.856 💥 | | Caltech101 | 0.922 | 0.914 | 0.937 💥 | 0.932 | | HatefulMemes | 0.581 | 0.563 | 0.585 💥 | 0.578 |

Полезные ссылки