diff --git a/py_ecc/utils.py b/py_ecc/utils.py index 8538d56a..69b294d5 100644 --- a/py_ecc/utils.py +++ b/py_ecc/utils.py @@ -22,6 +22,12 @@ def prime_field_inv(a: int, n: int) -> int: """ Extended euclidean algorithm to find modular inverses for integers """ + # To address a == n edge case. + # https://tools.ietf.org/html/draft-irtf-cfrg-hash-to-curve-09#section-4 + # inv0(x): This function returns the multiplicative inverse of x in + # F, extended to all of F by fixing inv0(0) == 0. + a %= n + if a == 0: return 0 lm, hm = 1, 0 diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..6d8da068 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,16 @@ +import pytest + +from py_ecc.utils import prime_field_inv + + +@pytest.mark.parametrize( + 'a,n,result', + [ + (0, 7, 0), + (7, 7, 0), + (2, 7, 4), + (10, 7, 5), + ] +) +def test_prime_field_inv(a, n, result): + assert prime_field_inv(a, n) % n == result