diff --git a/speedtest.go b/speedtest.go index 021336c..ebb62ef 100644 --- a/speedtest.go +++ b/speedtest.go @@ -3,12 +3,14 @@ package main import ( "context" "fmt" - "gopkg.in/alecthomas/kingpin.v2" "os" + "slices" "strconv" "strings" "time" + "gopkg.in/alecthomas/kingpin.v2" + "github.com/showwin/speedtest-go/speedtest" ) @@ -31,6 +33,7 @@ var ( noUpload = kingpin.Flag("no-upload", "Disable upload test.").Bool() pingMode = kingpin.Flag("ping-mode", "Select a method for Ping. (support icmp/tcp/http)").Default("http").String() debug = kingpin.Flag("debug", "Enable debug mode.").Short('d').Bool() + countryCode = kingpin.Flag("filter-cc", "Filter servers by Country Code(s).").Strings() ) func main() { @@ -96,7 +99,16 @@ func main() { } else { servers, err = speedtestClient.FetchServers() task.CheckError(err) - task.Printf("Found %d Public Servers", len(servers)) + // cc filter auto attach + if slices.Contains(*countryCode, "auto") { + *countryCode = append(*countryCode, speedtestClient.User.Country) + } + if len(*countryCode) > 0 { + servers = servers.CC(*countryCode) + task.Printf("Found %d Public Servers with Country Code[%v]", len(servers), strings.Join(*countryCode, ",")) + } else { + task.Printf("Found %d Public Servers", len(servers)) + } if *showList { task.Complete() task.manager.Reset() diff --git a/speedtest/server.go b/speedtest/server.go index 0a592a4..7e76581 100644 --- a/speedtest/server.go +++ b/speedtest/server.go @@ -9,6 +9,7 @@ import ( "math" "net/http" "net/url" + "slices" "sort" "strconv" "strings" @@ -51,6 +52,7 @@ type Server struct { DLSpeed float64 `json:"dl_speed"` ULSpeed float64 `json:"ul_speed"` TestDuration TestDuration `json:"test_duration"` + CC string `json:"cc"` Context *Speedtest `json:"-"` } @@ -133,6 +135,28 @@ func (servers Servers) Swap(i, j int) { servers[i], servers[j] = servers[j], servers[i] } +// Filter filter by filterFunc +func (servers Servers) Filter(filterFunc func(server *Server) bool) Servers { + var retServers Servers + for i := range servers { + if filterFunc(servers[i]) { + retServers = append(retServers, servers[i]) + } + } + return retServers +} + +// CC filter by Country Code +func (servers Servers) CC(cc []string) Servers { + var upperCC []string + for i := range cc { + upperCC = append(upperCC, strings.ToUpper(cc[i])) + } + return servers.Filter(func(server *Server) bool { + return slices.Contains(upperCC, server.CC) + }) +} + // Less compares the distance. For sorting servers. func (b ByDistance) Less(i, j int) bool { return b.Servers[i].Distance < b.Servers[j].Distance @@ -212,6 +236,7 @@ func (s *Speedtest) FetchServerListContext(ctx context.Context) (Servers, error) query.Set("lat", strconv.FormatFloat(s.config.Location.Lat, 'f', -1, 64)) query.Set("lon", strconv.FormatFloat(s.config.Location.Lon, 'f', -1, 64)) } + u.RawQuery = query.Encode() dbg.Printf("Retrieving servers: %s\n", u.String()) req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil) diff --git a/speedtest/server_test.go b/speedtest/server_test.go index 0274eff..0f12ea7 100644 --- a/speedtest/server_test.go +++ b/speedtest/server_test.go @@ -204,3 +204,18 @@ func TestTotalDurationCount(t *testing.T) { t.Error("addition in testDurationTotalCount didn't work") } } + +func TestCityFlag(t *testing.T) { + testCC := "YISHUN" + testData := Servers{ + {CC: "YISHUN"}, + {CC: "TOKYO"}, + {CC: "YISHUN"}, + {CC: "TEST"}, + } + + tmpServers := testData.CC([]string{testCC}) + if tmpServers.Len() != 2 && tmpServers[0].CC != testCC { + t.Fatalf("not match: %s", testCC) + } +} diff --git a/speedtest/user.go b/speedtest/user.go index 1bb0ae0..03a2915 100644 --- a/speedtest/user.go +++ b/speedtest/user.go @@ -12,10 +12,11 @@ const speedTestConfigUrl = "https://www.speedtest.net/speedtest-config.php" // User represents information determined about the caller by speedtest.net type User struct { - IP string `xml:"ip,attr"` - Lat string `xml:"lat,attr"` - Lon string `xml:"lon,attr"` - Isp string `xml:"isp,attr"` + IP string `xml:"ip,attr"` + Lat string `xml:"lat,attr"` + Lon string `xml:"lon,attr"` + Isp string `xml:"isp,attr"` + Country string `xml:"country,attr"` } // Users for decode xml @@ -61,6 +62,9 @@ func (s *Speedtest) FetchUserInfoContext(ctx context.Context) (*User, error) { } s.User = &users.Users[0] + if s.config.Location != nil && len(s.config.Location.CC) > 0 { + s.User.Country = s.config.Location.CC + } return s.User, nil }