1. 程式人生 > >【AI實戰】快速掌握TensorFlow(二):計算圖、會話

【AI實戰】快速掌握TensorFlow(二):計算圖、會話


在前面的文章中,我們已經完成了AI基礎環境的搭建(見文章:Ubuntu + Anaconda + TensorFlow + GPU + PyCharm搭建AI基礎環境),以及初步瞭解了TensorFlow的特點和基本操作(見文章:快速掌握TensorFlow(一)),接下來將繼續學習掌握TensorFlow。

本文主要是學習掌握TensorFlow的計算圖、會話操作。

 

 
計算圖是TensorFlow的核心概念,使用圖(Graph)來表示計算任務,由節點和邊組成。TensorFlow由前端負責構建計算圖,後端負責執行計算圖。
為了執行圖的計算,圖必須在會話(Session)裡面啟動,會話將圖的操作分發到CPU、GPU等裝置上執行。
下面將介紹如何在TensorFlow裡面建立會話、圖以及基本操作。

1、圖(Graph)
TensorFlow Python庫已經有一個預設圖 (default graph),如果沒有建立新的計算圖,則預設情況下是在這個default graph裡面建立節點和邊。
在圖裡面新增節點非常方便。例如現在要建立這樣的計算圖,兩個張量相加,如下圖:
 
程式碼如下:

import tensorflow as tf 
a=tf.constant([1.0,2.0], name='a') 
b=tf.constant([3.0,4.0], name='b') 
result = tf.add(a,b)

現在預設圖就有了三個節點,兩個constant(),和一個add()。
為了真正使兩個張量相加並得到結果,就必須在會話裡面啟動這個圖。

2、會話(Session)
要啟動計算圖,首先要建立一個Session物件。
使用tf.Session()建立會話,呼叫run()函式執行計算圖。如果沒有傳入任何建立引數,會話構造器將啟動預設圖。如果要指定某個計算圖,則傳入計算圖引數(如g1),則建立會話方式為tf.Session(graph=g1)建立會話(Session)主要有以下三種方式:
(1)建立一個會話

#啟動預設圖
sess=tf.Session()
result_value = sess.run(result)
print(result_value)
# ==> [4.0 6.0]

# 任務完成, 關閉會話.
sess.close()

(2) 建立一個會話
Session在使用完後需要關閉以釋放資源,除了顯式呼叫close外,也可以使用“with”程式碼塊 來自動完成關閉動作。程式碼如下:

with tf.Session() as sess:
    result_value = sess.run(result)
    print(result_value)
    # ==> [4.0 6.0]

(3)建立一個預設的會話

sess=tf.Session()
with sess.as_default():
    result_value = result.eval()
    print(result_value)

當指定預設會話後,可以通過tf.Tensor.eval函式來計算一個張量的取值。

(4)建立一個互動式會話
在互動式環境下(例如IPython),使用設定預設會話的方式來獲取張量的取值更加方便,TensorFlow提供了一種在互動式環境下直接構建預設會話的函式:tf.InteractiveSession,該函式會自動將生成的會話註冊為預設會話,使用 tf.Tensor.eval()代替 Session.run(),程式碼如下:

sess= tf.InteractiveSession()
result_value = result.eval()
print(result_value)
sess.close()

3、構建多個計算圖
在TensorFlow中可以構建多個計算圖,計算圖之間的張量和運算是不會共享的,通過這種方式,可以在同個專案中構建多個網路模型,而相互之間不會受影響。
使用tf.Graph()函式構建圖,構建多個計算圖的方式如下:

# 構建計算圖g1
g1=tf.Graph()
with g1.as_default():
    # 在計算圖g1中定義變數'v',並設定初始值為0。
    v=tf.get_variable('v',initializer=tf.zeros_initializer()(shape = [1]))
    
# 構建計算圖g2
g2=tf.Graph()
with g2.as_default():
    # 在計算圖g2中定義變數'v',並設定初始值微1。
    v=tf.get_variable('v',initializer=tf.ones_initializer()(shape = [1]))

# 在計算圖g1中讀取變數'v'的取值
with tf.Session(graph=g1) as sess:
    tf.global_variables_initializer().run()
    with tf.variable_scope('',reuse=True):
        print(sess.run(tf.get_variable('v')))
        # 輸出結果[0.]

# 在計算圖g2中讀取變數'v'的取值
with tf.Session(graph=g2) as sess:
    tf.global_variables_initializer().run()
    with tf.variable_scope('',reuse=True):
        print(sess.run(tf.get_variable('v')))
        # 輸出結果[1.]。

4、指定執行裝置
如果電腦有多個GPU,可以在圖、會話中指定要執行的裝置
(1)在圖中指定執行裝置

g=tf.Graph()
# 指定計算執行的裝置。
with g.device('/gpu:0'):
    result=tf.add(a,b)

(2)在會話中指定執行裝置

with tf.Session() as sess:
  with tf.device("/gpu:0"):
    result=tf.add(a,b)

執行的裝置用字串進行標識,目前支援的裝置包括:

  • "/cpu:0": 機器的 CPU
  • "/gpu:0": 機器的第一個 GPU,如果有的話
  • "/gpu:1": 機器的第二個 GPU,以此類推

通過以上介紹,已經瞭解了圖、會話的基本操作,使用圖 (graph) 來表示計算任務,使用會話 (Session) 來執行圖。

接下來,我們將有更多講解TensorFlow的精彩內容,敬請期待!

 

推薦相關閱讀