/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.sparse.algorithm.seismic;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.function.Supplier;
import lombok.Generated;
import lombok.NonNull;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.SegmentInfo;
import org.apache.lucene.util.BytesRef;
import org.opensearch.neuralsearch.sparse.accessor.SparseVectorReader;
import org.opensearch.neuralsearch.sparse.algorithm.seismic.RandomClusteringAlgorithm;
import org.opensearch.neuralsearch.sparse.algorithm.seismic.SeismicPostingClusterer;
import org.opensearch.neuralsearch.sparse.cache.CacheGatedForwardIndexReader;
import org.opensearch.neuralsearch.sparse.cache.CacheKey;
import org.opensearch.neuralsearch.sparse.cache.CacheableClusteredPostingWriter;
import org.opensearch.neuralsearch.sparse.cache.ClusteredPostingCache;
import org.opensearch.neuralsearch.sparse.cache.ForwardIndexCache;
import org.opensearch.neuralsearch.sparse.cache.ForwardIndexCacheItem;
import org.opensearch.neuralsearch.sparse.codec.MergeHelper;
import org.opensearch.neuralsearch.sparse.codec.SparseBinaryDocValuesPassThrough;
import org.opensearch.neuralsearch.sparse.common.MergeStateFacade;
import org.opensearch.neuralsearch.sparse.data.DocWeight;
import org.opensearch.neuralsearch.sparse.data.DocumentCluster;
import org.opensearch.neuralsearch.sparse.data.PostingClusters;

public class BatchClusteringTask
implements Supplier<List<Pair<BytesRef, PostingClusters>>> {
    @Generated
    private static final Logger log = LogManager.getLogger(BatchClusteringTask.class);
    private final List<BytesRef> terms;
    private final CacheKey key;
    private final float summaryPruneRatio;
    private final float clusterRatio;
    private final int nPostings;
    private final MergeStateFacade mergeStateFacade;
    private final FieldInfo fieldInfo;
    private final MergeHelper mergeHelper;

    public BatchClusteringTask(List<BytesRef> terms, CacheKey key, float summaryPruneRatio, float clusterRatio, int nPostings, @NonNull MergeStateFacade mergeStateFacade, FieldInfo fieldInfo, MergeHelper mergeHelper) {
        Objects.requireNonNull(mergeStateFacade, "mergeStateFacade is marked non-null but is null");
        this.terms = terms.stream().map(BytesRef::deepCopyOf).toList();
        this.key = key;
        this.summaryPruneRatio = summaryPruneRatio;
        this.clusterRatio = clusterRatio;
        this.nPostings = nPostings;
        this.mergeStateFacade = mergeStateFacade;
        this.fieldInfo = fieldInfo;
        this.mergeHelper = mergeHelper;
    }

    @Override
    public List<Pair<BytesRef, PostingClusters>> get() {
        ArrayList<Pair<BytesRef, PostingClusters>> postingClusters = new ArrayList<Pair<BytesRef, PostingClusters>>();
        int maxDocs = this.getTotalDocs();
        if (maxDocs == 0) {
            return postingClusters;
        }
        try {
            for (BytesRef term : this.terms) {
                int[] newIdToFieldProducerIndex = new int[maxDocs];
                int[] newIdToOldId = new int[maxDocs];
                List<DocWeight> docWeights = this.mergeHelper.getMergedPostingForATerm(this.mergeStateFacade, term, this.fieldInfo, newIdToFieldProducerIndex, newIdToOldId);
                SeismicPostingClusterer seismicPostingClusterer = new SeismicPostingClusterer(this.nPostings, new RandomClusteringAlgorithm(this.summaryPruneRatio, this.clusterRatio, newDocId -> {
                    int oldId = newIdToOldId[newDocId];
                    int segmentIndex = newIdToFieldProducerIndex[newDocId];
                    BinaryDocValues binaryDocValues = this.mergeStateFacade.getDocValuesProducers()[segmentIndex].getBinary(this.fieldInfo);
                    SparseVectorReader reader = this.getCacheGatedForwardIndexReader(binaryDocValues);
                    return reader.read(oldId);
                }));
                List<DocumentCluster> clusters = seismicPostingClusterer.cluster(docWeights);
                postingClusters.add((Pair<BytesRef, PostingClusters>)Pair.of((Object)term, (Object)new PostingClusters(clusters)));
                CacheableClusteredPostingWriter writer = ClusteredPostingCache.getInstance().getOrCreate(this.key).getWriter();
                writer.insert(term, clusters);
            }
        }
        catch (IOException e) {
            log.error("cluster failed", (Throwable)e);
            throw new RuntimeException(e);
        }
        return postingClusters;
    }

    private int getTotalDocs() {
        int maxDocs = 0;
        for (int i = 0; i < this.mergeStateFacade.getMaxDocs().length; ++i) {
            maxDocs += this.mergeStateFacade.getMaxDocs()[i];
        }
        return maxDocs;
    }

    private SparseVectorReader getCacheGatedForwardIndexReader(BinaryDocValues binaryDocValues) {
        if (binaryDocValues instanceof SparseBinaryDocValuesPassThrough) {
            SparseBinaryDocValuesPassThrough sparseBinaryDocValues = (SparseBinaryDocValuesPassThrough)binaryDocValues;
            SegmentInfo segmentInfo = sparseBinaryDocValues.getSegmentInfo();
            CacheKey cacheKey = new CacheKey(segmentInfo, this.fieldInfo);
            ForwardIndexCacheItem index = (ForwardIndexCacheItem)ForwardIndexCache.getInstance().get(cacheKey);
            if (index == null) {
                return new CacheGatedForwardIndexReader(null, null, sparseBinaryDocValues);
            }
            return new CacheGatedForwardIndexReader(index.getReader(), index.getWriter(), sparseBinaryDocValues);
        }
        return SparseVectorReader.NOOP_READER;
    }

    @Generated
    public List<BytesRef> getTerms() {
        return this.terms;
    }
}

