pyspark lda topic


from pyspark import SparkContext
from pyspark.sql import SQLContext
from pyspark.sql import SparkSession
from pyspark.sql import Row

import re
import numpy as np
from time import time
from sklearn.datasets import fetch_20newsgroups

from pyspark.ml.feature import CountVectorizer, HashingTF, IDF
from pyspark.ml.feature import Tokenizer, StopWordsRemover
from pyspark.ml.clustering import LDA

np.random.seed(0)

if __name__ == "__main__":

    sc = SparkContext('local', 'lda')
    sqlContext = SQLContext(sc)

    spark = SparkSession\
        .builder\
        .appName("LDA")\
        .getOrCreate()


    num_features = 8000  #vocabulary size
    num_topics = 20      #fixed for LDA

    #print "loading 20 newsgroups dataset..."
    tic = time()
    dataset = fetch_20newsgroups(shuffle=True, random_state=0, remove=('headers','footers','quotes'))
    train_corpus = dataset.data  # a list of 11314 documents / entries
    toc = time()
    print ("elapsed time: %.4f sec" %(toc - tic)   ) 

    #distribute data
    corpus_rdd = sc.parallelize(train_corpus)
    corpus_rdd = corpus_rdd.map(lambda doc: re.sub(r"[^A-Za-z]", " ", doc))
    corpus_rdd = corpus_rdd.map(lambda doc: u"".join(doc).encode('utf-8').strip())

    rdd_row = corpus_rdd.map(lambda doc: Row(raw_corpus=str(doc)))
    newsgroups = spark.createDataFrame(rdd_row)

    tokenizer = Tokenizer(inputCol="raw_corpus", outputCol="tokens")
    newsgroups = tokenizer.transform(newsgroups)
    newsgroups = newsgroups.drop('raw_corpus')       

    stopwords = StopWordsRemover(inputCol="tokens", outputCol="tokens_filtered")
    newsgroups = stopwords.transform(newsgroups)
    newsgroups = newsgroups.drop('tokens')

    count_vec = CountVectorizer(inputCol="tokens_filtered", outputCol="tf_features", vocabSize=num_features, minDF=2.0)
    count_vec_model = count_vec.fit(newsgroups)
    vocab = count_vec_model.vocabulary
    newsgroups = count_vec_model.transform(newsgroups)
    newsgroups = newsgroups.drop('tokens_filtered')

    #hashingTF = HashingTF(inputCol="tokens_filtered", outputCol="tf_features", numFeatures=num_features)
    #newsgroups = hashingTF.transform(newsgroups)
    #newsgroups = newsgroups.drop('tokens_filtered')

    idf = IDF(inputCol="tf_features", outputCol="features")
    newsgroups = idf.fit(newsgroups).transform(newsgroups)
    newsgroups = newsgroups.drop('tf_features')

    lda = LDA(k=num_topics, featuresCol="features", seed=0)
    model = lda.fit(newsgroups)

    topics = model.describeTopics()
    topics.show()

    model.topicsMatrix()

    topics_rdd = topics.rdd

    topics_words = topics_rdd\
       .map(lambda row: row['termIndices'])\
       .map(lambda idx_list: [vocab[idx] for idx in idx_list])\
       .collect()

    for idx, topic in enumerate(topics_words):
        print ("topic: ", idx)
        print ("----------")
        for word in topic:
            print( word)
        print( "----------")
elapsed time: 1.0284 sec
+-----+--------------------+--------------------+
|topic|         termIndices|         termWeights|
+-----+--------------------+--------------------+
|    0|[0, 552, 967, 108...|[0.01258332472159...|
|    1|[0, 1004, 76, 40,...|[0.08220619222238...|
|    2|[3, 0, 373, 18, 2...|[0.11591833022404...|
|    3|[3541, 1057, 2060...|[0.02274190214796...|
|    4|[0, 87, 364, 3645...|[0.01822486526972...|
|    5|[104, 0, 527, 188...|[0.01069089006155...|
|    6|[1, 4, 7, 16, 8, ...|[0.40037927605170...|
|    7|[0, 1079, 24, 325...|[0.01809503390053...|
|    8|[0, 50, 148, 19, ...|[0.01119972590376...|
|    9|[0, 261, 356, 340...|[0.00977327728361...|
|   10|[0, 182, 35, 743,...|[0.01214197116201...|
|   11|[0, 308, 706, 561...|[0.01776294690247...|
|   12|[0, 69, 38, 9, 35...|[0.01212285192631...|
|   13|[179, 1219, 0, 11...|[0.01490648021084...|
|   14|[0, 93, 133, 83, ...|[0.01778007036882...|
|   15|[569, 949, 0, 124...|[0.02115331163136...|
|   16|[755, 56, 303, 20...|[0.02666471302923...|
|   17|[171, 0, 299, 895...|[0.01153933315409...|
|   18|[0, 208, 574, 116...|[0.01336781044283...|
|   19|[831, 0, 2149, 87...|[0.00780105775382...|
+-----+--------------------+--------------------+

topic:  0
----------

monitor
printer
vga
pin
print
cable
apple
please
video
----------
topic:  1
----------

pts
com
edu
la
pt
vs
period
pp
w
----------
topic:  2
----------
x

entry
n
output
oname
c
entries
eof
file
----------
topic:  3
----------
den
dod
tank
b'if

accelerators
ctrl
radius
td
rc
posted @ 2022-08-19 22:58  luoganttcc  阅读(13)  评论(0编辑  收藏  举报