Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1601,9 +1601,15 @@ boolean onDone(TDSReader tdsReader) throws SQLServerException {
if (null != procedureName)
return false;

//For Insert, we must fetch additional TDS_DONE token that comes with the actual update count
if (doneToken.cmdIsInsert() && (-1 != doneToken.getUpdateCount()) && EXECUTE == executeMethod) {
return true;
}

// Always return all update counts from statements executed through Statement.execute()
if (EXECUTE == executeMethod)
return false;
if (EXECUTE == executeMethod) {
return false;
}

// Statement.executeUpdate() may or may not return this update count depending on the
// setting of the lastUpdateCount connection property:
Expand Down
6 changes: 5 additions & 1 deletion src/main/java/com/microsoft/sqlserver/jdbc/StreamDone.java
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,12 @@ final long getUpdateCount() {
}
}

final boolean cmdIsInsert() {
return (CMD_INSERT == curCmd);
}

final boolean cmdIsDMLOrDDL() {
switch (curCmd) {
switch (curCmd) {
case CMD_INSERT:
case CMD_BULKINSERT:
case CMD_DELETE:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import java.sql.Statement;
import java.sql.Types;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.UUID;
import java.util.concurrent.Executors;
Expand Down Expand Up @@ -2692,4 +2694,323 @@ public void terminate() throws Exception {
}
}
}


@Nested
public class TCGenKeys {
private final String tableName = AbstractSQLGenerator
.escapeIdentifier(RandomUtil.getIdentifier("TCInsertWithGenKeys"));
private final String idTableName = AbstractSQLGenerator
.escapeIdentifier(RandomUtil.getIdentifier("TCInsertWithGenKeysIDs"));

private final String triggerName = AbstractSQLGenerator.escapeIdentifier("Trigger");
private final int NUM_ROWS = 3;

@BeforeEach
public void setup() throws Exception {
try (Connection con = getConnection()) {
con.setAutoCommit(false);
try (Statement stmt = con.createStatement()) {
TestUtils.dropTriggerIfExists(triggerName, stmt);
stmt.executeUpdate("CREATE TABLE " + tableName + " (ID int NOT NULL IDENTITY(1,1) PRIMARY KEY, NAME varchar(32));");

stmt.executeUpdate("CREATE TABLE " + idTableName + "(ID int NOT NULL IDENTITY(1,1) PRIMARY KEY);");

stmt.executeUpdate("CREATE TRIGGER " + triggerName + " ON " + tableName + " FOR INSERT AS INSERT INTO " + idTableName + " DEFAULT VALUES;");

for (int i = 0; i < NUM_ROWS; i++) {
stmt.executeUpdate("INSERT INTO " + tableName + " (NAME) VALUES ('test')");
}

}
con.commit();
}
}

/**
* Tests executeUpdate for Insert followed by getGenerateKeys
*
* @throws Exception
*/
@Test
public void testExecuteUpdateInsertAndGenKeys() throws Exception {
try (Connection con = getConnection()) {
try(Statement stmt = con.createStatement()) {
String sql = "INSERT INTO " + tableName + " (NAME) VALUES('test')";
List<String> lst = Arrays.asList("ID");
String[] arr = lst.toArray(new String[0]);
stmt.executeUpdate(sql, arr);
try (ResultSet generatedKeys = stmt.getGeneratedKeys()) {
if (generatedKeys.next()) {
int id = generatedKeys.getInt(1);
assertEquals(id, 4, "id should have been 4, but received : " + id);
}
}
}
} catch (SQLException e) {
fail(TestResource.getResource("R_unexpectedException") + e.getMessage());
}
}

/**
* Tests execute for Insert followed by getGenerateKeys
*
* @throws Exception
*/
@Test
public void testExecuteInsertAndGenKeys() throws Exception {
try (Connection con = getConnection()) {
try(Statement stmt = con.createStatement()) {
String sql = "INSERT INTO " + tableName + " (NAME) VALUES('test')";
List<String> lst = Arrays.asList("ID");
String[] arr = lst.toArray(new String[0]);
stmt.execute(sql, arr);
try (ResultSet generatedKeys = stmt.getGeneratedKeys()) {
if (generatedKeys.next()) {
int id = generatedKeys.getInt(1);
assertEquals(id, 4, "generated key should have been 4");
}
}
}
} catch (SQLException e) {
fail(TestResource.getResource("R_unexpectedException") + e.getMessage());
}
}

/**
* Tests execute for Insert followed by select
*
* @throws Exception
*/
@Test
public void testExecuteInsertAndSelect() throws Exception {

try (Connection con = getConnection()) {
try(Statement stmt = con.createStatement()) {
boolean retval = stmt.execute("INSERT INTO " + tableName +" (NAME) VALUES('test') SELECT NAME FROM " + tableName + " WHERE ID = 1");
do {
if (retval == false) {
int count = stmt.getUpdateCount();
if (count == -1) {
// no more results
break;
} else {
assertEquals(count, 1, "update count should have been 1");
}
} else {
// process ResultSet
try (ResultSet rs = stmt.getResultSet()) {
if (rs.next()) {
String val = rs.getString(1);
assertEquals(val, "test", "read value should have been 'test'");
}
}
}
retval = stmt.getMoreResults();
} while (true);
}
} catch (SQLException e) {
fail(TestResource.getResource("R_unexpectedException") + e.getMessage());
}
}


/**
* Tests execute for Merge followed by select
*
* @throws Exception
*/
@Test
public void testExecuteMergeAndSelect() throws Exception {
try (Connection con = getConnection()) {
try(Statement stmt = con.createStatement()) {
boolean retval = stmt.execute("MERGE INTO " + tableName + " AS target USING (VALUES ('test1')) AS source (name) ON target.name = source.name WHEN NOT MATCHED THEN INSERT (name) VALUES ('test1'); SELECT NAME FROM " + tableName + " WHERE ID = 1");
do {
if (retval == false) {
int count = stmt.getUpdateCount();
if (count == -1) {
// no more results
break;
} else {
assertEquals(count, 1, "update count should have been 1");
}
} else {
// process ResultSet
try (ResultSet rs = stmt.getResultSet()) {
if (rs.next()) {
String val = rs.getString(1);
assertEquals(val, "test", "read value should have been 'test'");
}
}

}
retval = stmt.getMoreResults();
} while (true);
}
} catch (SQLException e) {
fail(TestResource.getResource("R_unexpectedException") + e.getMessage());
}
}

/**
* Tests execute for Insert multiple rows followed by select
*
* @throws Exception
*/
@Test
public void testExecuteInsertManyRowsAndSelect() throws Exception {
try (Connection con = getConnection()) {
try (Statement stmt = con.createStatement()) {
boolean retval = stmt.execute("INSERT INTO " + tableName + " SELECT NAME FROM " + tableName + " SELECT NAME FROM " + tableName + " WHERE ID = 1");
do {
if (retval == false) {
int count = stmt.getUpdateCount();
if (count == -1) {
// no more results
break;
} else {
assertEquals(count, 3, "update count should have been 6");
}
} else {
// process ResultSet
try (ResultSet rs = stmt.getResultSet()) {
if (rs.next()) {
String val = rs.getString(1);
assertEquals(val, "test", "read value should have been 'test'");
}
}

}
retval = stmt.getMoreResults();
} while (true);
}
} catch (SQLException e) {
fail(TestResource.getResource("R_unexpectedException") + e.getMessage());
}
}

/**
* Tests execute two Inserts followed by select
*
* @throws Exception
*/
@Test
public void testExecuteTwoInsertsRowsAndSelect() throws Exception {
try (Connection con = getConnection()) {
try(Statement stmt = con.createStatement()) {
boolean retval = stmt.execute("INSERT INTO " + tableName + " (NAME) VALUES('test') INSERT INTO " + tableName + " (NAME) VALUES('test') SELECT NAME from " + tableName + " WHERE ID = 1");
do {
if (retval == false) {
int count = stmt.getUpdateCount();
if (count == -1) {
// no more results
break;
} else {
assertEquals(count, 1, "update count should have been 2");
}
} else {
// process ResultSet
try (ResultSet rs = stmt.getResultSet()) {
if (rs.next()) {
String val = rs.getString(1);
assertEquals(val, "test", "read value should have been 'test'");
}
}

}
retval = stmt.getMoreResults();
} while (true);
}
} catch (SQLException e) {
fail(TestResource.getResource("R_unexpectedException") + e.getMessage());
}
}


/**
* Tests execute for Update followed by select
*
* @throws Exception
*/
@Test
public void testExecuteUpdAndSelect() throws Exception {
try (Connection con = getConnection()) {
try(Statement stmt = con.createStatement()) {
boolean retval = stmt.execute("UPDATE " + tableName +" SET NAME = 'test' SELECT NAME FROM " + tableName + " WHERE ID = 1");
do {
if (retval == false) {
int count = stmt.getUpdateCount();
if (count == -1) {
// no more results
break;
} else {
assertEquals(count, 3, "update count should have been 3");
}
} else {
// process ResultSet
try (ResultSet rs = stmt.getResultSet()) {
if (rs.next()) {
String val = rs.getString(1);
assertEquals(val, "test", "read value should have been 'test'");
}
}
}
retval = stmt.getMoreResults();
} while (true);
}
} catch (SQLException e) {
fail(TestResource.getResource("R_unexpectedException") + e.getMessage());
}
}

/**
* Tests execute for Update followed by select
*
* @throws Exception
*/
@Test
public void testExecuteDelAndSelect() throws Exception {
try (Connection con = getConnection()) {
try(Statement stmt = con.createStatement()) {
boolean retval = stmt.execute("DELETE FROM " + tableName +" WHERE ID = 1 SELECT NAME FROM " + tableName + " WHERE ID = 2");
do {
if (retval == false) {
int count = stmt.getUpdateCount();
if (count == -1) {
// no more results
break;
} else {
assertEquals(count, 1, "update count should have been 1");
}
} else {
// process ResultSet
try (ResultSet rs = stmt.getResultSet()) {
if (rs.next()) {
String val = rs.getString(1);
assertEquals(val, "test", "read value should have been 'test'");
}
}
}
retval = stmt.getMoreResults();
} while (true);
}
} catch (SQLException e) {
fail(TestResource.getResource("R_unexpectedException") + e.getMessage());
}
}

@AfterEach
public void terminate() throws Exception {
try (Connection con = getConnection(); Statement stmt = con.createStatement()) {
try {
TestUtils.dropTriggerIfExists(triggerName, stmt);
TestUtils.dropTableIfExists(idTableName, stmt);
TestUtils.dropTableIfExists(tableName, stmt);
} catch (SQLException e) {
fail(TestResource.getResource("R_unexpectedException") + e.getMessage());
}
}
}
}

}
Loading