From 523b1c449e5cda7333c6afebaa43a4ff49693f30 Mon Sep 17 00:00:00 2001 From: Obed Kooijman Date: Mon, 9 Dec 2024 02:35:50 +0100 Subject: [PATCH 1/4] added NULLIF translations --- .../NpgsqlSqlTranslatingExpressionVisitor.cs | 68 +++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/src/EFCore.PG/Query/Internal/NpgsqlSqlTranslatingExpressionVisitor.cs b/src/EFCore.PG/Query/Internal/NpgsqlSqlTranslatingExpressionVisitor.cs index 7e2fd3d47..bbbdc0afb 100644 --- a/src/EFCore.PG/Query/Internal/NpgsqlSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.PG/Query/Internal/NpgsqlSqlTranslatingExpressionVisitor.cs @@ -77,6 +77,74 @@ public NpgsqlSqlTranslatingExpressionVisitor( _timestampTzMapping = _typeMappingSource.FindMapping("timestamp with time zone")!; } + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override Expression VisitConditional(ConditionalExpression conditionalExpression) + { + var test = Visit(conditionalExpression.Test); + var ifTrue = Visit(conditionalExpression.IfTrue); + var ifFalse = Visit(conditionalExpression.IfFalse); + + if (TranslationFailed(conditionalExpression.Test, test, out var sqlTest) + || TranslationFailed(conditionalExpression.IfTrue, ifTrue, out var sqlIfTrue) + || TranslationFailed(conditionalExpression.IfFalse, ifFalse, out var sqlIfFalse)) + { + return QueryCompilationContext.NotTranslatedExpression; + } + + (SqlExpression, SqlExpression)? result = null; + + if (sqlTest is SqlBinaryExpression binary && sqlIfTrue is not null && sqlIfFalse is not null) + { + result = ConstructOperandsByOperator(binary, sqlIfTrue, sqlIfFalse); + } + + return result is not ({ } left, { } right) + ? _sqlExpressionFactory.Case([new CaseWhenClause(sqlTest!, sqlIfTrue!)], sqlIfFalse) + : _sqlExpressionFactory.Function("NULLIF", [left, right], true, [false, false], right.Type); + } + + private static (SqlExpression left, SqlExpression right)? ConstructOperandsByOperator( + SqlBinaryExpression binary, + SqlExpression ifTrue, + SqlExpression ifFalse) + { + (SqlExpression, SqlExpression)? result = null; + + if (binary.OperatorType is ExpressionType.Equal && ifTrue is SqlConstantExpression { Value: null }) + { + result = ConstructOperandsBySide(binary, ifFalse); + } + else if (binary.OperatorType is ExpressionType.NotEqual && ifFalse is SqlConstantExpression { Value: null }) + { + result = ConstructOperandsBySide(binary, ifTrue); + } + + return result; + } + + private static (SqlExpression left, SqlExpression right)? ConstructOperandsBySide( + SqlBinaryExpression expression, + SqlExpression sqlOnFalse) + { + (SqlExpression, SqlExpression)? operands = null; + if (expression.Left.Equals(sqlOnFalse)) + { + operands = (expression.Left, expression.Right); + } + + if (expression.Right.Equals(sqlOnFalse)) + { + operands = (expression.Right, expression.Left); + } + + return operands; + } + /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in From e7ae8cba5585f720eb15630290644f29e6deb257 Mon Sep 17 00:00:00 2001 From: Obed Kooijman Date: Mon, 9 Dec 2024 02:39:46 +0100 Subject: [PATCH 2/4] condition adjustment --- .../Query/Internal/NpgsqlSqlTranslatingExpressionVisitor.cs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/EFCore.PG/Query/Internal/NpgsqlSqlTranslatingExpressionVisitor.cs b/src/EFCore.PG/Query/Internal/NpgsqlSqlTranslatingExpressionVisitor.cs index bbbdc0afb..2c8f0715a 100644 --- a/src/EFCore.PG/Query/Internal/NpgsqlSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.PG/Query/Internal/NpgsqlSqlTranslatingExpressionVisitor.cs @@ -136,8 +136,7 @@ private static (SqlExpression left, SqlExpression right)? ConstructOperandsBySid { operands = (expression.Left, expression.Right); } - - if (expression.Right.Equals(sqlOnFalse)) + else if (expression.Right.Equals(sqlOnFalse)) { operands = (expression.Right, expression.Left); } From 1bac34113ebb7d4585a701d5bf2b48ac16693831 Mon Sep 17 00:00:00 2001 From: Obed Kooijman Date: Mon, 9 Dec 2024 22:05:50 +0100 Subject: [PATCH 3/4] added tests --- .../NorthwindFunctionsQueryNpgsqlTest.cs | 64 +++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/test/EFCore.PG.FunctionalTests/Query/NorthwindFunctionsQueryNpgsqlTest.cs b/test/EFCore.PG.FunctionalTests/Query/NorthwindFunctionsQueryNpgsqlTest.cs index 43c78f8c2..c87a047ff 100644 --- a/test/EFCore.PG.FunctionalTests/Query/NorthwindFunctionsQueryNpgsqlTest.cs +++ b/test/EFCore.PG.FunctionalTests/Query/NorthwindFunctionsQueryNpgsqlTest.cs @@ -843,6 +843,70 @@ GROUP BY o."ProductID" #endregion Statistics + #region NullIf + + [Theory] + [MemberData(nameof(IsAsyncData))] + public async Task NullIf_with_equality_left_sided(bool async) + { + await AssertQuery( + async, + cs => cs.Set().Select(x => x.OrderID == 1 ? (int?)null : x.OrderID)); + + AssertSql( + """ +SELECT NULLIF(o."OrderID", 1) +FROM "Orders" AS o +"""); + } + + [Theory] + [MemberData(nameof(IsAsyncData))] + public async Task NullIf_with_equality_right_sided(bool async) + { + await AssertQuery( + async, + cs => cs.Set().Select(x => 1 == x.OrderID ? (int?)null : x.OrderID)); + + AssertSql( + """ +SELECT NULLIF(o."OrderID", 1) +FROM "Orders" AS o +"""); + } + + [Theory] + [MemberData(nameof(IsAsyncData))] + public async Task NullIf_with_inequality_left_sided(bool async) + { + await AssertQuery( + async, + cs => cs.Set().Select(x => x.OrderID != 1 ? x.OrderID : (int?)null)); + + AssertSql( + """ +SELECT NULLIF(o."OrderID", 1) +FROM "Orders" AS o +"""); + } + + [Theory] + [MemberData(nameof(IsAsyncData))] + public async Task NullIf_with_inequality_right_sided(bool async) + { + await AssertQuery( + async, + cs => cs.Set().Select(x => 1 != x.OrderID ? x.OrderID : (int?)null)); + + AssertSql( + """ +SELECT NULLIF(o."OrderID", 1) +FROM "Orders" AS o +"""); + } + + #endregion + #region Unsupported // PostgreSQL does not have strpos with starting position From 96b846ee8df146fe960afc48b2cd4ccf733b5a13 Mon Sep 17 00:00:00 2001 From: Shay Rojansky Date: Wed, 11 Dec 2024 23:16:42 +0100 Subject: [PATCH 4/4] Clean up and correct some errors --- .../NpgsqlSqlTranslatingExpressionVisitor.cs | 66 ++++++++----------- 1 file changed, 29 insertions(+), 37 deletions(-) diff --git a/src/EFCore.PG/Query/Internal/NpgsqlSqlTranslatingExpressionVisitor.cs b/src/EFCore.PG/Query/Internal/NpgsqlSqlTranslatingExpressionVisitor.cs index 2c8f0715a..0ef5b7e9b 100644 --- a/src/EFCore.PG/Query/Internal/NpgsqlSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.PG/Query/Internal/NpgsqlSqlTranslatingExpressionVisitor.cs @@ -96,52 +96,44 @@ protected override Expression VisitConditional(ConditionalExpression conditional return QueryCompilationContext.NotTranslatedExpression; } - (SqlExpression, SqlExpression)? result = null; - + // Translate: + // a == b ? null : a -> NULLIF(a, b) + // a != b ? a : null -> NULLIF(a, b) if (sqlTest is SqlBinaryExpression binary && sqlIfTrue is not null && sqlIfFalse is not null) { - result = ConstructOperandsByOperator(binary, sqlIfTrue, sqlIfFalse); + switch (binary.OperatorType) + { + case ExpressionType.Equal + when ifTrue is SqlConstantExpression { Value: null } && TryTranslateToNullIf(sqlIfFalse, out var nullIfTranslation): + case ExpressionType.NotEqual + when ifFalse is SqlConstantExpression { Value: null } && TryTranslateToNullIf(sqlIfTrue, out nullIfTranslation): + return nullIfTranslation; + } } - return result is not ({ } left, { } right) - ? _sqlExpressionFactory.Case([new CaseWhenClause(sqlTest!, sqlIfTrue!)], sqlIfFalse) - : _sqlExpressionFactory.Function("NULLIF", [left, right], true, [false, false], right.Type); - } - - private static (SqlExpression left, SqlExpression right)? ConstructOperandsByOperator( - SqlBinaryExpression binary, - SqlExpression ifTrue, - SqlExpression ifFalse) - { - (SqlExpression, SqlExpression)? result = null; + return _sqlExpressionFactory.Case([new CaseWhenClause(sqlTest!, sqlIfTrue!)], sqlIfFalse); - if (binary.OperatorType is ExpressionType.Equal && ifTrue is SqlConstantExpression { Value: null }) + bool TryTranslateToNullIf(SqlExpression conditionalResult, [NotNullWhen(true)] out Expression? nullIfTranslation) { - result = ConstructOperandsBySide(binary, ifFalse); - } - else if (binary.OperatorType is ExpressionType.NotEqual && ifFalse is SqlConstantExpression { Value: null }) - { - result = ConstructOperandsBySide(binary, ifTrue); - } + var (left, right) = (binary.Left, binary.Right); - return result; - } + if (left.Equals(conditionalResult)) + { + nullIfTranslation = _sqlExpressionFactory.Function( + "NULLIF", [left, right], true, [false, false], left.Type, left.TypeMapping); + return true; + } - private static (SqlExpression left, SqlExpression right)? ConstructOperandsBySide( - SqlBinaryExpression expression, - SqlExpression sqlOnFalse) - { - (SqlExpression, SqlExpression)? operands = null; - if (expression.Left.Equals(sqlOnFalse)) - { - operands = (expression.Left, expression.Right); - } - else if (expression.Right.Equals(sqlOnFalse)) - { - operands = (expression.Right, expression.Left); - } + if (right.Equals(conditionalResult)) + { + nullIfTranslation = _sqlExpressionFactory.Function( + "NULLIF", [right, left], true, [false, false], right.Type, right.TypeMapping); + return true; + } - return operands; + nullIfTranslation = null; + return false; + } } ///