/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.planner;

import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.elasticsearch.common.Strings;
import org.elasticsearch.compute.aggregation.IntermediateStateDesc;
import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
import org.elasticsearch.xpack.esql.core.expression.Alias;
import org.elasticsearch.xpack.esql.core.expression.AttributeMap;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute;
import org.elasticsearch.xpack.esql.core.expression.function.Function;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Count;
import org.elasticsearch.xpack.esql.expression.function.aggregate.CountDistinct;
import org.elasticsearch.xpack.esql.expression.function.aggregate.FromPartial;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Max;
import org.elasticsearch.xpack.esql.expression.function.aggregate.MedianAbsoluteDeviation;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Min;
import org.elasticsearch.xpack.esql.expression.function.aggregate.NumericAggregate;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Percentile;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Rate;
import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialAggregateFunction;
import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialCentroid;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum;
import org.elasticsearch.xpack.esql.expression.function.aggregate.ToPartial;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Top;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Values;

final class AggregateMapper {
    private static final List<String> NUMERIC = List.of("Int", "Long", "Double");
    private static final List<String> SPATIAL = List.of("GeoPoint", "CartesianPoint");
    private static final List<? extends Class<? extends Function>> AGG_FUNCTIONS = List.of(Count.class, CountDistinct.class, Max.class, MedianAbsoluteDeviation.class, Min.class, Percentile.class, SpatialCentroid.class, Sum.class, Values.class, Top.class, Rate.class, FromPartial.class, ToPartial.class);
    private static final Map<AggDef, List<IntermediateStateDesc>> mapper = AGG_FUNCTIONS.stream().flatMap(AggregateMapper::typeAndNames).flatMap(AggregateMapper::groupingAndNonGrouping).collect(Collectors.toUnmodifiableMap(aggDef -> aggDef, AggregateMapper::lookupIntermediateState));
    private final HashMap<Expression, List<NamedExpression>> cache = new HashMap();

    AggregateMapper() {
    }

    public List<NamedExpression> mapNonGrouping(List<? extends Expression> aggregates) {
        return this.doMapping(aggregates, false);
    }

    public List<NamedExpression> mapNonGrouping(Expression aggregate) {
        return this.map(aggregate, false).toList();
    }

    public List<NamedExpression> mapGrouping(List<? extends Expression> aggregates) {
        return this.doMapping(aggregates, true);
    }

    private List<NamedExpression> doMapping(List<? extends Expression> aggregates, boolean grouping) {
        AttributeMap attrToExpressions = new AttributeMap();
        aggregates.stream().flatMap(agg -> this.map((Expression)agg, grouping)).forEach(ne -> attrToExpressions.put(ne.toAttribute(), ne));
        return attrToExpressions.values().stream().toList();
    }

    public List<NamedExpression> mapGrouping(Expression aggregate) {
        return this.map(aggregate, true).toList();
    }

    private Stream<NamedExpression> map(Expression aggregate, boolean grouping) {
        return this.cache.computeIfAbsent(Alias.unwrap((Expression)aggregate), aggKey -> AggregateMapper.computeEntryForAgg(aggKey, grouping)).stream();
    }

    private static List<NamedExpression> computeEntryForAgg(Expression aggregate, boolean grouping) {
        AggDef aggDef = AggregateMapper.aggDefOrNull(aggregate, grouping);
        if (aggDef != null) {
            List<IntermediateStateDesc> is = AggregateMapper.getNonNull(aggDef);
            List<NamedExpression> exp = AggregateMapper.isToNE(is).toList();
            return exp;
        }
        if (aggregate instanceof FieldAttribute || aggregate instanceof MetadataAttribute || aggregate instanceof ReferenceAttribute) {
            return List.of();
        }
        throw new EsqlIllegalArgumentException("unknown agg: " + String.valueOf(aggregate.getClass()) + ": " + String.valueOf(aggregate));
    }

    private static List<IntermediateStateDesc> getNonNull(AggDef aggDef) {
        List<IntermediateStateDesc> l = mapper.get(aggDef);
        if (l == null) {
            throw new EsqlIllegalArgumentException("Cannot find intermediate state for: " + String.valueOf(aggDef));
        }
        return l;
    }

    private static Stream<Tuple<Class<?>, Tuple<String, String>>> typeAndNames(Class<?> clazz) {
        List<String> types;
        List<String> extraConfigs = List.of("");
        if (NumericAggregate.class.isAssignableFrom(clazz)) {
            types = NUMERIC;
        } else if (Max.class.isAssignableFrom(clazz) || Min.class.isAssignableFrom(clazz)) {
            types = List.of("Boolean", "Int", "Long", "Double", "Ip", "BytesRef");
        } else if (clazz == Count.class) {
            types = List.of("");
        } else if (SpatialAggregateFunction.class.isAssignableFrom(clazz)) {
            types = SPATIAL;
            extraConfigs = List.of("SourceValues", "DocValues");
        } else if (Values.class.isAssignableFrom(clazz)) {
            types = List.of("Int", "Long", "Double", "Boolean", "BytesRef");
        } else if (Top.class.isAssignableFrom(clazz)) {
            types = List.of("Boolean", "Int", "Long", "Double", "Ip", "BytesRef");
        } else if (Rate.class.isAssignableFrom(clazz)) {
            types = List.of("Int", "Long", "Double");
        } else if (FromPartial.class.isAssignableFrom(clazz) || ToPartial.class.isAssignableFrom(clazz)) {
            types = List.of("");
        } else if (CountDistinct.class.isAssignableFrom(clazz)) {
            types = Stream.concat(NUMERIC.stream(), Stream.of("Boolean", "BytesRef")).toList();
        } else {
            assert (false) : "unknown aggregate type " + String.valueOf(clazz);
            throw new IllegalArgumentException("unknown aggregate type " + String.valueOf(clazz));
        }
        return AggregateMapper.combinations(types, extraConfigs).map(combo -> new Tuple((Object)clazz, combo));
    }

    private static Stream<Tuple<String, String>> combinations(List<String> types, List<String> extraConfigs) {
        return types.stream().flatMap(type -> extraConfigs.stream().map(config -> new Tuple(type, config)));
    }

    private static Stream<AggDef> groupingAndNonGrouping(Tuple<Class<?>, Tuple<String, String>> tuple) {
        if (((Class)tuple.v1()).isAssignableFrom(Rate.class)) {
            return Stream.of(new AggDef((Class)tuple.v1(), (String)((Tuple)tuple.v2()).v1(), (String)((Tuple)tuple.v2()).v2(), true));
        }
        return Stream.of(new AggDef((Class)tuple.v1(), (String)((Tuple)tuple.v2()).v1(), (String)((Tuple)tuple.v2()).v2(), true), new AggDef((Class)tuple.v1(), (String)((Tuple)tuple.v2()).v1(), (String)((Tuple)tuple.v2()).v2(), false));
    }

    private static AggDef aggDefOrNull(Expression aggregate, boolean grouping) {
        if (aggregate instanceof AggregateFunction) {
            AggregateFunction aggregateFunction = (AggregateFunction)aggregate;
            return new AggDef(((Object)((Object)aggregateFunction)).getClass(), AggregateMapper.dataTypeToString(aggregateFunction.field().dataType(), ((Object)((Object)aggregateFunction)).getClass()), aggregate instanceof SpatialCentroid ? "SourceValues" : "", grouping);
        }
        return null;
    }

    private static List<IntermediateStateDesc> lookupIntermediateState(AggDef aggDef) {
        try {
            return AggregateMapper.lookup(aggDef.aggClazz(), aggDef.type(), aggDef.extra(), aggDef.grouping()).invokeExact();
        }
        catch (Throwable t) {
            throw new EsqlIllegalArgumentException(t);
        }
    }

    private static MethodHandle lookup(Class<?> clazz, String type, String extra, boolean grouping) {
        try {
            return MethodHandles.lookup().findStatic(Class.forName(AggregateMapper.determineAggName(clazz, type, extra, grouping)), "intermediateStateDesc", MethodType.methodType(List.class));
        }
        catch (ClassNotFoundException | IllegalAccessException | NoSuchMethodException e) {
            throw new EsqlIllegalArgumentException(e);
        }
    }

    private static String determineAggName(Class<?> clazz, String type, String extra, boolean grouping) {
        StringBuilder sb = new StringBuilder();
        sb.append(AggregateMapper.determinePackageName(clazz)).append(".");
        sb.append(clazz.getSimpleName());
        sb.append(type);
        sb.append(extra);
        sb.append(grouping ? "Grouping" : "");
        sb.append("AggregatorFunction");
        return sb.toString();
    }

    private static String determinePackageName(Class<?> clazz) {
        if (clazz.getSimpleName().startsWith("Spatial")) {
            return "org.elasticsearch.compute.aggregation.spatial";
        }
        return "org.elasticsearch.compute.aggregation";
    }

    private static Stream<NamedExpression> isToNE(List<IntermediateStateDesc> intermediateStateDescs) {
        return intermediateStateDescs.stream().map(is -> {
            DataType dataType = Strings.isEmpty((CharSequence)is.dataType()) ? AggregateMapper.toDataType(is.type()) : DataType.fromEs((String)is.dataType());
            return new ReferenceAttribute(Source.EMPTY, is.name(), dataType);
        });
    }

    private static DataType toDataType(ElementType elementType) {
        return switch (elementType) {
            default -> throw new IncompatibleClassChangeError();
            case ElementType.BOOLEAN -> DataType.BOOLEAN;
            case ElementType.BYTES_REF -> DataType.KEYWORD;
            case ElementType.INT -> DataType.INTEGER;
            case ElementType.LONG -> DataType.LONG;
            case ElementType.DOUBLE -> DataType.DOUBLE;
            case ElementType.FLOAT, ElementType.NULL, ElementType.DOC, ElementType.COMPOSITE, ElementType.UNKNOWN -> throw new EsqlIllegalArgumentException("unsupported agg type: " + String.valueOf(elementType));
        };
    }

    private static String dataTypeToString(DataType type, Class<?> aggClass) {
        if (aggClass == Count.class) {
            return "";
        }
        if (aggClass == ToPartial.class || aggClass == FromPartial.class) {
            return "";
        }
        if ((aggClass == Max.class || aggClass == Min.class) && type.equals((Object)DataType.IP)) {
            return "Ip";
        }
        if (aggClass == Top.class && type.equals((Object)DataType.IP)) {
            return "Ip";
        }
        return switch (type) {
            default -> throw new IncompatibleClassChangeError();
            case DataType.BOOLEAN -> "Boolean";
            case DataType.INTEGER, DataType.COUNTER_INTEGER -> "Int";
            case DataType.LONG, DataType.DATETIME, DataType.COUNTER_LONG, DataType.DATE_NANOS -> "Long";
            case DataType.DOUBLE, DataType.COUNTER_DOUBLE -> "Double";
            case DataType.KEYWORD, DataType.IP, DataType.VERSION, DataType.TEXT -> "BytesRef";
            case DataType.GEO_POINT -> "GeoPoint";
            case DataType.CARTESIAN_POINT -> "CartesianPoint";
            case DataType.SEMANTIC_TEXT, DataType.UNSUPPORTED, DataType.NULL, DataType.UNSIGNED_LONG, DataType.SHORT, DataType.BYTE, DataType.FLOAT, DataType.HALF_FLOAT, DataType.SCALED_FLOAT, DataType.OBJECT, DataType.SOURCE, DataType.DATE_PERIOD, DataType.TIME_DURATION, DataType.CARTESIAN_SHAPE, DataType.GEO_SHAPE, DataType.DOC_DATA_TYPE, DataType.TSID_DATA_TYPE, DataType.PARTIAL_AGG -> throw new EsqlIllegalArgumentException("illegal agg type: " + type.typeName());
        };
    }

    private record AggDef(Class<?> aggClazz, String type, String extra, boolean grouping) {
    }
}

