package org.apache.ignite.ml.naivebayes.compound;

import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import org.apache.ignite.ml.Exportable;
import org.apache.ignite.ml.Exporter;
import org.apache.ignite.ml.IgniteModel;
import org.apache.ignite.ml.environment.deploy.DeployableObject;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.naivebayes.discrete.DiscreteNaiveBayesModel;
import org.apache.ignite.ml.naivebayes.gaussian.GaussianNaiveBayesModel;

/* loaded from: input_file:org/apache/ignite/ml/naivebayes/compound/CompoundNaiveBayesModel.class */
public class CompoundNaiveBayesModel implements IgniteModel<Vector, Double>, Exportable<CompoundNaiveBayesModel>, DeployableObject {
    private static final long serialVersionUID = -5045925321135798960L;
    private double[] priorProbabilities;
    private double[] labels;
    private GaussianNaiveBayesModel gaussianModel;
    private DiscreteNaiveBayesModel discreteModel;
    static final /* synthetic */ boolean $assertionsDisabled;
    private Collection<Integer> gaussianFeatureIdsToSkip = Collections.emptyList();
    private Collection<Integer> discreteFeatureIdsToSkip = Collections.emptyList();

    @Override // org.apache.ignite.ml.Exportable
    public <P> void saveModel(Exporter<CompoundNaiveBayesModel, P> exporter, P p) {
        exporter.save(this, p);
    }

    @Override // org.apache.ignite.ml.inference.Model
    public Double predict(Vector vector) {
        double[] dArr = new double[this.priorProbabilities.length];
        for (int i = 0; i < this.priorProbabilities.length; i++) {
            dArr[i] = Math.log(this.priorProbabilities[i]);
        }
        if (this.discreteModel != null) {
            dArr = sum(dArr, this.discreteModel.probabilityPowers(skipFeatures(vector, this.discreteFeatureIdsToSkip)));
        }
        if (this.gaussianModel != null) {
            dArr = sum(dArr, this.gaussianModel.probabilityPowers(skipFeatures(vector, this.gaussianFeatureIdsToSkip)));
        }
        int i2 = 0;
        for (int i3 = 0; i3 < dArr.length; i3++) {
            if (dArr[i3] > dArr[i2]) {
                i2 = i3;
            }
        }
        return Double.valueOf(this.labels[i2]);
    }

    public GaussianNaiveBayesModel getGaussianModel() {
        return this.gaussianModel;
    }

    public DiscreteNaiveBayesModel getDiscreteModel() {
        return this.discreteModel;
    }

    public CompoundNaiveBayesModel withPriorProbabilities(double[] dArr) {
        this.priorProbabilities = (double[]) dArr.clone();
        return this;
    }

    public CompoundNaiveBayesModel withLabels(double[] dArr) {
        this.labels = (double[]) dArr.clone();
        return this;
    }

    public CompoundNaiveBayesModel withGaussianModel(GaussianNaiveBayesModel gaussianNaiveBayesModel) {
        this.gaussianModel = gaussianNaiveBayesModel;
        return this;
    }

    public CompoundNaiveBayesModel withDiscreteModel(DiscreteNaiveBayesModel discreteNaiveBayesModel) {
        this.discreteModel = discreteNaiveBayesModel;
        return this;
    }

    public CompoundNaiveBayesModel withGaussianFeatureIdsToSkip(Collection<Integer> collection) {
        this.gaussianFeatureIdsToSkip = collection;
        return this;
    }

    public CompoundNaiveBayesModel withDiscreteFeatureIdsToSkip(Collection<Integer> collection) {
        this.discreteFeatureIdsToSkip = collection;
        return this;
    }

    private static double[] sum(double[] dArr, double[] dArr2) {
        if (!$assertionsDisabled && dArr.length != dArr2.length) {
            throw new AssertionError();
        }
        double[] dArr3 = new double[dArr.length];
        for (int i = 0; i < dArr.length; i++) {
            dArr3[i] = dArr[i] + dArr2[i];
        }
        return dArr3;
    }

    private static Vector skipFeatures(Vector vector, Collection<Integer> collection) {
        double[] dArr = new double[vector.size() - collection.size()];
        int i = 0;
        for (int i2 = 0; i2 < vector.size(); i2++) {
            if (!collection.contains(Integer.valueOf(i2))) {
                dArr[i] = vector.get(i2);
                i++;
            }
        }
        return VectorUtils.of(dArr);
    }

    @Override // org.apache.ignite.ml.environment.deploy.DeployableObject
    public List<Object> getDependencies() {
        return Arrays.asList(this.discreteModel, this.gaussianModel);
    }

    static {
        $assertionsDisabled = !CompoundNaiveBayesModel.class.desiredAssertionStatus();
    }
}
