/*
 * Decompiled with CFR 0.152.
 */
package org.apache.lucene.search;

import java.io.IOException;
import java.util.Arrays;
import java.util.Objects;
import org.apache.lucene.document.LateInteractionField;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DoubleValues;
import org.apache.lucene.search.DoubleValuesSource;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MultiVectorSimilarity;

public class LateInteractionFloatValuesSource
extends DoubleValuesSource {
    private final String fieldName;
    private final float[][] queryVector;
    private final VectorSimilarityFunction vectorSimilarityFunction;
    private final MultiVectorSimilarity scoreFunction;

    public LateInteractionFloatValuesSource(String fieldName, float[][] queryVector) {
        this(fieldName, queryVector, VectorSimilarityFunction.COSINE, ScoreFunction.SUM_MAX_SIM);
    }

    public LateInteractionFloatValuesSource(String fieldName, float[][] queryVector, VectorSimilarityFunction vectorSimilarityFunction) {
        this(fieldName, queryVector, vectorSimilarityFunction, ScoreFunction.SUM_MAX_SIM);
    }

    public LateInteractionFloatValuesSource(String fieldName, float[][] queryVector, VectorSimilarityFunction vectorSimilarityFunction, MultiVectorSimilarity scoreFunction) {
        this.fieldName = Objects.requireNonNull(fieldName);
        this.queryVector = this.validateQueryVector(queryVector);
        this.vectorSimilarityFunction = Objects.requireNonNull(vectorSimilarityFunction);
        this.scoreFunction = Objects.requireNonNull(scoreFunction);
    }

    private float[][] validateQueryVector(float[][] queryVector) {
        if (queryVector == null || queryVector.length == 0) {
            throw new IllegalArgumentException("queryVector must not be null or empty");
        }
        if (queryVector[0] == null || queryVector[0].length == 0) {
            throw new IllegalArgumentException("composing token vectors in provided query vector should not be null or empty");
        }
        for (int i = 1; i < queryVector.length; ++i) {
            if (queryVector[i] != null && queryVector[i].length == queryVector[0].length) continue;
            throw new IllegalArgumentException("all composing token vectors in provided query vector should have the same length");
        }
        return queryVector;
    }

    @Override
    public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws IOException {
        final BinaryDocValues values = ctx.reader().getBinaryDocValues(this.fieldName);
        if (values == null) {
            return DoubleValues.EMPTY;
        }
        return new DoubleValues(){

            @Override
            public double doubleValue() throws IOException {
                return LateInteractionFloatValuesSource.this.scoreFunction.compare(LateInteractionFloatValuesSource.this.queryVector, LateInteractionField.decode(values.binaryValue()), LateInteractionFloatValuesSource.this.vectorSimilarityFunction);
            }

            @Override
            public boolean advanceExact(int doc) throws IOException {
                return values.advanceExact(doc);
            }
        };
    }

    @Override
    public boolean needsScores() {
        return false;
    }

    @Override
    public DoubleValuesSource rewrite(IndexSearcher reader) throws IOException {
        return this;
    }

    @Override
    public int hashCode() {
        return Objects.hash(new Object[]{this.fieldName, Arrays.deepHashCode((Object[])this.queryVector), this.vectorSimilarityFunction, this.scoreFunction});
    }

    @Override
    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || this.getClass() != obj.getClass()) {
            return false;
        }
        LateInteractionFloatValuesSource other = (LateInteractionFloatValuesSource)obj;
        return Objects.equals(this.fieldName, other.fieldName) && this.vectorSimilarityFunction == other.vectorSimilarityFunction && this.scoreFunction == other.scoreFunction && Arrays.deepEquals((Object[])this.queryVector, (Object[])other.queryVector);
    }

    @Override
    public String toString() {
        return "LateInteractionFloatValuesSource(fieldName=" + this.fieldName + " similarityFunction=" + String.valueOf((Object)this.vectorSimilarityFunction) + " scoreFunction=" + String.valueOf(this.scoreFunction.getClass()) + " queryVector=" + Arrays.deepToString((Object[])this.queryVector) + ")";
    }

    @Override
    public boolean isCacheable(LeafReaderContext ctx) {
        return true;
    }

    public static enum ScoreFunction implements MultiVectorSimilarity
    {
        SUM_MAX_SIM{

            @Override
            public float compare(float[][] queryVector, float[][] docVector, VectorSimilarityFunction vectorSimilarityFunction) {
                if (docVector.length == 0) {
                    return Float.MIN_VALUE;
                }
                float result = 0.0f;
                for (float[] q : queryVector) {
                    float maxSim = Float.MIN_VALUE;
                    for (float[] d : docVector) {
                        if (q.length != d.length) {
                            throw new IllegalArgumentException("Provided multi-vectors are incompatible. Their composing token vectors should have the same dimension, got " + q.length + " != " + d.length);
                        }
                        maxSim = Float.max(maxSim, vectorSimilarityFunction.compare(q, d));
                    }
                    result += maxSim;
                }
                return result;
            }
        };

    }
}

