K-Means Clustering
K-Means Clustering is an unsupervised method of clustering groups data points into k clusters in which each observation belongs to the cluster with the nearest mean.
K-Means Clustering is an unsupervised learning algorithm.
Mathematical Model
The K-means algorithm clusters data by separating samples in k groups, minimizing a criterion known as the inertia or within-cluster variance sum-of-squares.
Model Training Process
Roughly, the model training process is as follows:
Select initial cluster centroids.
Assign data points to the nearest centroid as measured by euclidean distance.
Compute new cluster centroids.
Repeat steps 2-3 as needed to achieve best possible within cluster variance.
Python Example
To download the code below, click here.
""" k_means_clustering_with_scikit_learn.py clusters data points and displays the results for more information on the k-means algorithm, see: https://scikit-learn.org/stable/modules/clustering.html#k-means """ # Import needed libraries. import numpy as np import matplotlib.pyplot as plotlib from sklearn.cluster import KMeans from sklearn.datasets import make_blobs # Set number of samples and clusters. number_of_samples = 2000 number_of_clusters = 5 # Set the random number state for cluster center initialization. random_state = 150 # Set standard deviations for data points creation. standard_deviation = [1.0, 2.5, 0.5] # Create data points for clustering. # For details on the make_blobs function, see: # https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_blobs.html X_varied, y_varied = make_blobs(n_samples=number_of_samples, cluster_std=standard_deviation, random_state=random_state) # Create clusters. # For details on the KMeans function, see: # https://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html#sklearn.cluster.KMeans y_pred = KMeans(n_clusters=number_of_clusters, random_state=random_state)\ .fit_predict(X_varied) # Plot the clustered data points. plotlib.scatter(X_varied[:, 0], X_varied[:, 1], c=y_pred) plotlib.title("Data in " + str(number_of_clusters) + " Clusters") # Display the plot. plotlib.show()
Results are below: