@@ -657,30 +657,25 @@ func TestMinMaxTLSVersions(t *testing.T) {
657657
658658func TestTLSSettingValidate (t * testing.T ) {
659659 tests := []struct {
660- name string
661- minVersion string
662- maxVersion string
663- errorTxt string
660+ name string
661+ tlsConfig TLSSetting
662+ errorTxt string
664663 }{
665- {name : `TLS Config ["", ""] to be valid` , minVersion : "" , maxVersion : "" },
666- {name : `TLS Config ["", "1.3"] to be valid` , minVersion : "" , maxVersion : "1.3" },
667- {name : `TLS Config ["1.2", ""] to be valid` , minVersion : "1.2" , maxVersion : "" },
668- {name : `TLS Config ["1.3", "1.3"] to be valid` , minVersion : "1.3" , maxVersion : "1.3" },
669- {name : `TLS Config ["1.0", "1.1"] to be valid` , minVersion : "1.0" , maxVersion : "1.1" },
670- {name : `TLS Config ["asd", ""] to give [Error]` , minVersion : "asd" , maxVersion : "" , errorTxt : `invalid TLS min_version: unsupported TLS version: "asd"` },
671- {name : `TLS Config ["", "asd"] to give [Error]` , minVersion : "" , maxVersion : "asd" , errorTxt : `invalid TLS max_version: unsupported TLS version: "asd"` },
672- {name : `TLS Config ["0.4", ""] to give [Error]` , minVersion : "0.4" , maxVersion : "" , errorTxt : `invalid TLS min_version: unsupported TLS version: "0.4"` },
673- {name : `TLS Config ["1.2", "1.1"] to give [Error]` , minVersion : "1.2" , maxVersion : "1.1" , errorTxt : `invalid TLS configuration: min_version cannot be greater than max_version` },
664+ {name : `TLS Config ["", ""] to be valid` , tlsConfig : Config {MinVersion : "" , MaxVersion : "" }},
665+ {name : `TLS Config ["", "1.3"] to be valid` , tlsConfig : Config {MinVersion : "" , MaxVersion : "1.3" }},
666+ {name : `TLS Config ["1.2", ""] to be valid` , tlsConfig : Config {MinVersion : "1.2" , MaxVersion : "" }},
667+ {name : `TLS Config ["1.3", "1.3"] to be valid` , tlsConfig : Config {MinVersion : "1.3" , MaxVersion : "1.3" }},
668+ {name : `TLS Config ["1.0", "1.1"] to be valid` , tlsConfig : Config {MinVersion : "1.0" , MaxVersion : "1.1" }},
669+ {name : `TLS Config ["asd", ""] to give [Error]` , tlsConfig : Config {MinVersion : "asd" , MaxVersion : "" }, errorTxt : `invalid TLS min_version: unsupported TLS version: "asd"` },
670+ {name : `TLS Config ["", "asd"] to give [Error]` , tlsConfig : Config {MinVersion : "" , MaxVersion : "asd" }, errorTxt : `invalid TLS max_version: unsupported TLS version: "asd"` },
671+ {name : `TLS Config ["0.4", ""] to give [Error]` , tlsConfig : Config {MinVersion : "0.4" , MaxVersion : "" }, errorTxt : `invalid TLS min_version: unsupported TLS version: "0.4"` },
672+ {name : `TLS Config ["1.2", "1.1"] to give [Error]` , tlsConfig : Config {MinVersion : "1.2" , MaxVersion : "1.1" }, errorTxt : `invalid TLS configuration: min_version cannot be greater than max_version` },
673+ {name : `TLS Config with both CA File and PEM` , tlsConfig : Config {CAFile : "test" , CAPem : "test" }, errorTxt : `provide either a CA file or the PEM-encoded string, but not both` },
674674 }
675675
676676 for _ , test := range tests {
677677 t .Run (test .name , func (t * testing.T ) {
678- setting := TLSSetting {
679- MinVersion : test .minVersion ,
680- MaxVersion : test .maxVersion ,
681- }
682-
683- err := setting .Validate ()
678+ err := test .tlsConfig .Validate ()
684679
685680 if test .errorTxt == "" {
686681 assert .Nil (t , err )
@@ -741,6 +736,86 @@ invalid TLS cipher suite: "BAR"`,
741736}
742737
743738func TestSystemCertPool (t * testing.T ) {
739+ anError := errors .New ("my error" )
740+ tests := []struct {
741+ name string
742+ tlsSetting TLSSetting
743+ wantErr error
744+ systemCertFn func () (* x509.CertPool , error )
745+ }{
746+ {
747+ name : "not using system cert pool" ,
748+ tlsSetting : TLSSetting {
749+ IncludeSystemCACertsPool : false ,
750+ CAFile : filepath .Join ("testdata" , "ca-1.crt" ),
751+ },
752+ wantErr : nil ,
753+ systemCertFn : x509 .SystemCertPool ,
754+ },
755+ {
756+ name : "using system cert pool" ,
757+ tlsSetting : TLSSetting {
758+ IncludeSystemCACertsPool : true ,
759+ CAFile : filepath .Join ("testdata" , "ca-1.crt" ),
760+ },
761+ wantErr : nil ,
762+ systemCertFn : x509 .SystemCertPool ,
763+ },
764+ {
765+ name : "error loading system cert pool" ,
766+ tlsSetting : TLSSetting {
767+ IncludeSystemCACertsPool : true ,
768+ CAFile : filepath .Join ("testdata" , "ca-1.crt" ),
769+ },
770+ wantErr : anError ,
771+ systemCertFn : func () (* x509.CertPool , error ) {
772+ return nil , anError
773+ },
774+ },
775+ {
776+ name : "nil system cert pool" ,
777+ tlsSetting : TLSSetting {
778+ IncludeSystemCACertsPool : true ,
779+ CAFile : filepath .Join ("testdata" , "ca-1.crt" ),
780+ },
781+ wantErr : nil ,
782+ systemCertFn : func () (* x509.CertPool , error ) {
783+ return nil , nil
784+ },
785+ },
786+ }
787+ for _ , test := range tests {
788+ t .Run (test .name , func (t * testing.T ) {
789+ oldSystemCertPool := systemCertPool
790+ systemCertPool = test .systemCertFn
791+ defer func () {
792+ systemCertPool = oldSystemCertPool
793+ }()
794+
795+ serverConfig := ServerConfig {
796+ TLSSetting : test .tlsSetting ,
797+ }
798+ c , err := serverConfig .LoadTLSConfig ()
799+ if test .wantErr != nil {
800+ assert .ErrorContains (t , err , test .wantErr .Error ())
801+ } else {
802+ assert .NotNil (t , c .RootCAs )
803+ }
804+
805+ clientConfig := ClientConfig {
806+ TLSSetting : test .tlsSetting ,
807+ }
808+ c , err = clientConfig .LoadTLSConfig ()
809+ if test .wantErr != nil {
810+ assert .ErrorContains (t , err , test .wantErr .Error ())
811+ } else {
812+ assert .NotNil (t , c .RootCAs )
813+ }
814+ })
815+ }
816+ }
817+
818+ func TestSystemCertPool_loadCert (t * testing.T ) {
744819 anError := errors .New ("my error" )
745820 tests := []struct {
746821 name string
0 commit comments