挑战全网比较全的spark求topN的思路(java、scala)
- java版本:
1、自定义实现排序key,实现Ordered接口,根据指定的排序条件,重写compare 、less、greater等方法,封装多个字段进行排序;
// 代码示例其中CategorySortKey为自定义的key
JavaPairRDD<CategorySortKey, String> sortedCategoryCountRDD = sortKey2countRDD.sortByKey(false);

1 public class CategorySortKey implements Ordered<CategorySortKey>, Serializable { 2 3 private static final long serialVersionUID = -6007890914324789180L; 4 5 private long clickCount; 6 private long orderCount; 7 private long payCount; 8 9 public CategorySortKey(long clickCount, long orderCount, long payCount) { 10 this.clickCount = clickCount; 11 this.orderCount = orderCount; 12 this.payCount = payCount; 13 } 14 15 @Override 16 public boolean $greater(CategorySortKey other) { 17 if(clickCount > other.getClickCount()) { 18 return true; 19 } else if(clickCount == other.getClickCount() && 20 orderCount > other.getOrderCount()) { 21 return true; 22 } else if(clickCount == other.getClickCount() && 23 orderCount == other.getOrderCount() && 24 payCount > other.getPayCount()) { 25 return true; 26 } 27 return false; 28 } 29 30 @Override 31 public boolean $greater$eq(CategorySortKey other) { 32 if($greater(other)) { 33 return true; 34 } else if(clickCount == other.getClickCount() && 35 orderCount == other.getOrderCount() && 36 payCount == other.getPayCount()) { 37 return true; 38 } 39 return false; 40 } 41 42 @Override 43 public boolean $less(CategorySortKey other) { 44 if(clickCount < other.getClickCount()) { 45 return true; 46 } else if(clickCount == other.getClickCount() && 47 orderCount < other.getOrderCount()) { 48 return true; 49 } else if(clickCount == other.getClickCount() && 50 orderCount == other.getOrderCount() && 51 payCount < other.getPayCount()) { 52 return true; 53 } 54 return false; 55 } 56 57 @Override 58 public boolean $less$eq(CategorySortKey other) { 59 if($less(other)) { 60 return true; 61 } else if(clickCount == other.getClickCount() && 62 orderCount == other.getOrderCount() && 63 payCount == other.getPayCount()) { 64 return true; 65 } 66 return false; 67 } 68 69 @Override 70 public int compare(CategorySortKey other) { 71 if(clickCount - other.getClickCount() != 0) { 72 return (int) (clickCount - other.getClickCount()); 73 } else if(orderCount - other.getOrderCount() != 0) { 74 return (int) (orderCount - other.getOrderCount()); 75 } else if(payCount - other.getPayCount() != 0) { 76 return (int) (payCount - other.getPayCount()); 77 } 78 return 0; 79 } 80 81 @Override 82 public int compareTo(CategorySortKey other) { 83 if(clickCount - other.getClickCount() != 0) { 84 return (int) (clickCount - other.getClickCount()); 85 } else if(orderCount - other.getOrderCount() != 0) { 86 return (int) (orderCount - other.getOrderCount()); 87 } else if(payCount - other.getPayCount() != 0) { 88 return (int) (payCount - other.getPayCount()); 89 } 90 return 0; 91 } 92 93 public long getClickCount() { 94 return clickCount; 95 } 96 97 public void setClickCount(long clickCount) { 98 this.clickCount = clickCount; 99 } 100 101 public long getOrderCount() { 102 return orderCount; 103 } 104 105 public void setOrderCount(long orderCount) { 106 this.orderCount = orderCount; 107 } 108 109 public long getPayCount() { 110 return payCount; 111 } 112 113 public void setPayCount(long payCount) { 114 this.payCount = payCount; 115 } 116 117 }
2、通过一些计算方式,获取到封装的实体类集合,然后利用java8的多条件排序(thenComparing);
3、若是想获取每个领域的topN:(如想获取每个班级的top10的数学成绩的同学姓名)
A、先声明长度为N的数组,然后一个个去进行比较,然后数组移位,用大的替换小的;

1 // 此代码是rdd操作的一部分,如foreach算子,每一个tuple2都会执行、 2 long categoryid = tuple._1; 3 Iterator<String> iterator = tuple._2.iterator(); 4 5 // 定义取topn的排序数组 6 String[] top10Sessions = new String[10]; 7 8 while (iterator.hasNext()) { 9 String sessionCount = iterator.next(); 10 long count = Long.valueOf(sessionCount.split(",")[1]); 11 12 // 遍历排序数组 13 for (int i = 0; i < top10Sessions.length; i++) { 14 // 如果当前i位,没有数据,那么直接将i位数据赋值为当前sessionCount 15 if (top10Sessions[i] == null) { 16 top10Sessions[i] = sessionCount; 17 break; 18 } else { 19 long _count = Long.valueOf(top10Sessions[i].split(",")[1]); 20 21 // 如果sessionCount比i位的sessionCount要大 22 if (count > _count) { 23 // 从排序数组最后一位开始,到i位,所有数据往后挪一位 24 for (int j = 9; j > i; j--) { 25 top10Sessions[j] = top10Sessions[j - 1]; 26 } 27 // 将i位赋值为sessionCount 28 top10Sessions[i] = sessionCount; 29 break; 30 } 31 32 // 比较小,继续外层for循环 33 } 34 } 35 }
- scala版本:
1、可将要排序的值形成一个tuple2的mapRDD,然后执行rdd的sortBy(_._2);
2、通过一些计算方式,获取到封装的实体类集合,利用将 value 转换为数组,sortWith(lt : scala.Function2[A,A,scala.Boolean])进行指定方式排序;
3、若是想获取每个领域的topN:(如想获取每个班级的top10的数学成绩的同学姓名)
A、先声明长度为N的数组,然后一个个去进行比较,然后数组移位,用大的替换小的。(此方式与java版本一致);
B、scala:(1)按照key 对数据进行聚合(groupByKey)(2)将 value 转换为数组,利用 scala 的 sortBy 或者sortWith进行排序(mapValues),但是数据量太大会OOM。
1 // (省份,[(广告A,sum),(广告B,sum),(广告C,sum)]) 2 val groupRDD: RDD[(String, Iterable[(String, Int)])] = newMapRDD.groupByKey() 3 4 // 将分组后的数据组内排序(降序),取前三名 5 //保持key不变,对value进行操作,使用mapValues 6 //降序(List):(Ordering.Int.reverse) 7 val resultRDD: RDD[(String, List[(String, Int)])] = groupRDD.mapValues( 8 iter => { 9 iter.toList.sortBy(_._2)(Ordering.Int.reverse).take(3) 10 } 11 )
C、弥补上个方案的缺点:此时可以拆解上述步骤:(1)取出所有的 key (2)对 key 进行迭代,每次取出一个 key 利用 spark 的排序算子进行排序
1 // (省份1,(广告A,sum1))、 (省份1,(广告B,sum2)) 2 // (省份2,(广告A,sum3))、 (省份2,(广告B,sum4)) 3 val provinces: Array[String] = subTeacherAndOne.map(_._1).distinct().collect() 4 5 for (province <- provinces){ 6 // 按照学科将数据进行过滤出来 7 val filtered: RDD[((String, String), Int)] = subTeacherAndOne.filter(t => t._1.equals(province)) 8 // RDD的takeOrdered方法,可以设置排序规则,并指定取前几个值 9 val res: Array[((String, String), Int)] = filtered.takeOrdered(2)((a, b) => a._2._2 - b._2._2) 10 // 打印结果 11 println(res.toBuffer) 12 }
D、使用自定义分区器取TopN
1 import org.apache.spark.{Partitioner, SparkConf, SparkContext} 2 import org.apache.spark.rdd.RDD 3 4 import scala.collection.mutable 5 6 /** 7 * 使用自定义分区器取topn,可以将较大的数据按照分区加载到内存中进行排序取topn,这样减小了 8 * 内存的压力,同时也加快了数据,但是如果某个分区中的数据量过大的话,这个方法就不适用了 9 */ 10 object Teacher03 { 11 def main(args: Array[String]): Unit = { 12 // 创建配置参数对象,进行参数配置 13 val conf = new SparkConf().setAppName(this.getClass.getSimpleName) 14 var isLocal:Boolean = args(0).toBoolean 15 if(isLocal){ 16 conf.setMaster("local[*]") 17 } 18 // 创建Spark入口 19 val sc = new SparkContext(conf) 20 // 读取文件 21 val lines: RDD[String] = sc.textFile(args(1)) 22 // 对数据进行切分聚合处理 23 val subTeacherAndOne: RDD[((String, String), Int)] = lines.map(line => { 24 val strArr = line.split("/") 25 val url: String = strArr(2) 26 val teacher: String = strArr(3) 27 val subject = url.substring(0, url.indexOf(".")) 28 ((subject, teacher), 1) 29 }).reduceByKey(_ + _) 30 // 将学科信息收集到客户端 31 val subjects: Array[String] = subTeacherAndOne.map(_._1._1).distinct().collect() 32 // 根据自定义的分区器进行分区,有几个学科就有几个分区 33 val partitioned: RDD[((String, String), Int)] = subTeacherAndOne.partitionBy(new SubjectPartitioner(subjects)) 34 // 使用mapPartition进行区内排序取topn 35 val res: RDD[((String, String), Int)] = partitioned.mapPartitions(t => { 36 val it = t.toList.sortBy(_._2).take(2).iterator 37 it 38 }) 39 // 打印结果 40 println(res.collect().toBuffer) 41 // 释放资源 42 sc.stop() 43 } 44 } 45 46 /** 47 * 自定义分区器在创建好对象后,task会去执行里面的方法,getPartition(key: Any)这个方法 48 * 中的key,task会根据自己设定的key.asInstanceOf的类型,进行匹配对应的类型 49 * @param subjects 50 */ 51 class SubjectPartitioner(subjects : Array[String]) extends Partitioner { 52 // 设置分区规则,按照学科指定分区 53 private val ruels = new mutable.HashMap[String, Int]() 54 var i:Int = 0 55 // 将指定每个学科归属哪个分区 56 for (sub <- subjects){ 57 ruels(sub) = i 58 i += 1 59 } 60 61 // 重写指定分区数量的方法 62 override def numPartitions: Int = subjects.length 63 64 // 重写获取分区的方法 65 override def getPartition(key: Any): Int = { 66 val subject: String = key.asInstanceOf[(String, String)]._1 67 ruels(subject) 68 } 69 }
【推荐】还在用 ECharts 开发大屏?试试这款永久免费的开源 BI 工具!
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步