1. 程式人生 > >LR 邏輯迴歸程式碼 (梯度下降)

LR 邏輯迴歸程式碼 (梯度下降)

這幾天一直在推導邏輯迴歸的公式,就自己寫了個,發現動手寫和公式還是有點區別的。

公式推導就不贅述了,反正沒有解析解,就只能梯度下降呢,不過後面會優化成隨機梯度和牛頓法來處理,畢竟批梯度下降對於大資料那就是不行的呢。

程式碼如下,這裡稍微參考了下別人的程式碼,主要是關於權重更新,矩陣的運算,開始參考公式,老是有問題。

#coding=utf-8
#!/usr/bin/python
import pprint
from numpy import *
import matplotlib.pyplot as plt
def load_data(path):
    '''
:param path:傳遞路徑,返回樣例的資料和標籤,格式採用矩陣,便於進行矩陣運算
:return: ''' data_set=[] label_set=[] file_object=open(path) for line in file_object.readlines(): lineArr = line.strip().split() lineArr = [float(x) for x in lineArr ] label_set.append(int(lineArr[-1])) #最後一列預設為標記 lineArr[-1]=1#這裡新增1,是因為又一個常數變數w和b合併在一起. data_set.append(lineArr) #歸一化操作,暫時不動,後續新增
#data_set[0]=(data_set[0]-min(data_set[0]))/(max(data_set[0])-min(data_set[0])) return data_set,label_set class myLRregression(object): def __init__(self): ''' :這裡先初始化資料和權重 ''' self.data_set=[] self.label_set=[] self.weight=[] def sigmoid(self, inX): '''
:param inX:sigmod 函式 :return: ''' return 1.0/(1+exp(-inX)) def load_data(self,path): ''' :param path:傳遞路徑,返回樣例的資料和標籤,格式採用矩陣,便於進行矩陣運算 :return: ''' file_object=open(path) for line in file_object.readlines(): lineArr = line.strip().split() lineArr = [float(x) for x in lineArr ] self.label_set.append(int(lineArr[-1])) #最後一列預設為標記 lineArr[-1]=1#這裡新增1,是因為又一個常數變數w和b合併在一起. self.data_set.append(lineArr) def get_data_set(self): pprint.pprint(self.data_set) def train(self,train_fun): ''' :return:返回訓練好的權重 ''' max_iter=20000 #最大迭代次數 alpha=0.01 #設定變數為0.01 data_set=mat(self.data_set) #轉換為矩陣,進行計算 lable_set=mat(self.label_set).transpose() weights = mat(ones((3,1))) #權重初始化為1 print weights if(train_fun=="gradDescent"): for i in range(max_iter): loss=lable_set-self.sigmoid(data_set*weights) #這的矩陣運算稍微注意下 weights=weights+alpha*data_set.transpose()*loss print weights return weights,data_set,lable_set #self.weight=list(weights) if __name__ == '__main__': print("-------start load data-----") path="./LR/testSet.txt" LR=myLRregression() LR.load_data(path) weights,data_set,lable_set=LR.train(train_fun='gradDescent') pprint.pprint( LR.weight) pprint.pprint( LR.data_set) pprint.pprint(LR.label_set)

測試資料如下:

-0.01761214.0530640
-1.3956344.6625411
-0.7521576.5386200
-1.3223717.1528530
0.42336311.0546770
0.4067047.0673351
0.66739412.7414520
-2.4601506.8668051
0.5694119.5487550
-0.02663210.4277430
0.8504336.9203341
1.34718313.1755000
1.1768133.1670201
-1.7818719.0979530
-0.5666065.7490031
0.9316351.5895051
-0.0242056.1518231
-0.0364532.6909881
-0.1969490.4441651
1.0144595.7543991
1.9852983.2306191
-1.693453-0.5575401
-0.57652511.7789220
-0.346811-1.6787301
-2.1244842.6724711
1.2179169.5970150
-0.7339289.0986870
-3.642001-1.6180871
0.3159853.5239531
1.4166149.6192320
-0.3863233.9892861
0.5569218.2949841
1.22486311.5873600
-1.347803-2.4060511
1.1966044.9518511
0.2752219.5436470
0.4705759.3324880
-1.8895679.5426620
-1.52789312.1505790
-1.18524711.3093180
-0.4456783.2973031
1.0422226.1051551
-0.61878710.3209860
1.1520830.5484671
0.8285342.6760451
-1.23772810.5490330
-0.683565-2.1661251
0.2294565.9219381
-0.95988511.5553360
0.49291110.9933240
0.1849928.7214880
-0.35571510.3259760
-0.3978228.0583970
0.82483913.7303430
1.5072785.0278661
0.0996716.8358391
-0.34400810.7174850
1.7859287.7186451
-0.91880111.5602170
-0.3640094.7473001
-0.8417224.1190831
0.4904261.9605391
-0.0071949.0757920
0.35610712.4478630
0.34257812.2811620
-0.810823-1.4660181
2.5307776.4768011
1.29668311.6075590
0.47548712.0400350
-0.78327711.0097250
0.07479811.0236500
-1.3374720.4683391
-0.10278113.7636510
-0.1473242.8748461
0.5183899.8870350
1.0153997.5718820
-1.658086-0.0272551
1.3199442.1712281
2.0562165.0199811
-0.8516334.3756911
-1.5100476.0619920
-1.076637-3.1818881
1.82109610.2839900
3.0101508.4017661
-1.0994581.6882741
-0.834872-1.7338691
-0.8466373.8490751
1.40010212.6287810
1.7528425.4681661
0.0785570.0597361
0.089392-0.7153001
1.82566212.6938080
0.1974459.7446380
0.1261170.9223111
-0.6797971.2205301
0.6779832.5566661
0.76134910.6938620
-2.1687910.1436321
1.3886109.3419970
0.31702914.7390250