GMM聚类算法

from pyspark.sql import Row
from pyspark.ml.clustering import GaussianMixture, GaussianMixtureModel
from pyspark.ml.linalg import Vectors

def f(x):
rel = {}
rel['features']=Vectors. \
dense(str(x[2]),str(x[24]),str(x[28]),str(x[29]))
rel['label'] = str(x[22])
return rel

data = spark.sparkContext.textFile("file:///home/hw17685187119/student2.txt").map(lambda line: line.split(';')).map(lambda p: Row(**f(p))).toDF()

gm = GaussianMixture().setK(3).setPredictionCol("Prediction").setProbabilityCol("Probability")
gmm = gm.fit(data)


result = gmm.transform(data)
result.show(150, False)


for i in range(3):
print("Component "+str(i)+": weight is "+str(gmm.weights[i])+"\n mu vector is "+str( gmm.gaussiansDF.select('mean').head())+"\n sigma matrix is "+ str(gmm.gaussiansDF.select('cov').head()))

posted @ 2021-06-16 13:57  Plum_Brilliant  阅读(253)  评论(0编辑  收藏  举报