/*
 * Decompiled with CFR 0.152.
 */
package com.o19s.es.ltr.feature.store;

import com.o19s.es.ltr.LtrQueryContext;
import com.o19s.es.ltr.feature.Feature;
import com.o19s.es.ltr.feature.FeatureSet;
import com.o19s.es.ltr.feature.store.ExtraLoggingSupplier;
import com.o19s.es.ltr.feature.store.FeatureSupplier;
import com.o19s.es.ltr.feature.store.StoredFeature;
import com.o19s.es.ltr.query.LtrRewritableQuery;
import com.o19s.es.ltr.query.LtrRewriteContext;
import com.o19s.es.ltr.ranker.LogLtrRanker;
import com.o19s.es.termstat.TermStatSupplier;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.TermToBytesRefAttribute;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.TermStates;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.ScorerSupplier;
import org.apache.lucene.search.Weight;
import org.opensearch.common.lucene.search.function.LeafScoreFunction;
import org.opensearch.common.lucene.search.function.ScriptScoreFunction;
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.xcontent.DeprecationHandler;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.analysis.NamedAnalyzer;
import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.ltr.settings.LTRSettings;
import org.opensearch.script.ScoreScript;
import org.opensearch.script.Script;

public class ScriptFeature
implements Feature {
    public static final String TEMPLATE_LANGUAGE = "script_feature";
    public static final String FEATURE_VECTOR = "feature_vector";
    public static final String TERM_STAT = "termStats";
    public static final String MATCH_COUNT = "matchCount";
    public static final String UNIQUE_TERMS = "uniqueTerms";
    public static final String EXTRA_LOGGING = "extra_logging";
    public static final String EXTRA_SCRIPT_PARAMS = "extra_script_params";
    private static final ThreadLocal<TermStatSupplier> CURRENT_TERM_STATS = new ThreadLocal();
    private final String name;
    private final Script script;
    private final Collection<String> queryParams;
    private final Map<String, Object> baseScriptParams;
    private final Map<String, String> extraScriptParams;

    public ScriptFeature(String name, Script script, Collection<String> queryParams) {
        this.name = Objects.requireNonNull(name);
        this.script = Objects.requireNonNull(script);
        this.queryParams = queryParams;
        HashMap<String, Object> ltrScriptParams = new HashMap<String, Object>();
        Map<String, String> ltrExtraScriptParams = new HashMap<String, String>();
        for (Map.Entry entry : script.getParams().entrySet()) {
            if (!((String)entry.getKey()).equals(EXTRA_SCRIPT_PARAMS)) {
                ltrScriptParams.put(String.valueOf(entry.getKey()), entry.getValue());
                continue;
            }
            ltrExtraScriptParams = (Map)entry.getValue();
        }
        this.baseScriptParams = ltrScriptParams;
        this.extraScriptParams = ltrExtraScriptParams;
    }

    public static ScriptFeature compile(StoredFeature feature) {
        try {
            XContentParser xContentParser = XContentType.JSON.xContent().createParser(NamedXContentRegistry.EMPTY, (DeprecationHandler)LoggingDeprecationHandler.INSTANCE, feature.template());
            return new ScriptFeature(feature.name(), Script.parse((XContentParser)xContentParser, (String)"native"), feature.queryParams());
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public String name() {
        return this.name;
    }

    @Override
    public Query doToQuery(LtrQueryContext context, FeatureSet featureSet, Map<String, Object> params) {
        List missingParams = this.queryParams.stream().filter(x -> !params.containsKey(x)).collect(Collectors.toList());
        if (!missingParams.isEmpty()) {
            String names = String.join((CharSequence)",", missingParams);
            throw new IllegalArgumentException("Missing required param(s): [" + names + "]");
        }
        HashMap<String, Object> queryTimeParams = new HashMap<String, Object>();
        HashMap<String, Object> extraQueryTimeParams = new HashMap<String, Object>();
        for (String x2 : this.queryParams) {
            if (!params.containsKey(x2)) continue;
            if (this.extraScriptParams.containsKey(x2)) {
                extraQueryTimeParams.put(this.extraScriptParams.get(x2), params.get(x2));
                continue;
            }
            queryTimeParams.put(x2, params.get(x2));
        }
        FeatureSupplier supplier = new FeatureSupplier(featureSet);
        ExtraLoggingSupplier extraLoggingSupplier = new ExtraLoggingSupplier();
        HashMap<String, Object> nparams = new HashMap<String, Object>();
        HashSet<Term> terms = new HashSet<Term>();
        if (this.baseScriptParams.containsKey("term_stat")) {
            HashMap termspec = (HashMap)this.baseScriptParams.get("term_stat");
            String analyzerName = null;
            ArrayList fields = null;
            ArrayList termList = null;
            Object analyzerNameObj = termspec.get("analyzer");
            Object fieldsObj = termspec.get("fields");
            Object termListObj = termspec.get("terms");
            if (analyzerNameObj != null && analyzerNameObj instanceof String) {
                analyzerName = ((String)analyzerNameObj).startsWith("!") ? ((String)analyzerNameObj).substring(1) : (String)params.get(analyzerNameObj);
            }
            if (fieldsObj != null) {
                if (fieldsObj instanceof String) {
                    fields = (ArrayList)params.get(fieldsObj);
                } else if (fieldsObj instanceof ArrayList) {
                    fields = (ArrayList)fieldsObj;
                }
            }
            if (termListObj != null) {
                if (termListObj instanceof String) {
                    termList = (ArrayList)params.get(termListObj);
                } else if (termListObj instanceof ArrayList) {
                    termList = (ArrayList)termListObj;
                }
            }
            if (fields == null || termList == null) {
                throw new IllegalArgumentException("Term Stats injection requires fields and terms");
            }
            NamedAnalyzer analyzer = null;
            for (String field : fields) {
                if (analyzerName == null) {
                    MappedFieldType fieldType = context.getQueryShardContext().getFieldType(field);
                    analyzer = fieldType.getTextSearchInfo().getSearchAnalyzer();
                } else {
                    analyzer = context.getQueryShardContext().getIndexAnalyzers().get(analyzerName);
                }
                if (analyzer == null) {
                    throw new IllegalArgumentException("No analyzer found for [" + analyzerName + "]");
                }
                for (String termString : termList) {
                    TokenStream ts = analyzer.tokenStream(field, termString);
                    TermToBytesRefAttribute termAtt = (TermToBytesRefAttribute)ts.getAttribute(TermToBytesRefAttribute.class);
                    try {
                        ts.reset();
                        while (ts.incrementToken()) {
                            terms.add(new Term(field, termAtt.getBytesRef()));
                        }
                        ts.close();
                    }
                    catch (IOException iOException) {}
                }
            }
            nparams.put(TERM_STAT, CURRENT_TERM_STATS::get);
            nparams.put(MATCH_COUNT, () -> CURRENT_TERM_STATS.get().getMatchedTermCount());
            nparams.put(UNIQUE_TERMS, terms.size());
        }
        nparams.putAll(this.baseScriptParams);
        nparams.putAll(queryTimeParams);
        nparams.putAll(extraQueryTimeParams);
        nparams.put(FEATURE_VECTOR, supplier);
        nparams.put(EXTRA_LOGGING, extraLoggingSupplier);
        Script script = new Script(this.script.getType(), this.script.getLang(), this.script.getIdOrCode(), this.script.getOptions(), nparams);
        ScoreScript.Factory factoryFactory = (ScoreScript.Factory)context.getQueryShardContext().compile(script, ScoreScript.CONTEXT);
        ScoreScript.LeafFactory leafFactory = factoryFactory.newFactory(nparams, context.getQueryShardContext().lookup(), context.getQueryShardContext().searcher());
        ScriptScoreFunction function = new ScriptScoreFunction(script, leafFactory, context.getQueryShardContext().index().getName(), context.getQueryShardContext().getShardId(), context.getQueryShardContext().indexVersionCreated(), null);
        return new LtrScript(function, supplier, extraLoggingSupplier, terms);
    }

    static class LtrScript
    extends Query
    implements LtrRewritableQuery {
        private final ScriptScoreFunction function;
        private final FeatureSupplier supplier;
        private final ExtraLoggingSupplier extraLoggingSupplier;
        private final Set<Term> terms;

        LtrScript(ScriptScoreFunction function, FeatureSupplier supplier, ExtraLoggingSupplier extraLoggingSupplier, Set<Term> terms) {
            this.function = function;
            this.supplier = supplier;
            this.extraLoggingSupplier = extraLoggingSupplier;
            this.terms = terms;
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            LtrScript ol = (LtrScript)o;
            return this.sameClassAs(o) && Objects.equals(this.function, ol.function);
        }

        public int hashCode() {
            return Objects.hash(this.classHash(), this.function);
        }

        public String toString(String field) {
            return "LtrScript:" + field;
        }

        public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
            if (!LTRSettings.isLTRPluginEnabled()) {
                throw new IllegalStateException("LTR plugin is disabled. To enable, update ltr.plugin.enabled to true");
            }
            if (!scoreMode.needsScores()) {
                return new MatchAllDocsQuery().createWeight(searcher, scoreMode, 1.0f);
            }
            return new LtrScriptWeight(this, this.function, this.terms, searcher, scoreMode);
        }

        @Override
        public Query ltrRewrite(LtrRewriteContext context) throws IOException {
            this.supplier.set(context.getFeatureVectorSupplier());
            LogLtrRanker.LogConsumer consumer = context.getLogConsumer();
            if (consumer != null) {
                this.extraLoggingSupplier.setSupplier(consumer::getExtraLoggingMap);
            } else {
                this.extraLoggingSupplier.setSupplier(() -> null);
            }
            return this;
        }

        public void visit(QueryVisitor visitor) {
            Set fields = this.terms.stream().map(Term::field).collect(Collectors.toUnmodifiableSet());
            for (String field : fields) {
                if (visitor.acceptField(field)) continue;
                return;
            }
            visitor.getSubVisitor(BooleanClause.Occur.SHOULD, (Query)this).consumeTerms((Query)this, this.terms.toArray(new Term[0]));
        }
    }

    static class LtrScriptWeight
    extends Weight {
        private final IndexSearcher searcher;
        private final ScoreMode scoreMode;
        private final ScriptScoreFunction function;
        private final Set<Term> terms;
        private final HashMap<Term, TermStates> termContexts;

        LtrScriptWeight(Query query, ScriptScoreFunction function, Set<Term> terms, IndexSearcher searcher, ScoreMode scoreMode) throws IOException {
            super(query);
            this.function = function;
            this.terms = terms;
            this.searcher = searcher;
            this.scoreMode = scoreMode;
            this.termContexts = new HashMap();
            if (scoreMode.needsScores()) {
                for (Term t : terms) {
                    TermStates ctx = TermStates.build((IndexSearcher)searcher, (Term)t, (boolean)true);
                    if (ctx != null && ctx.docFreq() > 0) {
                        searcher.collectionStatistics(t.field());
                        searcher.termStatistics(t, ctx.docFreq(), ctx.totalTermFreq());
                    }
                    this.termContexts.put(t, ctx);
                }
            }
        }

        public Explanation explain(LeafReaderContext context, int doc) throws IOException {
            return this.function.getLeafScoreFunction(context).explainScore(doc, Explanation.noMatch((String)"none", (Explanation[])new Explanation[0]));
        }

        public Scorer getScorer(final LeafReaderContext context) throws IOException {
            final LeafScoreFunction leafScoreFunction = this.function.getLeafScoreFunction(context);
            final DocIdSetIterator iterator = DocIdSetIterator.all((int)context.reader().maxDoc());
            final TermStatSupplier termStatSupplier = new TermStatSupplier();
            return new Scorer(){

                public int docID() {
                    return iterator.docID();
                }

                public float score() throws IOException {
                    CURRENT_TERM_STATS.set(termStatSupplier);
                    if (terms.size() > 0) {
                        termStatSupplier.bump(searcher, context, this.docID(), terms, scoreMode, termContexts);
                    }
                    float score = (float)leafScoreFunction.score(iterator.docID(), 0.0f);
                    CURRENT_TERM_STATS.remove();
                    return score;
                }

                public DocIdSetIterator iterator() {
                    return iterator;
                }

                public float getMaxScore(int upTo) throws IOException {
                    return Float.POSITIVE_INFINITY;
                }
            };
        }

        public ScorerSupplier scorerSupplier(final LeafReaderContext context) throws IOException {
            final Scorer scorer = this.getScorer(context);
            return new ScorerSupplier(this){

                public Scorer get(long leadCost) throws IOException {
                    return scorer;
                }

                public long cost() {
                    return context.reader().maxDoc();
                }
            };
        }

        public void extractTerms(Set<Term> terms) {
        }

        public boolean isCacheable(LeafReaderContext ctx) {
            return false;
        }
    }
}

