private ProjectNode buildProjectedAggregation(PlanBuilder planBuilder, Predicate<Symbol> projectionFilter) { Symbol a = planBuilder.symbol("a"); Symbol b = planBuilder.symbol("b"); Symbol key = planBuilder.symbol("key"); return planBuilder.project( Assignments.identity(ImmutableList.of(a, b).stream().filter(projectionFilter).collect(toImmutableSet())), planBuilder.aggregation(aggregationBuilder -> aggregationBuilder .source(planBuilder.values(key)) .singleGroupingSet(key) .addAggregation(a, planBuilder.expression("count()"), ImmutableList.of()) .addAggregation(b, planBuilder.expression("count()"), ImmutableList.of()))); } }
Optional.empty())) .addAggregation(p.symbol("AVG", DOUBLE), PlanBuilder.expression("avg(COL2)"), ImmutableList.of(DOUBLE)) .singleGroupingSet(p.symbol("COL1")))) .matches( project(ImmutableMap.of(
Optional.empty())) .addAggregation(p.symbol("AVG", DOUBLE), PlanBuilder.expression("avg(COL2)"), ImmutableList.of(DOUBLE)) .singleGroupingSet(p.symbol("COL1")))) .matches( project(ImmutableMap.of(
.addAggregation(pb.symbol("count", BIGINT), expression("count()"), ImmutableList.of()) .addAggregation(pb.symbol("count_on_x", BIGINT), expression("count(x)"), ImmutableList.of(BIGINT)) .singleGroupingSet(pb.symbol("y", BIGINT), pb.symbol("z", BIGINT)) .source(pb.values(pb.symbol("x", BIGINT), pb.symbol("y", BIGINT), pb.symbol("z", BIGINT))))) .withSourceStats(PlanNodeStatsEstimate.builder()
Optional.empty())) .addAggregation(new Symbol("AVG"), PlanBuilder.expression("avg(COL2)"), ImmutableList.of(DOUBLE)) .singleGroupingSet(new Symbol("COL1")))) .doesNotFire(); .build(), p.aggregation(builder -> builder.singleGroupingSet(p.symbol("COL1"), p.symbol("unused")) .source( p.values( Optional.empty())) .addAggregation(p.symbol("AVG", DOUBLE), PlanBuilder.expression("avg(COL2)"), ImmutableList.of(DOUBLE)) .singleGroupingSet(p.symbol("COL1")))) .doesNotFire();
Optional.of(p.symbol("RIGHT_HASH")))) .addAggregation(p.symbol("AVG", DOUBLE), expression("AVG(LEFT_AGGR)"), ImmutableList.of(DOUBLE)) .singleGroupingSet(p.symbol("LEFT_GROUP_BY"), p.symbol("RIGHT_GROUP_BY")) .step(PARTIAL))) .matches(project(ImmutableMap.of(
@Test public void testValidateSuccessful() { validatePlan( p -> p.aggregation( a -> a.step(SINGLE) .singleGroupingSet(p.symbol("nationkey")) .source( p.tableScan( nationTableHandle, ImmutableList.of(p.symbol("nationkey", BIGINT)), ImmutableMap.of(p.symbol("nationkey", BIGINT), new TpchColumnHandle("nationkey", BIGINT)), Optional.of(nationTableLayoutHandle))))); validatePlan( p -> p.aggregation( a -> a.step(SINGLE) .singleGroupingSet(p.symbol("unique"), p.symbol("nationkey")) .preGroupedSymbols(p.symbol("unique"), p.symbol("nationkey")) .source( p.assignUniqueId(p.symbol("unique"), p.tableScan( nationTableHandle, ImmutableList.of(p.symbol("nationkey", BIGINT)), ImmutableMap.of(p.symbol("nationkey", BIGINT), new TpchColumnHandle("nationkey", BIGINT)), Optional.of(nationTableLayoutHandle)))))); }
@Test public void testWithGroups() { tester().assertThat(new AddIntermediateAggregations()) .setSystemProperty(ENABLE_INTERMEDIATE_AGGREGATIONS, "true") .setSystemProperty(TASK_CONCURRENCY, "4") .on(p -> p.aggregation(af -> { af.singleGroupingSet(p.symbol("c")) .step(AggregationNode.Step.FINAL) .addAggregation(p.symbol("c"), expression("count(b)"), ImmutableList.of(BIGINT)) .source( p.gatheringExchange( ExchangeNode.Scope.REMOTE, p.aggregation(ap -> ap.singleGroupingSet(p.symbol("b")) .step(AggregationNode.Step.PARTIAL) .addAggregation(p.symbol("b"), expression("count(a)"), ImmutableList.of(BIGINT)) .source( p.values(p.symbol("a")))))); })) .doesNotFire(); }
private AggregationNode buildAggregation(PlanBuilder planBuilder, Predicate<Symbol> sourceSymbolFilter) { Symbol avg = planBuilder.symbol("avg"); Symbol input = planBuilder.symbol("input"); Symbol key = planBuilder.symbol("key"); Symbol keyHash = planBuilder.symbol("keyHash"); Symbol mask = planBuilder.symbol("mask"); Symbol unused = planBuilder.symbol("unused"); List<Symbol> sourceSymbols = ImmutableList.of(input, key, keyHash, mask, unused); return planBuilder.aggregation(aggregationBuilder -> aggregationBuilder .singleGroupingSet(key) .addAggregation(avg, planBuilder.expression("avg(input)"), ImmutableList.of(BIGINT), mask) .hashSymbol(keyHash) .source( planBuilder.values( sourceSymbols.stream() .filter(sourceSymbolFilter) .collect(toImmutableList()), ImmutableList.of()))); } }
@Test public void testDoesNotFireWhenAggregationDoesNotHaveSymbols() { tester().assertThat(new PushAggregationThroughOuterJoin()) .on(p -> p.aggregation(ab -> ab .source(p.join( JoinNode.Type.LEFT, p.values(ImmutableList.of(p.symbol("COL1")), ImmutableList.of(expressions("10"))), p.values(ImmutableList.of(p.symbol("COL2")), ImmutableList.of(expressions("20"))), ImmutableList.of(new JoinNode.EquiJoinClause(new Symbol("COL1"), new Symbol("COL2"))), ImmutableList.of(new Symbol("COL1"), new Symbol("COL2")), Optional.empty(), Optional.empty(), Optional.empty())) .addAggregation(new Symbol("SUM"), PlanBuilder.expression("sum(COL1)"), ImmutableList.of(DOUBLE)) .singleGroupingSet(new Symbol("COL1")))) .doesNotFire(); } }
private AggregationNode buildAggregation(PlanBuilder planBuilder) { Symbol avg = planBuilder.symbol("avg"); Symbol arrayAgg = planBuilder.symbol("array_agg"); Symbol input = planBuilder.symbol("input"); Symbol key = planBuilder.symbol("key"); Symbol keyHash = planBuilder.symbol("keyHash"); Symbol mask = planBuilder.symbol("mask"); List<Symbol> sourceSymbols = ImmutableList.of(input, key, keyHash, mask); return planBuilder.aggregation(aggregationBuilder -> aggregationBuilder .singleGroupingSet(key) .addAggregation(avg, planBuilder.expression("avg(input order by input)"), ImmutableList.of(BIGINT), mask) .addAggregation(arrayAgg, planBuilder.expression("array_agg(input order by input)"), ImmutableList.of(BIGINT), mask) .hashSymbol(keyHash) .source(planBuilder.values(sourceSymbols, ImmutableList.of()))); } }
@Test public void testDoesNotFireWhenGroupingOnInner() { tester().assertThat(new PushAggregationThroughOuterJoin()) .on(p -> p.aggregation(ab -> ab .source(p.join(JoinNode.Type.LEFT, p.values(ImmutableList.of(p.symbol("COL1")), ImmutableList.of(expressions("10"))), p.values(new Symbol("COL2"), new Symbol("COL3")), ImmutableList.of(new JoinNode.EquiJoinClause(new Symbol("COL1"), new Symbol("COL2"))), ImmutableList.of(new Symbol("COL1"), new Symbol("COL2")), Optional.empty(), Optional.empty(), Optional.empty())) .addAggregation(new Symbol("AVG"), PlanBuilder.expression("avg(COL2)"), ImmutableList.of(DOUBLE)) .singleGroupingSet(new Symbol("COL1"), new Symbol("COL3")))) .doesNotFire(); }
@Test public void testAggregationStatsCappedToInputRows() { tester().assertStatsFor(pb -> pb .aggregation(ab -> ab .addAggregation(pb.symbol("count_on_x", BIGINT), expression("count(x)"), ImmutableList.of(BIGINT)) .singleGroupingSet(pb.symbol("y", BIGINT), pb.symbol("z", BIGINT)) .source(pb.values(pb.symbol("x", BIGINT), pb.symbol("y", BIGINT), pb.symbol("z", BIGINT))))) .withSourceStats(PlanNodeStatsEstimate.builder() .setOutputRowCount(100) .addSymbolStatistics(new Symbol("y"), SymbolStatsEstimate.builder().setDistinctValuesCount(50).build()) .addSymbolStatistics(new Symbol("z"), SymbolStatsEstimate.builder().setDistinctValuesCount(50).build()) .build()) .check(check -> check.outputRowsCount(100)); } }
@Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "Streaming aggregation with input not grouped on the grouping keys") public void testValidateFailed() { validatePlan( p -> p.aggregation( a -> a.step(SINGLE) .singleGroupingSet(p.symbol("nationkey")) .preGroupedSymbols(p.symbol("nationkey")) .source( p.tableScan( nationTableHandle, ImmutableList.of(p.symbol("nationkey", BIGINT)), ImmutableMap.of(p.symbol("nationkey", BIGINT), new TpchColumnHandle("nationkey", BIGINT)), Optional.of(nationTableLayoutHandle))))); }
@Test public void doesNotFireOnCorrelatedWithNonScalarAggregation() { tester().assertThat(new TransformCorrelatedScalarAggregationToJoin(tester().getMetadata().getFunctionRegistry())) .on(p -> p.lateral( ImmutableList.of(p.symbol("corr")), p.values(p.symbol("corr")), p.aggregation(ab -> ab .source(p.values(p.symbol("a"), p.symbol("b"))) .addAggregation(p.symbol("sum"), PlanBuilder.expression("sum(a)"), ImmutableList.of(BIGINT)) .singleGroupingSet(p.symbol("b"))))) .doesNotFire(); }