1111using Microsoft . CodeAnalysis . CSharp ;
1212using Microsoft . CodeAnalysis . CSharp . Syntax ;
1313using Microsoft . CodeAnalysis . Text ;
14+ using Microsoft . DotNet . Scaffolding . Shared . CodeModifier ;
1415using Microsoft . DotNet . Scaffolding . Shared . Project ;
1516using Microsoft . DotNet . Scaffolding . Shared . ProjectModel ;
1617using Microsoft . VisualStudio . Web . CodeGeneration . DotNet ;
@@ -35,7 +36,7 @@ public class DbContextEditorServices : IDbContextEditorServices
3536 private const string WebApplicationCreateBuilder = "WebApplication.CreateBuilder" ;
3637 private const string AddRazorPages = "Services.AddRazorPages()" ;
3738 private const string CreateBuilder = "CreateBuilder(args)" ;
38-
39+ private const string Main = nameof ( Main ) ;
3940
4041 public DbContextEditorServices (
4142 IProjectContext projectContext ,
@@ -155,7 +156,13 @@ private string GetSafeModelName(string name, ITypeSymbol dbContext)
155156 return safeName ;
156157 }
157158
158- public EditSyntaxTreeResult EditStartupForNewContext ( ModelType startUp , string dbContextTypeName , string dbContextNamespace , string dataBaseName , bool useSqlite )
159+ public EditSyntaxTreeResult EditStartupForNewContext (
160+ ModelType startUp ,
161+ string dbContextTypeName ,
162+ string dbContextNamespace ,
163+ string dataBaseName ,
164+ bool useSqlite ,
165+ bool useTopLevelStatements )
159166 {
160167 Contract . Assert ( startUp != null && startUp . TypeSymbol != null ) ;
161168 Contract . Assert ( ! String . IsNullOrEmpty ( dbContextTypeName ) ) ;
@@ -169,12 +176,13 @@ public EditSyntaxTreeResult EditStartupForNewContext(ModelType startUp, string d
169176
170177 var startUpClassNode = rootNode . FindNode ( declarationReference . Span ) ;
171178
172- var configServicesMethod = startUpClassNode . ChildNodes ( )
173- . FirstOrDefault ( n => n is MethodDeclarationSyntax
174- && ( ( MethodDeclarationSyntax ) n ) . Identifier . ToString ( ) == ConfigureServices ) as MethodDeclarationSyntax ;
175179 var configRootProperty = TryGetIConfigurationRootProperty ( startUp . TypeSymbol ) ;
176180 //if using Startup.cs, the ConfigureServices method should exist.
177- if ( configServicesMethod != null && configRootProperty != null )
181+ if ( startUpClassNode . ChildNodes ( )
182+ . FirstOrDefault ( n =>
183+ n is MethodDeclarationSyntax syntax &&
184+ syntax . Identifier . ToString ( ) == ConfigureServices )
185+ is MethodDeclarationSyntax configServicesMethod && configRootProperty != null )
178186 {
179187 var servicesParam = configServicesMethod . ParameterList . Parameters
180188 . FirstOrDefault ( p => p . Type . ToString ( ) . Equals ( IServiceCollection ) ) ;
@@ -217,46 +225,56 @@ public EditSyntaxTreeResult EditStartupForNewContext(ModelType startUp, string d
217225 //minimal hosting scenario
218226 else
219227 {
220- CompilationUnitSyntax classSyntax = startUpClassNode as CompilationUnitSyntax ;
221- if ( classSyntax != null )
228+ var statementLeadingTrivia = string . Empty ;
229+ StatementSyntax dbContextExpression = null ;
230+ var compilationSyntax = rootNode as CompilationUnitSyntax ;
231+ if ( ! useTopLevelStatements )
222232 {
223- //get leading trivia. there should be atleast one member
224- var statementLeadingTrivia = classSyntax . Members . First ( ) ? . GetLeadingTrivia ( ) . ToString ( ) ;
225-
226- string textToAddAtEnd = AddDbContextString ( minimalHostingTemplate : true , useSqlite , statementLeadingTrivia ) ;
227- _connectionStringsWriter . AddConnectionString ( dbContextTypeName , dataBaseName , useSqlite : useSqlite ) ;
228- textToAddAtEnd = Environment . NewLine + textToAddAtEnd ;
229-
230- //get builder identifier string, should exist
231- var builderExpression = classSyntax . Members . Where ( st => st . ToString ( ) . Contains ( WebApplicationCreateBuilder ) ) . FirstOrDefault ( ) ;
232- var builderIdentifierString = GetBuilderIdentifier ( builderExpression ) ;
233-
234- //create syntax expression that adds DbContext
235- //added InvalidOperationExceptino if Configuration.GetConnectionString returns null.
236- var expression = SyntaxFactory . ParseStatement ( string . Format ( textToAddAtEnd ,
237- string . Format ( "{0}.Services" , builderIdentifierString ) ,
238- dbContextTypeName ,
239- string . Format ( "{0}.Configuration" , builderIdentifierString ) ,
240- string . Format ( " ?? throw new InvalidOperationException(\" Connection string '{0}' not found.\" )" , dbContextTypeName ) ) ) ;
241- var dbContextExpression = SyntaxFactory . GlobalStatement ( expression ) ;
242-
243- //get global statement to insert after (different for web app vs web api)
244- var statementToInsertAfter = classSyntax . Members . Where ( st => st . ToString ( ) . Contains ( AddRazorPages ) ) . FirstOrDefault ( ) ;
245- if ( statementToInsertAfter == null )
246- {
247- statementToInsertAfter = classSyntax . Members . Where ( st => st . ToString ( ) . Contains ( CreateBuilder ) ) . FirstOrDefault ( ) ;
248- }
249-
250- var newClassSyntax = classSyntax . InsertNodesAfter ( statementToInsertAfter , new List < GlobalStatementSyntax > ( ) { dbContextExpression } ) ;
251- var newRoot = rootNode . ReplaceNode ( classSyntax , newClassSyntax ) ;
233+ MethodDeclarationSyntax methodSyntax = DocumentBuilder . GetMethodFromSyntaxRoot ( compilationSyntax , Main ) ;
234+ dbContextExpression = GetAddDbContextStatement ( methodSyntax . Body , dbContextTypeName , dbContextNamespace , useSqlite ) ;
235+ }
236+ else if ( useTopLevelStatements )
237+ {
238+ dbContextExpression = GetAddDbContextStatement ( compilationSyntax , dbContextTypeName , dbContextNamespace , useSqlite ) ;
239+ }
252240
241+ if ( statementLeadingTrivia != null && dbContextExpression != null )
242+ {
243+ var newRoot = compilationSyntax ;
253244 //add additional namespaces
254245 var namespacesToAdd = new [ ] { "Microsoft.EntityFrameworkCore" , "Microsoft.Extensions.DependencyInjection" , dbContextNamespace } ;
255246 foreach ( var namespaceName in namespacesToAdd )
256247 {
257- newRoot = RoslynCodeEditUtilities . AddUsingDirectiveIfNeeded ( namespaceName , newRoot as CompilationUnitSyntax ) ;
248+ newRoot = RoslynCodeEditUtilities . AddUsingDirectiveIfNeeded ( namespaceName , newRoot ) ;
249+ }
250+ if ( ! useTopLevelStatements )
251+ {
252+ MethodDeclarationSyntax methodSyntax = DocumentBuilder . GetMethodFromSyntaxRoot ( newRoot , Main ) ;
253+ var modifiedBlock = methodSyntax . Body ;
254+ var statementToInsertAround = methodSyntax . Body . ChildNodes ( ) . Where ( st => st . ToString ( ) . Contains ( AddRazorPages ) ) . FirstOrDefault ( ) ;
255+ if ( statementToInsertAround == null )
256+ {
257+ statementToInsertAround = methodSyntax . Body . ChildNodes ( ) . Where ( st => st . ToString ( ) . Contains ( CreateBuilder ) ) . FirstOrDefault ( ) ;
258+ modifiedBlock = methodSyntax . Body . InsertNodesAfter ( statementToInsertAround , new List < StatementSyntax > ( ) { dbContextExpression } ) ;
259+ }
260+ else
261+ {
262+ modifiedBlock = methodSyntax . Body . InsertNodesBefore ( statementToInsertAround , new List < StatementSyntax > ( ) { dbContextExpression } ) ;
263+ }
264+ var modifiedMethod = methodSyntax . WithBody ( modifiedBlock ) ;
265+ newRoot = newRoot . ReplaceNode ( methodSyntax , modifiedMethod ) ;
258266 }
267+ else
268+ {
269+ var statementToInsertAfter = newRoot . Members . Where ( st => st . ToString ( ) . Contains ( AddRazorPages ) ) . FirstOrDefault ( ) ;
270+ if ( statementToInsertAfter == null )
271+ {
272+ statementToInsertAfter = newRoot . Members . Where ( st => st . ToString ( ) . Contains ( CreateBuilder ) ) . FirstOrDefault ( ) ;
273+ }
259274
275+ newRoot = newRoot . InsertNodesAfter ( statementToInsertAfter , new List < GlobalStatementSyntax > ( ) { SyntaxFactory . GlobalStatement ( dbContextExpression ) } ) ;
276+ }
277+
260278 return new EditSyntaxTreeResult ( )
261279 {
262280 Edited = true ,
@@ -273,6 +291,38 @@ public EditSyntaxTreeResult EditStartupForNewContext(ModelType startUp, string d
273291 } ;
274292 }
275293
294+ /// <summary>
295+ /// Get the StatementSyntax that adds the db context to the WebApplicationBuilder.
296+ /// </summary>
297+ /// <param name="rootNode">Using the base class to allow this var to be either CompilationUnitSyntax or a MethodBodySyntax
298+ /// To get the WebApplicationBuilder variable name
299+ /// </param>
300+ /// <param name="dbContextTypeName"></param>
301+ /// <param name="dataBaseName"></param>
302+ /// <param name="useSqlite"></param>
303+ internal StatementSyntax GetAddDbContextStatement ( SyntaxNode rootNode , string dbContextTypeName , string dataBaseName , bool useSqlite )
304+ {
305+ //get leading trivia. there should be atleast one member var statementLeadingTrivia = classSyntax.ChildNodes()
306+ var statementLeadingTrivia = rootNode . ChildNodes ( ) . First ( ) ? . GetLeadingTrivia ( ) . ToString ( ) ;
307+ string textToAddAtEnd = AddDbContextString ( minimalHostingTemplate : true , useSqlite , statementLeadingTrivia ) ;
308+ _connectionStringsWriter . AddConnectionString ( dbContextTypeName , dataBaseName , useSqlite : useSqlite ) ;
309+ textToAddAtEnd = Environment . NewLine + textToAddAtEnd ;
310+
311+ //get builder identifier string, should exist
312+ var builderExpression = rootNode . ChildNodes ( ) . Where ( st => st . ToString ( ) . Contains ( WebApplicationCreateBuilder ) ) . FirstOrDefault ( ) as MemberDeclarationSyntax ;
313+ var builderIdentifierString = GetBuilderIdentifier ( builderExpression ) ;
314+
315+ //create syntax expression that adds DbContext
316+ //added InvalidOperationExceptino if Configuration.GetConnectionString returns null.
317+ var expression = SyntaxFactory . ParseStatement ( string . Format ( textToAddAtEnd ,
318+ string . Format ( "{0}.Services" , builderIdentifierString ) ,
319+ dbContextTypeName ,
320+ string . Format ( "{0}.Configuration" , builderIdentifierString ) ,
321+ string . Format ( " ?? throw new InvalidOperationException(\" Connection string '{0}' not found.\" )" , dbContextTypeName ) ) ) . WithLeadingTrivia ( SyntaxFactory . Whitespace ( statementLeadingTrivia ) ) ;
322+
323+ return expression ;
324+ }
325+
276326 private string GetBuilderIdentifier ( MemberDeclarationSyntax builderMember )
277327 {
278328 if ( builderMember != null )
0 commit comments