diff --git a/extended/src/main/java/apoc/agg/MultiStats.java b/extended/src/main/java/apoc/agg/MultiStats.java index 097c041fd0..34901b9267 100644 --- a/extended/src/main/java/apoc/agg/MultiStats.java +++ b/extended/src/main/java/apoc/agg/MultiStats.java @@ -1,6 +1,7 @@ package apoc.agg; import apoc.Extended; +import org.jetbrains.annotations.NotNull; import org.neo4j.graphdb.Entity; import org.neo4j.kernel.impl.util.ValueUtils; import org.neo4j.procedure.Description; @@ -20,6 +21,68 @@ @Extended public class MultiStats { + + @UserAggregationFunction("apoc.agg.rollup") + @Description("Return a multi-dimensional aggregation") + public RollupFunction rollup() { + return new RollupFunction(); + } + + public static class RollupFunction { + private static final String NULL_ROLLUP = "NULL"; + private final Map result = new HashMap<>(); +// private final Map>> result = new HashMap<>(); + + @UserAggregationUpdate + public void aggregate( + @Name("value") Object value, + @Name(value = "groupKeys") List groupKeys, + @Name(value = "aggKeys") List aggKeys) { + Entity entity = (Entity) value; + + if (groupKeys.isEmpty()) { + return; + } + + if (entity.hasProperty(groupKeys.get(0))) { + return; + } + + result.compute(groupKeys.get(0), (i, v) -> { + result.compute(groupKeys.get(1), (i2, v2) -> { + + }); + }); + + +// result.compute(NULL_ROLLUP, () + + + // primo compute + // inner compute + // + // secondo compute + // terzo compute + // `NULL` + } + } + + + /* + mysql> SELECT SupplierID, CategoryID, sum(Price), avg(Price), Unit FROM Products GROUP BY SupplierID, CategoryID WITH ROLLUP; +ERROR 1055 (42000): Expression #5 of SELECT list is not in GROUP BY clause and contains nonaggregated column 'Northwind.Products.Unit' which is not functionally dependent on columns in GROUP BY clause; this is incompatible with sql_mode=only_full_group_by +mysql> SELECT SupplierID, CategoryID, sum(Price), avg(Price) FROM Products GROUP BY SupplierID, CategoryID WITH ROLLUP; ++------------+------------+------------+------------+ +| SupplierID | CategoryID | sum(Price) | avg(Price) | ++------------+------------+------------+------------+ +| 1 | 1 | 37 | 18.5000 | +| 1 | 2 | 10 | 10.0000 | +| 1 | NULL | 47 | 15.6667 | +| 2 | 2 | 81 | 20.2500 | +| 2 | NULL | 81 | 20.2500 | + */ + + @UserAggregationFunction("apoc.agg.multiStats") @Description("Return a multi-dimensional aggregation") public MultiStatsFunction multiStats() { @@ -46,22 +109,7 @@ public void aggregate( map.compute(property.toString(), (propKey, propVal) -> { - Map propMap = Objects.requireNonNullElseGet(propVal, HashMap::new); - - NumberValue count = propMap.compute("count", - ((subKey, subVal) -> (NumberValue) ValueUtils.of(subVal == null ? 1 : subVal.longValue() + 1)) ); - - AnyValue neo4jValue = ValueUtils.of(property); - - if (neo4jValue instanceof NumberValue numberValue) { - NumberValue sum = propMap.compute("sum", - ((subKey, subVal) -> subVal == null ? numberValue : ValueMath.overflowSafeAdd(subVal, numberValue))); - - propMap.compute("avg", - ((subKey, subVal) -> subVal == null ? ValueUtils.asDoubleValue(numberValue.doubleValue()) : sum.dividedBy(count.doubleValue()) )); - } - - return propMap; + return getStringNumberValueMap(property, propVal); }); return map; @@ -77,4 +125,24 @@ public Map>> result() { return result; } } + + + private static Map getStringNumberValueMap(Object property, Map propVal) { + Map propMap = Objects.requireNonNullElseGet(propVal, HashMap::new); + + NumberValue count = propMap.compute("count", + ((subKey, subVal) -> (NumberValue) ValueUtils.of(subVal == null ? 1 : subVal.longValue() + 1)) ); + + AnyValue neo4jValue = ValueUtils.of(property); + + if (neo4jValue instanceof NumberValue numberValue) { + NumberValue sum = propMap.compute("sum", + ((subKey, subVal) -> subVal == null ? numberValue : ValueMath.overflowSafeAdd(subVal, numberValue))); + + propMap.compute("avg", + ((subKey, subVal) -> subVal == null ? ValueUtils.asDoubleValue(numberValue.doubleValue()) : sum.dividedBy(count.doubleValue()) )); + } + + return propMap; + } } diff --git a/extended/src/test/java/apoc/agg/MultiStatsTest.java b/extended/src/test/java/apoc/agg/MultiStatsTest.java index c9e167a750..cd71d14674 100644 --- a/extended/src/test/java/apoc/agg/MultiStatsTest.java +++ b/extended/src/test/java/apoc/agg/MultiStatsTest.java @@ -44,7 +44,7 @@ public static void tearDown() { // similar to https://community.neo4j.com/t/listing-the-community-size-of-different-community-detection-algorithms-already-calculated/42895 @Test public void testMultiStatsComparedWithCypherMultiAggregation() { - List multiAggregationResult = db.executeTransactionally(""" + String multiAggregationResult = db.executeTransactionally(""" MATCH (p:Person) WITH p CALL { @@ -72,9 +72,24 @@ RETURN sum(p.lpa) AS sumLpa, avg(p.lpa) AS avgLpa, count(p.lpa) AS countLpa sumWcc, avgWcc, countWcc, sumAnother, avgAnother, countAnother, sumLpa, avgLpa, countLpa""", Map.of(), - Iterators::asList); + result -> result.resultAsString()); - List multiStatsResult = db.executeTransactionally(""" + /* + [ {key1: val1, key2: val2, key2: val3, } ] + */ + + + /* + - riga 1 + - riga 2 + + ---- + + - + + */ + + String multiStatsResult = db.executeTransactionally(""" match (p:Person) with apoc.agg.multiStats(p, ["lpa","wcc","louvain", "another"]) as data match (p:Person) @@ -91,8 +106,9 @@ RETURN sum(p.lpa) AS sumLpa, avg(p.lpa) AS avgLpa, count(p.lpa) AS countLpa data.wcc[toString(p.wcc)].sum AS sumWcc, data.louvain[toString(p.louvain)].sum AS sumLouvain, data.lpa[toString(p.lpa)].sum AS sumLpa - """, Map.of(), Iterators::asList); + """, Map.of(), r -> r.resultAsString()); + System.out.println("multiStatsResult = \n" + multiStatsResult); assertEquals(multiAggregationResult, multiStatsResult); }