程式設計介面的新動態:用 TensorFlow 高層 API 來構建機器學習模型 丨 Google 開發者大會 2018
Google 開發者大會 (Google Developer Days,簡稱 GDD) 是展示 Google 最新開發者產品和平臺的全球盛會,旨在幫助你快速開發優質應用,發展和留住活躍使用者群,充分利用各種工具獲得更多收益。2018 Google 開發者大會於 9 月 20 日和 21 日於上海舉辦。:point_right: ofollow,noindex">Google 開發者大會 2018 掘金專題
2018 年 9 月 21 日 馮亦菲(Google Brain 軟體工程師)帶來一場《用 TensorFlow 高層 API 來構建機器學習模型》的演講,本文將對演講做一個回顧。
如何機器學習模型

構建機器學習模型是一個循序漸進的過程:
- 首先獲取資料集,並明確機器學習所要解決的問題;
- 接下來處理資料集,讓我們的模型能夠快速的理解形式;
- 緊接著搭建機器學習模型的結構,並且訓練評估我們的模型;
- 最後,當模型達到我們預先設定的目標,就可以將其打包投入到生產環境中去。
機器學習問題
推測自然保護區型別

本例中的所使用的資料集為科羅拉多州森林植被資料集。該資料集記錄了美國科羅拉多州不同地塊的森林植被型別,每個樣本包含了描述每塊土地的若干特徵,包括海拔、坡度、到水源的距離、遮陽情況和土壤型別,並且隨同給出了地塊的已知森林植被型別。資料集下載
資料樣本
按照連結地址下載的原始資料集是以逗號分割、每行有55個整數列,如下圖:

每一列所代表的含義如下圖所示:

其中,上圖中橘黃色方框選中的區域就是我們此次想要推測的自然保護區型別。
使用 TensorFlow構建模型
在瞭解我們所有的資料集之後,可以著手使用 TensorFlow 來構建我們的模型。
第一步 import tensorflow
import tensorflow as tf 複製程式碼
接下來馮亦菲推薦使用 Eager Execution
立即執行
tf.enable_eager_execution() 複製程式碼
eager 執行模型

Eager Execution 的優點如下:
- 快速除錯即刻的執行錯誤並通過 Python 工具進行整合
- 藉助易於使用的 Python 控制流支援動態模型
- 為自定義和高階梯度提供強大支援
- 適用於幾乎所有可用的 TensorFlow 運算
使用 tensorflow
提供的 Dataset
載入資料.我們下載的資料格式為 .csv
,所以我們用的是 CsvDataset
。
dataset = tf.contrib.data.CsvDataset( filenames = ['covtype.csv.train'] record_defaults = [tf.int32] * 55) print(list(dataset.take(1))) 複製程式碼
檢視 dataset
中第一行資料。

上圖中雖然我們可以看出每一行原始資料的含義,但是想要帶入到機器學習的模型中,我們還需要對資料進行解析。
解析原始資料集
col_names = ['elevation','aspect','slope'...] //特徵名稱 def _parse_csv_row(*vals): soil_type_t = tf.convert_to_tensor(vals[14:54]) //土壤型別是一個特徵,而不是 40 個單獨的特徵 feat_vals = vals[:10] + (soil_type_t, vals[54]) //重新組成 12 個特徵 features = dict(zip(col_names, feat_vals)) //給每個特徵取一個名稱 class_label = tf.argmax(vals[10:14], axis = 0) //生成自然保護區標籤,在原始資料集中它的一個長度為 4 return features, class_label 複製程式碼
解析資料
dataset = dataset.map(_parse_csv_row).batch(64) //利用解析函式解析,並以 64 為單位對訓練集進行分組 print(list(dataset.take(1))) 複製程式碼
因為我們使用的是 Eager
執行模式,所以我們可以直接列印檢視資料,如下圖所示:

定義資料特徵
使用 feature_column
, 將原始數值變成模型可理解的數值。
# Cover_Type / integer / 1 to 7 cover_type = tf.keras.feature_column. categorical_column_with_identity( 'cover_type', num_buckets=8 ) //把離散的類別,變成對模型有意義的連續的數值 cover_embedding = tf.kears.feature_column. embedding_column(cover_type,dimension = 10) numeric_features = [tf.keras.feature_column. numeric_column(feat) for feat in numeric_cols] soil_type = tf.keras.feature_column. numeric_column(soil_type, shape = (40,)) columns = numeric_features + [ soil_type, cover_embedding] feature_layer = tf.keras.feature_column. FeatureLayer(columns) 複製程式碼
構造模型
這裡我們使用的是 keras
API 來搭建我們的模型,通過 keras
可以像搭積木一樣來構造模型。
model = tf.keras.Sequential([ feature_layer, tf.keras.layers.Dense(256, activation = tf.nn.relu), tf.keras.layers.Dense(16, activation = tf.nn.relu), tf.keras.layers.Dense(8, activation = tf.nn.relu), tf.keras.layers.Dense(4, activation = tf.nn.softmax) ]) model.fit(dataset,steps_per_epoch = NUM_TRAIN_EXAMPLES/64) 複製程式碼
訓練結果如下:

驗證模型
載入驗證資料
def load_data(*filenames): dataset = tf.contrib.data.CsvDataset( filenames,record_defaults) dataset = dataset.map(_parse_csv_row) dataset = dataset.batch(64) return dataset 複製程式碼
驗證
test_data = load_data('covtype.csv.test') loss, accury = model.evaluate( test_data, steps = 50 ) print(loss, accury) 複製程式碼
損失值與準去率如下圖所示:

輸出模型
如果驗證結果達到我們所設定的要求,可以利用 SavedModel
打包。
export_dir = tf.contrib.saved_model. save_keras_model(model, 'keras_nn') //重新訓練已訓練的模型 restored_model = tf.contrib.saved_model. save_keras_model(export_dir) 複製程式碼

至此我們完成了構造機器學習模型的完整流程。在我們日常的生產環境中,不存在一次構建完美模型的情況。所以,我們可以根據上面構建模型的 5 個步驟,就單個步驟或多這個步驟進行改進。例如,我們可以獲得跟多原始資料,以提高模型的普遍性等等。
Wide & Deep 模型
前面我們使用的是 keras
模型,下面我們使用 Wide & Deep
模型就實現。
model = tf.estimator.DNNLinearCombinedClassifier( linear_feature_columns = [cover_type, soil_type], dnn_feature_columns = numeric_features, dnn_hidden_unites = [256, 16, 8] n_classes = 4 ) //訓練 model.train( input_fn = lambda: load_data('covtype.csv.train')) //驗證 model.evaluate( input_fn = lambda: load_data('covtype.csv.test')) //輸出 features_sample = list(dataset.take(1))[0][0] input_receiver_fn = tf.estimator.export. bulid_raw_serving_input_receiver_fn( features_sample) //重新訓練已訓練的模型 model.export_saved_model( export_dir_base = 'wide_deep', serving_input_receiver_fn = input_receiver_fn) 複製程式碼
以上就是本次演講的全部內容,希望對大家有所幫助。