/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.common.Numbers;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ObjectParser;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RawInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NullInferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ShapPath;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode;
import org.elasticsearch.xpack.core.ml.inference.utils.Statistics;
import org.elasticsearch.xpack.core.ml.job.config.Operator;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

public class TreeInferenceModel
implements InferenceModel {
    private static final Logger LOGGER = LogManager.getLogger(TreeInferenceModel.class);
    public static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(TreeInferenceModel.class);
    private static final ConstructingObjectParser<TreeInferenceModel, Void> PARSER = new ConstructingObjectParser("tree_inference_model", true, a -> new TreeInferenceModel((List)a[0], (List)a[1], a[2] == null ? null : TargetType.fromString((String)a[2]), (List)a[3]));
    private final Node[] nodes;
    private String[] featureNames;
    private final TargetType targetType;
    private List<String> classificationLabels;
    private final double highOrderCategory;
    private final int maxDepth;
    private final int leafSize;
    private volatile boolean preparedForInference = false;

    public static TreeInferenceModel fromXContent(XContentParser parser) {
        return (TreeInferenceModel)PARSER.apply(parser, null);
    }

    TreeInferenceModel(List<String> featureNames, List<NodeBuilder> nodes, @Nullable TargetType targetType, List<String> classificationLabels) {
        this.featureNames = ExceptionsHelper.requireNonNull(featureNames, Tree.FEATURE_NAMES).toArray(new String[0]);
        if (ExceptionsHelper.requireNonNull(nodes, Tree.TREE_STRUCTURE).size() == 0) {
            throw new IllegalArgumentException("[tree_structure] must not be empty");
        }
        this.nodes = (Node[])nodes.stream().map(NodeBuilder::build).toArray(Node[]::new);
        this.targetType = targetType == null ? TargetType.REGRESSION : targetType;
        this.classificationLabels = classificationLabels == null ? null : Collections.unmodifiableList(classificationLabels);
        this.highOrderCategory = this.maxLeafValue();
        int leafSize = 1;
        for (Node node : this.nodes) {
            if (!(node instanceof LeafNode)) continue;
            leafSize = ((LeafNode)node).leafValue.length;
            break;
        }
        this.leafSize = leafSize;
        this.maxDepth = TreeInferenceModel.getDepth(this.nodes, 0, 0);
    }

    @Override
    public String[] getFeatureNames() {
        return this.featureNames;
    }

    @Override
    public TargetType targetType() {
        return this.targetType;
    }

    @Override
    public InferenceResults infer(Map<String, Object> fields, InferenceConfig config, Map<String, String> featureDecoderMap) {
        return this.innerInfer(InferenceModel.extractFeatures(this.featureNames, fields), config, featureDecoderMap);
    }

    @Override
    public InferenceResults infer(double[] features, InferenceConfig config) {
        return this.innerInfer(features, config, Collections.emptyMap());
    }

    private InferenceResults innerInfer(double[] features, InferenceConfig config, Map<String, String> featureDecoderMap) {
        if (!config.isTargetTypeSupported(this.targetType)) {
            throw ExceptionsHelper.badRequestException("Cannot infer using configuration for [{}] when model target_type is [{}]", config.getName(), this.targetType.toString());
        }
        if (!this.preparedForInference) {
            throw ExceptionsHelper.serverError("model is not prepared for inference");
        }
        Object featureImportance = config.requestingImportance() ? (Object)this.featureImportance(features) : new double[][]{};
        return this.buildResult(this.getLeaf(features), (double[][])featureImportance, featureDecoderMap, config);
    }

    private InferenceResults buildResult(double[] value, double[][] featureImportance, Map<String, String> featureDecoderMap, InferenceConfig config) {
        assert (value != null && value.length > 0);
        if (config instanceof NullInferenceConfig) {
            return new RawInferenceResults(value, featureImportance);
        }
        Map<String, double[]> decodedFeatureImportance = config.requestingImportance() ? InferenceHelpers.decodeFeatureImportances(featureDecoderMap, IntStream.range(0, featureImportance.length).boxed().collect(Collectors.toMap(i -> this.featureNames[i], i -> featureImportance[i]))) : Collections.emptyMap();
        switch (this.targetType) {
            case CLASSIFICATION: {
                ClassificationConfig classificationConfig = (ClassificationConfig)config;
                Tuple<InferenceHelpers.TopClassificationValue, List<TopClassEntry>> topClasses = InferenceHelpers.topClasses(this.classificationProbability(value), this.classificationLabels, null, classificationConfig.getNumTopClasses(), classificationConfig.getPredictionFieldType());
                InferenceHelpers.TopClassificationValue classificationValue = (InferenceHelpers.TopClassificationValue)topClasses.v1();
                return new ClassificationInferenceResults((double)classificationValue.getValue(), InferenceHelpers.classificationLabel(classificationValue.getValue(), this.classificationLabels), (List<TopClassEntry>)((List)topClasses.v2()), InferenceHelpers.transformFeatureImportanceClassification(decodedFeatureImportance, this.classificationLabels, classificationConfig.getPredictionFieldType()), config, (Double)classificationValue.getProbability(), (Double)classificationValue.getScore());
            }
            case REGRESSION: {
                return new RegressionInferenceResults(value[0], config, InferenceHelpers.transformFeatureImportanceRegression(decodedFeatureImportance));
            }
        }
        throw new UnsupportedOperationException("unsupported target_type [" + (Object)((Object)this.targetType) + "] for inference on tree model");
    }

    private double[] classificationProbability(double[] inferenceValue) {
        if (inferenceValue.length > 1) {
            return Statistics.softMax(inferenceValue);
        }
        assert (inferenceValue[0] == Math.rint(inferenceValue[0]));
        double maxCategory = this.highOrderCategory;
        assert (maxCategory == Math.rint(maxCategory));
        double[] list = Collections.nCopies(Double.valueOf(maxCategory + 1.0).intValue(), 0.0).stream().mapToDouble(Double::doubleValue).toArray();
        list[Double.valueOf((double)inferenceValue[0]).intValue()] = 1.0;
        return list;
    }

    private double[] getLeaf(double[] features) {
        Node node = this.nodes[0];
        while (!node.isLeaf()) {
            node = this.nodes[node.compare(features)];
        }
        return ((LeafNode)node).leafValue;
    }

    public double[][] featureImportance(double[] fieldValues) {
        double[][] featureImportance = new double[fieldValues.length][this.leafSize];
        for (int i = 0; i < fieldValues.length; ++i) {
            featureImportance[i] = new double[this.leafSize];
        }
        int arrSize = (this.maxDepth + 1) * (this.maxDepth + 2) / 2;
        ShapPath.PathElement[] elements = new ShapPath.PathElement[arrSize];
        for (int i = 0; i < arrSize; ++i) {
            elements[i] = new ShapPath.PathElement();
        }
        double[] scale = new double[arrSize];
        ShapPath initialPath = new ShapPath(elements, scale);
        this.shapRecursive(fieldValues, initialPath, 0, 1.0, 1.0, -1, featureImportance, 0);
        return featureImportance;
    }

    private void shapRecursive(double[] processedFeatures, ShapPath parentSplitPath, int nodeIndex, double parentFractionZero, double parentFractionOne, int parentFeatureIndex, double[][] featureImportance, int nextIndex) {
        ShapPath splitPath = new ShapPath(parentSplitPath, nextIndex);
        Node currNode = this.nodes[nodeIndex];
        nextIndex = splitPath.extend(parentFractionZero, parentFractionOne, parentFeatureIndex, nextIndex);
        if (currNode.isLeaf()) {
            double[] leafValue = ((LeafNode)currNode).leafValue;
            for (int i = 1; i < nextIndex; ++i) {
                int inputColumnIndex = splitPath.featureIndex(i);
                double scaled = splitPath.sumUnwoundPath(i, nextIndex) * (splitPath.fractionOnes(i) - splitPath.fractionZeros(i));
                for (int j = 0; j < leafValue.length; ++j) {
                    double[] dArray = featureImportance[inputColumnIndex];
                    int n = j;
                    dArray[n] = dArray[n] + scaled * leafValue[j];
                }
            }
        } else {
            InnerNode innerNode = (InnerNode)currNode;
            int hotIndex = currNode.compare(processedFeatures);
            int coldIndex = hotIndex == innerNode.leftChild ? innerNode.rightChild : innerNode.leftChild;
            double incomingFractionZero = 1.0;
            double incomingFractionOne = 1.0;
            int splitFeature = innerNode.splitFeature;
            int pathIndex = splitPath.findFeatureIndex(splitFeature, nextIndex);
            if (pathIndex > -1) {
                incomingFractionZero = splitPath.fractionZeros(pathIndex);
                incomingFractionOne = splitPath.fractionOnes(pathIndex);
                nextIndex = splitPath.unwind(pathIndex, nextIndex);
            }
            double hotFractionZero = (double)this.nodes[hotIndex].getNumberSamples() / (double)currNode.getNumberSamples();
            double coldFractionZero = (double)this.nodes[coldIndex].getNumberSamples() / (double)currNode.getNumberSamples();
            this.shapRecursive(processedFeatures, splitPath, hotIndex, incomingFractionZero * hotFractionZero, incomingFractionOne, splitFeature, featureImportance, nextIndex);
            this.shapRecursive(processedFeatures, splitPath, coldIndex, incomingFractionZero * coldFractionZero, 0.0, splitFeature, featureImportance, nextIndex);
        }
    }

    @Override
    public boolean supportsFeatureImportance() {
        return true;
    }

    @Override
    public String getName() {
        return "tree";
    }

    @Override
    public void rewriteFeatureIndices(Map<String, Integer> newFeatureIndexMapping) {
        LOGGER.debug(() -> new ParameterizedMessage("rewriting features {}", (Object)newFeatureIndexMapping));
        if (this.preparedForInference) {
            return;
        }
        this.preparedForInference = true;
        if (newFeatureIndexMapping == null || newFeatureIndexMapping.isEmpty()) {
            return;
        }
        for (Node node : this.nodes) {
            if (node.isLeaf()) continue;
            InnerNode treeNode = (InnerNode)node;
            Integer newSplitFeatureIndex = newFeatureIndexMapping.get(this.featureNames[treeNode.splitFeature]);
            if (newSplitFeatureIndex == null) {
                throw new IllegalArgumentException("[tree] failed to optimize for inference");
            }
            treeNode.splitFeature = newSplitFeatureIndex;
        }
        this.featureNames = new String[0];
        this.classificationLabels = null;
    }

    public long ramBytesUsed() {
        long size = SHALLOW_SIZE;
        size += RamUsageEstimator.sizeOfCollection(this.classificationLabels);
        size += RamUsageEstimator.sizeOf((String[])this.featureNames);
        return size += RamUsageEstimator.sizeOf((Accountable[])this.nodes);
    }

    private double maxLeafValue() {
        if (this.targetType != TargetType.CLASSIFICATION) {
            return Double.NaN;
        }
        double max = 0.0;
        for (Node node : this.nodes) {
            if (!(node instanceof LeafNode)) continue;
            LeafNode leafNode = (LeafNode)node;
            if (leafNode.leafValue.length > 1) {
                return leafNode.leafValue.length;
            }
            max = Math.max(leafNode.leafValue[0], max);
        }
        return max;
    }

    public Node[] getNodes() {
        return this.nodes;
    }

    public String toString() {
        return "TreeInferenceModel{nodes=" + Arrays.toString(this.nodes) + ", featureNames=" + Arrays.toString(this.featureNames) + ", targetType=" + (Object)((Object)this.targetType) + ", classificationLabels=" + this.classificationLabels + ", highOrderCategory=" + this.highOrderCategory + ", maxDepth=" + this.maxDepth + ", leafSize=" + this.leafSize + ", preparedForInference=" + this.preparedForInference + '}';
    }

    private static int getDepth(Node[] nodes, int nodeIndex, int depth) {
        Node node = nodes[nodeIndex];
        if (node instanceof LeafNode) {
            return 0;
        }
        InnerNode innerNode = (InnerNode)node;
        int depthLeft = TreeInferenceModel.getDepth(nodes, innerNode.leftChild, depth + 1);
        int depthRight = TreeInferenceModel.getDepth(nodes, innerNode.rightChild, depth + 1);
        return Math.max(depthLeft, depthRight) + 1;
    }

    static {
        PARSER.declareStringArray(ConstructingObjectParser.constructorArg(), Tree.FEATURE_NAMES);
        PARSER.declareObjectArray(ConstructingObjectParser.constructorArg(), (arg_0, arg_1) -> ((ObjectParser)NodeBuilder.PARSER).apply(arg_0, arg_1), Tree.TREE_STRUCTURE);
        PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), Tree.TARGET_TYPE);
        PARSER.declareStringArray(ConstructingObjectParser.optionalConstructorArg(), Tree.CLASSIFICATION_LABELS);
    }

    public static abstract class Node
    implements Accountable {
        int compare(double[] features) {
            throw new IllegalArgumentException("cannot call compare against a leaf node.");
        }

        abstract long getNumberSamples();

        public boolean isLeaf() {
            return this instanceof LeafNode;
        }
    }

    public static class LeafNode
    extends Node {
        public static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(LeafNode.class);
        private final double[] leafValue;
        private final long numberSamples;

        LeafNode(double[] leafValue, long numberSamples) {
            this.leafValue = leafValue;
            this.numberSamples = numberSamples;
        }

        public long ramBytesUsed() {
            return SHALLOW_SIZE + RamUsageEstimator.sizeOf((double[])this.leafValue);
        }

        @Override
        long getNumberSamples() {
            return this.numberSamples;
        }

        public double[] getLeafValue() {
            return this.leafValue;
        }

        public String toString() {
            return "LeafNode{leafValue=" + Arrays.toString(this.leafValue) + ", numberSamples=" + this.numberSamples + '}';
        }
    }

    public static class InnerNode
    extends Node {
        public static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(InnerNode.class);
        private final Operator operator;
        private final double threshold;
        private int splitFeature;
        private final boolean defaultLeft;
        private final int leftChild;
        private final int rightChild;
        private final long numberSamples;

        InnerNode(Operator operator, double threshold, int splitFeature, boolean defaultLeft, int leftChild, int rightChild, long numberSamples) {
            this.operator = operator;
            this.threshold = threshold;
            this.splitFeature = splitFeature;
            this.defaultLeft = defaultLeft;
            this.leftChild = leftChild;
            this.rightChild = rightChild;
            this.numberSamples = numberSamples;
        }

        @Override
        public int compare(double[] features) {
            double feature = features[this.splitFeature];
            if (InnerNode.isMissing(feature)) {
                return this.defaultLeft ? this.leftChild : this.rightChild;
            }
            return this.operator.test(feature, this.threshold) ? this.leftChild : this.rightChild;
        }

        @Override
        long getNumberSamples() {
            return this.numberSamples;
        }

        private static boolean isMissing(double feature) {
            return !Numbers.isValidDouble((double)feature);
        }

        public long ramBytesUsed() {
            return SHALLOW_SIZE;
        }

        public String toString() {
            return "InnerNode{operator=" + (Object)((Object)this.operator) + ", threshold=" + this.threshold + ", splitFeature=" + this.splitFeature + ", defaultLeft=" + this.defaultLeft + ", leftChild=" + this.leftChild + ", rightChild=" + this.rightChild + ", numberSamples=" + this.numberSamples + '}';
        }
    }

    static class NodeBuilder {
        private static final ObjectParser<NodeBuilder, Void> PARSER = new ObjectParser("tree_inference_model_node", true, NodeBuilder::new);
        private Operator operator = Operator.LTE;
        private double threshold = Double.NaN;
        private int splitFeature = -1;
        private boolean defaultLeft = false;
        private int leftChild = -1;
        private int rightChild = -1;
        private long numberSamples;
        private double[] leafValue = new double[0];

        NodeBuilder() {
        }

        public NodeBuilder setOperator(Operator operator) {
            this.operator = operator;
            return this;
        }

        public NodeBuilder setThreshold(double threshold) {
            this.threshold = threshold;
            return this;
        }

        public NodeBuilder setSplitFeature(int splitFeature) {
            this.splitFeature = splitFeature;
            return this;
        }

        public NodeBuilder setDefaultLeft(boolean defaultLeft) {
            this.defaultLeft = defaultLeft;
            return this;
        }

        public NodeBuilder setLeftChild(int leftChild) {
            this.leftChild = leftChild;
            return this;
        }

        public NodeBuilder setRightChild(int rightChild) {
            this.rightChild = rightChild;
            return this;
        }

        public NodeBuilder setNumberSamples(long numberSamples) {
            this.numberSamples = numberSamples;
            return this;
        }

        private NodeBuilder setLeafValue(List<Double> leafValue) {
            return this.setLeafValue(leafValue.stream().mapToDouble(Double::doubleValue).toArray());
        }

        public NodeBuilder setLeafValue(double[] leafValue) {
            this.leafValue = leafValue;
            return this;
        }

        Node build() {
            if (this.leftChild < 0) {
                return new LeafNode(this.leafValue, this.numberSamples);
            }
            return new InnerNode(this.operator, this.threshold, this.splitFeature, this.defaultLeft, this.leftChild, this.rightChild, this.numberSamples);
        }

        static {
            PARSER.declareDouble(NodeBuilder::setThreshold, TreeNode.THRESHOLD);
            PARSER.declareField(NodeBuilder::setOperator, p -> Operator.fromString(p.text()), TreeNode.DECISION_TYPE, ObjectParser.ValueType.STRING);
            PARSER.declareInt(NodeBuilder::setLeftChild, TreeNode.LEFT_CHILD);
            PARSER.declareInt(NodeBuilder::setRightChild, TreeNode.RIGHT_CHILD);
            PARSER.declareBoolean(NodeBuilder::setDefaultLeft, TreeNode.DEFAULT_LEFT);
            PARSER.declareInt(NodeBuilder::setSplitFeature, TreeNode.SPLIT_FEATURE);
            PARSER.declareDoubleArray(NodeBuilder::setLeafValue, TreeNode.LEAF_VALUE);
            PARSER.declareLong(NodeBuilder::setNumberSamples, TreeNode.NUMBER_SAMPLES);
        }
    }
}

