From 04e549abf7b51bce090ddb45b95840f7591d0726 Mon Sep 17 00:00:00 2001 From: vga91 Date: Tue, 16 Apr 2024 17:13:30 +0200 Subject: [PATCH] wip - finising chroma procs --- .../test/java/apoc/vectordb/ChromaDbTest.java | 39 ++-- .../src/main/java/apoc/vectordb/Chroma.java | 148 ------------ .../src/main/java/apoc/vectordb/ChromaDb.java | 219 ++++++++++++++++++ .../src/main/java/apoc/vectordb/Qdrant.java | 12 +- .../src/main/java/apoc/vectordb/VectorDb.java | 20 +- .../main/java/apoc/vectordb/VectorDbUtil.java | 4 +- .../apoc/vectordb/VectorEmbeddingConfig.java | 21 +- 7 files changed, 280 insertions(+), 183 deletions(-) delete mode 100644 extended/src/main/java/apoc/vectordb/Chroma.java create mode 100644 extended/src/main/java/apoc/vectordb/ChromaDb.java diff --git a/extended-it/src/test/java/apoc/vectordb/ChromaDbTest.java b/extended-it/src/test/java/apoc/vectordb/ChromaDbTest.java index 7f051980b0..7a5ac64ca2 100644 --- a/extended-it/src/test/java/apoc/vectordb/ChromaDbTest.java +++ b/extended-it/src/test/java/apoc/vectordb/ChromaDbTest.java @@ -37,19 +37,20 @@ public class ChromaDbTest { @ClassRule public static DbmsRule db = new ImpermanentDbmsRule(); - private static ChromaDBContainer qdrant = new ChromaDBContainer("chromadb/chroma:0.4.25.dev137"); + private static ChromaDBContainer chroma = new ChromaDBContainer("chromadb/chroma:0.4.25.dev137"); public static String HOST; + private static AtomicReference collId = new AtomicReference<>(); @BeforeClass public static void setUp() throws Exception { - qdrant.start(); + chroma.start(); - HOST = "localhost:" + qdrant.getMappedPort(8000); - TestUtil.registerProcedure(db, VectorDb.class, Qdrant.class); + HOST = "localhost:" + chroma.getMappedPort(8000); + TestUtil.registerProcedure(db, VectorDb.class, ChromaDb.class); apocConfig().setProperty(APOC_IMPORT_FILE_ENABLED, true); apocConfig().setProperty(APOC_EXPORT_FILE_ENABLED, true); - AtomicReference id = new AtomicReference<>(); + testCall(db, """ CALL apoc.vectordb.custom({ endpoint: $endpoint, @@ -63,7 +64,7 @@ public static void setUp() throws Exception { Map.of("endpoint", "http://" + HOST + "/api/v1/collections"), r -> { Map value = (Map) r.get("value"); - id.set((String) value.get("id")); + collId.set((String) value.get("id")); }); testCall(db, """ @@ -85,7 +86,7 @@ public static void setUp() throws Exception { payload: {city: "London", foo: "two"} } ]*/ - }, method: 'POST'})""", Map.of("endpoint", "http://" + HOST + "/api/v1/collections/%s/add".formatted(id.get())), + }, method: 'POST'})""", Map.of("endpoint", "http://" + HOST + "/api/v1/collections/%s/add".formatted(collId.get())), r -> { assertEquals(true, r.get("value")); }); @@ -103,8 +104,8 @@ public void after() { @Test public void getEmbeddings() { - testResult(db, "CALL apoc.vectordb.qdrant.get($host, 'test_collection', [1]) ", - Map.of("host", HOST), + testResult(db, "CALL apoc.vectordb.chroma.get($host, $collection, [1]) ", + Map.of("host", HOST, "collection", collId.get()), r -> { System.out.println("r = " + r.next()); }); @@ -115,8 +116,8 @@ public void getEmbedding() { // String filter = System.getenv("PINECONE_FILTER"); // Assume.assumeNotNull("No PINECONE_FILTER environment configured", host); // todo -> nResults: 10, ovvero limit, come parametro opzionale - testResult(db, "CALL apoc.vectordb.qdrant.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], {}, 5)", - Map.of("host", HOST, /*"filter", filter, */"conf", emptyMap()), + testResult(db, "CALL apoc.vectordb.chroma.query($host, $collection, [0.2, 0.1, 0.9, 0.7], {}, 5)", + Map.of("host", HOST, "collection", collId.get(), /*"filter", filter, */"conf", emptyMap()), r -> { System.out.println("r = " + r.next()); System.out.println("r = " + r.next()); @@ -128,8 +129,8 @@ public void getEmbeddingWithYield() { // String filter = System.getenv("PINECONE_FILTER"); // Assume.assumeNotNull("No PINECONE_FILTER environment configured", host); // todo -> nResults: 10, ovvero limit, come parametro opzionale - testResult(db, "CALL apoc.vectordb.qdrant.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], {}, 5) YIELD metadata, id", - Map.of("host", HOST, /*"filter", filter, */"conf", emptyMap()), + testResult(db, "CALL apoc.vectordb.chroma.query($host, $collection, [0.2, 0.1, 0.9, 0.7], {}, 5) YIELD metadata, id", + Map.of("host", HOST, "collection", collId.get(), /*"filter", filter, */"conf", emptyMap()), r -> { System.out.println("r = " + r.next()); System.out.println("r = " + r.next()); @@ -144,8 +145,8 @@ public void getEmbeddingWithCreateIndex() { "prop", "myId", "id", "foo", "create", true)); - testResult(db, "CALL apoc.vectordb.qdrant.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], {}, 5, $conf)", - Map.of("host", HOST, "conf", conf), + testResult(db, "CALL apoc.vectordb.chroma.query($host, $collection, [0.2, 0.1, 0.9, 0.7], {}, 5, $conf)", + Map.of("host", HOST, "collection", collId.get(), "conf", conf), r -> { System.out.println("r = " + r.next()); System.out.println("r = " + r.next()); @@ -178,8 +179,8 @@ public void getEmbeddingWithCreateExistingNode() { "label", "Test", "prop", "myId", "id", "foo")); - testResult(db, "CALL apoc.vectordb.qdrant.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], {}, 5, $conf)", - Map.of("host", HOST, "conf", conf), + testResult(db, "CALL apoc.vectordb.chroma.query($host, $collection, [0.2, 0.1, 0.9, 0.7], {}, 5, $conf)", + Map.of("host", HOST, "collection", collId.get(), "conf", conf), r -> { System.out.println("r = " + r.next()); System.out.println("r = " + r.next()); @@ -213,8 +214,8 @@ public void getEmbeddingWithCreateRelIndex() { "prop", "myId", "id", "foo", "create", true)); - testResult(db, "CALL apoc.vectordb.qdrant.query($host, 'test_collection', [0.2, 0.1, 0.9, 0.7], {}, 5, $conf)", - Map.of("host", HOST, "conf", conf), + testResult(db, "CALL apoc.vectordb.chroma.query($host, $collection, [0.2, 0.1, 0.9, 0.7], {}, 5, $conf)", + Map.of("host", HOST, "collection", collId.get(), "conf", conf), r -> { System.out.println("r = " + r.next()); System.out.println("r = " + r.next()); diff --git a/extended/src/main/java/apoc/vectordb/Chroma.java b/extended/src/main/java/apoc/vectordb/Chroma.java deleted file mode 100644 index 2f6ce7c06a..0000000000 --- a/extended/src/main/java/apoc/vectordb/Chroma.java +++ /dev/null @@ -1,148 +0,0 @@ -package apoc.vectordb; - -import apoc.result.MapResult; -import apoc.util.UrlResolver; -import org.neo4j.graphdb.GraphDatabaseService; -import org.neo4j.graphdb.Transaction; -import org.neo4j.graphdb.security.URLAccessChecker; -import org.neo4j.internal.kernel.api.procs.ProcedureCallContext; -import org.neo4j.procedure.Context; -import org.neo4j.procedure.Description; -import org.neo4j.procedure.Mode; -import org.neo4j.procedure.Name; -import org.neo4j.procedure.Procedure; - -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.stream.Stream; - -import static apoc.ml.RestAPIConfig.ENDPOINT_KEY; -import static apoc.ml.RestAPIConfig.JSON_PATH; -import static apoc.ml.RestAPIConfig.METHOD_KEY; -import static apoc.vectordb.VectorDb.getEmbeddingResultStream; -import static apoc.vectordb.VectorEmbeddingConfig.EMBEDDING_KEY; -import static apoc.vectordb.VectorEmbeddingConfig.METADATA_KEY; - -public class Chroma { - - @Context - public ProcedureCallContext procedureCallContext; - - @Context - public Transaction tx; - - @Context - public GraphDatabaseService db; - - // todo - create an enum Factory in case of others VectorDbs - // e.g. ChromaType.from() - public static class ChromaEmbeddingType { - - public static VectorEmbeddingConfig fromGet(Map config, ProcedureCallContext procedureCallContext, List ids) { - List fields = procedureCallContext.outputFields().toList(); - -// // "with_payload": and "with_vectors": return the metadata and vector, if true -// // therefore is the RestAPI itself that doesn't return the data if `YIELD ` has not metadata/embedding -// Map additionalBodies = Map.of("with_payload", fields.contains("metadata"), -// "with_vectors", fields.contains("embedding"), -// "ids", ids); -// -// config.putIfAbsent(EMBEDDING_KEY, "vector"); -// config.putIfAbsent(METADATA_KEY, "payload"); -// config.putIfAbsent(JSON_PATH, "result"); -// -// config.putIfAbsent(METHOD_KEY, "POST"); - - return new VectorEmbeddingConfig(config, Map.of(), additionalBodies); - } - - public static VectorEmbeddingConfig fromQuery(Map config, ProcedureCallContext procedureCallContext, - List vector, Map filter, long limit) { - List fields = procedureCallContext.outputFields().toList(); - -// // "with_payload": and "with_vectors": return the metadata and vector, if true -// // therefore is the RestAPI itself that doesn't return the data if `YIELD ` has not metadata/embedding -// Map additionalBodies = Map.of("with_payload", fields.contains("metadata"), -// "with_vectors", fields.contains("embedding"), -// "vector", vector, -// "filter", filter, -// "limit", limit); -// -// config.putIfAbsent(EMBEDDING_KEY, "vector"); -// config.putIfAbsent(METADATA_KEY, "payload"); -// config.putIfAbsent(JSON_PATH, "result"); - - return new VectorEmbeddingConfig(config, Map.of(), additionalBodies); - } - } - - @Context - public URLAccessChecker urlAccessChecker; - - @Procedure("apoc.vectordb.qdrant.create") - @Description("apoc.vectordb.qdrant.create") - public Stream create(@Name("hostOrKey") String hostOrKey, - @Name("name") String name, - @Name("similarity") String similarity, - @Name("size") String size, - @Name(value = "configuration", defaultValue = "{}") Map configuration) { - // todo - create collection - return null; - } - - @Procedure("apoc.vectordb.qdrant.delete") - @Description("apoc.vectordb.qdrant.delete") - public Stream delete(@Name("hostOrKey") String hostOrKey, @Name("name") String name, @Name(value = "configuration", defaultValue = "{}") Map configuration) { - // todo - delete collection - return null; - } - - @Procedure("apoc.vectordb.qdrant.upsert") - @Description("apoc.vectordb.qdrant.upsert") - public Stream upsert(@Name("hostOrKey") String hostOrKey, @Name("vectors") List> vectors, @Name(value = "configuration", defaultValue = "{}") Map configuration) { - // todo - upsert vectors - return null; - } - - @Procedure(value = "apoc.vectordb.qdrant.get", mode = Mode.SCHEMA) - @Description("apoc.vectordb.qdrant.get()") - public Stream query(@Name("hostOrKey") String hostOrKey, - @Name("collection") String collection, - @Name("ids") List ids, - @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { - var config = new HashMap<>(configuration); - - String qdrantUrl = getQdrantUrl(hostOrKey); - String endpoint = "%s/api/v1/collections/%s/get".formatted(qdrantUrl, collection); - config.putIfAbsent(ENDPOINT_KEY, endpoint); - - VectorEmbeddingConfig apiConfig = Qdrant.QdrantEmbeddingType.fromGet(config, procedureCallContext, ids); - return getEmbeddingResultStream(apiConfig, procedureCallContext, urlAccessChecker, db, tx); - } - - @Procedure(value = "apoc.vectordb.qdrant.query", mode = Mode.SCHEMA) - @Description("apoc.vectordb.qdrant.query()") - public Stream query(@Name("hostOrKey") String hostOrKey, - @Name("collection") String collection, - @Name(value = "vector", defaultValue = "[]") List vector, - @Name(value = "filter", defaultValue = "{}") Map filter, - @Name(value = "limit", defaultValue = "10") long limit, - @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { - - - var config = new HashMap<>(configuration); - - String qdrantUrl = getQdrantUrl(hostOrKey); - String endpoint = "%s/collections/%s/points/search".formatted(qdrantUrl, collection); - config.putIfAbsent(ENDPOINT_KEY, endpoint); - - VectorEmbeddingConfig apiConfig = Qdrant.QdrantEmbeddingType.fromQuery(config, procedureCallContext, vector, filter, limit); - return getEmbeddingResultStream(apiConfig, procedureCallContext, urlAccessChecker, db, tx); - } - - - protected String getQdrantUrl(String hostOrKey) { - return new UrlResolver("http", "localhost", 6333).getUrl("qdrant", hostOrKey); - } -} diff --git a/extended/src/main/java/apoc/vectordb/ChromaDb.java b/extended/src/main/java/apoc/vectordb/ChromaDb.java new file mode 100644 index 0000000000..4f89eb5339 --- /dev/null +++ b/extended/src/main/java/apoc/vectordb/ChromaDb.java @@ -0,0 +1,219 @@ +package apoc.vectordb; + +import apoc.result.MapResult; +import apoc.util.UrlResolver; +import org.apache.commons.collections4.CollectionUtils; +import org.apache.commons.collections4.ListUtils; +import org.jetbrains.annotations.NotNull; +import org.neo4j.graphdb.GraphDatabaseService; +import org.neo4j.graphdb.Transaction; +import org.neo4j.graphdb.security.URLAccessChecker; +import org.neo4j.internal.kernel.api.procs.ProcedureCallContext; +import org.neo4j.procedure.Context; +import org.neo4j.procedure.Description; +import org.neo4j.procedure.Mode; +import org.neo4j.procedure.Name; +import org.neo4j.procedure.Procedure; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Stream; + +import static apoc.ml.RestAPIConfig.ENDPOINT_KEY; +import static apoc.util.MapUtil.map; +import static apoc.vectordb.VectorDb.getEmbeddingResultStream; +import static apoc.vectordb.VectorEmbeddingConfig.*; + +public class ChromaDb { + + @Context + public ProcedureCallContext procedureCallContext; + + @Context + public Transaction tx; + + @Context + public GraphDatabaseService db; + + // todo - create an enum Factory in case of others VectorDbs + // e.g. ChromaType.from() + public static class ChromaEmbeddingType { + + public static VectorEmbeddingConfig fromGet(Map config, + ProcedureCallContext procedureCallContext, + List ids) { + + List fields = procedureCallContext.outputFields().toList(); + + Map additionalBodies = map("ids", ids); + + return getVectorEmbeddingConfig(config, fields, additionalBodies); + } + + public static VectorEmbeddingConfig fromQuery(Map config, + ProcedureCallContext procedureCallContext, + List vector, + Map filter, + long limit) { + + List fields = procedureCallContext.outputFields().toList(); + + Map additionalBodies = map("query_embeddings", List.of(vector), + "where", filter, + "n_results", limit); + + return getVectorEmbeddingConfig(config, fields, additionalBodies); + } + + + private static VectorEmbeddingConfig getVectorEmbeddingConfig(Map config, + List fields, + Map additionalBodies) { + ArrayList include = new ArrayList<>(); + if (fields.contains("metadata")) { + include.add("metadatas"); + } + if (fields.contains("text")) { + include.add("documents"); + } + if (fields.contains("embedding")) { + include.add("embeddings"); + } + + additionalBodies.put("include", include); + + return new VectorEmbeddingConfig(config, Map.of(), additionalBodies); + } + } + + @Context + public URLAccessChecker urlAccessChecker; + + @Procedure("apoc.vectordb.chroma.create") + @Description("apoc.vectordb.chroma.create") + public Stream create(@Name("hostOrKey") String hostOrKey, + @Name("name") String name, + @Name("similarity") String similarity, + @Name("size") String size, + @Name(value = "configuration", defaultValue = "{}") Map configuration) { + // todo - create collection + return null; + } + + @Procedure("apoc.vectordb.chroma.delete") + @Description("apoc.vectordb.chroma.delete") + public Stream delete(@Name("hostOrKey") String hostOrKey, @Name("name") String name, @Name(value = "configuration", defaultValue = "{}") Map configuration) { + // todo - delete collection + return null; + } + + @Procedure("apoc.vectordb.chroma.upsert") + @Description("apoc.vectordb.chroma.upsert") + public Stream upsert(@Name("hostOrKey") String hostOrKey, @Name("vectors") List> vectors, @Name(value = "configuration", defaultValue = "{}") Map configuration) { + // todo - upsert vectors + return null; + } + + @Procedure(value = "apoc.vectordb.chroma.get", mode = Mode.SCHEMA) + @Description("apoc.vectordb.chroma.get()") + public Stream query(@Name("hostOrKey") String hostOrKey, + @Name("collection") String collection, + @Name("ids") List ids, + @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { + var config = new HashMap<>(configuration); + + String qdrantUrl = getQdrantUrl(hostOrKey); + String endpoint = "%s/api/v1/collections/%s/get".formatted(qdrantUrl, collection); + config.putIfAbsent(ENDPOINT_KEY, endpoint); + + VectorEmbeddingConfig apiConfig = ChromaEmbeddingType.fromGet(config, procedureCallContext, ids); + return getEmbeddingResultStream(apiConfig, procedureCallContext, urlAccessChecker, db, tx, + v -> getForList((Map) v).stream()); + } + + @Procedure(value = "apoc.vectordb.chroma.query", mode = Mode.SCHEMA) + @Description("apoc.vectordb.chroma.query()") + public Stream query(@Name("hostOrKey") String hostOrKey, + @Name("collection") String collection, + @Name(value = "vector", defaultValue = "[]") List vector, + @Name(value = "filter", defaultValue = "{}") Map filter, + @Name(value = "limit", defaultValue = "10") long limit, + @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { + + var config = new HashMap<>(configuration); + + String qdrantUrl = getQdrantUrl(hostOrKey); + String endpoint = "%s/api/v1/collections/%s/query".formatted(qdrantUrl, collection); + config.putIfAbsent(ENDPOINT_KEY, endpoint); + + VectorEmbeddingConfig apiConfig = ChromaEmbeddingType.fromQuery(config, procedureCallContext, vector, filter, limit); + return getEmbeddingResultStream(apiConfig, procedureCallContext, urlAccessChecker, db, tx, + v -> queryForList((Map) v).stream()); + } + + private static List queryForList(Map s) { + List distances = s.get("distances") == null + ? null + : ((List) s.get("distances")) + .get(0); + List metadatas = s.get("metadatas") == null + ? null + : ((List) s.get("metadatas")) + .get(0); + List documents = s.get("documents") == null + ? null + : ((List) s.get("documents")) + .get(0); + + final List result = new ArrayList<>(); + List ids = ((List) s.get("ids")).get(0); + for (int i = 0; i < ids.size(); i++) { + Map map = map(DEFAULT_ID, ids.get(i)); + if (CollectionUtils.isEmpty(distances)) { + map.put(DEFAULT_EMBEDDING, distances); + } + if (CollectionUtils.isEmpty(metadatas)) { + map.put(DEFAULT_METADATA, metadatas); + } + if (CollectionUtils.isEmpty(documents)) { + map.put(DEFAULT_TEXT, documents); + } + + result.add(map); + } + + return result; + } + + private static List getForList(Map s) { + List distances = (List) s.get("distances"); + List metadatas = (List) s.get("metadatas"); + List documents = (List) s.get("documents"); + + final List result = new ArrayList<>(); + List ids = (List) s.get("ids"); + for (int i = 0; i < ids.size(); i++) { + Map map = map(DEFAULT_ID, ids.get(i)); + if (CollectionUtils.isEmpty(distances)) { + map.put(DEFAULT_EMBEDDING, distances); + } + if (CollectionUtils.isEmpty(metadatas)) { + map.put(DEFAULT_METADATA, metadatas); + } + if (CollectionUtils.isEmpty(documents)) { + map.put(DEFAULT_TEXT, documents); + } + + result.add(map); + } + + return result; + } + + + protected String getQdrantUrl(String hostOrKey) { + return new UrlResolver("http", "localhost", 6333).getUrl("qdrant", hostOrKey); + } +} diff --git a/extended/src/main/java/apoc/vectordb/Qdrant.java b/extended/src/main/java/apoc/vectordb/Qdrant.java index d8ffe2817a..3d8442ebe9 100644 --- a/extended/src/main/java/apoc/vectordb/Qdrant.java +++ b/extended/src/main/java/apoc/vectordb/Qdrant.java @@ -45,12 +45,10 @@ public class Qdrant { // e.g. ChromaType.from() public static class QdrantEmbeddingType { - public static VectorEmbeddingConfig fromGet(Map config, ProcedureCallContext procedureCallContext, List ids) { + public static VectorEmbeddingConfig fromGet(Map config, ProcedureCallContext procedureCallContext, List ids) { List fields = procedureCallContext.outputFields().toList(); config.putIfAbsent(METHOD_KEY, "POST"); - - // "with_payload": and "with_vectors": return the metadata and vector, if true - // therefore is the RestAPI itself that doesn't return the data if `YIELD ` has not metadata/embedding + Map additionalBodies = map("ids", ids); return getVectorEmbeddingConfig(config, fields, additionalBodies); @@ -60,8 +58,6 @@ public static VectorEmbeddingConfig fromQuery(Map config, Proced List vector, Map filter, long limit) { List fields = procedureCallContext.outputFields().toList(); - // "with_payload": and "with_vectors": return the metadata and vector, if true - // therefore is the RestAPI itself that doesn't return the data if `YIELD ` has not metadata/embedding Map additionalBodies = map("vector", vector, "filter", filter, "limit", limit); @@ -69,6 +65,8 @@ public static VectorEmbeddingConfig fromQuery(Map config, Proced return getVectorEmbeddingConfig(config, fields, additionalBodies); } + // "with_payload": and "with_vectors": return the metadata and vector, if true + // therefore is the RestAPI itself that doesn't return the data if `YIELD ` has not metadata/embedding private static VectorEmbeddingConfig getVectorEmbeddingConfig(Map config, List fields, Map additionalBodies) { additionalBodies.put("with_payload", fields.contains("metadata")); additionalBodies.put("with_vectors", fields.contains("embedding")); @@ -113,7 +111,7 @@ public Stream upsert(@Name("hostOrKey") String hostOrKey, @Name("vect @Description("apoc.vectordb.qdrant.get()") public Stream query(@Name("hostOrKey") String hostOrKey, @Name("collection") String collection, - @Name("ids") List ids, + @Name("ids") List ids, @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { var config = new HashMap<>(configuration); diff --git a/extended/src/main/java/apoc/vectordb/VectorDb.java b/extended/src/main/java/apoc/vectordb/VectorDb.java index c4490a49a6..f30cecac4c 100644 --- a/extended/src/main/java/apoc/vectordb/VectorDb.java +++ b/extended/src/main/java/apoc/vectordb/VectorDb.java @@ -27,6 +27,7 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.function.Function; import java.util.stream.Stream; import static apoc.ml.RestAPIConfig.JSON_PATH; @@ -62,8 +63,21 @@ public Stream get(@Name("hostOrKey") String hostOrKey, VectorEmbeddingConfig restAPIConfig = new VectorEmbeddingConfig(configuration, Map.of(), Map.of()); return getEmbeddingResultStream(restAPIConfig, procedureCallContext, urlAccessChecker, db, tx); } + + public static Stream getEmbeddingResultStream(VectorEmbeddingConfig conf, + ProcedureCallContext procedureCallContext, + URLAccessChecker urlAccessChecker, + GraphDatabaseService db, + Transaction tx) throws Exception { + return getEmbeddingResultStream(conf, procedureCallContext, urlAccessChecker, db, tx, v -> ((List) v).stream()); + } - public static Stream getEmbeddingResultStream(VectorEmbeddingConfig conf, ProcedureCallContext procedureCallContext, URLAccessChecker urlAccessChecker, GraphDatabaseService db, Transaction tx) throws Exception { + public static Stream getEmbeddingResultStream(VectorEmbeddingConfig conf, + ProcedureCallContext procedureCallContext, + URLAccessChecker urlAccessChecker, + GraphDatabaseService db, + Transaction tx, + Function> objectMapper) throws Exception { List fields = procedureCallContext.outputFields().toList(); boolean hasEmbedding = fields.contains("embedding"); @@ -73,10 +87,10 @@ public static Stream getEmbeddingResultStream(VectorEmbeddingCo VectorMappingConfig mapping = conf.getMapping(); return resultStream - .flatMap(v -> ((List>) v).stream()) + .flatMap(objectMapper) .map(m -> { // - long id = (long) m.get(conf.getIdKey()); + Object id = (long) m.get(conf.getIdKey()); List embedding = hasEmbedding ? (List) m.get(conf.getEmbeddingKey()) : null; Map metadata = hasMetadata ? (Map) m.get(conf.getMetadataKey()) : null; // in case of get operation, e.g. http://localhost:52798/collections/{coll_name}/points with Qdrant db, diff --git a/extended/src/main/java/apoc/vectordb/VectorDbUtil.java b/extended/src/main/java/apoc/vectordb/VectorDbUtil.java index 952b2e4496..417cf03b24 100644 --- a/extended/src/main/java/apoc/vectordb/VectorDbUtil.java +++ b/extended/src/main/java/apoc/vectordb/VectorDbUtil.java @@ -6,12 +6,12 @@ public class VectorDbUtil { public static class EmbeddingResult { - public final long id; + public final Object id; public final Double score; public final List embedding; public final Map metadata; - public EmbeddingResult(long id, Double score, List embedding, Map metadata) { + public EmbeddingResult(Object id, Double score, List embedding, Map metadata) { this.id = id; this.embedding = embedding; this.score = score; diff --git a/extended/src/main/java/apoc/vectordb/VectorEmbeddingConfig.java b/extended/src/main/java/apoc/vectordb/VectorEmbeddingConfig.java index 799a7bda08..50c28d4d11 100644 --- a/extended/src/main/java/apoc/vectordb/VectorEmbeddingConfig.java +++ b/extended/src/main/java/apoc/vectordb/VectorEmbeddingConfig.java @@ -8,10 +8,18 @@ public class VectorEmbeddingConfig extends RestAPIConfig { public static final String EMBEDDING_KEY = "embeddingKey"; public static final String METADATA_KEY = "metadataKey"; public static final String SCORE_KEY = "scoreKey"; + public static final String TEXT_KEY = "textKey"; public static final String ID_KEY = "idKey"; public static final String MAPPING_KEY = "mapping"; + + public static final String DEFAULT_ID = "id"; + public static final String DEFAULT_TEXT = "text"; + public static final String DEFAULT_EMBEDDING = "embedding"; + public static final String DEFAULT_METADATA = "metadata"; + public static final String DEFAULT_SCORE = "score"; private final String idKey; + private final String textKey; private final String embeddingKey; private final String metadataKey; private final String scoreKey; @@ -20,10 +28,11 @@ public class VectorEmbeddingConfig extends RestAPIConfig { public VectorEmbeddingConfig(Map config, Map additionalHeaders, Map additionalBodies) { super(config, additionalHeaders, additionalBodies); - this.embeddingKey = (String) config.getOrDefault(EMBEDDING_KEY, "embedding"); - this.metadataKey = (String) config.getOrDefault(METADATA_KEY, "metadata"); - this.scoreKey = (String) config.getOrDefault(SCORE_KEY, "score"); - this.idKey = (String) config.getOrDefault(ID_KEY, "id"); + this.embeddingKey = (String) config.getOrDefault(EMBEDDING_KEY, DEFAULT_EMBEDDING); + this.metadataKey = (String) config.getOrDefault(METADATA_KEY, DEFAULT_METADATA); + this.scoreKey = (String) config.getOrDefault(SCORE_KEY, DEFAULT_SCORE); + this.idKey = (String) config.getOrDefault(ID_KEY, DEFAULT_ID); + this.textKey = (String) config.getOrDefault(TEXT_KEY, DEFAULT_TEXT); this.mapping = new VectorMappingConfig((Map) config.getOrDefault(MAPPING_KEY, Map.of()));//.getOrDefault(MAPPING_KEY, Map.of()); } @@ -43,6 +52,10 @@ public String getScoreKey() { return scoreKey; } + public String getTextKey() { + return textKey; + } + public VectorMappingConfig getMapping() { return mapping; }