๐Ÿ’ป ํ”„๋กœ์ ํŠธ/๐Ÿงธ TOY-PROJECTS

[DeepLook] 5. ๋ฐฑ์—”๋“œ ์—ฐ๊ฒฐ

์žฅ์˜์ค€ 2023. 6. 21. 04:12

์ด์ „๊นŒ์ง€๋Š” ํ”„๋ก ํŠธ ์ชฝ ์ž‘์—…๋งŒ ๋งก์•„์„œ ํ•˜๊ณ , ๋ฐฑ์—”๋“œ๋Š” node js๋กœ ํ•œ๋ฒˆ ๋ง›๋ณธ ์ƒํƒœ์˜€๋Š”๋ฐ, ๊ฐ‘์ž‘์Šค๋ ˆ ์›น๊ณผ ๋ฐฑ์„ ์—ฐ๊ฒฐํ•ด์•ผ ํ•˜๋Š” ํƒœ์Šคํฌ๊ฐ€ ์ฃผ์–ด์กŒ๋‹ค.

์‹œ๊ฐ„์ด ์—†๊ณ , ๋ชจ๋ธ์ด ๊ทœ๋ชจ๊ฐ€ ์žˆ๋Š” ๋ชจ๋ธ์€ ์•„๋‹ˆ์–ด์„œ flask๋ฅผ ์‚ฌ์šฉํ•˜๊ธฐ๋กœ ๊ฒฐ์ •ํ–ˆ๋‹ค.

๋” ์ž์„ธํ•œ ์ฝ”๋“œ๋ฅผ ๋ณด๊ณ  ์‹ถ์œผ๋ฉด ๊นƒํ—ˆ๋ธŒ๋ฅผ ์ฐธ๊ณ ํ•˜๊ธธ ๋ฐ”๋ž€๋‹ค. (์ข€ ๋‚œ์žกํ•  ์ˆ˜ ์žˆ์Œ ์ฃผ์˜)


๋ฐฑ์—”๋“œ ๋ชจ๋ธ๊ณผ ์—ฐ๊ฒฐ ์‹œ ๊ฑฐ์นœ ๋‹จ๊ณ„๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค:

  1. pickle ํŒŒ์ผ์„ ๋งŒ๋“ค์–ด ๋ชจ๋“  ์—ฐ์˜ˆ์ธ์˜ ์–ผ๊ตด ์ž„๋ฒ ๋”ฉ์„ ์ €์žฅํ•œ๋‹ค.
  2. client ์ธก์—์„œ post ์š”์ฒญ์„ ๋ณด๋ƒˆ์„ ๋•Œ ์ด๋ฏธ์ง€ ๋ฐ์ดํ„ฐ๋ฅผ ๋ฐ›์•„ ์ž„๋ฒ ๋”ฉ์„ ์ถ”์ถœํ•œ๋‹ค.
  3. ์ถ”์ถœํ•œ ์ž„๋ฒ ๋”ฉ๊ณผ ๊ฐ€์žฅ ์œ ์‚ฌํ•œ ์ž„๋ฒ ๋”ฉ์„ pickle ํŒŒ์ผ๋กœ๋ถ€ํ„ฐ ๋กœ๋”ฉํ•œ๋‹ค.
  4. ์ถ”์ถœ๋œ ์–ผ๊ตด ์ž„๋ฒ ๋”ฉ๊ณผ ์‚ฌ์ „ ํ•™์Šต๋œ ์ž„๋ฒ ๋”ฉ ๊ฐ„์˜ ์œ ์‚ฌ๋„๋ฅผ ๊ณ„์‚ฐํ•œ๋‹ค.
  5. ๊ฐ€์žฅ ์œ ์‚ฌํ•œ ์œ ๋ช…์ธ์˜ ์ด๋‹ˆ์…œ๊ณผ ์œ ์‚ฌ๋„๋ฅผ JSON ํ˜•์‹์œผ๋กœ ๋ฐ˜ํ™˜ํ•œ๋‹ค.

๊ฐ ๋‹จ๊ณ„์— ๊ด€ํ•œ ์ž์„ธํ•œ ์„ค๋ช…์„ ์•„๋ž˜์— ์ž‘์„ฑํ•ด๋ณธ๋‹ค.

1. pickle๋กœ ์—ฐ์˜ˆ์ธ๋“ค์˜ ๋ชจ๋“  ์–ผ๊ตด ์ž„๋ฒ ๋”ฉ ์ €์žฅ

ํ•ด๋‹น ์ž‘์—…์„ ์ˆ˜ํ–‰ํ•˜๊ธฐ ์œ„ํ•ด ์ž‘์„ฑํ•œ ์ฝ”๋“œ๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค:

# pickle ํŒŒ์ผ์— ์ด๋ฏธ์ง€๋ฅผ ์ €์žฅํ•˜๊ธฐ ์œ„ํ•ด ๋งŒ๋“ค์—ˆ์Šต๋‹ˆ๋‹ค.
import os
import glob
import cv2
import torch
import pickle
import numpy as np
from arcface_model import CustomArcFaceModel
from albumentations import Compose, Resize
from albumentations.pytorch import ToTensorV2

def load_image(image_path):
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    return image

def preprocess_image(image):
    if isinstance(image, np.ndarray):
        preprocess = Compose([
            Resize(224, 224),
            ToTensorV2()
        ])
        image = preprocess(image=image)['image']
        image = image.float() / 255.0
        return image.unsqueeze(0)
    elif isinstance(image, torch.Tensor):
        return image.float() / 255.0
    else:
        raise ValueError("Unsupported image format.")


num_classes = 11  # ๋ถ„๋ฅ˜ํ•  ํด๋ž˜์Šค์˜ ์ˆ˜
device = torch.device('cpu')

model = CustomArcFaceModel(num_classes)
model.load_state_dict(torch.load('arcface.pth', map_location=torch.device('cpu')))
model.eval()

celebrity_image_dict = {}

celebrity_initial_list = ['shg', 'idh', 'she', 'ijh', 'cde', 'chj', 'har', 'jjj', 'jsi', 'ojy', 'smo']

embeddings_dict = {}

for celebrity_initial in celebrity_initial_list:
    
  image_folder = f'/Users/jang-youngjoon/dev-projects/youtuber-look-alike/pre-processed-image/{celebrity_initial}/'
  image_files = glob.glob(os.path.join(image_folder, '*.jpg'))
  embeddings = []

  for image_file in image_files:
        image = load_image(image_file)
        preprocessed_image = preprocess_image(image).to(device)

        with torch.no_grad():
            embedding = model(preprocessed_image)
            embeddings.append(embedding.squeeze().cpu().numpy())

  embeddings_dict[celebrity_initial] = embeddings

  with open('trained_celebrity_embeddings.pkl', 'wb') as f:
      pickle.dump(embeddings_dict, f)

์ฝ”๋“œ์— ๊ด€ํ•œ ์„ค๋ช…์€ ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค:

  1. load_image(image_path) ํ•จ์ˆ˜๋Š” ์ฃผ์–ด์ง„ ์ด๋ฏธ์ง€ ๊ฒฝ๋กœ์—์„œ ์ด๋ฏธ์ง€๋ฅผ ์ฝ์–ด์™€ RGB ํ˜•์‹์œผ๋กœ ๋ณ€ํ™˜๋‹ค.
  2. preprocess_image(image) ํ•จ์ˆ˜๋Š” ์ž…๋ ฅ ์ด๋ฏธ์ง€๋ฅผ ์‚ฌ์ „ ์ฒ˜๋ฆฌํ•œ๋‹ค.
    ์ด๋ฏธ์ง€๊ฐ€ NumPy ๋ฐฐ์—ด์ธ ๊ฒฝ์šฐ, ํฌ๊ธฐ๋ฅผ ์žฌ์กฐ์ •ํ•˜๊ณ  ํ…์„œ๋กœ ๋ณ€ํ™˜ํ•œ ํ›„ 0์—์„œ 1 ์‚ฌ์ด์˜ ๊ฐ’์œผ๋กœ ์ •๊ทœํ™”ํ•ฉ๋‹ˆ๋‹ค. ์ด๋ฏธ์ง€๊ฐ€ ์ด๋ฏธ ํ…์„œ์ธ ๊ฒฝ์šฐ์—๋Š” ์ •๊ทœํ™”๋งŒ ์ˆ˜ํ–‰ํ•œ๋‹ค.
  3. num_classes ๋ณ€์ˆ˜๋Š” ๋ถ„๋ฅ˜ํ•  ํด๋ž˜์Šค์˜ ์ˆ˜๋ฅผ ์ง€์ •ํ•œ๋‹ค.
  4. celebrity_image_dict ๋ณ€์ˆ˜๋Š” ์ž„๋ฒ ๋”ฉ์„ ์ €์žฅํ•  ๋”•์…”๋„ˆ๋ฆฌ์ด๋‹ค.
  5. celebrity_initial_list ๋ณ€์ˆ˜๋Š” ์ž„๋ฒ ๋”ฉ์„ ์ถ”์ถœํ•  ์—ฐ์˜ˆ์ธ์˜ ์ด๋‹ˆ์…œ ๋ชฉ๋ก์ด๋‹ค.
  6. embeddings_dict ๋ณ€์ˆ˜๋Š” ์—ฐ์˜ˆ์ธ์˜ ์ด๋ฆ„์„ ํ‚ค๋กœ ํ•˜๊ณ , ํ•ด๋‹น ์—ฐ์˜ˆ์ธ์˜ ์ด๋ฏธ์ง€์— ๋Œ€ํ•œ ์ž„๋ฒ ๋”ฉ ๋ชฉ๋ก์„ ๊ฐ’์œผ๋กœ ๊ฐ–๋Š” ๋”•์…”๋„ˆ๋ฆฌ์ด๋‹ค.
  7. ์ฃผ์–ด์ง„ celebrity_initial_list์˜ ๊ฐ ์š”์†Œ์— ๋Œ€ํ•ด ๋‹ค์Œ ์ž‘์—…์„ ์ˆ˜ํ–‰ํ•œ๋‹ค:
    • ํ•ด๋‹น ์œ ๋ช…์ธ์˜ ์ด๋ฏธ์ง€ ํด๋”์—์„œ ์ด๋ฏธ์ง€ ํŒŒ์ผ ๋ชฉ๋ก์„ ๊ฐ€์ ธ์˜จ๋‹ค.
    • ๊ฐ ์ด๋ฏธ์ง€์— ๋Œ€ํ•ด ์–ผ๊ตด ์ž„๋ฒ ๋”ฉ์„ ์ถ”์ถœํ•˜๊ณ  ๋ฆฌ์ŠคํŠธ์— ์ถ”๊ฐ€ํ•œ๋‹ค.
    • ์ถ”์ถœํ•œ ์ž„๋ฒ ๋”ฉ ๋ชฉ๋ก์„ embeddings_dict์— ์ €์žฅํ•œ๋‹ค.
  8. ๋ชจ๋“  ์œ ๋ช…์ธ์— ๋Œ€ํ•œ ์ž„๋ฒ ๋”ฉ์ด ์ถ”์ถœ๋œ ํ›„, trained_celebrity_embeddings.pkl ํŒŒ์ผ์— embeddings_dict๋ฅผ ์ €์žฅํ•œ๋‹ค.

2. Client ์ธก์œผ๋กœ๋ถ€ํ„ฐ POST ์š”์ฒญ ๋ฐ›์•„ ์ „์ฒ˜๋ฆฌ

์‹œ๊ฐ„์ด ์—†์–ด์„œ presigned-url ๊ธฐ๋Šฅ์„ ๊ตฌํ˜„ํ•˜์ง€ ๋ชปํ•  ๊ฒƒ ๊ฐ™์•„, ๊ทธ๋ƒฅ client ์ธก์œผ๋กœ๋ถ€ํ„ฐ base64๋กœ encoding๋œ ๊ฐ’์„ ๋ฐ›์•˜๋‹ค.

@app.route("/flask/predict", methods=["POST"])
@cross_origin("*")
def predict():
    try:
        image_data = request.form.get("image")
        image_decoded = base64.b64decode(image_data)
        nparr = np.frombuffer(image_decoded, np.uint8)
        image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

์ดํ›„, ๋ฐ›์€ ์ด๋ฏธ์ง€๋ฅผ ์ด์ „๊ณผ ๊ฐ™์ด cropํ•˜๊ณ  resizing ํ•˜๋Š” ์ž‘์—…์„ ์ˆ˜ํ–‰ํ–ˆ๋‹ค.

def crop_face(image):
    face_cascade = cv2.CascadeClassifier("./haarcascade_frontalface_default.xml")
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    faces = face_cascade.detectMultiScale(gray, 1.1, 5)

    for x, y, w, h in faces:
        cropped_image = image[y : y + h, x : x + w]
        resized_image = cv2.resize(cropped_image, (224, 224))
        if resized_image.shape[0] > 0 and resized_image.shape[1] > 0:  # ์ด๋ฏธ์ง€๊ฐ€ ์กด์žฌํ•˜๋Š”์ง€ ํ™•์ธ
            return resized_image
    return None  # 404 ๋Œ€์‹  None ๋ฐ˜ํ™˜

์ด๋•Œ, ์ด๋ฏธ์ง€์—์„œ ์–ผ๊ตด์ด ๊ฒ€์ถœ๋˜์ง€ ์•Š๋Š” ๊ฒฝ์šฐ, None์„ ๋ฐ˜ํ™˜ํ–ˆ๋‹ค.

3. ์œ ์‚ฌํ•œ ์—ฐ์˜ˆ์ธ ์ž„๋ฒ ๋”ฉ ์ฐพ๊ธฐ

์–ผ๊ตด์ด ๊ฒ€์ถœ๋œ ๊ฒฝ์šฐ, pickle ํŒŒ์ผ์—์„œ ๊ฐ€์žฅ ์œ ์‚ฌํ•œ ์—ฐ์˜ˆ์ธ ์ž„๋ฒ ๋”ฉ์„ ์ฐพ์•„์ฃผ๋Š” ์ฝ”๋“œ๋ฅผ ์ž‘์„ฑํ–ˆ๋‹ค.

def find_most_similar_celebrity(self, user_face_embedding, celebrity_face_embeddings):
        max_similarity = -1
        most_similar_celebrity_index = -1

        for i, celebrity_embedding in enumerate(celebrity_face_embeddings):
            similarity = self.cosine_similarity(user_face_embedding, celebrity_embedding)
            if similarity > max_similarity:
                max_similarity = similarity
                most_similar_celebrity_index = i

        return most_similar_celebrity_index, max_similarity
        
def predict_celebrity(image):
    with torch.no_grad():
        cropped_image = crop_face(image)
        if cropped_image is None:  # None์ธ ๊ฒฝ์šฐ ์ฒดํฌ
            return [None, 0]  # celebrity_initial ๋ฐ ์ •ํ™•๋„๋ฅผ None, 0์œผ๋กœ ์„ค์ •
        else:
            image = preprocess(image=cropped_image)["image"]
            image = image.float() / 255.0
            image = image.unsqueeze(0).to(device)
            user_face_embedding = model(image).squeeze()

            closest_celebrity, max_similarity = model.find_most_similar_celebrity(
                user_face_embedding, trained_embeddings
            )
            return [closest_celebrity, max_similarity.item()]

์ตœ์ข…์ ์œผ๋กœ, ๊ฐ€์žฅ ์œ ์‚ฌํ•œ celebrity initial(์ด๋‹ˆ์…œ)๊ณผ accracy(์œ ์‚ฌ๋„)๋ฅผ json ํ˜•์‹์œผ๋กœ ๋ฐ˜ํ™˜ํ–ˆ๋‹ค.

@app.route("/flask/predict", methods=["POST"])
@cross_origin("*")
def predict():
    try:
        image_data = request.form.get("image")
        image_decoded = base64.b64decode(image_data)
        nparr = np.frombuffer(image_decoded, np.uint8)
        image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        prediction = predict_celebrity(image)
        celebrity_initial = get_initial(prediction[0])
        print("์ด๋‹ˆ์…œ:", celebrity_initial, "์ •ํ™•๋„:", prediction[1])
        return jsonify(
            {"celebrity_initial": celebrity_initial, "accuracy": prediction[1]}
        )
    except Exception as e:
        print(e)
        return jsonify({"error": "Error occurred during prediction"}), 500

4. ๋ฐฐํฌ

๋ฐฐํฌ๋Š” ๋ฐฑ์—”๋“œ๋ฅผ ๋งก๋Š” ๋™๋ฃŒ๊ฐ€ ๋„์™€์ฃผ์—ˆ๊ณ , ๋‚˜๋Š” docker ํŒŒ์ผ์„ ๋งŒ๋“ค์–ด buildํ•˜๋Š” ์ž‘์—…๊นŒ์ง€๋งŒ ์ˆ˜ํ–‰ํ–ˆ๋‹ค.

์ž‘์„ฑํ•œ dockerfile์€ ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค:

FROM python:3.10.9

COPY . /deep-look-ai
WORKDIR /deep-look-ai

RUN python3 -m pip install --upgrade pip
RUN	pip3 install -r requirements.txt
RUN apt-get update
RUN apt-get -y install libgl1-mesa-glx
CMD ["python3", "-m", "flask_app", "run", "--host=0.0.0.0", "--port=6000"]

์ด๋ ‡๊ฒŒ, ์„ฑ๊ณต์ ์œผ๋กœ ๋ฐฑ์—”๋“œ์™€ ์—ฐ๊ฒฐํ•  ์ˆ˜ ์žˆ๊ฒŒ ๋˜์—ˆ๋‹ค.

์ด ๊ณผ์ •์—์„œ ๊ฐœ์„ ํ•  ๋ฌธ์ œ์ ์€, POST api์˜ return ๊ฐ’์œผ๋กœ ์˜ˆ์ธก ๊ฒฐ๊ณผ๊ฐ’์„ ์ฃผ์—ˆ๋‹ค๋Š” ์ ์ด๋‹ค.

์ด ๋ถ€๋ถ„์„ GET api๋กœ ๋ฐ”๊พธ์–ด ์ „๋‹ฌํ•ด ์ฃผ์—ˆ๋‹ค๋ฉด, HTTP ํ”„๋กœํ† ์ฝœ ๊ทœ์น™์„ ์ข€ ๋” ์ž˜ ์ง€ํ‚ฌ ์ˆ˜ ์žˆ์—ˆ์„ ๊ฒƒ ๊ฐ™๋‹ค.

๋‹ค์Œ์€ ์ดํ›„ ์ƒํ™ฉ์— ๋Œ€ํ•œ ์—๋Ÿฌ ํ•ธ๋“ค๋ง์— ๋Œ€ํ•ด ๋‹ค๋ค„๋ณผ ์˜ˆ์ •์ด๋‹ค.