1. 程式人生 > >tensorflow中的Session()和run()

tensorflow中的Session()和run()

創建 需要 ons 參與 鍵值對 連接 edt 分布 const

Session()方法

tensorflow的內核使用更加高效的C++作為後臺,以支撐它的密集計算。tensorflow把前臺(即python程序)與後臺程序之間的連接稱為"會話(Session)"

Session作為會話,主要功能是指定操作對象的執行環境,Session類構造函數有3個可選參數。

  • target(可選):指定連接的執行引擎,多用於分布式場景。
  • graph(可選):指定要在Session對象中參與計算的圖(graph)。
  • config(可選):輔助配置Session對象所需的參數(限制CPU或GPU使用數目,設置優化參數以及設置日誌選項等)。

run()方法
Session對象創建完畢,便可以使用它最重要的方法run()來啟動所需要的數據流圖進行計算。

run()方法有4個參數:

run(
    fetches,
    feed_dict=None
    options=None,
    run_metadata=None
)

(1).fetches參數

  • ‘取得之物‘,表示數據流圖中能接收的任意數據流圖元素,各類Op/Tensor對象。Op,run()將返回None;Tensor,rnu()將返回Numpy數組。
import tensorflow as tf
from collections import namedtuple

a = tf.constant([10, 20])
b = tf.constant([1.0, 2.0])
session = tf.Session()

v1 = session.run(a) #fetches參數為單個張量值,返回值為Numpy數組
print(v1)
v2 = session.run([a, b]) #fetches參數為python類表,包括兩個numpy的1維矩陣
print(v2)
v3 = session.run(tf.global_variables_initializer()) #fetches 為Op類型
print(v3)
session.close()
[10 20]
[array([10, 20], dtype=int32), array([ 1.,  2.], dtype=float32)]
None

(2). feed_dict參數

  • 可選項,給數據流圖提供運行時數據。feed_dict的數據結構為python中的字典,其元素為各種鍵值對。"key"為各種Tensor對象的句柄;"value"很廣泛,但必須和“鍵”的類型相匹配,或能轉換為同一類型。
import tensorflow as tf

a = tf.add(1, 2)
b = tf.multiply(a, 2)
session = tf.Session()
v1 = session.run(b)
print(v1)

replace_dict = {a:20}
v2 = session.run(b, feed_dict = replace_dict)
print(v2)
6
40

tensorflow中的Session()和run()