/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.agent.tools;

import com.google.common.collect.ImmutableMap;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.security.AccessController;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.StringJoiner;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import lombok.Generated;
import org.apache.commons.text.StringSubstitutor;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionType;
import org.opensearch.action.admin.indices.mapping.get.GetMappingsRequest;
import org.opensearch.agent.tools.utils.ToolHelper;
import org.opensearch.cluster.metadata.MappingMetadata;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.Strings;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
import org.opensearch.ml.common.spi.tools.WithModelTool;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.transport.client.Client;

@ToolAnnotation(value="CreateAnomalyDetectorTool")
public class CreateAnomalyDetectorTool
implements WithModelTool {
    @Generated
    private static final Logger log = LogManager.getLogger(CreateAnomalyDetectorTool.class);
    public static final String TYPE = "CreateAnomalyDetectorTool";
    private static final String DEFAULT_DESCRIPTION = "This is a tool used to help creating anomaly detector. It takes a required argument which is the name of the index, extract the index mappings and let the LLM to give the suggested aggregation field, aggregation method, category field and the date field which are required to create an anomaly detector.";
    private static final String EXTRACT_INFORMATION_REGEX = "(?s).*\\{category_field=([^|]*)\\|aggregation_field=([^|]*)\\|aggregation_method=([^}]*)}.*";
    private static final Set<String> VALID_FIELD_TYPES = Set.of("keyword", "constant_keyword", "wildcard", "long", "integer", "short", "byte", "double", "float", "half_float", "scaled_float", "unsigned_long", "ip");
    private static final String OUTPUT_KEY_INDEX = "index";
    private static final String OUTPUT_KEY_CATEGORY_FIELD = "categoryField";
    private static final String OUTPUT_KEY_AGGREGATION_FIELD = "aggregationField";
    private static final String OUTPUT_KEY_AGGREGATION_METHOD = "aggregationMethod";
    private static final String OUTPUT_KEY_DATE_FIELDS = "dateFields";
    private static final Map<String, String> DEFAULT_PROMPT_DICT = CreateAnomalyDetectorTool.loadDefaultPromptFromFile();
    private String name = "CreateAnomalyDetectorTool";
    private String description = "This is a tool used to help creating anomaly detector. It takes a required argument which is the name of the index, extract the index mappings and let the LLM to give the suggested aggregation field, aggregation method, category field and the date field which are required to create an anomaly detector.";
    private String version;
    private Client client;
    private String modelId;
    private ModelType modelType;
    private String contextPrompt;
    private Map<String, Object> attributes;

    public CreateAnomalyDetectorTool(Client client, String modelId, String modelType, String contextPrompt) {
        this.client = client;
        this.modelId = modelId;
        if (!ModelType.OPENAI.toString().equalsIgnoreCase(modelType) && !ModelType.CLAUDE.toString().equalsIgnoreCase(modelType)) {
            throw new IllegalArgumentException("Unsupported model_type: " + modelType);
        }
        this.modelType = ModelType.from(modelType);
        this.contextPrompt = contextPrompt.isEmpty() ? DEFAULT_PROMPT_DICT.getOrDefault(this.modelType.toString(), "") : contextPrompt;
    }

    public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
        String tenantId = parameters.get("tenant_id");
        Map<String, String> enrichedParameters = this.enrichParameters(parameters);
        String indexName = enrichedParameters.get(OUTPUT_KEY_INDEX);
        if (Strings.isNullOrEmpty((String)indexName)) {
            throw new IllegalArgumentException("Return this final answer to human directly and do not use other tools: 'Please provide index name'. Please try to directly send this message to human to ask for index name");
        }
        if (indexName.startsWith(".")) {
            throw new IllegalArgumentException("CreateAnomalyDetectionTool doesn't support searching indices starting with '.' since it could be system index, current searching index name: " + indexName);
        }
        GetMappingsRequest getMappingsRequest = (GetMappingsRequest)new GetMappingsRequest().indices(new String[]{indexName});
        this.client.admin().indices().getMappings(getMappingsRequest, ActionListener.wrap(response -> {
            Map mappings = response.getMappings();
            if (mappings.size() == 0) {
                throw new IllegalArgumentException("No mapping found for the index: " + indexName);
            }
            String firstIndexName = (String)mappings.keySet().toArray()[0];
            MappingMetadata mappingMetadata = (MappingMetadata)mappings.get(firstIndexName);
            Map mappingSource = (Map)mappingMetadata.getSourceAsMap().get("properties");
            if (Objects.isNull(mappingSource)) {
                throw new IllegalArgumentException("The index " + indexName + " doesn't have mapping metadata, please add data to it or using another index.");
            }
            HashMap<String, String> fieldsToType = new HashMap<String, String>();
            ToolHelper.extractFieldNamesTypes(mappingSource, fieldsToType, "", true);
            Set<String> dateFields = this.findDateTypeFields(fieldsToType);
            if (dateFields.isEmpty()) {
                throw new IllegalArgumentException("The index " + indexName + " doesn't have date type fields, cannot create an anomaly detector for it.");
            }
            StringJoiner dateFieldsJoiner = new StringJoiner(",");
            dateFields.forEach(dateFieldsJoiner::add);
            Map<String, String> filteredMapping = fieldsToType.entrySet().stream().filter(entry -> VALID_FIELD_TYPES.contains(entry.getValue())).collect(Collectors.toUnmodifiableMap(Map.Entry::getKey, Map.Entry::getValue));
            String prompt = this.constructPrompt(filteredMapping, firstIndexName);
            RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(Collections.singletonMap("prompt", prompt)).build();
            MLPredictionTaskRequest request = new MLPredictionTaskRequest(this.modelId, MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset((MLInputDataset)inputDataSet).build(), null, tenantId);
            this.client.execute((ActionType)MLPredictionTaskAction.INSTANCE, (ActionRequest)request, ActionListener.wrap(mlTaskResponse -> {
                ModelTensorOutput modelTensorOutput = (ModelTensorOutput)mlTaskResponse.getOutput();
                ModelTensors modelTensors = (ModelTensors)modelTensorOutput.getMlModelOutputs().get(0);
                ModelTensor modelTensor = (ModelTensor)modelTensors.getMlModelTensors().get(0);
                Map dataAsMap = modelTensor.getDataAsMap();
                if (dataAsMap == null) {
                    listener.onFailure((Exception)new IllegalStateException("Remote endpoint fails to inference."));
                    return;
                }
                String finalResponse = (String)dataAsMap.get("response");
                if (Strings.isNullOrEmpty((String)finalResponse)) {
                    listener.onFailure((Exception)new IllegalStateException("Remote endpoint fails to inference, no response found."));
                    return;
                }
                Pattern pattern = Pattern.compile(EXTRACT_INFORMATION_REGEX);
                Matcher matcher = pattern.matcher(finalResponse);
                if (!matcher.matches()) {
                    log.error("The inference result from remote endpoint is not valid because the result: [" + finalResponse + "] cannot match the regex: (?s).*\\{category_field=([^|]*)\\|aggregation_field=([^|]*)\\|aggregation_method=([^}]*)}.*");
                    listener.onFailure((Exception)new IllegalStateException("The inference result from remote endpoint is not valid, cannot extract the key information from the result."));
                    return;
                }
                String categoryField = matcher.group(1).replaceAll("\"", "").strip();
                String aggregationField = matcher.group(2).replaceAll("\"", "").strip();
                String aggregationMethod = matcher.group(3).replaceAll("\"", "").strip();
                ImmutableMap result = ImmutableMap.of((Object)OUTPUT_KEY_INDEX, (Object)firstIndexName, (Object)OUTPUT_KEY_CATEGORY_FIELD, (Object)categoryField, (Object)OUTPUT_KEY_AGGREGATION_FIELD, (Object)aggregationField, (Object)OUTPUT_KEY_AGGREGATION_METHOD, (Object)aggregationMethod, (Object)OUTPUT_KEY_DATE_FIELDS, (Object)dateFieldsJoiner.toString());
                listener.onResponse((Object)AccessController.doPrivileged(() -> CreateAnomalyDetectorTool.lambda$run$1((Map)result)));
            }, e -> {
                log.error("fail to predict model: " + String.valueOf(e));
                listener.onFailure(e);
            }));
        }, e -> {
            log.error("failed to get mapping: " + String.valueOf(e));
            if (e.toString().contains("IndexNotFoundException")) {
                listener.onFailure((Exception)new IllegalArgumentException("Return this final answer to human directly and do not use other tools: 'The index doesn't exist, please provide another index and retry'. Please try to directly send this message to human to ask for index name"));
            } else {
                listener.onFailure(e);
            }
        }));
    }

    private Map<String, String> enrichParameters(Map<String, String> parameters) {
        HashMap<String, String> result;
        block2: {
            result = new HashMap<String, String>(parameters);
            try {
                Map chatParameters = (Map)StringUtils.gson.fromJson(parameters.get("input"), Map.class);
                result.putAll(chatParameters);
            }
            catch (Exception e) {
                String indexName = parameters.getOrDefault("input", "");
                if (indexName.isEmpty()) break block2;
                result.put(OUTPUT_KEY_INDEX, indexName);
            }
        }
        return result;
    }

    private Set<String> findDateTypeFields(Map<String, String> fieldsToType) {
        HashSet<String> result = new HashSet<String>();
        for (Map.Entry<String, String> entry : fieldsToType.entrySet()) {
            String value = entry.getValue();
            if (!value.equals("date") && !value.equals("date_nanos")) continue;
            result.add(entry.getKey());
        }
        return result;
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    private static Map<String, String> loadDefaultPromptFromFile() {
        try (InputStream inputStream = CreateAnomalyDetectorTool.class.getResourceAsStream("CreateAnomalyDetectorDefaultPrompt.json");){
            if (inputStream == null) return new HashMap<String, String>();
            Map map = (Map)StringUtils.gson.fromJson(new String(inputStream.readAllBytes(), StandardCharsets.UTF_8), Map.class);
            return map;
        }
        catch (IOException e) {
            log.error("Failed to load prompt from the file CreateAnomalyDetectorDefaultPrompt.json, error: ", (Throwable)e);
        }
        return new HashMap<String, String>();
    }

    private String constructPrompt(Map<String, String> fieldsToType, String indexName) {
        StringJoiner tableInfoJoiner = new StringJoiner("\n");
        for (Map.Entry<String, String> entry : fieldsToType.entrySet()) {
            tableInfoJoiner.add("- " + entry.getKey() + ": " + entry.getValue());
        }
        ImmutableMap indexInfo = ImmutableMap.of((Object)"indexName", (Object)indexName, (Object)"indexMapping", (Object)tableInfoJoiner.toString());
        StringSubstitutor substitutor = new StringSubstitutor((Map)indexInfo, "${indexInfo.", "}");
        return substitutor.replace(this.contextPrompt);
    }

    public boolean validate(Map<String, String> parameters) {
        return parameters != null && parameters.size() != 0;
    }

    public String getType() {
        return TYPE;
    }

    @Generated
    public void setVersion(String version) {
        this.version = version;
    }

    @Generated
    public void setClient(Client client) {
        this.client = client;
    }

    @Generated
    public void setModelId(String modelId) {
        this.modelId = modelId;
    }

    @Generated
    public void setModelType(ModelType modelType) {
        this.modelType = modelType;
    }

    @Generated
    public void setContextPrompt(String contextPrompt) {
        this.contextPrompt = contextPrompt;
    }

    @Generated
    public void setAttributes(Map<String, Object> attributes) {
        this.attributes = attributes;
    }

    @Generated
    public Client getClient() {
        return this.client;
    }

    @Generated
    public String getContextPrompt() {
        return this.contextPrompt;
    }

    @Generated
    public Map<String, Object> getAttributes() {
        return this.attributes;
    }

    @Generated
    public void setName(String name) {
        this.name = name;
    }

    @Generated
    public String getName() {
        return this.name;
    }

    @Generated
    public String getDescription() {
        return this.description;
    }

    @Generated
    public void setDescription(String description) {
        this.description = description;
    }

    @Generated
    public String getVersion() {
        return this.version;
    }

    @Generated
    public String getModelId() {
        return this.modelId;
    }

    @Generated
    public ModelType getModelType() {
        return this.modelType;
    }

    private static /* synthetic */ String lambda$run$1(Map result) throws Exception {
        return StringUtils.gson.toJson((Object)result);
    }

    static enum ModelType {
        CLAUDE,
        OPENAI;


        public static ModelType from(String value) {
            return ModelType.valueOf(value.toUpperCase(Locale.ROOT));
        }
    }

    public static class Factory
    implements WithModelTool.Factory<CreateAnomalyDetectorTool> {
        private Client client;
        private static Factory INSTANCE;

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        public static Factory getInstance() {
            if (INSTANCE != null) {
                return INSTANCE;
            }
            Class<CreateAnomalyDetectorTool> clazz = CreateAnomalyDetectorTool.class;
            synchronized (CreateAnomalyDetectorTool.class) {
                if (INSTANCE != null) {
                    // ** MonitorExit[var0] (shouldn't be in output)
                    return INSTANCE;
                }
                INSTANCE = new Factory();
                // ** MonitorExit[var0] (shouldn't be in output)
                return INSTANCE;
            }
        }

        public void init(Client client) {
            this.client = client;
        }

        public CreateAnomalyDetectorTool create(Map<String, Object> map) {
            String modelId = (String)map.getOrDefault("model_id", "");
            if (modelId.isEmpty()) {
                throw new IllegalArgumentException("model_id cannot be empty.");
            }
            String modelType = (String)map.getOrDefault("model_type", ModelType.CLAUDE.toString());
            if (modelType.isEmpty()) {
                modelType = ModelType.CLAUDE.toString();
            } else if (!ModelType.OPENAI.toString().equalsIgnoreCase(modelType) && !ModelType.CLAUDE.toString().equalsIgnoreCase(modelType)) {
                throw new IllegalArgumentException("Unsupported model_type: " + modelType);
            }
            String prompt = (String)map.getOrDefault("prompt", "");
            return new CreateAnomalyDetectorTool(this.client, modelId, modelType, prompt);
        }

        public String getDefaultDescription() {
            return CreateAnomalyDetectorTool.DEFAULT_DESCRIPTION;
        }

        public String getDefaultType() {
            return CreateAnomalyDetectorTool.TYPE;
        }

        public String getDefaultVersion() {
            return null;
        }

        public List<String> getAllModelKeys() {
            return List.of("model_id");
        }
    }
}

