package org.apache.ignite.ml.tree;

import android.graphics.ColorSpace;
import java.io.Serializable;
import java.lang.invoke.SerializedLambda;
import java.util.Arrays;
import java.util.Collections;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
import org.apache.ignite.ml.preprocessing.Preprocessor;
import org.apache.ignite.ml.preprocessing.encoding.EncoderPreprocessor;
import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;
import org.apache.ignite.ml.tree.data.DecisionTreeData;
import org.apache.ignite.ml.tree.data.DecisionTreeDataBuilder;
import org.apache.ignite.ml.tree.impurity.ImpurityMeasure;
import org.apache.ignite.ml.tree.impurity.ImpurityMeasureCalculator;
import org.apache.ignite.ml.tree.impurity.util.StepFunction;
import org.apache.ignite.ml.tree.impurity.util.StepFunctionCompressor;
import org.apache.ignite.ml.tree.leaf.DecisionTreeLeafBuilder;

/* loaded from: input_file:org/apache/ignite/ml/tree/DecisionTree.class */
public abstract class DecisionTree<T extends ImpurityMeasure<T>> extends SingleLabelDatasetTrainer<DecisionTreeNode> {
    int maxDeep;
    double minImpurityDecrease;
    StepFunctionCompressor<T> compressor;
    private final DecisionTreeLeafBuilder decisionTreeLeafBuilder;
    protected boolean usingIdx = true;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/ignite/ml/tree/DecisionTree$SplitPoint.class */
    public static class SplitPoint<T extends ImpurityMeasure<T>> implements Serializable {
        private static final long serialVersionUID = -1758525953544425043L;
        private final T val;
        private final int col;
        private final double threshold;

        SplitPoint(T t, int i, double d) {
            this.val = t;
            this.col = i;
            this.threshold = d;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public DecisionTree(int i, double d, StepFunctionCompressor<T> stepFunctionCompressor, DecisionTreeLeafBuilder decisionTreeLeafBuilder) {
        this.maxDeep = i;
        this.minImpurityDecrease = d;
        this.compressor = stepFunctionCompressor;
        this.decisionTreeLeafBuilder = decisionTreeLeafBuilder;
    }

    private static void printTree(DecisionTreeNode decisionTreeNode, int i, StringBuilder sb, boolean z, boolean z2) {
        if (decisionTreeNode != null) {
            sb.append(z ? String.join(EncoderPreprocessor.KEY_FOR_NULL_VALUES, Collections.nCopies(i, "\t")) : EncoderPreprocessor.KEY_FOR_NULL_VALUES);
            if (decisionTreeNode instanceof DecisionTreeLeafNode) {
                DecisionTreeLeafNode decisionTreeLeafNode = (DecisionTreeLeafNode) decisionTreeNode;
                Object[] objArr = new Object[1];
                objArr[0] = z2 ? "then" : "else";
                sb.append(String.format("%s return ", objArr)).append(String.format("%.4f", Double.valueOf(decisionTreeLeafNode.getVal())));
                return;
            }
            if (!(decisionTreeNode instanceof DecisionTreeConditionalNode)) {
                throw new IllegalArgumentException();
            }
            DecisionTreeConditionalNode decisionTreeConditionalNode = (DecisionTreeConditionalNode) decisionTreeNode;
            sb.append(String.format("%sif (x", i == 0 ? EncoderPreprocessor.KEY_FOR_NULL_VALUES : z2 ? "then " : "else ")).append(decisionTreeConditionalNode.getCol()).append(" > ").append(String.format("%.4f", Double.valueOf(decisionTreeConditionalNode.getThreshold()))).append(z ? ")\n" : ") ");
            printTree(decisionTreeConditionalNode.getThenNode(), i + 1, sb, z, true);
            sb.append(z ? "\n" : " ");
            printTree(decisionTreeConditionalNode.getElseNode(), i + 1, sb, z, false);
        }
    }

    @Override // org.apache.ignite.ml.trainers.DatasetTrainer
    public <K, V> DecisionTreeNode fitWithInitializedDeployingContext(DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
        try {
            Dataset<EmptyContext, DecisionTreeData> build = datasetBuilder.build(this.envBuilder, new EmptyContextBuilder(), new DecisionTreeDataBuilder(preprocessor, this.usingIdx), learningEnvironment());
            Throwable th = null;
            try {
                try {
                    DecisionTreeNode fit = fit(build);
                    if (build != null) {
                        if (0 != 0) {
                            try {
                                build.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            build.close();
                        }
                    }
                    return fit;
                } finally {
                }
            } finally {
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    @Override // org.apache.ignite.ml.trainers.DatasetTrainer
    public boolean isUpdateable(DecisionTreeNode decisionTreeNode) {
        return true;
    }

    @Override // org.apache.ignite.ml.trainers.DatasetTrainer
    public DecisionTree<T> withEnvironmentBuilder(LearningEnvironmentBuilder learningEnvironmentBuilder) {
        return (DecisionTree) super.withEnvironmentBuilder(learningEnvironmentBuilder);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.ignite.ml.trainers.DatasetTrainer
    public <K, V> DecisionTreeNode updateModel(DecisionTreeNode decisionTreeNode, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
        return (DecisionTreeNode) fit(datasetBuilder, preprocessor);
    }

    public <K, V> DecisionTreeNode fit(Dataset<EmptyContext, DecisionTreeData> dataset) {
        return split(dataset, dArr -> {
            return true;
        }, 0, getImpurityMeasureCalculator(dataset));
    }

    private DecisionTreeNode split(Dataset<EmptyContext, DecisionTreeData> dataset, TreeFilter treeFilter, int i, ImpurityMeasureCalculator<T> impurityMeasureCalculator) {
        StepFunction<T>[] calculateImpurityForAllColumns;
        SplitPoint calculateBestSplitPoint;
        if (i < this.maxDeep && (calculateImpurityForAllColumns = calculateImpurityForAllColumns(dataset, treeFilter, impurityMeasureCalculator, i)) != null && (calculateBestSplitPoint = calculateBestSplitPoint(calculateImpurityForAllColumns)) != null) {
            return new DecisionTreeConditionalNode(calculateBestSplitPoint.col, calculateBestSplitPoint.threshold, split(dataset, updatePredicateForThenNode(treeFilter, calculateBestSplitPoint), i + 1, impurityMeasureCalculator), split(dataset, updatePredicateForElseNode(treeFilter, calculateBestSplitPoint), i + 1, impurityMeasureCalculator), null);
        }
        return this.decisionTreeLeafBuilder.createLeafNode(dataset, treeFilter);
    }

    private StepFunction<T>[] calculateImpurityForAllColumns(Dataset<EmptyContext, DecisionTreeData> dataset, TreeFilter treeFilter, ImpurityMeasureCalculator<T> impurityMeasureCalculator, int i) {
        return (StepFunction[]) dataset.compute(decisionTreeData -> {
            return this.compressor != null ? this.compressor.compress(impurityMeasureCalculator.calculate(decisionTreeData, treeFilter, i)) : impurityMeasureCalculator.calculate(decisionTreeData, treeFilter, i);
        }, this::reduce);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private SplitPoint calculateBestSplitPoint(StepFunction<T>[] stepFunctionArr) {
        SplitPoint splitPoint = null;
        for (int i = 0; i < stepFunctionArr.length; i++) {
            StepFunction<T> stepFunction = stepFunctionArr[i];
            double[] x = stepFunction.getX();
            T[] y = stepFunction.getY();
            for (int i2 = 1; i2 < y.length - 1; i2++) {
                if (y[0].impurity() - y[i2].impurity() > this.minImpurityDecrease && (splitPoint == null || y[i2].compareTo((ColorSpace.Adaptation) splitPoint.val) < 0)) {
                    splitPoint = new SplitPoint(y[i2], i, calculateThreshold(x, i2));
                }
            }
        }
        return splitPoint;
    }

    private StepFunction<T>[] reduce(StepFunction<T>[] stepFunctionArr, StepFunction<T>[] stepFunctionArr2) {
        if (stepFunctionArr == null) {
            return stepFunctionArr2;
        }
        if (stepFunctionArr2 == null) {
            return stepFunctionArr;
        }
        StepFunction<T>[] stepFunctionArr3 = (StepFunction[]) Arrays.copyOf(stepFunctionArr, stepFunctionArr.length);
        for (int i = 0; i < stepFunctionArr3.length; i++) {
            stepFunctionArr3[i] = stepFunctionArr3[i].add(stepFunctionArr2[i]);
        }
        return stepFunctionArr3;
    }

    private double calculateThreshold(double[] dArr, int i) {
        return (dArr[i] + dArr[i + 1]) / 2.0d;
    }

    private TreeFilter updatePredicateForThenNode(TreeFilter treeFilter, SplitPoint splitPoint) {
        return treeFilter.and(dArr -> {
            return dArr[splitPoint.col] > splitPoint.threshold;
        });
    }

    private TreeFilter updatePredicateForElseNode(TreeFilter treeFilter, SplitPoint splitPoint) {
        return treeFilter.and(dArr -> {
            return dArr[splitPoint.col] <= splitPoint.threshold;
        });
    }

    public static String printTree(DecisionTreeNode decisionTreeNode, boolean z) {
        StringBuilder sb = new StringBuilder();
        printTree(decisionTreeNode, 0, sb, z, false);
        return sb.toString();
    }

    protected abstract ImpurityMeasureCalculator<T> getImpurityMeasureCalculator(Dataset<EmptyContext, DecisionTreeData> dataset);

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -934873754:
                if (implMethodName.equals("reduce")) {
                    z = true;
                    break;
                }
                break;
            case -609453633:
                if (implMethodName.equals("lambda$updatePredicateForThenNode$31c6d976$1")) {
                    z = 3;
                    break;
                }
                break;
            case -603154420:
                if (implMethodName.equals("lambda$calculateImpurityForAllColumns$8681ce92$1")) {
                    z = false;
                    break;
                }
                break;
            case 1545707299:
                if (implMethodName.equals("lambda$updatePredicateForElseNode$31c6d976$1")) {
                    z = 2;
                    break;
                }
                break;
            case 1808410439:
                if (implMethodName.equals("lambda$fit$edea134a$1")) {
                    z = 4;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/tree/DecisionTree") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/ignite/ml/tree/impurity/ImpurityMeasureCalculator;Lorg/apache/ignite/ml/tree/TreeFilter;ILorg/apache/ignite/ml/tree/data/DecisionTreeData;)[Lorg/apache/ignite/ml/tree/impurity/util/StepFunction;")) {
                    DecisionTree decisionTree = (DecisionTree) serializedLambda.getCapturedArg(0);
                    ImpurityMeasureCalculator impurityMeasureCalculator = (ImpurityMeasureCalculator) serializedLambda.getCapturedArg(1);
                    TreeFilter treeFilter = (TreeFilter) serializedLambda.getCapturedArg(2);
                    int intValue = ((Integer) serializedLambda.getCapturedArg(3)).intValue();
                    return decisionTreeData -> {
                        return this.compressor != null ? this.compressor.compress(impurityMeasureCalculator.calculate(decisionTreeData, treeFilter, intValue)) : impurityMeasureCalculator.calculate(decisionTreeData, treeFilter, intValue);
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/math/functions/IgniteBinaryOperator") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/tree/DecisionTree") && serializedLambda.getImplMethodSignature().equals("([Lorg/apache/ignite/ml/tree/impurity/util/StepFunction;[Lorg/apache/ignite/ml/tree/impurity/util/StepFunction;)[Lorg/apache/ignite/ml/tree/impurity/util/StepFunction;")) {
                    DecisionTree decisionTree2 = (DecisionTree) serializedLambda.getCapturedArg(0);
                    return decisionTree2::reduce;
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/tree/TreeFilter") && serializedLambda.getFunctionalInterfaceMethodName().equals("test") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Z") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/tree/DecisionTree") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/ignite/ml/tree/DecisionTree$SplitPoint;[D)Z")) {
                    SplitPoint splitPoint = (SplitPoint) serializedLambda.getCapturedArg(0);
                    return dArr -> {
                        return dArr[splitPoint.col] <= splitPoint.threshold;
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/tree/TreeFilter") && serializedLambda.getFunctionalInterfaceMethodName().equals("test") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Z") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/tree/DecisionTree") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/ignite/ml/tree/DecisionTree$SplitPoint;[D)Z")) {
                    SplitPoint splitPoint2 = (SplitPoint) serializedLambda.getCapturedArg(0);
                    return dArr2 -> {
                        return dArr2[splitPoint2.col] > splitPoint2.threshold;
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/ml/tree/TreeFilter") && serializedLambda.getFunctionalInterfaceMethodName().equals("test") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Z") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/tree/DecisionTree") && serializedLambda.getImplMethodSignature().equals("([D)Z")) {
                    return dArr3 -> {
                        return true;
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
