/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.processor.normalization;

import com.google.common.primitives.Floats;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.Generated;
import org.apache.commons.lang3.Validate;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.opensearch.neuralsearch.processor.CompoundTopDocs;
import org.opensearch.neuralsearch.processor.NormalizeScoresDTO;
import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard;
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;
import org.opensearch.neuralsearch.processor.explain.ExplanationDetails;
import org.opensearch.neuralsearch.processor.explain.ExplanationUtils;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationUtil;
import org.opensearch.neuralsearch.processor.util.ProcessorUtils;

public class MinMaxScoreNormalizationTechnique
implements ScoreNormalizationTechnique,
ExplainableTechnique {
    public static final String TECHNIQUE_NAME = "min_max";
    protected static final float MIN_SCORE = 0.001f;
    private static final float SINGLE_RESULT_SCORE = 1.0f;
    private static final String PARAM_NAME_LOWER_BOUNDS = "lower_bounds";
    private static final String PARAM_NAME_LOWER_BOUND_MODE = "mode";
    private static final String PARAM_NAME_LOWER_BOUND_MIN_SCORE = "min_score";
    private static final Set<String> SUPPORTED_PARAMETERS = Set.of("lower_bounds");
    private static final Map<String, Set<String>> NESTED_PARAMETERS = Map.of("lower_bounds", Set.of("mode", "min_score"));
    private final Optional<List<Pair<LowerBound.Mode, Float>>> lowerBoundsOptional;

    public MinMaxScoreNormalizationTechnique() {
        this(Map.of(), new ScoreNormalizationUtil());
    }

    public MinMaxScoreNormalizationTechnique(Map<String, Object> params, ScoreNormalizationUtil scoreNormalizationUtil) {
        scoreNormalizationUtil.validateParameters(params, SUPPORTED_PARAMETERS, NESTED_PARAMETERS);
        this.lowerBoundsOptional = this.getLowerBounds(params);
    }

    @Override
    public void normalize(NormalizeScoresDTO normalizeScoresDTO) {
        List<CompoundTopDocs> queryTopDocs = normalizeScoresDTO.getQueryTopDocs();
        MinMaxScores minMaxScores = this.getMinMaxScoresResult(queryTopDocs);
        for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
            if (Objects.isNull(compoundQueryTopDocs)) continue;
            List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs();
            if (this.isLowerBoundsAndSubQueriesCountMismatched(topDocsPerSubQuery)) {
                throw new IllegalArgumentException(String.format(Locale.ROOT, "expected lower bounds array to contain %d elements matching the number of sub-queries, but found a mismatch", topDocsPerSubQuery.size()));
            }
            for (int j = 0; j < topDocsPerSubQuery.size(); ++j) {
                TopDocs subQueryTopDoc = topDocsPerSubQuery.get(j);
                LowerBound lowerBound = this.getLowerBound(j);
                for (ScoreDoc scoreDoc : subQueryTopDoc.scoreDocs) {
                    scoreDoc.score = this.normalizeSingleScore(scoreDoc.score, minMaxScores.getMinScoresPerSubquery()[j], minMaxScores.getMaxScoresPerSubquery()[j], lowerBound);
                }
            }
        }
    }

    private boolean isLowerBoundsAndSubQueriesCountMismatched(List<TopDocs> topDocsPerSubQuery) {
        return this.lowerBoundsOptional.isPresent() && !topDocsPerSubQuery.isEmpty() && this.lowerBoundsOptional.get().size() != topDocsPerSubQuery.size();
    }

    private LowerBound getLowerBound(int subQueryIndex) {
        return this.lowerBoundsOptional.map(pairs -> new LowerBound(true, (LowerBound.Mode)((Object)((Object)((Pair)pairs.get(subQueryIndex)).getLeft())), ((Float)((Pair)pairs.get(subQueryIndex)).getRight()).floatValue())).orElseGet(LowerBound::new);
    }

    private MinMaxScores getMinMaxScoresResult(List<CompoundTopDocs> queryTopDocs) {
        int numOfSubqueries = ProcessorUtils.getNumOfSubqueries(queryTopDocs);
        float[] minScoresPerSubquery = this.getMinScores(queryTopDocs, numOfSubqueries);
        float[] maxScoresPerSubquery = this.getMaxScores(queryTopDocs, numOfSubqueries);
        return new MinMaxScores(minScoresPerSubquery, maxScoresPerSubquery);
    }

    @Override
    public String techniqueName() {
        return TECHNIQUE_NAME;
    }

    @Override
    public String describe() {
        return this.lowerBoundsOptional.map(lb -> {
            String lowerBounds = lb.stream().map(pair -> String.format(Locale.ROOT, "(%s, %s)", pair.getLeft(), pair.getRight())).collect(Collectors.joining(", ", "[", "]"));
            return String.format(Locale.ROOT, "%s, lower bounds %s", TECHNIQUE_NAME, lowerBounds);
        }).orElse(String.format(Locale.ROOT, "%s", TECHNIQUE_NAME));
    }

    @Override
    public Map<DocIdAtSearchShard, ExplanationDetails> explain(List<CompoundTopDocs> queryTopDocs) {
        MinMaxScores minMaxScores = this.getMinMaxScoresResult(queryTopDocs);
        HashMap<DocIdAtSearchShard, List<Float>> normalizedScores = new HashMap<DocIdAtSearchShard, List<Float>>();
        for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
            if (Objects.isNull(compoundQueryTopDocs)) continue;
            List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs();
            int numberOfSubQueries = topDocsPerSubQuery.size();
            for (int subQueryIndex = 0; subQueryIndex < numberOfSubQueries; ++subQueryIndex) {
                TopDocs subQueryTopDoc = topDocsPerSubQuery.get(subQueryIndex);
                for (ScoreDoc scoreDoc : subQueryTopDoc.scoreDocs) {
                    DocIdAtSearchShard docIdAtSearchShard = new DocIdAtSearchShard(scoreDoc.doc, compoundQueryTopDocs.getSearchShard());
                    LowerBound lowerBound = this.getLowerBound(subQueryIndex);
                    float normalizedScore = this.normalizeSingleScore(scoreDoc.score, minMaxScores.getMinScoresPerSubquery()[subQueryIndex], minMaxScores.getMaxScoresPerSubquery()[subQueryIndex], lowerBound);
                    ScoreNormalizationUtil.setNormalizedScore(normalizedScores, docIdAtSearchShard, subQueryIndex, numberOfSubQueries, normalizedScore);
                    scoreDoc.score = normalizedScore;
                }
            }
        }
        return ExplanationUtils.getDocIdAtQueryForNormalization(normalizedScores, this);
    }

    private float[] getMaxScores(List<CompoundTopDocs> queryTopDocs, int numOfSubqueries) {
        float[] maxScores = new float[numOfSubqueries];
        Arrays.fill(maxScores, Float.MIN_VALUE);
        for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
            if (Objects.isNull(compoundQueryTopDocs)) continue;
            List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs();
            for (int j = 0; j < topDocsPerSubQuery.size(); ++j) {
                maxScores[j] = Math.max(maxScores[j], Arrays.stream(topDocsPerSubQuery.get((int)j).scoreDocs).map(scoreDoc -> Float.valueOf(scoreDoc.score)).max(Float::compare).orElse(Float.valueOf(Float.MIN_VALUE)).floatValue());
            }
        }
        return maxScores;
    }

    private float[] getMinScores(List<CompoundTopDocs> queryTopDocs, int numOfScores) {
        float[] minScores = new float[numOfScores];
        Arrays.fill(minScores, Float.MAX_VALUE);
        for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
            if (Objects.isNull(compoundQueryTopDocs)) continue;
            List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs();
            for (int j = 0; j < topDocsPerSubQuery.size(); ++j) {
                minScores[j] = Math.min(minScores[j], Arrays.stream(topDocsPerSubQuery.get((int)j).scoreDocs).map(scoreDoc -> Float.valueOf(scoreDoc.score)).min(Float::compare).orElse(Float.valueOf(Float.MAX_VALUE)).floatValue());
            }
        }
        return minScores;
    }

    private float normalizeSingleScore(float score, float minScore, float maxScore, LowerBound lowerBound) {
        if (Floats.compare((float)maxScore, (float)minScore) == 0 && Floats.compare((float)maxScore, (float)score) == 0) {
            return 1.0f;
        }
        if (!lowerBound.isEnabled()) {
            return LowerBound.Mode.IGNORE.normalize(score, minScore, maxScore, lowerBound.getMinScore());
        }
        return lowerBound.getMode().normalize(score, minScore, maxScore, lowerBound.getMinScore());
    }

    private Optional<List<Pair<LowerBound.Mode, Float>>> getLowerBounds(Map<String, Object> params) {
        if (Objects.isNull(params) || !params.containsKey(PARAM_NAME_LOWER_BOUNDS)) {
            return Optional.empty();
        }
        List lowerBoundsParams = Optional.ofNullable(params.get(PARAM_NAME_LOWER_BOUNDS)).filter(List.class::isInstance).map(List.class::cast).orElseThrow(() -> new IllegalArgumentException("lower_bounds must be a List"));
        if (lowerBoundsParams.size() > 5) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "lower_bounds size %d should be less than or equal to %d", lowerBoundsParams.size(), 5));
        }
        List lowerBounds = lowerBoundsParams.stream().map(this::parseLowerBound).collect(Collectors.toList());
        return Optional.of(lowerBounds);
    }

    private Pair<LowerBound.Mode, Float> parseLowerBound(Object boundObj) {
        if (!(boundObj instanceof Map)) {
            throw new IllegalArgumentException("each lower bound must be a map");
        }
        Map lowerBound = (Map)boundObj;
        String lowerBoundModeValue = Objects.toString(lowerBound.get(PARAM_NAME_LOWER_BOUND_MODE), "");
        LowerBound.Mode mode = LowerBound.Mode.fromString(lowerBoundModeValue);
        float minScore = this.extractAndValidateMinScore(lowerBound);
        return ImmutablePair.of((Object)((Object)mode), (Object)Float.valueOf(minScore));
    }

    private float extractAndValidateMinScore(Map<String, Object> lowerBound) {
        Object minScoreObj = lowerBound.get(PARAM_NAME_LOWER_BOUND_MIN_SCORE);
        if (minScoreObj == null) {
            return 0.0f;
        }
        try {
            float minScore = 0.0f;
            if (Objects.nonNull(lowerBound.get(PARAM_NAME_LOWER_BOUND_MIN_SCORE))) {
                minScore = Float.parseFloat(String.valueOf(lowerBound.get(PARAM_NAME_LOWER_BOUND_MIN_SCORE)));
            }
            Validate.isTrue((minScore >= -10000.0f && minScore <= 10000.0f ? 1 : 0) != 0, (String)"min_score must be a valid finite number between %f and %f", (Object[])new Object[]{Float.valueOf(-10000.0f), Float.valueOf(10000.0f)});
            return minScore;
        }
        catch (NumberFormatException e) {
            throw new IllegalArgumentException("invalid format for min_score: must be a valid float value", e);
        }
    }

    @Generated
    public String toString() {
        return "MinMaxScoreNormalizationTechnique(TECHNIQUE_NAME=min_max)";
    }

    private static class MinMaxScores {
        float[] minScoresPerSubquery;
        float[] maxScoresPerSubquery;

        @Generated
        public MinMaxScores(float[] minScoresPerSubquery, float[] maxScoresPerSubquery) {
            this.minScoresPerSubquery = minScoresPerSubquery;
            this.maxScoresPerSubquery = maxScoresPerSubquery;
        }

        @Generated
        public float[] getMinScoresPerSubquery() {
            return this.minScoresPerSubquery;
        }

        @Generated
        public float[] getMaxScoresPerSubquery() {
            return this.maxScoresPerSubquery;
        }
    }

    static class LowerBound {
        static final float MIN_LOWER_BOUND_SCORE = -10000.0f;
        static final float MAX_LOWER_BOUND_SCORE = 10000.0f;
        static final float DEFAULT_LOWER_BOUND_SCORE = 0.0f;
        private final boolean enabled;
        private final Mode mode;
        private final float minScore;

        LowerBound() {
            this(false, Mode.DEFAULT, 0.0f);
        }

        LowerBound(boolean enabled, Mode mode, float minScore) {
            this.enabled = enabled;
            this.mode = mode;
            this.minScore = minScore;
        }

        @Generated
        public boolean isEnabled() {
            return this.enabled;
        }

        @Generated
        public Mode getMode() {
            return this.mode;
        }

        @Generated
        public float getMinScore() {
            return this.minScore;
        }

        protected static enum Mode {
            APPLY{

                @Override
                public float normalize(float score, float minScore, float maxScore, float lowerBoundScore) {
                    if (maxScore < lowerBoundScore || score < lowerBoundScore) {
                        return (score - minScore) / (maxScore - minScore);
                    }
                    return (score - lowerBoundScore) / (maxScore - lowerBoundScore);
                }
            }
            ,
            CLIP{

                @Override
                public float normalize(float score, float minScore, float maxScore, float lowerBoundScore) {
                    if (score < minScore) {
                        return 0.0f;
                    }
                    if (maxScore < lowerBoundScore) {
                        return (score - minScore) / (maxScore - minScore);
                    }
                    return (score - lowerBoundScore) / (maxScore - lowerBoundScore);
                }
            }
            ,
            IGNORE{

                @Override
                public float normalize(float score, float minScore, float maxScore, float lowerBoundScore) {
                    float normalizedScore = (score - minScore) / (maxScore - minScore);
                    return normalizedScore == 0.0f ? 0.001f : normalizedScore;
                }
            };

            public static final Mode DEFAULT;
            public static final String VALID_VALUES;

            public static Mode fromString(String value) {
                if (Objects.isNull(value)) {
                    throw new IllegalArgumentException("mode value cannot be null or empty");
                }
                if (value.trim().isEmpty()) {
                    return DEFAULT;
                }
                try {
                    return Mode.valueOf(value.toUpperCase(Locale.ROOT));
                }
                catch (IllegalArgumentException e) {
                    throw new IllegalArgumentException(String.format(Locale.ROOT, "invalid mode: %s, valid values are: %s", value, VALID_VALUES));
                }
            }

            public abstract float normalize(float var1, float var2, float var3, float var4);

            public String toString() {
                return this.name().toLowerCase(Locale.ROOT);
            }

            static {
                DEFAULT = APPLY;
                VALID_VALUES = Arrays.stream(Mode.values()).map(mode -> mode.name().toLowerCase(Locale.ROOT)).collect(Collectors.joining(", "));
            }
        }
    }
}

