メモリでk-meansクラスタリングアルゴリズムを実行する
3193 ワード
<strong><span style="font-size:18px;">/***
* @author YangXin
* @info K-Means
*/
package unitNine;
import java.util.ArrayList;
import java.util.List;
import org.apache.mahout.clustering.Cluster;
import org.apache.mahout.clustering.UncommonDistributions;
import org.apache.mahout.common.distance.EuclideanDistanceMeasure;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
public class KMeansExample {
private static void generateSamples(List<Vector> vectors, int num, double mx, double my, double sd){
for(int i = 0; i < num; i++){
vectors.add(new DenseVector(new double[]{UncommonDistributions.rNorm(mx, sd),UncommonDistributions.rNorm(my, sd) }));
}
}
public static void main(String[] args){
List<Vector> sampleData = new ArrayList<Vector>();
RandomPointsUtil.generateSamples(sampleData, 400, 1, 1, 3);
RandomPointsUtil.generateSamples(sampleData, 300, 1, 0, 0.5);
RandomPointsUtil.generateSamples(sampleData, 300, 0, 2, 0.1);
int k = 3;
List<Vector> randomPoints = RandomPointsUtil.chooseRandomPoints(
sampleData, k);
List<Cluster> clusters = new ArrayList<Cluster>();
int clusterId = 0;
for (Vector v : randomPoints) {
clusters.add(new Cluster(v, clusterId++, new EuclideanDistanceMeasure()));
}
List<List<Cluster>> finalClusters = KMeansClusterer.clusterPoints(
sampleData, clusters, new EuclideanDistanceMeasure(), 3, 0.01);
for (Cluster cluster : finalClusters.get(finalClusters.size() - 1)) {
System.out.println("Cluster id: " + cluster.getId() + " center: "
+ cluster.getCenter().asFormatString());
}
}
</span></strong>
<strong><span style="font-size:18px;">/***
* @author YangXin
* @info
*/
package unitNine;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
public class RandomPointsUtil {
public static void generateSamples(List<Vector> vectors, int num,
double mx, double my, double sd) {
for (int i = 0; i < num; i++) {
vectors.add(new DenseVector(new double[] {
org.apache.mahout.clustering.UncommonDistributions.rNorm(mx, sd),
org.apache.mahout.clustering.UncommonDistributions.rNorm(my, sd) }));
}
}
public static List<Vector> chooseRandomPoints(Iterable<Vector> vectors, int k) {
List<Vector> chosenPoints = new ArrayList<Vector>(k);
Random random = RandomUtils.getRandom();
for (Vector value : vectors) {
int currentSize = chosenPoints.size();
if (currentSize < k) {
chosenPoints.add(value);
} else if (random.nextInt(currentSize + 1) == 0) { // with chance 1/(currentSize+1) pick new element
int indexToRemove = random.nextInt(currentSize); // evict one chosen randomly
chosenPoints.remove(indexToRemove);
chosenPoints.add(value);
}
}
return chosenPoints;
}
}
</span></strong>