@Override protected Type visitQuantifiedComparisonExpression(QuantifiedComparisonExpression node, StackableAstVisitorContext<Context> context) { Expression value = node.getValue(); process(value, context); Expression subquery = node.getSubquery(); process(subquery, context); Type comparisonType = coerceToSingleType(context, node, "Value expression and result of subquery must be of the same type for quantified comparison: %s vs %s", value, subquery); switch (node.getOperator()) { case LESS_THAN: case LESS_THAN_OR_EQUAL: case GREATER_THAN: case GREATER_THAN_OR_EQUAL: if (!comparisonType.isOrderable()) { throw new SemanticException(TYPE_MISMATCH, node, "Type [%s] must be orderable in order to be used in quantified comparison", comparisonType); } break; case EQUAL: case NOT_EQUAL: if (!comparisonType.isComparable()) { throw new SemanticException(TYPE_MISMATCH, node, "Type [%s] must be comparable in order to be used in quantified comparison", comparisonType); } break; default: throw new IllegalStateException(format("Unexpected comparison type: %s", node.getOperator())); } return setExpressionType(node, BOOLEAN); }
process(expression, context); Type type = getExpressionType(expression); if (!type.isComparable()) { process(sortItem.getSortKey(), context); Type type = getExpressionType(sortItem.getSortKey()); if (!type.isOrderable()) { Type type = process(frame.getStart().getValue().get(), context); if (!type.equals(INTEGER) && !type.equals(BIGINT)) { throw new SemanticException(TYPE_MISMATCH, node, "Window frame start value type must be INTEGER or BIGINT(actual %s)", type); Type type = process(frame.getEnd().get().getValue().get(), context); if (!type.equals(INTEGER) && !type.equals(BIGINT)) { throw new SemanticException(TYPE_MISMATCH, node, "Window frame end value type must be INTEGER or BIGINT (actual %s)", type); process(expression, context); argumentTypesBuilder.add(new TypeSignatureProvider(process(expression, context).getTypeSignature())); Type sortKeyType = process(sortItem.getSortKey(), context); if (!sortKeyType.isOrderable()) { throw new SemanticException(TYPE_MISMATCH, node, "ORDER BY can only be applied to orderable types (actual: %s)", sortKeyType.getDisplayName()); process(expression, new StackableAstVisitorContext<>(context.getContext().expectingLambda(expectedFunctionType.getArgumentTypes()))); coerceType(expression, actualType, expectedType, format("Function %s argument %d", function, i));
@Override protected Type visitArithmeticUnary(ArithmeticUnaryExpression node, StackableAstVisitorContext<Context> context) { switch (node.getSign()) { case PLUS: Type type = process(node.getValue(), context); if (!type.equals(DOUBLE) && !type.equals(REAL) && !type.equals(BIGINT) && !type.equals(INTEGER) && !type.equals(SMALLINT) && !type.equals(TINYINT)) { // TODO: figure out a type-agnostic way of dealing with this. Maybe add a special unary operator // that types can chose to implement, or piggyback on the existence of the negation operator throw new SemanticException(TYPE_MISMATCH, node, "Unary '+' operator cannot by applied to %s type", type); } return setExpressionType(node, type); case MINUS: return getOperator(context, node, OperatorType.NEGATION, node.getValue()); } throw new UnsupportedOperationException("Unsupported unary operator: " + node.getSign()); }
process(expression, context); Type type = expressionTypes.get(expression); if (!type.isComparable()) { process(sortItem.getSortKey(), context); Type type = expressionTypes.get(sortItem.getSortKey()); if (!type.isOrderable()) { Type type = process(frame.getStart().getValue().get(), context); if (!type.equals(BIGINT)) { throw new SemanticException(TYPE_MISMATCH, node, "Window frame start value type must be BIGINT (actual %s)", type); Type type = process(frame.getEnd().get().getValue().get(), context); if (!type.equals(BIGINT)) { throw new SemanticException(TYPE_MISMATCH, node, "Window frame end value type must be BIGINT (actual %s)", type); argumentTypes.add(process(expression, context).getTypeSignature()); throw new SemanticException(TYPE_MISMATCH, node, "DISTINCT can only be applied to comparable types (actual: %s)", type); coerceType(context, expression, type, format("Function %s argument %d", function, i));
assertColumnPrefix(qualifiedName, node); Type baseType = process(node.getBase(), context); if (!(baseType instanceof RowType)) { throw new SemanticException(TYPE_MISMATCH, node.getBase(), "Expression %s is not of type ROW", node.getBase()); throw createMissingAttributeException(node);
private Type coerceToSingleType(StackableAstVisitorContext<AnalysisContext> context, Node node, String message, Expression first, Expression second) { Type firstType = null; if (first != null) { firstType = process(first, context); } Type secondType = null; if (second != null) { secondType = process(second, context); } if (firstType == null) { return secondType; } if (secondType == null) { return firstType; } if (firstType.equals(secondType)) { return firstType; } // coerce types if possible if (canCoerce(firstType, secondType)) { expressionCoercions.put(first, secondType); return secondType; } if (canCoerce(secondType, firstType)) { expressionCoercions.put(second, firstType); return firstType; } throw new SemanticException(TYPE_MISMATCH, node, message, firstType, secondType); }
private Type coerceToSingleType(StackableAstVisitorContext<AnalysisContext> context, String message, List<Expression> expressions) { // determine super type Type superType = UNKNOWN; for (Expression expression : expressions) { Optional<Type> newSuperType = typeManager.getCommonSuperType(superType, process(expression, context)); if (!newSuperType.isPresent()) { throw new SemanticException(TYPE_MISMATCH, expression, message, superType); } superType = newSuperType.get(); } // verify all expressions can be coerced to the superType for (Expression expression : expressions) { Type type = process(expression, context); if (!type.equals(superType)) { if (!canCoerce(type, superType)) { throw new SemanticException(TYPE_MISMATCH, expression, message, superType); } expressionCoercions.put(expression, superType); } } return superType; } }
@Override public Type visitCast(Cast node, StackableAstVisitorContext<AnalysisContext> context) { Type type = typeManager.getType(parseTypeSignature(node.getType())); if (type == null) { throw new SemanticException(TYPE_MISMATCH, node, "Unknown type: " + node.getType()); } if (type.equals(UNKNOWN)) { throw new SemanticException(TYPE_MISMATCH, node, "UNKNOWN is not a valid type"); } Type value = process(node.getExpression(), context); if (!value.equals(UNKNOWN) && !node.isTypeOnly()) { try { functionRegistry.getCoercion(value, type); } catch (OperatorNotFoundException e) { throw new SemanticException(TYPE_MISMATCH, node, "Cannot cast %s to %s", value, type); } } expressionTypes.put(node, type); return type; }
private Type getOperator(StackableAstVisitorContext<AnalysisContext> context, Expression node, OperatorType operatorType, Expression... arguments) { ImmutableList.Builder<Type> argumentTypes = ImmutableList.builder(); for (Expression expression : arguments) { argumentTypes.add(process(expression, context)); } Signature operatorSignature; try { operatorSignature = functionRegistry.resolveOperator(operatorType, argumentTypes.build()); } catch (OperatorNotFoundException e) { throw new SemanticException(TYPE_MISMATCH, node, "%s", e.getMessage()); } for (int i = 0; i < arguments.length; i++) { Expression expression = arguments[i]; Type type = typeManager.getType(operatorSignature.getArgumentTypes().get(i)); coerceType(context, expression, type, format("Operator %s argument %d", operatorSignature, i)); } Type type = typeManager.getType(operatorSignature.getReturnType()); expressionTypes.put(node, type); return type; }
@Override protected Type visitInPredicate(InPredicate node, StackableAstVisitorContext<AnalysisContext> context) { Expression value = node.getValue(); process(value, context); Expression valueList = node.getValueList(); process(valueList, context); if (valueList instanceof InListExpression) { InListExpression inListExpression = (InListExpression) valueList; coerceToSingleType(context, "IN value and list items must be the same type: %s", ImmutableList.<Expression>builder().add(value).addAll(inListExpression.getValues()).build()); } else if (valueList instanceof SubqueryExpression) { coerceToSingleType(context, node, "value and result of subquery must be of the same type for IN expression: %s vs %s", value, valueList); } expressionTypes.put(node, BOOLEAN); return BOOLEAN; }
@Override protected Type visitAtTimeZone(AtTimeZone node, StackableAstVisitorContext<AnalysisContext> context) { Type valueType = process(node.getValue(), context); process(node.getTimeZone(), context); if (!valueType.equals(TIME_WITH_TIME_ZONE) && !valueType.equals(TIMESTAMP_WITH_TIME_ZONE) && !valueType.equals(TIME) && !valueType.equals(TIMESTAMP)) { throw new SemanticException(TYPE_MISMATCH, node.getValue(), "Type of value must be a time or timestamp with or without time zone (actual %s)", valueType); } Type resultType = valueType; if (valueType.equals(TIME)) { resultType = TIME_WITH_TIME_ZONE; } else if (valueType.equals(TIMESTAMP)) { resultType = TIMESTAMP_WITH_TIME_ZONE; } expressionTypes.put(node, resultType); return resultType; }
@Override protected Type visitArithmeticUnary(ArithmeticUnaryExpression node, StackableAstVisitorContext<AnalysisContext> context) { switch (node.getSign()) { case PLUS: Type type = process(node.getValue(), context); if (!type.equals(BIGINT) && !type.equals(DOUBLE)) { // TODO: figure out a type-agnostic way of dealing with this. Maybe add a special unary operator // that types can chose to implement, or piggyback on the existence of the negation operator throw new SemanticException(TYPE_MISMATCH, node, "Unary '+' operator cannot by applied to %s type", type); } expressionTypes.put(node, type); return type; case MINUS: return getOperator(context, node, OperatorType.NEGATION, node.getValue()); } throw new UnsupportedOperationException("Unsupported unary operator: " + node.getSign()); }
@Override protected Type visitSearchedCaseExpression(SearchedCaseExpression node, StackableAstVisitorContext<AnalysisContext> context) { for (WhenClause whenClause : node.getWhenClauses()) { coerceType(context, whenClause.getOperand(), BOOLEAN, "CASE WHEN clause"); } Type type = coerceToSingleType(context, "All CASE results must be the same type: %s", getCaseResultExpressions(node.getWhenClauses(), node.getDefaultValue())); expressionTypes.put(node, type); for (WhenClause whenClause : node.getWhenClauses()) { Type whenClauseType = process(whenClause.getResult(), context); requireNonNull(whenClauseType, format("Expression types does not contain an entry for %s", whenClause)); expressionTypes.put(whenClause, whenClauseType); } return type; }
@Override protected Type visitSimpleCaseExpression(SimpleCaseExpression node, StackableAstVisitorContext<AnalysisContext> context) { for (WhenClause whenClause : node.getWhenClauses()) { coerceToSingleType(context, whenClause, "CASE operand type does not match WHEN clause operand type: %s vs %s", node.getOperand(), whenClause.getOperand()); } Type type = coerceToSingleType(context, "All CASE results must be the same type: %s", getCaseResultExpressions(node.getWhenClauses(), node.getDefaultValue())); expressionTypes.put(node, type); for (WhenClause whenClause : node.getWhenClauses()) { Type whenClauseType = process(whenClause.getResult(), context); requireNonNull(whenClauseType, format("Expression types does not contain an entry for %s", whenClause)); expressionTypes.put(whenClause, whenClauseType); } return type; }
@Override protected Type visitExtract(Extract node, StackableAstVisitorContext<AnalysisContext> context) { Type type = process(node.getExpression(), context); if (!isDateTimeType(type)) { throw new SemanticException(TYPE_MISMATCH, node.getExpression(), "Type of argument to extract must be DATE, TIME, TIMESTAMP, or INTERVAL (actual %s)", type); } Extract.Field field = node.getField(); if ((field == TIMEZONE_HOUR || field == TIMEZONE_MINUTE) && !(type.equals(TIME_WITH_TIME_ZONE) || type.equals(TIMESTAMP_WITH_TIME_ZONE))) { throw new SemanticException(TYPE_MISMATCH, node.getExpression(), "Type of argument to extract time zone field must have a time zone (actual %s)", type); } expressionTypes.put(node, BIGINT); return BIGINT; }
@Override protected Type visitLikePredicate(LikePredicate node, StackableAstVisitorContext<AnalysisContext> context) { Type valueType = getVarcharType(node.getValue(), context); Type patternType = getVarcharType(node.getPattern(), context); coerceType(context, node.getValue(), valueType, "Left side of LIKE expression"); coerceType(context, node.getPattern(), patternType, "Pattern for LIKE expression"); if (node.getEscape() != null) { Type escapeType = getVarcharType(node.getEscape(), context); coerceType(context, node.getEscape(), escapeType, "Escape for LIKE expression"); } expressionTypes.put(node, BOOLEAN); return BOOLEAN; }
@Override protected Type visitIfExpression(IfExpression node, StackableAstVisitorContext<AnalysisContext> context) { coerceType(context, node.getCondition(), BOOLEAN, "IF condition"); Type type; if (node.getFalseValue().isPresent()) { type = coerceToSingleType(context, node, "Result types for IF must be the same: %s vs %s", node.getTrueValue(), node.getFalseValue().get()); } else { type = process(node.getTrueValue(), context); } expressionTypes.put(node, type); return type; }
private void coerceType(StackableAstVisitorContext<AnalysisContext> context, Expression expression, Type expectedType, String message) { Type actualType = process(expression, context); if (!actualType.equals(expectedType)) { if (!canCoerce(actualType, expectedType)) { throw new SemanticException(TYPE_MISMATCH, expression, message + " must evaluate to a %s (actual: %s)", expectedType, actualType); } expressionCoercions.put(expression, expectedType); } }
@Override protected Type visitComparisonExpression(ComparisonExpression node, StackableAstVisitorContext<AnalysisContext> context) { OperatorType operatorType; if (node.getType() == ComparisonExpression.Type.IS_DISTINCT_FROM) { operatorType = OperatorType.EQUAL; } else { operatorType = OperatorType.valueOf(node.getType().name()); } return getOperator(context, node, operatorType, node.getLeft(), node.getRight()); }
@Override protected Type visitNotExpression(NotExpression node, StackableAstVisitorContext<AnalysisContext> context) { coerceType(context, node.getValue(), BOOLEAN, "Value of logical NOT expression"); expressionTypes.put(node, BOOLEAN); return BOOLEAN; }