- 累加器作为spark的一个共享变量的实现,在用于累加计数计算计算指标的时候可以有效的减少网络的消耗
-
spark中有一个节点的角色是Master,根据配置文件进行分配,Master节点的职责主要是参与worker节点之间的资源调度。
-
参与spark作业计算的是worker节点上的excutor,在最开始会将原始RDD复制到excutor的各个task进程上以供计算。这时候如果task过多,或者原始RDD过大,则会耗费更多的时间在资源复制上。
累加器可以实现将资源文件复制到每个excutor上,供excutor中的task进程计算使用,减少网络的占用
现在要实现一个累加的效果是,比如有很多个字符串,比如:
ONE:1 TWO:2 THREE:3 ONE:9
想要计算出ONE,TWO,THREE的count和,那么现在可以先自定义一个Accumulator2
import org.apache.spark.util.AccumulatorV2;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
public class UserDefinedAccumulator extends AccumulatorV2<String, String> {
private String one = "ONE";
private String two = "TWO";
private String three = "THREE";
//将想要计算的字符串拼接起来,并赋初始值,后续针对data进行累加,并返回
private String data = one + ":0;" + two + ":0;" + three + ":0;";
//原始状态
private String zero = data;
//判断是否是初始状态,直接与原始状态的字符串进行对比
@Override
public boolean isZero() {
return data.equals(zero);
}
//复制一个新的累加器
@Override
public AccumulatorV2<String, String> copy() {
return new UserDefinedAccumulator();
}
//重置,恢复原始状态
@Override
public void reset() {
data = zero;
}
//针对传入的字符串,与当前累加器现有的值进行累加
@Override
public void add(String v) {
data = mergeData(v, data, ";");
}
//将两个累加器的计算结果进行合并
@Override
public void merge(AccumulatorV2<String, String> other) {
data = mergeData(other.value(), data, ";");
}
//将此累加器的计算值返回
@Override
public String value() {
return data;
}
/**
* 合并两个字符串
* @param data_1 字符串1
* @param data_2 字符串2
* @param delimit 分隔符
* @return 结果
*/
private String mergeData(String data_1, String data_2, String delimit) {
StringBuffer res = new StringBuffer();
String[] infos_1 = data_1.split(delimit);
String[] infos_2 = data_2.split(delimit);
Map<String, Integer> map_1 = new HashMap<>();
Map<String, Integer> map_2 = new HashMap<>();
for (String info : infos_1) {
String[] kv = info.split(":");
if (kv.length == 2) {
String k = kv[0].toUpperCase();
Integer v = Integer.valueOf(kv[1]);
map_1.put(k, v);
continue;
}
}
for (String info : infos_2) {
String[] kv = info.split(":");
if (kv.length == 2) {
String k = kv[0].toUpperCase();
Integer v = Integer.valueOf(kv[1]);
map_2.put(k, v);
continue;
}
}
for (Map.Entry<String, Integer> entry : map_1.entrySet()) {
String key = entry.getKey();
Integer value = entry.getValue();
if (map_2.containsKey(key)) {
value = value + map_2.get(key);
map_2.remove(key);
}
res.append(key + ":" + value + delimit);
}
for (Map.Entry<String, Integer> entry : map_1.entrySet()) {
String key = entry.getKey();
Integer value = entry.getValue();
if (res.toString().contains(key)) {
continue;
}
res.append(key + ":" + value + delimit);
}
if (!map_2.isEmpty()) {
for (Map.Entry<String, Integer> entry : map_2.entrySet()) {
String key = entry.getKey();
Integer value = entry.getValue();
res.append(key + ":" + value + delimit);
}
}
return res.toString().substring(0, res.toString().length() - 1);
}
}
新建一个测试类:
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.api.java.function.VoidFunction;
import org.apache.spark.sql.SparkSession;
import scala.Tuple2;
import java.util.Arrays;
import java.util.Random;
public class CustomerAccumulator {
public static void main(String[] args) {
SparkSession sparkSession = SparkSession.builder()
.master("local[4]")
.appName("CustomerBroadcast")
.getOrCreate();
JavaSparkContext sc = JavaSparkContext.fromSparkContext(sparkSession.sparkContext());
sc.setLogLevel("ERROR");
JavaRDD<String> rdd = sc.parallelize(Arrays.asList("ONE", "TWO", "THREE","ONE"));
UserDefinedAccumulator count = new UserDefinedAccumulator();
//将累加器进行注册
sc.sc().register(count, "user_count");
//随机设置值
JavaPairRDD<String, String> pairRDD = rdd.mapToPair((PairFunction<String, String, String>) s -> {
int num = new Random().nextInt(10);
return new Tuple2<>(s, s + ":" + num);
});
//foreach中进行累加
pairRDD.foreach((VoidFunction<Tuple2<String, String>>) tuple2 -> {
System.out.println(tuple2._2);
count.add(tuple2._2);
});
System.out.println("the value of accumulator is:"+count.value());
}
}
查看运行结果: