멀티헤드 어텐션에서 각 헤드(Head)가 서로 다른 가중치를 사용한다는 것은, 각 헤드마다 독립적인 쿼리(Query), 키(Key), 밸류(Value) 변환 행렬을 갖는다는 의미입니다. 이를 구체적인 예시와 함께 설명하겠습니다.
1. 입력 데이터 예시
- 입력 문장:
"The cat sat on the mat."
- 각 단어의 임베딩 벡터를 아래와 같이 가정합니다.
The
: [0.2, 0.5, 0.1]cat
: [0.7, 0.3, 0.4]sat
: [0.4, 0.6, 0.2]on
: [0.1, 0.8, 0.3]the
: [0.3, 0.2, 0.9]mat
: [0.5, 0.4, 0.7]
2. 멀티헤드 어텐션 구조 (2개 헤드 가정)
- 헤드 1과 헤드 2는 각각 다른 가중치 행렬로 입력을 변환합니다.
- 예시 가중치 행렬 (3차원 → 2차원으로 변환):
- 헤드 1의 가중치:
- WQ1 : [[0.1,0.2], [0.3,0.4], [0.5,0.6]]
- WK1 : [[0.2,0.1], [0.4,0.3], [0.6,0.5]]
- WV1 : [[0.3,0.4], [0.5,0.6], [0.7,0.8]]
- 헤드 2의 가중치:
- WQ2 : [[0.6,0.5], [0.4,0.3], [0.2,0.1]]
- WK2 : [[0.5,0.6], [0.3,0.4], [0.1,0.2]]
- WV2 : [[0.8,0.7], [0.6,0.5], [0.4,0.3]]
- 헤드 1의 가중치:
- 예시 가중치 행렬 (3차원 → 2차원으로 변환):
3. 헤드별 계산 과정
헤드 1의 계산
- 쿼리(Q1), 키(K1), 밸류(V1) 생성:
- 예: 단어
cat
의 쿼리 계산
Qcat1 = cat 임베딩×WQ1 = [0.7, 0.3, 0.4]×WQ1 = [0.7∗0.1+0.3∗0.3+0.4∗0.5, 0.7∗0.2+0.3∗0.4+0.4∗0.6] = [0.07+0.09+0.2 = 0.36, 0.14+0.12+0.24 = 0.5]
→ Qcat1 = [0.36,0.5] - 모든 단어에 대해 Q1, K1, V1를 계산합니다.
- 예: 단어
- 어텐션 스코어 계산:
cat
과sat
의 어텐션 스코어 예시:
Score1 = Qcat1⋅Ksat1 = [0.36,0.5]⋅[Ksat1]
→ 헤드 1은cat
과sat
의 위치적 관계(예: 문법적 구조)에 집중할 수 있습니다.
헤드 2의 계산
- 쿼리(Q2), 키(K2), 밸류(V2) 생성:
- 예: 단어
cat
의 쿼리 계산
Qcat2 = cat 임베딩×WQ2 = [0.7, 0.3, 0.4]×WQ2 = [0.7∗0.6+0.3∗0.4+0.4∗0.2, 0.7∗0.5+0.3∗0.3+0.4∗0.1] = [0.42+0.12+0.08=0.62, 0.35+0.09+0.04=0.48]
→ Qcat2 = [0.62,0.48]
- 예: 단어
- 어텐션 스코어 계산:
cat
과mat
의 어텐션 스코어 예시:
Score2 = Qcat2⋅Kmat2 = [0.62,0.48]⋅[Kmat2]
→ 헤드 2는cat
과mat
의 의미적 관계(예: “고양이가 매트 위에 앉음”)에 집중할 수 있습니다.
4. 결과 통합
- 헤드 1과 헤드 2의 결과를 결합합니다.
- 예:
cat
에 대한 최종 어텐션 값- 헤드 1 결과: 위치적 관계 강조 →
sat
과 높은 어텐션 - 헤드 2 결과: 의미적 관계 강조 →
mat
과 높은 어텐션
- 헤드 1 결과: 위치적 관계 강조 →
- 두 결과를 연결(concatenate)하거나, 평균내어 최종 출력을 생성합니다.
- 예:
핵심 요약
- 다양한 관점 포착:
- 헤드 1: 문법적 구조 (예: 동사
sat
과 명사cat
의 관계). - 헤드 2: 의미적 관계 (예:
cat
과mat
의 물리적 위치).
- 헤드 1: 문법적 구조 (예: 동사
- 가중치의 역할:
- 각 헤드의 WQ, WK, WV가 다르기 때문에, 동일한 입력이라도 다른 방식으로 변환됩니다.
- 예를 들어, 헤드 1의 가중치는 위치 정보를, 헤드 2의 가중치는 의미 정보를 추출하도록 학습됩니다.
- 실제 학습 과정:
- 가중치 행렬은 초기에 랜덤하게 초기화되며, **역전파(backpropagation)**를 통해 최적화됩니다.
- 모델은 태스크(예: 번역, 분류)에 맞게 각 헤드가 어떤 정보에 집중할지 자동으로 학습합니다.
예시 도식화
입력: "The cat sat on the mat" │ ├─ 헤드 1 ──→ [위치적 관계 집중] → "cat" ↔ "sat" │ └─ 헤드 2 ──→ [의미적 관계 집중] → "cat" ↔ "mat"
멀티헤드 어텐션은 다양한 유형의 관계를 병렬로 포착하여 모델의 성능을 향상시킵니다!