Train Machine Learning algorithms with Python scikit-learn
One of the most used Python libraries for Machine Learning is scikit-learn. It provides implementation a broad range of algorithm, and is relatively simple to use out-of-the-box. Let’s look at an example of Logistic Regression with the classic ‘iris’ dataset.
# Import libraries
import pandas as pd
import seaborn as sns
import sklearn
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
Load dataset
# Load sample dataset into a DataFrame
X = sns.load_dataset('iris')
X
sepal_length | sepal_width | petal_length | petal_width | species | |
---|---|---|---|---|---|
0 | 5.1 | 3.5 | 1.4 | 0.2 | setosa |
1 | 4.9 | 3.0 | 1.4 | 0.2 | setosa |
2 | 4.7 | 3.2 | 1.3 | 0.2 | setosa |
3 | 4.6 | 3.1 | 1.5 | 0.2 | setosa |
4 | 5.0 | 3.6 | 1.4 | 0.2 | setosa |
... | ... | ... | ... | ... | ... |
145 | 6.7 | 3.0 | 5.2 | 2.3 | virginica |
146 | 6.3 | 2.5 | 5.0 | 1.9 | virginica |
147 | 6.5 | 3.0 | 5.2 | 2.0 | virginica |
148 | 6.2 | 3.4 | 5.4 | 2.3 | virginica |
149 | 5.9 | 3.0 | 5.1 | 1.8 | virginica |
150 rows × 5 columns
# Pop the target column 'species' and store in a distinct Series
y = X.pop('species')
Split into train and test sets
# Use `train_test_split()` function from sklearn to extract a 20% random test set
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
Fit model
# Instantiate a logistic regression model with default parameters
model = LogisticRegression()
# Fit model on training set
model.fit(X_train, y_train)
# Predict on test set
y_pred = model.predict(X_test)
Evaluate model performance
# Classification report
print(sklearn.metrics.classification_report(y_test, y_pred))
precision recall f1-score support
setosa 1.00 1.00 1.00 6
versicolor 0.83 1.00 0.91 10
virginica 1.00 0.86 0.92 14
accuracy 0.93 30
macro avg 0.94 0.95 0.94 30
weighted avg 0.94 0.93 0.93 30
# Plot confusion matrix
sklearn.metrics.plot_confusion_matrix(model, X_test, y_test)
<sklearn.metrics._plot.confusion_matrix.ConfusionMatrixDisplay at 0x7fb3d0117510>