/*
 * Decompiled with CFR 0.152.
 */
package com.o19s.es.ltr.ranker.dectree;

import com.o19s.es.ltr.ranker.DenseFeatureVector;
import com.o19s.es.ltr.ranker.DenseLtrRanker;
import com.o19s.es.ltr.ranker.normalizer.Normalizer;
import java.util.Objects;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.RamUsageEstimator;

public class NaiveAdditiveDecisionTree
extends DenseLtrRanker
implements Accountable {
    private static final long BASE_RAM_USED = RamUsageEstimator.shallowSizeOfInstance(Split.class);
    private final Node[] trees;
    private final float[] weights;
    private final int modelSize;
    private final Normalizer normalizer;

    public NaiveAdditiveDecisionTree(Node[] trees, float[] weights, int modelSize, Normalizer normalizer) {
        assert (trees.length == weights.length);
        this.trees = trees;
        this.weights = weights;
        this.modelSize = modelSize;
        this.normalizer = normalizer;
    }

    @Override
    public String name() {
        return "naive_additive_decision_tree";
    }

    @Override
    protected float score(DenseFeatureVector vector) {
        float sum = 0.0f;
        float[] scores = vector.scores;
        for (int i = 0; i < this.trees.length; ++i) {
            sum += this.weights[i] * this.trees[i].eval(scores);
        }
        return this.normalizer.normalize(sum);
    }

    @Override
    protected int size() {
        return this.modelSize;
    }

    public long ramBytesUsed() {
        return BASE_RAM_USED + RamUsageEstimator.sizeOf((float[])this.weights) + RamUsageEstimator.sizeOf((Accountable[])this.trees);
    }

    public static interface Node
    extends Accountable {
        public boolean isLeaf();

        public float eval(float[] var1);
    }

    public static class Split
    implements Node {
        private static final long BASE_RAM_USED = RamUsageEstimator.shallowSizeOfInstance(Split.class);
        private final Node left;
        private final Node right;
        private final int feature;
        private final float threshold;

        public Split(Node left, Node right, int feature, float threshold) {
            this.left = Objects.requireNonNull(left);
            this.right = Objects.requireNonNull(right);
            this.feature = feature;
            this.threshold = threshold;
        }

        @Override
        public boolean isLeaf() {
            return false;
        }

        @Override
        public float eval(float[] scores) {
            Node n = this;
            while (!n.isLeaf()) {
                assert (n instanceof Split);
                Split s = n;
                if (s.threshold > scores[s.feature]) {
                    n = s.left;
                    continue;
                }
                n = s.right;
            }
            assert (n instanceof Leaf);
            return n.eval(scores);
        }

        public long ramBytesUsed() {
            return BASE_RAM_USED + this.left.ramBytesUsed() + this.right.ramBytesUsed();
        }
    }

    public static class Leaf
    implements Node {
        private static final long BASE_RAM_USED = RamUsageEstimator.shallowSizeOfInstance(Split.class);
        private final float output;

        public Leaf(float output) {
            this.output = output;
        }

        @Override
        public boolean isLeaf() {
            return true;
        }

        @Override
        public float eval(float[] scores) {
            return this.output;
        }

        public long ramBytesUsed() {
            return BASE_RAM_USED;
        }
    }
}

