MLA(Multi-Head Latent Attention) 쉽게 이해하기(with예시)

멀티헤드 어텐션에서 각 헤드(Head)가 서로 다른 가중치를 사용한다는 것은, 각 헤드마다 독립적인 쿼리(Query)키(Key)밸류(Value) 변환 행렬을 갖는다는 의미입니다. 이를 구체적인 예시와 함께 설명하겠습니다.


1. 입력 데이터 예시

  • 입력 문장: "The cat sat on the mat."
  • 각 단어의 임베딩 벡터를 아래와 같이 가정합니다.
    • The: [0.2, 0.5, 0.1]
    • cat: [0.7, 0.3, 0.4]
    • sat: [0.4, 0.6, 0.2]
    • on: [0.1, 0.8, 0.3]
    • the: [0.3, 0.2, 0.9]
    • mat: [0.5, 0.4, 0.7]

2. 멀티헤드 어텐션 구조 (2개 헤드 가정)

  • 헤드 1과 헤드 2는 각각 다른 가중치 행렬로 입력을 변환합니다.
    • 예시 가중치 행렬 (3차원 → 2차원으로 변환):
      • 헤드 1의 가중치:
        • WQ1 ​: [[0.1,0.2], [0.3,0.4], [0.5,0.6]]
        • WK1 ​: [[0.2,0.1], [0.4,0.3], [0.6,0.5]]
        • WV1 ​: [[0.3,0.4], [0.5,0.6], [0.7,0.8]]
      • 헤드 2의 가중치:
        • WQ2 ​: [[0.6,0.5], [0.4,0.3], [0.2,0.1]]
        • WK2 ​: [[0.5,0.6], [0.3,0.4], [0.1,0.2]]
        • WV2 ​: [[0.8,0.7], [0.6,0.5], [0.4,0.3]]

3. 헤드별 계산 과정

헤드 1의 계산

  1. 쿼리(Q1), 키(K1), 밸류(V1) 생성:
    • 예: 단어 cat의 쿼리 계산
      Qcat1 = cat 임베딩×WQ1 = [0.7, 0.3, 0.4]×WQ1 = [0.7∗0.1+0.3∗0.3+0.4∗0.5, 0.7∗0.2+0.3∗0.4+0.4∗0.6] = [0.07+0.09+0.2 = 0.36, 0.14+0.12+0.24 = 0.5]
      → Qcat1 = [0.36,0.5]
    • 모든 단어에 대해 Q1, K1, V1를 계산합니다.
  2. 어텐션 스코어 계산:
    • cat과 sat의 어텐션 스코어 예시:
      Score1 = Qcat1⋅Ksat1 = [0.36,0.5]⋅[Ksat1]
      → 헤드 1은 cat과 sat의 위치적 관계(예: 문법적 구조)에 집중할 수 있습니다.

헤드 2의 계산

  1. 쿼리(Q2), 키(K2), 밸류(V2) 생성:
    • 예: 단어 cat의 쿼리 계산
      Qcat2 = cat 임베딩×WQ2 = [0.7, 0.3, 0.4]×WQ2 = [0.7∗0.6+0.3∗0.4+0.4∗0.2, 0.7∗0.5+0.3∗0.3+0.4∗0.1] = [0.42+0.12+0.08=0.62, 0.35+0.09+0.04=0.48]
      → Qcat2 = [0.62,0.48]
  2. 어텐션 스코어 계산:
    • cat과 mat의 어텐션 스코어 예시:
      Score2 = Qcat2⋅Kmat2 = [0.62,0.48]⋅[Kmat2]
      → 헤드 2는 cat과 mat의 의미적 관계(예: “고양이가 매트 위에 앉음”)에 집중할 수 있습니다.

4. 결과 통합

  • 헤드 1과 헤드 2의 결과를 결합합니다.
    • 예: cat에 대한 최종 어텐션 값
      • 헤드 1 결과: 위치적 관계 강조 → sat과 높은 어텐션
      • 헤드 2 결과: 의미적 관계 강조 → mat과 높은 어텐션
    • 두 결과를 연결(concatenate)하거나, 평균내어 최종 출력을 생성합니다.

핵심 요약

  1. 다양한 관점 포착:
    • 헤드 1: 문법적 구조 (예: 동사 sat과 명사 cat의 관계).
    • 헤드 2: 의미적 관계 (예: cat과 mat의 물리적 위치).
  2. 가중치의 역할:
    • 각 헤드의 WQ, WK, WV​가 다르기 때문에, 동일한 입력이라도 다른 방식으로 변환됩니다.
    • 예를 들어, 헤드 1의 가중치는 위치 정보를, 헤드 2의 가중치는 의미 정보를 추출하도록 학습됩니다.
  3. 실제 학습 과정:
    • 가중치 행렬은 초기에 랜덤하게 초기화되며, **역전파(backpropagation)**를 통해 최적화됩니다.
    • 모델은 태스크(예: 번역, 분류)에 맞게 각 헤드가 어떤 정보에 집중할지 자동으로 학습합니다.

예시 도식화

입력: "The cat sat on the mat"
       │
       ├─ 헤드 1 ──→ [위치적 관계 집중] → "cat" ↔ "sat"
       │
       └─ 헤드 2 ──→ [의미적 관계 집중] → "cat" ↔ "mat"

멀티헤드 어텐션은 다양한 유형의 관계를 병렬로 포착하여 모델의 성능을 향상시킵니다!

답글 남기기

이메일 주소는 공개되지 않습니다. 필수 필드는 *로 표시됩니다

error: