1. 程式人生 > >【Tensorflow】怎樣為你的網路預加工和打包訓練資料?(二):小資料集的處理方案

【Tensorflow】怎樣為你的網路預加工和打包訓練資料?(二):小資料集的處理方案

實驗環境:python2.7

第二篇我們來講一講小資料集的處理方法,小資料集一般多以文字儲存為主,csv是一種流行的資料格式,另外也有txt等。當然也會有.mat或者.npy這種經過處理的格式。

一.處理csv格式資料集

實驗資料集是鳶尾花卉資料集iris,格式是.csv,需要的同學可以到這裡下載

為了工程需要我直接介紹讀取該型別資料的最快方法,通過一些庫,我們是可以用很少的步驟就讀取進來訓練的,這裡用到的是一個各種資料操作方法的集合庫,pandas。

下載pandas:

sudo pip install pandas

然後匯入:
import pandas

使用read_csv函式快速讀取一個csv檔案,到底有多方便?一句話就夠了
data = pandas.read_csv("iris.csv")

此時返回的data我們可以看看它是長什麼樣的:

我們再對比一下,csv檔案中的資料:

這時候你應該發現問題了,讀取csv的時候預設把第一行作為列標題讀進來了,導致後續的資料就不對了,顯然一句話搞定的東西會出現很多問題。注意資料集的特殊性,iris資料集是不帶有標題列的,所以我們就要說明一下,新增這一個引數:

data = pandas.read_csv("iris.csv", header=None)


現在輸出就對了,可以看到系統自動為列生成了一組索引,當然我們可以自定義索引的名字:

data = pandas.read_csv("iris.csv", header=None, prefix='col')


在數字前面加字串

也可以分別指定具體的名字:

data = pandas.read_csv("iris.csv", header=None, 
                       names=['atr1','atr2','atr3','atr4','label'])


讓我們列印資料的格式看看:

print type(data)
print type(data["atr1"])
print type(data["atr1"][0])


可以看到具體元素的值是numpy的,但是其餘的都還是pandas的自帶格式,怎麼轉換呢,如下:

train_data = data.as_matrix(columns=['atr1','atr2','atr3','atr4'])
label = data.as_matrix(columns=['label'])
print train_data,label

這樣我們就把指定的幾列轉換為numpy陣列了,但是,還是會出現一個問題,讀取csv預設的元素type是np.float64,也就是說label也是np.float64型別的,處理方案可以對讀取完畢的numpy陣列處理,也可以讀取的時候處理,如下:
data = pandas.read_csv("iris.csv", header=None, 
                       names=['atr1','atr2','atr3','atr4','label'], 
                       dtype={'label':np.int8})

完整程式如下,這裡我用了np.squeeze來去掉長度為1的維度,這個應該好理解:
import pandas
import numpy as np
data = pandas.read_csv("iris.csv", header=None, 
                       names=['atr1','atr2','atr3','atr4','label'], 
                       dtype={'label':np.int8})

train_data = data.as_matrix(columns=['atr1','atr2','atr3','atr4'])
label = data.as_matrix(columns=['label'])
label = np.squeeze(label)

就這麼幾行,資料集就匯入了!

二.txt的處理方法

和上面類似,txt檔案也是可以用read_csv來處理的,因為兩者的根本區別只是分隔符不同而已,舉一個例子:在我的用tensorflow實現usps和mnist資料集的遷移學習使用到的資料集usps,我們將它下載下來,手工刪除第一行10 256的分類說明和尾行的-1

因為這兩行會影響我們結果的生成,然後呼叫:

data = pandas.read_csv("usps_train.jf", sep='\s+', header=None)

資料就生成好了,這裡我們指定了sep分割符的型別是空格或者多於一個空格,總共7291個樣本,第一列為標籤,後面256列分別表示畫素值。

當然你也可以像我在用tensorflow實現usps和mnist資料集的遷移學習中的做法一樣,用python原生的方法讀取,秀一秀你的程式碼技術偷笑偷笑,但是做工程的話,還是以方便為主,一句話就搞定的事,何樂而不為呢?

三.延伸

補充一下,遇到csv較大記憶體不夠的情況,可以嘗試使用read_csv中的分成chunk分塊讀取的方案,這裡我就不描述了(搞deep learning的我相信大家的記憶體都很大,不會被小小几個G難住吧,哈哈)

附上分塊讀取的解決方案,和read_csv函式引數的詳解