/*
 * Decompiled with CFR 0.152.
 */
package org.apache.uniffle.shuffle.manager;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;
import org.apache.hadoop.conf.Configuration;
import org.apache.spark.MapOutputTracker;
import org.apache.spark.MapOutputTrackerMaster;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkEnv;
import org.apache.spark.SparkException;
import org.apache.spark.shuffle.RssSparkConfig;
import org.apache.spark.shuffle.RssSparkShuffleUtils;
import org.apache.spark.shuffle.ShuffleManager;
import org.apache.spark.shuffle.SparkVersionUtils;
import org.apache.uniffle.client.api.CoordinatorClient;
import org.apache.uniffle.client.factory.CoordinatorClientFactory;
import org.apache.uniffle.client.request.RssFetchClientConfRequest;
import org.apache.uniffle.client.response.RssFetchClientConfResponse;
import org.apache.uniffle.com.google.common.annotations.VisibleForTesting;
import org.apache.uniffle.com.google.common.collect.Maps;
import org.apache.uniffle.common.ClientType;
import org.apache.uniffle.common.RemoteStorageInfo;
import org.apache.uniffle.common.config.ConfigOption;
import org.apache.uniffle.common.config.RssClientConf;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.rpc.StatusCode;
import org.apache.uniffle.shuffle.manager.RssShuffleManagerInterface;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class RssShuffleManagerBase
implements RssShuffleManagerInterface,
ShuffleManager {
    private static final Logger LOG = LoggerFactory.getLogger(RssShuffleManagerBase.class);
    private AtomicBoolean isInitialized = new AtomicBoolean(false);
    private Method unregisterAllMapOutputMethod;
    private Method registerShuffleMethod;

    public abstract void configureBlockIdLayout(SparkConf var1, RssConf var2);

    @VisibleForTesting
    protected static void configureBlockIdLayout(SparkConf sparkConf, RssConf rssConf, int maxFailures, boolean speculation) {
        if (sparkConf.contains(RssSparkConfig.RSS_MAX_PARTITIONS.key())) {
            RssShuffleManagerBase.configureBlockIdLayoutFromMaxPartitions(sparkConf, rssConf, maxFailures, speculation);
        } else {
            RssShuffleManagerBase.configureBlockIdLayoutFromLayoutConfig(sparkConf, rssConf, maxFailures, speculation);
        }
    }

    private static void configureBlockIdLayoutFromMaxPartitions(SparkConf sparkConf, RssConf rssConf, int maxFailures, boolean speculation) {
        int maxPartitions = sparkConf.getInt(RssSparkConfig.RSS_MAX_PARTITIONS.key(), ((Integer)RssSparkConfig.RSS_MAX_PARTITIONS.defaultValue().get()).intValue());
        if (maxPartitions <= 1) {
            throw new IllegalArgumentException("Value of " + RssSparkConfig.RSS_MAX_PARTITIONS.key() + " must be larger than 1: " + maxPartitions);
        }
        int attemptIdBits = RssShuffleManagerBase.getAttemptIdBits(RssShuffleManagerBase.getMaxAttemptNo(maxFailures, speculation));
        int partitionIdBits = 32 - Integer.numberOfLeadingZeros(maxPartitions - 1);
        int taskAttemptIdBits = partitionIdBits + attemptIdBits;
        int sequenceNoBits = 63 - partitionIdBits - taskAttemptIdBits;
        if (taskAttemptIdBits > 31) {
            throw new IllegalArgumentException("Cannot support " + RssSparkConfig.RSS_MAX_PARTITIONS.key() + "=" + maxPartitions + " partitions, as this would require to reserve more than 31 bits in the block id for task attempt ids. With spark.maxFailures=" + maxFailures + " and spark.speculation=" + (speculation ? "true" : "false") + " at most " + (1 << 31 - attemptIdBits) + " partitions can be supported.");
        }
        if (sequenceNoBits > 31) {
            int spareBits = sequenceNoBits - 31;
            spareBits += spareBits % 2;
            taskAttemptIdBits += spareBits / 2;
            maxPartitions = 1 << (partitionIdBits += spareBits / 2);
            if (LOG.isInfoEnabled()) {
                LOG.info("Increasing " + RssSparkConfig.RSS_MAX_PARTITIONS.key() + " to " + maxPartitions + ", otherwise we would have to support 2^" + sequenceNoBits + " (more than 2^31) sequence numbers.");
            }
            sequenceNoBits -= spareBits;
            sparkConf.set(RssSparkConfig.RSS_MAX_PARTITIONS.key(), String.valueOf(maxPartitions));
        }
        rssConf.set(RssClientConf.BLOCKID_SEQUENCE_NO_BITS, sequenceNoBits);
        rssConf.set(RssClientConf.BLOCKID_PARTITION_ID_BITS, partitionIdBits);
        rssConf.set(RssClientConf.BLOCKID_TASK_ATTEMPT_ID_BITS, taskAttemptIdBits);
        sparkConf.set("spark." + RssClientConf.BLOCKID_SEQUENCE_NO_BITS.key(), String.valueOf(sequenceNoBits));
        sparkConf.set("spark." + RssClientConf.BLOCKID_PARTITION_ID_BITS.key(), String.valueOf(partitionIdBits));
        sparkConf.set("spark." + RssClientConf.BLOCKID_TASK_ATTEMPT_ID_BITS.key(), String.valueOf(taskAttemptIdBits));
    }

    private static void configureBlockIdLayoutFromLayoutConfig(SparkConf sparkConf, RssConf rssConf, int maxFailures, boolean speculation) {
        String sparkPrefix = "spark.";
        String sparkSeqNoBitsKey = sparkPrefix + RssClientConf.BLOCKID_SEQUENCE_NO_BITS.key();
        String sparkPartIdBitsKey = sparkPrefix + RssClientConf.BLOCKID_PARTITION_ID_BITS.key();
        String sparkTaskIdBitsKey = sparkPrefix + RssClientConf.BLOCKID_TASK_ATTEMPT_ID_BITS.key();
        List<String> sparkKeys = Arrays.asList(sparkSeqNoBitsKey, sparkPartIdBitsKey, sparkTaskIdBitsKey);
        if (sparkKeys.stream().anyMatch(arg_0 -> ((SparkConf)sparkConf).contains(arg_0))) {
            if (!sparkKeys.stream().allMatch(arg_0 -> ((SparkConf)sparkConf).contains(arg_0))) {
                String allKeys = sparkKeys.stream().collect(Collectors.joining(", "));
                String existingKeys = Arrays.stream(sparkConf.getAll()).map(t -> (String)t._1).filter(sparkKeys.stream().collect(Collectors.toSet())::contains).collect(Collectors.joining(", "));
                throw new IllegalArgumentException("All block id bit config keys must be provided (" + allKeys + "), not just a sub-set: " + existingKeys);
            }
        }
        List<ConfigOption> rssKeys = Arrays.asList(RssClientConf.BLOCKID_SEQUENCE_NO_BITS, RssClientConf.BLOCKID_PARTITION_ID_BITS, RssClientConf.BLOCKID_TASK_ATTEMPT_ID_BITS);
        if (rssKeys.stream().anyMatch(rssConf::contains)) {
            if (!rssKeys.stream().allMatch(rssConf::contains)) {
                String allKeys = rssKeys.stream().map(ConfigOption::key).collect(Collectors.joining(", "));
                String existingKeys = rssConf.getKeySet().stream().filter(rssKeys.stream().map(ConfigOption::key).collect(Collectors.toSet())::contains).collect(Collectors.joining(", "));
                throw new IllegalArgumentException("All block id bit config keys must be provided (" + allKeys + "), not just a sub-set: " + existingKeys);
            }
        }
        if (sparkKeys.stream().allMatch(arg_0 -> ((SparkConf)sparkConf).contains(arg_0))) {
            rssConf.set(RssClientConf.BLOCKID_SEQUENCE_NO_BITS, sparkConf.getInt(sparkSeqNoBitsKey, 0));
            rssConf.set(RssClientConf.BLOCKID_PARTITION_ID_BITS, sparkConf.getInt(sparkPartIdBitsKey, 0));
            rssConf.set(RssClientConf.BLOCKID_TASK_ATTEMPT_ID_BITS, sparkConf.getInt(sparkTaskIdBitsKey, 0));
        } else if (rssKeys.stream().allMatch(rssConf::contains)) {
            sparkConf.set(sparkSeqNoBitsKey, rssConf.getValue(RssClientConf.BLOCKID_SEQUENCE_NO_BITS));
            sparkConf.set(sparkPartIdBitsKey, rssConf.getValue(RssClientConf.BLOCKID_PARTITION_ID_BITS));
            sparkConf.set(sparkTaskIdBitsKey, rssConf.getValue(RssClientConf.BLOCKID_TASK_ATTEMPT_ID_BITS));
        } else {
            sparkConf.set(RssSparkConfig.RSS_MAX_PARTITIONS.key(), RssSparkConfig.RSS_MAX_PARTITIONS.defaultValueString());
            RssShuffleManagerBase.configureBlockIdLayoutFromMaxPartitions(sparkConf, rssConf, maxFailures, speculation);
        }
    }

    protected static int getMaxAttemptNo(int maxFailures, boolean speculation) {
        int maxAttemptNo;
        int n = maxAttemptNo = maxFailures < 1 ? 0 : maxFailures - 1;
        if (speculation) {
            ++maxAttemptNo;
        }
        return maxAttemptNo;
    }

    protected static int getAttemptIdBits(int maxAttemptNo) {
        return 32 - Integer.numberOfLeadingZeros(maxAttemptNo);
    }

    public abstract long getTaskAttemptIdForBlockId(int var1, int var2);

    protected static long getTaskAttemptIdForBlockId(int mapIndex, int attemptNo, int maxFailures, boolean speculation, int maxTaskAttemptIdBits) {
        int maxAttemptNo = RssShuffleManagerBase.getMaxAttemptNo(maxFailures, speculation);
        int attemptBits = RssShuffleManagerBase.getAttemptIdBits(maxAttemptNo);
        if (attemptNo > maxAttemptNo) {
            throw new RssException("Observing attempt number " + attemptNo + " while maxFailures is set to " + maxFailures + (speculation ? " with speculation enabled" : "") + ".");
        }
        int mapIndexBits = 32 - Integer.numberOfLeadingZeros(mapIndex);
        if (mapIndexBits + attemptBits > maxTaskAttemptIdBits) {
            throw new RssException("Observing mapIndex[" + mapIndex + "] that would produce a taskAttemptId with " + (mapIndexBits + attemptBits) + " bits which is larger than the allowed " + maxTaskAttemptIdBits + " bits (maxFailures[" + maxFailures + "], speculation[" + speculation + "]). Please consider providing more bits for taskAttemptIds.");
        }
        return (long)mapIndex << attemptBits | (long)attemptNo;
    }

    protected static void fetchAndApplyDynamicConf(SparkConf sparkConf) {
        String clientType = (String)sparkConf.get(RssSparkConfig.RSS_CLIENT_TYPE);
        String coordinators = sparkConf.get(RssSparkConfig.RSS_COORDINATOR_QUORUM.key());
        CoordinatorClientFactory coordinatorClientFactory = CoordinatorClientFactory.getInstance();
        List<CoordinatorClient> coordinatorClients = coordinatorClientFactory.createCoordinatorClient(ClientType.valueOf(clientType), coordinators);
        int timeoutMs = sparkConf.getInt(RssSparkConfig.RSS_ACCESS_TIMEOUT_MS.key(), ((Integer)RssSparkConfig.RSS_ACCESS_TIMEOUT_MS.defaultValue().get()).intValue());
        for (CoordinatorClient client : coordinatorClients) {
            RssFetchClientConfResponse response = client.fetchClientConf(new RssFetchClientConfRequest(timeoutMs));
            if (response.getStatusCode() == StatusCode.SUCCESS) {
                LOG.info("Success to get conf from {}", (Object)client.getDesc());
                RssSparkShuffleUtils.applyDynamicClientConf(sparkConf, response.getClientConf());
                break;
            }
            LOG.warn("Fail to get conf from {}", (Object)client.getDesc());
        }
        coordinatorClients.forEach(CoordinatorClient::close);
    }

    @Override
    public void unregisterAllMapOutput(int shuffleId) throws SparkException {
        if (!RssSparkShuffleUtils.isStageResubmitSupported()) {
            return;
        }
        MapOutputTrackerMaster tracker = RssShuffleManagerBase.getMapOutputTrackerMaster();
        if (this.isInitialized.compareAndSet(false, true)) {
            this.unregisterAllMapOutputMethod = RssShuffleManagerBase.getUnregisterAllMapOutputMethod(tracker);
            this.registerShuffleMethod = RssShuffleManagerBase.getRegisterShuffleMethod(tracker);
        }
        if (this.unregisterAllMapOutputMethod != null) {
            try {
                this.unregisterAllMapOutputMethod.invoke((Object)tracker, shuffleId);
            }
            catch (IllegalAccessException | InvocationTargetException e) {
                throw new RssException("Invoke unregisterAllMapOutput method failed", e);
            }
        } else {
            int numMaps = this.getNumMaps(shuffleId);
            int numReduces = this.getPartitionNum(shuffleId);
            RssShuffleManagerBase.defaultUnregisterAllMapOutput(tracker, this.registerShuffleMethod, shuffleId, numMaps, numReduces);
        }
    }

    private static void defaultUnregisterAllMapOutput(MapOutputTrackerMaster tracker, Method registerShuffle, int shuffleId, int numMaps, int numReduces) throws SparkException {
        if (tracker != null && registerShuffle != null) {
            tracker.unregisterShuffle(shuffleId);
            try {
                if (SparkVersionUtils.MAJOR_VERSION > 3 || SparkVersionUtils.isSpark3() && SparkVersionUtils.MINOR_VERSION >= 2) {
                    registerShuffle.invoke((Object)tracker, shuffleId, numMaps, numReduces);
                }
                registerShuffle.invoke((Object)tracker, shuffleId, numMaps);
            }
            catch (IllegalAccessException | InvocationTargetException e) {
                throw new RssException("Invoke registerShuffle method failed", e);
            }
        } else {
            throw new SparkException("default unregisterAllMapOutput should only be called on the driver side");
        }
        tracker.incrementEpoch();
    }

    private static Method getUnregisterAllMapOutputMethod(MapOutputTrackerMaster tracker) {
        if (tracker != null) {
            Class<?> klass = tracker.getClass();
            Method m = null;
            try {
                if (SparkVersionUtils.isSpark2() && SparkVersionUtils.MINOR_VERSION <= 3) {
                    LOG.warn("Spark version <= 2.3, fallback to default method");
                } else if (SparkVersionUtils.isSpark2()) {
                    m = klass.getDeclaredMethod("unregisterAllMapOutput", Integer.TYPE);
                } else if (SparkVersionUtils.isSpark3() && SparkVersionUtils.MINOR_VERSION <= 1) {
                    m = klass.getDeclaredMethod("unregisterAllMapOutput", Integer.TYPE);
                } else if (SparkVersionUtils.isSpark3()) {
                    m = klass.getDeclaredMethod("unregisterAllMapAndMergeOutput", Integer.TYPE);
                } else {
                    LOG.warn("Unknown spark version({}), fallback to default method", (Object)SparkVersionUtils.SPARK_VERSION);
                }
            }
            catch (NoSuchMethodException e) {
                LOG.warn("Got no such method error when get unregisterAllMapOutput method for spark version({})", (Object)SparkVersionUtils.SPARK_VERSION);
            }
            return m;
        }
        return null;
    }

    private static Method getRegisterShuffleMethod(MapOutputTrackerMaster tracker) {
        if (tracker != null) {
            Class<?> klass = tracker.getClass();
            Method m = null;
            try {
                m = SparkVersionUtils.MAJOR_VERSION > 3 || SparkVersionUtils.isSpark3() && SparkVersionUtils.MINOR_VERSION >= 2 ? klass.getDeclaredMethod("registerShuffle", Integer.TYPE, Integer.TYPE, Integer.TYPE) : klass.getDeclaredMethod("registerShuffle", Integer.TYPE, Integer.TYPE);
            }
            catch (NoSuchMethodException e) {
                LOG.warn("Got no such method error when get registerShuffle method for spark version({})", (Object)SparkVersionUtils.SPARK_VERSION);
            }
            return m;
        }
        return null;
    }

    private static MapOutputTrackerMaster getMapOutputTrackerMaster() {
        MapOutputTracker tracker = Optional.ofNullable(SparkEnv.get()).map(SparkEnv::mapOutputTracker).orElse(null);
        return tracker instanceof MapOutputTrackerMaster ? (MapOutputTrackerMaster)tracker : null;
    }

    private static Map<String, String> parseRemoteStorageConf(Configuration conf) {
        HashMap<String, String> confItems = Maps.newHashMap();
        for (Map.Entry entry : conf) {
            confItems.put((String)entry.getKey(), (String)entry.getValue());
        }
        return confItems;
    }

    protected static RemoteStorageInfo getDefaultRemoteStorageInfo(SparkConf sparkConf) {
        HashMap<String, String> confItems = Maps.newHashMap();
        RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
        if (rssConf.getBoolean(RssClientConf.RSS_CLIENT_REMOTE_STORAGE_USE_LOCAL_CONF_ENABLED)) {
            confItems = RssShuffleManagerBase.parseRemoteStorageConf(new Configuration(true));
        }
        for (String key : rssConf.getKeySet()) {
            String val;
            if (!key.startsWith("rss.hadoop.") || (val = rssConf.getString(key, null)) == null) continue;
            String extractedKey = key.replaceFirst("rss.hadoop.", "");
            confItems.put(extractedKey, val);
        }
        return new RemoteStorageInfo(sparkConf.get(RssSparkConfig.RSS_REMOTE_STORAGE_PATH.key(), ""), confItems);
    }
}

