@Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "Final aggregation with default value not separated from partial aggregation by remote hash exchange") public void testGloballyDistributedFinalAggregationInTheSameStageAsPartialAggregation() { PlanNode root = builder.aggregation( af -> af.step(FINAL) .groupingSets(groupingSets(ImmutableList.of(symbol), 2, ImmutableSet.of(0))) .source(builder.aggregation(ap -> ap .step(PARTIAL) .groupingSets(groupingSets(ImmutableList.of(symbol), 2, ImmutableSet.of(0))) .source(tableScanNode)))); validatePlan(root, false); }
@Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "Final aggregation with default value not separated from partial aggregation by local hash exchange") public void testSingleNodeFinalAggregationInTheSameStageAsPartialAggregation() { PlanNode root = builder.aggregation( af -> af.step(FINAL) .groupingSets(groupingSets(ImmutableList.of(symbol), 2, ImmutableSet.of(0))) .source(builder.aggregation(ap -> ap .step(PARTIAL) .groupingSets(groupingSets(ImmutableList.of(symbol), 2, ImmutableSet.of(0))) .source(tableScanNode)))); validatePlan(root, true); }
@Test public void testSingleThreadFinalAggregationInTheSameStageAsPartialAggregation() { PlanNode root = builder.aggregation( af -> af.step(FINAL) .groupingSets(groupingSets(ImmutableList.of(symbol), 2, ImmutableSet.of(0))) .source(builder.aggregation(ap -> ap .step(PARTIAL) .groupingSets(groupingSets(ImmutableList.of(symbol), 2, ImmutableSet.of(0))) .source(builder.values())))); validatePlan(root, true); }
@Test public void testGloballyDistributedFinalAggregationSeparatedFromPartialAggregationByRemoteHashExchange() { Symbol symbol = new Symbol("symbol"); PlanNode root = builder.aggregation( af -> af.step(FINAL) .groupingSets(groupingSets(ImmutableList.of(symbol), 2, ImmutableSet.of(0))) .source(builder.exchange(e -> e .type(REPARTITION) .scope(REMOTE) .fixedHashDistributionParitioningScheme(ImmutableList.of(symbol), ImmutableList.of(symbol)) .addInputsSet(symbol) .addSource(builder.aggregation(ap -> ap .step(PARTIAL) .groupingSets(groupingSets(ImmutableList.of(symbol), 2, ImmutableSet.of(0))) .source(tableScanNode)))))); validatePlan(root, false); }
@Test public void testAggregationStatsCappedToInputRows() { tester().assertStatsFor(pb -> pb .aggregation(ab -> ab .addAggregation(pb.symbol("count_on_x", BIGINT), expression("count(x)"), ImmutableList.of(BIGINT)) .singleGroupingSet(pb.symbol("y", BIGINT), pb.symbol("z", BIGINT)) .source(pb.values(pb.symbol("x", BIGINT), pb.symbol("y", BIGINT), pb.symbol("z", BIGINT))))) .withSourceStats(PlanNodeStatsEstimate.builder() .setOutputRowCount(100) .addSymbolStatistics(new Symbol("y"), SymbolStatsEstimate.builder().setDistinctValuesCount(50).build()) .addSymbolStatistics(new Symbol("z"), SymbolStatsEstimate.builder().setDistinctValuesCount(50).build()) .build()) .check(check -> check.outputRowsCount(100)); } }
@Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "Final aggregation with default value not separated from partial aggregation by local hash exchange") public void testWithPartialAggregationBelowJoinWithoutSeparatingExchange() { Symbol symbol = new Symbol("symbol"); PlanNode root = builder.aggregation( af -> af.step(FINAL) .groupingSets(groupingSets(ImmutableList.of(symbol), 2, ImmutableSet.of(0))) .source(builder.join( INNER, builder.aggregation(ap -> ap .step(PARTIAL) .groupingSets(groupingSets(ImmutableList.of(symbol), 2, ImmutableSet.of(0))) .source(tableScanNode)), builder.values()))); validatePlan(root, true); }
private ProjectNode buildProjectedAggregation(PlanBuilder planBuilder, Predicate<Symbol> projectionFilter) { Symbol a = planBuilder.symbol("a"); Symbol b = planBuilder.symbol("b"); Symbol key = planBuilder.symbol("key"); return planBuilder.project( Assignments.identity(ImmutableList.of(a, b).stream().filter(projectionFilter).collect(toImmutableSet())), planBuilder.aggregation(aggregationBuilder -> aggregationBuilder .source(planBuilder.values(key)) .singleGroupingSet(key) .addAggregation(a, planBuilder.expression("count()"), ImmutableList.of()) .addAggregation(b, planBuilder.expression("count()"), ImmutableList.of()))); } }
@Test public void testAggregationExpressionNotRewritten() { tester().assertThat(functionCallRewriter.aggregationExpressionRewrite()) .on(p -> p.aggregation(a -> a .globalGrouping() .addAggregation( p.symbol("count_1", DateType.DATE), nowCall, ImmutableList.of()) .source( p.values()))) .doesNotFire(); }
@Test public void testFiresOnCountAggregateOverEnforceSingleRow() { tester().assertThat(new PruneCountAggregationOverScalar()) .on(p -> p.aggregation((a) -> a .addAggregation( p.symbol("count_1", BigintType.BIGINT), new FunctionCall(QualifiedName.of("count"), ImmutableList.of()), ImmutableList.of(BigintType.BIGINT)) .step(AggregationNode.Step.SINGLE) .globalGrouping() .source(p.enforceSingleRow(p.tableScan(ImmutableList.of(), ImmutableMap.of()))))) .matches(values(ImmutableMap.of("count_1", 0))); }
@Test public void testFiresOnCountAggregateOverValues() { tester().assertThat(new PruneCountAggregationOverScalar()) .on(p -> p.aggregation((a) -> a .addAggregation( p.symbol("count_1", BigintType.BIGINT), new FunctionCall(QualifiedName.of("count"), ImmutableList.of()), ImmutableList.of(BigintType.BIGINT)) .step(AggregationNode.Step.SINGLE) .globalGrouping() .source(p.values(ImmutableList.of(p.symbol("orderkey")), ImmutableList.of(p.expressions("1")))))) .matches(values(ImmutableMap.of("count_1", 0))); }
@Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "Streaming aggregation with input not grouped on the grouping keys") public void testValidateFailed() { validatePlan( p -> p.aggregation( a -> a.step(SINGLE) .singleGroupingSet(p.symbol("nationkey")) .preGroupedSymbols(p.symbol("nationkey")) .source( p.tableScan( nationTableHandle, ImmutableList.of(p.symbol("nationkey", BIGINT)), ImmutableMap.of(p.symbol("nationkey", BIGINT), new TpchColumnHandle("nationkey", BIGINT)), Optional.of(nationTableLayoutHandle))))); }
@Test public void doesNotFireOnCorrelatedWithNonScalarAggregation() { tester().assertThat(new TransformCorrelatedScalarAggregationToJoin(tester().getMetadata().getFunctionRegistry())) .on(p -> p.lateral( ImmutableList.of(p.symbol("corr")), p.values(p.symbol("corr")), p.aggregation(ab -> ab .source(p.values(p.symbol("a"), p.symbol("b"))) .addAggregation(p.symbol("sum"), PlanBuilder.expression("sum(a)"), ImmutableList.of(BIGINT)) .singleGroupingSet(p.symbol("b"))))) .doesNotFire(); }
@Test public void testMixedDistinctAndNonDistinct() { tester().assertThat(new SingleDistinctAggregationToGroupBy()) .on(p -> p.aggregation(builder -> builder .globalGrouping() .addAggregation(p.symbol("output1"), expression("count(DISTINCT input1)"), ImmutableList.of(BIGINT)) .addAggregation(p.symbol("output2"), expression("count(input2)"), ImmutableList.of(BIGINT)) .source( p.values( p.symbol("input1"), p.symbol("input2"))))) .doesNotFire(); }
@Test public void testNoDistinct() { tester().assertThat(new SingleDistinctAggregationToGroupBy()) .on(p -> p.aggregation(builder -> builder .globalGrouping() .addAggregation(p.symbol("output1"), expression("count(input1)"), ImmutableList.of(BIGINT)) .addAggregation(p.symbol("output2"), expression("count(input2)"), ImmutableList.of(BIGINT)) .source( p.values( p.symbol("input1"), p.symbol("input2"))))) .doesNotFire(); }
@Test public void testMultipleDistincts() { tester().assertThat(new SingleDistinctAggregationToGroupBy()) .on(p -> p.aggregation(builder -> builder .globalGrouping() .addAggregation(p.symbol("output1"), expression("count(DISTINCT input1)"), ImmutableList.of(BIGINT)) .addAggregation(p.symbol("output2"), expression("count(DISTINCT input2)"), ImmutableList.of(BIGINT)) .source( p.values( p.symbol("input1"), p.symbol("input2"))))) .doesNotFire(); }
@Test public void testMultipleAggregations() { tester().assertThat(new MultipleDistinctAggregationToMarkDistinct()) .on(p -> p.aggregation(builder -> builder .globalGrouping() .addAggregation(p.symbol("output1"), expression("count(DISTINCT input)"), ImmutableList.of(BIGINT)) .addAggregation(p.symbol("output2"), expression("sum(DISTINCT input)"), ImmutableList.of(BIGINT)) .source( p.values(p.symbol("input"))))) .doesNotFire(); }
@Test public void testNoDistinct() { tester().assertThat(new SingleDistinctAggregationToGroupBy()) .on(p -> p.aggregation(builder -> builder .globalGrouping() .addAggregation(p.symbol("output1"), expression("count(input1)"), ImmutableList.of(BIGINT)) .source( p.values( p.symbol("input1"), p.symbol("input2"))))) .doesNotFire(); }
@Test public void testDoesNotFire() { assertRuleApplication() .on(p -> p.aggregation(a -> a.globalGrouping() .step(AggregationNode.Step.FINAL) .addAggregation(p.symbol("sp"), PlanBuilder.expression("spatial_partitioning(geometry, 10)"), ImmutableList.of(GEOMETRY)) .source(p.values(p.symbol("geometry"))))) .doesNotFire(); }
@Test public void testSingleDistinct() { tester().assertThat(new MultipleDistinctAggregationToMarkDistinct()) .on(p -> p.aggregation(builder -> builder .globalGrouping() .addAggregation(p.symbol("output1"), expression("count(DISTINCT input1)"), ImmutableList.of(BIGINT)) .source( p.values( p.symbol("input1"), p.symbol("input2"))))) .doesNotFire(); }
@Test public void testDistinctWithFilter() { tester().assertThat(new SingleDistinctAggregationToGroupBy()) .on(p -> p.aggregation(builder -> builder .globalGrouping() .addAggregation(p.symbol("output"), expression("count(DISTINCT input1) filter (where input2 > 0)"), ImmutableList.of(BIGINT)) .source( p.values( p.symbol("input1"), p.symbol("input2"))))) .doesNotFire(); }