深度有趣 | 20 CycleGAN性別轉換
介紹可用於實現多種非配對影象翻譯任務的CycleGAN模型,並完成性別轉換任務
原理
和pix2pix不同,CycleGAN不需要嚴格配對的圖片,只需要兩類(domain)即可,例如一個資料夾都是蘋果圖片,另一個資料夾都是橘子圖片
使用A和B兩類圖片,就可以實現A到B的翻譯和B到A的翻譯
論文官方網站上提供了詳細的例子和介紹, ofollow,noindex">junyanz.github.io/CycleGAN/ ,例如蘋果和橘子、馬和斑馬、夏天和冬天、照片和藝術作品等

以及論文的官方Github專案, github.com/junyanz/Cyc… ,使用PyTorch實現
CycleGAN由兩個生成器G和F,以及兩個判別器Dx和Dy組成

G接受真的X並輸出假的Y,即完成X到Y的翻譯;F接受真的Y並輸出假的X,即完成Y到X的翻譯;Dx接受真假X並進行判別,Dy接受真假Y並進行判別
CycleGAN的損失函式和標準GAN差不多,只是寫兩套而已
除此之外,為了避免mode collasp問題,CycleGAN還考慮了迴圈一致損失(Cycle Consistency Loss)
因此CycleGAN的總損失如下,G、F、Dx、Dy分別需要min、max其中的部分損失項
實現
在論文的具體實現中,使用了兩個tricks
- 使用Least-Square Loss即最小平方誤差代替標準的GAN損失
- 以G為例,維護一個歷史假Y圖片集合,例如50張。每次G生成假Y之後將其加到集合中,再從集合中隨機地取出一張假Y,和一張真Y一起輸入給判別器進行判別。這樣一來,假Y集合代表了G根據X生成Y的平均能力,使得訓練更加穩定
使用以下專案訓練CycleGAN模型, TensorFlow" rel="nofollow,noindex">github.com/vanhuyz/Cyc… ,主要包括幾個程式碼:
-
build_data.py
:將圖片資料整理為tfrecords檔案 -
ops.py
:定義了一些小的網路模組 -
generator.py
:生成器的定義 -
discriminator.py
:判別器的定義 -
model.py
:使用生成器和判別器定義CycleGAN -
train.py
:訓練模型的程式碼 -
export_graph.py
:將訓練好的模型打包成.pd
檔案 -
inference.py
:使用打包好的.pb
檔案翻譯圖片,即使用模型進行推斷
生成器和判別器結構如下,如果感興趣可以進一步閱讀專案原始碼

性別轉換
使用CelebA中的男性圖片和女性圖片,訓練一個實現性別轉換的CycleGAN
將CelebA資料集中的圖片處理成 256*256
大小,並按照性別儲存至male和female兩個資料夾,分別包含84434張男性圖片和118165張女性圖片
# -*- coding: utf-8 -*- from imageio import imread, imsave import cv2 import glob, os from tqdm import tqdm data_dir = 'data' male_dir = 'data/male' female_dir = 'data/female' if not os.path.exists(data_dir): os.mkdir(data_dir) if not os.path.exists(male_dir): os.mkdir(male_dir) if not os.path.exists(female_dir): os.mkdir(female_dir) WIDTH = 256 HEIGHT = 256 def read_process_save(read_path, save_path): image = imread(read_path) h = image.shape[0] w = image.shape[1] if h > w: image = image[h // 2 - w // 2: h // 2 + w // 2, :, :] else: image = image[:, w // 2 - h // 2: w // 2 + h // 2, :] image = cv2.resize(image, (WIDTH, HEIGHT)) imsave(save_path, image) target = 'Male' with open('list_attr_celeba.txt', 'r') as fr: lines = fr.readlines() all_tags = lines[0].strip('\n').split() for i in tqdm(range(1, len(lines))): line = lines[i].strip('\n').split() if int(line[all_tags.index(target) + 1]) == 1: read_process_save(os.path.join('celeba', line[0]), os.path.join(male_dir, line[0])) # 男 else: read_process_save(os.path.join('celeba', line[0]), os.path.join(female_dir, line[0])) # 女 複製程式碼
使用 build_data.py
將圖片轉換成tfrecords格式
python CycleGAN-TensorFlow/build_data.py --X_input_dir data/male/ --Y_input_dir data/female/ --X_output_file data/male.tfrecords --Y_output_file data/female.tfrecords 複製程式碼
使用 train.py
訓練CycleGAN模型
python CycleGAN-TensorFlow/train.py --X data/male.tfrecords --Y data/female.tfrecords --image_size 256 複製程式碼
訓練開始後,會生成checkpoints資料夾,並根據當前日期和時間生成一個子資料夾,例如 20180507-0231
,其中包括用於顯示tensorboard的 events.out.tfevents
檔案,以及和模型相關的一些檔案
使用tensorboard檢視模型訓練細節,執行以下命令後訪問6006埠即可
tensorboard --logdir=checkpoints/20180507-0231 複製程式碼
以下是迭代185870次之後,tensorboard的IMAGES頁面

模型訓練沒有迭代次數限制,所以感覺效果不錯或者迭代次數差不多了,便可以終止訓練
使用 export_graph.py
將模型打包成 .pb
檔案,生成的檔案在pretrained資料夾中
python CycleGAN-TensorFlow/export_graph.py --checkpoint_dir checkpoints/20180507-0231/ --XtoY_model male2female.pb --YtoX_model female2male.pb --image_size 256 複製程式碼
通過 inference.py
使用模型處理圖片
python CycleGAN-TensorFlow/inference.py --model pretrained/male2female.pb --input Trump.jpg --output Trump_female.jpg --image_size 256 複製程式碼
python CycleGAN-TensorFlow/inference.py --model pretrained/female2male.pb --input Hillary.jpg --output Hillary_male.jpg --image_size 256 複製程式碼
在程式碼中使用模型處理多張圖片
# -*- coding: utf-8 -*- import tensorflow as tf import numpy as np from model import CycleGAN from imageio import imread, imsave import glob import os image_file = 'face.jpg' W = 256 result = np.zeros((4 * W, 5 * W, 3)) for gender in ['male', 'female']: if gender == 'male': images = glob.glob('../faces/male/*.jpg') model = '../pretrained/male2female.pb' r = 0 else: images = glob.glob('../faces/female/*.jpg') model = '../pretrained/female2male.pb' r = 2 graph = tf.Graph() with graph.as_default(): graph_def = tf.GraphDef() with tf.gfile.FastGFile(model, 'rb') as model_file: graph_def.ParseFromString(model_file.read()) tf.import_graph_def(graph_def, name='') with tf.Session(graph=graph) as sess: input_tensor = graph.get_tensor_by_name('input_image:0') output_tensor = graph.get_tensor_by_name('output_image:0') for i, image in enumerate(images): image = imread(image) output = sess.run(output_tensor, feed_dict={input_tensor: image}) with open(image_file, 'wb') as f: f.write(output) output = imread(image_file) maxv = np.max(output) minv = np.min(output) output = ((output - minv) / (maxv - minv) * 255).astype(np.uint8) result[r * W: (r + 1) * W, i * W: (i + 1) * W, :] = image result[(r + 1) * W: (r + 2) * W, i * W: (i + 1) * W, :] = output os.remove(image_file) imsave('CycleGAN性別轉換結果.jpg', result) 複製程式碼

視訊性別轉換
對一段視訊,識別每一幀可能包含的人臉,檢測人臉對應的性別,並使用CycleGAN完成性別的雙向轉換
使用以下專案實現性別的檢測, github.com/yu4u/age-ge… ,通過Keras訓練模型,可以檢測出人臉的性別和年齡
舉個例子,使用OpenCV獲取攝像頭圖片,通過dlib檢測人臉,並得到每一個檢測結果對應的年齡和性別
# -*- coding: utf-8 -*- from wide_resnet import WideResNet import numpy as np import cv2 import dlib depth = 16 width = 8 img_size = 64 model = WideResNet(img_size, depth=depth, k=width)() model.load_weights('weights.hdf5') def draw_label(image, point, label, font=cv2.FONT_HERSHEY_SIMPLEX, font_scale=1, thickness=2): size = cv2.getTextSize(label, font, font_scale, thickness)[0] x, y = point cv2.rectangle(image, (x, y - size[1]), (x + size[0], y), (255, 0, 0), cv2.FILLED) cv2.putText(image, label, point, font, font_scale, (255, 255, 255), thickness) detector = dlib.get_frontal_face_detector() cap = cv2.VideoCapture(0) cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640) cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480) while True: ret, image_np = cap.read() image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB) img_h = image_np.shape[0] img_w = image_np.shape[1] detected = detector(image_np, 1) faces = [] if len(detected) > 0: for i, d in enumerate(detected): x0, y0, x1, y1, w, h = d.left(), d.top(), d.right(), d.bottom(), d.width(), d.height() cv2.rectangle(image_np, (x0, y0), (x1, y1), (255, 0, 0), 2) x0 = max(int(x0 - 0.25 * w), 0) y0 = max(int(y0 - 0.45 * h), 0) x1 = min(int(x1 + 0.25 * w), img_w - 1) y1 = min(int(y1 + 0.05 * h), img_h - 1) w = x1 - x0 h = y1 - y0 if w > h: x0 = x0 + w // 2 - h // 2 w = h x1 = x0 + w else: y0 = y0 + h // 2 - w // 2 h = w y1 = y0 + h faces.append(cv2.resize(image_np[y0: y1, x0: x1, :], (img_size, img_size))) faces = np.array(faces) results = model.predict(faces) predicted_genders = results[0] ages = np.arange(0, 101).reshape(101, 1) predicted_ages = results[1].dot(ages).flatten() for i, d in enumerate(detected): label = '{}, {}'.format(int(predicted_ages[i]), 'F' if predicted_genders[i][0] > 0.5 else 'M') draw_label(image_np, (d.left(), d.top()), label) cv2.imshow('gender and age', cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)) if cv2.waitKey(25) & 0xFF == ord('q'): cap.release() cv2.destroyAllWindows() break 複製程式碼
將以上專案和CycleGAN應用於視訊的雙向性別轉換,首先提取出視訊中的人臉,記錄人臉出現的幀數、位置以及對應的性別,視訊共830幀,檢測出721張人臉
# -*- coding: utf-8 -*- from wide_resnet import WideResNet import numpy as np import cv2 import dlib import pickle depth = 16 width = 8 img_size = 64 model = WideResNet(img_size, depth=depth, k=width)() model.load_weights('weights.hdf5') detector = dlib.get_frontal_face_detector() cap = cv2.VideoCapture('../friends.mp4') pos = [] frame_id = -1 while cap.isOpened(): ret, image_np = cap.read() frame_id += 1 if len((np.array(image_np)).shape) == 0: break image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB) img_h = image_np.shape[0] img_w = image_np.shape[1] detected = detector(image_np, 1) if len(detected) > 0: for d in detected: x0, y0, x1, y1, w, h = d.left(), d.top(), d.right(), d.bottom(), d.width(), d.height() x0 = max(int(x0 - 0.25 * w), 0) y0 = max(int(y0 - 0.45 * h), 0) x1 = min(int(x1 + 0.25 * w), img_w - 1) y1 = min(int(y1 + 0.05 * h), img_h - 1) w = x1 - x0 h = y1 - y0 if w > h: x0 = x0 + w // 2 - h // 2 w = h x1 = x0 + w else: y0 = y0 + h // 2 - w // 2 h = w y1 = y0 + h face = cv2.resize(image_np[y0: y1, x0: x1, :], (img_size, img_size)) result = model.predict(np.array([face])) pred_gender = result[0][0][0] if pred_gender > 0.5: pos.append([frame_id, y0, y1, x0, x1, h, w, 'F']) else: pos.append([frame_id, y0, y1, x0, x1, h, w, 'M']) print(frame_id + 1, len(pos)) with open('../pos.pkl', 'wb') as fw: pickle.dump(pos, fw) cap.release() cv2.destroyAllWindows() 複製程式碼
再使用CycleGAN,將原視訊中出現的人臉轉換成相反的性別,並寫入新的視訊檔案
# -*- coding: utf-8 -*- import tensorflow as tf import numpy as np from model import CycleGAN from imageio import imread import os import cv2 import pickle from tqdm import tqdm with open('../pos.pkl', 'rb') as fr: pos = pickle.load(fr) cap = cv2.VideoCapture('../friends.mp4') ret, image_np = cap.read() out = cv2.VideoWriter('../output.mp4', -1, cap.get(cv2.CAP_PROP_FPS), (image_np.shape[1], image_np.shape[0])) frames = [] while cap.isOpened(): ret, image_np = cap.read() if len((np.array(image_np)).shape) == 0: break frames.append(cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)) image_size = 256 image_file = 'face.jpg' for gender in ['M', 'F']: if gender == 'M': model = '../pretrained/male2female.pb' else: model = '../pretrained/female2male.pb' graph = tf.Graph() with graph.as_default(): graph_def = tf.GraphDef() with tf.gfile.FastGFile(model, 'rb') as model_file: graph_def.ParseFromString(model_file.read()) tf.import_graph_def(graph_def, name='') with tf.Session(graph=graph) as sess: input_tensor = graph.get_tensor_by_name('input_image:0') output_tensor = graph.get_tensor_by_name('output_image:0') for i in tqdm(range(len(pos))): fid, y0, y1, x0, x1, h, w, g = pos[i] if g == gender: face = cv2.resize(frames[fid - 1][y0: y1, x0: x1, :], (image_size, image_size)) output_face = sess.run(output_tensor, feed_dict={input_tensor: face}) with open(image_file, 'wb') as f: f.write(output_face) output_face = imread(image_file) maxv = np.max(output_face) minv = np.min(output_face) output_face = ((output_face - minv) / (maxv - minv) * 255).astype(np.uint8) output_face = cv2.resize(output_face, (w, h)) frames[fid - 1][y0: y1, x0: x1, :] = output_face for frame in frames: out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) os.remove(image_file) cap.release() out.release() cv2.destroyAllWindows() 複製程式碼
生成的視訊檔案只有影象、沒有聲音,可以使用 ffmpeg
進一步處理
如果沒有 ffmpeg
則下載並安裝, www.ffmpeg.org/download.ht…
進入命令列,從原始視訊中提取音訊
ffmpeg -i friends.mp4 -f mp3 -vn sound.mp3 複製程式碼
將提取的音訊和生成的視訊合成在一起
ffmpeg -i output.mp4 -i sound.mp3 combine.mp4 複製程式碼
其他
專案還提供了四個訓練好的模型, github.com/vanhuyz/Cyc… ,包括蘋果到橘子、橘子到蘋果、馬到斑馬、斑馬到馬,如果感興趣可以嘗試一下
用CycleGAN不僅可以完成兩類圖片之間的轉換,也可以實現兩個物體之間的轉換,例如將一個人翻譯成另一個人
可以考慮從一部電影中提取出兩個角色對應的圖片,訓練CycleGAN之後,即可將一個人翻譯成另一個人
還有一些比較大膽的嘗試, 提高駕駛技術:用GAN去除(愛情)動作片中的馬賽克和衣服