Examples

Image Color Quantization

import numpy as np
from PIL import Image
from kmeans import KMeans

# Load image
img = Image.open('photo.jpg')
img_array = np.array(img)

# Reshape to (n_pixels, 3)
pixels = img_array.reshape(-1, 3).astype(np.float64)

# Cluster colors
kmeans_model = KMeans(n_clusters=16)
kmeans_model.fit(pixels)

# Replace colors with centroids
quantized = kmeans_model.centroids_[kmeans_model.labels_]
quantized_img = quantized.reshape(img_array.shape).astype(np.uint8)

# Save result
Image.fromarray(quantized_img).save('quantized.jpg')

Customer Segmentation

import pandas as pd
from kmeans import KMeans

# Load customer data
df = pd.read_csv('customers.csv')
features = df[['age', 'income', 'spending_score']].values

# Normalize features
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
features_scaled = scaler.fit_transform(features)

# Cluster customers
model = KMeans(n_clusters=5)
df['segment'] = model.fit_predict(features_scaled)

# Analyze segments
print(df.groupby('segment').mean())

Anomaly Detection

from kmeans import KMeans
import numpy as np

# Cluster normal data
normal_data = np.random.randn(1000, 5)
model = KMeans(n_clusters=3)
model.fit(normal_data)

# Check new points
test_data = np.random.randn(100, 5)
labels = model.predict(test_data)

# Calculate distances to nearest centroid
distances = np.array([
    np.linalg.norm(test_data[i] - model.centroids_[labels[i]])
    for i in range(len(test_data))
])

# Flag anomalies (far from any centroid)
threshold = np.percentile(distances, 95)
anomalies = test_data[distances > threshold]
print(f"Found {len(anomalies)} anomalies")

Time Series Clustering

from kmeans import KMeans

# Assume time_series is shape (n_series, n_timepoints)
time_series = np.random.randn(100, 50)

# Cluster time series
model = KMeans(n_clusters=5)
labels = model.fit_predict(time_series)

# Plot representative series from each cluster
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 5, figsize=(20, 4))
for i in range(5):
    cluster_series = time_series[labels == i]
    axes[i].plot(cluster_series.T, alpha=0.3)
    axes[i].plot(model.centroids_[i], 'r-', linewidth=2)
    axes[i].set_title(f'Cluster {i}')
plt.show()