/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.nlp.tokenizers;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.OptionalInt;
import java.util.stream.Collectors;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Strings;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertJapaneseTokenization;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.DebertaV2Tokenization;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.MPNetTokenization;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RobertaTokenization;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.XLMRobertaTokenization;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.inference.nlp.NlpTask;
import org.elasticsearch.xpack.ml.inference.nlp.Vocabulary;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertJapaneseTokenizer;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.DebertaV2Tokenizer;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.DelimitedToken;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.MPNetTokenizer;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.RobertaTokenizer;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.XLMRobertaTokenizer;

public abstract class NlpTokenizer
implements Releasable {
    public static final int CALC_DEFAULT_SPAN_VALUE = -2;

    abstract int clsTokenId();

    abstract int sepTokenId();

    abstract int maxSequenceLength();

    abstract boolean isWithSpecialTokens();

    abstract int numExtraTokensForSingleSequence();

    abstract int getNumExtraTokensForSeqPair();

    int defaultSpanForChunking(int maxWindowSize) {
        return (maxWindowSize - this.numExtraTokensForSingleSequence()) / 2;
    }

    public abstract TokenizationResult buildTokenizationResult(List<TokenizationResult.Tokens> var1);

    public final List<TokenizationResult.Tokens> tokenize(String seq, Tokenization.Truncate truncate, int span, int sequenceId, Integer windowSize) {
        if (windowSize == null) {
            windowSize = this.maxSequenceLength();
        }
        InnerTokenization innerResult = this.innerTokenize(seq);
        List<? extends DelimitedToken.Encoded> tokenIds = innerResult.tokens();
        List<Integer> tokenPositionMap = innerResult.tokenPositionMap();
        int numTokens = this.isWithSpecialTokens() ? tokenIds.size() + this.numExtraTokensForSingleSequence() : tokenIds.size();
        boolean isTruncated = false;
        if (numTokens > windowSize) {
            switch (truncate) {
                case FIRST: 
                case SECOND: 
                case BALANCED: {
                    isTruncated = true;
                    tokenIds = tokenIds.subList(0, this.isWithSpecialTokens() ? windowSize - this.numExtraTokensForSingleSequence() : windowSize);
                    tokenPositionMap = tokenPositionMap.subList(0, this.isWithSpecialTokens() ? windowSize - this.numExtraTokensForSingleSequence() : windowSize);
                    break;
                }
                case NONE: {
                    if (span != -1) break;
                    throw ExceptionsHelper.badRequestException((String)"Input too large. The tokenized input length [{}] exceeds the maximum sequence length [{}]", (Object[])new Object[]{numTokens, windowSize});
                }
            }
        }
        if (numTokens <= windowSize || span == -1) {
            return List.of(this.createTokensBuilder(this.clsTokenId(), this.sepTokenId(), this.isWithSpecialTokens()).addSequence(tokenIds.stream().map(DelimitedToken.Encoded::getEncoding).collect(Collectors.toList()), tokenPositionMap).build(seq, isTruncated, innerResult.tokens, -1, sequenceId));
        }
        if (span == -2) {
            span = this.defaultSpanForChunking(windowSize);
        }
        ArrayList<TokenizationResult.Tokens> toReturn = new ArrayList<TokenizationResult.Tokens>();
        int splitEndPos = 0;
        int splitStartPos = 0;
        int spanPrev = -1;
        while (splitEndPos < tokenIds.size()) {
            splitEndPos = Math.min(splitStartPos + (this.isWithSpecialTokens() ? windowSize - this.numExtraTokensForSingleSequence() : windowSize), tokenIds.size());
            if (splitEndPos != tokenIds.size()) {
                while (splitEndPos > splitStartPos + 1 && Objects.equals(tokenPositionMap.get(splitEndPos), tokenPositionMap.get(splitEndPos - 1))) {
                    --splitEndPos;
                }
            }
            toReturn.add(this.createTokensBuilder(this.clsTokenId(), this.sepTokenId(), this.isWithSpecialTokens()).addSequence(tokenIds.subList(splitStartPos, splitEndPos).stream().map(DelimitedToken.Encoded::getEncoding).collect(Collectors.toList()), tokenPositionMap.subList(splitStartPos, splitEndPos)).build(seq, false, tokenIds.subList(splitStartPos, splitEndPos), spanPrev, sequenceId));
            spanPrev = span;
            int prevSplitStart = splitStartPos;
            splitStartPos = splitEndPos - span;
            if (splitStartPos >= tokenIds.size()) continue;
            while (splitStartPos > prevSplitStart + 1 && Objects.equals(tokenPositionMap.get(splitStartPos), tokenPositionMap.get(splitStartPos - 1))) {
                --splitStartPos;
                ++spanPrev;
            }
        }
        return toReturn;
    }

    public TokenizationResult.Tokens tokenize(String seq1, String seq2, Tokenization.Truncate truncate, int sequenceId) {
        return this.tokenize(seq1, this.innerTokenize(seq1), seq2, truncate, sequenceId);
    }

    public TokenizationResult.Tokens tokenize(String seq1, InnerTokenization innerResultSeq1, String seq2, Tokenization.Truncate truncate, int sequenceId) {
        List<? extends DelimitedToken.Encoded> tokenIdsSeq1 = innerResultSeq1.tokens;
        List<Integer> tokenPositionMapSeq1 = innerResultSeq1.tokenPositionMap;
        InnerTokenization innerResultSeq2 = this.innerTokenize(seq2);
        List<? extends DelimitedToken.Encoded> tokenIdsSeq2 = innerResultSeq2.tokens;
        List<Integer> tokenPositionMapSeq2 = innerResultSeq2.tokenPositionMap;
        if (!this.isWithSpecialTokens()) {
            throw new IllegalArgumentException("Unable to do sequence pair tokenization without special tokens");
        }
        int extraTokens = this.getNumExtraTokensForSeqPair();
        int numTokens = tokenIdsSeq1.size() + tokenIdsSeq2.size() + extraTokens;
        boolean isTruncated = false;
        if (numTokens > this.maxSequenceLength()) {
            switch (truncate) {
                case FIRST: {
                    isTruncated = true;
                    if (tokenIdsSeq2.size() > this.maxSequenceLength() - extraTokens) {
                        throw ExceptionsHelper.badRequestException((String)"Attempting truncation [{}] but input is too large for the second sequence. The tokenized input length [{}] exceeds the maximum sequence length [{}], when taking special tokens into account", (Object[])new Object[]{truncate.toString(), tokenIdsSeq2.size(), this.maxSequenceLength() - extraTokens});
                    }
                    tokenIdsSeq1 = tokenIdsSeq1.subList(0, this.maxSequenceLength() - extraTokens - tokenIdsSeq2.size());
                    tokenPositionMapSeq1 = tokenPositionMapSeq1.subList(0, this.maxSequenceLength() - extraTokens - tokenIdsSeq2.size());
                    break;
                }
                case SECOND: {
                    isTruncated = true;
                    if (tokenIdsSeq1.size() > this.maxSequenceLength() - extraTokens) {
                        throw ExceptionsHelper.badRequestException((String)"Attempting truncation [{}] but input is too large for the first sequence. The tokenized input length [{}] exceeds the maximum sequence length [{}], when taking special tokens into account", (Object[])new Object[]{truncate.toString(), tokenIdsSeq1.size(), this.maxSequenceLength() - extraTokens});
                    }
                    tokenIdsSeq2 = tokenIdsSeq2.subList(0, this.maxSequenceLength() - extraTokens - tokenIdsSeq1.size());
                    tokenPositionMapSeq2 = tokenPositionMapSeq2.subList(0, this.maxSequenceLength() - extraTokens - tokenIdsSeq1.size());
                    break;
                }
                case BALANCED: {
                    isTruncated = true;
                    int firstSequenceLength = 0;
                    firstSequenceLength = tokenIdsSeq2.size() > (this.maxSequenceLength() - this.getNumExtraTokensForSeqPair()) / 2 ? Math.min(tokenIdsSeq1.size(), (this.maxSequenceLength() - this.getNumExtraTokensForSeqPair()) / 2) : Math.min(tokenIdsSeq1.size(), this.maxSequenceLength() - tokenIdsSeq2.size() - this.getNumExtraTokensForSeqPair());
                    int secondSequenceLength = Math.min(tokenIdsSeq2.size(), this.maxSequenceLength() - firstSequenceLength - this.getNumExtraTokensForSeqPair());
                    tokenIdsSeq1 = tokenIdsSeq1.subList(0, firstSequenceLength);
                    tokenPositionMapSeq1 = tokenPositionMapSeq1.subList(0, firstSequenceLength);
                    tokenIdsSeq2 = tokenIdsSeq2.subList(0, secondSequenceLength);
                    tokenPositionMapSeq2 = tokenPositionMapSeq2.subList(0, secondSequenceLength);
                    break;
                }
                case NONE: {
                    throw ExceptionsHelper.badRequestException((String)"Input too large. The tokenized input length [{}] exceeds the maximum sequence length [{}]", (Object[])new Object[]{numTokens, this.maxSequenceLength()});
                }
            }
        }
        return this.createTokensBuilder(this.clsTokenId(), this.sepTokenId(), this.isWithSpecialTokens()).addSequencePair(tokenIdsSeq1.stream().map(DelimitedToken.Encoded::getEncoding).collect(Collectors.toList()), tokenPositionMapSeq1, tokenIdsSeq2.stream().map(DelimitedToken.Encoded::getEncoding).collect(Collectors.toList()), tokenPositionMapSeq2).build(List.of(seq1, seq2), isTruncated, List.of(innerResultSeq1.tokens, innerResultSeq2.tokens), -1, sequenceId);
    }

    public List<TokenizationResult.Tokens> tokenize(String seq1, String seq2, Tokenization.Truncate truncate, int span, int sequenceId) {
        if (!this.isWithSpecialTokens()) {
            throw new IllegalArgumentException("Unable to do sequence pair tokenization without special tokens");
        }
        InnerTokenization innerResultSeq1 = this.innerTokenize(seq1);
        List<? extends DelimitedToken.Encoded> tokenIdsSeq1 = innerResultSeq1.tokens;
        List<Integer> tokenPositionMapSeq1 = innerResultSeq1.tokenPositionMap;
        InnerTokenization innerResultSeq2 = this.innerTokenize(seq2);
        List<? extends DelimitedToken.Encoded> tokenIdsSeq2 = innerResultSeq2.tokens;
        List<Integer> tokenPositionMapSeq2 = innerResultSeq2.tokenPositionMap;
        int extraTokens = this.getNumExtraTokensForSeqPair();
        int numTokens = tokenIdsSeq1.size() + tokenIdsSeq2.size() + extraTokens;
        boolean isTruncated = false;
        if (numTokens > this.maxSequenceLength() && span < 0) {
            switch (truncate) {
                case FIRST: {
                    isTruncated = true;
                    if (tokenIdsSeq2.size() > this.maxSequenceLength() - extraTokens) {
                        throw ExceptionsHelper.badRequestException((String)"Attempting truncation [{}] but input is too large for the second sequence. The tokenized input length [{}] exceeds the maximum sequence length [{}], when taking special tokens into account", (Object[])new Object[]{truncate.toString(), tokenIdsSeq2.size(), this.maxSequenceLength() - extraTokens});
                    }
                    tokenIdsSeq1 = tokenIdsSeq1.subList(0, this.maxSequenceLength() - extraTokens - tokenIdsSeq2.size());
                    tokenPositionMapSeq1 = tokenPositionMapSeq1.subList(0, this.maxSequenceLength() - extraTokens - tokenIdsSeq2.size());
                    break;
                }
                case SECOND: {
                    isTruncated = true;
                    if (tokenIdsSeq1.size() > this.maxSequenceLength() - extraTokens) {
                        throw ExceptionsHelper.badRequestException((String)"Attempting truncation [{}] but input is too large for the first sequence. The tokenized input length [{}] exceeds the maximum sequence length [{}], when taking special tokens into account", (Object[])new Object[]{truncate.toString(), tokenIdsSeq1.size(), this.maxSequenceLength() - extraTokens});
                    }
                    tokenIdsSeq2 = tokenIdsSeq2.subList(0, this.maxSequenceLength() - extraTokens - tokenIdsSeq1.size());
                    tokenPositionMapSeq2 = tokenPositionMapSeq2.subList(0, this.maxSequenceLength() - extraTokens - tokenIdsSeq1.size());
                    break;
                }
                case BALANCED: {
                    isTruncated = true;
                    int firstSequenceLength = 0;
                    firstSequenceLength = tokenIdsSeq2.size() > (this.maxSequenceLength() - this.getNumExtraTokensForSeqPair()) / 2 ? Math.min(tokenIdsSeq1.size(), (this.maxSequenceLength() - this.getNumExtraTokensForSeqPair()) / 2) : Math.min(tokenIdsSeq1.size(), this.maxSequenceLength() - tokenIdsSeq2.size() - this.getNumExtraTokensForSeqPair());
                    int secondSequenceLength = Math.min(tokenIdsSeq2.size(), this.maxSequenceLength() - firstSequenceLength - this.getNumExtraTokensForSeqPair());
                    tokenIdsSeq1 = tokenIdsSeq1.subList(0, firstSequenceLength);
                    tokenPositionMapSeq1 = tokenPositionMapSeq1.subList(0, firstSequenceLength);
                    tokenIdsSeq2 = tokenIdsSeq2.subList(0, secondSequenceLength);
                    tokenPositionMapSeq2 = tokenPositionMapSeq2.subList(0, secondSequenceLength);
                    break;
                }
                case NONE: {
                    throw ExceptionsHelper.badRequestException((String)"Input too large. The tokenized input length [{}] exceeds the maximum sequence length [{}]", (Object[])new Object[]{numTokens, this.maxSequenceLength()});
                }
            }
        }
        if (isTruncated || numTokens < this.maxSequenceLength()) {
            return List.of(this.createTokensBuilder(this.clsTokenId(), this.sepTokenId(), this.isWithSpecialTokens()).addSequencePair(tokenIdsSeq1.stream().map(DelimitedToken.Encoded::getEncoding).collect(Collectors.toList()), tokenPositionMapSeq1, tokenIdsSeq2.stream().map(DelimitedToken.Encoded::getEncoding).collect(Collectors.toList()), tokenPositionMapSeq2).build(List.of(seq1, seq2), isTruncated, List.of(innerResultSeq1.tokens, innerResultSeq2.tokens), -1, sequenceId));
        }
        ArrayList<TokenizationResult.Tokens> toReturn = new ArrayList<TokenizationResult.Tokens>();
        int splitEndPos = 0;
        int splitStartPos = 0;
        int spanPrev = -1;
        List<Integer> seq1TokenIds = tokenIdsSeq1.stream().map(DelimitedToken.Encoded::getEncoding).collect(Collectors.toList());
        int trueMaxSeqLength = this.maxSequenceLength() - extraTokens - tokenIdsSeq1.size();
        if (trueMaxSeqLength <= 0) {
            throw new IllegalArgumentException(Strings.format((String)"Unable to do sequence pair tokenization: the first sequence [%d tokens] is longer than the max sequence length [%d tokens]", (Object[])new Object[]{tokenIdsSeq1.size() + extraTokens, this.maxSequenceLength()}));
        }
        if (span > trueMaxSeqLength) {
            throw new IllegalArgumentException(Strings.format((String)"Unable to do sequence pair tokenization: the combined first sequence, span length and delimiting tokens [%d + %d + %d = %d tokens] is longer than the max sequence length [%d tokens]. Reduce the size of the [span] window.", (Object[])new Object[]{tokenIdsSeq1.size(), span, extraTokens, tokenIdsSeq1.size() + span + extraTokens, this.maxSequenceLength()}));
        }
        while (splitEndPos < tokenIdsSeq2.size()) {
            splitEndPos = Math.min(splitStartPos + trueMaxSeqLength, tokenIdsSeq2.size());
            if (splitEndPos != tokenIdsSeq2.size()) {
                while (splitEndPos > splitStartPos + 1 && Objects.equals(tokenPositionMapSeq2.get(splitEndPos), tokenPositionMapSeq2.get(splitEndPos - 1))) {
                    --splitEndPos;
                }
            }
            toReturn.add(this.createTokensBuilder(this.clsTokenId(), this.sepTokenId(), this.isWithSpecialTokens()).addSequencePair(seq1TokenIds, tokenPositionMapSeq1, tokenIdsSeq2.subList(splitStartPos, splitEndPos).stream().map(DelimitedToken.Encoded::getEncoding).collect(Collectors.toList()), tokenPositionMapSeq2.subList(splitStartPos, splitEndPos)).build(List.of(seq1, seq2), false, List.of(tokenIdsSeq1, tokenIdsSeq2.subList(splitStartPos, splitEndPos)), spanPrev, sequenceId));
            spanPrev = span;
            int prevSplitStart = splitStartPos;
            splitStartPos = splitEndPos - span;
            if (splitStartPos <= prevSplitStart) {
                throw new IllegalStateException("Tokenization cannot be satisfied with the current span setting. Consider decreasing the span setting");
            }
            if (splitStartPos >= tokenIdsSeq2.size()) continue;
            while (splitStartPos > prevSplitStart + 1 && Objects.equals(tokenPositionMapSeq2.get(splitStartPos), tokenPositionMapSeq2.get(splitStartPos - 1))) {
                --splitStartPos;
                ++spanPrev;
            }
        }
        return toReturn;
    }

    public abstract NlpTask.RequestBuilder requestBuilder();

    public abstract OptionalInt getPadTokenId();

    public abstract String getPadToken();

    public abstract OptionalInt getMaskTokenId();

    public abstract String getMaskToken();

    public abstract List<String> getVocabulary();

    public int getSpan() {
        return -1;
    }

    abstract TokenizationResult.TokensBuilder createTokensBuilder(int var1, int var2, boolean var3);

    public abstract InnerTokenization innerTokenize(String var1);

    public static NlpTokenizer build(Vocabulary vocabulary, Tokenization params) throws IOException {
        ExceptionsHelper.requireNonNull((Object)params, (ParseField)NlpConfig.TOKENIZATION);
        ExceptionsHelper.requireNonNull((Object)vocabulary, (ParseField)NlpConfig.VOCABULARY);
        if (params instanceof BertTokenization) {
            return BertTokenizer.builder(vocabulary.get(), params).build();
        }
        if (params instanceof BertJapaneseTokenization) {
            return BertJapaneseTokenizer.builder(vocabulary.get(), params).build();
        }
        if (params instanceof MPNetTokenization) {
            return MPNetTokenizer.mpBuilder(vocabulary.get(), params).build();
        }
        if (params instanceof RobertaTokenization) {
            RobertaTokenization robertaTokenization = (RobertaTokenization)params;
            return RobertaTokenizer.builder(vocabulary.get(), vocabulary.merges(), robertaTokenization).build();
        }
        if (params instanceof XLMRobertaTokenization) {
            XLMRobertaTokenization xlmRobertaTokenization = (XLMRobertaTokenization)params;
            return XLMRobertaTokenizer.builder(vocabulary.get(), vocabulary.scores(), xlmRobertaTokenization).build();
        }
        if (params instanceof DebertaV2Tokenization) {
            DebertaV2Tokenization debertaV2Tokenization = (DebertaV2Tokenization)params;
            return DebertaV2Tokenizer.builder(vocabulary.get(), vocabulary.scores(), debertaV2Tokenization).build();
        }
        throw new IllegalArgumentException("unknown tokenization type [" + params.getName() + "]");
    }

    public record InnerTokenization(List<? extends DelimitedToken.Encoded> tokens, List<Integer> tokenPositionMap) {
    }
}

