/*
 * Decompiled with CFR 0.152.
 */
package com.o19s.es.ltr.logging;

import com.o19s.es.ltr.feature.FeatureSet;
import com.o19s.es.ltr.logging.LoggingSearchExtBuilder;
import com.o19s.es.ltr.query.RankerQuery;
import com.o19s.es.ltr.ranker.LogLtrRanker;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.BoostQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Weight;
import org.opensearch.common.collect.Tuple;
import org.opensearch.common.document.DocumentField;
import org.opensearch.search.SearchHit;
import org.opensearch.search.fetch.FetchContext;
import org.opensearch.search.fetch.FetchSubPhase;
import org.opensearch.search.fetch.FetchSubPhaseProcessor;
import org.opensearch.search.rescore.QueryRescorer;
import org.opensearch.search.rescore.RescoreContext;

public class LoggingFetchSubPhase
implements FetchSubPhase {
    public FetchSubPhaseProcessor getProcessor(FetchContext context) throws IOException {
        LoggingSearchExtBuilder ext = (LoggingSearchExtBuilder)context.getSearchExt("ltr_log");
        if (ext == null) {
            return null;
        }
        BooleanQuery.Builder builder = new BooleanQuery.Builder();
        ArrayList<HitLogConsumer> loggers = new ArrayList<HitLogConsumer>();
        Map namedQueries = context.parsedQuery().namedFilters();
        if (namedQueries.size() > 0) {
            ext.logSpecsStream().filter(l -> l.getNamedQuery() != null).forEach(l -> {
                Tuple<RankerQuery, HitLogConsumer> query = this.extractQuery((LoggingSearchExtBuilder.LogSpec)l, namedQueries);
                builder.add(new BooleanClause((Query)query.v1(), BooleanClause.Occur.MUST));
                loggers.add((HitLogConsumer)query.v2());
            });
            ext.logSpecsStream().filter(l -> l.getRescoreIndex() != null).forEach(l -> {
                Tuple<RankerQuery, HitLogConsumer> query = this.extractRescore((LoggingSearchExtBuilder.LogSpec)l, context.rescore());
                builder.add(new BooleanClause((Query)query.v1(), BooleanClause.Occur.MUST));
                loggers.add((HitLogConsumer)query.v2());
            });
        }
        Weight w = context.searcher().rewrite((Query)builder.build()).createWeight((IndexSearcher)context.searcher(), ScoreMode.COMPLETE, 1.0f);
        return new LoggingFetchSubPhaseProcessor(w, loggers);
    }

    private Tuple<RankerQuery, HitLogConsumer> extractQuery(LoggingSearchExtBuilder.LogSpec logSpec, Map<String, Query> namedQueries) {
        Query q = namedQueries.get(logSpec.getNamedQuery());
        if (q == null) {
            throw new IllegalArgumentException("No query named [" + logSpec.getNamedQuery() + "] found");
        }
        return this.toLogger(logSpec, this.inspectQuery(q).orElseThrow(() -> new IllegalArgumentException("Query named [" + logSpec.getNamedQuery() + "] must be a [sltr] query [" + (q instanceof BoostQuery ? ((BoostQuery)q).getQuery().getClass().getSimpleName() : q.getClass().getSimpleName()) + "] found")));
    }

    private Tuple<RankerQuery, HitLogConsumer> extractRescore(LoggingSearchExtBuilder.LogSpec logSpec, List<RescoreContext> contexts) {
        if (logSpec.getRescoreIndex() >= contexts.size()) {
            throw new IllegalArgumentException("rescore index [" + logSpec.getRescoreIndex() + "] is out of bounds, only [" + contexts.size() + "] rescore context(s) are available");
        }
        RescoreContext context = contexts.get(logSpec.getRescoreIndex());
        if (!(context instanceof QueryRescorer.QueryRescoreContext)) {
            throw new IllegalArgumentException("Expected a [QueryRescoreContext] but found a [" + context.getClass().getSimpleName() + "] at index [" + logSpec.getRescoreIndex() + "]");
        }
        QueryRescorer.QueryRescoreContext qrescore = (QueryRescorer.QueryRescoreContext)context;
        return this.toLogger(logSpec, this.inspectQuery(qrescore.query()).orElseThrow(() -> new IllegalArgumentException("Expected a [sltr] query but found a [" + qrescore.query().getClass().getSimpleName() + "] at index [" + logSpec.getRescoreIndex() + "]")));
    }

    private Optional<RankerQuery> inspectQuery(Query q) {
        if (q instanceof RankerQuery) {
            return Optional.of((RankerQuery)q);
        }
        if (q instanceof BoostQuery && ((BoostQuery)q).getQuery() instanceof RankerQuery) {
            return Optional.of((RankerQuery)((BoostQuery)q).getQuery());
        }
        return Optional.empty();
    }

    private Tuple<RankerQuery, HitLogConsumer> toLogger(LoggingSearchExtBuilder.LogSpec logSpec, RankerQuery query) {
        HitLogConsumer consumer = new HitLogConsumer(logSpec.getLoggerName(), query.featureSet(), logSpec.isMissingAsZero());
        query = query.toLoggerQuery(consumer);
        return new Tuple((Object)query, (Object)consumer);
    }

    static class LoggingFetchSubPhaseProcessor
    implements FetchSubPhaseProcessor {
        private final Weight weight;
        private final List<HitLogConsumer> loggers;
        private Scorer scorer;

        LoggingFetchSubPhaseProcessor(Weight weight, List<HitLogConsumer> loggers) {
            this.weight = weight;
            this.loggers = loggers;
        }

        public void setNextReader(LeafReaderContext readerContext) throws IOException {
            this.scorer = this.weight.scorer(readerContext);
        }

        public void process(FetchSubPhase.HitContext hitContext) throws IOException {
            if (this.scorer != null && this.scorer.iterator().advance(hitContext.docId()) == hitContext.docId()) {
                this.loggers.forEach(l -> l.nextDoc(hitContext.hit()));
                this.scorer.score();
            }
        }
    }

    static class HitLogConsumer
    implements LogLtrRanker.LogConsumer {
        private static final String FIELD_NAME = "_ltrlog";
        private static final String EXTRA_LOGGING_NAME = "extra_logging";
        private final String name;
        private final FeatureSet set;
        private final boolean missingAsZero;
        private List<Map<String, Object>> currentLog;
        private SearchHit currentHit;
        private Map<String, Object> extraLogging;

        HitLogConsumer(String name, FeatureSet set, boolean missingAsZero) {
            this.name = name;
            this.set = set;
            this.missingAsZero = missingAsZero;
        }

        private void rebuild() {
            ArrayList<Map<String, Object>> ini = new ArrayList<Map<String, Object>>(this.set.size() + 1);
            for (int i = 0; i < this.set.size(); ++i) {
                HashMap<String, Object> defaultKeyVal = new HashMap<String, Object>();
                defaultKeyVal.put("name", this.set.feature(i).name());
                if (this.missingAsZero) {
                    defaultKeyVal.put("value", Float.valueOf(0.0f));
                }
                ini.add(i, defaultKeyVal);
            }
            this.currentLog = ini;
            this.extraLogging = null;
        }

        @Override
        public void accept(int featureOrdinal, float score) {
            assert (this.currentLog != null);
            assert (this.currentHit != null);
            this.currentLog.get(featureOrdinal).put("value", Float.valueOf(score));
        }

        @Override
        public Map<String, Object> getExtraLoggingMap() {
            if (this.extraLogging == null) {
                this.extraLogging = new HashMap<String, Object>();
                HashMap<String, Object> logEntry = new HashMap<String, Object>();
                logEntry.put("name", EXTRA_LOGGING_NAME);
                logEntry.put("value", this.extraLogging);
                this.currentLog.add(logEntry);
            }
            return this.extraLogging;
        }

        void nextDoc(SearchHit hit) {
            DocumentField logs = (DocumentField)hit.getFields().get(FIELD_NAME);
            if (logs == null) {
                logs = this.newLogField();
                hit.setDocumentField(FIELD_NAME, logs);
            }
            Map entries = (Map)logs.getValue();
            this.rebuild();
            this.currentHit = hit;
            entries.put(this.name, this.currentLog);
        }

        DocumentField newLogField() {
            List logList = Collections.singletonList(new HashMap());
            return new DocumentField(FIELD_NAME, logList);
        }
    }
}

