#!/usr/bin/env python # -*- coding: UTF-8 -*- """ @Project :recommender @File :kmeans.py @IDE :PyCharm @Author :rengengchen @Time :2023/12/29 11:53 """ import torch import time from tqdm import tqdm class KMEANS: def __init__(self, n_clusters=20, max_iter=None, verbose=True, device=torch.device("cpu")): self.n_clusters = n_clusters self.labels = None self.dists = None # shape: [x.shape[0],n_cluster] self.centers = None self.variation = torch.Tensor([float("Inf")]).to(device) self.verbose = verbose self.started = False self.representative_samples = None self.max_iter = max_iter self.count = 0 self.device = device def fit(self, X): x = int(np.random.uniform(0, k)) self.centers = x[0].reshape(1, -1) # kmeans++ for i in range(self.n_clusters - 1): dis = 0 for j, cj in enumerate(self.centers): d = ((x - cj) ** 2).sum(1) if j == 0: dis = d else: dis += d self.centers = torch.cat((self.centroids, x[dis.argmax(0)].reshape(1, -1)), 0) self.centers = init_points while True: # 聚类标记 self.nearest_center(x) # 更新中心点 self.update_center(x) if self.verbose: print(self.variation, torch.argmin(self.dists, (0))) if torch.abs(self.variation) < 1e-3 and self.max_iter is None: break elif self.max_iter is not None and self.count == self.max_iter: break self.count += 1 self.representative_sample() def nearest_center(self, x): labels = torch.empty((x.shape[0],)).long().to(self.device) dists = torch.empty((0, self.n_clusters)).to(self.device) for i, sample in enumerate(x): dist = torch.sum(torch.mul(sample - self.centers, sample - self.centers), (1)) labels[i] = torch.argmin(dist) dists = torch.cat([dists, dist.unsqueeze(0)], (0)) self.labels = labels if self.started: self.variation = torch.sum(self.dists - dists) self.dists = dists self.started = True def update_center(self, x): centers = torch.empty((0, x.shape[1])).to(self.device) for i in range(self.n_clusters): mask = self.labels == i cluster_samples = x[mask] centers = torch.cat([centers, torch.mean(cluster_samples, (0)).unsqueeze(0)], (0)) self.centers = centers def representative_sample(self): # 查找距离中心点最近的样本,作为聚类的代表样本,更加直观 self.representative_samples = torch.argmin(self.dists, (0)) def time_clock(matrix, device): a = time.time() k = KMEANS(max_iter=10, verbose=False, device=device) k.fit(matrix) b = time.time() return (b - a) / k.count def choose_device(cuda=False): if cuda: device = torch.device("cuda:0") else: device = torch.device("cpu") return device if __name__ == "__main__": import matplotlib.pyplot as plt plt.figure() device = choose_device(False) cpu_speeds = [] for i in tqdm([20, 100, 500, 2000, 8000, 20000]): matrix = torch.rand((10000, i)).to(device) speed = time_clock(matrix, device) cpu_speeds.append(speed) l1, = plt.plot([20, 100, 500, 2000, 8000, 20000], cpu_speeds, color='r', label='CPU') device = choose_device(True) gpu_speeds = [] for i in tqdm([20, 100, 500, 2000, 8000, 20000]): matrix = torch.rand((10000, i)).to(device) speed = time_clock(matrix, device) gpu_speeds.append(speed) l2, = plt.plot([20, 100, 500, 2000, 8000, 20000], gpu_speeds, color='g', label="GPU")