Skip to content

Commit

Permalink
wip - rollup
Browse files Browse the repository at this point in the history
  • Loading branch information
vga91 committed Apr 8, 2024
1 parent 92d515d commit 62bcf27
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 20 deletions.
100 changes: 84 additions & 16 deletions extended/src/main/java/apoc/agg/MultiStats.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package apoc.agg;

import apoc.Extended;
import org.jetbrains.annotations.NotNull;
import org.neo4j.graphdb.Entity;
import org.neo4j.kernel.impl.util.ValueUtils;
import org.neo4j.procedure.Description;
Expand All @@ -20,6 +21,68 @@
@Extended
public class MultiStats {


@UserAggregationFunction("apoc.agg.rollup")
@Description("Return a multi-dimensional aggregation")
public RollupFunction rollup() {
return new RollupFunction();
}

public static class RollupFunction {
private static final String NULL_ROLLUP = "NULL";
private final Map<String, Object> result = new HashMap<>();
// private final Map<String, Map<String, Map<String, NumberValue>>> result = new HashMap<>();

@UserAggregationUpdate
public void aggregate(
@Name("value") Object value,
@Name(value = "groupKeys") List<String> groupKeys,
@Name(value = "aggKeys") List<String> aggKeys) {
Entity entity = (Entity) value;

if (groupKeys.isEmpty()) {
return;
}

if (entity.hasProperty(groupKeys.get(0))) {
return;
}

result.compute(groupKeys.get(0), (i, v) -> {
result.compute(groupKeys.get(1), (i2, v2) -> {

});
});


// result.compute(NULL_ROLLUP, ()


// primo compute
// inner compute
//
// secondo compute
// terzo compute
// `NULL`
}
}


/*
mysql> SELECT SupplierID, CategoryID, sum(Price), avg(Price), Unit FROM Products GROUP BY SupplierID, CategoryID WITH ROLLUP;
ERROR 1055 (42000): Expression #5 of SELECT list is not in GROUP BY clause and contains nonaggregated column 'Northwind.Products.Unit' which is not functionally dependent on columns in GROUP BY clause; this is incompatible with sql_mode=only_full_group_by
mysql> SELECT SupplierID, CategoryID, sum(Price), avg(Price) FROM Products GROUP BY SupplierID, CategoryID WITH ROLLUP;
+------------+------------+------------+------------+
| SupplierID | CategoryID | sum(Price) | avg(Price) |
+------------+------------+------------+------------+
| 1 | 1 | 37 | 18.5000 |
| 1 | 2 | 10 | 10.0000 |
| 1 | NULL | 47 | 15.6667 |
| 2 | 2 | 81 | 20.2500 |
| 2 | NULL | 81 | 20.2500 |
*/


@UserAggregationFunction("apoc.agg.multiStats")
@Description("Return a multi-dimensional aggregation")
public MultiStatsFunction multiStats() {
Expand All @@ -46,22 +109,7 @@ public void aggregate(

map.compute(property.toString(), (propKey, propVal) -> {

Map<String, NumberValue> propMap = Objects.requireNonNullElseGet(propVal, HashMap::new);

NumberValue count = propMap.compute("count",
((subKey, subVal) -> (NumberValue) ValueUtils.of(subVal == null ? 1 : subVal.longValue() + 1)) );

AnyValue neo4jValue = ValueUtils.of(property);

if (neo4jValue instanceof NumberValue numberValue) {
NumberValue sum = propMap.compute("sum",
((subKey, subVal) -> subVal == null ? numberValue : ValueMath.overflowSafeAdd(subVal, numberValue)));

propMap.compute("avg",
((subKey, subVal) -> subVal == null ? ValueUtils.asDoubleValue(numberValue.doubleValue()) : sum.dividedBy(count.doubleValue()) ));
}

return propMap;
return getStringNumberValueMap(property, propVal);
});

return map;
Expand All @@ -77,4 +125,24 @@ public Map<String, Map<String, Map<String, NumberValue>>> result() {
return result;
}
}


private static Map<String, NumberValue> getStringNumberValueMap(Object property, Map<String, NumberValue> propVal) {
Map<String, NumberValue> propMap = Objects.requireNonNullElseGet(propVal, HashMap::new);

NumberValue count = propMap.compute("count",
((subKey, subVal) -> (NumberValue) ValueUtils.of(subVal == null ? 1 : subVal.longValue() + 1)) );

AnyValue neo4jValue = ValueUtils.of(property);

if (neo4jValue instanceof NumberValue numberValue) {
NumberValue sum = propMap.compute("sum",
((subKey, subVal) -> subVal == null ? numberValue : ValueMath.overflowSafeAdd(subVal, numberValue)));

propMap.compute("avg",
((subKey, subVal) -> subVal == null ? ValueUtils.asDoubleValue(numberValue.doubleValue()) : sum.dividedBy(count.doubleValue()) ));
}

return propMap;
}
}
24 changes: 20 additions & 4 deletions extended/src/test/java/apoc/agg/MultiStatsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public static void tearDown() {
// similar to https://community.neo4j.com/t/listing-the-community-size-of-different-community-detection-algorithms-already-calculated/42895
@Test
public void testMultiStatsComparedWithCypherMultiAggregation() {
List multiAggregationResult = db.executeTransactionally("""
String multiAggregationResult = db.executeTransactionally("""
MATCH (p:Person)
WITH p
CALL {
Expand Down Expand Up @@ -72,9 +72,24 @@ RETURN sum(p.lpa) AS sumLpa, avg(p.lpa) AS avgLpa, count(p.lpa) AS countLpa
sumWcc, avgWcc, countWcc,
sumAnother, avgAnother, countAnother,
sumLpa, avgLpa, countLpa""", Map.of(),
Iterators::asList);
result -> result.resultAsString());

List multiStatsResult = db.executeTransactionally("""
/*
[ {key1: val1, key2: val2, key2: val3, <AGGR>} ]
*/


/*
- riga 1
- riga 2
----
-
*/

String multiStatsResult = db.executeTransactionally("""
match (p:Person)
with apoc.agg.multiStats(p, ["lpa","wcc","louvain", "another"]) as data
match (p:Person)
Expand All @@ -91,8 +106,9 @@ RETURN sum(p.lpa) AS sumLpa, avg(p.lpa) AS avgLpa, count(p.lpa) AS countLpa
data.wcc[toString(p.wcc)].sum AS sumWcc,
data.louvain[toString(p.louvain)].sum AS sumLouvain,
data.lpa[toString(p.lpa)].sum AS sumLpa
""", Map.of(), Iterators::asList);
""", Map.of(), r -> r.resultAsString());

System.out.println("multiStatsResult = \n" + multiStatsResult);
assertEquals(multiAggregationResult, multiStatsResult);

}
Expand Down

0 comments on commit 62bcf27

Please # to comment.