/*
 * Decompiled with CFR 0.152.
 */
package org.apache.lucene.ars_nouveau.search;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import org.apache.lucene.ars_nouveau.index.FieldInfo;
import org.apache.lucene.ars_nouveau.index.IndexReader;
import org.apache.lucene.ars_nouveau.index.LeafReader;
import org.apache.lucene.ars_nouveau.index.LeafReaderContext;
import org.apache.lucene.ars_nouveau.index.QueryTimeout;
import org.apache.lucene.ars_nouveau.search.BooleanClause;
import org.apache.lucene.ars_nouveau.search.BooleanQuery;
import org.apache.lucene.ars_nouveau.search.ConjunctionDISI;
import org.apache.lucene.ars_nouveau.search.DocIdSetIterator;
import org.apache.lucene.ars_nouveau.search.Explanation;
import org.apache.lucene.ars_nouveau.search.FieldExistsQuery;
import org.apache.lucene.ars_nouveau.search.FilteredDocIdSetIterator;
import org.apache.lucene.ars_nouveau.search.HitQueue;
import org.apache.lucene.ars_nouveau.search.IndexSearcher;
import org.apache.lucene.ars_nouveau.search.MatchNoDocsQuery;
import org.apache.lucene.ars_nouveau.search.Query;
import org.apache.lucene.ars_nouveau.search.QueryVisitor;
import org.apache.lucene.ars_nouveau.search.ScoreDoc;
import org.apache.lucene.ars_nouveau.search.ScoreMode;
import org.apache.lucene.ars_nouveau.search.Scorer;
import org.apache.lucene.ars_nouveau.search.ScorerSupplier;
import org.apache.lucene.ars_nouveau.search.TaskExecutor;
import org.apache.lucene.ars_nouveau.search.TimeLimitingKnnCollectorManager;
import org.apache.lucene.ars_nouveau.search.TopDocs;
import org.apache.lucene.ars_nouveau.search.TopDocsCollector;
import org.apache.lucene.ars_nouveau.search.TotalHits;
import org.apache.lucene.ars_nouveau.search.VectorScorer;
import org.apache.lucene.ars_nouveau.search.Weight;
import org.apache.lucene.ars_nouveau.search.knn.KnnCollectorManager;
import org.apache.lucene.ars_nouveau.search.knn.TopKnnCollectorManager;
import org.apache.lucene.ars_nouveau.util.BitSet;
import org.apache.lucene.ars_nouveau.util.BitSetIterator;
import org.apache.lucene.ars_nouveau.util.Bits;

abstract class AbstractKnnVectorQuery
extends Query {
    private static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS;
    protected final String field;
    protected final int k;
    protected final Query filter;

    public AbstractKnnVectorQuery(String field, int k, Query filter) {
        this.field = Objects.requireNonNull(field, "field");
        this.k = k;
        if (k < 1) {
            throw new IllegalArgumentException("k must be at least 1, got: " + k);
        }
        this.filter = filter;
    }

    @Override
    public Query rewrite(IndexSearcher indexSearcher) throws IOException {
        Weight filterWeight;
        IndexReader reader = indexSearcher.getIndexReader();
        if (this.filter != null) {
            BooleanQuery booleanQuery = new BooleanQuery.Builder().add(this.filter, BooleanClause.Occur.FILTER).add(new FieldExistsQuery(this.field), BooleanClause.Occur.FILTER).build();
            Query rewritten = indexSearcher.rewrite(booleanQuery);
            filterWeight = indexSearcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1.0f);
        } else {
            filterWeight = null;
        }
        TimeLimitingKnnCollectorManager knnCollectorManager = new TimeLimitingKnnCollectorManager(this.getKnnCollectorManager(this.k, indexSearcher), indexSearcher.getTimeout());
        TaskExecutor taskExecutor = indexSearcher.getTaskExecutor();
        List<LeafReaderContext> leafReaderContexts = reader.leaves();
        ArrayList tasks = new ArrayList(leafReaderContexts.size());
        for (LeafReaderContext context : leafReaderContexts) {
            tasks.add(() -> this.searchLeaf(context, filterWeight, knnCollectorManager));
        }
        TopDocs[] perLeafResults = (TopDocs[])taskExecutor.invokeAll(tasks).toArray(TopDocs[]::new);
        TopDocs topK = this.mergeLeafResults(perLeafResults);
        if (topK.scoreDocs.length == 0) {
            return new MatchNoDocsQuery();
        }
        return this.createRewrittenQuery(reader, topK);
    }

    private TopDocs searchLeaf(LeafReaderContext ctx, Weight filterWeight, TimeLimitingKnnCollectorManager timeLimitingKnnCollectorManager) throws IOException {
        TopDocs results = this.getLeafResults(ctx, filterWeight, timeLimitingKnnCollectorManager);
        if (ctx.docBase > 0) {
            for (ScoreDoc scoreDoc : results.scoreDocs) {
                scoreDoc.doc += ctx.docBase;
            }
        }
        return results;
    }

    private TopDocs getLeafResults(LeafReaderContext ctx, Weight filterWeight, TimeLimitingKnnCollectorManager timeLimitingKnnCollectorManager) throws IOException {
        LeafReader reader = ctx.reader();
        Bits liveDocs = reader.getLiveDocs();
        if (filterWeight == null) {
            return this.approximateSearch(ctx, liveDocs, Integer.MAX_VALUE, timeLimitingKnnCollectorManager);
        }
        Scorer scorer = filterWeight.scorer(ctx);
        if (scorer == null) {
            return NO_RESULTS;
        }
        BitSet acceptDocs = this.createBitSet(scorer.iterator(), liveDocs, reader.maxDoc());
        int cost = acceptDocs.cardinality();
        QueryTimeout queryTimeout = timeLimitingKnnCollectorManager.getQueryTimeout();
        if (cost <= this.k) {
            return this.exactSearch(ctx, new BitSetIterator(acceptDocs, cost), queryTimeout);
        }
        TopDocs results = this.approximateSearch(ctx, acceptDocs, cost + 1, timeLimitingKnnCollectorManager);
        if (results.totalHits.relation() == TotalHits.Relation.EQUAL_TO || queryTimeout != null && queryTimeout.shouldExit()) {
            return results;
        }
        return this.exactSearch(ctx, new BitSetIterator(acceptDocs, cost), queryTimeout);
    }

    private BitSet createBitSet(DocIdSetIterator iterator, final Bits liveDocs, int maxDoc) throws IOException {
        if (liveDocs == null && iterator instanceof BitSetIterator) {
            BitSetIterator bitSetIterator = (BitSetIterator)iterator;
            return bitSetIterator.getBitSet();
        }
        FilteredDocIdSetIterator filterIterator = new FilteredDocIdSetIterator(this, iterator){

            @Override
            protected boolean match(int doc) {
                return liveDocs == null || liveDocs.get(doc);
            }
        };
        return BitSet.of(filterIterator, maxDoc);
    }

    protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) {
        return new TopKnnCollectorManager(k, searcher);
    }

    protected abstract TopDocs approximateSearch(LeafReaderContext var1, Bits var2, int var3, KnnCollectorManager var4) throws IOException;

    abstract VectorScorer createVectorScorer(LeafReaderContext var1, FieldInfo var2) throws IOException;

    protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout) throws IOException {
        int doc;
        FieldInfo fi = context.reader().getFieldInfos().fieldInfo(this.field);
        if (fi == null || fi.getVectorDimension() == 0) {
            return NO_RESULTS;
        }
        VectorScorer vectorScorer = this.createVectorScorer(context, fi);
        if (vectorScorer == null) {
            return NO_RESULTS;
        }
        int queueSize = Math.min(this.k, Math.toIntExact(acceptIterator.cost()));
        HitQueue queue = new HitQueue(queueSize, true);
        TotalHits.Relation relation = TotalHits.Relation.EQUAL_TO;
        ScoreDoc topDoc = (ScoreDoc)queue.top();
        DocIdSetIterator vectorIterator = vectorScorer.iterator();
        DocIdSetIterator conjunction = ConjunctionDISI.createConjunction(List.of(vectorIterator, acceptIterator), List.of());
        while ((doc = conjunction.nextDoc()) != Integer.MAX_VALUE) {
            if (queryTimeout != null && queryTimeout.shouldExit()) {
                relation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
                break;
            }
            assert (vectorIterator.docID() == doc);
            float score = vectorScorer.score();
            if (!(score > topDoc.score)) continue;
            topDoc.score = score;
            topDoc.doc = doc;
            topDoc = (ScoreDoc)queue.updateTop();
        }
        while (queue.size() > 0 && ((ScoreDoc)queue.top()).score < 0.0f) {
            queue.pop();
        }
        ScoreDoc[] topScoreDocs = new ScoreDoc[queue.size()];
        for (int i = topScoreDocs.length - 1; i >= 0; --i) {
            topScoreDocs[i] = (ScoreDoc)queue.pop();
        }
        TotalHits totalHits = new TotalHits(acceptIterator.cost(), relation);
        return new TopDocs(totalHits, topScoreDocs);
    }

    protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) {
        return TopDocs.merge(this.k, perLeafResults);
    }

    private Query createRewrittenQuery(IndexReader reader, TopDocs topK) {
        int len = topK.scoreDocs.length;
        assert (len > 0);
        float maxScore = topK.scoreDocs[0].score;
        Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc));
        int[] docs = new int[len];
        float[] scores = new float[len];
        for (int i = 0; i < len; ++i) {
            docs[i] = topK.scoreDocs[i].doc;
            scores[i] = topK.scoreDocs[i].score;
        }
        int[] segmentStarts = AbstractKnnVectorQuery.findSegmentStarts(reader.leaves(), docs);
        return new DocAndScoreQuery(docs, scores, maxScore, segmentStarts, reader.getContext().id());
    }

    static int[] findSegmentStarts(List<LeafReaderContext> leaves, int[] docs) {
        int[] starts = new int[leaves.size() + 1];
        starts[starts.length - 1] = docs.length;
        if (starts.length == 2) {
            return starts;
        }
        int resultIndex = 0;
        for (int i = 1; i < starts.length - 1; ++i) {
            int upper = leaves.get((int)i).docBase;
            if ((resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper)) < 0) {
                resultIndex = -1 - resultIndex;
            }
            starts[i] = resultIndex;
        }
        return starts;
    }

    @Override
    public void visit(QueryVisitor visitor) {
        if (visitor.acceptField(this.field)) {
            visitor.visitLeaf(this);
        }
    }

    @Override
    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        AbstractKnnVectorQuery that = (AbstractKnnVectorQuery)o;
        return this.k == that.k && Objects.equals(this.field, that.field) && Objects.equals(this.filter, that.filter);
    }

    @Override
    public int hashCode() {
        return Objects.hash(this.field, this.k, this.filter);
    }

    public String getField() {
        return this.field;
    }

    public int getK() {
        return this.k;
    }

    public Query getFilter() {
        return this.filter;
    }

    static class DocAndScoreQuery
    extends Query {
        private final int[] docs;
        private final float[] scores;
        private final float maxScore;
        private final int[] segmentStarts;
        private final Object contextIdentity;

        DocAndScoreQuery(int[] docs, float[] scores, float maxScore, int[] segmentStarts, Object contextIdentity) {
            this.docs = docs;
            this.scores = scores;
            this.maxScore = maxScore;
            this.segmentStarts = segmentStarts;
            this.contextIdentity = contextIdentity;
        }

        @Override
        public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, final float boost) throws IOException {
            if (searcher.getIndexReader().getContext().id() != this.contextIdentity) {
                throw new IllegalStateException("This DocAndScore query was created by a different reader");
            }
            return new Weight(this){

                @Override
                public Explanation explain(LeafReaderContext context, int doc) {
                    int found = Arrays.binarySearch(docs, doc + context.docBase);
                    if (found < 0) {
                        return Explanation.noMatch("not in top " + docs.length + " docs", new Explanation[0]);
                    }
                    return Explanation.match((Number)Float.valueOf(scores[found] * boost), "within top " + docs.length + " docs", new Explanation[0]);
                }

                @Override
                public int count(LeafReaderContext context) {
                    return segmentStarts[context.ord + 1] - segmentStarts[context.ord];
                }

                @Override
                public ScorerSupplier scorerSupplier(final LeafReaderContext context) throws IOException {
                    if (segmentStarts[context.ord] == segmentStarts[context.ord + 1]) {
                        return null;
                    }
                    Scorer scorer = new Scorer(){
                        final int lower;
                        final int upper;
                        int upTo;
                        {
                            this.lower = segmentStarts[context.ord];
                            this.upper = segmentStarts[context.ord + 1];
                            this.upTo = -1;
                        }

                        @Override
                        public DocIdSetIterator iterator() {
                            return new DocIdSetIterator(){

                                @Override
                                public int docID() {
                                    return this.docIdNoShadow();
                                }

                                @Override
                                public int nextDoc() {
                                    upTo = upTo == -1 ? lower : ++upTo;
                                    return this.docIdNoShadow();
                                }

                                @Override
                                public int advance(int target) throws IOException {
                                    return this.slowAdvance(target);
                                }

                                @Override
                                public long cost() {
                                    return upper - lower;
                                }
                            };
                        }

                        @Override
                        public float getMaxScore(int docId) {
                            return maxScore * boost;
                        }

                        @Override
                        public float score() {
                            return scores[this.upTo] * boost;
                        }

                        private int docIdNoShadow() {
                            if (this.upTo == -1) {
                                return -1;
                            }
                            if (this.upTo >= this.upper) {
                                return Integer.MAX_VALUE;
                            }
                            return docs[this.upTo] - context.docBase;
                        }

                        @Override
                        public int docID() {
                            return this.docIdNoShadow();
                        }
                    };
                    return new Weight.DefaultScorerSupplier(scorer);
                }

                @Override
                public boolean isCacheable(LeafReaderContext ctx) {
                    return true;
                }
            };
        }

        @Override
        public String toString(String field) {
            return "DocAndScoreQuery[" + this.docs[0] + ",...][" + this.scores[0] + ",...]," + this.maxScore;
        }

        @Override
        public void visit(QueryVisitor visitor) {
            visitor.visitLeaf(this);
        }

        @Override
        public boolean equals(Object obj) {
            if (!this.sameClassAs(obj)) {
                return false;
            }
            return this.contextIdentity == ((DocAndScoreQuery)obj).contextIdentity && Arrays.equals(this.docs, ((DocAndScoreQuery)obj).docs) && Arrays.equals(this.scores, ((DocAndScoreQuery)obj).scores);
        }

        @Override
        public int hashCode() {
            return Objects.hash(this.classHash(), this.contextIdentity, Arrays.hashCode(this.docs), Arrays.hashCode(this.scores));
        }
    }
}

