์ด์ ๊น์ง๋ ํ๋ก ํธ ์ชฝ ์์ ๋ง ๋งก์์ ํ๊ณ , ๋ฐฑ์๋๋ node js๋ก ํ๋ฒ ๋ง๋ณธ ์ํ์๋๋ฐ, ๊ฐ์์ค๋ ์น๊ณผ ๋ฐฑ์ ์ฐ๊ฒฐํด์ผ ํ๋ ํ์คํฌ๊ฐ ์ฃผ์ด์ก๋ค.
์๊ฐ์ด ์๊ณ , ๋ชจ๋ธ์ด ๊ท๋ชจ๊ฐ ์๋ ๋ชจ๋ธ์ ์๋์ด์ flask๋ฅผ ์ฌ์ฉํ๊ธฐ๋ก ๊ฒฐ์ ํ๋ค.
๋ ์์ธํ ์ฝ๋๋ฅผ ๋ณด๊ณ ์ถ์ผ๋ฉด ๊นํ๋ธ๋ฅผ ์ฐธ๊ณ ํ๊ธธ ๋ฐ๋๋ค. (์ข ๋์กํ ์ ์์ ์ฃผ์)
๋ฐฑ์๋ ๋ชจ๋ธ๊ณผ ์ฐ๊ฒฐ ์ ๊ฑฐ์น ๋จ๊ณ๋ ๋ค์๊ณผ ๊ฐ๋ค:
- pickle ํ์ผ์ ๋ง๋ค์ด ๋ชจ๋ ์ฐ์์ธ์ ์ผ๊ตด ์๋ฒ ๋ฉ์ ์ ์ฅํ๋ค.
- client ์ธก์์ post ์์ฒญ์ ๋ณด๋์ ๋ ์ด๋ฏธ์ง ๋ฐ์ดํฐ๋ฅผ ๋ฐ์ ์๋ฒ ๋ฉ์ ์ถ์ถํ๋ค.
- ์ถ์ถํ ์๋ฒ ๋ฉ๊ณผ ๊ฐ์ฅ ์ ์ฌํ ์๋ฒ ๋ฉ์ pickle ํ์ผ๋ก๋ถํฐ ๋ก๋ฉํ๋ค.
- ์ถ์ถ๋ ์ผ๊ตด ์๋ฒ ๋ฉ๊ณผ ์ฌ์ ํ์ต๋ ์๋ฒ ๋ฉ ๊ฐ์ ์ ์ฌ๋๋ฅผ ๊ณ์ฐํ๋ค.
- ๊ฐ์ฅ ์ ์ฌํ ์ ๋ช ์ธ์ ์ด๋์ ๊ณผ ์ ์ฌ๋๋ฅผ JSON ํ์์ผ๋ก ๋ฐํํ๋ค.
๊ฐ ๋จ๊ณ์ ๊ดํ ์์ธํ ์ค๋ช ์ ์๋์ ์์ฑํด๋ณธ๋ค.
1. pickle๋ก ์ฐ์์ธ๋ค์ ๋ชจ๋ ์ผ๊ตด ์๋ฒ ๋ฉ ์ ์ฅ
ํด๋น ์์ ์ ์ํํ๊ธฐ ์ํด ์์ฑํ ์ฝ๋๋ ๋ค์๊ณผ ๊ฐ๋ค:
# pickle ํ์ผ์ ์ด๋ฏธ์ง๋ฅผ ์ ์ฅํ๊ธฐ ์ํด ๋ง๋ค์์ต๋๋ค.
import os
import glob
import cv2
import torch
import pickle
import numpy as np
from arcface_model import CustomArcFaceModel
from albumentations import Compose, Resize
from albumentations.pytorch import ToTensorV2
def load_image(image_path):
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
return image
def preprocess_image(image):
if isinstance(image, np.ndarray):
preprocess = Compose([
Resize(224, 224),
ToTensorV2()
])
image = preprocess(image=image)['image']
image = image.float() / 255.0
return image.unsqueeze(0)
elif isinstance(image, torch.Tensor):
return image.float() / 255.0
else:
raise ValueError("Unsupported image format.")
num_classes = 11 # ๋ถ๋ฅํ ํด๋์ค์ ์
device = torch.device('cpu')
model = CustomArcFaceModel(num_classes)
model.load_state_dict(torch.load('arcface.pth', map_location=torch.device('cpu')))
model.eval()
celebrity_image_dict = {}
celebrity_initial_list = ['shg', 'idh', 'she', 'ijh', 'cde', 'chj', 'har', 'jjj', 'jsi', 'ojy', 'smo']
embeddings_dict = {}
for celebrity_initial in celebrity_initial_list:
image_folder = f'/Users/jang-youngjoon/dev-projects/youtuber-look-alike/pre-processed-image/{celebrity_initial}/'
image_files = glob.glob(os.path.join(image_folder, '*.jpg'))
embeddings = []
for image_file in image_files:
image = load_image(image_file)
preprocessed_image = preprocess_image(image).to(device)
with torch.no_grad():
embedding = model(preprocessed_image)
embeddings.append(embedding.squeeze().cpu().numpy())
embeddings_dict[celebrity_initial] = embeddings
with open('trained_celebrity_embeddings.pkl', 'wb') as f:
pickle.dump(embeddings_dict, f)
์ฝ๋์ ๊ดํ ์ค๋ช ์ ๋ค์๊ณผ ๊ฐ๋ค:
- load_image(image_path) ํจ์๋ ์ฃผ์ด์ง ์ด๋ฏธ์ง ๊ฒฝ๋ก์์ ์ด๋ฏธ์ง๋ฅผ ์ฝ์ด์ RGB ํ์์ผ๋ก ๋ณํ๋ค.
- preprocess_image(image) ํจ์๋ ์
๋ ฅ ์ด๋ฏธ์ง๋ฅผ ์ฌ์ ์ฒ๋ฆฌํ๋ค.
์ด๋ฏธ์ง๊ฐ NumPy ๋ฐฐ์ด์ธ ๊ฒฝ์ฐ, ํฌ๊ธฐ๋ฅผ ์ฌ์กฐ์ ํ๊ณ ํ ์๋ก ๋ณํํ ํ 0์์ 1 ์ฌ์ด์ ๊ฐ์ผ๋ก ์ ๊ทํํฉ๋๋ค. ์ด๋ฏธ์ง๊ฐ ์ด๋ฏธ ํ ์์ธ ๊ฒฝ์ฐ์๋ ์ ๊ทํ๋ง ์ํํ๋ค. - num_classes ๋ณ์๋ ๋ถ๋ฅํ ํด๋์ค์ ์๋ฅผ ์ง์ ํ๋ค.
- celebrity_image_dict ๋ณ์๋ ์๋ฒ ๋ฉ์ ์ ์ฅํ ๋์ ๋๋ฆฌ์ด๋ค.
- celebrity_initial_list ๋ณ์๋ ์๋ฒ ๋ฉ์ ์ถ์ถํ ์ฐ์์ธ์ ์ด๋์ ๋ชฉ๋ก์ด๋ค.
- embeddings_dict ๋ณ์๋ ์ฐ์์ธ์ ์ด๋ฆ์ ํค๋ก ํ๊ณ , ํด๋น ์ฐ์์ธ์ ์ด๋ฏธ์ง์ ๋ํ ์๋ฒ ๋ฉ ๋ชฉ๋ก์ ๊ฐ์ผ๋ก ๊ฐ๋ ๋์ ๋๋ฆฌ์ด๋ค.
- ์ฃผ์ด์ง celebrity_initial_list์ ๊ฐ ์์์ ๋ํด ๋ค์ ์์
์ ์ํํ๋ค:
- ํด๋น ์ ๋ช ์ธ์ ์ด๋ฏธ์ง ํด๋์์ ์ด๋ฏธ์ง ํ์ผ ๋ชฉ๋ก์ ๊ฐ์ ธ์จ๋ค.
- ๊ฐ ์ด๋ฏธ์ง์ ๋ํด ์ผ๊ตด ์๋ฒ ๋ฉ์ ์ถ์ถํ๊ณ ๋ฆฌ์คํธ์ ์ถ๊ฐํ๋ค.
- ์ถ์ถํ ์๋ฒ ๋ฉ ๋ชฉ๋ก์ embeddings_dict์ ์ ์ฅํ๋ค.
- ๋ชจ๋ ์ ๋ช ์ธ์ ๋ํ ์๋ฒ ๋ฉ์ด ์ถ์ถ๋ ํ, trained_celebrity_embeddings.pkl ํ์ผ์ embeddings_dict๋ฅผ ์ ์ฅํ๋ค.
2. Client ์ธก์ผ๋ก๋ถํฐ POST ์์ฒญ ๋ฐ์ ์ ์ฒ๋ฆฌ
์๊ฐ์ด ์์ด์ presigned-url ๊ธฐ๋ฅ์ ๊ตฌํํ์ง ๋ชปํ ๊ฒ ๊ฐ์, ๊ทธ๋ฅ client ์ธก์ผ๋ก๋ถํฐ base64๋ก encoding๋ ๊ฐ์ ๋ฐ์๋ค.
@app.route("/flask/predict", methods=["POST"])
@cross_origin("*")
def predict():
try:
image_data = request.form.get("image")
image_decoded = base64.b64decode(image_data)
nparr = np.frombuffer(image_decoded, np.uint8)
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
์ดํ, ๋ฐ์ ์ด๋ฏธ์ง๋ฅผ ์ด์ ๊ณผ ๊ฐ์ด cropํ๊ณ resizing ํ๋ ์์ ์ ์ํํ๋ค.
def crop_face(image):
face_cascade = cv2.CascadeClassifier("./haarcascade_frontalface_default.xml")
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
faces = face_cascade.detectMultiScale(gray, 1.1, 5)
for x, y, w, h in faces:
cropped_image = image[y : y + h, x : x + w]
resized_image = cv2.resize(cropped_image, (224, 224))
if resized_image.shape[0] > 0 and resized_image.shape[1] > 0: # ์ด๋ฏธ์ง๊ฐ ์กด์ฌํ๋์ง ํ์ธ
return resized_image
return None # 404 ๋์ None ๋ฐํ
์ด๋, ์ด๋ฏธ์ง์์ ์ผ๊ตด์ด ๊ฒ์ถ๋์ง ์๋ ๊ฒฝ์ฐ, None์ ๋ฐํํ๋ค.
3. ์ ์ฌํ ์ฐ์์ธ ์๋ฒ ๋ฉ ์ฐพ๊ธฐ
์ผ๊ตด์ด ๊ฒ์ถ๋ ๊ฒฝ์ฐ, pickle ํ์ผ์์ ๊ฐ์ฅ ์ ์ฌํ ์ฐ์์ธ ์๋ฒ ๋ฉ์ ์ฐพ์์ฃผ๋ ์ฝ๋๋ฅผ ์์ฑํ๋ค.
def find_most_similar_celebrity(self, user_face_embedding, celebrity_face_embeddings):
max_similarity = -1
most_similar_celebrity_index = -1
for i, celebrity_embedding in enumerate(celebrity_face_embeddings):
similarity = self.cosine_similarity(user_face_embedding, celebrity_embedding)
if similarity > max_similarity:
max_similarity = similarity
most_similar_celebrity_index = i
return most_similar_celebrity_index, max_similarity
def predict_celebrity(image):
with torch.no_grad():
cropped_image = crop_face(image)
if cropped_image is None: # None์ธ ๊ฒฝ์ฐ ์ฒดํฌ
return [None, 0] # celebrity_initial ๋ฐ ์ ํ๋๋ฅผ None, 0์ผ๋ก ์ค์
else:
image = preprocess(image=cropped_image)["image"]
image = image.float() / 255.0
image = image.unsqueeze(0).to(device)
user_face_embedding = model(image).squeeze()
closest_celebrity, max_similarity = model.find_most_similar_celebrity(
user_face_embedding, trained_embeddings
)
return [closest_celebrity, max_similarity.item()]
์ต์ข ์ ์ผ๋ก, ๊ฐ์ฅ ์ ์ฌํ celebrity initial(์ด๋์ )๊ณผ accracy(์ ์ฌ๋)๋ฅผ json ํ์์ผ๋ก ๋ฐํํ๋ค.
@app.route("/flask/predict", methods=["POST"])
@cross_origin("*")
def predict():
try:
image_data = request.form.get("image")
image_decoded = base64.b64decode(image_data)
nparr = np.frombuffer(image_decoded, np.uint8)
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
prediction = predict_celebrity(image)
celebrity_initial = get_initial(prediction[0])
print("์ด๋์
:", celebrity_initial, "์ ํ๋:", prediction[1])
return jsonify(
{"celebrity_initial": celebrity_initial, "accuracy": prediction[1]}
)
except Exception as e:
print(e)
return jsonify({"error": "Error occurred during prediction"}), 500
4. ๋ฐฐํฌ
๋ฐฐํฌ๋ ๋ฐฑ์๋๋ฅผ ๋งก๋ ๋๋ฃ๊ฐ ๋์์ฃผ์๊ณ , ๋๋ docker ํ์ผ์ ๋ง๋ค์ด buildํ๋ ์์ ๊น์ง๋ง ์ํํ๋ค.
์์ฑํ dockerfile์ ๋ค์๊ณผ ๊ฐ๋ค:
FROM python:3.10.9
COPY . /deep-look-ai
WORKDIR /deep-look-ai
RUN python3 -m pip install --upgrade pip
RUN pip3 install -r requirements.txt
RUN apt-get update
RUN apt-get -y install libgl1-mesa-glx
CMD ["python3", "-m", "flask_app", "run", "--host=0.0.0.0", "--port=6000"]
์ด๋ ๊ฒ, ์ฑ๊ณต์ ์ผ๋ก ๋ฐฑ์๋์ ์ฐ๊ฒฐํ ์ ์๊ฒ ๋์๋ค.
์ด ๊ณผ์ ์์ ๊ฐ์ ํ ๋ฌธ์ ์ ์, POST api์ return ๊ฐ์ผ๋ก ์์ธก ๊ฒฐ๊ณผ๊ฐ์ ์ฃผ์๋ค๋ ์ ์ด๋ค.
์ด ๋ถ๋ถ์ GET api๋ก ๋ฐ๊พธ์ด ์ ๋ฌํด ์ฃผ์๋ค๋ฉด, HTTP ํ๋กํ ์ฝ ๊ท์น์ ์ข ๋ ์ ์งํฌ ์ ์์์ ๊ฒ ๊ฐ๋ค.
๋ค์์ ์ดํ ์ํฉ์ ๋ํ ์๋ฌ ํธ๋ค๋ง์ ๋ํด ๋ค๋ค๋ณผ ์์ ์ด๋ค.
'๐ป ํ๋ก์ ํธ > ๐งธ TOY-PROJECTS' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
[ํ ์ดํ๋ก์ ํธ-๊ณต๊ฐ ์ฑ๋ด] ํ๋ก์ ํธ ๊ฐ์ (0) | 2023.07.05 |
---|---|
[DeepLook] 6. ์ต์ข ๊ฒฐ๊ณผ๋ฌผ, ์ดํ ์๋ฌ ํธ๋ค๋ง๊ณผ ๋ง๋ฌด๋ฆฌ (2) | 2023.06.21 |
[DeepLook] 4. ๋ชจ๋ธ ์ ์ ๋ฐ ํ์ต (0) | 2023.06.21 |
[DeepLook] 3. ์ ์ฒ๋ฆฌ (haar-cascade ์๊ณ ๋ฆฌ์ฆ) (0) | 2023.06.20 |
[DeepLook] 2. AI ์์ ์ค๊ณ ๊ณผ์ / ํฌ๋กค๋ง (0) | 2023.06.20 |