Skip to content

[FEATURE]: katz centrality #148

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: pybind11
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cpp_easygraph/cpp_easygraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ PYBIND11_MODULE(cpp_easygraph, m) {

m.def("cpp_closeness_centrality", &closeness_centrality, py::arg("G"), py::arg("weight") = "weight", py::arg("cutoff") = py::none(), py::arg("sources") = py::none());
m.def("cpp_betweenness_centrality", &betweenness_centrality, py::arg("G"), py::arg("weight") = "weight", py::arg("cutoff") = py::none(),py::arg("sources") = py::none(), py::arg("normalized") = py::bool_(true), py::arg("endpoints") = py::bool_(false));
m.def("cpp_katz_centrality", &cpp_katz_centrality, py::arg("G"), py::arg("alpha") = 0.1, py::arg("beta") = 1.0, py::arg("max_iter") = 1000, py::arg("tol") = 1e-6, py::arg("normalized") = true);
m.def("cpp_k_core", &core_decomposition, py::arg("G"));
m.def("cpp_density", &density, py::arg("G"));
m.def("cpp_constraint", &constraint, py::arg("G"), py::arg("nodes") = py::none(), py::arg("weight") = py::none(), py::arg("n_workers") = py::none());
Expand Down
10 changes: 9 additions & 1 deletion cpp_easygraph/functions/centrality/centrality.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,12 @@

py::object closeness_centrality(py::object G, py::object weight, py::object cutoff, py::object sources);
py::object betweenness_centrality(py::object G, py::object weight, py::object cutoff, py::object sources,
py::object normalized, py::object endpoints);
py::object normalized, py::object endpoints);
py::object cpp_katz_centrality(
py::object G,
py::object py_alpha,
py::object py_beta,
py::object py_max_iter,
py::object py_tol,
py::object py_normalized
);
120 changes: 120 additions & 0 deletions cpp_easygraph/functions/centrality/katz_centrality.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
#include <cmath>
#include <vector>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "centrality.h"
#include "../../classes/graph.h"

namespace py = pybind11;

py::object cpp_katz_centrality(
py::object G,
py::object py_alpha,
py::object py_beta,
py::object py_max_iter,
py::object py_tol,
py::object py_normalized
) {
try {
Graph& graph = G.cast<Graph&>();
auto csr = graph.gen_CSR();
int n = csr->nodes.size();

if (n == 0) {
return py::dict();
}

// Initialize vectors
std::vector<double> x0(n, 1.0);
std::vector<double> x1(n);
std::vector<double>* x_prev = &x0;
std::vector<double>* x_next = &x1;

// Process beta parameter
std::vector<double> b(n);
if (py::isinstance<py::float_>(py_beta) || py::isinstance<py::int_>(py_beta)) {
double beta_val = py_beta.cast<double>();
for (int i = 0; i < n; i++) {
b[i] = beta_val;
}
} else if (py::isinstance<py::dict>(py_beta)) {
py::dict beta_dict = py_beta.cast<py::dict>();
for (int i = 0; i < n; i++) {
node_t internal_id = csr->nodes[i];
py::object node_obj = graph.id_to_node[py::cast(internal_id)];
if (beta_dict.contains(node_obj)) {
b[i] = beta_dict[node_obj].cast<double>();
} else {
b[i] = 1.0;
}
}
} else {
throw py::type_error("beta must be a float or a dict");
}

// Extract parameters
double alpha = py_alpha.cast<double>();
int max_iter = py_max_iter.cast<int>();
double tol = py_tol.cast<double>();
bool normalized = py_normalized.cast<bool>();

// Iterative updates
int iter = 0;
for (; iter < max_iter; iter++) {
for (int i = 0; i < n; i++) {
double sum = 0.0;
int start = csr->V[i];
int end = csr->V[i + 1];
for (int jj = start; jj < end; jj++) {
int j = csr->E[jj];
sum += (*x_prev)[j];
}
(*x_next)[i] = alpha * sum + b[i];
}

// Check convergence
double change = 0.0;
for (int i = 0; i < n; i++) {
change += std::abs((*x_next)[i] - (*x_prev)[i]);
}

if (change < tol) {
break;
}

std::swap(x_prev, x_next);
}

// Handle convergence failure
if (iter == max_iter) {
throw std::runtime_error("Katz centrality failed to converge in " + std::to_string(max_iter) + " iterations");
}

// Normalization
std::vector<double>& x_final = *x_next;
if (normalized) {
double norm = 0.0;
for (double val : x_final) {
norm += val * val;
}
norm = std::sqrt(norm);
if (norm > 0) {
for (int i = 0; i < n; i++) {
x_final[i] /= norm;
}
}
}

// Prepare results
py::dict result;
for (int i = 0; i < n; i++) {
node_t internal_id = csr->nodes[i];
py::object node_obj = graph.id_to_node[py::cast(internal_id)];
result[node_obj] = x_final[i];
}

return result;
} catch (const std::exception& e) {
throw std::runtime_error(e.what());
}
}
1 change: 1 addition & 0 deletions easygraph/functions/centrality/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
from .flowbetweenness import *
from .laplacian import *
from .pagerank import *
from .katz_centrality import *
105 changes: 105 additions & 0 deletions easygraph/functions/centrality/katz_centrality.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from easygraph.utils import *
import numpy as np
from easygraph.utils.decorators import *

__all__ = ["katz_centrality"]

@not_implemented_for("multigraph")
@hybrid("cpp_katz_centrality")
def katz_centrality(G, alpha=0.1, beta=1.0, max_iter=1000, tol=1e-6, normalized=True):
r"""
Compute the Katz centrality for nodes in a graph.

Katz centrality computes the influence of a node based on the total number
of walks between nodes, attenuated by a factor of their length. It is
defined as the solution to the linear system:

.. math::

x = \alpha A x + \beta

where:
- \( A \) is the adjacency matrix of the graph,
- \( \alpha \) is a scalar attenuation factor,
- \( \beta \) is the bias vector (typically all ones),
- and \( x \) is the resulting centrality vector.

The algorithm runs an iterative fixed-point method until convergence.

Parameters
----------
G : easygraph.Graph
An EasyGraph graph instance. Must be simple (non-multigraph).

alpha : float, optional (default=0.1)
Attenuation factor, must be smaller than the reciprocal of the largest
eigenvalue of the adjacency matrix to ensure convergence.

beta : float or dict, optional (default=1.0)
Bias term. Can be a constant scalar applied to all nodes, or a dictionary
mapping node IDs to values.

max_iter : int, optional (default=1000)
Maximum number of iterations before the algorithm terminates.

tol : float, optional (default=1e-6)
Convergence tolerance. Iteration stops when the L1 norm of the difference
between successive iterations is below this threshold.

normalized : bool, optional (default=True)
If True, the result vector will be normalized to unit norm (L2).

Returns
-------
dict
A dictionary mapping node IDs to Katz centrality scores.

Raises
------
RuntimeError
If the algorithm fails to converge within `max_iter` iterations.

Examples
--------
>>> import easygraph as eg
>>> from easygraph import katz_centrality
>>> G = eg.Graph()
>>> G.add_edges_from([(0, 1), (1, 2), (2, 3)])
>>> katz_centrality(G, alpha=0.05)
{0: 0.370..., 1: 0.447..., 2: 0.447..., 3: 0.370...}
"""
# Create node ordering
nodes = list(G.nodes)
n = len(nodes)
node_to_index = {node: i for i, node in enumerate(nodes)}
index_to_node = {i: node for i, node in enumerate(nodes)}

# Build adjacency matrix
A = np.zeros((n, n), dtype=np.float64)
for u in G.nodes:
for v in G.adj[u]:
A[node_to_index[u], node_to_index[v]] = 1.0

# Initialize x and beta
x = np.ones(n, dtype=np.float64)
if isinstance(beta, dict):
b = np.array([beta.get(index_to_node[i], 1.0) for i in range(n)])
else:
b = np.ones(n, dtype=np.float64) * beta

# Iterative update using vectorized ops
for _ in range(max_iter):
x_new = alpha * A @ x + b
if np.linalg.norm(x_new - x, ord=1) < tol:
break
x = x_new
else:
raise RuntimeError(f"Katz centrality failed to converge in {max_iter} iterations")

if normalized:
norm = np.linalg.norm(x)
if norm > 0:
x /= norm

result = {index_to_node[i]: float(x[i]) for i in range(n)}
return result
Loading