diff --git a/src/Scaffolding/VS.Web.CG.Mvc/Minimal Api/MinimalApiGenerator.cs b/src/Scaffolding/VS.Web.CG.Mvc/Minimal Api/MinimalApiGenerator.cs index 850271b6a..a6d572926 100644 --- a/src/Scaffolding/VS.Web.CG.Mvc/Minimal Api/MinimalApiGenerator.cs +++ b/src/Scaffolding/VS.Web.CG.Mvc/Minimal Api/MinimalApiGenerator.cs @@ -154,21 +154,23 @@ internal async Task AddEndpointsMethod(string membersBlockText, string endpoints //Get class syntax node to add members to the class var docRoot = docEditor.OriginalRoot as CompilationUnitSyntax; //create CodeFile just to add usings - var usings = new List(); //add usings for DbContext related actins. if (!string.IsNullOrEmpty(templateModel.DbContextNamespace)) { - usings.Add(Constants.MicrosoftEntityFrameworkCorePackageName); usings.Add(templateModel.DbContextNamespace); } + if (!string.IsNullOrEmpty(templateModel.ContextTypeName)) + { + usings.Add(Constants.MicrosoftEntityFrameworkCorePackageName); + } + if (templateModel.OpenAPI) { usings.Add("Microsoft.AspNetCore.Http.HttpResults"); usings.Add("Microsoft.AspNetCore.OpenApi"); } - var endpointsCodeFile = new CodeFile { Usings = usings.ToArray()}; var docBuilder = new DocumentBuilder(docEditor, endpointsCodeFile, ConsoleLogger); var newRoot = docBuilder.AddUsings(new CodeChangeOptions()); @@ -179,11 +181,31 @@ internal async Task AddEndpointsMethod(string membersBlockText, string endpoints if (classNode != null && classNode is ClassDeclarationSyntax classDeclaration) { + SyntaxNode classParentSyntax = null; + //if class is not static, create a new class in the same file + if (!classDeclaration.Modifiers.Any(x => x.Text.Equals(SyntaxFactory.Token(SyntaxKind.StaticKeyword).Text))) + { + classParentSyntax = classDeclaration.Parent; + classDeclaration = SyntaxFactory.ClassDeclaration($"{templateModel.ModelType.Name}Endpoints") + .WithModifiers(SyntaxFactory.TokenList(SyntaxFactory.Token(SyntaxKind.PublicKeyword), SyntaxFactory.Token(SyntaxKind.StaticKeyword))) + .NormalizeWhitespace() + .WithLeadingTrivia(SyntaxFactory.CarriageReturnLineFeed, SyntaxFactory.CarriageReturnLineFeed); + } var modifiedClass = classDeclaration.AddMembers( - SyntaxFactory.GlobalStatement( - SyntaxFactory.ParseStatement(membersBlockText)) - .WithLeadingTrivia(SyntaxFactory.Tab)); - newRoot = newRoot.ReplaceNode(classNode, modifiedClass); + SyntaxFactory.GlobalStatement(SyntaxFactory.ParseStatement(membersBlockText)).WithLeadingTrivia(SyntaxFactory.Tab)); + + //modify class parent by adding class, classParentSyntax should be null if given class is static. + if (classParentSyntax != null) + { + classParentSyntax = classParentSyntax.InsertNodesAfter(classNode.Parent.ChildNodes().Last(), new List() { modifiedClass }); + newRoot = newRoot.ReplaceNode(classNode.Parent, classParentSyntax); + } + //modify given class + else + { + newRoot = newRoot.ReplaceNode(classNode, modifiedClass); + } + docEditor.ReplaceNode(docRoot, newRoot); var classFileSourceTxt = await docEditor.GetChangedDocument()?.GetTextAsync(); var classFileTxt = classFileSourceTxt?.ToString(); diff --git a/test/Scaffolding/Shared/MSBuildProjectStrings.cs b/test/Scaffolding/Shared/MSBuildProjectStrings.cs index 21226709e..0f2924b05 100644 --- a/test/Scaffolding/Shared/MSBuildProjectStrings.cs +++ b/test/Scaffolding/Shared/MSBuildProjectStrings.cs @@ -311,7 +311,7 @@ public static IWebHost BuildWebHost(string[] args) => .WithName(""DeleteCar""); } }"; - public const string EndpointsEmptyClass = @"namespace MinimalApiTest { class Endpoints { } } "; + public const string EndpointsEmptyClass = @"namespace MinimalApiTest { static class Endpoints { } } "; public const string MinimalProgramcsFile = @"var builder = WebApplication.CreateBuilder(args); // Add services to the container. @@ -417,8 +417,29 @@ public class Manufacturer } }"; -// Strings for 3 layered project - public const string WebProjectTxt = @" + public const string CarWithoutNamespaceTxt = @" +using System.Collections.Generic; +using System.ComponentModel.DataAnnotations; + +public class Car +{ + public string ID { get; set; } + public string Name { get; set; } + public int ManufacturerID { get; set; } + public Manufacturer Manufacturer { get; set; } + [DataType(DataType.MultilineText)] + public string Notes { get; set; } +} + +public class Manufacturer +{ + public int ID { get; set; } + public string Name { get; set; } + public virtual ICollection Cars { get; set; } +}"; + + // Strings for 3 layered project + public const string WebProjectTxt = @"