๐Ÿ’ป ํ”„๋กœ์ ํŠธ/๐Ÿงธ TOY-PROJECTS

[DeepLook] 4. ๋ชจ๋ธ ์„ ์ • ๋ฐ ํ•™์Šต

์žฅ์˜์ค€ 2023. 6. 21. 02:18

์‚ฌ์ง„ ์ „์ฒ˜๋ฆฌ ์ดํ›„, ๋ชจ๋ธ์„ ์„ ์ •ํ•˜๊ณ  ํ•™์Šต์‹œํ‚ค๋Š” ๊ณผ์ •์„ ๊ฑฐ์ณค๋‹ค.

๋ชจ๋ธ์˜ ํ›„๋ณด๋Š” ResNet, EfficientNet, Arcface๊ฐ€ ์žˆ์—ˆ๋Š”๋ฐ, ๊ฐ€์žฅ ์–ผ๊ตด ์œ ์‚ฌ๋„ ๋ถ€๋ถ„์—์„œ ํฐ ์„ฑ๋Šฅ์„ ๋ณด์ด๋Š” Arcface ๋ชจ๋ธ์„ ์„ ์ •ํ–ˆ๋‹ค.

๋” ์ž์„ธํ•œ ๊ณผ์ •์€ ์ฝ”๋žฉ์„ ํ†ตํ•ด ํ™•์ธํ•  ์ˆ˜ ์žˆ๋‹ค.

1. CSV ํŒŒ์ผ ์ƒ์„ฑ

์šฐ์„  ์ „์ฒ˜๋ฆฌ๋œ ์‚ฌ์ง„๋“ค์˜ ์ด๋ฆ„(์ด๋ฆ„_์ˆœ๋ฒˆ ํ˜•ํƒœ)๊ณผ ํ•ด๋‹น ์ธ๋ฌผ๋“ค์ด label ๋œ csv ํŒŒ์ผ์„ ์ƒ์„ฑํ–ˆ๋‹ค.

SMO๋Š” ์†๋ช…์˜ค, IJH๋Š” ์ž„์ง€์—ฐ

2. Train, Test dataset ๋ถ„๋ฆฌ

ํ•ญ๋ชฉ๋ณ„๋กœ 70%๋Š” train, 30%๋Š” test์˜ ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ ๋ถ„๋ฆฌํ–ˆ๋‹ค.

# ์ด๋‹ˆ์…œ์„ ํ•˜๋‚˜์˜ ๋ฐฐ์—ด๋กœ ๋ชจ์œผ๊ธฐ
class_name_list = []
tmp=df.copy()
for initial in tmp['class']:
  if initial not in class_name_list:
    class_name_list.append(initial)
print(class_name_list) #['SMO', 'CHJ', 'CDE', 'HAR', 'JJJ', 'JSI', 'OJY', 'SHE', 'SHG', 'IDH', 'IJH']

# train, valid ๋ณ„ dataframe ์ƒ์„ฑ
train = pd.DataFrame(columns=tmp.columns)
valid = pd.DataFrame(columns=tmp.columns)

# train_test_split ํ•จ์ˆ˜๋ฅผ ์ด์šฉํ•˜์—ฌ 30%๋Š” test, 70%๋Š” train์œผ๋กœ ๋ถ„๋ฆฌ
for class_name in class_name_list:
  tmp_with_class = tmp.loc[tmp['class'] == class_name]
  train_tmp, valid_tmp = train_test_split(tmp_with_class, test_size = 0.3, random_state = 42)

  train_tmp['class'] = class_name
  valid_tmp['class'] = class_name
  train = pd.concat([train, train_tmp])
  valid = pd.concat([valid, valid_tmp])

์ด๋ ‡๊ฒŒ train ์ด๋ผ๋Š” ๋ณ€์ˆ˜์—๋Š” ๊ฐ ์ด๋‹ˆ์…œ์— ํ•ด๋‹นํ•˜๋Š” ๋ฐ์ดํ„ฐ์˜ 70%๊ฐ€, 

test๋ผ๋Š” ๋ณ€์ˆ˜์—๋Š” ๊ฐ ์ด๋‹ˆ์…œ์— ํ•ด๋‹นํ•˜๋Š” ๋ฐ์ดํ„ฐ์˜ 30%๊ฐ€ ํ• ๋‹น๋˜์—ˆ๋‹ค.

3. One-hot encoding ์ ์šฉ

one_hot_encoded = pd.get_dummies(tmp['class'])
train_one_hot_encoded = pd.get_dummies(train['class'])
valid_one_hot_encoded = pd.get_dummies(valid['class'])

data = pd.concat([tmp, one_hot_encoded], axis=1)
data = data.drop(['class'], axis=1)
train = pd.concat([train, train_one_hot_encoded], axis=1)
train = train.drop(['class'], axis=1)
valid = pd.concat([valid, valid_one_hot_encoded], axis=1)
valid = valid.drop(['class'], axis=1)

valid

์ดํ›„, ๊ฐ ์ด๋‹ˆ์…œ์— ๋Œ€ํ•ด one-hot encoding์„ ์ ์šฉํ–ˆ๋”๋‹ˆ ๊ฒฐ๊ณผ๊ฐ€ ๋‹ค์Œ๊ณผ ๊ฐ™์ด ๋‚˜์™”๋‹ค.

valid์— ๋Œ€ํ•œ ์ถœ๋ ฅ ๊ฒฐ๊ณผ, one-hot encoding์ด ์ ์šฉ๋๋‹ค.

4. CustomDataset

์šฐ์„  ๊ฐ๊ฐ์˜ ๋ฐ์ดํ„ฐ์— albumentations ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ data augmentation์„ ์ง„ํ–‰ํ–ˆ๋‹ค.

train_transform = A.Compose([
    A.Resize(224, 224),
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.5),
    A.RandomBrightnessContrast(brightness_limit=(-0.3, 0.3), contrast_limit=(-0.3, 0.3), p=1),
    A.ChannelShuffle(p=0.2),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

valid_transform = A.Compose([
    A.Resize(224, 224),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

์ด๋•Œ train๊ณผ valid๋ฅผ ๋‹ค๋ฅด๊ฒŒ ์ฆ๊ฐ•์‹œํ‚จ ์ด์œ ๋Š” ๋Œ€ํ‘œ์ ์œผ๋กœ ๊ณผ์ ํ•ฉ(Overfitting)์„ ๋ฐฉ์ง€ํ•˜๊ธฐ ์œ„ํ•ด์„œ์ด๋‹ค.

์ „์ฒด ๋ฐ์ดํ„ฐ์…‹์„ ์‚ฌ์šฉํ•˜๋ฉด ๋ชจ๋ธ์ด ๋ฐ์ดํ„ฐ๋ฅผ ์™ธ์šฐ๋Š” ํ˜„์ƒ์ธ ๊ณผ์ ํ•ฉ์ด ๋ฐœ์ƒํ•  ๊ฐ€๋Šฅ์„ฑ์ด ๋†’์•„์ง„๋‹ค.

๋”ฐ๋ผ์„œ ๋ชจ๋ธ์ด ์ผ๋ฐ˜ํ™”ํ•  ์ˆ˜ ์žˆ๋Š” ๋Šฅ๋ ฅ์„ ํ–ฅ์ƒํ•˜๊ธฐ ์œ„ํ•ด์„œ๋Š” ์ผ๋ถ€ ๋ฐ์ดํ„ฐ๋ฅผ ๋–ผ์–ด๋‚ด์–ด ๊ฒ€์ฆ(validation)์— ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ด ์ข‹๋‹ค.

์ดํ›„, ๋‹ค์Œ๊ณผ ๊ฐ™์ด CustomDataset์„ ์ •์˜ํ–ˆ๋‹ค.

class CustomDataset(Dataset):
    def __init__(self, file_list, label_list, transform=None):
        self.file_list = file_list
        self.label_list = label_list
        self.transform = transform
        
    def __len__(self):
        return len(self.file_list)
    
    def __getitem__(self, index):
        image = cv2.imread(self.file_list[index])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # BGR -> RGB
        
        if self.transform:
            transformed = self.transform(image=image, force_apply=False)
            image = transformed["image"]
        
        label = self.label_list[index]
        return image, label

์ž‘์„ฑํ•œ CustomDataset์„ ์ด์šฉํ•ด์„œ Dataloader๋ฅผ ์ ์šฉํ–ˆ๋‹ค.

from torch.utils.data import DataLoader

# ํŒŒ์ผ ์ด๋ฆ„ ๋ฐ ๋ ˆ์ด๋ธ” ๋ชฉ๋ก ์ถ”์ถœ
train_files = train["file_name"].tolist()
train_labels = train.drop(["file_name", "index"], axis=1).values
valid_files = valid["file_name"].tolist()
valid_labels = valid.drop(["file_name", "index"], axis=1).values

# CustomDataset ์ •์˜
train_dataset = CustomDataset(train_files, train_labels, transform=train_transform)
valid_dataset = CustomDataset(valid_files, valid_labels, transform=valid_transform)

# DataLoader ์ •์˜
batch_size = 16

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True)

๋‹ค์Œ ๊ฒฐ๊ณผ๋ฅผ ํ™•์ธํ•ด๋ณด์ž.

for x,y in valid_loader:
  print(f'Image Shape: {x.shape}')
  print(f'Label Shape: {y.shape}')
  break

์ด๋ฏธ์ง€์˜ shape์ด ์˜๋„ํ–ˆ๋˜ ๋Œ€๋กœ 224 * 224๋กœ ๋‚˜์˜ค๊ณ  (3์€ RGB๋ฅผ ๋œปํ•œ๋‹ค.), label๋„ ์ด๋‹ˆ์…œ ๊ฐœ์ˆ˜์ธ 11๊ฐœ๋กœ ์ž˜ ๋‚˜์˜จ ๊ฒƒ์„ ๋ณผ ์ˆ˜ ์žˆ๋‹ค.

5. Model ์ •์˜

๋„ˆ๋ฌด ๊ธธ์–ด์„œ ์ž์„ธํ•œ arcface ๋ชจ๋ธ์— ๊ด€ํ•ด์„œ๋Š” ๋‹ค์Œ ๋ธ”๋กœ๊ทธ์—์„œ ์ •๋ฆฌํ•˜๊ธฐ๋กœ ํ•˜๊ณ , ์ž‘์„ฑํ•œ ์ฝ”๋“œ๋งŒ ์ฒจ๋ถ€ํ•ด ๋ณด๋„๋ก ํ•˜๊ฒ ๋‹ค.

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import math

class ArcMarginProduct(nn.Module):
    def __init__(self, in_features, out_features, scale=30.0, margin=0.50, easy_margin=False, device='cuda'):
        super(ArcMarginProduct, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.scale = scale
        self.margin = margin
        self.device = device
        self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)

        self.easy_margin = easy_margin
        self.cos_m = math.cos(margin)
        self.sin_m = math.sin(margin)
        self.th = math.cos(math.pi - margin)
        self.mm = math.sin(math.pi - margin) * margin

    def forward(self, input, label):
        cosine = F.linear(F.normalize(input), F.normalize(self.weight))
        sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
        phi = cosine * self.cos_m - sine * self.sin_m
        if self.easy_margin:
            phi = torch.where(cosine > 0, phi, cosine)
        else:
            phi = torch.where(cosine > self.th, phi, cosine - self.mm)

        one_hot = torch.zeros(cosine.size(), device=input.device)
        one_hot.scatter_(1, label.view(-1, 1).long(), 1)
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output *= self.scale

        return output

class CustomArcFaceModel(nn.Module):
    def __init__(self, num_classes, device='cuda'):
        super(CustomArcFaceModel, self).__init__()
        self.device = device
        self.backbone = nn.Sequential(*list(models.resnet50(pretrained=True).children())[:-1])
        self.arc_margin_product = ArcMarginProduct(2048, num_classes, device=self.device)
        nn.init.kaiming_normal_(self.arc_margin_product.weight)

    def forward(self, x, labels=None):
        features = self.backbone(x)
        features = F.normalize(features)
        features = features.view(features.size(0), -1)
        if labels is not None:
            logits = self.arc_margin_product(features, labels)
            return logits

        return features

    def cosine_similarity(self, x1, x2):
        return torch.dot(x1, x2) / (torch.norm(x1) * torch.norm(x2))

    def find_most_similar_celebrity(self, user_face_embedding, celebrity_face_embeddings):
        max_similarity = -1
        most_similar_celebrity_index = -1

        for i, celebrity_embedding in enumerate(celebrity_face_embeddings):
            similarity = self.cosine_similarity(user_face_embedding, celebrity_embedding)
            if similarity > max_similarity:
                max_similarity = similarity
                most_similar_celebrity_index = i

        return most_similar_celebrity_index, max_similarity

์ด๋ฏธ์ง€๊ฐ€ ์ž…๋ ฅ๋์„ ๋•Œ, ์œ„ ์ฝ”๋“œ๋ฅผ ๊ฑฐ์นจ์œผ๋กœ ๋‹ค์Œ๊ณผ ๊ฐ™์€ ๊ณผ์ •์ด ์ผ์–ด๋‚œ๋‹ค:

  1. ์ž…๋ ฅ ์ด๋ฏธ์ง€๋ฅผ CustomArcFaceModel์˜ forward ๋ฉ”์„œ๋“œ์— ์ „๋‹ฌํ•œ๋‹ค.
  2. ์ž…๋ ฅ ์ด๋ฏธ์ง€๋Š” self.backbone์œผ๋กœ ์ง€์ •๋œ ๋ฐฑ๋ณธ ๋ชจ๋ธ(์—ฌ๊ธฐ์„œ๋Š” ResNet-50)์„ ํ†ต๊ณผํ•œ๋‹ค.
  3. ๋ฐฑ๋ณธ ๋ชจ๋ธ์€ ์ด๋ฏธ์ง€๋ฅผ ํŠน์ง• ๋งต(feature map)์œผ๋กœ ๋ณ€ํ™˜ํ•œ๋‹ค.
  4. ํŠน์ง• ๋งต์€ ์ •๊ทœํ™”(normalization)๋œ๋‹ค.
  5. ์ •๊ทœํ™”๋œ ํŠน์ง• ๋งต์€ ๋ฒกํ„ฐ๋กœ ํŽผ์ณ์ง„๋‹ค.
  6. ํŽผ์ณ์ง„ ํŠน์ง• ๋ฒกํ„ฐ๋Š” self.arc_margin_product์ธ ArcMarginProduct ๋ ˆ์ด์–ด๋ฅผ ํ†ต๊ณผํ•œ๋‹ค.
  7. ArcMarginProduct๋Š” ์ž…๋ ฅ ๋ฒกํ„ฐ์™€ ๋ ˆ์ด๋ธ”์„ ๋ฐ›์•„์„œ, cosine ์œ ์‚ฌ๋„๋ฅผ ๊ณ„์‚ฐํ•˜๊ณ  ArcFace ์†์‹ค ํ•จ์ˆ˜๋ฅผ ์ ์šฉํ•˜์—ฌ ๋กœ์ง“(logits)์„ ๊ณ„์‚ฐํ•œ๋‹ค.
  8. ๊ณ„์‚ฐ๋œ ๋กœ์ง“์€ ๋ฐ˜ํ™˜๋œ๋‹ค.

๋”ฐ๋ผ์„œ, CustomArcFaceModel ํด๋ž˜์Šค๋ฅผ ์‚ฌ์šฉํ•˜๋ฉด ์ž…๋ ฅ ์ด๋ฏธ์ง€๊ฐ€ ๋ฐฑ๋ณธ ๋ชจ๋ธ์ธ ResNet-50์„ ํ†ต๊ณผํ•˜์—ฌ ํŠน์ง• ๋ฒกํ„ฐ๋กœ ๋ณ€ํ™˜๋˜๊ณ , ์ด ํŠน์ง• ๋ฒกํ„ฐ๋Š” ArcMarginProduct ๋ ˆ์ด์–ด๋ฅผ ํ†ตํ•ด ๋กœ์ง“์œผ๋กœ ๋ณ€ํ™˜๋œ๋‹ค. 

6. ๋ชจ๋ธ train

train ์‹œ, epoch๋งˆ๋‹ค ์ถœ๋ ฅ๊ฐ’์ด ๋‚˜์˜ค๋„๋ก ์ฝ”๋“œ๋ฅผ ์ž‘์„ฑํ–ˆ๋‹ค.

from tqdm import tqdm

def train(model, optimizer, criterion, train_loader, valid_loader, device, epochs):
    model.to(device)
    best_accuracy = 0.0
    
    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        train_corrects = 0
        
        for x, y in tqdm(train_loader, desc=f"Epoch {epoch + 1} - Training"):
            x = x.to(device)
            
            # y = y.to(device)
            y = torch.argmax(y, dim=1).to(device)
            optimizer.zero_grad()
            output = model(x, y)
            loss = criterion(output, y)
            loss.backward()
            optimizer.step()
            
            _, preds = torch.max(output, 1)
            train_loss += loss.item() * x.size(0)
            train_corrects += torch.sum(torch.eq(torch.round(preds), y.data)).float()
            
        train_loss = train_loss / len(train_loader.dataset)
        train_accuracy = train_corrects.double() / len(train_loader.dataset)
        
        model.eval()
        valid_loss = 0.0
        valid_corrects = 0
        
        with torch.no_grad():
            for x, y in tqdm(valid_loader, desc=f"Epoch {epoch + 1} - Validation"):
                x = x.to(device)
                # y = y.to(device)
                y = torch.argmax(y, dim=1).to(device)
                output = model(x, y)
                loss = criterion(output, y)
                
                _, preds = torch.max(output, 1)
                valid_loss += loss.item() * x.size(0)
                valid_corrects += torch.sum(torch.eq(torch.round(preds), y.data)).float()
                
        valid_loss = valid_loss / len(valid_loader.dataset)
        valid_accuracy = valid_corrects.double() / len(valid_loader.dataset)
        
        print(f"Epoch {epoch + 1}/{epochs} - Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.4f}, Valid Loss: {valid_loss:.4f}, Valid Acc: {valid_accuracy:.4f}")
        
        if valid_accuracy > best_accuracy:
            best_accuracy = valid_accuracy
            torch.save(model.state_dict(), "arcface.pth")
            
    return model
# ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ ์„ค์ •
num_classes = 11  # ๋ถ„๋ฅ˜ํ•  ํด๋ž˜์Šค์˜ ์ˆ˜ (CDE, CHJ, HAR, IDH, IJH, JJJ, JSI, OJY, SHE, SHG, SMO)
embedding_size = 2048
learning_rate = CFG['LEARNING_RATE']
epochs = 150

# ๋ชจ๋ธ ์ƒ์„ฑ
model = CustomArcFaceModel(num_classes)

# ์˜ตํ‹ฐ๋งˆ์ด์ € ๋ฐ ์†์‹ค ํ•จ์ˆ˜ ์„ค์ •
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)
criterion = nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2,threshold_mode='abs',min_lr=1e-8, verbose=True)

# ๋ชจ๋ธ ํ•™์Šต
trained_model = train(model, optimizer, criterion, train_loader, valid_loader, device, epochs)

์ดํ›„, ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ํ• ๋‹นํ•˜์—ฌ ๋ชจ๋ธ์„ ํ•™์Šต์‹œ์ผฐ๋”๋‹ˆ, ๋‹ค์Œ๊ณผ ๊ฐ™์ด ์ตœ์ข… ๊ฒฐ๊ณผ๊ฐ€ ๋‚˜์™”๋‹ค.

๊ฒฐ๊ณผ๋ฅผ ๋ณด๋ฉด, train accuracy๋Š” ๋†’์€ ๊ฒƒ์— ๋น„ํ•ด valid accuracy๊ฐ€ ๋‚ฎ์€ ๊ฒƒ์„ ๋ณผ ์ˆ˜ ์žˆ๋‹ค. ์ด๋Š” ๊ณผ์ ํ•ฉ ํ˜„์ƒ์œผ๋กœ ํŒ๋‹จํ•  ์ˆ˜ ์žˆ๋‹ค.

ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์กฐ๊ธˆ ๋” ์กฐ์ •ํ•˜์—ฌ train ์‹œ์ผœ๋ณผ ์˜ˆ์ •์ด๋‹ค.