diff options
-rw-r--r-- | Include/cpython/longintrepr.h | 6 | ||||
-rw-r--r-- | Lib/test/test_pow.py | 22 | ||||
-rw-r--r-- | Objects/longobject.c | 108 |
3 files changed, 106 insertions, 30 deletions
diff --git a/Include/cpython/longintrepr.h b/Include/cpython/longintrepr.h index ff4155f9656..68dbf9c4382 100644 --- a/Include/cpython/longintrepr.h +++ b/Include/cpython/longintrepr.h @@ -21,8 +21,6 @@ extern "C" { PyLong_SHIFT. The majority of the code doesn't care about the precise value of PyLong_SHIFT, but there are some notable exceptions: - - long_pow() requires that PyLong_SHIFT be divisible by 5 - - PyLong_{As,From}ByteArray require that PyLong_SHIFT be at least 8 - long_hash() requires that PyLong_SHIFT is *strictly* less than the number @@ -63,10 +61,6 @@ typedef long stwodigits; /* signed variant of twodigits */ #define PyLong_BASE ((digit)1 << PyLong_SHIFT) #define PyLong_MASK ((digit)(PyLong_BASE - 1)) -#if PyLong_SHIFT % 5 != 0 -#error "longobject.c requires that PyLong_SHIFT be divisible by 5" -#endif - /* Long integer representation. The absolute value of a number is equal to SUM(for i=0 through abs(ob_size)-1) ob_digit[i] * 2**(SHIFT*i) diff --git a/Lib/test/test_pow.py b/Lib/test/test_pow.py index 660ff80bbf5..5cea9ceb20f 100644 --- a/Lib/test/test_pow.py +++ b/Lib/test/test_pow.py @@ -93,6 +93,28 @@ class PowTest(unittest.TestCase): pow(int(i),j,k) ) + def test_big_exp(self): + import random + self.assertEqual(pow(2, 50000), 1 << 50000) + # Randomized modular tests, checking the identities + # a**(b1 + b2) == a**b1 * a**b2 + # a**(b1 * b2) == (a**b1)**b2 + prime = 1000000000039 # for speed, relatively small prime modulus + for i in range(10): + a = random.randrange(1000, 1000000) + bpower = random.randrange(1000, 50000) + b = random.randrange(1 << (bpower - 1), 1 << bpower) + b1 = random.randrange(1, b) + b2 = b - b1 + got1 = pow(a, b, prime) + got2 = pow(a, b1, prime) * pow(a, b2, prime) % prime + if got1 != got2: + self.fail(f"{a=:x} {b1=:x} {b2=:x} {got1=:x} {got2=:x}") + got3 = pow(a, b1 * b2, prime) + got4 = pow(pow(a, b1, prime), b2, prime) + if got3 != got4: + self.fail(f"{a=:x} {b1=:x} {b2=:x} {got3=:x} {got4=:x}") + def test_bug643260(self): class TestRpow: def __rpow__(self, other): diff --git a/Objects/longobject.c b/Objects/longobject.c index 09ae9455c5b..b5648fca7dc 100644 --- a/Objects/longobject.c +++ b/Objects/longobject.c @@ -74,12 +74,34 @@ maybe_small_long(PyLongObject *v) #define KARATSUBA_CUTOFF 70 #define KARATSUBA_SQUARE_CUTOFF (2 * KARATSUBA_CUTOFF) -/* For exponentiation, use the binary left-to-right algorithm - * unless the exponent contains more than FIVEARY_CUTOFF digits. - * In that case, do 5 bits at a time. The potential drawback is that - * a table of 2**5 intermediate results is computed. +/* For exponentiation, use the binary left-to-right algorithm unless the + ^ exponent contains more than HUGE_EXP_CUTOFF bits. In that case, do + * (no more than) EXP_WINDOW_SIZE bits at a time. The potential drawback is + * that a table of 2**(EXP_WINDOW_SIZE - 1) intermediate results is + * precomputed. */ -#define FIVEARY_CUTOFF 8 +#define EXP_WINDOW_SIZE 5 +#define EXP_TABLE_LEN (1 << (EXP_WINDOW_SIZE - 1)) +/* Suppose the exponent has bit length e. All ways of doing this + * need e squarings. The binary method also needs a multiply for + * each bit set. In a k-ary method with window width w, a multiply + * for each non-zero window, so at worst (and likely!) + * ceiling(e/w). The k-ary sliding window method has the same + * worst case, but the window slides so it can sometimes skip + * over an all-zero window that the fixed-window method can't + * exploit. In addition, the windowing methods need multiplies + * to precompute a table of small powers. + * + * For the sliding window method with width 5, 16 precomputation + * multiplies are needed. Assuming about half the exponent bits + * are set, then, the binary method needs about e/2 extra mults + * and the window method about 16 + e/5. + * + * The latter is smaller for e > 53 1/3. We don't have direct + * access to the bit length, though, so call it 60, which is a + * multiple of a long digit's max bit length (15 or 30 so far). + */ +#define HUGE_EXP_CUTOFF 60 #define SIGCHECK(PyTryBlock) \ do { \ @@ -4172,14 +4194,15 @@ long_pow(PyObject *v, PyObject *w, PyObject *x) int negativeOutput = 0; /* if x<0 return negative output */ PyLongObject *z = NULL; /* accumulated result */ - Py_ssize_t i, j, k; /* counters */ + Py_ssize_t i, j; /* counters */ PyLongObject *temp = NULL; + PyLongObject *a2 = NULL; /* may temporarily hold a**2 % c */ - /* 5-ary values. If the exponent is large enough, table is - * precomputed so that table[i] == a**i % c for i in range(32). + /* k-ary values. If the exponent is large enough, table is + * precomputed so that table[i] == a**(2*i+1) % c for i in + * range(EXP_TABLE_LEN). */ - PyLongObject *table[32] = {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, - 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0}; + PyLongObject *table[EXP_TABLE_LEN] = {0}; /* a, b, c = v, w, x */ CHECK_BINOP(v, w); @@ -4332,7 +4355,7 @@ long_pow(PyObject *v, PyObject *w, PyObject *x) } /* else bi is 0, and z==1 is correct */ } - else if (i <= FIVEARY_CUTOFF) { + else if (i <= HUGE_EXP_CUTOFF / PyLong_SHIFT ) { /* Left-to-right binary exponentiation (HAC Algorithm 14.79) */ /* http://www.cacr.math.uwaterloo.ca/hac/about/chap14.pdf */ @@ -4366,23 +4389,59 @@ long_pow(PyObject *v, PyObject *w, PyObject *x) } } else { - /* Left-to-right 5-ary exponentiation (HAC Algorithm 14.82) */ - Py_INCREF(z); /* still holds 1L */ - table[0] = z; - for (i = 1; i < 32; ++i) - MULT(table[i-1], a, table[i]); + /* Left-to-right k-ary sliding window exponentiation + * (Handbook of Applied Cryptography (HAC) Algorithm 14.85) + */ + Py_INCREF(a); + table[0] = a; + MULT(a, a, a2); + /* table[i] == a**(2*i + 1) % c */ + for (i = 1; i < EXP_TABLE_LEN; ++i) + MULT(table[i-1], a2, table[i]); + Py_CLEAR(a2); + + /* Repeatedly extract the next (no more than) EXP_WINDOW_SIZE bits + * into `pending`, starting with the next 1 bit. The current bit + * length of `pending` is `blen`. + */ + int pending = 0, blen = 0; +#define ABSORB_PENDING do { \ + int ntz = 0; /* number of trailing zeroes in `pending` */ \ + assert(pending && blen); \ + assert(pending >> (blen - 1)); \ + assert(pending >> blen == 0); \ + while ((pending & 1) == 0) { \ + ++ntz; \ + pending >>= 1; \ + } \ + assert(ntz < blen); \ + blen -= ntz; \ + do { \ + MULT(z, z, z); \ + } while (--blen); \ + MULT(z, table[pending >> 1], z); \ + while (ntz-- > 0) \ + MULT(z, z, z); \ + assert(blen == 0); \ + pending = 0; \ + } while(0) for (i = Py_SIZE(b) - 1; i >= 0; --i) { const digit bi = b->ob_digit[i]; - - for (j = PyLong_SHIFT - 5; j >= 0; j -= 5) { - const int index = (bi >> j) & 0x1f; - for (k = 0; k < 5; ++k) + for (j = PyLong_SHIFT - 1; j >= 0; --j) { + const int bit = (bi >> j) & 1; + pending = (pending << 1) | bit; + if (pending) { + ++blen; + if (blen == EXP_WINDOW_SIZE) + ABSORB_PENDING; + } + else /* absorb strings of 0 bits */ MULT(z, z, z); - if (index) - MULT(z, table[index], z); } } + if (pending) + ABSORB_PENDING; } if (negativeOutput && (Py_SIZE(z) != 0)) { @@ -4399,13 +4458,14 @@ long_pow(PyObject *v, PyObject *w, PyObject *x) Py_CLEAR(z); /* fall through */ Done: - if (Py_SIZE(b) > FIVEARY_CUTOFF) { - for (i = 0; i < 32; ++i) + if (Py_SIZE(b) > HUGE_EXP_CUTOFF / PyLong_SHIFT) { + for (i = 0; i < EXP_TABLE_LEN; ++i) Py_XDECREF(table[i]); } Py_DECREF(a); Py_DECREF(b); Py_XDECREF(c); + Py_XDECREF(a2); Py_XDECREF(temp); return (PyObject *)z; } |