AI/vision

[논문 리뷰] Flamingo: a Visual Language Model for Few-Shot Learning

민사민서 2025. 1. 3. 16:53

Introduction

  • VLM의 시초라고도 할 수 있는 모델. 아래와 같은 achievement를 달성하였다
    • pretrained vision-only modellanguage-only model을 효과적으로 연결(bridge)했다
    • visual + textual data가 랜덤하게 interleave된 시퀀스도 처리 가능하다 => large scale web data를 긁어올 수 있었음
    • 이미지/동영상 모두 원활하게 처리 가능하다

  • in-context few-shot learning capability를 통해 별도의 fine-tune 없이도 여러 vision&language task에서 SOTA를 달성했다
    • 기존 computer vision 분야에서의 국룰은 large supervised data로 pretrain → task of interest에 맞게 finetune
    • flamingo는 few input/output examples만 프롬프트로 제공하면 (few-shot learning) 된다
    • 16개의 멀티모달 task 중 6개에서 finetuned SOTA 모델을 앞섰다 (1000배나 적은 task-specific training data로도)
  • VQA와 같은 open-ended task도 잘한다 !! (대신 classification은 약간 못하는)
  • 문제를 visual input conditioning이 주어진 text prediction problem으로 치환하여 해결한다
    • 2개의 pre-trained + frozen model이 존재한다 = vision model (perceive visual scenes) + large LM (perform basic reasoning)
    • perceiver-based 아키텍쳐 덕분에 high-resolution image나 video도 원활하게 처리할 수 있다
  • 머신러닝 목적으로 annote된 데이터셋을 사용하지 않았다, 대신 web에서 얻은 mixture of large-scale multimodal data를 활용

 

Model Structure

Flamingo 아키텍쳐이다

Vision Encoder에서 spatio-temporal feature를 추출하고 → perceiver resampler를 거쳐 fixed number of visual token이 나오고 → cross attention layer와 interleave + frozen LM layer 를 거쳐 → next token prediction task를 수행한다

l-th language token은 앞선 language token과 image/video preceding token의 영향을 받는다 (실제로는 masked cross attention 때문에 직전 visual feature에만 영향을 받음)

 

Vision Encoder

- Normalizer-Free ResNet(NFNet) F6 모델을 사용했으며, pretrain 후에 모델 학습 시에는 freeze 했다

- language encoder로는 BERT 아키텍쳐를 사용했다

- image/text pair 각각에 대해 인코딩한 뒤 embedding에 대해 contrastive learning 방식을 적용했다 = paried embedding의 similarity는 maximize하고 unpaired embedding의 similarity는 minimize하는 방식

- CLIP과 같이 두 개의 contrastive loss를 정의하고 sum을 최소화하는 방향으로 multi-class cross entropy loss 

=> Flamingo 모델에는 contrastive vision encoder의 weight만 가져다썼다고 한다 (only vision encoder part)

 

Perceiver Resampler

- varying size의 image/video feature를 인풋으로 받아 fixed number of visual outputs(64)를 내보낸다

- ablation study에 따르면 이런 resampling module이 있을 때와 단순 MLP layer로 대체했을 때와 성능 차이가 크다고 함

- x = predefined size의 learnable latent input query이다

- X_f = flatten된 visual feature이다, given frame에 따라 learnt temporal position encoding을 더하고 (이미지는 frame=1 video로 간주) 쭉 flatten해서 생성한다, spatial grid position encoding은 따로 하지 않음

- Query로 X, Key/Value로 X_f와 X를 concat한 값을 넣어 attention 수행 + FFN 통과시킴

- 최종적인 output token tnsms learnt latent queries의 개수와 동일

 

gated xattn-dense layers

- pretrained LM blocks를 두고 그 사이에 gated cross-attention dense block을 삽입 (learnable, trained from scratch)

- tanh gating mechanism // ablation study에서 효과적임이 증명됨

- perceiver resampler로부터 나온 vision output을 K,V / language input을 Q로 해서 cross attention

 

interleave visual data and text

- masked cross attention 방식을 사용한다

- text token이 주어지면 model은 last preceding image/video에 대응되는 visual token하고만 cross-attends 한다, 하늘색 부분은 전부 masking 되어있는 것 = 즉 한 번에 한 개의 image만 directly attend한다

- all previous images에 대한 dependency는 LM에서의 self-attention에 남아있겠죠

 

Training

  • M3W dataset (interleaved iamge and text dataset)
    • Web에서 HTML 추출하고 DOM 기반으로 element 추출, 태그 추가
  • ALIGN, LTIP, VTP 등등의 데이터셋 사용

 

Task Adaptation with few-shot in-context learning

  • (image,text) 혹은 (video,text) 형태로 example pairs를 제공하고, query visual input을 더해 프롬프트를 제공

 

Analysis

ablation studies에서 주목할만한 관점들?

  • gated cross attention dense layer에서 0-initialized tanh gating을 없앴더니 training instability가 커지더라
  • gated xatten-dense layer를 추가하면 compute 속도는 느려지지만 performance가 올라간다 => Flamingo-9B에서는 4개 layer마다 1개씩 Insert
  • Perceiver Resampler를 MLP로 바꾸니 성능이 떨어지더라
  • Freezing LM component가 생각보다 중요하다, freeze하지 않고 같이 fine-tune 했더니 오히려 성능이 많이 떨어지더라 = model이 새로운 objective를 학습하는 과정에서 사전지식을 점차 잊어버리는 = catastrophic forgetting

한계?

  • LM(Transformer decoder)이 text generation을 담당, LM에서 발생하는 hallucination, ungrounded guess, poor generalization when sequnece is too long 이런 문제들이 존재
  • classification 성능은 그리 좋지 못하다, 애초에 open-ended task를 타겟으로 한 것
  • in-context learning은 뛰어난 방법이지만 어떻게 사용하냐에 따라 성능이 달라짐, sensitive to various aspects of demonstrations