Skip to content

Commit 1941acb

Browse files
committed
Added support for no top-level statements (#1895)
* added helpers, most cases without top lvl work * minor update, still doesn't work. * missed in merge * final fixes and tests.
1 parent 1c83458 commit 1941acb

17 files changed

Lines changed: 967 additions & 84 deletions

File tree

src/MSIdentityScaffolding/Microsoft.DotNet.MSIdentity/CodeReaderWriter/ProjectModifier.cs

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ internal class ProjectModifier
2626
private readonly IEnumerable<string> _files;
2727
private readonly IConsoleLogger _consoleLogger;
2828
private PropertyInfo? _codeModifierConfigPropertyInfo;
29-
29+
private const string Main = nameof(Main);
3030
public ProjectModifier(ProvisioningToolOptions toolOptions, IEnumerable<string> files, IConsoleLogger consoleLogger)
3131
{
3232
_toolOptions = toolOptions ?? throw new ArgumentNullException(nameof(toolOptions));
@@ -75,12 +75,14 @@ public async Task AddAuthCodeAsync()
7575
return;
7676
}
7777

78-
var isMinimalApp = await ProjectModifierHelper.IsMinimalApp(project);
78+
var isMinimalApp = await ProjectModifierHelper.IsMinimalApp(project.Documents.ToList());
79+
var useTopLevelsStatements = await ProjectModifierHelper.IsUsingTopLevelStatements(project.Documents.ToList());
7980
CodeChangeOptions options = new CodeChangeOptions
8081
{
8182
MicrosoftGraph = _toolOptions.CallsGraph,
8283
DownstreamApi = _toolOptions.CallsDownstreamApi,
83-
IsMinimalApp = isMinimalApp
84+
IsMinimalApp = isMinimalApp,
85+
UsingTopLevelsStatements = useTopLevelsStatements
8486
};
8587

8688
// Go through all the files, make changes using DocumentBuilder.
@@ -225,7 +227,7 @@ internal async Task ModifyCsFile(CodeFile file, CodeAnalysis.Project project, Co
225227
if (file.FileName.Equals("Startup.cs"))
226228
{
227229
// Startup class file name may be different
228-
file.FileName = await ProjectModifierHelper.GetStartupClass(project) ?? file.FileName;
230+
file.FileName = await ProjectModifierHelper.GetStartupClass(project.Documents.ToList()) ?? file.FileName;
229231
}
230232

231233
var fileDoc = project.Documents.Where(d => d.Name.Equals(file.FileName)).FirstOrDefault();
@@ -260,8 +262,7 @@ internal async Task ModifyCsFile(CodeFile file, CodeAnalysis.Project project, Co
260262
private static SyntaxNode? ModifyRoot(DocumentBuilder documentBuilder, CodeChangeOptions options, CodeFile file)
261263
{
262264
var root = documentBuilder.AddUsings(options);
263-
if (file.FileName.Equals("Program.cs") && file.Methods.TryGetValue("Global", out var globalChanges)
264-
&& root.Members.Any(node => node.IsKind(SyntaxKind.GlobalStatement)))
265+
if (file.FileName.Equals("Program.cs") && file.Methods.TryGetValue("Global", out var globalChanges))
265266
{
266267
var filteredChanges = ProjectModifierHelper.FilterCodeSnippets(globalChanges.CodeChanges, options);
267268
var updatedIdentifer = ProjectModifierHelper.GetBuilderVariableIdentifierTransformation(root.Members);
@@ -270,9 +271,20 @@ internal async Task ModifyCsFile(CodeFile file, CodeAnalysis.Project project, Co
270271
(string oldValue, string newValue) = updatedIdentifer.Value;
271272
filteredChanges = ProjectModifierHelper.UpdateVariables(filteredChanges, oldValue, newValue);
272273
}
273-
274-
var updatedRoot = DocumentBuilder.ApplyChangesToMethod(root, filteredChanges);
275-
return updatedRoot;
274+
if (!options.UsingTopLevelsStatements)
275+
{
276+
var mainMethod = root?.ChildNodes().FirstOrDefault(n => n is MethodDeclarationSyntax
277+
&& ((MethodDeclarationSyntax)n).Identifier.ToString().Equals(Main, StringComparison.OrdinalIgnoreCase));
278+
if (mainMethod != null)
279+
{
280+
var updatedMethod = DocumentBuilder.ApplyChangesToMethod(mainMethod, filteredChanges);
281+
return root?.ReplaceNode(mainMethod, updatedMethod);
282+
}
283+
}
284+
else if (root.Members.Any(node => node.IsKind(SyntaxKind.GlobalStatement)))
285+
{
286+
return DocumentBuilder.ApplyChangesToMethod(root, filteredChanges);
287+
}
276288
}
277289
else
278290
{

src/Scaffolding/VS.Web.CG.EFCore/DbContextEditorServices.cs

Lines changed: 88 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
using Microsoft.CodeAnalysis.CSharp;
1212
using Microsoft.CodeAnalysis.CSharp.Syntax;
1313
using Microsoft.CodeAnalysis.Text;
14+
using Microsoft.DotNet.Scaffolding.Shared.CodeModifier;
1415
using Microsoft.DotNet.Scaffolding.Shared.Project;
1516
using Microsoft.DotNet.Scaffolding.Shared.ProjectModel;
1617
using Microsoft.VisualStudio.Web.CodeGeneration.DotNet;
@@ -35,7 +36,7 @@ public class DbContextEditorServices : IDbContextEditorServices
3536
private const string WebApplicationCreateBuilder = "WebApplication.CreateBuilder";
3637
private const string AddRazorPages = "Services.AddRazorPages()";
3738
private const string CreateBuilder = "CreateBuilder(args)";
38-
39+
private const string Main = nameof(Main);
3940

4041
public DbContextEditorServices(
4142
IProjectContext projectContext,
@@ -155,7 +156,13 @@ private string GetSafeModelName(string name, ITypeSymbol dbContext)
155156
return safeName;
156157
}
157158

158-
public EditSyntaxTreeResult EditStartupForNewContext(ModelType startUp, string dbContextTypeName, string dbContextNamespace, string dataBaseName, bool useSqlite)
159+
public EditSyntaxTreeResult EditStartupForNewContext(
160+
ModelType startUp,
161+
string dbContextTypeName,
162+
string dbContextNamespace,
163+
string dataBaseName,
164+
bool useSqlite,
165+
bool useTopLevelStatements)
159166
{
160167
Contract.Assert(startUp != null && startUp.TypeSymbol != null);
161168
Contract.Assert(!String.IsNullOrEmpty(dbContextTypeName));
@@ -169,12 +176,13 @@ public EditSyntaxTreeResult EditStartupForNewContext(ModelType startUp, string d
169176

170177
var startUpClassNode = rootNode.FindNode(declarationReference.Span);
171178

172-
var configServicesMethod = startUpClassNode.ChildNodes()
173-
.FirstOrDefault(n => n is MethodDeclarationSyntax
174-
&& ((MethodDeclarationSyntax)n).Identifier.ToString() == ConfigureServices) as MethodDeclarationSyntax;
175179
var configRootProperty = TryGetIConfigurationRootProperty(startUp.TypeSymbol);
176180
//if using Startup.cs, the ConfigureServices method should exist.
177-
if (configServicesMethod != null && configRootProperty != null)
181+
if (startUpClassNode.ChildNodes()
182+
.FirstOrDefault(n =>
183+
n is MethodDeclarationSyntax syntax &&
184+
syntax.Identifier.ToString() == ConfigureServices)
185+
is MethodDeclarationSyntax configServicesMethod && configRootProperty != null)
178186
{
179187
var servicesParam = configServicesMethod.ParameterList.Parameters
180188
.FirstOrDefault(p => p.Type.ToString().Equals(IServiceCollection));
@@ -217,46 +225,56 @@ public EditSyntaxTreeResult EditStartupForNewContext(ModelType startUp, string d
217225
//minimal hosting scenario
218226
else
219227
{
220-
CompilationUnitSyntax classSyntax = startUpClassNode as CompilationUnitSyntax;
221-
if (classSyntax != null)
228+
var statementLeadingTrivia = string.Empty;
229+
StatementSyntax dbContextExpression = null;
230+
var compilationSyntax = rootNode as CompilationUnitSyntax;
231+
if (!useTopLevelStatements)
222232
{
223-
//get leading trivia. there should be atleast one member
224-
var statementLeadingTrivia = classSyntax.Members.First()?.GetLeadingTrivia().ToString();
225-
226-
string textToAddAtEnd = AddDbContextString(minimalHostingTemplate: true, useSqlite, statementLeadingTrivia);
227-
_connectionStringsWriter.AddConnectionString(dbContextTypeName, dataBaseName, useSqlite: useSqlite);
228-
textToAddAtEnd = Environment.NewLine + textToAddAtEnd;
229-
230-
//get builder identifier string, should exist
231-
var builderExpression = classSyntax.Members.Where(st => st.ToString().Contains(WebApplicationCreateBuilder)).FirstOrDefault();
232-
var builderIdentifierString = GetBuilderIdentifier(builderExpression);
233-
234-
//create syntax expression that adds DbContext
235-
//added InvalidOperationExceptino if Configuration.GetConnectionString returns null.
236-
var expression = SyntaxFactory.ParseStatement(string.Format(textToAddAtEnd,
237-
string.Format("{0}.Services", builderIdentifierString),
238-
dbContextTypeName,
239-
string.Format("{0}.Configuration", builderIdentifierString),
240-
string.Format(" ?? throw new InvalidOperationException(\"Connection string '{0}' not found.\")", dbContextTypeName)));
241-
var dbContextExpression = SyntaxFactory.GlobalStatement(expression);
242-
243-
//get global statement to insert after (different for web app vs web api)
244-
var statementToInsertAfter = classSyntax.Members.Where(st => st.ToString().Contains(AddRazorPages)).FirstOrDefault();
245-
if (statementToInsertAfter == null)
246-
{
247-
statementToInsertAfter = classSyntax.Members.Where(st => st.ToString().Contains(CreateBuilder)).FirstOrDefault();
248-
}
249-
250-
var newClassSyntax = classSyntax.InsertNodesAfter(statementToInsertAfter, new List<GlobalStatementSyntax>() { dbContextExpression });
251-
var newRoot = rootNode.ReplaceNode(classSyntax, newClassSyntax);
233+
MethodDeclarationSyntax methodSyntax = DocumentBuilder.GetMethodFromSyntaxRoot(compilationSyntax, Main);
234+
dbContextExpression = GetAddDbContextStatement(methodSyntax.Body, dbContextTypeName, dbContextNamespace, useSqlite);
235+
}
236+
else if(useTopLevelStatements)
237+
{
238+
dbContextExpression = GetAddDbContextStatement(compilationSyntax, dbContextTypeName, dbContextNamespace, useSqlite);
239+
}
252240

241+
if (statementLeadingTrivia != null && dbContextExpression != null)
242+
{
243+
var newRoot = compilationSyntax;
253244
//add additional namespaces
254245
var namespacesToAdd = new[] { "Microsoft.EntityFrameworkCore", "Microsoft.Extensions.DependencyInjection", dbContextNamespace };
255246
foreach (var namespaceName in namespacesToAdd)
256247
{
257-
newRoot = RoslynCodeEditUtilities.AddUsingDirectiveIfNeeded(namespaceName, newRoot as CompilationUnitSyntax);
248+
newRoot = RoslynCodeEditUtilities.AddUsingDirectiveIfNeeded(namespaceName, newRoot);
249+
}
250+
if (!useTopLevelStatements)
251+
{
252+
MethodDeclarationSyntax methodSyntax = DocumentBuilder.GetMethodFromSyntaxRoot(newRoot, Main);
253+
var modifiedBlock = methodSyntax.Body;
254+
var statementToInsertAround = methodSyntax.Body.ChildNodes().Where(st => st.ToString().Contains(AddRazorPages)).FirstOrDefault();
255+
if (statementToInsertAround == null)
256+
{
257+
statementToInsertAround = methodSyntax.Body.ChildNodes().Where(st => st.ToString().Contains(CreateBuilder)).FirstOrDefault();
258+
modifiedBlock = methodSyntax.Body.InsertNodesAfter(statementToInsertAround, new List<StatementSyntax>() { dbContextExpression });
259+
}
260+
else
261+
{
262+
modifiedBlock = methodSyntax.Body.InsertNodesBefore(statementToInsertAround, new List<StatementSyntax>() { dbContextExpression });
263+
}
264+
var modifiedMethod = methodSyntax.WithBody(modifiedBlock);
265+
newRoot = newRoot.ReplaceNode(methodSyntax, modifiedMethod);
258266
}
267+
else
268+
{
269+
var statementToInsertAfter = newRoot.Members.Where(st => st.ToString().Contains(AddRazorPages)).FirstOrDefault();
270+
if (statementToInsertAfter == null)
271+
{
272+
statementToInsertAfter = newRoot.Members.Where(st => st.ToString().Contains(CreateBuilder)).FirstOrDefault();
273+
}
259274

275+
newRoot = newRoot.InsertNodesAfter(statementToInsertAfter, new List<GlobalStatementSyntax>() { SyntaxFactory.GlobalStatement(dbContextExpression) });
276+
}
277+
260278
return new EditSyntaxTreeResult()
261279
{
262280
Edited = true,
@@ -273,6 +291,38 @@ public EditSyntaxTreeResult EditStartupForNewContext(ModelType startUp, string d
273291
};
274292
}
275293

294+
/// <summary>
295+
/// Get the StatementSyntax that adds the db context to the WebApplicationBuilder.
296+
/// </summary>
297+
/// <param name="rootNode">Using the base class to allow this var to be either CompilationUnitSyntax or a MethodBodySyntax
298+
/// To get the WebApplicationBuilder variable name
299+
/// </param>
300+
/// <param name="dbContextTypeName"></param>
301+
/// <param name="dataBaseName"></param>
302+
/// <param name="useSqlite"></param>
303+
internal StatementSyntax GetAddDbContextStatement(SyntaxNode rootNode, string dbContextTypeName, string dataBaseName, bool useSqlite)
304+
{
305+
//get leading trivia. there should be atleast one member var statementLeadingTrivia = classSyntax.ChildNodes()
306+
var statementLeadingTrivia = rootNode.ChildNodes().First()?.GetLeadingTrivia().ToString();
307+
string textToAddAtEnd = AddDbContextString(minimalHostingTemplate: true, useSqlite, statementLeadingTrivia);
308+
_connectionStringsWriter.AddConnectionString(dbContextTypeName, dataBaseName, useSqlite: useSqlite);
309+
textToAddAtEnd = Environment.NewLine + textToAddAtEnd;
310+
311+
//get builder identifier string, should exist
312+
var builderExpression = rootNode.ChildNodes().Where(st => st.ToString().Contains(WebApplicationCreateBuilder)).FirstOrDefault() as MemberDeclarationSyntax;
313+
var builderIdentifierString = GetBuilderIdentifier(builderExpression);
314+
315+
//create syntax expression that adds DbContext
316+
//added InvalidOperationExceptino if Configuration.GetConnectionString returns null.
317+
var expression = SyntaxFactory.ParseStatement(string.Format(textToAddAtEnd,
318+
string.Format("{0}.Services", builderIdentifierString),
319+
dbContextTypeName,
320+
string.Format("{0}.Configuration", builderIdentifierString),
321+
string.Format(" ?? throw new InvalidOperationException(\"Connection string '{0}' not found.\")", dbContextTypeName))).WithLeadingTrivia(SyntaxFactory.Whitespace(statementLeadingTrivia));
322+
323+
return expression;
324+
}
325+
276326
private string GetBuilderIdentifier(MemberDeclarationSyntax builderMember)
277327
{
278328
if (builderMember != null)

src/Scaffolding/VS.Web.CG.EFCore/EntityFrameworkModelProcessor.cs

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ public async Task Process()
103103
{
104104
throw new InvalidOperationException(string.Format(MessageStrings.ModelTypeNotFound, "Program"));
105105
}
106-
106+
107107
if (!dbContextSymbols.Any())
108108
{
109109
//add nullable properties
@@ -367,6 +367,7 @@ private ReflectedTypesProvider GetReflectedTypesProvider(Compilation projectComp
367367
_loader,
368368
_logger);
369369
}
370+
370371
private async Task GenerateNewDbContextAndRegisterProgramFile(ModelType programType, IApplicationInfo applicationInfo)
371372
{
372373
AssemblyAttributeGenerator assemblyAttributeGenerator = GetAssemblyAttributeGenerator();
@@ -382,17 +383,20 @@ private async Task GenerateNewDbContextAndRegisterProgramFile(ModelType programT
382383
// Create a new Context
383384
_logger.LogMessage(string.Format(MessageStrings.GeneratingDbContext, _dbContextFullTypeName));
384385
bool nullabledEnabled = "enable".Equals(applicationInfo?.WorkspaceHelper?.GetMsBuildProperty("Nullable"), StringComparison.OrdinalIgnoreCase);
386+
bool useTopLevelsStatements = await ProjectModifierHelper.IsUsingTopLevelStatements(_modelTypesLocator);
385387
var dbContextTemplateModel = new NewDbContextTemplateModel(_dbContextFullTypeName, _modelTypeSymbol, programType, nullabledEnabled);
386388
_dbContextSyntaxTree = await _dbContextEditorServices.AddNewContext(dbContextTemplateModel);
387389
ContextProcessingStatus = ContextProcessingStatus.ContextAdded;
388390

389391
if (programType != null)
390392
{
391-
_programEditResult = _dbContextEditorServices.EditStartupForNewContext(programType,
393+
_programEditResult = _dbContextEditorServices.EditStartupForNewContext(
394+
programType,
392395
dbContextTemplateModel.DbContextTypeName,
393396
dbContextTemplateModel.DbContextNamespace,
394397
dataBaseName: dbContextTemplateModel.DbContextTypeName + "-" + Guid.NewGuid().ToString(),
395-
_useSqlite);
398+
_useSqlite,
399+
useTopLevelsStatements);
396400
}
397401

398402
if (!_programEditResult.Edited)
@@ -452,14 +456,15 @@ private async Task GenerateNewDbContextAndRegister(ModelType startupType, ModelT
452456

453457
_dbContextSyntaxTree = await _dbContextEditorServices.AddNewContext(dbContextTemplateModel);
454458
ContextProcessingStatus = ContextProcessingStatus.ContextAdded;
455-
459+
bool useTopLevelsStatements = await ProjectModifierHelper.IsUsingTopLevelStatements(_modelTypesLocator);
456460
if (startupType != null)
457461
{
458462
_startupEditResult = _dbContextEditorServices.EditStartupForNewContext(startupType,
459463
dbContextTemplateModel.DbContextTypeName,
460464
dbContextTemplateModel.DbContextNamespace,
461465
dataBaseName: dbContextTemplateModel.DbContextTypeName + "-" + Guid.NewGuid().ToString(),
462-
_useSqlite);
466+
_useSqlite,
467+
useTopLevelsStatements);
463468
}
464469

465470
if (!_startupEditResult.Edited)

src/Scaffolding/VS.Web.CG.EFCore/IDbContextEditorServices.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
// Copyright (c) .NET Foundation. All rights reserved.
22
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
33

4+
using System;
5+
using System.Collections.Generic;
46
using System.Threading.Tasks;
57
using Microsoft.CodeAnalysis;
68
using Microsoft.DotNet.Scaffolding.Shared.Project;
@@ -13,6 +15,6 @@ public interface IDbContextEditorServices
1315

1416
EditSyntaxTreeResult AddModelToContext(ModelType dbContext, ModelType modelType, bool nullableEnabled);
1517

16-
EditSyntaxTreeResult EditStartupForNewContext(ModelType startup, string dbContextTypeName, string dbContextNamespace, string dataBaseName, bool useSqlite);
18+
EditSyntaxTreeResult EditStartupForNewContext(ModelType startup, string dbContextTypeName, string dbContextNamespace, string dataBaseName, bool useSqlite, bool useTopLevelStatements);
1719
}
1820
}

0 commit comments

Comments
 (0)