DBSCAN算法python实现(附完整数据集和代码)


目录

[TOC]

1. 算法思路

DBSCAN算法的核心是“延伸”。先找到一个未访问的点p,若该点是核心点,则创建一个新的簇C,将其邻域中的点放入该簇,并遍历其邻域中的点,若其邻域中有点q为核心点,则将q的邻域内的点也划入簇C,直到C不再扩展。直到最后所有的点都标记为已访问。

点p通过密度可达来扩大自己的“地盘”,实际上就是簇在“延伸”。

图示网站:https://www.naftaliharris.com/blog/visualizing-dbscan-clustering/ 可以看一下簇是如何延伸的。

2. 算法实现

2.1 计算两点之间的距离

# 计算两个点之间的欧式距离,参数为两个元组
def dist(t1, t2):
    dis = math.sqrt((np.power((t1[0]-t2[0]),2) + np.power((t1[1]-t2[1]),2)))
    # print("两点之间的距离为:"+str(dis))
    return dis

2.2 读取文件,加载数据集

def loadDataSet(fileName, splitChar='\t'):
    dataSet = []
    with open(fileName) as fr:
        for line in fr.readlines():
            curline = line.strip().split(splitChar)
            fltline = list(map(float, curline))
            dataSet.append(fltline)
    return dataSet

2.3 DBSCAN算法实现

1、标记点是否被访问:我设置了两个列表,一个存放未访问的点unvisited,一个存放已访问的点visited。每次访问一个点,unvisited列表remove该点,visited列表append该点,以此来实现点的标记改变。

2、C作为输出结果,初始时是一个长度为所有点的个数的值全为-1的列表。之后修改点对应的索引的值来设置点属于哪个簇

# DBSCAN算法,参数为数据集,Eps为指定半径参数,MinPts为制定邻域密度阈值
def dbscan(Data, Eps, MinPts):
    num = len(Data)  # 点的个数
    # print("点的个数:"+str(num))
    unvisited = [i for i in range(num)]  # 没有访问到的点的列表
    # print(unvisited)
    visited = []  # 已经访问的点的列表
    C = [-1 for i in range(num)]
    # C为输出结果,默认是一个长度为num的值全为-1的列表
    # 用k来标记不同的簇,k = -1表示噪声点
    k = -1
    # 如果还有没访问的点
    while len(unvisited) > 0:
        # 随机选择一个unvisited对象
        p = random.choice(unvisited)
        unvisited.remove(p)
        visited.append(p)
        # N为p的epsilon邻域中的对象的集合
        N = []
        for i in range(num):
            if (dist(Data[i], Data[p]) <= Eps):# and (i!=p):
                N.append(i)
        # 如果p的epsilon邻域中的对象数大于指定阈值,说明p是一个核心对象
        if len(N) >= MinPts:
            k = k+1
            # print(k)
            C[p] = k
            # 对于p的epsilon邻域中的每个对象pi
            for pi in N:
                if pi in unvisited:
                    unvisited.remove(pi)
                    visited.append(pi)
                    # 找到pi的邻域中的核心对象,将这些对象放入N中
                    # M是位于pi的邻域中的点的列表
                    M = []
                    for j in range(num):
                        if (dist(Data[j], Data[pi])<=Eps): #and (j!=pi):
                            M.append(j)
                    if len(M)>=MinPts:
                        for t in M:
                            if t not in N:
                                N.append(t)
                # 若pi不属于任何簇,C[pi] == -1说明C中第pi个值没有改动
                if C[pi] == -1:
                    C[pi] = k
        # 如果p的epsilon邻域中的对象数小于指定阈值,说明p是一个噪声点
        else:
            C[p] = -1

    return C

3. 问题记录

代码思路非常简单,让我以为实现起来也很简单。结果拖拖拉拉半个多月才终于将算法改好。

算法实现过程中遇到的问题其实是小问题,但是导致的结果非常严重。因为不起眼所以才难以察觉。

这是刚开始我运行算法得到的结果(Eps为10,MinPts为10):

img

Eps为2,MinPts为10(我改了点的大小):

img

可以看出图中颜色特别多,实际上就是聚成的簇太多,可实际上目测应该只有七八个簇。这是为什么呢?

原来是变量k的重复使用问题。

前面我用k来标识不同的簇,后面(如下图)我又将k变成了循环变量,注意M列表中都是整数,代表点在数据集中的索引,所以实际上是k在整数列表中遍历,覆盖掉了前面用来标识不同簇的k值,导致每次运行出来k取值特别多(如下下图)。

img

img

4. 运行结果

img

5. 完整代码

5.1 源数据

附数据集

链接:数据集788个点
提取码:rv06

5.2 源代码

# encoding:utf-8
import matplotlib.pyplot as plt
import random
import numpy as np
import math
from sklearn import datasets

list_1 = []
list_2 = []
# 数据集一:随机生成散点图,参数为点的个数
# def scatter(num):
#     for i in range(num):
#         x = random.randint(0, 100)
#         list_1.append(x)
#         y = random.randint(0, 100)
#         list_2.append(y)
#     print(list_1)
#     print(list_2)
#     data = list(zip(list_1, list_2))
#     print(data)
#     #plt.scatter(list_1, list_2)
#     #plt.show()
#     return data
#scatter(50)

def loadDataSet(fileName, splitChar='\t'):
    dataSet = []
    with open(fileName) as fr:
        for line in fr.readlines():
            curline = line.strip().split(splitChar)
            fltline = list(map(float, curline))
            dataSet.append(fltline)
    return dataSet

# 计算两个点之间的欧式距离,参数为两个元组
def dist(t1, t2):
    dis = math.sqrt((np.power((t1[0]-t2[0]),2) + np.power((t1[1]-t2[1]),2)))
    # print("两点之间的距离为:"+str(dis))
    return dis

# dis = dist((1,1),(3,4))
# print(dis)


# DBSCAN算法,参数为数据集,Eps为指定半径参数,MinPts为制定邻域密度阈值
def dbscan(Data, Eps, MinPts):
    num = len(Data)  # 点的个数
    # print("点的个数:"+str(num))
    unvisited = [i for i in range(num)]  # 没有访问到的点的列表
    # print(unvisited)
    visited = []  # 已经访问的点的列表
    C = [-1 for i in range(num)]
    # C为输出结果,默认是一个长度为num的值全为-1的列表
    # 用k来标记不同的簇,k = -1表示噪声点
    k = -1
    # 如果还有没访问的点
    while len(unvisited) > 0:
        # 随机选择一个unvisited对象
        p = random.choice(unvisited)
        unvisited.remove(p)
        visited.append(p)
        # N为p的epsilon邻域中的对象的集合
        N = []
        for i in range(num):
            if (dist(Data[i], Data[p]) <= Eps):# and (i!=p):
                N.append(i)
        # 如果p的epsilon邻域中的对象数大于指定阈值,说明p是一个核心对象
        if len(N) >= MinPts:
            k = k+1
            # print(k)
            C[p] = k
            # 对于p的epsilon邻域中的每个对象pi
            for pi in N:
                if pi in unvisited:
                    unvisited.remove(pi)
                    visited.append(pi)
                    # 找到pi的邻域中的核心对象,将这些对象放入N中
                    # M是位于pi的邻域中的点的列表
                    M = []
                    for j in range(num):
                        if (dist(Data[j], Data[pi])<=Eps): #and (j!=pi):
                            M.append(j)
                    if len(M)>=MinPts:
                        for t in M:
                            if t not in N:
                                N.append(t)
                # 若pi不属于任何簇,C[pi] == -1说明C中第pi个值没有改动
                if C[pi] == -1:
                    C[pi] = k
        # 如果p的epsilon邻域中的对象数小于指定阈值,说明p是一个噪声点
        else:
            C[p] = -1

    return C


# 数据集二:788个点
dataSet = loadDataSet('788points.txt', splitChar=',')
C = dbscan(dataSet, 2, 14)
print(C)
x = []
y = []
for data in dataSet:
    x.append(data[0])
    y.append(data[1])
plt.figure(figsize=(8, 6), dpi=80)
plt.scatter(x,y, c=C, marker='o')
plt.show()
# print(x)
# print(y)

文章作者: Leon
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 Leon !
评论
 上一篇
机器学习系列之决策树算法(10):决策树模型,XGBoost,LightGBM和CatBoost模型可视化 机器学习系列之决策树算法(10):决策树模型,XGBoost,LightGBM和CatBoost模型可视化
安装 graphviz 参考文档 http://graphviz.readthedocs.io/en/stable/manual.html#installation graphviz安装包下载地址 https://www.graphviz
下一篇 
短文本聚类【DBSCAN】算法原理+Python代码实现+聚类结果展示 短文本聚类【DBSCAN】算法原理+Python代码实现+聚类结果展示
目录[TOC] 1. 算法原理1.1 常见的聚类算法聚类算法属于常见的无监督分类算法,在很多场景下都有应用,如用户聚类,文本聚类等。常见的聚类算法可以分成两类: 以 k-means 为代表的基于分区的算法 以层次聚类为代表的基于层次划分的
  目录