/*
 * Decompiled with CFR 0.152.
 */
package org.apache.tez.runtime.library.cartesianproduct;

import com.google.common.primitives.Ints;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.tez.common.ReflectionUtils;
import org.apache.tez.dag.api.TezReflectionException;
import org.apache.tez.dag.api.UserPayload;
import org.apache.tez.dag.api.VertexManagerPluginContext;
import org.apache.tez.dag.api.event.VertexState;
import org.apache.tez.dag.api.event.VertexStateUpdate;
import org.apache.tez.runtime.api.TaskAttemptIdentifier;
import org.apache.tez.runtime.api.events.VertexManagerEvent;
import org.apache.tez.runtime.library.cartesianproduct.CartesianProductCombination;
import org.apache.tez.runtime.library.cartesianproduct.CartesianProductFilter;
import org.apache.tez.runtime.library.cartesianproduct.CartesianProductUserPayload;
import org.apache.tez.runtime.library.cartesianproduct.CartesianProductVertexManagerReal;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

class CartesianProductVertexManagerPartitioned
extends CartesianProductVertexManagerReal {
    private List<String> sourceVertices;
    private int[] numPartitions;
    private float minFraction;
    private float maxFraction;
    private int parallelism = 0;
    private boolean vertexStarted = false;
    private boolean vertexReconfigured = false;
    private int numCPSrcNotInConfiguredState = 0;
    private int numBroadcastSrcNotInRunningState = 0;
    private CartesianProductFilter filter;
    private Map<String, BitSet> sourceTaskCompleted = new HashMap<String, BitSet>();
    private int numFinishedSrcTasks = 0;
    private int totalNumSrcTasks = 0;
    private int lastScheduledTaskId = -1;
    private static final Logger LOG = LoggerFactory.getLogger(CartesianProductVertexManagerPartitioned.class);

    public CartesianProductVertexManagerPartitioned(VertexManagerPluginContext context) {
        super(context);
    }

    @Override
    public void initialize(CartesianProductUserPayload.CartesianProductConfigProto config) throws TezReflectionException {
        this.sourceVertices = config.getSourcesList();
        this.numPartitions = Ints.toArray(config.getNumPartitionsList());
        this.minFraction = config.hasMinFraction() ? config.getMinFraction() : 0.25f;
        float f = this.maxFraction = config.hasMaxFraction() ? config.getMaxFraction() : 0.75f;
        if (config.hasFilterClassName()) {
            UserPayload userPayload = config.hasFilterUserPayload() ? UserPayload.create((ByteBuffer)ByteBuffer.wrap(config.getFilterUserPayload().toByteArray())) : null;
            try {
                this.filter = (CartesianProductFilter)ReflectionUtils.createClazzInstance((String)config.getFilterClassName(), (Class[])new Class[]{UserPayload.class}, (Object[])new UserPayload[]{userPayload});
            }
            catch (TezReflectionException e) {
                LOG.error("Creating filter failed");
                throw e;
            }
        }
        for (String sourceVertex : this.sourceVertices) {
            this.sourceTaskCompleted.put(sourceVertex, new BitSet());
        }
        for (String vertex : this.getContext().getInputVertexEdgeProperties().keySet()) {
            if (this.sourceVertices.indexOf(vertex) != -1) {
                this.getContext().registerForVertexStateUpdates(vertex, EnumSet.of(VertexState.CONFIGURED));
                ++this.numCPSrcNotInConfiguredState;
                continue;
            }
            this.getContext().registerForVertexStateUpdates(vertex, EnumSet.of(VertexState.RUNNING));
            ++this.numBroadcastSrcNotInRunningState;
        }
        this.getContext().vertexReconfigurationPlanned();
    }

    @Override
    public void onVertexManagerEventReceived(VertexManagerEvent vmEvent) throws IOException {
    }

    @Override
    public synchronized void onVertexStarted(List<TaskAttemptIdentifier> completions) throws Exception {
        this.vertexStarted = true;
        if (completions != null) {
            for (TaskAttemptIdentifier attempt : completions) {
                this.onSourceTaskCompleted(attempt);
            }
        }
        this.tryScheduleTask();
    }

    @Override
    public synchronized void onVertexStateUpdated(VertexStateUpdate stateUpdate) throws IOException {
        VertexState state = stateUpdate.getVertexState();
        if (state == VertexState.CONFIGURED) {
            if (!this.vertexReconfigured) {
                this.reconfigureVertex();
            }
            --this.numCPSrcNotInConfiguredState;
            this.totalNumSrcTasks += this.getContext().getVertexNumTasks(stateUpdate.getVertexName());
        } else if (state == VertexState.RUNNING) {
            --this.numBroadcastSrcNotInRunningState;
        }
        this.tryScheduleTask();
    }

    @Override
    public synchronized void onSourceTaskCompleted(TaskAttemptIdentifier attempt) throws Exception {
        int taskId = attempt.getTaskIdentifier().getIdentifier();
        String vertex = attempt.getTaskIdentifier().getVertexIdentifier().getName();
        if (!this.sourceTaskCompleted.containsKey(vertex)) {
            return;
        }
        BitSet bitSet = this.sourceTaskCompleted.get(vertex);
        if (!bitSet.get(taskId)) {
            bitSet.set(taskId);
            ++this.numFinishedSrcTasks;
            this.tryScheduleTask();
        }
    }

    private void reconfigureVertex() throws IOException {
        HashMap<String, Integer> vertexPartitionMap = new HashMap<String, Integer>();
        CartesianProductCombination combination = new CartesianProductCombination(this.numPartitions);
        combination.firstTask();
        do {
            for (int i = 0; i < this.sourceVertices.size(); ++i) {
                vertexPartitionMap.put(this.sourceVertices.get(i), combination.getCombination().get(i));
            }
            if (this.filter != null && !this.filter.isValidCombination(vertexPartitionMap)) continue;
            ++this.parallelism;
        } while (combination.nextTask());
        this.getContext().reconfigureVertex(this.parallelism, null, null);
        this.vertexReconfigured = true;
        this.getContext().doneReconfiguringVertex();
    }

    private void tryScheduleTask() {
        if (!this.vertexStarted || this.numCPSrcNotInConfiguredState > 0 || this.numBroadcastSrcNotInRunningState > 0) {
            return;
        }
        float percentFinishedSrcTask = (float)this.numFinishedSrcTasks * 1.0f / (float)this.totalNumSrcTasks;
        int numTaskToSchedule = percentFinishedSrcTask < this.minFraction ? 0 : (this.minFraction <= percentFinishedSrcTask && percentFinishedSrcTask <= this.maxFraction ? (int)((percentFinishedSrcTask - this.minFraction) / (this.maxFraction - this.minFraction) * (float)this.parallelism) : this.parallelism);
        if (numTaskToSchedule - 1 > this.lastScheduledTaskId) {
            ArrayList<VertexManagerPluginContext.ScheduleTaskRequest> scheduleTaskRequests = new ArrayList<VertexManagerPluginContext.ScheduleTaskRequest>();
            for (int i = this.lastScheduledTaskId + 1; i < numTaskToSchedule; ++i) {
                scheduleTaskRequests.add(VertexManagerPluginContext.ScheduleTaskRequest.create((int)i, null));
            }
            this.lastScheduledTaskId = numTaskToSchedule - 1;
            this.getContext().scheduleTasks(scheduleTaskRequests);
        }
    }
}

