@Override public void selectSeeds(List<double[]> points, List<double[]> seeds) { if( seeds.size() > points.size() ) throw new IllegalArgumentException("More seeds requested than points!"); distance.resize(points.size()); // the first seed is randomly selected from the list of points double[] seed = points.get( rand.nextInt(points.size()) ); copyInto(seed,seeds.get(0)); // compute the distance each points is from the seed totalDistance = 0; for (int i = 0; i < points.size(); i++) { double[] p = points.get(i); double d = StandardKMeans_F64.distanceSq(p,seed); distance.data[i] = d; totalDistance += d; } // iteratively select the next seed and update the list of point distances for (int i = 1; i < seeds.size(); i++) { if( totalDistance == 0 ) { // if the total distance is zero that means there are duplicate points and that // all the unique points have already been added as seeds. just select a point // and copy it into rest of the seeds copyInto(seed, seeds.get(i)); } else { double target = rand.nextDouble(); copyInto(selectNextSeed(points, target), seeds.get(i)); updateDistances(points, seeds.get(i)); } } }
@Override public void selectSeeds(List<double[]> points, List<double[]> seeds) { if( seeds.size() > points.size() ) throw new IllegalArgumentException("More seeds requested than points!"); distance.resize(points.size()); // the first seed is randomly selected from the list of points double[] seed = points.get( rand.nextInt(points.size()) ); copyInto(seed,seeds.get(0)); // compute the distance each points is from the seed totalDistance = 0; for (int i = 0; i < points.size(); i++) { double[] p = points.get(i); double d = StandardKMeans_F64.distanceSq(p,seed); distance.data[i] = d; totalDistance += d; } // iteratively select the next seed and update the list of point distances for (int i = 1; i < seeds.size(); i++) { if( totalDistance == 0 ) { // if the total distance is zero that means there are duplicate points and that // all the unique points have already been added as seeds. just select a point // and copy it into rest of the seeds copyInto(seed, seeds.get(i)); } else { double target = rand.nextDouble(); copyInto(selectNextSeed(points, target), seeds.get(i)); updateDistances(points, seeds.get(i)); } } }
@Test public void updateDistances() { InitializePlusPlus alg = new InitializePlusPlus(); alg.init(1,123); alg.distance.resize(3); alg.distance.data = new double[]{3,6,1}; List<double[]> points = new ArrayList<double[]>(); for (int i = 0; i < 3; i++) { points.add(new double[]{i*i}); } alg.updateDistances(points, new double[]{-1}); assertEquals(1,alg.distance.get(0),1e-8); assertEquals(4,alg.distance.get(1),1e-8); assertEquals(1,alg.distance.get(2),1e-8); assertEquals(6,alg.totalDistance,1e-8); }