1. 程式人生 > >PyTorch學習:動態圖和靜態圖

PyTorch學習:動態圖和靜態圖

動態圖和靜態圖

目前神經網路框架分為靜態圖框架和動態圖框架,PyTorch 和 TensorFlow、Caffe 等框架最大的區別就是他們擁有不同的計算圖表現形式。 TensorFlow 使用靜態圖,這意味著我們先定義計算圖,然後不斷使用它,而在 PyTorch 中,每次都會重新構建一個新的計算圖。通過這次課程,我們會了解靜態圖和動態圖之間的優缺點。

對於使用者來說,兩種形式的計算圖有著非常大的區別,同時靜態圖和動態圖都有他們各自的優點,比如動態圖比較方便debug,使用者能夠用任何他們喜歡的方式進行debug,同時非常直觀,而靜態圖是通過先定義後執行的方式,之後再次執行的時候就不再需要重新構建計算圖,所以速度會比動態圖更快。

# tensorflow
import tensorflow as tf
first_counter = tf.constant(0)
second_counter = tf.constant(10)
# tensorflow
import tensorflow as tf
first_counter = tf.constant(0)
second_counter = tf.constant(10)
def cond(first_counter, second_counter, *args):
    return first_counter < second_counter
def body(first_counter, second_counter):
    first_counter = tf.add(first_counter, 2)
    second_counter = tf.add(second_counter, 1)
    return first_counter, second_counter
c1, c2 = tf.while_loop(cond, body, [first_counter, second_counter])
with tf.Session() as sess:
    counter_1_res, counter_2_res = sess.run([c1, c2])
print(counter_1_res)
print(counter_2_res)

可以看到 TensorFlow 需要將整個圖構建成靜態的,換句話說,每次執行的時候圖都是一樣的,是不能夠改變的,所以不能直接使用 Python 的 while 迴圈語句,需要使用輔助函式 tf.while_loop 寫成 TensorFlow 內部的形式

# pytorch
import torch
first_counter = torch.Tensor([0])
second_counter = torch.Tensor([10])

while (first_counter < second_counter)[0]:
    first_counter += 2
    second_counter += 1

print(first_counter)
print(second_counter)

可以看到 PyTorch 的寫法跟 Python 的寫法是完全一致的,沒有任何額外的學習成本