/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.searchrelevance.ml;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.util.CollectionUtils;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.searchrelevance.common.MLConstants;
import org.opensearch.searchrelevance.ml.ChunkResult;
import org.opensearch.searchrelevance.ml.TokenizerUtil;

public class MLAccessor {
    private MachineLearningNodeClient mlClient;
    private static final Logger LOGGER = LogManager.getLogger(MLAccessor.class);
    private static final int MAX_RETRY_NUMBER = 3;
    private static final long RETRY_DELAY_MS = 1000L;

    public MLAccessor(MachineLearningNodeClient mlClient) {
        this.mlClient = mlClient;
    }

    public void predict(String modelId, int tokenLimit, String searchText, String reference, Map<String, String> hits, final boolean ignoreFailure, final ActionListener<ChunkResult> progressListener) {
        final List<MLInput> mlInputs = this.getMLInputs(tokenLimit, searchText, reference, hits);
        LOGGER.info("Number of chunks: {}", (Object)mlInputs.size());
        final ConcurrentHashMap succeededChunks = new ConcurrentHashMap();
        final ConcurrentHashMap failedChunks = new ConcurrentHashMap();
        final AtomicInteger processedChunks = new AtomicInteger(0);
        int i = 0;
        while (i < mlInputs.size()) {
            final int chunkIndex = i++;
            this.predictSingleChunkWithRetry(modelId, mlInputs.get(chunkIndex), chunkIndex, 0, new ActionListener<String>(){

                public void onResponse(String response) {
                    LOGGER.info("Chunk {} processed successfully", (Object)chunkIndex);
                    String processedResponse = response.substring(1, response.length() - 1);
                    MLAccessor.this.handleChunkCompletion(chunkIndex, processedResponse, null, mlInputs.size(), succeededChunks, failedChunks, ignoreFailure, processedChunks, (ActionListener<ChunkResult>)progressListener);
                }

                public void onFailure(Exception e) {
                    LOGGER.error("Chunk {} failed after all retries", (Object)chunkIndex, (Object)e);
                    MLAccessor.this.handleChunkCompletion(chunkIndex, null, e, mlInputs.size(), succeededChunks, failedChunks, ignoreFailure, processedChunks, (ActionListener<ChunkResult>)progressListener);
                }
            });
        }
    }

    private void predictSingleChunkWithRetry(final String modelId, final MLInput mlInput, final int chunkIndex, final int retryCount, final ActionListener<String> chunkListener) {
        this.predictSingleChunk(modelId, mlInput, new ActionListener<String>(){

            public void onResponse(String response) {
                chunkListener.onResponse((Object)response);
            }

            public void onFailure(Exception e) {
                if (retryCount < 3) {
                    LOGGER.warn("Chunk {} failed, attempt {}/{}. Retrying...", (Object)chunkIndex, (Object)(retryCount + 1), (Object)3);
                    long delay = 1000L * (long)Math.pow(2.0, retryCount);
                    MLAccessor.this.scheduleRetry(() -> MLAccessor.this.predictSingleChunkWithRetry(modelId, mlInput, chunkIndex, retryCount + 1, (ActionListener<String>)chunkListener), delay);
                } else {
                    chunkListener.onFailure(e);
                }
            }
        });
    }

    private void scheduleRetry(Runnable runnable, long delayMs) {
        CompletableFuture.delayedExecutor(delayMs, TimeUnit.MILLISECONDS).execute(runnable);
    }

    public void predictSingleChunk(String modelId, MLInput mlInput, ActionListener<String> listener) {
        this.mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> listener.onResponse((Object)this.extractResponseContent((MLOutput)mlOutput)), arg_0 -> listener.onFailure(arg_0)));
    }

    private List<MLInput> getMLInputs(int tokenLimit, String searchText, String reference, Map<String, String> hits) {
        ArrayList<MLInput> mlInputs = new ArrayList<MLInput>();
        HashMap<String, String> currentChunk = new HashMap<String, String>();
        for (Map.Entry<String, String> entry : hits.entrySet()) {
            HashMap<String, String> tempChunk = new HashMap<String, String>(currentChunk);
            tempChunk.put(entry.getKey(), entry.getValue());
            String messages = this.formatMessages(searchText, reference, tempChunk);
            int totalTokens = TokenizerUtil.countTokens(messages);
            if (totalTokens > tokenLimit) {
                if (currentChunk.isEmpty()) {
                    LOGGER.warn("Entry with key {} causes total tokens to exceed limit of {}", (Object)entry.getKey(), (Object)tokenLimit);
                    HashMap<String, String> singleEntryChunk = new HashMap<String, String>();
                    HashMap<String, String> testChunk = new HashMap<String, String>();
                    testChunk.put(entry.getKey(), entry.getValue());
                    String testMessages = this.formatMessages(searchText, reference, testChunk);
                    int excessTokens = TokenizerUtil.countTokens(testMessages) - tokenLimit;
                    int currentTokens = TokenizerUtil.countTokens(entry.getValue());
                    String truncatedValue = TokenizerUtil.truncateString(entry.getValue(), Math.max(1, currentTokens - excessTokens));
                    singleEntryChunk.put(entry.getKey(), truncatedValue);
                    mlInputs.add(this.createMLInput(searchText, reference, singleEntryChunk));
                    continue;
                }
                mlInputs.add(this.createMLInput(searchText, reference, currentChunk));
                currentChunk = new HashMap();
                currentChunk.put(entry.getKey(), entry.getValue());
                continue;
            }
            currentChunk.put(entry.getKey(), entry.getValue());
        }
        if (!currentChunk.isEmpty()) {
            mlInputs.add(this.createMLInput(searchText, reference, currentChunk));
        }
        return mlInputs;
    }

    private String formatMessages(String searchText, String reference, Map<String, String> hits) {
        try {
            String hitsJson;
            try (XContentBuilder builder = XContentFactory.jsonBuilder();){
                builder.startArray();
                for (Map.Entry<String, String> hit : hits.entrySet()) {
                    builder.startObject();
                    builder.field("id", hit.getKey());
                    builder.field("source", hit.getValue());
                    builder.endObject();
                }
                builder.endArray();
                hitsJson = builder.toString();
            }
            String userContent = Objects.isNull(reference) || reference.isEmpty() ? String.format(Locale.ROOT, "SearchText - %s; Hits - %s", searchText, hitsJson) : String.format(Locale.ROOT, "SearchText: %s; Reference: %s; Hits: %s", searchText, reference, hitsJson);
            return String.format(Locale.ROOT, "[{\"role\":\"system\",\"content\":\"%s\"},{\"role\":\"user\",\"content\":\"%s\"}]", MLConstants.PROMPT_SEARCH_RELEVANCE, MLConstants.escapeJson(userContent));
        }
        catch (IOException e) {
            LOGGER.error("Error converting hits to JSON string", (Throwable)e);
            throw new IllegalArgumentException("Failed to process hits", e);
        }
    }

    private MLInput createMLInput(String searchText, String reference, Map<String, String> hits) {
        HashMap<String, String> parameters = new HashMap<String, String>();
        parameters.put("messages", this.formatMessages(searchText, reference, hits));
        return MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset((MLInputDataset)new RemoteInferenceInputDataSet(parameters)).build();
    }

    private String extractResponseContent(MLOutput mlOutput) {
        if (!(mlOutput instanceof ModelTensorOutput)) {
            throw new IllegalArgumentException("Expected ModelTensorOutput, but got " + mlOutput.getClass().getSimpleName());
        }
        ModelTensorOutput modelTensorOutput = (ModelTensorOutput)mlOutput;
        List tensorOutputList = modelTensorOutput.getMlModelOutputs();
        if (CollectionUtils.isEmpty((Collection)tensorOutputList) || CollectionUtils.isEmpty((Collection)((ModelTensors)tensorOutputList.get(0)).getMlModelTensors())) {
            throw new IllegalStateException("Empty model result produced. Expected at least [1] tensor output and [1] model tensor, but got [0]");
        }
        ModelTensor tensor = (ModelTensor)((ModelTensors)tensorOutputList.get(0)).getMlModelTensors().get(0);
        Map dataMap = tensor.getDataAsMap();
        Map choices = (Map)((List)dataMap.get("choices")).get(0);
        Map message = (Map)choices.get("message");
        String content = (String)message.get("content");
        return content;
    }

    private void handleChunkCompletion(int chunkIndex, String response, Exception error, int totalChunks, ConcurrentMap<Integer, String> succeededChunks, ConcurrentMap<Integer, String> failedChunks, boolean ignoreFailure, AtomicInteger processedChunks, ActionListener<ChunkResult> progressListener) {
        block6: {
            try {
                if (error != null) {
                    String errorMessage = error.getMessage();
                    failedChunks.put(chunkIndex, errorMessage);
                    if (!ignoreFailure) {
                        progressListener.onFailure(error);
                        return;
                    }
                } else {
                    succeededChunks.put(chunkIndex, response);
                }
                int processed = processedChunks.incrementAndGet();
                boolean isLastChunk = processed == totalChunks;
                ChunkResult result = new ChunkResult(chunkIndex, totalChunks, isLastChunk, new HashMap<Integer, String>(succeededChunks), new HashMap<Integer, String>(failedChunks));
                progressListener.onResponse((Object)result);
                if (isLastChunk) {
                    this.handleFinalStatus(result, ignoreFailure, progressListener);
                }
            }
            catch (Exception e) {
                LOGGER.error("Error handling chunk completion for chunk {}", (Object)chunkIndex, (Object)e);
                if (ignoreFailure) break block6;
                progressListener.onFailure(e);
            }
        }
    }

    private void handleFinalStatus(ChunkResult finalResult, boolean ignoreFailure, ActionListener<ChunkResult> progressListener) {
        if (finalResult.getFailedChunksCount() > 0 && !ignoreFailure) {
            String errorMessage = String.format(Locale.ROOT, "Failed to process %d out of %d chunks", finalResult.getFailedChunksCount(), finalResult.getTotalChunks());
            progressListener.onFailure((Exception)new RuntimeException(errorMessage));
        } else {
            progressListener.onResponse((Object)finalResult);
        }
    }
}

