1. 程式人生 > >TensorFlow學習實踐(三):使用TFRecord格式資料和tf.estimator API進行模型訓練和預測

TensorFlow學習實踐(三):使用TFRecord格式資料和tf.estimator API進行模型訓練和預測

本文以mnist為例,介紹如何使用TFRecord格式資料和tf.estimator API進行模型訓練和預測。

參考:

目錄

一、資料輸入

def input_fn(filenames, training):
    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.map(parse_data)

    if training:
        dataset = dataset.shuffle(buffer_size=50000)
    dataset = dataset.batch(FLAGS.batch_size)
    if training:
        dataset = dataset.repeat()

    iterator = dataset.make_one_shot_iterator()
    features, labels = iterator.get_next()
    return features, labels

通過tf.data的API對TFRecord資料進行解析,parse_data的具體實現見上篇:TensorFlow學習實踐(二):使用TFRecord格式資料和tf.data API進行模型訓練和預測。官方文件(datasets)中有說明,和estimator配合使用時要用dataset.make_one_shot_iterator(),Note: Currently, one-shot iterators are the only type that is easily usable with an Estimator.

二、模型定義

def model_fn(features, labels, mode):
    with tf.variable_scope('conv1'):
        conv1 = tf.layers.conv2d(inputs=features,
                                 filters=32,
                                 kernel_size=[5, 5],
                                 padding='same',
                                 activation=tf.nn.relu)

    pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)      # 14*14*32

    with tf.variable_scope('conv2'):
        conv2 = tf.layers.conv2d(inputs=pool1,
                                 filters=64,
                                 kernel_size=[5, 5],
                                 padding='same',
                                 activation=tf.nn.relu)

    pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2)      # 7*7*64

    with tf.variable_scope('fc1'):
        pool2_flat = tf.reshape(pool2, [-1, 7*7*64])
        fc1 = tf.layers.dense(inputs=pool2_flat, units=1024, activation=tf.nn.relu)
        dropout1 = tf.layers.dropout(inputs=fc1, rate=0.4, training=mode == tf.estimator.ModeKeys.TRAIN)

    with tf.variable_scope('logits'):
        logits = tf.layers.dense(inputs=dropout1, units=10)     # 使用該值計算交叉熵損失
        predict = tf.nn.softmax(logits)

    predictions = {
        # Generate predictions (for PREDICT and EVAL mode)
        "classes": tf.argmax(input=logits, axis=1),
        # Add `softmax_tensor` to the graph. It is used for PREDICT and by the
        # `logging_hook`.
        "probabilities": tf.nn.softmax(logits, name="softmax_tensor")
    }

    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

    loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
    accuracy = tf.metrics.accuracy(labels=labels, predictions=predictions["classes"])
    tf.summary.scalar('accuracy', accuracy[1])

    if mode == tf.estimator.ModeKeys.TRAIN:
        global_step = tf.train.get_global_step()
        train_op = train(loss, global_step)
        return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)

    # Add evaluation metrics (for EVAL mode)
    eval_metric_ops = {"eval_accuracy": accuracy}
    return tf.estimator.EstimatorSpec(mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)

根據tf.estimator.ModeKeys的值,返回不同的結果

三、模型訓練和驗證

def train():
    my_checkpoint_config = tf.estimator.RunConfig(save_checkpoints_steps=100, keep_checkpoint_max=5)

    mnist_classifier = tf.estimator.Estimator(model_fn=mnist.model_fn, model_dir=FLAGS.train_dir,
                                              config=my_checkpoint_config)
    tensor_to_log = {'probabilities': 'softmax_tensor'}
    logging_hook = tf.train.LoggingTensorHook(tensors=tensor_to_log, every_n_iter=100)

    mnist_classifier.train(input_fn=lambda: mnist.input_fn(['./train_img.tfrecords'], True),
                           # hooks=[logging_hook],
                           steps=FLAGS.max_step)

    eval_results = mnist_classifier.evaluate(input_fn=lambda: mnist.input_fn(['./validation_img.tfrecords'], False))
    print(eval_results)

訓練結果:

...
INFO:tensorflow:loss = 4.674489e-05, step = 11500 (0.549 sec)
INFO:tensorflow:Saving checkpoints for 11600 into ./train\model.ckpt.
INFO:tensorflow:global_step/sec: 181.878
INFO:tensorflow:loss = 0.0001392595, step = 11600 (0.550 sec)
INFO:tensorflow:Saving checkpoints for 11700 into ./train\model.ckpt.
INFO:tensorflow:global_step/sec: 73.6977
INFO:tensorflow:loss = 1.4009732e-05, step = 11700 (1.356 sec)
INFO:tensorflow:Saving checkpoints for 11800 into ./train\model.ckpt.
INFO:tensorflow:global_step/sec: 179.256
INFO:tensorflow:loss = 0.00017982977, step = 11800 (0.558 sec)
INFO:tensorflow:Saving checkpoints for 11900 into ./train\model.ckpt.
INFO:tensorflow:global_step/sec: 182.132
INFO:tensorflow:loss = 0.00027710196, step = 11900 (0.549 sec)
INFO:tensorflow:Saving checkpoints for 12000 into ./train\model.ckpt.
INFO:tensorflow:Loss for final step: 4.0867322e-05.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2018-08-29-13:18:34
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from ./train\model.ckpt-12000
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Finished evaluation at 2018-08-29-13:18:34
INFO:tensorflow:Saving dict for global step 12000: eval_accuracy = 0.9934, global_step = 12000, loss = 0.0540578
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 12000: ./train\model.ckpt-12000
{'eval_accuracy': 0.9934, 'loss': 0.0540578, 'global_step': 12000}

Process finished with exit code 0

訓練了12000步,batch_size是128,訓練結束後對驗證集進行驗證,準確率99.34%。

這裡是在訓練結束後再進行驗證的,我想每訓練100步,對驗證集驗證一次,看下準確率,但是從官方文件沒看到怎麼操作,文件中說明,每重複一次tf.estimator.Estimator().train(),訓練步數會累積,也就是會在上輪結果上繼續訓練。

即,下面兩種寫法訓練的結果相同:

mnist_classifier.train(input_fn=lambda: mnist.input_fn(['./train_img.tfrecords'], True), steps=100)
mnist_classifier.train(input_fn=lambda: mnist.input_fn(['./train_img.tfrecords'], True), steps=100)
mnist_classifier.train(input_fn=lambda: mnist.input_fn(['./train_img.tfrecords'], True), steps=100)

mnist_classifier.train(input_fn=lambda: mnist.input_fn(['./train_img.tfrecords'], True), steps=300)

於是,修改訓練部分程式碼如下:

def train():
    my_checkpoint_config = tf.estimator.RunConfig(save_checkpoints_steps=100, keep_checkpoint_max=5)

    mnist_classifier = tf.estimator.Estimator(model_fn=mnist.model_fn, model_dir=FLAGS.train_dir,
                                              config=my_checkpoint_config)
    tensor_to_log = {'probabilities': 'softmax_tensor'}
    logging_hook = tf.train.LoggingTensorHook(tensors=tensor_to_log, every_n_iter=100)
    for i in range(FLAGS.max_step//100):
        mnist_classifier.train(input_fn=lambda: mnist.input_fn(['./train_img.tfrecords'], True),
                               # hooks=[logging_hook],
                               steps=100)

        eval_results = mnist_classifier.evaluate(input_fn=lambda: mnist.input_fn(['./validation_img.tfrecords'], False))
        print(eval_results)

訓練結果:

...
{'eval_accuracy': 0.9922, 'loss': 0.06816594, 'global_step': 11300}
{'eval_accuracy': 0.9922, 'loss': 0.068535455, 'global_step': 11400}
{'eval_accuracy': 0.9922, 'loss': 0.06853329, 'global_step': 11500}
{'eval_accuracy': 0.9924, 'loss': 0.06850766, 'global_step': 11600}
{'eval_accuracy': 0.9922, 'loss': 0.068637684, 'global_step': 11700}
{'eval_accuracy': 0.992, 'loss': 0.069534324, 'global_step': 11800}
{'eval_accuracy': 0.9916, 'loss': 0.07054804, 'global_step': 11900}
{'eval_accuracy': 0.9916, 'loss': 0.07171986, 'global_step': 12000}

Process finished with exit code 0

這樣做會降低效率,檢視原始碼可以發現,每次train最後都會呼叫MonitoredTrainingSession建立session訓練,還會restore上一次的訓練結果,暫時沒找到好的方法能夠在訓練過程中對驗證集進行驗證。

最後:完整程式碼