package zqr.com; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.recommendation.ALS; import org.apache.spark.mllib.recommendation.MatrixFactorizationModel; import org.apache.spark.mllib.recommendation.Rating; import scala.Tuple2; import scala.reflect.internal.Names; import java.util.*; public class MovieRecommended { public static void main(String args[]){ // 输入用户id System.out.println("输入用户id:"); Scanner scan=new Scanner(System.in); String number=scan.nextLine(); SparkConf conf = new SparkConf().setAppName("movie recommended system").setMaster("local"); JavaSparkContext sc = new JavaSparkContext(conf); // 其底层实际上就是Scala的SparkContext String path = "/usr/local/spark/testdata/ml-latest-small/ratings.csv"; JavaRDD<String> rating_data = sc.textFile(path); // rating_data.collect().forEach(System.out::println); Set<String>NotViewMovie=new HashSet<String>(); Set<String>ViewMovie=new HashSet<String>(); JavaRDD<Rating> mapdata = rating_data.map(s -> { String[] sarray = s.split(","); Rating rt=null; if(sarray[0].equals("userId")) { rt=null; }else if(Double.parseDouble(sarray[2])<5&&Double.parseDouble(sarray[2])>=0&&!sarray[2].isEmpty()){ rt=new Rating(Integer.parseInt(sarray[0]), Integer.parseInt(sarray[1]), Double.parseDouble(sarray[2])); if(!sarray[0].equals(number)){ NotViewMovie.add(sarray[1]); }else{ ViewMovie.add(sarray[1]); } }else{ rt=null; } return rt; }).filter(k->k!=null); //============================================================================================ NotViewMovie.addAll(ViewMovie); List<String> d=new ArrayList<String>(); for(String h:NotViewMovie){ d.add(h); } //===================================================================================== //mapdata.collect().forEach(System.out::println); // 隐性因子个数 int rank = 10; //迭代次数 int numIterations = 10; //lambda是ALS的正则化参数; MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(mapdata), rank, numIterations, 0.01); //System.out.println("model:"+model); // 评估评级数据模型 JavaRDD data=sc.parallelize(d); JavaRDD<Tuple2<Object, Object>> userProducts = data.map(r -> new Tuple2<>(number, r)); JavaPairRDD<Tuple2<Integer, Integer>, Double> predictions = JavaPairRDD.fromJavaRDD( model.predict(JavaRDD.toRDD(userProducts)).toJavaRDD() .map(r -> new Tuple2<>(new Tuple2<>(r.user(), r.product()), r.rating())) ); //System.out.println("打印predictions的值"); //predictions.collect().forEach(System.out::println); //Map<String,Map<Double,String>> data=new HashMap<String,Map<Double,String>>(); Map<String,Double> tow=new TreeMap<String,Double>(); List list=predictions.collect(); for(Object x : list){ String string=x.toString(); String []arr=string.split("[,()]"); String uid=arr[2]; String mid=arr[3]; double pfen=Double.parseDouble(arr[5]); if(uid.equals(number)){ tow.put(mid,pfen); } //data.put(uid,tow); //System.out.println(uid+"--->"+mid+"---->"+pfen); } //System.out.println(list); // // for (Map.Entry<String, Double> entry : tow.entrySet()) { // System.out.println(entry.getKey() + ":" + entry.getValue()); // } //转换成list进行排序 List<Map.Entry<String, Double>> li = new ArrayList<Map.Entry<String,Double>>(tow.entrySet()); // 排序 Collections.sort(li, new Comparator<Map.Entry<String, Double>>() { //根据value排序 public int compare(Map.Entry<String, Double> o1, Map.Entry<String, Double> o2) { double result = o2.getValue() - o1.getValue(); if(result > 0) return 1; else if(result == 0) return 0; else return -1; } }); List<String> l=new ArrayList<String>(); int num=0; for (Map.Entry<String, Double> entry : li) { System.out.println(entry.getKey() + " " + entry.getValue()); l.add(entry.getKey().toString()); num++; if(num==10){ break; } } for(String x:l){ System.out.println(x); } String path1 = "/usr/local/spark/testdata/ml-latest-small/movies.csv"; JavaRDD<String> mviddata = sc.textFile(path1); // rating_data.collect().forEach(System.out::println); JavaRDD<Map<String,String>> move_pair = mviddata.map(s -> { String[] sarray = s.split(","); String x=sarray[0].toString(); String y=sarray[1].toString(); Map mp=new HashMap<String,String>(); mp.put(x,y); return mp; }); //move_pair.collect().forEach(System.out::println); for(String idx:l) { move_pair.foreach(x -> { if(x.containsKey(idx)){ String value=x.get(idx).toString(); System.out.println(value); } }); } //for (Map.Entry<String, Double> entry : tow.entrySet()) { //System.out.println(entry.getKey() + ":" + entry.getValue()); //} // model.save(sc.sc(), "target/movierd/myCollaborativeFilter"); } }
spark ML ALS实现电影推荐
最新推荐文章于 2022-12-28 09:56:15 发布