1. 程式人生 > >L2範數懲罰項,高維線性迴歸

L2範數懲罰項,高維線性迴歸

%matplotlib inline
import mxnet
from mxnet import nd,autograd
from mxnet import gluon,init
from mxnet.gluon import data as gdata,loss as gloss,nn
import gluonbook as gb



n_train, n_test, num_inputs = 20,100,200

true_w = nd.ones((num_inputs, 1)) * 0.01
true_b = 0.05

features = nd.random.normal(shape=(n_train+n_test, num_inputs))
labels 
= nd.dot(features,true_w) + true_b labels += nd.random.normal(scale=0.01, shape=labels.shape) train_feature = features[:n_train,:] test_feature = features[n_train:,:] train_labels = labels[:n_train] test_labels = labels[n_train:] #print(features,train_feature,test_feature) # 初始化模型引數 def init_params(): w
= nd.random.normal(scale=1, shape=(num_inputs, 1)) b = nd.zeros(shape=(1,)) w.attach_grad() b.attach_grad() return [w,b] # 定義,訓練,測試 batch_size = 1 num_epochs = 100 lr = 0.03 train_iter = gdata.DataLoader(gdata.ArrayDataset(train_feature,train_labels),batch_size=batch_size,shuffle=True)
# 定義網路 def linreg(X, w, b): return nd.dot(X,w) + b # 損失函式 def squared_loss(y_hat, y): """Squared loss.""" return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2 # L2 範數懲罰 def l2_penalty(w): return (w**2).sum() / 2 def sgd(params, lr, batch_size): for param in params: param[:] = param - lr * param.grad / batch_size def fit_and_plot(lambd): w, b = init_params() train_ls, test_ls = [], [] for _ in range(num_epochs): for X, y in train_iter: with autograd.record(): # 添加了 L2 範數懲罰項。 l = squared_loss(linreg(X, w, b), y) + lambd * l2_penalty(w) l.backward() sgd([w, b], lr, batch_size) train_ls.append(squared_loss(linreg(train_feature, w, b), train_labels).mean().asscalar()) test_ls.append(squared_loss(linreg(test_feature, w, b), test_labels).mean().asscalar()) gb.semilogy(range(1, num_epochs + 1), train_ls, 'epochs', 'loss', range(1, num_epochs + 1), test_ls, ['train', 'test']) print('L2 norm of w:', w.norm().asscalar())
fit_and_plot(0)
fit_and_plot(3)

訓練集太少,容易出現過擬合,即訓練集loss遠小於測試集loss,解決方案,權重衰減——(L2範數正則化)

例如線性迴歸:

loss(w1,w2,b) = 1/n * sum(x1w1 + x2w2 + b - y)^2 /2 ,平方損失函式。

權重引數 w = [w1,w2],

新損失函式 loss(w1,w2,b) += lambd / 2n *||w||^2

迭代方程: