diff --git a/eng/Versions.props b/eng/Versions.props index 1626fcd14..b82dc386c 100644 --- a/eng/Versions.props +++ b/eng/Versions.props @@ -51,7 +51,7 @@ 7.0.1 - 7.0.1 + 7.0.2 rtm False true diff --git a/scripts/install-aspnet-codegenerator.cmd b/scripts/install-aspnet-codegenerator.cmd index 806f87096..d6a1ee8f6 100644 --- a/scripts/install-aspnet-codegenerator.cmd +++ b/scripts/install-aspnet-codegenerator.cmd @@ -1,4 +1,4 @@ -set VERSION=7.0.0-dev +set VERSION=8.0.0-dev set DEFAULT_NUPKG_PATH=%userprofile%\.nuget\packages set SRC_DIR=%cd% set NUPKG=artifacts/packages/Debug/Shipping/ diff --git a/src/MSIdentityScaffolding/Microsoft.DotNet.MSIdentity/DeveloperCredentials/MsalTokenCredential.cs b/src/MSIdentityScaffolding/Microsoft.DotNet.MSIdentity/DeveloperCredentials/MsalTokenCredential.cs index 620ae32b8..9817594cc 100644 --- a/src/MSIdentityScaffolding/Microsoft.DotNet.MSIdentity/DeveloperCredentials/MsalTokenCredential.cs +++ b/src/MSIdentityScaffolding/Microsoft.DotNet.MSIdentity/DeveloperCredentials/MsalTokenCredential.cs @@ -84,65 +84,106 @@ private async Task GetOrCreateApp() public override async ValueTask GetTokenAsync(TokenRequestContext requestContext, CancellationToken cancellationToken) { var app = await GetOrCreateApp(); - AuthenticationResult? result = null; var accounts = await app.GetAccountsAsync()!; - IAccount? account; + IAccount? account = string.IsNullOrEmpty(Username) + ? accounts.FirstOrDefault() + : accounts.FirstOrDefault(account => string.Equals(account.Username, Username, StringComparison.OrdinalIgnoreCase)); - if (!string.IsNullOrEmpty(Username)) - { - account = accounts.FirstOrDefault(account => account.Username == Username); - } - else + AuthenticationResult? result = account is null + ? await GetAuthenticationWithoutAccount(requestContext.Scopes, app, cancellationToken) + : await GetAuthenticationWithAccount(requestContext.Scopes, app, account, cancellationToken); + + if (result is null || result.AccessToken is null) { - account = accounts.FirstOrDefault(); + _consoleLogger.LogFailureAndExit(Resources.FailedToAcquireToken); } + + // Note: In the future, the token type *could* be POP instead of Bearer + return new AccessToken(result!.AccessToken!, result.ExpiresOn); + } + + private async Task GetAuthenticationWithAccount(string[] scopes, IPublicClientApplication app, IAccount? account, CancellationToken cancellationToken) + { + AuthenticationResult? result = null; try { - result = await app.AcquireTokenSilent(requestContext.Scopes, account) + result = await app.AcquireTokenSilent(scopes, account) .WithAuthority(Instance, TenantId) .ExecuteAsync(cancellationToken); } catch (MsalUiRequiredException ex) { - if (account == null && !string.IsNullOrEmpty(Username)) + try { - _consoleLogger.LogFailureAndExit( - $"No valid tokens found in the cache.\n" + - $"Please sign-in to Visual Studio with this account: {Username}.\n\n" + - $"After signing-in, re-run the tool."); + result = await app.AcquireTokenInteractive(scopes) + .WithAccount(account) + .WithClaims(ex.Claims) + .WithAuthority(Instance, TenantId) + .WithUseEmbeddedWebView(false) + .ExecuteAsync(cancellationToken); + } + catch (Exception e) + { + _consoleLogger.LogFailureAndExit(string.Join(Environment.NewLine, Resources.SignInError, e.Message)); } - result = await app.AcquireTokenInteractive(requestContext.Scopes) - .WithAccount(account) - .WithClaims(ex.Claims) - .WithAuthority(Instance, TenantId) - .ExecuteAsync(cancellationToken); } catch (MsalServiceException ex) { // AAD error codes: https://learn.microsoft.com/en-us/azure/active-directory/develop/reference-aadsts-error-codes - if (ex.Message.Contains("AADSTS70002")) // "The client does not exist or is not enabled for consumers" - { - // We want to exit here because this is probably an MSA without an AAD tenant. - _consoleLogger.LogFailureAndExit( - "An Azure AD tenant, and a user in that tenant, " + - "needs to be created for this account before an application can be created. " + - "See https://aka.ms/ms-identity-app/create-a-tenant. "); - } + var errorMessage = ex.Message.Contains("AADSTS70002") // "The client does not exist or is not enabled for consumers" + ? Resources.ClientDoesNotExist + : string.Join(Environment.NewLine, Resources.SignInError, ex.Message); // we want to exit here. Re-sign in will not resolve the issue. - _consoleLogger.LogFailureAndExit(string.Join(Environment.NewLine, Resources.SignInError, ex.Message)); + _consoleLogger.LogFailureAndExit(errorMessage); } catch (Exception ex) { _consoleLogger.LogFailureAndExit(string.Join(Environment.NewLine, Resources.SignInError, ex.Message)); } - if (result is null) + return result; + } + + /// + /// If there are no matching accounts in the msal cache, we need to make a call to AcquireTokenInteractive in order to populate the cache. + /// + /// + /// + /// + /// + /// + private async Task GetAuthenticationWithoutAccount(string[] scopes, IPublicClientApplication app, CancellationToken cancellationToken) + { + AuthenticationResult? result = null; + try { - _consoleLogger.LogFailureAndExit(Resources.FailedToAcquireToken); + result = await app.AcquireTokenInteractive(scopes) + .WithAuthority(Instance, TenantId) + .WithUseEmbeddedWebView(false) + .ExecuteAsync(cancellationToken); + } + catch (MsalUiRequiredException ex) // Need to get Claims, hence the nested try/catch + { + try + { + result = await app.AcquireTokenInteractive(scopes) + .WithClaims(ex.Claims) + .WithAuthority(Instance, TenantId) + .WithUseEmbeddedWebView(false) + .ExecuteAsync(cancellationToken); + } + catch (Exception e) + { + _consoleLogger.LogFailureAndExit(string.Join(Environment.NewLine, Resources.SignInError, e.Message)); + } + } + catch (Exception e) + { + _consoleLogger.LogFailureAndExit(string.Join(Environment.NewLine, Resources.SignInError, e.Message)); } - return new AccessToken(result!.AccessToken, result.ExpiresOn); + return result; } } } diff --git a/src/MSIdentityScaffolding/Microsoft.DotNet.MSIdentity/Properties/Resources.Designer.cs b/src/MSIdentityScaffolding/Microsoft.DotNet.MSIdentity/Properties/Resources.Designer.cs index b07820705..35724925e 100644 --- a/src/MSIdentityScaffolding/Microsoft.DotNet.MSIdentity/Properties/Resources.Designer.cs +++ b/src/MSIdentityScaffolding/Microsoft.DotNet.MSIdentity/Properties/Resources.Designer.cs @@ -298,6 +298,15 @@ internal static string AuthNotEnabled { } } + /// + /// Looks up a localized string similar to An Azure AD tenant, and a user in that tenant, needs to be created for this account before an application can be created. See https://aka.ms/ms-identity-app/create-a-tenant.. + /// + internal static string ClientDoesNotExist { + get { + return ResourceManager.GetString("ClientDoesNotExist", resourceCulture); + } + } + /// /// Looks up a localized string similar to Client secret - {0}.. /// diff --git a/src/MSIdentityScaffolding/Microsoft.DotNet.MSIdentity/Properties/Resources.resx b/src/MSIdentityScaffolding/Microsoft.DotNet.MSIdentity/Properties/Resources.resx index 0dab9f7de..ad159bf5e 100644 --- a/src/MSIdentityScaffolding/Microsoft.DotNet.MSIdentity/Properties/Resources.resx +++ b/src/MSIdentityScaffolding/Microsoft.DotNet.MSIdentity/Properties/Resources.resx @@ -159,6 +159,9 @@ Authentication is not enabled yet in this project. An app registration will be created, but the tool does not add the code yet (work in progress). + + An Azure AD tenant, and a user in that tenant, needs to be created for this account before an application can be created. See https://aka.ms/ms-identity-app/create-a-tenant. + Client secret - {0}. diff --git a/src/Scaffolding/VS.Web.CG.EFCore/ConnectionStringsWriter.cs b/src/Scaffolding/VS.Web.CG.EFCore/ConnectionStringsWriter.cs index ed9aec6c3..fb6132d75 100644 --- a/src/Scaffolding/VS.Web.CG.EFCore/ConnectionStringsWriter.cs +++ b/src/Scaffolding/VS.Web.CG.EFCore/ConnectionStringsWriter.cs @@ -3,6 +3,7 @@ using System; using System.IO; +using Microsoft.DotNet.Scaffolding.Shared; using Microsoft.VisualStudio.Web.CodeGeneration.DotNet; using Newtonsoft.Json.Linq; @@ -10,9 +11,6 @@ namespace Microsoft.VisualStudio.Web.CodeGeneration.EntityFrameworkCore { public class ConnectionStringsWriter : IConnectionStringsWriter { - private const string SQLConnectionStringFormat = "Server=(localdb)\\mssqllocaldb;Database={0};Trusted_Connection=True;MultipleActiveResultSets=true"; - private const string SQLiteConnectionStringFormat = "Data Source={0}.db"; - private IApplicationInfo _applicationInfo; private IFileSystem _fileSystem; @@ -30,6 +28,11 @@ internal ConnectionStringsWriter( } public void AddConnectionString(string connectionStringName, string dataBaseName, bool useSqlite) + { + AddConnectionString(connectionStringName, dataBaseName, useSqlite ? DbProvider.SQLite : DbProvider.SqlServer); + } + + public void AddConnectionString(string connectionStringName, string databaseName, DbProvider databaseProvider) { var appSettingsFile = Path.Combine(_applicationInfo.ApplicationBasePath, "appsettings.json"); JObject content; @@ -55,13 +58,18 @@ public void AddConnectionString(string connectionStringName, string dataBaseName if (content[connectionStringNodeName][connectionStringName] == null) { - var connectionString = string.Format( - useSqlite ? SQLiteConnectionStringFormat : SQLConnectionStringFormat, - dataBaseName); - writeContent = true; - content[connectionStringNodeName][connectionStringName] = connectionString; + if (EfConstants.ConnectionStringsDict.TryGetValue(databaseProvider, out var connectionString)) + { + if (!databaseProvider.Equals(DbProvider.CosmosDb)) + { + connectionString = string.Format(connectionString, databaseName); + } + + writeContent = true; + content[connectionStringNodeName][connectionStringName] = connectionString; + } } - + // Json.Net loses comments so the above code if requires any changes loses // comments in the file. The writeContent bool is for saving // a specific case without losing comments - when no changes are needed. @@ -71,4 +79,4 @@ public void AddConnectionString(string connectionStringName, string dataBaseName } } } -} \ No newline at end of file +} diff --git a/src/Scaffolding/VS.Web.CG.EFCore/DbContextEditorServices.cs b/src/Scaffolding/VS.Web.CG.EFCore/DbContextEditorServices.cs index 760563691..b349c681c 100644 --- a/src/Scaffolding/VS.Web.CG.EFCore/DbContextEditorServices.cs +++ b/src/Scaffolding/VS.Web.CG.EFCore/DbContextEditorServices.cs @@ -11,9 +11,11 @@ using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.Text; +using Microsoft.DotNet.Scaffolding.Shared; using Microsoft.DotNet.Scaffolding.Shared.CodeModifier; using Microsoft.DotNet.Scaffolding.Shared.Project; using Microsoft.DotNet.Scaffolding.Shared.ProjectModel; +using Microsoft.VisualBasic; using Microsoft.VisualStudio.Web.CodeGeneration.DotNet; using Microsoft.VisualStudio.Web.CodeGeneration.Templating; using Newtonsoft.Json.Linq; @@ -101,45 +103,7 @@ private async Task AddNewContextItemsInternal(string templateName, N public EditSyntaxTreeResult AddModelToContext(ModelType dbContext, ModelType modelType, bool nullableEnabled) { - if (!IsModelPropertyExists(dbContext.TypeSymbol, modelType.FullName)) - { - // Todo : Consider using DeclaringSyntaxtReference - var sourceLocation = dbContext.TypeSymbol.Locations.Where(l => l.IsInSource).FirstOrDefault(); - if (sourceLocation != null) - { - var syntaxTree = sourceLocation.SourceTree; - var rootNode = syntaxTree.GetRoot(); - var dbContextNode = rootNode.FindNode(sourceLocation.SourceSpan); - var lastNode = dbContextNode.ChildNodes().Last(); - - var safeModelName = GetSafeModelName(modelType.Name, dbContext.TypeSymbol); - var nullabilityClause = nullableEnabled ? " = default!;" : ""; - // Todo : Need pluralization for property name below. - // It is not always safe to just use DbSet as there can be multiple class names in different namespaces. - var dbSetProperty = "public DbSet<" + modelType.FullName + "> " + safeModelName + " { get; set; }" + nullabilityClause + Environment.NewLine; - var propertyDeclarationWrapper = CSharpSyntaxTree.ParseText(dbSetProperty); - - var newNode = rootNode.InsertNodesAfter(lastNode, - propertyDeclarationWrapper.GetRoot().WithTriviaFrom(lastNode).ChildNodes()); - - newNode = RoslynCodeEditUtilities.AddUsingDirectiveIfNeeded("Microsoft.EntityFrameworkCore", newNode as CompilationUnitSyntax); //DbSet namespace - newNode = RoslynCodeEditUtilities.AddUsingDirectiveIfNeeded(modelType.Namespace, newNode as CompilationUnitSyntax); - - var modifiedTree = syntaxTree.WithRootAndOptions(newNode, syntaxTree.Options); - - return new EditSyntaxTreeResult() - { - Edited = true, - OldTree = syntaxTree, - NewTree = modifiedTree - }; - } - } - - return new EditSyntaxTreeResult() - { - Edited = false - }; + return AddModelToContext(dbContext, modelType, new Dictionary() { { nameof(nullableEnabled), nullableEnabled.ToString() }}); } private string GetSafeModelName(string name, ITypeSymbol dbContext) @@ -165,147 +129,35 @@ public EditSyntaxTreeResult EditStartupForNewContext( bool useTopLevelStatements) { Contract.Assert(startUp != null && startUp.TypeSymbol != null); - Contract.Assert(!String.IsNullOrEmpty(dbContextTypeName)); - Contract.Assert(!String.IsNullOrEmpty(dataBaseName)); + Contract.Assert(!string.IsNullOrEmpty(dbContextTypeName)); + Contract.Assert(!string.IsNullOrEmpty(dataBaseName)); - var declarationReference = startUp.TypeSymbol.DeclaringSyntaxReferences.FirstOrDefault(); - if (declarationReference != null) + var parameters = new Dictionary { - var sourceTree = declarationReference.SyntaxTree; - var rootNode = sourceTree.GetRoot(); - - var startUpClassNode = rootNode.FindNode(declarationReference.Span); - - var configRootProperty = TryGetIConfigurationRootProperty(startUp.TypeSymbol); - //if using Startup.cs, the ConfigureServices method should exist. - if (startUpClassNode.ChildNodes() - .FirstOrDefault(n => - n is MethodDeclarationSyntax syntax && - syntax.Identifier.ToString() == ConfigureServices) - is MethodDeclarationSyntax configServicesMethod && configRootProperty != null) - { - var servicesParam = configServicesMethod.ParameterList.Parameters - .FirstOrDefault(p => p.Type.ToString().Equals(IServiceCollection)); - - var statementLeadingTrivia = configServicesMethod.Body.OpenBraceToken.LeadingTrivia.ToString() + " "; - if (servicesParam != null) - { - string textToAddAtEnd = AddDbContextString(minimalHostingTemplate: false, useSqlite, statementLeadingTrivia); - _connectionStringsWriter.AddConnectionString(dbContextTypeName, dataBaseName, useSqlite: useSqlite); - if (configServicesMethod.Body.Statements.Any()) - { - textToAddAtEnd = Environment.NewLine + textToAddAtEnd; - } - - //string.Empty instead of InvalidOperationException for legacy scenarios. - var expression = SyntaxFactory.ParseStatement(string.Format(textToAddAtEnd, - servicesParam.Identifier, - dbContextTypeName, - configRootProperty.Name, - string.Empty)); - - MethodDeclarationSyntax newConfigServicesMethod = configServicesMethod.AddBodyStatements(expression); - - var newRoot = rootNode.ReplaceNode(configServicesMethod, newConfigServicesMethod); - - var namespacesToAdd = new[] { "Microsoft.EntityFrameworkCore", "Microsoft.Extensions.DependencyInjection", dbContextNamespace }; - foreach (var namespaceName in namespacesToAdd) - { - newRoot = RoslynCodeEditUtilities.AddUsingDirectiveIfNeeded(namespaceName, newRoot as CompilationUnitSyntax); - } - - return new EditSyntaxTreeResult() - { - Edited = true, - OldTree = sourceTree, - NewTree = sourceTree.WithRootAndOptions(newRoot, sourceTree.Options) - }; - } - } - //minimal hosting scenario - else - { - var statementLeadingTrivia = string.Empty; - StatementSyntax dbContextExpression = null; - var compilationSyntax = rootNode as CompilationUnitSyntax; - if (!useTopLevelStatements) - { - MethodDeclarationSyntax methodSyntax = DocumentBuilder.GetMethodFromSyntaxRoot(compilationSyntax, Main); - dbContextExpression = GetAddDbContextStatement(methodSyntax.Body, dbContextTypeName, dbContextNamespace, useSqlite); - } - else if(useTopLevelStatements) - { - dbContextExpression = GetAddDbContextStatement(compilationSyntax, dbContextTypeName, dbContextNamespace, useSqlite); - } - - if (statementLeadingTrivia != null && dbContextExpression != null) - { - var newRoot = compilationSyntax; - //add additional namespaces - var namespacesToAdd = new[] { "Microsoft.EntityFrameworkCore", "Microsoft.Extensions.DependencyInjection", dbContextNamespace }; - foreach (var namespaceName in namespacesToAdd) - { - newRoot = RoslynCodeEditUtilities.AddUsingDirectiveIfNeeded(namespaceName, newRoot); - } - if (!useTopLevelStatements) - { - MethodDeclarationSyntax methodSyntax = DocumentBuilder.GetMethodFromSyntaxRoot(newRoot, Main); - var modifiedBlock = methodSyntax.Body; - var statementToInsertAround = methodSyntax.Body.ChildNodes().Where(st => st.ToString().Contains(AddRazorPages)).FirstOrDefault(); - if (statementToInsertAround == null) - { - statementToInsertAround = methodSyntax.Body.ChildNodes().Where(st => st.ToString().Contains(CreateBuilder)).FirstOrDefault(); - modifiedBlock = methodSyntax.Body.InsertNodesAfter(statementToInsertAround, new List() { dbContextExpression }); - } - else - { - modifiedBlock = methodSyntax.Body.InsertNodesBefore(statementToInsertAround, new List() { dbContextExpression }); - } - var modifiedMethod = methodSyntax.WithBody(modifiedBlock); - newRoot = newRoot.ReplaceNode(methodSyntax, modifiedMethod); - } - else - { - var statementToInsertAfter = newRoot.Members.Where(st => st.ToString().Contains(AddRazorPages)).FirstOrDefault(); - if (statementToInsertAfter == null) - { - statementToInsertAfter = newRoot.Members.Where(st => st.ToString().Contains(CreateBuilder)).FirstOrDefault(); - } - - newRoot = newRoot.InsertNodesAfter(statementToInsertAfter, new List() { SyntaxFactory.GlobalStatement(dbContextExpression) }); - } - - return new EditSyntaxTreeResult() - { - Edited = true, - OldTree = sourceTree, - NewTree = sourceTree.WithRootAndOptions(newRoot, sourceTree.Options) - }; - } - } - } - - return new EditSyntaxTreeResult() - { - Edited = false + { nameof(NewDbContextTemplateModel.DbContextTypeName), dbContextTypeName }, + { nameof(NewDbContextTemplateModel.DbContextNamespace), dbContextNamespace }, + { "dataBaseName", dataBaseName}, + { "databaseProvider", useSqlite ? EfConstants.SQLite : EfConstants.SqlServer }, + { "useTopLevelStatements", useTopLevelStatements.ToString() } }; + + return EditStartupForNewContext(startUp, parameters); } /// /// Get the StatementSyntax that adds the db context to the WebApplicationBuilder. /// - /// Using the base class to allow this var to be either CompilationUnitSyntax or a MethodBodySyntax - /// To get the WebApplicationBuilder variable name + /// Using the base class to allow this var to be either CompilationUnitSyntax or a MethodBodySynta To get the WebApplicationBuilder variable name /// /// /// - /// - internal StatementSyntax GetAddDbContextStatement(SyntaxNode rootNode, string dbContextTypeName, string dataBaseName, bool useSqlite) + /// + internal StatementSyntax GetAddDbContextStatement(SyntaxNode rootNode, string dbContextTypeName, string dataBaseName, DbProvider dataContextTypeString) { //get leading trivia. there should be atleast one member var statementLeadingTrivia = classSyntax.ChildNodes() var statementLeadingTrivia = rootNode.ChildNodes().First()?.GetLeadingTrivia().ToString(); - string textToAddAtEnd = AddDbContextString(minimalHostingTemplate: true, useSqlite, statementLeadingTrivia); - _connectionStringsWriter.AddConnectionString(dbContextTypeName, dataBaseName, useSqlite: useSqlite); + string textToAddAtEnd = AddDbContextString(minimalHostingTemplate: true, statementLeadingTrivia, dataContextTypeString); + _connectionStringsWriter.AddConnectionString(dbContextTypeName, dataBaseName, dataContextTypeString); textToAddAtEnd = Environment.NewLine + textToAddAtEnd; //get builder identifier string, should exist @@ -345,24 +197,41 @@ private string GetBuilderIdentifier(MemberDeclarationSyntax builderMember) } return "builder"; } - - private string AddDbContextString(bool minimalHostingTemplate, bool useSqlite, string statementLeadingTrivia) + + internal string AddDbContextString(bool minimalHostingTemplate, string statementLeadingTrivia, DbProvider databaseProvider) { - string textToAddAtEnd; + string textToAddAtEnd = string.Empty; string additionalNewline = Environment.NewLine; string additionalLeadingTrivia = minimalHostingTemplate ? string.Empty : " "; string leadingTrivia = minimalHostingTemplate ? string.Empty : statementLeadingTrivia; - if (useSqlite) - { - textToAddAtEnd = - leadingTrivia + "{0}.AddDbContext<{1}>(options =>" + additionalNewline + - statementLeadingTrivia + additionalLeadingTrivia + " options.UseSqlite({2}.GetConnectionString(\"{1}\"){3}));" + Environment.NewLine; - } - else + switch (databaseProvider) { - textToAddAtEnd = - leadingTrivia + "{0}.AddDbContext<{1}>(options =>" + additionalNewline + - statementLeadingTrivia + additionalLeadingTrivia + " options.UseSqlServer({2}.GetConnectionString(\"{1}\"){3}));" + Environment.NewLine; + case DbProvider.SQLite: + textToAddAtEnd = + leadingTrivia + "{0}.AddDbContext<{1}>(options =>" + additionalNewline + + statementLeadingTrivia + additionalLeadingTrivia + " options.UseSqlite({2}.GetConnectionString(\"{1}\"){3}));" + Environment.NewLine; + break; + + case DbProvider.SqlServer: + textToAddAtEnd = + leadingTrivia + "{0}.AddDbContext<{1}>(options =>" + additionalNewline + + statementLeadingTrivia + additionalLeadingTrivia + " options.UseSqlServer({2}.GetConnectionString(\"{1}\"){3}));" + Environment.NewLine; + break; + + case DbProvider.CosmosDb: + textToAddAtEnd = + leadingTrivia + "{0}.AddDbContext<{1}>(options =>" + additionalNewline + + statementLeadingTrivia + additionalLeadingTrivia + " options.UseCosmos({2}.GetConnectionString(\"{1}\"), \"DATABASE_NAME\"));" + Environment.NewLine; + break; + + case DbProvider.Postgres: + textToAddAtEnd = + leadingTrivia + "{0}.AddDbContext<{1}>(options =>" + additionalNewline + + statementLeadingTrivia + additionalLeadingTrivia + " options.UseNpgsql({2}.GetConnectionString(\"{1}\"){3}));" + Environment.NewLine; + break; + + default: + break; } return textToAddAtEnd; } @@ -467,6 +336,193 @@ private bool IsModelPropertyExistsOnSymbol(ITypeSymbol dbContext, string modelTy return false; } + public EditSyntaxTreeResult AddModelToContext(ModelType dbContext, ModelType modelType, IDictionary parameters) + { + if (!IsModelPropertyExists(dbContext.TypeSymbol, modelType.FullName)) + { + // Todo : Consider using DeclaringSyntaxtReference + var sourceLocation = dbContext.TypeSymbol.Locations.Where(l => l.IsInSource).FirstOrDefault(); + if (sourceLocation != null) + { + var syntaxTree = sourceLocation.SourceTree; + var rootNode = syntaxTree.GetRoot(); + var dbContextNode = rootNode.FindNode(sourceLocation.SourceSpan); + var lastNode = dbContextNode.ChildNodes().Last(); + + var safeModelName = GetSafeModelName(modelType.Name, dbContext.TypeSymbol); + parameters.TryGetValue("nullableEnabled", out var nullableEnabled); + var nullableEnabledBool = string.IsNullOrEmpty(nullableEnabled) ? true : nullableEnabled.Equals(bool.TrueString, StringComparison.OrdinalIgnoreCase); + var nullabilityClause = nullableEnabledBool ? " = default!;" : ""; + // Todo : Need pluralization for property name below. + // It is not always safe to just use DbSet as there can be multiple class names in different namespaces. + var dbSetProperty = "public DbSet<" + modelType.FullName + "> " + safeModelName + " { get; set; }" + nullabilityClause + Environment.NewLine; + var propertyDeclarationWrapper = CSharpSyntaxTree.ParseText(dbSetProperty); + + var newNode = rootNode.InsertNodesAfter(lastNode, + propertyDeclarationWrapper.GetRoot().WithTriviaFrom(lastNode).ChildNodes()); + + newNode = RoslynCodeEditUtilities.AddUsingDirectiveIfNeeded("Microsoft.EntityFrameworkCore", newNode as CompilationUnitSyntax); //DbSet namespace + newNode = RoslynCodeEditUtilities.AddUsingDirectiveIfNeeded(modelType.Namespace, newNode as CompilationUnitSyntax); + + var modifiedTree = syntaxTree.WithRootAndOptions(newNode, syntaxTree.Options); + + return new EditSyntaxTreeResult() + { + Edited = true, + OldTree = syntaxTree, + NewTree = modifiedTree + }; + } + } + + return new EditSyntaxTreeResult() + { + Edited = false + }; + } + + public EditSyntaxTreeResult EditStartupForNewContext(ModelType startUp, IDictionary parameters) + { + var declarationReference = startUp.TypeSymbol.DeclaringSyntaxReferences.FirstOrDefault(); + if (declarationReference != null) + { + //get all params + parameters.TryGetValue(nameof(NewDbContextTemplateModel.DbContextTypeName), out var dbContextTypeName); + parameters.TryGetValue("dataBaseName", out var dataBaseName); + parameters.TryGetValue("databaseProvider", out var dataContextTypeString); + DbProvider dataContextType = DbProvider.SqlServer; + if (Enum.TryParse(typeof(DbProvider), dataContextTypeString, ignoreCase:true, out var dataContextTypeObj)) + { + dataContextType = (DbProvider)dataContextTypeObj; + } + parameters.TryGetValue("useTopLevelStatements", out var useTopLevelStatementsString); + var useTopLevelStatements = useTopLevelStatementsString.Equals(bool.TrueString, StringComparison.OrdinalIgnoreCase); + parameters.TryGetValue(nameof(NewDbContextTemplateModel.DbContextNamespace), out var dbContextNamespace); + parameters.TryGetValue("dataBaseName", out var databaseName); + Contract.Assert(!string.IsNullOrEmpty(dbContextTypeName)); + Contract.Assert(!string.IsNullOrEmpty(dataBaseName)); + + //continue if got the prerequistite variables + var sourceTree = declarationReference.SyntaxTree; + var rootNode = sourceTree.GetRoot(); + + var startUpClassNode = rootNode.FindNode(declarationReference.Span); + + var configRootProperty = TryGetIConfigurationRootProperty(startUp.TypeSymbol); + //if using Startup.cs, the ConfigureServices method should exist. + if (startUpClassNode.ChildNodes() + .FirstOrDefault(n => + n is MethodDeclarationSyntax syntax && + syntax.Identifier.ToString() == ConfigureServices) + is MethodDeclarationSyntax configServicesMethod && configRootProperty != null) + { + var servicesParam = configServicesMethod.ParameterList.Parameters + .FirstOrDefault(p => p.Type.ToString().Equals(IServiceCollection)); + + var statementLeadingTrivia = configServicesMethod.Body.OpenBraceToken.LeadingTrivia.ToString() + " "; + if (servicesParam != null) + { + string textToAddAtEnd = AddDbContextString(minimalHostingTemplate: false, statementLeadingTrivia, dataContextType); + _connectionStringsWriter.AddConnectionString(dbContextTypeName, databaseName, dataContextType); + if (configServicesMethod.Body.Statements.Any()) + { + textToAddAtEnd = Environment.NewLine + textToAddAtEnd; + } + + //string.Empty instead of InvalidOperationException for legacy scenarios. + var expression = SyntaxFactory.ParseStatement(string.Format(textToAddAtEnd, + servicesParam.Identifier, + dbContextTypeName, + configRootProperty.Name, + string.Empty)); + + MethodDeclarationSyntax newConfigServicesMethod = configServicesMethod.AddBodyStatements(expression); + + var newRoot = rootNode.ReplaceNode(configServicesMethod, newConfigServicesMethod); + + var namespacesToAdd = new[] { "Microsoft.EntityFrameworkCore", "Microsoft.Extensions.DependencyInjection", dbContextNamespace }; + foreach (var namespaceName in namespacesToAdd) + { + newRoot = RoslynCodeEditUtilities.AddUsingDirectiveIfNeeded(namespaceName, newRoot as CompilationUnitSyntax); + } + + return new EditSyntaxTreeResult() + { + Edited = true, + OldTree = sourceTree, + NewTree = sourceTree.WithRootAndOptions(newRoot, sourceTree.Options) + }; + } + } + //minimal hosting scenario + else + { + var statementLeadingTrivia = string.Empty; + StatementSyntax dbContextExpression = null; + var compilationSyntax = rootNode as CompilationUnitSyntax; + if (!useTopLevelStatements) + { + MethodDeclarationSyntax methodSyntax = DocumentBuilder.GetMethodFromSyntaxRoot(compilationSyntax, Main); + dbContextExpression = GetAddDbContextStatement(methodSyntax.Body, dbContextTypeName, databaseName, dataContextType); + } + else if (useTopLevelStatements) + { + dbContextExpression = GetAddDbContextStatement(compilationSyntax, dbContextTypeName, databaseName, dataContextType); + } + + if (statementLeadingTrivia != null && dbContextExpression != null) + { + var newRoot = compilationSyntax; + //add additional namespaces + var namespacesToAdd = new[] { "Microsoft.EntityFrameworkCore", "Microsoft.Extensions.DependencyInjection", dbContextNamespace }; + foreach (var namespaceName in namespacesToAdd) + { + newRoot = RoslynCodeEditUtilities.AddUsingDirectiveIfNeeded(namespaceName, newRoot); + } + if (!useTopLevelStatements) + { + MethodDeclarationSyntax methodSyntax = DocumentBuilder.GetMethodFromSyntaxRoot(newRoot, Main); + var modifiedBlock = methodSyntax.Body; + var statementToInsertAround = methodSyntax.Body.ChildNodes().Where(st => st.ToString().Contains(AddRazorPages)).FirstOrDefault(); + if (statementToInsertAround == null) + { + statementToInsertAround = methodSyntax.Body.ChildNodes().Where(st => st.ToString().Contains(CreateBuilder)).FirstOrDefault(); + modifiedBlock = methodSyntax.Body.InsertNodesAfter(statementToInsertAround, new List() { dbContextExpression }); + } + else + { + modifiedBlock = methodSyntax.Body.InsertNodesBefore(statementToInsertAround, new List() { dbContextExpression }); + } + var modifiedMethod = methodSyntax.WithBody(modifiedBlock); + newRoot = newRoot.ReplaceNode(methodSyntax, modifiedMethod); + } + else + { + var statementToInsertAfter = newRoot.Members.Where(st => st.ToString().Contains(AddRazorPages)).FirstOrDefault(); + if (statementToInsertAfter == null) + { + statementToInsertAfter = newRoot.Members.Where(st => st.ToString().Contains(CreateBuilder)).FirstOrDefault(); + } + + newRoot = newRoot.InsertNodesAfter(statementToInsertAfter, new List() { SyntaxFactory.GlobalStatement(dbContextExpression) }); + } + + return new EditSyntaxTreeResult() + { + Edited = true, + OldTree = sourceTree, + NewTree = sourceTree.WithRootAndOptions(newRoot, sourceTree.Options) + }; + } + } + } + + return new EditSyntaxTreeResult() + { + Edited = false + }; + } + private IEnumerable TemplateFolders { get diff --git a/src/Scaffolding/VS.Web.CG.EFCore/EntityFrameworkModelProcessor.cs b/src/Scaffolding/VS.Web.CG.EFCore/EntityFrameworkModelProcessor.cs index 9854fea86..f0781e03c 100644 --- a/src/Scaffolding/VS.Web.CG.EFCore/EntityFrameworkModelProcessor.cs +++ b/src/Scaffolding/VS.Web.CG.EFCore/EntityFrameworkModelProcessor.cs @@ -8,7 +8,6 @@ using System.IO; using System.Linq; using System.Reflection; -using System.Text.RegularExpressions; using System.Threading.Tasks; using Microsoft.CodeAnalysis; using Microsoft.EntityFrameworkCore; @@ -18,7 +17,6 @@ using Microsoft.DotNet.Scaffolding.Shared.ProjectModel; using Microsoft.VisualStudio.Web.CodeGeneration.DotNet; using Microsoft.DotNet.Scaffolding.Shared.Project; -using System.Collections; namespace Microsoft.VisualStudio.Web.CodeGeneration.EntityFrameworkCore { @@ -27,7 +25,7 @@ internal class EntityFrameworkModelProcessor private const string EFSqlServerPackageName = "Microsoft.EntityFrameworkCore.SqlServer"; private const string MySqlException = nameof(MySqlException); private const string NewDbContextFolderName = "Data"; - private bool _useSqlite; + private DbProvider _databaseProvider; private string _dbContextFullTypeName; private ModelType _modelTypeSymbol; private string _areaName; @@ -50,7 +48,7 @@ public EntityFrameworkModelProcessor ( string dbContextFullTypeName, ModelType modelTypeSymbol, string areaName, - bool useSqlite, + DbProvider databaseProvider, ICodeGenAssemblyLoadContext loader, IDbContextEditorServices dbContextEditorServices, IModelTypesLocator modelTypesLocator, @@ -76,8 +74,7 @@ public EntityFrameworkModelProcessor ( _applicationInfo = applicationInfo; _fileSystem = fileSystem; _workspace = workspace; - _useSqlite = useSqlite; - + _databaseProvider = databaseProvider; _assemblyAttributeGenerator = GetAssemblyAttributeGenerator(); } @@ -101,7 +98,6 @@ public async Task Process() if (!dbContextSymbols.Any()) { - //add nullable properties await GenerateNewDbContextAndRegisterProgramFile(programType, _applicationInfo); } else if (TryGetDbContextSymbolInWebProject(dbContextSymbols, out dbContextSymbolInWebProject)) @@ -303,7 +299,7 @@ private async Task EnsureDbContextInLibraryIsValid(ModelType dbContextSymbol) private async Task AddModelTypeToExistingDbContextIfNeeded(ModelType dbContextSymbol, IApplicationInfo appInfo) { bool nullabledEnabled = "enable".Equals(_projectContext.Nullable, StringComparison.OrdinalIgnoreCase); - var addResult = _dbContextEditorServices.AddModelToContext(dbContextSymbol, _modelTypeSymbol, nullabledEnabled); + var addResult = _dbContextEditorServices.AddModelToContext(dbContextSymbol, _modelTypeSymbol, new Dictionary { { "nullableEnabled", nullabledEnabled.ToString()} }); var projectCompilation = await _workspace.CurrentSolution.Projects .First(project => project.AssemblyName == _projectContext.AssemblyName) .GetCompilationAsync(); @@ -371,10 +367,9 @@ private async Task GenerateNewDbContextAndRegisterProgramFile(ModelType programT Edited = false }; - if (!_useSqlite) - { - ValidateEFSqlServerDependency(); - } + // Validate for necessary ef packages (based on database type) + EFValidationUtil.ValidateEFDependencies(_projectContext.PackageDependencies, _databaseProvider); + // Create a new Context _logger.LogMessage(string.Format(MessageStrings.GeneratingDbContext, _dbContextFullTypeName)); bool nullabledEnabled = "enable".Equals(_projectContext.Nullable, StringComparison.OrdinalIgnoreCase); @@ -385,13 +380,16 @@ private async Task GenerateNewDbContextAndRegisterProgramFile(ModelType programT if (programType != null) { - _programEditResult = _dbContextEditorServices.EditStartupForNewContext( - programType, - dbContextTemplateModel.DbContextTypeName, - dbContextTemplateModel.DbContextNamespace, - dataBaseName: dbContextTemplateModel.DbContextTypeName + "-" + Guid.NewGuid().ToString(), - _useSqlite, - useTopLevelsStatements); + var parameters = new Dictionary + { + { nameof(NewDbContextTemplateModel.DbContextTypeName), dbContextTemplateModel.DbContextTypeName }, + { nameof(NewDbContextTemplateModel.DbContextNamespace), dbContextTemplateModel.DbContextNamespace }, + { "dataBaseName", dbContextTemplateModel.DbContextTypeName + "-" + Guid.NewGuid().ToString()}, + { "databaseProvider", _databaseProvider.ToString() }, + { "useTopLevelStatements", useTopLevelsStatements.ToString() } + }; + + _programEditResult = _dbContextEditorServices.EditStartupForNewContext(programType, parameters); } if (!_programEditResult.Edited) @@ -440,10 +438,9 @@ private async Task GenerateNewDbContextAndRegister(ModelType startupType, ModelT Edited = false }; - if (!_useSqlite) - { - ValidateEFSqlServerDependency(); - } + // Validate for necessary ef packages (based on database type) + EFValidationUtil.ValidateEFDependencies(_projectContext.PackageDependencies, _databaseProvider); + // Create a new Context _logger.LogMessage(string.Format(MessageStrings.GeneratingDbContext, _dbContextFullTypeName)); bool nullabledEnabled = "enable".Equals(_projectContext.Nullable, StringComparison.OrdinalIgnoreCase); @@ -454,12 +451,16 @@ private async Task GenerateNewDbContextAndRegister(ModelType startupType, ModelT bool useTopLevelsStatements = await ProjectModifierHelper.IsUsingTopLevelStatements(_modelTypesLocator); if (startupType != null) { - _startupEditResult = _dbContextEditorServices.EditStartupForNewContext(startupType, - dbContextTemplateModel.DbContextTypeName, - dbContextTemplateModel.DbContextNamespace, - dataBaseName: dbContextTemplateModel.DbContextTypeName + "-" + Guid.NewGuid().ToString(), - _useSqlite, - useTopLevelsStatements); + var parameters = new Dictionary + { + { nameof(NewDbContextTemplateModel.DbContextTypeName), dbContextTemplateModel.DbContextTypeName }, + { nameof(NewDbContextTemplateModel.DbContextNamespace), dbContextTemplateModel.DbContextNamespace }, + { "dataBaseName", dbContextTemplateModel.DbContextTypeName + "-" + Guid.NewGuid().ToString()}, + { "databaseProvider", _databaseProvider.ToString() }, + { "useTopLevelStatements", useTopLevelsStatements.ToString() } + }; + + _startupEditResult = _dbContextEditorServices.EditStartupForNewContext(startupType, parameters); } if (!_startupEditResult.Edited) @@ -512,7 +513,7 @@ private ModelMetadata GetModelMetadata(Type dbContextType, Type modelType, Type } DbContext dbContextInstance = TryCreateContextUsingAppCode(dbContextType, dbContextType); - + Console.WriteLine($"\nUsing database provider '{dbContextInstance.Database.ProviderName}'!\n"); if (dbContextInstance == null) { throw new InvalidOperationException(string.Format( @@ -575,14 +576,6 @@ private DbContext TryCreateContextUsingAppCode(Type dbContextType, Type startupT } } - private void ValidateEFSqlServerDependency() - { - if (_projectContext.GetPackage(EFSqlServerPackageName) == null && CalledFromCommandline) - { - throw new InvalidOperationException(MessageStrings.EFSqlServerPackageNotAvailable); - } - } - private AssemblyAttributeGenerator GetAssemblyAttributeGenerator() { var originalAssembly = _loader.LoadFromName( diff --git a/src/Scaffolding/VS.Web.CG.EFCore/EntityFrameworkServices.cs b/src/Scaffolding/VS.Web.CG.EFCore/EntityFrameworkServices.cs index bbe8cfbea..59fb372e3 100644 --- a/src/Scaffolding/VS.Web.CG.EFCore/EntityFrameworkServices.cs +++ b/src/Scaffolding/VS.Web.CG.EFCore/EntityFrameworkServices.cs @@ -102,6 +102,11 @@ public EntityFrameworkServices( } public async Task GetModelMetadata(string dbContextFullTypeName, ModelType modelTypeSymbol, string areaName, bool useSqlite) + { + return await GetModelMetadata(dbContextFullTypeName, modelTypeSymbol, areaName, useSqlite ? DbProvider.SQLite : DbProvider.SqlServer); + } + + public async Task GetModelMetadata(string dbContextFullTypeName, ModelType modelTypeSymbol, string areaName, DbProvider databaseProvider) { if (string.IsNullOrEmpty(dbContextFullTypeName)) { @@ -111,7 +116,7 @@ public async Task GetModelMetadata(string dbContextFull var processor = new EntityFrameworkModelProcessor(dbContextFullTypeName, modelTypeSymbol, areaName, - useSqlite, + databaseProvider, _loader, _dbContextEditorServices, _modelTypesLocator, diff --git a/src/Scaffolding/VS.Web.CG.EFCore/IConnectionStringsWriter.cs b/src/Scaffolding/VS.Web.CG.EFCore/IConnectionStringsWriter.cs index 4831484df..c9236c832 100644 --- a/src/Scaffolding/VS.Web.CG.EFCore/IConnectionStringsWriter.cs +++ b/src/Scaffolding/VS.Web.CG.EFCore/IConnectionStringsWriter.cs @@ -1,10 +1,14 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System; +using Microsoft.DotNet.Scaffolding.Shared; namespace Microsoft.VisualStudio.Web.CodeGeneration.EntityFrameworkCore { public interface IConnectionStringsWriter { + [Obsolete] void AddConnectionString(string connectionStringName, string dataBaseName, bool useSqlite); + void AddConnectionString(string connectionStringName, string databaseName, DbProvider databaseProvider); } } diff --git a/src/Scaffolding/VS.Web.CG.EFCore/IDbContextEditorServices.cs b/src/Scaffolding/VS.Web.CG.EFCore/IDbContextEditorServices.cs index b116889e2..83ccc5bc5 100644 --- a/src/Scaffolding/VS.Web.CG.EFCore/IDbContextEditorServices.cs +++ b/src/Scaffolding/VS.Web.CG.EFCore/IDbContextEditorServices.cs @@ -13,8 +13,13 @@ public interface IDbContextEditorServices { Task AddNewContext(NewDbContextTemplateModel dbContextTemplateModel); + [Obsolete] EditSyntaxTreeResult AddModelToContext(ModelType dbContext, ModelType modelType, bool nullableEnabled); + [Obsolete] EditSyntaxTreeResult EditStartupForNewContext(ModelType startup, string dbContextTypeName, string dbContextNamespace, string dataBaseName, bool useSqlite, bool useTopLevelStatements); + + EditSyntaxTreeResult AddModelToContext(ModelType dbContext, ModelType modelType, IDictionary parameters); + EditSyntaxTreeResult EditStartupForNewContext(ModelType startup, IDictionary parameters); } } diff --git a/src/Scaffolding/VS.Web.CG.EFCore/IEntityFrameworkService.cs b/src/Scaffolding/VS.Web.CG.EFCore/IEntityFrameworkService.cs index a03038e81..52d224249 100644 --- a/src/Scaffolding/VS.Web.CG.EFCore/IEntityFrameworkService.cs +++ b/src/Scaffolding/VS.Web.CG.EFCore/IEntityFrameworkService.cs @@ -1,14 +1,18 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -using Microsoft.DotNet.Scaffolding.Shared.Project; +using System; using System.Threading.Tasks; +using Microsoft.DotNet.Scaffolding.Shared; +using Microsoft.DotNet.Scaffolding.Shared.Project; namespace Microsoft.VisualStudio.Web.CodeGeneration.EntityFrameworkCore { public interface IEntityFrameworkService { - /// + [Obsolete] + Task GetModelMetadata(string dbContextFullTypeName, ModelType modelTypeName, string areaName, bool useSqlite); + /// /// Gets the EF metadata for given context and model. /// Method takes in full type name of context and if there is no context with that name, /// attempts to create one. When creating a context, the method also tries to modify Startup @@ -23,8 +27,8 @@ public interface IEntityFrameworkService /// Full name (including namespace) of the context class. /// Model type for which the EF metadata has to be returned. /// Name of the area on which scaffolding is being run. Used for generating path for new DbContext. - /// flag for using sqlite instead of sqlserver + /// enum DbProvider (default DbProvider.SqlServer) /// Returns . - Task GetModelMetadata(string dbContextFullTypeName, ModelType modelTypeName, string areaName, bool useSqlite); + Task GetModelMetadata(string dbContextFullTypeName, ModelType modelTypeName, string areaName, DbProvider databaseProvider); } } diff --git a/src/Scaffolding/VS.Web.CG.Mvc/Areas/AreaGenerator.cs b/src/Scaffolding/VS.Web.CG.Mvc/Areas/AreaGenerator.cs index 1459186a9..1d0f206b0 100644 --- a/src/Scaffolding/VS.Web.CG.Mvc/Areas/AreaGenerator.cs +++ b/src/Scaffolding/VS.Web.CG.Mvc/Areas/AreaGenerator.cs @@ -32,30 +32,10 @@ public AreaGenerator(IApplicationInfo applicationInfo, IModelTypesLocator modelTypesLocator, ILogger logger) { - if(serviceProvider == null) - { - throw new ArgumentNullException(nameof(serviceProvider)); - } - - if(applicationInfo == null) - { - throw new ArgumentNullException(nameof(applicationInfo)); - } - - if(logger == null) - { - throw new ArgumentNullException(nameof(logger)); - } - - if(modelTypesLocator == null) - { - throw new ArgumentNullException(nameof(modelTypesLocator)); - } - - _serviceProvider = serviceProvider; - _logger = logger; - _appInfo = applicationInfo; - _modelTypesLocator = modelTypesLocator; + _serviceProvider = serviceProvider ?? throw new ArgumentNullException(nameof(serviceProvider)); + _logger = logger ?? throw new ArgumentNullException(nameof(logger)); + _appInfo = applicationInfo ?? throw new ArgumentNullException(nameof(applicationInfo)); + _modelTypesLocator = modelTypesLocator ?? throw new ArgumentNullException(nameof(modelTypesLocator)); } public async Task GenerateCode(AreaGeneratorCommandLine model) diff --git a/src/Scaffolding/VS.Web.CG.Mvc/Common/CommonCommandLineModel.cs b/src/Scaffolding/VS.Web.CG.Mvc/Common/CommonCommandLineModel.cs index 6c0ac270b..437ff5d91 100644 --- a/src/Scaffolding/VS.Web.CG.Mvc/Common/CommonCommandLineModel.cs +++ b/src/Scaffolding/VS.Web.CG.Mvc/Common/CommonCommandLineModel.cs @@ -2,7 +2,9 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using Microsoft.DotNet.Scaffolding.Shared; using Microsoft.VisualStudio.Web.CodeGeneration.CommandLine; +using Microsoft.VisualStudio.Web.CodeGenerators.Mvc.MinimalApi; namespace Microsoft.VisualStudio.Web.CodeGenerators.Mvc { @@ -15,9 +17,14 @@ public abstract class CommonCommandLineModel [Option(Name = "dataContext", ShortName = "dc", Description = "DbContext class to use")] public string DataContextClass { get; set; } + [Obsolete("Use --databaseProvider or -dbProvider to configure database type instead")] [Option(Name = "useSqlite", ShortName ="sqlite", Description = "Flag to specify if DbContext should use SQLite instead of SQL Server.")] public bool UseSqlite { get; set; } + [Option(Name = "databaseProvider", ShortName = "dbProvider", Description = "Database provider to use. Options include 'sqlserver' (default), 'sqlite', 'cosmos', 'postgres'.")] + public string DatabaseProviderString { get; set; } + public DbProvider DatabaseProvider { get; set; } + [Option(Name = "referenceScriptLibraries", ShortName = "scripts", Description = "Switch to specify whether to reference script libraries in the generated views")] public bool ReferenceScriptLibraries { get; set; } @@ -52,7 +59,32 @@ protected CommonCommandLineModel(CommonCommandLineModel copyFrom) Force = copyFrom.Force; RelativeFolderPath = copyFrom.RelativeFolderPath; ControllerNamespace = copyFrom.ControllerNamespace; - UseSqlite = copyFrom.UseSqlite; + DatabaseProvider = copyFrom.DatabaseProvider; + } + } + + public static class CommonCommandLineModelExtensions + { + public static void ValidateCommandline(this CommonCommandLineModel model, ILogger logger) + { + if (model == null) + { + throw new ArgumentNullException(nameof(model)); + } + +#pragma warning disable CS0618 // Type or member is obsolete + if (model.UseSqlite) + { +#pragma warning restore CS0618 // Type or member is obsolete + //instead of throwing an error, letting the devs know that its obsolete. + logger.LogMessage(MessageStrings.SqliteObsoleteOption, LogMessageLevel.Information); + //Setting DatabaseProvider to SQLite if --databaseProvider|-dbProvider is not provided. + if (string.IsNullOrEmpty(model.DatabaseProviderString)) + { + model.DatabaseProvider = DbProvider.SQLite; + model.DatabaseProviderString = EfConstants.SQLite; + } + } } } -} \ No newline at end of file +} diff --git a/src/Scaffolding/VS.Web.CG.Mvc/Common/EFValidationUtil.cs b/src/Scaffolding/VS.Web.CG.Mvc/Common/EFValidationUtil.cs deleted file mode 100644 index 1fce22d67..000000000 --- a/src/Scaffolding/VS.Web.CG.Mvc/Common/EFValidationUtil.cs +++ /dev/null @@ -1,62 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using System.Collections.Generic; -using System.Linq; -using Microsoft.DotNet.Scaffolding.Shared.ProjectModel; - -namespace Microsoft.VisualStudio.Web.CodeGenerators.Mvc -{ - internal static class EFValidationUtil - { - const string EfDesignPackageName = "Microsoft.EntityFrameworkCore.Design"; - const string SqlServerPackageName = "Microsoft.EntityFrameworkCore.SqlServer"; - const string SqlitePackageName = "Microsoft.EntityFrameworkCore.Sqlite"; - - internal static void ValidateEFDependencies(IEnumerable dependencies, bool useSqlite) - { - var isEFDesignPackagePresent = dependencies - .Any(package => package.Name.Equals(EfDesignPackageName, StringComparison.OrdinalIgnoreCase)); - - if (!isEFDesignPackagePresent) - { - throw new InvalidOperationException( - string.Format(MessageStrings.InstallEfPackages, $"{EfDesignPackageName}")); - } - if (useSqlite) - { - ValidateSqliteDependency(dependencies); - } - else - { - ValidateSqlServerDependency(dependencies); - } - - } - - internal static void ValidateSqlServerDependency(IEnumerable dependencies) - { - var isSqlServerPackagePresent = dependencies - .Any(package => package.Name.Equals(SqlServerPackageName, StringComparison.OrdinalIgnoreCase)); - - if (!isSqlServerPackagePresent) - { - throw new InvalidOperationException( - string.Format(MessageStrings.InstallSqlPackage, $"{SqlServerPackageName}.")); - } - } - - internal static void ValidateSqliteDependency(IEnumerable dependencies) - { - var isSqlServerPackagePresent = dependencies - .Any(package => package.Name.Equals(SqlitePackageName, StringComparison.OrdinalIgnoreCase)); - - if (!isSqlServerPackagePresent) - { - throw new InvalidOperationException( - string.Format(MessageStrings.InstallSqlPackage, $"{SqlitePackageName}.")); - } - } - } -} diff --git a/src/Scaffolding/VS.Web.CG.Mvc/Common/ModelTypeAndContextModel.cs b/src/Scaffolding/VS.Web.CG.Mvc/Common/ModelTypeAndContextModel.cs index 10c25e794..6df79559b 100644 --- a/src/Scaffolding/VS.Web.CG.Mvc/Common/ModelTypeAndContextModel.cs +++ b/src/Scaffolding/VS.Web.CG.Mvc/Common/ModelTypeAndContextModel.cs @@ -1,8 +1,7 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - +using Microsoft.DotNet.Scaffolding.Shared; using Microsoft.DotNet.Scaffolding.Shared.Project; -using Microsoft.VisualStudio.Web.CodeGeneration; using Microsoft.VisualStudio.Web.CodeGeneration.EntityFrameworkCore; namespace Microsoft.VisualStudio.Web.CodeGenerators.Mvc @@ -15,6 +14,6 @@ public class ModelTypeAndContextModel public string DbContextFullName { get; set; } - public bool UseSqlite { get; set; } + public DbProvider DatabaseProvider { get; set; } } } diff --git a/src/Scaffolding/VS.Web.CG.Mvc/Controller/CommandLineGenerator.cs b/src/Scaffolding/VS.Web.CG.Mvc/Controller/CommandLineGenerator.cs index ab6bcd9ef..7c3b51a56 100644 --- a/src/Scaffolding/VS.Web.CG.Mvc/Controller/CommandLineGenerator.cs +++ b/src/Scaffolding/VS.Web.CG.Mvc/Controller/CommandLineGenerator.cs @@ -6,6 +6,7 @@ using Microsoft.VisualStudio.Web.CodeGeneration; using Microsoft.VisualStudio.Web.CodeGeneration.CommandLine; using Microsoft.Extensions.DependencyInjection; +using Microsoft.DotNet.Scaffolding.Shared; namespace Microsoft.VisualStudio.Web.CodeGenerators.Mvc.Controller { @@ -16,12 +17,7 @@ public class CommandLineGenerator : ICodeGenerator public CommandLineGenerator(IServiceProvider serviceProvider) { - if (serviceProvider == null) - { - throw new ArgumentNullException(nameof(serviceProvider)); - } - - _serviceProvider = serviceProvider; + _serviceProvider = serviceProvider ?? throw new ArgumentNullException(nameof(serviceProvider)); } public async Task GenerateCode(CommandLineGeneratorModel model) diff --git a/src/Scaffolding/VS.Web.CG.Mvc/Controller/CommandLineGeneratorModel.cs b/src/Scaffolding/VS.Web.CG.Mvc/Controller/CommandLineGeneratorModel.cs index 688a46d1e..87b399886 100644 --- a/src/Scaffolding/VS.Web.CG.Mvc/Controller/CommandLineGeneratorModel.cs +++ b/src/Scaffolding/VS.Web.CG.Mvc/Controller/CommandLineGeneratorModel.cs @@ -1,7 +1,9 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - +using System; +using System.Reflection; using Microsoft.VisualStudio.Web.CodeGeneration.CommandLine; +using Microsoft.DotNet.Scaffolding.Shared; namespace Microsoft.VisualStudio.Web.CodeGenerators.Mvc.Controller { diff --git a/src/Scaffolding/VS.Web.CG.Mvc/Controller/ControllerWithContextGenerator.cs b/src/Scaffolding/VS.Web.CG.Mvc/Controller/ControllerWithContextGenerator.cs index 239d6fa59..5f195215f 100644 --- a/src/Scaffolding/VS.Web.CG.Mvc/Controller/ControllerWithContextGenerator.cs +++ b/src/Scaffolding/VS.Web.CG.Mvc/Controller/ControllerWithContextGenerator.cs @@ -57,12 +57,12 @@ public ControllerWithContextGenerator( public override async Task Generate(CommandLineGeneratorModel controllerGeneratorModel) { - Contract.Assert(!String.IsNullOrEmpty(controllerGeneratorModel.ModelClass)); + Contract.Assert(!string.IsNullOrEmpty(controllerGeneratorModel.ModelClass)); ValidateNameSpaceName(controllerGeneratorModel); - + controllerGeneratorModel.ValidateCommandline(Logger); if (CalledFromCommandline) { - EFValidationUtil.ValidateEFDependencies(ProjectContext.PackageDependencies, controllerGeneratorModel.UseSqlite); + EFValidationUtil.ValidateEFDependencies(ProjectContext.PackageDependencies, controllerGeneratorModel.DatabaseProvider); } string outputPath = ValidateAndGetOutputPath(controllerGeneratorModel); @@ -72,6 +72,7 @@ public override async Task Generate(CommandLineGeneratorModel controllerGenerato controllerGeneratorModel, EntityFrameworkService, ModelTypesLocator, + Logger, _areaName); if (string.IsNullOrEmpty(controllerGeneratorModel.ControllerName)) diff --git a/src/Scaffolding/VS.Web.CG.Mvc/Identity/IdentityGenerator.cs b/src/Scaffolding/VS.Web.CG.Mvc/Identity/IdentityGenerator.cs index 05fb1fe7d..7a25c6b25 100644 --- a/src/Scaffolding/VS.Web.CG.Mvc/Identity/IdentityGenerator.cs +++ b/src/Scaffolding/VS.Web.CG.Mvc/Identity/IdentityGenerator.cs @@ -156,8 +156,8 @@ private string GetTemplateFolderRootForContentVersion(IdentityGeneratorTemplateM new[] { relativePath }, - _projectContext - ).First(); + _projectContext) + .First(); } public IdentityGenerator(IApplicationInfo applicationInfo, @@ -222,7 +222,7 @@ await EditProgramCsForIdentity( templateModel.DbContextClass, templateModel.UserClass, templateModel.DbContextNamespace, - templateModel.UseSQLite); + templateModel.DatabaseProvider); } await AddTemplateFiles(templateModel); @@ -253,14 +253,14 @@ private string GetIdentityCodeModifierConfig() /// For injecting the DbContext class in statements. /// For injecting the IdentityUser class in statements. /// For injecting the namespace for DbContext class in statements. - /// To opt between injecting UseSqlite or UseSqlServer + /// "Database type to use : DbProvider.SqlServer or DbProvider.SQLite" /// internal async Task EditProgramCsForIdentity( IModelTypesLocator modelTypesLocator, string dbContextClassName, string identityUserClassName, string dbContextNamespace, - bool useSqlite = false) + DbProvider databaseProvider) { var jsonText = GetIdentityCodeModifierConfig(); CodeModifierConfig identityProgramFileConfig = JsonSerializer.Deserialize(jsonText); @@ -289,7 +289,7 @@ internal async Task EditProgramCsForIdentity( filteredChanges = ProjectModifierHelper.UpdateVariables(filteredChanges, oldValue, newValue); } - filteredChanges = ApplyIdentityChanges(filteredChanges, dbContextClassName, identityUserClassName, useSqlite, useTopLevelsStatements); + filteredChanges = ApplyIdentityChanges(filteredChanges, dbContextClassName, identityUserClassName, databaseProvider, useTopLevelsStatements); if (useTopLevelsStatements) { @@ -317,18 +317,18 @@ internal async Task EditProgramCsForIdentity( } } - private CodeSnippet[] ApplyIdentityChanges(CodeSnippet[] filteredChanges, string dbContextClassName, string identityUserClassName, bool useSqlite, bool useTopLevelsStatements) + private CodeSnippet[] ApplyIdentityChanges(CodeSnippet[] filteredChanges, string dbContextClassName, string identityUserClassName, DbProvider databaseProvider, bool useTopLevelsStatements) { foreach (var codeChange in filteredChanges) { codeChange.LeadingTrivia = codeChange.LeadingTrivia ?? new Formatting(); - codeChange.Block = EditIdentityStrings(codeChange.Block, dbContextClassName, identityUserClassName, useSqlite, codeChange?.LeadingTrivia?.NumberOfSpaces); + codeChange.Block = EditIdentityStrings(codeChange.Block, dbContextClassName, identityUserClassName, databaseProvider, codeChange?.LeadingTrivia?.NumberOfSpaces); } return filteredChanges; } - internal static string EditIdentityStrings(string stringToModify, string dbContextClassName, string identityUserClassName, bool isSqlite, int? spaces) + internal static string EditIdentityStrings(string stringToModify, string dbContextClassName, string identityUserClassName, DbProvider databaseProvider, int? spaces) { if (string.IsNullOrEmpty(stringToModify)) { @@ -351,8 +351,7 @@ internal static string EditIdentityStrings(string stringToModify, string dbConte if (stringToModify.Contains(OptionsUseConnectionString)) { modifiedString = modifiedString.Replace("options.{0}", - isSqlite ? $"options.{UseSqlite}" : - $"options.{UseSqlServer}"); + databaseProvider.Equals(DbProvider.SQLite) ? $"options.{UseSqlite}" : $"options.{UseSqlServer}"); } if (stringToModify.Contains(GetConnectionString)) { @@ -430,8 +429,8 @@ await _codegeneratorActionService.AddFileFromTemplateAsync( { _connectionStringsWriter.AddConnectionString( connectionStringName: $"{templateModel.DbContextClass}Connection", - dataBaseName: templateModel.ApplicationName, - useSqlite: templateModel.UseSQLite); + databaseName: templateModel.ApplicationName, + templateModel.DatabaseProvider.Equals(DbProvider.SQLite) ? DbProvider.SQLite : DbProvider.SqlServer); } } diff --git a/src/Scaffolding/VS.Web.CG.Mvc/Identity/IdentityGeneratorCommandLineModel.cs b/src/Scaffolding/VS.Web.CG.Mvc/Identity/IdentityGeneratorCommandLineModel.cs index 501ffa25b..997e73314 100644 --- a/src/Scaffolding/VS.Web.CG.Mvc/Identity/IdentityGeneratorCommandLineModel.cs +++ b/src/Scaffolding/VS.Web.CG.Mvc/Identity/IdentityGeneratorCommandLineModel.cs @@ -1,4 +1,6 @@ -using Microsoft.VisualStudio.Web.CodeGeneration.CommandLine; +using System; +using Microsoft.DotNet.Scaffolding.Shared; +using Microsoft.VisualStudio.Web.CodeGeneration.CommandLine; namespace Microsoft.VisualStudio.Web.CodeGenerators.Mvc.Identity { @@ -7,9 +9,14 @@ public class IdentityGeneratorCommandLineModel [Option(Name = "rootNamespace", ShortName = "rn", Description = "Root namesapce to use for generating identity code." )] public string RootNamespace { get; set; } - [Option(Name = "useSqLite", ShortName ="sqlite", Description = "Flag to specify if DbContext should use SQLite instead of SQL Server.")] + [Obsolete("Use --databaseProvider or -dbProvider to configure database type instead")] + [Option(Name = "useSqlite", ShortName = "sqlite", Description = "Flag to specify if DbContext should use SQLite instead of SQL Server.")] public bool UseSqlite { get; set; } + [Option(Name = "databaseProvider", ShortName = "dbProvider", Description = "Database provider to use. Options include 'sqlserver' (default), 'sqlite', 'cosmos', 'postgres'.")] + public string DatabaseProviderString { get; set; } + public DbProvider DatabaseProvider { get; set; } + [Option(Name = "dbContext", ShortName = "dc", Description = "Name of the DbContext to use, or generate (if it does not exist).")] public string DbContext { get; set; } diff --git a/src/Scaffolding/VS.Web.CG.Mvc/Identity/IdentityGeneratorTemplateModel.cs b/src/Scaffolding/VS.Web.CG.Mvc/Identity/IdentityGeneratorTemplateModel.cs index 99ca13433..22026801c 100644 --- a/src/Scaffolding/VS.Web.CG.Mvc/Identity/IdentityGeneratorTemplateModel.cs +++ b/src/Scaffolding/VS.Web.CG.Mvc/Identity/IdentityGeneratorTemplateModel.cs @@ -1,6 +1,8 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using Microsoft.DotNet.Scaffolding.Shared; + namespace Microsoft.VisualStudio.Web.CodeGenerators.Mvc.Identity { public class IdentityGeneratorTemplateModel @@ -11,7 +13,7 @@ public class IdentityGeneratorTemplateModel public string UserClassNamespace { get; set; } public string DbContextClass { get; set; } public string DbContextNamespace { get; set; } - public bool UseSQLite { get; set; } + public DbProvider DatabaseProvider { get; set; } public bool IsUsingExistingDbContext { get; set; } public bool IsGenerateCustomUser { get; set; } public IdentityGeneratorFile[] FilesToGenerate { get; set; } @@ -22,4 +24,4 @@ public class IdentityGeneratorTemplateModel public string SupportFileLocation { get; set; } public bool HasExistingNonEmptyWwwRoot { get; set; } } -} \ No newline at end of file +} diff --git a/src/Scaffolding/VS.Web.CG.Mvc/Identity/IdentityGeneratorTemplateModelBuilder.cs b/src/Scaffolding/VS.Web.CG.Mvc/Identity/IdentityGeneratorTemplateModelBuilder.cs index c785d66c4..8533ece60 100644 --- a/src/Scaffolding/VS.Web.CG.Mvc/Identity/IdentityGeneratorTemplateModelBuilder.cs +++ b/src/Scaffolding/VS.Web.CG.Mvc/Identity/IdentityGeneratorTemplateModelBuilder.cs @@ -12,7 +12,6 @@ using Microsoft.VisualStudio.Web.CodeGeneration; using Microsoft.VisualStudio.Web.CodeGeneration.DotNet; using Microsoft.VisualStudio.Web.CodeGeneration.EntityFrameworkCore; -using System.Diagnostics; namespace Microsoft.VisualStudio.Web.CodeGenerators.Mvc.Identity { @@ -37,60 +36,23 @@ public IdentityGeneratorTemplateModelBuilder( IFileSystem fileSystem, ILogger logger) { - if (commandlineModel == null) - { - throw new ArgumentNullException(nameof(commandlineModel)); - } - - if (applicationInfo == null) - { - throw new ArgumentNullException(nameof(applicationInfo)); - } - - if (projectContext == null) - { - throw new ArgumentNullException(nameof(projectContext)); - } - - if (workspace == null) - { - throw new ArgumentNullException(nameof(workspace)); - } - - if (loader == null) - { - throw new ArgumentNullException(nameof(loader)); - } - - if (fileSystem == null) - { - throw new ArgumentNullException(nameof(fileSystem)); - } - - if (logger == null) - { - throw new ArgumentNullException(nameof(logger)); - } - - _commandlineModel = commandlineModel; - _applicationInfo = applicationInfo; - _projectContext = projectContext; - _workspace = workspace; - _loader = loader; - _fileSystem = fileSystem; - _logger = logger; + _commandlineModel = commandlineModel ?? throw new ArgumentNullException(nameof(commandlineModel)); + _applicationInfo = applicationInfo ?? throw new ArgumentNullException(nameof(applicationInfo)); + _projectContext = projectContext ?? throw new ArgumentNullException(nameof(projectContext)); + _workspace = workspace ?? throw new ArgumentNullException(nameof(workspace)); + _loader = loader ?? throw new ArgumentNullException(nameof(loader)); ; + _fileSystem = fileSystem ?? throw new ArgumentNullException(nameof(fileSystem)); + _logger = logger ?? throw new ArgumentNullException(nameof(logger)); } internal bool IsFilesSpecified => !string.IsNullOrEmpty(_commandlineModel.Files); internal bool IsExcludeSpecificed => !string.IsNullOrEmpty(_commandlineModel.ExcludeFiles); internal bool IsDbContextSpecified => !string.IsNullOrEmpty(_commandlineModel.DbContext); internal bool IsUsingExistingDbContext { get; set; } - - private Type _userType; - internal string UserClass { get; private set; } internal string UserClassNamespace { get; private set; } + private Type _userType; internal Type UserType { get @@ -104,7 +66,7 @@ internal Type UserType UserClassNamespace = _userType?.Namespace; } } - + internal DbProvider DatabaseProvider { get; set; } internal string DbContextClass { get; private set; } internal string DbContextNamespace { get; private set; } internal string RootNamespace { get; private set; } @@ -119,7 +81,7 @@ public async Task ValidateAndBuild() ? _projectContext.RootNamespace : _commandlineModel.RootNamespace; - ValidateRequiredDependencies(_commandlineModel.UseSqlite); + ValidateRequiredDependencies(); var defaultDbContextNamespace = $"{RootNamespace}.Areas.Identity.Data"; @@ -133,6 +95,7 @@ public async Task ValidateAndBuild() DbContextClass = GetClassNameFromTypeName(_commandlineModel.DbContext); DbContextNamespace = GetNamespaceFromTypeName(_commandlineModel.DbContext) ?? defaultDbContextNamespace; + DatabaseProvider = ModelMetadataUtilities.ValidateDatabaseProvider(_commandlineModel.DatabaseProviderString, _logger); } else { @@ -149,6 +112,7 @@ public async Task ValidateAndBuild() // --dbContext paramter was not specified. So we need to generate one using convention. DbContextClass = GetDefaultDbContextName(); DbContextNamespace = defaultDbContextNamespace; + DatabaseProvider = ModelMetadataUtilities.ValidateDatabaseProvider(_commandlineModel.DatabaseProviderString, _logger); } // if an existing user class was determined from the DbContext, don't try to get it from here. @@ -196,7 +160,7 @@ public async Task ValidateAndBuild() DbContextNamespace = DbContextNamespace, UserClass = UserClass, UserClassNamespace = UserClassNamespace, - UseSQLite = _commandlineModel.UseSqlite, + DatabaseProvider = DatabaseProvider, IsUsingExistingDbContext = IsUsingExistingDbContext, Namespace = RootNamespace, IsGenerateCustomUser = IsGenerateCustomUser, @@ -666,9 +630,9 @@ private Type FindUserTypeFromDbContext(Type existingDbContext) var usersProperty = existingDbContext.GetProperties() .FirstOrDefault(p => p.Name == "Users"); - if (usersProperty == null - || !usersProperty.PropertyType.IsGenericType - || usersProperty.PropertyType.GetGenericArguments().Count() != 1) + if (usersProperty == null || + !usersProperty.PropertyType.IsGenericType || + usersProperty.PropertyType.GetGenericArguments().Count() != 1) { // The IdentityDbContext has DbSet Users property. // The only case this would happen is if the user hides the inherited property. @@ -696,8 +660,8 @@ private async Task FindExistingType(string type) _loader, _logger); - if (_reflectedTypesProvider.GetCompilationErrors() != null - && _reflectedTypesProvider.GetCompilationErrors().Any()) + if (_reflectedTypesProvider.GetCompilationErrors() != null && + _reflectedTypesProvider.GetCompilationErrors().Any()) { // Failed to build the project. throw new InvalidOperationException( @@ -726,6 +690,27 @@ private void ValidateCommandLine(IdentityGeneratorCommandLineModel model) errorStrings.Add(string.Format(MessageStrings.InvalidDbContextClassName, model.DbContext)); } +#pragma warning disable CS0618 // Type or member is obsolete + if (model.UseSqlite) + { +#pragma warning restore CS0618 // Type or member is obsolete + //instead of throwing an error, letting the devs know that its obsolete. + _logger.LogMessage(MessageStrings.SqliteObsoleteOption, LogMessageLevel.Information); + //Setting DatabaseProvider to SQLite if --databaseProvider|-dbProvider is not provided. + if (string.IsNullOrEmpty(model.DatabaseProviderString)) + { + model.DatabaseProvider = DbProvider.SQLite; + model.DatabaseProviderString = EfConstants.SQLite; + } + } + + if (!string.IsNullOrEmpty(model.DatabaseProviderString) && !EfConstants.IdentityDbProviders.Contains(model.DatabaseProviderString, StringComparer.OrdinalIgnoreCase)) + { + string dbList = $"'{string.Join("', ", EfConstants.IdentityDbProviders.ToArray(), 0, EfConstants.IdentityDbProviders.Count - 1)}' and '{EfConstants.IdentityDbProviders.LastOrDefault()}'"; + errorStrings.Add(string.Format(MessageStrings.InvalidDatabaseProvider, model.DatabaseProviderString)); + errorStrings.Add($"Supported database providers include : {dbList}"); + } + if (!string.IsNullOrEmpty(model.RootNamespace) && !RoslynUtilities.IsValidNamespace(model.RootNamespace)) { errorStrings.Add(string.Format(MessageStrings.InvalidNamespaceName, model.RootNamespace)); @@ -747,7 +732,7 @@ private void ValidateCommandLine(IdentityGeneratorCommandLineModel model) } } - private void ValidateRequiredDependencies(bool useSqlite) + private void ValidateRequiredDependencies() { var dependencies = new HashSet() { @@ -760,11 +745,6 @@ private void ValidateRequiredDependencies(bool useSqlite) .PackageDependencies .Any(package => package.Name.Equals(EfDesignPackageName, StringComparison.OrdinalIgnoreCase)); - if (!useSqlite) - { - dependencies.Add("Microsoft.EntityFrameworkCore.SqlServer"); - } - var missingPackages = dependencies.Where(d => !_projectContext.PackageDependencies.Any(p => p.Name.Equals(d, StringComparison.OrdinalIgnoreCase))); if (CalledFromCommandline && missingPackages.Any()) { diff --git a/src/Scaffolding/VS.Web.CG.Mvc/MessageStrings.Designer.cs b/src/Scaffolding/VS.Web.CG.Mvc/MessageStrings.Designer.cs index 61ef9839a..d6ff3c152 100644 --- a/src/Scaffolding/VS.Web.CG.Mvc/MessageStrings.Designer.cs +++ b/src/Scaffolding/VS.Web.CG.Mvc/MessageStrings.Designer.cs @@ -240,6 +240,24 @@ internal static string InvalidClassName { } } + /// + /// Looks up a localized string similar to Invalid database provider '{0}' found.. + /// + internal static string InvalidDatabaseProvider { + get { + return ResourceManager.GetString("InvalidDatabaseProvider", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Invalid database type '{0}'. + /// + internal static string InvalidDatabaseType { + get { + return ResourceManager.GetString("InvalidDatabaseType", resourceCulture); + } + } + /// /// Looks up a localized string similar to Value of --dbContext '{0}' is not a valid class name.. /// @@ -339,6 +357,15 @@ internal static string NamespaceOptionDesc { } } + /// + /// Looks up a localized string similar to No database provider found. Using 'SqlServer' by default for new DbContext creation!. + /// + internal static string NoDbProviderFound { + get { + return ResourceManager.GetString("NoDbProviderFound", resourceCulture); + } + } + /// /// Looks up a localized string similar to Specify the relative output folder path from project where the file needs to be generated, if not specified, file will be generated in the project folder.. /// @@ -411,6 +438,15 @@ internal static string ScriptsOptionDesc { } } + /// + /// Looks up a localized string similar to --useSqlite|-sqlite option is obsolete now. Use --databaseProvider|-dbProvider instead in the future.. + /// + internal static string SqliteObsoleteOption { + get { + return ResourceManager.GetString("SqliteObsoleteOption", resourceCulture); + } + } + /// /// Looks up a localized string similar to Failed to generate readme file at '{0}'.. /// diff --git a/src/Scaffolding/VS.Web.CG.Mvc/MessageStrings.resx b/src/Scaffolding/VS.Web.CG.Mvc/MessageStrings.resx index 4a7be8f14..18b7596fe 100644 --- a/src/Scaffolding/VS.Web.CG.Mvc/MessageStrings.resx +++ b/src/Scaffolding/VS.Web.CG.Mvc/MessageStrings.resx @@ -258,4 +258,16 @@ The class name '{0}' is not valid. + + Invalid database type '{0}' + + + No database provider found. Using 'SqlServer' by default for new DbContext creation! + + + --useSqlite|-sqlite option is obsolete now. Use --databaseProvider|-dbProvider instead in the future. + + + Invalid database provider '{0}' found. + \ No newline at end of file diff --git a/src/Scaffolding/VS.Web.CG.Mvc/Minimal Api/MinimalApiGenerator.cs b/src/Scaffolding/VS.Web.CG.Mvc/Minimal Api/MinimalApiGenerator.cs index c384a8bc3..827496a59 100644 --- a/src/Scaffolding/VS.Web.CG.Mvc/Minimal Api/MinimalApiGenerator.cs +++ b/src/Scaffolding/VS.Web.CG.Mvc/Minimal Api/MinimalApiGenerator.cs @@ -66,17 +66,19 @@ public MinimalApiGenerator(IApplicationInfo applicationInfo, /// public async Task GenerateCode(MinimalApiGeneratorCommandLineModel model) { + model.ValidateCommandline(Logger); var namespaceName = NameSpaceUtilities.GetSafeNameSpaceFromPath(model.RelativeFolderPath, AppInfo.ApplicationName); //get model and dbcontext var modelTypeAndContextModel = await ModelMetadataUtilities.GetModelEFMetadataMinimalAsync( model, EntityFrameworkService, ModelTypesLocator, + Logger, areaName : string.Empty); if (!string.IsNullOrEmpty(modelTypeAndContextModel.DbContextFullName) && CalledFromCommandline) { - EFValidationUtil.ValidateEFDependencies(ProjectContext.PackageDependencies, useSqlite: model.UseSqlite); + EFValidationUtil.ValidateEFDependencies(ProjectContext.PackageDependencies, model.DatabaseProvider); } if (model.OpenApi) @@ -92,7 +94,7 @@ public async Task GenerateCode(MinimalApiGeneratorCommandLineModel model) NullableEnabled = "enable".Equals(ProjectContext?.Nullable, StringComparison.OrdinalIgnoreCase), OpenAPI = model.OpenApi, MethodName = $"Map{modelTypeAndContextModel.ModelType.Name}Endpoints", - UseSqlite = model.UseSqlite, + DatabaseProvider = model.DatabaseProvider, UseTypedResults = !model.NoTypedResults }; diff --git a/src/Scaffolding/VS.Web.CG.Mvc/Minimal Api/MinimalApiGeneratorCommandLineModel.cs b/src/Scaffolding/VS.Web.CG.Mvc/Minimal Api/MinimalApiGeneratorCommandLineModel.cs index 9206fb289..1ec8fdb2e 100644 --- a/src/Scaffolding/VS.Web.CG.Mvc/Minimal Api/MinimalApiGeneratorCommandLineModel.cs +++ b/src/Scaffolding/VS.Web.CG.Mvc/Minimal Api/MinimalApiGeneratorCommandLineModel.cs @@ -1,3 +1,5 @@ +using System; +using Microsoft.DotNet.Scaffolding.Shared; using Microsoft.VisualStudio.Web.CodeGeneration.CommandLine; namespace Microsoft.VisualStudio.Web.CodeGenerators.Mvc.MinimalApi @@ -22,9 +24,14 @@ public class MinimalApiGeneratorCommandLineModel [Option(Name = "endpointsNamespace", ShortName = "namespace", Description = "Specify the name of the namespace to use for the generated controller")] public string EndpointsNamespace { get; set; } + [Obsolete("Use --databaseProvider or -dbProvider to configure database type instead")] [Option(Name = "useSqlite", ShortName = "sqlite", Description = "Flag to specify if DbContext should use SQLite instead of SQL Server.")] public bool UseSqlite { get; set; } + [Option(Name = "databaseProvider", ShortName = "dbProvider", Description = "Database provider to use. Options include 'sqlserver' (default), 'sqlite', 'cosmos', 'postgres'.")] + public string DatabaseProviderString { get; set; } + public DbProvider DatabaseProvider { get; set; } + [Option(Name = "noTypedResults", ShortName = "ntr", Description = "Flag to not use TypedResults for minimal apis.")] public bool NoTypedResults { get; set; } @@ -39,7 +46,7 @@ protected MinimalApiGeneratorCommandLineModel(MinimalApiGeneratorCommandLineMode RelativeFolderPath = copyFrom.RelativeFolderPath; OpenApi = copyFrom.OpenApi; EndpointsNamespace = copyFrom.EndpointsNamespace; - UseSqlite = copyFrom.UseSqlite; + DatabaseProvider = copyFrom.DatabaseProvider; NoTypedResults = copyFrom.NoTypedResults; } @@ -48,4 +55,29 @@ public MinimalApiGeneratorCommandLineModel Clone() return new MinimalApiGeneratorCommandLineModel(this); } } + + public static class MinimalApiGeneratorCommandLineModelExtensions + { + public static void ValidateCommandline(this MinimalApiGeneratorCommandLineModel model, ILogger logger) + { + if (model == null) + { + throw new ArgumentNullException(nameof(model)); + } + +#pragma warning disable CS0618 // Type or member is obsolete + if (model.UseSqlite) + { +#pragma warning restore CS0618 // Type or member is obsolete + //instead of throwing an error, letting the devs know that its obsolete. + logger.LogMessage(MessageStrings.SqliteObsoleteOption, LogMessageLevel.Information); + //Setting DatabaseProvider to SQLite if --databaseProvider|-dbProvider is not provided. + if (string.IsNullOrEmpty(model.DatabaseProviderString)) + { + model.DatabaseProvider = DbProvider.SQLite; + model.DatabaseProviderString = EfConstants.SQLite; + } + } + } + } } diff --git a/src/Scaffolding/VS.Web.CG.Mvc/Minimal Api/MinimalApiModel.cs b/src/Scaffolding/VS.Web.CG.Mvc/Minimal Api/MinimalApiModel.cs index 9c906616b..38fc1f839 100644 --- a/src/Scaffolding/VS.Web.CG.Mvc/Minimal Api/MinimalApiModel.cs +++ b/src/Scaffolding/VS.Web.CG.Mvc/Minimal Api/MinimalApiModel.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using Microsoft.DotNet.Scaffolding.Shared; using Microsoft.DotNet.Scaffolding.Shared.Project; using Microsoft.VisualStudio.Web.CodeGeneration.EntityFrameworkCore; @@ -37,6 +38,8 @@ public MinimalApiModel( //Endpoints class name public string EndpointsName { get; set; } + //Database type eg. SQL Server, SQLite, Cosmos DB, Postgres and more later. + public DbProvider DatabaseProvider { get; set; } public bool NullableEnabled { get; set; } //If CRUD endpoints support Open API diff --git a/src/Scaffolding/VS.Web.CG.Mvc/ModelMetadataUtilities.cs b/src/Scaffolding/VS.Web.CG.Mvc/ModelMetadataUtilities.cs index 9cb20aea0..21d58022b 100644 --- a/src/Scaffolding/VS.Web.CG.Mvc/ModelMetadataUtilities.cs +++ b/src/Scaffolding/VS.Web.CG.Mvc/ModelMetadataUtilities.cs @@ -2,12 +2,11 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.Collections.Generic; using System.Diagnostics.Contracts; using System.Linq; using System.Threading.Tasks; +using Microsoft.DotNet.Scaffolding.Shared; using Microsoft.DotNet.Scaffolding.Shared.Project; -using Microsoft.VisualStudio.Web.CodeGeneration; using Microsoft.VisualStudio.Web.CodeGeneration.EntityFrameworkCore; using Microsoft.VisualStudio.Web.CodeGenerators.Mvc.MinimalApi; @@ -36,6 +35,7 @@ internal static async Task ValidateModelAndGetEFMetada CommonCommandLineModel commandLineModel, IEntityFrameworkService entityFrameworkService, IModelTypesLocator modelTypesLocator, + ILogger logger, string areaName) { ModelType model = ValidationUtil.ValidateType(commandLineModel.ModelClass, "model", modelTypesLocator); @@ -46,18 +46,23 @@ internal static async Task ValidateModelAndGetEFMetada var dbContextFullName = dataContext != null ? dataContext.FullName : commandLineModel.DataContextClass; + if (dataContext == null) + { + commandLineModel.DatabaseProvider = ValidateDatabaseProvider(commandLineModel.DatabaseProviderString, logger); + } + var modelMetadata = await entityFrameworkService.GetModelMetadata( dbContextFullName, model, areaName, - commandLineModel.UseSqlite); + commandLineModel.DatabaseProvider); return new ModelTypeAndContextModel() { ModelType = model, DbContextFullName = dbContextFullName, ContextProcessingResult = modelMetadata, - UseSqlite = commandLineModel.UseSqlite + DatabaseProvider = commandLineModel.DatabaseProvider }; } @@ -65,6 +70,7 @@ internal static async Task GetModelEFMetadataMinimalAs MinimalApiGeneratorCommandLineModel commandLineModel, IEntityFrameworkService entityFrameworkService, IModelTypesLocator modelTypesLocator, + ILogger logger, string areaName) { ModelType model = ValidationUtil.ValidateType(commandLineModel.ModelClass, "model", modelTypesLocator); @@ -83,12 +89,16 @@ internal static async Task GetModelEFMetadataMinimalAs { dataContext = ValidationUtil.ValidateType(commandLineModel.DataContextClass, "dataContext", modelTypesLocator, throwWhenNotFound: false); dbContextFullName = dataContext != null ? dataContext.FullName : commandLineModel.DataContextClass; - + if (dataContext == null) + { + commandLineModel.DatabaseProvider = ValidateDatabaseProvider(commandLineModel.DatabaseProviderString, logger); + } + modelMetadata = await entityFrameworkService.GetModelMetadata( dbContextFullName, model, areaName, - useSqlite: commandLineModel.UseSqlite); + databaseProvider: commandLineModel.DatabaseProvider); } return new ModelTypeAndContextModel() @@ -96,8 +106,26 @@ internal static async Task GetModelEFMetadataMinimalAs ModelType = model, DbContextFullName = dbContextFullName, ContextProcessingResult = modelMetadata, - UseSqlite = false + DatabaseProvider = commandLineModel.DatabaseProvider }; } + + internal static DbProvider ValidateDatabaseProvider(string databaseProviderString, ILogger logger) + { + if (string.IsNullOrEmpty(databaseProviderString)) + { + logger.LogMessage(MessageStrings.NoDbProviderFound, LogMessageLevel.Information); + return DbProvider.SqlServer; + } + else if (Enum.TryParse(databaseProviderString, ignoreCase: true, out DbProvider dbProvider)) + { + return dbProvider; + } + else + { + string dbList = $"'{string.Join("', ", EfConstants.AllDbProviders.ToArray(), 0, EfConstants.AllDbProviders.Count - 1)} and '{EfConstants.AllDbProviders.LastOrDefault()}'"; + throw new InvalidOperationException($"Invalid database provider '{databaseProviderString}'.\nSupported database providers include : {dbList}"); + } + } } } diff --git a/src/Scaffolding/VS.Web.CG.Mvc/ParameterDefinitions/controller.json b/src/Scaffolding/VS.Web.CG.Mvc/ParameterDefinitions/controller.json index f3e52fece..6bb6ef5ad 100644 --- a/src/Scaffolding/VS.Web.CG.Mvc/ParameterDefinitions/controller.json +++ b/src/Scaffolding/VS.Web.CG.Mvc/ParameterDefinitions/controller.json @@ -69,11 +69,11 @@ "Name": "controllerNamespace", "ShortName": "namespace", "Description": "Specify the name of the namespace to use for the generated controller" - }, - { - "Name" : "useSqlite", - "ShortName" : "sqlite", - "Description": "Flag to specify if DbContext should use SQLite instead of SQL Server." - } + }, + { + "Name": "databaseProvider", + "ShortName": "dbProvider", + "Description": "Database provider to use. Options include 'sqlserver' (default), 'sqlite', 'cosmos', 'postgres'." + } ] -} \ No newline at end of file +} diff --git a/src/Scaffolding/VS.Web.CG.Mvc/ParameterDefinitions/identity.json b/src/Scaffolding/VS.Web.CG.Mvc/ParameterDefinitions/identity.json index bb477bb0e..d9c88cd9c 100644 --- a/src/Scaffolding/VS.Web.CG.Mvc/ParameterDefinitions/identity.json +++ b/src/Scaffolding/VS.Web.CG.Mvc/ParameterDefinitions/identity.json @@ -24,9 +24,9 @@ "Description": "Name of the User class to generate." }, { - "Name": "useSqLite", - "ShortName": "sqlite", - "Description": "Flag to specify if DbContext should use SQLite instead of SQL Server." + "Name": "databaseProvider", + "ShortName": "dbProvider", + "Description": "Database provider to use. Options include 'sqlserver' (default), 'sqlite', 'cosmos', 'postgres'." }, { "Name": "force", diff --git a/src/Scaffolding/VS.Web.CG.Mvc/ParameterDefinitions/minimalapi.json b/src/Scaffolding/VS.Web.CG.Mvc/ParameterDefinitions/minimalapi.json index 172d98b8f..763262395 100644 --- a/src/Scaffolding/VS.Web.CG.Mvc/ParameterDefinitions/minimalapi.json +++ b/src/Scaffolding/VS.Web.CG.Mvc/ParameterDefinitions/minimalapi.json @@ -25,6 +25,10 @@ { "Name": "-namespace|--endpointsNamespace", "Description": "Specify the name of the namespace to use for the generated Endpoints file" + }, + { + "Name" : "databaseProvider", + "Description": "Database provider to use. Options include 'sqlserver' (default), 'sqlite', 'cosmos', 'postgres'." } ], "Options": [] diff --git a/src/Scaffolding/VS.Web.CG.Mvc/ParameterDefinitions/razorpage.json b/src/Scaffolding/VS.Web.CG.Mvc/ParameterDefinitions/razorpage.json index ce26327dd..15067cb3c 100644 --- a/src/Scaffolding/VS.Web.CG.Mvc/ParameterDefinitions/razorpage.json +++ b/src/Scaffolding/VS.Web.CG.Mvc/ParameterDefinitions/razorpage.json @@ -62,10 +62,10 @@ "ShortName": "npm", "Description": "Switch to not generate a PageModel class for Empty template" }, - { - "Name" : "useSqlite", - "ShortName" : "sqlite", - "Description": "Flag to specify if DbContext should use SQLite instead of SQL Server." - } + { + "Name": "databaseProvider", + "ShortName": "dbProvider", + "Description": "Database provider to use. Options include 'sqlserver' (default), 'sqlite', 'cosmos', 'postgres'." + } ] -} \ No newline at end of file +} diff --git a/src/Scaffolding/VS.Web.CG.Mvc/ParameterDefinitions/view.json b/src/Scaffolding/VS.Web.CG.Mvc/ParameterDefinitions/view.json index db1eb8b47..cefcde065 100644 --- a/src/Scaffolding/VS.Web.CG.Mvc/ParameterDefinitions/view.json +++ b/src/Scaffolding/VS.Web.CG.Mvc/ParameterDefinitions/view.json @@ -57,10 +57,10 @@ "ShortName": "partial", "Description": "Generate a partial view, other layout options (-l and -udl) are ignored if this is specified" }, - { - "Name" : "useSqlite", - "ShortName" : "sqlite", - "Description": "Flag to specify if DbContext should use SQLite instead of SQL Server." - } + { + "Name": "databaseProvider", + "ShortName": "dbProvider", + "Description": "Database provider to use. Options include 'sqlserver' (default), 'sqlite', 'cosmos', 'postgres'." + } ] -} \ No newline at end of file +} diff --git a/src/Scaffolding/VS.Web.CG.Mvc/RazorPage/EFModelBasedRazorPageScaffolder.cs b/src/Scaffolding/VS.Web.CG.Mvc/RazorPage/EFModelBasedRazorPageScaffolder.cs index 52ab04ac1..2a768ce05 100644 --- a/src/Scaffolding/VS.Web.CG.Mvc/RazorPage/EFModelBasedRazorPageScaffolder.cs +++ b/src/Scaffolding/VS.Web.CG.Mvc/RazorPage/EFModelBasedRazorPageScaffolder.cs @@ -77,17 +77,19 @@ public override async Task GenerateCode(RazorPageGeneratorModel razorGeneratorMo throw new ArgumentException(MessageStrings.PageModelFlagNotSupported); } + razorGeneratorModel.ValidateCommandline(_logger); var outputPath = ValidateAndGetOutputPath(razorGeneratorModel, outputFileName: razorGeneratorModel.RazorPageName + Constants.ViewExtension); if (CalledFromCommandline) { - EFValidationUtil.ValidateEFDependencies(_projectContext.PackageDependencies, razorGeneratorModel.UseSqlite); + EFValidationUtil.ValidateEFDependencies(_projectContext.PackageDependencies, razorGeneratorModel.DatabaseProvider); } ModelTypeAndContextModel modelTypeAndContextModel = await ModelMetadataUtilities.ValidateModelAndGetEFMetadata( razorGeneratorModel, _entityFrameworkService, _modelTypesLocator, + _logger, string.Empty); TemplateModel = GetRazorPageWithContextTemplateModel(razorGeneratorModel, modelTypeAndContextModel); @@ -139,13 +141,14 @@ internal async Task GenerateViews(RazorPageGeneratorModel razorPageGeneratorMode if (CalledFromCommandline) { - EFValidationUtil.ValidateEFDependencies(_projectContext.PackageDependencies, razorPageGeneratorModel.UseSqlite); + EFValidationUtil.ValidateEFDependencies(_projectContext.PackageDependencies, razorPageGeneratorModel.DatabaseProvider); } modelTypeAndContextModel = await ModelMetadataUtilities.ValidateModelAndGetEFMetadata( razorPageGeneratorModel, _entityFrameworkService, _modelTypesLocator, + _logger, string.Empty); TemplateModel = GetRazorPageWithContextTemplateModel(razorPageGeneratorModel, modelTypeAndContextModel); diff --git a/src/Scaffolding/VS.Web.CG.Mvc/RazorPage/RazorPageScaffolderBase.cs b/src/Scaffolding/VS.Web.CG.Mvc/RazorPage/RazorPageScaffolderBase.cs index 7ca9496d4..9514640d4 100644 --- a/src/Scaffolding/VS.Web.CG.Mvc/RazorPage/RazorPageScaffolderBase.cs +++ b/src/Scaffolding/VS.Web.CG.Mvc/RazorPage/RazorPageScaffolderBase.cs @@ -288,7 +288,7 @@ protected RazorPageWithContextTemplateModel GetRazorPageWithContextTemplateModel JQueryVersion = "1.10.2", //Todo BootstrapVersion = razorGeneratorModel.BootstrapVersion, ContentVersion = DetermineContentVersion(razorGeneratorModel), - UseSqlite = razorGeneratorModel.UseSqlite, + DatabaseProvider = razorGeneratorModel.DatabaseProvider, NullableEnabled = _projectContext.Nullable }; diff --git a/src/Scaffolding/VS.Web.CG.Mvc/RazorPage/RazorPageWithContextTemplateModel.cs b/src/Scaffolding/VS.Web.CG.Mvc/RazorPage/RazorPageWithContextTemplateModel.cs index 6c1cbfa96..731dea921 100644 --- a/src/Scaffolding/VS.Web.CG.Mvc/RazorPage/RazorPageWithContextTemplateModel.cs +++ b/src/Scaffolding/VS.Web.CG.Mvc/RazorPage/RazorPageWithContextTemplateModel.cs @@ -1,7 +1,7 @@ using System; using System.Collections.Generic; +using Microsoft.DotNet.Scaffolding.Shared; using Microsoft.DotNet.Scaffolding.Shared.Project; -using Microsoft.VisualStudio.Web.CodeGeneration; namespace Microsoft.VisualStudio.Web.CodeGenerators.Mvc.Razor { @@ -27,7 +27,7 @@ public RazorPageWithContextTemplateModel(ModelType modelType, string dbContextFu DbContextNamespace = classNameModel.NamespaceName; } - public bool UseSqlite { get; set; } + public DbProvider DatabaseProvider { get; set; } public string ViewDataTypeName { get; set; } diff --git a/src/Scaffolding/VS.Web.CG.Mvc/View/EFModelBasedViewScaffolder.cs b/src/Scaffolding/VS.Web.CG.Mvc/View/EFModelBasedViewScaffolder.cs index 80831327b..d4e30b745 100644 --- a/src/Scaffolding/VS.Web.CG.Mvc/View/EFModelBasedViewScaffolder.cs +++ b/src/Scaffolding/VS.Web.CG.Mvc/View/EFModelBasedViewScaffolder.cs @@ -18,7 +18,9 @@ public class EFModelBasedViewScaffolder : ViewScaffolderBase { private IEntityFrameworkService _entityFrameworkService; private IModelTypesLocator _modelTypesLocator; - + private IFileSystem _fileSystem; + private bool CalledFromCommandline => !(_fileSystem is SimulationModeFileSystem); + public EFModelBasedViewScaffolder( IProjectContext projectContext, IApplicationInfo applicationInfo, @@ -26,21 +28,13 @@ public EFModelBasedViewScaffolder( IEntityFrameworkService entityFrameworkService, ICodeGeneratorActionsService codeGeneratorActionsService, IServiceProvider serviceProvider, - ILogger logger) + ILogger logger, + IFileSystem fileSystem) : base(projectContext, applicationInfo, codeGeneratorActionsService, serviceProvider, logger) { - if (modelTypesLocator == null) - { - throw new ArgumentNullException(nameof(modelTypesLocator)); - } - - if (entityFrameworkService == null) - { - throw new ArgumentNullException(nameof(entityFrameworkService)); - } - - _modelTypesLocator = modelTypesLocator; - _entityFrameworkService = entityFrameworkService; + _modelTypesLocator = modelTypesLocator ?? throw new ArgumentNullException(nameof(modelTypesLocator)); + _entityFrameworkService = entityFrameworkService ?? throw new ArgumentNullException(nameof(entityFrameworkService)); + _fileSystem = fileSystem ?? throw new ArgumentNullException(nameof(fileSystem)); } public override async Task GenerateCode(ViewGeneratorModel viewGeneratorModel) @@ -59,13 +53,12 @@ public override async Task GenerateCode(ViewGeneratorModel viewGeneratorModel) { throw new ArgumentException(MessageStrings.TemplateNameRequired); } - + viewGeneratorModel.ValidateCommandline(_logger); ModelTypeAndContextModel modelTypeAndContextModel = null; var outputPath = ValidateAndGetOutputPath(viewGeneratorModel, outputFileName: viewGeneratorModel.ViewName + Constants.ViewExtension); - - if (!string.IsNullOrEmpty(_projectContext.TargetFrameworkMoniker)) + if (!string.IsNullOrEmpty(_projectContext.TargetFrameworkMoniker) && CalledFromCommandline) { - EFValidationUtil.ValidateEFDependencies(_projectContext.PackageDependencies, viewGeneratorModel.UseSqlite); + EFValidationUtil.ValidateEFDependencies(_projectContext.PackageDependencies, viewGeneratorModel.DatabaseProvider); } @@ -73,6 +66,7 @@ public override async Task GenerateCode(ViewGeneratorModel viewGeneratorModel) viewGeneratorModel, _entityFrameworkService, _modelTypesLocator, + _logger, string.Empty); await GenerateView(viewGeneratorModel, modelTypeAndContextModel, outputPath); diff --git a/src/Scaffolding/VS.Web.CG.Mvc/baseline.netcore.json b/src/Scaffolding/VS.Web.CG.Mvc/baseline.netcore.json index 3939ee3cc..867e9e273 100644 --- a/src/Scaffolding/VS.Web.CG.Mvc/baseline.netcore.json +++ b/src/Scaffolding/VS.Web.CG.Mvc/baseline.netcore.json @@ -2348,19 +2348,19 @@ }, { "Kind": "Method", - "Name": "get_UseSQLite", + "Name": "get_DatabaseProvider", "Parameters": [], - "ReturnType": "System.Boolean", + "ReturnType": "System.String", "Visibility": "Public", "GenericParameter": [] }, { "Kind": "Method", - "Name": "set_UseSQLite", + "Name": "set_DatabaseProvider", "Parameters": [ { "Name": "value", - "Type": "System.Boolean" + "Type": "System.String" } ], "ReturnType": "System.Void", diff --git a/src/Shared/Microsoft.DotNet.Scaffolding.Shared/EFValidationUtil.cs b/src/Shared/Microsoft.DotNet.Scaffolding.Shared/EFValidationUtil.cs new file mode 100644 index 000000000..1dbc5e51c --- /dev/null +++ b/src/Shared/Microsoft.DotNet.Scaffolding.Shared/EFValidationUtil.cs @@ -0,0 +1,41 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.DotNet.Scaffolding.Shared.ProjectModel; + +namespace Microsoft.DotNet.Scaffolding.Shared +{ + internal static class EFValidationUtil + { + internal static void ValidateEFDependencies(IEnumerable dependencies, DbProvider dataContextType) + { + var isEFDesignPackagePresent = dependencies + .Any(package => package.Name.Equals(EfConstants.EfDesignPackageName, StringComparison.OrdinalIgnoreCase)); + + if (!isEFDesignPackagePresent) + { + throw new InvalidOperationException( + string.Format(MessageStrings.InstallEfPackages, $"{EfConstants.EfDesignPackageName}")); + } + if (EfConstants.EfPackagesDict.TryGetValue(dataContextType, out var dbProviderPackageName)) + { + ValidateDependency(dbProviderPackageName, dependencies); + } + } + + internal static void ValidateDependency(string packageName, IEnumerable dependencies) + { + var isPackagePresent = dependencies + .Any(package => package.Name.Equals(packageName, StringComparison.OrdinalIgnoreCase)); + + if (!isPackagePresent) + { + throw new InvalidOperationException( + string.Format(MessageStrings.InstallSqlPackage, $"{packageName}.")); + } + } + } +} diff --git a/src/Shared/Microsoft.DotNet.Scaffolding.Shared/MessageStrings.Designer.cs b/src/Shared/Microsoft.DotNet.Scaffolding.Shared/MessageStrings.Designer.cs index aa64595e9..a16822fab 100644 --- a/src/Shared/Microsoft.DotNet.Scaffolding.Shared/MessageStrings.Designer.cs +++ b/src/Shared/Microsoft.DotNet.Scaffolding.Shared/MessageStrings.Designer.cs @@ -114,6 +114,24 @@ internal static string EndFileSystemChangeToken { } } + /// + /// Looks up a localized string similar to To scaffold controllers and views using models, install Entity Framework core packages and try again: {0}. + /// + internal static string InstallEfPackages { + get { + return ResourceManager.GetString("InstallEfPackages", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to To scaffold, install the following Entity Framework core packages and try again: {0}. + /// + internal static string InstallSqlPackage { + get { + return ResourceManager.GetString("InstallSqlPackage", resourceCulture); + } + } + /// /// Looks up a localized string similar to Invalid FileSystemChange message.. /// diff --git a/src/Shared/Microsoft.DotNet.Scaffolding.Shared/MessageStrings.resx b/src/Shared/Microsoft.DotNet.Scaffolding.Shared/MessageStrings.resx index 51060f4f4..d374bafe6 100644 --- a/src/Shared/Microsoft.DotNet.Scaffolding.Shared/MessageStrings.resx +++ b/src/Shared/Microsoft.DotNet.Scaffolding.Shared/MessageStrings.resx @@ -135,6 +135,12 @@ :::End FileSystemChange::: + + To scaffold controllers and views using models, install Entity Framework core packages and try again: {0} + + + To scaffold, install the following Entity Framework core packages and try again: {0} + Invalid FileSystemChange message. @@ -150,4 +156,4 @@ :::Start FileSystemChange::: - + \ No newline at end of file diff --git a/src/Shared/Microsoft.DotNet.Scaffolding.Shared/SharedConstants.cs b/src/Shared/Microsoft.DotNet.Scaffolding.Shared/SharedConstants.cs new file mode 100644 index 000000000..c12d9317f --- /dev/null +++ b/src/Shared/Microsoft.DotNet.Scaffolding.Shared/SharedConstants.cs @@ -0,0 +1,44 @@ +using System.Collections.Generic; + +namespace Microsoft.DotNet.Scaffolding.Shared +{ + public enum DbProvider + { + SqlServer, SQLite, CosmosDb, Postgres, Existing + } + + public static class EfConstants + { + public static string SqlServer = DbProvider.SqlServer.ToString(); + public static string SQLite = DbProvider.SQLite.ToString(); + public static string CosmosDb = DbProvider.CosmosDb.ToString(); + public static string Postgres = DbProvider.Postgres.ToString(); + public const string EfDesignPackageName = "Microsoft.EntityFrameworkCore.Design"; + public const string SqlServerPackageName = "Microsoft.EntityFrameworkCore.SqlServer"; + public const string SqlitePackageName = "Microsoft.EntityFrameworkCore.Sqlite"; + public const string CosmosPakcageName = "Microsoft.EntityFrameworkCore.Cosmos"; + public const string PostgresPackageName = "Npgsql.EntityFrameworkCore.PostgreSQL"; + public const string SQLConnectionStringFormat = "Server=(localdb)\\mssqllocaldb;Database={0};Trusted_Connection=True;MultipleActiveResultSets=true"; + public const string SQLiteConnectionStringFormat = "Data Source={0}.db"; + public const string CosmosDbConnectionStringFormat = "AccountEndpoint=https://localhost:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw=="; + public const string PostgresConnectionStringFormat = "server=localhost;username=postgres;database={0}"; + public static readonly IDictionary ConnectionStringsDict = new Dictionary + { + { DbProvider.SqlServer, SQLConnectionStringFormat }, + { DbProvider.SQLite, SQLiteConnectionStringFormat }, + { DbProvider.CosmosDb, CosmosDbConnectionStringFormat }, + { DbProvider.Postgres, PostgresConnectionStringFormat } + }; + + public static readonly IDictionary EfPackagesDict = new Dictionary + { + { DbProvider.SqlServer, SqlServerPackageName }, + { DbProvider.SQLite, SqlitePackageName }, + { DbProvider.CosmosDb, CosmosPakcageName }, + { DbProvider.Postgres, PostgresPackageName } + }; + + public static readonly IList IdentityDbProviders = new List { SqlServer, SQLite }; + public static readonly IList AllDbProviders = new List { SqlServer, SQLite, CosmosDb, Postgres }; + } +} diff --git a/test/Scaffolding/TestApps/ModelTypesLocatorTestWebApp/Properties/launchSettings.json b/test/Scaffolding/TestApps/ModelTypesLocatorTestWebApp/Properties/launchSettings.json index 6bce4b235..165046527 100644 --- a/test/Scaffolding/TestApps/ModelTypesLocatorTestWebApp/Properties/launchSettings.json +++ b/test/Scaffolding/TestApps/ModelTypesLocatorTestWebApp/Properties/launchSettings.json @@ -18,10 +18,10 @@ "ModelTypesLocatorTestWebApp": { "commandName": "Project", "launchBrowser": true, + "applicationUrl": "https://localhost:5001;http://localhost:5000", "environmentVariables": { "ASPNETCORE_ENVIRONMENT": "Development" - }, - "applicationUrl": "https://localhost:5001;http://localhost:5000" + } } } } \ No newline at end of file diff --git a/test/Scaffolding/VS.Web.CG.EFCore.Test/ConnectionStringsWriterTests.cs b/test/Scaffolding/VS.Web.CG.EFCore.Test/ConnectionStringsWriterTests.cs index 5152ffc96..8717e7e1f 100644 --- a/test/Scaffolding/VS.Web.CG.EFCore.Test/ConnectionStringsWriterTests.cs +++ b/test/Scaffolding/VS.Web.CG.EFCore.Test/ConnectionStringsWriterTests.cs @@ -6,6 +6,7 @@ using System.Reflection; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; +using Microsoft.DotNet.Scaffolding.Shared; using Microsoft.DotNet.Scaffolding.Shared.ProjectModel; using Microsoft.VisualStudio.Web.CodeGeneration.DotNet; using Microsoft.VisualStudio.Web.CodeGeneration.EntityFrameworkCore; @@ -34,17 +35,30 @@ public void AddConnectionString_Creates_App_Settings_File() var fs = new MockFileSystem(); var testObj = GetTestObject(fs); - //Act + //Act, test obsolete AddConnectionString testObj.AddConnectionString("MyDbContext", "MyDbContext-NewGuid", false); - + //test SqlServer + testObj.AddConnectionString("MyDbContext2", "MyDbContext-SqlServerDb", DbProvider.SqlServer); + //test SqlServer + testObj.AddConnectionString("MyDbContext3", "MyDbContext-SqliteDb", DbProvider.SQLite); + //test SqlServer + testObj.AddConnectionString("MyDbContext4", "MyDbContext-CosmosDb", DbProvider.CosmosDb); + //test SqlServer + testObj.AddConnectionString("MyDbContext5", "MyDbContext-PostgresDb", DbProvider.Postgres); //Assert string expected = @"{ ""ConnectionStrings"": { - ""MyDbContext"": ""Server=(localdb)\\mssqllocaldb;Database=MyDbContext-NewGuid;Trusted_Connection=True;MultipleActiveResultSets=true"" + ""MyDbContext"": ""Server=(localdb)\\mssqllocaldb;Database=MyDbContext-NewGuid;Trusted_Connection=True;MultipleActiveResultSets=true"", + ""MyDbContext2"": ""Server=(localdb)\\mssqllocaldb;Database=MyDbContext-SqlServerDb;Trusted_Connection=True;MultipleActiveResultSets=true"", + ""MyDbContext3"": ""Data Source=MyDbContext-SqliteDb.db"", + ""MyDbContext4"": ""AccountEndpoint=https://localhost:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw=="", + ""MyDbContext5"": ""server=localhost;username=postgres;database=MyDbContext-PostgresDb"" } }"; + var appSettingsPath = Path.Combine(AppBase, "appsettings.json"); fs.FileExists(appSettingsPath); + var appsettingsstring = fs.ReadAllText(appSettingsPath); Assert.Equal(expected, fs.ReadAllText(appSettingsPath), ignoreCase: false, ignoreLineEndingDifferences: true); } diff --git a/test/Scaffolding/VS.Web.CG.EFCore.Test/DbContextEditorServicesTests.cs b/test/Scaffolding/VS.Web.CG.EFCore.Test/DbContextEditorServicesTests.cs index 35e92febc..3f10e36b2 100644 --- a/test/Scaffolding/VS.Web.CG.EFCore.Test/DbContextEditorServicesTests.cs +++ b/test/Scaffolding/VS.Web.CG.EFCore.Test/DbContextEditorServicesTests.cs @@ -1,14 +1,17 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - +using System; +using System.Collections.Generic; using System.IO; using System.Linq; using System.Reflection; using System.Threading.Tasks; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; +using Microsoft.DotNet.Scaffolding.Shared; using Microsoft.DotNet.Scaffolding.Shared.Project; using Microsoft.DotNet.Scaffolding.Shared.ProjectModel; +using Microsoft.EntityFrameworkCore; using Microsoft.VisualStudio.Web.CodeGeneration.DotNet; using Microsoft.VisualStudio.Web.CodeGeneration.EntityFrameworkCore; using Microsoft.VisualStudio.Web.CodeGeneration.EntityFrameworkCore.Test; @@ -71,7 +74,7 @@ public void AddModelToContext_Adds_Model_From_Same_Project_To_Context(string bef var modelType = ModelType.FromITypeSymbol(types.Where(ts => ts.Name == "MyModel").First()); var contextType = ModelType.FromITypeSymbol(types.Where(ts => ts.Name == "MyContext").First()); - var result = testObj.AddModelToContext(contextType, modelType, nullableEnabled: false); + var result = testObj.AddModelToContext(contextType, modelType, new Dictionary() { { "nullableEnabled", bool.FalseString } }); Assert.True(result.Edited); @@ -100,7 +103,7 @@ public void AddModelToContext_Adds_Model_From_Same_Project_To_Context_With_Nulla var modelType = ModelType.FromITypeSymbol(types.Where(ts => ts.Name == "MyModel").First()); var contextType = ModelType.FromITypeSymbol(types.Where(ts => ts.Name == "MyContext").First()); - var result = testObj.AddModelToContext(contextType, modelType, nullableEnabled: true); + var result = testObj.AddModelToContext(contextType, modelType, new Dictionary() { { "nullableEnabled", bool.TrueString } }); Assert.True(result.Edited); Assert.Equal(afterDbContextText, result.NewTree.GetText().ToString(), ignoreCase: false, ignoreLineEndingDifferences: true); @@ -147,7 +150,7 @@ public async Task GetAddDbContextStatementTests() DbContextEditorServices testObj = GetTestObject(); var syntaxTree = CSharpSyntaxTree.ParseText(MinimalProgramCsFile); var root = await syntaxTree.GetRootAsync(); - var dbContextExpression = testObj.GetAddDbContextStatement(root, "DbContextName", "DatabaseName", useSqlite: false); + var dbContextExpression = testObj.GetAddDbContextStatement(root, "DbContextName", "DatabaseName", DbProvider.SqlServer); var correctDbContextString = "builder.Services.AddDbContext(options => options.UseSqlServer(builder.Configuration.GetConnectionString(\"DbContextName\") ?? throw new InvalidOperationException(\"Connection string 'DbContextName' not found.\")));"; var trimmedDbContextString = ProjectModifierHelper.TrimStatement(dbContextExpression.ToString()); @@ -155,6 +158,13 @@ public async Task GetAddDbContextStatementTests() Assert.Equal(trimmedCorrectDbContextString, trimmedDbContextString); } + [Theory] + [MemberData(nameof(AddDbContextStringData))] + public void AddDbContextStringTests(bool minimalHostingTemplate, string statementLeadingTrivia, DbProvider databaseProvider, string optionsExpected) + { + string dbContextString = GetTestObject().AddDbContextString(minimalHostingTemplate, statementLeadingTrivia, databaseProvider); + Assert.True(dbContextString.Contains(optionsExpected, StringComparison.OrdinalIgnoreCase)); + } private DbContextEditorServices GetTestObject(MockFileSystem fs = null) { @@ -169,6 +179,20 @@ private DbContextEditorServices GetTestObject(MockFileSystem fs = null) fs != null ? fs : new MockFileSystem()); } + public static IEnumerable AddDbContextStringData => + new [] + { + new object[] { true, string.Empty, DbProvider.SqlServer, "options.UseSqlServer" }, + new object[] { false, string.Empty, DbProvider.SqlServer, "options.UseSqlServer" }, + new object[] { true, string.Empty, DbProvider.SQLite, "options.UseSqlite" }, + new object[] { false, string.Empty, DbProvider.SQLite, "options.UseSqlite" }, + new object[] { true, string.Empty, DbProvider.CosmosDb, "options.UseCosmos" }, + new object[] { false, string.Empty, DbProvider.CosmosDb, "options.UseCosmos" }, + new object[] { true, string.Empty, DbProvider.Postgres, "options.UseNpgsql" }, + new object[] { false, string.Empty, DbProvider.Postgres, "options.UseNpgsql" }, + new object[] { true, null, DbProvider.SqlServer, "options.UseSqlServer" }, + new object[] { true, null, null, string.Empty } + }; private static readonly string AppBase = "AppBase"; } } diff --git a/test/Scaffolding/VS.Web.CG.EFCore.Test/EntityFrameworkServicesTests.cs b/test/Scaffolding/VS.Web.CG.EFCore.Test/EntityFrameworkServicesTests.cs index e11140d42..877f93ae2 100644 --- a/test/Scaffolding/VS.Web.CG.EFCore.Test/EntityFrameworkServicesTests.cs +++ b/test/Scaffolding/VS.Web.CG.EFCore.Test/EntityFrameworkServicesTests.cs @@ -2,6 +2,8 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Collections; +using System.Collections.Generic; using System.IO; using System.Linq; using Microsoft.DotNet.Scaffolding.Shared; @@ -52,16 +54,11 @@ private EntityFrameworkServices GetEfServices(string path, string applicationNam Edited = true }; - dbContextMock.Setup(db => db.EditStartupForNewContext(It.IsAny(), - It.IsAny(), - It.IsAny(), - It.IsAny(), - useSqlite, - false)) + dbContextMock.Setup(db => db.EditStartupForNewContext(It.IsAny(), It.IsAny>())) .Returns(editSyntaxTreeResult); var connectionStringsWriter = new Mock(); - connectionStringsWriter.Setup(c => c.AddConnectionString(It.IsAny(), It.IsAny(), It.IsAny())); + connectionStringsWriter.Setup(c => c.AddConnectionString(It.IsAny(), It.IsAny(), It.IsAny())); var filesLocator = new FilesLocator(); var compilationService = new RoslynCompilationService(_appInfo, _loader, _projectContext); diff --git a/test/Scaffolding/VS.Web.CG.Mvc.Test/IdentityGeneratorTemplateModelBuilderTests.cs b/test/Scaffolding/VS.Web.CG.Mvc.Test/IdentityGeneratorTemplateModelBuilderTests.cs index bdc48e5c3..eca2a2d6a 100644 --- a/test/Scaffolding/VS.Web.CG.Mvc.Test/IdentityGeneratorTemplateModelBuilderTests.cs +++ b/test/Scaffolding/VS.Web.CG.Mvc.Test/IdentityGeneratorTemplateModelBuilderTests.cs @@ -46,8 +46,7 @@ public async Task TestValidateAndBuild() var commandLineModel = new IdentityGeneratorCommandLineModel() { - RootNamespace = "Test.Namespace", - UseSqlite = false + RootNamespace = "Test.Namespace" }; var applicationInfo = new ApplicationInfo("TestApp", "Sample"); diff --git a/test/Scaffolding/VS.Web.CG.Mvc.Test/IdentityGeneratorTests.cs b/test/Scaffolding/VS.Web.CG.Mvc.Test/IdentityGeneratorTests.cs index 0f6c4bed8..c89f95090 100644 --- a/test/Scaffolding/VS.Web.CG.Mvc.Test/IdentityGeneratorTests.cs +++ b/test/Scaffolding/VS.Web.CG.Mvc.Test/IdentityGeneratorTests.cs @@ -36,7 +36,7 @@ public void ReplaceIdentityStringsTests( { for (int i = 0; i < unModifiedStrings.Count; i++) { - string editResult = IdentityGenerator.EditIdentityStrings(unModifiedStrings[i], dbContext, identityUser, false, 0); + string editResult = IdentityGenerator.EditIdentityStrings(unModifiedStrings[i], dbContext, identityUser, DotNet.Scaffolding.Shared.DbProvider.SqlServer, 0); Assert.Contains(editResult, modifiedStrings); } } diff --git a/test/Scaffolding/VS.Web.CG.Mvc.Test/ModelMetadataUtilitiesTest.cs b/test/Scaffolding/VS.Web.CG.Mvc.Test/ModelMetadataUtilitiesTest.cs index a61d45326..5f16a9fe0 100644 --- a/test/Scaffolding/VS.Web.CG.Mvc.Test/ModelMetadataUtilitiesTest.cs +++ b/test/Scaffolding/VS.Web.CG.Mvc.Test/ModelMetadataUtilitiesTest.cs @@ -5,6 +5,7 @@ using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; +using Microsoft.DotNet.Scaffolding.Shared; using Microsoft.DotNet.Scaffolding.Shared.Project; using Microsoft.VisualStudio.Web.CodeGeneration; using Microsoft.VisualStudio.Web.CodeGeneration.EntityFrameworkCore; @@ -21,6 +22,7 @@ public class ModelMetadataUtilitiesTest private CommonCommandLineModel model; private Mock modelTypesLocator; private Mock modelTypesLocatorWithoutContext; + private Mock logger; public ModelMetadataUtilitiesTest() { @@ -28,6 +30,7 @@ public ModelMetadataUtilitiesTest() modelTypesLocator = new Mock(); modelTypesLocatorWithoutContext = new Mock(); codeModelService = new Mock(); + logger= new Mock(); } [Fact] @@ -114,7 +117,7 @@ public async void TestValidateModelAndGetModelMetadata() ModelMetadata = null }; - efService.Setup(e => e.GetModelMetadata(model.DataContextClass, modelType, string.Empty, false)) + efService.Setup(e => e.GetModelMetadata(model.DataContextClass, modelType, string.Empty, DbProvider.SqlServer)) .Returns(Task.FromResult(contextProcessingResult)); //Act @@ -122,6 +125,7 @@ public async void TestValidateModelAndGetModelMetadata() model, efService.Object, modelTypesLocator.Object, + logger.Object, string.Empty); //Assert @@ -136,7 +140,7 @@ public async void TestValidateModelAndGetModelMetadata() FullName = "A.B.C.SampleDataContext" }; dataContextTypes.Add(dataContextType); - efService.Setup(e => e.GetModelMetadata(dataContextType.FullName, modelType, string.Empty, false)) + efService.Setup(e => e.GetModelMetadata(dataContextType.FullName, modelType, string.Empty, DbProvider.SqlServer)) .Returns(Task.FromResult(contextProcessingResult)); //Act @@ -144,6 +148,7 @@ public async void TestValidateModelAndGetModelMetadata() model, efService.Object, modelTypesLocator.Object, + logger.Object, string.Empty); //Assert @@ -155,7 +160,6 @@ public async void TestValidateModelAndGetModelMetadata() [Fact] public async void TestGetModelEFMetadataMinimalAsync() { - var modelTypes = new List(); var dataContextTypes = new List(); //Arrange @@ -182,6 +186,7 @@ public async void TestGetModelEFMetadataMinimalAsync() minimalApiModelWithContext, efService.Object, modelTypesLocator.Object, + logger.Object, areaName: string.Empty)); Exception exWithoutContext = await Assert.ThrowsAsync( @@ -189,6 +194,7 @@ public async void TestGetModelEFMetadataMinimalAsync() minimalApiModelWithoutContext, efService.Object, modelTypesLocatorWithoutContext.Object, + logger.Object, areaName: string.Empty)); Assert.Equal("A type with the name SampleModel does not exist", ex.Message); @@ -216,7 +222,7 @@ public async void TestGetModelEFMetadataMinimalAsync() ModelMetadata = null }; - efService.Setup(e => e.GetModelMetadata(minimalApiModelWithContext.DataContextClass, modelType, string.Empty, false)) + efService.Setup(e => e.GetModelMetadata(minimalApiModelWithContext.DataContextClass, modelType, string.Empty, DbProvider.SqlServer)) .Returns(Task.FromResult(contextProcessingResult)); //Act @@ -224,12 +230,14 @@ public async void TestGetModelEFMetadataMinimalAsync() minimalApiModelWithContext, efService.Object, modelTypesLocator.Object, + logger.Object, string.Empty); var resultWithoutContext = await ModelMetadataUtilities.GetModelEFMetadataMinimalAsync( minimalApiModelWithoutContext, efService.Object, modelTypesLocatorWithoutContext.Object, + logger.Object, string.Empty); //Assert scenario with DbContext @@ -249,15 +257,16 @@ public async void TestGetModelEFMetadataMinimalAsync() FullName = "A.B.C.SampleDataContext" }; dataContextTypes.Add(dataContextType); - efService.Setup(e => e.GetModelMetadata(dataContextType.FullName, modelType, string.Empty, false)) + efService.Setup(e => e.GetModelMetadata(dataContextType.FullName, modelType, string.Empty, DbProvider.SqlServer)) .Returns(Task.FromResult(contextProcessingResult)); //Act result = await ModelMetadataUtilities.GetModelEFMetadataMinimalAsync( - minimalApiModelWithContext, - efService.Object, - modelTypesLocator.Object, - string.Empty); + minimalApiModelWithContext, + efService.Object, + modelTypesLocator.Object, + logger.Object, + string.Empty); //Assert Assert.Equal(contextProcessingResult.ContextProcessingStatus, result.ContextProcessingResult.ContextProcessingStatus);