sklearn手寫數字識別
阿新 • • 發佈:2018-11-09
import numpy as np from sklearn.datasets import load_digits from sklearn.preprocessing import LabelBinarizer from sklearn.model_selection import train_test_split from sklearn.metrics import classification_report, confusion_matrix import matplotlib.pyplot as plt # load data digits = load_digits() print digits.images.shape # show image plt.imshow(digits.images[0], cmap='gray') plt.show() # data X = digits.data # label y = digits.target print X.shape print y.shape print X[:3] print y[:3] # define a nerve network, struct: 64-100-10 # define input layer to hidden layer's value metrics V = np.random.random((64, 100)) * 2 - 1 # define hidden layer to input layer's value metrics W = np.random.random((100, 10)) * 2 - 1 # data split X_train, X_test, y_train, y_test = train_test_split(X, y) # tag binaryzation labels_train = LabelBinarizer().fit_transform(y_train) print y_train[:5] print labels_train[:5] # active function def sigmoid(x): return 1 / (1 + np.exp(-x)) # active function's derivation def dsigmoid(x): return x * (1 - x) # train model def train(X, y, steps=10000, lr=0.11): global V, W for n in range(steps + 1): # random get a data i = np.random.randint(X.shape[0]) # get a data x = X[i] x = np.atleast_2d(x) # BP # calculate hidden layer's output L1 = sigmoid(np.dot(x, V)) # calculate ouput layer's output L2 = sigmoid(np.dot(L1, W)) # calculate L2_delta, L1_delta L2_delta = (y[i] - L2) * dsigmoid(L2) L1_delta = L2_delta.dot(W.T) * dsigmoid(L1) # update weight W += lr * L1.T.dot(L2_delta) V += lr * x.T.dot(L1_delta) # train 1000 times to get a prediction if n % 1000 == 0: output = predict(X_test) predictions = np.argmax(output, axis=1) acc = np.mean(np.equal(predictions, y_test)) print 'steps: ', n, 'accuracy: ', acc def predict(x): L1 = sigmoid(np.dot(x, V)) L2 = sigmoid(np.dot(L1, W)) return L2 train(X_train, labels_train, 30000) output = predict(X_test) predictions = np.argmax(output, axis=1) print classification_report(predictions, y_test) print confusion_matrix(predictions, y_test)
(1797, 8, 8) (1797, 64) (1797,) [[ 0. 0. 5. 13. 9. 1. 0. 0. 0. 0. 13. 15. 10. 15. 5. 0. 0. 3. 15. 2. 0. 11. 8. 0. 0. 4. 12. 0. 0. 8. 8. 0. 0. 5. 8. 0. 0. 9. 8. 0. 0. 4. 11. 0. 1. 12. 7. 0. 0. 2. 14. 5. 10. 12. 0. 0. 0. 0. 6. 13. 10. 0. 0. 0.] [ 0. 0. 0. 12. 13. 5. 0. 0. 0. 0. 0. 11. 16. 9. 0. 0. 0. 0. 3. 15. 16. 6. 0. 0. 0. 7. 15. 16. 16. 2. 0. 0. 0. 0. 1. 16. 16. 3. 0. 0. 0. 0. 1. 16. 16. 6. 0. 0. 0. 0. 1. 16. 16. 6. 0. 0. 0. 0. 0. 11. 16. 10. 0. 0.] [ 0. 0. 0. 4. 15. 12. 0. 0. 0. 0. 3. 16. 15. 14. 0. 0. 0. 0. 8. 13. 8. 16. 0. 0. 0. 0. 1. 6. 15. 11. 0. 0. 0. 1. 8. 13. 15. 1. 0. 0. 0. 9. 16. 16. 5. 0. 0. 0. 0. 3. 13. 16. 16. 11. 5. 0. 0. 0. 0. 3. 11. 16. 9. 0.]] [0 1 2] [9 6 6 4 5] [[0 0 0 0 0 0 0 0 0 1] [0 0 0 0 0 0 1 0 0 0] [0 0 0 0 0 0 1 0 0 0] [0 0 0 0 1 0 0 0 0 0] [0 0 0 0 0 1 0 0 0 0]] steps: 0 accuracy: 0.08222222222222222 steps: 1000 accuracy: 0.6666666666666666 steps: 2000 accuracy: 0.7555555555555555 steps: 3000 accuracy: 0.7777777777777778 steps: 4000 accuracy: 0.7777777777777778 steps: 5000 accuracy: 0.7977777777777778 steps: 6000 accuracy: 0.8422222222222222 steps: 7000 accuracy: 0.8266666666666667 steps: 8000 accuracy: 0.8422222222222222 steps: 9000 accuracy: 0.8377777777777777 steps: 10000 accuracy: 0.8555555555555555 steps: 11000 accuracy: 0.8555555555555555 steps: 12000 accuracy: 0.8533333333333334 steps: 13000 accuracy: 0.8533333333333334 steps: 14000 accuracy: 0.8644444444444445 steps: 15000 accuracy: 0.9 steps: 16000 accuracy: 0.9333333333333333 steps: 17000 accuracy: 0.9488888888888889 steps: 18000 accuracy: 0.9488888888888889 steps: 19000 accuracy: 0.9555555555555556 steps: 20000 accuracy: 0.96 steps: 21000 accuracy: 0.9555555555555556 steps: 22000 accuracy: 0.9622222222222222 steps: 23000 accuracy: 0.9511111111111111 steps: 24000 accuracy: 0.9488888888888889 steps: 25000 accuracy: 0.9555555555555556 steps: 26000 accuracy: 0.9666666666666667 steps: 27000 accuracy: 0.9622222222222222 steps: 28000 accuracy: 0.9666666666666667 steps: 29000 accuracy: 0.9666666666666667 steps: 30000 accuracy: 0.9533333333333334 precision recall f1-score support 0 1.00 1.00 1.00 40 1 0.98 0.85 0.91 53 2 0.98 1.00 0.99 49 3 0.96 0.96 0.96 47 4 0.91 1.00 0.95 42 5 0.98 0.98 0.98 41 6 0.96 0.98 0.97 50 7 1.00 0.93 0.96 44 8 0.84 0.90 0.87 41 9 0.93 0.95 0.94 43 micro avg 0.95 0.95 0.95 450 macro avg 0.95 0.95 0.95 450 weighted avg 0.96 0.95 0.95 450 [[40 0 0 0 0 0 0 0 0 0] [ 0 45 1 0 3 0 0 0 4 0] [ 0 0 49 0 0 0 0 0 0 0] [ 0 0 0 45 0 0 0 0 0 2] [ 0 0 0 0 42 0 0 0 0 0] [ 0 0 0 0 0 40 1 0 0 0] [ 0 0 0 0 0 0 49 0 1 0] [ 0 0 0 1 1 0 0 41 1 0] [ 0 1 0 1 0 0 1 0 37 1] [ 0 0 0 0 0 1 0 0 1 41]]