/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.optimizer.rules.logical;

import java.time.ZoneId;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.predicate.Predicates;
import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Or;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.expression.function.scalar.ip.CIDRMatch;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.In;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules;
import org.elasticsearch.xpack.esql.type.EsqlDataTypeConverter;

public final class CombineDisjunctions
extends OptimizerRules.OptimizerExpressionRule<Or> {
    public CombineDisjunctions() {
        super(OptimizerRules.TransformDirection.UP);
    }

    protected static In createIn(Expression key, List<Expression> values, ZoneId zoneId) {
        return new In(key.source(), key, values);
    }

    protected static Equals createEquals(Expression k, Set<Expression> v, ZoneId finalZoneId) {
        return new Equals(k.source(), k, v.iterator().next(), finalZoneId);
    }

    protected static CIDRMatch createCIDRMatch(Expression k, List<Expression> v) {
        return new CIDRMatch(k.source(), k, v);
    }

    @Override
    public Expression rule(Or or) {
        Expression combineOr;
        Or e = or;
        List exps = Predicates.splitOr((Expression)e);
        LinkedHashMap<Expression, Set> ins = new LinkedHashMap<Expression, Set>();
        LinkedHashMap<Expression, Set> cidrs = new LinkedHashMap<Expression, Set>();
        LinkedHashMap<Expression, Set> ips = new LinkedHashMap<Expression, Set>();
        ZoneId zoneId = null;
        LinkedList<Expression> ors = new LinkedList<Expression>();
        boolean changed = false;
        for (Expression exp : exps) {
            if (exp instanceof Equals) {
                Equals eq = (Equals)exp;
                if (eq.right().foldable()) {
                    ins.computeIfAbsent(eq.left(), k -> new LinkedHashSet()).add(eq.right());
                    if (eq.left().dataType() == DataType.IP) {
                        Object value = eq.right().fold();
                        if (value instanceof BytesRef) {
                            BytesRef bytesRef = (BytesRef)value;
                            value = EsqlDataTypeConverter.ipToString(bytesRef);
                        }
                        ips.computeIfAbsent(eq.left(), k -> new LinkedHashSet()).add(new Literal(Source.EMPTY, value, DataType.IP));
                    }
                } else {
                    ors.add(exp);
                }
                if (zoneId != null) continue;
                zoneId = eq.zoneId();
                continue;
            }
            if (exp instanceof In) {
                In in = (In)exp;
                ins.computeIfAbsent(in.value(), k -> new LinkedHashSet()).addAll(in.list());
                if (in.value().dataType() != DataType.IP) continue;
                ArrayList<Literal> values = new ArrayList<Literal>(in.list().size());
                for (Expression i : in.list()) {
                    Object value = i.fold();
                    if (value instanceof BytesRef) {
                        BytesRef bytesRef = (BytesRef)value;
                        value = EsqlDataTypeConverter.ipToString(bytesRef);
                    }
                    values.add(new Literal(Source.EMPTY, value, DataType.IP));
                }
                ips.computeIfAbsent(in.value(), k -> new LinkedHashSet()).addAll(values);
                continue;
            }
            if (exp instanceof CIDRMatch) {
                CIDRMatch cm = (CIDRMatch)exp;
                cidrs.computeIfAbsent(cm.ipField(), k -> new LinkedHashSet()).addAll(cm.matches());
                continue;
            }
            ors.add(exp);
        }
        if (!cidrs.isEmpty()) {
            for (Expression f : ips.keySet()) {
                cidrs.computeIfAbsent(f, k -> new LinkedHashSet()).addAll((Collection)ips.get(f));
                ins.remove(f);
            }
        }
        if (!ins.isEmpty()) {
            ZoneId finalZoneId = zoneId;
            ins.forEach((k, v) -> ors.add((Expression)(v.size() == 1 ? CombineDisjunctions.createEquals(k, v, finalZoneId) : CombineDisjunctions.createIn(k, new ArrayList<Expression>((Collection<Expression>)v), finalZoneId))));
            changed = true;
        }
        if (!cidrs.isEmpty()) {
            cidrs.forEach((k, v) -> ors.add((Expression)CombineDisjunctions.createCIDRMatch(k, new ArrayList<Expression>((Collection<Expression>)v))));
            changed = true;
        }
        if (changed && !e.semanticEquals(combineOr = Predicates.combineOr(ors))) {
            e = combineOr;
        }
        return e;
    }
}

