diff --git a/docs/function-types.md b/docs/function-types.md new file mode 100644 index 000000000..7926035a3 --- /dev/null +++ b/docs/function-types.md @@ -0,0 +1,73 @@ +# Function Types + +BrighterScript allows specific definitions of functions that are passed as parameters. This enables code with callbacks to fully type-checked, and for any return types of callbacks to be propagated through the code. + +## Syntax + +Function types support the same syntax as anonymous function definitions: parameter names, optional parameters, and parameter and return types. + +```brighterscript +sub useCallBack(callback as function(name as string, num as integer) as string) + print "Result is: " + callback("hello", 7) +end sub +``` + +transpiles to + +```BrightScript +sub useCallBack(callback as function) + print "Result is: " + callback("hello", 7) +end sub +``` + +## Use with Type Statements + +Including full function signatures inside another function signature can be a bit verbose. Function types can be used in type statements to make a shorthand for easier reading. + +```brighterscript +type TextChangeHandler = function(oldName as string, newName as string) as boolean + +sub checkText(callback as TextChangeHandler) + if callback(m.oldName, m.newName) + print "Text Change OK" + end if +end sub +``` + +transpiles to + +```BrightScript +sub checkText(callback as function) + if callback(m.oldName, m.newName) + print "Text Change OK" + end if +end sub +``` + +## Validation + +Both the function type itself and the arguments when it is called and return values are validated. + +Validating the function type: + +```brighterscript +sub useCallback(callback as sub(input as string)) + callback("hello") +end sub + +sub testCallback() + ' This is a validation error: "sub(input as integer)" is NOT compatible with "sub(input as string)" + useCallback(sub(input as integer) + print input + 1 + end sub) +end sub +``` + +Validating the function call: + +```brighterscript +sub useCallback(callback as sub(input as string)) + ' This is a validation error: "integer" is NOT compatible with "string" + callback(123) +end sub +``` diff --git a/docs/readme.md b/docs/readme.md index 267c7e733..d07e6cec4 100644 --- a/docs/readme.md +++ b/docs/readme.md @@ -79,6 +79,14 @@ catch ' look, no exception variable! end try ``` +## [Function Types](function-types.md) + +```brighterscript +function useCallback(callback as function(input as string) as integer) as integer + return callback("test") +end function +``` + ## [Imports](imports.md) ```brighterscript diff --git a/src/DiagnosticMessages.ts b/src/DiagnosticMessages.ts index d394f93b2..e96d194dd 100644 --- a/src/DiagnosticMessages.ts +++ b/src/DiagnosticMessages.ts @@ -1084,6 +1084,9 @@ export const defaultMaximumTruncationLength = 160; export function typeCompatibilityMessage(actualTypeString: string, expectedTypeString: string, data: TypeCompatibilityData) { let message = ''; + if (!data) { + return message; + } actualTypeString = data?.actualType?.toString() ?? actualTypeString; expectedTypeString = data?.expectedType?.toString() ?? expectedTypeString; @@ -1104,14 +1107,16 @@ export function typeCompatibilityMessage(actualTypeString: string, expectedTypeS partBuilder: (x) => `\n member "${x.name}" should be '${x.expectedType}' but is '${x.actualType}'`, maxLength: defaultMaximumTruncationLength }); + } else if (data?.actualParamCount !== data?.expectedParamCount) { + message = `. Type '${expectedTypeString}' requires ${data.expectedParamCount} parameter${data.expectedParamCount === 1 ? '' : 's'} but '${actualTypeString}' has ${data.actualParamCount}`; } else if (data?.parameterMismatches?.length > 0) { message = '. ' + util.truncate({ leadingText: `Type '${actualTypeString}' has incompatible parameters:`, items: data.parameterMismatches, itemSeparator: '', partBuilder: (x) => { - let pExpected = x.data?.expectedType.toString() ?? 'dynamic'; - let pActual = x.data?.actualType.toString() ?? 'dynamic'; + let pExpected = x.data?.expectedType?.toString() ?? 'dynamic'; + let pActual = x.data?.actualType?.toString() ?? 'dynamic'; if (x.expectedOptional !== x.actualOptional) { pExpected += x.expectedOptional ? '?' : ''; diff --git a/src/Scope.spec.ts b/src/Scope.spec.ts index 74e33a3e7..6ef95c266 100644 --- a/src/Scope.spec.ts +++ b/src/Scope.spec.ts @@ -27,7 +27,7 @@ import { InterfaceType } from './types/InterfaceType'; import { ComponentType } from './types/ComponentType'; import { WalkMode, createVisitor } from './astUtils/visitors'; import type { BinaryExpression, CallExpression, DottedGetExpression, FunctionExpression } from './parser/Expression'; -import { ObjectType, UninitializedType } from './types'; +import { ObjectType, TypedFunctionType, UninitializedType } from './types'; import undent from 'undent'; import * as fsExtra from 'fs-extra'; import { InlineInterfaceType } from './types/InlineInterfaceType'; @@ -4617,6 +4617,159 @@ describe('Scope', () => { expectTypeToBe(memberType, StringType); }); }); + + describe('typed function expressions', () => { + it('correctly types parameters with typed function expressions', () => { + const file = program.setFile('source/test.bs', ` + sub testFunc(callback as function(s as string) as integer) + value = callback("hello") + print value + end sub + `); + program.validate(); + expectZeroDiagnostics(program); + const funcExprBody = file.parser.ast.findChildren(isFunctionExpression, { walkMode: WalkMode.visitAllRecursive })[0].body; + const table = funcExprBody.getSymbolTable(); + const valueSymbol = table.getSymbol('value', SymbolTypeFlag.runtime)[0]; + expectTypeToBe(valueSymbol.type, IntegerType); + }); + + it('correctly types parameters with typed function expressions with multiple parameters', () => { + const file = program.setFile('source/test.bs', ` + sub testFunc(callback as function(s as string, i as integer) as integer) + value = callback("hello", 123) + print value + end sub + `); + program.validate(); + expectZeroDiagnostics(program); + const funcExprBody = file.parser.ast.findChildren(isFunctionExpression, { walkMode: WalkMode.visitAllRecursive })[0].body; + const table = funcExprBody.getSymbolTable(); + const callbackSymbol = table.getSymbol('callback', SymbolTypeFlag.runtime)[0]; + expectTypeToBe(callbackSymbol.type, TypedFunctionType); + const callbackType = callbackSymbol.type as TypedFunctionType; + expect(callbackType.params.length).to.equal(2); + expectTypeToBe(callbackType.params[0].type, StringType); + expectTypeToBe(callbackType.params[1].type, IntegerType); + const valueSymbol = table.getSymbol('value', SymbolTypeFlag.runtime)[0]; + expectTypeToBe(valueSymbol.type, IntegerType); + }); + + it('can use type statement to define a typed function type', () => { + const file = program.setFile('source/test.bs', ` + type MyCallback = function(s as string) as integer + + sub testFunc(callback as MyCallback) + value = callback("hello") + print value + end sub + `); + program.validate(); + expectZeroDiagnostics(program); + const funcExprBody = file.parser.ast.findChildren(isFunctionExpression, { walkMode: WalkMode.visitAllRecursive })[0].body; + const table = funcExprBody.getSymbolTable(); + const valueSymbol = table.getSymbol('value', SymbolTypeFlag.runtime)[0]; + expectTypeToBe(valueSymbol.type, IntegerType); + }); + + it('can use a union of typed function types', () => { + const file = program.setFile('source/test.bs', ` + type CallbackA = function(s as string) as integer + type CallbackB = function(i as integer) as string + + sub testFunc(input, callback as CallbackA or CallbackB) + value = callback(input) + print value + end sub + `); + program.validate(); + expectZeroDiagnostics(program); + const funcExprBody = file.parser.ast.findChildren(isFunctionExpression, { walkMode: WalkMode.visitAllRecursive })[0].body; + const table = funcExprBody.getSymbolTable(); + const valueSymbol = table.getSymbol('value', SymbolTypeFlag.runtime)[0]; + expectTypeToBe(valueSymbol.type, UnionType); + const valueType = valueSymbol.type as UnionType; + expect(valueType.types).to.include(IntegerType.instance); + expect(valueType.types).to.include(StringType.instance); + }); + + it('function return types are greedy in the parser', () => { + const file = program.setFile('source/test.bs', ` + sub testFunc(callback as function() as integer or string) + value = callback() + print value + end sub + `); + program.validate(); + expectZeroDiagnostics(program); + const funcExprBody = file.parser.ast.findChildren(isFunctionExpression, { walkMode: WalkMode.visitAllRecursive })[0].body; + const table = funcExprBody.getSymbolTable(); + const valueSymbol = table.getSymbol('value', SymbolTypeFlag.runtime)[0]; + expectTypeToBe(valueSymbol.type, UnionType); + const valueType = valueSymbol.type as UnionType; + expect(valueType.types).to.include(IntegerType.instance); + expect(valueType.types).to.include(StringType.instance); + }); + + it('brackets in function types do not cause issues', () => { + const file = program.setFile('source/test.bs', ` + sub testFunc(callback as function(s as string) as (integer or string)) + value = callback("hello") + print value + end sub + `); + program.validate(); + expectZeroDiagnostics(program); + const funcExprBody = file.parser.ast.findChildren(isFunctionExpression, { walkMode: WalkMode.visitAllRecursive })[0].body; + const table = funcExprBody.getSymbolTable(); + const valueSymbol = table.getSymbol('value', SymbolTypeFlag.runtime)[0]; + expectTypeToBe(valueSymbol.type, UnionType); + const valueType = valueSymbol.type as UnionType; + expect(valueType.types).to.include(IntegerType.instance); + expect(valueType.types).to.include(StringType.instance); + }); + + it('function types with unions in parameters work', () => { + const file = program.setFile('source/test.bs', ` + sub testFunc(callback as function(s as string or integer) as integer) + value = callback("hello") + print value + end sub + `); + program.validate(); + expectZeroDiagnostics(program); + const funcExprBody = file.parser.ast.findChildren(isFunctionExpression, { walkMode: WalkMode.visitAllRecursive })[0].body; + const table = funcExprBody.getSymbolTable(); + const callbackSymbol = table.getSymbol('callback', SymbolTypeFlag.runtime)[0]; + expectTypeToBe(callbackSymbol.type, TypedFunctionType); + const callbackType = callbackSymbol.type as TypedFunctionType; + expect(callbackType.params.length).to.equal(1); + expectTypeToBe(callbackType.params[0].type, UnionType); + const paramType = callbackType.params[0].type as UnionType; + expect(paramType.types).to.include(StringType.instance); + expect(paramType.types).to.include(IntegerType.instance); + const valueSymbol = table.getSymbol('value', SymbolTypeFlag.runtime)[0]; + expectTypeToBe(valueSymbol.type, IntegerType); + }); + + it('args can be union of function types', () => { + const file = program.setFile('source/test.bs', ` + sub testFunc(callback as (function(s as string) as string) or (function(s as string) as integer)) + value = callback("hello") + print value + end sub + `); + program.validate(); + expectZeroDiagnostics(program); + const funcExprBody = file.parser.ast.findChildren(isFunctionExpression, { walkMode: WalkMode.visitAllRecursive })[0].body; + const table = funcExprBody.getSymbolTable(); + const valueSymbol = table.getSymbol('value', SymbolTypeFlag.runtime)[0]; + expectTypeToBe(valueSymbol.type, UnionType); + const valueType = valueSymbol.type as UnionType; + expect(valueType.types).to.include(IntegerType.instance); + expect(valueType.types).to.include(StringType.instance); + }); + }); }); describe('symbol tables with pocket tables', () => { diff --git a/src/astUtils/reflection.ts b/src/astUtils/reflection.ts index d7e3aba0b..3efc27aa7 100644 --- a/src/astUtils/reflection.ts +++ b/src/astUtils/reflection.ts @@ -1,5 +1,5 @@ import type { Body, AssignmentStatement, Block, ExpressionStatement, FunctionStatement, IfStatement, IncrementStatement, PrintStatement, GotoStatement, LabelStatement, ReturnStatement, EndStatement, StopStatement, ForStatement, ForEachStatement, WhileStatement, DottedSetStatement, IndexedSetStatement, LibraryStatement, NamespaceStatement, ImportStatement, ClassStatement, InterfaceFieldStatement, InterfaceMethodStatement, InterfaceStatement, EnumStatement, EnumMemberStatement, TryCatchStatement, CatchStatement, ThrowStatement, MethodStatement, FieldStatement, ConstStatement, ContinueStatement, DimStatement, TypecastStatement, AliasStatement, AugmentedAssignmentStatement, ConditionalCompileConstStatement, ConditionalCompileErrorStatement, ConditionalCompileStatement, ExitStatement, TypeStatement } from '../parser/Statement'; -import type { LiteralExpression, BinaryExpression, CallExpression, FunctionExpression, DottedGetExpression, XmlAttributeGetExpression, IndexedGetExpression, GroupingExpression, EscapedCharCodeLiteralExpression, ArrayLiteralExpression, AALiteralExpression, UnaryExpression, VariableExpression, SourceLiteralExpression, NewExpression, CallfuncExpression, TemplateStringQuasiExpression, TemplateStringExpression, TaggedTemplateStringExpression, AnnotationExpression, FunctionParameterExpression, AAMemberExpression, TernaryExpression, NullCoalescingExpression, PrintSeparatorExpression, TypecastExpression, TypedArrayExpression, TypeExpression, InlineInterfaceMemberExpression, InlineInterfaceExpression } from '../parser/Expression'; +import type { LiteralExpression, BinaryExpression, CallExpression, FunctionExpression, DottedGetExpression, XmlAttributeGetExpression, IndexedGetExpression, GroupingExpression, EscapedCharCodeLiteralExpression, ArrayLiteralExpression, AALiteralExpression, UnaryExpression, VariableExpression, SourceLiteralExpression, NewExpression, CallfuncExpression, TemplateStringQuasiExpression, TemplateStringExpression, TaggedTemplateStringExpression, AnnotationExpression, FunctionParameterExpression, AAMemberExpression, TernaryExpression, NullCoalescingExpression, PrintSeparatorExpression, TypecastExpression, TypedArrayExpression, TypeExpression, InlineInterfaceMemberExpression, InlineInterfaceExpression, TypedFunctionTypeExpression } from '../parser/Expression'; import type { BrsFile } from '../files/BrsFile'; import type { XmlFile } from '../files/XmlFile'; import type { BsDiagnostic, TypedefProvider } from '../interfaces'; @@ -333,6 +333,9 @@ export function isInlineInterfaceExpression(element: any): element is InlineInte export function isInlineInterfaceMemberExpression(element: any): element is InlineInterfaceMemberExpression { return element?.kind === AstNodeKind.InlineInterfaceMemberExpression; } +export function isTypedFunctionTypeExpression(element: any): element is TypedFunctionTypeExpression { + return element?.kind === AstNodeKind.TypedFunctionTypeExpression; +} // BscType reflection export function isStringType(value: any): value is StringType { @@ -349,6 +352,10 @@ export function isTypedFunctionType(value: any): value is TypedFunctionType { return value?.kind === BscTypeKind.TypedFunctionType; } +export function isTypedFunctionTypeLike(value: any): value is TypedFunctionType | TypeStatementType | UnionType { + return isTypedFunctionType(value) || isTypeStatementTypeOf(value, isTypedFunctionTypeLike) || isUnionTypeOf(value, isTypedFunctionTypeLike); +} + export function isFunctionType(value: any): value is FunctionType { return value?.kind === BscTypeKind.FunctionType; } @@ -492,7 +499,12 @@ export function isCallFuncableType(target): target is CallFuncableType { } export function isCallableType(target): target is BaseFunctionType { - return isFunctionTypeLike(target) || isTypedFunctionType(target) || isObjectType(target) || (isDynamicType(target) && !isAnyReferenceType(target)); + return isFunctionTypeLike(target) || + isTypedFunctionTypeLike(target) || + isTypeStatementTypeOf(target, isCallableType) || + isUnionTypeOf(target, isCallableType) || + isObjectType(target) || + (isDynamicType(target) && !isAnyReferenceType(target)); } export function isAnyReferenceType(target): target is AnyReferenceType { diff --git a/src/bscPlugin/hover/HoverProcessor.spec.ts b/src/bscPlugin/hover/HoverProcessor.spec.ts index 6cbb21c3b..40443f9d3 100644 --- a/src/bscPlugin/hover/HoverProcessor.spec.ts +++ b/src/bscPlugin/hover/HoverProcessor.spec.ts @@ -818,6 +818,101 @@ describe('HoverProcessor', () => { hover = program.getHover(file.srcPath, util.createPosition(6, 29))[0]; expect(hover?.contents).to.eql([`${fence('input as {name as string}')}${commentSep}from doc comment`]); }); + + it('should show the expansion of a type statement type', () => { + const file = program.setFile('source/main.bs', ` + type MyType = { + name as string + age as integer + } + + sub fooFunc(input as MyType) + print input.name + end sub + `); + program.validate(); + + // print input.na|me + let hover = program.getHover(file.srcPath, util.createPosition(7, 35))[0]; + expect(hover?.contents).to.eql([fence('MyType.name as string')]); + // print in|put.name + hover = program.getHover(file.srcPath, util.createPosition(7, 29))[0]; + expect(hover?.contents).to.eql([fence('input as MyType')]); + }); + + it('should show type statement type when hovering on the type name in the parameter list', () => { + const file = program.setFile('source/main.bs', ` + type MyType = { + name as string + age as integer + } + + sub fooFunc(input as MyType) + print input.name + end sub + `); + program.validate(); + + // sub fooFunc(input as MyT|ype) + let hover = program.getHover(file.srcPath, util.createPosition(6, 41))[0]; + expect(hover?.contents).to.eql([fence('type MyType')]); + }); + + it('should show a hover on a typed function type parameter', () => { + const file = program.setFile('source/main.bs', ` + sub fooFunc(callback as function(name as string) as integer) + callback("hello") + end sub + `); + program.validate(); + + // callback as func|tion(name as string) as integer + let hover = program.getHover(file.srcPath, util.createPosition(1, 45))[0]; + expect(hover?.contents).to.eql([fence('function (name as string) as integer')]); + // call|back as function(name as string) as integer + hover = program.getHover(file.srcPath, util.createPosition(1, 33))[0]; + expect(hover?.contents).to.eql([fence('function callback(name as string) as integer')]); + }); + + it('should show a proper hover on inline anon functions', () => { + const file = program.setFile('source/main.bs', ` + sub fooFunc() + callback = function(arg as string) as integer + return arg.len() + end function + print callback("hello") + end sub + `); + program.validate(); + + // callback = fun|ction(arg as string) as integer + let hover = program.getHover(file.srcPath, util.createPosition(2, 34))[0]; + expect(hover?.contents).to.eql([fence('function (arg as string) as integer')]); + // print call|back("hello") + hover = program.getHover(file.srcPath, util.createPosition(5, 30))[0]; + expect(hover?.contents).to.eql([fence('function callback(arg as string) as integer')]); + }); + + it('should show a proper hover on inline anon functions', () => { + const file = program.setFile('source/main.bs', ` + sub useFunc(myFunc as function(num as integer) as string) + print myFunc(123) + end sub + + sub otherFunc() + useFunc(function(a, b = invalid) as string + print a + print b + return "" + end function) + end sub + `); + program.validate(); + + // useFunc(fun|ction(a, b = invalid) as string + let hover = program.getHover(file.srcPath, util.createPosition(6, 34))[0]; + expect(hover?.contents).to.eql([fence('function (a as dynamic, b? as dynamic) as string')]); + }); }); describe('callFunc', () => { diff --git a/src/bscPlugin/hover/HoverProcessor.ts b/src/bscPlugin/hover/HoverProcessor.ts index b53fb51d1..a65cda2a0 100644 --- a/src/bscPlugin/hover/HoverProcessor.ts +++ b/src/bscPlugin/hover/HoverProcessor.ts @@ -1,4 +1,4 @@ -import { isAssignmentStatement, isBrsFile, isCallfuncExpression, isClassStatement, isDottedGetExpression, isEnumMemberStatement, isEnumStatement, isEnumType, isForStatement, isInheritableType, isInterfaceStatement, isMemberField, isNamespaceStatement, isNamespaceType, isNewExpression, isTypedFunctionType, isXmlFile } from '../../astUtils/reflection'; +import { isAssignmentStatement, isBrsFile, isCallfuncExpression, isClassStatement, isDottedGetExpression, isEnumMemberStatement, isEnumStatement, isEnumType, isForStatement, isInheritableType, isInterfaceStatement, isMemberField, isNamespaceStatement, isNamespaceType, isNewExpression, isTypedFunctionType, isTypeStatement, isTypeStatementType, isXmlFile } from '../../astUtils/reflection'; import type { BrsFile } from '../../files/BrsFile'; import type { XmlFile } from '../../files/XmlFile'; import type { ExtraSymbolData, Hover, ProvideHoverEvent, TypeChainEntry } from '../../interfaces'; @@ -110,6 +110,10 @@ export class HoverProcessor { firstToken = extraData.definingNode.tokens.enum; exprTypeString = extraData.definingNode.fullName; declarationText = firstToken?.text ?? TokenKind.Enum; + } else if (isTypeStatementType(expressionType) && isTypeStatement(extraData.definingNode)) { + firstToken = extraData.definingNode.tokens.type; + declarationText = (firstToken?.text ?? 'type'); + exprTypeString = expressionType.toString(); } } const innerText = `${declarationText} ${exprTypeString}`.trim(); @@ -175,11 +179,14 @@ export class HoverProcessor { const processedTypeChain = util.processTypeChain(typeChain); const fullName = processedTypeChain.fullNameOfItem || token.text; // if the type chain has dynamic in it, then just say the token text - const exprNameString = !processedTypeChain.containsDynamic ? fullName : token.text; + let exprNameString = !processedTypeChain.containsDynamic ? fullName : token.text; + if (isTypedFunctionType(exprType)) { + exprNameString = processedTypeChain.fullNameOfItem; + } const useCustomTypeHover = isInTypeExpression || expression?.findAncestor(isNewExpression); let hoverContent = ''; let descriptionNode; - if (useCustomTypeHover && isInheritableType(exprType)) { + if (useCustomTypeHover && (isInheritableType(exprType) || isTypeStatementType(exprType))) { hoverContent = this.getCustomTypeHover(exprType, extraData); } else if (isMemberField(expression)) { hoverContent = this.getMemberHover(expression, exprType); diff --git a/src/bscPlugin/validation/BrsFileValidator.ts b/src/bscPlugin/validation/BrsFileValidator.ts index 95704e41a..ced36f5c9 100644 --- a/src/bscPlugin/validation/BrsFileValidator.ts +++ b/src/bscPlugin/validation/BrsFileValidator.ts @@ -1,4 +1,4 @@ -import { isAliasStatement, isBlock, isBody, isClassStatement, isConditionalCompileConstStatement, isConditionalCompileErrorStatement, isConditionalCompileStatement, isConstStatement, isDottedGetExpression, isDottedSetStatement, isEnumStatement, isForEachStatement, isForStatement, isFunctionExpression, isFunctionStatement, isIfStatement, isImportStatement, isIndexedGetExpression, isIndexedSetStatement, isInterfaceStatement, isInvalidType, isLibraryStatement, isLiteralExpression, isMethodStatement, isNamespaceStatement, isTypecastExpression, isTypecastStatement, isTypeStatement, isUnaryExpression, isVariableExpression, isVoidType, isWhileStatement } from '../../astUtils/reflection'; +import { isAliasStatement, isBlock, isBody, isClassStatement, isConditionalCompileConstStatement, isConditionalCompileErrorStatement, isConditionalCompileStatement, isConstStatement, isDottedGetExpression, isDottedSetStatement, isEnumStatement, isForEachStatement, isForStatement, isFunctionExpression, isFunctionStatement, isIfStatement, isImportStatement, isIndexedGetExpression, isIndexedSetStatement, isInterfaceStatement, isInvalidType, isLibraryStatement, isLiteralExpression, isMethodStatement, isNamespaceStatement, isTypecastExpression, isTypecastStatement, isTypedFunctionTypeExpression, isTypeStatement, isUnaryExpression, isVariableExpression, isVoidType, isWhileStatement } from '../../astUtils/reflection'; import { createVisitor, WalkMode } from '../../astUtils/visitors'; import { DiagnosticMessages } from '../../DiagnosticMessages'; import type { BrsFile } from '../../files/BrsFile'; @@ -196,6 +196,9 @@ export class BrsFileValidator { this.validateFunctionParameterCount(node); }, FunctionParameterExpression: (node) => { + if (isTypedFunctionTypeExpression(node.parent)) { + return; + } const paramName = node.tokens?.name?.text; if (!paramName) { return; diff --git a/src/bscPlugin/validation/ScopeValidator.spec.ts b/src/bscPlugin/validation/ScopeValidator.spec.ts index 5cea29b49..7d716cbda 100644 --- a/src/bscPlugin/validation/ScopeValidator.spec.ts +++ b/src/bscPlugin/validation/ScopeValidator.spec.ts @@ -7,7 +7,7 @@ import type { TypeCompatibilityData } from '../../interfaces'; import { IntegerType } from '../../types/IntegerType'; import { StringType } from '../../types/StringType'; import type { BrsFile } from '../../files/BrsFile'; -import { FloatType, InterfaceType, TypedFunctionType, VoidType, BooleanType } from '../../types'; +import { FloatType, InterfaceType, TypedFunctionType, VoidType, BooleanType, ArrayType } from '../../types'; import { SymbolTypeFlag } from '../../SymbolTypeFlag'; import { AssociativeArrayType } from '../../types/AssociativeArrayType'; import undent from 'undent'; @@ -297,6 +297,31 @@ describe('ScopeValidator', () => { DiagnosticMessages.mismatchArgumentCount(2, 1).message ]); }); + + it('validates against typed functions types', () => { + program.setFile('source/main.bs', ` + sub main(cb as function(num as integer, name as string) as void) + cb(1) + end sub + `); + program.validate(); + expectDiagnostics(program, [ + DiagnosticMessages.mismatchArgumentCount(2, 1).message + ]); + }); + + it('validates against typed functions types from type statements', () => { + program.setFile('source/main.bs', ` + type Callback = function(num as integer, name as string) as void + sub main(cb as Callback) + cb(1) + end sub + `); + program.validate(); + expectDiagnostics(program, [ + DiagnosticMessages.mismatchArgumentCount(2, 1).message + ]); + }); }); describe('argumentTypeMismatch', () => { @@ -2252,6 +2277,223 @@ describe('ScopeValidator', () => { ]); }); }); + + describe('typed function type expressions', () => { + it('allows using typed function type expressions correctly', () => { + program.setFile('source/main.bs', ` + sub main(myFunc as function(num as integer) as string) + print myFunc(123) + end sub + `); + program.validate(); + expectZeroDiagnostics(program); + }); + + it('validates using typed function type expressions incorrectly', () => { + program.setFile('source/main.bs', ` + sub main(myFunc as function(num as integer) as string) + print myFunc("123") + end sub + `); + program.validate(); + expectDiagnostics(program, [ + DiagnosticMessages.argumentTypeMismatch('string', 'integer').message + ]); + }); + + it('validates using typed function type expressions from type statements', () => { + program.setFile('source/main.bs', ` + type MyFuncType = function(num as integer) as string + + sub main(myFunc as MyFuncType) + print myFunc("123") + end sub + `); + program.validate(); + expectDiagnostics(program, [ + DiagnosticMessages.argumentTypeMismatch('string', 'integer').message + ]); + }); + + it('validates using typed function type expressions from type statements', () => { + program.setFile('source/main.bs', ` + type MyFuncType = function(num as integer) as string + + sub main(myFunc as MyFuncType) + print myFunc("123") + end sub + `); + program.validate(); + expectDiagnostics(program, [ + DiagnosticMessages.argumentTypeMismatch('string', 'integer').message + ]); + }); + + it('validates using typed function type expressions returned from other functions', () => { + program.setFile('source/main.bs', ` + function getFunc() as function(num as integer) as string + return function(x as integer) as string + return "hello " + x.toStr() + end function + end function + + sub main() + myFunc = getFunc() + print myFunc("123") + end sub + `); + program.validate(); + expectDiagnostics(program, [ + DiagnosticMessages.argumentTypeMismatch('string', 'integer').message + ]); + }); + + it('validates using typed function type expressions with complex arguments', () => { + program.setFile('source/main.bs', ` + type MyFuncType = function(arg1 as {id as integer}, arg2 as IFace) as string + + interface IFace + name as string + data as float[] + end interface + + sub main(myFunc as MyFuncType) + print myFunc({id: "123"}, {name: false}) + end sub + `); + program.validate(); + expectDiagnostics(program, [ + DiagnosticMessages.argumentTypeMismatch('roAssociativeArray', '{id as integer}', { + fieldMismatches: [ + { name: 'id', expectedType: IntegerType.instance, actualType: StringType.instance } + ] + }).message, + DiagnosticMessages.argumentTypeMismatch('roAssociativeArray', 'IFace', { + fieldMismatches: [ + { name: 'name', expectedType: StringType.instance, actualType: BooleanType.instance } + ], + missingFields: [ + { name: 'data', expectedType: new ArrayType(FloatType.instance) } + ] + }).message + ]); + }); + + it('validates passing an incompatible typed function type expression as an argument', () => { + program.setFile('source/main.bs', ` + sub useFunc(myFunc as function(num as integer) as string) + print myFunc(123) + end sub + + sub otherFunc() + useFunc(function(a, b) as void + print a + print b + end function) + end sub + `); + program.validate(); + expectDiagnostics(program, [ + DiagnosticMessages.argumentTypeMismatch('function (a as dynamic, b as dynamic) as void', 'function (num as integer) as string', { + expectedParamCount: 1, + actualParamCount: 2 + }).message + ]); + }); + + it('allows passing a function with additional optional parameters', () => { + program.setFile('source/main.bs', ` + type MyFuncType1 = function(num as integer) as string + sub useFunc(myFunc as MyFuncType1) + print myFunc(123) + end sub + sub otherFunc() + useFunc(function(a as integer, b = "" as string) as string + print a + print b + return "hello" + end function) + end sub + `); + program.validate(); + expectZeroDiagnostics(program); + }); + + it('check if a passed in function has the right number and type of optional parameters', () => { + program.setFile('source/main.bs', ` + type MyFuncType1 = function(num as integer, s = "" as string) as string + + sub useFunc(myFunc as MyFuncType1) + print myFunc(123) + print myFunc(123, "test") + end sub + + sub otherFunc() + useFunc(function(a as integer) as string + print a + return "hello" + end function) + end sub + `); + program.validate(); + expectDiagnostics(program, [ + DiagnosticMessages.argumentTypeMismatch('function (a as integer) as string', 'MyFuncType1', { + expectedParamCount: 2, + actualParamCount: 1 + }).message + ]); + }); + + it('allows a passed in function that has the right number and type of optional parameters', () => { + program.setFile('source/main.bs', ` + type MyFuncType1 = function(num as integer, s = "" as string) as string + + sub useFunc(myFunc as MyFuncType1) + print myFunc(123) + print myFunc(123, "test") + end sub + + sub otherFunc() + useFunc(function(a as integer, s = "hello" as string) as string + print a + return "hello" + end function) + end sub + `); + program.validate(); + expectZeroDiagnostics(program); + }); + + it('allows passing a function that itself has function-typed parameters', () => { + program.setFile('source/main.bs', ` + type CallbackType = function(x as integer) as void + + function main(fnWrap as function(callback as CallbackType) as void) + myCallback = function(x as integer) as void + print x + end function + fnWrap(myCallback) + end function + `); + program.validate(); + expectZeroDiagnostics(program); + }); + + + it('handles recursive function type expressions', () => { + program.setFile('source/main.bs', ` + type RecursiveFuncType = function(x as integer) as RecursiveFuncType + function main(fn as RecursiveFuncType) + if m.x > 0 + nextFn = fn(m.x - 1) + main(nextFn) + end if + end function + `); + program.validate(); + expectZeroDiagnostics(program); + }); + }); }); describe('cannotFindName', () => { @@ -3642,6 +3884,53 @@ describe('ScopeValidator', () => { ]); }); }); + + describe('function return types', () => { + it('should allow returning a function that matches the return type', () => { + program.setFile('source/main.bs', ` + function getFunc() as function() as string + return function() as string + return "hello" + end function + end function + `); + program.validate(); + expectZeroDiagnostics(program); + }); + + + it('should not allow returning a function that does not match the return type', () => { + program.setFile('source/main.bs', ` + function getFunc() as function() as string + return function() as integer + return 123 + end function + end function + `); + program.validate(); + expectDiagnostics(program, [ + DiagnosticMessages.returnTypeMismatch('function () as integer', 'function () as string', { + returnTypeMismatch: { + expectedType: StringType.instance, + actualType: IntegerType.instance + } + }).message + ]); + }); + + it('can have a function type as a parameter to a function type', () => { + program.setFile('source/main.bs', ` + function test(func as function(arg1 as function() as integer) as integer) as integer + return func(function() as integer + return 123 + end function) + end function + `); + program.validate(); + expectZeroDiagnostics(program); + }); + + }); }); describe('returnTypeCoercionMismatch', () => { diff --git a/src/bscPlugin/validation/ScopeValidator.ts b/src/bscPlugin/validation/ScopeValidator.ts index 9891ed1ca..efee4d18b 100644 --- a/src/bscPlugin/validation/ScopeValidator.ts +++ b/src/bscPlugin/validation/ScopeValidator.ts @@ -1,5 +1,5 @@ import { DiagnosticTag, type Range } from 'vscode-languageserver'; -import { isAliasStatement, isArrayType, isAssignmentStatement, isAssociativeArrayType, isBinaryExpression, isBooleanTypeLike, isBrsFile, isCallExpression, isCallFuncableTypeLike, isCallableType, isCallfuncExpression, isClassStatement, isClassType, isComponentType, isDottedGetExpression, isDynamicType, isEnumMemberType, isEnumType, isFunctionExpression, isFunctionParameterExpression, isIterableType, isLiteralExpression, isNamespaceStatement, isNamespaceType, isNewExpression, isNumberTypeLike, isObjectType, isPrimitiveType, isReferenceType, isReturnStatement, isStringTypeLike, isTypedFunctionType, isUnionType, isVariableExpression, isVoidType, isXmlScope } from '../../astUtils/reflection'; +import { isAliasStatement, isArrayType, isAssignmentStatement, isAssociativeArrayType, isBinaryExpression, isBooleanTypeLike, isBrsFile, isCallExpression, isCallFuncableTypeLike, isCallableType, isCallfuncExpression, isClassStatement, isClassType, isComponentType, isCompoundType, isDottedGetExpression, isDynamicType, isEnumMemberType, isEnumType, isFunctionExpression, isFunctionParameterExpression, isIterableType, isLiteralExpression, isNamespaceStatement, isNamespaceType, isNewExpression, isNumberTypeLike, isObjectType, isPrimitiveType, isReferenceType, isReturnStatement, isStringTypeLike, isTypeStatementType, isTypedFunctionType, isUnionType, isVariableExpression, isVoidType, isXmlScope } from '../../astUtils/reflection'; import type { DiagnosticInfo } from '../../DiagnosticMessages'; import { DiagnosticMessages } from '../../DiagnosticMessages'; import type { BrsFile } from '../../files/BrsFile'; @@ -558,7 +558,10 @@ export class ScopeValidator { * Detect calls to functions with the incorrect number of parameters, or wrong types of arguments */ private validateFunctionCall(file: BrsFile, callee: Expression, funcType: BscType, callErrorLocation: Location, args: Expression[], argOffset = 0) { - if (!funcType?.isResolvable() || !isCallableType(funcType)) { + while (isTypeStatementType(funcType)) { + funcType = funcType.wrappedType; + } + if (!funcType?.isResolvable() || !isCallableType(funcType) || isCompoundType(funcType)) { const funcName = util.getAllDottedGetPartsAsString(callee, ParseMode.BrighterScript, isCallfuncExpression(callee) ? '@.' : '.'); if (isUnionType(funcType)) { if (!util.isUnionOfFunctions(funcType) && !isCallfuncExpression(callee)) { diff --git a/src/parser/AstNode.ts b/src/parser/AstNode.ts index d5f5ddee3..021e5d1ce 100644 --- a/src/parser/AstNode.ts +++ b/src/parser/AstNode.ts @@ -361,5 +361,6 @@ export enum AstNodeKind { PrintSeparatorExpression = 'PrintSeparatorExpression', InlineInterfaceExpression = 'InlineInterfaceExpression', InlineInterfaceMemberExpression = 'InlineInterfaceMemberExpression', - TypeStatement = 'TypeStatement' + TypeStatement = 'TypeStatement', + TypedFunctionTypeExpression = 'TypedFunctionTypeExpression' } diff --git a/src/parser/Expression.ts b/src/parser/Expression.ts index 85b1c93ef..57d6a4736 100644 --- a/src/parser/Expression.ts +++ b/src/parser/Expression.ts @@ -11,7 +11,7 @@ import * as fileUrl from 'file-url'; import type { WalkOptions, WalkVisitor } from '../astUtils/visitors'; import { WalkMode } from '../astUtils/visitors'; import { walk, InternalWalkMode, walkArray } from '../astUtils/visitors'; -import { isAALiteralExpression, isAAMemberExpression, isArrayLiteralExpression, isArrayType, isCallableType, isCallExpression, isCallfuncExpression, isDottedGetExpression, isEscapedCharCodeLiteralExpression, isFunctionExpression, isFunctionStatement, isIntegerType, isInterfaceMethodStatement, isInvalidType, isLiteralBoolean, isLiteralExpression, isLiteralNumber, isLiteralString, isLongIntegerType, isMethodStatement, isNamespaceStatement, isNativeType, isNewExpression, isPrimitiveType, isReferenceType, isStringType, isTemplateStringExpression, isTypecastExpression, isUnaryExpression, isVariableExpression, isVoidType } from '../astUtils/reflection'; +import { isAALiteralExpression, isAAMemberExpression, isArrayLiteralExpression, isArrayType, isCallableType, isCallExpression, isCallfuncExpression, isDottedGetExpression, isEscapedCharCodeLiteralExpression, isFunctionExpression, isFunctionStatement, isIntegerType, isInterfaceMethodStatement, isInvalidType, isLiteralBoolean, isLiteralExpression, isLiteralNumber, isLiteralString, isLongIntegerType, isMethodStatement, isNamespaceStatement, isNativeType, isNewExpression, isPrimitiveType, isReferenceType, isStringType, isTemplateStringExpression, isTypecastExpression, isTypeStatementType, isUnaryExpression, isVariableExpression, isVoidType } from '../astUtils/reflection'; import type { GetTypeOptions, TranspileResult, TypedefProvider } from '../interfaces'; import { TypeChainEntry } from '../interfaces'; import { VoidType } from '../types/VoidType'; @@ -197,7 +197,10 @@ export class CallExpression extends Expression { } getType(options: GetTypeOptions) { - const calleeType = this.callee.getType(options); + let calleeType = this.callee.getType(options); + while (isTypeStatementType(calleeType)) { + calleeType = calleeType.wrappedType; + } if (options.ignoreCall) { return calleeType; } @@ -2902,6 +2905,94 @@ export class InlineInterfaceMemberExpression extends Expression { } } +export class TypedFunctionTypeExpression extends Expression { + constructor(options: { + functionType?: Token; + leftParen?: Token; + params?: FunctionParameterExpression[]; + rightParen?: Token; + as?: Token; + returnType?: TypeExpression; + + }) { + super(); + this.tokens = { + functionType: options.functionType, + leftParen: options.leftParen, + rightParen: options.rightParen, + as: options.as + }; + this.params = options.params; + this.returnType = options.returnType; + this.location = util.createBoundingLocation( + this.tokens.functionType, + this.tokens.leftParen, + ...this.params, + this.tokens.rightParen, + this.tokens.as, + this.returnType + ); + } + + public readonly kind = AstNodeKind.TypedFunctionTypeExpression; + + public readonly tokens: { + readonly functionType?: Token; + readonly leftParen?: Token; + readonly rightParen?: Token; + readonly as?: Token; + }; + + public readonly params: FunctionParameterExpression[]; + public readonly returnType?: TypeExpression; + + public readonly location: Location; + + public transpile(state: BrsTranspileState): TranspileResult { + return [this.getType({ flags: SymbolTypeFlag.typetime }).toTypeString()]; + } + + public walk(visitor: WalkVisitor, options: WalkOptions) { + if (options.walkMode & InternalWalkMode.walkExpressions) { + walkArray(this.params, visitor, options, this); + walk(this, 'returnType', visitor, options); + } + } + + public getType(options: GetTypeOptions): BscType { + const returnType = this.returnType?.getType({ ...options, typeChain: undefined }) ?? DynamicType.instance; + const functionType = new TypedFunctionType(returnType); + for (const param of this.params) { + functionType.addParameter(param.tokens.name.text, param.getType({ ...options, typeChain: undefined }), !!param.defaultValue); + } + functionType.setSub(this.tokens.functionType?.kind === TokenKind.Sub); + + options.typeChain?.push(new TypeChainEntry({ + name: '', + type: functionType, + astNode: this, + data: options.data + })); + + return functionType; + } + + public clone() { + return this.finalizeClone( + new TypedFunctionTypeExpression({ + functionType: util.cloneToken(this.tokens.functionType), + leftParen: util.cloneToken(this.tokens.leftParen), + params: this.params?.map(x => x?.clone()), + rightParen: util.cloneToken(this.tokens.rightParen), + as: util.cloneToken(this.tokens.as), + returnType: this.returnType?.clone() + }), + ['params', 'returnType'] + ); + } +} + + /** * A list of names of functions that are restricted from being stored to a * variable, property, or passed as an argument. (i.e. `type` or `createobject`). diff --git a/src/parser/Parser.spec.ts b/src/parser/Parser.spec.ts index 2e253827c..76fef134f 100644 --- a/src/parser/Parser.spec.ts +++ b/src/parser/Parser.spec.ts @@ -3005,6 +3005,139 @@ describe('parser', () => { expectZeroDiagnostics(diagnostics); }); }); + + describe('typed functions as types', () => { + it('disallowed in brightscript mode', () => { + let { diagnostics } = parse(` + function test(func as function()) + return func() + end function + `, ParseMode.BrightScript); + expectDiagnosticsIncludes(diagnostics, [ + DiagnosticMessages.bsFeatureNotSupportedInBrsFiles('typed function types') + ]); + }); + + it('can be passed as param types', () => { + let { diagnostics } = parse(` + function test(func as function()) + return func() + end function + `, ParseMode.BrighterScript); + expectZeroDiagnostics(diagnostics); + }); + + it('can have a return type', () => { + let { diagnostics } = parse(` + function test(func as sub() as integer) as integer + return func() + end function + `, ParseMode.BrighterScript); + expectZeroDiagnostics(diagnostics); + }); + + it('can use sub or function', () => { + let { diagnostics } = parse(` + function test(func as sub() as integer) as integer + return func() + end function + + function test2(func as function() as integer) as integer + return func() + end function + `, ParseMode.BrighterScript); + expectZeroDiagnostics(diagnostics); + }); + + it('can have primitive parameters', () => { + let { diagnostics } = parse(` + function test(func as function(name as string, num as integer) as integer) as integer + return func("hello", 123) + end function + `, ParseMode.BrighterScript); + expectZeroDiagnostics(diagnostics); + }); + + it('can have complex parameters', () => { + let { diagnostics } = parse(` + interface IFace + name as string + end interface + + function test(func as function(thing as IFace) as integer) as integer + return func({name: "hello"}) + end function + `, ParseMode.BrighterScript); + expectZeroDiagnostics(diagnostics); + }); + + it('can have compound parameters', () => { + let { diagnostics } = parse(` + interface IFace + name as string + end interface + + function test(func as function(arg1 as string or integer, arg2 as IFace) as integer) as integer + return func("hello", {name: "hello"}) + end function + `, ParseMode.BrighterScript); + expectZeroDiagnostics(diagnostics); + }); + + it('invalid syntax inside the function type causes errors', () => { + let { diagnostics } = parse(` + function test(func as function(arg2 noAsBeforeType) as integer) as integer + return func({name: "hello"}) + end function + `, ParseMode.BrighterScript); + expectDiagnosticsIncludes(diagnostics, [ + DiagnosticMessages.unmatchedLeftToken('(', 'function type expression').message + ]); + }); + + it('can be used as return types', () => { + let { diagnostics } = parse(` + function test() as function() as integer + return function() as integer + return 123 + end function + end function + `, ParseMode.BrighterScript); + expectZeroDiagnostics(diagnostics); + }); + + it('can have a union as return type', () => { + let { diagnostics } = parse(` + type foo = function() as integer or string + function test() as foo + return function() as integer + return 123 + end function + end function + `, ParseMode.BrighterScript); + expectZeroDiagnostics(diagnostics); + }); + + it('can have optional parameters', () => { + let { diagnostics } = parse(` + function test(func as function(arg1 as string, arg2 = 0 as integer) as integer) as integer + return func("hello") + end function + `, ParseMode.BrighterScript); + expectZeroDiagnostics(diagnostics); + }); + + it('can have a function type as a parameter to a function type', () => { + let { diagnostics } = parse(` + function test(func as function(arg1 as function() as integer) as integer) as integer + return func(function() as integer + return 123 + end function) + end function + `, ParseMode.BrighterScript); + expectZeroDiagnostics(diagnostics); + }); + }); }); export function parse(text: string, mode?: ParseMode, bsConsts: Record = {}) { diff --git a/src/parser/Parser.ts b/src/parser/Parser.ts index 4b5eab050..3a5f78ba2 100644 --- a/src/parser/Parser.ts +++ b/src/parser/Parser.ts @@ -96,7 +96,8 @@ import { XmlAttributeGetExpression, PrintSeparatorExpression, InlineInterfaceExpression, - InlineInterfaceMemberExpression + InlineInterfaceMemberExpression, + TypedFunctionTypeExpression } from './Expression'; import type { Range } from 'vscode-languageserver'; import type { Logger } from '../logging'; @@ -3108,9 +3109,13 @@ export class Parser { * @returns an expression that was successfully parsed */ private getTypeExpressionPart(changedTokens: { token: Token; oldKind: TokenKind }[]) { - let expr: VariableExpression | DottedGetExpression | TypedArrayExpression | InlineInterfaceExpression | GroupingExpression; + let expr: VariableExpression | DottedGetExpression | TypedArrayExpression | InlineInterfaceExpression | GroupingExpression | TypedFunctionTypeExpression; - if (this.checkAny(...DeclarableTypes)) { + if (this.checkAny(TokenKind.Sub, TokenKind.Function) && this.checkNext(TokenKind.LeftParen)) { + // this is a tyyed function type expression, eg. "function(type1, type2) as type3" + this.warnIfNotBrighterScriptMode('typed function types'); + expr = this.typedFunctionTypeExpression(); + } else if (this.checkAny(...DeclarableTypes)) { // if this is just a type, just use directly expr = new VariableExpression({ name: this.advance() as Identifier }); } else { @@ -3166,6 +3171,46 @@ export class Parser { return expr; } + private typedFunctionTypeExpression() { + const funcOrSub = this.advance(); + const openParen = this.consume(DiagnosticMessages.expectedToken(TokenKind.LeftParen), TokenKind.LeftParen); + const params: FunctionParameterExpression[] = []; + + if (!this.check(TokenKind.RightParen)) { + do { + if (params.length >= CallExpression.MaximumArguments) { + this.diagnostics.push({ + ...DiagnosticMessages.tooManyCallableParameters(params.length, CallExpression.MaximumArguments), + location: this.peek().location + }); + } + + params.push(this.functionParameter()); + } while (this.match(TokenKind.Comma)); + } + + const closeParen = this.consume( + DiagnosticMessages.unmatchedLeftToken(openParen.text, 'function type expression'), + TokenKind.RightParen + ); + + let asToken: Token; + let returnType: TypeExpression; + + if (this.check(TokenKind.As)) { + [asToken, returnType] = this.consumeAsTokenAndTypeExpression(); + } + return new TypedFunctionTypeExpression({ + functionType: funcOrSub, + rightParen: openParen, + params: params, + leftParen: closeParen, + as: asToken, + returnType: returnType + }); + + } + private inlineInterface() { let expr: InlineInterfaceExpression; diff --git a/src/types/ArrayType.ts b/src/types/ArrayType.ts index 5eb50029a..ffcaa9f47 100644 --- a/src/types/ArrayType.ts +++ b/src/types/ArrayType.ts @@ -1,6 +1,6 @@ import { SymbolTypeFlag } from '../SymbolTypeFlag'; -import { isArrayType, isDynamicType, isEnumMemberType, isInvalidType, isObjectType } from '../astUtils/reflection'; +import { isArrayType, isDynamicType, isEnumMemberType, isInvalidType, isObjectType, isTypeStatementType } from '../astUtils/reflection'; import type { TypeCompatibilityData } from '../interfaces'; import { BscType } from './BscType'; import { BscTypeKind } from './BscTypeKind'; @@ -48,7 +48,9 @@ export class ArrayType extends BscType { } public isTypeCompatible(targetType: BscType, data?: TypeCompatibilityData) { - + while (isTypeStatementType(targetType)) { + targetType = targetType.wrappedType; + } if (isDynamicType(targetType)) { return true; } else if (isObjectType(targetType)) { diff --git a/src/types/BooleanType.ts b/src/types/BooleanType.ts index fe06c08e1..11d53a7df 100644 --- a/src/types/BooleanType.ts +++ b/src/types/BooleanType.ts @@ -1,4 +1,4 @@ -import { isBooleanType, isDynamicType, isObjectType } from '../astUtils/reflection'; +import { isBooleanType, isBooleanTypeLike, isDynamicType, isObjectType } from '../astUtils/reflection'; import { BscType } from './BscType'; import { BscTypeKind } from './BscTypeKind'; import { isNativeInterfaceCompatible, isUnionTypeCompatible } from './helpers'; @@ -13,7 +13,7 @@ export class BooleanType extends BscType { public isTypeCompatible(targetType: BscType, data?: TypeCompatibilityData) { return ( - isBooleanType(targetType) || + isBooleanTypeLike(targetType) || isDynamicType(targetType) || isObjectType(targetType) || isUnionTypeCompatible(this, targetType) || diff --git a/src/types/ClassType.ts b/src/types/ClassType.ts index 691b9defc..8b9a19aea 100644 --- a/src/types/ClassType.ts +++ b/src/types/ClassType.ts @@ -1,5 +1,5 @@ import { SymbolTable } from '../SymbolTable'; -import { isClassType, isDynamicType, isInvalidType, isObjectType } from '../astUtils/reflection'; +import { isClassType, isDynamicType, isInvalidType, isObjectType, isTypeStatementType } from '../astUtils/reflection'; import type { TypeCompatibilityData } from '../interfaces'; import type { BscType } from './BscType'; import { BscTypeKind } from './BscTypeKind'; @@ -17,6 +17,9 @@ export class ClassType extends InheritableType { public readonly kind = BscTypeKind.ClassType; public isTypeCompatible(targetType: BscType, data?: TypeCompatibilityData) { + while (isTypeStatementType(targetType)) { + targetType = targetType.wrappedType; + } if (this.isEqual(targetType, data)) { return true; } else if ( diff --git a/src/types/ComponentType.ts b/src/types/ComponentType.ts index e2c46c5f9..59f47ae55 100644 --- a/src/types/ComponentType.ts +++ b/src/types/ComponentType.ts @@ -1,6 +1,6 @@ import { SymbolTypeFlag } from '../SymbolTypeFlag'; import { SymbolTable } from '../SymbolTable'; -import { isComponentType, isDynamicType, isInvalidType, isObjectType, isReferenceType } from '../astUtils/reflection'; +import { isComponentType, isDynamicType, isInvalidType, isObjectType, isReferenceType, isTypeStatementType } from '../astUtils/reflection'; import type { TypeCompatibilityData } from '../interfaces'; import type { BscType } from './BscType'; import { BscTypeKind } from './BscTypeKind'; @@ -22,6 +22,9 @@ export class ComponentType extends CallFuncableType { } public isTypeCompatible(targetType: BscType, data?: TypeCompatibilityData) { + while (isTypeStatementType(targetType)) { + targetType = targetType.wrappedType; + } if (this.isEqual(targetType)) { return true; } else if (isInvalidType(targetType) || diff --git a/src/types/EnumType.ts b/src/types/EnumType.ts index f279cbc34..cda5c018e 100644 --- a/src/types/EnumType.ts +++ b/src/types/EnumType.ts @@ -1,5 +1,5 @@ import { SymbolTypeFlag } from '../SymbolTypeFlag'; -import { isDynamicType, isEnumMemberType, isEnumType, isObjectType } from '../astUtils/reflection'; +import { isDynamicType, isEnumMemberType, isEnumType, isObjectType, isTypeStatementType } from '../astUtils/reflection'; import type { TypeCompatibilityData } from '../interfaces'; import { BscType } from './BscType'; import { BscTypeKind } from './BscTypeKind'; @@ -20,6 +20,9 @@ export class EnumType extends BscType { public readonly kind = BscTypeKind.EnumType; public isTypeCompatible(targetType: BscType, data?: TypeCompatibilityData) { + while (isTypeStatementType(targetType)) { + targetType = targetType.wrappedType; + } return ( isDynamicType(targetType) || isObjectType(targetType) || diff --git a/src/types/IntersectionType.ts b/src/types/IntersectionType.ts index 3ac3df35b..9498aedaf 100644 --- a/src/types/IntersectionType.ts +++ b/src/types/IntersectionType.ts @@ -1,5 +1,5 @@ import type { GetTypeOptions, TypeCompatibilityData } from '../interfaces'; -import { isDynamicType, isIntersectionType, isObjectType, isTypedFunctionType } from '../astUtils/reflection'; +import { isDynamicType, isIntersectionType, isObjectType, isTypedFunctionType, isTypeStatementType } from '../astUtils/reflection'; import { BscType } from './BscType'; import { ReferenceTypeWithDefault, ReferenceType } from './ReferenceType'; import { addAssociatedTypesTableAsSiblingToMemberTable, getAllTypesFromCompoundType, isEnumTypeCompatible, isTypeWithPotentialDefaultDynamicMember, joinTypesString, reduceTypesForIntersectionType } from './helpers'; @@ -158,6 +158,9 @@ export class IntersectionType extends BscType { isTypeCompatible(targetType: BscType, data?: TypeCompatibilityData): boolean { + while (isTypeStatementType(targetType)) { + targetType = targetType.wrappedType; + } if (isDynamicType(targetType) || isObjectType(targetType) || this === targetType) { return true; } diff --git a/src/types/TypeStatementType.ts b/src/types/TypeStatementType.ts index 87f5b217d..7816f1105 100644 --- a/src/types/TypeStatementType.ts +++ b/src/types/TypeStatementType.ts @@ -1,6 +1,7 @@ import { BscType } from './BscType'; import type { GetTypeOptions, TypeCompatibilityData } from '../interfaces'; import { BscTypeKind } from './BscTypeKind'; +import { isCallableType } from '../astUtils/reflection'; export class TypeStatementType extends BscType { @@ -11,9 +12,12 @@ export class TypeStatementType extends BscType { } public isTypeCompatible(targetType: BscType, data?: TypeCompatibilityData) { - return ( - this.wrappedType.isTypeCompatible(targetType, data) - ); + data = data || {}; + if (!data.expectedType) { + data.expectedType = this; + } + + return this.wrappedType.isTypeCompatible(targetType, data); } public toString() { @@ -52,5 +56,12 @@ export class TypeStatementType extends BscType { return this.wrappedType.getCallFuncTable(); } + get returnType() { + if (isCallableType(this.wrappedType)) { + return this.wrappedType.returnType; + } + return undefined; + } + } diff --git a/src/types/TypedFunctionType.ts b/src/types/TypedFunctionType.ts index 3e29ccfa8..1e82a8beb 100644 --- a/src/types/TypedFunctionType.ts +++ b/src/types/TypedFunctionType.ts @@ -1,4 +1,4 @@ -import { isDynamicType, isObjectType, isTypedFunctionType } from '../astUtils/reflection'; +import { isDynamicType, isObjectType, isTypedFunctionType, isTypeStatementType } from '../astUtils/reflection'; import { BaseFunctionType } from './BaseFunctionType'; import type { BscType } from './BscType'; import { BscTypeKind } from './BscTypeKind'; @@ -52,6 +52,13 @@ export class TypedFunctionType extends BaseFunctionType { } public isTypeCompatible(targetType: BscType, data: TypeCompatibilityData = {}) { + data = data || {}; + if (!data.actualType) { + data.actualType = targetType; + } + while (isTypeStatementType(targetType)) { + targetType = targetType.wrappedType; + } if ( isDynamicType(targetType) || isObjectType(targetType) || @@ -60,6 +67,9 @@ export class TypedFunctionType extends BaseFunctionType { return true; } if (isTypedFunctionType(targetType)) { + if (this === targetType) { + return true; + } return this.checkParamsAndReturnValue(targetType, true, (t1, t2, d) => t1.isTypeCompatible(t2, d), data); } return false; @@ -116,12 +126,12 @@ export class TypedFunctionType extends BaseFunctionType { paramTypeData.expectedType = paramTypeData.expectedType ?? myParam?.type; paramTypeData.actualType = paramTypeData.actualType ?? targetParam?.type; if (!targetParam || !myParam) { - data.expectedParamCount = this.params.filter(p => !p.isOptional).length; - data.actualParamCount = targetType.params.filter(p => !p.isOptional).length; + data.expectedParamCount = this.params.length; + data.actualParamCount = targetType.params.length; } data.parameterMismatches.push({ index: i, data: paramTypeData }); - data.expectedType = this; - data.actualType = targetType; + data.expectedType = data.expectedType ?? this; + data.actualType = data.actualType ?? targetType; return false; } if ((!allowOptionalParamDifferences && myParam.isOptional !== targetParam.isOptional) || @@ -150,8 +160,8 @@ export class TypedFunctionType extends BaseFunctionType { data = data ?? {}; data.expectedVariadic = this.isVariadic; data.actualVariadic = targetType.isVariadic; - data.expectedType = this; - data.actualType = targetType; + data.expectedType = data.expectedType ?? this; + data.actualType = data.actualType ?? targetType; return false; } //made it here, all params and return type pass predicate diff --git a/src/types/UnionType.ts b/src/types/UnionType.ts index 0c657a3ea..45a764128 100644 --- a/src/types/UnionType.ts +++ b/src/types/UnionType.ts @@ -1,5 +1,5 @@ import type { GetTypeOptions, TypeCompatibilityData } from '../interfaces'; -import { isDynamicType, isObjectType, isTypedFunctionType, isUnionType } from '../astUtils/reflection'; +import { isDynamicType, isObjectType, isTypedFunctionType, isTypeStatementType, isUnionType } from '../astUtils/reflection'; import { BscType } from './BscType'; import { ReferenceType } from './ReferenceType'; import { addAssociatedTypesTableAsSiblingToMemberTable, findTypeUnion, findTypeUnionDeepCheck, getAllTypesFromCompoundType, getUniqueType, isEnumTypeCompatible, joinTypesString } from './helpers'; @@ -117,6 +117,9 @@ export class UnionType extends BscType { isTypeCompatible(targetType: BscType, data?: TypeCompatibilityData): boolean { + while (isTypeStatementType(targetType)) { + targetType = targetType.wrappedType; + } if (isDynamicType(targetType) || isObjectType(targetType) || this === targetType) { return true; } diff --git a/src/types/helpers.ts b/src/types/helpers.ts index 0fc9af8b9..3d98a6f64 100644 --- a/src/types/helpers.ts +++ b/src/types/helpers.ts @@ -1,5 +1,5 @@ import type { TypeCompatibilityData } from '../interfaces'; -import { isAnyReferenceType, isArrayDefaultTypeReferenceType, isAssociativeArrayTypeLike, isCompoundType, isDynamicType, isEnumMemberType, isEnumType, isInheritableType, isInterfaceType, isIntersectionType, isObjectType, isReferenceType, isTypePropertyReferenceType, isUnionType, isUnionTypeOf, isVoidType } from '../astUtils/reflection'; +import { isAnyReferenceType, isArrayDefaultTypeReferenceType, isAssociativeArrayTypeLike, isCompoundType, isDynamicType, isEnumMemberType, isEnumType, isInheritableType, isInterfaceType, isIntersectionType, isObjectType, isReferenceType, isTypePropertyReferenceType, isTypeStatementType, isUnionType, isUnionTypeOf, isVoidType } from '../astUtils/reflection'; import type { BscType } from './BscType'; import type { UnionType } from './UnionType'; import type { SymbolTable } from '../SymbolTable'; @@ -42,10 +42,16 @@ export function getUniqueTypesFromArray(types: BscType[], allowNameEquality = tr if (!currentType) { return false; } + while (isTypeStatementType(currentType)) { + currentType = currentType.wrappedType; + } if ((isTypePropertyReferenceType(currentType) || isArrayDefaultTypeReferenceType(currentType)) && !currentType.isResolvable()) { return true; } const latestIndex = types.findIndex((checkType) => { + while (isTypeStatementType(checkType)) { + checkType = checkType.wrappedType; + } return currentType.isEqual(checkType, { allowNameEquality: allowNameEquality }); }); // the index that was found is the index we're checking --- there are no equal types after this diff --git a/src/util.ts b/src/util.ts index 20c2c6228..01e9bb45a 100644 --- a/src/util.ts +++ b/src/util.ts @@ -25,7 +25,7 @@ import type { CallExpression, CallfuncExpression, DottedGetExpression, FunctionP import { LogLevel, createLogger } from './logging'; import { isToken, type Identifier, type Token } from './lexer/Token'; import { TokenKind } from './lexer/TokenKind'; -import { isAnyReferenceType, isBinaryExpression, isBooleanTypeLike, isBrsFile, isCallExpression, isCallableType, isCallfuncExpression, isClassType, isCompoundType, isComponentType, isDottedGetExpression, isDoubleTypeLike, isDynamicType, isEnumMemberType, isExpression, isFloatTypeLike, isIndexedGetExpression, isIntegerTypeLike, isIntersectionType, isInvalidTypeLike, isLiteralString, isLongIntegerTypeLike, isNamespaceStatement, isNamespaceType, isNewExpression, isNumberTypeLike, isObjectType, isPrimitiveType, isReferenceType, isStatement, isStringTypeLike, isTypeExpression, isTypedArrayExpression, isTypedFunctionType, isUninitializedType, isUnionType, isVariableExpression, isVoidType, isXmlAttributeGetExpression, isXmlFile, isArrayType, isAssociativeArrayTypeLike, isBuiltInType } from './astUtils/reflection'; +import { isAnyReferenceType, isBinaryExpression, isBooleanTypeLike, isBrsFile, isCallExpression, isCallableType, isCallfuncExpression, isClassType, isCompoundType, isComponentType, isDottedGetExpression, isDoubleTypeLike, isDynamicType, isEnumMemberType, isExpression, isFloatTypeLike, isIndexedGetExpression, isIntegerTypeLike, isIntersectionType, isInvalidTypeLike, isLiteralString, isLongIntegerTypeLike, isNamespaceStatement, isNamespaceType, isNewExpression, isNumberTypeLike, isObjectType, isPrimitiveType, isReferenceType, isStatement, isStringTypeLike, isTypeExpression, isTypedArrayExpression, isTypedFunctionType, isUninitializedType, isUnionType, isVariableExpression, isVoidType, isXmlAttributeGetExpression, isXmlFile, isArrayType, isAssociativeArrayTypeLike, isBuiltInType, isTypedFunctionTypeLike } from './astUtils/reflection'; import { WalkMode } from './astUtils/visitors'; import { SourceNode } from 'source-map'; import * as requireRelative from 'require-relative'; @@ -2590,7 +2590,7 @@ export class Util { public getReturnTypeOfUnionOfFunctions(type: UnionType): BscType { if (this.isUnionOfFunctions(type, true)) { - const typedFuncsInUnion = type.types.filter(t => isTypedFunctionType(t) || isReferenceType(t)) as TypedFunctionType[]; + const typedFuncsInUnion = type.types.filter(t => isTypedFunctionTypeLike(t) || isReferenceType(t)) as TypedFunctionType[]; if (typedFuncsInUnion.length < type.types.length) { // is non-typedFuncs in union return DynamicType.instance;