ALSアルゴリズム推奨
31970 ワード
package zqr.com;
import breeze.optimize.linear.LinearProgram;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.recommendation.ALS;
import org.apache.spark.mllib.recommendation.MatrixFactorizationModel;
import scala.Tuple2;
import org.apache.spark.mllib.recommendation.Rating;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.util.*;
public class AlsAddFp {
public static String k="";
public static Map,String>map=new HashMap<>();
public static SetNotbygoods=new HashSet();
public static Setbygoods=new HashSet();
public static Map,String> md=new HashMap();
public static List d=new ArrayList();
public static void main(String args []) {
// id
System.out.println(" id:");
Scanner scan=new Scanner(System.in);
String number=scan.nextLine();
SparkConf conf = new SparkConf().setAppName("Spark WordCount written by Java").setMaster("local");
JavaSparkContext sc = new JavaSparkContext(conf); // Scala SparkContext
//
// //
// String path = "/usr/local/spark/data/mllib/sample_fpgrowth.txt";
//
// JavaRDD data = sc.textFile(path);
// data.collect().forEach(System.out::println);
//
//
// JavaRDD//
// String arr[]=x.split(" ");
// String vaule="";
// if(map.containsKey(arr[0])){
// for(int i=1;i// vaule = vaule +" "+arr[i];
// }
// map.put(arr[0],map.get(arr[0]) + " "+vaule.toString()) ;
// }
// else{
// for(int i=1;i// vaule = vaule +" "+arr[i];
// }
// map.put(arr[0],vaule.toString()) ;
// }
// return map;
//
// });
// data_deal.collect();
// for (Map.Entry entry : map.entrySet()) {
// //Map.entry ( - ) : entry
// //entry.getKey() ;entry.getValue(); entry.setValue();
// //map.entrySet() Set 。
// System.out.println("----------------------->>>>"+entry.getKey()+","+entry.getValue());
// }
//
=======================================================================================================================
//
//
// Map> mp=new HashMap>();
//
//
// for (Map.Entry entry : map.entrySet()) {
// Mapcharacter=new HashMap();
// //Map.entry ( - ) : entry
// //entry.getKey() ;entry.getValue(); entry.setValue();
// //map.entrySet() Set 。
// k=entry.getKey().toString();
// String []arr=entry.getValue().split(" ");
// for(String s:arr){
// if(character.containsKey(s)&&s.length()>0){
// int value=character.get(s);
// character.put(s,value+1);
// }else if(s!=" "&&s.length()>0){
// character.put(s,1);
// }
// }
// System.out.println("key= " + entry.getKey() + " and value= "
// + entry.getValue());
// mp.put(k,character);
// }
//
// for (Map.Entry> entry : mp.entrySet()) {
// //Map.entry ( - ) : entry
// //entry.getKey() ;entry.getValue(); entry.setValue();
// //map.entrySet() Set 。
// System.out.println("----------------------->>>>"+entry.getKey()+","+entry.getValue());
// for(Map.Entry entr : entry.getValue().entrySet()) {
// String str = entry.getKey()+","+entr.getKey()+","+entr.getValue()+"
";
//
// System.out.println(str);
//
// FileOutputStream fos = null;
// try {
// fos = new FileOutputStream("/home/zqr/ /file",true);
// } catch (FileNotFoundException e) {
// e.printStackTrace();
// }
// //true
// try {
// fos.write(str.getBytes());
// } catch (IOException e) {
// e.printStackTrace();
// }
// }
// }
//=======================================================================================================================
//=======================================================================================================================
sc.setLogLevel("WARN");
JavaRDD data1 = sc.textFile("/home/zqr/ /file");
// JavaRDD// Map> m=new HashMap>();
// String[] sarray = s.split(",");
// Map mm=new HashMap();
// mm.put(sarray[1],Integer.parseInt(sarray[2]));
// m.put(sarray[0],mm);
// return m;
// });
JavaRDD ratings = data1.map(s -> {
String[] sarray=s.split(",");
int i=0;
if(sarray[0]!=null) {
byte[] gc = sarray[1].getBytes();
i = (int) gc[0];
}
if(!sarray[0].equals(number)){
Notbygoods.add(sarray[1]);
// System.out.println(" 1!"+sarray[1]);
}else{
bygoods.add(sarray[1]);
// System.out.println(" 2!"+sarray[1]);
}
return new Rating(Integer.parseInt(sarray[0]),
i,
Integer.parseInt(sarray[2]));
});
ratings.collect();
//============================================================================================
System.out.println(Notbygoods);
System.out.println(bygoods);
for(String f:Notbygoods){
md.put(f,"");
}
//2System.out.println(md.entrySet());
for(String h:bygoods){
if(md.containsKey(h)){
md.remove(h);
}
}
//System.out.println(md.entrySet());
for (Map.Entry, String> en : md.entrySet()) {
//System.out.println(en.getKey() + ":" + en.getValue());
byte[] gc = en.getKey().getBytes();
int k = (int) gc[0];
d.add(k);
}
System.out.println(" list :"+d);
//=====================================================================================
/**
* ALS
*/
//
int rank = 10;
//
int numIterations = 10;
//lambda ALS ;
MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(ratings), rank, numIterations, 0.01);
System.out.println("model:"+model);
//
JavaRDD da=sc.parallelize(d);
JavaRDD, Object>> userProducts =
da.map(r -> new Tuple2<>(Integer.parseInt(number), r));
JavaPairRDD, Integer>, Double> predictions = JavaPairRDD.fromJavaRDD(
model.predict(JavaRDD.toRDD(userProducts)).toJavaRDD()
.map(r -> new Tuple2<>(new Tuple2<>(r.user(), r.product()), r.rating()))
);
//===================================================================================
System.out.println(" predictions ");
//predictions.collect().forEach(System.out::println);
//=====================================================================================
Map,Double> to=new TreeMap,Double>();
List list=predictions.collect();
for(Object x : list){
String string=x.toString();
String []arr=string.split("[,)(]");
String uid=arr[2];
String mid=arr[3];
double pfen=Double.parseDouble(arr[5]);
if(uid.equals(number)){
to.put(mid,pfen);
}
}
// list
List, Double>> lil = new ArrayList,Double>>(to.entrySet());
//
Collections.sort(lil, new Comparator, Double>>() {
// value
public int compare(Map.Entry, Double> o1,
Map.Entry, Double> o2) {
double result = o1.getValue() - o2.getValue();
if(result < 0)
return 1;
else if(result == 0)
return 0;
else
return -1;
}
});
List l=new ArrayList();
int num=0;
for (Map.Entry, Double> entry : lil) {
System.out.println(entry.getKey() + " " + entry.getValue());
l.add(entry.getKey().toString());
num++;
if(num==3){
break;
}
}
for(String x:l){
int s=Integer.parseInt(x);
System.out.println((char)s);
}
}
}