ELECTRA
- MLM 방식에서 좀 더 확장된 태스크와 NSP으로 학습을 진행한다.
- 마스킹된 토큰을 맞히는 방식에서 더 나아가 MLM 예측값이 실제 토큰인지 아닌지 맞히는 태스크를 진행한다.→ 기존 BERT 모델에서도 마스킹된 15% 단어 내에서 10%는 마스킹 토큰[MASK] 대신 틀린 단어를 넣어 두고, 또 다른 10%는 마스킹하지 않는 방식으로 fine-tuning 학습 데이터와의 간극을 좁히고자 하였다. 하지만 MLM이 모델 학습의 주요 태스크이므로 MASKED 데이터를 안쓰는 것이 아니다. 반면 ELECTRA 모델에서 적용된 replaced token detection task의 경우 인풋 데이터로 실제 문장에서 일부 단어가 대체된 문장이 활용되므로 구멍없는 멀쩡한 문장을 쓴다는 점에서 MLM에서의 간극을 해결한다.
- → 기존 MLM 학습 데이터와 fine-tuning에서 활용되는 데이터와의 불일치 문제를 해결하기 위한 방안
- 사전 학습을 위해 생성자(generator)와 판별자(discriminator)로 불리는 BERT 모델을 활용한다.
Generator
실제 문장에서 일부 단어가 수정된 문장을 입력값으로 써야하므로 이를 만드는 작업이 선행되어야 한다. 이때는 MLM 태스크를 적용한다.
MASK 처리하기 전 문장의 임베딩 벡터 $e(x)$와 이를 MASK처리한 뒤 12개의 인코더로 구성된 Generator에 입력한 결과인 $h_G(X)$ 이 두 가지 데이터에서 MASK된 위치에 해당하는 두 벡터 간 내적과 소프트맥스 함수 적용을 통해 MASK된 단어가 vocab에서 어떤 단어에 해당할 것 같은지에 대한 확률 분포를 출력하게 한다.
$P_G(x_t|X) = \frac{exp(e(x_t)^Th_G(X)t)}{\sum{x^\prime}exp(e(x^\prime)^Th_G(X)_t)}$
그 결과의 argmax값을 인덱싱한 단어가 MASK되었던 단어의 대체 단어가 되고, 결론적으로 수정된 문장이 생성된다. 논문에서는 이를 $X^{corrupt}$라고 한다.
Discriminator
생성자를 통해 얻은 $X^{corrupt}$를 또 다른 BERT 모델의 입력값으로 쓴다. 여기에서는 인코딩 결과를 시그모이드 함수를 지닌 피드포워드에 입력하여 각각의 토큰이 변경된 토큰인지 아닌지에 대한 이진 분류 태스크를 통해 학습을 진행한다.
$D(X, t) = sigmoid(w^Th_D(X)_t)$
ELECTRA 모델 학습
결론적으로 아래와 같은 태스크를 수행한다.
- uniform distribution에서 마스킹할 지점을 선정한다.
- 마스킹 지점에 해당하는 단어를 [MASK]로 대체한다. → $X^{masked}$
- 생성자를 통해 마스킹 된 단어를 예측하는 태스크를 수행한다.
- 예측 단어로 [MASK]를 대체한다. → $X^{corrupt}$
또한 위 사진을 통해 Loss function이 두 개 존재하는 것을 확인할 수 있다.
- 마스킹 부분이 실제값이 될 likelihood를 $-log$를 취해 minimize하는 방식으로 weight 학습
- 생성자의 경우 마스크된 값 하나를 보는 것이었지만 판별자는 모든 토큰에 대해서 정답(0 또는 1)에 근사하도록 weight를 학습한다.
효율적인 모델 학습 방법 탐색
논문에서는 효율적인 모델 학습을 위한 실험을 진행했고, 그 중 모델 사이즈를 변경하는 것과 학습 알고리즘을 변경한 실험 결과에 대해 소개하고 있다. 그 중 모델 사이즈 변경과 관련해서 아래와 같이 소개되었다.
논문에서는 모델의 효율적인 사전 학습을 위해 생성자와 판별자의 weight를 공유할 것을 제안했다. 그렇게 하면 생성자와 판별자 이 두 BERT 모델의 weight가 하나로 묶이게 된다. 논문을 읽다보면 이 말을 왜 했나 싶었다. 왜냐하면 이 문장 이후로는 생성자와 판별자의 weight를 공유한 사례가 더이상 소개되지 않기 때문이다. 개인적으로 weight가 공유되는 것이 중요하다는 것을 말하고 싶어서였던 것 같기도 하고...
어쨌든 그 외 실험 결과에 따르면, 생성자의 사이즈를 줄이고, 토큰과 포지셔널 임베딩 레이어의 weight만 공유하여 학습하는 것이 GLUE score가 높게 측정되는 것을 확인할 수 있었다고 한다. 특히 생성자의 크기가 판별자와 동일하거나 더 커지는 경우 성능이 저하되는 것을 확인할 수 있다.
'데이터사이언스 이론 공부' 카테고리의 다른 글
BERT의 파생모델 [DistilBERT] (0) | 2022.12.30 |
---|---|
BERT의 파생모델 [SpanBERT] (0) | 2022.12.29 |
BERT의 파생 모델 [RoBERTa] 특징 (0) | 2022.12.22 |
BERT의 파생 모델 [ALBERT] (1) | 2022.12.20 |
BERT 모델에서의 임베딩 벡터 추출 방식에 관하여 (0) | 2022.12.17 |