python spark example

这是年初写的一个job,用于对api非法访问(大量403)进行统计,并有其他后续手段。写的比较通俗易懂,做个sample记录下

数据源是kafka stream,实时计算。规则是mysql配置的,简单说就是1分钟内超过多少次403就记录下来

  1 import json
  2 import logging
  3 from datetime import datetime
  4 
  5 import MySQLdb
  6 from pyspark import SparkContext, SparkConf
  7 from pyspark.streaming import StreamingContext
  8 from pyspark.streaming.kafka import KafkaUtils
  9 
 10 logger = logging.getLogger()
 11 hdlr = logging.FileHandler('nginx_log_stats.log')
 12 formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
 13 hdlr.setFormatter(formatter)
 14 logger.addHandler(hdlr)
 15 logger.setLevel(logging.ERROR)
 16 
 17 
 18 def is_in_one_minute(nginx_timestamp):
 19     """
 20     :param nginx_time: "timestamp": "16/Feb/2017:08:23:59 +0000"
 21     :return:
 22     """
 23     now = datetime.now()
 24     nginx_datetime = datetime.strptime(nginx_timestamp.split('+')[0].strip(),
 25                                        '%d/%b/%Y:%H:%M:%S')
 26     return (now - nginx_datetime).seconds <= 60 if now > nginx_datetime else (nginx_datetime - now).seconds <= 60
 27 
 28 
 29 
 30 # save to mysql
 31 def saveToMysql(partition):
 32     host = "..."
 33     user = "..."
 34     password = "..."
 35     db_name = "..._monitor"
 36     db = MySQLdb.connect(host, user, password, db_name, charset='utf8')
 37     db.autocommit(True)
 38     cursor = db.cursor()
 39     for d1ct in partition:
 40         sql = r"""INSERT INTO `security_suspect_request` (`domain`, `api`, `code`, `ip`, `access_count`) VALUES ('{domain}', '{api}', '{code}', '{ip}', {access_count})""".format(
 41             domain=d1ct['domain'], api=d1ct['path'], code=d1ct['response'], ip=d1ct['ip'],
 42             access_count=d1ct['count'])
 43         cursor.execute(sql)
 44     db.close()
 45 
 46 
 47 def dictfetchall(cursor):
 48     "Return all rows from a cursor as a dict"
 49     columns = [col[0] for col in cursor.description]
 50     return [
 51         dict(zip(columns, row))
 52         for row in cursor.fetchall()
 53         ]
 54 
 55 
 56 def filterPolicy(log):
 57     '''
 58 {
 59   "path": "/var/log/nginx/webapi..../access-log",
 60   "host": "ip-10-...",
 61   "clientip": "10....",
 62   "timestamp": "16/Feb/2017:08:23:59 +0000",
 63   "domain": "...com",
 64   "verb": "POST",
 65   "request_path": "/video/upload",
 66   "request_param": "sig=b400fdce...&userId=...",
 67   "httpversion": "1.1",
 68   "response": "403",
 69   "bytes": "0",
 70   "agent": "Dalvik/1.6.0 (Linux; U; Android 4.4.4; SM-T561 Build/KTU84P)",
 71   "response_time": "0.110",
 72   "topic": "nginx"
 73 }
 74     '''
 75     # true save . false ignore
 76     true_flag = 0
 77     this = json.loads(log[1])
 78     # filter time
 79     if not is_in_one_minute(this['timestamp']):
 80         return False
 81     # filter condition
 82     for policy in filterVar.value:
 83         if policy['domain'] == 'all' or ('domain' in this.keys() and this['domain'] == policy['domain']):
 84             if policy['api'] == 'all' or ('request_path' in this.keys() and this['request_path'] == policy['api']):
 85                 if 'response' in this.keys() and this['response'] == str(policy['code']):
 86                     true_flag += 1
 87 
 88     return True if true_flag else False
 89 
 90 
 91 def countMap(log):
 92     import json, re
 93     this = json.loads(log[1])
 94     key = this.get('domain', "") + "--" + re.sub(r'\/\d+$', r'',
 95                                                  this.get('request_path', "") + "--" + this.get(
 96                                                      'clientip') + "--" + this.get('response'))
 97     value = {'count': 1}
 98     return key, value
 99 
100 
101 def countReduce(prev, cur):
102     cur['count'] = cur['count'] + prev['count']
103     return cur
104 
105 
106 def output(tup1e):
107     """
108     a touple (key, value)
109     """
110     tup1e[1]['domain'], tup1e[1]['path'], tup1e[1]['ip'], tup1e[1]['response'] = tup1e[0].split('--')
111     return tup1e[1]
112 
113 
114 def youAreUnderArrest(d1ct):
115     mylimit = None
116     for row in filterVar.value:
117         if row['domain'] == 'all' or row['domain'] == d1ct['domain']:
118             if row['api'] == 'all' or row['api'] == d1ct['path']:
119                 if row['code'] == int(d1ct['response']):
120                     mylimit = row['limit']
121 
122     return False if mylimit is None else d1ct['count'] >= mylimit
123 
124 
125 if __name__ == "__main__":
126     host = "..."
127     user = "..."
128     password = "..."
129     db_name = "..._monitor"
130     db = MySQLdb.connect(host, user, password, db_name, charset='utf8')
131     db.autocommit(True)
132     cur = db.cursor()
133     try:
134         # for now only support 1 row
135         cur.execute(r"""SELECT * FROM security_anti_hacker_policy""")
136         filter_option = dictfetchall(cur)
137     finally:
138         db.close()
139 
140     topic = 'nginx.log'
141     zkQuorum = '...:2181,...:2181,...:2181'
142     conf = (SparkConf()
143             .setMaster("spark://...:7077")
144             .setAppName("anti_hacker_stats")
145             .set("spark.driver.memory", "1g")
146             .set("spark.executor.memory", "1g")
147             .set("spark.cores.max", 2))
148     sc = SparkContext(conf=conf)
149     # broadcast variable for share
150     filterVar = sc.broadcast(filter_option)
151     ssc = StreamingContext(sc, 60)
152     kvs = KafkaUtils.createStream(ssc, zkQuorum, "anti-hacker", {topic: 1},
153                                   {"auto.offset.reset": 'largest'})
154     lines = kvs.filter(filterPolicy).map(countMap).reduceByKey(countReduce).map(output).filter(youAreUnderArrest)
155     lines.foreachRDD(lambda rdd: rdd.foreachPartition(saveToMysql))
156     # lines.saveAsTextFiles('test')
157     # lines = kvs.filter(filterPolicy)
158     # lines.pprint()
159     ssc.start()
160     ssc.awaitTermination()

python写spark需要在spark服务器上用pyspark执行,调试很不方便,更建议用scala,另有example

几个重点:

  1. 因为spark天然就是分布式的,所以每个rdd可以认为就是在在不同的机器上,是不能共享jdbc connection的,需要各写各的
  2. 因为上面那个原因,如果需要共享数据呢?很直观,就是150行那句sc.broadcast,将共享数据广播給各个rdd
  3. 数据格式很重要,你必须了解数据源里的格式
posted @ 2017-10-09 11:59  Els0n  阅读(1443)  评论(0编辑  收藏  举报