/*
 * Decompiled with CFR 0.152.
 */
package io.trino.execution.scheduler;

import com.google.common.base.Preconditions;
import io.trino.execution.NodeTaskMap;
import io.trino.execution.PartitionedSplitsInfo;
import io.trino.execution.RemoteTask;
import io.trino.execution.scheduler.NodeMap;
import io.trino.metadata.InternalNode;
import io.trino.spi.SplitWeight;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;

public final class NodeAssignmentStats {
    private final NodeTaskMap nodeTaskMap;
    private final Map<InternalNode, PartitionedSplitsInfo> nodeTotalSplitsInfo;
    private final Map<String, PendingSplitInfo> stageQueuedSplitInfo;

    public NodeAssignmentStats(NodeTaskMap nodeTaskMap, NodeMap nodeMap, List<RemoteTask> existingTasks) {
        this.nodeTaskMap = Objects.requireNonNull(nodeTaskMap, "nodeTaskMap is null");
        int nodeMapSize = Objects.requireNonNull(nodeMap, "nodeMap is null").getNodesByHostAndPort().size();
        this.nodeTotalSplitsInfo = new HashMap<InternalNode, PartitionedSplitsInfo>(nodeMapSize);
        this.stageQueuedSplitInfo = new HashMap<String, PendingSplitInfo>(nodeMapSize);
        for (RemoteTask task : existingTasks) {
            Preconditions.checkArgument((this.stageQueuedSplitInfo.put(task.getNodeId(), new PendingSplitInfo(task.getQueuedPartitionedSplitsInfo(), task.getUnacknowledgedPartitionedSplitCount())) == null ? 1 : 0) != 0, (Object)"A single stage may not have multiple tasks running on the same node");
        }
        if (existingTasks.size() < nodeMapSize) {
            Function<String, PendingSplitInfo> createEmptySplitInfo = ignored -> new PendingSplitInfo(PartitionedSplitsInfo.forZeroSplits(), 0);
            for (InternalNode node : nodeMap.getNodesByHostAndPort().values()) {
                this.stageQueuedSplitInfo.computeIfAbsent(node.getNodeIdentifier(), createEmptySplitInfo);
            }
        }
    }

    public long getTotalSplitsWeight(InternalNode node) {
        PartitionedSplitsInfo nodeTotalSplits = this.nodeTotalSplitsInfo.computeIfAbsent(node, this.nodeTaskMap::getPartitionedSplitsOnNode);
        PendingSplitInfo stageInfo = this.stageQueuedSplitInfo.get(node.getNodeIdentifier());
        if (stageInfo == null) {
            return nodeTotalSplits.getWeightSum();
        }
        return Math.addExact(nodeTotalSplits.getWeightSum(), stageInfo.getAssignedSplitsWeight());
    }

    public long getQueuedSplitsWeightForStage(InternalNode node) {
        PendingSplitInfo stageInfo = this.stageQueuedSplitInfo.get(node.getNodeIdentifier());
        return stageInfo == null ? 0L : stageInfo.getQueuedSplitsWeight();
    }

    public int getUnacknowledgedSplitCountForStage(InternalNode node) {
        PendingSplitInfo stageInfo = this.stageQueuedSplitInfo.get(node.getNodeIdentifier());
        return stageInfo == null ? 0 : stageInfo.getUnacknowledgedSplitCount();
    }

    public void addAssignedSplit(InternalNode node, SplitWeight splitWeight) {
        this.getOrCreateStageSplitInfo(node).addAssignedSplit(splitWeight);
    }

    public void removeAssignedSplit(InternalNode node, SplitWeight splitWeight) {
        this.getOrCreateStageSplitInfo(node).removeAssignedSplit(splitWeight);
    }

    private PendingSplitInfo getOrCreateStageSplitInfo(InternalNode node) {
        String nodeId = node.getNodeIdentifier();
        PendingSplitInfo stageInfo = this.stageQueuedSplitInfo.get(nodeId);
        if (stageInfo == null) {
            stageInfo = new PendingSplitInfo(PartitionedSplitsInfo.forZeroSplits(), 0);
            this.stageQueuedSplitInfo.put(nodeId, stageInfo);
        }
        return stageInfo;
    }

    private static final class PendingSplitInfo {
        private final int queuedSplitCount;
        private final long queuedSplitsWeight;
        private final int unacknowledgedSplitCount;
        private int assignedSplits;
        private long assignedSplitsWeight;

        private PendingSplitInfo(PartitionedSplitsInfo queuedSplitsInfo, int unacknowledgedSplitCount) {
            this.queuedSplitCount = Objects.requireNonNull(queuedSplitsInfo, "queuedSplitsInfo is null").getCount();
            this.queuedSplitsWeight = queuedSplitsInfo.getWeightSum();
            this.unacknowledgedSplitCount = unacknowledgedSplitCount;
        }

        public int getAssignedSplitCount() {
            return this.assignedSplits;
        }

        public long getAssignedSplitsWeight() {
            return this.assignedSplitsWeight;
        }

        public int getQueuedSplitCount() {
            return this.queuedSplitCount + this.assignedSplits;
        }

        public long getQueuedSplitsWeight() {
            return Math.addExact(this.queuedSplitsWeight, this.assignedSplitsWeight);
        }

        public int getUnacknowledgedSplitCount() {
            return this.unacknowledgedSplitCount + this.assignedSplits;
        }

        public void addAssignedSplit(SplitWeight splitWeight) {
            ++this.assignedSplits;
            this.assignedSplitsWeight = Math.addExact(this.assignedSplitsWeight, splitWeight.getRawValue());
        }

        public void removeAssignedSplit(SplitWeight splitWeight) {
            --this.assignedSplits;
            this.assignedSplitsWeight -= splitWeight.getRawValue();
        }
    }
}

