1. 程式人生 > >tensorflow源碼解析之common_runtime-graph_optimizer

tensorflow源碼解析之common_runtime-graph_optimizer

部分 eat moved 相互 示例 data pes unique 模式

目錄

  1. 核心概念
  2. graph_optimizer
  3. function
  4. optimization_registry

1. 核心概念

本篇主要講圖的優化叠代器。我們在構建原始圖的時候,專註於達到目的,但不會去考慮圖的執行效率。如果把圖的設計過程比喻為高級語言的編寫,那麽圖的優化過程就相當於,將高級語言編譯為機器語言的過程中,為了能夠加速進行的編譯優化。比如,將相同的常數折疊,將Identity節點去除等等。本節主要用來討論,跟圖優化相關的類和函數。

2. graph_optimizer

進行圖優化,需要有一個統一的入口,它的輸入是圖本身,以及圖執行的環境,以及優化的配置,輸出是優化後的圖。這個入口就是GraphOptimizer,我們先來看看它的結構和接口:

class GraphOptimizer {
  public:
    GraphOptimizer(const OptimizerOptions& opts);
    void Optimize(FunctionLibraryRuntime* runtime, Env* env, Device* device, std::unique_ptr<Graph>* graph, const std::unordered_map<const Node*, std::vector<PartialTensorShape>>* shape_map);
  private:
    OptimizerOptions opts_;
};

顯然,其中的Optimize就是這個類最重要的API,它將圖優化配置opts中的優化過程應用的graph上。可能會將graph替換為另外一個圖對象。device是這張圖將要運行的設備,它使得優化算法可以考慮針對設備應當考慮的優化選項。shape_map如果是非空的話,它將圖中節點的名稱映射為部分可知的節點輸出形狀,可能在某些圖優化中會被應用,比如常量折疊優化。
關於圖優化,我們需要了解的更為細致一些,所以,先看一下這個類的構造函數具體的實現方式。

GraphOptimizer::GraphOptimizer(const OptimizerOptions& opts) : opts_(opts) {
    if(opts_.opt_level()>=OptimizerOptions::L1){
        opts_.set_do_common_subexpression_elimination(true);
        opts_.set_do_constant_folding(true);
    }
}

通過這個函數我們了解到,優化配置是有級別概念的,當級別大於等於1時,某些默認的優化配置需要被開啟,比如“公共子項消除”和“常量折疊”。這些內容我們在具體的優化步驟中也會看到。下面就來看一下核心API,Optimize的內容:

void GraphOptimizer::Optimize(FunctionLibraryRuntime* runtime, Env* env, Device* device, std::unique_ptr<Graph>* graph, const std::unordered_map<const Node*, std::vector<PartialTensorShape>>* shape_map){
    Graph* g = graph->get();
    DumpGraph("Initial",g);//導出當前圖的結構
    
    bool changed = true;
    const int kMaxRounds = 10;
    for(int rounds = 0; rounds < kMaxRounds; ++rounds){
        changed = false;
        if(RemoveListArrayConverter(g)){
            DumpGraph("RemoveListArrayConverter", g);
            changed = true;
        }
        if(opts_.do_function_inlining() && RemoveDeadNodes(g)){
            DumpGraph("RemoveDeadNodes", g);
            changed = true;
        }
        if(opts_.do_function_inlining() && RemoveIdentityNodes(g)){
            DumpGraph("RemoveIdentityNodes", g);
            changed = true;
        }
        if(opts_.do_constant_folding()){
            ConstantFoldingOptions cf_opts;
            cf_opts.shape_map = shape_map;
            bool was_mutated;
            ConstantFold(cf_opts, runtime, env, device, g, &was_mutated).IgnoreError();
            if(was_mutated){
                RemoveDeadNodes(g);
                DumpGraph("ConstFolding",g);
                changed = true;
            }
        }
        if(opts_.do_function_inlining() && FixupSourceAndSinkEdges(g)){
            DumpGraph("FixupSourceAndSinkEdges",g);
            changed = true;
        }
        if(opts_.do_common_subexpression_elimination() && OptimizeCSE(g,nullptr)){
            DumpGraph("ExpandInlineFunctions",g);
            changed = true;
        }
        if(!changed) break;
    }
    
    //由於flib_def永遠不會消失,因此我們可以放心的使用它來構建新圖
    std::unique_ptr<Graph> copy(new Graph(g->flib_def()));
    CopyGraph(*g, copy.get());
    graph->swap(copy);
    
    DumpGraph("ReCopy", graph->get());
}

在對圖進行優化時,我們不可能一蹴而就的,因為優化之間會相互影響,比如我們對圖進行了A優化,對於A優化來說,此時圖已經是最優的了,但之後我們又對圖進行了B優化,此時對於B優化來說,圖已經是最優的了,但對於A優化來說則未必。因此圖優化是一個循環上升的過程,TF設置了最高的優化是10遍,對於大多數圖來說,也就足夠了。
在圖優化的過程中,我們發現了很多之前沒見過的函數,這些函數的定義都在function.h文件中,為了加深對於圖優化過程的理解,下面我們了解下這個文件中的函數。

3. function

function.h文件中,沒有類定義,全部都是硬生生的函數定義,幹貨滿滿。

//kernel生成器,根據FunctionLibraryRuntime和NodeDef來生成kernel
typedef std::function<Status(FunctionLibraryRuntime*, const NodeDef&, std::unique_ptr<OpKernel>*)> CustomKernelCreator;
void RegisterDefaultCustomKernelCreator(CusteomKernelCreator cb);//kernel生成器的註冊器

//創建一個FunctionLibraryRuntime,用來實例化lib_def中的函數,並在device上運行,如果custom_kernel_creator是非空的,它會被返回的runtime用來生成kernel
std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(const DeviceMgr* device_mgr, Env* env, Device* device, int graph_def_version, const FunctionLibraryDefinition* lib_def, const OptimizerOptions& optimizer_options, CusteomKernelCreator custom_kernel_creator);

//與之前的函數類似,只不過返回的runtime直接利用RegisterDefaultCustomKernelCreator註冊的全局custom_kernel_creator來生成新的kernel
std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(const DeviceMgr* device_mgr, Env* env, Device* device, int graph_def_version, const FunctionLibraryDefinition* lib_def, const OptimizerOptions& optimizer_options);

//函數體的內容
struct FunctionBody {
    FunctionDef fdef;
    Graph* graph = nullptr;
    DataTypeVector arg_types;
    DataTypeVector ret_types;
    gtl::InlinedVector<Node*, 4> arg_nodes;
    gtl::InlinedVector<Node*, 4> ret_nodes;
    
    FuntionBody(){}
    FunctionBody(const FunctionDef& f, DataTypeSlice arg_types, DataTypeSlice ret_types, Graph* g);
    ~FunctionBody();
};

//刪除以下節點,第一,無狀態的,第二,無參數的,第三,對輸出無貢獻的
bool RemoveDeadNodes(Graph* g);

//尋找如下的模式,src-(in)->node-(out)->dst,如果node是identity節點,in是唯一的輸入數據邊,out是唯一的輸出數據邊,則使用src->dst重寫以上模式
bool RemoveIdentityNodes(Graph* g);

//將圖中的_ListToArray和_ArrayToList轉化為Identity節點
bool RemoveListArrayConverter(Graph* g);

//對於圖中的每個節點,如果lib指明這個節點是一個函數調用,那麽內聯這個函數體。如果至少一個節點被內聯了,返回true。
bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph);

//將graph中的內容導出到日誌文件,如果日誌級別足夠高的話
void DumpGraph(StringPiece label, const Graph* g);

//應用圖重寫的優化,例如內聯、死節點移除等
void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr<Graph>* g);

//將一個函數的圖轉化為GraphDef
void ToGraphDef(const Graph* g, GraphDef* gdef, bool pretty = false);

//給定一個數值函數,返回它的導數函數
FunctionBody* SymbolicGradient(const FunctionBody& f);

//將一個FunctionDef示例化為一個graph,設置fbody指向擁有FunctionDef的FunctionBody
Status FunctionDefToBodyHelper(const FunctionDef& fdef, const AttrSlice& attrs, const FunctionLibraryDefinition* const lib_def, const std::function<Status(const string&, const OpDef**)>& get_func_sig, FunctionBody** fbody);

現在回過頭來看GraphOptimizer類中的Optimize函數,首先它把Array和List相互轉換節點變為Identity節點,然後刪除了死節點,刪除Identity節點,進行常量折疊,修復輸入輸出邊,進行公共子項消除,最終完成了對圖的優化。

4. optimization_registry

optimization_registry.h文件中,包含了一些維護一個全局的圖優化遍歷註冊器所需要的類,在會話初始化一張圖時,會使用這個全局優化遍歷註冊器來對圖進行優化。
首先我們來看第一個類,GraphOptimizationPassOptions,顧名思義,它包含了圖優化遍歷所需要的參數。這些足夠作為一個字典的鍵值,我們通常會使用一個字典來保持各個圖優化遍歷器的狀態。

struct GraphOptimizationPassOptions {
    string session_handle;
    const SessionOptions* session_options = nullptr;
    const CostModel* cost_model = nullptr;
    FunctionLibraryDefinition* flib_def = nullptr;
    const DeviceSet* device_set = nullptr;
    //如果優化遍歷在圖分割之前被使用,那麽它優化的對象就是這個graph,如果是圖分割之後被使用,那麽這個graph是null
    std::unique_ptr<Graph>* graph = nullptr;
    //進行圖分割後的優化遍歷時使用
    std::unordered_map<string, std::unique_ptr<Graph>* partition_graphs = nullptr;
};

圖優化遍歷,按照在圖分割之前還是之後進行,可以分為兩類,但我們使用了GraphOptimizationPassOptions這樣一個接口。
接下來是GraphOptimizationPass類,所有的圖優化遍歷類,都是這個類的子類,它的結構也非常簡單。

class GraphOptimizationPass {
  public:
    virtual ~GraphOptimizationPass() {}
    virtual Status Run(const GraphOptimizationPassOption& options) = 0;
};

當我們擁有了多種圖優化遍歷的算法之後,需要對這些進行統一管理,因此TF提出了一種對圖優化遍歷算法進行統一註冊和管理的類:

//這裏的鍵值為phase,圖優化遍歷算法是按照phase的升序順序執行的,在一個phase內部,執行順序是未定義的
typedef std::map<int, std::vector<std::unique_ptr<GraphOptimizationPass>>> GraphOptimizationPasses;

class OptimizationPassRegistry {
  public:
    enum Grouping {
        PRE_PLACEMENT,//在cost model賦值之後,在節點放置算法之前
        POST_PLACEMENT,//在節點放置算法之後
        POST_REWRITE_FOR_EXEC,//在利用feed/fetch節點進行重寫之後
        POST_PARTITIONING,//在圖分割之後
    };
    void Register(Grouping grouping, int phase, std::unique_ptr<GraphOptimizationPass> pass);//註冊圖優化遍歷算法
    Status RunGrouping(Grouping grouping, const GraphOptimizationPassOptions& options);//運行一個groupping中所有的圖優化遍歷算法,按照phase的升序運行
    static OptimizationPassRegistry* Global();//返回一個全局的圖優化遍歷註冊器
  private:
    std::map<Grouping, GraphOptimizationPasses> groups_;
};

總結一下,groups是一個雙層的映射,先從Grouping映射到圖優化遍歷算法組,這個算法組本身也是個映射,從phase映射到真正的圖優化遍歷算法,如下:

graph LR
Grouping-->GraphOptimizationPasses
phase-->GraphOptimizationPass

最後,TF為剛才的註冊器提供了一個全局的入口:

class OptimizationPassRegistration {
  public:
    OptimizationPassRegistration(OptimizationPassRegistry::Grouping grouping, int phase, std::unique_ptr<GraphOptimizationPass> pass){
        OptimizationPassRegistry::Global->Register(grouping,phase,std::move(pass));
    }
};

tensorflow源碼解析之common_runtime-graph_optimizer