@Test public void testSingleAggregation() { tester().assertThat(new SingleDistinctAggregationToGroupBy()) .on(p -> p.aggregation(builder -> builder .globalGrouping() .addAggregation(p.symbol("output"), expression("count(DISTINCT input)"), ImmutableList.of(BIGINT)) .source( p.values(p.symbol("input"))))) .matches( aggregation( globalAggregation(), ImmutableMap.of( Optional.of("output"), functionCall("count", ImmutableList.of("input"))), ImmutableMap.of(), Optional.empty(), SINGLE, aggregation( singleGroupingSet("input"), ImmutableMap.of(), ImmutableMap.of(), Optional.empty(), SINGLE, values("input")))); }
@Test public void testFiresOnCountAggregateOverEnforceSingleRow() { tester().assertThat(new PruneCountAggregationOverScalar()) .on(p -> p.aggregation((a) -> a .addAggregation( p.symbol("count_1", BigintType.BIGINT), new FunctionCall(QualifiedName.of("count"), ImmutableList.of()), ImmutableList.of(BigintType.BIGINT)) .step(AggregationNode.Step.SINGLE) .globalGrouping() .source(p.enforceSingleRow(p.tableScan(ImmutableList.of(), ImmutableMap.of()))))) .matches(values(ImmutableMap.of("count_1", 0))); }
@Test public void testDoesNotFireOnNonNestedAggregate() { tester().assertThat(new PruneCountAggregationOverScalar()) .on(p -> p.aggregation((a) -> a .globalGrouping() .addAggregation( p.symbol("count_1", BigintType.BIGINT), new FunctionCall(QualifiedName.of("count"), ImmutableList.of()), ImmutableList.of(BigintType.BIGINT)) .source( p.tableScan(ImmutableList.of(), ImmutableMap.of()))) ).doesNotFire(); }
@Test public void rewritesOnSubqueryWithProjection() { tester().assertThat(new TransformCorrelatedScalarAggregationToJoin(tester().getMetadata().getFunctionRegistry())) .on(p -> p.lateral( ImmutableList.of(p.symbol("corr")), p.values(p.symbol("corr")), p.project(Assignments.of(p.symbol("expr"), p.expression("sum + 1")), p.aggregation(ab -> ab .source(p.values(p.symbol("a"), p.symbol("b"))) .addAggregation(p.symbol("sum"), PlanBuilder.expression("sum(a)"), ImmutableList.of(BIGINT)) .globalGrouping())))) .matches( project(ImmutableMap.of("corr", expression("corr"), "expr", expression("(\"sum_1\" + 1)")), aggregation(ImmutableMap.of("sum_1", functionCall("sum", ImmutableList.of("a"))), join(JoinNode.Type.LEFT, ImmutableList.of(), assignUniqueId("unique", values(ImmutableMap.of("corr", 0))), project(ImmutableMap.of("non_null", expression("true")), values(ImmutableMap.of("a", 0, "b", 1))))))); } }
@Test public void rewritesOnSubqueryWithoutProjection() { tester().assertThat(new TransformCorrelatedScalarAggregationToJoin(tester().getMetadata().getFunctionRegistry())) .on(p -> p.lateral( ImmutableList.of(p.symbol("corr")), p.values(p.symbol("corr")), p.aggregation(ab -> ab .source(p.values(p.symbol("a"), p.symbol("b"))) .addAggregation(p.symbol("sum"), PlanBuilder.expression("sum(a)"), ImmutableList.of(BIGINT)) .globalGrouping()))) .matches( project(ImmutableMap.of("sum_1", expression("sum_1"), "corr", expression("corr")), aggregation(ImmutableMap.of("sum_1", functionCall("sum", ImmutableList.of("a"))), join(JoinNode.Type.LEFT, ImmutableList.of(), assignUniqueId("unique", values(ImmutableMap.of("corr", 0))), project(ImmutableMap.of("non_null", expression("true")), values(ImmutableMap.of("a", 0, "b", 1))))))); }
@Test public void testDistinctWithFilter() { tester().assertThat(new MultipleDistinctAggregationToMarkDistinct()) .on(p -> p.aggregation(builder -> builder .globalGrouping() .addAggregation(p.symbol("output1"), expression("count(DISTINCT input1) filter (where input2 > 0)"), ImmutableList.of(BIGINT)) .addAggregation(p.symbol("output2"), expression("count(DISTINCT input2) filter (where input1 > 0)"), ImmutableList.of(BIGINT)) .source( p.values( p.symbol("input1"), p.symbol("input2"))))) .doesNotFire(); tester().assertThat(new MultipleDistinctAggregationToMarkDistinct()) .on(p -> p.aggregation(builder -> builder .globalGrouping() .addAggregation(p.symbol("output1"), expression("count(DISTINCT input1) filter (where input2 > 0)"), ImmutableList.of(BIGINT)) .addAggregation(p.symbol("output2"), expression("count(DISTINCT input2)"), ImmutableList.of(BIGINT)) .source( p.values( p.symbol("input1"), p.symbol("input2"))))) .doesNotFire(); } }
@Test public void testFiresOnNestedCountAggregate() { tester().assertThat(new PruneCountAggregationOverScalar()) .on(p -> p.aggregation((a) -> a .addAggregation( p.symbol("count_1", BigintType.BIGINT), new FunctionCall(QualifiedName.of("count"), ImmutableList.of()), ImmutableList.of(BigintType.BIGINT)) .globalGrouping() .step(AggregationNode.Step.SINGLE) .source( p.aggregation((aggregationBuilder) -> aggregationBuilder .source(p.tableScan(ImmutableList.of(), ImmutableMap.of())) .globalGrouping() .step(AggregationNode.Step.SINGLE))))) .matches(values(ImmutableMap.of("count_1", 0))); }
@Test public void testDoesNotFireOnNestedCountAggregateWithNonEmptyGroupBy() { tester().assertThat(new PruneCountAggregationOverScalar()) .on(p -> p.aggregation((a) -> a .addAggregation( p.symbol("count_1", BigintType.BIGINT), new FunctionCall(QualifiedName.of("count"), ImmutableList.of()), ImmutableList.of(BigintType.BIGINT)) .step(AggregationNode.Step.SINGLE) .globalGrouping() .source( p.aggregation(aggregationBuilder -> { aggregationBuilder .source(p.tableScan(ImmutableList.of(), ImmutableMap.of())).groupingSets(singleGroupingSet(ImmutableList.of(p.symbol("orderkey")))); aggregationBuilder .source(p.tableScan(ImmutableList.of(), ImmutableMap.of())); })))) .doesNotFire(); }
@Test public void testAggregationExpressionRewrite() { tester().assertThat(functionCallRewriter.aggregationExpressionRewrite()) .on(p -> p.aggregation(a -> a .globalGrouping() .addAggregation( p.symbol("count_1", BigintType.BIGINT), new FunctionCall(QualifiedName.of("count"), ImmutableList.of()), ImmutableList.of(BigintType.BIGINT)) .source( p.values()))) .matches( PlanMatchPattern.aggregation( ImmutableMap.of("count_1", functionCall("now", ImmutableList.of())), values())); }
@Test public void testSessionDisable() { tester().assertThat(new AddIntermediateAggregations()) .setSystemProperty(ENABLE_INTERMEDIATE_AGGREGATIONS, "false") .setSystemProperty(TASK_CONCURRENCY, "4") .on(p -> p.aggregation(af -> { af.globalGrouping() .step(AggregationNode.Step.FINAL) .addAggregation(p.symbol("c"), expression("count(b)"), ImmutableList.of(BIGINT)) .source( p.gatheringExchange( ExchangeNode.Scope.REMOTE, p.aggregation(ap -> ap.globalGrouping() .step(AggregationNode.Step.PARTIAL) .addAggregation(p.symbol("b"), expression("count(a)"), ImmutableList.of(BIGINT)) .source( p.values(p.symbol("a")))))); })) .doesNotFire(); }
@Test public void testAggregationExpressionNotRewritten() { tester().assertThat(functionCallRewriter.aggregationExpressionRewrite()) .on(p -> p.aggregation(a -> a .globalGrouping() .addAggregation( p.symbol("count_1", DateType.DATE), nowCall, ImmutableList.of()) .source( p.values()))) .doesNotFire(); }
@Test public void testFiresOnCountAggregateOverValues() { tester().assertThat(new PruneCountAggregationOverScalar()) .on(p -> p.aggregation((a) -> a .addAggregation( p.symbol("count_1", BigintType.BIGINT), new FunctionCall(QualifiedName.of("count"), ImmutableList.of()), ImmutableList.of(BigintType.BIGINT)) .step(AggregationNode.Step.SINGLE) .globalGrouping() .source(p.values(ImmutableList.of(p.symbol("orderkey")), ImmutableList.of(p.expressions("1")))))) .matches(values(ImmutableMap.of("count_1", 0))); }
@Test public void testMixedDistinctAndNonDistinct() { tester().assertThat(new SingleDistinctAggregationToGroupBy()) .on(p -> p.aggregation(builder -> builder .globalGrouping() .addAggregation(p.symbol("output1"), expression("count(DISTINCT input1)"), ImmutableList.of(BIGINT)) .addAggregation(p.symbol("output2"), expression("count(input2)"), ImmutableList.of(BIGINT)) .source( p.values( p.symbol("input1"), p.symbol("input2"))))) .doesNotFire(); }
@Test public void testMultipleAggregations() { tester().assertThat(new MultipleDistinctAggregationToMarkDistinct()) .on(p -> p.aggregation(builder -> builder .globalGrouping() .addAggregation(p.symbol("output1"), expression("count(DISTINCT input)"), ImmutableList.of(BIGINT)) .addAggregation(p.symbol("output2"), expression("sum(DISTINCT input)"), ImmutableList.of(BIGINT)) .source( p.values(p.symbol("input"))))) .doesNotFire(); }
@Test public void testNoDistinct() { tester().assertThat(new SingleDistinctAggregationToGroupBy()) .on(p -> p.aggregation(builder -> builder .globalGrouping() .addAggregation(p.symbol("output1"), expression("count(input1)"), ImmutableList.of(BIGINT)) .addAggregation(p.symbol("output2"), expression("count(input2)"), ImmutableList.of(BIGINT)) .source( p.values( p.symbol("input1"), p.symbol("input2"))))) .doesNotFire(); }
@Test public void testMultipleDistincts() { tester().assertThat(new SingleDistinctAggregationToGroupBy()) .on(p -> p.aggregation(builder -> builder .globalGrouping() .addAggregation(p.symbol("output1"), expression("count(DISTINCT input1)"), ImmutableList.of(BIGINT)) .addAggregation(p.symbol("output2"), expression("count(DISTINCT input2)"), ImmutableList.of(BIGINT)) .source( p.values( p.symbol("input1"), p.symbol("input2"))))) .doesNotFire(); }
@Test public void testNoDistinct() { tester().assertThat(new SingleDistinctAggregationToGroupBy()) .on(p -> p.aggregation(builder -> builder .globalGrouping() .addAggregation(p.symbol("output1"), expression("count(input1)"), ImmutableList.of(BIGINT)) .source( p.values( p.symbol("input1"), p.symbol("input2"))))) .doesNotFire(); }
@Test public void testSingleDistinct() { tester().assertThat(new MultipleDistinctAggregationToMarkDistinct()) .on(p -> p.aggregation(builder -> builder .globalGrouping() .addAggregation(p.symbol("output1"), expression("count(DISTINCT input1)"), ImmutableList.of(BIGINT)) .source( p.values( p.symbol("input1"), p.symbol("input2"))))) .doesNotFire(); }
@Test public void testDistinctWithFilter() { tester().assertThat(new SingleDistinctAggregationToGroupBy()) .on(p -> p.aggregation(builder -> builder .globalGrouping() .addAggregation(p.symbol("output"), expression("count(DISTINCT input1) filter (where input2 > 0)"), ImmutableList.of(BIGINT)) .source( p.values( p.symbol("input1"), p.symbol("input2"))))) .doesNotFire(); }
@Test public void testDoesNotFire() { assertRuleApplication() .on(p -> p.aggregation(a -> a.globalGrouping() .step(AggregationNode.Step.FINAL) .addAggregation(p.symbol("sp"), PlanBuilder.expression("spatial_partitioning(geometry, 10)"), ImmutableList.of(GEOMETRY)) .source(p.values(p.symbol("geometry"))))) .doesNotFire(); }