https://github.com/salesforce/LAVIS/tree/main/lavis/models/blip2_models
Introduction
- 기존 Vision-language pretraining (VLP) 연구의 한계 => end-to-end fashion으로 큰 규모의 모델과 데이터셋을 학습하려다보니 computational cost가 무척 큼
- 일반적이고 compute-efficient한 VLP method 제시, pre-trained vision model과 language model 모델을 가져옴
- 연산의 효율성과 catastrophic forgetting 방지를 위해 unimodal pretrained model은 frozen
- cross-modal alignment를 위해 lightweight transformer인 Q-former 도입
- set of learnable query vectors frozen image encoder의 output 중 유의미한 visual features 추출
- pre-training stage는 크게 두 단계로 구성됨
- vision-language representation learning
- vision-to-language generative learning
- 결과적으로 modality gap을 Q-Former로 잘 메꿨으며, LLM의 power인 제로샷 image-to-text generation도 잘 함
Related Work
- end-to-end fashion의 VLP 작업들
- architecture: dual-encoder, fusion-encoder, encoder-decorder, unified transformer ...
- various objectives: image-text contrastive learning, masked language modeling ...
- 높은 연산 비용 뿐 아니라 end-to-end이므로 우수한 pretrained unimodal model을 유연하게 적용하지 못함
- 사전 학습된 모델을 가져와 VLP 도중에는 freeze 하는 방식들
- 주된 목표는 visual feature를 text space로 align하는 것
- Frozen => fintunes image encoder, encoder의 output이 soft prompt for LLM
- Flamingo => LLM에 cross-attention layer를 삽입하여 visual feature도 반영
Model Architecture
- modality gap을 연결하기 위해 Querying Transformer (Q-Former) 도입
- 두 종류의 transformer submodule이 있어 self attention layer를 공유함
- image transformer(좌): frozen image encoder의 output이 매 블록에 inject됨
- text transformer(우): text decoder 혹은 text encoder 둘 다 가능
- set of learnable query 임베딩이 image transformer의 인풋으로 들어간다
- 당연히 frozen image features와 cross attention으로 상호작용하고, shared self attention layer에서 text와도 상호작용
- 이 논문에서는 32 queries x 768 dimension each => 이게 곧 output query representation Z
- xattn layer는 랜덤하게 초기화 / self attn layer는 BERT_base의 weight로 초기화
two-stage pre-training
Vision-Language representation learning stage (frozen image encoder 만 연결 )
- Q-Former의 queries가 text 관점에서 가장 informative한 visual representation을 추출하도록 학습시키고픔
- text와 relevant한 visual features를 추출하는 것을 학습시켜 LLM이 VL alignment를 배워야하는 부담 감소
- 세 가지 objective 적용
- Image-Text Contrastive Learning(ITC)
- image transformer의 output query representation Z와 [CLS] token이 text transformer를 거쳐 나온 output embedding t 를 align 하고자 했다
- queries와 text는 서로를 볼 수 없게 self-attention mask를 적용 => image 표현과 text 표현을 align
- Image-grounded Text Generation (ITG)
- query는 query끼리만 볼 수 있고, text token은 query와 이전 text만 볼 수 있게 masking을 적용
- input images라는 condition이 주어졌을 때 text를 생성하도록 학습
- Image-Text Matching (ITM)
- image and text representation을 좀 더 fine-grained한 관점에서 align하기 위한 단계이다
- hard negative image-text pair 쌍을 추가로 수집하여 이 pair가 positive(match)인지 negative(unmatch)인지 binary classification task를 수행
- Image-Text Contrastive Learning(ITC)
Vision-to-Language generative learning stage (from froze encoder)
- Q-former에 frozen LLM을 붙여 LLM의 generative language capability를 최대한 살리고자 함
- output query embedding Z와 LLM의 text embedding을 맞춰주기 위해 FC layer 추가하여 projection
- decoder-based LLM => soft visual prompts가 다이렉트로 input으로 들어가 학습
- encoder-decoder-based LLM => soft visual prompts + prefix text 넣어서 학습
Model Pre-training
- BLIP과 동일한 pre-training dataset 상ㅇ, CapFilt method 사용해 web에서 syntethic captions data 마련
- image encoder: CLIP의 ViT 모델을 가져와 second last layer의 output feature를 사용함
- language model: unsupervised-trained OPT model / instruction-trained FlanT5 model
Experiments
간략하게만 짚어보자면
- stronger image encoder / stronger LLM을 사용하면 성능이 좋아진다
- Vision-language representation learning (stage 1)은 상당히 효과적이었다
- LLM이 vision-language alignment를 배워야하는 burdern을 줄여주었다, Q-Former가 text-relevant한 vision feature를 효과적으로 뽑아주기 때문에
- 만약 이 단계가 없었다면 Flamingo의 Perceiver Resampler와 비슷했을 것, vision-to-language generative learning으로만 학습하는 것이므로
- representation learning이 없을 때 catastrophic forgetting이 발생하기도 했다
- image captioning task를 위해 LLM initial input에 "a photo of" 프롬프트를 추가해 finetune
- image encoder + Q-Former 학습, 상당한 성능 개선
- VQA task를 위해 image encoder + Q-Former 만 추가적으로 학습 => open-ended generation에서도 SOTA 성능
- Image-Text Retrieval을 위해 first-stage-pretrained 모델만 똑 떼와서 image encoder + Q-Former finetune
- SOTA 급의 성능, ITG loss 또한 generation 관련이지만 image-text retrieval에도 유용하더라
Limitation
- in-context VQA example을 잘 하지 못한다
- single image-text pair로만 학습했기 때문에 여러 image-text pairs의 상관관계를 충분히 학습하지 못함
- Flamingo의 경우 자체적인 M3W dataset을 구축하여 multiple image-text pairs per sequence 사용
- LLM의 한계가 그대로 반영된다 (frozen model)
- incorrect한 추론 경로, up-to-date한 information은 잘 모름
- 공격적인 언어, social bias, 민감한 정보 유출