/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.dataframe.traintestsplit;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.elasticsearch.xpack.ml.dataframe.traintestsplit.AbstractReservoirTrainTestSplitter;

public class StratifiedTrainTestSplitter
extends AbstractReservoirTrainTestSplitter {
    private final Map<String, AbstractReservoirTrainTestSplitter.SampleInfo> classSamples = new HashMap<String, AbstractReservoirTrainTestSplitter.SampleInfo>();

    public StratifiedTrainTestSplitter(List<String> fieldNames, String dependentVariable, Map<String, Long> classCounts, double trainingPercent, long randomizeSeed) {
        super(fieldNames, dependentVariable, trainingPercent, randomizeSeed);
        classCounts.entrySet().forEach(entry -> this.classSamples.put((String)entry.getKey(), new AbstractReservoirTrainTestSplitter.SampleInfo((Long)entry.getValue())));
    }

    @Override
    protected AbstractReservoirTrainTestSplitter.SampleInfo getSampleInfo(String[] row) {
        String classValue = row[this.dependentVariableIndex];
        AbstractReservoirTrainTestSplitter.SampleInfo sample = this.classSamples.get(classValue);
        if (sample == null) {
            throw new IllegalStateException("Unknown class [" + classValue + "]; expected one of " + String.valueOf(this.classSamples.keySet()));
        }
        return sample;
    }
}

