/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.util.prune;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.PriorityQueue;
import org.opensearch.common.collect.Tuple;
import org.opensearch.neuralsearch.util.prune.PruneType;

public class PruneUtils {
    public static final String PRUNE_TYPE_FIELD = "prune_type";
    public static final String PRUNE_RATIO_FIELD = "prune_ratio";

    private static Tuple<Map<String, Float>, Map<String, Float>> pruneByTopK(Map<String, Float> sparseVector, float k, boolean requiresPrunedEntries) {
        HashMap<String, Float> lowScores;
        PriorityQueue<Map.Entry<String, Float>> pq = new PriorityQueue<Map.Entry<String, Float>>((a, b) -> Float.compare(((Float)a.getValue()).floatValue(), ((Float)b.getValue()).floatValue()));
        for (Map.Entry<String, Float> entry : sparseVector.entrySet()) {
            if (pq.size() < (int)k) {
                pq.offer(entry);
                continue;
            }
            if (!(entry.getValue().floatValue() > ((Float)((Map.Entry)pq.peek()).getValue()).floatValue())) continue;
            pq.poll();
            pq.offer(entry);
        }
        HashMap<String, Float> highScores = new HashMap<String, Float>();
        HashMap<String, Float> hashMap = lowScores = requiresPrunedEntries ? new HashMap<String, Float>(sparseVector) : null;
        while (!pq.isEmpty()) {
            Map.Entry entry = (Map.Entry)pq.poll();
            highScores.put((String)entry.getKey(), (Float)entry.getValue());
            if (!Objects.nonNull(lowScores)) continue;
            lowScores.remove(entry.getKey());
        }
        return new Tuple(highScores, lowScores);
    }

    private static Tuple<Map<String, Float>, Map<String, Float>> pruneByMaxRatio(Map<String, Float> sparseVector, float ratio, boolean requiresPrunedEntries) {
        float maxValue = sparseVector.values().stream().max(Float::compareTo).orElse(Float.valueOf(0.0f)).floatValue();
        HashMap<String, Float> highScores = new HashMap<String, Float>();
        HashMap<String, Float> lowScores = requiresPrunedEntries ? new HashMap<String, Float>() : null;
        for (Map.Entry<String, Float> entry : sparseVector.entrySet()) {
            if (entry.getValue().floatValue() >= ratio * maxValue) {
                highScores.put(entry.getKey(), entry.getValue());
                continue;
            }
            if (!Objects.nonNull(lowScores)) continue;
            lowScores.put(entry.getKey(), entry.getValue());
        }
        return new Tuple(highScores, lowScores);
    }

    private static Tuple<Map<String, Float>, Map<String, Float>> pruneByValue(Map<String, Float> sparseVector, float thresh, boolean requiresPrunedEntries) {
        HashMap<String, Float> highScores = new HashMap<String, Float>();
        HashMap<String, Float> lowScores = requiresPrunedEntries ? new HashMap<String, Float>() : null;
        for (Map.Entry<String, Float> entry : sparseVector.entrySet()) {
            if (entry.getValue().floatValue() >= thresh) {
                highScores.put(entry.getKey(), entry.getValue());
                continue;
            }
            if (!Objects.nonNull(lowScores)) continue;
            lowScores.put(entry.getKey(), entry.getValue());
        }
        return new Tuple(highScores, lowScores);
    }

    private static Tuple<Map<String, Float>, Map<String, Float>> pruneByAlphaMass(Map<String, Float> sparseVector, float alpha, boolean requiresPrunedEntries) {
        ArrayList<Map.Entry<String, Float>> sortedEntries = new ArrayList<Map.Entry<String, Float>>(sparseVector.entrySet());
        sortedEntries.sort(Map.Entry.comparingByValue(Comparator.reverseOrder()));
        float sum = (float)sparseVector.values().stream().mapToDouble(Float::doubleValue).sum();
        float topSum = 0.0f;
        HashMap<String, Float> highScores = new HashMap<String, Float>();
        HashMap<String, Float> lowScores = requiresPrunedEntries ? new HashMap<String, Float>() : null;
        for (Map.Entry entry : sortedEntries) {
            float value = ((Float)entry.getValue()).floatValue();
            if ((topSum += value) <= alpha * sum) {
                highScores.put((String)entry.getKey(), Float.valueOf(value));
                continue;
            }
            if (!Objects.nonNull(lowScores)) continue;
            lowScores.put((String)entry.getKey(), Float.valueOf(value));
        }
        return new Tuple(highScores, lowScores);
    }

    public static Tuple<Map<String, Float>, Map<String, Float>> splitSparseVector(PruneType pruneType, float pruneRatio, Map<String, Float> sparseVector) {
        if (Objects.isNull((Object)pruneType)) {
            throw new IllegalArgumentException("Prune type must be provided");
        }
        if (Objects.isNull(sparseVector)) {
            throw new IllegalArgumentException("Sparse vector must be provided");
        }
        for (Map.Entry<String, Float> entry : sparseVector.entrySet()) {
            if (!(entry.getValue().floatValue() <= 0.0f)) continue;
            throw new IllegalArgumentException("Pruned values must be positive");
        }
        switch (pruneType) {
            case TOP_K: {
                return PruneUtils.pruneByTopK(sparseVector, pruneRatio, true);
            }
            case ALPHA_MASS: {
                return PruneUtils.pruneByAlphaMass(sparseVector, pruneRatio, true);
            }
            case MAX_RATIO: {
                return PruneUtils.pruneByMaxRatio(sparseVector, pruneRatio, true);
            }
            case ABS_VALUE: {
                return PruneUtils.pruneByValue(sparseVector, pruneRatio, true);
            }
        }
        return new Tuple(new HashMap<String, Float>(sparseVector), new HashMap());
    }

    public static Map<String, Float> pruneSparseVector(PruneType pruneType, float pruneRatio, Map<String, Float> sparseVector) {
        if (Objects.isNull((Object)pruneType)) {
            throw new IllegalArgumentException("Prune type must be provided");
        }
        if (Objects.isNull(sparseVector)) {
            throw new IllegalArgumentException("Sparse vector must be provided");
        }
        for (Map.Entry<String, Float> entry : sparseVector.entrySet()) {
            if (!(entry.getValue().floatValue() <= 0.0f)) continue;
            throw new IllegalArgumentException("Pruned values must be positive");
        }
        switch (pruneType) {
            case TOP_K: {
                return (Map)PruneUtils.pruneByTopK(sparseVector, pruneRatio, false).v1();
            }
            case ALPHA_MASS: {
                return (Map)PruneUtils.pruneByAlphaMass(sparseVector, pruneRatio, false).v1();
            }
            case MAX_RATIO: {
                return (Map)PruneUtils.pruneByMaxRatio(sparseVector, pruneRatio, false).v1();
            }
            case ABS_VALUE: {
                return (Map)PruneUtils.pruneByValue(sparseVector, pruneRatio, false).v1();
            }
        }
        return sparseVector;
    }

    public static boolean isValidPruneRatio(PruneType pruneType, float pruneRatio) {
        if (pruneType == null) {
            throw new IllegalArgumentException("Prune type cannot be null");
        }
        switch (pruneType) {
            case TOP_K: {
                return pruneRatio > 0.0f && (double)pruneRatio == Math.floor(pruneRatio);
            }
            case ALPHA_MASS: 
            case MAX_RATIO: {
                return pruneRatio >= 0.0f && pruneRatio < 1.0f;
            }
            case ABS_VALUE: {
                return pruneRatio >= 0.0f;
            }
        }
        return true;
    }

    public static String getValidPruneRatioDescription(PruneType pruneType) {
        if (pruneType == null) {
            throw new IllegalArgumentException("Prune type cannot be null");
        }
        switch (pruneType) {
            case TOP_K: {
                return "prune_ratio should be positive integer.";
            }
            case ALPHA_MASS: 
            case MAX_RATIO: {
                return "prune_ratio should be in the range [0, 1).";
            }
            case ABS_VALUE: {
                return "prune_ratio should be non-negative.";
            }
        }
        return "prune_ratio field is not supported when prune_type is none";
    }
}

