#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Optional
from pyspark.storagelevel import StorageLevel
from pyspark.version import __version__
if __version__[:3] >= "3.4":
from pyspark.sql.utils import is_remote
else:
# All the Connect-related utilities are accessible starting from 3.4.x
def is_remote() -> bool:
return False
from pyspark.sql import SparkSession
from graphframes.classic.graphframe import GraphFrame as GraphFrameClassic
from graphframes.lib import Pregel
if __version__[:3] >= "3.4":
from graphframes.connect.graphframe_client import GraphFrameConnect
else:
class GraphFrameConnect:
def __init__(self, *args, **kwargs) -> None:
raise ValueError("Unreachable error happened!")
if TYPE_CHECKING:
from pyspark.sql import Column, DataFrame
[docs]
class GraphFrame:
"""
Represents a graph with vertices and edges stored as DataFrames.
:param v: :class:`DataFrame` holding vertex information.
Must contain a column named "id" that stores unique
vertex IDs.
:param e: :class:`DataFrame` holding edge information.
Must contain two columns "src" and "dst" storing source
vertex IDs and destination vertex IDs of edges, respectively.
>>> localVertices = [(1,"A"), (2,"B"), (3, "C")]
>>> localEdges = [(1,2,"love"), (2,1,"hate"), (2,3,"follow")]
>>> v = spark.createDataFrame(localVertices, ["id", "name"])
>>> e = spark.createDataFrame(localEdges, ["src", "dst", "action"])
>>> g = GraphFrame(v, e)
"""
@staticmethod
def _from_impl(impl: GraphFrameClassic | GraphFrameConnect) -> "GraphFrame":
return GraphFrame(impl.vertices, impl.edges)
def __init__(self, v: DataFrame, e: DataFrame) -> None:
if is_remote():
self._impl = GraphFrameConnect(v, e)
else:
self._impl = GraphFrameClassic(v, e)
@property
def vertices(self) -> DataFrame:
"""
:class:`DataFrame` holding vertex information, with unique column "id"
for vertex IDs.
"""
return self._impl.vertices
@property
def edges(self) -> DataFrame:
"""
:class:`DataFrame` holding edge information, with unique columns "src" and
"dst" storing source vertex IDs and destination vertex IDs of edges,
respectively.
"""
return self._impl.edges
def __repr__(self) -> str:
return self._impl.__repr__()
[docs]
def cache(self) -> "GraphFrame":
"""Persist the dataframe representation of vertices and edges of the graph with the default
storage level.
"""
return GraphFrame._from_impl(self._impl.cache())
[docs]
def persist(self, storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY) -> "GraphFrame":
"""Persist the dataframe representation of vertices and edges of the graph with the given
storage level.
"""
return GraphFrame._from_impl(self._impl.persist(storageLevel=storageLevel))
[docs]
def unpersist(self, blocking: bool = False) -> "GraphFrame":
"""Mark the dataframe representation of vertices and edges of the graph as non-persistent,
and remove all blocks for it from memory and disk.
"""
return GraphFrame._from_impl(self._impl.unpersist(blocking=blocking))
@property
def outDegrees(self) -> DataFrame:
"""
The out-degree of each vertex in the graph, returned as a DataFrame with two columns:
- "id": the ID of the vertex
- "outDegree" (integer) storing the out-degree of the vertex
Note that vertices with 0 out-edges are not returned in the result.
:return: DataFrame with new vertices column "outDegree"
"""
return self._impl.outDegrees
@property
def inDegrees(self) -> DataFrame:
"""
The in-degree of each vertex in the graph, returned as a DataFame with two columns:
- "id": the ID of the vertex
- "inDegree" (int) storing the in-degree of the vertex
Note that vertices with 0 in-edges are not returned in the result.
:return: DataFrame with new vertices column "inDegree"
"""
return self._impl.inDegrees
@property
def degrees(self) -> DataFrame:
"""
The degree of each vertex in the graph, returned as a DataFrame with two columns:
- "id": the ID of the vertex
- 'degree' (integer) the degree of the vertex
Note that vertices with 0 edges are not returned in the result.
:return: DataFrame with new vertices column "degree"
"""
return self._impl.degrees
@property
def triplets(self) -> DataFrame:
"""
The triplets (source vertex)-[edge]->(destination vertex) for all edges in the graph.
Returned as a :class:`DataFrame` with three columns:
- "src": source vertex with schema matching 'vertices'
- "edge": edge with schema matching 'edges'
- 'dst': destination vertex with schema matching 'vertices'
:return: DataFrame with columns 'src', 'edge', and 'dst'
"""
return self._impl.triplets
@property
def pregel(self) -> Pregel:
"""
Get the :class:`graphframes.lib.Pregel` object for running pregel.
See :class:`graphframes.lib.Pregel` for more details.
"""
return self._impl.pregel
[docs]
def find(self, pattern: str) -> DataFrame:
"""
Motif finding.
See Scala documentation for more details.
:param pattern: String describing the motif to search for.
:return: DataFrame with one Row for each instance of the motif found
"""
return self._impl.find(pattern=pattern)
[docs]
def filterVertices(self, condition: str | Column) -> "GraphFrame":
"""
Filters the vertices based on expression, remove edges containing any dropped vertices.
:param condition: String or Column describing the condition expression for filtering.
:return: GraphFrame with filtered vertices and edges.
"""
return GraphFrame._from_impl(self._impl.filterVertices(condition=condition))
[docs]
def filterEdges(self, condition: str | Column) -> "GraphFrame":
"""
Filters the edges based on expression, keep all vertices.
:param condition: String or Column describing the condition expression for filtering.
:return: GraphFrame with filtered edges.
"""
return GraphFrame._from_impl(self._impl.filterEdges(condition=condition))
[docs]
def dropIsolatedVertices(self) -> "GraphFrame":
"""
Drops isolated vertices, vertices are not contained in any edges.
:return: GraphFrame with filtered vertices.
"""
return GraphFrame._from_impl(self._impl.dropIsolatedVertices())
[docs]
def bfs(
self,
fromExpr: str,
toExpr: str,
edgeFilter: str | None = None,
maxPathLength: int = 10,
) -> DataFrame:
"""
Breadth-first search (BFS).
See Scala documentation for more details.
:return: DataFrame with one Row for each shortest path between matching vertices.
"""
return self._impl.bfs(
fromExpr=fromExpr,
toExpr=toExpr,
edgeFilter=edgeFilter,
maxPathLength=maxPathLength,
)
[docs]
def aggregateMessages(
self,
aggCol: Column | str,
sendToSrc: Column | str | None = None,
sendToDst: Column | str | None = None,
) -> DataFrame:
"""
Aggregates messages from the neighbours.
When specifying the messages and aggregation function, the user may reference columns using
the static methods in :class:`graphframes.lib.AggregateMessages`.
See Scala documentation for more details.
:param aggCol: the requested aggregation output either as
:class:`pyspark.sql.Column` or SQL expression string
:param sendToSrc: message sent to the source vertex of each triplet either as
:class:`pyspark.sql.Column` or SQL expression string (default: None)
:param sendToDst: message sent to the destination vertex of each triplet either as
:class:`pyspark.sql.Column` or SQL expression string (default: None)
:return: DataFrame with columns for the vertex ID and the resulting aggregated message
"""
return self._impl.aggregateMessages(aggCol=aggCol, sendToSrc=sendToSrc, sendToDst=sendToDst)
# Standard algorithms
[docs]
def connectedComponents(
self,
algorithm: str = "graphframes",
checkpointInterval: int = 2,
broadcastThreshold: int = 1000000,
) -> DataFrame:
"""
Computes the connected components of the graph.
See Scala documentation for more details.
:param algorithm: connected components algorithm to use (default: "graphframes")
Supported algorithms are "graphframes" and "graphx".
:param checkpointInterval: checkpoint interval in terms of number of iterations (default: 2)
:param broadcastThreshold: broadcast threshold in propagating component assignments
(default: 1000000)
:return: DataFrame with new vertices column "component"
"""
return self._impl.connectedComponents(
algorithm=algorithm,
checkpointInterval=checkpointInterval,
broadcastThreshold=broadcastThreshold,
)
[docs]
def labelPropagation(self, maxIter: int) -> DataFrame:
"""
Runs static label propagation for detecting communities in networks.
See Scala documentation for more details.
:param maxIter: the number of iterations to be performed
:return: DataFrame with new vertices column "label"
"""
return self._impl.labelPropagation(maxIter=maxIter)
[docs]
def shortestPaths(self, landmarks: list[Any]) -> DataFrame:
"""
Runs the shortest path algorithm from a set of landmark vertices in the graph.
See Scala documentation for more details.
:param landmarks: a set of one or more landmarks
:return: DataFrame with new vertices column "distances"
"""
return self._impl.shortestPaths(landmarks=landmarks)
[docs]
def stronglyConnectedComponents(self, maxIter: int) -> DataFrame:
"""
Runs the strongly connected components algorithm on this graph.
See Scala documentation for more details.
:param maxIter: the number of iterations to run
:return: DataFrame with new vertex column "component"
"""
return self._impl.stronglyConnectedComponents(maxIter=maxIter)
[docs]
def svdPlusPlus(
self,
rank: int = 10,
maxIter: int = 2,
minValue: float = 0.0,
maxValue: float = 5.0,
gamma1: float = 0.007,
gamma2: float = 0.007,
gamma6: float = 0.005,
gamma7: float = 0.015,
) -> tuple[DataFrame, float]:
"""
Runs the SVD++ algorithm.
See Scala documentation for more details.
:return: Tuple of DataFrame with new vertex columns storing learned model, and loss value
"""
return self._impl.svdPlusPlus(
rank=rank,
maxIter=maxIter,
minValue=minValue,
maxValue=maxValue,
gamma1=gamma1,
gamma2=gamma2,
gamma6=gamma6,
gamma7=gamma7,
)
[docs]
def triangleCount(self) -> DataFrame:
"""
Counts the number of triangles passing through each vertex in this graph.
See Scala documentation for more details.
:return: DataFrame with new vertex column "count"
"""
return self._impl.triangleCount()
[docs]
def powerIterationClustering(
self, k: int, maxIter: int, weightCol: Optional[str] = None
) -> DataFrame:
"""
Power Iteration Clustering (PIC), a scalable graph clustering algorithm developed by Lin and Cohen.
From the abstract: PIC finds a very low-dimensional embedding of a dataset using truncated power iteration
on a normalized pair-wise similarity matrix of the data.
:param k: the numbers of clusters to create
:param maxIter: param for maximum number of iterations (>= 0)
:param weightCol: optional name of weight column, 1.0 is used if not provided
:return: DataFrame with new column "cluster"
""" # noqa: E501
return self._impl.powerIterationClustering(k, maxIter, weightCol)
def _test():
import doctest
import graphframe
globs = graphframe.__dict__.copy()
globs["spark"] = SparkSession.builder.master("local[4]").appName("PythonTest").getOrCreate()
(failure_count, test_count) = doctest.testmod(
globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE
)
globs["spark"].stop()
if failure_count:
exit(-1)
if __name__ == "__main__":
_test()