/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.index.query;

import com.google.common.annotations.VisibleForTesting;
import java.io.IOException;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.stream.Collectors;
import lombok.Generated;
import org.apache.commons.lang.StringUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.SegmentReader;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.FilteredDocIdSetIterator;
import org.apache.lucene.search.HitQueue;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Weight;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.FSDirectory;
import org.apache.lucene.store.FilterDirectory;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.DocIdSetBuilder;
import org.apache.lucene.util.FixedBitSet;
import org.opensearch.common.io.PathUtils;
import org.opensearch.common.lucene.Lucene;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.IndexUtil;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.memory.NativeMemoryAllocation;
import org.opensearch.knn.index.memory.NativeMemoryCacheManager;
import org.opensearch.knn.index.memory.NativeMemoryEntryContext;
import org.opensearch.knn.index.memory.NativeMemoryLoadStrategy;
import org.opensearch.knn.index.query.FilterIdsSelector;
import org.opensearch.knn.index.query.KNNQuery;
import org.opensearch.knn.index.query.KNNQueryResult;
import org.opensearch.knn.index.query.KNNScorer;
import org.opensearch.knn.index.query.filtered.FilteredIdsKNNByteIterator;
import org.opensearch.knn.index.query.filtered.FilteredIdsKNNIterator;
import org.opensearch.knn.index.query.filtered.KNNIterator;
import org.opensearch.knn.index.query.filtered.NestedFilteredIdsKNNByteIterator;
import org.opensearch.knn.index.query.filtered.NestedFilteredIdsKNNIterator;
import org.opensearch.knn.index.util.KNNEngine;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelUtil;
import org.opensearch.knn.jni.JNIService;
import org.opensearch.knn.plugin.stats.KNNCounter;

public class KNNWeight
extends Weight {
    @Generated
    private static final Logger log = LogManager.getLogger(KNNWeight.class);
    private static ModelDao modelDao;
    private final KNNQuery knnQuery;
    private final float boost;
    private final NativeMemoryCacheManager nativeMemoryCacheManager;
    private final Weight filterWeight;

    public KNNWeight(KNNQuery query, float boost) {
        super((Query)query);
        this.knnQuery = query;
        this.boost = boost;
        this.nativeMemoryCacheManager = NativeMemoryCacheManager.getInstance();
        this.filterWeight = null;
    }

    public KNNWeight(KNNQuery query, float boost, Weight filterWeight) {
        super((Query)query);
        this.knnQuery = query;
        this.boost = boost;
        this.nativeMemoryCacheManager = NativeMemoryCacheManager.getInstance();
        this.filterWeight = filterWeight;
    }

    public static void initialize(ModelDao modelDao) {
        KNNWeight.modelDao = modelDao;
    }

    public Explanation explain(LeafReaderContext context, int doc) {
        return Explanation.match((Number)Float.valueOf(1.0f), (String)"No Explanation", (Explanation[])new Explanation[0]);
    }

    public Scorer scorer(LeafReaderContext context) throws IOException {
        BitSet filterBitSet = this.getFilteredDocsBitSet(context);
        int cardinality = filterBitSet.cardinality();
        if (this.filterWeight != null && cardinality == 0) {
            return KNNScorer.emptyScorer(this);
        }
        HashMap<Integer, Float> docIdsToScoreMap = new HashMap<Integer, Float>();
        if (this.filterWeight != null && this.canDoExactSearch(cardinality)) {
            docIdsToScoreMap.putAll(this.doExactSearch(context, filterBitSet, cardinality));
        } else {
            Map<Integer, Float> annResults = this.doANNSearch(context, filterBitSet, cardinality);
            if (annResults == null) {
                return null;
            }
            if (this.canDoExactSearchAfterANNSearch(cardinality, annResults.size())) {
                log.debug("Doing ExactSearch after doing ANNSearch as the number of documents returned are less than K, even when we have more than K filtered Ids. K: {}, ANNResults: {}, filteredIdCount: {}", (Object)this.knnQuery.getK(), (Object)annResults.size(), (Object)cardinality);
                annResults = this.doExactSearch(context, filterBitSet, cardinality);
            }
            docIdsToScoreMap.putAll(annResults);
        }
        if (docIdsToScoreMap.isEmpty()) {
            return KNNScorer.emptyScorer(this);
        }
        return this.convertSearchResponseToScorer(docIdsToScoreMap);
    }

    private BitSet getFilteredDocsBitSet(LeafReaderContext ctx) throws IOException {
        if (this.filterWeight == null) {
            return new FixedBitSet(0);
        }
        Bits liveDocs = ctx.reader().getLiveDocs();
        int maxDoc = ctx.reader().maxDoc();
        Scorer scorer = this.filterWeight.scorer(ctx);
        if (scorer == null) {
            return new FixedBitSet(0);
        }
        return this.createBitSet(scorer.iterator(), liveDocs, maxDoc);
    }

    private BitSet createBitSet(DocIdSetIterator filteredDocIdsIterator, final Bits liveDocs, int maxDoc) throws IOException {
        if (liveDocs == null && filteredDocIdsIterator instanceof BitSetIterator) {
            return ((BitSetIterator)filteredDocIdsIterator).getBitSet();
        }
        FilteredDocIdSetIterator filterIterator = new FilteredDocIdSetIterator(filteredDocIdsIterator){

            protected boolean match(int doc) {
                return liveDocs == null || liveDocs.get(doc);
            }
        };
        return BitSet.of((DocIdSetIterator)filterIterator, (int)maxDoc);
    }

    private int[] getParentIdsArray(LeafReaderContext context) throws IOException {
        if (this.knnQuery.getParentsFilter() == null) {
            return null;
        }
        return this.bitSetToIntArray(this.knnQuery.getParentsFilter().getBitSet(context));
    }

    private int[] bitSetToIntArray(BitSet bitSet) {
        int cardinality = bitSet.cardinality();
        int[] intArray = new int[cardinality];
        BitSetIterator bitSetIterator = new BitSetIterator(bitSet, (long)cardinality);
        int index = 0;
        int docId = bitSetIterator.nextDoc();
        while (docId != Integer.MAX_VALUE) {
            assert (index < intArray.length);
            intArray[index++] = docId;
            docId = bitSetIterator.nextDoc();
        }
        return intArray;
    }

    private Map<Integer, Float> doANNSearch(LeafReaderContext context, BitSet filterIdsBitSet, int cardinality) throws IOException {
        KNNQueryResult[] results;
        NativeMemoryAllocation indexAllocation;
        VectorDataType vectorDataType;
        SpaceType spaceType;
        KNNEngine knnEngine;
        SegmentReader reader = Lucene.segmentReader((LeafReader)context.reader());
        String directory = ((FSDirectory)FilterDirectory.unwrap((Directory)reader.directory())).getDirectory().toString();
        FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(this.knnQuery.getField());
        if (fieldInfo == null) {
            log.debug("[KNN] Field info not found for {}:{}", (Object)this.knnQuery.getField(), (Object)reader.getSegmentName());
            return null;
        }
        String modelId = fieldInfo.getAttribute("model_id");
        if (modelId != null) {
            ModelMetadata modelMetadata = modelDao.getMetadata(modelId);
            if (!ModelUtil.isModelCreated(modelMetadata)) {
                throw new RuntimeException("Model \"" + modelId + "\" is not created.");
            }
            knnEngine = modelMetadata.getKnnEngine();
            spaceType = modelMetadata.getSpaceType();
            vectorDataType = modelMetadata.getVectorDataType();
        } else {
            String engineName = fieldInfo.attributes().getOrDefault("engine", KNNEngine.NMSLIB.getName());
            knnEngine = KNNEngine.getEngine(engineName);
            String spaceTypeName = fieldInfo.attributes().getOrDefault("spaceType", SpaceType.L2.getValue());
            spaceType = SpaceType.getSpace(spaceTypeName);
            vectorDataType = VectorDataType.get(fieldInfo.attributes().getOrDefault("data_type", VectorDataType.FLOAT.getValue()));
        }
        List<String> engineFiles = this.getEngineFiles(reader, knnEngine.getExtension());
        if (engineFiles.isEmpty()) {
            log.debug("[KNN] No engine index found for field {} for segment {}", (Object)this.knnQuery.getField(), (Object)reader.getSegmentName());
            return null;
        }
        Path indexPath = PathUtils.get((String)directory, (String[])new String[]{engineFiles.get(0)});
        KNNCounter.GRAPH_QUERY_REQUESTS.increment();
        try {
            indexAllocation = this.nativeMemoryCacheManager.get(new NativeMemoryEntryContext.IndexEntryContext(indexPath.toString(), NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance(), IndexUtil.getParametersAtLoading(spaceType, knnEngine, this.knnQuery.getIndexName(), vectorDataType), this.knnQuery.getIndexName(), modelId), true);
        }
        catch (ExecutionException e) {
            KNNCounter.GRAPH_QUERY_ERRORS.increment();
            throw new RuntimeException(e);
        }
        FilterIdsSelector filterIdsSelector = FilterIdsSelector.getFilterIdSelector(filterIdsBitSet, cardinality);
        long[] filterIds = filterIdsSelector.getFilterIds();
        FilterIdsSelector.FilterIdsSelectorType filterType = filterIdsSelector.getFilterType();
        indexAllocation.readLock();
        try {
            if (indexAllocation.isClosed()) {
                throw new RuntimeException("Index has already been closed");
            }
            int[] parentIds = this.getParentIdsArray(context);
            results = this.knnQuery.getK() > 0 ? (this.knnQuery.getVectorDataType() == VectorDataType.BINARY ? JNIService.queryBinaryIndex(indexAllocation.getMemoryAddress(), this.knnQuery.getByteQueryVector(), this.knnQuery.getK(), this.knnQuery.getMethodParameters(), knnEngine, filterIds, filterType.getValue(), parentIds) : JNIService.queryIndex(indexAllocation.getMemoryAddress(), this.knnQuery.getQueryVector(), this.knnQuery.getK(), this.knnQuery.getMethodParameters(), knnEngine, filterIds, filterType.getValue(), parentIds)) : JNIService.radiusQueryIndex(indexAllocation.getMemoryAddress(), this.knnQuery.getQueryVector(), this.knnQuery.getRadius().floatValue(), this.knnQuery.getMethodParameters(), knnEngine, this.knnQuery.getContext().getMaxResultWindow(), filterIds, filterType.getValue(), parentIds);
        }
        catch (Exception e) {
            KNNCounter.GRAPH_QUERY_ERRORS.increment();
            throw new RuntimeException(e);
        }
        finally {
            indexAllocation.readUnlock();
        }
        if (results.length == 0) {
            log.debug("[KNN] Query yielded 0 results");
            return null;
        }
        return Arrays.stream(results).collect(Collectors.toMap(KNNQueryResult::getId, result -> Float.valueOf(knnEngine.score(result.getScore(), spaceType))));
    }

    @VisibleForTesting
    List<String> getEngineFiles(SegmentReader reader, String extension) throws IOException {
        Object engineExtension = reader.getSegmentInfo().info.getUseCompoundFile() ? extension + "c" : extension;
        String engineSuffix = this.knnQuery.getField() + (String)engineExtension;
        String underLineEngineSuffix = "_" + engineSuffix;
        List<String> engineFiles = reader.getSegmentInfo().files().stream().filter(fileName -> fileName.endsWith(underLineEngineSuffix)).sorted(Comparator.comparingInt(String::length)).collect(Collectors.toList());
        return engineFiles;
    }

    private Map<Integer, Float> doExactSearch(LeafReaderContext leafReaderContext, BitSet filterIdsBitSet, int cardinality) {
        try {
            int docId;
            HitQueue queue = new HitQueue(Math.min(this.knnQuery.getK(), cardinality), true);
            ScoreDoc topDoc = (ScoreDoc)queue.top();
            HashMap<Integer, Float> docToScore = new HashMap<Integer, Float>();
            KNNIterator iterator = this.getFilteredKNNIterator(leafReaderContext, filterIdsBitSet);
            while ((docId = iterator.nextDoc()) != Integer.MAX_VALUE) {
                if (!(iterator.score() > topDoc.score)) continue;
                topDoc.score = iterator.score();
                topDoc.doc = docId;
                topDoc = (ScoreDoc)queue.updateTop();
            }
            while (queue.size() > 0 && ((ScoreDoc)queue.top()).score < 0.0f) {
                queue.pop();
            }
            while (queue.size() > 0) {
                ScoreDoc doc = (ScoreDoc)queue.pop();
                docToScore.put(doc.doc, Float.valueOf(doc.score));
            }
            return docToScore;
        }
        catch (Exception e) {
            log.error("Error while getting the doc values to do the k-NN Search for query : {}", (Object)this.knnQuery, (Object)e);
            return Collections.emptyMap();
        }
    }

    private KNNIterator getFilteredKNNIterator(LeafReaderContext leafReaderContext, BitSet filterIdsBitSet) throws IOException {
        SegmentReader reader = Lucene.segmentReader((LeafReader)leafReaderContext.reader());
        FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(this.knnQuery.getField());
        BinaryDocValues values = DocValues.getBinary((LeafReader)leafReaderContext.reader(), (String)fieldInfo.getName());
        SpaceType spaceType = this.getSpaceType(fieldInfo);
        if (VectorDataType.BINARY == this.knnQuery.getVectorDataType()) {
            return this.knnQuery.getParentsFilter() == null ? new FilteredIdsKNNByteIterator(filterIdsBitSet, this.knnQuery.getByteQueryVector(), values, spaceType) : new NestedFilteredIdsKNNByteIterator(filterIdsBitSet, this.knnQuery.getByteQueryVector(), values, spaceType, this.knnQuery.getParentsFilter().getBitSet(leafReaderContext));
        }
        return this.knnQuery.getParentsFilter() == null ? new FilteredIdsKNNIterator(filterIdsBitSet, this.knnQuery.getQueryVector(), values, spaceType) : new NestedFilteredIdsKNNIterator(filterIdsBitSet, this.knnQuery.getQueryVector(), values, spaceType, this.knnQuery.getParentsFilter().getBitSet(leafReaderContext));
    }

    private Scorer convertSearchResponseToScorer(Map<Integer, Float> docsToScore) throws IOException {
        int maxDoc = Collections.max(docsToScore.keySet()) + 1;
        DocIdSetBuilder docIdSetBuilder = new DocIdSetBuilder(maxDoc);
        DocIdSetBuilder.BulkAdder setAdder = docIdSetBuilder.grow(docsToScore.size());
        docsToScore.keySet().forEach(arg_0 -> ((DocIdSetBuilder.BulkAdder)setAdder).add(arg_0));
        DocIdSetIterator docIdSetIter = docIdSetBuilder.build().iterator();
        return new KNNScorer(this, docIdSetIter, docsToScore, this.boost);
    }

    public boolean isCacheable(LeafReaderContext context) {
        return true;
    }

    public static float normalizeScore(float score) {
        if (score >= 0.0f) {
            return 1.0f / (1.0f + score);
        }
        return -score + 1.0f;
    }

    private SpaceType getSpaceType(FieldInfo fieldInfo) {
        String spaceTypeString = fieldInfo.getAttribute("spaceType");
        if (StringUtils.isNotEmpty((String)spaceTypeString)) {
            return SpaceType.getSpace(spaceTypeString);
        }
        String modelId = fieldInfo.getAttribute("model_id");
        if (StringUtils.isNotEmpty((String)modelId)) {
            ModelMetadata modelMetadata = modelDao.getMetadata(modelId);
            return modelMetadata.getSpaceType();
        }
        throw new IllegalArgumentException(String.format(Locale.ROOT, "Unable to find the Space Type from Field Info attribute for field %s", fieldInfo.getName()));
    }

    private boolean canDoExactSearch(int filterIdsCount) {
        log.debug("Info for doing exact search filterIdsLength : {}, Threshold value: {}", (Object)filterIdsCount, (Object)KNNSettings.getFilteredExactSearchThreshold(this.knnQuery.getIndexName()));
        if (this.knnQuery.getRadius() != null) {
            return false;
        }
        int filterThresholdValue = KNNSettings.getFilteredExactSearchThreshold(this.knnQuery.getIndexName());
        if (filterIdsCount <= this.knnQuery.getK()) {
            return true;
        }
        if (this.isExactSearchThresholdSettingSet(filterThresholdValue)) {
            return filterThresholdValue >= filterIdsCount;
        }
        return KNNConstants.MAX_DISTANCE_COMPUTATIONS >= filterIdsCount * (this.knnQuery.getVectorDataType() == VectorDataType.FLOAT ? this.knnQuery.getQueryVector().length : this.knnQuery.getByteQueryVector().length);
    }

    private boolean isExactSearchThresholdSettingSet(int filterThresholdValue) {
        return filterThresholdValue != KNNSettings.ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_DEFAULT_VALUE;
    }

    private boolean canDoExactSearchAfterANNSearch(int filterIdsCount, int annResultCount) {
        return this.filterWeight != null && filterIdsCount >= this.knnQuery.getK() && this.knnQuery.getK() > annResultCount;
    }
}

