/*
 * Decompiled with CFR 0.152.
 */
package io.github.jbellis.jvector.pq;

import io.github.jbellis.jvector.vector.VectorUtil;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.function.BiFunction;

public class KMeansPlusPlusClusterer {
    private final int k;
    private final BiFunction<float[], float[], Float> distanceFunction;
    private final Random random;
    private final float[][] points;
    private final int[] assignments;
    private final float[][] centroids;
    private final int[] centroidDenoms;
    private final float[][] centroidNums;

    public KMeansPlusPlusClusterer(float[][] points, int k, BiFunction<float[], float[], Float> distanceFunction) {
        if (k <= 0) {
            throw new IllegalArgumentException("Number of clusters must be positive.");
        }
        if (k > points.length) {
            throw new IllegalArgumentException(String.format("Number of clusters %d cannot exceed number of points %d", k, points.length));
        }
        this.points = points;
        this.k = k;
        this.distanceFunction = distanceFunction;
        this.random = new Random();
        this.centroidDenoms = new int[k];
        this.centroidNums = new float[k][points[0].length];
        this.centroids = this.chooseInitialCentroids(points);
        this.assignments = new int[points.length];
        this.assignPointsToClusters();
    }

    public float[][] cluster(int maxIterations) {
        int changedCount;
        for (int i = 0; i < maxIterations && !((double)(changedCount = this.clusterOnce()) <= 0.01 * (double)this.points.length); ++i) {
        }
        return this.centroids;
    }

    public int clusterOnce() {
        this.updateCentroids();
        return this.assignPointsToClusters();
    }

    private float[][] chooseInitialCentroids(float[][] points) {
        int i;
        float[][] centroids = new float[this.k][];
        float[] distances = new float[points.length];
        Arrays.fill(distances, Float.MAX_VALUE);
        float[] firstCentroid = points[this.random.nextInt(points.length)];
        centroids[0] = firstCentroid;
        for (i = 0; i < points.length; ++i) {
            float distance1 = this.distanceFunction.apply(points[i], firstCentroid).floatValue();
            distances[i] = Math.min(distances[i], distance1);
        }
        for (i = 1; i < this.k; ++i) {
            float totalDistance = 0.0f;
            for (float distance : distances) {
                totalDistance += distance;
            }
            float r = this.random.nextFloat() * totalDistance;
            int selectedIdx = -1;
            for (int j = 0; j < distances.length; ++j) {
                if (!((double)(r -= distances[j]) < 1.0E-6)) continue;
                selectedIdx = j;
                break;
            }
            if (selectedIdx == -1) {
                selectedIdx = this.random.nextInt(points.length);
            }
            float[] nextCentroid = points[selectedIdx];
            centroids[i] = nextCentroid;
            for (int j = 0; j < points.length; ++j) {
                float newDistance = this.distanceFunction.apply(points[j], nextCentroid).floatValue();
                distances[j] = Math.min(distances[j], newDistance);
            }
        }
        return centroids;
    }

    private int assignPointsToClusters() {
        int changedCount = 0;
        for (int i = 0; i < this.points.length; ++i) {
            float[] point = this.points[i];
            int oldAssignment = this.assignments[i];
            int newAssignment = this.getNearestCluster(point, this.centroids);
            if (newAssignment == oldAssignment) continue;
            this.centroidDenoms[oldAssignment] = this.centroidDenoms[oldAssignment] - 1;
            this.centroidDenoms[newAssignment] = this.centroidDenoms[newAssignment] + 1;
            VectorUtil.subInPlace(this.centroidNums[oldAssignment], point);
            VectorUtil.addInPlace(this.centroidNums[newAssignment], point);
            this.assignments[i] = newAssignment;
            ++changedCount;
        }
        return changedCount;
    }

    private int getNearestCluster(float[] point, float[][] centroids) {
        float minDistance = Float.MAX_VALUE;
        int nearestCluster = 0;
        for (int i = 0; i < this.k; ++i) {
            float distance = this.distanceFunction.apply(point, centroids[i]).floatValue();
            if (!(distance < minDistance)) continue;
            minDistance = distance;
            nearestCluster = i;
        }
        return nearestCluster;
    }

    private void updateCentroids() {
        for (int i = 0; i < this.centroids.length; ++i) {
            int denom = this.centroidDenoms[i];
            if (denom == 0) {
                this.centroids[i] = this.points[this.random.nextInt(this.points.length)];
                continue;
            }
            this.centroids[i] = Arrays.copyOf(this.centroidNums[i], this.centroidNums[i].length);
            VectorUtil.divInPlace(this.centroids[i], this.centroidDenoms[i]);
        }
    }

    public static float[] centroidOf(List<float[]> points) {
        if (points.isEmpty()) {
            throw new IllegalArgumentException("Can't compute centroid of empty points list");
        }
        float[] centroid = VectorUtil.sum(points);
        VectorUtil.divInPlace(centroid, points.size());
        return centroid;
    }
}

