Polynomial regression

Polynomial regression

In linear regression, we are looking for a straight line to fit the data as best as possible. But we are not satisfied with simple linear regression in most cases. In the case of this special linear regression as shown in the figure below, this special regression method is called polynomial regression (Polynomial regression).

The following data:

import numpy as np
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (8,6)

x = np.random.uniform(-3, 3, size = 100)
y = 2 * x ** 2 + 3 * x + 3 + np.random.normal(0, 1, size = 100) # add a little noise

plt.scatter(x, y)

If you use ordinary linear regression:

from sklearn.linear_model import LinearRegression

X = x.reshape(-1,1)
lin_reg = LinearRegression()
lin_reg.fit(X, y)
y_pred = lin_reg.predict(X)
plt.scatter(x, y)
plt.scatter(x, y_pred, color ='r')

It can be seen that it is obviously not good to use linear regression to fit. In order to solve this problem, you can add a feature of X squared:

X2 = np.hstack([X, X**2])
lin_reg2 = LinearRegression()
lin_reg2.fit(X2, y)
y_pred2 = lin_reg2.predict(X2)
plt.scatter(x, y)
plt.scatter(x, y_pred2, color ='r')

In fact, there is a encapsulated method in sklearn ( sklearn.preprocessing.PolynomialFeatures), we don't have to generate this feature ourselves:

from sklearn.preprocessing import PolynomialFeatures
poly = PolynomialFeatures(degree=2) # Add a few square features
X2 = poly.transform(X)

# Training
lin_reg = LinearRegression()
lin_reg.fit(X2, y)
y_pred = lin_reg.predict(X2)
plt.scatter(x, y)
plt.scatter(x, y_pred, color ='r')

You can also write to the pipeline and call, which is more convenient:

from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler

poly_reg = Pipeline([
    ("poly", PolynomialFeatures(degree=2)),
    ("std_scaler", StandardScaler()),
    ("lin_reg", LinearRegression())

y_pred = poly_reg.predict(X)
Reference: https://cloud.tencent.com/developer/article/1734927 Polynomial regression-Cloud+Community-Tencent Cloud