๐Ÿ“š ๋…ผ๋ฌธ

Attention is All You Need

์žฅ์˜์ค€ 2023. 6. 16. 02:08

Background

Seq2Seq ๋ชจ๋ธ

  • Encoder์™€ Decoder๋ฅผ ๊ฐ๊ฐ RNN์œผ๋กœ ๊ตฌ์„ฑํ•˜๋Š” ๋ฐฉ์‹์ž…๋‹ˆ๋‹ค.
  • ๋™์ž‘์›๋ฆฌ
    1. โ€˜๋‚˜๋Š”โ€™, โ€˜ํ˜ธ๋‘๋ฅผโ€™, โ€˜์‚ฌ๋ž‘ํ•ดโ€™๋ผ๋Š” 3๊ฐœ์˜ ํ† ํฐ๋“ค์„ ์ˆœ์ฐจ์ ์œผ๋กœ LSTM ์…€์— ๋„ฃ์œผ๋ฉด , hidden state๋ฅผ ํ•˜๋‚˜์”ฉ ์ถœ๋ ฅํ•ฉ๋‹ˆ๋‹ค.
    2. ์ด๋ ‡๊ฒŒ ํ† ํฐ๋“ค์˜ hidden state๋“ค์ด ์ถœ๋ ฅ๋˜๋ฉด, ๋งˆ์ง€๋ง‰ hidden state๋Š” ์ •๋ณด๋ฅผ ์••์ถ•ํ•œ vector๊ฐ€ ๋˜๊ณ , ์ด๋ฅผ Context Vector ๋ผ๊ณ  ์นญํ•ฉ๋‹ˆ๋‹ค.
    3. Context Vector๋ฅผ ํ†ตํ•ด ์ดํ›„ token๋“ค์„ ๋„ฃ์—ˆ์„ ๋•Œ ๋‹ค์Œ token ์˜ˆ์ธก์„ ์œ„ํ•œ hidden state๊ฐ€ ์ถœ๋ ฅ๋ฉ๋‹ˆ๋‹ค.
  • ๋ฌธ์ œ์ 
    1. Sequence๊ฐ€ ๊ธธ์–ด์ง€๋Š” ๊ฒฝ์šฐ์—๋Š” Gradient Vanishing ๋ฌธ์ œ๊ฐ€ ๋ฐœ์ƒํ•˜์—ฌ Context Vector์— ์•ž ์ˆœ์„œ token๋“ค์˜ ์ •๋ณด๊ฐ€ ์†Œ์‹ค๋˜๋Š” ๋ฌธ์ œ๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค.
    2. Context vector๋กœ๋Š” encoder์˜ ์ •๋ณด๋ฅผ ๋ชจ๋‘ ๋‹ด๊ธฐ ํž˜๋“ค์—ˆ๊ณ , ๊ทธ์— ๋”ฐ๋ผ decoder์˜ ๋ฒˆ์—ญ ํ’ˆ์งˆ์ด ์ €ํ•˜๋˜๋Š” Bottleneck Effect๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค.

๐Ÿ’ก ์ด๋Ÿฌํ•œ ๋ฌธ์ œ์ ๋“ค์— ๋Œ€ํ•œ ํ•ด๊ฒฐ์„ ์œ„ํ•ด Attention์„ ์ด์šฉํ•˜์ž๋Š” ์•„์ด๋””์–ด๊ฐ€ ์ œ์•ˆ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.

Attention ์˜ˆ์‹œ (Long-term Attention)

  1. hidden state๋“ค์„ ํ•˜๋‚˜์˜ ํ–‰๋ ฌ๋กœ ๋ณด์กดํ•œ ํ›„, decoder์˜ ์‹œ์ž‘ ํ† ํฐ์„ LSTM ์…€์— ๋„ฃ์€ hidden state์™€ ๋‚ด์  ํ•˜์—ฌ Alignment Scores์™€ Alignment Weights๋ฅผ ๊ตฌํ•ฉ๋‹ˆ๋‹ค.
  2. Attention Weights๋ฅผ hidden state ํ–‰๋ ฌ์— ๊ณฑํ•˜์—ฌ ๊ฐ€์ค‘ํ•ฉ์„ ๊ตฌํ•ฉ๋‹ˆ๋‹ค.
  3. ์œ„์˜ ๊ณผ์ •๋“ค์„ ์ˆ˜ํ–‰ํ•˜์—ฌ ๊ตฌํ•œ Attention Value๋“ค์„ concatํ•˜๊ณ , tanh, softmax ํ•จ์ˆ˜๋ฅผ ์ด์šฉํ•ด ์˜ˆ์ธก ํ™•๋ฅ ์„ ๊ตฌํ•˜์—ฌ ๋‹ค์Œ token์„ ์˜ˆ์ธกํ•ฉ๋‹ˆ๋‹ค.

๐Ÿ’ก Attention์„ ํ†ตํ•ด ํ˜„์žฌ ์ถœ๋ ฅ ๋Œ€์ƒ์— ๋Œ€ํ•ด ๋ชจ๋“  encoder์˜ hidden state๋ฅผ ๊ตฌํ•  ์ˆ˜ ์žˆ๊ฒŒ ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.


Training Architecture

Token๋“ค์ด encoder์— sequentialํ•˜๊ฒŒ ๋“ค์–ด๊ฐ€์„œ ์ˆœ์„œ๊ฐ€ ๋ณด์กด๋˜๋Š” Seq2Seq ๋ชจ๋ธ๊ณผ๋Š” ๋‹ค๋ฅด๊ฒŒ, Transformer๋Š” ํ•œ ๋ฒˆ์— ๋ชจ๋“  ํ† ํฐ์„ encoder์— ๋„ฃ๋Š” ๋ฐฉ์‹์œผ๋กœ ์„ค๊ณ„๋˜์–ด ์ˆœ์„œ์™€ ๊ด€๋ จ๋œ ์ •๋ณด๊ฐ€ ๋ณด์กด๋˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.

์ด๋Ÿฌํ•œ ๋ฌธ์ œ์ ์„ Position Encoding์œผ๋กœ ํ•ด๊ฒฐํ–ˆ์Šต๋‹ˆ๋‹ค.

Positional Encoding

positional encoding์€ ๋‹ค์Œ๊ณผ ๊ฐ™์€ ํŠน์ง•์„ ๊ฐ€์ง‘๋‹ˆ๋‹ค:

  1. ๊ฐ ํ† ํฐ์˜ ์œ„์น˜๋งˆ๋‹ค ์œ ์ผํ•œ ๊ฐ’์„ ์ง€๋…€์•ผ ํ•ฉ๋‹ˆ๋‹ค.
  2. ํ† ํฐ ๊ฐ„ ์ฐจ์ด๊ฐ€ ์ผ์ •ํ•œ ์˜๋ฏธ๋ฅผ ์ง€๋…€์•ผ ํ•ฉ๋‹ˆ๋‹ค.
  3. ๋” ๊ธด ๊ธธ์ด์˜ ๋ฌธ์žฅ์ด ์ž…๋ ฅ๋˜์–ด๋„ ์ผ๋ฐ˜ํ™”๊ฐ€ ๊ฐ€๋Šฅํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

๐Ÿ’ก ์ด๋ ‡๊ฒŒ embedding์— position encoding์ด ๋”ํ•ด์ง€๋ฉด, ์‹œ๊ณต๊ฐ„์  ํŠน์„ฑ์„ ๋ฐ˜์˜ํ•œ ํ–‰๋ ฌ๋กœ ๋ณ€ํ™˜๋ฉ๋‹ˆ๋‹ค.

Encoder์˜ ๊ตฌ์กฐ

  1. Input์„ 3๊ฐœ ๋ณต์‚ฌํ•˜์—ฌ, Qurey, Key, Value ๊ฐ๊ฐ์˜ ๊ฐ€์ค‘์น˜๋ฅผ ๊ณฑํ•ด Q, K, V ํ–‰๋ ฌ์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
  2. Q์™€ K์˜ ํ–‰๋ ฌ๊ณฑ์„ ์ˆ˜ํ–‰ํ•˜์—ฌ ์Šค์ผ€์ผ๋งํ•˜๊ณ , masking๊ณผ softmax ํ•จ์ˆ˜๋ฅผ ๊ฑฐ์ณ ๋‚˜์˜จ ๊ฒฐ๊ณผ๋ฌผ์„ V์™€์˜ ํ–‰๋ ฌ๊ณฑ์„ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค. ์ด๋•Œ, ๋‹ค์–‘ํ•œ task ์ˆ˜ํ–‰์„ ์œ„ํ•ด multi-head attention๊ธฐ๋ฒ•์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
  3. ์ดํ›„ ๋งŒ๋“ค์–ด์ง„ 2๊ฐœ์˜ ํ–‰๋ ฌ์„ concatํ•˜๊ณ , residual connection, layer normalization์„ ์ˆ˜ํ–‰ํ•˜์—ฌ feed-forward network์— ๋„˜๊น๋‹ˆ๋‹ค.
  4. ๋‹ค์‹œ residual connection, layer normalization์„ ์ˆ˜ํ–‰ํ•˜์—ฌ output์„ ์–ป์Šต๋‹ˆ๋‹ค.

Q, K, V์˜ ๊ณ„์‚ฐ ๋ฐฉ์‹

์ด๋•Œ, output๊ณผ input์˜ ํฌ๊ธฐ๋Š” ๋™์ผํ•ฉ๋‹ˆ๋‹ค !!!

ํ•ด๋‹น ๋…ผ๋ฌธ์—์„œ๋Š” 6๊ฐœ์˜ encoder layer๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค. output๊ณผ input์˜ ํฌ๊ธฐ๊ฐ€ ๋™์ผํ•˜๋ฏ€๋กœ, ๊ฐ™์€ ๊ตฌ์กฐ์˜ encoder layer๋ฅผ 6๋ฒˆ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

Decoder์˜ ๊ตฌ์กฐ

  1. ํƒ€๊นƒ ๋ฌธ์žฅ ํ•™์Šต ์‹œ input token๋“ค์„ one-hot-encoding์œผ๋กœ ๋ณ€ํ™˜ํ•˜์—ฌ embedding layer์— ๋„ฃ๊ณ , positional encoding์„ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.
  2. ์•ž์„  encoder์˜ 1~3๋ฒˆ ๊ณผ์ • (self-attention ๊ณผ์ •)์„ ๋ฐ˜๋ณตํ•˜์—ฌ ์‚ฐ์ถœ๋œ ๊ฒฐ๊ณผ๋ฅผ Query๋กœ, Encoder์˜ ouput์„ 2๊ฐœ ๋ณต์‚ฌํ•˜์—ฌ Key์™€ Value ๊ฐ’์œผ๋กœ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
    • Self-attention ์ˆ˜ํ–‰ ์‹œ encoder์™€ decoder์˜ ๋‹ค๋ฅธ masking ๊ธฐ๋ฒ•
๋”๋ณด๊ธฐ
- padding masking ๊ธฐ๋ฒ•๋งŒ์„ ์‚ฌ์šฉํ•˜๋Š” encoder์™€๋Š” ๋‹ค๋ฅด๊ฒŒ, decoder์—์„œ๋Š” padding masking๊ณผ ์‚ผ๊ฐํ–‰๋ ฌ ์œ—๋ถ€๋ถ„๊นŒ์ง€ maskingํ•˜๋Š” ๊ธฐ๋ฒ•์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
- decoder์—์„œ ํ•ด๋‹น masking ๊ธฐ๋ฒ•์„ ์‚ฌ์šฉํ•˜๋Š” ์ด์œ ๋Š” attention์˜ ๊ฐ€์ค‘์น˜๋ฅผ ๊ตฌํ•  ๋•Œ, ๋Œ€์ƒ token ์ดํ›„์˜ token์€ ์ฐธ์กฐํ•˜์ง€ ๋ชปํ•˜๋„๋ก ํ•˜๊ธฐ ์œ„ํ•ด์„œ์ž…๋‹ˆ๋‹ค.
- ๋งŒ์•ฝ ๊ทธ๋ฆผ๊ณผ ๊ฐ™์ด ์‚ผ๊ฐํ–‰๋ ฌ์˜ ์œ—๋ถ€๋ถ„์„ maskingํ•˜์ง€ ์•Š๋Š”๋‹ค๋ฉด, ์ฒซ๋ฒˆ์งธ ํ–‰์—์„œ I๋Š” love, Hodu~too์˜ ๋ชจ๋“  ์ •๋ณด๋“ค์„ ๋ฏธ๋ฆฌ ๋ฐ›์•„์˜ค๋Š” ๋ฌธ์ œ๊ฐ€ ๋ฐœ์ƒํ•ฉ๋‹ˆ๋‹ค.

3. ์•„๋ž˜ ์ˆ˜์‹์„ ํ™œ์šฉํ•˜์—ฌ multi-head attention์„ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค. (encoder-decoder self attention์ด๋ผ๊ณ ๋„ ๋ถ€๋ฆ…๋‹ˆ๋‹ค.)

Q, K, V์˜ ๊ณ„์‚ฐ ๋ฐฉ์‹

4. ์ƒ์„ฑ๋œ ๋‘ ํ–‰๋ ฌ์€ concat ํ•˜๊ณ , residual connection, layer normalization์„ ์ˆ˜ํ–‰ํ•˜์—ฌ feed forward network์— ๋„˜๊น๋‹ˆ๋‹ค.

5. ์ดํ›„ ๋‹ค์‹œ residual connection, layer normalization์„ ์ˆ˜ํ–‰ํ•˜์—ฌ output์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.

์ด๋•Œ๋„, output๊ณผ input์˜ ํฌ๊ธฐ๋Š” ๋™์ผํ•ฉ๋‹ˆ๋‹ค.

๋…ผ๋ฌธ์—์„œ๋Š” 6๊ฐœ์˜ decoder layer๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ตœ์ข… ๊ฒฐ๊ณผ๋ฌผ์„ ์ƒ์„ฑํ•˜์—ฌ ์˜ˆ์ธก์— ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.

    1.