博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
Spark:求出分组内的TopN
阅读量:5889 次
发布时间:2019-06-19

本文共 9932 字,大约阅读时间需要 33 分钟。

制作测试数据源:

c1 85c2 77c3 88c1 22c1 66c3 95c3 54c2 91c2 66c1 54c1 65c2 41c4 65

spark scala实现代码:

import org.apache.spark.SparkConfimport org.apache.spark.sql.SparkSessionobject GroupTopN1 {  System.setProperty("hadoop.home.dir", "D:\\Java_Study\\hadoop-common-2.2.0-bin-master")  case class Rating(userId: String, rating: Long)  def main(args: Array[String]) {    val sparkConf = new SparkConf().setAppName("ALS with ML Pipeline")    val spark = SparkSession      .builder()      .config(sparkConf)      .master("local")      .config("spark.sql.warehouse.dir", "/")      .getOrCreate()    import spark.implicits._    import spark.sql    val lines = spark.read.textFile("C:\\Users\\Administrator\\Desktop\\group.txt")    val classScores = lines.map(line => Rating(line.split(" ")(0).toString, line.split(" ")(1).toLong))    classScores.createOrReplaceTempView("tb_test")    var df = sql(      s"""|select          | userId,          | rating,          | row_number()over(partition by userId order by rating desc) rn          |from tb_test          |having(rn<=3)          |""".stripMargin)    df.show()    spark.stop()  }}

打印结果:

+------+------+---+|userId|rating| rn|+------+------+---+|    c1|    85|  1||    c1|    66|  2||    c1|    65|  3||    c4|    65|  1||    c3|    95|  1||    c3|    88|  2||    c3|    54|  3||    c2|    91|  1||    c2|    77|  2||    c2|    66|  3|+------+------+---+

spark java代码实现:

import org.apache.spark.SparkConf;import org.apache.spark.api.java.JavaRDD;import org.apache.spark.api.java.function.Function;import org.apache.spark.api.java.function.MapFunction;import org.apache.spark.sql.*;import org.apache.spark.sql.types.DataTypes;import org.apache.spark.sql.types.StructField;import org.apache.spark.sql.types.StructType;import scala.Function1;import javax.management.RuntimeErrorException;import java.util.List;import java.util.ArrayList;public class Test {    public static void main(String[] args) {        System.out.println("Hello");        SparkConf sparkConf = new SparkConf().setAppName("ALS with ML Pipeline");        SparkSession spark = SparkSession                .builder()                .config(sparkConf)                .master("local")                .config("spark.sql.warehouse.dir", "/")                .getOrCreate();        // Create an RDD        JavaRDD
peopleRDD = spark.sparkContext() .textFile("C:\\Users\\Administrator\\Desktop\\group.txt", 1) .toJavaRDD(); // The schema is encoded in a string String schemaString = "userId rating"; // Generate the schema based on the string of schema List
fields = new ArrayList<>(); StructField field1 = DataTypes.createStructField("userId", DataTypes.StringType, true); StructField field2 = DataTypes.createStructField("rating", DataTypes.LongType, true); fields.add(field1); fields.add(field2); StructType schema = DataTypes.createStructType(fields); // Convert records of the RDD (people) to Rows JavaRDD
rowRDD = peopleRDD.map((Function
) record -> { String[] attributes = record.split(" "); if(attributes.length!=2){ throw new Exception(); } return RowFactory.create(attributes[0],Long.valueOf( attributes[1].trim())); }); // Apply the schema to the RDD Dataset
peopleDataFrame = spark.createDataFrame(rowRDD, schema); peopleDataFrame.createOrReplaceTempView("tb_test"); Dataset
items = spark.sql("select userId,rating,row_number()over(partition by userId order by rating desc) rn " + "from tb_test " + "having(rn<=3)"); items.show(); spark.stop(); }}

输出结果同上边输出结果。

Java 中使用combineByKey实现TopN:

import org.apache.spark.api.java.JavaPairRDD;import org.apache.spark.api.java.JavaRDD;import org.apache.spark.api.java.JavaSparkContext;import org.apache.spark.api.java.function.FlatMapFunction;import org.apache.spark.api.java.function.Function;import org.apache.spark.api.java.function.Function2;import org.apache.spark.api.java.function.PairFunction;import org.apache.spark.sql.Dataset;import org.apache.spark.sql.Row;import org.apache.spark.sql.RowFactory;import org.apache.spark.sql.SparkSession;import org.apache.spark.sql.types.DataTypes;import org.apache.spark.sql.types.StructField;import org.apache.spark.sql.types.StructType;import scala.Tuple2;import java.util.*;public class SparkJava {    public static void main(String[] args) {        SparkSession spark = SparkSession.builder().master("local[*]").appName("Spark").getOrCreate();        final JavaSparkContext ctx = JavaSparkContext.fromSparkContext(spark.sparkContext());        List
data = Arrays.asList("a,110,a1", "b,122,b1", "c,123,c1", "a,210,a2", "b,212,b2", "a,310,a3", "b,312,b3", "a,410,a4", "b,412,b4"); JavaRDD
javaRDD = ctx.parallelize(data); JavaPairRDD
javaPairRDD = javaRDD.mapToPair(new PairFunction
() { public Tuple2
call(String key) throws Exception { return new Tuple2
(key.split(",")[0], Integer.valueOf(key.split(",")[1])); } }); final int topN = 3; JavaPairRDD
> combineByKeyRDD2 = javaPairRDD.combineByKey(new Function
>() { public List
call(Integer v1) throws Exception { List
items = new ArrayList
(); items.add(v1); return items; } }, new Function2
, Integer, List
>() { public List
call(List
v1, Integer v2) throws Exception { if (v1.size() > topN) { Integer item = Collections.min(v1); v1.remove(item); v1.add(v2); } return v1; } }, new Function2
, List
, List
>() { public List
call(List
v1, List
v2) throws Exception { v1.addAll(v2); while (v1.size() > topN) { Integer item = Collections.min(v1); v1.remove(item); } return v1; } }); // 由K:String,V:List
转化为 K:String,V:Integer // old:[(a,[210, 310, 410]), (b,[122, 212, 312]), (c,[123])] // new:[(a,210), (a,310), (a,410), (b,122), (b,212), (b,312), (c,123)] JavaRDD
> javaTupleRDD = combineByKeyRDD2.flatMap(new FlatMapFunction
>, Tuple2
>() { public Iterator
> call(Tuple2
> stringListTuple2) throws Exception { List
> items=new ArrayList
>(); for(Integer v:stringListTuple2._2){ items.add(new Tuple2
(stringListTuple2._1,v)); } return items.iterator(); } }); JavaRDD
rowRDD = javaTupleRDD.map(new Function
, Row>() { public Row call(Tuple2
kv) throws Exception { String key = kv._1; Integer num = kv._2; return RowFactory.create(key, num); } }); ArrayList
fields = new ArrayList
(); StructField field = null; field = DataTypes.createStructField("key", DataTypes.StringType, true); fields.add(field); field = DataTypes.createStructField("TopN_values", DataTypes.IntegerType, true); fields.add(field); StructType schema = DataTypes.createStructType(fields); Dataset
df = spark.createDataFrame(rowRDD, schema); df.printSchema(); df.show(); spark.stop(); }}

输出:

root |-- key: string (nullable = true) |-- TopN_values: integer (nullable = true)+---+-----------+|key|TopN_values|+---+-----------+|  a|        210||  a|        310||  a|        410||  b|        122||  b|        212||  b|        312||  c|        123|+---+-----------+

Spark使用combineByKeyWithClassTag函数实现TopN

combineByKeyWithClassTag函数,借助HashSet的排序,此例是取组内最大的N个元素一下是代码:

  • createcombiner就简单的将首个元素装进HashSet然后返回就可以了;
  • mergevalue插入元素之后,如果元素的个数大于N就删除最小的元素;
  • mergeCombiner在合并之后,如果总的个数大于N,就从一次删除最小的元素,知道Hashset内只有N 个元素。
import org.apache.spark.rdd.RDDimport org.apache.spark.sql.SparkSessionimport scala.collection.mutableobject Main {  val N = 3  def main(args: Array[String]): Unit = {    val spark = SparkSession      .builder()      .master("local[*]")      .appName("Spark")      .getOrCreate()    val sc = spark.sparkContext    var SampleDataset = List(      ("apple.com", 3L),      ("apple.com", 4L),      ("apple.com", 1L),      ("apple.com", 9L),      ("google.com", 4L),      ("google.com", 1L),      ("google.com", 2L),      ("google.com", 3L),      ("google.com", 11L),      ("google.com", 32L),      ("slashdot.org", 11L),      ("slashdot.org", 12L),      ("slashdot.org", 13L),      ("slashdot.org", 14L),      ("slashdot.org", 15L),      ("slashdot.org", 16L),      ("slashdot.org", 17L),      ("slashdot.org", 18L),      ("microsoft.com", 5L),      ("microsoft.com", 2L),      ("microsoft.com", 6L),      ("microsoft.com", 9L),      ("google.com", 4L))    val urdd: RDD[(String, Long)] = sc.parallelize(SampleDataset).map((t) => (t._1, t._2))    var topNs = urdd.combineByKeyWithClassTag(      //createCombiner      (firstInt: Long) => {        var uset = new mutable.TreeSet[Long]()        uset += firstInt      },      // mergeValue      (uset: mutable.TreeSet[Long], value: Long) => {        uset += value        while (uset.size > N) {          uset.remove(uset.min)        }        uset      },      //mergeCombiners      (uset1: mutable.TreeSet[Long], uset2: mutable.TreeSet[Long]) => {        var resultSet = uset1 ++ uset2        while (resultSet.size > N) {          resultSet.remove(resultSet.min)        }        resultSet      }    )    import spark.implicits._    topNs.flatMap(rdd => {      var uset = new mutable.HashSet[String]()      for (i <- rdd._2.toList) {        uset += rdd._1 + "/" + i.toString      }      uset    }).map(rdd => {      (rdd.split("/")(0), rdd.split("/")(1))    }).toDF("key", "TopN_values").show()  }}

参考《https://blog.csdn.net/gpwner/article/details/78455234》

输出结果:

+-------------+-----------+|          key|TopN_values|+-------------+-----------+|   google.com|          4||   google.com|         11||   google.com|         32||microsoft.com|          9||microsoft.com|          6||microsoft.com|          5||    apple.com|          4||    apple.com|          9||    apple.com|          3|| slashdot.org|         16|| slashdot.org|         17|| slashdot.org|         18|+-------------+-----------+

 

你可能感兴趣的文章
hdu4893Wow! Such Sequence! (线段树)
查看>>
Android 最简单的SD卡文件遍历程序
查看>>
JavaScript获取DOM元素位置和尺寸大小
查看>>
1065: 贝贝的加密工作
查看>>
lintcode 单词接龙II
查看>>
Material Design学习之 ProgreesBar
查看>>
WEB版一次选择多个文件进行批量上传(WebUploader)的解决方案
查看>>
Redis之 命令行 操作
查看>>
Jvm(46),指令集----对象创建与访问指令
查看>>
EL 表达式小结
查看>>
内部排序
查看>>
jQuery EasyUI API 中文文档 - 组合(Combo)
查看>>
10个关于 Dropbox 的另类功用(知乎问答精编)[还是转来了]
查看>>
Oracle体系结构
查看>>
用Modelsim仿真QII FFT IP核的时候出现的Error: Illegal target for defparam
查看>>
javascript Error对象详解
查看>>
nc 局域网聊天+文件传输(netcat)
查看>>
每天一个linux命令(25):linux文件属性详解
查看>>
go微服务框架go-micro深度学习(三) Registry服务的注册和发现
查看>>
python 重载方法有哪些特点 - 老王python - 博客园
查看>>