Модель
ruCLIP Large [vit-large-patch14-336] exclusive

Russian Contrastive Language–Image Pre-training. Модель-ранжировщик текстов и изображений, 430 млн параметров.

ruCLIP - это мультимодальная модель для ранжирования изображений и подписей к ним, а также получения семантической близости изображений и текстов. Архитектура впервые представлена OpenAI.

Лицензия
Apache 2.0
Размер файлов
1.5 GB
Версия
0.1
Бесплатно
Подключить
Категории
cvNLPclipgpt3rucliprugpt3pytorchcomputer visionnatural language generation
Разработчик
Sber AI, SberDevices
Контакты
gpt3support@sber.ru
Описание

ruCLIP (Russian Contrastive Language – Image Pre-training) обучена для русского языка на открытых данных, собранных из Рунета.

240 миллионов уникальных пар картинка-текст в обучающей выборке.

Информация об использовании модели:

Эксклюзивная версия модели ruCLIP, доступная только на платформе Cloud. Модель отличает сочетание меньшего размера патча - 14 и большего размера входных изображений - 336. Данное сочетание привело к наиболее высоким качественным результатам оценки модели - на 12 из 18 датасетах (zero-shot задача), на 5 из 16 датасетах (few-shot задача).

ruCLIP — это модель, состоящая из двух частей (или нейронных сетей):

  1. Image Encoder — часть для кодирования изображений и перевода их в общее векторное пространство. В качестве архитектуры в оригинальной работе берутся ResNet разных размеров и Visual Transformer — тоже разных размеров. В ruCLIP Large в качестве image encoder используется ViT-L/14.
  2. Text Encoder — часть для кодирования текстов и перевода их в общее векторное пространство. В качестве архитектуры используется текстовый Transformer.

Similarity

ruclip0

[{"собачка": 0.9913390278816223}, {"кошка": 0.0070372591726481915}, {"мышка": 5.551231879508123e-05}, {"машина": 0.000418326846556738}, {"стол": 0.0001436278544133529}, {"дом": 0.00010760522854980081}, {"жидкость": 0.0008986211032606661}]

KFServing

Класс KFServingRuClipModel представлен ниже.

Вы можете подавать на вход модели:

  • ссылки на изображения
  • картинки в формате base64

На выходе модель покажет близость между текстами и картинками. Чем ближе значение к 1, тем ближе семантическое сходство картинки и текста.

import os
import json
from collections import OrderedDict
from typing import Dict
import kfserving
import requests
import torch
import numpy as np
import io
from io import BytesIO
import base64
from PIL import Image
import re

from ruclip import CLIP, RuCLIPProcessor

def open_images_base64(img_strs):
    return [Image.open(BytesIO(base64.b64decode(img_str))) for img_str in img_strs]

def open_image_link(links):
    imgs = []
    for img_link in links:
        response = requests.get(img_link)
        imgs.append(Image.open(BytesIO(response.content)))
    return imgs

def create_image(sim_plt):
    my_stringIObytes = io.BytesIO()
    sim_plt.savefig(my_stringIObytes, format="jpg")
    my_stringIObytes.seek(0)
    my_base64 = base64.b64encode(my_stringIObytes.read())
    return my_base64

class KFServingRuClipModel(kfserving.KFModel):
    def __init__(self, name: str, model_path="./ruclip-vit-large-patch14-336"):
        super().__init__(name)
        self.name = name
        self.ready = False
        self.model_path = model_path

    def load(self):
        self.device = "cuda"
        self.clip = CLIP.from_pretrained(self.model_path).eval().to(self.device)
        self.clip_processor = RuCLIPProcessor.from_pretrained(self.model_path)
        self.ready = True
        
    def get_text_latents(self, texts):
        with torch.no_grad():
            inputs = self.clip_processor(text=texts, images=None)
            text_latents = self.clip.encode_text(
                input_ids=inputs["input_ids"].to(self.device), 
            )
            text_latents = text_latents / text_latents.norm(dim=-1, keepdim=True)
        return text_latents
        
    def get_logits(self, text_latents, pil_images):
        with torch.no_grad():
            inputs = self.clip_processor(text=None, images=pil_images)
            image_latents = self.clip.encode_image(
                pixel_values=inputs["pixel_values"].to(self.device)
            )
            image_latents = image_latents / image_latents.norm(dim=-1, keepdim=True)
            logits_per_text = torch.matmul(text_latents, image_latents.t())
            logits_per_image = logits_per_text.t()
        return logits_per_text, logits_per_image
    
    def get_similarity_scores(self, texts, images):
        """
        Find the most similar image to text.
        `texts`: array of texts or one text ["some_desc"]
        `images`: array of images.
        """
        text_latents = self.get_text_latents(texts)
        results = []
        for pil_image in images:
            _, logits_per_image = self.get_logits(text_latents, [pil_image])
            probs_raw = (logits_per_image * self.clip.logit_scale.exp().detach()).softmax(dim=-1)[0]
            label_id = probs_raw.argmax().item()
            confidence = probs_raw.max().item()
            
            probs = []
            for i in range(len(texts)):
                probs.append({texts[i]: probs_raw[i].item()})

            buffered = BytesIO()
            pil_image.save(buffered, format="JPEG")
            img = base64.b64encode(buffered.getvalue()).decode("utf-8")
            results.append({"image": img, "text": texts[label_id], "confidence" : confidence, "all_res": probs})
        
        return results
    
    def predict(self, request: Dict) -> Dict:
        texts = request["instances"][0]["texts"]
        
        img_strs = request["instances"][0].get("images", None)
        if img_strs is not None:
            images = open_images_base64(img_strs)
        
        images_links = request["instances"][0].get("image_links", None)

        if images_links is not None:
            images = open_image_link(images_links)
          
        error_msg = None
        predictions = []
        try:
            predictions = self.get_similarity_scores(texts, images)
        except Exception as ex:
            print(ex)
            error_msg = ex

        if error_msg is not None:
            return {"predictions": predictions, "error_message": str(error_msg)}
        else:
            return {"predictions": predictions}

Функция predict возвращает массив со словарями для каждой картинки вида:

     predictions: [
				{
					"image": "base64 image", 
					"text": "наиболее подходящий текст для картинки" ,
          "confidence": "мера близости для лучшей картинки", 
          "all_res": [{"менее вероятный текст": 0.31}, {"другой текст": 0.33},} ...]
				}
     ]

Пример работы с моделью

!pip install -r requirements.txt
from kfserving_ru_CLIP import KFServingRuClipModel

model = KFServingRuClipModel("kfserving-clip")
model.load()

# zero-shot and links
url_cat = "https://cs11.livemaster.ru/storage/topic/NxN/2c/9b/9cf0a41d13ecb11439e6145dff576315df83op.jpg?h=3KvOPndE06tlraLLSmkHPQ"
url_cat2 = "https://pbs.twimg.com/profile_images/560798448962633728/rDEdUfV_.jpeg"
url_dog = "https://ichef.bbci.co.uk/news/640/cpsprodpb/475B/production/_98776281_gettyimages-521697453.jpg"
result = model.predict({"instances": [{
        "texts": ["собачка", "кошка", "мышка", "машина", "стол", "дом", "жидкость"],
        "image_links": [url_dog, url_cat, url_cat2]
    }]
})

"Res: ", result

{"predictions": "[{"image": b"/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQ...", "text": "собачка", "confidence": 0.9913390278816223,"all_res": [{"собачка": 0.9913390278816223}, {"кошка": 0.0070372591726481915}, {"мышка": 5.551231879508123e-05}, {"машина": 0.000418326846556738}, {"стол": 0.0001436278544133529}, {"дом": 0.00010760522854980081}, {"жидкость": 0.0008986211032606661}]}, ... ]", "error_message": None}

Оценки модели на популярных датасетах

Косинусная близость между текстами и картинками для модели ruCLIP

ruclip3

Предсказания топ 5 классов для изображений с помощью ruCLIP

ruclip4

Сравнение моделей на задаче zero-shot классификации для разных датасетов. Жирным выделена лучшая метрика для каждого из датасетов без учета оригинального CLIP без переводчика.

ruCLIP Small [rugpt3-small]ruCLIP Large [vit-large-patch14-336] exclusiveCLIP [vit-base-patch16-224] original + MTCLIP [vit-base-patch16-224] original
Food1010.1370.7110.6630.882
CIFAR100.8080.9050.8590.892
CIFAR1000.4390.5910.6020.647
Birdsnap0.0350.2130.1260.395
SUN3970.2570.5230.4470.631
Stanford Cars0.0220.6580.5670.637
DTD0.1680.4070.2420.432
MNIST0.1370.2410.5580.558
STL100.9090.9560.9660.970
PCam0.4840.5530.6020.572
CLEVR0.1040.1420.2400.240
Rendered SST20.4820.5380.4830.483
FGVC Aircraft0.0190.0750.2190.244
Oxford Pets0.4620.5450.5060.873
Caltech1010.5890.8350.7910.882
HatefulMemes0.5270.5180.5790.589
ImageNet0.5370.4880.3910.638
Flowers1020.0630.5160.3560.696

Звездочками показана средняя zero-shot оценка моделей на 16 датасетах. Также, как и в статье, на признаках, которые достает CLIP для изображений были обучены логистические регрессии с использованием 1-2-4-8-16 изображений для каждого класса. Поскольку признаки, которые извлекаются у openai и openai_mt одинаковые — для openai_mt нет отдельного графика few-shot классификации. Также мы посчитали усредненный few-shot график для модели ruCLIP Large exclusive без учета трех датасетов - PCam, Oxford Pets и FGVC Aircraft, на которых модель проигрывает достаточно сильно и можно видеть (пунктирная линия), что среднее качество немного превосходит ruCLIP Small.

ruclip_large_ex_1

То же самое, но отдельно по каждому датасету.

ruclip_large_ex_2

Сравнение linear-prob метрики для трех моделей на разных датасетах.

ruCLIP Small [rugpt3-small]ruCLIP Large [vit-large-patch14-336] exclusiveCLIP [vit-base-patch16-224] original
Food1010.8730.8960.900
CIFAR100.9480.9430.953
CIFAR1000.7940.7690.807
Birdsnap0.5840.6090.664
SUN3970.7520.7580.776
Stanford Cars0.8060.8310.865
DTD0.7370.7300.769
MNIST0.9850.9490.988
STL100.9770.9810.982
PCam0.8330.8060.830
CLEVR0.5240.3180.604
Rendered SST20.5680.6370.605
FGVC Aircraft0.4990.3400.604
Oxford Pets0.8940.7530.931
Caltech1010.9360.9360.955
HatefulMemes0.6370.5850.645

Здесь указаны графики корреляции zero-shot и linear-prob результатов для разных моделей.

ruCLIP Large [vit-large-patch16-336] exclusive

ruclip_large_ex_3

ruCLIP Small [rugpt3-small]

ruclip_base_4

CLIP [vit-base-patch16-224] original + MT

ruclip_base_5

CLIP [vit-base-patch16-224] original

ruclip_base_6

Сравнение разных моделей на ImageNet датасетах.

resnet101CLIP [vit-base-patch16-224] originalCLIP [vit-base-patch16-224] original + MTruCLIP Large [vit-large-patch14-336] exclusiveruCLIP Small [rugpt3-small]
ImageNet0.7380.6380.3910.4880.537
ImageNetV20.6180.5810.3530.4300.458
ImageNet-R0.2710.4890.3520.5250.240
ImageNet-A0.0210.2650.1560.2290.080
ImageNet-Sketch0.2640.4480.2910.4190.250

Zero-shot классификация для разных датасетов на моделях ruCLIP.

ruCLIP Large [vit-large-patch14-336] exclusiveruCLIP Large [vit-large-patch14-224]ruCLIP Base [vit-base-patch32-384]ruCLIP Base [vit-base-patch16-384] exclusive
Food101, acc0.712 💥0.5970.6420.689
CIFAR10, acc0.906 💥0.8780.8620.845
CIFAR100, acc0.591 💥0.5110.5290.569
Birdsnap, acc0.213 💥0.1720.1610.195
SUN397, acc0.523 💥0.4840.5100.521
Stanford Cars, acc0.659 💥0.5590.5720.626
DTD, acc0.4080.3700.3900.421 💥
MNIST, acc0.2420.3370.4040.478 💥
STL10, acc0.9560.9340.9460.964 💥
PCam, acc0.554 💥0.5200.5060.501
CLEVR, acc0.1420.1520.188 💥0.132
Rendered SST2, acc0.539 💥0.5290.5080.525
ImageNet, acc0.488 💥0.4260.4510.482
FGVC Aircraft, mean-per-class0.075 💥0.0460.0530.046
Oxford Pets, mean-per-class0.5460.6040.5870.635 💥
Caltech101, mean-per-class0.835 💥0.7770.8340.835 💥
Flowers102, mean-per-class0.517 💥0.4550.4490.452
Hateful Memes, roc-auc0.5190.5300.5370.543💥

Few-shot классификация для разных датасетов на моделях ruCLIP.

ruCLIP Large [vit-large-patch14-336] exclusiveruCLIP Large [vit-large-patch14-224]ruCLIP Base [vit-base-patch32-384]ruCLIP Base [vit-base-patch16-384] exclusive
Food1010.896 💥0.8400.8510.890
CIFAR100.943 💥0.9270.9340.942
CIFAR1000.7700.7340.7450.773 💥
Birdsnap0.6090.5670.4340.612 💥
SUN3970.759 💥0.7310.7210.758
Stanford Cars0.8310.7970.7660.840 💥
DTD0.7310.7110.7030.749 💥
MNIST0.9490.9490.9650.971 💥
STL100.981 💥0.9730.9680.974
PCam0.8070.7910.8350.846 💥
CLEVR0.3180.3580.3080.378 💥
Rendered SST20.6370.6510.6510.661 💥
FGVC Aircraft0.3410.2900.2830.362 💥
Oxford Pets0.7530.8190.7300.856 💥
Caltech1010.937 💥0.9140.9220.932
HatefulMemes0.585 💥0.5630.5810.578

Полезные ссылки