/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.optimizer.rules.logical;

import java.util.ArrayList;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.function.BiFunction;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.expression.function.scalar.nulls.Coalesce;
import org.elasticsearch.xpack.esql.expression.predicate.Predicates;
import org.elasticsearch.xpack.esql.expression.predicate.logical.And;
import org.elasticsearch.xpack.esql.expression.predicate.nulls.IsNotNull;
import org.elasticsearch.xpack.esql.expression.predicate.nulls.IsNull;
import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules;

public class PropagateNullable
extends OptimizerRules.OptimizerExpressionRule<And> {
    public PropagateNullable() {
        super(OptimizerRules.TransformDirection.DOWN);
    }

    @Override
    public Expression rule(And and, LogicalOptimizerContext ctx) {
        List<Expression> splits = Predicates.splitAnd((Expression)and);
        LinkedHashSet<Expression> nullExpressions = new LinkedHashSet<Expression>();
        LinkedHashSet<Expression> notNullExpressions = new LinkedHashSet<Expression>();
        LinkedList<Expression> others = new LinkedList<Expression>();
        for (Expression ex : splits) {
            if (ex instanceof IsNull) {
                IsNull isn = (IsNull)ex;
                nullExpressions.add(isn.field());
                continue;
            }
            if (ex instanceof IsNotNull) {
                IsNotNull isnn = (IsNotNull)ex;
                notNullExpressions.add(isnn.field());
                continue;
            }
            others.add(ex);
        }
        if (Sets.haveNonEmptyIntersection(nullExpressions, notNullExpressions)) {
            return Literal.of((Expression)and, (Object)Boolean.FALSE);
        }
        boolean modified = PropagateNullable.replace(nullExpressions, others, splits, this::nullify);
        if (modified |= PropagateNullable.replace(notNullExpressions, others, splits, this::nonNullify)) {
            return Predicates.combineAnd(splits);
        }
        return and;
    }

    private static boolean replace(Iterable<Expression> pattern, List<Expression> target, List<Expression> originalExpressions, BiFunction<Expression, Expression, Expression> replacer) {
        boolean modified = false;
        for (Expression s : pattern) {
            for (int i = 0; i < target.size(); ++i) {
                Expression replacement;
                Expression t = target.get(i);
                if (!t.anyMatch(arg_0 -> ((Expression)s).semanticEquals(arg_0)) || (replacement = replacer.apply(t, s)) == t) continue;
                modified = true;
                target.set(i, replacement);
                originalExpressions.replaceAll(e -> t.semanticEquals(e) ? replacement : e);
            }
        }
        return modified;
    }

    protected Expression nonNullify(Expression exp, Expression nonNullExp) {
        return exp;
    }

    protected Expression nullify(Expression exp, Expression nullExp) {
        if (exp instanceof Coalesce) {
            ArrayList<Expression> newChildren = new ArrayList<Expression>(exp.children());
            newChildren.removeIf(e -> e.semanticEquals(nullExp));
            if (newChildren.size() != exp.children().size() && newChildren.size() > 0) {
                return (Expression)exp.replaceChildren(newChildren);
            }
        }
        return Literal.of((Expression)exp, null);
    }
}

