/*
 * Decompiled with CFR 0.152.
 */
package org.apache.cassandra.spark.bulkwriter;

import com.google.common.collect.BoundType;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.google.common.collect.Range;
import java.io.IOException;
import java.math.BigInteger;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.cassandra.spark.bulkwriter.BulkWriterContext;
import org.apache.cassandra.spark.bulkwriter.DirectDataTransferApi;
import org.apache.cassandra.spark.bulkwriter.DirectStreamSession;
import org.apache.cassandra.spark.bulkwriter.MockBulkWriterContext;
import org.apache.cassandra.spark.bulkwriter.MockScheduledExecutorService;
import org.apache.cassandra.spark.bulkwriter.MockTableWriter;
import org.apache.cassandra.spark.bulkwriter.NonValidatingTestSortedSSTableWriter;
import org.apache.cassandra.spark.bulkwriter.RingInstance;
import org.apache.cassandra.spark.bulkwriter.StreamSession;
import org.apache.cassandra.spark.bulkwriter.TokenRangeMappingUtils;
import org.apache.cassandra.spark.bulkwriter.TransportContext;
import org.apache.cassandra.spark.bulkwriter.token.ConsistencyLevel;
import org.apache.cassandra.spark.bulkwriter.token.MultiClusterReplicaAwareFailureHandler;
import org.apache.cassandra.spark.bulkwriter.token.ReplicaAwareFailureHandler;
import org.apache.cassandra.spark.bulkwriter.token.TokenRangeMapping;
import org.apache.cassandra.spark.common.model.CassandraInstance;
import org.apache.cassandra.spark.data.ReplicationFactor;
import org.apache.cassandra.spark.exception.ConsistencyNotSatisfiedException;
import org.apache.cassandra.spark.utils.DigestAlgorithm;
import org.apache.cassandra.spark.utils.XXHash32DigestAlgorithm;
import org.assertj.core.api.AbstractThrowableAssert;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;

public class StreamSessionConsistencyTest {
    private static final int NUMBER_DCS = 2;
    private static final int FILES_PER_SSTABLE = 8;
    private static final int REPLICATION_FACTOR = 3;
    private static final List<String> EXPECTED_INSTANCES = ImmutableList.of((Object)"DC1-i2", (Object)"DC1-i3", (Object)"DC1-i4", (Object)"DC2-i2", (Object)"DC2-i3", (Object)"DC2-i4");
    private static final Range<BigInteger> RANGE = Range.range((Comparable)BigInteger.valueOf(101L), (BoundType)BoundType.CLOSED, (Comparable)BigInteger.valueOf(199L), (BoundType)BoundType.CLOSED);
    private static final ImmutableMap<String, Integer> rfOptions = ImmutableMap.of((Object)"DC1", (Object)3, (Object)"DC2", (Object)3);
    private static final TokenRangeMapping<RingInstance> TOKEN_RANGE_MAPPING = TokenRangeMappingUtils.buildTokenRangeMapping(0, rfOptions, 6);
    private static final Map<String, Object> COLUMN_BIND_VALUES = ImmutableMap.of((Object)"id", (Object)0, (Object)"date", (Object)1, (Object)"course", (Object)"course", (Object)"marks", (Object)2);
    @TempDir
    private Path folder;
    private MockTableWriter tableWriter;
    private MockBulkWriterContext writerContext;
    private TransportContext.DirectDataBulkWriterContext transportContext;
    private final MockScheduledExecutorService executor = new MockScheduledExecutorService();
    private DigestAlgorithm digestAlgorithm;

    public static Collection<Object[]> data() {
        List cls = Arrays.stream(ConsistencyLevel.CL.values()).collect(Collectors.toList());
        List failures = IntStream.rangeClosed(0, 3).boxed().collect(Collectors.toList());
        List failuresPerDc = Lists.cartesianProduct((List[])new List[]{failures, failures});
        List clsToFailures = Lists.cartesianProduct((List[])new List[]{cls, failuresPerDc});
        return clsToFailures.stream().map(List::toArray).collect(Collectors.toList());
    }

    private void setup(ConsistencyLevel.CL consistencyLevel) {
        this.digestAlgorithm = new XXHash32DigestAlgorithm();
        this.tableWriter = new MockTableWriter(this.folder);
        this.writerContext = new MockBulkWriterContext(TOKEN_RANGE_MAPPING, "cassandra-5.0.5", consistencyLevel);
        this.writerContext.setReplicationFactor(new ReplicationFactor(ReplicationFactor.ReplicationStrategy.NetworkTopologyStrategy, rfOptions));
        this.transportContext = (TransportContext.DirectDataBulkWriterContext)this.writerContext.transportContext();
    }

    @ParameterizedTest(name="CL: {0}, numFailures: {1}")
    @MethodSource(value={"data"})
    public void testConsistencyLevelAndFailureInCommit(ConsistencyLevel.CL consistencyLevel, List<Integer> failuresPerDc) throws IOException, ExecutionException, InterruptedException {
        this.setup(consistencyLevel);
        AtomicInteger dc1Failures = new AtomicInteger(failuresPerDc.get(0));
        AtomicInteger dc2Failures = new AtomicInteger(failuresPerDc.get(1));
        ImmutableMap dcFailures = ImmutableMap.of((Object)"DC1", (Object)dc1Failures, (Object)"DC2", (Object)dc2Failures);
        boolean shouldFail = this.calculateFailure(consistencyLevel, dc1Failures.get(), dc2Failures.get());
        this.writerContext.setCommitResultSupplier((uuids, dc) -> {
            if (((AtomicInteger)dcFailures.get(dc)).getAndDecrement() > 0) {
                return new DirectDataTransferApi.RemoteCommitResult(false, uuids, null, "");
            }
            return new DirectDataTransferApi.RemoteCommitResult(true, null, null, "");
        });
        StreamSession<?> streamSession = this.createStreamSession(NonValidatingTestSortedSSTableWriter::new);
        streamSession.addRow(BigInteger.valueOf(102L), COLUMN_BIND_VALUES);
        Future fut = streamSession.finalizeStreamAsync();
        if (shouldFail) {
            ((AbstractThrowableAssert)Assertions.assertThatThrownBy(fut::get).isInstanceOf(ExecutionException.class)).hasCauseExactlyInstanceOf(ConsistencyNotSatisfiedException.class).hasMessageContaining("Failed to write 1 ranges with " + String.valueOf(consistencyLevel) + " for job " + this.writerContext.job().getId() + " in phase UploadAndCommit.");
        } else {
            fut.get();
        }
        this.executor.assertFuturesCalled();
        Assertions.assertThat((int)this.writerContext.getUploads().values().stream().mapToInt(Collection::size).sum()).isEqualTo(48);
        List instances = this.writerContext.getUploads().keySet().stream().map(CassandraInstance::nodeName).collect(Collectors.toList());
        Assertions.assertThat(instances).containsAll(EXPECTED_INSTANCES);
    }

    @ParameterizedTest(name="CL: {0}, numFailures: {1}")
    @MethodSource(value={"data"})
    public void testConsistencyLevelAndFailureInUpload(ConsistencyLevel.CL consistencyLevel, List<Integer> failuresPerDc) throws IOException, ExecutionException, InterruptedException {
        this.setup(consistencyLevel);
        AtomicInteger dc1Failures = new AtomicInteger(failuresPerDc.get(0));
        AtomicInteger dc2Failures = new AtomicInteger(failuresPerDc.get(1));
        int numFailures = dc1Failures.get() + dc2Failures.get();
        ImmutableMap dcFailures = ImmutableMap.of((Object)"DC1", (Object)dc1Failures, (Object)"DC2", (Object)dc2Failures);
        boolean shouldFail = this.calculateFailure(consistencyLevel, dc1Failures.get(), dc2Failures.get());
        this.writerContext.setUploadSupplier(instance -> ((AtomicInteger)dcFailures.get((Object)instance.datacenter())).getAndDecrement() <= 0);
        StreamSession<?> streamSession = this.createStreamSession(NonValidatingTestSortedSSTableWriter::new);
        streamSession.addRow(BigInteger.valueOf(102L), COLUMN_BIND_VALUES);
        Future fut = streamSession.finalizeStreamAsync();
        if (shouldFail) {
            ((AbstractThrowableAssert)Assertions.assertThatThrownBy(fut::get).isInstanceOf(ExecutionException.class)).hasCauseExactlyInstanceOf(ConsistencyNotSatisfiedException.class).hasMessageContaining("Failed to write 1 ranges with " + String.valueOf(consistencyLevel) + " for job " + this.writerContext.job().getId() + " in phase UploadAndCommit.");
        } else {
            fut.get();
        }
        this.executor.assertFuturesCalled();
        int totalFilesToUpload = 48;
        int filesSkipped = numFailures * 7;
        Assertions.assertThat((int)this.writerContext.getUploads().values().stream().mapToInt(Collection::size).sum()).isEqualTo(totalFilesToUpload - filesSkipped);
        List instances = this.writerContext.getUploads().keySet().stream().map(CassandraInstance::nodeName).collect(Collectors.toList());
        Assertions.assertThat(instances).containsAll(EXPECTED_INSTANCES);
    }

    private boolean calculateFailure(ConsistencyLevel.CL consistencyLevel, int dc1Failures, int dc2Failures) {
        int localQuorum = 2;
        int localFailuresViolatingQuorum = 3 - localQuorum;
        int totalInstances = 6;
        int allDcsQuorum = totalInstances / 2 + 1;
        switch (consistencyLevel) {
            case ALL: {
                return dc1Failures + dc2Failures > 0;
            }
            case TWO: {
                return dc1Failures + dc2Failures > totalInstances - 2;
            }
            case QUORUM: {
                return dc1Failures + dc2Failures > totalInstances - allDcsQuorum;
            }
            case LOCAL_QUORUM: {
                return dc1Failures > localFailuresViolatingQuorum;
            }
            case EACH_QUORUM: {
                return dc1Failures > localFailuresViolatingQuorum || dc2Failures > localFailuresViolatingQuorum;
            }
            case LOCAL_ONE: {
                return dc1Failures > 2;
            }
            case ONE: {
                return dc1Failures + dc2Failures > totalInstances - 1;
            }
        }
        throw new IllegalArgumentException("CL: " + String.valueOf(consistencyLevel) + " not supported");
    }

    private StreamSession<?> createStreamSession(MockTableWriter.Creator writerCreator) {
        return new DirectStreamSession((BulkWriterContext)this.writerContext, writerCreator.create(this.tableWriter, this.folder, this.digestAlgorithm, 1), this.transportContext, "sessionId", RANGE, (ReplicaAwareFailureHandler)new MultiClusterReplicaAwareFailureHandler(this.writerContext.cluster().getPartitioner()), (ExecutorService)this.executor);
    }
}

