1. 程式人生 > >圖融合之加載子圖:Tensorflow.contrib.slim與tf.train.Saver之坑


sce npe ids import print filters cpu xxx true

import tensorflow as tf
import tensorflow.contrib.slim as slim

import rawpy
import numpy as np
import tensorflow as tf
import struct
import glob
import os
from PIL import Image
import time

__sony__ = 0
__huawei__ = 1
__blackberry__ = 2

__stage_raw2raw__ = 0
__stage_raw2rgb__ = 1
= 2 train_prefix = 0 valid_prefix = 1 test_prefix = 2 # ============ CONFIGURATION ============ USE_GPU = False if USE_GPU: os.environ[CUDA_VISIBLE_DEVICES] = 2 # change this to switch between datasets source_id = __sony__ # switch between training stages training_stage = __stage_raw2rgb__
# patch size should be set on running patch_size = (512, 512) #patch_size = (2840, 4248) # switch between training and validation current_prefix = train_prefix # model saving settings max_epoch = 2000 save_epoch_delay = 1 model_dir = ./result_raw2raw/ out_dir = ./output_raw2raw/ log_dir = ./log_raw2raw/
learn_rate = 1e-2 # ============ CONFIGURATION ============ if source_id == __blackberry__: WHITE_LEVEL = 1023 BLACK_LEVEL = 64 HEIGHT = 3024 WIDTH = 4032 elif source_id == __sony__: WHITE_LEVEL = 16383 BLACK_LEVEL = 512 HEIGHT = 2848 WIDTH = 4256 elif source_id == __huawei__: WHITE_LEVEL = 1023 BLACK_LEVEL = 64 HEIGHT = 2976 WIDTH = 3968 if USE_GPU: data_dir = ../see_in_the_dark/dataset/Sony_small/ else: data_dir = D:/data/Sony_small/ # !!!!!! DO NOT TOUCH THIS SETTING !!!!!! fixed_size = (128, 128) num_of_denoise_filter = 3 standard_brightness = 0.1 # !!!!!! DO NOT TOUCH THIS SETTING !!!!!! def has_nan_in_tensor(x): return np.sum(x != x) > 0 def raw_from_file(path): if source_id == __sony__: data = rawpy.imread(path) raw = data.raw_image_visible.astype(np.float32) raw = raw.reshape(2848, 4256) # convert from RGBG into standard GRGB format: # cut the strips of left and right borders h, w = raw.shape[0], raw.shape[1] return np.reshape(raw[:, 1:w-1], [h, w-2, 1]) elif source_id == __huawei__: data = rawpy.imread(path) raw = data.raw_image_visible.astype(np.float32) raw = raw.reshape(2976, 3968) # convert from BGRG into standard GRGB format: # cut the strips of top and bottom borders h, w = raw.shape[0], raw.shape[1] return np.reshape(raw[1:h-1, :], [h-2, w, 1]) elif source_id == __blackberry__: data = open(path, rb).read() data = struct.unpack(H*int(len(data)/2), data) raw = np.float32(data) raw = raw.reshape(3024, 4032) h, w = raw.shape[0], raw.shape[1] return np.reshape(raw, [h, w, 1]) else: assert False def rgb_from_file(path): if source_id == __sony__: raw = rawpy.imread(path) rgb = np.float32( raw.postprocess( use_camera_wb=True, half_size=False, no_auto_bright=True, output_bps=16 ) ) / 65535.0 return rgb[:, 1:-1, :] elif source_id == __huawei__: raw = rawpy.imread(path) rgb = np.float32( raw.postprocess( use_camera_wb=True, half_size=False, no_auto_bright=True, output_bps=16 ) ) / 65535.0 return rgb[1:-1, :, :] else: raise NameError(file type [%d] does not support rawpy! % source_id) def black_level_correction(bayer): with tf.name_scope(black_level_corr): r = 1.0/(WHITE_LEVEL-BLACK_LEVEL) return tf.nn.relu((bayer - BLACK_LEVEL)*r) def bound(bayer): return tf.minimum(tf.maximum(bayer, 0), 1) def bayer_to_rgb(bayer): with tf.name_scope(bayer2rgb): filters = np.array([ [0.0, 1.0, 0.0, 0.0], # R [0.5, 0.0, 0.0, 0.5], # (G1+G2)/2 [0.0, 0.0, 1.0, 0.0], # B ]).reshape([1, 3, 2, 2]).transpose([2, 3, 0, 1]) return tf.nn.conv2d( bayer, filters, strides=(1, 2, 2, 1), padding=VALID, name=bayer_converter ) def demosaic(rgb): with tf.name_scope(demosaic): return tf.image.resize_bilinear(rgb, patch_size) def color_correction(rgb, color_matrix): with tf.name_scope(color_corr): filters = tf.reshape(color_matrix, [1, 1, 3, 3]) return tf.nn.conv2d(rgb, filters, (1, 1, 1, 1), SAME, name=output) def min_max_normalize(rgb): _min = tf.reduce_min(rgb) _max = tf.reduce_max(rgb) return (rgb - _min + 1e-8)/(_max - _min + 1e-8) def gaussian_norm(rgb): _mean = tf.reduce_mean(rgb) _vari = tf.sqrt(tf.reduce_mean(tf.square(rgb-_mean))) return (rgb-_mean)/_vari # not supported on SNPE, so do it on cpu of mobile phone # in case of negative value, normalize it before power operation def gamma_correction(rgb, gamma): with tf.name_scope(gamma_corr): return tf.pow(min_max_normalize(rgb), gamma) def lrelu(x): return tf.maximum(x*0.2, x) def network_raw2raw(inputs): with tf.name_scope(raw2raw): net = slim.conv2d(inputs, 32, [3, 3], rate=1, activation_fn=lrelu, weights_initializer=tf.initializers.constant, scope=g_conv1) net = slim.conv2d(net, 32, [3, 3], rate=2, activation_fn=lrelu, weights_initializer=tf.initializers.constant, scope=g_conv2) net = slim.conv2d(net, 32, [3, 3], rate=4, activation_fn=lrelu, weights_initializer=tf.initializers.constant, scope=g_conv3) net = slim.conv2d(net, 32, [3, 3], rate=8, activation_fn=lrelu, weights_initializer=tf.initializers.constant, scope=g_conv4) net = slim.conv2d(net, 32, [3, 3], rate=16, activation_fn=lrelu, weights_initializer=tf.initializers.constant, scope=g_conv5) net = slim.conv2d(net, 1, [1, 1], rate=1, activation_fn=None, scope=g_conv_last) return net def show(rgb, title): im = Image.fromarray(np.uint8(rgb * 255)) im.show(title) def save(rgb, path): im = Image.fromarray(np.uint8(rgb * 255)) im.save(path) def concat(ims): return np.concatenate(ims, axis=1) def get_color_matrix_and_gamma(bayer): with tf.name_scope(isp_param_gen): with tf.name_scope(common_extractor): channels = tf.layers.conv2d(bayer, 3, kernel_size=3, strides=2, padding=valid) activations = tf.nn.tanh(channels) channels = tf.layers.conv2d(activations, 5, kernel_size=3, strides=2, padding=valid) activations = tf.nn.relu(channels) with tf.name_scope(color_matrix): channels_cm = tf.layers.conv2d(activations, 7, kernel_size=3, strides=2, padding=valid) activations_cm = tf.nn.tanh(channels_cm) channels_cm = tf.layers.conv2d(activations_cm, 5, kernel_size=3, strides=2, padding=valid) channels_flat_cm = tf.reshape( channels_cm, [-1, channels_cm.shape[1]*channels_cm.shape[2]*channels_cm.shape[3]]) color_matrix = tf.reshape(tf.layers.dense(channels_flat_cm, 9), [3, 3]) with tf.name_scope(gamma): channels_gamma = tf.layers.conv2d(activations, 7, kernel_size=3, strides=2, padding=valid) activations_gama = tf.nn.tanh(channels_gamma) channels_gamma = tf.layers.conv2d(activations_gama, 5, kernel_size=3, strides=2, padding=valid) channels_flat_gamma = tf.reshape( channels_gamma, [-1, channels_gamma.shape[1] * channels_gamma.shape[2] * channels_gamma.shape[3]]) gamma = tf.reshape(tf.maximum(tf.layers.dense(channels_flat_gamma, 1), 1e-3), [1]) return color_matrix, gamma def build_isp_process_flow(bayer, color_matrix, gamma): with tf.name_scope(isp_flow): return gamma_correction( color_correction( demosaic( bayer ), color_matrix ), gamma ) # in form of NHWC def color_normalize(rgb): return rgb/tf.expand_dims(tf.maximum(tf.reduce_sum(rgb, axis=3), 1e-7), axis=-1) def color_loss(rgb_out, rgb_gt): return tf.reduce_mean(tf.abs(color_normalize(rgb_out) - color_normalize(rgb_gt))) # load images from files gt_files = glob.glob(data_dir + /long/ + current_prefix + *.ARW) in_files = [None]*len(gt_files) train_ids = [None] * len(gt_files) gt_raws = [None] * len(train_ids) gt_rgbs = [None] * len(train_ids) in_raws = [None] * len(train_ids) # Reorganize the raw files according to their training id for i in range(len(gt_files)): if USE_GPU: train_ids[i] = gt_files[i].split(/)[-1][1:5] else: train_ids[i] = gt_files[i].split(\\)[-1][1:5] # for input files, multiple ones may relate to single ground truth file in_files[i] = glob.glob(data_dir + /short/ + current_prefix + train_ids[i] + *.ARW) in_raws[i] = [None]*len(in_files[i]) def get_gt_file_by_train_id(tid): return gt_files[tid] def get_in_file_by_train_id_file_id(tid, fid): return in_files[tid][fid] def get_patch_pair_raw_raw(raw_in, raw_gt): h, w = raw_in.shape[0], raw_in.shape[1] y, x = np.random.randint(0, h - patch_size[0]), np.random.randint(0, w - patch_size[1]) return ( np.expand_dims(raw_in[y:y + patch_size[0], x:x + patch_size[1], :], axis=0), np.expand_dims(raw_gt[y:y + patch_size[0], x:x + patch_size[1], :], axis=0) ) def get_patch_pair_raw_rgb(raw, rgb): h, w = raw.shape[0], raw.shape[1] y, x = np.random.randint(0, h - patch_size[0]), np.random.randint(0, w - patch_size[1]) return ( np.expand_dims(raw[y:y + patch_size[0], x:x + patch_size[1], :], axis=0), np.expand_dims(rgb[y:y + patch_size[0], x:x + patch_size[1], :], axis=0) ) def get_rand_patch_from_file_raw2rgb(): while True: seq = np.random.permutation(len(train_ids)) for ind in seq: if gt_rgbs[ind] is None: # resource not found in cache, load it from disk gt_file = get_gt_file_by_train_id(ind) gt_rgb = rgb_from_file(gt_file) fid = np.random.randint(0, len(in_files[ind])) if in_raws[ind][fid] is None: in_file = get_in_file_by_train_id_file_id(ind, fid) in_raw = raw_from_file(in_file) # cache them when using GPU on linux server since memory is sufficient if USE_GPU: gt_rgbs[ind] = gt_rgb in_raws[ind][fid] = in_raw yield get_patch_pair_raw_rgb(in_raw, gt_rgb) def get_rand_patch_from_file_raw2raw(): while True: seq = np.random.permutation(len(train_ids)) for ind in seq: if gt_raws[ind] is None: # resource not found in cache, load it from disk gt_file = get_gt_file_by_train_id(ind) gt_raw = raw_from_file(gt_file) fid = np.random.randint(0, len(in_files[ind])) if in_raws[ind][fid] is None: in_file = get_in_file_by_train_id_file_id(ind, fid) in_raw = raw_from_file(in_file) # cache them when using GPU on linux server since memory is sufficient if USE_GPU: in_raws[ind][fid] = in_raw gt_raws[ind] = gt_raw yield get_patch_pair_raw_rgb(in_raw, gt_raw) # basic nodes t_bayer_in = tf.placeholder(dtype=tf.float32, shape=[None, None, None, 1], name=input) t_bayer_gt = tf.placeholder(dtype=tf.float32, shape=[None, None, None, 1]) t_bayer_std = black_level_correction(t_bayer_in) t_bayer_gt_std = black_level_correction(t_bayer_gt) t_bayer_boosted = network_raw2raw(tf.minimum(300*t_bayer_std, 1.0)) t_half_rgb = bayer_to_rgb(t_bayer_std) t_half_rgb_boosted = bayer_to_rgb(bound(t_bayer_boosted)) t_half_rgb_gt = bayer_to_rgb(t_bayer_gt_std) t_half_rgb_resized = tf.image.resize_bilinear(t_half_rgb, fixed_size) t_rgb_gt = tf.placeholder(dtype=tf.float32, shape=[None, None, None, 3]) # ISP nodes t_color_matrix, t_gamma = get_color_matrix_and_gamma(t_half_rgb_resized) # training raw2raw alone # t_err_raw = tf.reduce_mean(tf.abs(t_half_rgb_gt - t_half_rgb_boosted)) t_err_raw = tf.reduce_mean(tf.abs(gaussian_norm(t_half_rgb_boosted) - gaussian_norm(t_half_rgb_gt))) # training raw2rgb alone t_half_rgb_freeze = tf.stop_gradient(t_half_rgb_boosted) t_rgb_freeze = build_isp_process_flow(t_half_rgb_freeze, t_color_matrix, t_gamma) # t_err_rgb = tf.reduce_mean(tf.abs(t_rgb_gt - t_rgb_freeze)) t_err_rgb = color_loss(t_rgb_freeze, t_rgb_gt) + tf.abs(t_gamma[0] - 1.0/2.5) # t_err_rgb = color_loss(t_rgb_freeze, t_rgb_gt) # training overall model t_rgb_final = build_isp_process_flow(t_half_rgb_boosted, t_color_matrix, t_gamma) # t_err_overall = tf.reduce_mean(tf.abs(t_rgb_gt - t_rgb_final)) t_err_overall = color_loss(t_rgb_final, t_rgb_gt) def clean_no_grad_vars(vs, gs): vs_clear = [] gs_clear = [] for i in range(len(gs)): if gs[i] is not None: vs_clear.append(vs[i]) gs_clear.append(gs[i]) return vs_clear, gs_clear def make_var_grad_pairs(vs, gs): return [(gs[i], vs[i]) for i in range(len(vs))] def train(): print(Staged training begins...) t_opt = tf.train.GradientDescentOptimizer(learning_rate=learn_rate) sess = tf.Session() t_minimizer_raw2raw = t_opt.minimize(t_err_raw) t_minimizer_raw2rgb = t_opt.minimize(t_err_rgb) t_minimizer_overall = t_opt.minimize(t_err_overall) # include = [‘g_conv1‘, ‘g_conv2‘, ‘g_conv3‘, ‘g_conv4‘, ‘g_conv5‘, ‘g_conv_last‘] # variables_to_restore = slim.get_variables_to_restore(include=include) # saver = tf.train.Saver(variables_to_restore) saver = tf.train.Saver(tf.global_variables()) sess.run(tf.global_variables_initializer()) # logger if not os.path.exists(log_dir): os.mkdir(log_dir) logger = tf.summary.FileWriter(log_dir, graph=sess.graph) t_sum_raw = tf.summary.scalar(raw2raw_loss, t_err_raw) t_sum_rgb = tf.summary.scalar(raw2rgb_loss, t_err_rgb) t_sum_all = tf.summary.scalar(overall_loss, t_err_overall) if not os.path.exists(os.path.join(model_dir, checkpoint)): if not os.path.exists(model_dir): os.mkdir(model_dir) else: print(Restoring model...) model_name_prefix = model_checkpoint_path: " with open(os.path.join(model_dir + checkpoint)) as ckpt: latest_id = ckpt.readline()[len(model_name_prefix):-2] saver.restore(sess, os.path.join(model_dir, latest_id)) # bind saver to the full graph instead of a sub-graph saver = tf.train.Saver(tf.global_variables()) # first stage: raw to raw training if training_stage == __stage_raw2raw__: print(Stage I: train to map input raw into ground truth raw) patches = get_rand_patch_from_file_raw2raw() counter = 0 t_start = time.clock() for raw_in, raw_gt in patches: _, err_raw2raw, sum_raw = sess.run( [t_minimizer_raw2raw, t_err_raw, t_sum_raw], feed_dict={ t_bayer_in: raw_in, t_bayer_gt: raw_gt } ) logger.add_summary(sum_raw, counter) epoch = int(counter / len(train_ids)) print(Epoch# %d Counter# %d Loss= %.7f % (epoch, counter, err_raw2raw)) counter += 1 if counter % 100 is 0: t_stop = time.clock() print(Speed: %.6f % ((t_stop - t_start) / 100)) t_start = t_stop if counter > max_epoch * len(train_ids): saver.save(sess, model_dir + / + str(epoch)) print(Training done.) break elif counter % (len(train_ids) * save_epoch_delay) is 0: saver.save(sess, model_dir + / + str(epoch)) print(Model saved.) # second stage: raw to rgb training if training_stage == __stage_raw2rgb__: print(Stage II: train to map generated raw into ground truth rgb) # gradient clip # t_vs = tf.trainable_variables() # t_gs = tf.gradients(t_err_rgb, t_vs) # t_vs, t_gs = clean_no_grad_vars(t_vs, t_gs) # t_var_grad_pairs = make_var_grad_pairs(t_vs, t_gs) # t_minimizer_raw2rgb = t_opt.apply_gradients(t_var_grad_pairs) patches = get_rand_patch_from_file_raw2rgb() counter = 0 t_start = time.clock() for raw_in, rgb_gt in patches: _, err_raw2rgb, sum_rgb, gamma = sess.run( [t_minimizer_raw2rgb, t_err_rgb, t_sum_rgb, t_gamma], feed_dict={ t_bayer_in: raw_in, t_rgb_gt: rgb_gt } ) # _, err_raw2rgb, grads, sum_rgb, gamma = sess.run( # [t_minimizer_raw2rgb, t_err_rgb, t_gs, t_sum_rgb, t_gamma], # feed_dict={ # t_bayer_in: raw_in, # t_rgb_gt: rgb_gt # } # ) logger.add_summary(sum_rgb, counter) epoch = int(counter / len(train_ids)) print(Epoch# %d Counter# %d Loss= %.7f Gamma=%.6f % (epoch, counter, err_raw2rgb, 1.0 / gamma)) # Gradient check # for i in range(len(grads)): # if has_nan_in_tensor(grads[i]): # print(‘Nan value found in gradient: %s!‘ % t_gs[i].name) counter += 1 if counter % 100 is 0: t_stop = time.clock() print(Speed: %.6f % ((t_stop - t_start) / 100)) t_start = t_stop if counter > max_epoch * len(train_ids): saver.save(sess, model_dir + / + str(epoch)) print(Training done.) elif counter % (len(train_ids) * save_epoch_delay) is 0: saver.save(sess, model_dir + / + str(epoch)) print(Model saved.) # second stage: overall training if training_stage == __stage_overall__: print(Stage III: train to map input raw into ground truth rgb) patches = get_rand_patch_from_file_raw2rgb() counter = 0 t_start = time.clock() for raw_in, rgb_gt in patches: _, err_overall, sum_all = sess.run( [t_minimizer_overall, t_err_overall, t_sum_all], feed_dict={ t_bayer_in: raw_in, t_rgb_gt: rgb_gt } ) logger.add_summary(sum_all, counter) epoch = int(counter / len(train_ids)) print(Epoch# %d Counter# %d Loss= %.7f % (epoch, counter, err_overall)) counter += 1 if counter % 100 is 0: t_stop = time.clock() print(Speed: %.6f % ((t_stop - t_start) / 100)) t_start = t_stop if counter > max_epoch * len(train_ids): saver.save(sess, model_dir + / + str(epoch)) print(Training done.) elif counter % (len(train_ids) * save_epoch_delay) is 0: saver.save(sess, model_dir + / + str(epoch)) print(Model saved.) # finalization logger.close() sess.close() def test_half_rgb(): print(Testing Half RGB reconstruction...) sess = tf.Session() t_vars = tf.global_variables() # var_names = [] # for v in t_vars: # var_names.append(v.name) # print(v.name) saver = tf.train.Saver(t_vars) if not os.path.exists(model_dir): assert path not found! model_name_prefix = model_checkpoint_path: " with open(os.path.join(model_dir, checkpoint)) as ckpt: latest_id = ckpt.readline()[len(model_name_prefix):-2] saver.restore(sess, os.path.join(model_dir, latest_id)) print(Model loaded.) if not os.path.exists(out_dir): os.mkdir(out_dir) patches = get_rand_patch_from_file_raw2raw() counter = 0 for raw_in, raw_gt in patches: half_rgb_boosted, half_rgb_gt = sess.run( [t_half_rgb_boosted, t_half_rgb_gt], feed_dict={ t_bayer_in: raw_in, t_bayer_gt: raw_gt } ) im_cmp = concat((half_rgb_boosted[0], half_rgb_gt[0])) # show(im_cmp, str(counter)) save(im_cmp, (out_dir + /HALF_%04d.jpg) % counter) counter += 1 if counter >= 20: break if __name__ == __main__: # test_half_rgb() train()



問題是:直接采用tf.train.Saver()的話,等效於saver = tf.train.Saver(tf.global_variables())





slim.get_variables_to_restore(include=include) 中 include 是一個name list,采用正則進行名字匹配,原理是:if v.name.startswith(‘VAR_NAME_PREFIX‘): ADD_TO_LIST(ret)

於是當你的include list中有conv2d這個變量名稱前綴時,所有的conv2d_xxx都會被自動添加到列表中,而且,SLIM很傻逼的不進行查重檢查!!!於是你得到的var_list中將會出現重復的

變量,導致加載模型時報錯:at least two of variables have the same name : conv2d_1/bias !!!


