< prev | next >

K-Means Clustering

K-Means Clustering is an unsupervised method of clustering groups data points into k clusters in which each observation belongs to the cluster with the nearest mean

K-Means Clustering is an unsupervised learning algorithm.

Mathematical Model

The K-means algorithm clusters data by separating samples in k groups, minimizing a criterion known as the inertia or within-cluster variance sum-of-squares.

Model Training Process

Roughly, the model training process is as follows:

  1. Select initial cluster centroids.

  2. Assign data points to the nearest centroid as measured by euclidean distance.

  3. Compute new cluster centroids.

  4. Repeat steps 2-3 as needed to achieve best possible within cluster variance.

Python Example

To download the code below, click here.

"""
k_means_clustering_with_scikit_learn.py
clusters data points and displays the results
for more information on the k-means algorithm, see:
https://scikit-learn.org/stable/modules/clustering.html#k-means
"""

# Import needed libraries.
import numpy as np
import matplotlib.pyplot as plotlib
from sklearn.cluster import KMeans
from sklearn.datasets import make_blobs

# Set number of samples and clusters.
number_of_samples = 2000
number_of_clusters = 5

# Set the random number state for cluster center initialization.
random_state = 150

# Set standard deviations for data points creation.
standard_deviation = [1.0, 2.5, 0.5]

# Create data points for clustering.
# For details on the make_blobs function, see:
# https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_blobs.html
X_varied, y_varied = make_blobs(n_samples=number_of_samples,
                                cluster_std=standard_deviation,
                                random_state=random_state)

# Create clusters.
# For details on the KMeans function, see:
# https://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html#sklearn.cluster.KMeans
y_pred = KMeans(n_clusters=number_of_clusters,
                random_state=random_state)\
                .fit_predict(X_varied)

# Plot the clustered data points.
plotlib.scatter(X_varied[:, 0], X_varied[:, 1], c=y_pred)
plotlib.title("Data in " + str(number_of_clusters) + " Clusters")

# Display the plot.
plotlib.show()

Results are below: