diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_api.cc b/cpp/src/arrow/flight/sql/odbc/odbc_api.cc index 01780f0efe2..ea3c289473b 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_api.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_api.cc @@ -214,8 +214,45 @@ SQLRETURN SQLFreeHandle(SQLSMALLINT type, SQLHANDLE handle) { SQLRETURN SQLFreeStmt(SQLHSTMT handle, SQLUSMALLINT option) { ARROW_LOG(DEBUG) << "SQLFreeStmt called with handle: " << handle << ", option: " << option; - // GH-47706 TODO: Implement SQLFreeStmt - return SQL_INVALID_HANDLE; + + switch (option) { + case SQL_CLOSE: { + using ODBC::ODBCStatement; + + return ODBCStatement::ExecuteWithDiagnostics(handle, SQL_ERROR, [=]() { + ODBCStatement* statement = reinterpret_cast(handle); + + // Close cursor with suppressErrors set to true + statement->CloseCursor(true); + + return SQL_SUCCESS; + }); + } + + case SQL_DROP: { + return SQLFreeHandle(SQL_HANDLE_STMT, handle); + } + + case SQL_UNBIND: { + // GH-47716 TODO: Add tests for SQLBindCol unbinding + using ODBC::ODBCDescriptor; + using ODBC::ODBCStatement; + return ODBCStatement::ExecuteWithDiagnostics(handle, SQL_ERROR, [=]() { + ODBCStatement* statement = reinterpret_cast(handle); + ODBCDescriptor* ard = statement->GetARD(); + // Unbind columns + ard->SetHeaderField(SQL_DESC_COUNT, (void*)0, 0); + return SQL_SUCCESS; + }); + } + + // SQLBindParameter is not supported + case SQL_RESET_PARAMS: { + return SQL_SUCCESS; + } + } + + return SQL_ERROR; } inline bool IsValidStringFieldArgs(SQLPOINTER diag_info_ptr, SQLSMALLINT buffer_length, diff --git a/cpp/src/arrow/flight/sql/odbc/tests/connection_test.cc b/cpp/src/arrow/flight/sql/odbc/tests/connection_test.cc index c5646b42bef..dc687ffa724 100644 --- a/cpp/src/arrow/flight/sql/odbc/tests/connection_test.cc +++ b/cpp/src/arrow/flight/sql/odbc/tests/connection_test.cc @@ -220,4 +220,46 @@ TEST(SQLSetEnvAttr, TestSQLSetEnvAttrNullValuePointer) { ASSERT_EQ(SQL_SUCCESS, SQLFreeEnv(env)); } +TYPED_TEST(ConnectionTest, TestSQLAllocFreeStmt) { + SQLHSTMT statement; + + // Allocate a statement using alloc statement + ASSERT_EQ(SQL_SUCCESS, SQLAllocStmt(this->conn, &statement)); + + SQLWCHAR sql_buffer[kOdbcBufferSize] = L"SELECT 1"; + ASSERT_EQ(SQL_SUCCESS, SQLExecDirect(statement, sql_buffer, SQL_NTS)); + + // Close statement handle + ASSERT_EQ(SQL_SUCCESS, SQLFreeStmt(statement, SQL_CLOSE)); + + // Free statement handle + ASSERT_EQ(SQL_SUCCESS, SQLFreeStmt(statement, SQL_DROP)); +} + +TYPED_TEST(ConnectionHandleTest, TestCloseConnectionWithOpenStatement) { + SQLHSTMT statement; + + // Connect string + std::string connect_str = this->GetConnectionString(); + ASSERT_OK_AND_ASSIGN(std::wstring wconnect_str, + arrow::util::UTF8ToWideString(connect_str)); + std::vector connect_str0(wconnect_str.begin(), wconnect_str.end()); + + SQLWCHAR out_str[kOdbcBufferSize] = L""; + SQLSMALLINT out_str_len; + + // Connecting to ODBC server. + ASSERT_EQ(SQL_SUCCESS, + SQLDriverConnect(this->conn, NULL, &connect_str0[0], + static_cast(connect_str0.size()), out_str, + kOdbcBufferSize, &out_str_len, SQL_DRIVER_NOPROMPT)) + << GetOdbcErrorMessage(SQL_HANDLE_DBC, this->conn); + + // Allocate a statement using alloc statement + ASSERT_EQ(SQL_SUCCESS, SQLAllocStmt(this->conn, &statement)); + + // Disconnect from ODBC without closing the statement first + ASSERT_EQ(SQL_SUCCESS, SQLDisconnect(this->conn)); +} + } // namespace arrow::flight::sql::odbc