1. 程式人生 > >實現屬於自己的TensorFlow(1):計算圖與前向傳播

實現屬於自己的TensorFlow(1):計算圖與前向傳播

前段時間因為課題需要使用了一段時間TensorFlow,感覺這種框架很有意思,除了可以搭建複雜的神經網路,也可以優化其他自己需要的計算模型,所以一直想自己學習一下寫一個類似的圖計算框架。前幾天組會開完決定著手實現一個模仿TensorFlow介面的簡陋版本圖計算框架以學習計算圖程式的編寫以及前向傳播和反向傳播的實現。目前實現了前向傳播和反向傳播以及梯度下降優化器,並寫了個優化線性模型的例子。

雖然前向傳播反向傳播這些原理了解起來並不是很複雜,但是真正著手寫起來才發現,裡面還是有很多細節需要學習和處理才能對實際的模型進行優化(例如Loss函式對每個計算節點矩陣求導的處理)。其中SimpleFlow的程式碼並沒有考慮太多的東西比如dtype

和張量size的檢查等,因為只是為了實現主要圖計算功能並沒有考慮任何的優化, 內部張量運算使用的Numpy的介面(畢竟是學習和練手的目的嘛)。好久時間沒更新部落格了,在接下來的幾篇裡面我將把實現的過程的細節總結一下,希望可以給後面學習的童鞋做個參考。

正文

本文主要介紹計算圖以及前向傳播的實現, 主要涉及圖的構建以及通過對構建好的圖進行後序遍歷然後進行前向傳播計算得到具體節點上的輸出值。

先貼上一個簡單的實現效果吧:

Python
123456789 importsimpleflow assf# Create a graphwithsf.Graph().as_default():a=sf.constant(1.0,name='a')b=sf.constant(2.0,name='b')result=sf.add(a,b,name='result')# Create a session to computewithtf.Session()assess:print(sess.run(result))

計算圖(Computational Graph)

計算圖是計算代數中的一個基礎處理方法,我們可以通過一個有向圖來表示一個給定的數學表示式,並可以根據圖的特點快速方便對錶達式中的變數進行求導。而神經網路的本質就是一個多層複合函式, 因此也可以通過一個圖來表示其表示式。

本部分主要總結計算圖的實現,在計算圖這個有向圖中,每個節點代表著一種特定的運算例如求和,乘積,向量乘積,平方等等… 例如求和表示式$f(x,y)=x+y$使用有向圖表示為:

 

表示式$f(x,y,z)=z(x+y)$使用有向圖表示為:

與TensorFlow的實現不同,為了簡化,在SimpleFlow中我並沒有定義Tensor類來表示計算圖中節點之間的資料流動,而是直接定義節點的型別,其中主要定義了四種類型來表示圖中的節點:

  1. Operation: 操作節點主要接受一個或者兩個輸入節點然後進行簡單的操作運算,例如上圖中的加法操作和乘法操作等。
  2. Variable: 沒有輸入節點的節點,此節點包含的資料在運算過程中是可以變化的。
  3. Constant: 類似Variable節點,也沒有輸入節點,此節點中的資料在圖的運算過程中不會發生變化
  4. Placeholder: 同樣沒有輸入節點,此節點的資料是通過圖建立好以後通過使用者傳入的

其實圖中的所有節點都可以看成是某種操作,其中VariableConstantPlaceholder都是一種特殊的操作,只是相對於普通的Operation而言,他們沒有輸入,但是都會有輸出(像上圖中的xxyy節點,他們本身輸出自身的值到++節點中去),通常會輸出到Operation節點,進行進一步的計算。

下面我們主要介紹如何實現計算圖的基本元件: 節點和邊。

Operation節點

節點表示操作,邊代表節點接收和輸出的資料,操作節點需要含有以下屬性:

  1. input_nodes: 輸入節點,裡面存放與當前節點相連線的輸入節點的引用
  2. output_nodes: 輸出節點, 存放以當前節點作為輸入的節點,也就是當前節點的去向
  3. output_value: 儲存當前節點的數值, 如果是Add節點,此變數就儲存兩個輸入節點output_value的和
  4. name: 當前節點的名稱
  5. graph: 此節點所屬的圖
Python
123456789101112131415161718192021222324252627282930313233343536 classOperation(object):''' Base class for all operations in simpleflow.    An operation is a node in computational graph receiving zero or more nodes    as input and produce zero or more nodes as output. Vertices could be an    operation, variable or placeholder.    '''def__init__(self,*input_nodes,name=None):''' Operation constructor.        :param input_nodes: Input nodes for the operation node.        :type input_nodes: Objects of `Operation`, `Variable` or `Placeholder`.        :param name: The operation name.        :type name: str.        '''# Nodes received by this operation.self.input_nodes=input_nodes# Nodes that receive this operation node as input.self.output_nodes=[]# Output value of this operation in session execution.self.output_value=None# Operation name.self.name=name# Graph the operation belongs to.self.graph=DEFAULT_GRAPH# Add this operation node to destination lists in its input nodes.fornode ininput_nodes:node.output_nodes.append(self)# Add this operation to default graph.self.graph.operations.append(self)defcompute_output(self):''' Compute and return the output value of the operation.        '''raiseNotImplementedErrordefcompute_gradient(self,grad=None):''' Compute and return the gradient of the operation wrt inputs.        '''raiseNotImplementedError

在初始化方法中除了定義上面提到的屬性外,還需要進行兩個操作:

  1. 將當前節點的引用新增到他輸入節點的output_nodes這樣可以在輸入節點中找到當前節點。
  2. 將當前節點的引用新增到圖中,方便後面對圖中的資源進行回收等操作

另外,每個操作節點還有兩個必須的方法: comput_outputcompute_gradient. 他們分別負責根據輸入節點的值計算當前節點的輸出值和根據操作屬性和當前節點的值計算梯度。關於梯度的計算將在後續的文章中詳細介紹,本文只對節點輸出值的計算進行介紹。

下面我以求和操作為例來說明具體操作節點的實現:

Python
12345678910111213141516171819 classAdd(Operation):''' An addition operation.    '''def__init__(self,x,y,name=None):''' Addition constructor.        :param x: The first input node.        :type x: Object of `Operation`, `Variable` or `Placeholder`.        :param y: The second input node.        :type y: Object of `Operation`, `Variable` or `Placeholder`.        :param name: The operation name.        :type name: str.        '''super(self.__class__,self).__init__(x,y,name=name)defcompute_output(self):''' Compute and return the value of addition operation.        '''x,y=self.input_nodesself.output_value=np.add(x.output_value,y.output_value)returnself.output_value

可見,計算當前節點output_value的值的前提條件就是他的輸入節點的值在此之前已經計算得到了

Variable節點

Operation節點類似,Variable節點也需要output_valueoutput_nodes等屬性,但是它沒有輸入節點,也就沒有input_nodes屬性了,而是需要在建立的時候確定一個初始值initial_value:

Python
123456789101112131415161718192021222324252627282930 classVariable(object):''' Variable node in computational graph.    '''def__init__(self,initial_value=None,name=None,trainable=True):''' Variable constructor.        :param initial_value: The initial value of the variable.        :type initial_value: number or a ndarray.        :param name: Name of the variable.        :type name: str.        '''# Variable initial value.self.initial_value=initial_value# Output value of this operation in session execution.self.output_value=None# Nodes that receive this variable node as input.self.output_nodes=[]# Variable name.self.name=name# Graph the variable belongs to.self.graph=DEFAULT_GRAPH# Add to the currently active default graph.self.graph.variables.append(self)iftrainable:self.graph.trainable_variables.append(self)defcompute_output(self):''' Compute and return the variable value.        '''ifself.output_value isNone:self.output_value=self.initial_valuereturnself.output_value

Constant節點和Placeholder節點

計算圖物件

在定義了圖中的節點後我們需要將定義好的節點放入到一個圖中統一保管,因此就需要定義一個Graph類來存放建立的節點,方便統一操作圖中節點的資源。

Python
12345678 classGraph(object):''' Graph containing all computing nodes.    '''def__init__(self):''' Graph constructor.        '''self.operations,self.constants,self.placeholders=[],[],[]self.variables,self.trainable_variables=[],[]

為了提供一個預設的圖,在匯入simpleflow模組的時候建立一個全域性變數來引用預設的圖:

Python
1234 from.graph importGraph# Create a default graph.importbuiltinsDEFAULT_GRAPH=builtins.DEFAULT_GRAPH=Graph()

為了模仿TensorFlow的介面,我們給Graph新增上下文管理器協議方法使其成為一個上下文管理器, 同時也新增一個as_default方法:

Python
123456789101112131415161718 classGraph(object):#...def__enter__(self):''' Reset default graph.        '''globalDEFAULT_GRAPHself.old_graph=DEFAULT_GRAPHDEFAULT_GRAPH=selfreturnselfdef__exit__(self,exc_type,exc_value,exc_tb):''' Recover default graph.        '''globalDEFAULT_GRAPHDEFAULT_GRAPH=self.old_graphdefas_default(self):''' Set this graph as global default graph.        '''returnself

這樣在進入with程式碼塊之前先儲存舊的預設圖物件然後將當前圖賦值給全域性圖物件,這樣with程式碼塊中的節點預設會新增到當前的圖中。最後退出with程式碼塊時再對圖進行恢復即可。這樣我們可以按照TensorFlow的方式來在某個圖中建立節點.

Ok,根據上