/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.planner.plan.rules.physical.stream;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.convert.ConverterRule;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlKind;
import org.apache.flink.calcite.shaded.org.checkerframework.checker.nullness.qual.Nullable;
import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
import org.apache.flink.table.planner.plan.nodes.FlinkConventions;
import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalMultiJoin;
import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalMultiJoin;
import org.apache.flink.table.planner.plan.trait.FlinkRelDistribution;
import org.apache.flink.table.runtime.operators.join.stream.keyselector.AttributeBasedJoinKeyExtractor;
import org.apache.flink.table.runtime.operators.join.stream.keyselector.JoinKeyExtractor;

public class StreamPhysicalMultiJoinRule
extends ConverterRule {
    public static final RelOptRule INSTANCE = new StreamPhysicalMultiJoinRule();

    private StreamPhysicalMultiJoinRule() {
        super(ConverterRule.Config.INSTANCE.withConversion(FlinkLogicalMultiJoin.class, FlinkConventions.LOGICAL(), FlinkConventions.STREAM_PHYSICAL(), "StreamPhysicalMultiJoinRule"));
    }

    @Override
    public RelNode convert(RelNode rel) {
        FlinkLogicalMultiJoin multiJoin = (FlinkLogicalMultiJoin)rel;
        Map<Integer, List<AttributeBasedJoinKeyExtractor.ConditionAttributeRef>> joinAttributeMap = this.createJoinAttributeMap(multiJoin);
        List inputRowTypes = multiJoin.getInputs().stream().map(i -> FlinkTypeFactory.toLogicalRowType(i.getRowType())).collect(Collectors.toList());
        AttributeBasedJoinKeyExtractor keyExtractor = new AttributeBasedJoinKeyExtractor(joinAttributeMap, inputRowTypes);
        List<RelNode> newInputs = this.createHashDistributedInputs(multiJoin.getInputs(), (JoinKeyExtractor)keyExtractor);
        RelTraitSet traitSet = rel.getTraitSet().replace(FlinkConventions.STREAM_PHYSICAL());
        return new StreamPhysicalMultiJoin(multiJoin.getCluster(), traitSet, newInputs, multiJoin.getJoinFilter(), multiJoin.getRowType(), multiJoin.getJoinConditions(), multiJoin.getJoinTypes(), joinAttributeMap, multiJoin.getPostJoinFilter(), multiJoin.getHints(), (JoinKeyExtractor)keyExtractor);
    }

    private List<RelNode> createHashDistributedInputs(List<RelNode> inputs, JoinKeyExtractor keyExtractor) {
        ArrayList<RelNode> newInputs = new ArrayList<RelNode>();
        for (int i = 0; i < inputs.size(); ++i) {
            RelNode input = inputs.get(i);
            RelTraitSet inputTraitSet = this.createInputTraitSet(input, keyExtractor, i);
            RelNode convertedInput = RelOptRule.convert(input, inputTraitSet.simplify());
            newInputs.add(convertedInput);
        }
        return newInputs;
    }

    private RelTraitSet createInputTraitSet(RelNode input, JoinKeyExtractor keyExtractor, int inputIndex) {
        int[] commonJoinKeyIndices = keyExtractor.getCommonJoinKeyIndices(inputIndex);
        RelTraitSet inputTraitSet = input.getTraitSet().replace(FlinkConventions.STREAM_PHYSICAL());
        if (commonJoinKeyIndices.length > 0) {
            FlinkRelDistribution hashDistribution = FlinkRelDistribution.hash(commonJoinKeyIndices, true);
            inputTraitSet = inputTraitSet.replace(hashDistribution);
        } else {
            inputTraitSet = inputTraitSet.replace(FlinkRelDistribution.SINGLETON());
        }
        return inputTraitSet;
    }

    private Map<Integer, List<AttributeBasedJoinKeyExtractor.ConditionAttributeRef>> createJoinAttributeMap(FlinkLogicalMultiJoin multiJoin) {
        HashMap<Integer, List<AttributeBasedJoinKeyExtractor.ConditionAttributeRef>> joinAttributeMap = new HashMap<Integer, List<AttributeBasedJoinKeyExtractor.ConditionAttributeRef>>();
        List<Integer> inputFieldCounts = multiJoin.getInputs().stream().map(input -> input.getRowType().getFieldCount()).collect(Collectors.toList());
        ArrayList<Integer> inputOffsets = new ArrayList<Integer>();
        int currentOffset = 0;
        for (Integer count : inputFieldCounts) {
            inputOffsets.add(currentOffset);
            currentOffset += count.intValue();
        }
        List<? extends RexNode> joinConditions = multiJoin.getJoinConditions();
        for (RexNode rexNode : joinConditions) {
            this.extractEqualityConditions(rexNode, inputOffsets, inputFieldCounts, joinAttributeMap);
        }
        return joinAttributeMap;
    }

    private void extractEqualityConditions(RexNode condition, List<Integer> inputOffsets, List<Integer> inputFieldCounts, Map<Integer, List<AttributeBasedJoinKeyExtractor.ConditionAttributeRef>> joinAttributeMap) {
        InputRef rightRef;
        InputRef leftRef;
        if (!(condition instanceof RexCall)) {
            return;
        }
        RexCall call = (RexCall)condition;
        SqlKind kind = call.getOperator().getKind();
        if (kind != SqlKind.EQUALS) {
            for (RexNode operand : call.getOperands()) {
                this.extractEqualityConditions(operand, inputOffsets, inputFieldCounts, joinAttributeMap);
            }
            return;
        }
        if (call.getOperands().size() != 2) {
            return;
        }
        RexNode op1 = call.getOperands().get(0);
        RexNode op2 = call.getOperands().get(1);
        if (!(op1 instanceof RexInputRef) || !(op2 instanceof RexInputRef)) {
            return;
        }
        InputRef inputRef1 = this.findInputRef(((RexInputRef)op1).getIndex(), inputOffsets, inputFieldCounts);
        InputRef inputRef2 = this.findInputRef(((RexInputRef)op2).getIndex(), inputOffsets, inputFieldCounts);
        if (inputRef1 == null || inputRef2 == null) {
            return;
        }
        if (inputRef1.inputIndex < inputRef2.inputIndex) {
            leftRef = inputRef1;
            rightRef = inputRef2;
        } else {
            leftRef = inputRef2;
            rightRef = inputRef1;
        }
        AttributeBasedJoinKeyExtractor.ConditionAttributeRef attrRef = new AttributeBasedJoinKeyExtractor.ConditionAttributeRef(leftRef.inputIndex, leftRef.attributeIndex, rightRef.inputIndex, rightRef.attributeIndex);
        joinAttributeMap.computeIfAbsent(rightRef.inputIndex, k -> new ArrayList()).add(attrRef);
    }

    private @Nullable InputRef findInputRef(int fieldIndex, List<Integer> inputOffsets, List<Integer> inputFieldCounts) {
        for (int i = 0; i < inputOffsets.size(); ++i) {
            int offset = inputOffsets.get(i);
            if (fieldIndex < offset || fieldIndex >= offset + inputFieldCounts.get(i)) continue;
            return new InputRef(i, fieldIndex - offset);
        }
        return null;
    }

    private static final class InputRef {
        private final int inputIndex;
        private final int attributeIndex;

        private InputRef(int inputIndex, int attributeIndex) {
            this.inputIndex = inputIndex;
            this.attributeIndex = attributeIndex;
        }
    }
}

