Skip to content

Commit e50f7e3

Browse files
authored
fix: Make mirror node security protocol rely on port number (#1507)
* feat: add GetMirrorBaseUrl func Signed-off-by: Ivan Ivanov <[email protected]> * fix: use port 5551 for localhost Signed-off-by: Ivan Ivanov <[email protected]> * test: unit Signed-off-by: Ivan Ivanov <[email protected]> * test: unit Signed-off-by: Ivan Ivanov <[email protected]> * chore: refactor Signed-off-by: Ivan Ivanov <[email protected]> --------- Signed-off-by: Ivan Ivanov <[email protected]>
1 parent 87f7a3c commit e50f7e3

11 files changed

+760
-44
lines changed

sdk/account_id.go

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -342,29 +342,18 @@ func (id *AccountID) _MirrorNodeRequest(client *Client, populateType string) (ma
342342
return nil, errors.New("mirror node is not set")
343343
}
344344

345-
mirrorUrl := client.GetMirrorNetwork()[0]
346-
index := strings.Index(mirrorUrl, ":")
347-
if index == -1 {
348-
return nil, errors.New("invalid mirrorUrl format")
349-
}
350-
mirrorUrl = mirrorUrl[:index]
351-
352-
var url string
353-
protocol := "https"
354-
port := ""
355-
356-
if client.GetLedgerID() == nil {
357-
protocol = "http"
358-
port = ":5551"
345+
mirrorUrl, err := client.GetMirrorRestApiBaseUrl()
346+
if err != nil {
347+
return nil, err
359348
}
360349

361350
if populateType == "account" {
362-
url = fmt.Sprintf("%s://%s%s/api/v1/accounts/%s", protocol, mirrorUrl, port, hex.EncodeToString(*id.AliasEvmAddress))
351+
mirrorUrl = fmt.Sprintf("%s/accounts/%s", mirrorUrl, hex.EncodeToString(*id.AliasEvmAddress))
363352
} else {
364-
url = fmt.Sprintf("%s://%s%s/api/v1/accounts/%s", protocol, mirrorUrl, port, id.String())
353+
mirrorUrl = fmt.Sprintf("%s/accounts/%s", mirrorUrl, id.String())
365354
}
366355

367-
resp, err := http.Get(url) // #nosec
356+
resp, err := http.Get(mirrorUrl) // #nosec
368357
if err != nil {
369358
return nil, err
370359
}

sdk/account_id_unit_test.go

Lines changed: 126 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@ package hiero
66
// SPDX-License-Identifier: Apache-2.0
77

88
import (
9+
"encoding/hex"
10+
"encoding/json"
11+
"net/http"
12+
"net/http/httptest"
913
"strings"
1014
"testing"
1115

1216
"github.com/stretchr/testify/assert"
13-
14-
"encoding/hex"
15-
1617
"github.com/stretchr/testify/require"
1718
)
1819

@@ -276,3 +277,125 @@ func TestUnitAccountIDToEvmAddress(t *testing.T) {
276277
id = AccountID{Shard: 1, Realm: 1, AliasEvmAddress: &bytes}
277278
require.Equal(t, expected, id.ToEvmAddress())
278279
}
280+
281+
func TestUnitAccountIDPopulateWithDifferentPorts(t *testing.T) {
282+
// Note: Not running in parallel since we modify global http.DefaultTransport
283+
284+
tests := []struct {
285+
name string
286+
domain string
287+
expectedScheme string
288+
description string
289+
}{
290+
{
291+
name: "port 80 uses HTTP",
292+
domain: "mirror80.example.com:80",
293+
expectedScheme: "http",
294+
description: "Port 80 should use HTTP scheme",
295+
},
296+
{
297+
name: "port 443 uses HTTPS",
298+
domain: "mirror443.example.com:443",
299+
expectedScheme: "https",
300+
description: "Port 443 should use HTTPS scheme",
301+
},
302+
{
303+
name: "port 8443 uses HTTPS",
304+
domain: "mirror8443.example.com:8443",
305+
expectedScheme: "https",
306+
description: "Other ports should use HTTPS scheme for security",
307+
},
308+
{
309+
name: "port 9999 uses HTTPS",
310+
domain: "mirror9999.example.com:9999",
311+
expectedScheme: "https",
312+
description: "Any non-standard port should use HTTPS scheme",
313+
},
314+
}
315+
316+
for _, test := range tests {
317+
t.Run(test.name, func(t *testing.T) {
318+
t.Run("PopulateAccount", func(t *testing.T) {
319+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
320+
assert.Contains(t, r.URL.Path, "accounts")
321+
322+
response := map[string]interface{}{
323+
"account": "0.0.12345",
324+
}
325+
w.Header().Set("Content-Type", "application/json")
326+
err := json.NewEncoder(w).Encode(response)
327+
require.NoError(t, err)
328+
}))
329+
defer server.Close()
330+
331+
// Setup mock transport
332+
cleanup := SetupMockTransportForDomain(test.domain, server.URL)
333+
defer cleanup()
334+
335+
// Setup client with the test domain as the mirror network
336+
client, err := _NewMockClient()
337+
require.NoError(t, err)
338+
client.SetLedgerID(*NewLedgerIDTestnet())
339+
client.SetMirrorNetwork([]string{test.domain})
340+
341+
// Create an account ID with EVM address
342+
evmAddressBytes, err := hex.DecodeString(evmAddress)
343+
require.NoError(t, err)
344+
accountID := AccountID{
345+
Shard: 0,
346+
Realm: 0,
347+
Account: 0,
348+
AliasEvmAddress: &evmAddressBytes,
349+
}
350+
351+
// Test PopulateAccount
352+
err = accountID.PopulateAccount(client)
353+
require.NoError(t, err, "PopulateAccount should succeed for %s", test.description)
354+
assert.Equal(t, uint64(12345), accountID.Account)
355+
})
356+
357+
// Test PopulateEvmAddress
358+
t.Run("PopulateEvmAddress", func(t *testing.T) {
359+
// Create a mock server that responds with EVM address data
360+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
361+
// Verify the request path
362+
assert.Contains(t, r.URL.Path, "accounts/0.0.789")
363+
364+
response := map[string]interface{}{
365+
"evm_address": "0x" + evmAddress,
366+
}
367+
w.Header().Set("Content-Type", "application/json")
368+
err := json.NewEncoder(w).Encode(response)
369+
require.NoError(t, err)
370+
}))
371+
defer server.Close()
372+
373+
// Setup mock transport
374+
cleanup := SetupMockTransportForDomain(test.domain, server.URL)
375+
defer cleanup()
376+
377+
// Setup client with the test domain as the mirror network
378+
client, err := _NewMockClient()
379+
require.NoError(t, err)
380+
client.SetLedgerID(*NewLedgerIDTestnet())
381+
client.SetMirrorNetwork([]string{test.domain})
382+
383+
// Create an account ID with account number
384+
accountID := AccountID{
385+
Shard: 0,
386+
Realm: 0,
387+
Account: 789,
388+
}
389+
390+
// Test PopulateEvmAddress
391+
err = accountID.PopulateEvmAddress(client)
392+
require.NoError(t, err, "PopulateEvmAddress should succeed for %s", test.description)
393+
require.NotNil(t, accountID.AliasEvmAddress)
394+
395+
expectedBytes, err := hex.DecodeString(evmAddress)
396+
require.NoError(t, err)
397+
assert.Equal(t, expectedBytes, *accountID.AliasEvmAddress)
398+
})
399+
})
400+
}
401+
}

sdk/client.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -788,3 +788,11 @@ func (client *Client) SetLogLevel(level LogLevel) *Client {
788788
client.logger.SetLevel(level)
789789
return client
790790
}
791+
792+
func (client *Client) GetMirrorRestApiBaseUrl() (string, error) {
793+
mirrorNode, err := client.mirrorNetwork._GetNextMirrorNode()
794+
if err != nil {
795+
return "", err
796+
}
797+
return mirrorNode.getBaseRestUrl()
798+
}

sdk/client_unit_test.go

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ package hiero
77

88
import (
99
"bytes"
10+
"net/url"
1011
"testing"
1112
"time"
1213

@@ -370,3 +371,94 @@ func TestUnitClientForNetworkV2(t *testing.T) {
370371
require.Error(t, err)
371372
assert.Equal(t, err.Error(), "network is empty")
372373
}
374+
375+
func TestUnitClientGetMirrorRestApiBaseUrl(t *testing.T) {
376+
t.Parallel()
377+
378+
tests := []struct {
379+
name string
380+
domain string
381+
expectedScheme string
382+
}{
383+
{
384+
name: "HTTP for port 80",
385+
domain: "mirror.example.com:80",
386+
expectedScheme: "http",
387+
},
388+
{
389+
name: "HTTPS for port 443",
390+
domain: "mirror.example.com:443",
391+
expectedScheme: "https",
392+
},
393+
{
394+
name: "HTTPS for custom port",
395+
domain: "mirror.example.com:8080",
396+
expectedScheme: "https",
397+
},
398+
}
399+
400+
for _, test := range tests {
401+
t.Run(test.name, func(t *testing.T) {
402+
client, err := _NewMockClient()
403+
require.NoError(t, err)
404+
client.SetLedgerID(*NewLedgerIDTestnet())
405+
client.SetMirrorNetwork([]string{test.domain})
406+
407+
baseURL, err := client.GetMirrorRestApiBaseUrl()
408+
require.NoError(t, err)
409+
410+
parsedURL, err := url.Parse(baseURL)
411+
require.NoError(t, err)
412+
assert.Equal(t, test.expectedScheme, parsedURL.Scheme)
413+
414+
assert.Equal(t, test.domain, parsedURL.Host)
415+
assert.Equal(t, "/api/v1", parsedURL.Path)
416+
})
417+
}
418+
}
419+
420+
func TestUnitClientGetMirrorRestApiBaseUrlLocalHost(t *testing.T) {
421+
t.Parallel()
422+
423+
tests := []struct {
424+
name string
425+
domain string
426+
expectedScheme string
427+
expectedPort string
428+
}{
429+
{
430+
name: "HTTP local host",
431+
domain: "localhost:80",
432+
expectedScheme: "http",
433+
},
434+
{
435+
name: "HTTP 127.0.0.1",
436+
domain: "127.0.0.1:8080",
437+
expectedScheme: "http",
438+
},
439+
}
440+
441+
for _, test := range tests {
442+
t.Run(test.name, func(t *testing.T) {
443+
client, err := _NewMockClient()
444+
require.NoError(t, err)
445+
client.SetLedgerID(*NewLedgerIDTestnet())
446+
client.SetMirrorNetwork([]string{test.domain})
447+
448+
baseURL, err := client.GetMirrorRestApiBaseUrl()
449+
require.NoError(t, err)
450+
451+
parsedURL, err := url.Parse(baseURL)
452+
require.NoError(t, err)
453+
assert.Equal(t, test.expectedScheme, parsedURL.Scheme)
454+
455+
if test.domain == "localhost:80" {
456+
assert.Equal(t, "localhost:5551", parsedURL.Host)
457+
} else {
458+
assert.Equal(t, "127.0.0.1:5551", parsedURL.Host)
459+
}
460+
461+
assert.Equal(t, "/api/v1", parsedURL.Path)
462+
})
463+
}
464+
}

sdk/contract_id.go

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -171,16 +171,11 @@ func (id *ContractID) PopulateContract(client *Client) error {
171171
if client.mirrorNetwork == nil || len(client.GetMirrorNetwork()) == 0 {
172172
return errors.New("mirror node is not set")
173173
}
174-
mirrorUrl := client.GetMirrorNetwork()[0]
175-
index := strings.Index(mirrorUrl, ":")
176-
if index == -1 {
177-
return errors.New("invalid mirrorUrl format")
178-
}
179-
mirrorUrl = mirrorUrl[:index]
180-
url := fmt.Sprintf("https://%s/api/v1/contracts/%s", mirrorUrl, hex.EncodeToString(id.EvmAddress))
181-
if client.GetLedgerID() == nil {
182-
url = fmt.Sprintf("http://%s:5551/api/v1/contracts/%s", mirrorUrl, hex.EncodeToString(id.EvmAddress))
174+
mirrorUrl, err := client.GetMirrorRestApiBaseUrl()
175+
if err != nil {
176+
return err
183177
}
178+
url := fmt.Sprintf("%s/contracts/%s", mirrorUrl, hex.EncodeToString(id.EvmAddress))
184179

185180
resp, err := http.Get(url) // #nosec
186181
if err != nil {

0 commit comments

Comments
 (0)