/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.regression.sgd.fm;

import com.oracle.labs.mlrg.olcut.config.ConfigurationManager;
import com.oracle.labs.mlrg.olcut.config.Option;
import com.oracle.labs.mlrg.olcut.config.Options;
import com.oracle.labs.mlrg.olcut.config.UsageException;
import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.io.IOException;
import java.util.logging.Logger;
import org.tribuo.Dataset;
import org.tribuo.Model;
import org.tribuo.OutputFactory;
import org.tribuo.data.DataOptions;
import org.tribuo.math.StochasticGradientOptimiser;
import org.tribuo.math.optimisers.GradientOptimiserOptions;
import org.tribuo.regression.RegressionFactory;
import org.tribuo.regression.evaluation.RegressionEvaluation;
import org.tribuo.regression.sgd.RegressionObjective;
import org.tribuo.regression.sgd.fm.FMRegressionTrainer;
import org.tribuo.regression.sgd.objectives.AbsoluteLoss;
import org.tribuo.regression.sgd.objectives.Huber;
import org.tribuo.regression.sgd.objectives.SquaredLoss;
import org.tribuo.util.Util;

public final class TrainTest {
    private static final Logger logger = Logger.getLogger(TrainTest.class.getName());

    public static void main(String[] args) throws IOException {
        ConfigurationManager cm;
        LabsLogFormatter.setAllLogFormatters();
        FMRegressionOptions o = new FMRegressionOptions();
        try {
            cm = new ConfigurationManager(args, (Options)o);
        }
        catch (UsageException e) {
            logger.info(e.getMessage());
            return;
        }
        if (o.general.trainingPath == null || o.general.testingPath == null) {
            logger.info(cm.usage());
            return;
        }
        logger.info("Configuring gradient optimiser");
        RegressionObjective obj = null;
        switch (o.loss) {
            case ABSOLUTE: {
                obj = new AbsoluteLoss();
                break;
            }
            case SQUARED: {
                obj = new SquaredLoss();
                break;
            }
            case HUBER: {
                obj = new Huber();
                break;
            }
            default: {
                logger.warning("Unknown objective function " + (Object)((Object)o.loss));
                logger.info(cm.usage());
                return;
            }
        }
        StochasticGradientOptimiser grad = o.gradientOptions.getOptimiser();
        logger.info(String.format("Set logging interval to %d", o.loggingInterval));
        RegressionFactory factory = new RegressionFactory();
        Pair data = o.general.load((OutputFactory)factory);
        Dataset train = (Dataset)data.getA();
        Dataset test = (Dataset)data.getB();
        logger.info("Feature domain - " + train.getFeatureIDMap());
        FMRegressionTrainer trainer = new FMRegressionTrainer(obj, grad, o.epochs, o.loggingInterval, o.minibatchSize, o.general.seed, o.factorSize, o.variance, o.standardise);
        logger.info("Training using " + ((Object)((Object)trainer)).toString());
        long trainStart = System.currentTimeMillis();
        Model model = trainer.train(train);
        long trainStop = System.currentTimeMillis();
        logger.info("Finished training regressor " + Util.formatDuration((long)trainStart, (long)trainStop));
        long testStart = System.currentTimeMillis();
        RegressionEvaluation evaluation = (RegressionEvaluation)factory.getEvaluator().evaluate(model, test);
        long testStop = System.currentTimeMillis();
        logger.info("Finished evaluating model " + Util.formatDuration((long)testStart, (long)testStop));
        System.out.println(evaluation.toString());
        if (o.general.outputPath != null) {
            o.general.saveModel(model);
        }
    }

    public static class FMRegressionOptions
    implements Options {
        public DataOptions general;
        public GradientOptimiserOptions gradientOptions;
        @Option(charName=105, longName="epochs", usage="Number of SGD epochs.")
        public int epochs = 5;
        @Option(charName=111, longName="objective", usage="Loss function.")
        public LossEnum loss = LossEnum.SQUARED;
        @Option(charName=112, longName="logging-interval", usage="Log the objective after <int> examples.")
        public int loggingInterval = 100;
        @Option(charName=122, longName="minibatch-size", usage="Minibatch size.")
        public int minibatchSize = 1;
        @Option(charName=100, longName="factor-size", usage="Factor size.")
        public int factorSize = 5;
        @Option(longName="variance", usage="Variance of the initialization gaussian.")
        public double variance = 0.5;
        @Option(longName="standardise", usage="Standardise the output regressors before model fitting.")
        public boolean standardise = false;

        public String getOptionsDescription() {
            return "Trains and tests a linear SGD regression model on the specified datasets.";
        }
    }

    public static enum LossEnum {
        ABSOLUTE,
        SQUARED,
        HUBER;

    }
}

