Background
Seq2Seq ๋ชจ๋ธ
- Encoder์ Decoder๋ฅผ ๊ฐ๊ฐ RNN์ผ๋ก ๊ตฌ์ฑํ๋ ๋ฐฉ์์ ๋๋ค.
- ๋์์๋ฆฌ
- ‘๋๋’, ‘ํธ๋๋ฅผ’, ‘์ฌ๋ํด’๋ผ๋ 3๊ฐ์ ํ ํฐ๋ค์ ์์ฐจ์ ์ผ๋ก LSTM ์ ์ ๋ฃ์ผ๋ฉด , hidden state๋ฅผ ํ๋์ฉ ์ถ๋ ฅํฉ๋๋ค.
- ์ด๋ ๊ฒ ํ ํฐ๋ค์ hidden state๋ค์ด ์ถ๋ ฅ๋๋ฉด, ๋ง์ง๋ง hidden state๋ ์ ๋ณด๋ฅผ ์์ถํ vector๊ฐ ๋๊ณ , ์ด๋ฅผ Context Vector ๋ผ๊ณ ์นญํฉ๋๋ค.
- Context Vector๋ฅผ ํตํด ์ดํ token๋ค์ ๋ฃ์์ ๋ ๋ค์ token ์์ธก์ ์ํ hidden state๊ฐ ์ถ๋ ฅ๋ฉ๋๋ค.
- ๋ฌธ์ ์
- Sequence๊ฐ ๊ธธ์ด์ง๋ ๊ฒฝ์ฐ์๋ Gradient Vanishing ๋ฌธ์ ๊ฐ ๋ฐ์ํ์ฌ Context Vector์ ์ ์์ token๋ค์ ์ ๋ณด๊ฐ ์์ค๋๋ ๋ฌธ์ ๊ฐ ๋ฐ์ํ์ต๋๋ค.
- Context vector๋ก๋ encoder์ ์ ๋ณด๋ฅผ ๋ชจ๋ ๋ด๊ธฐ ํ๋ค์๊ณ , ๊ทธ์ ๋ฐ๋ผ decoder์ ๋ฒ์ญ ํ์ง์ด ์ ํ๋๋ Bottleneck Effect๊ฐ ๋ฐ์ํ์ต๋๋ค.
๐ก ์ด๋ฌํ ๋ฌธ์ ์ ๋ค์ ๋ํ ํด๊ฒฐ์ ์ํด Attention์ ์ด์ฉํ์๋ ์์ด๋์ด๊ฐ ์ ์๋์์ต๋๋ค.
Attention ์์ (Long-term Attention)
- hidden state๋ค์ ํ๋์ ํ๋ ฌ๋ก ๋ณด์กดํ ํ, decoder์ ์์ ํ ํฐ์ LSTM ์ ์ ๋ฃ์ hidden state์ ๋ด์ ํ์ฌ Alignment Scores์ Alignment Weights๋ฅผ ๊ตฌํฉ๋๋ค.
- Attention Weights๋ฅผ hidden state ํ๋ ฌ์ ๊ณฑํ์ฌ ๊ฐ์คํฉ์ ๊ตฌํฉ๋๋ค.
- ์์ ๊ณผ์ ๋ค์ ์ํํ์ฌ ๊ตฌํ Attention Value๋ค์ concatํ๊ณ , tanh, softmax ํจ์๋ฅผ ์ด์ฉํด ์์ธก ํ๋ฅ ์ ๊ตฌํ์ฌ ๋ค์ token์ ์์ธกํฉ๋๋ค.
๐ก Attention์ ํตํด ํ์ฌ ์ถ๋ ฅ ๋์์ ๋ํด ๋ชจ๋ encoder์ hidden state๋ฅผ ๊ตฌํ ์ ์๊ฒ ๋์์ต๋๋ค.
Training Architecture
Token๋ค์ด encoder์ sequentialํ๊ฒ ๋ค์ด๊ฐ์ ์์๊ฐ ๋ณด์กด๋๋ Seq2Seq ๋ชจ๋ธ๊ณผ๋ ๋ค๋ฅด๊ฒ, Transformer๋ ํ ๋ฒ์ ๋ชจ๋ ํ ํฐ์ encoder์ ๋ฃ๋ ๋ฐฉ์์ผ๋ก ์ค๊ณ๋์ด ์์์ ๊ด๋ จ๋ ์ ๋ณด๊ฐ ๋ณด์กด๋์ง ์์ต๋๋ค.
์ด๋ฌํ ๋ฌธ์ ์ ์ Position Encoding์ผ๋ก ํด๊ฒฐํ์ต๋๋ค.
Positional Encoding
positional encoding์ ๋ค์๊ณผ ๊ฐ์ ํน์ง์ ๊ฐ์ง๋๋ค:
- ๊ฐ ํ ํฐ์ ์์น๋ง๋ค ์ ์ผํ ๊ฐ์ ์ง๋ ์ผ ํฉ๋๋ค.
- ํ ํฐ ๊ฐ ์ฐจ์ด๊ฐ ์ผ์ ํ ์๋ฏธ๋ฅผ ์ง๋ ์ผ ํฉ๋๋ค.
- ๋ ๊ธด ๊ธธ์ด์ ๋ฌธ์ฅ์ด ์ ๋ ฅ๋์ด๋ ์ผ๋ฐํ๊ฐ ๊ฐ๋ฅํด์ผ ํฉ๋๋ค.
๐ก ์ด๋ ๊ฒ embedding์ position encoding์ด ๋ํด์ง๋ฉด, ์๊ณต๊ฐ์ ํน์ฑ์ ๋ฐ์ํ ํ๋ ฌ๋ก ๋ณํ๋ฉ๋๋ค.
Encoder์ ๊ตฌ์กฐ
- Input์ 3๊ฐ ๋ณต์ฌํ์ฌ, Qurey, Key, Value ๊ฐ๊ฐ์ ๊ฐ์ค์น๋ฅผ ๊ณฑํด Q, K, V ํ๋ ฌ์ ์์ฑํฉ๋๋ค.
- Q์ K์ ํ๋ ฌ๊ณฑ์ ์ํํ์ฌ ์ค์ผ์ผ๋งํ๊ณ , masking๊ณผ softmax ํจ์๋ฅผ ๊ฑฐ์ณ ๋์จ ๊ฒฐ๊ณผ๋ฌผ์ V์์ ํ๋ ฌ๊ณฑ์ ๊ณ์ฐํฉ๋๋ค. ์ด๋, ๋ค์ํ task ์ํ์ ์ํด multi-head attention๊ธฐ๋ฒ์ ์ฌ์ฉํฉ๋๋ค.
- ์ดํ ๋ง๋ค์ด์ง 2๊ฐ์ ํ๋ ฌ์ concatํ๊ณ , residual connection, layer normalization์ ์ํํ์ฌ feed-forward network์ ๋๊น๋๋ค.
- ๋ค์ residual connection, layer normalization์ ์ํํ์ฌ output์ ์ป์ต๋๋ค.
์ด๋, output๊ณผ input์ ํฌ๊ธฐ๋ ๋์ผํฉ๋๋ค !!!
ํด๋น ๋ ผ๋ฌธ์์๋ 6๊ฐ์ encoder layer๋ฅผ ์ฌ์ฉํฉ๋๋ค. output๊ณผ input์ ํฌ๊ธฐ๊ฐ ๋์ผํ๋ฏ๋ก, ๊ฐ์ ๊ตฌ์กฐ์ encoder layer๋ฅผ 6๋ฒ ์ฌ์ฉํ ์ ์์ต๋๋ค.
Decoder์ ๊ตฌ์กฐ
- ํ๊น ๋ฌธ์ฅ ํ์ต ์ input token๋ค์ one-hot-encoding์ผ๋ก ๋ณํํ์ฌ embedding layer์ ๋ฃ๊ณ , positional encoding์ ์ํํฉ๋๋ค.
- ์์ encoder์ 1~3๋ฒ ๊ณผ์ (self-attention ๊ณผ์ )์ ๋ฐ๋ณตํ์ฌ ์ฐ์ถ๋ ๊ฒฐ๊ณผ๋ฅผ Query๋ก, Encoder์ ouput์ 2๊ฐ ๋ณต์ฌํ์ฌ Key์ Value ๊ฐ์ผ๋ก ์ฌ์ฉํฉ๋๋ค.
-
- Self-attention ์ํ ์ encoder์ decoder์ ๋ค๋ฅธ masking ๊ธฐ๋ฒ
- decoder์์ ํด๋น masking ๊ธฐ๋ฒ์ ์ฌ์ฉํ๋ ์ด์ ๋ attention์ ๊ฐ์ค์น๋ฅผ ๊ตฌํ ๋, ๋์ token ์ดํ์ token์ ์ฐธ์กฐํ์ง ๋ชปํ๋๋ก ํ๊ธฐ ์ํด์์ ๋๋ค.
- ๋ง์ฝ ๊ทธ๋ฆผ๊ณผ ๊ฐ์ด ์ผ๊ฐํ๋ ฌ์ ์๋ถ๋ถ์ maskingํ์ง ์๋๋ค๋ฉด, ์ฒซ๋ฒ์งธ ํ์์ I๋ love, Hodu~too์ ๋ชจ๋ ์ ๋ณด๋ค์ ๋ฏธ๋ฆฌ ๋ฐ์์ค๋ ๋ฌธ์ ๊ฐ ๋ฐ์ํฉ๋๋ค.
3. ์๋ ์์์ ํ์ฉํ์ฌ multi-head attention์ ์ํํฉ๋๋ค. (encoder-decoder self attention์ด๋ผ๊ณ ๋ ๋ถ๋ฆ ๋๋ค.)
4. ์์ฑ๋ ๋ ํ๋ ฌ์ concat ํ๊ณ , residual connection, layer normalization์ ์ํํ์ฌ feed forward network์ ๋๊น๋๋ค.
5. ์ดํ ๋ค์ residual connection, layer normalization์ ์ํํ์ฌ output์ ์์ฑํฉ๋๋ค.
์ด๋๋, output๊ณผ input์ ํฌ๊ธฐ๋ ๋์ผํฉ๋๋ค.
๋ ผ๋ฌธ์์๋ 6๊ฐ์ decoder layer๋ฅผ ์ฌ์ฉํ์ฌ ์ต์ข ๊ฒฐ๊ณผ๋ฌผ์ ์์ฑํ์ฌ ์์ธก์ ์ฌ์ฉํฉ๋๋ค.