summaryrefslogtreecommitdiffstatshomepage
path: root/py/mpz.c
diff options
context:
space:
mode:
Diffstat (limited to 'py/mpz.c')
-rw-r--r--py/mpz.c87
1 files changed, 71 insertions, 16 deletions
diff --git a/py/mpz.c b/py/mpz.c
index 8e6aecbcae..186229569b 100644
--- a/py/mpz.c
+++ b/py/mpz.c
@@ -37,7 +37,9 @@
#if MICROPY_LONGINT_IMPL == MICROPY_LONGINT_IMPL_MPZ
#define DIG_SIZE (MPZ_DIG_SIZE)
-#define DIG_MASK ((1 << DIG_SIZE) - 1)
+#define DIG_MASK ((1L << DIG_SIZE) - 1)
+#define DIG_MSB (1L << (DIG_SIZE - 1))
+#define DIG_BASE (1L << DIG_SIZE)
/*
mpz is an arbitrary precision integer type with a public API.
@@ -61,7 +63,7 @@ STATIC mp_int_t mpn_cmp(const mpz_dig_t *idig, mp_uint_t ilen, const mpz_dig_t *
if (ilen > jlen) { return 1; }
for (idig += ilen, jdig += ilen; ilen > 0; --ilen) {
- mp_int_t cmp = *(--idig) - *(--jdig);
+ mpz_dbl_dig_signed_t cmp = (mpz_dbl_dig_t)*(--idig) - (mpz_dbl_dig_t)*(--jdig);
if (cmp < 0) { return -1; }
if (cmp > 0) { return 1; }
}
@@ -127,7 +129,7 @@ STATIC mp_uint_t mpn_shr(mpz_dig_t *idig, mpz_dig_t *jdig, mp_uint_t jlen, mp_ui
for (mp_uint_t i = jlen; i > 0; i--, idig++, jdig++) {
mpz_dbl_dig_t d = *jdig;
if (i > 1) {
- d |= jdig[1] << DIG_SIZE;
+ d |= (mpz_dbl_dig_t)jdig[1] << DIG_SIZE;
}
d >>= n_part;
*idig = d & DIG_MASK;
@@ -152,7 +154,7 @@ STATIC mp_uint_t mpn_add(mpz_dig_t *idig, const mpz_dig_t *jdig, mp_uint_t jlen,
jlen -= klen;
for (; klen > 0; --klen, ++idig, ++jdig, ++kdig) {
- carry += *jdig + *kdig;
+ carry += (mpz_dbl_dig_t)*jdig + (mpz_dbl_dig_t)*kdig;
*idig = carry & DIG_MASK;
carry >>= DIG_SIZE;
}
@@ -182,7 +184,7 @@ STATIC mp_uint_t mpn_sub(mpz_dig_t *idig, const mpz_dig_t *jdig, mp_uint_t jlen,
jlen -= klen;
for (; klen > 0; --klen, ++idig, ++jdig, ++kdig) {
- borrow += *jdig - *kdig;
+ borrow += (mpz_dbl_dig_t)*jdig - (mpz_dbl_dig_t)*kdig;
*idig = borrow & DIG_MASK;
borrow >>= DIG_SIZE;
}
@@ -301,7 +303,7 @@ STATIC mp_uint_t mpn_mul_dig_add_dig(mpz_dig_t *idig, mp_uint_t ilen, mpz_dig_t
mpz_dbl_dig_t carry = dadd;
for (; ilen > 0; --ilen, ++idig) {
- carry += *idig * dmul; // will never overflow so long as DIG_SIZE <= WORD_SIZE / 2
+ carry += (mpz_dbl_dig_t)*idig * (mpz_dbl_dig_t)dmul; // will never overflow so long as DIG_SIZE <= 8*sizeof(mpz_dbl_dig_t)/2
*idig = carry & DIG_MASK;
carry >>= DIG_SIZE;
}
@@ -328,7 +330,7 @@ STATIC mp_uint_t mpn_mul(mpz_dig_t *idig, mpz_dig_t *jdig, mp_uint_t jlen, mpz_d
mp_uint_t jl = jlen;
for (mpz_dig_t *jd = jdig; jl > 0; --jl, ++jd, ++id) {
- carry += *id + *jd * *kdig; // will never overflow so long as DIG_SIZE <= WORD_SIZE / 2
+ carry += (mpz_dbl_dig_t)*id + (mpz_dbl_dig_t)*jd * (mpz_dbl_dig_t)*kdig; // will never overflow so long as DIG_SIZE <= 8*sizeof(mpz_dbl_dig_t)/2
*id = carry & DIG_MASK;
carry >>= DIG_SIZE;
}
@@ -375,7 +377,7 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig,
// count number of leading zeros in leading digit of denominator
{
mpz_dig_t d = den_dig[den_len - 1];
- while ((d & (1 << (DIG_SIZE - 1))) == 0) {
+ while ((d & DIG_MSB) == 0) {
d <<= 1;
++norm_shift;
}
@@ -412,21 +414,36 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig,
// keep going while we have enough digits to divide
while (*num_len > den_len) {
- mpz_dbl_dig_t quo = (*num_dig << DIG_SIZE) | num_dig[-1];
+ mpz_dbl_dig_t quo = ((mpz_dbl_dig_t)*num_dig << DIG_SIZE) | num_dig[-1];
// get approximate quotient
quo /= lead_den_digit;
- // multiply quo by den and subtract from num get remainder
- {
+ // Multiply quo by den and subtract from num to get remainder.
+ // We have different code here to handle different compile-time
+ // configurations of mpz:
+ //
+ // 1. DIG_SIZE is stricly less than half the number of bits
+ // available in mpz_dbl_dig_t. In this case we can use a
+ // slightly more optimal (in time and space) routine that
+ // uses the extra bits in mpz_dbl_dig_signed_t to store a
+ // sign bit.
+ //
+ // 2. DIG_SIZE is exactly half the number of bits available in
+ // mpz_dbl_dig_t. In this (common) case we need to be careful
+ // not to overflow the borrow variable. And the shifting of
+ // borrow needs some special logic (it's a shift right with
+ // round up).
+
+ if (DIG_SIZE < 8 * sizeof(mpz_dbl_dig_t) / 2) {
mpz_dbl_dig_signed_t borrow = 0;
for (mpz_dig_t *n = num_dig - den_len, *d = den_dig; n < num_dig; ++n, ++d) {
- borrow += *n - quo * *d; // will overflow if DIG_SIZE >= 16
+ borrow += (mpz_dbl_dig_t)*n - (mpz_dbl_dig_t)quo * (mpz_dbl_dig_t)*d; // will overflow if DIG_SIZE >= 8*sizeof(mpz_dbl_dig_t)/2
*n = borrow & DIG_MASK;
borrow >>= DIG_SIZE;
}
- borrow += *num_dig; // will overflow if DIG_SIZE >= 16
+ borrow += *num_dig; // will overflow if DIG_SIZE >= 8*sizeof(mpz_dbl_dig_t)/2
*num_dig = borrow & DIG_MASK;
borrow >>= DIG_SIZE;
@@ -434,7 +451,7 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig,
for (; borrow != 0; --quo) {
mpz_dbl_dig_t carry = 0;
for (mpz_dig_t *n = num_dig - den_len, *d = den_dig; n < num_dig; ++n, ++d) {
- carry += *n + *d;
+ carry += (mpz_dbl_dig_t)*n + (mpz_dbl_dig_t)*d;
*n = carry & DIG_MASK;
carry >>= DIG_SIZE;
}
@@ -444,6 +461,44 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig,
borrow += carry;
}
+ } else { // DIG_SIZE == 8 * sizeof(mpz_dbl_dig_t) / 2
+ mpz_dbl_dig_t borrow = 0;
+
+ for (mpz_dig_t *n = num_dig - den_len, *d = den_dig; n < num_dig; ++n, ++d) {
+ mpz_dbl_dig_t x = (mpz_dbl_dig_t)quo * (mpz_dbl_dig_t)(*d);
+ if (x >= *n || *n - x <= borrow) {
+ borrow += (mpz_dbl_dig_t)x - (mpz_dbl_dig_t)*n;
+ *n = (-borrow) & DIG_MASK;
+ borrow = (borrow >> DIG_SIZE) + ((borrow & DIG_MASK) == 0 ? 0 : 1); // shift-right with round-up
+ } else {
+ *n = ((mpz_dbl_dig_t)*n - (mpz_dbl_dig_t)x - (mpz_dbl_dig_t)borrow) & DIG_MASK;
+ borrow = 0;
+ }
+ }
+ if (borrow >= *num_dig) {
+ borrow -= (mpz_dbl_dig_t)*num_dig;
+ *num_dig = (-borrow) & DIG_MASK;
+ borrow = (borrow >> DIG_SIZE) + ((borrow & DIG_MASK) == 0 ? 0 : 1); // shift-right with round-up
+ } else {
+ *num_dig = (*num_dig - borrow) & DIG_MASK;
+ borrow = 0;
+ }
+
+ // adjust quotient if it is too big
+ for (; borrow != 0; --quo) {
+ mpz_dbl_dig_t carry = 0;
+ for (mpz_dig_t *n = num_dig - den_len, *d = den_dig; n < num_dig; ++n, ++d) {
+ carry += (mpz_dbl_dig_t)*n + (mpz_dbl_dig_t)*d;
+ *n = carry & DIG_MASK;
+ carry >>= DIG_SIZE;
+ }
+ carry += (mpz_dbl_dig_t)*num_dig;
+ *num_dig = carry & DIG_MASK;
+ carry >>= DIG_SIZE;
+
+ //assert(borrow >= carry); // enable this to check the logic
+ borrow -= carry;
+ }
}
// store this digit of the quotient
@@ -1256,7 +1311,7 @@ bool mpz_as_uint_checked(const mpz_t *i, mp_uint_t *value) {
mpz_dig_t *d = i->dig + i->len;
while (--d >= i->dig) {
- if (val > ((~0) >> DIG_SIZE)) {
+ if (val > (~(WORD_MSBIT_HIGH) >> (DIG_SIZE - 1))) {
// will overflow
return false;
}
@@ -1273,7 +1328,7 @@ mp_float_t mpz_as_float(const mpz_t *i) {
mpz_dig_t *d = i->dig + i->len;
while (--d >= i->dig) {
- val = val * (1 << DIG_SIZE) + *d;
+ val = val * DIG_BASE + *d;
}
if (i->neg != 0) {