1. 程式人生 > >用tensorflow訓練自己的資料_3、訓練模型

用tensorflow訓練自己的資料_3、訓練模型

訓練模型的時候,維數一定要匹配,同時要了解你自己的資料的格式,和讀取的型別,一個one_hot編碼用的函式和非one_hot用的函式完全不一樣,這也是我當時一直出現問題的原因。

#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Thu Jan 25 11:32:40 2018

@author: huangxudong
"""
import dr_alexnet
import tensorflow as tf
import read_data2

#定義網路超引數
learning_rate=0.01
train_iters=2000
batch_size=5
capacity=256
display_step=10
#讀取資料
tra_list,tra_labels,val_list,val_labels=read_data2.get_files('/home/bigvision/Desktop/DR_model',0.2)
tra_list_batch,tra_label_batch=read_data2.get_batch(tra_list,tra_labels,512,512,batch_size,capacity)
val_list_batch,val_label_batch=read_data2.get_batch(val_list,val_labels,512,512,batch_size,capacity)

#定義網路引數
n_class=6       #標記維度
dropout=0.75
skip=[]
#輸入佔位符
x=tf.placeholder(tf.float32,[None,786432])  #2800*2100*3,512*512*3
y=tf.placeholder(tf.int32,[None])
#print(y.shape)
keep_prob=tf.placeholder(tf.float32)  #dropout


''''構建模型,定義損失函式和優化器'''''
pred=dr_alexnet.alexNet(x,dropout,n_class,skip)
#定義損失函式和優化器
cost=tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y,logits=pred.fc3))
optimizer=tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)
#評估函式,優化函式
correct_pred=tf.nn.in_top_k(pred.fc3,y,1)  #1表示列上去最大,0是行,這個地方如果是one_hot就是tf.argmax
accuracy=tf.reduce_mean(tf.cast(correct_pred,tf.float32))    #改型別


'''訓練模型'''
init=tf.global_variables_initializer()   #初始化所有變數

with tf.Session() as sess:
    sess.run(init)
    coord=tf.train.Coordinator()      
    threads= tf.train.start_queue_runners(coord=coord)    
    step=1
    #開始訓練,達到最大訓練次數
    while step*batch_size<train_iters:       
        batch_x,batch_y=tra_list_batch.eval(session=sess),tra_label_batch.eval(session=sess)
        batch_x=batch_x.reshape((batch_size,786432))
        batch_y=batch_y.T
        
        sess.run(optimizer,feed_dict={x:batch_x,y:batch_y,keep_prob:dropout})
        if step%display_step==2:            
            #計算損失值和準確度,輸出
            loss,acc=sess.run([cost,accuracy],feed_dict={x:batch_x,y:batch_y,keep_prob:1.})
            print("Iter"+str(step*batch_size)+",Minibatch Loss="+ "{:.6f}".format(loss)+", Training Acc"+ "{:.5f}".format(acc))
        step+=1
    print("Optimization Finished!")
    coord.request_stop()     
    coord.join(threads)            #多執行緒進行batch送入

feed_dict字典讀取資料的時候不能是tensor型別,必須是list,numpy型別(還有一個忘了),所以在送入batch資料的時候加入了.eval(session.sess),當初這塊也是磨了很久。希望以後不在犯錯

本人新人,對大家有幫助的話就點贊哦