TensorFlow學習實踐(三):使用TFRecord格式資料和tf.estimator API進行模型訓練和預測
阿新 • • 發佈:2019-02-06
本文以mnist為例,介紹如何使用TFRecord格式資料和tf.estimator API進行模型訓練和預測。
參考:
目錄
一、資料輸入
def input_fn(filenames, training): dataset = tf.data.TFRecordDataset(filenames) dataset = dataset.map(parse_data) if training: dataset = dataset.shuffle(buffer_size=50000) dataset = dataset.batch(FLAGS.batch_size) if training: dataset = dataset.repeat() iterator = dataset.make_one_shot_iterator() features, labels = iterator.get_next() return features, labels
通過tf.data的API對TFRecord資料進行解析,parse_data的具體實現見上篇:TensorFlow學習實踐(二):使用TFRecord格式資料和tf.data API進行模型訓練和預測。官方文件(datasets)中有說明,和estimator配合使用時要用dataset.make_one_shot_iterator(),Note: Currently, one-shot iterators are the only type that is easily usable with an Estimator
.
二、模型定義
def model_fn(features, labels, mode): with tf.variable_scope('conv1'): conv1 = tf.layers.conv2d(inputs=features, filters=32, kernel_size=[5, 5], padding='same', activation=tf.nn.relu) pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2) # 14*14*32 with tf.variable_scope('conv2'): conv2 = tf.layers.conv2d(inputs=pool1, filters=64, kernel_size=[5, 5], padding='same', activation=tf.nn.relu) pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2) # 7*7*64 with tf.variable_scope('fc1'): pool2_flat = tf.reshape(pool2, [-1, 7*7*64]) fc1 = tf.layers.dense(inputs=pool2_flat, units=1024, activation=tf.nn.relu) dropout1 = tf.layers.dropout(inputs=fc1, rate=0.4, training=mode == tf.estimator.ModeKeys.TRAIN) with tf.variable_scope('logits'): logits = tf.layers.dense(inputs=dropout1, units=10) # 使用該值計算交叉熵損失 predict = tf.nn.softmax(logits) predictions = { # Generate predictions (for PREDICT and EVAL mode) "classes": tf.argmax(input=logits, axis=1), # Add `softmax_tensor` to the graph. It is used for PREDICT and by the # `logging_hook`. "probabilities": tf.nn.softmax(logits, name="softmax_tensor") } if mode == tf.estimator.ModeKeys.PREDICT: return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions) loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits) accuracy = tf.metrics.accuracy(labels=labels, predictions=predictions["classes"]) tf.summary.scalar('accuracy', accuracy[1]) if mode == tf.estimator.ModeKeys.TRAIN: global_step = tf.train.get_global_step() train_op = train(loss, global_step) return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op) # Add evaluation metrics (for EVAL mode) eval_metric_ops = {"eval_accuracy": accuracy} return tf.estimator.EstimatorSpec(mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)
根據tf.estimator.ModeKeys的值,返回不同的結果
三、模型訓練和驗證
def train(): my_checkpoint_config = tf.estimator.RunConfig(save_checkpoints_steps=100, keep_checkpoint_max=5) mnist_classifier = tf.estimator.Estimator(model_fn=mnist.model_fn, model_dir=FLAGS.train_dir, config=my_checkpoint_config) tensor_to_log = {'probabilities': 'softmax_tensor'} logging_hook = tf.train.LoggingTensorHook(tensors=tensor_to_log, every_n_iter=100) mnist_classifier.train(input_fn=lambda: mnist.input_fn(['./train_img.tfrecords'], True), # hooks=[logging_hook], steps=FLAGS.max_step) eval_results = mnist_classifier.evaluate(input_fn=lambda: mnist.input_fn(['./validation_img.tfrecords'], False)) print(eval_results)
訓練結果:
...
INFO:tensorflow:loss = 4.674489e-05, step = 11500 (0.549 sec)
INFO:tensorflow:Saving checkpoints for 11600 into ./train\model.ckpt.
INFO:tensorflow:global_step/sec: 181.878
INFO:tensorflow:loss = 0.0001392595, step = 11600 (0.550 sec)
INFO:tensorflow:Saving checkpoints for 11700 into ./train\model.ckpt.
INFO:tensorflow:global_step/sec: 73.6977
INFO:tensorflow:loss = 1.4009732e-05, step = 11700 (1.356 sec)
INFO:tensorflow:Saving checkpoints for 11800 into ./train\model.ckpt.
INFO:tensorflow:global_step/sec: 179.256
INFO:tensorflow:loss = 0.00017982977, step = 11800 (0.558 sec)
INFO:tensorflow:Saving checkpoints for 11900 into ./train\model.ckpt.
INFO:tensorflow:global_step/sec: 182.132
INFO:tensorflow:loss = 0.00027710196, step = 11900 (0.549 sec)
INFO:tensorflow:Saving checkpoints for 12000 into ./train\model.ckpt.
INFO:tensorflow:Loss for final step: 4.0867322e-05.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2018-08-29-13:18:34
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from ./train\model.ckpt-12000
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Finished evaluation at 2018-08-29-13:18:34
INFO:tensorflow:Saving dict for global step 12000: eval_accuracy = 0.9934, global_step = 12000, loss = 0.0540578
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 12000: ./train\model.ckpt-12000
{'eval_accuracy': 0.9934, 'loss': 0.0540578, 'global_step': 12000}
Process finished with exit code 0
訓練了12000步,batch_size是128,訓練結束後對驗證集進行驗證,準確率99.34%。
這裡是在訓練結束後再進行驗證的,我想每訓練100步,對驗證集驗證一次,看下準確率,但是從官方文件沒看到怎麼操作,文件中說明,每重複一次tf.estimator.Estimator().train(),訓練步數會累積,也就是會在上輪結果上繼續訓練。
即,下面兩種寫法訓練的結果相同:
mnist_classifier.train(input_fn=lambda: mnist.input_fn(['./train_img.tfrecords'], True), steps=100)
mnist_classifier.train(input_fn=lambda: mnist.input_fn(['./train_img.tfrecords'], True), steps=100)
mnist_classifier.train(input_fn=lambda: mnist.input_fn(['./train_img.tfrecords'], True), steps=100)
和
mnist_classifier.train(input_fn=lambda: mnist.input_fn(['./train_img.tfrecords'], True), steps=300)
於是,修改訓練部分程式碼如下:
def train():
my_checkpoint_config = tf.estimator.RunConfig(save_checkpoints_steps=100, keep_checkpoint_max=5)
mnist_classifier = tf.estimator.Estimator(model_fn=mnist.model_fn, model_dir=FLAGS.train_dir,
config=my_checkpoint_config)
tensor_to_log = {'probabilities': 'softmax_tensor'}
logging_hook = tf.train.LoggingTensorHook(tensors=tensor_to_log, every_n_iter=100)
for i in range(FLAGS.max_step//100):
mnist_classifier.train(input_fn=lambda: mnist.input_fn(['./train_img.tfrecords'], True),
# hooks=[logging_hook],
steps=100)
eval_results = mnist_classifier.evaluate(input_fn=lambda: mnist.input_fn(['./validation_img.tfrecords'], False))
print(eval_results)
訓練結果:
...
{'eval_accuracy': 0.9922, 'loss': 0.06816594, 'global_step': 11300}
{'eval_accuracy': 0.9922, 'loss': 0.068535455, 'global_step': 11400}
{'eval_accuracy': 0.9922, 'loss': 0.06853329, 'global_step': 11500}
{'eval_accuracy': 0.9924, 'loss': 0.06850766, 'global_step': 11600}
{'eval_accuracy': 0.9922, 'loss': 0.068637684, 'global_step': 11700}
{'eval_accuracy': 0.992, 'loss': 0.069534324, 'global_step': 11800}
{'eval_accuracy': 0.9916, 'loss': 0.07054804, 'global_step': 11900}
{'eval_accuracy': 0.9916, 'loss': 0.07171986, 'global_step': 12000}
Process finished with exit code 0
這樣做會降低效率,檢視原始碼可以發現,每次train最後都會呼叫MonitoredTrainingSession建立session訓練,還會restore上一次的訓練結果,暫時沒找到好的方法能夠在訓練過程中對驗證集進行驗證。