@Test(expectedExceptions = IllegalArgumentException.class) public void testInvalidEqualityExpression3() { EqualityInference.Builder builder = new EqualityInference.Builder(); addEquality("a1", "a1", builder); }
public Builder extractInferenceCandidates(Expression expression) { return addAllEqualities(filter(extractConjuncts(expression), isInferenceCandidate())); }
public Builder addEquality(Expression expression) { expression = normalizeInPredicateToEquality(expression); checkArgument(isInferenceCandidate().apply(expression), "Expression must be a simple equality: " + expression); ComparisonExpression comparison = (ComparisonExpression) expression; addEquality(comparison.getLeft(), comparison.getRight()); return this; }
@Test public void testEqualityPartitionGeneration() EqualityInference.Builder builder = new EqualityInference.Builder(); builder.addEquality(nameReference("a1"), nameReference("b1")); builder.addEquality(add("a1", "a1"), multiply(nameReference("a1"), number(2))); builder.addEquality(nameReference("b1"), nameReference("c1")); builder.addEquality(add("a1", "a1"), nameReference("c1")); builder.addEquality(add("a1", "b1"), nameReference("c1")); EqualityInference inference = builder.build(); EqualityInference newInference = new EqualityInference.Builder() .addAllEqualities(equalityPartition.getScopeEqualities()) .addAllEqualities(equalityPartition.getScopeComplementEqualities()) .addAllEqualities(equalityPartition.getScopeStraddlingEqualities()) .build();
@Test public void testEqualityPartitionGeneration() EqualityInference.Builder builder = new EqualityInference.Builder(); builder.addEquality(nameReference("a1"), nameReference("b1")); builder.addEquality(add("a1", "a1"), multiply(nameReference("a1"), number(2))); builder.addEquality(nameReference("b1"), nameReference("c1")); builder.addEquality(add("a1", "a1"), nameReference("c1")); builder.addEquality(add("a1", "b1"), nameReference("c1")); EqualityInference inference = builder.build(); EqualityInference newInference = new EqualityInference.Builder() .addAllEqualities(equalityPartition.getScopeEqualities()) .addAllEqualities(equalityPartition.getScopeComplementEqualities()) .addAllEqualities(equalityPartition.getScopeStraddlingEqualities()) .build();
@Test public void testMultipleEqualitySetsPredicateGeneration() EqualityInference.Builder builder = new EqualityInference.Builder(); addEquality("a1", "b1", builder); addEquality("b1", "c1", builder); addEquality("c2", "d2", builder); EqualityInference inference = builder.build(); EqualityInference newInference = new EqualityInference.Builder() .addAllEqualities(equalityPartition.getScopeEqualities()) .addAllEqualities(equalityPartition.getScopeComplementEqualities()) .addAllEqualities(equalityPartition.getScopeStraddlingEqualities()) .build();
@Test public void testTransitivity() EqualityInference.Builder builder = new EqualityInference.Builder(); addEquality("a1", "b1", builder); addEquality("b1", "c1", builder); addEquality("c2", "d2", builder); EqualityInference inference = builder.build();
@Test public void testTransitivity() EqualityInference.Builder builder = new EqualityInference.Builder(); addEquality("a1", "b1", builder); addEquality("b1", "c1", builder); addEquality("c2", "d2", builder); EqualityInference inference = builder.build();
@Test public void testMultipleEqualitySetsPredicateGeneration() EqualityInference.Builder builder = new EqualityInference.Builder(); addEquality("a1", "b1", builder); addEquality("b1", "c1", builder); addEquality("c2", "d2", builder); EqualityInference inference = builder.build(); EqualityInference newInference = new EqualityInference.Builder() .addAllEqualities(equalityPartition.getScopeEqualities()) .addAllEqualities(equalityPartition.getScopeComplementEqualities()) .addAllEqualities(equalityPartition.getScopeStraddlingEqualities()) .build();
@Test public void testExpressionsThatMayReturnNullOnNonNullInput() { List<Expression> candidates = ImmutableList.of( new Cast(nameReference("b"), "BIGINT", true), // try_cast new FunctionCall(QualifiedName.of("try"), ImmutableList.of(nameReference("b"))), new NullIfExpression(nameReference("b"), number(1)), new IfExpression(nameReference("b"), number(1), new NullLiteral()), new DereferenceExpression(nameReference("b"), identifier("x")), new InPredicate(nameReference("b"), new InListExpression(ImmutableList.of(new NullLiteral()))), new SearchedCaseExpression(ImmutableList.of(new WhenClause(new IsNotNullPredicate(nameReference("b")), new NullLiteral())), Optional.empty()), new SimpleCaseExpression(nameReference("b"), ImmutableList.of(new WhenClause(number(1), new NullLiteral())), Optional.empty()), new SubscriptExpression(new ArrayConstructor(ImmutableList.of(new NullLiteral())), nameReference("b"))); for (Expression candidate : candidates) { EqualityInference.Builder builder = new EqualityInference.Builder(); builder.extractInferenceCandidates(equals(nameReference("b"), nameReference("x"))); builder.extractInferenceCandidates(equals(nameReference("a"), candidate)); EqualityInference inference = builder.build(); List<Expression> equalities = inference.generateEqualitiesPartitionedBy(matchesSymbols("b")).getScopeStraddlingEqualities(); assertEquals(equalities.size(), 1); assertTrue(equalities.get(0).equals(equals(nameReference("x"), nameReference("b"))) || equalities.get(0).equals(equals(nameReference("b"), nameReference("x")))); } }
@Test public void testExpressionsThatMayReturnNullOnNonNullInput() { List<Expression> candidates = ImmutableList.of( new Cast(nameReference("b"), "BIGINT", true), // try_cast new FunctionCall(QualifiedName.of("try"), ImmutableList.of(nameReference("b"))), new NullIfExpression(nameReference("b"), number(1)), new IfExpression(nameReference("b"), number(1), new NullLiteral()), new DereferenceExpression(nameReference("b"), identifier("x")), new InPredicate(nameReference("b"), new InListExpression(ImmutableList.of(new NullLiteral()))), new SearchedCaseExpression(ImmutableList.of(new WhenClause(new IsNotNullPredicate(nameReference("b")), new NullLiteral())), Optional.empty()), new SimpleCaseExpression(nameReference("b"), ImmutableList.of(new WhenClause(number(1), new NullLiteral())), Optional.empty()), new SubscriptExpression(new ArrayConstructor(ImmutableList.of(new NullLiteral())), nameReference("b"))); for (Expression candidate : candidates) { EqualityInference.Builder builder = new EqualityInference.Builder(); builder.extractInferenceCandidates(equals(nameReference("b"), nameReference("x"))); builder.extractInferenceCandidates(equals(nameReference("a"), candidate)); EqualityInference inference = builder.build(); List<Expression> equalities = inference.generateEqualitiesPartitionedBy(matchesSymbols("b")).getScopeStraddlingEqualities(); assertEquals(equalities.size(), 1); assertTrue(equalities.get(0).equals(equals(nameReference("x"), nameReference("b"))) || equalities.get(0).equals(equals(nameReference("b"), nameReference("x")))); } }
@Test public void testConstantEqualities() { EqualityInference.Builder builder = new EqualityInference.Builder(); addEquality("a1", "b1", builder); addEquality("b1", "c1", builder); builder.addEquality(nameReference("c1"), number(1)); EqualityInference inference = builder.build(); // Should always prefer a constant if available (constant is part of all scopes) assertEquals(inference.rewriteExpression(nameReference("a1"), matchesSymbols("a1", "b1")), number(1)); // All scope equalities should utilize the constant if possible EqualityInference.EqualityPartition equalityPartition = inference.generateEqualitiesPartitionedBy(matchesSymbols("a1", "b1")); assertEquals(equalitiesAsSets(equalityPartition.getScopeEqualities()), set(set(nameReference("a1"), number(1)), set(nameReference("b1"), number(1)))); assertEquals(equalitiesAsSets(equalityPartition.getScopeComplementEqualities()), set(set(nameReference("c1"), number(1)))); // There should be no scope straddling equalities as the full set of equalities should be already represented by the scope and inverse scope assertTrue(equalityPartition.getScopeStraddlingEqualities().isEmpty()); }
@Test public void testConstantEqualities() { EqualityInference.Builder builder = new EqualityInference.Builder(); addEquality("a1", "b1", builder); addEquality("b1", "c1", builder); builder.addEquality(nameReference("c1"), number(1)); EqualityInference inference = builder.build(); // Should always prefer a constant if available (constant is part of all scopes) assertEquals(inference.rewriteExpression(nameReference("a1"), matchesSymbols("a1", "b1")), number(1)); // All scope equalities should utilize the constant if possible EqualityInference.EqualityPartition equalityPartition = inference.generateEqualitiesPartitionedBy(matchesSymbols("a1", "b1")); assertEquals(equalitiesAsSets(equalityPartition.getScopeEqualities()), set(set(nameReference("a1"), number(1)), set(nameReference("b1"), number(1)))); assertEquals(equalitiesAsSets(equalityPartition.getScopeComplementEqualities()), set(set(nameReference("c1"), number(1)))); // There should be no scope straddling equalities as the full set of equalities should be already represented by the scope and inverse scope assertTrue(equalityPartition.getScopeStraddlingEqualities().isEmpty()); }
@Test public void testSubExpressionRewrites() { EqualityInference.Builder builder = new EqualityInference.Builder(); builder.addEquality(nameReference("a1"), add("b", "c")); // a1 = b + c builder.addEquality(nameReference("a2"), multiply(nameReference("b"), add("b", "c"))); // a2 = b * (b + c) builder.addEquality(nameReference("a3"), multiply(nameReference("a1"), add("b", "c"))); // a3 = a1 * (b + c) EqualityInference inference = builder.build(); // Expression (b + c) should get entirely rewritten as a1 assertEquals(inference.rewriteExpression(add("b", "c"), symbolBeginsWith("a")), nameReference("a1")); // Only the sub-expression (b + c) should get rewritten in terms of a* assertEquals(inference.rewriteExpression(multiply(nameReference("ax"), add("b", "c")), symbolBeginsWith("a")), multiply(nameReference("ax"), nameReference("a1"))); // To be compliant, could rewrite either the whole expression, or just the sub-expression. Rewriting larger expressions are preferred assertEquals(inference.rewriteExpression(multiply(nameReference("a1"), add("b", "c")), symbolBeginsWith("a")), nameReference("a3")); }
@Test public void testSubExpressionRewrites() { EqualityInference.Builder builder = new EqualityInference.Builder(); builder.addEquality(nameReference("a1"), add("b", "c")); // a1 = b + c builder.addEquality(nameReference("a2"), multiply(nameReference("b"), add("b", "c"))); // a2 = b * (b + c) builder.addEquality(nameReference("a3"), multiply(nameReference("a1"), add("b", "c"))); // a3 = a1 * (b + c) EqualityInference inference = builder.build(); // Expression (b + c) should get entirely rewritten as a1 assertEquals(inference.rewriteExpression(add("b", "c"), symbolBeginsWith("a")), nameReference("a1")); // Only the sub-expression (b + c) should get rewritten in terms of a* assertEquals(inference.rewriteExpression(multiply(nameReference("ax"), add("b", "c")), symbolBeginsWith("a")), multiply(nameReference("ax"), nameReference("a1"))); // To be compliant, could rewrite either the whole expression, or just the sub-expression. Rewriting larger expressions are preferred assertEquals(inference.rewriteExpression(multiply(nameReference("a1"), add("b", "c")), symbolBeginsWith("a")), nameReference("a3")); }
private static void addEquality(String symbol1, String symbol2, EqualityInference.Builder builder) { builder.addEquality(nameReference(symbol1), nameReference(symbol2)); }
private static void addEquality(String symbol1, String symbol2, EqualityInference.Builder builder) { builder.addEquality(nameReference(symbol1), nameReference(symbol2)); }
public Builder extractInferenceCandidates(Expression expression) { return addAllEqualities(filter(extractConjuncts(expression), isInferenceCandidate())); }
public Builder addAllEqualities(Iterable<Expression> expressions) { for (Expression expression : expressions) { addEquality(expression); } return this; }
public EqualityInference build() { generateMoreEquivalences(); return new EqualityInference(equalities.getEquivalentClasses(), derivedExpressions); } }