/*
 * Decompiled with CFR 0.152.
 */
package org.apache.seatunnel.engine.server.dag.execution;

import java.net.URL;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import lombok.NonNull;
import org.apache.seatunnel.api.table.catalog.CatalogTable;
import org.apache.seatunnel.common.utils.SeaTunnelException;
import org.apache.seatunnel.engine.common.config.EngineConfig;
import org.apache.seatunnel.engine.common.utils.IdGenerator;
import org.apache.seatunnel.engine.core.dag.actions.AbstractAction;
import org.apache.seatunnel.engine.core.dag.actions.Action;
import org.apache.seatunnel.engine.core.dag.actions.ShuffleAction;
import org.apache.seatunnel.engine.core.dag.actions.ShuffleConfig;
import org.apache.seatunnel.engine.core.dag.actions.ShuffleMultipleRowStrategy;
import org.apache.seatunnel.engine.core.dag.actions.ShuffleStrategy;
import org.apache.seatunnel.engine.core.dag.actions.SinkAction;
import org.apache.seatunnel.engine.core.dag.actions.SinkConfig;
import org.apache.seatunnel.engine.core.dag.actions.SourceAction;
import org.apache.seatunnel.engine.core.dag.actions.TransformAction;
import org.apache.seatunnel.engine.core.dag.actions.TransformChainAction;
import org.apache.seatunnel.engine.core.dag.actions.UnknownActionException;
import org.apache.seatunnel.engine.core.dag.logical.LogicalDag;
import org.apache.seatunnel.engine.core.dag.logical.LogicalEdge;
import org.apache.seatunnel.engine.core.dag.logical.LogicalVertex;
import org.apache.seatunnel.engine.core.job.ConnectorJarIdentifier;
import org.apache.seatunnel.engine.core.job.JobImmutableInformation;
import org.apache.seatunnel.engine.server.dag.execution.ExecutionEdge;
import org.apache.seatunnel.engine.server.dag.execution.ExecutionPlan;
import org.apache.seatunnel.engine.server.dag.execution.ExecutionVertex;
import org.apache.seatunnel.engine.server.dag.execution.Pipeline;
import org.apache.seatunnel.engine.server.dag.execution.PipelineGenerator;
import org.apache.seatunnel.shade.com.google.common.base.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ExecutionPlanGenerator {
    private static final Logger log = LoggerFactory.getLogger(ExecutionPlanGenerator.class);
    private final LogicalDag logicalPlan;
    private final JobImmutableInformation jobImmutableInformation;
    private final EngineConfig engineConfig;
    private final IdGenerator idGenerator = new IdGenerator();

    public ExecutionPlanGenerator(@NonNull LogicalDag logicalPlan, @NonNull JobImmutableInformation jobImmutableInformation, @NonNull EngineConfig engineConfig) {
        if (logicalPlan == null) {
            throw new NullPointerException("logicalPlan is marked non-null but is null");
        }
        if (jobImmutableInformation == null) {
            throw new NullPointerException("jobImmutableInformation is marked non-null but is null");
        }
        if (engineConfig == null) {
            throw new NullPointerException("engineConfig is marked non-null but is null");
        }
        Preconditions.checkArgument(logicalPlan.getEdges().size() > 0, "ExecutionPlan Builder must have LogicalPlan.");
        this.logicalPlan = logicalPlan;
        this.jobImmutableInformation = jobImmutableInformation;
        this.engineConfig = engineConfig;
    }

    public ExecutionPlan generate() {
        log.debug("Generate execution plan using logical plan:");
        Set<ExecutionEdge> executionEdges = this.generateExecutionEdges(this.logicalPlan.getEdges());
        log.debug("Phase 1: generate execution edge list {}", (Object)executionEdges);
        executionEdges = this.generateShuffleEdges(executionEdges);
        log.debug("Phase 2: generate shuffle edge list {}", (Object)executionEdges);
        executionEdges = this.generateTransformChainEdges(executionEdges);
        log.debug("Phase 3: generate transform chain edge list {}", (Object)executionEdges);
        List<Pipeline> pipelines = this.generatePipelines(executionEdges);
        log.debug("Phase 4: generate pipeline list {}", (Object)pipelines);
        ExecutionPlan executionPlan = new ExecutionPlan(pipelines, this.jobImmutableInformation);
        log.debug("Phase 5: generate execution plan: {}", (Object)executionPlan);
        return executionPlan;
    }

    public static Action recreateAction(Action action, Long id, int parallelism) {
        AbstractAction newAction;
        if (action instanceof ShuffleAction) {
            newAction = new ShuffleAction(id, action.getName(), ((ShuffleAction)action).getConfig());
        } else if (action instanceof SinkAction) {
            newAction = new SinkAction(id, action.getName(), new ArrayList<Action>(), ((SinkAction)action).getSink(), action.getJarUrls(), action.getConnectorJarIdentifiers(), (SinkConfig)action.getConfig());
        } else if (action instanceof SourceAction) {
            newAction = new SourceAction((long)id, action.getName(), ((SourceAction)action).getSource(), action.getJarUrls(), action.getConnectorJarIdentifiers());
        } else if (action instanceof TransformAction) {
            newAction = new TransformAction((long)id, action.getName(), ((TransformAction)action).getTransform(), action.getJarUrls(), action.getConnectorJarIdentifiers());
        } else if (action instanceof TransformChainAction) {
            newAction = new TransformChainAction((long)id, action.getName(), action.getJarUrls(), action.getConnectorJarIdentifiers(), ((TransformChainAction)action).getTransforms());
        } else {
            throw new UnknownActionException(action);
        }
        newAction.setParallelism(parallelism);
        return newAction;
    }

    private Set<ExecutionEdge> generateExecutionEdges(Set<LogicalEdge> logicalEdges) {
        LinkedHashSet<ExecutionEdge> executionEdges = new LinkedHashSet<ExecutionEdge>();
        HashMap<Long, ExecutionVertex> logicalVertexIdToExecutionVertexMap = new HashMap<Long, ExecutionVertex>();
        ArrayList<LogicalEdge> sortedLogicalEdges = new ArrayList<LogicalEdge>(logicalEdges);
        Collections.sort(sortedLogicalEdges, (o1, o2) -> {
            if (!o1.getInputVertexId().equals(o2.getInputVertexId())) {
                return o1.getInputVertexId() > o2.getInputVertexId() ? 1 : -1;
            }
            if (!o1.getTargetVertexId().equals(o2.getTargetVertexId())) {
                return o1.getTargetVertexId() > o2.getTargetVertexId() ? 1 : -1;
            }
            return 0;
        });
        for (LogicalEdge logicalEdge : sortedLogicalEdges) {
            LogicalVertex logicalInputVertex = logicalEdge.getInputVertex();
            ExecutionVertex executionInputVertex = logicalVertexIdToExecutionVertexMap.computeIfAbsent(logicalInputVertex.getVertexId(), vertexId -> {
                long newId = this.idGenerator.getNextId();
                Action newLogicalInputAction = ExecutionPlanGenerator.recreateAction(logicalInputVertex.getAction(), newId, logicalInputVertex.getParallelism());
                return new ExecutionVertex(newId, newLogicalInputAction, logicalInputVertex.getParallelism());
            });
            LogicalVertex logicalTargetVertex = logicalEdge.getTargetVertex();
            ExecutionVertex executionTargetVertex = logicalVertexIdToExecutionVertexMap.computeIfAbsent(logicalTargetVertex.getVertexId(), vertexId -> {
                long newId = this.idGenerator.getNextId();
                Action newLogicalTargetAction = ExecutionPlanGenerator.recreateAction(logicalTargetVertex.getAction(), newId, logicalTargetVertex.getParallelism());
                return new ExecutionVertex(newId, newLogicalTargetAction, logicalTargetVertex.getParallelism());
            });
            ExecutionEdge executionEdge = new ExecutionEdge(executionInputVertex, executionTargetVertex);
            executionEdges.add(executionEdge);
        }
        return executionEdges;
    }

    private Set<ExecutionEdge> generateShuffleEdges(Set<ExecutionEdge> executionEdges) {
        LinkedHashMap targetVerticesMap = new LinkedHashMap();
        HashSet sourceExecutionVertices = new HashSet();
        executionEdges.forEach(edge -> {
            ExecutionVertex leftVertex = edge.getLeftVertex();
            ExecutionVertex rightVertex = edge.getRightVertex();
            if (leftVertex.getAction() instanceof SourceAction) {
                sourceExecutionVertices.add(leftVertex);
            }
            targetVerticesMap.computeIfAbsent(leftVertex.getVertexId(), id -> new ArrayList()).add(rightVertex);
        });
        if (sourceExecutionVertices.size() != 1) {
            return executionEdges;
        }
        ExecutionVertex sourceExecutionVertex = (ExecutionVertex)sourceExecutionVertices.stream().findFirst().get();
        Action sourceAction = sourceExecutionVertex.getAction();
        ArrayList<CatalogTable> producedCatalogTables = new ArrayList();
        if (sourceAction instanceof SourceAction) {
            try {
                producedCatalogTables = ((SourceAction)sourceAction).getSource().getProducedCatalogTables();
            }
            catch (UnsupportedOperationException unsupportedOperationException) {}
        } else {
            if (sourceAction instanceof TransformChainAction) {
                return executionEdges;
            }
            throw new SeaTunnelException("source action must be SourceAction or TransformChainAction");
        }
        if (producedCatalogTables.size() <= 1 || ((List)targetVerticesMap.get(sourceExecutionVertex.getVertexId())).size() <= 1) {
            return executionEdges;
        }
        List sinkVertices = (List)targetVerticesMap.get(sourceExecutionVertex.getVertexId());
        Optional<ExecutionVertex> hasOtherAction = sinkVertices.stream().filter(vertex -> !(vertex.getAction() instanceof SinkAction)).findFirst();
        Preconditions.checkArgument(!hasOtherAction.isPresent());
        LinkedHashSet<ExecutionEdge> newExecutionEdges = new LinkedHashSet<ExecutionEdge>();
        ShuffleStrategy shuffleStrategy = ((ShuffleMultipleRowStrategy.ShuffleMultipleRowStrategyBuilder)((ShuffleStrategy.ShuffleStrategyBuilder)((ShuffleMultipleRowStrategy.ShuffleMultipleRowStrategyBuilder)((ShuffleMultipleRowStrategy.ShuffleMultipleRowStrategyBuilder)ShuffleMultipleRowStrategy.builder().jobId(this.jobImmutableInformation.getJobId())).inputPartitions(sourceAction.getParallelism())).catalogTables(producedCatalogTables)).queueEmptyQueueTtl((int)(this.engineConfig.getCheckpointConfig().getCheckpointInterval() * 3L))).build();
        ShuffleConfig shuffleConfig = ShuffleConfig.builder().shuffleStrategy(shuffleStrategy).build();
        long shuffleVertexId = this.idGenerator.getNextId();
        String shuffleActionName = String.format("Shuffle [%s]", sourceAction.getName());
        ShuffleAction shuffleAction = new ShuffleAction(shuffleVertexId, shuffleActionName, shuffleConfig);
        shuffleAction.setParallelism(sourceAction.getParallelism());
        ExecutionVertex shuffleVertex = new ExecutionVertex(shuffleVertexId, shuffleAction, shuffleAction.getParallelism());
        ExecutionEdge sourceToShuffleEdge = new ExecutionEdge(sourceExecutionVertex, shuffleVertex);
        newExecutionEdges.add(sourceToShuffleEdge);
        for (ExecutionVertex sinkVertex : sinkVertices) {
            sinkVertex.setParallelism(1);
            sinkVertex.getAction().setParallelism(1);
            ExecutionEdge shuffleToSinkEdge = new ExecutionEdge(shuffleVertex, sinkVertex);
            newExecutionEdges.add(shuffleToSinkEdge);
        }
        return newExecutionEdges;
    }

    private Set<ExecutionEdge> generateTransformChainEdges(Set<ExecutionEdge> executionEdges) {
        HashMap inputVerticesMap = new HashMap();
        HashMap targetVerticesMap = new HashMap();
        HashSet sourceExecutionVertices = new HashSet();
        executionEdges.forEach(edge -> {
            ExecutionVertex leftVertex = edge.getLeftVertex();
            ExecutionVertex rightVertex = edge.getRightVertex();
            if (leftVertex.getAction() instanceof SourceAction) {
                sourceExecutionVertices.add(leftVertex);
            }
            inputVerticesMap.computeIfAbsent(rightVertex.getVertexId(), id -> new ArrayList()).add(leftVertex);
            targetVerticesMap.computeIfAbsent(leftVertex.getVertexId(), id -> new ArrayList()).add(rightVertex);
        });
        HashMap<Long, ExecutionVertex> transformChainVertexMap = new HashMap<Long, ExecutionVertex>();
        HashMap<Long, Long> chainedTransformVerticesMapping = new HashMap<Long, Long>();
        for (ExecutionVertex sourceVertex : sourceExecutionVertices) {
            ArrayList<ExecutionVertex> vertices = new ArrayList<ExecutionVertex>();
            vertices.add(sourceVertex);
            for (int index = 0; index < vertices.size(); ++index) {
                ExecutionVertex vertex = (ExecutionVertex)vertices.get(index);
                this.fillChainedTransformExecutionVertex(vertex, chainedTransformVerticesMapping, transformChainVertexMap, executionEdges, Collections.unmodifiableMap(inputVerticesMap), Collections.unmodifiableMap(targetVerticesMap));
                if (!targetVerticesMap.containsKey(vertex.getVertexId())) continue;
                vertices.addAll((Collection)targetVerticesMap.get(vertex.getVertexId()));
            }
        }
        LinkedHashSet<ExecutionEdge> transformChainEdges = new LinkedHashSet<ExecutionEdge>();
        for (ExecutionEdge executionEdge : executionEdges) {
            ExecutionVertex leftVertex = executionEdge.getLeftVertex();
            ExecutionVertex rightVertex = executionEdge.getRightVertex();
            boolean needRebuild = false;
            if (chainedTransformVerticesMapping.containsKey(leftVertex.getVertexId())) {
                needRebuild = true;
                leftVertex = (ExecutionVertex)transformChainVertexMap.get(chainedTransformVerticesMapping.get(leftVertex.getVertexId()));
            }
            if (chainedTransformVerticesMapping.containsKey(rightVertex.getVertexId())) {
                needRebuild = true;
                rightVertex = (ExecutionVertex)transformChainVertexMap.get(chainedTransformVerticesMapping.get(rightVertex.getVertexId()));
            }
            if (needRebuild) {
                executionEdge = new ExecutionEdge(leftVertex, rightVertex);
            }
            transformChainEdges.add(executionEdge);
        }
        return transformChainEdges;
    }

    private void fillChainedTransformExecutionVertex(ExecutionVertex currentVertex, Map<Long, Long> chainedTransformVerticesMapping, Map<Long, ExecutionVertex> transformChainVertexMap, Set<ExecutionEdge> executionEdges, Map<Long, List<ExecutionVertex>> inputVerticesMap, Map<Long, List<ExecutionVertex>> targetVerticesMap) {
        if (chainedTransformVerticesMapping.containsKey(currentVertex.getVertexId())) {
            return;
        }
        ArrayList<ExecutionVertex> transformChainedVertices = new ArrayList<ExecutionVertex>();
        this.collectChainedVertices(currentVertex, transformChainedVertices, executionEdges, inputVerticesMap, targetVerticesMap);
        if (transformChainedVertices.size() > 0) {
            long newVertexId = this.idGenerator.getNextId();
            ArrayList transforms = new ArrayList(transformChainedVertices.size());
            ArrayList names = new ArrayList(transformChainedVertices.size());
            HashSet<URL> jars = new HashSet<URL>();
            HashSet<ConnectorJarIdentifier> identifiers = new HashSet<ConnectorJarIdentifier>();
            transformChainedVertices.stream().peek(vertex -> chainedTransformVerticesMapping.put(vertex.getVertexId(), newVertexId)).map(ExecutionVertex::getAction).map(action -> (TransformAction)action).forEach(action -> {
                transforms.add(action.getTransform());
                jars.addAll(action.getJarUrls());
                identifiers.addAll(action.getConnectorJarIdentifiers());
                names.add(action.getName());
            });
            String transformChainActionName = String.format("TransformChain[%s]", String.join((CharSequence)"->", names));
            TransformChainAction transformChainAction = new TransformChainAction(newVertexId, transformChainActionName, jars, identifiers, transforms);
            transformChainAction.setParallelism(currentVertex.getAction().getParallelism());
            ExecutionVertex executionVertex = new ExecutionVertex(newVertexId, transformChainAction, currentVertex.getParallelism());
            transformChainVertexMap.put(newVertexId, executionVertex);
            chainedTransformVerticesMapping.put(currentVertex.getVertexId(), executionVertex.getVertexId());
        }
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    private void collectChainedVertices(ExecutionVertex currentVertex, List<ExecutionVertex> chainedVertices, Set<ExecutionEdge> executionEdges, Map<Long, List<ExecutionVertex>> inputVerticesMap, Map<Long, List<ExecutionVertex>> targetVerticesMap) {
        Action action = currentVertex.getAction();
        if (!(action instanceof TransformAction)) return;
        if (chainedVertices.size() == 0) {
            chainedVertices.add(currentVertex);
        } else {
            if (inputVerticesMap.get(currentVertex.getVertexId()).size() != 1) return;
            executionEdges.remove(new ExecutionEdge(chainedVertices.get(chainedVertices.size() - 1), currentVertex));
            chainedVertices.add(currentVertex);
        }
        if (targetVerticesMap.get(currentVertex.getVertexId()).size() != 1) return;
        this.collectChainedVertices(targetVerticesMap.get(currentVertex.getVertexId()).get(0), chainedVertices, executionEdges, inputVerticesMap, targetVerticesMap);
    }

    private List<Pipeline> generatePipelines(Set<ExecutionEdge> executionEdges) {
        LinkedHashSet<ExecutionVertex> executionVertices = new LinkedHashSet<ExecutionVertex>();
        for (ExecutionEdge edge : executionEdges) {
            executionVertices.add(edge.getLeftVertex());
            executionVertices.add(edge.getRightVertex());
        }
        PipelineGenerator pipelineGenerator = new PipelineGenerator(executionVertices, new ArrayList<ExecutionEdge>(executionEdges));
        List<Pipeline> pipelines = pipelineGenerator.generatePipelines();
        HashSet<String> duplicatedActionNames = new HashSet<String>();
        HashSet<String> actionNames = new HashSet<String>();
        for (Pipeline pipeline : pipelines) {
            Integer pipelineId = pipeline.getId();
            for (ExecutionVertex vertex : pipeline.getVertexes().values()) {
                Action action = vertex.getAction();
                String actionName = String.format("pipeline-%s [%s]", pipelineId, action.getName());
                action.setName(actionName);
                if (actionNames.contains(actionName)) {
                    duplicatedActionNames.add(actionName);
                }
                actionNames.add(actionName);
            }
        }
        Preconditions.checkArgument(duplicatedActionNames.isEmpty(), "Action name is duplicated: " + duplicatedActionNames);
        return pipelines;
    }
}

