spark记录(20)自定义累加器Accumulator

自定义累加器

/**
 * 自定义累加器需要继承AccumulatorV2<IN,OUT>类
 * 并且要指定要累加的类型
 */
public class MyAccumulator extends AccumulatorV2<MyKey,MyKey> implements Serializable {

    /**
     * 该累加状态是在Driver端初始化
     * 并且值也是保存在Driver端
     */
    private MyKey info = new MyKey(0, 0);

    public MyKey getInfo() {
        return info;
    }

    public void setInfo(MyKey info) {
        this.info = info;
    }

    /**
     * 判断是否是初始化状态
     * 直接与原始状态的值比较
     * 该判断为自己定义的判断方式
     * @return
     */
    @Override
    public boolean isZero() {
        return info.getPersonAgeSum()==0 && info.getPersonNum()==0;
    }

    /**
     * 为每个分区创建一个新的累加器
     * @return
     */
    @Override
    public AccumulatorV2<MyKey, MyKey> copy() {
        MyAccumulator myAccumulator = new MyAccumulator();
        myAccumulator.info = this.info;
        return myAccumulator;
    }

    /**
     * 初始化不同的partition分区中的累加类型
     */
    @Override
    public void reset() {
        info = new MyKey(0, 0);
    }

    /**
     * 进行累加时以何种规则进行累加
     * @param v 每条新进来的记录
     */
    @Override
    public void add(MyKey v) {
        info.setPersonNum(info.getPersonNum() + v.getPersonNum());
        info.setPersonAgeSum(info.getPersonAgeSum() + v.getPersonAgeSum());
    }

    /**
     * 合并不同partition分区中accumulator中储存的状态值
     * @param other 每个分区中的累加器
     */
    @Override
    public void merge(AccumulatorV2<MyKey, MyKey> other) {
        MyKey value = other.value();
        info.setPersonNum(info.getPersonNum()+value.getPersonNum());
        info.setPersonAgeSum(info.getPersonAgeSum()+value.getPersonAgeSum());
    }

    /**
     * 最后返回的累加完成的状态值
     * @return
     */
    @Override
    public MyKey value() {
        return info;
    }
}

 

自定义key

public class MyKey implements Serializable {
    private Integer personNum;
    private Integer personAgeSum;

    public MyKey() {
    }

    public MyKey(Integer personNum, Integer personAgeSum) {
        this.personNum = personNum;
        this.personAgeSum = personAgeSum;
    }

    public Integer getPersonNum() {
        return personNum;
    }

    public void setPersonNum(Integer personNum) {
        this.personNum = personNum;
    }

    public Integer getPersonAgeSum() {
        return personAgeSum;
    }

    public void setPersonAgeSum(Integer personAgeSum) {
        this.personAgeSum = personAgeSum;
    }

    @Override
    public String toString() {
        return "MyKey{" +
                "personNum=" + personNum +
                ", personAgeSum=" + personAgeSum +
                '}';
    }
}

 运行:

public class MyRun {
    public static void main(String[] args) {
        SparkConf conf = new SparkConf();
        conf.setAppName("testAccumulator");
        conf.setMaster("local");

        JavaSparkContext sc = new JavaSparkContext(conf);
        MyAccumulator acc = new MyAccumulator();

        sc.sc().register(acc,"PersonInfoAccumulator");
        JavaRDD<String> rdd = sc.parallelize(Arrays.asList(
                "zhangsan 1", "lisi 2", "wangwu 3", "zhaoliu 4", "tianqi 5", "zhengba 6"
        ));

        rdd.map(new Function<String, String>() {
            @Override
            public String call(String v1) throws Exception {
                acc.add(new MyKey(1,Integer.parseInt(v1.split(" ")[1])));
                return v1;
            }
        }).collect();

        System.out.println("value = "+acc.value());

    }
}

结果:

value = MyKey{personNum=6, personAgeSum=21}

posted @ 2019-11-05 21:22  kpsmile  阅读(323)  评论(0编辑  收藏  举报