python plt让scatter能够使不同类别的点有不同的颜色、大小和形状

参考:https://blog.csdn.net/weixin_43769946/article/details/103522194

python自带的plt可以给不同类别生成不同的颜色,但不能生成不同的性形状。所以需要自己实现一个方法。

1.定义mscatter函数

import matplotlib.pyplot as plt
def mscatter(x, y, ax=None, m=None, **kw):
    import matplotlib.markers as mmarkers
    if not ax: ax = plt.gca()
    sc = ax.scatter(x, y, **kw)
    if (m is not None) and (len(m) == len(x)):
        paths = []
        for marker in m:
            if isinstance(marker, mmarkers.MarkerStyle):
                marker_obj = marker
            else:
                marker_obj = mmarkers.MarkerStyle(marker)
            path = marker_obj.get_path().transformed(
                marker_obj.get_transform())
            paths.append(path)
        sc.set_paths(paths)
    return sc

2.例子

from sklearn.datasets import make_blobs
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.mixture import GaussianMixture
import numpy as np

def mscatter(x, y, ax=None, m=None, **kw):
    import matplotlib.markers as mmarkers
    if not ax: ax = plt.gca()
    sc = ax.scatter(x, y, **kw)
    if (m is not None) and (len(m) == len(x)):
        paths = []
        for marker in m:
            if isinstance(marker, mmarkers.MarkerStyle):
                marker_obj = marker
            else:
                marker_obj = mmarkers.MarkerStyle(marker)
            path = marker_obj.get_path().transformed(
                marker_obj.get_transform())
            paths.append(path)
        sc.set_paths(paths)
    return sc

def datasets():
    """
    使用datasets包产生一些数据
    :return:
    """
    
    plt.rcParams['axes.unicode_minus'] = False#解决不显示负数问题
    X,y_true = make_blobs(n_samples=400,centers=3,random_state=1000)
  #
random_state=1000效果也不错
rng = np.random.RandomState(70) Y = np.dot(X,rng.randn(2,2)) plt.scatter(Y[:,0],Y[:,1],s=30) plt.title("original data") plt.show() return X,Y def GmmKmean(data,dataY): """ GMM算法与Kmeans算法对比 :return: """ kmeansy = KMeans(n_clusters=3,random_state=1) kmeansy.fit(dataY) datay_kmeans = kmeansy.predict(dataY) # 可视化 # print(datay_kmeans[0]) map_marker = {0: '*', 1: 'P', 2: '^'} markers = list(map(lambda x: map_marker[x], datay_kmeans)) centers = kmeansy.cluster_centers_ # print(centers) mscatter(dataY[:,0],dataY[:,1],c=datay_kmeans,s=30,m=markers,cmap='viridis') plt.scatter(centers[:,0],centers[:,1],c="red",s=100) plt.title("k-means cluster") plt.show() # gmmY = GaussianMixture(n_components=4) gmm = GaussianMixture(n_components=3,random_state=10) gmm.fit(dataY) labelsY = gmm.predict(dataY) # print(labelsY) map_marker = {0: '*', 1: 'P', 2: '^'} markersG = list(map(lambda x: map_marker[x],labelsY)) mscatter(dataY[:,0],dataY[:,1],c=labelsY,s=30,m=markersG,cmap='viridis') plt.title("GMM cluster",fontproperties="SimHei") plt.show() return None if __name__ == "__main__": data,dataY = datasets() GmmKmean(data,dataY) # draw(KMeans(n_clusters=3).fit(dataY))

 

运行结果

 

posted @ 2022-10-11 15:44  StarZhai  阅读(2835)  评论(0编辑  收藏  举报