AI-сервис
Kandinsky

Большая модель (12 млрд. параметров) генерации изображений по текстовому описанию на русском языке.

Лицензия
Apache 2.0
Размер файлов
31.2 GB
Версия
0.1
Бесплатно
Подключить
Категории
ruDALL-EKandinskytransformerimage generation
Разработчик
Sber AI, Sber Devices
Контакты
gpt3support@sber.ru
Описание

Описание

В основе архитектуры ruDALL-E XXL Kandinsky — так называемый трансформер, который состоит из энкодера и декодера. Общая идея состоит в том, чтобы вычислить embedding по входным данным с помощью энкодера, а затем с учетом известного выхода правильным образом декодировать этот embedding.

Отличия моделей ruDALL-E XL от Kandinsky лучше всего показать в таблице:

ModelParamsLayersHidden SizeNumAttentionHeadsOptimizer
ruDALL-E XL1.3B24204816AdamW
Kandinsky12B64384060AdamW-8bit-bnb + deepspeed zero3

Характеристики

  • Размер файлов: 31.2 GB
  • Модель GPU: 224xA100 (Обучение), A100/3xV100 (Инференс)
  • Фреймворк: pytorch
  • Формат: checkpoint
  • Версия: 0.1

На первом этапе модель Kandinsky обучалась на протяжении двух месяцев, и этот процесс занял 20 352 GPU-V100 дней. В рамках этой фазы обучения использовался датасет без фильтрации, состоящий из 52 млн пар изображений и текстовых описаний к ним; впоследствии он был сокращён до 28 млн пар. В состав данных вошли такие известные датасеты, как ConceptualCaptions, YFCC100m (описания были переведены на русский язык системой машинного перевода), русская Википедия и другие. Первый этап обучения продолжался в течение 250 тыс. итераций.

После этого была выполнена вторая фаза обучения pretrained модели на новых отфильтрованных данных (7 680 GPU-A100 дней). В состав обучающего датасета на этот раз вошли исключительно нативные русскоязычные данные (без автоматического перевода с других языков): русская часть датасета laion5B, vist, flickr8, flickr30, ru_wiki, CelebA и др. Из датасетов путём фильтрации были исключены изображения с водяными знаками, а также выполнен реранкинг пар с помощью модели ruCLIP. В общей сложности набор данных для второй фазы обучения составил 119 млн пар, обучение длилось 60 тыс. итераций.

Информация об использовании модели

Модель принимает на вход текстовое описание на русском языке, параметры для модели и выполняет генерацию заданного количества сэмплов-изображений. Генерируемые изображения имеют размер 256x256 пикселей.

Подробная техническая статья про обучение модели и ее возможности представлена по ссылке.

KFServing

Класс KFServingKandinskyModel представлен ниже.

На вход модели подаются текстовое описание и параметры для генерации (опционально).

Модель сгенерирует картинки, отберет лучшие с помощью нейросети ruCLIP и улучшит их качество, используя модель Super Resolution. На выходе модель выдаст список изображений и значения близости этих изображений с текстом, полученные с помощью ruCLIP.

Модель Kandinsky можно запускать на одной GPU Tesla A100, а также на трёх или более GPU Tesla V100. Код для запуска модели на трёх GPU Tesla V100 представлен ниже.

import kfserving
from typing import List, Dict
import re
import os
import numpy as np
import json
from PIL import Image
import base64
from io import BytesIO
import io

import torch, torchvision

import ruclip

from rudalle.pipelines import generate_images, show, super_resolution, cherry_pick_by_ruclip
from rudalle import get_rudalle_model, get_tokenizer, get_vae
from rudalle import utils
from rudalle.dalle import MODELS

import math
from torch.nn import LayerNorm

from rudalle.dalle.utils import divide, split_tensor_along_last_dim
from rudalle.dalle.image_attention import get_conv_mask, get_row_mask, get_col_mask
from rudalle.dalle.transformer import *


def check_device(obj, target_device):
    if obj.device != target_device:
        obj = obj.to(target_device)
    return obj


class DalleTransformerParallel(DalleTransformer):
    _mask_map = []

    def __init__(self, *args, **kwargs):
        super(DalleTransformerParallel, self).__init__(*args, **kwargs)
        
        self.text_seq_length = kwargs['text_seq_length']
        self.image_tokens_per_dim = kwargs['image_tokens_per_dim']
        self.is_bool_mask = kwargs['is_bool_mask']

    def forward(self, hidden_states, attention_mask, cache=None, use_cache=False, gradient_checkpointing=None):
        if cache is None:
            cache = {}
        # Immutable caching uses much more VRAM.
        # present_cache = {}

        if gradient_checkpointing:
            assert not use_cache
            layers = []

        for i, layer in enumerate(self.layers):
            mask = attention_mask
            layer_mask = self._get_layer_mask(i)[:mask.size(2), :mask.size(3)]
            
            attention_mask = check_device(attention_mask, layer_mask.device)
            
            mask = torch.mul(attention_mask, layer_mask)
            if gradient_checkpointing:
                layers.append(Layer(layer,
                                    # only get the embeddings, not present_has_cache
                                    lambda x: x[0],
                                    mask,
                                    use_cache=False, has_cache=False))
            else:
                hidden_states = check_device(hidden_states, self.layer_devices[i])
                mask = check_device(mask, self.layer_devices[i])
                    
                hidden_states, layer_cache = layer(
                    hidden_states, mask, cache.get(i), mlp_cache=i == len(self.layers)-1, use_cache=use_cache)
                cache[i] = layer_cache
                
        if gradient_checkpointing:
            hidden_states = torch.utils.checkpoint.checkpoint_sequential(
                layers, gradient_checkpointing, hidden_states)

        hidden_states = rescale_max(hidden_states, self.custom_relax)
        
        hidden_states = check_device(hidden_states, layer_mask.device)
        
        output = self.final_layernorm(hidden_states)
        return output, cache
    
    def to_parallel(self, device_main, devices):
        self.devices = devices
        self.device_main = device_main
        self.devices_count = len(devices)
        
        # set different devices to Transformer layers.
        self.layer_devices = [devices[i//math.ceil(len(self.layers)/len(devices))] for i in range(len(self.layers))]
        for i, layer in enumerate(self.layers):
            layer.to(self.layer_devices[i])
            
        self.final_layernorm.to(device_main)
        
        # reregister buffers with main device
        row_mask = get_row_mask(self.text_seq_length, self.image_tokens_per_dim, is_bool_mask=self.is_bool_mask).to(device_main)
        col_mask = get_col_mask(self.text_seq_length, self.image_tokens_per_dim, is_bool_mask=self.is_bool_mask).to(device_main)
        conv_mask = get_conv_mask(self.text_seq_length, self.image_tokens_per_dim, is_bool_mask=self.is_bool_mask,
                                  hf_version=self.hf_version).to(device_main)
        self.register_buffer('row_mask', row_mask)
        self.register_buffer('col_mask', col_mask)
        self.register_buffer('conv_mask', conv_mask)


from rudalle.realesrgan.model import RealESRGAN
from rudalle.realesrgan import MODELS as REALESRGAN_MODELS

def load_realesrgan(name, device='cpu', fp16=False, cache_dir='/tmp/rudalle'):
    assert name in REALESRGAN_MODELS
    config = REALESRGAN_MODELS[name]
    
    model = RealESRGAN(device, config['scale'], fp16=fp16)
    cache_dir = os.path.join(cache_dir, name)
    model.load_weights(os.path.join(cache_dir, config['filename']))
    return model

from ruclip import MODELS as RUCLIP_MODELS
from ruclip.model import CLIP
from ruclip.processor import RuCLIPProcessor
from ruclip.predictor import Predictor

def load_ruclip(name, device='cpu', cache_dir='/tmp/ruclip'):
    assert name in RUCLIP_MODELS, f'All models: {RUCLIP_MODELS.keys()}'
    config = RUCLIP_MODELS[name]

    clip = CLIP.from_pretrained(cache_dir).eval().to(device)
    clip_processor = RuCLIPProcessor.from_pretrained(cache_dir)
    return clip, clip_processor


def load_dalle_model_parallel(chekpoint_path, device_main, device_ids, **model_kwargs):
    config = MODELS['Kandinsky'].copy()
    config['model_params'].update(model_kwargs)
    config_tranformer = config['model_params'].copy()
    config_tranformer.update(dict(mlp_activation='gelu_jit', hf_version='v3'))
    for k in ['embedding_dropout_prob', 'image_vocab_size', 'vocab_size']:
        config_tranformer.pop(k)
    
    print('Creating Kandinsky model')
    dalle = get_rudalle_model('Kandinsky', pretrained=False, fp16=False, device='cpu', **config['model_params'])
    del dalle.transformer
    dalle.to(device_main)
    
    dalle.transformer = DalleTransformerParallel(**config_tranformer)
    print('Loading Kandinsky checkpoint')
    dalle.load_state_dict(torch.load(chekpoint_path, map_location='cpu'))
    dalle.transformer.to_parallel(device_main, device_ids)
    dalle.eval()
    return dalle


def encode_img(pil_image, format="JPEG") -> str:
    buff = BytesIO()
    pil_image.save(buff, format=format)
    img_b64 = base64.b64encode(buff.getvalue()).decode('utf-8')
    return img_b64
    

class KFServingKandinskyModel(kfserving.KFModel):
    def __init__(self, name: str):
        super().__init__(name)
        self.name = name
        self.ready = False
        self.model = None
        self.gpu = True
        self.device_main = torch.device('cuda:0')
        self.device = self.device_main
        self.devices = [torch.device('cuda:0'), torch.device('cuda:1'), torch.device('cuda:2')]

    def load(self):
        print('Loading models...')
        cache_dir = './weights'
        rudalle_path = './weights/rudalle_kandinsky.bin'
        
        self.model = load_dalle_model_parallel(rudalle_path, self.device, self.devices,
                                               cogview_layernorm_prescale=False, custom_relax=False, is_bool_mask=False)
        print('rudalle loaded')
        #
        self.realesrgan = load_realesrgan('x4', device=self.device, cache_dir=cache_dir)
        print('realesrgan loaded')
        #
        self.tokenizer = get_tokenizer(path=os.path.join(cache_dir, 'tokenizer', 'bpe.model'))
        print('tokenizer loaded')
        #
        dwt=False
        filename = 'vqgan.gumbelf8-sber.model.ckpt'
        checkpoint = torch.load(os.path.join(cache_dir, 'vae', filename), map_location='cpu')
        self.vae = get_vae(pretrained=False, dwt=dwt).to(self.device)
        self.vae.model.load_state_dict(checkpoint['state_dict'], strict=False)
        print('vae loaded')
        #
        ruclip_model_name = 'ruclip-vit-large-patch14-336'
        self.clip, self.processor = load_ruclip(ruclip_model_name, device=self.device, 
                                                cache_dir=os.path.join(cache_dir, ruclip_model_name))
        self.clip_predictor = ruclip.Predictor(self.clip, self.processor, self.device, bs=8)
        print('ruclip loaded')
        
        self.ready = True
        
    def predict(self, request: Dict) -> Dict:
        query = request['instances'][0]['query']
        
        # set default generation params
        gen_params = {
            "images_to_gen": 8,
            "images_to_cherry_pick": 4,
            "top_k": 4096,
            "top_p": 0.975,
            "super_res": True
        }
        
        if "params" in request['instances'][0].keys():
            custom_params = request['instances'][0]['params']
            gen_params.update(custom_params)
        
        #
        query = request['instances'][0]['query']
        
        images, dalle_scores = generate_images(query, self.tokenizer, self.model, self.vae,
            top_k=gen_params['top_k'], top_p=gen_params['top_p'], 
            images_num=gen_params['images_to_gen'], bs=2
        )
        top_images, clip_scores = cherry_pick_by_ruclip(images, query, self.clip_predictor, count=gen_params['images_to_cherry_pick'])
        if gen_params['super_res']:
            top_images = super_resolution(top_images, self.realesrgan)
        
        encoded_images = [encode_img(img) for img in top_images]
        
        return {"ok": True, "images": encoded_images, "clip_scores": clip_scores}

Функция predict возвращает словарь с изображениями, закодированными в base64, значениями близости для каждого изображения, полученные ru-CLIP’ом, и флагом ok, равным True, если модель отработала успешна, и равным False в противном случае.

{
	"ok": True,
	"images": ["base64_image_1", "base64_image_2", "base64_image_3", ...],
	"clip_scores": [0.63, 0.46, 0.45, ...]
}

Пример работы с моделью

Установим зависимости

pip install -r requirements.txt

Подгружаем модель и запускаем генерацию

import base64
import requests
from io import BytesIO
from PIL import Image
import numpy as np
from io import BytesIO

from rudalle.pipelines import show

def decode_img(img_b64):
    bin_img = base64.b64decode(img_b64)
    buff = BytesIO(bin_img)
    return Image.open(buff)

from app import KFServingKandinskyModel

model = KFServingKandinskyModel("kfserving-kandinsky")
model.load()

result = model.predict({
    "instances":[{
        "query": "пейзаж со снежными горами и озером розового цвета"
    }]
})

if result["ok"]:
    images = [decode_img(img_base64) for img_base64 in result['images']]
    print(result['clip_scores'])
    show(images)
[0.6338722705841064, 0.4660634696483612, 0.4531598687171936, 0.43820977210998535]

Изображение

Можно поменять параметры для генерации, передав новые параметры в поле params

Возможные параметры:

  • images_to_gen - integer. Количество изображений для генерации
  • images_to_cherry_pick - integer. Количество изображений для отбора моделью ru-CLIP
  • super_res - boolean. True, если нужно сделать улучшение качества изображения, False иначе
  • top_k, top_p - integer, float. Настраиваемые параметры генерации.
result = model.predict({
  "instances":[{
      "query": "пейзаж со снежными горами и озером розового цвета",
      "params": {
          "images_to_gen": 12,
          "images_to_cherry_pick": 12,
      }
  }]
})

if result["ok"]:
  images = [decode_img(img_base64) for img_base64 in result['images']]
  print(result['clip_scores'])
  show(images)
[0.5908635258674622, 0.5827779769897461, 0.5569754838943481, 0.5007778406143188, 0.4987826347351074, 0.45289450883865356, 0.40974923968315125, 0.38415390253067017, 0.380871444940567, 0.36455070972442627, 0.3052992820739746, 0.30278700590133667]

Изображение

Процесс обучения модели

На графиках показаны изменения текстовой, визуальной и общей функций потерь для первой (синий) и второй (оранжевый) фаз обучения на валидационной выборке. В качестве такой выборки мы использовали часть MS-COCO Validation Dataset, которая состояла из 422 пар. Каждая пара была проверена вручную: описание переведено на русский язык автоматическим переводчиком и скорректировано при необходимости.

Изображение

Анализируя данные горизонтальной оси, можно заметить, что первая фаза обучения включала в себя около 3,5 эпох, а вторая — 1,5 эпохи. При первом взгляде на графики возникает естественный вопрос по поводу разрыва между первой и второй фазами. На самом деле, этому есть несколько объяснений: изменение кодовой базы для тренировки модели (ушли от подхода Megatron model-parallel), изменение количества карт в обучении (что, соответственно, привело к сбросу всех состояний оптимизатора DeepSpeed Zero3 после первой фазы), а также совершенно новые данные для обучения.

Эксперименты

Большинство работ по созданию моделей генерации изображений по текстовым описаниям используют метрику Frechet Inception Distance (FID) в качестве основной метрики качества. Каноничный датасет, который используется в ходе оценки качества — MS-COCO validation set (30 тыс. изображений). Использовав машинный перевод описаний на русский язык, мы смогли проверить модель Kandinsky в ряду известных на текущий момент моделей. Как видно из таблицы, Kandinsky показывает лучшее значение метрики FID среди моделей с аналогичной архитектурой.

Название моделиZero-shot FID
Авторегрессионные модели
ruDALL-E XL18.6
Kandinsky15.4💥
minDALL-E24.6
CogView27.1
DALL-E27.5
Диффузионные модели
GLIDE12.24
DALL-E 210.39
Imagen7.27

Human evaluation

Для оценки сгенерированных изображений мы использовали подход Human evaluation, аналогичный тому, который был применён для качественной оценки DALL-E, GLIDE, DALL-E 2 и др. Вниманию независимых наблюдателей были представлены пары изображений — одно из пары получено с помощью ruDALL-E XL, другое сгенерировано Kandinsky — с вопросом: «Какая картинка реалистичнее?». Вторым заданием было оценить, какая картинка больше соответствует текстовому описанию (и соответствует ли ему в принципе).

Для генерации было использовано 422 описания из валидационной выборки MS-COCO, каждая пара изображений была оценена 3-5 людьми. Как видно из графика, модель Kandinsky выигрывает у предшественника и с точки зрения реалистичности генерируемых изображений, и с точки зрения их соответствия текстовому запросу (и хотя изображения не всегда подходят к  запросам – столбец None на графике, – показатель соответствия более, чем в 70% случаев для обеих моделей весьма высок).

Изображение