1. 程式人生 > >tensorflow 固定部分引數訓練,只訓練部分引數

tensorflow 固定部分引數訓練,只訓練部分引數

def var_filter(var_list, last_layers = [0]):
    filter_keywords = ['fine_tune', 'layer_11', 'layer_10', 'layer_9', 'layer_8']
    for var in var_list:
        for layer in last_layers:
            kw = filter_keywords[layer]
            if kw in var.name:
                yield var
                break
        else:
            continue
            
def set_optimizer(self, n):
    train_vars = list(var_filter(tf.trainable_variables(), last_layers = range(n)))
    self.train_op = self.optim.minimize(self.loss, global_step=self.global_step, var_list = train_vars)