diff --git a/extended/src/main/java/apoc/ml/Prompt.java b/extended/src/main/java/apoc/ml/Prompt.java index e35a82c792..3743271308 100644 --- a/extended/src/main/java/apoc/ml/Prompt.java +++ b/extended/src/main/java/apoc/ml/Prompt.java @@ -274,8 +274,9 @@ public Stream query(@Name("question") String question, @Procedure public Stream schema(@Name(value = "conf", defaultValue = "{}") Map conf) throws MalformedURLException, JsonProcessingException { + String schema = loadSchema(tx, conf); String schemaExplanation = prompt("Please explain the graph database schema to me and relate it to well known concepts and domains.", - EXPLAIN_SCHEMA_PROMPT, "This database schema ", loadSchema(tx, conf), conf, List.of()); + EXPLAIN_SCHEMA_PROMPT, "This database schema ", schema, conf, List.of()); return Stream.of(new StringResult(schemaExplanation)); } diff --git a/extended/src/test/java/apoc/ml/PromptIT.java b/extended/src/test/java/apoc/ml/PromptIT.java index 3930502173..b0de35c8a3 100644 --- a/extended/src/test/java/apoc/ml/PromptIT.java +++ b/extended/src/test/java/apoc/ml/PromptIT.java @@ -13,6 +13,7 @@ import org.junit.BeforeClass; import org.junit.Rule; import org.junit.Test; +import org.neo4j.graphdb.Result; import org.neo4j.graphdb.Transaction; import org.neo4j.test.rule.DbmsRule; import org.neo4j.test.rule.ImpermanentDbmsRule; @@ -154,6 +155,52 @@ public void testCypher() { }); } + /* + TODO: + the loadSchema(tx, conf) seems to produce wrong results SOMETIMES??: + + nodes: + :Movie {released: INTEGER, tagline: STRING, title: STRING} +:Discipline {year: INTEGER, title: STRING} + relationships: + null + patterns: + + + */ + @Test + public void testCypherWithSchemaExplanation() { + long numOfQueries = 4L; + + String schema = db.executeTransactionally("CALL apoc.ml.schema({apiKey: $apiKey})", + Map.of("apiKey", OPENAI_KEY), Result::resultAsString); + System.out.println("schema = " + schema); + + // todo - il risultato รจ troppo generico e forse fa vedere altre cose, + // provare con la apoc.ml.cypher + + // todo --> https://kindo.ai/blog/8-tips-tricks-for-better-results-from-your-ai-prompts + + testResult(db, """ + CALL apoc.ml.cypher($query, {count: $numOfQueries, apiKey: $apiKey}) + """, + Map.of( + "query", "Who are the actors which also directed a movie?", + "numOfQueries", numOfQueries, + "apiKey", OPENAI_KEY + ), + (r) -> { + List> list = r.stream().toList(); + Assertions.assertThat(list).hasSize((int) numOfQueries); + Assertions.assertThat(list.stream() + .map(m -> m.get("query")) + .filter(Objects::nonNull) + .map(Object::toString) + .filter(StringUtils::isNotEmpty)) + .hasSize((int) numOfQueries); + }); + } + @Test public void testFromCypher() { testCall(db, """