Skip to content

Commit 544d040

Browse files
authored
[jwe] Work with non X25519 ECDH encryption (#1442)
* WIP code * more WIP * [jwe] Work with non X25519 ECDH keys * go fmt * omit non-pointer test https://github.com/lestrrat-go/jwx/actions/runs/16959725199/job/48069344266 This test is not _that_ important. * reduce code duplication * typo * Update Changes
1 parent def5218 commit 544d040

6 files changed

Lines changed: 374 additions & 8 deletions

File tree

Changes

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@ Changes
44
v3 has many incompatibilities with v2. To see the full list of differences between
55
v2 and v3, please read the Changes-v3.md file (https://github.com/lestrrat-go/jwx/blob/develop/v3/Changes-v3.md)
66

7+
v3.0.11 UNRELEASED
8+
* [jwe] Previously, ecdh.PrivateKey/ecdh.PublicKey were not properly handled
9+
when used for encryption, which has been fixed.
10+
711
v3.0.10 04 Aug 2025
812
* [jws/jwsbb] Add `jwsbb.ErrHeaderNotFound()` to return the same error type as when
913
a non-existent header is requested. via `HeaderGetXXX()` functions. Previously, this

internal/keyconv/keyconv.go

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@ import (
55
"crypto/ecdh"
66
"crypto/ecdsa"
77
"crypto/ed25519"
8+
"crypto/elliptic"
89
"crypto/rsa"
910
"fmt"
11+
"math/big"
1012

1113
"github.com/lestrrat-go/blackmagic"
1214
"github.com/lestrrat-go/jwx/v3/jwk"
@@ -263,3 +265,90 @@ func ECDHPublicKey(dst, src any) error {
263265

264266
return blackmagic.AssignIfCompatible(dst, pubECDH)
265267
}
268+
269+
// ecdhCurveToElliptic maps ECDH curves to elliptic curves
270+
func ecdhCurveToElliptic(ecdhCurve ecdh.Curve) (elliptic.Curve, error) {
271+
switch ecdhCurve {
272+
case ecdh.P256():
273+
return elliptic.P256(), nil
274+
case ecdh.P384():
275+
return elliptic.P384(), nil
276+
case ecdh.P521():
277+
return elliptic.P521(), nil
278+
default:
279+
return nil, fmt.Errorf(`keyconv: unsupported ECDH curve: %v`, ecdhCurve)
280+
}
281+
}
282+
283+
// ecdhPublicKeyToECDSA converts an ECDH public key to an ECDSA public key
284+
func ecdhPublicKeyToECDSA(ecdhPubKey *ecdh.PublicKey) (*ecdsa.PublicKey, error) {
285+
curve, err := ecdhCurveToElliptic(ecdhPubKey.Curve())
286+
if err != nil {
287+
return nil, err
288+
}
289+
290+
pubBytes := ecdhPubKey.Bytes()
291+
292+
// Parse the uncompressed point format (0x04 prefix + X + Y coordinates)
293+
if len(pubBytes) == 0 || pubBytes[0] != 0x04 {
294+
return nil, fmt.Errorf(`keyconv: invalid ECDH public key format`)
295+
}
296+
297+
keyLen := (len(pubBytes) - 1) / 2
298+
if len(pubBytes) != 1+2*keyLen {
299+
return nil, fmt.Errorf(`keyconv: invalid ECDH public key length`)
300+
}
301+
302+
x := new(big.Int).SetBytes(pubBytes[1 : 1+keyLen])
303+
y := new(big.Int).SetBytes(pubBytes[1+keyLen:])
304+
305+
return &ecdsa.PublicKey{
306+
Curve: curve,
307+
X: x,
308+
Y: y,
309+
}, nil
310+
}
311+
312+
func ECDHToECDSA(dst, src any) error {
313+
// convert ecdh.PublicKey to ecdsa.PublicKey, ecdh.PrivateKey to ecdsa.PrivateKey
314+
315+
// First, handle value types by converting to pointers
316+
switch s := src.(type) {
317+
case ecdh.PrivateKey:
318+
src = &s
319+
case ecdh.PublicKey:
320+
src = &s
321+
}
322+
323+
var privBytes []byte
324+
var pubkey *ecdh.PublicKey
325+
// Now handle the actual conversion with pointer types
326+
switch src := src.(type) {
327+
case *ecdh.PrivateKey:
328+
pubkey = src.PublicKey()
329+
privBytes = src.Bytes()
330+
case *ecdh.PublicKey:
331+
pubkey = src
332+
default:
333+
return fmt.Errorf(`keyconv: expected ecdh.PrivateKey, *ecdh.PrivateKey, ecdh.PublicKey, or *ecdh.PublicKey, got %T`, src)
334+
}
335+
336+
// convert the public key
337+
ecdsaPubKey, err := ecdhPublicKeyToECDSA(pubkey)
338+
if err != nil {
339+
return fmt.Errorf(`keyconv.ECDHToECDSA: failed to convert ECDH public key to ECDSA public key: %w`, err)
340+
}
341+
342+
// return if we were being asked to convert *ecdh.PublicKey
343+
if privBytes == nil {
344+
return blackmagic.AssignIfCompatible(dst, ecdsaPubKey)
345+
}
346+
347+
// Then create the private key with the public key embedded
348+
ecdsaPrivKey := &ecdsa.PrivateKey{
349+
D: new(big.Int).SetBytes(privBytes),
350+
PublicKey: *ecdsaPubKey,
351+
}
352+
353+
return blackmagic.AssignIfCompatible(dst, ecdsaPrivKey)
354+
}

internal/keyconv/keyconv_test.go

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
package keyconv_test
22

33
import (
4+
"crypto/ecdh"
45
"crypto/ecdsa"
6+
"crypto/rand"
57
"crypto/rsa"
68
"testing"
79

@@ -205,3 +207,126 @@ func TestKeyconv(t *testing.T) {
205207
})
206208
})
207209
}
210+
211+
func TestECDHToECDSA(t *testing.T) {
212+
curves := []struct {
213+
name string
214+
ecdhCurve ecdh.Curve
215+
jwaAlg jwa.EllipticCurveAlgorithm
216+
}{
217+
{"P256", ecdh.P256(), jwa.P256()},
218+
{"P384", ecdh.P384(), jwa.P384()},
219+
{"P521", ecdh.P521(), jwa.P521()},
220+
}
221+
222+
for _, curve := range curves {
223+
t.Run(curve.name, func(t *testing.T) {
224+
// Generate an ECDSA key for comparison
225+
ecdsaKey, err := jwxtest.GenerateEcdsaKey(curve.jwaAlg)
226+
require.NoError(t, err, `ecdsa.GenerateKey should succeed`)
227+
228+
// Convert ECDSA key to ECDH key
229+
ecdhPrivKey, err := ecdsaKey.ECDH()
230+
require.NoError(t, err, `ECDSA to ECDH conversion should succeed`)
231+
232+
ecdhPubKey := ecdhPrivKey.PublicKey()
233+
234+
t.Run("PrivateKey", func(t *testing.T) {
235+
testcases := []struct {
236+
name string
237+
src any
238+
error bool
239+
}{
240+
{"*ecdh.PrivateKey", ecdhPrivKey, false},
241+
{"invalid type", "not a key", true},
242+
}
243+
244+
for _, tc := range testcases {
245+
t.Run(tc.name, func(t *testing.T) {
246+
var dst *ecdsa.PrivateKey
247+
err := keyconv.ECDHToECDSA(&dst, tc.src)
248+
249+
if tc.error {
250+
require.Error(t, err, `ECDHToECDSA should fail for invalid input`)
251+
} else {
252+
require.NoError(t, err, `ECDHToECDSA should succeed`)
253+
require.NotNil(t, dst, `destination should not be nil`)
254+
255+
// Verify the converted key has the same curve
256+
require.Equal(t, ecdsaKey.Curve, dst.Curve, `curves should match`)
257+
258+
// Verify the private key values match
259+
require.Equal(t, ecdsaKey.D, dst.D, `private key values should match`)
260+
261+
// Verify the public key coordinates match
262+
require.Equal(t, ecdsaKey.PublicKey.X, dst.PublicKey.X, `X coordinates should match`)
263+
require.Equal(t, ecdsaKey.PublicKey.Y, dst.PublicKey.Y, `Y coordinates should match`)
264+
}
265+
})
266+
}
267+
})
268+
269+
t.Run("PublicKey", func(t *testing.T) {
270+
testcases := []struct {
271+
name string
272+
src any
273+
error bool
274+
}{
275+
{"*ecdh.PublicKey", ecdhPubKey, false},
276+
{"ecdh.PublicKey", *ecdhPubKey, false},
277+
{"invalid type", "not a key", true},
278+
}
279+
280+
for _, tc := range testcases {
281+
t.Run(tc.name, func(t *testing.T) {
282+
var dst *ecdsa.PublicKey
283+
err := keyconv.ECDHToECDSA(&dst, tc.src)
284+
285+
if tc.error {
286+
require.Error(t, err, `ECDHToECDSA should fail for invalid input`)
287+
} else {
288+
require.NoError(t, err, `ECDHToECDSA should succeed`)
289+
require.NotNil(t, dst, `destination should not be nil`)
290+
291+
// Verify the converted key has the same curve
292+
require.Equal(t, ecdsaKey.PublicKey.Curve, dst.Curve, `curves should match`)
293+
294+
// Verify the public key coordinates match
295+
require.Equal(t, ecdsaKey.PublicKey.X, dst.X, `X coordinates should match`)
296+
require.Equal(t, ecdsaKey.PublicKey.Y, dst.Y, `Y coordinates should match`)
297+
}
298+
})
299+
}
300+
})
301+
302+
t.Run("RoundTrip", func(t *testing.T) {
303+
// Test that ECDSA -> ECDH -> ECDSA produces the same key
304+
var convertedPrivKey *ecdsa.PrivateKey
305+
err := keyconv.ECDHToECDSA(&convertedPrivKey, ecdhPrivKey)
306+
require.NoError(t, err, `ECDHToECDSA should succeed`)
307+
308+
var convertedPubKey *ecdsa.PublicKey
309+
err = keyconv.ECDHToECDSA(&convertedPubKey, ecdhPubKey)
310+
require.NoError(t, err, `ECDHToECDSA should succeed`)
311+
312+
// Verify the keys are equivalent
313+
require.Equal(t, ecdsaKey.D, convertedPrivKey.D, `private key values should match`)
314+
require.Equal(t, ecdsaKey.PublicKey.X, convertedPrivKey.PublicKey.X, `private key X coordinates should match`)
315+
require.Equal(t, ecdsaKey.PublicKey.Y, convertedPrivKey.PublicKey.Y, `private key Y coordinates should match`)
316+
require.Equal(t, ecdsaKey.PublicKey.X, convertedPubKey.X, `public key X coordinates should match`)
317+
require.Equal(t, ecdsaKey.PublicKey.Y, convertedPubKey.Y, `public key Y coordinates should match`)
318+
})
319+
})
320+
}
321+
322+
t.Run("UnsupportedCurve", func(t *testing.T) {
323+
// Create a mock ECDH key with X25519 curve (not supported for ECDSA)
324+
x25519Key, err := ecdh.X25519().GenerateKey(rand.Reader)
325+
require.NoError(t, err, `X25519 key generation should succeed`)
326+
327+
var dst *ecdsa.PrivateKey
328+
err = keyconv.ECDHToECDSA(&dst, x25519Key)
329+
require.Error(t, err, `ECDHToECDSA should fail for unsupported curve`)
330+
require.Contains(t, err.Error(), "unsupported ECDH curve", `error should mention unsupported curve`)
331+
})
332+
}

jwe/encrypt.go

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -96,23 +96,48 @@ func (e *encrypter) EncryptKey(cek []byte) (keygen.ByteSource, error) {
9696
keyToUse = e.pubkey
9797
}
9898

99-
// Handle ecdsa.PublicKey by value - convert to pointer
100-
if pk, ok := keyToUse.(ecdsa.PublicKey); ok {
101-
keyToUse = &pk
99+
switch key := keyToUse.(type) {
100+
case *ecdsa.PublicKey:
101+
// no op
102+
case ecdsa.PublicKey:
103+
keyToUse = &key
104+
case *ecdsa.PrivateKey:
105+
keyToUse = &key.PublicKey
106+
case ecdsa.PrivateKey:
107+
keyToUse = &key.PublicKey
108+
case *ecdh.PublicKey:
109+
// no op
110+
case ecdh.PublicKey:
111+
keyToUse = &key
112+
case ecdh.PrivateKey:
113+
keyToUse = key.PublicKey()
114+
case *ecdh.PrivateKey:
115+
keyToUse = key.PublicKey()
102116
}
103117

104118
// Determine key type and call appropriate function
119+
switch key := keyToUse.(type) {
120+
case *ecdh.PublicKey:
121+
if key.Curve() == ecdh.X25519() {
122+
if !keywrap {
123+
return jwebb.KeyEncryptECDHESX25519(cek, e.keyalg.String(), e.apu, e.apv, key, keysize, e.ctalg.String())
124+
}
125+
return jwebb.KeyEncryptECDHESKeyWrapX25519(cek, e.keyalg.String(), e.apu, e.apv, key, keysize, e.ctalg.String())
126+
}
127+
128+
var ecdsaKey *ecdsa.PublicKey
129+
if err := keyconv.ECDHToECDSA(&ecdsaKey, key); err != nil {
130+
return nil, fmt.Errorf(`encrypt: failed to convert ECDH public key to ECDSA: %w`, err)
131+
}
132+
keyToUse = ecdsaKey
133+
}
134+
105135
switch key := keyToUse.(type) {
106136
case *ecdsa.PublicKey:
107137
if !keywrap {
108138
return jwebb.KeyEncryptECDHESECDSA(cek, e.keyalg.String(), e.apu, e.apv, key, keysize, e.ctalg.String())
109139
}
110140
return jwebb.KeyEncryptECDHESKeyWrapECDSA(cek, e.keyalg.String(), e.apu, e.apv, key, keysize, e.ctalg.String())
111-
case *ecdh.PublicKey:
112-
if !keywrap {
113-
return jwebb.KeyEncryptECDHESX25519(cek, e.keyalg.String(), e.apu, e.apv, key, keysize, e.ctalg.String())
114-
}
115-
return jwebb.KeyEncryptECDHESKeyWrapX25519(cek, e.keyalg.String(), e.apu, e.apv, key, keysize, e.ctalg.String())
116141
default:
117142
return nil, fmt.Errorf(`encrypt: unsupported key type for ECDH-ES: %T`, keyToUse)
118143
}

jwe/jwe_test.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1008,3 +1008,18 @@ func BenchmarkParseCompat(b *testing.B) {
10081008
}
10091009
}
10101010
}
1011+
1012+
func TestGH1434(t *testing.T) {
1013+
// Test if we can use ecdh.PrivateKey to encrypt and decrypt
1014+
1015+
key, err := ecdh.P256().GenerateKey(rand.Reader)
1016+
require.NoError(t, err, `ecdh.P256().GenerateKey should succeed`)
1017+
1018+
const payload = `hello, world!`
1019+
encrypted, err := jwe.Encrypt([]byte(payload), jwe.WithKey(jwa.ECDH_ES_A256KW(), key))
1020+
require.NoError(t, err, `jwe.Encrypt should succeed`)
1021+
1022+
decrypted, err := jwe.Decrypt(encrypted, jwe.WithKey(jwa.ECDH_ES_A256KW(), key))
1023+
require.NoError(t, err, `jwe.Decrypt should succeed`)
1024+
require.Equal(t, []byte(payload), decrypted, `decrypted payload should match original payload`)
1025+
}

0 commit comments

Comments
 (0)