diff --git a/packages/bolt-connection/src/bolt/bolt-protocol-util.js b/packages/bolt-connection/src/bolt/bolt-protocol-util.js index a7802e982..311eacc03 100644 --- a/packages/bolt-connection/src/bolt/bolt-protocol-util.js +++ b/packages/bolt-connection/src/bolt/bolt-protocol-util.js @@ -57,4 +57,25 @@ function assertDatabaseIsEmpty (database, onProtocolError = () => {}, observer) } } -export { assertDatabaseIsEmpty, assertTxConfigIsEmpty } +/** + * Asserts that the passed-in impersonated user is empty + * @param {string} impersonatedUser + * @param {function (err:Error)} onProtocolError Called when it does have impersonated user set + * @param {any} observer + */ +function assertImpersonatedUserIsEmpty (impersonatedUser, onProtocolError = () => {}, observer) { + if (impersonatedUser) { + const error = newError( + 'Driver is connected to the database that does not support user impersonation. ' + + 'Please upgrade to neo4j 4.4.0 or later in order to use this functionality. ' + + `Trying to impersonate ${impersonatedUser}.` + ) + + // unsupported API was used, consider this a fatal error for the current connection + onProtocolError(error.message) + observer.onError(error) + throw error + } +} + +export { assertDatabaseIsEmpty, assertTxConfigIsEmpty, assertImpersonatedUserIsEmpty } diff --git a/packages/bolt-connection/src/bolt/bolt-protocol-v1.js b/packages/bolt-connection/src/bolt/bolt-protocol-v1.js index c48f328b9..4d3e2e711 100644 --- a/packages/bolt-connection/src/bolt/bolt-protocol-v1.js +++ b/packages/bolt-connection/src/bolt/bolt-protocol-v1.js @@ -18,7 +18,8 @@ */ import { assertDatabaseIsEmpty, - assertTxConfigIsEmpty + assertTxConfigIsEmpty, + assertImpersonatedUserIsEmpty } from './bolt-protocol-util' import { Chunker } from '../chunking' import { v1 } from '../packstream' @@ -143,6 +144,7 @@ export default class BoltProtocol { * @param {TxConfig} param.txConfig the configuration. * @param {string} param.database the target database name. * @param {string} param.mode the access mode. + * @param {string} param.impersonatedUser the impersonated user * @param {function(err: Error)} param.beforeError the callback to invoke before handling the error. * @param {function(err: Error)} param.afterError the callback to invoke after handling the error. * @param {function()} param.beforeComplete the callback to invoke before handling the completion. @@ -154,6 +156,7 @@ export default class BoltProtocol { txConfig, database, mode, + impersonatedUser, beforeError, afterError, beforeComplete, @@ -167,6 +170,7 @@ export default class BoltProtocol { txConfig: txConfig, database, mode, + impersonatedUser, beforeError, afterError, beforeComplete, @@ -248,6 +252,7 @@ export default class BoltProtocol { * @param {Bookmark} param.bookmark the bookmark. * @param {TxConfig} param.txConfig the transaction configuration. * @param {string} param.database the target database name. + * @param {string} param.impersonatedUser the impersonated user * @param {string} param.mode the access mode. * @param {function(keys: string[])} param.beforeKeys the callback to invoke before handling the keys. * @param {function(keys: string[])} param.afterKeys the callback to invoke after handling the keys. @@ -266,6 +271,7 @@ export default class BoltProtocol { txConfig, database, mode, + impersonatedUser, beforeKeys, afterKeys, beforeError, @@ -289,6 +295,8 @@ export default class BoltProtocol { assertTxConfigIsEmpty(txConfig, this._onProtocolError, observer) // passing in a database name on this protocol version throws an error assertDatabaseIsEmpty(database, this._onProtocolError, observer) + // passing impersonated user on this protocol version throws an error + assertImpersonatedUserIsEmpty(impersonatedUser, this._onProtocolError, observer) this.write(RequestMessage.run(query, parameters), observer, false) this.write(RequestMessage.pullAll(), observer, flush) diff --git a/packages/bolt-connection/src/bolt/bolt-protocol-v3.js b/packages/bolt-connection/src/bolt/bolt-protocol-v3.js index a3425e48b..680ec84b9 100644 --- a/packages/bolt-connection/src/bolt/bolt-protocol-v3.js +++ b/packages/bolt-connection/src/bolt/bolt-protocol-v3.js @@ -18,7 +18,7 @@ */ import BoltProtocolV2 from './bolt-protocol-v2' import RequestMessage from './request-message' -import { assertDatabaseIsEmpty } from './bolt-protocol-util' +import { assertDatabaseIsEmpty, assertImpersonatedUserIsEmpty } from './bolt-protocol-util' import { StreamObserver, LoginObserver, @@ -78,6 +78,7 @@ export default class BoltProtocol extends BoltProtocolV2 { bookmark, txConfig, database, + impersonatedUser, mode, beforeError, afterError, @@ -95,6 +96,8 @@ export default class BoltProtocol extends BoltProtocolV2 { // passing in a database name on this protocol version throws an error assertDatabaseIsEmpty(database, this._onProtocolError, observer) + // passing impersonated user on this protocol version throws an error + assertImpersonatedUserIsEmpty(impersonatedUser, this._onProtocolError, observer) this.write( RequestMessage.begin({ bookmark, txConfig, mode }), @@ -152,6 +155,7 @@ export default class BoltProtocol extends BoltProtocolV2 { bookmark, txConfig, database, + impersonatedUser, mode, beforeKeys, afterKeys, @@ -174,6 +178,8 @@ export default class BoltProtocol extends BoltProtocolV2 { // passing in a database name on this protocol version throws an error assertDatabaseIsEmpty(database, this._onProtocolError, observer) + // passing impersonated user on this protocol version throws an error + assertImpersonatedUserIsEmpty(impersonatedUser, this._onProtocolError, observer) this.write( RequestMessage.runWithMetadata(query, parameters, { diff --git a/packages/bolt-connection/src/bolt/bolt-protocol-v4x0.js b/packages/bolt-connection/src/bolt/bolt-protocol-v4x0.js index a3f7e01b1..5521d43f7 100644 --- a/packages/bolt-connection/src/bolt/bolt-protocol-v4x0.js +++ b/packages/bolt-connection/src/bolt/bolt-protocol-v4x0.js @@ -18,6 +18,7 @@ */ import BoltProtocolV3 from './bolt-protocol-v3' import RequestMessage, { ALL } from './request-message' +import { assertImpersonatedUserIsEmpty } from './bolt-protocol-util' import { ResultStreamObserver, ProcedureRouteObserver @@ -44,6 +45,7 @@ export default class BoltProtocol extends BoltProtocolV3 { bookmark, txConfig, database, + impersonatedUser, mode, beforeError, afterError, @@ -59,6 +61,9 @@ export default class BoltProtocol extends BoltProtocolV3 { }) observer.prepareToHandleSingleResponse() + // passing impersonated user on this protocol version throws an error + assertImpersonatedUserIsEmpty(impersonatedUser, this._onProtocolError, observer) + this.write( RequestMessage.begin({ bookmark, txConfig, database, mode }), observer, @@ -75,6 +80,7 @@ export default class BoltProtocol extends BoltProtocolV3 { bookmark, txConfig, database, + impersonatedUser, mode, beforeKeys, afterKeys, @@ -101,6 +107,9 @@ export default class BoltProtocol extends BoltProtocolV3 { afterComplete }) + // passing impersonated user on this protocol version throws an error + assertImpersonatedUserIsEmpty(impersonatedUser, this._onProtocolError, observer) + const flushRun = reactive this.write( RequestMessage.runWithMetadata(query, parameters, { diff --git a/packages/bolt-connection/src/bolt/bolt-protocol-v4x4.js b/packages/bolt-connection/src/bolt/bolt-protocol-v4x4.js new file mode 100644 index 000000000..860567ffb --- /dev/null +++ b/packages/bolt-connection/src/bolt/bolt-protocol-v4x4.js @@ -0,0 +1,153 @@ +/** + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import BoltProtocolV43 from './bolt-protocol-v4x3' + +import { internal } from 'neo4j-driver-core' +import RequestMessage, { ALL } from './request-message' +import { RouteObserver, ResultStreamObserver } from './stream-observers' + +const { + constants: { BOLT_PROTOCOL_V4_4 }, + bookmark: { Bookmark }, +} = internal + +export default class BoltProtocol extends BoltProtocolV43 { + get version() { + return BOLT_PROTOCOL_V4_4 + } + + /** + * Request routing information + * + * @param {Object} param - + * @param {object} param.routingContext The routing context used to define the routing table. + * Multi-datacenter deployments is one of its use cases + * @param {string} param.databaseName The database name + * @param {Bookmark} params.sessionContext.bookmark The bookmark used for request the routing table + * @param {function(err: Error)} param.onError + * @param {function(RawRoutingTable)} param.onCompleted + * @returns {RouteObserver} the route observer + */ + requestRoutingInformation ({ + routingContext = {}, + databaseName = null, + impersonatedUser = null, + sessionContext = {}, + onError, + onCompleted + }) { + const observer = new RouteObserver({ + onProtocolError: this._onProtocolError, + onError, + onCompleted + }) + const bookmark = sessionContext.bookmark || Bookmark.empty() + this.write( + RequestMessage.routeV4x4(routingContext, bookmark.values(), { databaseName, impersonatedUser }), + observer, + true + ) + + return observer + } + + run ( + query, + parameters, + { + bookmark, + txConfig, + database, + mode, + impersonatedUser, + beforeKeys, + afterKeys, + beforeError, + afterError, + beforeComplete, + afterComplete, + flush = true, + reactive = false, + fetchSize = ALL + } = {} + ) { + const observer = new ResultStreamObserver({ + server: this._server, + reactive: reactive, + fetchSize: fetchSize, + moreFunction: this._requestMore.bind(this), + discardFunction: this._requestDiscard.bind(this), + beforeKeys, + afterKeys, + beforeError, + afterError, + beforeComplete, + afterComplete + }) + + const flushRun = reactive + this.write( + RequestMessage.runWithMetadata(query, parameters, { + bookmark, + txConfig, + database, + mode, + impersonatedUser + }), + observer, + flushRun && flush + ) + + if (!reactive) { + this.write(RequestMessage.pull({ n: fetchSize }), observer, flush) + } + + return observer + } + + beginTransaction ({ + bookmark, + txConfig, + database, + mode, + impersonatedUser, + beforeError, + afterError, + beforeComplete, + afterComplete + } = {}) { + const observer = new ResultStreamObserver({ + server: this._server, + beforeError, + afterError, + beforeComplete, + afterComplete + }) + observer.prepareToHandleSingleResponse() + + this.write( + RequestMessage.begin({ bookmark, txConfig, database, mode, impersonatedUser }), + observer, + true + ) + + return observer + } + +} diff --git a/packages/bolt-connection/src/bolt/create.js b/packages/bolt-connection/src/bolt/create.js index 0f8721b25..a8ab0d0d6 100644 --- a/packages/bolt-connection/src/bolt/create.js +++ b/packages/bolt-connection/src/bolt/create.js @@ -25,6 +25,7 @@ import BoltProtocolV4x0 from './bolt-protocol-v4x0' import BoltProtocolV4x1 from './bolt-protocol-v4x1' import BoltProtocolV4x2 from './bolt-protocol-v4x2' import BoltProtocolV4x3 from './bolt-protocol-v4x3' +import BoltProtocolV4x4 from './bolt-protocol-v4x4' import { Chunker, Dechunker } from '../channel' import ResponseHandler from './response-handler' @@ -164,6 +165,16 @@ function createProtocol ( onProtocolError, serversideRouting ) + case 4.4: + return new BoltProtocolV4x4( + server, + chunker, + packingConfig, + createResponseHandler, + log, + onProtocolError, + serversideRouting + ) default: throw newError('Unknown Bolt protocol version: ' + version) } diff --git a/packages/bolt-connection/src/bolt/handshake.js b/packages/bolt-connection/src/bolt/handshake.js index b875a28e3..bbed15022 100644 --- a/packages/bolt-connection/src/bolt/handshake.js +++ b/packages/bolt-connection/src/bolt/handshake.js @@ -76,7 +76,7 @@ function parseNegotiatedResponse (buffer) { */ function newHandshakeBuffer () { return createHandshakeMessage([ - [version(4, 3), version(4, 2)], + [version(4, 4), version(4, 2)], version(4, 1), version(4, 0), version(3, 0) diff --git a/packages/bolt-connection/src/bolt/request-message.js b/packages/bolt-connection/src/bolt/request-message.js index 39919fcc8..11b29852f 100644 --- a/packages/bolt-connection/src/bolt/request-message.js +++ b/packages/bolt-connection/src/bolt/request-message.js @@ -124,10 +124,11 @@ export default class RequestMessage { * @param {TxConfig} txConfig the configuration. * @param {string} database the database name. * @param {string} mode the access mode. + * @param {string} impersonatedUser the impersonated user. * @return {RequestMessage} new BEGIN message. */ - static begin ({ bookmark, txConfig, database, mode } = {}) { - const metadata = buildTxMetadata(bookmark, txConfig, database, mode) + static begin ({ bookmark, txConfig, database, mode, impersonatedUser } = {}) { + const metadata = buildTxMetadata(bookmark, txConfig, database, mode, impersonatedUser) return new RequestMessage( BEGIN, [metadata], @@ -159,14 +160,15 @@ export default class RequestMessage { * @param {TxConfig} txConfig the configuration. * @param {string} database the database name. * @param {string} mode the access mode. + * @param {string} impersonatedUser the impersonated user. * @return {RequestMessage} new RUN message with additional metadata. */ static runWithMetadata ( query, parameters, - { bookmark, txConfig, database, mode } = {} + { bookmark, txConfig, database, mode, impersonatedUser } = {} ) { - const metadata = buildTxMetadata(bookmark, txConfig, database, mode) + const metadata = buildTxMetadata(bookmark, txConfig, database, mode, impersonatedUser) return new RequestMessage( RUN, [query, parameters, metadata], @@ -237,6 +239,37 @@ export default class RequestMessage { )} ${databaseName}` ) } + + /** + * Generate the ROUTE message, this message is used to fetch the routing table from the server + * + * @param {object} routingContext The routing context used to define the routing table. Multi-datacenter deployments is one of its use cases + * @param {string[]} bookmarks The list of the bookmark should be used + * @param {object} databaseContext The context inforamtion of the database to get the routing table for. + * @param {string} databaseContext.databaseName The name of the database to get the routing table. + * @param {string} databaseContext.impersonatedUser The name of the user to impersonation when getting the routing table. + * @return {RequestMessage} the ROUTE message. + */ + static routeV4x4 (routingContext = {}, bookmarks = [], databaseContext = {}) { + const dbContext = {} + + if ( databaseContext.databaseName ) { + dbContext.db = databaseContext.databaseName + } + + if ( databaseContext.impersonatedUser ) { + dbContext.imp_user = databaseContext.impersonatedUser + } + + return new RequestMessage( + ROUTE, + [routingContext, bookmarks, dbContext], + () => + `ROUTE ${json.stringify(routingContext)} ${json.stringify( + bookmarks + )} ${json.stringify(dbContext)}` + ) + } } /** @@ -245,9 +278,10 @@ export default class RequestMessage { * @param {TxConfig} txConfig the configuration. * @param {string} database the database name. * @param {string} mode the access mode. + * @param {string} impersonatedUser the impersonated user mode. * @return {Object} a metadata object. */ -function buildTxMetadata (bookmark, txConfig, database, mode) { +function buildTxMetadata (bookmark, txConfig, database, mode, impersonatedUser) { const metadata = {} if (!bookmark.isEmpty()) { metadata.bookmarks = bookmark.values() @@ -261,6 +295,9 @@ function buildTxMetadata (bookmark, txConfig, database, mode) { if (database) { metadata.db = assertString(database, 'database') } + if (impersonatedUser) { + metadata.imp_user = assertString(impersonatedUser, 'impersonatedUser') + } if (mode === ACCESS_MODE_READ) { metadata.mode = READ_MODE } diff --git a/packages/bolt-connection/src/bolt/routing-table-raw.js b/packages/bolt-connection/src/bolt/routing-table-raw.js index e8f77e675..d8d50ae80 100644 --- a/packages/bolt-connection/src/bolt/routing-table-raw.js +++ b/packages/bolt-connection/src/bolt/routing-table-raw.js @@ -64,6 +64,15 @@ export default class RawRoutingTable { throw new Error('Not implemented') } + /** + * Get raw db + * + * @returns {string?} The database name + */ + get db () { + throw new Error('Not implemented') + } + /** * * @typedef {Object} ServerRole @@ -103,6 +112,10 @@ class ResponseRawRoutingTable extends RawRoutingTable { return this._response.rt.servers } + get db () { + return this._response.rt.db + } + get isNull () { return this._response === null } @@ -134,6 +147,10 @@ class RecordRawRoutingTable extends RawRoutingTable { return this._record.get('servers') } + get db () { + return this._record.has('db') ? this._record.get('db') : null + } + get isNull () { return this._record === null } diff --git a/packages/bolt-connection/src/connection-provider/connection-provider-direct.js b/packages/bolt-connection/src/connection-provider/connection-provider-direct.js index 76f0f408b..1fa74947a 100644 --- a/packages/bolt-connection/src/connection-provider/connection-provider-direct.js +++ b/packages/bolt-connection/src/connection-provider/connection-provider-direct.js @@ -26,7 +26,7 @@ import { import { internal, error } from 'neo4j-driver-core' const { - constants: { BOLT_PROTOCOL_V4_0, BOLT_PROTOCOL_V3 } + constants: { BOLT_PROTOCOL_V3, BOLT_PROTOCOL_V4_0, BOLT_PROTOCOL_V4_4 } } = internal const { SERVICE_UNAVAILABLE, newError } = error @@ -97,4 +97,10 @@ export default class DirectConnectionProvider extends PooledConnectionProvider { version => version >= BOLT_PROTOCOL_V3 ) } + + async supportsUserImpersonation () { + return await this._hasProtocolVersion( + version => version >= BOLT_PROTOCOL_V4_4 + ) + } } diff --git a/packages/bolt-connection/src/connection-provider/connection-provider-routing.js b/packages/bolt-connection/src/connection-provider/connection-provider-routing.js index fcd6b25ba..7646abd6c 100644 --- a/packages/bolt-connection/src/connection-provider/connection-provider-routing.js +++ b/packages/bolt-connection/src/connection-provider/connection-provider-routing.js @@ -36,7 +36,8 @@ const { ACCESS_MODE_READ: READ, ACCESS_MODE_WRITE: WRITE, BOLT_PROTOCOL_V3, - BOLT_PROTOCOL_V4_0 + BOLT_PROTOCOL_V4_0, + BOLT_PROTOCOL_V4_4 } } = internal @@ -123,22 +124,30 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider * See {@link ConnectionProvider} for more information about this method and * its arguments. */ - async acquireConnection ({ accessMode, database, bookmarks } = {}) { + async acquireConnection ({ accessMode, database, bookmarks, impersonatedUser, onDatabaseNameResolved } = {}) { let name let address + const context = { database: database || DEFAULT_DB_NAME } const databaseSpecificErrorHandler = new ConnectionErrorHandler( SESSION_EXPIRED, - (error, address) => this._handleUnavailability(error, address, database), - (error, address) => this._handleWriteFailure(error, address, database), + (error, address) => this._handleUnavailability(error, address, context.database), + (error, address) => this._handleWriteFailure(error, address, context.database), (error, address) => - this._handleAuthorizationExpired(error, address, database) + this._handleAuthorizationExpired(error, address, context.database) ) const routingTable = await this._freshRoutingTable({ accessMode, - database: database || DEFAULT_DB_NAME, - bookmark: bookmarks + database: context.database, + bookmark: bookmarks, + impersonatedUser, + onDatabaseNameResolved: (databaseName) => { + context.database = context.database || databaseName + if (onDatabaseNameResolved) { + onDatabaseNameResolved(databaseName) + } + } }) // select a target server based on specified access mode @@ -224,6 +233,12 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider ) } + async supportsUserImpersonation () { + return await this._hasProtocolVersion( + version => version >= BOLT_PROTOCOL_V4_4 + ) + } + forget (address, database) { this._routingTableRegistry.apply(database, { applyWhenExists: routingTable => routingTable.forget(address) @@ -235,7 +250,7 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider } forgetWriter (address, database) { - this._routingTableRegistry.apply(database, { + this._routingTableRegistry.apply( database, { applyWhenExists: routingTable => routingTable.forgetWriter(address) }) } @@ -244,7 +259,7 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider return this._connectionPool.acquire(address) } - _freshRoutingTable ({ accessMode, database, bookmark } = {}) { + _freshRoutingTable ({ accessMode, database, bookmark, impersonatedUser, onDatabaseNameResolved } = {}) { const currentRoutingTable = this._routingTableRegistry.get( database, () => new RoutingTable({ database }) @@ -256,30 +271,36 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider this._log.info( `Routing table is stale for database: "${database}" and access mode: "${accessMode}": ${currentRoutingTable}` ) - return this._refreshRoutingTable(currentRoutingTable, bookmark) + return this._refreshRoutingTable(currentRoutingTable, bookmark, impersonatedUser, onDatabaseNameResolved) } - _refreshRoutingTable (currentRoutingTable, bookmark) { + _refreshRoutingTable (currentRoutingTable, bookmark, impersonatedUser, onDatabaseNameResolved) { const knownRouters = currentRoutingTable.routers if (this._useSeedRouter) { return this._fetchRoutingTableFromSeedRouterFallbackToKnownRouters( knownRouters, currentRoutingTable, - bookmark + bookmark, + impersonatedUser, + onDatabaseNameResolved ) } return this._fetchRoutingTableFromKnownRoutersFallbackToSeedRouter( knownRouters, currentRoutingTable, - bookmark + bookmark, + impersonatedUser, + onDatabaseNameResolved ) } async _fetchRoutingTableFromSeedRouterFallbackToKnownRouters ( knownRouters, currentRoutingTable, - bookmark + bookmark, + impersonatedUser, + onDatabaseNameResolved ) { // we start with seed router, no routers were probed before const seenRouters = [] @@ -287,7 +308,8 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider seenRouters, this._seedRouter, currentRoutingTable, - bookmark + bookmark, + impersonatedUser ) if (newRoutingTable) { @@ -297,25 +319,30 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider newRoutingTable = await this._fetchRoutingTableUsingKnownRouters( knownRouters, currentRoutingTable, - bookmark + bookmark, + impersonatedUser ) } return await this._applyRoutingTableIfPossible( currentRoutingTable, - newRoutingTable + newRoutingTable, + onDatabaseNameResolved ) } async _fetchRoutingTableFromKnownRoutersFallbackToSeedRouter ( knownRouters, currentRoutingTable, - bookmark + bookmark, + impersonatedUser, + onDatabaseNameResolved ) { let newRoutingTable = await this._fetchRoutingTableUsingKnownRouters( knownRouters, currentRoutingTable, - bookmark + bookmark, + impersonatedUser ) if (!newRoutingTable) { @@ -324,25 +351,29 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider knownRouters, this._seedRouter, currentRoutingTable, - bookmark + bookmark, + impersonatedUser ) } return await this._applyRoutingTableIfPossible( currentRoutingTable, - newRoutingTable + newRoutingTable, + onDatabaseNameResolved ) } async _fetchRoutingTableUsingKnownRouters ( knownRouters, currentRoutingTable, - bookmark + bookmark, + impersonatedUser ) { const newRoutingTable = await this._fetchRoutingTable( knownRouters, currentRoutingTable, - bookmark + bookmark, + impersonatedUser ) if (newRoutingTable) { @@ -366,7 +397,8 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider seenRouters, seedRouter, routingTable, - bookmark + bookmark, + impersonatedUser ) { const resolvedAddresses = await this._resolveSeedRouter(seedRouter) @@ -375,7 +407,7 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider address => seenRouters.indexOf(address) < 0 ) - return await this._fetchRoutingTable(newAddresses, routingTable, bookmark) + return await this._fetchRoutingTable(newAddresses, routingTable, bookmark, impersonatedUser) } async _resolveSeedRouter (seedRouter) { @@ -387,7 +419,7 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider return [].concat.apply([], dnsResolvedAddresses) } - _fetchRoutingTable (routerAddresses, routingTable, bookmark) { + _fetchRoutingTable (routerAddresses, routingTable, bookmark, impersonatedUser) { return routerAddresses.reduce( async (refreshedTablePromise, currentRouter, currentIndex) => { const newRoutingTable = await refreshedTablePromise @@ -409,14 +441,16 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider // try next router const session = await this._createSessionForRediscovery( currentRouter, - bookmark + bookmark, + impersonatedUser ) if (session) { try { return await this._rediscovery.lookupRoutingTableOnRouter( session, routingTable.database, - currentRouter + currentRouter, + impersonatedUser ) } catch (error) { if (error && error.code === DATABASE_NOT_FOUND_ERROR_CODE) { @@ -440,7 +474,7 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider ) } - async _createSessionForRediscovery (routerAddress, bookmark) { + async _createSessionForRediscovery (routerAddress, bookmark, impersonatedUser) { try { const connection = await this._connectionPool.acquire(routerAddress) const connectionProvider = new SingleConnectionProvider(connection) @@ -458,7 +492,8 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider mode: READ, database: SYSTEM_DB_NAME, bookmark, - connectionProvider + connectionProvider, + impersonatedUser }) } catch (error) { // unable to acquire connection towards the given router @@ -471,7 +506,7 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider } } - async _applyRoutingTableIfPossible (currentRoutingTable, newRoutingTable) { + async _applyRoutingTableIfPossible (currentRoutingTable, newRoutingTable, onDatabaseNameResolved) { if (!newRoutingTable) { // none of routing servers returned valid routing table, throw exception throw newError( @@ -486,19 +521,21 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider this._useSeedRouter = true } - await this._updateRoutingTable(newRoutingTable) + await this._updateRoutingTable(newRoutingTable, onDatabaseNameResolved) return newRoutingTable } - async _updateRoutingTable (newRoutingTable) { + async _updateRoutingTable (newRoutingTable, onDatabaseNameResolved) { // close old connections to servers not present in the new routing table await this._connectionPool.keepAll(newRoutingTable.allServers()) this._routingTableRegistry.removeExpired() this._routingTableRegistry.register( - newRoutingTable.database, newRoutingTable ) + + onDatabaseNameResolved(newRoutingTable.database) + this._log.info(`Updated routing table ${newRoutingTable}`) } @@ -526,12 +563,11 @@ class RoutingTableRegistry { /** * Put a routing table in the registry * - * @param {string} database The database name * @param {RoutingTable} table The routing table * @returns {RoutingTableRegistry} this */ - register (database, table) { - this._tables.set(database, table) + register (table) { + this._tables.set(table.database, table) return this } @@ -558,12 +594,14 @@ class RoutingTableRegistry { /** * Retrieves a routing table from a given database name + * + * @param {string|impersonatedUser} impersonatedUser The impersonated User * @param {string} database The database name * @param {function()|RoutingTable} defaultSupplier The routing table supplier, if it's not a function or not exists, it will return itself as default value * @returns {RoutingTable} The routing table for the respective database */ get (database, defaultSupplier) { - if (this._tables.has(database)) { + if (this._tables.has(database) ) { return this._tables.get(database) } return typeof defaultSupplier === 'function' diff --git a/packages/bolt-connection/src/rediscovery/rediscovery.js b/packages/bolt-connection/src/rediscovery/rediscovery.js index e9a5c5bc6..cdd32cbb1 100644 --- a/packages/bolt-connection/src/rediscovery/rediscovery.js +++ b/packages/bolt-connection/src/rediscovery/rediscovery.js @@ -38,15 +38,17 @@ export default class Rediscovery { * @param {Session} session the session to use. * @param {string} database the database for which to lookup routing table. * @param {ServerAddress} routerAddress the URL of the router. + * @param {string} impersonatedUser The impersonated user * @return {Promise} promise resolved with new routing table or null when connection error happened. */ - lookupRoutingTableOnRouter (session, database, routerAddress) { + lookupRoutingTableOnRouter (session, database, routerAddress, impersonatedUser) { return session._acquireConnection(connection => { return this._requestRawRoutingTable( connection, session, database, - routerAddress + routerAddress, + impersonatedUser ).then(rawRoutingTable => { if (rawRoutingTable.isNull) { return null @@ -60,11 +62,12 @@ export default class Rediscovery { }) } - _requestRawRoutingTable (connection, session, database, routerAddress) { + _requestRawRoutingTable (connection, session, database, routerAddress, impersonatedUser) { return new Promise((resolve, reject) => { connection.protocol().requestRoutingInformation({ routingContext: this._routingContext, databaseName: database, + impersonatedUser, sessionContext: { bookmark: session._lastBookmark, mode: session._mode, diff --git a/packages/bolt-connection/src/rediscovery/routing-table.js b/packages/bolt-connection/src/rediscovery/routing-table.js index c6918eb87..0565a757c 100644 --- a/packages/bolt-connection/src/rediscovery/routing-table.js +++ b/packages/bolt-connection/src/rediscovery/routing-table.js @@ -45,7 +45,7 @@ export default class RoutingTable { expirationTime, ttl } = {}) { - this.database = database + this.database = database || null this.databaseName = database || 'default database' this.routers = routers || [] this.readers = readers || [] @@ -137,7 +137,7 @@ function removeFromArray (array, element) { /** * Create a valid routing table from a raw object * - * @param {string} database the database name. It is used for logging purposes + * @param {string} db the database name. It is used for logging purposes * @param {ServerAddress} routerAddress The router address, it is used for loggin purposes * @param {RawRoutingTable} rawRoutingTable Method used to get the raw routing table to be processed * @param {RoutingTable} The valid Routing Table @@ -158,7 +158,7 @@ export function createValidRoutingTable ( assertNonEmpty(readers, 'readers', routerAddress) return new RoutingTable({ - database, + database: database || rawRoutingTable.db, routers, readers, writers, diff --git a/packages/bolt-connection/test/bolt/bolt-protocol-v1.test.js b/packages/bolt-connection/test/bolt/bolt-protocol-v1.test.js index 9c4205f57..70702c9c9 100644 --- a/packages/bolt-connection/test/bolt/bolt-protocol-v1.test.js +++ b/packages/bolt-connection/test/bolt/bolt-protocol-v1.test.js @@ -277,6 +277,47 @@ describe('#unit BoltProtocolV1', () => { }) }) + describe('Bolt v4.4', () => { + /** + * @param {string} impersonatedUser The impersonated user. + * @param {function(protocol: BoltProtocolV1)} fn + */ + function verifyImpersonationNotSupportedErrror (impersonatedUser, fn) { + const recorder = new utils.MessageRecordingConnection() + const protocol = new BoltProtocolV1(recorder, null, false) + + expect(() => fn(protocol)).toThrowError( + 'Driver is connected to the database that does not support user impersonation. ' + + 'Please upgrade to neo4j 4.4.0 or later in order to use this functionality. ' + + `Trying to impersonate ${impersonatedUser}.` + ) + } + + describe('beginTransaction', () => { + function verifyBeginTransaction(impersonatedUser) { + verifyImpersonationNotSupportedErrror( + impersonatedUser, + protocol => protocol.beginTransaction({ impersonatedUser })) + } + + it('should throw error when impersonatedUser is set', () => { + verifyBeginTransaction('test') + }) + }) + + describe('run', () => { + function verifyRun (impersonatedUser) { + verifyImpersonationNotSupportedErrror( + impersonatedUser, + protocol => protocol.run('query', {}, { impersonatedUser })) + } + + it('should throw error when impersonatedUser is set', () => { + verifyRun('test') + }) + }) + }) + describe('unpacker configuration', () => { test.each([ [false, false], diff --git a/packages/bolt-connection/test/bolt/bolt-protocol-v2.test.js b/packages/bolt-connection/test/bolt/bolt-protocol-v2.test.js index c409f8026..c8d5596a2 100644 --- a/packages/bolt-connection/test/bolt/bolt-protocol-v2.test.js +++ b/packages/bolt-connection/test/bolt/bolt-protocol-v2.test.js @@ -51,4 +51,45 @@ describe('#unit BoltProtocolV2', () => { } ) }) + + describe('Bolt v4.4', () => { + /** + * @param {string} impersonatedUser The impersonated user. + * @param {function(protocol: BoltProtocolV2)} fn + */ + function verifyImpersonationNotSupportedErrror (impersonatedUser, fn) { + const recorder = new utils.MessageRecordingConnection() + const protocol = new BoltProtocolV2(recorder, null, false) + + expect(() => fn(protocol)).toThrowError( + 'Driver is connected to the database that does not support user impersonation. ' + + 'Please upgrade to neo4j 4.4.0 or later in order to use this functionality. ' + + `Trying to impersonate ${impersonatedUser}.` + ) + } + + describe('beginTransaction', () => { + function verifyBeginTransaction(impersonatedUser) { + verifyImpersonationNotSupportedErrror( + impersonatedUser, + protocol => protocol.beginTransaction({ impersonatedUser })) + } + + it('should throw error when impersonatedUser is set', () => { + verifyBeginTransaction('test') + }) + }) + + describe('run', () => { + function verifyRun (impersonatedUser) { + verifyImpersonationNotSupportedErrror( + impersonatedUser, + protocol => protocol.run('query', {}, { impersonatedUser })) + } + + it('should throw error when impersonatedUser is set', () => { + verifyRun('test') + }) + }) + }) }) diff --git a/packages/bolt-connection/test/bolt/bolt-protocol-v3.test.js b/packages/bolt-connection/test/bolt/bolt-protocol-v3.test.js index 543e06a52..513886451 100644 --- a/packages/bolt-connection/test/bolt/bolt-protocol-v3.test.js +++ b/packages/bolt-connection/test/bolt/bolt-protocol-v3.test.js @@ -234,6 +234,47 @@ describe('#unit BoltProtocolV3', () => { }) }) + describe('Bolt v4.4', () => { + /** + * @param {string} impersonatedUser The impersonated user. + * @param {function(protocol: BoltProtocolV3)} fn + */ + function verifyImpersonationNotSupportedErrror (impersonatedUser, fn) { + const recorder = new utils.MessageRecordingConnection() + const protocol = new BoltProtocolV3(recorder, null, false) + + expect(() => fn(protocol)).toThrowError( + 'Driver is connected to the database that does not support user impersonation. ' + + 'Please upgrade to neo4j 4.4.0 or later in order to use this functionality. ' + + `Trying to impersonate ${impersonatedUser}.` + ) + } + + describe('beginTransaction', () => { + function verifyBeginTransaction(impersonatedUser) { + verifyImpersonationNotSupportedErrror( + impersonatedUser, + protocol => protocol.beginTransaction({ impersonatedUser })) + } + + it('should throw error when impersonatedUser is set', () => { + verifyBeginTransaction('test') + }) + }) + + describe('run', () => { + function verifyRun (impersonatedUser) { + verifyImpersonationNotSupportedErrror( + impersonatedUser, + protocol => protocol.run('query', {}, { impersonatedUser })) + } + + it('should throw error when impersonatedUser is set', () => { + verifyRun('test') + }) + }) + }) + describe('unpacker configuration', () => { test.each([ [false, false], diff --git a/packages/bolt-connection/test/bolt/bolt-protocol-v4x0.test.js b/packages/bolt-connection/test/bolt/bolt-protocol-v4x0.test.js index b34efe86c..662f220f4 100644 --- a/packages/bolt-connection/test/bolt/bolt-protocol-v4x0.test.js +++ b/packages/bolt-connection/test/bolt/bolt-protocol-v4x0.test.js @@ -152,6 +152,47 @@ describe('#unit BoltProtocolV4x0', () => { { ...sessionContext, txConfig: TxConfig.empty() } ]) }) + + describe('Bolt v4.4', () => { + /** + * @param {string} impersonatedUser The impersonated user. + * @param {function(protocol: BoltProtocolV4x0)} fn + */ + function verifyImpersonationNotSupportedErrror (impersonatedUser, fn) { + const recorder = new utils.MessageRecordingConnection() + const protocol = new BoltProtocolV4x0(recorder, null, false) + + expect(() => fn(protocol)).toThrowError( + 'Driver is connected to the database that does not support user impersonation. ' + + 'Please upgrade to neo4j 4.4.0 or later in order to use this functionality. ' + + `Trying to impersonate ${impersonatedUser}.` + ) + } + + describe('beginTransaction', () => { + function verifyBeginTransaction(impersonatedUser) { + verifyImpersonationNotSupportedErrror( + impersonatedUser, + protocol => protocol.beginTransaction({ impersonatedUser })) + } + + it('should throw error when impersonatedUser is set', () => { + verifyBeginTransaction('test') + }) + }) + + describe('run', () => { + function verifyRun (impersonatedUser) { + verifyImpersonationNotSupportedErrror( + impersonatedUser, + protocol => protocol.run('query', {}, { impersonatedUser })) + } + + it('should throw error when impersonatedUser is set', () => { + verifyRun('test') + }) + }) + }) describe('unpacker configuration', () => { test.each([ diff --git a/packages/bolt-connection/test/bolt/bolt-protocol-v4x1.test.js b/packages/bolt-connection/test/bolt/bolt-protocol-v4x1.test.js index e4d56bb4b..c70ed33c4 100644 --- a/packages/bolt-connection/test/bolt/bolt-protocol-v4x1.test.js +++ b/packages/bolt-connection/test/bolt/bolt-protocol-v4x1.test.js @@ -18,8 +18,50 @@ */ import BoltProtocolV4x1 from '../../src/bolt/bolt-protocol-v4x1' +import utils from '../test-utils' describe('#unit BoltProtocolV4x1', () => { + describe('Bolt v4.4', () => { + /** + * @param {string} impersonatedUser The impersonated user. + * @param {function(protocol: BoltProtocolV4x1)} fn + */ + function verifyImpersonationNotSupportedErrror (impersonatedUser, fn) { + const recorder = new utils.MessageRecordingConnection() + const protocol = new BoltProtocolV4x1(recorder, null, false) + + expect(() => fn(protocol)).toThrowError( + 'Driver is connected to the database that does not support user impersonation. ' + + 'Please upgrade to neo4j 4.4.0 or later in order to use this functionality. ' + + `Trying to impersonate ${impersonatedUser}.` + ) + } + + describe('beginTransaction', () => { + function verifyBeginTransaction(impersonatedUser) { + verifyImpersonationNotSupportedErrror( + impersonatedUser, + protocol => protocol.beginTransaction({ impersonatedUser })) + } + + it('should throw error when impersonatedUser is set', () => { + verifyBeginTransaction('test') + }) + }) + + describe('run', () => { + function verifyRun (impersonatedUser) { + verifyImpersonationNotSupportedErrror( + impersonatedUser, + protocol => protocol.run('query', {}, { impersonatedUser })) + } + + it('should throw error when impersonatedUser is set', () => { + verifyRun('test') + }) + }) + }) + describe('unpacker configuration', () => { test.each([ [false, false], diff --git a/packages/bolt-connection/test/bolt/bolt-protocol-v4x2.test.js b/packages/bolt-connection/test/bolt/bolt-protocol-v4x2.test.js index 48d829c13..b0cbcdde8 100644 --- a/packages/bolt-connection/test/bolt/bolt-protocol-v4x2.test.js +++ b/packages/bolt-connection/test/bolt/bolt-protocol-v4x2.test.js @@ -18,8 +18,49 @@ */ import BoltProtocolV4x2 from '../../src/bolt/bolt-protocol-v4x2' +import utils from '../test-utils' describe('#unit BoltProtocolV4x2', () => { + describe('Bolt v4.4', () => { + /** + * @param {string} impersonatedUser The impersonated user. + * @param {function(protocol: BoltProtocolV4x2)} fn + */ + function verifyImpersonationNotSupportedErrror (impersonatedUser, fn) { + const recorder = new utils.MessageRecordingConnection() + const protocol = new BoltProtocolV4x2(recorder, null, false) + + expect(() => fn(protocol)).toThrowError( + 'Driver is connected to the database that does not support user impersonation. ' + + 'Please upgrade to neo4j 4.4.0 or later in order to use this functionality. ' + + `Trying to impersonate ${impersonatedUser}.` + ) + } + + describe('beginTransaction', () => { + function verifyBeginTransaction(impersonatedUser) { + verifyImpersonationNotSupportedErrror( + impersonatedUser, + protocol => protocol.beginTransaction({ impersonatedUser })) + } + + it('should throw error when impersonatedUser is set', () => { + verifyBeginTransaction('test') + }) + }) + + describe('run', () => { + function verifyRun (impersonatedUser) { + verifyImpersonationNotSupportedErrror( + impersonatedUser, + protocol => protocol.run('query', {}, { impersonatedUser })) + } + + it('should throw error when impersonatedUser is set', () => { + verifyRun('test') + }) + }) + }) describe('unpacker configuration', () => { test.each([ [false, false], diff --git a/packages/bolt-connection/test/bolt/bolt-protocol-v4x3.test.js b/packages/bolt-connection/test/bolt/bolt-protocol-v4x3.test.js index 8fd70c254..41c5f4de1 100644 --- a/packages/bolt-connection/test/bolt/bolt-protocol-v4x3.test.js +++ b/packages/bolt-connection/test/bolt/bolt-protocol-v4x3.test.js @@ -238,6 +238,47 @@ describe('#unit BoltProtocolV4x3', () => { expect(protocol.flushes).toEqual([true]) }) + describe('Bolt v4.4', () => { + /** + * @param {string} impersonatedUser The impersonated user. + * @param {function(protocol: BoltProtocolV4x3)} fn + */ + function verifyImpersonationNotSupportedErrror (impersonatedUser, fn) { + const recorder = new utils.MessageRecordingConnection() + const protocol = new BoltProtocolV4x3(recorder, null, false) + + expect(() => fn(protocol)).toThrowError( + 'Driver is connected to the database that does not support user impersonation. ' + + 'Please upgrade to neo4j 4.4.0 or later in order to use this functionality. ' + + `Trying to impersonate ${impersonatedUser}.` + ) + } + + describe('beginTransaction', () => { + function verifyBeginTransaction(impersonatedUser) { + verifyImpersonationNotSupportedErrror( + impersonatedUser, + protocol => protocol.beginTransaction({ impersonatedUser })) + } + + it('should throw error when impersonatedUser is set', () => { + verifyBeginTransaction('test') + }) + }) + + describe('run', () => { + function verifyRun (impersonatedUser) { + verifyImpersonationNotSupportedErrror( + impersonatedUser, + protocol => protocol.run('query', {}, { impersonatedUser })) + } + + it('should throw error when impersonatedUser is set', () => { + verifyRun('test') + }) + }) + }) + describe('unpacker configuration', () => { test.each([ [false, false], diff --git a/packages/bolt-connection/test/bolt/bolt-protocol-v4x4.test.js b/packages/bolt-connection/test/bolt/bolt-protocol-v4x4.test.js new file mode 100644 index 000000000..2314922ed --- /dev/null +++ b/packages/bolt-connection/test/bolt/bolt-protocol-v4x4.test.js @@ -0,0 +1,335 @@ +/** + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import BoltProtocolV4x4 from '../../src/bolt/bolt-protocol-v4x4' +import RequestMessage from '../../src/bolt/request-message' +import utils from '../test-utils' +import { RouteObserver } from '../../src/bolt/stream-observers' +import { internal } from 'neo4j-driver-core' + +const WRITE = 'WRITE' + +const { + txConfig: { TxConfig }, + bookmark: { Bookmark } +} = internal + +describe('#unit BoltProtocolV4x4', () => { + beforeEach(() => { + expect.extend(utils.matchers) + }) + + it('should request routing information', () => { + const recorder = new utils.MessageRecordingConnection() + const protocol = new BoltProtocolV4x4(recorder, null, false) + utils.spyProtocolWrite(protocol) + const routingContext = { someContextParam: 'value' } + const databaseName = 'name' + + const observer = protocol.requestRoutingInformation({ + routingContext, + databaseName + }) + + protocol.verifyMessageCount(1) + expect(protocol.messages[0]).toBeMessage( + RequestMessage.routeV4x4(routingContext, [], { databaseName, impersonatedUser: null }) + ) + expect(protocol.observers).toEqual([observer]) + expect(observer).toEqual(expect.any(RouteObserver)) + expect(protocol.flushes).toEqual([true]) + }) + + it('should request routing information sending bookmarks', () => { + const recorder = new utils.MessageRecordingConnection() + const protocol = new BoltProtocolV4x4(recorder, null, false) + utils.spyProtocolWrite(protocol) + const routingContext = { someContextParam: 'value' } + const listOfBookmarks = ['a', 'b', 'c'] + const bookmark = new Bookmark(listOfBookmarks) + const databaseName = 'name' + + const observer = protocol.requestRoutingInformation({ + routingContext, + databaseName, + sessionContext: { bookmark } + }) + + protocol.verifyMessageCount(1) + expect(protocol.messages[0]).toBeMessage( + RequestMessage.routeV4x4(routingContext, listOfBookmarks, { databaseName, impersonatedUser: null}) + ) + expect(protocol.observers).toEqual([observer]) + expect(observer).toEqual(expect.any(RouteObserver)) + expect(protocol.flushes).toEqual([true]) + }) + + it('should run a query', () => { + const database = 'testdb' + const bookmark = new Bookmark([ + 'neo4j:bookmark:v1:tx1', + 'neo4j:bookmark:v1:tx2' + ]) + const txConfig = new TxConfig({ + timeout: 5000, + metadata: { x: 1, y: 'something' } + }) + const recorder = new utils.MessageRecordingConnection() + const protocol = new BoltProtocolV4x4(recorder, null, false) + utils.spyProtocolWrite(protocol) + + const query = 'RETURN $x, $y' + const parameters = { x: 'x', y: 'y' } + + const observer = protocol.run(query, parameters, { + bookmark, + txConfig, + database, + mode: WRITE + }) + + protocol.verifyMessageCount(2) + + expect(protocol.messages[0]).toBeMessage( + RequestMessage.runWithMetadata(query, parameters, { + bookmark, + txConfig, + database, + mode: WRITE + }) + ) + expect(protocol.messages[1]).toBeMessage(RequestMessage.pull()) + expect(protocol.observers).toEqual([observer, observer]) + expect(protocol.flushes).toEqual([false, true]) + }) + + it('should run a with impersonated user', () => { + const database = 'testdb' + const impersonatedUser = 'the impostor' + const bookmark = new Bookmark([ + 'neo4j:bookmark:v1:tx1', + 'neo4j:bookmark:v1:tx2' + ]) + const txConfig = new TxConfig({ + timeout: 5000, + metadata: { x: 1, y: 'something' } + }) + const recorder = new utils.MessageRecordingConnection() + const protocol = new BoltProtocolV4x4(recorder, null, false) + utils.spyProtocolWrite(protocol) + + const query = 'RETURN $x, $y' + const parameters = { x: 'x', y: 'y' } + + const observer = protocol.run(query, parameters, { + bookmark, + txConfig, + database, + mode: WRITE, + impersonatedUser + }) + + protocol.verifyMessageCount(2) + + expect(protocol.messages[0]).toBeMessage( + RequestMessage.runWithMetadata(query, parameters, { + bookmark, + txConfig, + database, + mode: WRITE, + impersonatedUser + }) + ) + expect(protocol.messages[1]).toBeMessage(RequestMessage.pull()) + expect(protocol.observers).toEqual([observer, observer]) + expect(protocol.flushes).toEqual([false, true]) + }) + + it('should begin a transaction', () => { + const database = 'testdb' + const bookmark = new Bookmark([ + 'neo4j:bookmark:v1:tx1', + 'neo4j:bookmark:v1:tx2' + ]) + const txConfig = new TxConfig({ + timeout: 5000, + metadata: { x: 1, y: 'something' } + }) + const recorder = new utils.MessageRecordingConnection() + const protocol = new BoltProtocolV4x4(recorder, null, false) + utils.spyProtocolWrite(protocol) + + const observer = protocol.beginTransaction({ + bookmark, + txConfig, + database, + mode: WRITE + }) + + protocol.verifyMessageCount(1) + expect(protocol.messages[0]).toBeMessage( + RequestMessage.begin({ bookmark, txConfig, database, mode: WRITE }) + ) + expect(protocol.observers).toEqual([observer]) + expect(protocol.flushes).toEqual([true]) + }) + + it('should begin a transaction with impersonated user', () => { + const database = 'testdb' + const impersonatedUser = 'the impostor' + const bookmark = new Bookmark([ + 'neo4j:bookmark:v1:tx1', + 'neo4j:bookmark:v1:tx2' + ]) + const txConfig = new TxConfig({ + timeout: 5000, + metadata: { x: 1, y: 'something' } + }) + const recorder = new utils.MessageRecordingConnection() + const protocol = new BoltProtocolV4x4(recorder, null, false) + utils.spyProtocolWrite(protocol) + + const observer = protocol.beginTransaction({ + bookmark, + txConfig, + database, + mode: WRITE, + impersonatedUser + }) + + protocol.verifyMessageCount(1) + expect(protocol.messages[0]).toBeMessage( + RequestMessage.begin({ bookmark, txConfig, database, mode: WRITE, impersonatedUser }) + ) + expect(protocol.observers).toEqual([observer]) + expect(protocol.flushes).toEqual([true]) + }) + + it('should return correct bolt version number', () => { + const protocol = new BoltProtocolV4x4(null, null, false) + + expect(protocol.version).toBe(4.4) + }) + + it('should update metadata', () => { + const metadata = { t_first: 1, t_last: 2, db_hits: 3, some_other_key: 4 } + const protocol = new BoltProtocolV4x4(null, null, false) + + const transformedMetadata = protocol.transformMetadata(metadata) + + expect(transformedMetadata).toEqual({ + result_available_after: 1, + result_consumed_after: 2, + db_hits: 3, + some_other_key: 4 + }) + }) + + it('should initialize connection', () => { + const recorder = new utils.MessageRecordingConnection() + const protocol = new BoltProtocolV4x4(recorder, null, false) + utils.spyProtocolWrite(protocol) + + const clientName = 'js-driver/1.2.3' + const authToken = { username: 'neo4j', password: 'secret' } + + const observer = protocol.initialize({ userAgent: clientName, authToken }) + + protocol.verifyMessageCount(1) + expect(protocol.messages[0]).toBeMessage( + RequestMessage.hello(clientName, authToken) + ) + expect(protocol.observers).toEqual([observer]) + expect(protocol.flushes).toEqual([true]) + }) + + it('should begin a transaction', () => { + const bookmark = new Bookmark([ + 'neo4j:bookmark:v1:tx1', + 'neo4j:bookmark:v1:tx2' + ]) + const txConfig = new TxConfig({ + timeout: 5000, + metadata: { x: 1, y: 'something' } + }) + const recorder = new utils.MessageRecordingConnection() + const protocol = new BoltProtocolV4x4(recorder, null, false) + utils.spyProtocolWrite(protocol) + + const observer = protocol.beginTransaction({ + bookmark, + txConfig, + mode: WRITE + }) + + protocol.verifyMessageCount(1) + expect(protocol.messages[0]).toBeMessage( + RequestMessage.begin({ bookmark, txConfig, mode: WRITE }) + ) + expect(protocol.observers).toEqual([observer]) + expect(protocol.flushes).toEqual([true]) + }) + + it('should commit', () => { + const recorder = new utils.MessageRecordingConnection() + const protocol = new BoltProtocolV4x4(recorder, null, false) + utils.spyProtocolWrite(protocol) + + const observer = protocol.commitTransaction() + + protocol.verifyMessageCount(1) + expect(protocol.messages[0]).toBeMessage(RequestMessage.commit()) + expect(protocol.observers).toEqual([observer]) + expect(protocol.flushes).toEqual([true]) + }) + + it('should rollback', () => { + const recorder = new utils.MessageRecordingConnection() + const protocol = new BoltProtocolV4x4(recorder, null, false) + utils.spyProtocolWrite(protocol) + + const observer = protocol.rollbackTransaction() + + protocol.verifyMessageCount(1) + expect(protocol.messages[0]).toBeMessage(RequestMessage.rollback()) + expect(protocol.observers).toEqual([observer]) + expect(protocol.flushes).toEqual([true]) + }) + + describe('unpacker configuration', () => { + test.each([ + [false, false], + [false, true], + [true, false], + [true, true] + ])( + 'should create unpacker with disableLosslessIntegers=%p and useBigInt=%p', + (disableLosslessIntegers, useBigInt) => { + const protocol = new BoltProtocolV4x4(null, null, { + disableLosslessIntegers, + useBigInt + }) + expect(protocol._unpacker._disableLosslessIntegers).toBe( + disableLosslessIntegers + ) + expect(protocol._unpacker._useBigInt).toBe(useBigInt) + } + ) + }) +}) diff --git a/packages/bolt-connection/test/bolt/index.test.js b/packages/bolt-connection/test/bolt/index.test.js index 9926af724..838b46be6 100644 --- a/packages/bolt-connection/test/bolt/index.test.js +++ b/packages/bolt-connection/test/bolt/index.test.js @@ -29,6 +29,7 @@ import BoltProtocolV4x0 from '../../src/bolt/bolt-protocol-v4x0' import BoltProtocolV4x1 from '../../src/bolt/bolt-protocol-v4x1' import BoltProtocolV4x2 from '../../src/bolt/bolt-protocol-v4x2' import BoltProtocolV4x3 from '../../src/bolt/bolt-protocol-v4x3' +import BoltProtocolV4x4 from '../../src/bolt/bolt-protocol-v4x4' const { logger: { Logger } @@ -42,13 +43,13 @@ describe('#unit Bolt', () => { const writtenBuffer = channel.written[0] const boltMagicPreamble = '60 60 b0 17' - const protocolVersion4x3to4x2 = '00 01 03 04' + const protocolVersion4x4to4x2 = '00 02 04 04' const protocolVersion4x1 = '00 00 01 04' const protocolVersion4x0 = '00 00 00 04' const protocolVersion3 = '00 00 00 03' expect(writtenBuffer.toHex()).toEqual( - `${boltMagicPreamble} ${protocolVersion4x3to4x2} ${protocolVersion4x1} ${protocolVersion4x0} ${protocolVersion3}` + `${boltMagicPreamble} ${protocolVersion4x4to4x2} ${protocolVersion4x1} ${protocolVersion4x0} ${protocolVersion3}` ) }) @@ -301,7 +302,8 @@ describe('#unit Bolt', () => { v(4.0, BoltProtocolV4x0), v(4.1, BoltProtocolV4x1), v(4.2, BoltProtocolV4x2), - v(4.3, BoltProtocolV4x3) + v(4.3, BoltProtocolV4x3), + v(4.4, BoltProtocolV4x4) ] availableProtocols.forEach(lambda) diff --git a/packages/bolt-connection/test/bolt/request-message.test.js b/packages/bolt-connection/test/bolt/request-message.test.js index c6ae4ebea..91765684c 100644 --- a/packages/bolt-connection/test/bolt/request-message.test.js +++ b/packages/bolt-connection/test/bolt/request-message.test.js @@ -175,7 +175,7 @@ describe('#unit RequestMessage', () => { }) describe('BoltV4', () => { - function verify (message, signature, metadata, name) { + function verify(message, signature, metadata, name) { expect(message.signature).toEqual(signature) expect(message.fields).toEqual([metadata]) expect(message.toString()).toEqual(`${name} ${json.stringify(metadata)}`) @@ -266,4 +266,167 @@ describe('#unit RequestMessage', () => { ) }) }) + + describe('BoltV4.4', () => { + it('should create ROUTE message', () => { + const requestContext = { someValue: '1234' } + const bookmarks = ['a', 'b'] + const databaseName = 'user_db' + const impersonatedUser = "user" + + const message = RequestMessage.routeV4x4(requestContext, bookmarks, { databaseName, impersonatedUser }) + + expect(message.signature).toEqual(0x66) + expect(message.fields).toEqual([requestContext, bookmarks, { db: databaseName, imp_user: impersonatedUser }]) + expect(message.toString()).toEqual( + `ROUTE ${json.stringify(requestContext)} ${json.stringify( + bookmarks + )} ${json.stringify({ db: databaseName, imp_user: impersonatedUser })}` + ) + }) + + it('should create ROUTE message with default values', () => { + const message = RequestMessage.routeV4x4() + + expect(message.signature).toEqual(0x66) + expect(message.fields).toEqual([{}, [], {}]) + expect(message.toString()).toEqual( + `ROUTE ${json.stringify({})} ${json.stringify([])} ${json.stringify({})}` + ) + }) + + it('should create BEGIN message with impersonated user', () => { + ;[READ, WRITE].forEach(mode => { + const bookmark = new Bookmark([ + 'neo4j:bookmark:v1:tx1', + 'neo4j:bookmark:v1:tx10' + ]) + const impersonatedUser = 'the impostor' + const txConfig = new TxConfig({ timeout: 42, metadata: { key: 42 } }) + + const message = RequestMessage.begin({ bookmark, txConfig, mode, impersonatedUser }) + + const expectedMetadata = { + bookmarks: bookmark.values(), + tx_timeout: int(42), + tx_metadata: { key: 42 }, + imp_user: impersonatedUser + } + if (mode === READ) { + expectedMetadata.mode = 'r' + } + + expect(message.signature).toEqual(0x11) + expect(message.fields).toEqual([expectedMetadata]) + expect(message.toString()).toEqual( + `BEGIN ${json.stringify(expectedMetadata)}` + ) + }) + }) + + it('should create BEGIN message without impersonated user if it is not supplied or null', () => { + ;[undefined, null].forEach(impersonatedUser => { + const bookmark = new Bookmark([ + 'neo4j:bookmark:v1:tx1', + 'neo4j:bookmark:v1:tx10' + ]) + const mode = WRITE + const txConfig = new TxConfig({ timeout: 42, metadata: { key: 42 } }) + + const message = RequestMessage.begin({ bookmark, txConfig, mode, impersonatedUser }) + + const expectedMetadata = { + bookmarks: bookmark.values(), + tx_timeout: int(42), + tx_metadata: { key: 42 } + } + + expect(message.signature).toEqual(0x11) + expect(message.fields).toEqual([expectedMetadata]) + expect(message.toString()).toEqual( + `BEGIN ${json.stringify(expectedMetadata)}` + ) + }) + }) + + it('should create RUN message with the impersonated user', () => { + ;[READ, WRITE].forEach(mode => { + const query = 'RETURN $x' + const parameters = { x: 42 } + const bookmark = new Bookmark([ + 'neo4j:bookmark:v1:tx1', + 'neo4j:bookmark:v1:tx10', + 'neo4j:bookmark:v1:tx100' + ]) + const txConfig = new TxConfig({ + timeout: 999, + metadata: { a: 'a', b: 'b' } + }) + const impersonatedUser = 'the impostor' + + const message = RequestMessage.runWithMetadata(query, parameters, { + bookmark, + txConfig, + mode, + impersonatedUser + }) + + const expectedMetadata = { + bookmarks: bookmark.values(), + tx_timeout: int(999), + tx_metadata: { a: 'a', b: 'b' }, + imp_user: impersonatedUser + } + if (mode === READ) { + expectedMetadata.mode = 'r' + } + + expect(message.signature).toEqual(0x10) + expect(message.fields).toEqual([query, parameters, expectedMetadata]) + expect(message.toString()).toEqual( + `RUN ${query} ${json.stringify(parameters)} ${json.stringify( + expectedMetadata + )}` + ) + }) + }) + + it('should create RUN message without impersonated user if it is not supplied or null', () => { + ;[undefined, null].forEach(impersonatedUser => { + const mode = WRITE + const query = 'RETURN $x' + const parameters = { x: 42 } + const bookmark = new Bookmark([ + 'neo4j:bookmark:v1:tx1', + 'neo4j:bookmark:v1:tx10', + 'neo4j:bookmark:v1:tx100' + ]) + const txConfig = new TxConfig({ + timeout: 999, + metadata: { a: 'a', b: 'b' } + }) + + const message = RequestMessage.runWithMetadata(query, parameters, { + bookmark, + txConfig, + mode, + impersonatedUser + }) + + const expectedMetadata = { + bookmarks: bookmark.values(), + tx_timeout: int(999), + tx_metadata: { a: 'a', b: 'b' } + } + + expect(message.signature).toEqual(0x10) + expect(message.fields).toEqual([query, parameters, expectedMetadata]) + expect(message.toString()).toEqual( + `RUN ${query} ${json.stringify(parameters)} ${json.stringify( + expectedMetadata + )}` + ) + }) + }) + }) }) diff --git a/packages/bolt-connection/test/bolt/routing-table-raw.test.js b/packages/bolt-connection/test/bolt/routing-table-raw.test.js index 353b070a0..d9bd2ff62 100644 --- a/packages/bolt-connection/test/bolt/routing-table-raw.test.js +++ b/packages/bolt-connection/test/bolt/routing-table-raw.test.js @@ -29,11 +29,12 @@ describe('#unit RawRoutingTable', () => { shouldReturnNullRawRoutingTable(() => RawRoutingTable.ofRecord(null)) }) - describe('when record has servers and ttl', () => { + describe('when record has servers, db and ttl', () => { it('should return isNull equals false', () => { const record = newRecord({ ttl: 123, - servers: [{ role: 'READ', addresses: ['127.0.0.1'] }] + servers: [{ role: 'READ', addresses: ['127.0.0.1'] }], + db: 'homedb' }) const result = RawRoutingTable.ofRecord(record) expect(result.isNull).toEqual(false) @@ -42,22 +43,34 @@ describe('#unit RawRoutingTable', () => { it('should return the ttl', () => { const record = newRecord({ ttl: 123, - servers: [{ role: 'READ', addresses: ['127.0.0.1'] }] + servers: [{ role: 'READ', addresses: ['127.0.0.1'] }], + db: 'homedb' }) const result = RawRoutingTable.ofRecord(record) expect(result.ttl).toEqual(123) }) - it('should return the ttl', () => { + it('should return the servers', () => { const record = newRecord({ ttl: 123, - servers: [{ role: 'READ', addresses: ['127.0.0.1'] }] + servers: [{ role: 'READ', addresses: ['127.0.0.1'] }], + db: 'homedb' }) const result = RawRoutingTable.ofRecord(record) expect(result.servers).toEqual([ { role: 'READ', addresses: ['127.0.0.1'] } ]) }) + + it('should return the db', () => { + const record = newRecord({ + ttl: 123, + servers: [{ role: 'READ', addresses: ['127.0.0.1'] }], + db: 'homedb' + }) + const result = RawRoutingTable.ofRecord(record) + expect(result.db).toEqual('homedb') + }) }) describe('when record has servers and but no ttl', () => { @@ -119,6 +132,18 @@ describe('#unit RawRoutingTable', () => { expect(() => result.servers).toThrow() }) }) + + describe('when record does not have db name', () => { + it('should return db equals null', () => { + const record = newRecord({ + ttl: 123, + noServers: [{ role: 'READ', addresses: ['127.0.0.1'] }] + }) + const result = RawRoutingTable.ofRecord(record) + expect(result.db).toEqual(null) + }) + + }) }) describe('ofMessageResponse', () => { @@ -144,7 +169,7 @@ describe('#unit RawRoutingTable', () => { expect(result.ttl).toEqual(123) }) - it('should return the ttl', () => { + it('should return the servers', () => { const response = newResponse({ ttl: 123, servers: [{ role: 'READ', addresses: ['127.0.0.1'] }] @@ -154,6 +179,16 @@ describe('#unit RawRoutingTable', () => { { role: 'READ', addresses: ['127.0.0.1'] } ]) }) + + it('should return the db', () => { + const response = newResponse({ + ttl: 123, + servers: [{ role: 'READ', addresses: ['127.0.0.1'] }], + db: 'homedb' + }) + const result = RawRoutingTable.ofMessageResponse(response) + expect(result.db).toEqual('homedb') + }) }) function shouldReturnNullRawRoutingTable (subject) { @@ -176,6 +211,13 @@ describe('#unit RawRoutingTable', () => { fail(`it should not return ${servers}`) }).toThrow(new Error('Not implemented')) }) + + it('should not implement db', () => { + expect(() => { + const db = subject().db + fail(`it should not return ${db}`) + }).toThrow(new Error('Not implemented')) + }) } function newRecord (params = {}) { diff --git a/packages/bolt-connection/test/connection-provider/connection-provider-routing.test.js b/packages/bolt-connection/test/connection-provider/connection-provider-routing.test.js index 89bec5feb..509f5c8e8 100644 --- a/packages/bolt-connection/test/connection-provider/connection-provider-routing.test.js +++ b/packages/bolt-connection/test/connection-provider/connection-provider-routing.test.js @@ -70,6 +70,12 @@ describe('#unit RoutingConnectionProvider', () => { const serverEE = ServerAddress.fromUrl('serverEE') const serverABC = ServerAddress.fromUrl('serverABC') + + const usersDataSet = [ + [null], + [undefined], + ['the-impostor'] + ] it('can forget address', () => { const connectionProvider = newRoutingConnectionProvider([ @@ -182,7 +188,7 @@ describe('#unit RoutingConnectionProvider', () => { ) }, 10000) - it('acquires connection and returns a DelegateConnection', async () => { + it.each(usersDataSet)('acquires connection and returns a DelegateConnection [user=%s]', async (user) => { const pool = newPool() const connectionProvider = newRoutingConnectionProvider( [ @@ -198,18 +204,20 @@ describe('#unit RoutingConnectionProvider', () => { const conn1 = await connectionProvider.acquireConnection({ accessMode: READ, - database: null + database: null, + impersonatedUser: user }) expect(conn1 instanceof DelegateConnection).toBeTruthy() const conn2 = await connectionProvider.acquireConnection({ accessMode: WRITE, - database: null + database: null, + impersonatedUser: user }) expect(conn2 instanceof DelegateConnection).toBeTruthy() }, 10000) - it('acquires read connection with up-to-date routing table', done => { + it.each(usersDataSet)('acquires read connection with up-to-date routing table [user=%s]', (user, done) => { const pool = newPool() const connectionProvider = newRoutingConnectionProvider( [ @@ -224,7 +232,7 @@ describe('#unit RoutingConnectionProvider', () => { ) connectionProvider - .acquireConnection({ accessMode: READ, database: null }) + .acquireConnection({ accessMode: READ, database: null, impersonatedUser: user }) .then(connection => { expect(connection.address).toEqual(server3) expect(pool.has(server3)).toBeTruthy() @@ -240,7 +248,7 @@ describe('#unit RoutingConnectionProvider', () => { }) }, 10000) - it('acquires write connection with up-to-date routing table', done => { + it.each(usersDataSet)('acquires write connection with up-to-date routing table [user=%s]', (user, done) => { const pool = newPool() const connectionProvider = newRoutingConnectionProvider( [ @@ -255,13 +263,13 @@ describe('#unit RoutingConnectionProvider', () => { ) connectionProvider - .acquireConnection({ accessMode: WRITE, database: null }) + .acquireConnection({ accessMode: WRITE, database: null, impersonatedUser: user }) .then(connection => { expect(connection.address).toEqual(server5) expect(pool.has(server5)).toBeTruthy() connectionProvider - .acquireConnection({ accessMode: WRITE, database: null }) + .acquireConnection({ accessMode: WRITE, database: null, impersonatedUser: user }) .then(connection => { expect(connection.address).toEqual(server6) expect(pool.has(server6)).toBeTruthy() @@ -271,7 +279,7 @@ describe('#unit RoutingConnectionProvider', () => { }) }, 10000) - it('throws for illegal access mode', done => { + it.each(usersDataSet)('throws for illegal access mode [user=%s]', (user, done) => { const connectionProvider = newRoutingConnectionProvider([ newRoutingTable( null, @@ -282,14 +290,14 @@ describe('#unit RoutingConnectionProvider', () => { ]) connectionProvider - .acquireConnection({ accessMode: 'WRONG', database: null }) + .acquireConnection({ accessMode: 'WRONG', database: null, impersonatedUser: user }) .catch(error => { expect(error.message).toEqual('Illegal mode WRONG') done() }) }, 10000) - it('refreshes stale routing table to get read connection', done => { + it.each(usersDataSet)('refreshes stale routing table to get read connection [user=%s]', (user, done) => { const pool = newPool() const updatedRoutingTable = newRoutingTable( null, @@ -312,13 +320,13 @@ describe('#unit RoutingConnectionProvider', () => { ) connectionProvider - .acquireConnection({ accessMode: READ, database: null }) + .acquireConnection({ accessMode: READ, database: null, impersonatedUser: user }) .then(connection => { expect(connection.address).toEqual(serverC) expect(pool.has(serverC)).toBeTruthy() connectionProvider - .acquireConnection({ accessMode: READ, database: null }) + .acquireConnection({ accessMode: READ, database: null, impersonatedUser: user }) .then(connection => { expect(connection.address).toEqual(serverD) expect(pool.has(serverD)).toBeTruthy() @@ -328,7 +336,7 @@ describe('#unit RoutingConnectionProvider', () => { }) }, 10000) - it('refreshes stale routing table to get write connection', done => { + it.each(usersDataSet)('refreshes stale routing table to get write connection [user=%s]', (user, done) => { const pool = newPool() const updatedRoutingTable = newRoutingTable( null, @@ -351,13 +359,13 @@ describe('#unit RoutingConnectionProvider', () => { ) connectionProvider - .acquireConnection({ accessMode: WRITE, database: null }) + .acquireConnection({ accessMode: WRITE, database: null, impersonatedUser: user }) .then(connection => { expect(connection.address).toEqual(serverE) expect(pool.has(serverE)).toBeTruthy() connectionProvider - .acquireConnection({ accessMode: WRITE, database: null }) + .acquireConnection({ accessMode: WRITE, database: null, impersonatedUser: user }) .then(connection => { expect(connection.address).toEqual(serverF) expect(pool.has(serverF)).toBeTruthy() @@ -367,7 +375,7 @@ describe('#unit RoutingConnectionProvider', () => { }) }, 10000) - it('refreshes stale routing table to get read connection when one router fails', done => { + it.each(usersDataSet)('refreshes stale routing table to get read connection when one router fails [user=%s]', (user, done) => { const pool = newPool() const updatedRoutingTable = newRoutingTable( null, @@ -395,13 +403,13 @@ describe('#unit RoutingConnectionProvider', () => { ) connectionProvider - .acquireConnection({ accessMode: READ, database: null }) + .acquireConnection({ accessMode: READ, database: null, impersonatedUser: user }) .then(connection => { expect(connection.address).toEqual(serverC) expect(pool.has(serverC)).toBeTruthy() connectionProvider - .acquireConnection({ accessMode: READ, database: null }) + .acquireConnection({ accessMode: READ, database: null, impersonatedUser: user }) .then(connection => { expect(connection.address).toEqual(serverD) expect(pool.has(serverD)).toBeTruthy() @@ -411,7 +419,7 @@ describe('#unit RoutingConnectionProvider', () => { }) }, 10000) - it('refreshes stale routing table to get write connection when one router fails', done => { + it.each(usersDataSet)('refreshes stale routing table to get write connection when one router fails [user=%s]', (user, done) => { const pool = newPool() const updatedRoutingTable = newRoutingTable( null, @@ -439,13 +447,13 @@ describe('#unit RoutingConnectionProvider', () => { ) connectionProvider - .acquireConnection({ accessMode: WRITE, database: null }) + .acquireConnection({ accessMode: WRITE, database: null, impersonatedUser: user }) .then(connection => { expect(connection.address).toEqual(serverE) expect(pool.has(serverE)).toBeTruthy() connectionProvider - .acquireConnection({ accessMode: WRITE, database: null }) + .acquireConnection({ accessMode: WRITE, database: null, impersonatedUser: user }) .then(connection => { expect(connection.address).toEqual(serverF) expect(pool.has(serverF)).toBeTruthy() @@ -455,7 +463,7 @@ describe('#unit RoutingConnectionProvider', () => { }) }, 10000) - it('refreshes routing table without readers to get read connection', done => { + it.each(usersDataSet)('refreshes routing table without readers to get read connection [user=%s]', (user, done) => { const pool = newPool() const updatedRoutingTable = newRoutingTable( null, @@ -483,13 +491,13 @@ describe('#unit RoutingConnectionProvider', () => { ) connectionProvider - .acquireConnection({ accessMode: READ, database: null }) + .acquireConnection({ accessMode: READ, database: null, impersonatedUser: user }) .then(connection => { expect(connection.address).toEqual(serverC) expect(pool.has(serverC)).toBeTruthy() connectionProvider - .acquireConnection({ accessMode: READ, database: null }) + .acquireConnection({ accessMode: READ, database: null, impersonatedUser: user }) .then(connection => { expect(connection.address).toEqual(serverD) expect(pool.has(serverD)).toBeTruthy() @@ -499,7 +507,7 @@ describe('#unit RoutingConnectionProvider', () => { }) }, 10000) - it('refreshes routing table without writers to get write connection', done => { + it.each(usersDataSet)('refreshes routing table without writers to get write connection [user=%s]', (user, done) => { const pool = newPool() const updatedRoutingTable = newRoutingTable( null, @@ -527,13 +535,13 @@ describe('#unit RoutingConnectionProvider', () => { ) connectionProvider - .acquireConnection({ accessMode: WRITE, database: null }) + .acquireConnection({ accessMode: WRITE, database: null, impersonatedUser: user }) .then(connection => { expect(connection.address).toEqual(serverE) expect(pool.has(serverE)).toBeTruthy() connectionProvider - .acquireConnection({ accessMode: WRITE, database: null }) + .acquireConnection({ accessMode: WRITE, database: null, impersonatedUser: user }) .then(connection => { expect(connection.address).toEqual(serverF) expect(pool.has(serverF)).toBeTruthy() @@ -543,7 +551,7 @@ describe('#unit RoutingConnectionProvider', () => { }) }, 10000) - it('throws when all routers return nothing while getting read connection', done => { + it.each(usersDataSet)('throws when all routers return nothing while getting read connection [user=%s]', (user, done) => { const connectionProvider = newRoutingConnectionProvider( [ newRoutingTable( @@ -564,14 +572,14 @@ describe('#unit RoutingConnectionProvider', () => { ) connectionProvider - .acquireConnection({ accessMode: READ, database: null }) + .acquireConnection({ accessMode: READ, database: null, impersonatedUser: user }) .catch(error => { expect(error.code).toEqual(SERVICE_UNAVAILABLE) done() }) }, 10000) - it('throws when all routers return nothing while getting write connection', done => { + it.each(usersDataSet)('throws when all routers return nothing while getting write connection [user=%s]', (user, done) => { const connectionProvider = newRoutingConnectionProvider( [ newRoutingTable( @@ -592,14 +600,14 @@ describe('#unit RoutingConnectionProvider', () => { ) connectionProvider - .acquireConnection({ accessMode: WRITE, database: null }) + .acquireConnection({ accessMode: WRITE, database: null, impersonatedUser: user }) .catch(error => { expect(error.code).toEqual(SERVICE_UNAVAILABLE) done() }) }, 10000) - it('throws when all routers return routing tables without readers while getting read connection', done => { + it.each(usersDataSet)('throws when all routers return routing tables without readers while getting read connection', (user, done) => { const updatedRoutingTable = newRoutingTable( null, [serverA, serverB], @@ -626,14 +634,14 @@ describe('#unit RoutingConnectionProvider', () => { ) connectionProvider - .acquireConnection({ accessMode: READ, database: null }) + .acquireConnection({ accessMode: READ, database: null, impersonatedUser: user }) .catch(error => { expect(error.code).toEqual(SESSION_EXPIRED) done() }) }, 10000) - it('throws when all routers return routing tables without writers while getting write connection', done => { + it.each(usersDataSet)('throws when all routers return routing tables without writers while getting write connection [user=%s]', (user, done) => { const updatedRoutingTable = newRoutingTable( null, [serverA, serverB], @@ -660,14 +668,14 @@ describe('#unit RoutingConnectionProvider', () => { ) connectionProvider - .acquireConnection({ accessMode: WRITE, database: null }) + .acquireConnection({ accessMode: WRITE, database: null, impersonatedUser: user }) .catch(error => { expect(error.code).toEqual(SESSION_EXPIRED) done() }) }, 10000) - it('throws when stale routing table without routers while getting read connection', done => { + it.each(usersDataSet)('throws when stale routing table without routers while getting read connection [user=%s]', (user, done) => { const connectionProvider = newRoutingConnectionProvider( [ newRoutingTable( @@ -682,14 +690,14 @@ describe('#unit RoutingConnectionProvider', () => { ) connectionProvider - .acquireConnection({ accessMode: READ, database: null }) + .acquireConnection({ accessMode: READ, database: null, impersonatedUser: user }) .catch(error => { expect(error.code).toEqual(SERVICE_UNAVAILABLE) done() }) }, 10000) - it('throws when stale routing table without routers while getting write connection', done => { + it.each(usersDataSet)('throws when stale routing table without routers while getting write connection [user=%s]', (user, done) => { const connectionProvider = newRoutingConnectionProvider( [ newRoutingTable( @@ -704,14 +712,14 @@ describe('#unit RoutingConnectionProvider', () => { ) connectionProvider - .acquireConnection({ accessMode: WRITE, database: null }) + .acquireConnection({ accessMode: WRITE, database: null, impersonatedUser: user }) .catch(error => { expect(error.code).toEqual(SERVICE_UNAVAILABLE) done() }) }, 10000) - it('updates routing table after refresh', done => { + it.each(usersDataSet)('updates routing table after refresh [user=%s]', (user, done) => { const pool = newPool() const updatedRoutingTable = newRoutingTable( null, @@ -738,7 +746,7 @@ describe('#unit RoutingConnectionProvider', () => { ) connectionProvider - .acquireConnection({ accessMode: READ, database: null }) + .acquireConnection({ accessMode: READ, database: null, impersonatedUser: user }) .then(() => { expectRoutingTable( connectionProvider, @@ -759,7 +767,7 @@ describe('#unit RoutingConnectionProvider', () => { }) }, 10000) - it('forgets all routers when they fail while acquiring read connection', done => { + it.each(usersDataSet)('forgets all routers when they fail while acquiring read connection [user=%s]', (user, done) => { const connectionProvider = newRoutingConnectionProvider( [ newRoutingTable( @@ -774,7 +782,7 @@ describe('#unit RoutingConnectionProvider', () => { ) connectionProvider - .acquireConnection({ accessMode: READ, database: null }) + .acquireConnection({ accessMode: READ, database: null, impersonatedUser: user }) .catch(error => { expect(error.code).toEqual(SERVICE_UNAVAILABLE) expectRoutingTable( @@ -788,7 +796,7 @@ describe('#unit RoutingConnectionProvider', () => { }) }, 10000) - it('forgets all routers when they fail while acquiring write connection', done => { + it.each(usersDataSet)('forgets all routers when they fail while acquiring write connection [user=%s]', (user, done) => { const connectionProvider = newRoutingConnectionProvider( [ newRoutingTable( @@ -803,7 +811,7 @@ describe('#unit RoutingConnectionProvider', () => { ) connectionProvider - .acquireConnection({ accessMode: WRITE, database: null }) + .acquireConnection({ accessMode: WRITE, database: null, impersonatedUser: user }) .catch(error => { expect(error.code).toEqual(SERVICE_UNAVAILABLE) expectRoutingTable( @@ -817,7 +825,7 @@ describe('#unit RoutingConnectionProvider', () => { }) }, 10000) - it('uses seed router address when all existing routers fail', done => { + it.each(usersDataSet)('uses seed router address when all existing routers fail [user=%s]', (user, done) => { const updatedRoutingTable = newRoutingTable( null, [serverA, serverB, serverC], @@ -848,12 +856,12 @@ describe('#unit RoutingConnectionProvider', () => { ) connectionProvider - .acquireConnection({ accessMode: READ, database: null }) + .acquireConnection({ accessMode: READ, database: null, impersonatedUser: user }) .then(connection1 => { expect(connection1.address).toEqual(serverD) connectionProvider - .acquireConnection({ accessMode: WRITE, database: null }) + .acquireConnection({ accessMode: WRITE, database: null, impersonatedUser: user }) .then(connection2 => { expect(connection2.address).toEqual(serverF) @@ -869,7 +877,7 @@ describe('#unit RoutingConnectionProvider', () => { }) }, 10000) - it('uses resolved seed router address when all existing routers fail', done => { + it.each(usersDataSet)('uses resolved seed router address when all existing routers fail [user=%s]', (user, done) => { const updatedRoutingTable = newRoutingTable( null, [serverA, serverB], @@ -900,12 +908,12 @@ describe('#unit RoutingConnectionProvider', () => { ) connectionProvider - .acquireConnection({ accessMode: WRITE, database: null }) + .acquireConnection({ accessMode: WRITE, database: null, user }) .then(connection1 => { expect(connection1.address).toEqual(serverE) connectionProvider - .acquireConnection({ accessMode: READ, database: null }) + .acquireConnection({ accessMode: READ, database: null, user }) .then(connection2 => { expect(connection2.address).toEqual(serverC) @@ -921,7 +929,7 @@ describe('#unit RoutingConnectionProvider', () => { }) }, 10000) - it('uses resolved seed router address that returns correct routing table when all existing routers fail', done => { + it.each(usersDataSet)('uses resolved seed router address that returns correct routing table when all existing routers fail [user=%s]', (user, done) => { const updatedRoutingTable = newRoutingTable( null, [serverA, serverB], @@ -952,12 +960,12 @@ describe('#unit RoutingConnectionProvider', () => { ) connectionProvider - .acquireConnection({ accessMode: WRITE, database: null }) + .acquireConnection({ accessMode: WRITE, database: null, impersonatedUser: user }) .then(connection1 => { expect(connection1.address).toEqual(serverD) connectionProvider - .acquireConnection({ accessMode: WRITE, database: null }) + .acquireConnection({ accessMode: WRITE, database: null, impersonatedUser: user }) .then(connection2 => { expect(connection2.address).toEqual(serverE) @@ -973,7 +981,7 @@ describe('#unit RoutingConnectionProvider', () => { }) }, 10000) - it('fails when both existing routers and seed router fail to return a routing table', done => { + it.each(usersDataSet)('fails when both existing routers and seed router fail to return a routing table [user=%s]', (user, done) => { const connectionProvider = newRoutingConnectionProviderWithSeedRouter( server0, [server0], // seed router address resolves just to itself @@ -997,7 +1005,7 @@ describe('#unit RoutingConnectionProvider', () => { ) connectionProvider - .acquireConnection({ accessMode: READ, database: null }) + .acquireConnection({ accessMode: READ, database: null, impersonatedUser: user }) .catch(error => { expect(error.code).toEqual(SERVICE_UNAVAILABLE) @@ -1010,7 +1018,7 @@ describe('#unit RoutingConnectionProvider', () => { ) connectionProvider - .acquireConnection({ accessMode: WRITE, database: null }) + .acquireConnection({ accessMode: WRITE, database: null, impersonatedUser: user }) .catch(error => { expect(error.code).toEqual(SERVICE_UNAVAILABLE) @@ -1027,7 +1035,7 @@ describe('#unit RoutingConnectionProvider', () => { }) }, 10000) - it('fails when both existing routers and resolved seed router fail to return a routing table', done => { + it.each(usersDataSet)('fails when both existing routers and resolved seed router fail to return a routing table [user=%s]', (user, done) => { const connectionProvider = newRoutingConnectionProviderWithSeedRouter( server0, [server01], // seed router address resolves to a different one @@ -1050,7 +1058,7 @@ describe('#unit RoutingConnectionProvider', () => { ) connectionProvider - .acquireConnection({ accessMode: WRITE, database: null }) + .acquireConnection({ accessMode: WRITE, database: null, impersonatedUser: user }) .catch(error => { expect(error.code).toEqual(SERVICE_UNAVAILABLE) @@ -1063,7 +1071,7 @@ describe('#unit RoutingConnectionProvider', () => { ) connectionProvider - .acquireConnection({ accessMode: READ, database: null }) + .acquireConnection({ accessMode: READ, database: null, impersonatedUser: user }) .catch(error => { expect(error.code).toEqual(SERVICE_UNAVAILABLE) @@ -1080,7 +1088,7 @@ describe('#unit RoutingConnectionProvider', () => { }) }, 10000) - it('fails when both existing routers and all resolved seed routers fail to return a routing table', done => { + it.each(usersDataSet)('fails when both existing routers and all resolved seed routers fail to return a routing table [user=%s]', (user, done) => { const connectionProvider = newRoutingConnectionProviderWithSeedRouter( server0, [server02, server01], // seed router address resolves to 2 different addresses @@ -1105,7 +1113,7 @@ describe('#unit RoutingConnectionProvider', () => { ) connectionProvider - .acquireConnection({ accessMode: READ, database: null }) + .acquireConnection({ accessMode: READ, database: null, impersonatedUser: user }) .catch(error => { expect(error.code).toEqual(SERVICE_UNAVAILABLE) @@ -1118,7 +1126,7 @@ describe('#unit RoutingConnectionProvider', () => { ) connectionProvider - .acquireConnection({ accessMode: WRITE, database: null }) + .acquireConnection({ accessMode: WRITE, database: null, impersonatedUser: user }) .catch(error => { expect(error.code).toEqual(SERVICE_UNAVAILABLE) @@ -1135,7 +1143,7 @@ describe('#unit RoutingConnectionProvider', () => { }) }, 10000) - it('uses seed router when no existing routers', done => { + it.each(usersDataSet)('uses seed router when no existing routers [user=%s]', (user, done) => { const updatedRoutingTable = newRoutingTable( null, [serverA, serverB], @@ -1163,12 +1171,12 @@ describe('#unit RoutingConnectionProvider', () => { ) connectionProvider - .acquireConnection({ accessMode: WRITE, database: null }) + .acquireConnection({ accessMode: WRITE, database: null, impersonatedUser: user }) .then(connection1 => { expect(connection1.address).toEqual(serverD) connectionProvider - .acquireConnection({ accessMode: READ, database: null }) + .acquireConnection({ accessMode: READ, database: null, impersonatedUser: user }) .then(connection2 => { expect(connection2.address).toEqual(serverC) @@ -1184,7 +1192,7 @@ describe('#unit RoutingConnectionProvider', () => { }) }, 10000) - it('uses resolved seed router when no existing routers', done => { + it.each(usersDataSet)('uses resolved seed router when no existing routers [user=%s]', (user, done) => { const updatedRoutingTable = newRoutingTable( null, [serverA, serverB], @@ -1212,12 +1220,12 @@ describe('#unit RoutingConnectionProvider', () => { ) connectionProvider - .acquireConnection({ accessMode: READ, database: null }) + .acquireConnection({ accessMode: READ, database: null, impersonatedUser: user }) .then(connection1 => { expect(connection1.address).toEqual(serverC) connectionProvider - .acquireConnection({ accessMode: WRITE, database: null }) + .acquireConnection({ accessMode: WRITE, database: null, impersonatedUser: user }) .then(connection2 => { expect(connection2.address).toEqual(serverF) @@ -1233,7 +1241,7 @@ describe('#unit RoutingConnectionProvider', () => { }) }, 10000) - it('uses resolved seed router that returns routing table when no existing routers exist', done => { + it.each(usersDataSet)('uses resolved seed router that returns routing table when no existing routers exist [user=%s]', (user, done) => { const updatedRoutingTable = newRoutingTable( null, [serverA, serverB, serverC], @@ -1263,12 +1271,12 @@ describe('#unit RoutingConnectionProvider', () => { ) connectionProvider - .acquireConnection({ accessMode: WRITE, database: null }) + .acquireConnection({ accessMode: WRITE, database: null, impersonatedUser: user }) .then(connection1 => { expect(connection1.address).toEqual(serverF) connectionProvider - .acquireConnection({ accessMode: READ, database: null }) + .acquireConnection({ accessMode: READ, database: null, impersonatedUser: user }) .then(connection2 => { expect(connection2.address).toEqual(serverD) @@ -1284,7 +1292,7 @@ describe('#unit RoutingConnectionProvider', () => { }) }, 10000) - it('ignores already probed routers after seed router resolution', done => { + it.each(usersDataSet)('ignores already probed routers after seed router resolution [user=%s]', (user, done) => { const updatedRoutingTable = newRoutingTable( null, [serverA, serverB], @@ -1323,12 +1331,12 @@ describe('#unit RoutingConnectionProvider', () => { ) connectionProvider - .acquireConnection({ accessMode: READ, database: null }) + .acquireConnection({ accessMode: READ, database: null, impersonatedUser: user }) .then(connection1 => { expect(connection1.address).toEqual(serverC) connectionProvider - .acquireConnection({ accessMode: WRITE, database: null }) + .acquireConnection({ accessMode: WRITE, database: null, impersonatedUser: user }) .then(connection2 => { expect(connection2.address).toEqual(serverE) @@ -1351,7 +1359,7 @@ describe('#unit RoutingConnectionProvider', () => { }) }, 10000) - it('throws session expired when refreshed routing table has no readers', done => { + it.each(usersDataSet)('throws session expired when refreshed routing table has no readers [user=%s]', (user, done) => { const pool = newPool() const updatedRoutingTable = newRoutingTable( null, @@ -1378,14 +1386,14 @@ describe('#unit RoutingConnectionProvider', () => { ) connectionProvider - .acquireConnection({ accessMode: READ, database: null }) + .acquireConnection({ accessMode: READ, database: null, impersonatedUser: user }) .catch(error => { expect(error.code).toEqual(SESSION_EXPIRED) done() }) }, 10000) - it('throws session expired when refreshed routing table has no writers', done => { + it.each(usersDataSet)('throws session expired when refreshed routing table has no writers [user=%s]', (user, done) => { const pool = newPool() const updatedRoutingTable = newRoutingTable( null, @@ -1412,14 +1420,14 @@ describe('#unit RoutingConnectionProvider', () => { ) connectionProvider - .acquireConnection({ accessMode: WRITE, database: null }) + .acquireConnection({ accessMode: WRITE, database: null, impersonatedUser: user }) .catch(error => { expect(error.code).toEqual(SESSION_EXPIRED) done() }) }, 10000) - it('should purge connections for address when AuthorizationExpired happens', async () => { + it.each(usersDataSet)('should purge connections for address when AuthorizationExpired happens [user=%s]', async (user) => { const pool = newPool() jest.spyOn(pool, 'purge') @@ -1443,12 +1451,14 @@ describe('#unit RoutingConnectionProvider', () => { const server2Connection = await connectionProvider.acquireConnection({ accessMode: 'WRITE', - database: null + database: null, + impersonatedUser: user }) const server3Connection = await connectionProvider.acquireConnection({ accessMode: 'READ', - database: null + database: null, + impersonatedUser: user }) server3Connection.handleAndTransformError(error, server3) @@ -1458,7 +1468,7 @@ describe('#unit RoutingConnectionProvider', () => { expect(pool.purge).toHaveBeenCalledWith(server2) }) - it('should purge not change error when AuthorizationExpired happens', async () => { + it.each(usersDataSet)('should purge not change error when AuthorizationExpired happens [user=%s]', async (user) => { const pool = newPool() jest.spyOn(pool, 'purge') @@ -1482,7 +1492,8 @@ describe('#unit RoutingConnectionProvider', () => { const server2Connection = await connectionProvider.acquireConnection({ accessMode: 'WRITE', - database: null + database: null, + impersonatedUser: user }) const error = server2Connection.handleAndTransformError( @@ -1493,7 +1504,7 @@ describe('#unit RoutingConnectionProvider', () => { expect(error).toBe(expectedError) }) - it('should purge connections for address when TokenExpired happens', async () => { + it.each(usersDataSet)('should purge connections for address when TokenExpired happens [user=%s]', async (user) => { const pool = newPool() jest.spyOn(pool, 'purge') @@ -1517,12 +1528,14 @@ describe('#unit RoutingConnectionProvider', () => { const server2Connection = await connectionProvider.acquireConnection({ accessMode: 'WRITE', - database: null + database: null, + impersonatedUser: user }) const server3Connection = await connectionProvider.acquireConnection({ accessMode: 'READ', - database: null + database: null, + impersonatedUser: user }) server3Connection.handleAndTransformError(error, server3) @@ -1532,7 +1545,7 @@ describe('#unit RoutingConnectionProvider', () => { expect(pool.purge).toHaveBeenCalledWith(server2) }) - it('should not change error when TokenExpired happens', async () => { + it.each(usersDataSet)('should not change error when TokenExpired happens [user=%s]', async (user) => { const pool = newPool() jest.spyOn(pool, 'purge') @@ -1556,7 +1569,8 @@ describe('#unit RoutingConnectionProvider', () => { const server2Connection = await connectionProvider.acquireConnection({ accessMode: 'WRITE', - database: null + database: null, + impersonatedUser: user }) const error = server2Connection.handleAndTransformError( @@ -1567,7 +1581,7 @@ describe('#unit RoutingConnectionProvider', () => { expect(error).toBe(expectedError) }) - it('should use resolved seed router after accepting table with no writers', done => { + it.each(usersDataSet)('should use resolved seed router after accepting table with no writers [user=%s]', (user, done) => { const routingTable1 = newRoutingTable( null, [serverA, serverB], @@ -1607,12 +1621,12 @@ describe('#unit RoutingConnectionProvider', () => { connectionProvider._useSeedRouter = false connectionProvider - .acquireConnection({ accessMode: READ, database: null }) + .acquireConnection({ accessMode: READ, database: null, impersonatedUser: user }) .then(connection1 => { expect(connection1.address).toEqual(serverC) connectionProvider - .acquireConnection({ accessMode: READ, database: null }) + .acquireConnection({ accessMode: READ, database: null, impersonatedUser: user }) .then(connection2 => { expect(connection2.address).toEqual(serverD) @@ -1625,7 +1639,7 @@ describe('#unit RoutingConnectionProvider', () => { ) connectionProvider - .acquireConnection({ accessMode: WRITE, database: null }) + .acquireConnection({ accessMode: WRITE, database: null, impersonatedUser: user }) .then(connection3 => { expect(connection3.address).toEqual(serverEE) @@ -1644,7 +1658,7 @@ describe('#unit RoutingConnectionProvider', () => { }, 10000) describe('multi-database', () => { - it('should acquire read connection from correct routing table', async () => { + it.each(usersDataSet)('should acquire read connection from correct routing table [user=%s]', async (user) => { const pool = newPool() const connectionProvider = newRoutingConnectionProvider( [ @@ -1661,20 +1675,22 @@ describe('#unit RoutingConnectionProvider', () => { const conn1 = await connectionProvider.acquireConnection({ accessMode: READ, - database: 'databaseA' + database: 'databaseA', + impersonatedUser: user }) expect(conn1 instanceof DelegateConnection).toBeTruthy() expect(conn1.address).toBe(server1) const conn2 = await connectionProvider.acquireConnection({ accessMode: READ, - database: 'databaseB' + database: 'databaseB', + impersonatedUser: user }) expect(conn2 instanceof DelegateConnection).toBeTruthy() expect(conn2.address).toBe(serverA) }, 10000) - it('should purge connections for address when AuthorizationExpired happens', async () => { + it.each(usersDataSet)('should purge connections for address when AuthorizationExpired happens [user=%s]', async (user) => { const pool = newPool() jest.spyOn(pool, 'purge') @@ -1699,12 +1715,14 @@ describe('#unit RoutingConnectionProvider', () => { const server2Connection = await connectionProvider.acquireConnection({ accessMode: 'WRITE', - database: 'databaseA' + database: 'databaseA', + impersonatedUser: user }) const serverAConnection = await connectionProvider.acquireConnection({ accessMode: 'READ', - database: 'databaseB' + database: 'databaseB', + impersonatedUser: user }) serverAConnection.handleAndTransformError(error, serverA) @@ -1714,7 +1732,7 @@ describe('#unit RoutingConnectionProvider', () => { expect(pool.purge).toHaveBeenCalledWith(server2) }) - it('should purge not change error when AuthorizationExpired happens', async () => { + it.each(usersDataSet)('should purge not change error when AuthorizationExpired happens [user=%s]', async (user) => { const pool = newPool() const connectionProvider = newRoutingConnectionProvider( @@ -1737,7 +1755,8 @@ describe('#unit RoutingConnectionProvider', () => { const server2Connection = await connectionProvider.acquireConnection({ accessMode: 'WRITE', - database: 'databaseA' + database: 'databaseA', + impersonatedUser: user }) const error = server2Connection.handleAndTransformError( @@ -1748,7 +1767,7 @@ describe('#unit RoutingConnectionProvider', () => { expect(error).toBe(expectedError) }) - it('should purge connections for address when TokenExpired happens', async () => { + it.each(usersDataSet)('should purge connections for address when TokenExpired happens [user=%s]', async (user) => { const pool = newPool() jest.spyOn(pool, 'purge') @@ -1773,12 +1792,14 @@ describe('#unit RoutingConnectionProvider', () => { const server2Connection = await connectionProvider.acquireConnection({ accessMode: 'WRITE', - database: 'databaseA' + database: 'databaseA', + impersonatedUser: user }) const serverAConnection = await connectionProvider.acquireConnection({ accessMode: 'READ', - database: 'databaseB' + database: 'databaseB', + impersonatedUser: user }) serverAConnection.handleAndTransformError(error, serverA) @@ -1788,7 +1809,7 @@ describe('#unit RoutingConnectionProvider', () => { expect(pool.purge).toHaveBeenCalledWith(server2) }) - it('should not change error when TokenExpired happens', async () => { + it.each(usersDataSet)('should not change error when TokenExpired happens [user=%s]', async (user) => { const pool = newPool() const connectionProvider = newRoutingConnectionProvider( @@ -1811,7 +1832,8 @@ describe('#unit RoutingConnectionProvider', () => { const server2Connection = await connectionProvider.acquireConnection({ accessMode: 'WRITE', - database: 'databaseA' + database: 'databaseA', + impersonatedUser: user }) const error = server2Connection.handleAndTransformError( @@ -1822,7 +1844,7 @@ describe('#unit RoutingConnectionProvider', () => { expect(error).toBe(expectedError) }) - it('should acquire write connection from correct routing table', async () => { + it.each(usersDataSet)('should acquire write connection from correct routing table [user=%s]', async (user) => { const pool = newPool() const connectionProvider = newRoutingConnectionProvider( [ @@ -1839,20 +1861,22 @@ describe('#unit RoutingConnectionProvider', () => { const conn1 = await connectionProvider.acquireConnection({ accessMode: WRITE, - database: 'databaseA' + database: 'databaseA', + impersonatedUser: user }) expect(conn1 instanceof DelegateConnection).toBeTruthy() expect(conn1.address).toBe(server2) const conn2 = await connectionProvider.acquireConnection({ accessMode: WRITE, - database: 'databaseB' + database: 'databaseB', + impersonatedUser: user }) expect(conn2 instanceof DelegateConnection).toBeTruthy() expect(conn2.address).toBe(serverB) }, 10000) - it('should fail connection acquisition if database is not known', async () => { + it.each(usersDataSet)('should fail connection acquisition if database is not known [user=%s]', async (user) => { const pool = newPool() const connectionProvider = newRoutingConnectionProvider( [ @@ -1864,7 +1888,8 @@ describe('#unit RoutingConnectionProvider', () => { try { await connectionProvider.acquireConnection({ accessMode: WRITE, - database: 'databaseX' + database: 'databaseX', + impersonatedUser: user }) } catch (error) { expect(error instanceof Neo4jError).toBeTruthy() @@ -1878,7 +1903,7 @@ describe('#unit RoutingConnectionProvider', () => { expect(false).toBeTruthy('exception expected') }, 10000) - it('should forget read server from correct routing table on availability error', async () => { + it.each(usersDataSet)('should forget read server from correct routing table on availability error [user=%s]', async (user) => { const pool = newPool() const connectionProvider = newRoutingConnectionProvider( [ @@ -1900,7 +1925,8 @@ describe('#unit RoutingConnectionProvider', () => { const conn1 = await connectionProvider.acquireConnection({ accessMode: READ, - database: 'databaseB' + database: 'databaseB', + impersonatedUser: user }) // when @@ -1925,7 +1951,7 @@ describe('#unit RoutingConnectionProvider', () => { ) }, 10000) - it('should forget write server from correct routing table on availability error', async () => { + it.each(usersDataSet)('should forget write server from correct routing table on availability error [user=%s]', async (user) => { const pool = newPool() const connectionProvider = newRoutingConnectionProvider( [ @@ -1947,7 +1973,8 @@ describe('#unit RoutingConnectionProvider', () => { const conn1 = await connectionProvider.acquireConnection({ accessMode: WRITE, - database: 'databaseB' + database: 'databaseB', + impersonatedUser: user }) // when @@ -1972,7 +1999,7 @@ describe('#unit RoutingConnectionProvider', () => { ) }, 10000) - it('should forget write server from the default database routing table on availability error', async () => { + it.each(usersDataSet)('should forget write server from the default database routing table on availability error [user=%s]', async (user) => { const pool = newPool() const connectionProvider = newRoutingConnectionProvider( [ @@ -1994,7 +2021,8 @@ describe('#unit RoutingConnectionProvider', () => { const conn1 = await connectionProvider.acquireConnection({ accessMode: WRITE, - database: null + database: null, + impersonatedUser: user }) // when @@ -2019,7 +2047,7 @@ describe('#unit RoutingConnectionProvider', () => { ) }) - it('should forget write server from the default database routing table on availability error when db not informed', async () => { + it.each(usersDataSet)('should forget write server from the default database routing table on availability error when db not informed [user=%s]', async (user) => { const pool = newPool() const connectionProvider = newRoutingConnectionProvider( [ @@ -2040,7 +2068,8 @@ describe('#unit RoutingConnectionProvider', () => { ) const conn1 = await connectionProvider.acquireConnection({ - accessMode: WRITE + accessMode: WRITE, + impersonatedUser: user }) // when @@ -2065,7 +2094,7 @@ describe('#unit RoutingConnectionProvider', () => { ) }) - it('should forget write server from correct routing table on write error', async () => { + it.each(usersDataSet)('should forget write server from correct routing table on write error [user=%s]', async (user) => { const pool = newPool() const connectionProvider = newRoutingConnectionProvider( [ @@ -2087,7 +2116,8 @@ describe('#unit RoutingConnectionProvider', () => { const conn1 = await connectionProvider.acquireConnection({ accessMode: WRITE, - database: 'databaseB' + database: 'databaseB', + impersonatedUser: user }) // when @@ -2112,7 +2142,7 @@ describe('#unit RoutingConnectionProvider', () => { ) }, 10000) - it('should purge expired routing tables after specified duration on update', async () => { + it.each(usersDataSet)('should purge expired routing tables after specified duration on update [user=%s]', async (user) => { var originalDateNow = Date.now Date.now = () => 50000 try { @@ -2172,7 +2202,8 @@ describe('#unit RoutingConnectionProvider', () => { // force a routing table update for databaseC const conn1 = await connectionProvider.acquireConnection({ accessMode: WRITE, - database: 'databaseC' + database: 'databaseC', + impersonatedUser: user }) expect(conn1).not.toBeNull() expect(conn1.address).toBe(server1) @@ -2192,11 +2223,256 @@ describe('#unit RoutingConnectionProvider', () => { [server2, server3], [server1] ) - expectNoRoutingTable(connectionProvider, 'databaseB') + expectNoRoutingTable(connectionProvider, 'databaseB', user) } finally { Date.now = originalDateNow } }, 10000) + + it.each(usersDataSet)('should resolve the home database name for the user=%s', async (user) => { + const pool = newPool() + const connectionProvider = newRoutingConnectionProvider( + [], + pool, + { + null: { + 'server-non-existing-seed-router:7687': newRoutingTableWithUser( + { + database: null, + routers: [server1, server2, server3], + readers: [server1, server2], + writers: [server3], + user, + routingTableDatabase: 'homedb' + } + ) + } + } + ) + + const connection = await connectionProvider.acquireConnection({ impersonatedUser: user, accessMode: READ }) + + expect(connection.address).toEqual(server1) + + expectRoutingTable( + connectionProvider, + 'homedb', + [server1, server2, server3], + [server1, server2], + [server3] + ) + }) + + it.each(usersDataSet)('should acquire the non default database name for the user=%s with the informed name', async (user) => { + const pool = newPool() + const connectionProvider = newRoutingConnectionProvider( + [], + pool, + { + 'databaseA': { + 'server-non-existing-seed-router:7687': newRoutingTableWithUser( + { + database: 'databaseA', + routers: [server1, server3], + readers: [server1], + writers: [server3], + user, + routingTableDatabase: 'homedb' + } + ) + }, + 'databaseB': { + 'server-non-existing-seed-router:7687': newRoutingTableWithUser( + { + database: 'homedb', + routers: [server2, server3], + readers: [server2], + writers: [server3], + user, + routingTableDatabase: 'homedb' + } + ) + } + } + ) + + const connection = await connectionProvider.acquireConnection({ database: 'databaseA', impersonatedUser: user, accessMode: READ }) + + expect(connection.address).toEqual(server1) + + expectRoutingTable( + connectionProvider, + 'databaseA', + [server1, server3], + [server1], + [server3] + ) + }) + + it.each(usersDataSet)('should be able to acquire connection for homedb using it name', async (user) => { + const pool = newPool() + const connectionProvider = newRoutingConnectionProvider( + [], + pool, + { + null: { + 'server-non-existing-seed-router:7687': newRoutingTableWithUser({ + database: null, + routers: [server1, server2, server3], + readers: [server1, server2], + writers: [server3], + user, + routingTableDatabase: 'homedb' + }) + } + } + ) + + await connectionProvider.acquireConnection({ accessMode: READ, impersonatedUser: user, accessMode: WRITE }) + const connection = await connectionProvider.acquireConnection({ accessMode: READ, database: 'homedb', impersonatedUser: user }) + + expect(connection.address).toEqual(server1) + expect(pool.has(server1)).toBeTruthy() + }) + + + it('should be to acquire connection other users homedb using it name', async () => { + const user1 = 'the-impostor-number-1' + const user2 = 'the-impostor-number-2' + const defaultUser = undefined + const expirationTime = int(Date.now()).add(60000) + + const pool = newPool() + const connectionProvider = newRoutingConnectionProvider( + [], + pool, + { + null: { + 'server-non-existing-seed-router:7687': [ + newRoutingTableWithUser( + { + database: null, + routers: [server1], + readers: [server1], + writers: [server1], + user: user1, + expirationTime, + routingTableDatabase: 'homedb1' + } + ), + + newRoutingTableWithUser( + { + database: null, + routers: [server2], + readers: [server2], + writers: [server2], + expirationTime, + user: user2, + routingTableDatabase: 'homedb2' + } + ), + + newRoutingTableWithUser( + { + database: null, + routers: [server3], + readers: [server3], + writers: [server3], + expirationTime, + user: defaultUser, + routingTableDatabase: 'default-home-db' + } + ) + ] + }, + "kakakaka": {} + }, + ) + + await connectionProvider.acquireConnection({ accessMode: WRITE, impersonatedUser: user2 }) + await connectionProvider.acquireConnection({ accessMode: WRITE, impersonatedUser: user1 }) + await connectionProvider.acquireConnection({ accessMode: WRITE }) + + + const defaultConnToHomeDb1 = await connectionProvider.acquireConnection({ accessMode: READ, database: 'homedb1' }) + expect(defaultConnToHomeDb1.address).toEqual(server1) + expect(pool.has(server1)).toBeTruthy() + + const defaultConnToHomeDb2 = await connectionProvider.acquireConnection({ accessMode: READ, database: 'homedb2' }) + expect(defaultConnToHomeDb2.address).toEqual(server2) + expect(pool.has(server2)).toBeTruthy() + + const user1ConnToDefaultHomeDb = await connectionProvider.acquireConnection({ accessMode: READ, database: 'default-home-db', impersonatedUser: user1 }) + expect(user1ConnToDefaultHomeDb.address).toEqual(server3) + expect(pool.has(server3)).toBeTruthy() + + const user1ConnToHomeDb2 = await connectionProvider.acquireConnection({ accessMode: READ, database: 'homedb2', impersonatedUser: user1 }) + expect(user1ConnToHomeDb2.address).toEqual(server2) + expect(pool.has(server2)).toBeTruthy() + + const user2ConnToDefaultHomeDb = await connectionProvider.acquireConnection({ accessMode: READ, database: 'default-home-db', impersonatedUser: user2 }) + expect(user2ConnToDefaultHomeDb.address).toEqual(server3) + expect(pool.has(server3)).toBeTruthy() + + const user2ConnToHomeDb1 = await connectionProvider.acquireConnection({ accessMode: READ, database: 'homedb1', impersonatedUser: user2 }) + expect(user2ConnToHomeDb1.address).toEqual(server1) + expect(pool.has(server1)).toBeTruthy() + + }) + + + it.each(usersDataSet)('should call onDatabaseNameResolved with the resolved db acquiring home db [user=%s]', async (user) => { + const pool = newPool() + const connectionProvider = newRoutingConnectionProvider( + [], + pool, + { + null: { + 'server-non-existing-seed-router:7687': newRoutingTableWithUser({ + database: null, + routers: [server1, server2, server3], + readers: [server1, server2], + writers: [server3], + user, + routingTableDatabase: 'homedb' + }) + } + } + ) + const onDatabaseNameResolved = jest.fn() + + await connectionProvider.acquireConnection({ accessMode: READ, impersonatedUser: user, onDatabaseNameResolved }) + + expect(onDatabaseNameResolved).toHaveBeenCalledWith('homedb') + }) + + it.each(usersDataSet)('should call onDatabaseNameResolved with the resolved db acquiring named db [user=%s]', async (user) => { + const pool = newPool() + const connectionProvider = newRoutingConnectionProvider( + [], + pool, + { + 'databaseA': { + 'server-non-existing-seed-router:7687': newRoutingTableWithUser({ + database: 'databaseA', + routers: [server1, server2, server3], + readers: [server1, server2], + writers: [server3], + user, + routingTableDatabase: 'databaseB' + }) + } + } + ) + + const onDatabaseNameResolved = jest.fn() + + await connectionProvider.acquireConnection({ accessMode: READ, impersonatedUser: user, onDatabaseNameResolved, database: 'databaseA' }) + + expect(onDatabaseNameResolved).toHaveBeenCalledWith('databaseA') + }) + }) }) @@ -2235,7 +2511,7 @@ function newRoutingConnectionProviderWithSeedRouter ( }) connectionProvider._connectionPool = pool routingTables.forEach(r => { - connectionProvider._routingTableRegistry.register(r.database, r) + connectionProvider._routingTableRegistry.register(r) }) connectionProvider._rediscovery = new FakeRediscovery(routerToRoutingTable) connectionProvider._hostNameResolver = new FakeDnsResolver(seedRouterResolved) @@ -2245,20 +2521,36 @@ function newRoutingConnectionProviderWithSeedRouter ( return connectionProvider } +function newRoutingTableWithUser ({ + database, + routers, + readers, + writers, + expirationTime = Integer.MAX_VALUE, + routingTableDatabase, + user +}) { + const routingTable = newRoutingTable(database, routers, readers, writers, expirationTime, routingTableDatabase) + routingTable.user = user + return routingTable +} + function newRoutingTable ( database, routers, readers, writers, - expirationTime = Integer.MAX_VALUE + expirationTime = Integer.MAX_VALUE, + routingTableDatabase ) { - return new RoutingTable({ - database, + var routingTable = new RoutingTable({ + database: database || routingTableDatabase, routers, readers, writers, expirationTime }) + return routingTable } function setupRoutingConnectionProviderToRememberRouters ( @@ -2342,10 +2634,15 @@ class FakeRediscovery { this._routerToRoutingTable = routerToRoutingTable } - lookupRoutingTableOnRouter (ignored, database, router) { + lookupRoutingTableOnRouter (ignored, database, router, user) { const table = this._routerToRoutingTable[database || null] if (table) { - return Promise.resolve(table[router.asKey()]) + let routingTables = table[router.asKey()] + let routingTable = routingTables + if (routingTables instanceof Array) { + routingTable = routingTables.find(rt => rt.user === user) + } + return Promise.resolve(routingTable) } return Promise.resolve(null) } diff --git a/packages/bolt-connection/test/rediscovery/rediscovery.test.js b/packages/bolt-connection/test/rediscovery/rediscovery.test.js index 12b34b926..7318755f6 100644 --- a/packages/bolt-connection/test/rediscovery/rediscovery.test.js +++ b/packages/bolt-connection/test/rediscovery/rediscovery.test.js @@ -48,6 +48,7 @@ describe('#unit Rediscovery', () => { const expectedRoutingTable = new RoutingTable({ database: 'db', + ttl, expirationTime: calculateExpirationTime(Date.now(), ttl), routers: [ServerAddress.fromUrl('bolt://localhost:7687')], writers: [ServerAddress.fromUrl('bolt://localhost:7686')], diff --git a/packages/core/src/connection-provider.ts b/packages/core/src/connection-provider.ts index ee0e32e0a..8880ba672 100644 --- a/packages/core/src/connection-provider.ts +++ b/packages/core/src/connection-provider.ts @@ -20,6 +20,7 @@ import Connection from './connection' import { bookmark } from './internal' + /** * Inteface define a common way to acquire a connection * @@ -36,14 +37,18 @@ class ConnectionProvider { * synchronize on creation of databases and is never used in direct drivers. * * @param {object} param - object parameter - * @param {string} param.accessMode - the access mode for the to-be-acquired connection - * @param {string} param.database - the target database for the to-be-acquired connection - * @param {Bookmark} param.bookmarks - the bookmarks to send to routing discovery + * @property {string} param.accessMode - the access mode for the to-be-acquired connection + * @property {string} param.database - the target database for the to-be-acquired connection + * @property {Bookmark} param.bookmarks - the bookmarks to send to routing discovery + * @property {string} param.impersonatedUser - the impersonated user + * @property {function (databaseName:string?)} param.onDatabaseNameResolved - Callback called when the database name get resolved */ - acquireConnection(params?: { + acquireConnection(param?: { accessMode?: string database?: string - bookmarks: bookmark.Bookmark + bookmarks: bookmark.Bookmark, + impersonatedUser?: string, + onDatabaseNameResolved?: (databaseName?: string) => void }): Promise { throw Error('Not implemented') } @@ -68,6 +73,16 @@ class ConnectionProvider { throw Error('Not implemented') } + /** + * This method checks whether the backend database supports transaction config functionality + * by checking protocol handshake result. + * + * @returns {Promise} + */ + supportsUserImpersonation(): Promise { + throw Error('Not implemented') + } + /** * Closes this connection provider along with its internals (connections, pools, etc.) * diff --git a/packages/core/src/driver.ts b/packages/core/src/driver.ts index e68190973..3e6910ec8 100644 --- a/packages/core/src/driver.ts +++ b/packages/core/src/driver.ts @@ -78,6 +78,17 @@ type CreateConnectionProvider = ( hostNameResolver: ConfiguredCustomResolver ) => ConnectionProvider +type CreateSession = (args: { + mode: SessionMode + connectionProvider: ConnectionProvider + bookmark?: Bookmark + database: string + config: any + reactive: boolean + fetchSize: number, + impersonatedUser?: string +}) => Session + interface DriverConfig { encrypted?: EncryptionLevel | boolean trust?: TrustStrategy @@ -102,6 +113,7 @@ class Driver { private readonly _log: Logger private readonly _createConnectionProvider: CreateConnectionProvider private _connectionProvider: ConnectionProvider | null + private readonly _createSession: CreateSession /** * You should not be calling this directly, instead use {@link driver}. @@ -110,11 +122,13 @@ class Driver { * @param {Object} meta Metainformation about the driver * @param {Object} config * @param {function(id: number, config:Object, log:Logger, hostNameResolver: ConfiguredCustomResolver): ConnectionProvider } createConnectonProvider Creates the connection provider - */ + * @param {function(args): Session } createSession Creates the a session + */ constructor( meta: MetaInfo, config: DriverConfig = {}, - createConnectonProvider: CreateConnectionProvider + createConnectonProvider: CreateConnectionProvider, + createSession: CreateSession = args => new Session(args) ) { sanitizeConfig(config) validateConfig(config) @@ -124,6 +138,7 @@ class Driver { this._config = config this._log = Logger.create(config) this._createConnectionProvider = createConnectonProvider + this._createSession = createSession /** * Reference to the connection provider. Initialized lazily by {@link _getOrCreateConnectionProvider}. @@ -177,6 +192,19 @@ class Driver { return connectionProvider.supportsTransactionConfig() } + /** + * Returns whether the server supports user impersonation capabilities based on the protocol + * version negotiated via handshake. + * + * Note that this function call _always_ causes a round-trip to the server. + * + * @returns {Promise} promise resolved with a boolean or rejected with error. + */ + supportsUserImpersonation(): Promise { + const connectionProvider = this._getOrCreateConnectionProvider() + return connectionProvider.supportsUserImpersonation() + } + /** * @protected * @returns {boolean} @@ -224,17 +252,20 @@ class Driver { * @param {number} param.fetchSize - The record fetch size of each batch of this session. * Use {@link FETCH_ALL} to always pull all records in one batch. This will override the config value set on driver config. * @param {string} param.database - The database this session will operate on. + * @param {string} param.impersonatedUser - The username which the user wants to impersonate for the duration of the session. * @return {Session} new session. */ session({ defaultAccessMode = WRITE, bookmarks: bookmarkOrBookmarks, database = '', + impersonatedUser, fetchSize }: { defaultAccessMode?: SessionMode bookmarks?: string | string[] database?: string + impersonatedUser?: string fetchSize?: number } = {}): Session { return this._newSession({ @@ -242,6 +273,7 @@ class Driver { bookmarkOrBookmarks, database, reactive: false, + impersonatedUser, fetchSize: validateFetchSizeValue(fetchSize, this._config.fetchSize!!) }) } @@ -277,12 +309,14 @@ class Driver { bookmarkOrBookmarks, database, reactive, + impersonatedUser, fetchSize }: { defaultAccessMode: SessionMode bookmarkOrBookmarks?: string | string[] database: string reactive: boolean + impersonatedUser?: string fetchSize: number }) { const sessionMode = Session._validateSessionMode(defaultAccessMode) @@ -290,13 +324,14 @@ class Driver { const bookmark = bookmarkOrBookmarks ? new Bookmark(bookmarkOrBookmarks) : Bookmark.empty() - return new Session({ + return this._createSession({ mode: sessionMode, database: database || '', connectionProvider, bookmark, config: this._config, reactive, + impersonatedUser, fetchSize }) } diff --git a/packages/core/src/internal/connection-holder.ts b/packages/core/src/internal/connection-holder.ts index 4345e3aaa..466805409 100644 --- a/packages/core/src/internal/connection-holder.ts +++ b/packages/core/src/internal/connection-holder.ts @@ -82,31 +82,42 @@ class ConnectionHolder implements ConnectionHolderInterface { private _connectionProvider?: ConnectionProvider private _referenceCount: number private _connectionPromise: Promise + private _impersonatedUser?: string + private _onDatabaseNameResolved?: (databaseName?: string) => void /** * @constructor - * @param {string} mode - the access mode for new connection holder. - * @param {string} database - the target database name. - * @param {Bookmark} bookmark - the last bookmark - * @param {ConnectionProvider} connectionProvider - the connection provider to acquire connections from. + * @param {object} params + * @property {string} params.mode - the access mode for new connection holder. + * @property {string} params.database - the target database name. + * @property {Bookmark} params.bookmark - the last bookmark + * @property {ConnectionProvider} params.connectionProvider - the connection provider to acquire connections from. + * @property {string?} params.impersonatedUser - the user which will be impersonated + * @property {function(databaseName:string)} params.onDatabaseNameResolved - callback called when the database name is resolved */ constructor({ mode = ACCESS_MODE_WRITE, database = '', bookmark, - connectionProvider + connectionProvider, + impersonatedUser, + onDatabaseNameResolved }: { mode?: string database?: string bookmark?: Bookmark - connectionProvider?: ConnectionProvider + connectionProvider?: ConnectionProvider, + impersonatedUser?: string, + onDatabaseNameResolved?: (databaseName?: string) => void } = {}) { this._mode = mode this._database = database ? assertString(database, 'database') : '' this._bookmark = bookmark || Bookmark.empty() this._connectionProvider = connectionProvider + this._impersonatedUser = impersonatedUser this._referenceCount = 0 this._connectionPromise = Promise.resolve() + this._onDatabaseNameResolved = onDatabaseNameResolved } mode(): string | undefined { @@ -117,6 +128,10 @@ class ConnectionHolder implements ConnectionHolderInterface { return this._database } + setDatabase(database?: string) { + this._database = database + } + bookmark(): Bookmark { return this._bookmark } @@ -134,7 +149,9 @@ class ConnectionHolder implements ConnectionHolderInterface { this._connectionPromise = this._connectionProvider.acquireConnection({ accessMode: this._mode, database: this._database, - bookmarks: this._bookmark + bookmarks: this._bookmark, + impersonatedUser: this._impersonatedUser, + onDatabaseNameResolved: this._onDatabaseNameResolved }) } else { this._referenceCount++ @@ -288,6 +305,6 @@ class EmptyConnectionHolder extends ConnectionHolder { const EMPTY_CONNECTION_HOLDER: EmptyConnectionHolder = new EmptyConnectionHolder() // eslint-disable-next-line handle-callback-err -function ignoreError(error: Error) {} +function ignoreError(error: Error) { } export { ConnectionHolder, ReadOnlyConnectionHolder, EMPTY_CONNECTION_HOLDER } diff --git a/packages/core/src/internal/constants.ts b/packages/core/src/internal/constants.ts index 134c901af..3c6c1a589 100644 --- a/packages/core/src/internal/constants.ts +++ b/packages/core/src/internal/constants.ts @@ -31,6 +31,7 @@ const BOLT_PROTOCOL_V4_0: number = 4.0 const BOLT_PROTOCOL_V4_1: number = 4.1 const BOLT_PROTOCOL_V4_2: number = 4.2 const BOLT_PROTOCOL_V4_3: number = 4.3 +const BOLT_PROTOCOL_V4_4: number = 4.4 export { FETCH_ALL, @@ -44,5 +45,6 @@ export { BOLT_PROTOCOL_V4_0, BOLT_PROTOCOL_V4_1, BOLT_PROTOCOL_V4_2, - BOLT_PROTOCOL_V4_3 + BOLT_PROTOCOL_V4_3, + BOLT_PROTOCOL_V4_4 } diff --git a/packages/core/src/session.ts b/packages/core/src/session.ts index 7c6f5f91a..5aff5d06b 100644 --- a/packages/core/src/session.ts +++ b/packages/core/src/session.ts @@ -57,7 +57,9 @@ class Session { private _hasTx: boolean private _lastBookmark: Bookmark private _transactionExecutor: TransactionExecutor + private _impersonatedUser?: string private _onComplete: (meta: any) => void + private _databaseNameResolved: boolean /** * @constructor @@ -70,6 +72,7 @@ class Session { * @param {Object} args.config={} - This driver configuration. * @param {boolean} args.reactive - Whether this session should create reactive streams * @param {number} args.fetchSize - Defines how many records is pulled in each pulling batch + * @param {string} args.impersonatedUser - The username which the user wants to impersonate for the duration of the session. */ constructor({ mode, @@ -78,7 +81,8 @@ class Session { database, config, reactive, - fetchSize + fetchSize, + impersonatedUser }: { mode: SessionMode connectionProvider: ConnectionProvider @@ -86,29 +90,37 @@ class Session { database: string config: any reactive: boolean - fetchSize: number + fetchSize: number, + impersonatedUser?: string }) { this._mode = mode this._database = database this._reactive = reactive this._fetchSize = fetchSize + this._onDatabaseNameResolved = this._onDatabaseNameResolved.bind(this) this._readConnectionHolder = new ConnectionHolder({ mode: ACCESS_MODE_READ, database, bookmark, - connectionProvider + connectionProvider, + impersonatedUser, + onDatabaseNameResolved: this._onDatabaseNameResolved }) this._writeConnectionHolder = new ConnectionHolder({ mode: ACCESS_MODE_WRITE, database, bookmark, - connectionProvider + connectionProvider, + impersonatedUser, + onDatabaseNameResolved: this._onDatabaseNameResolved }) this._open = true this._hasTx = false + this._impersonatedUser = impersonatedUser this._lastBookmark = bookmark || Bookmark.empty() this._transactionExecutor = _createTransactionExecutor(config) this._onComplete = this._onCompleteCallback.bind(this) + this._databaseNameResolved = this._database !== '' } /** @@ -142,6 +154,7 @@ class Session { txConfig: autoCommitTxConfig, mode: this._mode, database: this._database, + impersonatedUser: this._impersonatedUser, afterComplete: this._onComplete, reactive: this._reactive, fetchSize: this._fetchSize @@ -251,6 +264,7 @@ class Session { const tx = new Transaction({ connectionHolder, + impersonatedUser: this._impersonatedUser, onClose: this._transactionClosed.bind(this), onBookmark: this._updateBookmark.bind(this), onConnection: this._assertSessionIsOpen.bind(this), @@ -343,6 +357,20 @@ class Session { ) } + /** + * @private + * @param {string|undefined} database The resolved database name + */ + _onDatabaseNameResolved(database?: string): void{ + if (!this._databaseNameResolved) { + const normalizedDatabase = database || '' + this._database = normalizedDatabase + this._readConnectionHolder.setDatabase(normalizedDatabase) + this._writeConnectionHolder.setDatabase(normalizedDatabase) + this._databaseNameResolved = true + } + } + /** * Update value of the last bookmark. * @private diff --git a/packages/core/src/transaction.ts b/packages/core/src/transaction.ts index 7cfa10fd6..f48a63a11 100644 --- a/packages/core/src/transaction.ts +++ b/packages/core/src/transaction.ts @@ -52,6 +52,7 @@ class Transaction { private _onComplete: (metadata: any) => void private _fetchSize: number private _results: any[] + private _impersonatedUser?: string /** * @constructor @@ -62,6 +63,7 @@ class Transaction { * is not yet released. * @param {boolean} reactive whether this transaction generates reactive streams * @param {number} fetchSize - the record fetch size in each pulling batch. + * @param {string} impersonatedUser - The name of the user which should be impersonated for the duration of the session. */ constructor({ connectionHolder, @@ -69,7 +71,8 @@ class Transaction { onBookmark, onConnection, reactive, - fetchSize + fetchSize, + impersonatedUser }: { connectionHolder: ConnectionHolder onClose: () => void @@ -77,6 +80,7 @@ class Transaction { onConnection: () => void reactive: boolean fetchSize: number + impersonatedUser?: string }) { this._connectionHolder = connectionHolder this._reactive = reactive @@ -88,6 +92,7 @@ class Transaction { this._onComplete = this._onCompleteCallback.bind(this) this._fetchSize = fetchSize this._results = [] + this._impersonatedUser = impersonatedUser } /** @@ -107,6 +112,7 @@ class Transaction { txConfig: txConfig, mode: this._connectionHolder.mode(), database: this._connectionHolder.database(), + impersonatedUser: this._impersonatedUser, beforeError: this._onError, afterComplete: this._onComplete }) @@ -289,7 +295,7 @@ const _states = { onComplete, onConnection, reactive, - fetchSize + fetchSize, }: StateTransitionParams ): any => { // RUN in explicit transaction can't contain bookmarks and transaction configuration @@ -305,7 +311,7 @@ const _states = { beforeError: onError, afterComplete: onComplete, reactive: reactive, - fetchSize: fetchSize + fetchSize: fetchSize, }) } else { throw newError('No connection available') diff --git a/packages/core/test/driver.test.ts b/packages/core/test/driver.test.ts index 6dde10230..ab3071dad 100644 --- a/packages/core/test/driver.test.ts +++ b/packages/core/test/driver.test.ts @@ -16,14 +16,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -import { ConnectionProvider, newError } from '../src' +import { ConnectionProvider, newError, Session } from '../src' import Driver from '../src/driver' +import { Bookmark } from '../src/internal/bookmark' import { Logger } from '../src/internal/logger' import { ConfiguredCustomResolver } from '../src/internal/resolver' describe('Driver', () => { let driver: Driver | null let connectionProvider: ConnectionProvider + let createSession: any const META_INFO = { routing: false, typename: '', @@ -34,10 +36,12 @@ describe('Driver', () => { beforeEach(() => { connectionProvider = new ConnectionProvider() connectionProvider.close = jest.fn(() => Promise.resolve()) + createSession = jest.fn(args => new Session(args)) driver = new Driver( META_INFO, CONFIG, - mockCreateConnectonProvider(connectionProvider) + mockCreateConnectonProvider(connectionProvider), + createSession ) }) @@ -48,6 +52,26 @@ describe('Driver', () => { } }) + + describe('.session()', () => { + it('should create the session with impersonated user', () => { + const impersonatedUser = 'the impostor' + + const session = driver!.session({ impersonatedUser }) + + expect(session).not.toBeUndefined() + expect(createSession).toHaveBeenCalledWith(expectedSessionParams({ impersonatedUser })) + }) + + + it('should create the session without impersonated user', () => { + const session = driver!.session() + + expect(session).not.toBeUndefined() + expect(createSession).toHaveBeenCalledWith(expectedSessionParams()) + }) + }) + it.each([ ['Promise.resolve(true)', Promise.resolve(true)], ['Promise.resolve(false)', Promise.resolve(false)], @@ -84,6 +108,25 @@ describe('Driver', () => { promise.catch(_ => 'Do nothing').finally(() => {}) }) + it.each([ + ['Promise.resolve(true)', Promise.resolve(true)], + ['Promise.resolve(false)', Promise.resolve(false)], + [ + "Promise.reject(newError('something went wrong'))", + Promise.reject(newError('something went wrong')) + ] + ])('.supportsUserImpersonation() => %s', (_, expectedPromise) => { + connectionProvider.supportsUserImpersonation = jest.fn( + () => expectedPromise + ) + + const promise: Promise = driver!.supportsUserImpersonation() + + expect(promise).toBe(expectedPromise) + + promise.catch(_ => 'Do nothing').finally(() => {}) + }) + function mockCreateConnectonProvider(connectionProvider: ConnectionProvider) { return ( id: number, @@ -92,4 +135,23 @@ describe('Driver', () => { hostNameResolver: ConfiguredCustomResolver ) => connectionProvider } + + function expectedSessionParams(extra: any = {}) { + return { + bookmark: Bookmark.empty(), + config: { + connectionAcquisitionTimeout: 60000, + fetchSize: 1000, + maxConnectionLifetime: 3600000, + maxConnectionPoolSize: 100, + }, + connectionProvider, + database: '', + fetchSize: 1000, + mode: "WRITE", + reactive: false, + impersonatedUser: undefined, + ...extra + } + } }) diff --git a/packages/neo4j-driver-lite/test/unit/index.test.ts b/packages/neo4j-driver-lite/test/unit/index.test.ts index 809095131..c6e868397 100644 --- a/packages/neo4j-driver-lite/test/unit/index.test.ts +++ b/packages/neo4j-driver-lite/test/unit/index.test.ts @@ -217,7 +217,8 @@ describe('index', () => { acquireConnection: () => Promise.reject(Error('something wrong')), close: () => Promise.resolve(), supportsMultiDb: () => Promise.resolve(true), - supportsTransactionConfig: () => Promise.resolve(true) + supportsTransactionConfig: () => Promise.resolve(true), + supportsUserImpersonation: () => Promise.resolve(true) } }) expect(session).toBeDefined() diff --git a/packages/neo4j-driver/src/driver.js b/packages/neo4j-driver/src/driver.js index c09026209..cb20bb3dd 100644 --- a/packages/neo4j-driver/src/driver.js +++ b/packages/neo4j-driver/src/driver.js @@ -54,19 +54,22 @@ class Driver extends CoreDriver { * @param {string|string[]} param.bookmarks - The initial reference or references to some previous transactions. Value is optional and * absence indicates that the bookmarks do not exist or are unknown. * @param {string} param.database - The database this session will operate on. + * @param {string} param.impersonatedUser - The name of the user which should be impersonated for the duration of the session. * @returns {RxSession} new reactive session. */ rxSession ({ defaultAccessMode = WRITE, bookmarks, database = '', - fetchSize + fetchSize, + impersonatedUser } = {}) { return new RxSession({ session: this._newSession({ defaultAccessMode, bookmarks, database, + impersonatedUser, reactive: true, fetchSize: validateFetchSizeValue(fetchSize, this._config.fetchSize) }), diff --git a/packages/neo4j-driver/test/internal/node/direct.driver.boltkit.test.js b/packages/neo4j-driver/test/internal/node/direct.driver.boltkit.test.js index a9d877e06..12ac605d2 100644 --- a/packages/neo4j-driver/test/internal/node/direct.driver.boltkit.test.js +++ b/packages/neo4j-driver/test/internal/node/direct.driver.boltkit.test.js @@ -158,6 +158,44 @@ describe('#stub-direct direct driver with stub server', () => { }, 60000) }) + describe('should report whether user impersonation is supported', () => { + async function verifySupportsUserImpersonation (version, expected) { + if (!boltStub.supported) { + return + } + + const server = await boltStub.start( + `./test/resources/boltstub/${version}/supports_protocol_version.script`, + 9001 + ) + + const driver = boltStub.newDriver('bolt://127.0.0.1:9001') + + await expectAsync(driver.supportsUserImpersonation()).toBeResolvedTo( + expected + ) + + await driver.close() + await server.exit() + } + + it('v3', () => verifySupportsUserImpersonation('v3', false), 60000) + it('v4', () => verifySupportsUserImpersonation('v4', false), 60000) + it('v4.2', () => verifySupportsUserImpersonation('v4.2', false), 60000) + it('v4.4', () => verifySupportsUserImpersonation('v4.4', true), 60000) + it('on error', async () => { + const driver = boltStub.newDriver('bolt://127.0.0.1:9001') + + await expectAsync(driver.supportsUserImpersonation()).toBeRejectedWith( + jasmine.objectContaining({ + code: SERVICE_UNAVAILABLE + }) + ) + + await driver.close() + }, 60000) + }) + describe('should cancel stream with result summary method', () => { async function verifyFailureOnCommit (version) { if (!boltStub.supported) { diff --git a/packages/neo4j-driver/test/internal/routing-table.test.js b/packages/neo4j-driver/test/internal/routing-table.test.js index f30738325..d88ba12cf 100644 --- a/packages/neo4j-driver/test/internal/routing-table.test.js +++ b/packages/neo4j-driver/test/internal/routing-table.test.js @@ -278,6 +278,38 @@ describe('#unit RoutingTable', () => { expect(result.expirationTime).toEqual(Integer.MAX_VALUE) }) + ;[ + [undefined, undefined, null], + [undefined, null, null], + [undefined, 'homedb2', 'homedb2'], + [null, undefined, null], + [null, null, null], + [null, 'homedb2', 'homedb2'], + ['homedb', undefined, 'homedb'], + ['homedb', null, 'homedb'], + ['homedb', 'homedb2', 'homedb'] + ].forEach(([database, databaseInMetadata, expected]) => { + it(`should return resolve correctly the database [database=${database}, databaseInMetadata=${databaseInMetadata}]`, () => { + const routers = ['router:7699'] + const readers = ['reader1:7699', 'reader2:7699'] + const writers = ['writer1:7693', 'writer2:7692', 'writer3:7629'] + + const result = RoutingTable.fromRawRoutingTable( + database, + ServerAddress.fromUrl('localhost:7687'), + RawRoutingTable.ofMessageResponse( + newMetadata({ + routers, + readers, + writers, + database: databaseInMetadata + }) + ) + ) + + expect(result.database).toEqual(expected) + }) + }) it('should return Integer.MAX_VALUE as expirationTime when ttl is negative', async () => { const ttl = int(-2) @@ -479,7 +511,8 @@ describe('#unit RoutingTable', () => { routers = [], readers = [], writers = [], - extra = [] + extra = [], + database = undefined } = {}) { const routersField = { role: 'ROUTE', @@ -496,7 +529,8 @@ describe('#unit RoutingTable', () => { return { rt: { ttl, - servers: [routersField, readersField, writersField, ...extra] + servers: [routersField, readersField, writersField, ...extra], + db: database } } } diff --git a/packages/neo4j-driver/test/resources/boltstub/v4.4/supports_protocol_version.script b/packages/neo4j-driver/test/resources/boltstub/v4.4/supports_protocol_version.script new file mode 100644 index 000000000..fb807b2ad --- /dev/null +++ b/packages/neo4j-driver/test/resources/boltstub/v4.4/supports_protocol_version.script @@ -0,0 +1,2 @@ +!: AUTO GOODBYE +!: BOLT 4.4 diff --git a/packages/testkit-backend/src/request-handlers.js b/packages/testkit-backend/src/request-handlers.js index 1b6180ac9..0754cc24f 100644 --- a/packages/testkit-backend/src/request-handlers.js +++ b/packages/testkit-backend/src/request-handlers.js @@ -63,7 +63,7 @@ export function DriverClose (context, data, wire) { } export function NewSession (context, data, wire) { - let { driverId, accessMode, bookmarks, database, fetchSize } = data + let { driverId, accessMode, bookmarks, database, fetchSize, impersonatedUser } = data switch (accessMode) { case 'r': accessMode = neo4j.session.READ @@ -80,7 +80,8 @@ export function NewSession (context, data, wire) { defaultAccessMode: accessMode, bookmarks, database, - fetchSize + fetchSize, + impersonatedUser }) const id = context.addSession(session) wire.writeResponse('Session', { id }) @@ -265,7 +266,9 @@ export function GetFeatures (_context, _params, wire) { 'Feature:Auth:Kerberos', 'Feature:Auth:Bearer', 'AuthorizationExpiredTreatment', - 'ConfHint:connection.recv_timeout_seconds' + 'ConfHint:connection.recv_timeout_seconds', + 'Feature:Bolt:4.4', + 'Feature:Impersonation' ] }) } @@ -304,7 +307,7 @@ export function GetRoutingTable (context, { driverId, database }, wire) { driver && driver._getOrCreateConnectionProvider() && driver._getOrCreateConnectionProvider()._routingTableRegistry && - driver._getOrCreateConnectionProvider()._routingTableRegistry.get(database) + driver._getOrCreateConnectionProvider()._routingTableRegistry.get(null, database) if (routingTable) { wire.writeResponse('RoutingTable', {