/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.bayes;

import java.util.Enumeration;
import java.util.Random;
import java.util.StringTokenizer;
import java.util.Vector;
import weka.classifiers.Classifier;
import weka.classifiers.bayes.blr.GaussianPriorImpl;
import weka.classifiers.bayes.blr.LaplacePriorImpl;
import weka.classifiers.bayes.blr.Prior;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.SerializedObject;
import weka.core.Tag;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Normalize;

public class BayesianLogisticRegression
extends Classifier
implements OptionHandler,
TechnicalInformationHandler {
    static final long serialVersionUID = -8013478897911757631L;
    public static double[] LogLikelihood;
    public static double[] InputHyperparameterValues;
    boolean debug = false;
    public boolean NormalizeData = false;
    public double Tolerance = 5.0E-4;
    public double Threshold = 0.5;
    public static final int GAUSSIAN = 1;
    public static final int LAPLACIAN = 2;
    public static final Tag[] TAGS_PRIOR;
    public int PriorClass = 1;
    public int NumFolds = 2;
    public int m_seed = 1;
    public static final int NORM_BASED = 1;
    public static final int CV_BASED = 2;
    public static final int SPECIFIC_VALUE = 3;
    public static final Tag[] TAGS_HYPER_METHOD;
    public int HyperparameterSelection = 1;
    public int ClassIndex = -1;
    public double HyperparameterValue = 0.27;
    public String HyperparameterRange = "R:0.01-316,3.16";
    public int maxIterations = 100;
    public int iterationCounter = 0;
    public double[] BetaVector;
    public double[] DeltaBeta;
    public double[] DeltaUpdate;
    public double[] Delta;
    public double[] Hyperparameters;
    public double[] R;
    public double[] DeltaR;
    public double Change;
    public Filter m_Filter;
    protected Instances m_Instances;
    protected Prior m_PriorUpdate;

    public String globalInfo() {
        return "Implements Bayesian Logistic Regression for both Gaussian and Laplace Priors.\n\nFor more information, see\n\n" + this.getTechnicalInformation();
    }

    public void initialize() throws Exception {
        int i;
        this.Change = 0.0;
        if (this.NormalizeData) {
            this.m_Filter = new Normalize();
            this.m_Filter.setInputFormat(this.m_Instances);
            this.m_Instances = Filter.useFilter(this.m_Instances, this.m_Filter);
        }
        String attName = "(intercept)";
        String attAtZero = this.m_Instances.attribute(0).name();
        int attNameIncr = 0;
        if (attAtZero.startsWith(attName)) {
            if (attAtZero.indexOf(41) < attAtZero.length() - 1) {
                String tempNum = attAtZero.substring(attAtZero.indexOf(41) + 1, attAtZero.length());
                attNameIncr = Integer.parseInt(tempNum);
                ++attNameIncr;
            }
            attName = attName + "" + attNameIncr;
        }
        Attribute att = new Attribute(attName);
        this.m_Instances.insertAttributeAt(att, 0);
        for (i = 0; i < this.m_Instances.numInstances(); ++i) {
            Instance instance = this.m_Instances.instance(i);
            instance.setValue(0, 1.0);
        }
        int numOfAttributes = this.m_Instances.numAttributes();
        int numOfInstances = this.m_Instances.numInstances();
        this.ClassIndex = this.m_Instances.classIndex();
        this.iterationCounter = 0;
        switch (this.HyperparameterSelection) {
            case 1: {
                this.HyperparameterValue = this.normBasedHyperParameter();
                if (!this.debug) break;
                System.out.println("Norm-based Hyperparameter: " + this.HyperparameterValue);
                break;
            }
            case 2: {
                this.HyperparameterValue = this.CVBasedHyperparameter();
                if (!this.debug) break;
                System.out.println("CV-based Hyperparameter: " + this.HyperparameterValue);
            }
        }
        this.BetaVector = new double[numOfAttributes];
        this.Delta = new double[numOfAttributes];
        this.DeltaBeta = new double[numOfAttributes];
        this.Hyperparameters = new double[numOfAttributes];
        this.DeltaUpdate = new double[numOfAttributes];
        for (int j = 0; j < numOfAttributes; ++j) {
            this.BetaVector[j] = 0.0;
            this.Delta[j] = 1.0;
            this.DeltaBeta[j] = 0.0;
            this.DeltaUpdate[j] = 0.0;
            this.Hyperparameters[j] = this.HyperparameterValue;
        }
        this.DeltaR = new double[numOfInstances];
        this.R = new double[numOfInstances];
        for (i = 0; i < numOfInstances; ++i) {
            this.DeltaR[i] = 0.0;
            this.R[i] = 0.0;
        }
        this.m_PriorUpdate = this.PriorClass == 1 ? new GaussianPriorImpl() : new LaplacePriorImpl();
    }

    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAll();
        result.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        result.enable(Capabilities.Capability.BINARY_ATTRIBUTES);
        result.enable(Capabilities.Capability.BINARY_CLASS);
        result.setMinimumNumberInstances(0);
        return result;
    }

    public void buildClassifier(Instances data) throws Exception {
        this.getCapabilities().testWithFail(data);
        this.m_Instances = new Instances(data);
        this.initialize();
        do {
            for (int j = 0; j < this.m_Instances.numAttributes(); ++j) {
                if (j == this.ClassIndex) continue;
                this.DeltaUpdate[j] = this.m_PriorUpdate.update(j, this.m_Instances, this.BetaVector[j], this.Hyperparameters[j], this.R, this.Delta[j]);
                this.DeltaBeta[j] = Math.min(Math.max(this.DeltaUpdate[j], 0.0 - this.Delta[j]), this.Delta[j]);
                for (int i = 0; i < this.m_Instances.numInstances(); ++i) {
                    Instance instance = this.m_Instances.instance(i);
                    if (instance.value(j) == 0.0) continue;
                    this.DeltaR[i] = this.DeltaBeta[j] * instance.value(j) * BayesianLogisticRegression.classSgn(instance.classValue());
                    int n = i;
                    this.R[n] = this.R[n] + this.DeltaR[i];
                }
                int n = j;
                this.BetaVector[n] = this.BetaVector[n] + this.DeltaBeta[j];
                this.Delta[j] = Math.max(2.0 * Math.abs(this.DeltaBeta[j]), this.Delta[j] / 2.0);
            }
        } while (!this.stoppingCriterion());
        this.m_PriorUpdate.computelogLikelihood(this.BetaVector, this.m_Instances);
        this.m_PriorUpdate.computePenalty(this.BetaVector, this.Hyperparameters);
    }

    public static double classSgn(double value) {
        if (value == 0.0) {
            return -1.0;
        }
        return 1.0;
    }

    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result = null;
        result = new TechnicalInformation(TechnicalInformation.Type.TECHREPORT);
        result.setValue(TechnicalInformation.Field.AUTHOR, "Alexander Genkin and David D. Lewis and David Madigan");
        result.setValue(TechnicalInformation.Field.YEAR, "2004");
        result.setValue(TechnicalInformation.Field.TITLE, "Large-scale bayesian logistic regression for text categorization");
        result.setValue(TechnicalInformation.Field.INSTITUTION, "DIMACS");
        result.setValue(TechnicalInformation.Field.URL, "http://www.stat.rutgers.edu/~madigan/PAPERS/shortFat-v3a.pdf");
        return result;
    }

    public static double bigF(double r, double sigma) {
        double funcValue = 0.25;
        double absR = Math.abs(r);
        if (absR > sigma) {
            funcValue = 1.0 / (2.0 + Math.exp(absR - sigma) + Math.exp(sigma - absR));
        }
        return funcValue;
    }

    public boolean stoppingCriterion() {
        double sum_deltaR = 0.0;
        double sum_R = 1.0;
        double value = 0.0;
        for (int i = 0; i < this.m_Instances.numInstances(); ++i) {
            sum_deltaR += Math.abs(this.DeltaR[i]);
            sum_R += Math.abs(this.R[i]);
        }
        double delta = Math.abs(sum_deltaR - this.Change);
        this.Change = delta / sum_R;
        if (this.debug) {
            System.out.println(this.Change + " <= " + this.Tolerance);
        }
        boolean shouldStop = this.Change <= this.Tolerance || this.iterationCounter >= this.maxIterations;
        ++this.iterationCounter;
        this.Change = sum_deltaR;
        return shouldStop;
    }

    public static double logisticLinkFunction(double r) {
        return Math.exp(r) / (1.0 + Math.exp(r));
    }

    public static double sgn(double r) {
        double sgn = 0.0;
        if (r > 0.0) {
            sgn = 1.0;
        } else if (r < 0.0) {
            sgn = -1.0;
        }
        return sgn;
    }

    public double normBasedHyperParameter() {
        double mean = 0.0;
        for (int i = 0; i < this.m_Instances.numInstances(); ++i) {
            Instance instance = this.m_Instances.instance(i);
            double sqr_sum = 0.0;
            for (int j = 0; j < this.m_Instances.numAttributes(); ++j) {
                if (j == this.ClassIndex) continue;
                sqr_sum += instance.value(j) * instance.value(j);
            }
            mean += sqr_sum;
        }
        return (double)this.m_Instances.numAttributes() / (mean /= (double)this.m_Instances.numInstances());
    }

    public double classifyInstance(Instance instance) throws Exception {
        double sum_R = 0.0;
        double classification = 0.0;
        sum_R = this.BetaVector[0];
        for (int j = 0; j < instance.numAttributes(); ++j) {
            if (j == this.ClassIndex - 1) continue;
            sum_R += this.BetaVector[j + 1] * instance.value(j);
        }
        classification = (sum_R = BayesianLogisticRegression.logisticLinkFunction(sum_R)) > this.Threshold ? 1.0 : 0.0;
        return classification;
    }

    public String toString() {
        if (this.m_Instances == null) {
            return "Bayesian logistic regression: No model built yet.";
        }
        StringBuffer buf = new StringBuffer();
        String text = "";
        switch (this.HyperparameterSelection) {
            case 1: {
                text = "Norm-Based Hyperparameter Selection: ";
                break;
            }
            case 2: {
                text = "Cross-Validation Based Hyperparameter Selection: ";
                break;
            }
            case 3: {
                text = "Specified Hyperparameter: ";
            }
        }
        buf.append(text).append(this.HyperparameterValue).append("\n\n");
        buf.append("Regression Coefficients\n");
        buf.append("=========================\n\n");
        for (int j = 0; j < this.m_Instances.numAttributes(); ++j) {
            if (j == this.ClassIndex || this.BetaVector[j] == 0.0) continue;
            buf.append(this.m_Instances.attribute(j).name()).append(" : ").append(this.BetaVector[j]).append("\n");
        }
        buf.append("===========================\n\n");
        buf.append("Likelihood: " + this.m_PriorUpdate.getLoglikelihood() + "\n\n");
        buf.append("Penalty: " + this.m_PriorUpdate.getPenalty() + "\n\n");
        buf.append("Regularized Log Posterior: " + this.m_PriorUpdate.getLogPosterior() + "\n");
        buf.append("===========================\n\n");
        return buf.toString();
    }

    public double CVBasedHyperparameter() throws Exception {
        boolean size = false;
        double[] list = null;
        double MaxHypeValue = 0.0;
        double MaxLikelihood = 0.0;
        StringTokenizer tokenizer = new StringTokenizer(this.HyperparameterRange);
        String rangeType = tokenizer.nextToken(":");
        if (rangeType.equals("R")) {
            String temp = tokenizer.nextToken();
            tokenizer = new StringTokenizer(temp);
            double start = Double.parseDouble(tokenizer.nextToken("-"));
            tokenizer = new StringTokenizer(tokenizer.nextToken());
            double end = Double.parseDouble(tokenizer.nextToken(","));
            double multiplier = Double.parseDouble(tokenizer.nextToken());
            int steps = (int)((Math.log10(end) - Math.log10(start)) / Math.log10(multiplier) + 1.0);
            list = new double[steps];
            int count = 0;
            for (double i = start; i <= end; i *= multiplier) {
                list[count++] = i;
            }
        } else if (rangeType.equals("L")) {
            Vector<String> vec = new Vector<String>();
            while (tokenizer.hasMoreTokens()) {
                vec.add(tokenizer.nextToken(","));
            }
            list = new double[vec.size()];
            for (int i = 0; i < vec.size(); ++i) {
                list[i] = Double.parseDouble((String)vec.get(i));
            }
        }
        if (list != null) {
            int numFolds = this.NumFolds;
            Random random = new Random(this.m_seed);
            this.m_Instances.randomize(random);
            this.m_Instances.stratify(numFolds);
            for (int k = 0; k < list.length; ++k) {
                for (int i = 0; i < numFolds; ++i) {
                    Instances train = this.m_Instances.trainCV(numFolds, i, random);
                    SerializedObject so = new SerializedObject(this);
                    BayesianLogisticRegression blr = (BayesianLogisticRegression)so.getObject();
                    blr.setHyperparameterSelection(new SelectedTag(3, TAGS_HYPER_METHOD));
                    blr.setHyperparameterValue(list[k]);
                    blr.setPriorClass(new SelectedTag(this.PriorClass, TAGS_PRIOR));
                    blr.setThreshold(this.Threshold);
                    blr.setTolerance(this.Tolerance);
                    blr.buildClassifier(train);
                    Instances test = this.m_Instances.testCV(numFolds, i);
                    double val = blr.getLoglikeliHood(blr.BetaVector, test);
                    if (this.debug) {
                        System.out.println("Fold " + i + "Hyperparameter: " + list[k]);
                        System.out.println("===================================");
                        System.out.println(" Likelihood: " + val);
                    }
                    if (!(k == 0 | val > MaxLikelihood)) continue;
                    MaxLikelihood = val;
                    MaxHypeValue = list[k];
                }
            }
        } else {
            return this.HyperparameterValue;
        }
        return MaxHypeValue;
    }

    public double getLoglikeliHood(double[] betas, Instances instances) {
        this.m_PriorUpdate.computelogLikelihood(betas, instances);
        return this.m_PriorUpdate.getLoglikelihood();
    }

    public Enumeration listOptions() {
        Vector<Option> newVector = new Vector<Option>();
        newVector.addElement(new Option("\tShow Debugging Output\n", "D", 0, "-D"));
        newVector.addElement(new Option("\tDistribution of the Prior (1=Gaussian, 2=Laplacian)\n\t(default: 1=Gaussian)", "P", 1, "-P <integer>"));
        newVector.addElement(new Option("\tHyperparameter Selection Method (1=Norm-based, 2=CV-based, 3=specific value)\n\t(default: 1=Norm-based)", "H", 1, "-H <integer>"));
        newVector.addElement(new Option("\tSpecified Hyperparameter Value (use in conjunction with -H 3)\n\t(default: 0.27)", "V", 1, "-V <double>"));
        newVector.addElement(new Option("\tHyperparameter Range (use in conjunction with -H 2)\n\t(format: R:start-end,multiplier OR L:val(1), val(2), ..., val(n))\n\t(default: R:0.01-316,3.16)", "R", 1, "-R <string>"));
        newVector.addElement(new Option("\tTolerance Value\n\t(default: 0.0005)", "Tl", 1, "-Tl <double>"));
        newVector.addElement(new Option("\tThreshold Value\n\t(default: 0.5)", "S", 1, "-S <double>"));
        newVector.addElement(new Option("\tNumber Of Folds (use in conjuction with -H 2)\n\t(default: 2)", "F", 1, "-F <integer>"));
        newVector.addElement(new Option("\tMax Number of Iterations\n\t(default: 100)", "I", 1, "-I <integer>"));
        newVector.addElement(new Option("\tNormalize the data", "N", 0, "-N"));
        newVector.addElement(new Option("\tSeed for randomizing instances order\n\tin CV-based hyperparameter selection\n\t(default: 1)", "seed", 1, "-seed <number>"));
        return newVector.elements();
    }

    public void setOptions(String[] options) throws Exception {
        String iterations;
        String seed;
        String folds;
        String HyperValue;
        String Hype;
        String Thres;
        this.debug = Utils.getFlag('D', options);
        String Tol = Utils.getOption("Tl", options);
        if (Tol.length() != 0) {
            this.Tolerance = Double.parseDouble(Tol);
        }
        if ((Thres = Utils.getOption('S', options)).length() != 0) {
            this.Threshold = Double.parseDouble(Thres);
        }
        if ((Hype = Utils.getOption('H', options)).length() != 0) {
            this.HyperparameterSelection = Integer.parseInt(Hype);
        }
        if ((HyperValue = Utils.getOption('V', options)).length() != 0) {
            this.HyperparameterValue = Double.parseDouble(HyperValue);
        }
        String HyperparameterRange = Utils.getOption("R", options);
        String strPrior = Utils.getOption('P', options);
        if (strPrior.length() != 0) {
            this.PriorClass = Integer.parseInt(strPrior);
        }
        if ((folds = Utils.getOption('F', options)).length() != 0) {
            this.NumFolds = Integer.parseInt(folds);
        }
        if ((seed = Utils.getOption("seed", options)).length() > 0) {
            this.setSeed(Integer.parseInt(seed));
        }
        if ((iterations = Utils.getOption('I', options)).length() != 0) {
            this.maxIterations = Integer.parseInt(iterations);
        }
        this.NormalizeData = Utils.getFlag('N', options);
        Utils.checkForRemainingOptions(options);
    }

    public String[] getOptions() {
        Vector<String> result = new Vector<String>();
        result.add("-D");
        result.add("-Tl");
        result.add("" + this.Tolerance);
        result.add("-S");
        result.add("" + this.Threshold);
        result.add("-H");
        result.add("" + this.HyperparameterSelection);
        result.add("-V");
        result.add("" + this.HyperparameterValue);
        result.add("-R");
        result.add("" + this.HyperparameterRange);
        result.add("-P");
        result.add("" + this.PriorClass);
        result.add("-F");
        result.add("" + this.NumFolds);
        result.add("-seed");
        result.add("" + this.getSeed());
        result.add("-I");
        result.add("" + this.maxIterations);
        result.add("-N");
        return result.toArray(new String[result.size()]);
    }

    public static void main(String[] argv) {
        BayesianLogisticRegression.runClassifier(new BayesianLogisticRegression(), argv);
    }

    public String debugTipText() {
        return "Turns on debugging mode.";
    }

    public void setDebug(boolean debugMode) {
        this.debug = debugMode;
    }

    public String hyperparameterSelectionTipText() {
        return "Select the type of Hyperparameter to be used.";
    }

    public SelectedTag getHyperparameterSelection() {
        return new SelectedTag(this.HyperparameterSelection, TAGS_HYPER_METHOD);
    }

    public void setHyperparameterSelection(SelectedTag newMethod) {
        if (newMethod.getTags() == TAGS_HYPER_METHOD) {
            int c = newMethod.getSelectedTag().getID();
            if (c >= 1 && c <= 3) {
                this.HyperparameterSelection = c;
            } else {
                throw new IllegalArgumentException("Wrong selection type, -H value should be: 1 for norm-based, 2 for CV-based and 3 for specific value");
            }
        }
    }

    public String priorClassTipText() {
        return "The type of prior to be used.";
    }

    public void setPriorClass(SelectedTag newMethod) {
        if (newMethod.getTags() == TAGS_PRIOR) {
            int c = newMethod.getSelectedTag().getID();
            if (c == 1 || c == 2) {
                this.PriorClass = c;
            } else {
                throw new IllegalArgumentException("Wrong selection type, -P value should be: 1 for Gaussian or 2 for Laplacian");
            }
        }
    }

    public SelectedTag getPriorClass() {
        return new SelectedTag(this.PriorClass, TAGS_PRIOR);
    }

    public String thresholdTipText() {
        return "Set the threshold for classifiction. The logistic function doesn't return a class label but an estimate of p(y=+1|B,x(i)). These estimates need to be converted to binary class label predictions. values above the threshold are assigned class +1.";
    }

    public double getThreshold() {
        return this.Threshold;
    }

    public void setThreshold(double threshold) {
        this.Threshold = threshold;
    }

    public String toleranceTipText() {
        return "This value decides the stopping criterion.";
    }

    public double getTolerance() {
        return this.Tolerance;
    }

    public void setTolerance(double tolerance) {
        this.Tolerance = tolerance;
    }

    public String hyperparameterValueTipText() {
        return "Specific hyperparameter value. Used when the hyperparameter selection method is set to specific value";
    }

    public double getHyperparameterValue() {
        return this.HyperparameterValue;
    }

    public void setHyperparameterValue(double hyperparameterValue) {
        this.HyperparameterValue = hyperparameterValue;
    }

    public String numFoldsTipText() {
        return "The number of folds to use for CV-based hyperparameter selection.";
    }

    public int getNumFolds() {
        return this.NumFolds;
    }

    public void setNumFolds(int numFolds) {
        this.NumFolds = numFolds;
    }

    public String seedTipText() {
        return "Seed for randomizing instances order prior to CV-based hyperparameter selection";
    }

    public void setSeed(int seed) {
        this.m_seed = seed;
    }

    public int getSeed() {
        return this.m_seed;
    }

    public String maxIterationsTipText() {
        return "The maximum number of iterations to perform.";
    }

    public int getMaxIterations() {
        return this.maxIterations;
    }

    public void setMaxIterations(int maxIterations) {
        this.maxIterations = maxIterations;
    }

    public String normalizeDataTipText() {
        return "Normalize the data.";
    }

    public boolean isNormalizeData() {
        return this.NormalizeData;
    }

    public void setNormalizeData(boolean normalizeData) {
        this.NormalizeData = normalizeData;
    }

    public String hyperparameterRangeTipText() {
        return "Hyperparameter value range. In case of CV-based Hyperparameters, you can specify the range in two ways: \nComma-Separated: L: 3,5,6 (This will be a list of possible values.)\nRange: R:0.01-316,3.16 (This will take values from 0.01-316 (inclusive) in multiplications of 3.16";
    }

    public String getHyperparameterRange() {
        return this.HyperparameterRange;
    }

    public void setHyperparameterRange(String hyperparameterRange) {
        this.HyperparameterRange = hyperparameterRange;
    }

    public boolean isDebug() {
        return this.debug;
    }

    public String getRevision() {
        return RevisionUtils.extract("$Revision: 7984 $");
    }

    static {
        TAGS_PRIOR = new Tag[]{new Tag(1, "Gaussian"), new Tag(2, "Laplacian")};
        TAGS_HYPER_METHOD = new Tag[]{new Tag(1, "Norm-based"), new Tag(2, "CV-based"), new Tag(3, "Specific value")};
    }
}

