/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.searchrelevance.judgments;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.StepListener;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.common.inject.Inject;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.search.SearchHit;
import org.opensearch.searchrelevance.common.MLConstants;
import org.opensearch.searchrelevance.dao.JudgmentCacheDao;
import org.opensearch.searchrelevance.dao.QuerySetDao;
import org.opensearch.searchrelevance.dao.SearchConfigurationDao;
import org.opensearch.searchrelevance.exception.SearchRelevanceException;
import org.opensearch.searchrelevance.judgments.BaseJudgmentsProcessor;
import org.opensearch.searchrelevance.ml.ChunkResult;
import org.opensearch.searchrelevance.ml.MLAccessor;
import org.opensearch.searchrelevance.model.JudgmentCache;
import org.opensearch.searchrelevance.model.JudgmentType;
import org.opensearch.searchrelevance.model.QuerySet;
import org.opensearch.searchrelevance.model.SearchConfiguration;
import org.opensearch.searchrelevance.model.builder.SearchRequestBuilder;
import org.opensearch.searchrelevance.stats.events.EventStatName;
import org.opensearch.searchrelevance.stats.events.EventStatsManager;
import org.opensearch.searchrelevance.utils.ParserUtils;
import org.opensearch.searchrelevance.utils.TimeUtils;
import org.opensearch.transport.client.Client;

public class LlmJudgmentsProcessor
implements BaseJudgmentsProcessor {
    private static final Logger LOGGER = LogManager.getLogger(LlmJudgmentsProcessor.class);
    private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
    private final MLAccessor mlAccessor;
    private final QuerySetDao querySetDao;
    private final SearchConfigurationDao searchConfigurationDao;
    private final JudgmentCacheDao judgmentCacheDao;
    private final Client client;

    @Inject
    public LlmJudgmentsProcessor(MLAccessor mlAccessor, QuerySetDao querySetDao, SearchConfigurationDao searchConfigurationDao, JudgmentCacheDao judgmentCacheDao, Client client) {
        this.mlAccessor = mlAccessor;
        this.querySetDao = querySetDao;
        this.searchConfigurationDao = searchConfigurationDao;
        this.judgmentCacheDao = judgmentCacheDao;
        this.client = client;
    }

    @Override
    public JudgmentType getJudgmentType() {
        return JudgmentType.LLM_JUDGMENT;
    }

    @Override
    public void generateJudgmentRating(Map<String, Object> metadata, ActionListener<List<Map<String, Object>>> listener) {
        try {
            EventStatsManager.increment(EventStatName.LLM_JUDGMENT_RATING_GENERATIONS);
            String querySetId = (String)metadata.get("querySetId");
            List searchConfigurationList = (List)metadata.get("searchConfigurationList");
            int size = (Integer)metadata.get("size");
            String modelId = (String)metadata.get("modelId");
            int tokenLimit = (Integer)metadata.get("tokenLimit");
            List contextFields = (List)metadata.get("contextFields");
            boolean ignoreFailure = (Boolean)metadata.get("ignoreFailure");
            QuerySet querySet = this.querySetDao.getQuerySetSync(querySetId);
            List<SearchConfiguration> searchConfigurations = searchConfigurationList.stream().map(id -> this.searchConfigurationDao.getSearchConfigurationSync((String)id)).collect(Collectors.toList());
            List<Map<String, Object>> judgments = this.generateLLMJudgments(modelId, size, tokenLimit, contextFields, querySet, searchConfigurations, ignoreFailure);
            listener.onResponse(judgments);
        }
        catch (Exception e) {
            LOGGER.error("Failed to generate LLM judgments", (Throwable)e);
            listener.onFailure((Exception)((Object)new SearchRelevanceException("Failed to generate LLM judgments", e, RestStatus.INTERNAL_SERVER_ERROR)));
        }
    }

    private List<Map<String, Object>> generateLLMJudgments(String modelId, int size, int tokenLimit, List<String> contextFields, QuerySet querySet, List<SearchConfiguration> searchConfigurations, boolean ignoreFailure) {
        List queryTextWithReferences = querySet.querySetQueries().stream().map(e -> e.queryText()).collect(Collectors.toList());
        ArrayList<Map<String, Object>> allJudgments = new ArrayList<Map<String, Object>>();
        for (String queryTextWithReference : queryTextWithReferences) {
            try {
                Map<String, String> docIdToScore = this.processQueryText(modelId, size, tokenLimit, contextFields, searchConfigurations, queryTextWithReference, ignoreFailure);
                HashMap<String, Object> judgmentForQuery = new HashMap<String, Object>();
                judgmentForQuery.put("query", queryTextWithReference);
                List docIdRatings = docIdToScore.entrySet().stream().map(entry -> Map.of("docId", (String)entry.getKey(), "rating", (String)entry.getValue())).collect(Collectors.toList());
                judgmentForQuery.put("ratings", docIdRatings);
                allJudgments.add(judgmentForQuery);
                LOGGER.debug("Processed query: {} with {} ratings", (Object)queryTextWithReference, (Object)docIdRatings.size());
            }
            catch (Exception e2) {
                LOGGER.error("Failed to process query: {}", (Object)queryTextWithReference, (Object)e2);
                if (ignoreFailure) continue;
                throw new SearchRelevanceException("Failed to generate LLM judgments", e2, RestStatus.INTERNAL_SERVER_ERROR);
            }
        }
        LOGGER.info("Completed processing {} queries", (Object)queryTextWithReferences.size());
        return allJudgments;
    }

    private Map<String, String> processQueryText(String modelId, int size, int tokenLimit, List<String> contextFields, List<SearchConfiguration> searchConfigurations, String queryTextWithReference, boolean ignoreFailure) {
        ConcurrentHashMap<String, String> docIdToScore;
        block8: {
            HashMap<String, String> unionHits = new HashMap<String, String>();
            docIdToScore = new ConcurrentHashMap<String, String>();
            HashMap<String, SearchHit> allHits = new HashMap<String, SearchHit>();
            String queryText = queryTextWithReference.split("#", 2)[0];
            for (SearchConfiguration searchConfiguration : searchConfigurations) {
                String index = searchConfiguration.index();
                String query = searchConfiguration.query();
                String searchPipeline = searchConfiguration.searchPipeline();
                try {
                    SearchRequest searchRequest = SearchRequestBuilder.buildSearchRequest(index, query, queryText, searchPipeline, size);
                    SearchResponse response = (SearchResponse)this.client.search(searchRequest).actionGet();
                    for (SearchHit hit : response.getHits().getHits()) {
                        allHits.put(hit.getId(), hit);
                    }
                }
                catch (Exception e) {
                    LOGGER.error("Search failed for index: {}", (Object)index, (Object)e);
                    if (ignoreFailure) continue;
                    throw new SearchRelevanceException("Search failed", e, RestStatus.INTERNAL_SERVER_ERROR);
                }
            }
            try {
                String index = searchConfigurations.get(0).index();
                ArrayList<String> docIds = new ArrayList<String>(allHits.keySet());
                List<String> unprocessedDocIds = this.deduplicateFromProcessedDocs(index, queryTextWithReference, docIds, contextFields, docIdToScore);
                LOGGER.info("Cached docIds: {}", docIdToScore.keySet());
                LOGGER.info("Unprocessed docIds: {}", unprocessedDocIds);
                for (String docId : unprocessedDocIds) {
                    SearchHit hit = (SearchHit)allHits.get(docId);
                    String compositeKey = ParserUtils.combinedIndexAndDocId(index, docId);
                    String contextSource = this.getContextSource(hit, contextFields);
                    unionHits.put(compositeKey, contextSource);
                }
                LOGGER.info("UnionHits size: {}", (Object)unionHits.size());
                if (!unionHits.isEmpty()) {
                    LOGGER.info("Processing {} uncached docs with LLM for query: {}", (Object)unionHits.size(), (Object)queryText);
                    PlainActionFuture llmFuture = PlainActionFuture.newFuture();
                    this.generateLLMJudgmentForQueryText(modelId, queryTextWithReference, tokenLimit, contextFields, unionHits, docIdToScore, ignoreFailure, (ActionListener<Map<String, String>>)llmFuture);
                    Map llmRatings = (Map)llmFuture.actionGet();
                    LOGGER.info("LLM returned ratings: {}", (Object)llmRatings);
                    docIdToScore.putAll(llmRatings);
                }
            }
            catch (Exception e) {
                LOGGER.error("Failed to process hits for query: {}", (Object)queryText, (Object)e);
                if (ignoreFailure) break block8;
                throw new SearchRelevanceException("Failed to process hits", e, RestStatus.INTERNAL_SERVER_ERROR);
            }
        }
        LOGGER.info("Final docIdToScore size: {}, contents: {}", (Object)docIdToScore.size(), docIdToScore);
        return docIdToScore;
    }

    private void generateLLMJudgmentForQueryText(final String modelId, final String queryTextWithReference, int tokenLimit, final List<String> contextFields, Map<String, String> unprocessedUnionHits, Map<String, String> docIdToRating, final boolean ignoreFailure, final ActionListener<Map<String, String>> listener) {
        LOGGER.debug("calculating LLM evaluation with modelId: {} and unprocessed unionHits: {}", (Object)modelId, unprocessedUnionHits);
        LOGGER.debug("processed docIdToRating before llm evaluation: {}", docIdToRating);
        if (unprocessedUnionHits.isEmpty()) {
            LOGGER.info("All hits found in cache, returning cached results for query: {}", (Object)queryTextWithReference);
            listener.onResponse(docIdToRating);
            return;
        }
        String[] queryTextRefArr = queryTextWithReference.split("#");
        String queryText = queryTextRefArr[0];
        String referenceAnswer = queryTextRefArr.length > 1 ? queryTextWithReference.split("#", 2)[1] : null;
        final ConcurrentHashMap<String, String> processedRatings = new ConcurrentHashMap<String, String>(docIdToRating);
        final ConcurrentHashMap combinedResponses = new ConcurrentHashMap();
        final AtomicBoolean hasFailure = new AtomicBoolean(false);
        this.mlAccessor.predict(modelId, tokenLimit, queryText, referenceAnswer, unprocessedUnionHits, ignoreFailure, new ActionListener<ChunkResult>(){

            public void onResponse(ChunkResult chunkResult) {
                try {
                    if (LlmJudgmentsProcessor.this.shouldFailImmediately(ignoreFailure, chunkResult)) {
                        String firstError = chunkResult.getFailedChunks().values().iterator().next();
                        this.handleProcessingError(new Exception(firstError), true);
                        return;
                    }
                    Map<Integer, String> succeededChunks = chunkResult.getSucceededChunks();
                    for (Map.Entry<Integer, String> entry : succeededChunks.entrySet()) {
                        Integer chunkIndex = entry.getKey();
                        if (combinedResponses.containsKey(chunkIndex)) continue;
                        LOGGER.debug("response before sanitization: {}", (Object)entry.getValue());
                        String sanitizedResponse = MLConstants.sanitizeLLMResponse(entry.getValue());
                        LOGGER.debug("response after sanitization: {}", (Object)sanitizedResponse);
                        List scores = (List)OBJECT_MAPPER.readValue(sanitizedResponse, (TypeReference)new TypeReference<List<Map<String, Object>>>(this){});
                        combinedResponses.put(chunkIndex, scores);
                    }
                    LlmJudgmentsProcessor.this.logFailedChunks(ignoreFailure, chunkResult);
                    if (chunkResult.isLastChunk() && !hasFailure.get()) {
                        LOGGER.info("Processing final results for query: {}. Successful chunks: {}, Failed chunks: {}", (Object)queryTextWithReference, (Object)chunkResult.getSuccessfulChunksCount(), (Object)chunkResult.getFailedChunksCount());
                        for (List ratings : combinedResponses.values()) {
                            for (Map rating : ratings) {
                                String compositeKey = (String)rating.get("id");
                                Double ratingScore = ((Number)rating.get("rating_score")).doubleValue();
                                String docId = ParserUtils.getDocIdFromCompositeKey(compositeKey);
                                processedRatings.put(docId, ratingScore.toString());
                                LlmJudgmentsProcessor.this.updateJudgmentCache(compositeKey, queryTextWithReference, contextFields, ratingScore.toString(), modelId);
                            }
                        }
                        listener.onResponse((Object)processedRatings);
                    }
                }
                catch (Exception e) {
                    this.handleProcessingError(e, chunkResult.isLastChunk());
                }
            }

            public void onFailure(Exception e) {
                this.handleProcessingError(e, true);
            }

            private void handleProcessingError(Exception e, boolean isLastChunk) {
                if (!ignoreFailure || isLastChunk) {
                    if (!hasFailure.getAndSet(true)) {
                        LOGGER.error("Failed to process chunk response", (Throwable)e);
                        listener.onFailure((Exception)((Object)new SearchRelevanceException("Failed to process chunk response", e, RestStatus.INTERNAL_SERVER_ERROR)));
                    }
                } else {
                    LOGGER.warn("Error processing chunk, continuing due to ignoreFailure=true", (Throwable)e);
                }
            }
        });
    }

    private List<String> deduplicateFromProcessedDocs(String targetIndex, String queryTextWithReference, List<String> docIds, List<String> contextFields, ConcurrentMap<String, String> docIdToRating) {
        HashSet<String> unprocessedDocIds = new HashSet<String>(docIds);
        for (String docId : docIds) {
            String compositeKey = ParserUtils.combinedIndexAndDocId(targetIndex, docId);
            try {
                PlainActionFuture future = PlainActionFuture.newFuture();
                this.judgmentCacheDao.getJudgmentCache(queryTextWithReference, compositeKey, contextFields, (ActionListener<SearchResponse>)future);
                SearchResponse response = (SearchResponse)future.actionGet();
                if (response.getHits().getTotalHits().value() <= 0L) continue;
                SearchHit hit = response.getHits().getHits()[0];
                Map source = hit.getSourceAsMap();
                String rating = (String)source.get("rating");
                String storedContextFields = (String)source.get("contextFieldsStr");
                LOGGER.info("Found existing judgment for docId: {}, rating: {}, storedContextFields: {}", (Object)docId, (Object)rating, (Object)storedContextFields);
                docIdToRating.put(docId, rating);
                unprocessedDocIds.remove(docId);
            }
            catch (Exception e) {
                LOGGER.error("Failed to check judgment cache for queryTextWithReference: {} and docId: {}", (Object)queryTextWithReference, (Object)docId, (Object)e);
            }
        }
        return new ArrayList<String>(unprocessedDocIds);
    }

    private void updateJudgmentCache(String compositeKey, String queryText, List<String> contextFields, String rating, String modelId) {
        JudgmentCache judgmentCache = new JudgmentCache(ParserUtils.generateUniqueId(queryText, compositeKey, contextFields), TimeUtils.getTimestamp(), queryText, compositeKey, contextFields, rating, modelId);
        StepListener createIndexStep = new StepListener();
        this.judgmentCacheDao.createIndexIfAbsent((StepListener<Void>)createIndexStep);
        createIndexStep.whenComplete(v -> this.judgmentCacheDao.upsertJudgmentCache(judgmentCache, ActionListener.wrap(response -> LOGGER.debug("Successfully processed judgment cache for queryText: {} and compositeKey: {}, contextFields: {}", (Object)queryText, (Object)compositeKey, (Object)contextFields), e -> LOGGER.error("Failed to process judgment cache for queryText: {} and compositeKey: {}, contextFields: {}", (Object)queryText, (Object)compositeKey, (Object)contextFields, e))), e -> LOGGER.error("Failed to create judgment cache index for queryText: {} and compositeKey: {}, contextFields: {}", (Object)queryText, (Object)compositeKey, (Object)contextFields, e));
    }

    private boolean shouldFailImmediately(boolean ignoreFailure, ChunkResult chunkResult) {
        return !ignoreFailure && !chunkResult.getFailedChunks().isEmpty();
    }

    private void logFailedChunks(boolean ignoreFailure, ChunkResult chunkResult) {
        if (ignoreFailure) {
            chunkResult.getFailedChunks().forEach((index, error) -> LOGGER.warn("Chunk {} failed: {}", index, error));
        }
    }

    private String getContextSource(SearchHit hit, List<String> contextFields) {
        try {
            if (contextFields != null && !contextFields.isEmpty()) {
                HashMap filteredSource = new HashMap();
                Map sourceAsMap = hit.getSourceAsMap();
                for (String field : contextFields) {
                    if (!sourceAsMap.containsKey(field)) continue;
                    filteredSource.put(field, sourceAsMap.get(field));
                }
                return OBJECT_MAPPER.writeValueAsString(filteredSource);
            }
            return hit.getSourceAsString();
        }
        catch (JsonProcessingException e) {
            LOGGER.error("Failed to process context source for hit: {}", (Object)hit.getId(), (Object)e);
            throw new RuntimeException("Failed to process context source", e);
        }
    }
}

