๐Ÿ’ป ํ”„๋กœ์ ํŠธ/๐Ÿงธ TOY-PROJECTS

[ํ† ์ดํ”„๋กœ์ ํŠธ-๊ณต๊ฐ ์ฑ—๋ด‡] ํ”„๋กœ์ ํŠธ ๊ฐœ์š”

์žฅ์˜์ค€ 2023. 7. 5. 03:45

์ด๋ฒˆ์— NLP๋ฅผ ์ฃผ์ œ๋กœ ์ž‘์€ ํ† ์ด ํ”„๋กœ์ ํŠธ๋ฅผ ์ง„ํ–‰ํ–ˆ๋Š”๋ฐ, ์ž์‹ ์˜ ์‹ฌ๋ฆฌ๋ฅผ ์ž‘์„ฑํ•˜๋ฉด ๊ทธ๊ฒƒ์— ๊ณต๊ฐํ•ด ์ฃผ๊ฑฐ๋‚˜ ์ƒ๋‹ดํ•ด ์ฃผ๋Š” ์ฑ—๋ด‡์„ ๋งŒ๋“ค์–ด๋ดค๋‹ค.

์ฐธ๊ณ ์ž๋ฃŒ:

https://wikidocs.net/157001 

https://hoit1302.tistory.com/162#[1]kogpt2%EA%B8%B0%EB%B0%98%EC%8B%AC%EB%A6%AC%EC%BC%80%EC%96%B4%EC%B1%97%EB%B4%87

 

๋ณธ ํ”„๋กœ์ ํŠธ๋Š” skt์˜ ์ƒ์„ฑ ๋ชจ๋ธ, KoGPT2๋ฅผ fine-tuning ํ•˜์—ฌ ์‚ฌ์šฉํ–ˆ์œผ๋ฉฐ,

๋ฐ์ดํ„ฐ์…‹์€ ์œ ๋ช…ํ•œ ์†ก์˜์ˆ™ ๋‹˜์˜ ์ฑ—๋ด‡ ๋ฐ์ดํ„ฐ์…‹์„ ์‚ฌ์šฉํ–ˆ๋‹ค.

์ž์„ธํ•œ ์ฝ”๋“œ๋Š” ๊นƒํ—ˆ๋ธŒ๋ฅผ ์ฐธ๊ณ ํ•˜๋ฉด ์ข‹์„ ๊ฒƒ ๊ฐ™๋‹ค.


1. ๋ฐ์ดํ„ฐ ์ „์ฒ˜๋ฆฌ

์šฐ์„  ์‚ฌ์šฉํ•œ ๋ฐ์ดํ„ฐ๋ฅผ df์— ์ €์žฅํ•˜๊ณ  df.head()๋ฅผ ์‹คํ–‰ํ•ด ๋ณด๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค:

์—ฌ๊ธฐ์„œ label์ด ์žˆ๋Š”๋ฐ, 0์€ ์ผ์ƒ๋‹ค๋ฐ˜์‚ฌ, 1์€ ์ด๋ณ„(๋ถ€์ • ๊ฐ์ •), 2๋Š” ์‚ฌ๋ž‘(๊ธ์ • ๊ฐ์ •)์ด๋‹ค.

ํ•™์Šต์„ ์‹œ์ผœ๋ณด๊ณ  ํšŒ๊ณ ํ•ด ๋ณด๋Š”๋ฐ, ์ด label์„ ๋„ฃ์–ด์„œ ํ•™์Šต์‹œ์ผฐ์œผ๋ฉด ํ›จ์”ฌ ์ข‹์•˜์„ ๊ฒƒ ๊ฐ™๋‹ค. (์ข‹์•˜๋‹ค๊ธฐ๋ณด๋‹ค๋Š” ๋‹ต๋ณ€์„ ํ‰๊ฐ€ํ•  ๋•Œ ์ง€ํ‘œ๋กœ ํ‰๊ฐ€ํ•˜๊ธฐ ์ข‹์•˜์„ ๊ฒƒ ๊ฐ™๋‹ค.)

ํ•˜์ง€๋งŒ, ์ฒ˜์Œ์˜ ๋‚˜๋Š” ์ฑ—๋ด‡์— ํ…์ŠคํŠธ๋ฅผ ๋„ฃ์—ˆ์„ ๋•Œ answer์„ ๋ฐ›๋Š” ๊ฒƒ์„ ๋ชฉํ‘œ๋กœ ํ•ด์„œ label์„ ์—†์•ค ์ฑ„ ํ•™์Šต์‹œ์ผฐ๋‹ค.

์ดํ›„, conversations๋ผ๋Š” ๋ณ€์ˆ˜์— ๋ฆฌ์ŠคํŠธ๋กœ ์ €์žฅํ–ˆ๋‹ค.

conversations ์ถœ๋ ฅ ๊ฒฐ๊ณผ

2. Tokenizing

Tokenizing์˜ ์›๋ฆฌ

์œ„ ๋ฐ์ดํ„ฐ์…‹(ํ…์ŠคํŠธ)์„ ์ปดํ“จํ„ฐ๊ฐ€ ์ดํ•ดํ•˜๊ธฐ์—๋Š” ๋ถ„๋ช… ํ•œ๊ณ„๊ฐ€ ์žˆ๋‹ค.

๊ทธ๋Ÿฌ๋ฏ€๋กœ, ํ•ด๋‹น ํ…์ŠคํŠธ๋ฅผ ์–ด๋–ค ์‹์œผ๋กœ ๋ถ„๋ฆฌํ•ด์„œ, ๋ถ„๋ฆฌ๋œ ํ…์ŠคํŠธ๋ฅผ ํŠน์ •ํ•œ ์ˆซ์ž(id)์— ๋Œ€์‘์‹œํ‚ค๊ณ , ํ•ด๋‹น id๋ฅผ ๋ชจ๋ธ์˜ ์ž…๋ ฅ์œผ๋กœ ๋„ฃ์–ด์ฃผ๋Š” ๊ณผ์ •์ด ํ•„์š”ํ•˜๋‹ค.

์ด๋•Œ, ์ž…๋ ฅ์œผ๋กœ ๋“ค์–ด์˜จ ํ…์ŠคํŠธ๋ฅผ ์กฐ๊ธˆ ๋” ์ž‘์€ ๋‹จ์œ„๋กœ ๋ถ„๋ฆฌํ•˜๋Š” ๊ณผ์ •์€ ๋Œ€ํ‘œ์ ์œผ๋กœ 3๊ฐ€์ง€๊ฐ€ ์žˆ๋‹ค:

  1. word-based
  2. character-based
  3. subword-based

๊ฐ๊ฐ์— ๊ด€ํ•ด ๊ฐ„๋‹จํ•˜๊ฒŒ ์„ค๋ช…ํ•˜์ž๋ฉด, word-based๋Š” ๋‹จ์ˆœํ•˜๊ฒŒ ๋‹จ์–ด ๋‹จ์œ„๋กœ ๋ถ„๋ฆฌํ•˜๊ณ , ๊ฐ ๋‹จ์–ด๋ณ„๋กœ ๊ณ ์œ ์˜ id ๊ฐ’์„ ๋ถ€์—ฌํ•˜๋Š” ๋ฐฉ์‹,

character-based๋Š” ๋ชจ๋“  ํ…์ŠคํŠธ๋ฅผ character ๋‹จ์œ„๋กœ ์ž๋ฅด๊ณ , ๊ฐ character๋งˆ๋‹ค ๊ณ ์œ ์˜ id๋ฅผ ๋ถ€์—ฌํ•˜๋Š” ๋ฐฉ์‹,

subword-based๋Š” ์ž์ฃผ ์‚ฌ์šฉ๋˜๋Š” ๋‹จ์–ด๋“ค์€ ๊ทธ๋ƒฅ ์‚ฌ์šฉํ•˜๊ณ , ์ž์ฃผ ๋“ฑ์žฅํ•˜์ง€ ์•Š๋Š” ๋‹จ์–ด๋“ค์€ ์˜๋ฏธ ์žˆ๋Š” subword๋กœ ๋ถ„๋ฆฌํ•˜๋Š” ๋ฐฉ์‹์ด๋‹ค.

์—ฌ๊ธฐ์„œ ๋‚ด๊ฐ€ ์‚ฌ์šฉํ•œ PreTrainedTokenizerFast๋ผ๋Š” tokenizer๋Š” subword-based ๋ฐฉ์‹์„ ์‚ฌ์šฉํ•œ๋‹ค.

Tokenizing ๋ฐฉ์‹

tokenizer์—์„œ๋Š” ํ…์ŠคํŠธ์˜ ์‹œ์ž‘, ๋, ๋ฏธ์ง€์˜ ํ† ํฐ, ํŒจ๋”ฉ, ๋งˆ์Šคํฌ, ๊ตฌ๋ถ„ ๋“ฑ์˜ ์—ญํ• ์„ ์ˆ˜ํ–‰ํ•˜๋Š” bos_token, eos_token, unk_token, pad_token, mask_token, sep_token ๋“ฑ์˜ ํŠน์ˆ˜ ํ† ํฐ์„ ์ง€์ •ํ•ด ์ฃผ์—ˆ๋‹ค.

BOS = '<bos>'
EOS = '<eos>'
MASK = '<unused0>'
PAD = '<pad>'
SEP = '<sep>'

tokenizer = PreTrainedTokenizerFast.from_pretrained(
                "skt/kogpt2-base-v2",
                bos_token=BOS, 
                eos_token=EOS, 
                unk_token='<unk>',
            	pad_token=PAD, 
                mask_token=MASK, 
                sep_token=SEP
            )

unk ํ† ํฐ์€ ์ž‘์„ฑํ•˜๊ณ  ์ƒ๊ฐ๋‚˜์„œ ์•ˆ์ด์˜์ง€๋งŒ ๊ท€์ฐฎ์•„์„œ ๋ฌธ์ž์—ด๋กœ ๋„ฃ์—ˆ๋‹ค

๊ฐ ํ† ํฐ์— ๊ด€ํ•œ ๊ฐ„๋‹จํ•œ ์„ค๋ช…์€ ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค:

  • bos_token: ๋ฌธ์žฅ์˜ ์‹œ์ž‘์„ ๋‚˜ํƒ€๋‚ด๋Š” ํ† ํฐ
  • eos_token: ๋ฌธ์žฅ์˜ ๋์„ ๋‚˜ํƒ€๋‚ด๋Š” ํ† ํฐ
  • unk_token: ์•Œ ์ˆ˜ ์—†๋Š” ๋‹จ์–ด๋ฅผ ๋‚˜ํƒ€๋‚ด๋Š” ํ† ํฐ
  • pad_token: ํŒจ๋”ฉ์„ ๋‚˜ํƒ€๋‚ด๋Š” ํ† ํฐ
  • mask_token: ๋งˆ์Šคํฌ๋ฅผ ๋‚˜ํƒ€๋‚ด๋Š” ํ† ํฐ
  • sep_token: ๊ตฌ๋ถ„์„ ๋‚˜ํƒ€๋‚ด๋Š” ํ† ํฐ

ํ† ํฐ ์‚ฌ์šฉ์˜ ์˜ˆ๋ฅผ ๋“ค์–ด๋ณด์ž.

๋งŒ์•ฝ ๋‚ด๊ฐ€ '์˜ค๋Š˜ ์šฐ์šธํ•ด์„œ ํŒŒ๋งˆํ–ˆ์–ด'๋ฅผ ํ† ํฌ๋‚˜์ด์ €๋ฅผ ํ†ตํ•ด ํ† ํฐํ™”์‹œํ‚ค๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™์ด ํ† ํฐํ™”๋œ๋‹ค:

['<bos>', '์˜ค๋Š˜', '์šฐ์šธ_', '_ํ•ด์„œ', 'ํŒŒ๋งˆ_', '_ํ–ˆ์–ด', '<eos>']

์—ฌ๊ธฐ์„œ ๋งŒ์•ฝ ๋‚ด๊ฐ€ ์‚ฌ์šฉํ•œ ํ† ํฌ๋‚˜์ด์ €์˜ ์‚ฌ์ „์— 'ํŒŒ๋งˆ'๋ผ๋Š” ๋‹จ์–ด๊ฐ€ ์—†์œผ๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™์ด ํ† ํฐํ™”๋œ๋‹ค:

['<bos>', '์˜ค๋Š˜', '<unk', '_ํ•ด์„œ', 'ํŒŒ๋งˆ_', '_ํ–ˆ์–ด', '<eos>']

 

๊ทธ๋ ‡๋‹ค๋ฉด, ์ด์ œ ์ด ํ† ํฐ๋“ค์„ ์ปดํ“จํ„ฐ๊ฐ€ ์•Œ์•„๋“ค์„ ์ˆ˜ ์žˆ๋„๋ก ์ˆซ์ž๋กœ ์ธ์ฝ”๋”ฉ ์‹œ์ผœ๋ณด์ž.

ํ•จ์ˆ˜๋Š” ๊ฐ„๋‹จํ•˜๋‹ค. tokenizer.encode(์ธ์ฝ”๋”ฉ ์‹œํ‚ฌ ๋ฌธ์žฅ)๋ฅผ ์‹คํ–‰ํ•ด ์ฃผ๋ฉด ๋œ๋‹ค.

๊ทธ๋Ÿผ ์œ„์˜ ํ† ํฐ๋“ค์€ ๋‹ค์Œ๊ณผ ๊ฐ™์ด ์ธ์ฝ”๋”ฉ ๋œ๋‹ค (์ˆซ์ž๋Š” ์ž„์˜์˜ ์ˆซ์ž์ž„):

[1, 8723, 1289, 7893, 4536, 1632, 0]

์ด๋ ‡๊ฒŒ tokenizer์˜ ์‚ฌ์ „์— ๋“ฑ๋ก๋˜์–ด index๊ฐ€ ์žˆ๋Š” ๊ฒฝ์šฐ๋Š” ํ•ด๋‹น index๋กœ,

 ์‚ฌ์šฉ์ž๊ฐ€ token์„ ์ง์ ‘ ๋„ฃ์–ด์ค˜์„œ index๊ฐ€ ์—†๋Š” ๊ฒฝ์šฐ์—๋Š” ์ƒˆ๋กœ ํ• ๋‹น๋œ index๋กœ ์ธ์ฝ”๋”ฉ ๋œ๋‹ค.

Tokenizing ์ ์šฉ

๊ทธ๋Ÿผ ์ด ๋ฐฉ๋ฒ•์„ ๊ฐ€๊ณตํ•œ ๋ฐ์ดํ„ฐ์…‹์— ์ด์šฉํ•ด ๋ณด์ž. ์ฝ”๋“œ๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค:

class ChatDataset(Dataset):
    def __init__(self, conversations, tokenizer, max_length=150):
        self.tokenizer = tokenizer
        self.inputs = []
        self.labels = []

        BOS_ID = tokenizer.bos_token_id
        EOS_ID = tokenizer.eos_token_id
        SEP_ID = tokenizer.sep_token_id
        PAD_ID = tokenizer.pad_token_id

        for conversation in conversations:
            question, answer = conversation

            # ํ† ํฐํ™”
            question_ids = tokenizer.encode(question)
            answer_ids = tokenizer.encode(answer)

            # [BOS] ์งˆ๋ฌธ [SEP] ๋‹ต๋ณ€ [EOS] ํ˜•ํƒœ๋กœ ๋งŒ๋“ฌ
            input_ids = [BOS_ID] + question_ids + [SEP_ID] + answer_ids + [EOS_ID]

            # ๋ ˆ์ด๋ธ” ์ƒ์„ฑ (-100์œผ๋กœ ํŒจ๋”ฉ)
            label_ids = [-100]*(len(question_ids)+2) + answer_ids + [-100]*1

            # ํŒจ๋”ฉ
            while len(input_ids) < max_length:
                input_ids.append(PAD_ID)
                label_ids.append(-100)

            self.inputs.append(input_ids)
            self.labels.append(label_ids)

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return torch.tensor(self.inputs[idx]), torch.tensor(self.labels[idx])

๋ฌธ๋งฅ์„ ์ž˜ ํ‘œํ˜„ํ•˜๊ธฐ ์œ„ํ•ด ์งˆ๋ฌธ๊ณผ ๊ทธ์— ๋Œ€ํ•œ ๋‹ต๋ณ€์„ ํ† ํฐํ™”ํ•˜์—ฌ [BOS] ์งˆ๋ฌธ [SEP] ๋‹ต๋ณ€ [EOS] ํ˜•ํƒœ๋กœ ๋งŒ๋“ค์–ด ์ฃผ๋‹ค. ์ด๋Ÿฌํ•œ ๋ฐฉ์‹์œผ๋กœ ๊ตฌ๋ถ„๋˜๊ฒŒ ์ธ์ฝ”๋”ฉํ•จ์œผ๋กœ์จ ๋ชจ๋ธ์ด ์งˆ๋ฌธ๊ณผ ๋‹ต๋ณ€์˜ ๊ฒฝ๊ณ„๋ฅผ ์ธ์‹ํ•˜๊ณ  ์งˆ๋ฌธ์— ๋Œ€ํ•œ ๋‹ต๋ณ€์„ ์ž˜ ์ƒ์„ฑํ•  ์ˆ˜ ์žˆ๋‹ค.

๋˜ํ•œ ์ฑ—๋ด‡์˜ ํ•™์Šต์„ ์œ„ํ•ด label๋„ ๋งŒ๋“ค์–ด์•ผ ํ–ˆ๋‹ค. ์šฐ๋ฆฌ์˜ ๋ชฉํ‘œ๋Š” ์งˆ๋ฌธ์— ๋งž๋Š” ๋‹ต๋ณ€์„ ์ƒ์„ฑํ•˜๋Š” ๊ฒƒ์ด๊ธฐ ๋•Œ๋ฌธ์—, ๋‹ต๋ณ€ ๋ถ€๋ถ„๋งŒ label๋กœ ์‚ฌ์šฉํ–ˆ์Šต๋‹ˆ๋‹ค.

์ด๋Š” [-100]*(len(question_ids)+2) + answer_ids + [-100]*1 ํ˜•์‹์œผ๋กœ ์ธ์ฝ”๋”ฉํ–ˆ๋‹ค. ์ด์™€ ๊ฐ™์ด [BOS] ์งˆ๋ฌธ [SEP], [EOS] ๋ถ€๋ถ„์„ [-100]์œผ๋กœ ๋งŒ๋“  ์ด์œ ๋Š” ๋‹ต๋ณ€ ๋ถ€๋ถ„(answer_ids)๋งŒ ์ถ”์ถœํ•˜์—ฌ label๋กœ ๋งŒ๋“ค๊ธฐ ์œ„ํ•ด์„œ์ด๋‹ค.

 

๊ทธ๋ ‡๋‹ค๋ฉด, [-100]์œผ๋กœ ์ธ์ฝ”๋”ฉํ•˜๋Š” ์ด์œ ๋Š” ๋ฌด์—‡์ผ๊นŒ?

CrossEntropyLoss์™€ ๊ฐ™์€ ๋ช‡๋ช‡ PyTorch์˜ ์†์‹ค ํ•จ์ˆ˜์—์„œ๋Š” -100์„ ํŠน์ˆ˜ํ•œ ๊ฐ’์œผ๋กœ ์‚ฌ์šฉํ•˜๊ธฐ ๋•Œ๋ฌธ์ด๋‹ค.

๋ ˆ์ด๋ธ”์ด -100์ธ ๊ฒฝ์šฐ, ํ•ด๋‹น ์œ„์น˜์˜ ์†์‹ค์€ ์ „์ฒด ์†์‹ค ๊ณ„์‚ฐ์—์„œ ๋ฌด์‹œ๋œ๋‹ค. ์ฆ‰, ๊ทธ ์œ„์น˜์˜ ์˜ˆ์ธก ๊ฐ’์ด ์–ด๋–ค ๊ฒƒ์ด๋“  ๊ฐ„์— ์†์‹ค์—๋Š” ์˜ํ–ฅ์„ ์ฃผ์ง€ ์•Š๋Š” ๊ฒƒ์ด๋‹ค.

์œ„์—์„œ ์ž‘์„ฑํ•œ ์ฝ”๋“œ์—์„œ๋Š” GPT-2 ๋ชจ๋ธ์„ ํ•™์Šตํ•  ๋•Œ, ์งˆ๋ฌธ ๋ถ€๋ถ„์— ํ•ด๋‹นํ•˜๋Š” ์œ„์น˜์™€ ํŒจ๋”ฉ ๋ถ€๋ถ„์˜ ์†์‹ค์€ ๊ณ„์‚ฐ๋˜์ง€ ์•Š๋„๋ก ํ•˜๊ธฐ ์œ„ํ•ด ์ด๋Ÿฌํ•œ ๋ฐฉ๋ฒ•์„ ์‚ฌ์šฉํ–ˆ๋‹ค. ์ด๋ ‡๊ฒŒ ํ•˜๋ฉด, ๋ชจ๋ธ์€ ๋‹ต๋ณ€ ๋ถ€๋ถ„์— ๋Œ€ํ•ด์„œ๋งŒ ์†์‹ค์„ ์ตœ์†Œํ™”ํ•˜๋Š” ๋ฐฉํ–ฅ์œผ๋กœ ํ•™์Šตํ•˜๊ฒŒ ๋œ๋‹ค.

3. ๋ชจ๋ธ ํ•™์Šต

์˜ตํ‹ฐ๋งˆ์ด์ €๋Š” ์ผ๋ฐ˜์ ์œผ๋กœ ์ž์—ฐ์–ด ์ฒ˜๋ฆฌ ๋ฐ ๋”ฅ๋Ÿฌ๋‹ ์ž‘์—…์—์„œ ํšจ๊ณผ์ ์œผ๋กœ ์‚ฌ์šฉ๋˜๋Š” ์˜ตํ‹ฐ๋งˆ์ด์ € ์ค‘ ํ•˜๋‚˜ AdamW๋ฅผ ์‚ฌ์šฉํ–ˆ๋‹ค.

AdamW๋Š” ๋ชจ๋ธ์˜ ๊ฐ€์ค‘์น˜๋ฅผ ์ •๊ทœํ™”ํ•˜๋ฏ€๋กœ ๋ชจ๋ธ์˜ ๊ฐ€์ค‘์น˜๋ฅผ ์ œ์–ดํ•˜๋ฉด์„œ๋„ ์ตœ์ ํ™”๋ฅผ ์ˆ˜ํ–‰ํ•  ์ˆ˜ ์žˆ๋‹ค๊ณ  ํ•œ๋‹ค.

์ž‘์„ฑ๋œ ์ฝ”๋“œ๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค:

from tqdm import tqdm

# ํ›ˆ๋ จ ๋ฃจํ”„
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

for epoch in range(10):  # ์—ํญ ์ˆ˜๋ฅผ ์ ์ ˆํ•˜๊ฒŒ ์กฐ์ •ํ•ด์ฃผ์„ธ์š”.
    total_loss = 0
    for i, (inputs, labels) in enumerate(tqdm(dataloader)):
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model(input_ids=inputs, labels=labels)

        loss = outputs.loss
        total_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1}, Loss: {total_loss/(i+1)}")

# ๋ชจ๋ธ ์ €์žฅ
torch.save(model.state_dict(), "koGpt2_chatbot.pt")

์ด๋ ‡๊ฒŒ ์‹คํ–‰ํ•˜์—ฌ ์ถ”์ถœ๋œ ๊ฐ epoch ๋‹น loss๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค:

๋‹คํ–‰ํžˆ๋„ loss๋Š” ๊ณ„์† ์ค„์–ด๋“œ๋Š” ๋ชจ์Šต..

4. ๋‹ต๋ณ€ ์ƒ์„ฑ

๋งˆ์ง€๋ง‰์œผ๋กœ, ๋‹ต๋ณ€ ์ƒ์„ฑ์€ ๋‹ค์Œ๊ณผ ๊ฐ™์ด ์ฝ”๋“œ๋ฅผ ์ž‘์„ฑํ–ˆ๋‹ค:

# ๋‹ต๋ณ€ ์ƒ์„ฑ
generated = model.generate(
               inputs,
               max_length=60,
               temperature=0.2,
               no_repeat_ngram_size=3,
               pad_token_id=tokenizer.pad_token_id,
               eos_token_id=tokenizer.eos_token_id,
               bos_token_id=tokenizer.bos_token_id
           )

generated_text = tokenizer.decode(generated[:, inputs.shape[-1]:][0], skip_special_tokens=True)
  • ์ƒ์„ฑ ๊ณผ์ •์—์„œ ์ƒ˜ํ”Œ๋ง์˜ ๋‹ค์–‘์„ฑ์„ ์กฐ์ ˆํ•˜๋Š” ์—ญํ• ์„ ํ•˜๋Š” temperature๋Š” 0.2๋กœ ์„ค์ •ํ–ˆ๋‹ค.
    • temperature๊ฐ€ ๋‚ฎ์„์ˆ˜๋ก ๊ฐ€์žฅ ๊ฐ€๋Šฅ์„ฑ์ด ๋†’์€ ๋‹จ์–ด๋ฅผ ์„ ํƒํ•˜๋„๋ก ํ•˜์—ฌ ์ƒ์„ฑ๋œ ํ…์ŠคํŠธ๊ฐ€ ๋ณด๋‹ค ์ผ๊ด€๋˜๊ณ  ๊ฒฐ์ •์ ์ธ ํŠน์„ฑ์„ ์ง€๋‹ˆ๊ณ , ๋ฐ˜๋Œ€๋กœ ๋†’์„์ˆ˜๋ก ๋ฌด์ž‘์œ„์„ฑ์„ ๋„์ž…ํ•˜์—ฌ ๋” ๋‹ค์–‘ํ•˜๊ณ  ์ฐฝ์˜์ ์ธ ๋‹ต๋ณ€์„ ์ƒ์„ฑํ•  ์ˆ˜ ์žˆ๋‹ค๊ณ  ํ•œ๋‹ค.
  • ์ƒ์„ฑ ๊ณผ์ •์—์„œ ๋ฐ˜๋ณต๋˜๋Š” n-gram์„ ๋ฐฉ์ง€ํ•˜๋Š” ์—ญํ• ์„ ํ•˜๋Š” no_repeat_ngram_size๋Š” 3์œผ๋กœ ์ง€์ •ํ•ด ์ฃผ์—ˆ๋‹ค.

5. ๊ฒฐ๊ณผ

์ดํ›„, ๊ฐ„๋‹จํ•˜๊ฒŒ ํ”„๋ก ํŠธ์™€ ๋ฐฑ์„ ๊ตฌ์„ฑํ•˜๊ณ  ๋ชจ๋ธ์„ ์—ฐ๊ฒฐ์‹œ์ผœ ๋ดค๋‹ค. ํ™•์‹คํžˆ ํ•œ๋ฒˆ ํ•ด๋ณธ ์ง€๋ผ, 30๋ถ„ ๋งŒ์— ํ”„๋ก ํŠธ, ๋ฐฑ์„ ๊ตฌ์„ฑํ–ˆ๋‹ค.

๊ฒฐ๊ณผ๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค (์‚ฌ์ง„ ์ฒจ๋ถ€):

์ฒ˜์Œ์—๋Š” ๊ดœ์ฐฎ์•˜๋Š”๋ฐ, ๋‹ค์–‘ํ•œ ์งˆ๋ฌธ์„ ํ•˜๋‹ˆ๊นŒ ๋‹ต๋ณ€์ด ์ด์ƒํ•ด์ง„๋‹ค. epoch์„ ๋” ๋Š˜๋ ค์„œ ํ•™์Šตํ•ด ๋ด์•ผ๊ฒ ๋‹ค.

๊ทธ๋ฆฌ๊ณ , ๋‹ต๋ณ€์˜ ์ฒซ ๋ถ€๋ถ„์€ ํ•ญ์ƒ ์ด์ƒํ•˜๊ฒŒ ๋‚˜์˜ค๋Š”๋ฐ, ๋‚ด๊ฐ€ ์ž‘์„ฑํ•œ ํ…์ŠคํŠธ์—์„œ ์ด์–ด์„œ ์ž‘์„ฑํ•˜๋ ค๊ณ  ํ•ด์„œ ๊ทธ๋Ÿฐ๊ฐ€..??

๊ทธ ์ด์œ ๋ฅผ ์ž˜ ๋ชจ๋ฅด๊ฒ ์ง€๋งŒ ๊ณ ์ณ๋ด์•ผ๊ฒ ๋‹ค...


ํ•˜๋ฃจ ๋งŒ์— ๋ฐ์ดํ„ฐ๋ฅผ ๊ฐ„๋‹จํ•˜๊ฒŒ ๊ฐ€๊ณตํ•˜๊ณ  ํ•™์Šตํ•˜๋Š” ์ฝ”๋“œ๋ฅผ ์ž‘์„ฑํ•˜์—ฌ ํ•™์Šต์‹œ์ผœ ๋ณด๊ณ , ๊ฒฐ๊ณผ๋ฌผ์„ ๋‚ด๋ดค๋‹ค.

์กฐ๊ธˆ ๋นจ๋ฆฌ ํ•ด์•ผ ํ•ด์„œ ์ •๊ตํ•˜๊ฒŒ ์ž‘์—…์€ ๋ชป ํ–ˆ์ง€๋งŒ, ๋‹ค์Œ์—๋Š” ์ข€ ๋” ๋งŽ์€ ๋ฐ์ดํ„ฐ์…‹์„ ๊ฐ์ • label์„ ํฌํ•จ์‹œ์ผœ ์ž‘์—…ํ•ด๋ณด๊ณ  ์‹ถ๋‹ค !!