Skip to content

Commit

Permalink
[feature][dingo-executor] Support for binary vector indexing
Browse files Browse the repository at this point in the history
  • Loading branch information
githubgxll authored and ketor committed Jan 23, 2025
1 parent 04037f7 commit 75d37b5
Show file tree
Hide file tree
Showing 33 changed files with 766 additions and 159 deletions.
5 changes: 4 additions & 1 deletion dingo-calcite/src/main/codegen/templates/Parser.jj
Original file line number Diff line number Diff line change
Expand Up @@ -3990,7 +3990,8 @@ List<Object> Expression2(ExprContext exprContext) :
{
if ((op == SqlUserDefinedOperators.L2_DISTANCE
|| op == SqlUserDefinedOperators.IP_DISTANCE
|| op == SqlUserDefinedOperators.COSINE_SIMILARITY)
|| op == SqlUserDefinedOperators.COSINE_SIMILARITY
|| op == SqlUserDefinedOperators.HAMMING_DISTANCE)
&& list.size() == 3) {
SqlNode call = op.createCall(s.end(this), (SqlNode)list.get(0), (SqlNode)list.get(2));
list.clear();
Expand Down Expand Up @@ -7851,6 +7852,7 @@ SqlBinaryOperator BinaryRowOperator() :
| <L2DISTANCE> { return SqlUserDefinedOperators.L2_DISTANCE; }
| <COSINESIMILARITY> { return SqlUserDefinedOperators.COSINE_SIMILARITY; }
| <IPDISTANCE> { return SqlUserDefinedOperators.IP_DISTANCE; }
| <HAMMINGDISTANCE> { return SqlUserDefinedOperators.HAMMING_DISTANCE; }
}

/**
Expand Down Expand Up @@ -8833,6 +8835,7 @@ void NonReservedKeyWord2of3() :
| < L2DISTANCE: "<->" >
| < COSINESIMILARITY: "<=>" >
| < IPDISTANCE: "<*>" >
| < HAMMINGDISTANCE: "<HD>">
| < DOUBLE_PERIOD: ".." >
| < QUOTE: "'" >
| < DOUBLE_QUOTE: "\"" >
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2369,9 +2369,10 @@ private static IndexDefinition fromSqlIndexDeclaration(
throw new RuntimeException("Column must be not null, column name: " + columnName);
}
} else if (i == 1) {
if (!columnDefinition.getTypeName().equals("ARRAY")
if (!columnDefinition.getTypeName().equals("BINARY")
&& (!columnDefinition.getTypeName().equals("ARRAY")
|| !(columnDefinition.getElementType() != null
&& columnDefinition.getElementType().equals("FLOAT"))) {
&& columnDefinition.getElementType().equals("FLOAT")))) {
throw new RuntimeException("Invalid column type: " + columnName);
}
if (columnDefinition.isNullable()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,12 @@ private String getCreateTable() {
case VECTOR_DISKANN:
type = "DISKANN";
break;
case VECTOR_BINARY_FLAT:
type = "BINARY_FLAT";
break;
case VECTOR_BINARY_IVF_FLAT:
type = "BINARY_IVF_FLAT";
break;
default:
type = "HNSW";
break;
Expand All @@ -234,6 +240,9 @@ private String getCreateTable() {
case "METRIC_TYPE_INNER_PRODUCT":
val = "INNER_PRODUCT";
break;
case "METRIC_TYPE_HAMMING" :
val = "HAMMING";
break;
}
}
createTableSqlStr.append(key).append("=").append(val).append(",");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import io.dingodb.exec.fun.special.ThrowFun;
import io.dingodb.exec.fun.vector.VectorCosineDistanceFun;
import io.dingodb.exec.fun.vector.VectorDistanceFun;
import io.dingodb.exec.fun.vector.VectorHammingDistanceFun;
import io.dingodb.exec.fun.vector.VectorIPDistanceFun;
import io.dingodb.exec.fun.vector.VectorImageFun;
import io.dingodb.exec.fun.vector.VectorL2DistanceFun;
Expand Down Expand Up @@ -336,6 +337,13 @@ private void init() {
family(SqlTypeFamily.ARRAY, SqlTypeFamily.ARRAY),
SqlFunctionCategory.NUMERIC
);
registerFunction(
VectorHammingDistanceFun.NAME,
FLOAT,
DingoInferTypes.FLOAT,
family(SqlTypeFamily.BINARY, SqlTypeFamily.BINARY),
SqlFunctionCategory.NUMERIC
);
registerFunction(
VersionFun.NAME,
ReturnTypes.VARCHAR_2000,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@

import io.dingodb.common.table.DiskAnnTable;
import io.dingodb.exec.fun.vector.VectorCosineDistanceFun;
import io.dingodb.exec.fun.vector.VectorHammingDistanceFun;
import io.dingodb.exec.fun.vector.VectorIPDistanceFun;
import io.dingodb.exec.fun.vector.VectorL2DistanceFun;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql2rel.SqlCosineSimilarityOperator;
import org.apache.calcite.sql2rel.SqlDiskAnnOperator;
import org.apache.calcite.sql2rel.SqlDocumentOperator;
import org.apache.calcite.sql2rel.SqlFunctionScanOperator;
import org.apache.calcite.sql2rel.SqlHammingDistanceOperator;
import org.apache.calcite.sql2rel.SqlHybridSearchOperator;
import org.apache.calcite.sql2rel.SqlIPDistanceOperator;
import org.apache.calcite.sql2rel.SqlL2DistanceOperator;
Expand Down Expand Up @@ -68,6 +70,7 @@ public class SqlUserDefinedOperators {
public static SqlL2DistanceOperator L2_DISTANCE
= new SqlL2DistanceOperator(VectorL2DistanceFun.NAME, SqlKind.OTHER_FUNCTION);


public static SqlHammingDistanceOperator HAMMING_DISTANCE
= new SqlHammingDistanceOperator(VectorHammingDistanceFun.NAME, SqlKind.OTHER_FUNCTION);

}
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,9 @@ public static SqlOperator findSqlOperator(String metricType) {
case "COSINE":
metricTypeFullName = "cosineDistance";
break;
case "HAMMING":
metricTypeFullName = "hammingDistance";
break;
default:
metricTypeFullName = null;
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ public void onMatch(@NonNull RelOptRuleCall call) {
public Void visitCall(RexCall call) {
if (call.op.getName().equalsIgnoreCase(SqlUserDefinedOperators.COSINE_SIMILARITY.getName())
|| call.op.getName().equalsIgnoreCase(SqlUserDefinedOperators.IP_DISTANCE.getName())
|| call.op.getName().equalsIgnoreCase(SqlUserDefinedOperators.L2_DISTANCE.getName())) {
|| call.op.getName().equalsIgnoreCase(SqlUserDefinedOperators.L2_DISTANCE.getName())
|| call.op.getName().equalsIgnoreCase(SqlUserDefinedOperators.HAMMING_DISTANCE.getName())) {
vectorSelected.add(call);
}
return super.visitCall(call);
Expand Down Expand Up @@ -168,7 +169,8 @@ private static LogicalProject getPostVectorFiltering(LogicalProject project,
String opName = rexCall.op.getName();
if (opName.equalsIgnoreCase(SqlUserDefinedOperators.COSINE_SIMILARITY.getName())
|| opName.equalsIgnoreCase(SqlUserDefinedOperators.IP_DISTANCE.getName())
|| opName.equalsIgnoreCase(SqlUserDefinedOperators.L2_DISTANCE.getName())) {
|| opName.equalsIgnoreCase(SqlUserDefinedOperators.L2_DISTANCE.getName())
|| opName.equalsIgnoreCase(SqlUserDefinedOperators.HAMMING_DISTANCE.getName())) {
rn.set(true);
return inputRef;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import io.dingodb.common.util.Pair;
import io.dingodb.meta.entity.Column;
import io.dingodb.meta.entity.IndexTable;
import io.dingodb.meta.entity.IndexType;
import io.dingodb.meta.entity.Table;
import lombok.extern.slf4j.Slf4j;
import org.apache.calcite.plan.RelOptRuleCall;
Expand All @@ -57,6 +58,7 @@
import static io.dingodb.calcite.rule.DingoGetByIndexRule.filterIndices;
import static io.dingodb.calcite.rule.DingoGetByIndexRule.filterScalarIndices;
import static io.dingodb.calcite.rule.DingoGetByIndexRule.getScalaIndices;
import static io.dingodb.calcite.utils.VectorUtils.parseBinaryStringToByteArray;
import static io.dingodb.calcite.visitor.function.DingoGetVectorByDistanceVisitFun.getTargetVector;

@Slf4j
Expand Down Expand Up @@ -90,10 +92,18 @@ public static RelNode getDingoGetVectorByDistance(RexNode condition, LogicalDing
if (condition != null) {
dispatchDistanceCondition(condition, selection, dingoTable);
}
int dimension;
if (((IndexTable) vector.getIndexTable()).getIndexType() == IndexType.VECTOR_BINARY_FLAT ||
((IndexTable) vector.getIndexTable()).getIndexType() == IndexType.VECTOR_BINARY_IVF_FLAT) {
byte[] binaryVector = parseBinaryStringToByteArray(vector.getOperands());
dimension = binaryVector.length;
} else {
List<Float> targetVector = getTargetVector(vector.getOperands());
dimension = targetVector.size();
}

List<Float> targetVector = getTargetVector(vector.getOperands());
// if filter matched point get by primary key, then DingoGetByKeys priority highest
Pair<Integer, Integer> vectorIdPair = getVectorIndex(dingoTable, targetVector.size());
Pair<Integer, Integer> vectorIdPair = getVectorIndex(dingoTable, dimension);
assert vectorIdPair != null;
RelTraitSet traitSet = vector.getTraitSet().replace(DingoRelStreaming.of(vector.getTable()));
boolean preFilter = vector.getHints() != null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ public void onMatch(RelOptRuleCall call) {
public Void visitCall(RexCall call) {
if (call.op.getName().equalsIgnoreCase(SqlUserDefinedOperators.COSINE_SIMILARITY.getName())
|| call.op.getName().equalsIgnoreCase(SqlUserDefinedOperators.IP_DISTANCE.getName())
|| call.op.getName().equalsIgnoreCase(SqlUserDefinedOperators.L2_DISTANCE.getName())) {
|| call.op.getName().equalsIgnoreCase(SqlUserDefinedOperators.L2_DISTANCE.getName())
|| call.op.getName().equalsIgnoreCase(SqlUserDefinedOperators.HAMMING_DISTANCE.getName())) {
vectorSelected.add(call);
}
return super.visitCall(call);
Expand Down
141 changes: 141 additions & 0 deletions dingo-calcite/src/main/java/io/dingodb/calcite/utils/VectorUtils.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
/*
* Copyright 2021 DataCanvas
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.dingodb.calcite.utils;

import io.dingodb.exec.fun.vector.VectorImageFun;
import io.dingodb.exec.fun.vector.VectorTextFun;
import io.dingodb.exec.restful.VectorExtract;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.sql.SqlBasicCall;
import org.apache.calcite.sql.SqlCharStringLiteral;
import org.apache.calcite.sql.SqlIdentifier;
import org.apache.calcite.sql.SqlLiteral;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlNumericLiteral;
import org.apache.calcite.sql.fun.SqlArrayValueConstructor;
import org.apache.calcite.util.NlsString;

import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;

public final class VectorUtils {

private VectorUtils(){}

private static final int MULTIPLE = 8;

public static byte[] parseBinaryStringToBinary(List<Object> operandsList) {
if (operandsList.get(2) instanceof SqlCharStringLiteral) {
String binaryString = ((NlsString)((SqlCharStringLiteral) operandsList.get(2)).getValue()).getValue();
int length = binaryString.length();
int segmentLength = length / MULTIPLE;
byte[] byteArray = new byte[segmentLength];
for (int i = 0; i < segmentLength; i++) {
String segment = binaryString.substring(i * MULTIPLE, (i + 1) * MULTIPLE);
byteArray[i] = (byte) Integer.parseInt(segment, 2);
}
return byteArray;
}
throw new RuntimeException("vector load binary string param error");
}

public static byte[] parseBinaryStringToByteArray(List<Object> operandsList) {
if (operandsList.get(2) instanceof SqlCharStringLiteral) {
String binaryString = ((NlsString)((SqlCharStringLiteral) operandsList.get(2)).getValue()).getValue();
int length = binaryString.length();
byte[] byteArray = new byte[length];
for (int i = 0; i < length; i++) {
char c = binaryString.charAt(i);
byteArray[i] = (byte) c;
}
return byteArray;
}
throw new RuntimeException("vector load binary string param error");
}

public static Float[] getVectorFloats(List<Object> operandsList) {
Float[] floatArray = null;
Object call = operandsList.get(2);
if (call instanceof RexCall) {
RexCall rexCall = (RexCall) call;
floatArray = new Float[rexCall.getOperands().size()];
int vectorDimension = rexCall.getOperands().size();
for (int i = 0; i < vectorDimension; i++) {
RexLiteral literal = (RexLiteral) rexCall.getOperands().get(i);
floatArray[i] = literal.getValueAs(Float.class);
}
return floatArray;
}
SqlBasicCall basicCall = (SqlBasicCall) operandsList.get(2);
if (basicCall.getOperator() instanceof SqlArrayValueConstructor) {
List<SqlNode> operands = basicCall.getOperandList();
floatArray = new Float[operands.size()];
for (int i = 0; i < operands.size(); i++) {
floatArray[i] = (
(Number) Objects.requireNonNull(((SqlNumericLiteral) operands.get(i)).getValue())
).floatValue();
}
} else {
List<SqlNode> sqlNodes = basicCall.getOperandList();
if (sqlNodes.size() < 2) {
throw new RuntimeException("vector load param error");
}
List<Object> paramList = sqlNodes.stream().map(e -> {
if (e instanceof SqlLiteral) {
return ((SqlLiteral)e).getValue();
} else if (e instanceof SqlIdentifier) {
return ((SqlIdentifier)e).getSimple();
} else {
return e.toString();
}
}).collect(Collectors.toList());
if (paramList.get(1) == null || paramList.get(0) == null) {
throw new RuntimeException("vector load param error");
}
String param = paramList.get(1).toString();
if (param.contains("'")) {
param = param.replace("'", "");
}
String funcName = basicCall.getOperator().getName();
if (funcName.equalsIgnoreCase(VectorTextFun.NAME)) {
floatArray = VectorExtract.getTxtVector(
basicCall.getOperator().getName(),
paramList.get(0).toString(),
param);
} else if (funcName.equalsIgnoreCase(VectorImageFun.NAME)) {
if (paramList.size() < 3) {
throw new RuntimeException("vector load param error");
}
Object localPath = paramList.get(2);
if (!(localPath instanceof Boolean)) {
throw new RuntimeException("vector load param error");
}
floatArray = VectorExtract.getImgVector(
basicCall.getOperator().getName(),
paramList.get(0).toString(),
paramList.get(1),
(Boolean) paramList.get(2));
}
}
if (floatArray == null) {
throw new RuntimeException("vector load error");
}
return floatArray;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@
import java.util.function.Supplier;

import static io.dingodb.calcite.rel.DingoRel.dingo;
import static io.dingodb.calcite.utils.VectorUtils.getVectorFloats;
import static io.dingodb.calcite.utils.VectorUtils.parseBinaryStringToByteArray;
import static io.dingodb.calcite.visitor.function.DingoVectorVisitFun.getTopkParam;
import static io.dingodb.calcite.visitor.function.DingoVectorVisitFun.getVectorFloats;
import static io.dingodb.exec.utils.OperatorCodeUtils.VECTOR_POINT_DISTANCE;

public final class DingoGetVectorByDistanceVisitFun {
Expand All @@ -67,17 +68,27 @@ static class OperatorSupplier implements Supplier<Vertex> {
@Override
public Vertex get() {
DingoRelOptTable dingoRelOptTable = (DingoRelOptTable) rel.getTable();
List<Float> targetVector = getTargetVector(rel.getOperands());
IndexTable indexTable = getVectorIndexTable(dingoRelOptTable, targetVector.size());
int dimension;
List<Float> targetVector = null;
byte[] binaryVector = null;
boolean isBinaryVector = false;
if (((IndexTable) rel.getIndexTable()).getIndexType() == IndexType.VECTOR_BINARY_FLAT ||
((IndexTable) rel.getIndexTable()).getIndexType() == IndexType.VECTOR_BINARY_IVF_FLAT) {
binaryVector = getBinaryVector(rel.getOperands());
dimension = binaryVector.length;
isBinaryVector = true;
} else {
targetVector = getTargetVector(rel.getOperands());
dimension = targetVector.size();
}
IndexTable indexTable = getVectorIndexTable(dingoRelOptTable, dimension);
if (indexTable == null) {
throw new RuntimeException("not found vector index");
}
MetaService metaService = MetaService.root().getSubMetaService(dingoRelOptTable.getSchemaName());
NavigableMap<ByteArrayUtils.ComparableByteArray, RangeDistribution> distributions
= metaService.getRangeDistribution(rel.getIndexTableId());

int dimension = Integer.parseInt(indexTable.getProperties()
.getOrDefault("dimension", targetVector.size()).toString());
String algType;
if (indexTable.indexType == IndexType.VECTOR_FLAT) {
algType = "FLAT";
Expand All @@ -94,7 +105,9 @@ public Vertex get() {
distributions.firstEntry().getValue(),
rel.getVectorIndex(),
rel.getIndexTableId(),
isBinaryVector,
targetVector,
binaryVector,
dimension,
algType,
indexTable.getProperties().getProperty("metricType"),
Expand All @@ -106,6 +119,10 @@ public Vertex get() {
}
}

public static byte[] getBinaryVector(List<Object> operandList) {
return parseBinaryStringToByteArray(operandList);
}

public static List<Float> getTargetVector(List<Object> operandList) {
Float[] vector = getVectorFloats(operandList);
return Arrays.asList(vector);
Expand Down
Loading

0 comments on commit 75d37b5

Please # to comment.