/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training.dataset;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.training.dataset.Batch;
import ai.djl.training.dataset.BulkDataIterable;
import ai.djl.training.dataset.DataIterable;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.dataset.Record;
import ai.djl.training.dataset.Sampler;
import ai.djl.translate.Batchifier;
import ai.djl.translate.TranslateException;
import ai.djl.util.Progress;
import java.io.IOException;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.stream.Stream;

public class ArrayDataset
extends RandomAccessDataset {
    protected NDArray[] data;
    protected NDArray[] labels;

    public ArrayDataset(RandomAccessDataset.BaseBuilder<?> builder) {
        super(builder);
        if (builder instanceof Builder) {
            Builder builder2 = (Builder)builder;
            this.data = builder2.data;
            this.labels = builder2.labels;
            long size = this.data[0].size(0);
            if (Stream.of(this.data).anyMatch(array -> array.size(0) != size)) {
                throw new IllegalArgumentException("All the NDArray must have the same length!");
            }
            if (this.labels != null && Stream.of(this.labels).anyMatch(array -> array.size(0) != size)) {
                throw new IllegalArgumentException("All the NDArray must have the same length!");
            }
        }
    }

    ArrayDataset() {
    }

    @Override
    protected long availableSize() {
        return this.data[0].size(0);
    }

    @Override
    public Record get(NDManager manager, long index) {
        NDList datum = new NDList();
        NDList label = new NDList();
        for (NDArray array : this.data) {
            datum.add(array.get(manager, index));
        }
        if (this.labels != null) {
            for (NDArray array : this.labels) {
                label.add(array.get(manager, index));
            }
        }
        return new Record(datum, label);
    }

    public Batch getByIndices(NDManager manager, long ... indices) {
        try (NDArray ndIndices = manager.create(indices);){
            NDIndex index = new NDIndex("{}", ndIndices);
            NDList datum = new NDList();
            NDList label = new NDList();
            for (NDArray array : this.data) {
                datum.add(array.get(manager, index));
            }
            if (this.labels != null) {
                for (NDArray array : this.labels) {
                    label.add(array.get(manager, index));
                }
            }
            Batch batch = new Batch(manager, datum, label, indices.length, Batchifier.STACK, Batchifier.STACK, -1L, -1L);
            return batch;
        }
    }

    public Batch getByRange(NDManager manager, long fromIndex, long toIndex) {
        NDIndex index = new NDIndex().addSliceDim(fromIndex, toIndex);
        NDList datum = new NDList();
        NDList label = new NDList();
        for (NDArray array : this.data) {
            datum.add(array.get(manager, index));
        }
        if (this.labels != null) {
            for (NDArray array : this.labels) {
                label.add(array.get(manager, index));
            }
        }
        int size = Math.toIntExact(toIndex - fromIndex);
        return new Batch(manager, datum, label, size, Batchifier.STACK, Batchifier.STACK, -1L, -1L);
    }

    @Override
    protected RandomAccessDataset newSubDataset(int[] indices, int from, int to) {
        return new SubDataset(this, indices, from, to);
    }

    @Override
    protected RandomAccessDataset newSubDataset(List<Long> subIndices) {
        return new SubDatasetByIndices(this, subIndices);
    }

    @Override
    public Iterable<Batch> getData(NDManager manager, Sampler sampler, ExecutorService executorService) throws IOException, TranslateException {
        this.prepare();
        if (this.dataBatchifier == Batchifier.STACK && this.labelBatchifier == Batchifier.STACK) {
            return new BulkDataIterable(this, manager, sampler, this.dataBatchifier, this.labelBatchifier, this.pipeline, this.targetPipeline, executorService, this.prefetchNumber, this.device);
        }
        return new DataIterable(this, manager, sampler, this.dataBatchifier, this.labelBatchifier, this.pipeline, this.targetPipeline, executorService, this.prefetchNumber, this.device);
    }

    @Override
    public void prepare(Progress progress) throws IOException {
    }

    private static final class SubDatasetByIndices
    extends ArrayDataset {
        private ArrayDataset dataset;
        private List<Long> subIndices;

        public SubDatasetByIndices(ArrayDataset dataset, List<Long> subIndices) {
            this.dataset = dataset;
            this.subIndices = subIndices;
            this.sampler = dataset.sampler;
            this.dataBatchifier = dataset.dataBatchifier;
            this.labelBatchifier = dataset.labelBatchifier;
            this.pipeline = dataset.pipeline;
            this.targetPipeline = dataset.targetPipeline;
            this.prefetchNumber = dataset.prefetchNumber;
            this.device = dataset.device;
            this.limit = Long.MAX_VALUE;
        }

        @Override
        public Record get(NDManager manager, long index) {
            return this.dataset.get(manager, this.subIndices.get(Math.toIntExact(index)));
        }

        @Override
        public Batch getByIndices(NDManager manager, long ... indices) {
            long[] resolvedIndices = new long[indices.length];
            int i = 0;
            for (long index : indices) {
                resolvedIndices[i++] = this.subIndices.get(Math.toIntExact(index));
            }
            return this.dataset.getByIndices(manager, resolvedIndices);
        }

        @Override
        public Batch getByRange(NDManager manager, long fromIndex, long toIndex) {
            long[] resolvedIndices = new long[(int)(toIndex - fromIndex)];
            int i = 0;
            for (long index = fromIndex; index < toIndex; ++index) {
                resolvedIndices[i++] = this.subIndices.get(Math.toIntExact(index));
            }
            return this.dataset.getByIndices(manager, resolvedIndices);
        }

        @Override
        protected long availableSize() {
            return this.subIndices.size();
        }

        @Override
        public void prepare(Progress progress) {
        }
    }

    private static final class SubDataset
    extends ArrayDataset {
        private ArrayDataset dataset;
        private int[] indices;
        private int from;
        private int to;

        public SubDataset(ArrayDataset dataset, int[] indices, int from, int to) {
            this.dataset = dataset;
            this.indices = indices;
            this.from = from;
            this.to = to;
            this.sampler = dataset.sampler;
            this.dataBatchifier = dataset.dataBatchifier;
            this.labelBatchifier = dataset.labelBatchifier;
            this.pipeline = dataset.pipeline;
            this.targetPipeline = dataset.targetPipeline;
            this.prefetchNumber = dataset.prefetchNumber;
            this.device = dataset.device;
            this.limit = Long.MAX_VALUE;
        }

        @Override
        public Record get(NDManager manager, long index) {
            if (index >= this.size()) {
                throw new IndexOutOfBoundsException("index(" + index + ") > size(" + this.size() + ").");
            }
            return this.dataset.get(manager, this.indices[Math.toIntExact(index) + this.from]);
        }

        @Override
        public Batch getByIndices(NDManager manager, long ... indices) {
            long[] resolvedIndices = new long[indices.length];
            int i = 0;
            for (long index : indices) {
                resolvedIndices[i++] = this.indices[Math.toIntExact(index) + this.from];
            }
            return this.dataset.getByIndices(manager, resolvedIndices);
        }

        @Override
        public Batch getByRange(NDManager manager, long fromIndex, long toIndex) {
            return this.dataset.getByRange(manager, fromIndex + (long)this.from, toIndex + (long)this.from);
        }

        @Override
        protected long availableSize() {
            return this.to - this.from;
        }

        @Override
        public void prepare(Progress progress) {
        }
    }

    public static final class Builder
    extends RandomAccessDataset.BaseBuilder<Builder> {
        private NDArray[] data;
        private NDArray[] labels;

        @Override
        protected Builder self() {
            return this;
        }

        public Builder setData(NDArray ... data) {
            this.data = data;
            return this.self();
        }

        public Builder optLabels(NDArray ... labels) {
            this.labels = labels;
            return this.self();
        }

        public ArrayDataset build() {
            if (this.data == null || this.data.length == 0) {
                throw new IllegalArgumentException("Please pass in at least one data");
            }
            return new ArrayDataset(this);
        }
    }
}

