/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.examples.ml;

import java.util.Arrays;
import org.apache.spark.examples.ml.JavaDocument;
import org.apache.spark.examples.ml.JavaLabeledDocument;
import org.apache.spark.ml.Estimator;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;
import org.apache.spark.ml.evaluation.Evaluator;
import org.apache.spark.ml.feature.HashingTF;
import org.apache.spark.ml.feature.Tokenizer;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.tuning.CrossValidator;
import org.apache.spark.ml.tuning.CrossValidatorModel;
import org.apache.spark.ml.tuning.ParamGridBuilder;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;

public class JavaModelSelectionViaCrossValidationExample {
    public static void main(String[] args) {
        SparkSession spark = SparkSession.builder().appName("JavaModelSelectionViaCrossValidationExample").getOrCreate();
        Dataset training = spark.createDataFrame(Arrays.asList(new JavaLabeledDocument(0L, "a b c d e spark", 1.0), new JavaLabeledDocument(1L, "b d", 0.0), new JavaLabeledDocument(2L, "spark f g h", 1.0), new JavaLabeledDocument(3L, "hadoop mapreduce", 0.0), new JavaLabeledDocument(4L, "b spark who", 1.0), new JavaLabeledDocument(5L, "g d a y", 0.0), new JavaLabeledDocument(6L, "spark fly", 1.0), new JavaLabeledDocument(7L, "was mapreduce", 0.0), new JavaLabeledDocument(8L, "e spark program", 1.0), new JavaLabeledDocument(9L, "a e c l", 0.0), new JavaLabeledDocument(10L, "spark compile", 1.0), new JavaLabeledDocument(11L, "hadoop software", 0.0)), JavaLabeledDocument.class);
        Tokenizer tokenizer = (Tokenizer)((Tokenizer)new Tokenizer().setInputCol("text")).setOutputCol("words");
        HashingTF hashingTF = new HashingTF().setNumFeatures(1000).setInputCol(tokenizer.getOutputCol()).setOutputCol("features");
        LogisticRegression lr = new LogisticRegression().setMaxIter(10).setRegParam(0.01);
        Pipeline pipeline = new Pipeline().setStages(new PipelineStage[]{tokenizer, hashingTF, lr});
        ParamMap[] paramGrid = new ParamGridBuilder().addGrid(hashingTF.numFeatures(), new int[]{10, 100, 1000}).addGrid(lr.regParam(), new double[]{0.1, 0.01}).build();
        CrossValidator cv = new CrossValidator().setEstimator((Estimator)pipeline).setEvaluator((Evaluator)new BinaryClassificationEvaluator()).setEstimatorParamMaps(paramGrid).setNumFolds(2).setParallelism(2);
        CrossValidatorModel cvModel = cv.fit(training);
        Dataset test = spark.createDataFrame(Arrays.asList(new JavaDocument(4L, "spark i j k"), new JavaDocument(5L, "l m n"), new JavaDocument(6L, "mapreduce spark"), new JavaDocument(7L, "apache hadoop")), JavaDocument.class);
        Dataset predictions = cvModel.transform(test);
        for (Row r : predictions.select("id", new String[]{"text", "probability", "prediction"}).collectAsList()) {
            System.out.println("(" + String.valueOf(r.get(0)) + ", " + String.valueOf(r.get(1)) + ") --> prob=" + String.valueOf(r.get(2)) + ", prediction=" + String.valueOf(r.get(3)));
        }
        spark.stop();
    }
}

