summaryrefslogtreecommitdiffstatshomepage
path: root/py/mpz.c
diff options
context:
space:
mode:
Diffstat (limited to 'py/mpz.c')
-rw-r--r--py/mpz.c67
1 files changed, 38 insertions, 29 deletions
diff --git a/py/mpz.c b/py/mpz.c
index 2c02699811..bb76479569 100644
--- a/py/mpz.c
+++ b/py/mpz.c
@@ -454,10 +454,8 @@ STATIC mp_uint_t mpn_mul(mpz_dig_t *idig, mpz_dig_t *jdig, mp_uint_t jlen, mpz_d
assumes num_dig has enough memory to be extended by 1 digit
assumes quo_dig has enough memory (as many digits as num)
assumes quo_dig is filled with zeros
- modifies den_dig memory, but restors it to original state at end
*/
-
-STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig, mp_uint_t den_len, mpz_dig_t *quo_dig, mp_uint_t *quo_len) {
+STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, const mpz_dig_t *den_dig, mp_uint_t den_len, mpz_dig_t *quo_dig, mp_uint_t *quo_len) {
mpz_dig_t *orig_num_dig = num_dig;
mpz_dig_t *orig_quo_dig = quo_dig;
mpz_dig_t norm_shift = 0;
@@ -478,6 +476,11 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig,
}
}
+ // We need to normalise the denominator (leading bit of leading digit is 1)
+ // so that the division routine works. Since the denominator memory is
+ // read-only we do the normalisation on the fly, each time a digit of the
+ // denominator is needed. We need to know is how many bits to shift by.
+
// count number of leading zeros in leading digit of denominator
{
mpz_dig_t d = den_dig[den_len - 1];
@@ -487,13 +490,6 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig,
}
}
- // normalise denomenator (leading bit of leading digit is 1)
- for (mpz_dig_t *den = den_dig, carry = 0; den < den_dig + den_len; ++den) {
- mpz_dig_t d = *den;
- *den = ((d << norm_shift) | carry) & DIG_MASK;
- carry = d >> (DIG_SIZE - norm_shift);
- }
-
// now need to shift numerator by same amount as denominator
// first, increase length of numerator in case we need more room to shift
num_dig[*num_len] = 0;
@@ -501,11 +497,14 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig,
for (mpz_dig_t *num = num_dig, carry = 0; num < num_dig + *num_len; ++num) {
mpz_dig_t n = *num;
*num = ((n << norm_shift) | carry) & DIG_MASK;
- carry = n >> (DIG_SIZE - norm_shift);
+ carry = (mpz_dbl_dig_t)n >> (DIG_SIZE - norm_shift);
}
// cache the leading digit of the denominator
- lead_den_digit = den_dig[den_len - 1];
+ lead_den_digit = (mpz_dbl_dig_t)den_dig[den_len - 1] << norm_shift;
+ if (den_len >= 2) {
+ lead_den_digit |= (mpz_dbl_dig_t)den_dig[den_len - 2] >> (DIG_SIZE - norm_shift);
+ }
// point num_dig to last digit in numerator
num_dig += *num_len - 1;
@@ -540,10 +539,13 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig,
// round up).
if (DIG_SIZE < 8 * sizeof(mpz_dbl_dig_t) / 2) {
+ const mpz_dig_t *d = den_dig;
+ mpz_dbl_dig_t d_norm = 0;
mpz_dbl_dig_signed_t borrow = 0;
- for (mpz_dig_t *n = num_dig - den_len, *d = den_dig; n < num_dig; ++n, ++d) {
- borrow += (mpz_dbl_dig_t)*n - (mpz_dbl_dig_t)quo * (mpz_dbl_dig_t)*d; // will overflow if DIG_SIZE >= 8*sizeof(mpz_dbl_dig_t)/2
+ for (mpz_dig_t *n = num_dig - den_len; n < num_dig; ++n, ++d) {
+ d_norm = ((mpz_dbl_dig_t)*d << norm_shift) | (d_norm >> DIG_SIZE);
+ borrow += (mpz_dbl_dig_t)*n - (mpz_dbl_dig_t)quo * (d_norm & DIG_MASK); // will overflow if DIG_SIZE >= 8*sizeof(mpz_dbl_dig_t)/2
*n = borrow & DIG_MASK;
borrow >>= DIG_SIZE;
}
@@ -553,9 +555,12 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig,
// adjust quotient if it is too big
for (; borrow != 0; --quo) {
+ d = den_dig;
+ d_norm = 0;
mpz_dbl_dig_t carry = 0;
- for (mpz_dig_t *n = num_dig - den_len, *d = den_dig; n < num_dig; ++n, ++d) {
- carry += (mpz_dbl_dig_t)*n + (mpz_dbl_dig_t)*d;
+ for (mpz_dig_t *n = num_dig - den_len; n < num_dig; ++n, ++d) {
+ d_norm = ((mpz_dbl_dig_t)*d << norm_shift) | (d_norm >> DIG_SIZE);
+ carry += (mpz_dbl_dig_t)*n + (d_norm & DIG_MASK);
*n = carry & DIG_MASK;
carry >>= DIG_SIZE;
}
@@ -566,10 +571,13 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig,
borrow += carry;
}
} else { // DIG_SIZE == 8 * sizeof(mpz_dbl_dig_t) / 2
+ const mpz_dig_t *d = den_dig;
+ mpz_dbl_dig_t d_norm = 0;
mpz_dbl_dig_t borrow = 0;
- for (mpz_dig_t *n = num_dig - den_len, *d = den_dig; n < num_dig; ++n, ++d) {
- mpz_dbl_dig_t x = (mpz_dbl_dig_t)quo * (mpz_dbl_dig_t)(*d);
+ for (mpz_dig_t *n = num_dig - den_len; n < num_dig; ++n, ++d) {
+ d_norm = ((mpz_dbl_dig_t)*d << norm_shift) | (d_norm >> DIG_SIZE);
+ mpz_dbl_dig_t x = (mpz_dbl_dig_t)quo * (d_norm & DIG_MASK);
if (x >= *n || *n - x <= borrow) {
borrow += (mpz_dbl_dig_t)x - (mpz_dbl_dig_t)*n;
*n = (-borrow) & DIG_MASK;
@@ -590,9 +598,12 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig,
// adjust quotient if it is too big
for (; borrow != 0; --quo) {
+ d = den_dig;
+ d_norm = 0;
mpz_dbl_dig_t carry = 0;
- for (mpz_dig_t *n = num_dig - den_len, *d = den_dig; n < num_dig; ++n, ++d) {
- carry += (mpz_dbl_dig_t)*n + (mpz_dbl_dig_t)*d;
+ for (mpz_dig_t *n = num_dig - den_len; n < num_dig; ++n, ++d) {
+ d_norm = ((mpz_dbl_dig_t)*d << norm_shift) | (d_norm >> DIG_SIZE);
+ carry += (mpz_dbl_dig_t)*n + (d_norm & DIG_MASK);
*n = carry & DIG_MASK;
carry >>= DIG_SIZE;
}
@@ -614,18 +625,11 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig,
--(*num_len);
}
- // unnormalise denomenator
- for (mpz_dig_t *den = den_dig + den_len - 1, carry = 0; den >= den_dig; --den) {
- mpz_dig_t d = *den;
- *den = ((d >> norm_shift) | carry) & DIG_MASK;
- carry = d << (DIG_SIZE - norm_shift);
- }
-
// unnormalise numerator (remainder now)
for (mpz_dig_t *num = orig_num_dig + *num_len - 1, carry = 0; num >= orig_num_dig; --num) {
mpz_dig_t n = *num;
*num = ((n >> norm_shift) | carry) & DIG_MASK;
- carry = n << (DIG_SIZE - norm_shift);
+ carry = (mpz_dbl_dig_t)n << (DIG_SIZE - norm_shift);
}
// strip trailing zeros
@@ -1506,11 +1510,16 @@ void mpz_divmod_inpl(mpz_t *dest_quo, mpz_t *dest_rem, const mpz_t *lhs, const m
dest_quo->len = 0;
mpz_need_dig(dest_rem, lhs->len + 1); // +1 necessary?
mpz_set(dest_rem, lhs);
- //rhs->dig[rhs->len] = 0;
mpn_div(dest_rem->dig, &dest_rem->len, rhs->dig, rhs->len, dest_quo->dig, &dest_quo->len);
+ // check signs and do Python style modulo
if (lhs->neg != rhs->neg) {
dest_quo->neg = 1;
+ if (!mpz_is_zero(dest_rem)) {
+ mpz_t mpzone; mpz_init_from_int(&mpzone, -1);
+ mpz_add_inpl(dest_quo, dest_quo, &mpzone);
+ mpz_add_inpl(dest_rem, dest_rem, rhs);
+ }
}
}