聚类分析之KMeans/Mean-shift/KNN

Python3.8

Python3.8

Conda
Python

Python 是一种高级、解释型、通用的编程语言,以其简洁易读的语法而闻名,适用于广泛的应用,包括Web开发、数据分析、人工智能和自动化脚本

一文读懂聚类分析:从核心概念到主流算法详解

在机器学习领域,数据挖掘的核心任务常被划分为监督学习无监督学习两大阵营。其中,聚类分析作为无监督学习的核心技术之一,凭借 “无需标注、自动归类” 的特性,在用户分群、异常检测、图像分割等场景中广泛应用。本文将从基础概念切入,拆解三大主流聚类算法的原理、公式与适用场景,同时厘清易混淆的相似算法,帮你快速掌握聚类分析的核心逻辑。

一、先理清楚:聚类分析的 “身份定位”

在展开算法前,我们首先要明确聚类分析的核心属性,避免与监督学习算法混淆:

  • 所属领域:无监督学习(无需人工标核心属性,避免与监督学习算法混淆:
  • 所属领域:无监督学习(无需人工标注数据类别,算法自动根据数据属性相似度划分)
  • 无监督学习三核心任务:聚类分析(本文重点)、关联规则(如 “买尿布的人常买啤酒”)、维度缩减(如 PCA 降维)
  • 易混淆概念区分
    • 逻辑回归、线性回归:均属于监督学习(需依赖标注好的 “特征 - 标签” 数据训练模型)
    • KNN(K 近邻):属于监督学习(通过 “待测样本周边 k 个邻居的类别” 判断自身类别)
    • KMeans(K 均值):属于无监督学习(通过 “距离中心点的远近” 自动划分簇类)

二、主流聚类算法拆解:原理、公式与特点

1. KMeans 聚类:“指定簇数 + 迭代收敛” 的经典方案

核心逻辑

KMeans 是最常用的聚类算法之一,核心思路是 “先指定簇的数量 k,再通过迭代更新簇中心点,最终让所有样本归到距离最近的簇中”,直到中心点位置不再变化(收敛)。

关键步骤与公式
  1. 初始化:随机选择 k 个样本作为初始簇中心点(记为u_{j}^{t},t 代表第 t 次迭代)
  1. 样本归类:计算每个样本x_{i}与所有簇中心点u_{j}^{t}的距离(常用欧氏距离dist(x_{i},u_{j}^{t})=\sqrt{\sum (x_{i}-u_{j}^{t})^{2}},将x_{i}归入距离最近的簇(即x_{i}u_{j}^{t}

更新中心点:对每个簇,计算该簇内所有样本的均值,作为新的簇中心点:u_{j}^{t+1}=\frac{1}{k_{j}}\sum_{x_{i}\in C_{j}}^{}x_{i}

其中C_{j}代表第 j 个簇,k_{j}是该簇内的样本数量

  1. 收敛判断:重复步骤 2-3,直到簇中心点的变化量小于预设阈值,或达到最大迭代次数。
优缺点与适用场景
  • 优点:计算速度快、对大规模数据友好、结果易解释
  • 缺点:需提前指定 k 值(k 的选择对结果影响大,常用肘部法则判断最优 k)、对异常值敏感、只适用于球形簇
  • 适用场景:数据分布相对规整、已知大致簇数量的场景(如用户消费等级分群、产品类别划分)

2. Mean - shift(均值漂移)聚类:“自动找簇 + 密度导向” 的灵活方案

核心逻辑

Mean - shift 是一种基于 “密度梯度上升” 的聚类算法,无需提前指定簇数量,而是通过 “让每个样本向数据密集区移动” 的方式,自动发现簇的数量和范围。核心思想是:数据密集的区域会形成 “引力中心”,样本会不断向引力中心漂移,最终汇聚成簇。

关键步骤与公式
  1. 定义搜索窗口:为每个样本设定一个以自身为中心、半径为 h 的搜索窗口(半径 h 需手动指定,影响簇的颗粒度)

计算均值漂移向量:在搜索窗口内,计算所有样本相对于当前中心点\( x_i \)的偏移均值,得到 “均值漂移向量” M (x):M(x)=\frac{1}{k}\sum_{u\in Window(x)}^{}(u-x_{i})

其中 k 是搜索窗口内的样本数量,u 是窗口内的样本

中心点漂移:将当前中心点沿均值漂移向量移动,更新为新的中心点:u^{t+1}=M^{t}+u^{t}

  1. 合并收敛点:重复步骤 2-3,直到中心点不再移动(收敛);最后将距离较近的收敛点合并,形成最终的簇。
优缺点与适用场景
  • 优点:无需指定簇数量(自动发现)、对非球形簇适配性好、对噪声有一定容忍度
  • 缺点:半径 h 的选择依赖经验(h 过大易合并簇,h 过小易分碎簇)、计算成本高于 KMeans
  • 适用场景:数据簇形状不规则、未知簇数量的场景(如图像分割、目标跟踪)

3. DBSCAN(基于密度的空间聚类):“密度筛选 + 噪声过滤” 的稳健方案

核心逻辑

DBSCAN 是典型的密度聚类算法,核心思路是 “通过‘核心点’的密度连接扩展簇,同时过滤掉低密度的噪声点”。它不依赖簇的形状假设,能有效处理非球形簇和异常值。

关键概念与步骤
  1. 两个核心参数
  • ε(epsilon):邻域半径(指定样本周边的搜索范围)
  • MinPts:邻域内的最小样本数(判断该样本是否为 “核心点” 的阈值)
  1. 样本分类
  • 核心点:邻域内(ε 范围内)样本数 ≥ MinPts 的样本(可作为簇的中心)
  • 边界点:邻域内样本数 < MinPts,但在某个核心点的邻域内(归属于核心点所在的簇)
  • 噪声点:既不是核心点也不是边界点的样本(直接过滤)
  1. 簇扩展规则:从任意核心点出发,将其所有密度可达的核心点和边界点归为一个簇;重复此过程,直到所有核心点都被分配到簇中。
优缺点与适用场景
  • 优点:无需指定簇数量、能识别任意形状的簇、可自动过滤噪声点
  • 缺点:对 ε 和 MinPts 的参数敏感(需通过领域知识或参数调优确定)、对高维数据效果较差(需先降维)
  • 适用场景:存在噪声数据、簇形状不规则的场景(如异常交易检测、地理区域划分)

三、必看对比:KMeans vs KNN

很多初学者会把 KMeans 和 KNN 搞混,两者虽都带 “K”,但本质完全不同,核心区别如下:

对比维度

KMeans(K 均值)

KNN(K 近邻)

所属学习类型

无监督学习(无标签数据)

监督学习(需标签数据)

“K” 的含义

预设的簇的数量

待测样本周边的邻居数量

核心逻辑

按 “距离中心点远近” 归类

按 “邻居的类别占比” 归类

计算时机

训练时完成聚类(离线计算)

预测时才计算邻居(在线计算)

典型用途

数据分群、聚类探索

分类任务(如疾病诊断)、回归任务

举个通俗例子:

  • 用 KMeans 给 100 个用户分群(k=3):算法自动把用户分成 3 类,无需知道 “每类用户是什么标签”;
  • 用 KNN 判断某用户是否为高价值客户(k=5):需先有 “已知高 / 低价值客户” 的标签数据,再看该用户周边 5 个邻居中 “高价值客户多还是低价值多”,以此判断类别。

四、算法实战

先简单介绍一下KMeans的用法:

#Kmeans
from sklearn.cluster import KMeans
KM=KMeans(n_clusters=3,random_state=0)
KM.fit(x)
#获取中心点
center =KM.cluster_centers_
#准确率
from sklearn.mestrucs import accuracy_score
accuracy=accuracy_score(y,y_predict)
#结果矫正
y_cal=[]
for i in y_predict:
    if i==0:
      y_cal.append(2)
    elif i==1:
      y_cal.append(1)
    else:
      y_cal.append(0)
print(y_predict,y_cal)

Mean-Shift:

#MeanShift
#自动计算带宽
from sklearn.cluster import MeanShift,estimate_bandwidth
bandwidth=estimate_bandwidth(x,n_samples=500)
#模型建立与训练
ms=MeanShift(bandwidth=bandwidth)
ms.fit(x)

KNN:

#KNN
from sklearn.neighbors import KNeighborsClassifier
KNN=KNeighborsClassifier(n_neighbors=3)
KNN.fit(x,y)

下面进入实例讲解:

需要下载sklearn库、pandas和numpy,使用到的data.csv可以参考大佬的资源:提取码1234DATA数据_免费高速下载|百度网盘-分享无限制

先导入 pandas 和 numpy 库,然后用 pandas 读取 csv 格式的 2D 数据文件,最后用 head () 显示数据前几行,方便快速查看数据结构和内容。

KMeans:
#实战 2D数据类别划分
import pandas as pd
import numpy as np
data=pd.read_csv('data.csv')
data.head()
V1V2labels
02.072345-3.2416930
117.93671015.7848100
21.0835767.3191760
311.12067014.4067800
423.7115502.5577290

首先,x = data.drop(['labels'], axis=1) 这句代码的作用是从原始数据中移除名为 “labels” 的列,得到的 x 就是用于模型训练的特征数据,包含了除标签之外的所有二维数据信息。接着,y = data.loc[:, 'labels'] 则是专门提取出 “labels” 列的数据,作为后续模型训练或验证时的标签,也就是我们希望预测或比对的类别信息。最后用 y.head() 展示标签列的前几行数据,方便快速了解标签的基本情况和格式。

#define x and y
x=data.drop(['labels'],axis=1)
y=data.loc[:,'labels']
y.head()

运行结果:

0    0
1    0
2    0
3    0
4    0
Name: labels, dtype: int64

`pd.value_counts(y)` 用于统计标签列`y`中各类类别出现的次数,能直观展示各类别数据的分布情况,比如不同类别分别有多少样本,帮助了解数据集中类别是否均衡,为后续聚类或分类分析提供基础信息。

pd.value_counts(y)

运行结果:

labels
2    1156
1     954
0     890
Name: count, dtype: int64

接着可视化未标记的二维数据分布。首先通过`%matplotlib inline`设置图表内嵌显示,导入matplotlib绘图库;然后创建画布,用散点图`plt.scatter()`绘制数据中V1和V2两个特征的分布,每个点代表一个样本;最后添加标题“unlabeled data”和坐标轴标签,通过`plt.show()`展示图像,直观呈现数据在二维空间中的聚集状态,为后续聚类分析提供可视化参考。

%matplotlib inline
from matplotlib import pyplot as plt
fig1=plt.figure()
plt.scatter(x.loc[:,'V1'],x.loc[:,'V2'])
plt.title("unlabeled data")
plt.xlabel('V1')
plt.ylabel('V2')
plt.show()

可视化带标签的二维数据分布,通过分别筛选出标签为0、1、2的样本,用散点图按标签类别绘制V1和V2特征的分布,不同标签对应不同颜色的点。图表添加了标题、坐标轴标签,并通过legend函数标注各类别对应的图例,最终展示出不同类别数据在二维空间中的实际分布情况,便于直观对比数据的真实聚类状态。

fig1=plt.figure()
label0=plt.scatter(x.loc[:,'V1'][y==0],x.loc[:,'V2'][y==0])
label1=plt.scatter(x.loc[:,'V1'][y==1],x.loc[:,'V2'][y==1])
label2=plt.scatter(x.loc[:,'V1'][y==2],x.loc[:,'V2'][y==2])
plt.title("unlabeled data")
plt.xlabel('V1')
plt.ylabel('V2')
plt.legend((label0,label1,label2),('label0','label1','label2'))
plt.show()

建立模型并训练:

#set kmean model
from sklearn.cluster import KMeans
KM=KMeans(n_clusters=3,random_state=0)
KM.fit(x)

从训练好的 `KMeans` 模型里获取各簇中心坐标并存入 `centers` 变量,随后创建新图形对象作为绘图“画布”。接着分别绘制属于标签 0、1、2 这三个不同簇的样本点,样本点的横、纵坐标对应数据中的 `V1` 和 `V2` 特征。之后为图形添加标题 “unlabeled data”,给横、纵坐标轴分别标注 “V1” 和 “V2”,同时添加图例以区分不同簇的样本点。最后,在图中绘制出各簇的中心位置,再通过 `plt.show()` 将整个图形展示出来。通过这样的可视化操作,能直观观察不同簇的样本分布以及各簇中心的位置,进而助力评估 K - Means 聚类的效果。

centers=KM.cluster_centers_
fig3=plt.figure()
label0=plt.scatter(x.loc[:,'V1'][y==0],x.loc[:,'V2'][y==0])
label1=plt.scatter(x.loc[:,'V1'][y==1],x.loc[:,'V2'][y==1])
label2=plt.scatter(x.loc[:,'V1'][y==2],x.loc[:,'V2'][y==2])
plt.title("unlabeled data")
plt.xlabel('V1')
plt.ylabel('V2')
plt.legend((label0,label1,label2),('label0','label1','label2'))
plt.scatter(centers[:,0],centers[:,1])
plt.show()

运行完下面这段代码我们会得到的这个模型的准确值,这里的结果是

0.31966666666666665

这是因为可能会出现标签混乱的问题。

#预测V1=80,V2=60
y_predict_test=KM.predict([[80,60]])
print(y_predict_test)
#predict based on traning data
y_predict=KM.predict(x)
print(pd.value_counts(y_predict),pd.value_counts(y))
from sklearn.metrics import accuracy_score
accuracy=accuracy_score(y,y_predict)
print(accuracy)

让我们对比一下生成的和正确的图片

fig4=plt.subplot(121)
label0=plt.scatter(x.loc[:,'V1'][y_predict==0],x.loc[:,'V2'][y_predict==0])
label1=plt.scatter(x.loc[:,'V1'][y_predict==1],x.loc[:,'V2'][y_predict==1])
label2=plt.scatter(x.loc[:,'V1'][y_predict==2],x.loc[:,'V2'][y_predict==2])
plt.title("predict data")
plt.xlabel('V1')
plt.ylabel('V2')
plt.legend((label0,label1,label2),('label0','label1','label2'))
plt.scatter(centers[:,0],centers[:,1])

fig4=plt.subplot(122)
label0=plt.scatter(x.loc[:,'V1'][y==0],x.loc[:,'V2'][y==0])
label1=plt.scatter(x.loc[:,'V1'][y==1],x.loc[:,'V2'][y==1])
label2=plt.scatter(x.loc[:,'V1'][y==2],x.loc[:,'V2'][y==2])
plt.title("labeled data")
plt.xlabel('V1')
plt.ylabel('V2')
plt.legend((label0,label1,label2),('label0','label1','label2'))
plt.scatter(centers[:,0],centers[:,1])
plt.show()

可以看到分类是正确的,但是块的label是混乱的

接下来需要矫正

#矫正
y_corrected=[]
for i in y_predict:
    if i==0:
        y_corrected.append(2)
    elif i==1:
        y_corrected.append(1) 
    else:
        y_corrected.append(0)
print(pd.value_counts(y_corrected),pd.value_counts(y))

再重新使用accuracy就会得出校正后的正确率的0.9几

Mean-Shift:
#自动计算带宽
from sklearn.cluster import MeanShift,estimate_bandwidth
bw=estimate_bandwidth(x,n_samples=500)
#模型建立与训练
ms=MeanShift(bandwidth=bw)
ms.fit(x)
y_predict_ms=ms.predict(x)
print(pd.value_counts(y_predict_ms),pd.value_counts(y))
fig9=plt.subplot(121)
label0=plt.scatter(x.loc[:,'V1'][y_predict_ms==0],x.loc[:,'V2'][y_predict_ms==0])
label1=plt.scatter(x.loc[:,'V1'][y_predict_ms==1],x.loc[:,'V2'][y_predict_ms==1])
label2=plt.scatter(x.loc[:,'V1'][y_predict_ms==2],x.loc[:,'V2'][y_predict_ms==2])
plt.title("ms data")
plt.xlabel('V1')
plt.ylabel('V2')
plt.legend((label0,label1,label2),('label0','label1','label2'))
plt.scatter(centers[:,0],centers[:,1])

fig10=plt.subplot(122)
label0=plt.scatter(x.loc[:,'V1'][y==0],x.loc[:,'V2'][y==0])
label1=plt.scatter(x.loc[:,'V1'][y==1],x.loc[:,'V2'][y==1])
label2=plt.scatter(x.loc[:,'V1'][y==2],x.loc[:,'V2'][y==2])
plt.title("labeled data")
plt.xlabel('V1')
plt.ylabel('V2')
plt.legend((label0,label1,label2),('label0','label1','label2'))
plt.scatter(centers[:,0],centers[:,1])
plt.show()

#矫正
y_corrected_ms=[]
for i in y_predict_ms:
    if i==0:
        y_corrected_ms.append(2)
    elif i==1:
        y_corrected_ms.append(1) 
    else:
        y_corrected_ms.append(0)
print(pd.value_counts(y_corrected_ms),pd.value_counts(y))

这里有一个细节需要注意,要改一下y_corrected_ms的格式

y_corrected_ms=np.array(y_corrected_ms)
fig11=plt.subplot(121)
label0=plt.scatter(x.loc[:,'V1'][y_corrected_ms==0],x.loc[:,'V2'][y_corrected_ms==0])
label1=plt.scatter(x.loc[:,'V1'][y_corrected_ms==1],x.loc[:,'V2'][y_corrected_ms==1])
label2=plt.scatter(x.loc[:,'V1'][y_corrected_ms==2],x.loc[:,'V2'][y_corrected_ms==2])
plt.title("y_corrected_ms data")
plt.xlabel('V1')
plt.ylabel('V2')
plt.legend((label0,label1,label2),('label0','label1','label2'))
plt.scatter(centers[:,0],centers[:,1])

fig8=plt.subplot(122)
label2=plt.scatter(x.loc[:,'V1'][y==0],x.loc[:,'V2'][y==0])
label1=plt.scatter(x.loc[:,'V1'][y==1],x.loc[:,'V2'][y==1])
label2=plt.scatter(x.loc[:,'V1'][y==2],x.loc[:,'V2'][y==2])
plt.title("labeled data")
plt.xlabel('V1')
plt.ylabel('V2')
plt.legend((label0,label1,label2),('label0','label1','label2'))
plt.scatter(centers[:,0],centers[:,1])
plt.show()

五、如何选择适合的聚类算法?

  1. 若数据规模大、簇是球形、已知大致簇数量 → 选KMeans(速度快、成本低)
  1. 若数据簇形状不规则、未知簇数量、需灵活适配 → 选Mean - shift(自动找簇、适配非球形)
  1. 若数据含噪声、需过滤异常值、簇形状复杂 → 选DBSCAN(密度筛选、抗噪声)

聚类分析的核心是 “让相似的数据聚在一起”,没有绝对最优的算法,只有 “最适配场景” 的选择。实际应用中,建议结合数据分布、业务需求和参数调优,才能得到理想的聚类结果。

您可能感兴趣的与本文相关的镜像

Python3.8

Python3.8

Conda
Python

Python 是一种高级、解释型、通用的编程语言,以其简洁易读的语法而闻名,适用于广泛的应用,包括Web开发、数据分析、人工智能和自动化脚本

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值