Skip to content

Commit 26bc5a8

Browse files
committed
Add support for vector bitwise ops (close #56)
1 parent d20958b commit 26bc5a8

10 files changed

Lines changed: 431 additions & 243 deletions

File tree

include/NZSL/Ast/ExpressionType.inl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -396,13 +396,15 @@ namespace nzsl::Ast
396396

397397
inline ExpressionType UnwrapExternalType(const ExpressionType& exprType)
398398
{
399-
if (IsStorageType(exprType))
400-
return std::get<StorageType>(exprType).containedType;
401-
else if (IsUniformType(exprType))
402-
return std::get<UniformType>(exprType).containedType;
403-
else if (IsArrayType(exprType))
399+
const ExpressionType& resolvedExprType = ResolveAlias(exprType);
400+
401+
if (IsStorageType(resolvedExprType))
402+
return std::get<StorageType>(resolvedExprType).containedType;
403+
else if (IsUniformType(resolvedExprType))
404+
return std::get<UniformType>(resolvedExprType).containedType;
405+
else if (IsArrayType(resolvedExprType))
404406
{
405-
const ArrayType& arrayType = std::get<ArrayType>(exprType);
407+
const ArrayType& arrayType = std::get<ArrayType>(resolvedExprType);
406408
if (arrayType.isWrapped)
407409
{
408410
ArrayType unwrappedArrayType;

include/NZSL/Ast/Transformations/TransformerContext.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ namespace nzsl::Ast
8585
{
8686
ExpressionType type;
8787
};
88-
88+
8989
TransformerContext();
9090

9191
void Reset();

include/NZSL/Ast/Utils.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ namespace nzsl::Ast
7878
NZSL_API StatementPtr Unscope(StatementPtr&& statement);
7979

8080
NZSL_API ExpressionType ValidateBinaryOp(BinaryType op, const ExpressionType& leftExprType, const ExpressionType& rightExprType, const SourceLocation& sourceLocation, const Stringifier& typeStringifier = {});
81+
NZSL_API void ValidateUnaryOp(UnaryType op, const ExpressionType& exprType, const SourceLocation& sourceLocation, const Stringifier& typeStringifier = {});
8182

8283
NZSL_API bool ValidateMatchingTypes(const ExpressionPtr& left, const ExpressionPtr& right);
8384
NZSL_API bool ValidateMatchingTypes(const ExpressionType& left, const ExpressionType& right);

src/NZSL/Ast/Transformations/ResolveTransformer.cpp

Lines changed: 2 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -2038,7 +2038,7 @@ namespace nzsl::Ast
20382038
if (!rightExprType)
20392039
return DontVisitChildren{};
20402040

2041-
binaryExpression.cachedExpressionType = ValidateBinaryOp(binaryExpression.op, ResolveAlias(*leftExprType), ResolveAlias(*rightExprType), binaryExpression.sourceLocation, BuildStringifier(binaryExpression.sourceLocation));
2041+
binaryExpression.cachedExpressionType = ValidateBinaryOp(binaryExpression.op, *leftExprType, *rightExprType, binaryExpression.sourceLocation, BuildStringifier(binaryExpression.sourceLocation));
20422042
return DontVisitChildren{};
20432043
}
20442044

@@ -2381,55 +2381,6 @@ namespace nzsl::Ast
23812381
if (!exprType)
23822382
return DontVisitChildren{};
23832383

2384-
const ExpressionType& resolvedExprType = ResolveAlias(*exprType);
2385-
2386-
switch (node.op)
2387-
{
2388-
case UnaryType::BitwiseNot:
2389-
{
2390-
if (resolvedExprType != ExpressionType(PrimitiveType::Int32) && resolvedExprType != ExpressionType(PrimitiveType::UInt32))
2391-
throw CompilerUnaryUnsupportedError{ node.sourceLocation, ToString(*exprType, node.sourceLocation) };
2392-
2393-
break;
2394-
}
2395-
2396-
case UnaryType::LogicalNot:
2397-
{
2398-
if (resolvedExprType != ExpressionType(PrimitiveType::Boolean))
2399-
throw CompilerUnaryUnsupportedError{ node.sourceLocation, ToString(*exprType, node.sourceLocation) };
2400-
2401-
break;
2402-
}
2403-
2404-
case UnaryType::Minus:
2405-
case UnaryType::Plus:
2406-
{
2407-
PrimitiveType basicType;
2408-
if (IsPrimitiveType(resolvedExprType))
2409-
basicType = std::get<PrimitiveType>(resolvedExprType);
2410-
else if (IsVectorType(resolvedExprType))
2411-
basicType = std::get<VectorType>(resolvedExprType).type;
2412-
else
2413-
throw CompilerUnaryUnsupportedError{ node.sourceLocation, ToString(*exprType, node.sourceLocation) };
2414-
2415-
switch (basicType)
2416-
{
2417-
case PrimitiveType::Float32:
2418-
case PrimitiveType::Float64:
2419-
case PrimitiveType::Int32:
2420-
case PrimitiveType::UInt32:
2421-
case PrimitiveType::FloatLiteral:
2422-
case PrimitiveType::IntLiteral:
2423-
break;
2424-
2425-
default:
2426-
throw CompilerUnaryUnsupportedError{ node.sourceLocation, ToString(*exprType, node.sourceLocation) };
2427-
}
2428-
2429-
break;
2430-
}
2431-
}
2432-
24332384
node.cachedExpressionType = *exprType;
24342385
return DontVisitChildren{};
24352386
}
@@ -2847,7 +2798,7 @@ namespace nzsl::Ast
28472798
{
28482799
ConstantValue& optionValue = optionValueIt->second;
28492800
EnsureLiteralValue(optType, optionValue, declOption.sourceLocation);
2850-
2801+
28512802
declOption.optIndex = RegisterConstant(declOption.optName, TransformerContext::ConstantData{ m_states->currentModuleId, optionValue }, declOption.optIndex, declOption.sourceLocation);
28522803
}
28532804
else

src/NZSL/Ast/Transformations/ValidationTransformer.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -507,7 +507,7 @@ namespace nzsl::Ast
507507
{
508508
case AssignType::Simple:
509509
{
510-
if (!ValidateMatchingTypes(*leftExprType, UnwrapExternalType(ResolveAlias(*rightExprType))))
510+
if (!ValidateMatchingTypes(*leftExprType, UnwrapExternalType(*rightExprType)))
511511
throw CompilerUnmatchingTypesError{ node.sourceLocation, ToString(*leftExprType, node.sourceLocation), ToString(*rightExprType, node.sourceLocation) };
512512

513513
break;
@@ -524,7 +524,7 @@ namespace nzsl::Ast
524524

525525
if (binaryType)
526526
{
527-
ExpressionType expressionType = ValidateBinaryOp(*binaryType, ResolveAlias(*leftExprType), UnwrapExternalType(ResolveAlias(*rightExprType)), node.sourceLocation);
527+
ExpressionType expressionType = ValidateBinaryOp(*binaryType, *leftExprType, UnwrapExternalType(*rightExprType), node.sourceLocation);
528528
if (!ValidateMatchingTypes(UnwrapExternalType(*leftExprType), expressionType))
529529
throw CompilerUnmatchingTypesError{ node.sourceLocation, ToString(*leftExprType, node.sourceLocation), ToString(expressionType, node.sourceLocation) };
530530
}
@@ -1009,6 +1009,11 @@ namespace nzsl::Ast
10091009
{
10101010
HandleChildren(node);
10111011

1012+
const ExpressionType* exprType = GetExpressionType(MandatoryExpr(node.expression, node.sourceLocation));
1013+
if (!exprType)
1014+
return DontVisitChildren{};
1015+
1016+
ValidateUnaryOp(node.op, *exprType, node.sourceLocation, BuildStringifier(node.sourceLocation));
10121017
return DontVisitChildren{};
10131018
}
10141019

0 commit comments

Comments
 (0)