diff --git a/docs/api/stats/sql.md b/docs/api/stats/sql.md
new file mode 100644
index 0000000000..fe7c0e90ef
--- /dev/null
+++ b/docs/api/stats/sql.md
@@ -0,0 +1,31 @@
+## Overview
+
+Sedona's stats module provides Scala and Python functions for conducting geospatial
+statistical analysis on dataframes with spatial columns.
+The stats module is built on top of the core module and provides a set of functions
+that can be used to perform spatial analysis on these dataframes. The stats module
+is designed to be used with the core module and the viz module to provide a
+complete set of geospatial analysis tools.
+
+## Using DBSCAN
+
+The DBSCAN function is provided at `org.apache.sedona.stats.DBSCAN.dbscan` in scala/java and `sedona.stats.dbscan.dbscan` in python.
+
+The function annotates a dataframe with a cluster label for each data record using the DBSCAN algorithm.
+The dataframe should contain at least one `GeometryType` column. Rows must be unique. If one
+geometry column is present it will be used automatically. If two are present, the one named
+'geometry' will be used. If more than one are present and none are named 'geometry', the
+column name must be provided. The new column will be named 'cluster'.
+
+### Parameters
+
+names in parentheses are python variable names
+
+- dataframe - dataframe to cluster. Must contain at least one GeometryType column
+- epsilon - minimum distance parameter of DBSCAN algorithm
+- minPts (min_pts) - minimum number of points parameter of DBSCAN algorithm
+- geometry - name of the geometry column
+- includeOutliers (include_outliers) - whether to include outliers in the output. Default is false
+- useSpheroid (use_spheroid) - whether to use a cartesian or spheroidal distance calculation. Default is false
+
+The output is the input DataFrame with the cluster label added to each row. Outlier will have a cluster value of -1 if included.
diff --git a/docs/tutorial/sql.md b/docs/tutorial/sql.md
index 7cf36cab40..ae71b10e9c 100644
--- a/docs/tutorial/sql.md
+++ b/docs/tutorial/sql.md
@@ -739,6 +739,60 @@ The coordinates of polygons have been changed. The output will be like this:
```
+## Cluster with DBSCAN
+
+Sedona provides an implementation of the [DBSCAN](https://en.wikipedia.org/wiki/Dbscan) algorithm to cluster spatial data.
+
+The algorithm is available as a Scala and Python function called on a spatial dataframe. The returned dataframe has an additional column added containing the unique identifier of the cluster that record is a member of and a boolean column indicating if the record is a core point.
+
+The first parameter is the dataframe, the next two are the epsilon and min_points parameters of the DBSCAN algorithm.
+
+=== "Scala"
+
+ ```scala
+ import org.apache.sedona.stats.DBSCAN.dbscan
+
+ dbscan(df, 0.1, 5).show()
+ ```
+
+=== "Java"
+
+ ```java
+ import org.apache.sedona.stats.DBSCAN;
+
+ DBSCAN.dbscan(df, 0.1, 5).show();
+ ```
+
+=== "Python"
+
+ ```python
+ from sedona.stats.dbscan import dbscan
+
+ dbscan(df, 0.1, 5).show()
+ ```
+
+The output will look like this:
+
+```
++----------------+---+------+-------+
+| geometry| id|isCore|cluster|
++----------------+---+------+-------+
+| POINT (2.5 4)| 3| false| 1|
+| POINT (3 4)| 2| false| 1|
+| POINT (3 5)| 5| false| 1|
+| POINT (1 3)| 9| true| 0|
+| POINT (2.5 4.5)| 7| true| 1|
+| POINT (1 2)| 1| true| 0|
+| POINT (1.5 2.5)| 4| true| 0|
+| POINT (1.2 2.5)| 8| true| 0|
+| POINT (1 2.5)| 11| true| 0|
+| POINT (1 5)| 10| false| -1|
+| POINT (5 6)| 12| false| -1|
+|POINT (12.8 4.5)| 6| false| -1|
+| POINT (4 3)| 13| false| -1|
++----------------+---+------+-------+
+```
+
## Run spatial queries
After creating a Geometry type column, you are able to run spatial queries.
diff --git a/mkdocs.yml b/mkdocs.yml
index 149742f6ca..68ce12073c 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -83,6 +83,8 @@ nav:
- Parameter: api/sql/Parameter.md
- RDD (core):
- Scala/Java doc: api/java-api.md
+ - Stats:
+ - DataFrame: api/stats/sql.md
- Viz:
- DataFrame/SQL: api/viz/sql.md
- RDD: api/viz/java-api.md
diff --git a/pom.xml b/pom.xml
index f4bca2f27d..d512b2cccc 100644
--- a/pom.xml
+++ b/pom.xml
@@ -85,6 +85,7 @@
3.3.0
3.3
2.17.2
+ 0.8.3-spark3.4
1.19.0
1.7.36
@@ -394,6 +395,10 @@
true
+
+ Spark Packages
+ https://repos.spark-packages.org/
+
@@ -578,6 +583,8 @@
${scala.compat.version}
${spark.version}
${scala.version}
+ ${log4j.version}
+ ${graphframe.version}
@@ -686,6 +693,7 @@
3.0.3
3.0
2.17.2
+ 0.8.1-spark3.0
true
@@ -703,6 +711,7 @@
3.1.2
3.1
2.17.2
+ 0.8.2-spark3.1
true
@@ -720,6 +729,7 @@
3.2.0
3.2
2.17.2
+ 0.8.2-spark3.2
true
@@ -738,6 +748,7 @@
3.3.0
3.3
2.17.2
+ 0.8.3-spark3.4
@@ -752,6 +763,7 @@
3.4.0
3.4
2.19.0
+ 0.8.3-spark3.4
true
@@ -768,6 +780,7 @@
3.5.0
3.5
2.20.0
+ 0.8.3-spark3.5
true
diff --git a/python/Pipfile b/python/Pipfile
index 3440e38da2..6c7e142b38 100644
--- a/python/Pipfile
+++ b/python/Pipfile
@@ -10,6 +10,8 @@ jupyter="*"
mkdocs="*"
pytest-cov = "*"
+scikit-learn = "*"
+
[packages]
pandas="<=1.5.3"
numpy="<2"
diff --git a/python/sedona/stats/__init__.py b/python/sedona/stats/__init__.py
new file mode 100644
index 0000000000..a67d5ea255
--- /dev/null
+++ b/python/sedona/stats/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/python/sedona/stats/clustering/__init__.py b/python/sedona/stats/clustering/__init__.py
new file mode 100644
index 0000000000..d2399abf8a
--- /dev/null
+++ b/python/sedona/stats/clustering/__init__.py
@@ -0,0 +1,21 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""The clustering module contains spark based implementations of popular geospatial clustering algorithms.
+
+These implementations are designed to scale to larger datasets and support various geometric feature types.
+"""
diff --git a/python/sedona/stats/clustering/dbscan.py b/python/sedona/stats/clustering/dbscan.py
new file mode 100644
index 0000000000..bb816e61aa
--- /dev/null
+++ b/python/sedona/stats/clustering/dbscan.py
@@ -0,0 +1,68 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""DBSCAN is a popular clustering algorithm for spatial data.
+
+It identifies groups of data where enough records are close enough to each other. This implementation leverages spark,
+sedona and graphframes to support large scale datasets and various, heterogeneous geometric feature types.
+"""
+from typing import Optional
+
+from pyspark.sql import DataFrame, SparkSession
+
+ID_COLUMN_NAME = "__id"
+DEFAULT_MAX_SAMPLE_SIZE = 1000000 # 1 million
+
+
+def dbscan(
+ dataframe: DataFrame,
+ epsilon: float,
+ min_pts: int,
+ geometry: Optional[str] = None,
+ include_outliers: bool = True,
+ use_spheroid=False,
+):
+ """Annotates a dataframe with a cluster label for each data record using the DBSCAN algorithm.
+
+ The dataframe should contain at least one GeometryType column. Rows must be unique. If one geometry column is
+ present it will be used automatically. If two are present, the one named 'geometry' will be used. If more than one
+ are present and neither is named 'geometry', the column name must be provided.
+
+ Args:
+ dataframe: spark dataframe containing the geometries
+ epsilon: minimum distance parameter of DBSCAN algorithm
+ min_pts: minimum number of points parameter of DBSCAN algorithm
+ geometry: name of the geometry column
+ include_outliers: whether to return outlier points. If True, outliers are returned with a cluster value of -1.
+ Default is False
+ use_spheroid: whether to use a cartesian or spheroidal distance calculation. Default is false
+
+ Returns:
+ A PySpark DataFrame containing the cluster label for each row
+ """
+ sedona = SparkSession.getActiveSession()
+
+ result_df = sedona._jvm.org.apache.sedona.stats.clustering.DBSCAN.dbscan(
+ dataframe._jdf,
+ float(epsilon),
+ min_pts,
+ geometry,
+ include_outliers,
+ use_spheroid,
+ )
+
+ return DataFrame(result_df, sedona)
diff --git a/python/tests/stats/__init__.py b/python/tests/stats/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/python/tests/stats/test_dbscan.py b/python/tests/stats/test_dbscan.py
new file mode 100644
index 0000000000..60cc8a991e
--- /dev/null
+++ b/python/tests/stats/test_dbscan.py
@@ -0,0 +1,247 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pyspark.sql.functions as f
+import pytest
+
+from itertools import product
+from sedona.sql.st_constructors import ST_MakePoint
+from sedona.sql.st_functions import ST_Buffer
+from sklearn.cluster import DBSCAN as sklearnDBSCAN
+from sedona.stats.clustering.dbscan import dbscan
+
+from tests.test_base import TestBase
+
+
+class TestDBScan(TestBase):
+
+ @pytest.fixture
+ def sample_data(self):
+ return [
+ {"id": 1, "x": 1.0, "y": 2.0},
+ {"id": 2, "x": 3.0, "y": 4.0},
+ {"id": 3, "x": 2.5, "y": 4.0},
+ {"id": 4, "x": 1.5, "y": 2.5},
+ {"id": 5, "x": 3.0, "y": 5.0},
+ {"id": 6, "x": 12.8, "y": 4.5},
+ {"id": 7, "x": 2.5, "y": 4.5},
+ {"id": 8, "x": 1.2, "y": 2.5},
+ {"id": 9, "x": 1.0, "y": 3.0},
+ {"id": 10, "x": 1.0, "y": 5.0},
+ {"id": 11, "x": 1.0, "y": 2.5},
+ {"id": 12, "x": 5.0, "y": 6.0},
+ {"id": 13, "x": 4.0, "y": 3.0},
+ ]
+
+ @pytest.fixture
+ def sample_dataframe(self, sample_data):
+ return (
+ self.spark.createDataFrame(sample_data)
+ .select(ST_MakePoint("x", "y").alias("arealandmark"), "id")
+ .repartition(9)
+ )
+
+ def get_expected_result(self, input_data, epsilon, min_pts, include_outliers=True):
+ labels = (
+ sklearnDBSCAN(eps=epsilon, min_samples=min_pts)
+ .fit([[datum["x"], datum["y"]] for datum in input_data])
+ .labels_
+ )
+ expected = [(x[0] + 1, x[1]) for x in list(enumerate(labels))]
+ clusters = [x for x in set(labels) if (x != -1 or include_outliers)]
+ cluster_members = {
+ frozenset([y[0] for y in expected if y[1] == x]) for x in clusters
+ }
+ return cluster_members
+
+ def get_actual_results(
+ self,
+ input_data,
+ epsilon,
+ min_pts,
+ geometry=None,
+ id=None,
+ include_outliers=True,
+ ):
+ result = dbscan(
+ input_data, epsilon, min_pts, geometry, include_outliers=include_outliers
+ )
+
+ result.show()
+
+ id = id or "id"
+ clusters_members = [
+ (x[id], x.cluster)
+ for x in result.collect()
+ if x.cluster != -1 or include_outliers
+ ]
+
+ result.unpersist()
+
+ clusters = {
+ frozenset([y[0] for y in clusters_members if y[1] == x])
+ for x in set([y[1] for y in clusters_members])
+ }
+
+ return clusters
+
+ @pytest.mark.parametrize("epsilon", [0.6, 0.7, 0.8])
+ @pytest.mark.parametrize("min_pts", [3, 4, 5])
+ def test_dbscan_valid_parameters(
+ self, sample_data, sample_dataframe, epsilon, min_pts
+ ):
+ # repeated broadcast joins with this small data size use a lot of RAM on broadcast references
+ self.spark.conf.set("sedona.join.autoBroadcastJoinThreshold", -1)
+ self.spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1)
+
+ assert self.get_expected_result(
+ sample_data, epsilon, min_pts
+ ) == self.get_actual_results(sample_dataframe, epsilon, min_pts)
+
+ def test_dbscan_valid_parameters_default_column_name(
+ self, sample_data, sample_dataframe
+ ):
+ # repeated broadcast joins with this small data size use a lot of RAM on broadcast references
+ self.spark.conf.set("sedona.join.autoBroadcastJoinThreshold", -1)
+ self.spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1)
+
+ df = sample_dataframe.select(
+ "id", f.col("arealandmark").alias("geometryFieldName")
+ )
+ epsilon = 0.6
+ min_pts = 4
+
+ assert self.get_expected_result(
+ sample_data, epsilon, min_pts
+ ) == self.get_actual_results(df, epsilon, min_pts)
+
+ def test_dbscan_valid_parameters_polygons(self, sample_data, sample_dataframe):
+ # repeated broadcast joins with this small data size use a lot of RAM on broadcast references
+ self.spark.conf.set("sedona.join.autoBroadcastJoinThreshold", -1)
+ self.spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1)
+
+ df = sample_dataframe.select(
+ "id", ST_Buffer(f.col("arealandmark"), 0.000001).alias("geometryFieldName")
+ )
+ epsilon = 0.6
+ min_pts = 4
+
+ assert self.get_expected_result(
+ sample_data, epsilon, min_pts
+ ) == self.get_actual_results(df, epsilon, min_pts)
+
+ def test_dbscan_supports_other_distance_function(self, sample_dataframe):
+ # repeated broadcast joins with this small data size use a lot of RAM on broadcast references
+ self.spark.conf.set("sedona.join.autoBroadcastJoinThreshold", -1)
+ self.spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1)
+
+ df = sample_dataframe.select(
+ "id", ST_Buffer(f.col("arealandmark"), 0.000001).alias("geometryFieldName")
+ )
+ epsilon = 0.6
+ min_pts = 4
+
+ dbscan(
+ df,
+ epsilon,
+ min_pts,
+ "geometryFieldName",
+ use_spheroid=True,
+ )
+
+ def test_dbscan_invalid_epsilon(self, sample_dataframe):
+ with pytest.raises(Exception):
+ dbscan(sample_dataframe, -0.1, 5, "arealandmark")
+
+ def test_dbscan_invalid_min_pts(self, sample_dataframe):
+ with pytest.raises(Exception):
+ dbscan(sample_dataframe, 0.1, -5, "arealandmark")
+
+ def test_dbscan_invalid_geometry_column(self, sample_dataframe):
+ with pytest.raises(Exception):
+ dbscan(sample_dataframe, 0.1, 5, "invalid_column")
+
+ def test_return_empty_df_when_no_clusters(self, sample_data, sample_dataframe):
+ # repeated broadcast joins with this small data size use a lot of RAM on broadcast references
+ self.spark.conf.set("sedona.join.autoBroadcastJoinThreshold", -1)
+ self.spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1)
+
+ epsilon = 0.1
+ min_pts = 10000
+
+ assert (
+ dbscan(
+ sample_dataframe,
+ epsilon,
+ min_pts,
+ "arealandmark",
+ include_outliers=False,
+ ).count()
+ == 0
+ )
+ # picked some coefficient we know yields clusters and thus hit the happy case
+ assert (
+ dbscan(
+ sample_dataframe,
+ epsilon,
+ min_pts,
+ "arealandmark",
+ include_outliers=False,
+ ).schema
+ == dbscan(sample_dataframe, 0.6, 3, "arealandmark").schema
+ )
+
+ def test_dbscan_doesnt_duplicate_border_points_in_two_clusters(self):
+ # repeated broadcast joins with this small data size use a lot of RAM on broadcast references
+ self.spark.conf.set("sedona.join.autoBroadcastJoinThreshold", -1)
+ self.spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1)
+
+ input_df = self.spark.createDataFrame(
+ [
+ {"id": 10, "x": 1.0, "y": 1.8},
+ {"id": 11, "x": 1.0, "y": 1.9},
+ {"id": 12, "x": 1.0, "y": 2.0},
+ {"id": 13, "x": 1.0, "y": 2.1},
+ {"id": 14, "x": 2.0, "y": 2.0},
+ {"id": 15, "x": 3.0, "y": 1.9},
+ {"id": 16, "x": 3.0, "y": 2.0},
+ {"id": 17, "x": 3.0, "y": 2.1},
+ {"id": 18, "x": 3.0, "y": 2.2},
+ ]
+ ).select(ST_MakePoint("x", "y").alias("geometry"), "id")
+
+ # make sure no id occurs more than once
+ output_df = dbscan(input_df, 1.0, 4)
+
+ assert output_df.count() == 9
+ assert output_df.select("cluster").distinct().count() == 2
+
+ def test_return_outliers_false_doesnt_return_outliers(
+ self, sample_data, sample_dataframe
+ ):
+ # repeated broadcast joins with this small data size use a lot of RAM on broadcast references
+ self.spark.conf.set("sedona.join.autoBroadcastJoinThreshold", -1)
+ self.spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1)
+
+ epsilon = 0.6
+ min_pts = 3
+
+ assert self.get_expected_result(
+ sample_data, epsilon, min_pts, include_outliers=False
+ ) == self.get_actual_results(
+ sample_dataframe, epsilon, min_pts, include_outliers=False
+ )
diff --git a/python/tests/test_base.py b/python/tests/test_base.py
index 7742d31464..8bd4b6461f 100644
--- a/python/tests/test_base.py
+++ b/python/tests/test_base.py
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
+from tempfile import mkdtemp
from sedona.spark import *
from sedona.utils.decorators import classproperty
@@ -27,6 +28,8 @@ def spark(self):
spark = SedonaContext.create(
SedonaContext.builder().master("local[*]").getOrCreate()
)
+ spark.sparkContext.setCheckpointDir(mkdtemp())
+
setattr(self, "__spark", spark)
return getattr(self, "__spark")
diff --git a/spark/common/pom.xml b/spark/common/pom.xml
index 948a35efe9..774b85c24f 100644
--- a/spark/common/pom.xml
+++ b/spark/common/pom.xml
@@ -157,6 +157,11 @@
+
+ graphframes
+ graphframes
+ ${graphframe.version}-s_${scala.compat.version}
+
org.scala-lang
scala-library
diff --git a/spark/common/src/main/scala/org/apache/sedona/stats/Util.scala b/spark/common/src/main/scala/org/apache/sedona/stats/Util.scala
new file mode 100644
index 0000000000..cdfe5fca23
--- /dev/null
+++ b/spark/common/src/main/scala/org/apache/sedona/stats/Util.scala
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sedona.stats
+
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
+
+private[stats] object Util {
+ def getGeometryColumnName(dataframe: DataFrame): String = {
+ val geomFields = dataframe.schema.fields.filter(_.dataType == GeometryUDT)
+
+ if (geomFields.isEmpty)
+ throw new IllegalArgumentException(
+ "No GeometryType column found. Provide a dataframe containing a geometry column.")
+
+ if (geomFields.length == 1)
+ return geomFields.head.name
+
+ if (geomFields.length > 1 && !geomFields.exists(_.name == "geometry"))
+ throw new IllegalArgumentException(
+ "Multiple GeometryType columns found. Provide the column name as an argument.")
+
+ "geometry"
+ }
+}
diff --git a/spark/common/src/main/scala/org/apache/sedona/stats/clustering/DBSCAN.scala b/spark/common/src/main/scala/org/apache/sedona/stats/clustering/DBSCAN.scala
new file mode 100644
index 0000000000..5bc691c2dd
--- /dev/null
+++ b/spark/common/src/main/scala/org/apache/sedona/stats/clustering/DBSCAN.scala
@@ -0,0 +1,154 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sedona.stats.clustering
+
+import org.apache.sedona.stats.Util.getGeometryColumnName
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
+import org.apache.spark.sql.sedona_sql.expressions.st_functions.{ST_Distance, ST_DistanceSpheroid}
+import org.apache.spark.sql.{Column, DataFrame, SparkSession}
+import org.graphframes.GraphFrame
+
+object DBSCAN {
+
+ private val ID_COLUMN = "__id"
+
+ /**
+ * Annotates a dataframe with a cluster label for each data record using the DBSCAN algorithm.
+ * The dataframe should contain at least one GeometryType column. Rows must be unique. If one
+ * geometry column is present it will be used automatically. If two are present, the one named
+ * 'geometry' will be used. If more than one are present and neither is named 'geometry', the
+ * column name must be provided. The new column will be named 'cluster'.
+ *
+ * @param dataframe
+ * dataframe to cluster. Must contain at least one GeometryType column
+ * @param epsilon
+ * minimum distance parameter of DBSCAN algorithm
+ * @param minPts
+ * minimum number of points parameter of DBSCAN algorithm
+ * @param geometry
+ * name of the geometry column
+ * @param includeOutliers
+ * whether to include outliers in the output. Default is false
+ * @param useSpheroid
+ * whether to use a cartesian or spheroidal distance calculation. Default is false
+ * @return
+ * The input DataFrame with the cluster label added to each row. Outlier will have a cluster
+ * value of -1 if included.
+ */
+ def dbscan(
+ dataframe: DataFrame,
+ epsilon: Double,
+ minPts: Int,
+ geometry: String = null,
+ includeOutliers: Boolean = true,
+ useSpheroid: Boolean = false): DataFrame = {
+
+ // We want to disable broadcast joins because the broadcast reference were using too much driver memory
+ val spark = SparkSession.getActiveSession.get
+
+ val geometryCol = geometry match {
+ case null => getGeometryColumnName(dataframe)
+ case _ => geometry
+ }
+
+ validateInputs(dataframe, epsilon, minPts, geometryCol)
+
+ val distanceFunction: (Column, Column) => Column =
+ if (useSpheroid) ST_DistanceSpheroid else ST_Distance
+
+ val hasIdColumn = dataframe.columns.contains("id")
+ val idDataframe = if (hasIdColumn) {
+ dataframe
+ .withColumnRenamed("id", ID_COLUMN)
+ .withColumn("id", sha2(to_json(struct("*")), 256))
+ } else {
+ dataframe.withColumn("id", sha2(to_json(struct("*")), 256))
+ }
+
+ val isCorePointsDF = idDataframe
+ .alias("left")
+ .join(
+ idDataframe.alias("right"),
+ distanceFunction(col(s"left.$geometryCol"), col(s"right.$geometryCol")) <= epsilon)
+ .groupBy(col(s"left.id"))
+ .agg(
+ first(struct("left.*")).alias("leftContents"),
+ count(col(s"right.id")).alias("neighbors_count"),
+ collect_list(col(s"right.id")).alias("neighbors"))
+ .withColumn("isCore", col("neighbors_count") >= lit(minPts))
+ .select("leftContents.*", "neighbors", "isCore")
+ .checkpoint()
+
+ val corePointsDF = isCorePointsDF.filter(col("isCore"))
+ val borderPointsDF = isCorePointsDF.filter(!col("isCore"))
+
+ val coreEdgesDf = corePointsDF
+ .select(col("id").alias("src"), explode(col("neighbors")).alias("dst"))
+ .alias("left")
+ .join(corePointsDF.alias("right"), col("left.dst") === col(s"right.id"))
+ .select(col("left.src"), col(s"right.id").alias("dst"))
+
+ val connectedComponentsDF = GraphFrame(corePointsDF, coreEdgesDf).connectedComponents.run
+
+ val borderComponentsDF = borderPointsDF
+ .select(struct("*").alias("leftContent"), explode(col("neighbors")).alias("neighbor"))
+ .join(connectedComponentsDF.alias("right"), col("neighbor") === col(s"right.id"))
+ .groupBy(col("leftContent.id"))
+ .agg(
+ first(col("leftContent")).alias("leftContent"),
+ min(col(s"right.component")).alias("component"))
+ .select("leftContent.*", "component")
+
+ val clusteredPointsDf = borderComponentsDF.union(connectedComponentsDF)
+
+ val outliersDf = idDataframe
+ .join(clusteredPointsDf, Seq("id"), "left_anti")
+ .withColumn("isCore", lit(false))
+ .withColumn("component", lit(-1))
+ .withColumn("neighbors", array().cast("array"))
+
+ val completedDf = (
+ if (includeOutliers) clusteredPointsDf.unionByName(outliersDf)
+ else clusteredPointsDf
+ ).withColumnRenamed("component", "cluster")
+
+ val returnDf = if (hasIdColumn) {
+ completedDf.drop("neighbors", "id").withColumnRenamed(ID_COLUMN, "id")
+ } else {
+ completedDf.drop("neighbors", "id")
+ }
+
+ returnDf
+
+ }
+
+ private def validateInputs(
+ geo_df: DataFrame,
+ epsilon: Double,
+ minPts: Int,
+ geometry: String): Unit = {
+ require(epsilon > 0, "epsilon must be greater than 0")
+ require(minPts > 0, "minPts must be greater than 0")
+ require(geo_df.columns.contains(geometry), "geometry column not found in dataframe")
+ require(
+ geo_df.schema.fields(geo_df.schema.fieldIndex(geometry)).dataType == GeometryUDT,
+ "geometry column must be of type GeometryType")
+ }
+}