挑战全网比较全的spark求topN的思路(java、scala)

  • java版本:

1、自定义实现排序key,实现Ordered接口,根据指定的排序条件,重写compare 、less、greater等方法,封装多个字段进行排序;

  

// 代码示例其中CategorySortKey为自定义的key
 JavaPairRDD<CategorySortKey, String> sortedCategoryCountRDD = sortKey2countRDD.sortByKey(false);
  1 public class CategorySortKey implements Ordered<CategorySortKey>, Serializable {
  2     
  3     private static final long serialVersionUID = -6007890914324789180L;
  4     
  5     private long clickCount;
  6     private long orderCount;
  7     private long payCount;
  8     
  9     public CategorySortKey(long clickCount, long orderCount, long payCount) {
 10         this.clickCount = clickCount;
 11         this.orderCount = orderCount;
 12         this.payCount = payCount;
 13     }
 14 
 15     @Override
 16     public boolean $greater(CategorySortKey other) {
 17         if(clickCount > other.getClickCount()) {
 18             return true;
 19         } else if(clickCount == other.getClickCount() && 
 20                 orderCount > other.getOrderCount()) {
 21             return true;
 22         } else if(clickCount == other.getClickCount() &&
 23                 orderCount == other.getOrderCount() &&
 24                 payCount > other.getPayCount()) {
 25             return true;
 26         }
 27         return false;
 28     }
 29 
 30     @Override
 31     public boolean $greater$eq(CategorySortKey other) {
 32         if($greater(other)) {
 33             return true;
 34         } else if(clickCount == other.getClickCount() &&
 35                 orderCount == other.getOrderCount() &&
 36                 payCount == other.getPayCount()) {
 37             return true;
 38         }
 39         return false;
 40     }
 41     
 42     @Override
 43     public boolean $less(CategorySortKey other) {
 44         if(clickCount < other.getClickCount()) {
 45             return true;
 46         } else if(clickCount == other.getClickCount() && 
 47                 orderCount < other.getOrderCount()) {
 48             return true;
 49         } else if(clickCount == other.getClickCount() &&
 50                 orderCount == other.getOrderCount() &&
 51                 payCount < other.getPayCount()) {
 52             return true;
 53         }
 54         return false;
 55     }
 56 
 57     @Override
 58     public boolean $less$eq(CategorySortKey other) {
 59         if($less(other)) {
 60             return true;
 61         } else if(clickCount == other.getClickCount() &&
 62                 orderCount == other.getOrderCount() &&
 63                 payCount == other.getPayCount()) {
 64             return true;
 65         }
 66         return false;
 67     }
 68 
 69     @Override
 70     public int compare(CategorySortKey other) {
 71         if(clickCount - other.getClickCount() != 0) {
 72             return (int) (clickCount - other.getClickCount());
 73         } else if(orderCount - other.getOrderCount() != 0) {
 74             return (int) (orderCount - other.getOrderCount());
 75         } else if(payCount - other.getPayCount() != 0) {
 76             return (int) (payCount - other.getPayCount());
 77         }
 78         return 0;
 79     }
 80     
 81     @Override
 82     public int compareTo(CategorySortKey other) {
 83         if(clickCount - other.getClickCount() != 0) {
 84             return (int) (clickCount - other.getClickCount());
 85         } else if(orderCount - other.getOrderCount() != 0) {
 86             return (int) (orderCount - other.getOrderCount());
 87         } else if(payCount - other.getPayCount() != 0) {
 88             return (int) (payCount - other.getPayCount());
 89         }
 90         return 0;
 91     }
 92 
 93     public long getClickCount() {
 94         return clickCount;
 95     }
 96 
 97     public void setClickCount(long clickCount) {
 98         this.clickCount = clickCount;
 99     }
100 
101     public long getOrderCount() {
102         return orderCount;
103     }
104 
105     public void setOrderCount(long orderCount) {
106         this.orderCount = orderCount;
107     }
108 
109     public long getPayCount() {
110         return payCount;
111     }
112 
113     public void setPayCount(long payCount) {
114         this.payCount = payCount;
115     }  
116     
117 }
View Code

2、通过一些计算方式,获取到封装的实体类集合,然后利用java8的多条件排序(thenComparing;

 

3、若是想获取每个领域的topN:(如想获取每个班级的top10的数学成绩的同学姓名)

  A、先声明长度为N的数组,然后一个个去进行比较,然后数组移位,用大的替换小的;

 1 // 此代码是rdd操作的一部分,如foreach算子,每一个tuple2都会执行、
 2 long categoryid = tuple._1;
 3             Iterator<String> iterator = tuple._2.iterator();
 4 
 5             // 定义取topn的排序数组
 6             String[] top10Sessions = new String[10];
 7 
 8             while (iterator.hasNext()) {
 9               String sessionCount = iterator.next();
10               long count = Long.valueOf(sessionCount.split(",")[1]);
11 
12               // 遍历排序数组
13               for (int i = 0; i < top10Sessions.length; i++) {
14                 // 如果当前i位,没有数据,那么直接将i位数据赋值为当前sessionCount
15                 if (top10Sessions[i] == null) {
16                   top10Sessions[i] = sessionCount;
17                   break;
18                 } else {
19                   long _count = Long.valueOf(top10Sessions[i].split(",")[1]);
20 
21                   // 如果sessionCount比i位的sessionCount要大
22                   if (count > _count) {
23                     // 从排序数组最后一位开始,到i位,所有数据往后挪一位
24                     for (int j = 9; j > i; j--) {
25                       top10Sessions[j] = top10Sessions[j - 1];
26                     }
27                     // 将i位赋值为sessionCount
28                     top10Sessions[i] = sessionCount;
29                     break;
30                   }
31 
32                   // 比较小,继续外层for循环
33                 }
34               }
35             }
View Code
  • scala版本:

1、可将要排序的值形成一个tuple2mapRDD,然后执行rdd的sortBy(_._2);

2、通过一些计算方式,获取到封装的实体类集合,利用将 value 转换为数组,sortWith(lt : scala.Function2[A,A,scala.Boolean])进行指定方式排序;

3、若是想获取每个领域的topN:(如想获取每个班级的top10的数学成绩的同学姓名)

  A、先声明长度为N的数组,然后一个个去进行比较,然后数组移位,用大的替换小的。(此方式与java版本一致);

  B、scala:(1)按照key 对数据进行聚合(groupByKey)(2)将 value 转换为数组,利用 scala 的 sortBy 或者sortWith进行排序(mapValues),但是数据量太大会OOM。

 

 1     //    (省份,[(广告A,sum),(广告B,sum),(广告C,sum)])
 2     val groupRDD: RDD[(String, Iterable[(String, Int)])] = newMapRDD.groupByKey()
 3 
 4     //   将分组后的数据组内排序(降序),取前三名
 5     //保持key不变,对value进行操作,使用mapValues
 6     //降序(List):(Ordering.Int.reverse)
 7     val resultRDD: RDD[(String, List[(String, Int)])] = groupRDD.mapValues(
 8       iter => {
 9         iter.toList.sortBy(_._2)(Ordering.Int.reverse).take(3)
10       }
11     )

 

  C、弥补上个方案的缺点:此时可以拆解上述步骤:(1)取出所有的 key (2)对 key 进行迭代,每次取出一个 key 利用 spark 的排序算子进行排序

 1  //    (省份1,(广告A,sum1))、 (省份1,(广告B,sum2))
 2  //    (省份2,(广告A,sum3))、 (省份2,(广告B,sum4))
 3     val provinces: Array[String] = subTeacherAndOne.map(_._1).distinct().collect()
 4 
 5     for (province <- provinces){
 6       // 按照学科将数据进行过滤出来
 7       val filtered: RDD[((String, String), Int)] = subTeacherAndOne.filter(t => t._1.equals(province))
 8       // RDD的takeOrdered方法,可以设置排序规则,并指定取前几个值
 9       val res: Array[((String, String), Int)] = filtered.takeOrdered(2)((a, b) => a._2._2 - b._2._2)
10       // 打印结果
11       println(res.toBuffer)
12     }

  D、使用自定义分区器取TopN

 1 import org.apache.spark.{Partitioner, SparkConf, SparkContext}
 2 import org.apache.spark.rdd.RDD
 3 
 4 import scala.collection.mutable
 5 
 6 /**
 7  * 使用自定义分区器取topn,可以将较大的数据按照分区加载到内存中进行排序取topn,这样减小了
 8  * 内存的压力,同时也加快了数据,但是如果某个分区中的数据量过大的话,这个方法就不适用了
 9  */
10 object Teacher03 {
11   def main(args: Array[String]): Unit = {
12     // 创建配置参数对象,进行参数配置
13     val conf = new SparkConf().setAppName(this.getClass.getSimpleName)
14     var isLocal:Boolean = args(0).toBoolean
15     if(isLocal){
16       conf.setMaster("local[*]")
17     }
18     // 创建Spark入口
19     val sc = new SparkContext(conf)
20     // 读取文件
21     val lines: RDD[String] = sc.textFile(args(1))
22     // 对数据进行切分聚合处理
23     val subTeacherAndOne: RDD[((String, String), Int)] = lines.map(line => {
24       val strArr = line.split("/")
25       val url: String = strArr(2)
26       val teacher: String = strArr(3)
27       val subject = url.substring(0, url.indexOf("."))
28       ((subject, teacher), 1)
29     }).reduceByKey(_ + _)
30     // 将学科信息收集到客户端
31     val subjects: Array[String] = subTeacherAndOne.map(_._1._1).distinct().collect()
32     // 根据自定义的分区器进行分区,有几个学科就有几个分区
33     val partitioned: RDD[((String, String), Int)] = subTeacherAndOne.partitionBy(new SubjectPartitioner(subjects))
34     // 使用mapPartition进行区内排序取topn
35     val res: RDD[((String, String), Int)] = partitioned.mapPartitions(t => {
36       val it = t.toList.sortBy(_._2).take(2).iterator
37       it
38     })
39     // 打印结果
40     println(res.collect().toBuffer)
41     // 释放资源
42     sc.stop()
43   }
44 }
45 
46 /**
47  * 自定义分区器在创建好对象后,task会去执行里面的方法,getPartition(key: Any)这个方法
48  * 中的key,task会根据自己设定的key.asInstanceOf的类型,进行匹配对应的类型
49  * @param subjects
50  */
51 class SubjectPartitioner(subjects : Array[String]) extends Partitioner {
52   // 设置分区规则,按照学科指定分区
53   private val ruels = new mutable.HashMap[String, Int]()
54   var i:Int = 0
55   // 将指定每个学科归属哪个分区
56   for (sub <- subjects){
57     ruels(sub) = i
58     i += 1
59   }
60 
61   // 重写指定分区数量的方法
62   override def numPartitions: Int = subjects.length
63 
64   // 重写获取分区的方法
65   override def getPartition(key: Any): Int = {
66     val subject: String = key.asInstanceOf[(String, String)]._1
67     ruels(subject)
68   }
69 }

 

posted @ 2022-07-03 23:59  杰然不同2019  阅读(424)  评论(0编辑  收藏  举报