public AggregationBuilder addAggregation(Symbol output, Expression expression, List<Type> inputTypes, Symbol mask) { return addAggregation(output, expression, inputTypes, Optional.of(mask)); }
public AggregationBuilder addAggregation(Symbol output, Expression expression, List<Type> inputTypes) { return addAggregation(output, expression, inputTypes, Optional.empty()); }
private ProjectNode buildProjectedAggregation(PlanBuilder planBuilder, Predicate<Symbol> projectionFilter) { Symbol a = planBuilder.symbol("a"); Symbol b = planBuilder.symbol("b"); Symbol key = planBuilder.symbol("key"); return planBuilder.project( Assignments.identity(ImmutableList.of(a, b).stream().filter(projectionFilter).collect(toImmutableSet())), planBuilder.aggregation(aggregationBuilder -> aggregationBuilder .source(planBuilder.values(key)) .singleGroupingSet(key) .addAggregation(a, planBuilder.expression("count()"), ImmutableList.of()) .addAggregation(b, planBuilder.expression("count()"), ImmutableList.of()))); } }
@Test public void testDoesNotFireWhenAggregationDoesNotHaveSymbols() { tester().assertThat(new PushAggregationThroughOuterJoin()) .on(p -> p.aggregation(ab -> ab .source(p.join( JoinNode.Type.LEFT, p.values(ImmutableList.of(p.symbol("COL1")), ImmutableList.of(expressions("10"))), p.values(ImmutableList.of(p.symbol("COL2")), ImmutableList.of(expressions("20"))), ImmutableList.of(new JoinNode.EquiJoinClause(new Symbol("COL1"), new Symbol("COL2"))), ImmutableList.of(new Symbol("COL1"), new Symbol("COL2")), Optional.empty(), Optional.empty(), Optional.empty())) .addAggregation(new Symbol("SUM"), PlanBuilder.expression("sum(COL1)"), ImmutableList.of(DOUBLE)) .singleGroupingSet(new Symbol("COL1")))) .doesNotFire(); } }
@Test public void testDoesNotFireWhenGroupingOnInner() { tester().assertThat(new PushAggregationThroughOuterJoin()) .on(p -> p.aggregation(ab -> ab .source(p.join(JoinNode.Type.LEFT, p.values(ImmutableList.of(p.symbol("COL1")), ImmutableList.of(expressions("10"))), p.values(new Symbol("COL2"), new Symbol("COL3")), ImmutableList.of(new JoinNode.EquiJoinClause(new Symbol("COL1"), new Symbol("COL2"))), ImmutableList.of(new Symbol("COL1"), new Symbol("COL2")), Optional.empty(), Optional.empty(), Optional.empty())) .addAggregation(new Symbol("AVG"), PlanBuilder.expression("avg(COL2)"), ImmutableList.of(DOUBLE)) .singleGroupingSet(new Symbol("COL1"), new Symbol("COL3")))) .doesNotFire(); }
@Test public void testDoesNotFireOnNonNestedAggregate() { tester().assertThat(new PruneCountAggregationOverScalar()) .on(p -> p.aggregation((a) -> a .globalGrouping() .addAggregation( p.symbol("count_1", BigintType.BIGINT), new FunctionCall(QualifiedName.of("count"), ImmutableList.of()), ImmutableList.of(BigintType.BIGINT)) .source( p.tableScan(ImmutableList.of(), ImmutableMap.of()))) ).doesNotFire(); }
@Test public void testAggregationStatsCappedToInputRows() { tester().assertStatsFor(pb -> pb .aggregation(ab -> ab .addAggregation(pb.symbol("count_on_x", BIGINT), expression("count(x)"), ImmutableList.of(BIGINT)) .singleGroupingSet(pb.symbol("y", BIGINT), pb.symbol("z", BIGINT)) .source(pb.values(pb.symbol("x", BIGINT), pb.symbol("y", BIGINT), pb.symbol("z", BIGINT))))) .withSourceStats(PlanNodeStatsEstimate.builder() .setOutputRowCount(100) .addSymbolStatistics(new Symbol("y"), SymbolStatsEstimate.builder().setDistinctValuesCount(50).build()) .addSymbolStatistics(new Symbol("z"), SymbolStatsEstimate.builder().setDistinctValuesCount(50).build()) .build()) .check(check -> check.outputRowsCount(100)); } }
@Test public void testAggregationExpressionNotRewritten() { tester().assertThat(functionCallRewriter.aggregationExpressionRewrite()) .on(p -> p.aggregation(a -> a .globalGrouping() .addAggregation( p.symbol("count_1", DateType.DATE), nowCall, ImmutableList.of()) .source( p.values()))) .doesNotFire(); }
@Test public void testFiresOnCountAggregateOverEnforceSingleRow() { tester().assertThat(new PruneCountAggregationOverScalar()) .on(p -> p.aggregation((a) -> a .addAggregation( p.symbol("count_1", BigintType.BIGINT), new FunctionCall(QualifiedName.of("count"), ImmutableList.of()), ImmutableList.of(BigintType.BIGINT)) .step(AggregationNode.Step.SINGLE) .globalGrouping() .source(p.enforceSingleRow(p.tableScan(ImmutableList.of(), ImmutableMap.of()))))) .matches(values(ImmutableMap.of("count_1", 0))); }
@Test public void testFiresOnCountAggregateOverValues() { tester().assertThat(new PruneCountAggregationOverScalar()) .on(p -> p.aggregation((a) -> a .addAggregation( p.symbol("count_1", BigintType.BIGINT), new FunctionCall(QualifiedName.of("count"), ImmutableList.of()), ImmutableList.of(BigintType.BIGINT)) .step(AggregationNode.Step.SINGLE) .globalGrouping() .source(p.values(ImmutableList.of(p.symbol("orderkey")), ImmutableList.of(p.expressions("1")))))) .matches(values(ImmutableMap.of("count_1", 0))); }
@Test public void doesNotFireOnCorrelatedWithNonScalarAggregation() { tester().assertThat(new TransformCorrelatedScalarAggregationToJoin(tester().getMetadata().getFunctionRegistry())) .on(p -> p.lateral( ImmutableList.of(p.symbol("corr")), p.values(p.symbol("corr")), p.aggregation(ab -> ab .source(p.values(p.symbol("a"), p.symbol("b"))) .addAggregation(p.symbol("sum"), PlanBuilder.expression("sum(a)"), ImmutableList.of(BIGINT)) .singleGroupingSet(p.symbol("b"))))) .doesNotFire(); }
@Test public void testMixedDistinctAndNonDistinct() { tester().assertThat(new SingleDistinctAggregationToGroupBy()) .on(p -> p.aggregation(builder -> builder .globalGrouping() .addAggregation(p.symbol("output1"), expression("count(DISTINCT input1)"), ImmutableList.of(BIGINT)) .addAggregation(p.symbol("output2"), expression("count(input2)"), ImmutableList.of(BIGINT)) .source( p.values( p.symbol("input1"), p.symbol("input2"))))) .doesNotFire(); }
@Test public void testNoDistinct() { tester().assertThat(new SingleDistinctAggregationToGroupBy()) .on(p -> p.aggregation(builder -> builder .globalGrouping() .addAggregation(p.symbol("output1"), expression("count(input1)"), ImmutableList.of(BIGINT)) .addAggregation(p.symbol("output2"), expression("count(input2)"), ImmutableList.of(BIGINT)) .source( p.values( p.symbol("input1"), p.symbol("input2"))))) .doesNotFire(); }
@Test public void testMultipleDistincts() { tester().assertThat(new SingleDistinctAggregationToGroupBy()) .on(p -> p.aggregation(builder -> builder .globalGrouping() .addAggregation(p.symbol("output1"), expression("count(DISTINCT input1)"), ImmutableList.of(BIGINT)) .addAggregation(p.symbol("output2"), expression("count(DISTINCT input2)"), ImmutableList.of(BIGINT)) .source( p.values( p.symbol("input1"), p.symbol("input2"))))) .doesNotFire(); }
@Test public void testNoDistinct() { tester().assertThat(new SingleDistinctAggregationToGroupBy()) .on(p -> p.aggregation(builder -> builder .globalGrouping() .addAggregation(p.symbol("output1"), expression("count(input1)"), ImmutableList.of(BIGINT)) .source( p.values( p.symbol("input1"), p.symbol("input2"))))) .doesNotFire(); }
@Test public void testMultipleAggregations() { tester().assertThat(new MultipleDistinctAggregationToMarkDistinct()) .on(p -> p.aggregation(builder -> builder .globalGrouping() .addAggregation(p.symbol("output1"), expression("count(DISTINCT input)"), ImmutableList.of(BIGINT)) .addAggregation(p.symbol("output2"), expression("sum(DISTINCT input)"), ImmutableList.of(BIGINT)) .source( p.values(p.symbol("input"))))) .doesNotFire(); }
private AggregationBuilder addAggregation(Symbol output, Expression expression, List<Type> inputTypes, Optional<Symbol> mask) { checkArgument(expression instanceof FunctionCall); FunctionCall aggregation = (FunctionCall) expression; Signature signature = metadata.getFunctionRegistry().resolveFunction(aggregation.getName(), TypeSignatureProvider.fromTypes(inputTypes)); return addAggregation(output, new Aggregation(aggregation, signature, mask)); }
@Test public void testSingleDistinct() { tester().assertThat(new MultipleDistinctAggregationToMarkDistinct()) .on(p -> p.aggregation(builder -> builder .globalGrouping() .addAggregation(p.symbol("output1"), expression("count(DISTINCT input1)"), ImmutableList.of(BIGINT)) .source( p.values( p.symbol("input1"), p.symbol("input2"))))) .doesNotFire(); }
@Test public void testDistinctWithFilter() { tester().assertThat(new SingleDistinctAggregationToGroupBy()) .on(p -> p.aggregation(builder -> builder .globalGrouping() .addAggregation(p.symbol("output"), expression("count(DISTINCT input1) filter (where input2 > 0)"), ImmutableList.of(BIGINT)) .source( p.values( p.symbol("input1"), p.symbol("input2"))))) .doesNotFire(); }
@Test public void testDoesNotFire() { assertRuleApplication() .on(p -> p.aggregation(a -> a.globalGrouping() .step(AggregationNode.Step.FINAL) .addAggregation(p.symbol("sp"), PlanBuilder.expression("spatial_partitioning(geometry, 10)"), ImmutableList.of(GEOMETRY)) .source(p.values(p.symbol("geometry"))))) .doesNotFire(); }