/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.langident;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.CustomWordEmbedding;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.langident.LangNetLayer;
import org.elasticsearch.xpack.core.ml.inference.utils.Statistics;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

public class LangIdentNeuralNetwork
implements StrictlyParsedTrainedModel,
LenientlyParsedTrainedModel,
InferenceModel {
    public static final ParseField NAME = new ParseField("lang_ident_neural_network", new String[0]);
    public static final ParseField EMBEDDED_VECTOR_FEATURE_NAME = new ParseField("embedded_vector_feature_name", new String[0]);
    public static final ParseField HIDDEN_LAYER = new ParseField("hidden_layer", new String[0]);
    public static final ParseField SOFTMAX_LAYER = new ParseField("softmax_layer", new String[0]);
    public static final ConstructingObjectParser<LangIdentNeuralNetwork, Void> STRICT_PARSER = LangIdentNeuralNetwork.createParser(false);
    public static final ConstructingObjectParser<LangIdentNeuralNetwork, Void> LENIENT_PARSER = LangIdentNeuralNetwork.createParser(true);
    private static final List<String> LANGUAGE_NAMES = Arrays.asList("eo", "co", "eu", "ta", "de", "mt", "ps", "te", "su", "uz", "zh-Latn", "ne", "nl", "sw", "sq", "hmn", "ja", "no", "mn", "so", "ko", "kk", "sl", "ig", "mr", "th", "zu", "ml", "hr", "bs", "lo", "sd", "cy", "hy", "uk", "pt", "lv", "iw", "cs", "vi", "jv", "be", "km", "mk", "tr", "fy", "am", "zh", "da", "sv", "fi", "ht", "af", "la", "id", "fil", "sm", "ca", "el", "ka", "sr", "it", "sk", "ru", "ru-Latn", "bg", "ny", "fa", "haw", "gl", "et", "ms", "gd", "bg-Latn", "ha", "is", "ur", "mi", "hi", "bn", "hi-Latn", "fr", "yi", "hu", "xh", "my", "tg", "ro", "ar", "lb", "el-Latn", "st", "ceb", "kn", "az", "si", "ky", "mg", "en", "gu", "es", "pl", "ja-Latn", "ga", "lt", "sn", "yo", "pa", "ku");
    private static final int MISSING_VALID_TXT_CLASSIFICATION = LANGUAGE_NAMES.size() - 1;
    private static final String MISSING_VALID_TXT_CLASSIFICATION_STR = "zxx";
    private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(LangIdentNeuralNetwork.class);
    private final LangNetLayer hiddenLayer;
    private final LangNetLayer softmaxLayer;
    private final String embeddedVectorFeatureName;

    private static ConstructingObjectParser<LangIdentNeuralNetwork, Void> createParser(boolean lenient) {
        ConstructingObjectParser parser = new ConstructingObjectParser(NAME.getPreferredName(), lenient, a -> new LangIdentNeuralNetwork((String)a[0], (LangNetLayer)a[1], (LangNetLayer)a[2]));
        parser.declareString(ConstructingObjectParser.constructorArg(), EMBEDDED_VECTOR_FEATURE_NAME);
        parser.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> lenient ? (LangNetLayer)LangNetLayer.LENIENT_PARSER.apply(p, c) : (LangNetLayer)LangNetLayer.STRICT_PARSER.apply(p, c), HIDDEN_LAYER);
        parser.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> lenient ? (LangNetLayer)LangNetLayer.LENIENT_PARSER.apply(p, c) : (LangNetLayer)LangNetLayer.STRICT_PARSER.apply(p, c), SOFTMAX_LAYER);
        return parser;
    }

    public static LangIdentNeuralNetwork fromXContentStrict(XContentParser parser) {
        return (LangIdentNeuralNetwork)STRICT_PARSER.apply(parser, null);
    }

    public static LangIdentNeuralNetwork fromXContentLenient(XContentParser parser) {
        return (LangIdentNeuralNetwork)LENIENT_PARSER.apply(parser, null);
    }

    public LangIdentNeuralNetwork(String embeddedVectorFeatureName, LangNetLayer hiddenLayer, LangNetLayer softmaxLayer) {
        this.embeddedVectorFeatureName = ExceptionsHelper.requireNonNull(embeddedVectorFeatureName, EMBEDDED_VECTOR_FEATURE_NAME);
        this.hiddenLayer = ExceptionsHelper.requireNonNull(hiddenLayer, HIDDEN_LAYER);
        this.softmaxLayer = ExceptionsHelper.requireNonNull(softmaxLayer, SOFTMAX_LAYER);
    }

    public LangIdentNeuralNetwork(StreamInput in) throws IOException {
        this.embeddedVectorFeatureName = in.readString();
        this.hiddenLayer = new LangNetLayer(in);
        this.softmaxLayer = new LangNetLayer(in);
    }

    @Override
    public InferenceResults infer(Map<String, Object> fields, InferenceConfig config, Map<String, String> featureDecoderMap) {
        if (config.requestingImportance()) {
            throw ExceptionsHelper.badRequestException("[{}] model does not supports feature importance", NAME.getPreferredName());
        }
        if (!(config instanceof ClassificationConfig)) {
            throw ExceptionsHelper.badRequestException("[{}] model only supports classification", NAME.getPreferredName());
        }
        Object vector = fields.get(this.embeddedVectorFeatureName);
        if (!(vector instanceof List)) {
            throw ExceptionsHelper.badRequestException("[{}] model could not find non-null collection of embeddings separated by unicode script type [{}]. Please verify that the input is a string.", NAME.getPreferredName(), this.embeddedVectorFeatureName);
        }
        List embeddedVector = (List)vector;
        ClassificationConfig classificationConfig = (ClassificationConfig)config;
        if (embeddedVector.isEmpty()) {
            return new ClassificationInferenceResults((double)MISSING_VALID_TXT_CLASSIFICATION, MISSING_VALID_TXT_CLASSIFICATION_STR, Collections.emptyList(), Collections.emptyList(), (InferenceConfig)classificationConfig, (Double)1.0, (Double)1.0);
        }
        double[] probabilities = new double[LANGUAGE_NAMES.size()];
        int totalLen = 0;
        for (Object vec : embeddedVector) {
            if (!(vec instanceof CustomWordEmbedding.StringLengthAndEmbedding)) continue;
            CustomWordEmbedding.StringLengthAndEmbedding stringLengthAndEmbedding = (CustomWordEmbedding.StringLengthAndEmbedding)vec;
            int square = stringLengthAndEmbedding.getUtf8StringLen() * stringLengthAndEmbedding.getUtf8StringLen();
            totalLen += square;
            double[] h0 = this.hiddenLayer.productPlusBias(false, stringLengthAndEmbedding.getEmbedding());
            double[] score = this.softmaxLayer.productPlusBias(true, h0);
            InferenceHelpers.sumDoubleArrays(probabilities, Statistics.softMax(score), Math.max(square, 1));
        }
        if (totalLen != 0) {
            InferenceHelpers.divMut(probabilities, totalLen);
        }
        Tuple<InferenceHelpers.TopClassificationValue, List<TopClassEntry>> topClasses = InferenceHelpers.topClasses(probabilities, LANGUAGE_NAMES, null, classificationConfig.getNumTopClasses(), PredictionFieldType.STRING);
        InferenceHelpers.TopClassificationValue classificationValue = (InferenceHelpers.TopClassificationValue)topClasses.v1();
        assert (classificationValue.getValue() >= 0 && classificationValue.getValue() < LANGUAGE_NAMES.size()) : "Invalid language predicted. Predicted language index " + String.valueOf(topClasses.v1());
        return new ClassificationInferenceResults((double)classificationValue.getValue(), LANGUAGE_NAMES.get(classificationValue.getValue()), (List<TopClassEntry>)((List)topClasses.v2()), Collections.emptyList(), (InferenceConfig)classificationConfig, (Double)classificationValue.getProbability(), (Double)classificationValue.getScore());
    }

    @Override
    public InferenceResults infer(double[] embeddedVector, InferenceConfig config) {
        throw new UnsupportedOperationException("[lang_ident] does not support nested inference");
    }

    @Override
    public void rewriteFeatureIndices(Map<String, Integer> newFeatureIndexMapping) {
        if (newFeatureIndexMapping != null && !newFeatureIndexMapping.isEmpty()) {
            throw new UnsupportedOperationException("[lang_ident] does not support nested inference");
        }
    }

    @Override
    public String[] getFeatureNames() {
        return new String[]{this.embeddedVectorFeatureName};
    }

    @Override
    public TargetType targetType() {
        return TargetType.CLASSIFICATION;
    }

    @Override
    public void validate() {
    }

    @Override
    public long estimatedNumOperations() {
        long numOps = this.hiddenLayer.getBias().length;
        numOps += (long)this.hiddenLayer.getWeights().length;
        numOps += (long)this.softmaxLayer.getBias().length;
        return numOps += (long)this.softmaxLayer.getWeights().length;
    }

    @Override
    public boolean supportsFeatureImportance() {
        return false;
    }

    public long ramBytesUsed() {
        long size = SHALLOW_SIZE;
        size += RamUsageEstimator.sizeOf((Accountable)this.hiddenLayer);
        return size += RamUsageEstimator.sizeOf((Accountable)this.softmaxLayer);
    }

    public String getWriteableName() {
        return NAME.getPreferredName();
    }

    public void writeTo(StreamOutput out) throws IOException {
        out.writeString(this.embeddedVectorFeatureName);
        this.hiddenLayer.writeTo(out);
        this.softmaxLayer.writeTo(out);
    }

    @Override
    public String getName() {
        return NAME.getPreferredName();
    }

    public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
        builder.startObject();
        builder.field(EMBEDDED_VECTOR_FEATURE_NAME.getPreferredName(), this.embeddedVectorFeatureName);
        builder.field(HIDDEN_LAYER.getPreferredName(), (ToXContent)this.hiddenLayer);
        builder.field(SOFTMAX_LAYER.getPreferredName(), (ToXContent)this.softmaxLayer);
        builder.endObject();
        return builder;
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        LangIdentNeuralNetwork that = (LangIdentNeuralNetwork)o;
        return Objects.equals(this.embeddedVectorFeatureName, that.embeddedVectorFeatureName) && Objects.equals(this.hiddenLayer, that.hiddenLayer) && Objects.equals(this.softmaxLayer, that.softmaxLayer);
    }

    public int hashCode() {
        return Objects.hash(this.embeddedVectorFeatureName, this.hiddenLayer, this.softmaxLayer);
    }
}

