/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.optimizer.rules.physical.local;

import java.lang.runtime.SwitchBootstraps;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.xpack.esql.core.expression.Alias;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialExtent;
import org.elasticsearch.xpack.esql.optimizer.LocalPhysicalOptimizerContext;
import org.elasticsearch.xpack.esql.optimizer.PhysicalOptimizerRules;
import org.elasticsearch.xpack.esql.plan.physical.AggregateExec;
import org.elasticsearch.xpack.esql.plan.physical.EvalExec;
import org.elasticsearch.xpack.esql.plan.physical.FieldExtractExec;
import org.elasticsearch.xpack.esql.plan.physical.FilterExec;
import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
import org.elasticsearch.xpack.esql.plan.physical.UnaryExec;

public class SpatialShapeBoundsExtraction
extends PhysicalOptimizerRules.ParameterizedOptimizerRule<AggregateExec, LocalPhysicalOptimizerContext> {
    @Override
    protected PhysicalPlan rule(AggregateExec aggregate, LocalPhysicalOptimizerContext ctx) {
        Set<Attribute> foundAttributes = SpatialShapeBoundsExtraction.findSpatialShapeBoundsAttributes(aggregate, ctx);
        if (foundAttributes.isEmpty()) {
            return aggregate;
        }
        return (PhysicalPlan)aggregate.transformDown(PhysicalPlan.class, exec -> {
            PhysicalPlan physicalPlan = exec;
            Objects.requireNonNull(physicalPlan);
            PhysicalPlan selector0$temp = physicalPlan;
            int index$1 = 0;
            return switch (SwitchBootstraps.typeSwitch("typeSwitch", new Object[]{AggregateExec.class, FieldExtractExec.class}, (Object)((Object)selector0$temp), index$1)) {
                case 0 -> {
                    AggregateExec agg = (AggregateExec)selector0$temp;
                    yield SpatialShapeBoundsExtraction.transformAggregateExec(agg, foundAttributes);
                }
                case 1 -> {
                    FieldExtractExec fieldExtractExec = (FieldExtractExec)selector0$temp;
                    yield SpatialShapeBoundsExtraction.transformFieldExtractExec(fieldExtractExec, foundAttributes);
                }
                default -> exec;
            };
        });
    }

    private static Set<Attribute> findSpatialShapeBoundsAttributes(AggregateExec aggregate, LocalPhysicalOptimizerContext ctx) {
        HashSet<Attribute> foundAttributes = new HashSet<Attribute>();
        aggregate.transformDown(UnaryExec.class, exec -> {
            UnaryExec unaryExec = exec;
            Objects.requireNonNull(unaryExec);
            UnaryExec selector0$temp = unaryExec;
            int index$1 = 0;
            switch (SwitchBootstraps.typeSwitch("typeSwitch", new Object[]{AggregateExec.class, EvalExec.class, FilterExec.class}, (Object)((Object)selector0$temp), index$1)) {
                case 0: {
                    AggregateExec agg = (AggregateExec)selector0$temp;
                    List aggregateFunctions = agg.aggregates().stream().flatMap(e -> SpatialShapeBoundsExtraction.extractAggregateFunction(e).stream()).toList();
                    List<SpatialExtent> spatialExtents = aggregateFunctions.stream().filter(SpatialExtent.class::isInstance).map(SpatialExtent.class::cast).toList();
                    List<AggregateFunction> nonSpatialExtents = aggregateFunctions.stream().filter(a -> !(a instanceof SpatialExtent)).toList();
                    Set fieldsAppearingInNonSpatialExtents = nonSpatialExtents.stream().flatMap(af -> af.references().stream()).filter(FieldAttribute.class::isInstance).map(f -> ((FieldAttribute)f).field()).collect(Collectors.toSet());
                    spatialExtents.stream().map(AggregateFunction::field).filter(FieldAttribute.class::isInstance).map(FieldAttribute.class::cast).filter(f -> SpatialShapeBoundsExtraction.isShape(f.field().getDataType()) && !fieldsAppearingInNonSpatialExtents.contains(f.field()) && ctx.searchStats().hasDocValues(f.fieldName())).forEach(foundAttributes::add);
                    break;
                }
                case 1: {
                    EvalExec evalExec = (EvalExec)selector0$temp;
                    foundAttributes.removeAll((Collection<?>)evalExec.references());
                    break;
                }
                case 2: {
                    FilterExec filterExec = (FilterExec)selector0$temp;
                    foundAttributes.removeAll((Collection<?>)filterExec.condition().references());
                    break;
                }
            }
            return exec;
        });
        return foundAttributes;
    }

    private static PhysicalPlan transformFieldExtractExec(FieldExtractExec fieldExtractExec, Set<Attribute> foundAttributes) {
        HashSet<Attribute> boundsAttributes = new HashSet<Attribute>(foundAttributes);
        boundsAttributes.retainAll(fieldExtractExec.attributesToExtract());
        return fieldExtractExec.withBoundsAttributes(boundsAttributes);
    }

    private static PhysicalPlan transformAggregateExec(AggregateExec agg, Set<Attribute> foundAttributes) {
        return (PhysicalPlan)((Object)agg.transformExpressionsDown(SpatialExtent.class, spatialExtent -> foundAttributes.contains(spatialExtent.field()) ? spatialExtent.withFieldExtractPreference(MappedFieldType.FieldExtractPreference.EXTRACT_SPATIAL_BOUNDS) : spatialExtent));
    }

    private static boolean isShape(DataType dataType) {
        return dataType == DataType.GEO_SHAPE || dataType == DataType.CARTESIAN_SHAPE;
    }

    private static Optional<AggregateFunction> extractAggregateFunction(NamedExpression expr) {
        Optional<AggregateFunction> optional;
        Alias as;
        Expression expression;
        if (expr instanceof Alias && (expression = (as = (Alias)expr).child()) instanceof AggregateFunction) {
            AggregateFunction af = (AggregateFunction)expression;
            optional = Optional.of(af);
        } else {
            optional = Optional.empty();
        }
        return optional;
    }
}

