/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.index.query.memoryoptsearch;

import java.io.IOException;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.SegmentReader;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.Weight;
import org.apache.lucene.search.join.DiversifyingNearestChildrenKnnCollectorManager;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.search.knn.KnnSearchStrategy;
import org.apache.lucene.search.knn.TopKnnCollectorManager;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.Bits;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.query.KNNQuery;
import org.opensearch.knn.index.query.KNNWeight;
import org.opensearch.knn.index.query.memoryoptsearch.RadiusVectorSimilarityCollector;
import org.opensearch.knn.plugin.stats.KNNCounter;

public class MemoryOptimizedKNNWeight
extends KNNWeight {
    @Generated
    private static final Logger log = LogManager.getLogger(MemoryOptimizedKNNWeight.class);
    private final KnnCollectorManager knnCollectorManager;

    public MemoryOptimizedKNNWeight(KNNQuery query, float boost, Weight filterWeight, IndexSearcher searcher, int k) {
        super(query, boost, filterWeight);
        this.knnCollectorManager = k > 0 ? (query.getParentsFilter() == null ? new TopKnnCollectorManager(k, searcher) : new DiversifyingNearestChildrenKnnCollectorManager(k, query.getParentsFilter(), searcher)) : (visitLimit, searchStrategy, context) -> new RadiusVectorSimilarityCollector(KNNConstants.DEFAULT_LUCENE_RADIAL_SEARCH_TRAVERSAL_SIMILARITY_RATIO.floatValue() * query.getRadius().floatValue(), query.getRadius().floatValue(), visitLimit);
    }

    @Override
    protected TopDocs doANNSearch(LeafReaderContext context, SegmentReader reader, FieldInfo fieldInfo, SpaceType spaceType, KNNEngine knnEngine, VectorDataType vectorDataType, byte[] quantizedTargetVector, String modelId, BitSet filterIdsBitSet, int cardinality, int k) {
        try {
            if (k > 0) {
                if (quantizedTargetVector != null) {
                    if (this.quantizationService.getVectorDataTypeForTransfer(fieldInfo) == VectorDataType.BINARY) {
                        return this.queryIndex(quantizedTargetVector, cardinality, cardinality + 1, context, filterIdsBitSet, reader, knnEngine, spaceType);
                    }
                    throw new IllegalStateException("VectorDataType for transfer acquired [" + String.valueOf((Object)this.quantizationService.getVectorDataTypeForTransfer(fieldInfo)) + "] while it is expected to get [" + String.valueOf((Object)VectorDataType.BINARY) + "]");
                }
                if (this.knnQuery.getVectorDataType() == VectorDataType.BINARY || this.knnQuery.getVectorDataType() == VectorDataType.BYTE) {
                    return this.queryIndex(this.knnQuery.getByteQueryVector(), cardinality, cardinality + 1, context, filterIdsBitSet, reader, knnEngine, spaceType);
                }
                return this.queryIndex(this.knnQuery.getQueryVector(), cardinality, cardinality + 1, context, filterIdsBitSet, reader, knnEngine, spaceType);
            }
            return this.queryIndex(this.knnQuery.getQueryVector(), cardinality, cardinality, context, filterIdsBitSet, reader, knnEngine, spaceType);
        }
        catch (Exception e) {
            KNNCounter.GRAPH_QUERY_ERRORS.increment();
            throw new RuntimeException(e);
        }
    }

    private TopDocs queryIndex(Object targetVector, int cardinality, int visitLimitWhenFilterExists, LeafReaderContext context, BitSet filterIdsBitSet, SegmentReader reader, KNNEngine knnEngine, SpaceType spaceType) throws IOException {
        BitSet bitSet;
        assert (targetVector instanceof float[] || targetVector instanceof byte[]);
        int visitedLimit = this.getFilterWeight() == null ? Integer.MAX_VALUE : visitLimitWhenFilterExists;
        KnnCollector knnCollector = this.knnCollectorManager.newCollector(visitedLimit, (KnnSearchStrategy)KnnSearchStrategy.Hnsw.DEFAULT, context);
        BitSet bitSet2 = bitSet = cardinality == 0 ? null : filterIdsBitSet;
        if (targetVector instanceof float[]) {
            float[] floatTargetVector = (float[])targetVector;
            reader.getVectorReader().search(this.knnQuery.getField(), floatTargetVector, knnCollector, (Bits)bitSet);
        } else {
            reader.getVectorReader().search(this.knnQuery.getField(), (byte[])targetVector, knnCollector, (Bits)bitSet);
        }
        TopDocs topDocs = knnCollector.topDocs();
        if (topDocs.scoreDocs.length == 0) {
            log.debug("[KNN] Query yielded 0 results");
            return EMPTY_TOPDOCS;
        }
        this.addExplainIfRequired(topDocs, knnEngine, spaceType);
        return topDocs;
    }
}

