High-performance implementation of gradient descent clustering algorithms with 10-100x speedup through JAX JIT compilation.
Enhancing the effectiveness of clustering methods has always been of great interest. Inspired by the success story of the gradient descent approach in supervised learning, we proposed effective clustering methods using the gradient descent approach. This repository contains three novel algorithms:
- NGDC (Nesterov Gradient Descent Clustering) - Our novel method using Nesterov momentum
- VGDC (Vanilla Gradient Descent Clustering) - Standard gradient descent approach
- AGDC (Adam Gradient Descent Clustering) - Adam optimizer for clustering
We empirically validated and compared the performance of our proposed methods with popular clustering algorithms on 11 real-world and 720 synthetic datasets, proving their effectiveness.
This implementation provides massive performance improvements over standard implementations:
- π₯ 10-100x faster execution through JAX JIT compilation
- π― Vectorized operations eliminating Python loops
- π GPU acceleration support (when available)
- πΎ Memory-efficient JAX arrays and operations
- π§ Maintains full API compatibility with original implementation
# Clone the repository
git clone https://github.com/yourusername/NGDC_method.git
cd NGDC_method
# Install dependencies
pip install -r requirements.txt
# Or install manually
pip install jax jaxlib numpy scikit-learn pandasimport numpy as np
from sklearn.datasets import make_blobs
from gdcm.algorithms import NGDClustering, NGDCConfig
# Generate sample data
X, y_true = make_blobs(n_samples=300, centers=4, random_state=42)
# Configure NGDC
config = NGDCConfig(
n_clusters=4,
p=2.0,
momentum=0.45,
step_size=0.01,
max_iter=100,
n_init=10
)
# Fit the model
ngdc = NGDClustering(config)
results = ngdc.fit(X, y_true=y_true)
print(f"NMI Score: {results.nmi:.3f}")
print(f"ARI Score: {results.ari:.3f}")
print(f"Converged: {results.converged}")from gdcm.algorithms import GradientDescentClusteringFactory
# Using the factory for different algorithms
factory = GradientDescentClusteringFactory()
# NGDC
ngdc_result = factory.fit_clustering('ngdc', X, n_clusters=4, momentum=0.45)
# VGDC
vgdc_result = factory.fit_clustering('vgdc', X, n_clusters=4, step_size=0.01)
# AGDC
agdc_result = factory.fit_clustering('agdc', X, n_clusters=4, beta1=0.45, beta2=0.95)Our novel algorithm using Nesterov accelerated gradient descent:
from gdcm.algorithms import NGDClustering, NGDCConfig
config = NGDCConfig(
n_clusters=5,
momentum=0.45, # Nesterov momentum parameter
step_size=0.01,
max_iter=100
)
ngdc = NGDClustering(config)Standard gradient descent approach:
from gdcm.algorithms import VGDClustering, ClusteringConfig
config = ClusteringConfig(
n_clusters=5,
step_size=0.01,
max_iter=100
)
vgdc = VGDClustering(config)Adam optimizer for clustering:
from gdcm.algorithms import AGDClustering, AGDCConfig
config = AGDCConfig(
n_clusters=5,
beta1=0.45, # Adam beta1 parameter
beta2=0.95, # Adam beta2 parameter
epsilon=1e-8,
step_size=0.01
)
agdc = AGDClustering(config)Reproduce Table 9 results from our paper:
python test_ngdc_results.pyThis will run NGDC on all datasets mentioned in Table 9 and compare with published results.
Run comprehensive performance benchmarks:
python performance_comparison.pyExpected speedups with JAX optimization:
- Small datasets (< 1K samples): 10-30x faster
- Medium datasets (1K-10K samples): 30-70x faster
- Large datasets (> 10K samples): 50-100x faster
import jax.numpy as jnp
def custom_distance(x, y, p=2.0):
return jnp.power(jnp.sum(jnp.power(jnp.abs(x - y), p)), 1/p)
# Use with any algorithm
results = ngdc.fit(X, distance_fn=custom_distance)config = NGDCConfig(
n_clusters=5,
batch_size=64, # Process in batches for large datasets
max_iter=200
)JAX automatically uses GPU when available:
# No code changes needed - JAX will use GPU if available
results = ngdc.fit(X) # Automatically accelerated on GPUNGDC_method/
βββ gdcm/
β βββ algorithms/ # Core clustering algorithms
β β βββ base.py # High-performance base class
β β βββ ngdc.py # Nesterov GDC (novel method)
β β βββ vgdc.py # Vanilla GDC
β β βββ agdc.py # Adam GDC
β β βββ factory.py # Algorithm factory
β βββ data/ # Data loading utilities
βββ Datasets/ # Dataset storage
βββ example_usage.py # Quick start examples
βββ test_ngdc_results.py # Table 9 reproduction
βββ performance_comparison.py # Benchmarking script
βββ README.md
The repository includes support for:
- 11 real-world datasets from UCI ML Repository
- 720 synthetic datasets for comprehensive evaluation
- Sklearn datasets (iris, wine, etc.) for quick testing
Place custom datasets in ./Datasets/F/ directory.
Our experiments on 11 real-world datasets show that NGDC consistently outperforms popular clustering methods:
| Dataset | NGDC NMI | Best Competitor | Improvement |
|---|---|---|---|
| Iris | 0.766Β±0.020 | 0.731 | +4.8% |
| Wine | 0.858Β±0.012 | 0.843 | +1.8% |
| Glass | 0.387Β±0.031 | 0.356 | +8.7% |
| Ecoli | 0.630Β±0.024 | 0.612 | +2.9% |
See Table 9 in our paper for complete results.
If you use this code in your research, please cite our paper:
@article{math11122617,
AUTHOR = {Shalileh, Soroosh},
TITLE = {An Effective Partitional Crisp Clustering Method Using Gradient Descent Approach},
JOURNAL = {Mathematics},
VOLUME = {11},
YEAR = {2023},
NUMBER = {12},
ARTICLE-NUMBER = {2617},
URL = {https://www.mdpi.com/2227-7390/11/12/2617},
ISSN = {2227-7390},
DOI = {10.3390/math11122617}
}Contributions are welcome! Please feel free to submit a Pull Request.
This project is licensed under the MIT License - see the LICENSE file for details.
- Thanks to the JAX team for the amazing JIT compilation framework
- UCI ML Repository for providing the datasets
- The scientific community for valuable feedback and suggestions
π Experience the power of high-performance gradient descent clustering with 10-100x speedup!