diff --git a/docs/api/stats/sql.md b/docs/api/stats/sql.md new file mode 100644 index 0000000000..fe7c0e90ef --- /dev/null +++ b/docs/api/stats/sql.md @@ -0,0 +1,31 @@ +## Overview + +Sedona's stats module provides Scala and Python functions for conducting geospatial +statistical analysis on dataframes with spatial columns. +The stats module is built on top of the core module and provides a set of functions +that can be used to perform spatial analysis on these dataframes. The stats module +is designed to be used with the core module and the viz module to provide a +complete set of geospatial analysis tools. + +## Using DBSCAN + +The DBSCAN function is provided at `org.apache.sedona.stats.DBSCAN.dbscan` in scala/java and `sedona.stats.dbscan.dbscan` in python. + +The function annotates a dataframe with a cluster label for each data record using the DBSCAN algorithm. +The dataframe should contain at least one `GeometryType` column. Rows must be unique. If one +geometry column is present it will be used automatically. If two are present, the one named +'geometry' will be used. If more than one are present and none are named 'geometry', the +column name must be provided. The new column will be named 'cluster'. + +### Parameters + +names in parentheses are python variable names + +- dataframe - dataframe to cluster. Must contain at least one GeometryType column +- epsilon - minimum distance parameter of DBSCAN algorithm +- minPts (min_pts) - minimum number of points parameter of DBSCAN algorithm +- geometry - name of the geometry column +- includeOutliers (include_outliers) - whether to include outliers in the output. Default is false +- useSpheroid (use_spheroid) - whether to use a cartesian or spheroidal distance calculation. Default is false + +The output is the input DataFrame with the cluster label added to each row. Outlier will have a cluster value of -1 if included. diff --git a/docs/tutorial/sql.md b/docs/tutorial/sql.md index 7cf36cab40..ae71b10e9c 100644 --- a/docs/tutorial/sql.md +++ b/docs/tutorial/sql.md @@ -739,6 +739,60 @@ The coordinates of polygons have been changed. The output will be like this: ``` +## Cluster with DBSCAN + +Sedona provides an implementation of the [DBSCAN](https://en.wikipedia.org/wiki/Dbscan) algorithm to cluster spatial data. + +The algorithm is available as a Scala and Python function called on a spatial dataframe. The returned dataframe has an additional column added containing the unique identifier of the cluster that record is a member of and a boolean column indicating if the record is a core point. + +The first parameter is the dataframe, the next two are the epsilon and min_points parameters of the DBSCAN algorithm. + +=== "Scala" + + ```scala + import org.apache.sedona.stats.DBSCAN.dbscan + + dbscan(df, 0.1, 5).show() + ``` + +=== "Java" + + ```java + import org.apache.sedona.stats.DBSCAN; + + DBSCAN.dbscan(df, 0.1, 5).show(); + ``` + +=== "Python" + + ```python + from sedona.stats.dbscan import dbscan + + dbscan(df, 0.1, 5).show() + ``` + +The output will look like this: + +``` ++----------------+---+------+-------+ +| geometry| id|isCore|cluster| ++----------------+---+------+-------+ +| POINT (2.5 4)| 3| false| 1| +| POINT (3 4)| 2| false| 1| +| POINT (3 5)| 5| false| 1| +| POINT (1 3)| 9| true| 0| +| POINT (2.5 4.5)| 7| true| 1| +| POINT (1 2)| 1| true| 0| +| POINT (1.5 2.5)| 4| true| 0| +| POINT (1.2 2.5)| 8| true| 0| +| POINT (1 2.5)| 11| true| 0| +| POINT (1 5)| 10| false| -1| +| POINT (5 6)| 12| false| -1| +|POINT (12.8 4.5)| 6| false| -1| +| POINT (4 3)| 13| false| -1| ++----------------+---+------+-------+ +``` + ## Run spatial queries After creating a Geometry type column, you are able to run spatial queries. diff --git a/mkdocs.yml b/mkdocs.yml index 149742f6ca..68ce12073c 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -83,6 +83,8 @@ nav: - Parameter: api/sql/Parameter.md - RDD (core): - Scala/Java doc: api/java-api.md + - Stats: + - DataFrame: api/stats/sql.md - Viz: - DataFrame/SQL: api/viz/sql.md - RDD: api/viz/java-api.md diff --git a/pom.xml b/pom.xml index f4bca2f27d..d512b2cccc 100644 --- a/pom.xml +++ b/pom.xml @@ -85,6 +85,7 @@ 3.3.0 3.3 2.17.2 + 0.8.3-spark3.4 1.19.0 1.7.36 @@ -394,6 +395,10 @@ true + + Spark Packages + https://repos.spark-packages.org/ + @@ -578,6 +583,8 @@ ${scala.compat.version} ${spark.version} ${scala.version} + ${log4j.version} + ${graphframe.version} @@ -686,6 +693,7 @@ 3.0.3 3.0 2.17.2 + 0.8.1-spark3.0 true @@ -703,6 +711,7 @@ 3.1.2 3.1 2.17.2 + 0.8.2-spark3.1 true @@ -720,6 +729,7 @@ 3.2.0 3.2 2.17.2 + 0.8.2-spark3.2 true @@ -738,6 +748,7 @@ 3.3.0 3.3 2.17.2 + 0.8.3-spark3.4 @@ -752,6 +763,7 @@ 3.4.0 3.4 2.19.0 + 0.8.3-spark3.4 true @@ -768,6 +780,7 @@ 3.5.0 3.5 2.20.0 + 0.8.3-spark3.5 true diff --git a/python/Pipfile b/python/Pipfile index 3440e38da2..6c7e142b38 100644 --- a/python/Pipfile +++ b/python/Pipfile @@ -10,6 +10,8 @@ jupyter="*" mkdocs="*" pytest-cov = "*" +scikit-learn = "*" + [packages] pandas="<=1.5.3" numpy="<2" diff --git a/python/sedona/stats/__init__.py b/python/sedona/stats/__init__.py new file mode 100644 index 0000000000..a67d5ea255 --- /dev/null +++ b/python/sedona/stats/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/python/sedona/stats/clustering/__init__.py b/python/sedona/stats/clustering/__init__.py new file mode 100644 index 0000000000..d2399abf8a --- /dev/null +++ b/python/sedona/stats/clustering/__init__.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""The clustering module contains spark based implementations of popular geospatial clustering algorithms. + +These implementations are designed to scale to larger datasets and support various geometric feature types. +""" diff --git a/python/sedona/stats/clustering/dbscan.py b/python/sedona/stats/clustering/dbscan.py new file mode 100644 index 0000000000..bb816e61aa --- /dev/null +++ b/python/sedona/stats/clustering/dbscan.py @@ -0,0 +1,68 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""DBSCAN is a popular clustering algorithm for spatial data. + +It identifies groups of data where enough records are close enough to each other. This implementation leverages spark, +sedona and graphframes to support large scale datasets and various, heterogeneous geometric feature types. +""" +from typing import Optional + +from pyspark.sql import DataFrame, SparkSession + +ID_COLUMN_NAME = "__id" +DEFAULT_MAX_SAMPLE_SIZE = 1000000 # 1 million + + +def dbscan( + dataframe: DataFrame, + epsilon: float, + min_pts: int, + geometry: Optional[str] = None, + include_outliers: bool = True, + use_spheroid=False, +): + """Annotates a dataframe with a cluster label for each data record using the DBSCAN algorithm. + + The dataframe should contain at least one GeometryType column. Rows must be unique. If one geometry column is + present it will be used automatically. If two are present, the one named 'geometry' will be used. If more than one + are present and neither is named 'geometry', the column name must be provided. + + Args: + dataframe: spark dataframe containing the geometries + epsilon: minimum distance parameter of DBSCAN algorithm + min_pts: minimum number of points parameter of DBSCAN algorithm + geometry: name of the geometry column + include_outliers: whether to return outlier points. If True, outliers are returned with a cluster value of -1. + Default is False + use_spheroid: whether to use a cartesian or spheroidal distance calculation. Default is false + + Returns: + A PySpark DataFrame containing the cluster label for each row + """ + sedona = SparkSession.getActiveSession() + + result_df = sedona._jvm.org.apache.sedona.stats.clustering.DBSCAN.dbscan( + dataframe._jdf, + float(epsilon), + min_pts, + geometry, + include_outliers, + use_spheroid, + ) + + return DataFrame(result_df, sedona) diff --git a/python/tests/stats/__init__.py b/python/tests/stats/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/tests/stats/test_dbscan.py b/python/tests/stats/test_dbscan.py new file mode 100644 index 0000000000..60cc8a991e --- /dev/null +++ b/python/tests/stats/test_dbscan.py @@ -0,0 +1,247 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pyspark.sql.functions as f +import pytest + +from itertools import product +from sedona.sql.st_constructors import ST_MakePoint +from sedona.sql.st_functions import ST_Buffer +from sklearn.cluster import DBSCAN as sklearnDBSCAN +from sedona.stats.clustering.dbscan import dbscan + +from tests.test_base import TestBase + + +class TestDBScan(TestBase): + + @pytest.fixture + def sample_data(self): + return [ + {"id": 1, "x": 1.0, "y": 2.0}, + {"id": 2, "x": 3.0, "y": 4.0}, + {"id": 3, "x": 2.5, "y": 4.0}, + {"id": 4, "x": 1.5, "y": 2.5}, + {"id": 5, "x": 3.0, "y": 5.0}, + {"id": 6, "x": 12.8, "y": 4.5}, + {"id": 7, "x": 2.5, "y": 4.5}, + {"id": 8, "x": 1.2, "y": 2.5}, + {"id": 9, "x": 1.0, "y": 3.0}, + {"id": 10, "x": 1.0, "y": 5.0}, + {"id": 11, "x": 1.0, "y": 2.5}, + {"id": 12, "x": 5.0, "y": 6.0}, + {"id": 13, "x": 4.0, "y": 3.0}, + ] + + @pytest.fixture + def sample_dataframe(self, sample_data): + return ( + self.spark.createDataFrame(sample_data) + .select(ST_MakePoint("x", "y").alias("arealandmark"), "id") + .repartition(9) + ) + + def get_expected_result(self, input_data, epsilon, min_pts, include_outliers=True): + labels = ( + sklearnDBSCAN(eps=epsilon, min_samples=min_pts) + .fit([[datum["x"], datum["y"]] for datum in input_data]) + .labels_ + ) + expected = [(x[0] + 1, x[1]) for x in list(enumerate(labels))] + clusters = [x for x in set(labels) if (x != -1 or include_outliers)] + cluster_members = { + frozenset([y[0] for y in expected if y[1] == x]) for x in clusters + } + return cluster_members + + def get_actual_results( + self, + input_data, + epsilon, + min_pts, + geometry=None, + id=None, + include_outliers=True, + ): + result = dbscan( + input_data, epsilon, min_pts, geometry, include_outliers=include_outliers + ) + + result.show() + + id = id or "id" + clusters_members = [ + (x[id], x.cluster) + for x in result.collect() + if x.cluster != -1 or include_outliers + ] + + result.unpersist() + + clusters = { + frozenset([y[0] for y in clusters_members if y[1] == x]) + for x in set([y[1] for y in clusters_members]) + } + + return clusters + + @pytest.mark.parametrize("epsilon", [0.6, 0.7, 0.8]) + @pytest.mark.parametrize("min_pts", [3, 4, 5]) + def test_dbscan_valid_parameters( + self, sample_data, sample_dataframe, epsilon, min_pts + ): + # repeated broadcast joins with this small data size use a lot of RAM on broadcast references + self.spark.conf.set("sedona.join.autoBroadcastJoinThreshold", -1) + self.spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1) + + assert self.get_expected_result( + sample_data, epsilon, min_pts + ) == self.get_actual_results(sample_dataframe, epsilon, min_pts) + + def test_dbscan_valid_parameters_default_column_name( + self, sample_data, sample_dataframe + ): + # repeated broadcast joins with this small data size use a lot of RAM on broadcast references + self.spark.conf.set("sedona.join.autoBroadcastJoinThreshold", -1) + self.spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1) + + df = sample_dataframe.select( + "id", f.col("arealandmark").alias("geometryFieldName") + ) + epsilon = 0.6 + min_pts = 4 + + assert self.get_expected_result( + sample_data, epsilon, min_pts + ) == self.get_actual_results(df, epsilon, min_pts) + + def test_dbscan_valid_parameters_polygons(self, sample_data, sample_dataframe): + # repeated broadcast joins with this small data size use a lot of RAM on broadcast references + self.spark.conf.set("sedona.join.autoBroadcastJoinThreshold", -1) + self.spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1) + + df = sample_dataframe.select( + "id", ST_Buffer(f.col("arealandmark"), 0.000001).alias("geometryFieldName") + ) + epsilon = 0.6 + min_pts = 4 + + assert self.get_expected_result( + sample_data, epsilon, min_pts + ) == self.get_actual_results(df, epsilon, min_pts) + + def test_dbscan_supports_other_distance_function(self, sample_dataframe): + # repeated broadcast joins with this small data size use a lot of RAM on broadcast references + self.spark.conf.set("sedona.join.autoBroadcastJoinThreshold", -1) + self.spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1) + + df = sample_dataframe.select( + "id", ST_Buffer(f.col("arealandmark"), 0.000001).alias("geometryFieldName") + ) + epsilon = 0.6 + min_pts = 4 + + dbscan( + df, + epsilon, + min_pts, + "geometryFieldName", + use_spheroid=True, + ) + + def test_dbscan_invalid_epsilon(self, sample_dataframe): + with pytest.raises(Exception): + dbscan(sample_dataframe, -0.1, 5, "arealandmark") + + def test_dbscan_invalid_min_pts(self, sample_dataframe): + with pytest.raises(Exception): + dbscan(sample_dataframe, 0.1, -5, "arealandmark") + + def test_dbscan_invalid_geometry_column(self, sample_dataframe): + with pytest.raises(Exception): + dbscan(sample_dataframe, 0.1, 5, "invalid_column") + + def test_return_empty_df_when_no_clusters(self, sample_data, sample_dataframe): + # repeated broadcast joins with this small data size use a lot of RAM on broadcast references + self.spark.conf.set("sedona.join.autoBroadcastJoinThreshold", -1) + self.spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1) + + epsilon = 0.1 + min_pts = 10000 + + assert ( + dbscan( + sample_dataframe, + epsilon, + min_pts, + "arealandmark", + include_outliers=False, + ).count() + == 0 + ) + # picked some coefficient we know yields clusters and thus hit the happy case + assert ( + dbscan( + sample_dataframe, + epsilon, + min_pts, + "arealandmark", + include_outliers=False, + ).schema + == dbscan(sample_dataframe, 0.6, 3, "arealandmark").schema + ) + + def test_dbscan_doesnt_duplicate_border_points_in_two_clusters(self): + # repeated broadcast joins with this small data size use a lot of RAM on broadcast references + self.spark.conf.set("sedona.join.autoBroadcastJoinThreshold", -1) + self.spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1) + + input_df = self.spark.createDataFrame( + [ + {"id": 10, "x": 1.0, "y": 1.8}, + {"id": 11, "x": 1.0, "y": 1.9}, + {"id": 12, "x": 1.0, "y": 2.0}, + {"id": 13, "x": 1.0, "y": 2.1}, + {"id": 14, "x": 2.0, "y": 2.0}, + {"id": 15, "x": 3.0, "y": 1.9}, + {"id": 16, "x": 3.0, "y": 2.0}, + {"id": 17, "x": 3.0, "y": 2.1}, + {"id": 18, "x": 3.0, "y": 2.2}, + ] + ).select(ST_MakePoint("x", "y").alias("geometry"), "id") + + # make sure no id occurs more than once + output_df = dbscan(input_df, 1.0, 4) + + assert output_df.count() == 9 + assert output_df.select("cluster").distinct().count() == 2 + + def test_return_outliers_false_doesnt_return_outliers( + self, sample_data, sample_dataframe + ): + # repeated broadcast joins with this small data size use a lot of RAM on broadcast references + self.spark.conf.set("sedona.join.autoBroadcastJoinThreshold", -1) + self.spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1) + + epsilon = 0.6 + min_pts = 3 + + assert self.get_expected_result( + sample_data, epsilon, min_pts, include_outliers=False + ) == self.get_actual_results( + sample_dataframe, epsilon, min_pts, include_outliers=False + ) diff --git a/python/tests/test_base.py b/python/tests/test_base.py index 7742d31464..8bd4b6461f 100644 --- a/python/tests/test_base.py +++ b/python/tests/test_base.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. +from tempfile import mkdtemp from sedona.spark import * from sedona.utils.decorators import classproperty @@ -27,6 +28,8 @@ def spark(self): spark = SedonaContext.create( SedonaContext.builder().master("local[*]").getOrCreate() ) + spark.sparkContext.setCheckpointDir(mkdtemp()) + setattr(self, "__spark", spark) return getattr(self, "__spark") diff --git a/spark/common/pom.xml b/spark/common/pom.xml index 948a35efe9..774b85c24f 100644 --- a/spark/common/pom.xml +++ b/spark/common/pom.xml @@ -157,6 +157,11 @@ + + graphframes + graphframes + ${graphframe.version}-s_${scala.compat.version} + org.scala-lang scala-library diff --git a/spark/common/src/main/scala/org/apache/sedona/stats/Util.scala b/spark/common/src/main/scala/org/apache/sedona/stats/Util.scala new file mode 100644 index 0000000000..cdfe5fca23 --- /dev/null +++ b/spark/common/src/main/scala/org/apache/sedona/stats/Util.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.stats + +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT + +private[stats] object Util { + def getGeometryColumnName(dataframe: DataFrame): String = { + val geomFields = dataframe.schema.fields.filter(_.dataType == GeometryUDT) + + if (geomFields.isEmpty) + throw new IllegalArgumentException( + "No GeometryType column found. Provide a dataframe containing a geometry column.") + + if (geomFields.length == 1) + return geomFields.head.name + + if (geomFields.length > 1 && !geomFields.exists(_.name == "geometry")) + throw new IllegalArgumentException( + "Multiple GeometryType columns found. Provide the column name as an argument.") + + "geometry" + } +} diff --git a/spark/common/src/main/scala/org/apache/sedona/stats/clustering/DBSCAN.scala b/spark/common/src/main/scala/org/apache/sedona/stats/clustering/DBSCAN.scala new file mode 100644 index 0000000000..5bc691c2dd --- /dev/null +++ b/spark/common/src/main/scala/org/apache/sedona/stats/clustering/DBSCAN.scala @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.stats.clustering + +import org.apache.sedona.stats.Util.getGeometryColumnName +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.sedona_sql.expressions.st_functions.{ST_Distance, ST_DistanceSpheroid} +import org.apache.spark.sql.{Column, DataFrame, SparkSession} +import org.graphframes.GraphFrame + +object DBSCAN { + + private val ID_COLUMN = "__id" + + /** + * Annotates a dataframe with a cluster label for each data record using the DBSCAN algorithm. + * The dataframe should contain at least one GeometryType column. Rows must be unique. If one + * geometry column is present it will be used automatically. If two are present, the one named + * 'geometry' will be used. If more than one are present and neither is named 'geometry', the + * column name must be provided. The new column will be named 'cluster'. + * + * @param dataframe + * dataframe to cluster. Must contain at least one GeometryType column + * @param epsilon + * minimum distance parameter of DBSCAN algorithm + * @param minPts + * minimum number of points parameter of DBSCAN algorithm + * @param geometry + * name of the geometry column + * @param includeOutliers + * whether to include outliers in the output. Default is false + * @param useSpheroid + * whether to use a cartesian or spheroidal distance calculation. Default is false + * @return + * The input DataFrame with the cluster label added to each row. Outlier will have a cluster + * value of -1 if included. + */ + def dbscan( + dataframe: DataFrame, + epsilon: Double, + minPts: Int, + geometry: String = null, + includeOutliers: Boolean = true, + useSpheroid: Boolean = false): DataFrame = { + + // We want to disable broadcast joins because the broadcast reference were using too much driver memory + val spark = SparkSession.getActiveSession.get + + val geometryCol = geometry match { + case null => getGeometryColumnName(dataframe) + case _ => geometry + } + + validateInputs(dataframe, epsilon, minPts, geometryCol) + + val distanceFunction: (Column, Column) => Column = + if (useSpheroid) ST_DistanceSpheroid else ST_Distance + + val hasIdColumn = dataframe.columns.contains("id") + val idDataframe = if (hasIdColumn) { + dataframe + .withColumnRenamed("id", ID_COLUMN) + .withColumn("id", sha2(to_json(struct("*")), 256)) + } else { + dataframe.withColumn("id", sha2(to_json(struct("*")), 256)) + } + + val isCorePointsDF = idDataframe + .alias("left") + .join( + idDataframe.alias("right"), + distanceFunction(col(s"left.$geometryCol"), col(s"right.$geometryCol")) <= epsilon) + .groupBy(col(s"left.id")) + .agg( + first(struct("left.*")).alias("leftContents"), + count(col(s"right.id")).alias("neighbors_count"), + collect_list(col(s"right.id")).alias("neighbors")) + .withColumn("isCore", col("neighbors_count") >= lit(minPts)) + .select("leftContents.*", "neighbors", "isCore") + .checkpoint() + + val corePointsDF = isCorePointsDF.filter(col("isCore")) + val borderPointsDF = isCorePointsDF.filter(!col("isCore")) + + val coreEdgesDf = corePointsDF + .select(col("id").alias("src"), explode(col("neighbors")).alias("dst")) + .alias("left") + .join(corePointsDF.alias("right"), col("left.dst") === col(s"right.id")) + .select(col("left.src"), col(s"right.id").alias("dst")) + + val connectedComponentsDF = GraphFrame(corePointsDF, coreEdgesDf).connectedComponents.run + + val borderComponentsDF = borderPointsDF + .select(struct("*").alias("leftContent"), explode(col("neighbors")).alias("neighbor")) + .join(connectedComponentsDF.alias("right"), col("neighbor") === col(s"right.id")) + .groupBy(col("leftContent.id")) + .agg( + first(col("leftContent")).alias("leftContent"), + min(col(s"right.component")).alias("component")) + .select("leftContent.*", "component") + + val clusteredPointsDf = borderComponentsDF.union(connectedComponentsDF) + + val outliersDf = idDataframe + .join(clusteredPointsDf, Seq("id"), "left_anti") + .withColumn("isCore", lit(false)) + .withColumn("component", lit(-1)) + .withColumn("neighbors", array().cast("array")) + + val completedDf = ( + if (includeOutliers) clusteredPointsDf.unionByName(outliersDf) + else clusteredPointsDf + ).withColumnRenamed("component", "cluster") + + val returnDf = if (hasIdColumn) { + completedDf.drop("neighbors", "id").withColumnRenamed(ID_COLUMN, "id") + } else { + completedDf.drop("neighbors", "id") + } + + returnDf + + } + + private def validateInputs( + geo_df: DataFrame, + epsilon: Double, + minPts: Int, + geometry: String): Unit = { + require(epsilon > 0, "epsilon must be greater than 0") + require(minPts > 0, "minPts must be greater than 0") + require(geo_df.columns.contains(geometry), "geometry column not found in dataframe") + require( + geo_df.schema.fields(geo_df.schema.fieldIndex(geometry)).dataType == GeometryUDT, + "geometry column must be of type GeometryType") + } +}