์ด๋ฒ์ NLP๋ฅผ ์ฃผ์ ๋ก ์์ ํ ์ด ํ๋ก์ ํธ๋ฅผ ์งํํ๋๋ฐ, ์์ ์ ์ฌ๋ฆฌ๋ฅผ ์์ฑํ๋ฉด ๊ทธ๊ฒ์ ๊ณต๊ฐํด ์ฃผ๊ฑฐ๋ ์๋ดํด ์ฃผ๋ ์ฑ๋ด์ ๋ง๋ค์ด๋ดค๋ค.
์ฐธ๊ณ ์๋ฃ:
๋ณธ ํ๋ก์ ํธ๋ skt์ ์์ฑ ๋ชจ๋ธ, KoGPT2๋ฅผ fine-tuning ํ์ฌ ์ฌ์ฉํ์ผ๋ฉฐ,
๋ฐ์ดํฐ์ ์ ์ ๋ช ํ ์ก์์ ๋์ ์ฑ๋ด ๋ฐ์ดํฐ์ ์ ์ฌ์ฉํ๋ค.
์์ธํ ์ฝ๋๋ ๊นํ๋ธ๋ฅผ ์ฐธ๊ณ ํ๋ฉด ์ข์ ๊ฒ ๊ฐ๋ค.
1. ๋ฐ์ดํฐ ์ ์ฒ๋ฆฌ
์ฐ์ ์ฌ์ฉํ ๋ฐ์ดํฐ๋ฅผ df์ ์ ์ฅํ๊ณ df.head()๋ฅผ ์คํํด ๋ณด๋ฉด ๋ค์๊ณผ ๊ฐ๋ค:
์ฌ๊ธฐ์ label์ด ์๋๋ฐ, 0์ ์ผ์๋ค๋ฐ์ฌ, 1์ ์ด๋ณ(๋ถ์ ๊ฐ์ ), 2๋ ์ฌ๋(๊ธ์ ๊ฐ์ )์ด๋ค.
ํ์ต์ ์์ผ๋ณด๊ณ ํ๊ณ ํด ๋ณด๋๋ฐ, ์ด label์ ๋ฃ์ด์ ํ์ต์์ผฐ์ผ๋ฉด ํจ์ฌ ์ข์์ ๊ฒ ๊ฐ๋ค. (์ข์๋ค๊ธฐ๋ณด๋ค๋ ๋ต๋ณ์ ํ๊ฐํ ๋ ์งํ๋ก ํ๊ฐํ๊ธฐ ์ข์์ ๊ฒ ๊ฐ๋ค.)
ํ์ง๋ง, ์ฒ์์ ๋๋ ์ฑ๋ด์ ํ ์คํธ๋ฅผ ๋ฃ์์ ๋ answer์ ๋ฐ๋ ๊ฒ์ ๋ชฉํ๋ก ํด์ label์ ์์ค ์ฑ ํ์ต์์ผฐ๋ค.
์ดํ, conversations๋ผ๋ ๋ณ์์ ๋ฆฌ์คํธ๋ก ์ ์ฅํ๋ค.
2. Tokenizing
Tokenizing์ ์๋ฆฌ
์ ๋ฐ์ดํฐ์ (ํ ์คํธ)์ ์ปดํจํฐ๊ฐ ์ดํดํ๊ธฐ์๋ ๋ถ๋ช ํ๊ณ๊ฐ ์๋ค.
๊ทธ๋ฌ๋ฏ๋ก, ํด๋น ํ ์คํธ๋ฅผ ์ด๋ค ์์ผ๋ก ๋ถ๋ฆฌํด์, ๋ถ๋ฆฌ๋ ํ ์คํธ๋ฅผ ํน์ ํ ์ซ์(id)์ ๋์์ํค๊ณ , ํด๋น id๋ฅผ ๋ชจ๋ธ์ ์ ๋ ฅ์ผ๋ก ๋ฃ์ด์ฃผ๋ ๊ณผ์ ์ด ํ์ํ๋ค.
์ด๋, ์ ๋ ฅ์ผ๋ก ๋ค์ด์จ ํ ์คํธ๋ฅผ ์กฐ๊ธ ๋ ์์ ๋จ์๋ก ๋ถ๋ฆฌํ๋ ๊ณผ์ ์ ๋ํ์ ์ผ๋ก 3๊ฐ์ง๊ฐ ์๋ค:
- word-based
- character-based
- subword-based
๊ฐ๊ฐ์ ๊ดํด ๊ฐ๋จํ๊ฒ ์ค๋ช ํ์๋ฉด, word-based๋ ๋จ์ํ๊ฒ ๋จ์ด ๋จ์๋ก ๋ถ๋ฆฌํ๊ณ , ๊ฐ ๋จ์ด๋ณ๋ก ๊ณ ์ ์ id ๊ฐ์ ๋ถ์ฌํ๋ ๋ฐฉ์,
character-based๋ ๋ชจ๋ ํ ์คํธ๋ฅผ character ๋จ์๋ก ์๋ฅด๊ณ , ๊ฐ character๋ง๋ค ๊ณ ์ ์ id๋ฅผ ๋ถ์ฌํ๋ ๋ฐฉ์,
subword-based๋ ์์ฃผ ์ฌ์ฉ๋๋ ๋จ์ด๋ค์ ๊ทธ๋ฅ ์ฌ์ฉํ๊ณ , ์์ฃผ ๋ฑ์ฅํ์ง ์๋ ๋จ์ด๋ค์ ์๋ฏธ ์๋ subword๋ก ๋ถ๋ฆฌํ๋ ๋ฐฉ์์ด๋ค.
์ฌ๊ธฐ์ ๋ด๊ฐ ์ฌ์ฉํ PreTrainedTokenizerFast๋ผ๋ tokenizer๋ subword-based ๋ฐฉ์์ ์ฌ์ฉํ๋ค.
Tokenizing ๋ฐฉ์
tokenizer์์๋ ํ ์คํธ์ ์์, ๋, ๋ฏธ์ง์ ํ ํฐ, ํจ๋ฉ, ๋ง์คํฌ, ๊ตฌ๋ถ ๋ฑ์ ์ญํ ์ ์ํํ๋ bos_token, eos_token, unk_token, pad_token, mask_token, sep_token ๋ฑ์ ํน์ ํ ํฐ์ ์ง์ ํด ์ฃผ์๋ค.
BOS = '<bos>'
EOS = '<eos>'
MASK = '<unused0>'
PAD = '<pad>'
SEP = '<sep>'
tokenizer = PreTrainedTokenizerFast.from_pretrained(
"skt/kogpt2-base-v2",
bos_token=BOS,
eos_token=EOS,
unk_token='<unk>',
pad_token=PAD,
mask_token=MASK,
sep_token=SEP
)
unk ํ ํฐ์ ์์ฑํ๊ณ ์๊ฐ๋์ ์์ด์์ง๋ง ๊ท์ฐฎ์์ ๋ฌธ์์ด๋ก ๋ฃ์๋ค
๊ฐ ํ ํฐ์ ๊ดํ ๊ฐ๋จํ ์ค๋ช ์ ๋ค์๊ณผ ๊ฐ๋ค:
- bos_token: ๋ฌธ์ฅ์ ์์์ ๋ํ๋ด๋ ํ ํฐ
- eos_token: ๋ฌธ์ฅ์ ๋์ ๋ํ๋ด๋ ํ ํฐ
- unk_token: ์ ์ ์๋ ๋จ์ด๋ฅผ ๋ํ๋ด๋ ํ ํฐ
- pad_token: ํจ๋ฉ์ ๋ํ๋ด๋ ํ ํฐ
- mask_token: ๋ง์คํฌ๋ฅผ ๋ํ๋ด๋ ํ ํฐ
- sep_token: ๊ตฌ๋ถ์ ๋ํ๋ด๋ ํ ํฐ
ํ ํฐ ์ฌ์ฉ์ ์๋ฅผ ๋ค์ด๋ณด์.
๋ง์ฝ ๋ด๊ฐ '์ค๋ ์ฐ์ธํด์ ํ๋งํ์ด'๋ฅผ ํ ํฌ๋์ด์ ๋ฅผ ํตํด ํ ํฐํ์ํค๋ฉด ๋ค์๊ณผ ๊ฐ์ด ํ ํฐํ๋๋ค:
['<bos>', '์ค๋', '์ฐ์ธ_', '_ํด์', 'ํ๋ง_', '_ํ์ด', '<eos>']
์ฌ๊ธฐ์ ๋ง์ฝ ๋ด๊ฐ ์ฌ์ฉํ ํ ํฌ๋์ด์ ์ ์ฌ์ ์ 'ํ๋ง'๋ผ๋ ๋จ์ด๊ฐ ์์ผ๋ฉด ๋ค์๊ณผ ๊ฐ์ด ํ ํฐํ๋๋ค:
['<bos>', '์ค๋', '<unk', '_ํด์', 'ํ๋ง_', '_ํ์ด', '<eos>']
๊ทธ๋ ๋ค๋ฉด, ์ด์ ์ด ํ ํฐ๋ค์ ์ปดํจํฐ๊ฐ ์์๋ค์ ์ ์๋๋ก ์ซ์๋ก ์ธ์ฝ๋ฉ ์์ผ๋ณด์.
ํจ์๋ ๊ฐ๋จํ๋ค. tokenizer.encode(์ธ์ฝ๋ฉ ์ํฌ ๋ฌธ์ฅ)๋ฅผ ์คํํด ์ฃผ๋ฉด ๋๋ค.
๊ทธ๋ผ ์์ ํ ํฐ๋ค์ ๋ค์๊ณผ ๊ฐ์ด ์ธ์ฝ๋ฉ ๋๋ค (์ซ์๋ ์์์ ์ซ์์):
[1, 8723, 1289, 7893, 4536, 1632, 0]
์ด๋ ๊ฒ tokenizer์ ์ฌ์ ์ ๋ฑ๋ก๋์ด index๊ฐ ์๋ ๊ฒฝ์ฐ๋ ํด๋น index๋ก,
์ฌ์ฉ์๊ฐ token์ ์ง์ ๋ฃ์ด์ค์ index๊ฐ ์๋ ๊ฒฝ์ฐ์๋ ์๋ก ํ ๋น๋ index๋ก ์ธ์ฝ๋ฉ ๋๋ค.
Tokenizing ์ ์ฉ
๊ทธ๋ผ ์ด ๋ฐฉ๋ฒ์ ๊ฐ๊ณตํ ๋ฐ์ดํฐ์ ์ ์ด์ฉํด ๋ณด์. ์ฝ๋๋ ๋ค์๊ณผ ๊ฐ๋ค:
class ChatDataset(Dataset):
def __init__(self, conversations, tokenizer, max_length=150):
self.tokenizer = tokenizer
self.inputs = []
self.labels = []
BOS_ID = tokenizer.bos_token_id
EOS_ID = tokenizer.eos_token_id
SEP_ID = tokenizer.sep_token_id
PAD_ID = tokenizer.pad_token_id
for conversation in conversations:
question, answer = conversation
# ํ ํฐํ
question_ids = tokenizer.encode(question)
answer_ids = tokenizer.encode(answer)
# [BOS] ์ง๋ฌธ [SEP] ๋ต๋ณ [EOS] ํํ๋ก ๋ง๋ฌ
input_ids = [BOS_ID] + question_ids + [SEP_ID] + answer_ids + [EOS_ID]
# ๋ ์ด๋ธ ์์ฑ (-100์ผ๋ก ํจ๋ฉ)
label_ids = [-100]*(len(question_ids)+2) + answer_ids + [-100]*1
# ํจ๋ฉ
while len(input_ids) < max_length:
input_ids.append(PAD_ID)
label_ids.append(-100)
self.inputs.append(input_ids)
self.labels.append(label_ids)
def __len__(self):
return len(self.inputs)
def __getitem__(self, idx):
return torch.tensor(self.inputs[idx]), torch.tensor(self.labels[idx])
๋ฌธ๋งฅ์ ์ ํํํ๊ธฐ ์ํด ์ง๋ฌธ๊ณผ ๊ทธ์ ๋ํ ๋ต๋ณ์ ํ ํฐํํ์ฌ [BOS] ์ง๋ฌธ [SEP] ๋ต๋ณ [EOS] ํํ๋ก ๋ง๋ค์ด ์ฃผ๋ค. ์ด๋ฌํ ๋ฐฉ์์ผ๋ก ๊ตฌ๋ถ๋๊ฒ ์ธ์ฝ๋ฉํจ์ผ๋ก์จ ๋ชจ๋ธ์ด ์ง๋ฌธ๊ณผ ๋ต๋ณ์ ๊ฒฝ๊ณ๋ฅผ ์ธ์ํ๊ณ ์ง๋ฌธ์ ๋ํ ๋ต๋ณ์ ์ ์์ฑํ ์ ์๋ค.
๋ํ ์ฑ๋ด์ ํ์ต์ ์ํด label๋ ๋ง๋ค์ด์ผ ํ๋ค. ์ฐ๋ฆฌ์ ๋ชฉํ๋ ์ง๋ฌธ์ ๋ง๋ ๋ต๋ณ์ ์์ฑํ๋ ๊ฒ์ด๊ธฐ ๋๋ฌธ์, ๋ต๋ณ ๋ถ๋ถ๋ง label๋ก ์ฌ์ฉํ์ต๋๋ค.
์ด๋ [-100]*(len(question_ids)+2) + answer_ids + [-100]*1 ํ์์ผ๋ก ์ธ์ฝ๋ฉํ๋ค. ์ด์ ๊ฐ์ด [BOS] ์ง๋ฌธ [SEP], [EOS] ๋ถ๋ถ์ [-100]์ผ๋ก ๋ง๋ ์ด์ ๋ ๋ต๋ณ ๋ถ๋ถ(answer_ids)๋ง ์ถ์ถํ์ฌ label๋ก ๋ง๋ค๊ธฐ ์ํด์์ด๋ค.
๊ทธ๋ ๋ค๋ฉด, [-100]์ผ๋ก ์ธ์ฝ๋ฉํ๋ ์ด์ ๋ ๋ฌด์์ผ๊น?
CrossEntropyLoss์ ๊ฐ์ ๋ช๋ช PyTorch์ ์์ค ํจ์์์๋ -100์ ํน์ํ ๊ฐ์ผ๋ก ์ฌ์ฉํ๊ธฐ ๋๋ฌธ์ด๋ค.
๋ ์ด๋ธ์ด -100์ธ ๊ฒฝ์ฐ, ํด๋น ์์น์ ์์ค์ ์ ์ฒด ์์ค ๊ณ์ฐ์์ ๋ฌด์๋๋ค. ์ฆ, ๊ทธ ์์น์ ์์ธก ๊ฐ์ด ์ด๋ค ๊ฒ์ด๋ ๊ฐ์ ์์ค์๋ ์ํฅ์ ์ฃผ์ง ์๋ ๊ฒ์ด๋ค.
์์์ ์์ฑํ ์ฝ๋์์๋ GPT-2 ๋ชจ๋ธ์ ํ์ตํ ๋, ์ง๋ฌธ ๋ถ๋ถ์ ํด๋นํ๋ ์์น์ ํจ๋ฉ ๋ถ๋ถ์ ์์ค์ ๊ณ์ฐ๋์ง ์๋๋ก ํ๊ธฐ ์ํด ์ด๋ฌํ ๋ฐฉ๋ฒ์ ์ฌ์ฉํ๋ค. ์ด๋ ๊ฒ ํ๋ฉด, ๋ชจ๋ธ์ ๋ต๋ณ ๋ถ๋ถ์ ๋ํด์๋ง ์์ค์ ์ต์ํํ๋ ๋ฐฉํฅ์ผ๋ก ํ์ตํ๊ฒ ๋๋ค.
3. ๋ชจ๋ธ ํ์ต
์ตํฐ๋ง์ด์ ๋ ์ผ๋ฐ์ ์ผ๋ก ์์ฐ์ด ์ฒ๋ฆฌ ๋ฐ ๋ฅ๋ฌ๋ ์์ ์์ ํจ๊ณผ์ ์ผ๋ก ์ฌ์ฉ๋๋ ์ตํฐ๋ง์ด์ ์ค ํ๋ AdamW๋ฅผ ์ฌ์ฉํ๋ค.
AdamW๋ ๋ชจ๋ธ์ ๊ฐ์ค์น๋ฅผ ์ ๊ทํํ๋ฏ๋ก ๋ชจ๋ธ์ ๊ฐ์ค์น๋ฅผ ์ ์ดํ๋ฉด์๋ ์ต์ ํ๋ฅผ ์ํํ ์ ์๋ค๊ณ ํ๋ค.
์์ฑ๋ ์ฝ๋๋ ๋ค์๊ณผ ๊ฐ๋ค:
from tqdm import tqdm
# ํ๋ จ ๋ฃจํ
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
for epoch in range(10): # ์ํญ ์๋ฅผ ์ ์ ํ๊ฒ ์กฐ์ ํด์ฃผ์ธ์.
total_loss = 0
for i, (inputs, labels) in enumerate(tqdm(dataloader)):
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(input_ids=inputs, labels=labels)
loss = outputs.loss
total_loss += loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"Epoch {epoch+1}, Loss: {total_loss/(i+1)}")
# ๋ชจ๋ธ ์ ์ฅ
torch.save(model.state_dict(), "koGpt2_chatbot.pt")
์ด๋ ๊ฒ ์คํํ์ฌ ์ถ์ถ๋ ๊ฐ epoch ๋น loss๋ ๋ค์๊ณผ ๊ฐ๋ค:
4. ๋ต๋ณ ์์ฑ
๋ง์ง๋ง์ผ๋ก, ๋ต๋ณ ์์ฑ์ ๋ค์๊ณผ ๊ฐ์ด ์ฝ๋๋ฅผ ์์ฑํ๋ค:
# ๋ต๋ณ ์์ฑ
generated = model.generate(
inputs,
max_length=60,
temperature=0.2,
no_repeat_ngram_size=3,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id
)
generated_text = tokenizer.decode(generated[:, inputs.shape[-1]:][0], skip_special_tokens=True)
- ์์ฑ ๊ณผ์ ์์ ์ํ๋ง์ ๋ค์์ฑ์ ์กฐ์ ํ๋ ์ญํ ์ ํ๋ temperature๋ 0.2๋ก ์ค์ ํ๋ค.
- temperature๊ฐ ๋ฎ์์๋ก ๊ฐ์ฅ ๊ฐ๋ฅ์ฑ์ด ๋์ ๋จ์ด๋ฅผ ์ ํํ๋๋ก ํ์ฌ ์์ฑ๋ ํ ์คํธ๊ฐ ๋ณด๋ค ์ผ๊ด๋๊ณ ๊ฒฐ์ ์ ์ธ ํน์ฑ์ ์ง๋๊ณ , ๋ฐ๋๋ก ๋์์๋ก ๋ฌด์์์ฑ์ ๋์ ํ์ฌ ๋ ๋ค์ํ๊ณ ์ฐฝ์์ ์ธ ๋ต๋ณ์ ์์ฑํ ์ ์๋ค๊ณ ํ๋ค.
- ์์ฑ ๊ณผ์ ์์ ๋ฐ๋ณต๋๋ n-gram์ ๋ฐฉ์งํ๋ ์ญํ ์ ํ๋ no_repeat_ngram_size๋ 3์ผ๋ก ์ง์ ํด ์ฃผ์๋ค.
5. ๊ฒฐ๊ณผ
์ดํ, ๊ฐ๋จํ๊ฒ ํ๋ก ํธ์ ๋ฐฑ์ ๊ตฌ์ฑํ๊ณ ๋ชจ๋ธ์ ์ฐ๊ฒฐ์์ผ ๋ดค๋ค. ํ์คํ ํ๋ฒ ํด๋ณธ ์ง๋ผ, 30๋ถ ๋ง์ ํ๋ก ํธ, ๋ฐฑ์ ๊ตฌ์ฑํ๋ค.
๊ฒฐ๊ณผ๋ ๋ค์๊ณผ ๊ฐ๋ค (์ฌ์ง ์ฒจ๋ถ):
์ฒ์์๋ ๊ด์ฐฎ์๋๋ฐ, ๋ค์ํ ์ง๋ฌธ์ ํ๋๊น ๋ต๋ณ์ด ์ด์ํด์ง๋ค. epoch์ ๋ ๋๋ ค์ ํ์ตํด ๋ด์ผ๊ฒ ๋ค.
๊ทธ๋ฆฌ๊ณ , ๋ต๋ณ์ ์ฒซ ๋ถ๋ถ์ ํญ์ ์ด์ํ๊ฒ ๋์ค๋๋ฐ, ๋ด๊ฐ ์์ฑํ ํ ์คํธ์์ ์ด์ด์ ์์ฑํ๋ ค๊ณ ํด์ ๊ทธ๋ฐ๊ฐ..??
๊ทธ ์ด์ ๋ฅผ ์ ๋ชจ๋ฅด๊ฒ ์ง๋ง ๊ณ ์ณ๋ด์ผ๊ฒ ๋ค...
ํ๋ฃจ ๋ง์ ๋ฐ์ดํฐ๋ฅผ ๊ฐ๋จํ๊ฒ ๊ฐ๊ณตํ๊ณ ํ์ตํ๋ ์ฝ๋๋ฅผ ์์ฑํ์ฌ ํ์ต์์ผ ๋ณด๊ณ , ๊ฒฐ๊ณผ๋ฌผ์ ๋ด๋ดค๋ค.
์กฐ๊ธ ๋นจ๋ฆฌ ํด์ผ ํด์ ์ ๊ตํ๊ฒ ์์ ์ ๋ชป ํ์ง๋ง, ๋ค์์๋ ์ข ๋ ๋ง์ ๋ฐ์ดํฐ์ ์ ๊ฐ์ label์ ํฌํจ์์ผ ์์ ํด๋ณด๊ณ ์ถ๋ค !!
'๐ป ํ๋ก์ ํธ > ๐งธ TOY-PROJECTS' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
[DeepLook] 6. ์ต์ข ๊ฒฐ๊ณผ๋ฌผ, ์ดํ ์๋ฌ ํธ๋ค๋ง๊ณผ ๋ง๋ฌด๋ฆฌ (2) | 2023.06.21 |
---|---|
[DeepLook] 5. ๋ฐฑ์๋ ์ฐ๊ฒฐ (0) | 2023.06.21 |
[DeepLook] 4. ๋ชจ๋ธ ์ ์ ๋ฐ ํ์ต (0) | 2023.06.21 |
[DeepLook] 3. ์ ์ฒ๋ฆฌ (haar-cascade ์๊ณ ๋ฆฌ์ฆ) (0) | 2023.06.20 |
[DeepLook] 2. AI ์์ ์ค๊ณ ๊ณผ์ / ํฌ๋กค๋ง (0) | 2023.06.20 |