/*
 * Decompiled with CFR 0.152.
 */
package org.apache.kylin.query.optrule;

import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.plan.RelOptRuleOperandChildren;
import org.apache.calcite.plan.hep.HepRelVertex;
import org.apache.calcite.rel.RelCollation;
import org.apache.calcite.rel.RelCollations;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.RelDataTypeSystem;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.BasicSqlType;
import org.apache.calcite.sql.type.SqlTypeFactoryImpl;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.kylin.guava30.shaded.common.collect.Lists;
import org.apache.kylin.metadata.datatype.DataType;
import org.apache.kylin.query.calcite.KylinRelDataTypeSystem;
import org.apache.kylin.query.relnode.ContextUtil;
import org.apache.kylin.query.relnode.OlapAggregateRel;
import org.apache.kylin.query.relnode.OlapProjectRel;
import org.apache.kylin.query.util.AggExpressionUtil;
import org.apache.kylin.query.util.RuleUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class OlapSumCastTransposeRule
extends RelOptRule {
    private static final Logger logger = LoggerFactory.getLogger(OlapSumCastTransposeRule.class);
    public static final OlapSumCastTransposeRule INSTANCE = new OlapSumCastTransposeRule(OlapSumCastTransposeRule.operand(OlapAggregateRel.class, (RelOptRuleOperand)OlapSumCastTransposeRule.operand(OlapProjectRel.class, null, OlapSumCastTransposeRule::needSumCastTranspose, (RelOptRuleOperandChildren)OlapSumCastTransposeRule.any()), (RelOptRuleOperand[])new RelOptRuleOperand[0]), RelFactories.LOGICAL_BUILDER, "OlapSumTransCastToThenRule");

    public OlapSumCastTransposeRule(RelOptRuleOperand operand, RelBuilderFactory relBuilderFactory, String description) {
        super(operand, relBuilderFactory, description);
    }

    public static boolean needSumCastTranspose(Project project) {
        if (project.getInput() instanceof HepRelVertex && ((HepRelVertex)project.getInput()).getCurrentRel() instanceof OlapAggregateRel) {
            return false;
        }
        List childExps = project.getProjects();
        for (RexNode rexNode : childExps) {
            if (!RuleUtils.containCast(rexNode)) continue;
            return true;
        }
        return false;
    }

    public boolean matches(RelOptRuleCall call) {
        Aggregate originalAgg = (Aggregate)call.rel(0);
        Project originalProject = (Project)call.rel(1);
        for (AggregateCall aggCall : originalAgg.getAggCallList()) {
            if (!AggExpressionUtil.isSum(aggCall.getAggregation().kind)) continue;
            int index = (Integer)aggCall.getArgList().get(0);
            RexNode value = (RexNode)originalProject.getProjects().get(index);
            if (!RuleUtils.containCast(value)) continue;
            RexNode rexNode = (RexNode)((RexCall)value).getOperands().get(0);
            DataType dataType = DataType.getType((String)rexNode.getType().getSqlTypeName().getName());
            return dataType.isNumberFamily() || dataType.isIntegerFamily();
        }
        return false;
    }

    public void onMatch(RelOptRuleCall call) {
        try {
            RelBuilder relBuilder = call.builder();
            Aggregate originalAgg = (Aggregate)call.rel(0);
            Project originalProject = (Project)call.rel(1);
            RelNode relNode = this.transposeSumCast(relBuilder, originalAgg, originalProject);
            ContextUtil.dumpCalcitePlan((String)"new plan", (RelNode)relNode, (Logger)logger);
            call.transformTo(relNode);
        }
        catch (Exception e) {
            logger.error("sql cannot apply sum cast transpose rule ", (Throwable)e);
        }
    }

    private RelNode transposeSumCast(RelBuilder relBuilder, Aggregate oldAgg, Project oldProject) {
        relBuilder.push(oldProject.getInput());
        List<AggExpressionUtil.AggExpression> aggExpressions = oldAgg.getAggCallList().stream().map(AggExpressionUtil.AggExpression::new).collect(Collectors.toList());
        List<RexNode> bottomProjectList = this.buildBottomProject(oldProject, aggExpressions);
        relBuilder.project(bottomProjectList);
        ImmutableBitSet bottomAggGroupSet = oldAgg.getGroupSet();
        RelBuilder.GroupKey groupKey = relBuilder.groupKey(bottomAggGroupSet);
        List<AggregateCall> aggCalls = this.buildBottomAggregate(relBuilder, aggExpressions, bottomAggGroupSet.cardinality());
        relBuilder.aggregate(groupKey, aggCalls);
        List<RexNode> caseProjList = this.buildTopProject(relBuilder, oldProject, oldAgg, aggExpressions);
        relBuilder.project(caseProjList);
        return relBuilder.build();
    }

    private List<RexNode> buildBottomProject(Project oldProject, List<AggExpressionUtil.AggExpression> aggExpressions) {
        ArrayList bottomProjectList = Lists.newArrayList();
        bottomProjectList.addAll(oldProject.getProjects());
        KylinRelDataTypeSystem typeSystem = new KylinRelDataTypeSystem();
        SqlTypeFactoryImpl typeFactory = new SqlTypeFactoryImpl((RelDataTypeSystem)typeSystem);
        for (AggExpressionUtil.AggExpression aggExpression : aggExpressions) {
            AggregateCall aggCall = aggExpression.getAggCall();
            if (!AggExpressionUtil.isSum(aggCall.getAggregation().kind)) continue;
            int index = (Integer)aggCall.getArgList().get(0);
            RexNode value = (RexNode)oldProject.getProjects().get(index);
            if (!RuleUtils.containCast(value)) continue;
            bottomProjectList.set(index, ((RexCall)value).operands.get(0));
            RelDataType type = ((RexNode)((RexCall)value).operands.get(0)).getType();
            type = typeSystem.deriveSumType((RelDataTypeFactory)typeFactory, type);
            aggExpression.setType(type);
        }
        return bottomProjectList;
    }

    private List<AggregateCall> buildBottomAggregate(RelBuilder relBuilder, List<AggExpressionUtil.AggExpression> aggExpressions, int bottomAggOffset) {
        ArrayList bottomAggCalls = Lists.newArrayList();
        for (AggExpressionUtil.AggExpression aggExpression : aggExpressions) {
            AggregateCall aggCall = aggExpression.getAggCall();
            if (AggExpressionUtil.isSum(aggCall.getAggregation().kind)) {
                AggregateCall oldAggCall = aggExpression.getAggCall();
                bottomAggCalls.add(AggregateCall.create((SqlAggFunction)SqlStdOperatorTable.SUM, (boolean)false, (boolean)false, (boolean)false, (List)aggExpression.getAggCall().getArgList(), (int)-1, null, (RelCollation)RelCollations.EMPTY, (int)bottomAggOffset, (RelNode)relBuilder.peek(), (RelDataType)aggExpression.getType(), (String)oldAggCall.name));
                continue;
            }
            bottomAggCalls.add(aggExpression.getAggCall());
        }
        return bottomAggCalls;
    }

    private List<RexNode> buildTopProject(RelBuilder relBuilder, Project oldProject, Aggregate oldAgg, List<AggExpressionUtil.AggExpression> aggExpressions) {
        int i;
        ArrayList topProjectList = Lists.newArrayList();
        RexBuilder rexBuilder = relBuilder.getRexBuilder();
        int groupSize = oldAgg.getGroupSet().asSet().size();
        for (i = 0; i < groupSize; ++i) {
            topProjectList.add(relBuilder.getRexBuilder().makeInputRef(relBuilder.peek(), i));
        }
        for (AggExpressionUtil.AggExpression aggExpression : aggExpressions) {
            AggregateCall aggCall = aggExpression.getAggCall();
            if (AggExpressionUtil.isSum(aggCall.getAggregation().kind)) {
                int index = (Integer)aggCall.getArgList().get(0);
                RexNode value = (RexNode)oldProject.getProjects().get(index);
                if (RuleUtils.containCast(value)) {
                    RelDataType type = ((RexCall)value).type;
                    if (type instanceof BasicSqlType && type.getPrecision() < aggCall.getType().getPrecision()) {
                        type = aggCall.getType();
                    }
                    value = relBuilder.getRexBuilder().makeCast(type, (RexNode)relBuilder.getRexBuilder().makeInputRef(relBuilder.peek(), i));
                    topProjectList.add(value);
                } else if (RuleUtils.isNotNullLiteral(value)) {
                    value = relBuilder.getRexBuilder().makeInputRef(relBuilder.peek(), i);
                    topProjectList.add(value);
                } else {
                    topProjectList.add(rexBuilder.makeBigintLiteral(BigDecimal.ZERO));
                }
            } else {
                RexInputRef rexNode = relBuilder.getRexBuilder().makeInputRef(relBuilder.peek(), i);
                topProjectList.add(rexNode);
            }
            ++i;
        }
        return topProjectList;
    }
}

