1. 程式人生 > >tensorflow源碼解析之framework-shapeinference

tensorflow源碼解析之framework-shapeinference

包含 簡單 核心api 更多 新的 ram 類的成員 cef 函數的參數

目錄

  1. 核心概念
  2. ShapeInference

1. 核心概念

前面我們講到op的時候,提到了一個操作的註冊器OpRegistry,並且提到,其中註冊的數據是一個結構OpRegistrationData,這個結構中除了OpDef之外,還包含了一個OpShapeInferenceFn,這個數據是做什麽用的呢?
我們知道,op只是定義了操作的輸入輸出和參數,但並沒有定義操作具體的輸入形狀,舉個例子,MatMul操作,代表矩陣乘法,這只是一個抽象的表示,沒有具體說,這個矩陣乘法代表的是[2,3]x[3,4]=[2,4],還是[100,200]x[200,300]=[100,300]。所以在實際應用中,輸入的真實形狀我們是不知道的,但是為了產生輸出,我們必須知道輸出的形狀,好給它申請對應大小的內存空間。所以,我們需要為每一個操作,配備一個形狀推斷的函數,這就是ShapeInference的由來。

2. ShapeInference

上面提到了,操作註冊器中用到的是OpRegistrationData,而不是ShapeInference,這兩者有什麽關系呢?回想一下前面講過的OpKernelContext,其實它們的功能很像。OpKernelContext是作為OpKernel的核心API Compute函數的參數,所有計算相關的參數都會包含在這個對象中。ShapeInference也是一樣,我們把所有跟形狀推斷相關的數據和功能函數封裝在一個ShapeInference對象中,然後把這個對象傳遞給OpShapeInferenceFn,就可以實現形狀推斷。這種設計實現了數據部分和實現邏輯的解耦。
在具體看ShapeInference類之前,我們先要看一些輔助類:

class Dimension {
  private:
    //...
    const int64 value_;
};
class DimensionHandle {
  private:
    //...
    const Dimension* ptr_ = nullptr;
};
class Shape {
    //...
  private:
    const int32 rank_;
    const std::vector<DimensionHandle> dims_;
};
class ShapeHandle {
    //...
  private:
    const Shape* ptr = nullptr;
};
class DimensionOrConstant {
  public:
    //...
    DimensionHandle dim;
    int64 val;
};
class ShapeAndType {
    ShapeHandle shape;
    DataType dtype = DT_INVALID;
};

這幾個類都比較簡單。在下面用到時能夠認得就好了。
下面我們看下InferenceContext這個類:

class InferenceContext {
  public:
    InferenceContext(int graph_def_version, const NodeDef* node_def, const OpDef& op_def, const std::vector<ShapeHandle>& input_shapes, const std::vector<const Tensor*>& input_tensors, const std::vector<ShapeHandle>& input_tensors_as_shapes, std::vector<std::unique_ptr<std::vector<ShapeAndType>>> input_handle_shapes_and_types);//構造函數
    Status Run(const std::function<Status(shape_inference::InferenceContext* c)>& fn);//運行一個以this為參數的函數,沒錯,這裏運行的就是OpShapeInferenceFn
    bool MergeInput(int idx, ShapeHandle shape);
    bool RelaxInput(int idx, ShapeHandle shape);
  private:
    ShapeManager shape_manager_;
    std::vector<ShapeHandle> inputs_;
    std::vector<const Tensor*> input_tensors_;
    std::vector<bool> requested_input_tensor_;
    std::vector<ShapeHandle> outputs_;
    std::vector<ShapeHandle> input_tensors_as_shapes_;
    std::vector<bool> requested_input_tensor_as_partial_shape_;
    std::vector<std::unique_ptr<std::vector<ShapeAndType>>> input_handle_shapes_and_types_;
    std::vector<std::unique_ptr<std::vector<ShapeAndType>>> output_handle_shapes_and_types_;
    const int graph_def_version_;
    const NodeDef& node_def_;
    NameRangeMap input_name_map_;
    NameRangeMap output_name_map_;
    Status construction_status_;
};

前面已經介紹過了這個類的作用,是作為真正的形狀推斷函數的參數,為形狀推斷提供足夠的數據和功能函數支持,那麽這個類的成員就比較清晰了,首先私有的一大堆成員,為形狀推斷提供數據支持,而大量的共有API函數,為形狀推斷提供公用的功能函數,比如上面提到的MergeInput和RelaxOutput,下面我們重點介紹下這兩個函數的功能:
MergeInput函數是將輸入索引idx處的輸入與shape合並,具體的合並規則是:

  • 如果ShapeHandles是一樣的,或者shape是未知的,那麽輸入維度不變。否則,如果輸入維度是未知的,那麽輸出是shape;
  • 如果兩個形狀都是已知的,它們必須擁有相同的rank;
  • 對於任意一個維度,如果在兩個形狀中這個維度都已知,那麽它們必須相等;
  • 如果一個形狀在任意維度上的信息都多於另一個形狀,那麽擁有更多信息的形狀將被返回。否則,一個新的形狀將被構建並返回,這個新的形狀綜合了輸入的兩個形狀的信息;
  • 比如,合並[2,?]和[?,2]將得到[2,2];
  • 比如,[2,2]不能被合並到[1,2]
    如果說MergeInput函數對輸入形狀是“收縮”的,那麽“RelaxInput”函數對輸入形狀就是“擴張”的,它傾向於讓形狀變的更模糊,具體的規則是:
  • 如果ShapeHandles是一樣的,那麽對應的shape將會被返回;
  • 如果任一個ShapeHandle是未知的,那麽一個未知的ShapeHandle將會被返回;
  • 如果兩個形狀的rank已知,但不同,那麽一個未知ShapeHandle將會被返回;
  • 對於任一維度,如果任一shape是未知的,那麽對應的輸出維度也是未知的;
  • 對於任一維度,如果兩個shape對應的維度位置都是已知的,但並不相同,那麽對應的輸出維度也是未知的;
  • 如果兩個shape的rank和對應維度大小都一樣,那麽這個形狀將會被返回;
  • 例如,[2,?]和[?,2]會得到[?,?];
  • 例如,[2,2]和[3,2]會得到[?,2];
  • 例如,[2,2]和[1,2,3]會得到?

tensorflow源碼解析之framework-shapeinference