/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.scheduler.adaptivebatch;

import java.util.List;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.JobManagerOptions;
import org.apache.flink.configuration.MemorySize;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.scheduler.adaptivebatch.BlockingResultInfo;
import org.apache.flink.runtime.scheduler.adaptivebatch.VertexParallelismDecider;
import org.apache.flink.util.MathUtils;
import org.apache.flink.util.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DefaultVertexParallelismDecider
implements VertexParallelismDecider {
    private static final Logger LOG = LoggerFactory.getLogger(DefaultVertexParallelismDecider.class);
    private static final double CAP_RATIO_OF_BROADCAST = 0.5;
    private final int globalMaxParallelism;
    private final int globalMinParallelism;
    private final long dataVolumePerTask;
    private final int globalDefaultSourceParallelism;

    private DefaultVertexParallelismDecider(int globalMaxParallelism, int globalMinParallelism, MemorySize dataVolumePerTask, int globalDefaultSourceParallelism) {
        Preconditions.checkArgument((globalMinParallelism > 0 ? 1 : 0) != 0, (Object)"The minimum parallelism must be larger than 0.");
        Preconditions.checkArgument((globalMaxParallelism >= globalMinParallelism ? 1 : 0) != 0, (Object)"Maximum parallelism should be greater than or equal to the minimum parallelism.");
        Preconditions.checkArgument((globalDefaultSourceParallelism > 0 ? 1 : 0) != 0, (Object)"The default source parallelism must be larger than 0.");
        Preconditions.checkNotNull((Object)dataVolumePerTask);
        this.globalMaxParallelism = globalMaxParallelism;
        this.globalMinParallelism = globalMinParallelism;
        this.dataVolumePerTask = dataVolumePerTask.getBytes();
        this.globalDefaultSourceParallelism = globalDefaultSourceParallelism;
    }

    @Override
    public int decideParallelismForVertex(JobVertexID jobVertexId, List<BlockingResultInfo> consumedResults, int vertexMaxParallelism) {
        if (consumedResults.isEmpty()) {
            return this.computeSourceParallelism(jobVertexId, vertexMaxParallelism);
        }
        int minParallelism = this.globalMinParallelism;
        int maxParallelism = this.globalMaxParallelism;
        if ((vertexMaxParallelism = DefaultVertexParallelismDecider.getNormalizedMaxParallelism(vertexMaxParallelism)) < minParallelism) {
            LOG.info("The vertex maximum parallelism {} is smaller than the global minimum parallelism {}. Use {} as the lower bound to decide parallelism of job vertex {}.", new Object[]{vertexMaxParallelism, minParallelism, vertexMaxParallelism, jobVertexId});
            minParallelism = vertexMaxParallelism;
        }
        if (vertexMaxParallelism < maxParallelism) {
            LOG.info("The vertex maximum parallelism {} is smaller than the global maximum parallelism {}. Use {} as the upper bound to decide parallelism of job vertex {}.", new Object[]{vertexMaxParallelism, maxParallelism, vertexMaxParallelism, jobVertexId});
            maxParallelism = vertexMaxParallelism;
        }
        Preconditions.checkState((maxParallelism >= minParallelism ? 1 : 0) != 0);
        return this.computeParallelism(jobVertexId, consumedResults, maxParallelism, minParallelism);
    }

    private int computeSourceParallelism(JobVertexID jobVertexId, int maxParallelism) {
        if (this.globalDefaultSourceParallelism > maxParallelism) {
            LOG.info("The global default source parallelism {} is larger than the maximum parallelism {}. Use {} as the parallelism of source job vertex {}.", new Object[]{this.globalDefaultSourceParallelism, maxParallelism, maxParallelism, jobVertexId});
            return maxParallelism;
        }
        return this.globalDefaultSourceParallelism;
    }

    private int computeParallelism(JobVertexID jobVertexId, List<BlockingResultInfo> consumedResults, int maxParallelism, int minParallelism) {
        long broadcastBytes = consumedResults.stream().filter(BlockingResultInfo::isBroadcast).mapToLong(consumedResult -> consumedResult.getBlockingPartitionSizes().stream().reduce(0L, Long::sum)).sum();
        long nonBroadcastBytes = consumedResults.stream().filter(consumedResult -> !consumedResult.isBroadcast()).mapToLong(consumedResult -> consumedResult.getBlockingPartitionSizes().stream().reduce(0L, Long::sum)).sum();
        long expectedMaxBroadcastBytes = (long)Math.ceil((double)this.dataVolumePerTask * 0.5);
        if (broadcastBytes > expectedMaxBroadcastBytes) {
            LOG.info("The size of broadcast data {} is larger than the expected maximum value {} ('{}' * {}). Use {} as the size of broadcast data to decide the parallelism of job vertex {}.", new Object[]{new MemorySize(broadcastBytes), new MemorySize(expectedMaxBroadcastBytes), JobManagerOptions.ADAPTIVE_BATCH_SCHEDULER_AVG_DATA_VOLUME_PER_TASK.key(), 0.5, new MemorySize(expectedMaxBroadcastBytes), jobVertexId});
            broadcastBytes = expectedMaxBroadcastBytes;
        }
        int initialParallelism = (int)Math.ceil((double)nonBroadcastBytes / (double)(this.dataVolumePerTask - broadcastBytes));
        int parallelism = DefaultVertexParallelismDecider.normalizeParallelism(initialParallelism);
        LOG.debug("The size of broadcast data is {}, the size of non-broadcast data is {}, the initially decided parallelism of job vertex {} is {}, after normalization is {}", new Object[]{new MemorySize(broadcastBytes), new MemorySize(nonBroadcastBytes), jobVertexId, initialParallelism, parallelism});
        if (parallelism < minParallelism) {
            LOG.info("The initially normalized parallelism {} is smaller than the normalized minimum parallelism {}. Use {} as the finally decided parallelism of job vertex {}.", new Object[]{parallelism, minParallelism, minParallelism, jobVertexId});
            parallelism = minParallelism;
        } else if (parallelism > maxParallelism) {
            LOG.info("The initially normalized parallelism {} is larger than the normalized maximum parallelism {}. Use {} as the finally decided parallelism of job vertex {}.", new Object[]{parallelism, maxParallelism, maxParallelism, jobVertexId});
            parallelism = maxParallelism;
        }
        return parallelism;
    }

    @VisibleForTesting
    int getGlobalMaxParallelism() {
        return this.globalMaxParallelism;
    }

    @VisibleForTesting
    int getGlobalMinParallelism() {
        return this.globalMinParallelism;
    }

    static DefaultVertexParallelismDecider from(Configuration configuration) {
        int minParallelism;
        int maxParallelism = DefaultVertexParallelismDecider.getNormalizedMaxParallelism(configuration);
        Preconditions.checkState((maxParallelism >= (minParallelism = DefaultVertexParallelismDecider.getNormalizedMinParallelism(configuration)) ? 1 : 0) != 0, (Object)String.format("Invalid configuration: '%s' should be greater than or equal to '%s' and the range must contain at least one power of 2.", JobManagerOptions.ADAPTIVE_BATCH_SCHEDULER_MAX_PARALLELISM.key(), JobManagerOptions.ADAPTIVE_BATCH_SCHEDULER_MIN_PARALLELISM.key()));
        return new DefaultVertexParallelismDecider(maxParallelism, minParallelism, (MemorySize)configuration.get(JobManagerOptions.ADAPTIVE_BATCH_SCHEDULER_AVG_DATA_VOLUME_PER_TASK), (Integer)configuration.get(JobManagerOptions.ADAPTIVE_BATCH_SCHEDULER_DEFAULT_SOURCE_PARALLELISM));
    }

    private static int getNormalizedMaxParallelism(int maxParallelism) {
        return MathUtils.roundDownToPowerOf2((int)maxParallelism);
    }

    static int getNormalizedMaxParallelism(Configuration configuration) {
        return DefaultVertexParallelismDecider.getNormalizedMaxParallelism(configuration.getInteger(JobManagerOptions.ADAPTIVE_BATCH_SCHEDULER_MAX_PARALLELISM));
    }

    static int getNormalizedMinParallelism(Configuration configuration) {
        return MathUtils.roundUpToPowerOfTwo((int)configuration.getInteger(JobManagerOptions.ADAPTIVE_BATCH_SCHEDULER_MIN_PARALLELISM));
    }

    static int normalizeParallelism(int parallelism) {
        int down = MathUtils.roundDownToPowerOf2((int)parallelism);
        int up = MathUtils.roundUpToPowerOfTwo((int)parallelism);
        return parallelism < (up + down) / 2 ? down : up;
    }
}

