TenderAutomateSystem/kmeans.py

125 lines
3.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project recommender
@File kmeans.py
@IDE PyCharm
@Author rengengchen
@Time 2023/12/29 11:53
"""
import torch
import time
from tqdm import tqdm
class KMEANS:
def __init__(self, n_clusters=20, max_iter=None, verbose=True, device=torch.device("cpu")):
self.n_clusters = n_clusters
self.labels = None
self.dists = None # shape: [x.shape[0],n_cluster]
self.centers = None
self.variation = torch.Tensor([float("Inf")]).to(device)
self.verbose = verbose
self.started = False
self.representative_samples = None
self.max_iter = max_iter
self.count = 0
self.device = device
def fit(self, X):
x = int(np.random.uniform(0, k))
self.centers = x[0].reshape(1, -1)
# kmeans++
for i in range(self.n_clusters - 1):
dis = 0
for j, cj in enumerate(self.centers):
d = ((x - cj) ** 2).sum(1)
if j == 0:
dis = d
else:
dis += d
self.centers = torch.cat((self.centroids, x[dis.argmax(0)].reshape(1, -1)), 0)
self.centers = init_points
while True:
# 聚类标记
self.nearest_center(x)
# 更新中心点
self.update_center(x)
if self.verbose:
print(self.variation, torch.argmin(self.dists, (0)))
if torch.abs(self.variation) < 1e-3 and self.max_iter is None:
break
elif self.max_iter is not None and self.count == self.max_iter:
break
self.count += 1
self.representative_sample()
def nearest_center(self, x):
labels = torch.empty((x.shape[0],)).long().to(self.device)
dists = torch.empty((0, self.n_clusters)).to(self.device)
for i, sample in enumerate(x):
dist = torch.sum(torch.mul(sample - self.centers, sample - self.centers), (1))
labels[i] = torch.argmin(dist)
dists = torch.cat([dists, dist.unsqueeze(0)], (0))
self.labels = labels
if self.started:
self.variation = torch.sum(self.dists - dists)
self.dists = dists
self.started = True
def update_center(self, x):
centers = torch.empty((0, x.shape[1])).to(self.device)
for i in range(self.n_clusters):
mask = self.labels == i
cluster_samples = x[mask]
centers = torch.cat([centers, torch.mean(cluster_samples, (0)).unsqueeze(0)], (0))
self.centers = centers
def representative_sample(self):
# 查找距离中心点最近的样本,作为聚类的代表样本,更加直观
self.representative_samples = torch.argmin(self.dists, (0))
def time_clock(matrix, device):
a = time.time()
k = KMEANS(max_iter=10, verbose=False, device=device)
k.fit(matrix)
b = time.time()
return (b - a) / k.count
def choose_device(cuda=False):
if cuda:
device = torch.device("cuda:0")
else:
device = torch.device("cpu")
return device
if __name__ == "__main__":
import matplotlib.pyplot as plt
plt.figure()
device = choose_device(False)
cpu_speeds = []
for i in tqdm([20, 100, 500, 2000, 8000, 20000]):
matrix = torch.rand((10000, i)).to(device)
speed = time_clock(matrix, device)
cpu_speeds.append(speed)
l1, = plt.plot([20, 100, 500, 2000, 8000, 20000], cpu_speeds, color='r', label='CPU')
device = choose_device(True)
gpu_speeds = []
for i in tqdm([20, 100, 500, 2000, 8000, 20000]):
matrix = torch.rand((10000, i)).to(device)
speed = time_clock(matrix, device)
gpu_speeds.append(speed)
l2, = plt.plot([20, 100, 500, 2000, 8000, 20000], gpu_speeds, color='g', label="GPU")