1
0
Fork 0
mirror of synced 2025-09-23 12:18:44 +00:00
This commit is contained in:
dark64 2023-06-16 12:19:28 +02:00
parent e23f480651
commit a518affccd

View file

@ -34,4 +34,137 @@ def split2<K, N>(field[K] mut limbs, field[K] mut carry, field[K][3] split) -> (
carry[i] = sum_and_carry[1]; carry[i] = sum_and_carry[1];
} }
return (limbs, carry); return (limbs, carry);
}
// 1 if true, 0 if false
def long_gt<K>(field[100] a, field[200] b) -> field {
field mut result = 0;
for u32 i in 0..K {
u32 j = K - i - 1;
result = a[j] > b[j] ? 1 : result;
}
return result;
}
// n bits per register
// a has k registers
// b has k registers
// a >= b
def long_sub<N, K>(field[200] a, field[200] b) -> field[100] {
field[100] mut diff = [0; 100];
field[100] mut borrow = [0; 100];
for u32 i in 0..K {
field[2] tmp = if i == 0 {
(a[i] >= b[i]) ? [a[i] - b[i], 0] : [a[i] - b[i] + (1 << N), 1]
} else {
(a[i] >= b[i] + borrow[i - 1]) ? [a[i] - b[i] - borrow[i - 1], 0] : [(1 << N) + a[i] - b[i] - borrow[i - 1], 1]
};
diff[i] = tmp[0];
borrow[i] = tmp[1];
}
return diff;
}
// a is a n-bit scalar
// b has k registers
def long_scalar_mult<N, K>(field a, field[K] b) -> field[100] {
field[100] mut out = [0; 100];
for u32 i in 0..K {
field temp = out[i] + (a * b[i]);
out[i] = temp % (1 << N);
out[i + 1] = out[i + 1] + temp \ (1 << N);
}
return out;
}
def short_div_norm_f1<N, K>(field[100] mut mult, field[200] a, field[200] b, field qhat) -> field {
mult = long_sub::<N, K>([...mult, ...[0; 100]], b);
return (long_gt::<K>(mult, a) == 1) ? qhat - 2 : qhat - 1;
}
// n bits per register
// a has k + 1 registers
// b has k registers
// assumes leading digit of b is at least 2 ** (n - 1)
// 0 <= a < (2**n) * b
def short_div_norm<N, K>(field[200] a, field[200] b) -> field {
field mut qhat = (a[K] * (1 << N) + a[K - 1]) \ b[K - 1];
qhat = (qhat > (1 << N) - 1) ? (1 << N) - 1 : qhat;
field[100] mult = long_scalar_mult::<N, K>(qhat, b[..K]);
u32 K1 = K + 1;
qhat = (long_gt::<K1>(mult, a) == 1) ? short_div_norm_f1::<N, K1>(mult, a, b, qhat) : qhat;
return qhat;
}
// n bits per register
// a has k + 1 registers
// b has k registers
// assumes leading digit of b is non-zero
// 0 <= a < (2**n) * b
def short_div<N, K>(field[200] a, field[K] b) -> field {
field scale = (1 << N) \ (1 + b[K - 1]);
u32 K1 = K + 1;
// k + 2 registers now
field[200] norm_a = [...long_scalar_mult::<N, K1>(scale, a[..K1]), ...[0; 100]];
// k + 1 registers now
field[200] norm_b = [...long_scalar_mult::<N, K>(scale, b), ...[0; 100]];
field ret = norm_b[K] != 0 ? short_div_norm::<N, K1>(norm_a, norm_b) : short_div_norm::<N, K>(norm_a, norm_b);
return ret;
}
def long_div_f1<K, M>(field[200] mut dividend, field[200] remainder) -> field[200] {
dividend[K] = 0;
for u32 i in 0..K {
u32 j = K - i - 1;
dividend[j] = remainder[j + M];
}
return dividend;
}
def long_div_f2<K>(field[200] mut dividend, field[200] remainder, u32 n) -> field[200] {
for u32 i in 0..K+1 {
u32 j = K - i;
dividend[j] = remainder[j + n];
}
return dividend;
}
def long_div<N, K, M, P>(field[P] a, field[K] b) -> field[2][100] {
assert(P == K + M);
field[200] mut remainder = [0; 200];
for u32 i in 0..P {
remainder[i] = a[i];
}
field[2][100] mut out = [[0; 100]; 2];
field[200] mut mult = [0; 200];
field[200] mut dividend = [0; 200];
for u32 j in 0..M+1 {
u32 i = M - j;
dividend = i == M ? long_div_f1::<K, M>(dividend, remainder) : long_div_f2::<K>(dividend, remainder, i);
out[0][i] = short_div::<N, K>(dividend, b);
field[100] mult_shift = long_scalar_mult::<N, K>(out[0][i], b);
field[200] mut subtrahend = [0; 200];
for u32 k in 0..K+1 {
subtrahend[i + k] = i + k < P ? mult_shift[k] : subtrahend[i + k];
}
remainder = [...long_sub::<N, P>(remainder, subtrahend), ...[0; 100]];
}
for u32 i in 0..K {
out[1][i] = remainder[i];
}
out[1][K] = 0;
return out;
} }