private AggregationNode buildAggregation(PlanBuilder planBuilder) { Symbol avg = planBuilder.symbol("avg"); Symbol arrayAgg = planBuilder.symbol("array_agg"); Symbol input = planBuilder.symbol("input"); Symbol key = planBuilder.symbol("key"); Symbol keyHash = planBuilder.symbol("keyHash"); Symbol mask = planBuilder.symbol("mask"); List<Symbol> sourceSymbols = ImmutableList.of(input, key, keyHash, mask); return planBuilder.aggregation(aggregationBuilder -> aggregationBuilder .singleGroupingSet(key) .addAggregation(avg, planBuilder.expression("avg(input order by input)"), ImmutableList.of(BIGINT), mask) .addAggregation(arrayAgg, planBuilder.expression("array_agg(input order by input)"), ImmutableList.of(BIGINT), mask) .hashSymbol(keyHash) .source(planBuilder.values(sourceSymbols, ImmutableList.of()))); } }
@Test public void rewritesScalarSubquery() { tester().assertThat(rule) .on(p -> p.lateral( ImmutableList.of(p.symbol("corr")), p.values(p.symbol("corr")), p.enforceSingleRow( p.filter( p.expression("1 = a"), // TODO use correlated predicate, it requires support for correlated subqueries in plan matchers p.values(ImmutableList.of(p.symbol("a")), ONE_ROW))))) .matches( lateral( ImmutableList.of("corr"), values("corr"), filter( "1 = a", values("a")))); }
private Function<PlanBuilder, PlanNode> crossJoinAndJoin(JoinNode.Type secondJoinType) { return p -> { Symbol axSymbol = p.symbol("axSymbol"); Symbol bySymbol = p.symbol("bySymbol"); Symbol cxSymbol = p.symbol("cxSymbol"); Symbol cySymbol = p.symbol("cySymbol"); // (a inner join b) inner join c on c.x = a.x and c.y = b.y return p.join(INNER, p.join(secondJoinType, p.values(axSymbol), p.values(bySymbol)), p.values(cxSymbol, cySymbol), new EquiJoinClause(cxSymbol, axSymbol), new EquiJoinClause(cySymbol, bySymbol)); }; }
public RuleAssert on(Function<PlanBuilder, PlanNode> planProvider) { checkArgument(plan == null, "plan has already been set"); PlanBuilder builder = new PlanBuilder(idAllocator, metadata); plan = planProvider.apply(builder); types = builder.getTypes(); return this; }
private ProjectNode buildProjectedFilter(PlanBuilder planBuilder, Predicate<Symbol> projectionFilter) { Symbol a = planBuilder.symbol("a"); Symbol b = planBuilder.symbol("b"); return planBuilder.project( Assignments.identity(Stream.of(a, b).filter(projectionFilter).collect(toImmutableSet())), planBuilder.filter( planBuilder.expression("b > 5"), planBuilder.values(a, b))); } }
@Test public void doesNotFireIfRuleNotChangePlan() { tester().assertThat(pickTableLayout.pickTableLayoutForPredicate()) .on(p -> p.filter(expression("nationkey % 17 = BIGINT '44' AND nationkey % 15 = BIGINT '43'"), p.tableScan( nationTableHandle, ImmutableList.of(p.symbol("nationkey", BIGINT)), ImmutableMap.of(p.symbol("nationkey", BIGINT), new TpchColumnHandle("nationkey", BIGINT)), Optional.of(nationTableLayoutHandle), TupleDomain.all(), TupleDomain.all()))) .doesNotFire(); }
@Test public void testDoesNotFireWhenAggregationDoesNotHaveSymbols() { tester().assertThat(new PushAggregationThroughOuterJoin()) .on(p -> p.aggregation(ab -> ab .source(p.join( JoinNode.Type.LEFT, p.values(ImmutableList.of(p.symbol("COL1")), ImmutableList.of(expressions("10"))), p.values(ImmutableList.of(p.symbol("COL2")), ImmutableList.of(expressions("20"))), ImmutableList.of(new JoinNode.EquiJoinClause(new Symbol("COL1"), new Symbol("COL2"))), ImmutableList.of(new Symbol("COL1"), new Symbol("COL2")), Optional.empty(), Optional.empty(), Optional.empty())) .addAggregation(new Symbol("SUM"), PlanBuilder.expression("sum(COL1)"), ImmutableList.of(DOUBLE)) .singleGroupingSet(new Symbol("COL1")))) .doesNotFire(); } }
@Test public void testReplicateScalar() { assertDetermineJoinDistributionType() .on(p -> p.join( INNER, p.values(ImmutableList.of(p.symbol("A1")), ImmutableList.of(expressions("10"), expressions("11"))), p.enforceSingleRow( p.values(ImmutableList.of(p.symbol("B1")), ImmutableList.of(expressions("50"), expressions("11")))), ImmutableList.of(new JoinNode.EquiJoinClause(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT))), ImmutableList.of(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT)), Optional.empty())) .setSystemProperty(JOIN_DISTRIBUTION_TYPE, JoinDistributionType.PARTITIONED.name()) .matches(join( INNER, ImmutableList.of(equiJoinClause("A1", "B1")), Optional.empty(), Optional.of(DistributionType.REPLICATED), values(ImmutableMap.of("A1", 0)), enforceSingleRow(values(ImmutableMap.of("B1", 0))))); }
@Test public void testFiresOnCountAggregateOverValues() { tester().assertThat(new PruneCountAggregationOverScalar()) .on(p -> p.aggregation((a) -> a .addAggregation( p.symbol("count_1", BigintType.BIGINT), new FunctionCall(QualifiedName.of("count"), ImmutableList.of()), ImmutableList.of(BigintType.BIGINT)) .step(AggregationNode.Step.SINGLE) .globalGrouping() .source(p.values(ImmutableList.of(p.symbol("orderkey")), ImmutableList.of(p.expressions("1")))))) .matches(values(ImmutableMap.of("count_1", 0))); }
@Test public void testValueExpressionNotRewritten() { tester().assertThat(zeroRewriter.valuesExpressionRewrite()) .on(p -> p.values( ImmutableList.<Symbol>of(p.symbol("a")), ImmutableList.of((ImmutableList.of(PlanBuilder.expression("0")))))) .doesNotFire(); }
@Test public void testDoesNotFire() { tester().assertThat(new RemoveTrivialFilters()) .on(p -> p.filter(p.expression("1 = 1"), p.values())) .doesNotFire(); }
@Test public void testValidateSuccessful() { validatePlan( p -> p.aggregation( a -> a.step(SINGLE) .singleGroupingSet(p.symbol("nationkey")) .source( p.tableScan( nationTableHandle, ImmutableList.of(p.symbol("nationkey", BIGINT)), ImmutableMap.of(p.symbol("nationkey", BIGINT), new TpchColumnHandle("nationkey", BIGINT)), Optional.of(nationTableLayoutHandle))))); validatePlan( p -> p.aggregation( a -> a.step(SINGLE) .singleGroupingSet(p.symbol("unique"), p.symbol("nationkey")) .preGroupedSymbols(p.symbol("unique"), p.symbol("nationkey")) .source( p.assignUniqueId(p.symbol("unique"), p.tableScan( nationTableHandle, ImmutableList.of(p.symbol("nationkey", BIGINT)), ImmutableMap.of(p.symbol("nationkey", BIGINT), new TpchColumnHandle("nationkey", BIGINT)), Optional.of(nationTableLayoutHandle)))))); }
@Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "Final aggregation with default value not separated from partial aggregation by local hash exchange") public void testWithPartialAggregationBelowJoinWithoutSeparatingExchange() { Symbol symbol = new Symbol("symbol"); PlanNode root = builder.aggregation( af -> af.step(FINAL) .groupingSets(groupingSets(ImmutableList.of(symbol), 2, ImmutableSet.of(0))) .source(builder.join( INNER, builder.aggregation(ap -> ap .step(PARTIAL) .groupingSets(groupingSets(ImmutableList.of(symbol), 2, ImmutableSet.of(0))) .source(tableScanNode)), builder.values()))); validatePlan(root, true); }
@Test public void testFiresOnCountAggregateOverEnforceSingleRow() { tester().assertThat(new PruneCountAggregationOverScalar()) .on(p -> p.aggregation((a) -> a .addAggregation( p.symbol("count_1", BigintType.BIGINT), new FunctionCall(QualifiedName.of("count"), ImmutableList.of()), ImmutableList.of(BigintType.BIGINT)) .step(AggregationNode.Step.SINGLE) .globalGrouping() .source(p.enforceSingleRow(p.tableScan(ImmutableList.of(), ImmutableMap.of()))))) .matches(values(ImmutableMap.of("count_1", 0))); }
@Test public void testAggregationExpressionNotRewritten() { tester().assertThat(functionCallRewriter.aggregationExpressionRewrite()) .on(p -> p.aggregation(a -> a .globalGrouping() .addAggregation( p.symbol("count_1", DateType.DATE), nowCall, ImmutableList.of()) .source( p.values()))) .doesNotFire(); }
@Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "Streaming aggregation with input not grouped on the grouping keys") public void testValidateFailed() { validatePlan( p -> p.aggregation( a -> a.step(SINGLE) .singleGroupingSet(p.symbol("nationkey")) .preGroupedSymbols(p.symbol("nationkey")) .source( p.tableScan( nationTableHandle, ImmutableList.of(p.symbol("nationkey", BIGINT)), ImmutableMap.of(p.symbol("nationkey", BIGINT), new TpchColumnHandle("nationkey", BIGINT)), Optional.of(nationTableLayoutHandle))))); }