Processing math: 100%


This page describes clustering algorithms in MLlib. The guide for clustering in the RDD-based API also has relevant information about these algorithms.

Table of Contents


k-means is one of the most commonly used clustering algorithms that clusters the data points into a predefined number of clusters. The MLlib implementation includes a parallelized variant of the k-means++ method called kmeans||.

KMeans is implemented as an Estimator and generates a KMeansModel as the base model.

Input Columns

Param name Type(s) Default Description
featuresCol Vector "features" Feature vector

Output Columns

Param name Type(s) Default Description
predictionCol Int "prediction" Predicted cluster center


Refer to the Scala API docs for more details.

import import

// Loads data. val dataset =“libsvm”).load(“data/mllib/sample_kmeans_data.txt”)

// Trains a k-means model. val kmeans = new KMeans().setK(2).setSeed(1L) val model =

// Make predictions val predictions = model.transform(dataset)

// Evaluate clustering by computing Silhouette score val evaluator = new ClusteringEvaluator()

val silhouette = evaluator.evaluate(predictions) println(s“Silhouette with squared euclidean distance = $silhouette”)

// Shows the result. println(“Cluster Centers: “) model.clusterCenters.foreach(println)

Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala" in the Spark repo.

Refer to the Java API docs for more details.

import; import; import; import; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row;

// Loads data. Dataset<Row> dataset =“libsvm”).load(“data/mllib/sample_kmeans_data.txt”);

// Trains a k-means model. KMeans kmeans = new KMeans().setK(2).setSeed(1L); KMeansModel model =;

// Make predictions Dataset<Row> predictions = model.transform(dataset);

// Evaluate clustering by computing Silhouette score ClusteringEvaluator evaluator = new ClusteringEvaluator();

double silhouette = evaluator.evaluate(predictions); System.out.println(“Silhouette with squared euclidean distance = “ + silhouette);

// Shows the result. Vector[] centers = model.clusterCenters(); System.out.println(“Cluster Centers: “); for (Vector center: centers) { System.out.println(center); }

Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/" in the Spark repo.

Refer to the Python API docs for more details.

from import KMeans from import ClusteringEvaluator

# Loads data. dataset =“libsvm”).load(“data/mllib/sample_kmeans_data.txt”)

# Trains a k-means model. kmeans = KMeans().setK(2).setSeed(1) model =

# Make predictions predictions = model.transform(dataset)

# Evaluate clustering by computing Silhouette score evaluator = ClusteringEvaluator()

silhouette = evaluator.evaluate(predictions) print(“Silhouette with squared euclidean distance = “ + str(silhouette))

# Shows the result. centers = model.clusterCenters() print(“Cluster Centers: “) for center in centers: print(center)

Find full example code at "examples/src/main/python/ml/" in the Spark repo.

Refer to the R API docs for more details.

# Fit a k-means model with spark.kmeans t <- training <- createDataFrame(t) df_list <- randomSplit(training, c(7,3), 2) kmeansDF <- df_list[[1]] kmeansTestDF <- df_list[[2]] kmeansModel <- spark.kmeans(kmeansDF, ~ Class + Sex + Age + Freq, k = 3)

</span># Model summary summary(kmeansModel)

</span># Get fitted result from the k-means model head(fitted(kmeansModel))

</span># Prediction kmeansPredictions <- predict(kmeansModel, kmeansTestDF) head(kmeansPredictions)

</span><div>Find full example code at “examples/src/main/r/ml/kmeans.R” in the Spark repo.</div>

Latent Dirichlet allocation (LDA)

LDA is implemented as an Estimator that supports both EMLDAOptimizer and OnlineLDAOptimizer, and generates a LDAModel as the base model. Expert users may cast a LDAModel generated by EMLDAOptimizer to a DistributedLDAModel if needed.


Refer to the Scala API docs for more details.


// Loads data. val dataset =“libsvm”) .load(“data/mllib/sample_lda_libsvm_data.txt”)

// Trains a LDA model. val lda = new LDA().setK(10).setMaxIter(10) val model =

val ll = model.logLikelihood(dataset) val lp = model.logPerplexity(dataset) println(s“The lower bound on the log likelihood of the entire corpus: $ll”) println(s“The upper bound on perplexity: $lp”)

// Describe topics. val topics = model.describeTopics(3) println(“The topics described by their top-weighted terms:”)

// Shows the result. val transformed = model.transform(dataset)

Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala" in the Spark repo.

Refer to the Java API docs for more details.

import; import; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession;

// Loads data. Dataset<Row> dataset =“libsvm”) .load(“data/mllib/sample_lda_libsvm_data.txt”);

// Trains a LDA model. LDA lda = new LDA().setK(10).setMaxIter(10); LDAModel model =;

double ll = model.logLikelihood(dataset); double lp = model.logPerplexity(dataset); System.out.println(“The lower bound on the log likelihood of the entire corpus: “ + ll); System.out.println(“The upper bound on perplexity: “ + lp);

// Describe topics. Dataset<Row> topics = model.describeTopics(3); System.out.println(“The topics described by their top-weighted terms:”);;

// Shows the result. Dataset<Row> transformed = model.transform(dataset);;

Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/" in the Spark repo.

Refer to the Python API docs for more details.

from import LDA

# Loads data. dataset =“libsvm”).load(“data/mllib/sample_lda_libsvm_data.txt”)

# Trains a LDA model. lda = LDA(k=10, maxIter=10) model =

ll = model.logLikelihood(dataset) lp = model.logPerplexity(dataset) print(“The lower bound on the log likelihood of the entire corpus: “ + str(ll)) print(“The upper bound on perplexity: “ + str(lp))

# Describe topics. topics = model.describeTopics(3) print(“The topics described by their top-weighted terms:”)

# Shows the result transformed = model.transform(dataset)

Find full example code at "examples/src/main/python/ml/" in the Spark repo.

Refer to the R API docs for more details.

# Load training data df <- read.df(“data/mllib/sample_lda_libsvm_data.txt”, source = “libsvm”) training <- df test <- df

</span># Fit a latent dirichlet allocation model with spark.lda model <- spark.lda(training, k = 10, maxIter = 10)

</span># Model summary summary(model)

</span># Posterior probabilities posterior <- spark.posterior(model, test) head(posterior)

</span># The log perplexity of the LDA model logPerplexity <- spark.perplexity(model, test) print(paste0(“The upper bound bound on perplexity: “, logPerplexity))

</span><div>Find full example code at “examples/src/main/r/ml/lda.R” in the Spark repo.</div>

Bisecting k-means

Bisecting k-means is a kind of hierarchical clustering using a divisive (or “top-down”) approach: all observations start in one cluster, and splits are performed recursively as one moves down the hierarchy.

Bisecting K-means can often be much faster than regular K-means, but it will generally produce a different clustering.

BisectingKMeans is implemented as an Estimator and generates a BisectingKMeansModel as the base model.


Refer to the Scala API docs for more details.

import import

// Loads data. val dataset =“libsvm”).load(“data/mllib/sample_kmeans_data.txt”)

// Trains a bisecting k-means model. val bkm = new BisectingKMeans().setK(2).setSeed(1) val model =

// Make predictions val predictions = model.transform(dataset)

// Evaluate clustering by computing Silhouette score val evaluator = new ClusteringEvaluator()

val silhouette = evaluator.evaluate(predictions) println(s“Silhouette with squared euclidean distance = $silhouette”)

// Shows the result. println(“Cluster Centers: “) val centers = model.clusterCenters centers.foreach(println)

Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/BisectingKMeansExample.scala" in the Spark repo.

Refer to the Java API docs for more details.

import; import; import; import; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row;

// Loads data. Dataset<Row> dataset =“libsvm”).load(“data/mllib/sample_kmeans_data.txt”);

// Trains a bisecting k-means model. BisectingKMeans bkm = new BisectingKMeans().setK(2).setSeed(1); BisectingKMeansModel model =;

// Make predictions Dataset<Row> predictions = model.transform(dataset);

// Evaluate clustering by computing Silhouette score ClusteringEvaluator evaluator = new ClusteringEvaluator();

double silhouette = evaluator.evaluate(predictions); System.out.println(“Silhouette with squared euclidean distance = “ + silhouette);

// Shows the result. System.out.println(“Cluster Centers: “); Vector[] centers = model.clusterCenters(); for (Vector center : centers) { System.out.println(center); }

Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/" in the Spark repo.

Refer to the Python API docs for more details.

from import BisectingKMeans from import ClusteringEvaluator

# Loads data. dataset =“libsvm”).load(“data/mllib/sample_kmeans_data.txt”)

# Trains a bisecting k-means model. bkm = BisectingKMeans().setK(2).setSeed(1) model =

# Make predictions predictions = model.transform(dataset)

# Evaluate clustering by computing Silhouette score evaluator = ClusteringEvaluator()

silhouette = evaluator.evaluate(predictions) print(“Silhouette with squared euclidean distance = “ + str(silhouette))

# Shows the result. print(“Cluster Centers: “) centers = model.clusterCenters() for center in centers: print(center)

Find full example code at "examples/src/main/python/ml/" in the Spark repo.

Refer to the R API docs for more details.

t <- training <- createDataFrame(t)

</span># Fit bisecting k-means model with four centers model <- spark.bisectingKmeans(training, Class ~ Survived, k = 4)

</span># get fitted result from a bisecting k-means model fitted.model <- fitted(model, “centers”)

</span># Model summary head(summary(fitted.model))

</span># fitted values on training data fitted <- predict(model, training) head(select(fitted, “Class”, “prediction”))

</span><div>Find full example code at “examples/src/main/r/ml/bisectingKmeans.R” in the Spark repo.</div>

Gaussian Mixture Model (GMM)

A Gaussian Mixture Model represents a composite distribution whereby points are drawn from one of k Gaussian sub-distributions, each with its own probability. The implementation uses the expectation-maximization algorithm to induce the maximum-likelihood model given a set of samples.

GaussianMixture is implemented as an Estimator and generates a GaussianMixtureModel as the base model.

Input Columns

Param name Type(s) Default Description
featuresCol Vector "features" Feature vector

Output Columns

Param name Type(s) Default Description
predictionCol Int "prediction" Predicted cluster center
probabilityCol Vector "probability" Probability of each cluster


Refer to the Scala API docs for more details.


// Loads data val dataset =“libsvm”).load(“data/mllib/sample_kmeans_data.txt”)

// Trains Gaussian Mixture Model val gmm = new GaussianMixture() .setK(2) val model =

// output parameters of mixture model model for (i <- 0 until model.getK) { println(s“Gaussian i:\nweight={model.weights(i)}\n” + s“mu=model.gaussians(i).mean\nsigma=\n{model.gaussians(i).cov}\n”) }

Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/GaussianMixtureExample.scala" in the Spark repo.

Refer to the Java API docs for more details.

import; import; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row;

// Loads data Dataset<Row> dataset =“libsvm”).load(“data/mllib/sample_kmeans_data.txt”);

// Trains a GaussianMixture model GaussianMixture gmm = new GaussianMixture() .setK(2); GaussianMixtureModel model =;

// Output the parameters of the mixture model for (int i = 0; i < model.getK(); i++) { System.out.printf(“Gaussian %d:\nweight=%f\nmu=%s\nsigma=\n%s\n\n”, i, model.weights()[i], model.gaussians()[i].mean(), model.gaussians()[i].cov()); }

Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/" in the Spark repo.

Refer to the Python API docs for more details.

from import GaussianMixture

# loads data dataset =“libsvm”).load(“data/mllib/sample_kmeans_data.txt”)

gmm = GaussianMixture().setK(2).setSeed(538009335) model =

print(“Gaussians shown as a DataFrame: “)

Find full example code at "examples/src/main/python/ml/" in the Spark repo.

Refer to the R API docs for more details.

# Load training data df <- read.df(“data/mllib/sample_kmeans_data.txt”, source = “libsvm”) training <- df test <- df

</span># Fit a gaussian mixture clustering model with spark.gaussianMixture model <- spark.gaussianMixture(training, ~ features, k = 2)

</span># Model summary summary(model)

</span># Prediction predictions <- predict(model, test) head(predictions)

</span><div>Find full example code at “examples/src/main/r/ml/gaussianMixture.R” in the Spark repo.</div>

Power Iteration Clustering (PIC)

Power Iteration Clustering (PIC) is a scalable graph clustering algorithm developed by Lin and Cohen. From the abstract: PIC finds a very low-dimensional embedding of a dataset using truncated power iteration on a normalized pair-wise similarity matrix of the data.’s PowerIterationClustering implementation takes the following parameters:


Refer to the Scala API docs for more details.


val dataset = spark.createDataFrame(Seq( (0L, 1L, 1.0), (0L, 2L, 1.0), (1L, 2L, 1.0), (3L, 4L, 1.0), (4L, 0L, 0.1) )).toDF(“src”, “dst”, “weight”)

val model = new PowerIterationClustering(). setK(2). setMaxIter(20). setInitMode(“degree”). setWeightCol(“weight”)

val prediction = model.assignClusters(dataset).select(“id”, “cluster”)

// Shows the cluster assignment

Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/PowerIterationClusteringExample.scala" in the Spark repo.

Refer to the Java API docs for more details.

import java.util.Arrays; import java.util.List;

import; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType;

List<Row> data = Arrays.asList( RowFactory.create(0L, 1L, 1.0), RowFactory.create(0L, 2L, 1.0), RowFactory.create(1L, 2L, 1.0), RowFactory.create(3L, 4L, 1.0), RowFactory.create(4L, 0L, 0.1) );

StructType schema = new StructType(new StructField[]{ new StructField(“src”, DataTypes.LongType, false, Metadata.empty()), new StructField(“dst”, DataTypes.LongType, false, Metadata.empty()), new StructField(“weight”, DataTypes.DoubleType, false, Metadata.empty()) });

Dataset<Row> df = spark.createDataFrame(data, schema);

PowerIterationClustering model = new PowerIterationClustering() .setK(2) .setMaxIter(10) .setInitMode(“degree”) .setWeightCol(“weight”);

Dataset<Row> result = model.assignClusters(df);;

Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/" in the Spark repo.

Refer to the Python API docs for more details.

from import PowerIterationClustering

df = spark.createDataFrame([ (0, 1, 1.0), (0, 2, 1.0), (1, 2, 1.0), (3, 4, 1.0), (4, 0, 0.1) ], [“src”, “dst”, “weight”])

pic = PowerIterationClustering(k=2, maxIter=20, initMode=“degree”, weightCol=“weight”)

# Shows the cluster assignment pic.assignClusters(df).show()

Find full example code at "examples/src/main/python/ml/" in the Spark repo.

Refer to the R API docs for more details.

df <- createDataFrame(list(list(0L, 1L, 1.0), list(0L, 2L, 1.0), list(1L, 2L, 1.0), list(3L, 4L, 1.0), list(4L, 0L, 0.1)), schema = c(“src”, “dst”, “weight”)) # assign clusters clusters <- spark.assignClusters(df, k = 2L, maxIter = 20L, initMode = “degree”, weightCol = “weight”)

</span>showDF(arrange(clusters, clusters$id))

</span><div>Find full example code at “examples/src/main/r/ml/powerIterationClustering.R” in the Spark repo.</div>