package org.apache.lucene.ars_nouveau.util.hnsw;

import java.io.IOException;
import org.apache.lucene.ars_nouveau.search.KnnCollector;
import org.apache.lucene.ars_nouveau.search.TopKnnCollector;
import org.apache.lucene.ars_nouveau.util.BitSet;
import org.apache.lucene.ars_nouveau.util.Bits;
import org.apache.lucene.ars_nouveau.util.FixedBitSet;
import org.apache.lucene.ars_nouveau.util.SparseFixedBitSet;
import org.apache.lucene.ars_nouveau.util.hnsw.HnswGraphBuilder;

/* loaded from: input_file:org/apache/lucene/ars_nouveau/util/hnsw/HnswGraphSearcher.class */
public class HnswGraphSearcher {
    private final NeighborQueue candidates;
    private BitSet visited;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:org/apache/lucene/ars_nouveau/util/hnsw/HnswGraphSearcher$OnHeapHnswGraphSearcher.class */
    private static class OnHeapHnswGraphSearcher extends HnswGraphSearcher {
        private NeighborArray cur;
        private int upto;

        private OnHeapHnswGraphSearcher(NeighborQueue neighborQueue, BitSet bitSet) {
            super(neighborQueue, bitSet);
        }

        @Override // org.apache.lucene.ars_nouveau.util.hnsw.HnswGraphSearcher
        void graphSeek(HnswGraph hnswGraph, int i, int i2) {
            this.cur = ((OnHeapHnswGraph) hnswGraph).getNeighbors(i, i2);
            this.upto = -1;
        }

        @Override // org.apache.lucene.ars_nouveau.util.hnsw.HnswGraphSearcher
        int graphNextNeighbor(HnswGraph hnswGraph) {
            int i = this.upto + 1;
            this.upto = i;
            if (i < this.cur.size()) {
                return this.cur.nodes()[this.upto];
            }
            return Integer.MAX_VALUE;
        }
    }

    public HnswGraphSearcher(NeighborQueue neighborQueue, BitSet bitSet) {
        this.candidates = neighborQueue;
        this.visited = bitSet;
    }

    public static void search(RandomVectorScorer randomVectorScorer, KnnCollector knnCollector, HnswGraph hnswGraph, Bits bits) throws IOException {
        search(randomVectorScorer, knnCollector, hnswGraph, new HnswGraphSearcher(new NeighborQueue(knnCollector.k(), true), new SparseFixedBitSet(getGraphSize(hnswGraph))), bits);
    }

    public static KnnCollector search(RandomVectorScorer randomVectorScorer, int i, OnHeapHnswGraph onHeapHnswGraph, Bits bits, int i2) throws IOException {
        TopKnnCollector topKnnCollector = new TopKnnCollector(i, i2);
        search(randomVectorScorer, topKnnCollector, onHeapHnswGraph, new OnHeapHnswGraphSearcher(new NeighborQueue(i, true), new SparseFixedBitSet(getGraphSize(onHeapHnswGraph))), bits);
        return topKnnCollector;
    }

    private static void search(RandomVectorScorer randomVectorScorer, KnnCollector knnCollector, HnswGraph hnswGraph, HnswGraphSearcher hnswGraphSearcher, Bits bits) throws IOException {
        int findBestEntryPoint = hnswGraphSearcher.findBestEntryPoint(randomVectorScorer, hnswGraph, knnCollector);
        if (findBestEntryPoint != -1) {
            hnswGraphSearcher.searchLevel(knnCollector, randomVectorScorer, 0, new int[]{findBestEntryPoint}, hnswGraph, bits);
        }
    }

    public HnswGraphBuilder.GraphBuilderKnnCollector searchLevel(RandomVectorScorer randomVectorScorer, int i, int i2, int[] iArr, HnswGraph hnswGraph) throws IOException {
        HnswGraphBuilder.GraphBuilderKnnCollector graphBuilderKnnCollector = new HnswGraphBuilder.GraphBuilderKnnCollector(i);
        searchLevel(graphBuilderKnnCollector, randomVectorScorer, i2, iArr, hnswGraph, null);
        return graphBuilderKnnCollector;
    }

    private int findBestEntryPoint(RandomVectorScorer randomVectorScorer, HnswGraph hnswGraph, KnnCollector knnCollector) throws IOException {
        int entryNode = hnswGraph.entryNode();
        if (entryNode == -1 || hnswGraph.numLevels() == 1) {
            return entryNode;
        }
        int graphSize = getGraphSize(hnswGraph);
        prepareScratchState(graphSize);
        float score = randomVectorScorer.score(entryNode);
        knnCollector.incVisitedCount(1);
        for (int numLevels = hnswGraph.numLevels() - 1; numLevels >= 1; numLevels--) {
            boolean z = true;
            this.visited.set(entryNode);
            while (z) {
                z = false;
                graphSeek(hnswGraph, numLevels, entryNode);
                while (true) {
                    int graphNextNeighbor = graphNextNeighbor(hnswGraph);
                    if (graphNextNeighbor != Integer.MAX_VALUE) {
                        if (!$assertionsDisabled && graphNextNeighbor >= graphSize) {
                            throw new AssertionError("friendOrd=" + graphNextNeighbor + "; size=" + graphSize);
                        }
                        if (!this.visited.getAndSet(graphNextNeighbor)) {
                            if (knnCollector.earlyTerminated()) {
                                return -1;
                            }
                            float score2 = randomVectorScorer.score(graphNextNeighbor);
                            knnCollector.incVisitedCount(1);
                            if (score2 > score) {
                                score = score2;
                                entryNode = graphNextNeighbor;
                                z = true;
                            }
                        }
                    }
                }
            }
        }
        if (knnCollector.earlyTerminated()) {
            return -1;
        }
        return entryNode;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void searchLevel(KnnCollector knnCollector, RandomVectorScorer randomVectorScorer, int i, int[] iArr, HnswGraph hnswGraph, Bits bits) throws IOException {
        int graphSize = getGraphSize(hnswGraph);
        prepareScratchState(graphSize);
        for (int i2 : iArr) {
            if (!this.visited.getAndSet(i2)) {
                if (knnCollector.earlyTerminated()) {
                    break;
                }
                float score = randomVectorScorer.score(i2);
                knnCollector.incVisitedCount(1);
                this.candidates.add(i2, score);
                if (bits == null || bits.get(i2)) {
                    knnCollector.collect(i2, score);
                }
            }
        }
        float minCompetitiveSimilarity = knnCollector.minCompetitiveSimilarity();
        while (this.candidates.size() > 0 && !knnCollector.earlyTerminated() && this.candidates.topScore() >= minCompetitiveSimilarity) {
            graphSeek(hnswGraph, i, this.candidates.pop());
            while (true) {
                int graphNextNeighbor = graphNextNeighbor(hnswGraph);
                if (graphNextNeighbor == Integer.MAX_VALUE) {
                    break;
                }
                if (!$assertionsDisabled && graphNextNeighbor >= graphSize) {
                    throw new AssertionError("friendOrd=" + graphNextNeighbor + "; size=" + graphSize);
                }
                if (!this.visited.getAndSet(graphNextNeighbor)) {
                    if (knnCollector.earlyTerminated()) {
                        break;
                    }
                    float score2 = randomVectorScorer.score(graphNextNeighbor);
                    knnCollector.incVisitedCount(1);
                    if (score2 > minCompetitiveSimilarity) {
                        this.candidates.add(graphNextNeighbor, score2);
                        if (bits == null || bits.get(graphNextNeighbor)) {
                            if (knnCollector.collect(graphNextNeighbor, score2)) {
                                minCompetitiveSimilarity = knnCollector.minCompetitiveSimilarity();
                            }
                        }
                    }
                }
            }
        }
    }

    private void prepareScratchState(int i) {
        this.candidates.clear();
        if (this.visited.length() < i) {
            this.visited = FixedBitSet.ensureCapacity((FixedBitSet) this.visited, i);
        }
        this.visited.clear();
    }

    void graphSeek(HnswGraph hnswGraph, int i, int i2) throws IOException {
        hnswGraph.seek(i, i2);
    }

    int graphNextNeighbor(HnswGraph hnswGraph) throws IOException {
        return hnswGraph.nextNeighbor();
    }

    private static int getGraphSize(HnswGraph hnswGraph) {
        return hnswGraph.maxNodeId() + 1;
    }

    static {
        $assertionsDisabled = !HnswGraphSearcher.class.desiredAssertionStatus();
    }
}
