Это демо приложение на базе Gradio, демонстрирующее возможности горячей замены PEFT-адаптеров, а именно LoRA, над одной и той же LLM прямо в Runtime. Выполнено в рамках тестового задания на позицию ML Engineer.
Создать окружение Python 3.10.10 с помощью Conda или Pyenv:
conda create -n myenv python=3.10.10 && conda activate myenvУстановить необходимые пакеты:
pip install -r requirements.txtИ запустить:
python -m appСобрать образ:
docker build -t llm-lora-hotswap .И запустить приложение:
docker run --gpus all --name hotswap-app --net host --rm -it llm-lora-hotswapПосле запуска приложение должно быть доступно на 7860 порту localhost (порт Gradio по умолчанию).
В данном проекте я набросал черновую структуру классов, которые могли бы быть прототипом для решения в продакшене. Поскольку у каждого адаптера могут быть свои нюансы токенизации, пре/пост-обработки сообщений, они инкапсулируются в классах-наследниках LLMAdapterBase.
Для самих ответов LLM я реализовал поддержку стриминга токенов, чтобы можно было отдать первый токен как можно быстрее на целевой интерфейс (в данном случае — в UI чата).
- Llama2 7B GPTQ — LLM, квантизованная с помощью метода GPTQ до 4bit. Выбрал её вместо квантизации через
bitsandbytes, поскольку по тестам GPTQ даёт выше качество итоговой модели. На моей локальной машине установлена RTX 3070 на 8Gb VRAM, поэтому нужна была хотя бы 4bit версия - Saiga2 LoRA — адаптер поверх Llama 2, дообученный на инструктивно-диалоговом датасете Сайга
- Llama 2 LoRA OpenAssistant Guanaco (блогпост) — адаптер, дообученный на очищенной части датасета OpenAssistant (OASTT)
Цель данного демо — показать с помощью простого кода реализацию горячей замены и предоставить интерактивный интерфейс для демонстрации работы. Если бы потребовалось реализовывать подобный функционал в виде REST API, я бы посмотрел такие решения как OpenLLM. Согласно вот этому обзору фреймворков для инференса и текущей документации, OpenLLM — единственный, который поддерживает адаптеры и их подмену в Runtime.
Однако OpenLLM сам по себе, и в особенности с адаптерами, будет давать низкий RPS и высокий Latency. Дело в том, что при сёрвинге в продакшене кучи адаптеров страдает батчинг запросов — следует задуматься над более эффективной утилизацией GPU и формированием батчей. Я нашёл пару многообещающих решений для этой проблемы:
Другой открытый вопрос — допустимо ли использовать адаптер, обученный поверх модели в половинной точности, с моделью квантизованной до 4bit. Допускаю, что может присутствовать деградация в качестве вывода такого адаптера. В продакшен разработке следовало бы проверить качество этой связки на downstream задачах.
Кроме того, при обучении своего адаптера под такой юзкейс я бы сразу смотрел в сторону quantization-aware методов:
- LoftQ: LoRA-Fine-Tuning-Aware Quantization for Large Language Models
- QA-LoRA: Quantization-Aware Low-Rank Adaptation of Large Language Models
- LQ-LoRA: Low-rank plus Quantized Matrix Decomposition for Efficient Language Model Finetuning
Основную массу времени я потратил на поиск базовой модели и подходящих адаптеров. Согласно заданию, я нацелился на использование Saiga2 от Ильи Гусева, однако она в свою очередь является файнтюном над базовой Llama2, а не Chat/Instruct версией — что, как оказалось, редкость, если хочется использовать модель в чатботе. Большинство LoRA-адаптеров для чата файнтюнятся именно от Chat/Instruct-модели.
Также оказалась не совсем прозрачной логика методов add_adapter(), set_adapter(), load_adapter() из библиотеки PEFT. Так, к примеру, добавление адаптера с помощью конфига и метода add_adapter() не инициализирует сами веса адаптера и, судя по всему, нацелено именно на юзкейс файнтюнинга модели.
Для инференса же необходимо вызывать именно load_adapter() с указанием идентификатора модели-адаптера с хаба (или локальной папки). Чтобы разобраться с этим, пришлось посмотреть код соответвующих методов, поскольку документация на момент написания очень расплывчатая.
Было забавно наблюдать, как подключенный адаптер Сайги не работает и базовая модель при виде русских символов в промпте выдаёт код, причём на C/C++ и под платформу Windows...
Кроме того, пришлось дописать некоторую логику по пост-процессингу вывода модели OASTT и попотеть над подбором гиперпараметров для генерации. Так модель очевидно плохо уловила специфику чата и старается продолжать реплики за человека. Поэтому я отлавливаю токены, соответствующие началу реплики ### Human: и останавливаю генерацию на них.