1
0
Fork 0
mirror of synced 2025-09-24 04:40:05 +00:00

use field constants to avoid curve dependant code

This commit is contained in:
dark64 2021-07-14 14:10:16 +02:00
parent 6271c5b746
commit aba16cdc8b
9 changed files with 23 additions and 35 deletions

View file

@ -1,12 +1,11 @@
#pragma curve bn128
from "field" import FIELD_SIZE_IN_BITS
// we can compare numbers up to 2^(pbits - 2) - 1, ie any number which fits in (pbits - 2) bits
// It should not work for the maxvalue = 2^(pbits - 2) - 1 augmented by one
// /!\ should be called with a = 0
def main(field a) -> bool:
field pbits = 254
// maxvalue = 2**252 - 1
field maxvalue = a + 7237005577332262213973186563042994240829374041602535252466099000494570602496 - 1
// we added a = 0 to prevent the condition to be evaluated at compile time
return a < (maxvalue + 1)
u32 pbits = FIELD_SIZE_IN_BITS
// we added a = 0 to prevent the condition to be evaluated at compile time
field maxvalue = a + (2**(pbits - 2) - 1)
return a < (maxvalue + 1)

View file

@ -1,7 +1,9 @@
from "field" import FIELD_MAX
// as p - 1 is greater than p/2, comparing to it should fail
// /!\ should be called with a = 0
def main(field a) -> bool:
field p = 21888242871839275222246405745257275088548364400416034343698204186575808495616 + a
field p = FIELD_MAX + a
// we added a = 0 to prevent the condition to be evaluated at compile time
return a < p

View file

@ -1,10 +1,9 @@
#pragma curve bn128
from "field" import FIELD_SIZE_IN_BITS
// we can compare numbers up to 2^(pbits - 2) - 1, ie any number which fits in (pbits - 2) bits
// lt should work for the maxvalue = 2^(pbits - 2) - 1
def main(field a) -> bool:
field pbits = 254
// maxvalue = 2**252 - 1
field maxvalue = 7237005577332262213973186563042994240829374041602535252466099000494570602496 - 1
u32 pbits = FIELD_SIZE_IN_BITS
field maxvalue = 2**(pbits - 2) - 1
return 0 < maxvalue

View file

@ -1,5 +1,6 @@
from "utils/pack/bool/unpack.zok" import main as unpack
from "utils/casts/u32_to_bits" import main as u32_to_bits
from "field" import FIELD_MAX, FIELD_SIZE_IN_BITS
// this comparison works for any N smaller than the field size, which is the case in practice
def le<N>(bool[N] a_bits, bool[N] c_bits) -> bool:
@ -17,11 +18,9 @@ def le<N>(bool[N] a_bits, bool[N] c_bits) -> bool:
return verified_conditions == N // this checks that all conditions were verified
// this instanciates comparison starting from field elements
// this instantiates comparison starting from field elements
def le<N>(field a, field c) -> bool:
field MAX = 21888242871839275222246405745257275088548364400416034343698204186575808495616
bool[N] MAX_BITS = unpack::<N>(MAX)
bool[N] MAX_BITS = unpack::<N>(FIELD_MAX)
bool[N] a_bits = unpack(a)
assert(le(a_bits, MAX_BITS))
@ -38,10 +37,7 @@ def le(u32 a, u32 c) -> bool:
return le(a_bits, c_bits)
def main(field a, u32 b) -> (bool, bool):
u32 N = 254
field c = 42
u32 d = 42
return le::<N>(a, c), le(b, d)
return le::<FIELD_SIZE_IN_BITS>(a, c), le(b, d)

View file

@ -0,0 +1,6 @@
from "EMBED" import unpack
from "field" import FIELD_SIZE_IN_BITS
def main(field a) -> bool[FIELD_SIZE_IN_BITS]:
bool[FIELD_SIZE_IN_BITS] b = unpack(a)
return b

View file

@ -1,5 +1,5 @@
{
"entry_point": "./tests/tests/split_bls.zok",
"entry_point": "./tests/tests/split.zok",
"curves": ["Bls12_381"],
"tests": [
{

View file

@ -1,7 +0,0 @@
from "EMBED" import unpack
def main(field a) -> (bool[255]):
bool[255] b = unpack(a)
return b

View file

@ -1,5 +1,5 @@
{
"entry_point": "./tests/tests/split_bn.zok",
"entry_point": "./tests/tests/split.zok",
"curves": ["Bn128"],
"tests": [
{

View file

@ -1,7 +0,0 @@
from "EMBED" import unpack
def main(field a) -> (bool[254]):
bool[254] b = unpack(a)
return b