1. 程式人生 > >facenet:triplet-loss理解與train_tripletloss.py程式碼理解


損失函式確定好之後如何在訓練時尋找anchor對應的negative樣本和positive樣本成為一個要著重考慮的問題。理論上講,我們應當選擇hard positive與hard negative,如下論文所示。


對於如何訓練該模型,可以參考github上 https://github.com/davidsandberg/facenet 的實現(tf版本),以下我做了一些關於train_tripletloss.py的註釋,註釋為個人理解,有不對的地方請大家指正,其中有些部分參考(

http://www.mamicode.com/info-detail-2096766.html

"""Training a face recognizer with TensorFlow based on the FaceNet paper
FaceNet: A Unified Embedding for Face Recognition and Clustering: http://arxiv.org/abs/1503.03832
Create the model directory if it doesn't exist os.makedirs(model_dir) # Write arguments to a text file facenet.write_arguments_to_file(args, os.path.join(log_dir, 'arguments.txt')) # Store some git revision info in a text file in the log directory src_path,_ = os.path.split(os.path.realpath(__file__)) facenet.store_revision_info(src_path, log_dir, ' '.join(sys.argv)) np.random.seed(seed=args.seed) train_set = facenet.get_dataset(args.data_dir) print('Model directory: %s' % model_dir) print('Log directory: %s' % log_dir) if args.pretrained_model: print('Pre-trained model: %s' % os.path.expanduser(args.pretrained_model)) if args.lfw_dir: print('LFW directory: %s' % args.lfw_dir) # Read the file containing the pairs used for testing pairs = lfw.read_pairs(os.path.expanduser(args.lfw_pairs)) # Get the paths for the corresponding images lfw_paths, actual_issame = lfw.get_paths(os.path.expanduser(args.lfw_dir), pairs) with tf.Graph().as_default(): tf.set_random_seed(args.seed) global_step = tf.Variable(0, trainable=False) # Placeholder for the learning rate learning_rate_placeholder = tf.placeholder(tf.float32, name='learning_rate') batch_size_placeholder = tf.placeholder(tf.int32, name='batch_size') phase_train_placeholder = tf.placeholder(tf.bool, name='phase_train') image_paths_placeholder = tf.placeholder(tf.string, shape=(None,3), name='image_paths') labels_placeholder = tf.placeholder(tf.int64, shape=(None,3), name='labels') input_queue = data_flow_ops.FIFOQueue(capacity=100000, dtypes=[tf.string, tf.int64], shapes=[(3,), (3,)], shared_name=None, name=None) enqueue_op = input_queue.enqueue_many([image_paths_placeholder, labels_placeholder]) nrof_preprocess_threads = 4 images_and_labels = [] for _ in range(nrof_preprocess_threads): filenames, label = input_queue.dequeue() images = [] for filename in tf.unstack(filenames): file_contents = tf.read_file(filename) image = tf.image.decode_image(file_contents, channels=3) if args.random_crop: image = tf.random_crop(image, [args.image_size, args.image_size, 3]) else: image = tf.image.resize_image_with_crop_or_pad(image, args.image_size, args.image_size) if args.random_flip: image = tf.image.random_flip_left_right(image) #pylint: disable=no-member image.set_shape((args.image_size, args.image_size, 3)) images.append(tf.image.per_image_standardization(image)) images_and_labels.append([images, label]) image_batch, labels_batch = tf.train.batch_join( images_and_labels, batch_size=batch_size_placeholder, shapes=[(args.image_size, args.image_size, 3), ()], enqueue_many=True, capacity=4 * nrof_preprocess_threads * args.batch_size, allow_smaller_final_batch=True) image_batch = tf.identity(image_batch, 'image_batch') image_batch = tf.identity(image_batch, 'input') labels_batch = tf.identity(labels_batch, 'label_batch') # Build the inference graph prelogits, _ = network.inference(image_batch, args.keep_probability, phase_train=phase_train_placeholder, bottleneck_layer_size=args.embedding_size, weight_decay=args.weight_decay) embeddings = tf.nn.l2_normalize(prelogits, 1, 1e-10, name='embeddings') # Split embeddings into anchor, positive and negative and calculate triplet loss anchor, positive, negative = tf.unstack(tf.reshape(embeddings, [-1,3,args.embedding_size]), 3, 1) triplet_loss = facenet.triplet_loss(anchor, positive, negative, args.alpha) learning_rate = tf.train.exponential_decay(learning_rate_placeholder, global_step, args.learning_rate_decay_epochs*args.epoch_size, args.learning_rate_decay_factor, staircase=True) tf.summary.scalar('learning_rate', learning_rate) # Calculate the total losses regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) total_loss = tf.add_n([triplet_loss] + regularization_losses, name='total_loss') # Build a Graph that trains the model with one batch of examples and updates the model parameters train_op = facenet.train(total_loss, global_step, args.optimizer, learning_rate, args.moving_average_decay, tf.global_variables()) # Create a saver saver = tf.train.Saver(tf.trainable_variables(), max_to_keep=3) # Build the summary operation based on the TF collection of Summaries. summary_op = tf.summary.merge_all() # Start running operations on the Graph. gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_memory_fraction) sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) # Initialize variables sess.run(tf.global_variables_initializer(), feed_dict={phase_train_placeholder:True}) sess.run(tf.local_variables_initializer(), feed_dict={phase_train_placeholder:True}) summary_writer = tf.summary.FileWriter(log_dir, sess.graph) coord = tf.train.Coordinator() tf.train.start_queue_runners(coord=coord, sess=sess) with sess.as_default(): if args.pretrained_model: print('Restoring pretrained model: %s' % args.pretrained_model) saver.restore(sess, os.path.expanduser(args.pretrained_model)) # Training and validation loop epoch = 0 while epoch < args.max_nrof_epochs: step = sess.run(global_step, feed_dict=None) epoch = step // args.epoch_size # Train for one epoch train(args, sess, train_set, epoch, image_paths_placeholder, labels_placeholder, labels_batch, batch_size_placeholder, learning_rate_placeholder, phase_train_placeholder, enqueue_op, input_queue, global_step, embeddings, total_loss, train_op, summary_op, summary_writer, args.learning_rate_schedule_file, args.embedding_size, anchor, positive, negative, triplet_loss) # Save variables and the metagraph if it doesn't exist already save_variables_and_metagraph(sess, saver, summary_writer, model_dir, subdir, step) # Evaluate on LFW if args.lfw_dir: evaluate(sess, lfw_paths, embeddings, labels_batch, image_paths_placeholder, labels_placeholder, batch_size_placeholder, learning_rate_placeholder, phase_train_placeholder, enqueue_op, actual_issame, args.batch_size, args.lfw_nrof_folds, log_dir, step, summary_writer, args.embedding_size) return model_dir def train(args, sess, dataset, epoch, image_paths_placeholder, labels_placeholder, labels_batch, batch_size_placeholder, learning_rate_placeholder, phase_train_placeholder, enqueue_op, input_queue, global_step, embeddings, loss, train_op, summary_op, summary_writer, learning_rate_schedule_file, embedding_size, anchor, positive, negative, triplet_loss): batch_number = 0 if args.learning_rate>0.0: lr = args.learning_rate else: lr = facenet.get_learning_rate_from_file(learning_rate_schedule_file, epoch) while batch_number < args.epoch_size: # Sample people randomly from the dataset #從dataset中隨機抽取樣本資料 image_paths, num_per_class = sample_people(dataset, args.people_per_batch, args.images_per_person) print('Running forward pass on sampled images: ', end='') start_time = time.time() #樣本個數nrof_examples=batch中的人數people_per_batch×每個人對應的照片數量images_per_person nrof_examples = args.people_per_batch * args.images_per_person labels_array = np.reshape(np.arange(nrof_examples),(-1,3)) image_paths_array = np.reshape(np.expand_dims(np.array(image_paths),1), (-1,3)) sess.run(enqueue_op, {image_paths_placeholder: image_paths_array, labels_placeholder: labels_array}) #emb_array為所有樣本對應的embedding特徵資訊 emb_array = np.zeros((nrof_examples, embedding_size)) #需要執行的batch數量nrof_batches為總樣本數除以batch_size nrof_batches = int(np.ceil(nrof_examples / args.batch_size)) for i in range(nrof_batches): batch_size = min(nrof_examples-i*args.batch_size, args.batch_size) #喂資料,得到一個batch的emb和lab資訊 emb, lab = sess.run([embeddings, labels_batch], feed_dict={batch_size_placeholder: batch_size, learning_rate_placeholder: lr, phase_train_placeholder: True}) emb_array[lab,:] = emb print('%.3f' % (time.time()-start_time)) # Select triplets based on the embeddings print('Selecting suitable triplets for training') #select_triplets用來找尋triplets triplets, nrof_random_negs, nrof_triplets = select_triplets(emb_array, num_per_class, image_paths, args.people_per_batch, args.alpha) selection_time = time.time() - start_time print('(nrof_random_negs, nrof_triplets) = (%d, %d): time=%.3f seconds' % (nrof_random_negs, nrof_triplets, selection_time)) # Perform training on the selected triplets #nrof_batches代表需要執行的batch數=nrof_triplets×3(圖片數量)/batch_size nrof_batches = int(np.ceil(nrof_triplets*3/args.batch_size)) #triplet_paths表示所有選中的圖片路徑 triplet_paths = list(itertools.chain(*triplets)) labels_array = np.reshape(np.arange(len(triplet_paths)),(-1,3)) triplet_paths_array = np.reshape(np.expand_dims(np.array(triplet_paths),1), (-1,3)) sess.run(enqueue_op, {image_paths_placeholder: triplet_paths_array, labels_placeholder: labels_array}) nrof_examples = len(triplet_paths) train_time = 0 i = 0 emb_array = np.zeros((nrof_examples, embedding_size)) loss_array = np.zeros((nrof_triplets,)) summary = tf.Summary() step = 0 #不斷執行batch while i < nrof_batches: start_time = time.time() batch_size = min(nrof_examples-i*args.batch_size, args.batch_size) feed_dict = {batch_size_placeholder: batch_size, learning_rate_placeholder: lr, phase_train_placeholder: True} err, _, step, emb, lab = sess.run([loss, train_op, global_step, embeddings, labels_batch], feed_dict=feed_dict) emb_array[lab,:] = emb loss_array[i] = err duration = time.time() - start_time print('Epoch: [%d][%d/%d]\tTime %.3f\tLoss %2.3f' % (epoch, batch_number+1, args.epoch_size, duration, err)) batch_number += 1 i += 1 train_time += duration summary.value.add(tag='loss', simple_value=err) # Add validation loss and accuracy to summary #pylint: disable=maybe-no-member summary.value.add(tag='time/selection', simple_value=selection_time) summary_writer.add_summary(summary, step) return step def select_triplets(embeddings, nrof_images_per_class, image_paths, people_per_batch, alpha): """ Select the triplets for training """ trip_idx = 0 #某個人的圖片的embedding在emb_arr中的開始索引 emb_start_idx = 0 num_trips = 0 triplets = [] # VGG Face: Choosing good triplets is crucial and should strike a balance between # selecting informative (i.e. challenging) examples and swamping training with examples that # are too hard. This is achieve by extending each pair (a, p) to a triplet (a, p, n) by sampling # the image n at random, but only between the ones that violate the triplet loss margin. The # latter is a form of hard-negative mining, but it is not as aggressive (and much cheaper) than # choosing the maximally violating example, as often done in structured output learning. #遍歷每一個人 for i in xrange(people_per_batch): #這個人有多少張圖片 nrof_images = int(nrof_images_per_class[i]) #遍歷第i個人所有照片對應的embedding特徵 for j in xrange(1,nrof_images): #a_idx表示第j張圖片在emb_arr中的位置 a_idx = emb_start_idx + j - 1 #neg_dists_sqr表示這張圖片跟其他所有圖片的歐式距離 neg_dists_sqr = np.sum(np.square(embeddings[a_idx] - embeddings),