인공지능/다양한 인공지능

[Meta-Learning] Prototypical Networks for Few-shot Learning 논문 리뷰

전공생 2023. 8. 24. 11:01

Introduction

이 논문은 이전 연구인 Matching Network와 비슷하게 구현된다.

Matching Network [1]

  • query set에 대한 class를 예측하기 위하여 어텐션 메커니즘을 사용하여 support set의 embedding 값을 학습한다.
    • support set: 라벨이 주어진 데이터셋. support set으로 few-shot task를 위한 준비를 해놓는다(few-shot 학습을 한다).
    • query set: 라벨이 주어지지 않은 데이터셋. support set을 통해 few-shot 학습을 하면, query point에 대해 class를 예측하고, query의 라벨과 비교하여 loss값을 계산한다.
  • 학습할 때 랜덤 추출된 minibatch—episode라 함—를 사용한다.
    • 각 episode는 class와 data point를 일부 랜덤 추출함으로써 few-shot task에서처럼 설계된다.
    • 즉, 학습 데이터도 epoch당 모든 클래스를 사용하지 않고 일부의 클래스를 사용하여 학습함으로써, 일반 모델 학습 중에서도 few-shot learning을 위한 능력이 생기도록 한다.
  • embedding 공간에 weighted nearest-neighbor 분류기를 통해 classification을 한다.
  • Matching Network에 대해 꽤나 자세하게 설명된 이유는 이 논문과 비슷한 아이디어여서도 있고, 뒤에서 Matching Network와 비교하는 내용이 많기 때문인듯 하다.

또 이전의 few-shot learning 중에 meta-learning을 이용한 접근법도 있었다.

Ravi와 Larochelle의 논문 [2]

  • 주어진 episode로 LSTM을 학습하여 classifier를 업데이트 한다. → 모델을 일반화
  • 이 연구는 여러 episode로 하나의 모데을 학습하는 것보단, LSTM의 meta-learner가 각 episode를 위한 custom된 모델을 train하도록 학습시키는게 낫다고 한다.
    • 각 episode에 맞게 모델들을 병렬적으로 학습하고 앙상블하여 모델을 쓴다는 의미인듯 하다.

이러한 이전의 few-shot 연구들은 오버피팅에 취약하다는 문제가 있다.

  • few-shot learning을 위한 데이터가 매우 적다보니까 분류기가 매우 단순한 귀납적(inductive) 편향을 가져야한다는 가정 하에 작업한다.
  • 그래서 고안된 것이 이 논문에서 제안한 Prototypical Networks이다.
    • 각 class에 대한 단일 프로토타입(prototype) 표현을 중심으로 point들이 군집되는 embedding이 존재한다는 아이디어를 기반으로 한다.

이를 위해

  • input값을 neural network를 사용하여 embedding 공간으로 비선형적인 mapping을 하고
  • class의 프로토타입을 embedding 공간의 support set의 평균으로 취한다.
    • 이때 classification은 query의 embedded point가 가장 가까운 class 프로토타입을 찾아서 수행된다.
    • zero-shot learning에도 이와 같은 방법이 사용된다. 이때 각 class는 해당 class를 잘 설명하는 meta-data를 가지게 된다(zero-shot의 경우 제공되는 support set이 없기 때문에 meta-data로 이를 대신한다). 따라서 이 shared space로 가는 meta-data의 embedding을 각 class를 위한 프로토타입을 제공하도록 학습된다.
  • classification을 할때 거리 함수를 어떤 것으로 사용하느냐가 성능에 큰 영향을 주는데, Euclidean 거리를 사용했을 때가 가장 성능이 좋았다고 한다.

prototypical Network는 기존의 meta-learning 알고리즘보다 더 간단하고 효율적인 알고리즘이라고 한다. 실제로, prototypical network를 적용하여 few-shot learning연구를 하는 사례들이 많은 것을 보니 좋은 알고리즘으로 보인다.

Method

Notation

  • $N$ : training set 크기
  • $\mathcal{D} = \left\{ \left( x_1, y_1\right), ..., \left( x_N, y_N\right) \right\}$ : training set
  • $S$ : support set
  • $x_i \in \mathbb{R}^D, y_i \in \left\{ 1, ..., K \right\}$
  • $S_k$ : clsss k에 대한 support set

Model

  1. Select class indices for episode
    • training set $\mathcal{D}$는 총 K개의 class(label)의 데이터로 구성된다. 이중에서 few-shot task를 모방하여 episode를 만들기 위해 일부의 class만 선택한다. 1~K 숫자 중 하나($N_C$)를 선택하여 $ V = \left\{1,.., N_C \right\} $의 클래스를 episode로 선택하여 사용한다.
  2. Select support examples
    • episode에 속하는 클래스 집합 V에 대해
    • 해당 클래스에 해당하는 데이터셋 중에서 일부를 추출하여 support set $S_k$(해당 클래스 $k$라고 가정하면)를 구성한다.
  3. Select query examples
    • 해당 클래스에 해당하는 데이터셋 중에서 (support set으로 뽑은 것을 제외하고) 일부를 추출하여 query set $Q_k$를 구성한다.
  4. Compute prototype($\mathbf{c}_k$) from support examples
    • prototypical network는 M차원의 프로토타입을 계산한다. 각 클래스의 프로토타입은 학습가능한 파라미터인 $\phi$에 대한 embedding 함수 $f_{\phi}:\mathbb{R}^D \rightarrow \mathbb{R}^M$를 통해 계산된다. 각 프로토타입은 해당 class에 속하는 embedded support points의 평균 vector가 된다. $$ \mathbf{x}_k = \frac{1}{|S_k|} \sum{\left( \mathbf{x}_i, y_i\right)\in S_k} {f_\phi \left(\mathbf{x}_i\right)} $$
  5. Update loss
    • loss는 위 알고리즘에 나와있는 식으로 계산되어 업데이트 된다. loss는 query point의 embedding인 $f_{\phi}(\mathbf{x})$와 프로토타입 $\mathbf{c}_k$ 사이의 거리를 통해 계산된다.
    • loss는 SGD를 통해 최소화된다.

Prototypical Network as Mixture Density Estimation

distance 함수로 regular Bregman divergence를 쓰면 Prototypical Networks는 support set에 의해 exponential family density로 mixture density estimation과 동등한 식이 된다.

Reinterpretation as a Linear Model

distance 함수로 Euclidean distance를 사용할 경우 $p_{\phi}(y=k|x)$는 특정한 파라미터를 가지는 선형 모델과 같아진다. 연구에서는 squared Euclidean distance를 주로 사용하였는데, 이는 선형 모델과 똑같음에도 불구하고 효율적이라고 한다. embedding 함수에 필요한 모든 비선형성이 충분히 포함되어 학습될 수 있기 때문인데, 이러한 개념은 Deep Neural Network에 적용된 개념과 같다.

Comparison to Matching Networks

one-shot learning의 경우 class당 하나의 support point가 있기 때문에 $\mathbf{c}_k=\mathbf{x}_k$가 되기 때문에 Prototypical Network와 Matching Network가 동일한 방법이라고 할 수 있다.

반면, few-shot learning의 경우에는 두 방법은 확연히 다른 방법이라고 할 수 있다.

Design Choices

Distance metric

Matching Network에서는 distance 함수로 cosine distance를 사용하였다. Matching Network와 Prototypical Network 모두 다양한 distance 함수를 사용할 수 있는데, 둘다 Euclidean distance를 사용했을때 좋은 성능을 가졌다. Prototypical Network에서는 cosine distance를 사용할 경우 Bregman divergence가 되지 않기 때문에 network가 mixture density esimation과 같아지지 않아서 그런 성능이 나오는 것으로 보인다.

Episode composition

episode를 구성하는 간단한 방법은 episode에 사용되는 class 개수 $N_C$와 class당 사용되는 support points 개수 $N_S$를 test 때와 학습 시에 동일한 값으로 사용하는 것이다(즉, $N_C$-way와 $N_S$-shot을 동일하게 맞춤).

그러나 실험해본 결과, test때 사용되는 way보다 학습할 때 더 큰 값을 사용했을 때 더 좋은 성능을 보였다.

반면, shot은 학습과 test에서 동일한 값을 사용할 때 가장 좋은 성능을 보였다.

Zero-shot Learning

few-shot learning의 경우 training points의 support set이 주어진 반면, zero-shot learning의 경우 class당 class의 meta-data vector $v_k$가 주어진다. 이 meta-data는 미리 결정되거나 raw-text로부터 학습된다.

Prototypical Network에서는 $v_k$를 구별되는 embedding으로 mapping해주는 $g_{\phi}$를 정의하여 few-shot learning과 동일하게 적용한다. 실험해본 결과, 이때 query point와 meta-data vector는 다른 도메인으로부터 생겨난 것이기 때문에, query embedding인 $f_{\phi}$를 제한하지 않고 프로토타입 embedding $g_{\phi}$를 unit length를 가지도록 고정하는게 도움이 되었다.

왼쪽은 3-way 5-shot의 few-shot learning으로 볼 수 있고, 오른쪽은 zero-shot learning이다. zero-shot learning의 경우 support points가 없기 때문에 meta-data vector로 프로토타입을 정의한다. 그외의 계산은 두 방법 모두 동일하다.


Reference

[1] Matching networks for one shot learning

[2] Optimization as a model for few-shot learning