1. 程式人生 > >Simple Linear Regression in Python

Simple Linear Regression in Python

Simple Linear Regression in Python

“If you can’t explain it simply, you don’t understand it well enough.”

Simple linear regression is a statistical method that allows us to summarise and study relationships between two continuous (quantitative) variables. I hope today to prove to myself that I understand and can demonstrate linear regression by coding it from scratch in Python without using Scikit Learn.

Import the libraries

I’ll start by importing the libraries — numpy and matplotlib. I’m using Anaconda, so I’ll use %matplotlib inline to display the charts in the notebook.

Create the data

I’m going to be creating my own data set for this example, which I’ll intentionally set to have a linear relationship.

I’ll use numpy to create two arrays, X and y. The linspace function will return evenly spaced numbers over an interval that I will specify. In the parentheses enter the low number, the high number, and the number of intervals. For this, I want the

I want to add some noise to these arrays, so I’ll create a variable using numpy’s random.uniform

function. This will return a sample from a uniform distribution. In the parentheses feed in the lower boundary, the upper boundary, and the output shape.

Plot the data

Use matplotlib to plot a basic scatter chart of X and y.

Line of best fit

The line of best fit is a straight line that will go through the centre of the data points on our scatter plot. The closer the points are to the line, the stronger the correlation between the two variables. This correlation can be positive or negative.

I can clearly see that there is a strong, positive correlation between the two variables — as Xincreases, so does y. I want to prove the strength of this correlation mathematically.

The equation of a straight line is y = mx + b, where m is the slope of the line and b is the y intercept

I already have our X and y values, so now I need to calculate m and b. The formulas for these can be written as:

Maths is not my strongest suit, so for me, these formulas were pretty intimidating at first. Sidenote — for anyone interested in going in-depth into the maths behind Machine Learning algorithms, I would highly recommend Lazy Programmer’s courses on Udemy.

Back to the formulas; the denominators in both are the same — the sum of X squared, minus the mean of X multiplied by the sum of X. Rather than calculate this twice, I’ll create a denominator variable. The most efficient way of calculating the sum of X squared in numpy, is to calculate the dot product.

Now I have the denominator, I’ll write the numerators for m and b.

The numerator for m is the sum of X multiplied by y (the dot product of X and y), minus the mean of y multiplied by the sum of X.

The numerator of b is the mean of y multiplied by the sum of X squared (the dot product of X), minus the mean of X multiplied by the sum of X multiplied by y (the dot product of X and y).

I can now plug these into the linear equation to calculate the predicted y values (the line of best fit).

Using matplotlib, I can now plot the line of best fit.

Calculating R squared

Now that I have fitted the prediction line, I want to calculate how close the data is to this line by using the coefficient of determination — R squared.

SSres: I’ll start by calculating the sum of the residuals — the euclidean distance between the actual data points on the y axis and their corresponding predicted values along the regression line. I’ll square these distances; the reason for this is that if I have one value five units above the line (+5) and another five units below the line (-5), they will cancel each other out and give the impression that the data is closer to the line that it actually is. Squaring these distances ensures they are all positive values.

SStot: then I want to calculate the distance between the actual data points on the y axis, and the mean of y — again squaring the result.

So the formula for R squared is 1 - SSres / SStot:

I will start by creating variables that calculate the values in parentheses, and then finding the dot product of these.

I can then print the R squared value.

So, 91.5% of the variation in y can be explained using X.

I hope that you enjoyed this, and maybe even found it useful. Huge credit goes to my three favourite ML tutors: Jose Portilla, Lazy Programmer, and Sentdex.