1. 程式人生 > >第三章- 【TensorFlow 入門】

第三章- 【TensorFlow 入門】

主要內容:TensorFlow的計算模型、資料模型和執行模型

1、TensorFlow計算模型—計算圖

計算圖是TensorFlow中最基本的一個概念,TensorFlow中的所有計算都會被轉化為計算圖上的節點。計算圖是用來描述TensorFlow中的計算。
計算圖概念
Tensor: 張量,張量可以被簡單理解為多維陣列。
Flow: 中文表示為“流”,直觀地表達了張量之間通過計算相互轉化的過程。

TensorFlow中的每一個計算都是計算圖上的一個節點,而節點之間的邊描述了計算之間的依賴關係。
依賴關係:一個運算的輸入依賴於另外一個運算的輸出,那麼這兩個運算有依賴關係。

1.1計算圖的使用

TensorFlow程式分為兩個階段:第一個階段需要定義計算圖中所有的計算,第二階段為執行計算
計算定義示例如下:

import tensorflow as tf
a=tf.constant([1.0,2.0],name='a')
b=tf.constant([2.0,3.0],name='b')
result = a + b 

TensorFlow程式中,系統會自動維護一個預設的計算圖,通過tf.get_default_graph函式可以獲取當前預設的計算圖,可以通過a.graph可以檢視張量所屬的計算圖,如果沒有特意指定,所以這個計算圖應該等於當前預設的計算圖。print(a.graph is tf.get_default_grapha())


除了預設的計算圖,TensorFlow支援通過tf.Graph函式來生成新的計算圖。不同的計算圖上的張量和運算都不會共享。以下程式碼示意瞭如何在不同計算圖上定義和使用變數。

import tensorflow as tf
g1=tf.Graph()
with g1.as_default():
   #在計算圖g1中定義變數"v",並設定初始值為0。
	v=tf.get_variable(“v”,shape=[1],initializer=tf.zeros_initializer)
	
g2=tf.Graph()
with g2.as_default():
   #在計算圖g2中定義變數"v",並設定初始值為1。
	v=tf.get_variable(“v”,shape=[1],initializer=tf.ones_initializer)

#在計算圖g1中讀取變數“v”的取值。
with tf.Session(graph=g1) as sess:
     tf.global_variables_initializer().run()
     with tf.variable_scope("",reuse=True):
     		#在計算圖g1中,變數“v”的取值應該為0,所以下面這行會輸出[0.]。
     		print(sess.run(tf.get_variable("v")))

#在計算圖g2中讀取變數“v”的取值。
with tf.Session(graph=g2) as sess:
     tf.global_variables_initializer().run()
     with tf.variable_scope("",reuse=True):
     		#在計算圖g2中,變數“v”的取值應該為1,所以下面這行會輸出[1.]。
     		print(sess.run(tf.get_variable("v")))

TensorFlow中的計算圖不僅僅可以隔離張量和計算,它還提供了管理張量和計算的機制。計算圖可以通過tf.Graph.device函式來指定執行計算的裝置。這為TensorFlow使用GPU提供了機制。以下程式可以將加法計算跑在GPU上。

g = tf.Graph()
#指定計算執行的裝置
with g.device('/gpu:0'):
	result = a + b

在一個計算圖中,可以通過集合(collection)來管理不同類別的資源。
比如通過tf.add_to_collection函式可以將資源加入一個或多個集合中,然後通過tf.get_collection獲取一個集合裡面的所有資源。這裡的資源可以是張量、變數或者執行TensorFlow程式所需要的佇列資源,等等。

2、TensorFlow數學模型—張量

在TensorFlow程式中,所有的資料都通過張量的形式來表示,從功能的角度上看,張量可以被簡單理解為多維陣列。張量是TensorFlow管理資料的形式。

  • 零階張量: 標量,也就是一個數;
  • 一階張量: 向量,也就是一維陣列;
  • n階張量: n維陣列。

張量中並沒有真正儲存數字,它儲存是如何得到這些數字的計算過程。

import tensorflow as tf
# tf.constant是一個計算,這個計算的結果為一個張量,就儲存在變數a中。
a = tf.constant([1.0,2.0],name = 'a')
b = tf.constant([2.0,3.0],name = 'b')
result = tf.add(a,b,name='add')
print(result)
"""
輸出:
Tensor("add:0",shape=(2,),dtype=float32)
""" 

張量中主要儲存了三個屬性:名字(name)、維度(shape)和型別(type)

張量的名字不僅是一個張量的唯一識別符號,它同樣也給出了這個張量是如何計算出來的。計算圖的每一個節點代表了一個計算,計算的結果就儲存在張量之中。張量和計算圖上節點所代表的計算結果是對應的。這樣張量的命名就可以通過“node:src_output”的形式來給出。其中node為節點的名稱,src_output表示當前張量來自節點的第幾個輸出。比如上面程式碼打出來的"add:0"就說明result這個張量是計算節點“add”輸出的第一個結果(編號從0開始)。
第二個屬性是張量的維度(shape)。這個屬性描述了一個張量的維度資訊。比如上面樣例中shape=(2,)說明了張量result是一個一維陣列,這個陣列的長度為2。
第三個屬性是型別(type),每一個張量會有一個唯一的型別。

TensorFlow支援14種不同的型別
實數(tf.float32、tf.float64)、整數(tf.int8、tf.int16、tf.int32、tf.int64、tf.uint8)、布林型(tf.bool)和複數(tf.complex64、tf.complex128)。

2.1張量的使用

張量的使用主要包括兩大類

  • 中間結果的引用。
a = tf.constant([1.0,2.0],name='a')
b = tf.constant([3.0,4.0],name='b')
result = a + b
print(a)
#輸出結果為
Tensor("a:0", shape=(2,), dtype=float32)

其中a和b其實就是對常量生成這個運算結果的引用,這在做加法運算時就可以直接使用這兩個變數,而不需要再去生成這些常量。通過張量來引用計算的中間結果

  • 獲得計算結果,也就是得到真實的數字。
    在以上的例項中,可以建立會話,呼叫run函式來得到張量的具體數字。
tf.Session().run(result)
#輸出結果為
array([4., 6.], dtype=float32)

3.TensorFlow執行模型—會話

會話(session):執行定義好的計算。會話擁有並管理TensorFlow程式執行時的所有資源。所有計算完成後需要關閉會話來幫助系統回收資源,否則就可能出現資源洩露的問題。TensorFlow中使用會話的模式有兩種,第一種模式需要明確呼叫會話生成函式和關閉會話函式,這種模式的流程如下:

#建立一個會話
sess = tf.Session()
#使用建立好的會話來得到關心的運算的結果
sess.run(result)
#關閉會話使得本次執行中使用到的資源可以被釋放
sess.close()

缺點:當程式出現異常退出時,關閉會話的函式可能就不會被執行從而導致資源洩露。為了解決這個問題,可以採用上下文管理器來使用會話。如下所示:

#建立一個會話,並通過python的上下文管理器來管理這個會話。
with tf.Session() as sess:
	sess.run()
#不需要呼叫sess.close()函式來關閉會話。
#上下文退出時會話關閉和資源釋放也自動完成了。

TensorFlow不會生成預設的會話,而是需要手動指定。當預設的會話別指定之後,可以通過tf.Tensor.eval函式來計算一個張量的取值。以下程式碼展示了通過設定預設會話來計算張量的取值。

sess=tf.Session()
with sess.as_default():
	print(result.eval())

以下程式碼可以完成相同的功能

sess=tf.Session()
#以下兩個命令有相同的功能
print(sess.run(result))
print(result.eval(session=sess))

在互動環境下,通過設定預設會話的方式來獲取張量的取值更加方便。所以TensorFlow提供了一種在互動環境下直接構建預設會話的函式。這個函式就是tf.InteractiveSession。使用這個函式會自動生成的會話註冊為預設會話。如下:

sess = tf.InteractiveSeesion()
print(result.eval())
sess.close()