wip
This commit is contained in:
parent
676f6b844f
commit
0ffcb31392
3 changed files with 289 additions and 141 deletions
3
test.zok
Normal file
3
test.zok
Normal file
|
@ -0,0 +1,3 @@
|
|||
def main(field x):
|
||||
assert(x < 3)
|
||||
return
|
|
@ -191,7 +191,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
// `sizeUnknown` accordingly:
|
||||
// `sizeUnknown = sizeUnknown && a[0]`
|
||||
// * b[1] = 1
|
||||
// `sizrUnknown = sizeUnknown && a[1]`
|
||||
// `sizeUnknown = sizeUnknown && a[1]`
|
||||
// * b[2] = 1
|
||||
// `sizeUnknown = sizeUnknown && a[2]`
|
||||
// * b[3] = 0
|
||||
|
@ -203,12 +203,16 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
// sizeUnkown *
|
||||
// **false => a -> {0,1}
|
||||
// ```
|
||||
//
|
||||
// # Returns
|
||||
//
|
||||
// * a vector of FlatExpression which all evaluate to 1 if a <= b and 0 otherwise
|
||||
fn strict_le_check(
|
||||
&mut self,
|
||||
statements_flattened: &mut FlatStatements<T>,
|
||||
a: &[FlatVariable],
|
||||
b: &[bool],
|
||||
a: Vec<FlatVariable>,
|
||||
) {
|
||||
) -> Vec<FlatExpression<T>> {
|
||||
let len = b.len();
|
||||
assert_eq!(a.len(), T::get_required_bits());
|
||||
assert_eq!(a.len(), b.len());
|
||||
|
@ -227,6 +231,8 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
FlatExpression::Number(T::from(1)),
|
||||
));
|
||||
|
||||
let mut res = vec![];
|
||||
|
||||
for (i, b) in b.iter().enumerate() {
|
||||
if *b {
|
||||
statements_flattened.push(FlatStatement::Definition(
|
||||
|
@ -275,12 +281,25 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
box and_name.into(),
|
||||
);
|
||||
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
FlatExpression::Number(T::from(1)),
|
||||
or,
|
||||
));
|
||||
res.push(or);
|
||||
}
|
||||
}
|
||||
|
||||
res
|
||||
}
|
||||
|
||||
fn enforce_strict_le_check(
|
||||
&mut self,
|
||||
statements_flattened: &mut FlatStatements<T>,
|
||||
a: &[FlatVariable],
|
||||
b: &[bool],
|
||||
) {
|
||||
let statements: Vec<_> = self
|
||||
.strict_le_check(statements_flattened, a, b)
|
||||
.into_iter()
|
||||
.map(|c| FlatStatement::Condition(FlatExpression::Number(T::from(1)), c))
|
||||
.collect();
|
||||
statements_flattened.extend(statements);
|
||||
}
|
||||
|
||||
/// Flatten an if/else expression
|
||||
|
@ -392,157 +411,268 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
let rhs_flattened =
|
||||
self.flatten_field_expression(symbols, statements_flattened, rhs);
|
||||
|
||||
// lhs
|
||||
let lhs_id = self.define(lhs_flattened, statements_flattened);
|
||||
match (lhs_flattened, rhs_flattened) {
|
||||
(lhs_flattened, FlatExpression::Number(constant)) => {
|
||||
// decompose lhs to bits
|
||||
let lhs_id = self.define(lhs_flattened, statements_flattened);
|
||||
|
||||
// check that lhs and rhs are within the right range, i.e., their higher two bits are zero. We use big-endian so they are at positions 0 and 1
|
||||
// define variables for the bits
|
||||
let lhs_bits_be: Vec<FlatVariable> =
|
||||
(0..bit_width).map(|_| self.use_sym()).collect();
|
||||
|
||||
// lhs
|
||||
{
|
||||
// define variables for the bits
|
||||
let lhs_bits_be: Vec<FlatVariable> =
|
||||
(0..safe_width).map(|_| self.use_sym()).collect();
|
||||
// add a directive to get the bits
|
||||
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
|
||||
lhs_bits_be.clone(),
|
||||
Solver::bits(bit_width),
|
||||
vec![lhs_id],
|
||||
)));
|
||||
|
||||
// add a directive to get the bits
|
||||
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
|
||||
lhs_bits_be.clone(),
|
||||
Solver::bits(safe_width),
|
||||
vec![lhs_id],
|
||||
)));
|
||||
// bitness checks
|
||||
for i in 0..bit_width {
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
FlatExpression::Identifier(lhs_bits_be[i]),
|
||||
FlatExpression::Mult(
|
||||
box FlatExpression::Identifier(lhs_bits_be[i]),
|
||||
box FlatExpression::Identifier(lhs_bits_be[i]),
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
// bit decomposition check
|
||||
let mut lhs_sum = FlatExpression::Number(T::from(0));
|
||||
|
||||
for i in 0..bit_width {
|
||||
lhs_sum = FlatExpression::Add(
|
||||
box lhs_sum,
|
||||
box FlatExpression::Mult(
|
||||
box FlatExpression::Identifier(lhs_bits_be[i]),
|
||||
box FlatExpression::Number(T::from(2).pow(bit_width - i - 1)),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
// bitness checks
|
||||
for i in 0..safe_width {
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
FlatExpression::Identifier(lhs_bits_be[i]),
|
||||
FlatExpression::Identifier(lhs_id),
|
||||
lhs_sum,
|
||||
));
|
||||
|
||||
// check that this decomposition does not overflow the field
|
||||
self.enforce_strict_le_check(
|
||||
statements_flattened,
|
||||
&lhs_bits_be,
|
||||
&T::max_value().bit_vector_be(),
|
||||
);
|
||||
|
||||
let conditions = self.strict_le_check(
|
||||
statements_flattened,
|
||||
&lhs_bits_be,
|
||||
&constant.bit_vector_be(),
|
||||
);
|
||||
|
||||
// define the condition as `conditions.len() - sum(conditions)`
|
||||
let condition = FlatExpression::Sub(
|
||||
box FlatExpression::Number(T::from(conditions.len())),
|
||||
box conditions
|
||||
.into_iter()
|
||||
.fold(FlatExpression::Number(T::zero()), |acc, e| {
|
||||
FlatExpression::Add(box acc, box e)
|
||||
}),
|
||||
);
|
||||
|
||||
// return `condition == 0`
|
||||
|
||||
// copy pasted the fieldeq flattening code
|
||||
let name_y = self.use_sym();
|
||||
let name_m = self.use_sym();
|
||||
|
||||
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
|
||||
vec![name_y, name_m],
|
||||
Solver::ConditionEq,
|
||||
vec![condition.clone()],
|
||||
)));
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
FlatExpression::Identifier(name_y),
|
||||
FlatExpression::Mult(
|
||||
box FlatExpression::Identifier(lhs_bits_be[i]),
|
||||
box FlatExpression::Identifier(lhs_bits_be[i]),
|
||||
box condition.clone(),
|
||||
box FlatExpression::Identifier(name_m),
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
// bit decomposition check
|
||||
let mut lhs_sum = FlatExpression::Number(T::from(0));
|
||||
|
||||
for i in 0..safe_width {
|
||||
lhs_sum = FlatExpression::Add(
|
||||
box lhs_sum,
|
||||
box FlatExpression::Mult(
|
||||
box FlatExpression::Identifier(lhs_bits_be[i]),
|
||||
box FlatExpression::Number(T::from(2).pow(safe_width - i - 1)),
|
||||
),
|
||||
let res = FlatExpression::Sub(
|
||||
box FlatExpression::Number(T::one()),
|
||||
box FlatExpression::Identifier(name_y),
|
||||
);
|
||||
}
|
||||
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
FlatExpression::Identifier(lhs_id),
|
||||
lhs_sum,
|
||||
));
|
||||
}
|
||||
|
||||
// rhs
|
||||
let rhs_id = self.define(rhs_flattened, statements_flattened);
|
||||
|
||||
// rhs
|
||||
{
|
||||
// define variables for the bits
|
||||
let rhs_bits_be: Vec<FlatVariable> =
|
||||
(0..safe_width).map(|_| self.use_sym()).collect();
|
||||
|
||||
// add a directive to get the bits
|
||||
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
|
||||
rhs_bits_be.clone(),
|
||||
Solver::bits(safe_width),
|
||||
vec![rhs_id],
|
||||
)));
|
||||
|
||||
// bitness checks
|
||||
for i in 0..safe_width {
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
FlatExpression::Identifier(rhs_bits_be[i]),
|
||||
FlatExpression::Mult(
|
||||
box FlatExpression::Identifier(rhs_bits_be[i]),
|
||||
box FlatExpression::Identifier(rhs_bits_be[i]),
|
||||
),
|
||||
FlatExpression::Number(T::zero()),
|
||||
FlatExpression::Mult(box res.clone(), box condition),
|
||||
));
|
||||
|
||||
res
|
||||
}
|
||||
(lhs_flattened, rhs_flattened) => {
|
||||
// lhs
|
||||
let lhs_id = self.define(lhs_flattened, statements_flattened);
|
||||
|
||||
// bit decomposition check
|
||||
let mut rhs_sum = FlatExpression::Number(T::from(0));
|
||||
// check that lhs and rhs are within the right range, i.e., their higher two bits are zero. We use big-endian so they are at positions 0 and 1
|
||||
|
||||
for i in 0..safe_width {
|
||||
rhs_sum = FlatExpression::Add(
|
||||
box rhs_sum,
|
||||
// lhs
|
||||
{
|
||||
// define variables for the bits
|
||||
let lhs_bits_be: Vec<FlatVariable> =
|
||||
(0..safe_width).map(|_| self.use_sym()).collect();
|
||||
|
||||
// add a directive to get the bits
|
||||
statements_flattened.push(FlatStatement::Directive(
|
||||
FlatDirective::new(
|
||||
lhs_bits_be.clone(),
|
||||
Solver::bits(safe_width),
|
||||
vec![lhs_id],
|
||||
),
|
||||
));
|
||||
|
||||
// bitness checks
|
||||
for i in 0..safe_width {
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
FlatExpression::Identifier(lhs_bits_be[i]),
|
||||
FlatExpression::Mult(
|
||||
box FlatExpression::Identifier(lhs_bits_be[i]),
|
||||
box FlatExpression::Identifier(lhs_bits_be[i]),
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
// bit decomposition check
|
||||
let mut lhs_sum = FlatExpression::Number(T::from(0));
|
||||
|
||||
for i in 0..safe_width {
|
||||
lhs_sum = FlatExpression::Add(
|
||||
box lhs_sum,
|
||||
box FlatExpression::Mult(
|
||||
box FlatExpression::Identifier(lhs_bits_be[i]),
|
||||
box FlatExpression::Number(
|
||||
T::from(2).pow(safe_width - i - 1),
|
||||
),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
FlatExpression::Identifier(lhs_id),
|
||||
lhs_sum,
|
||||
));
|
||||
}
|
||||
|
||||
// rhs
|
||||
let rhs_id = self.define(rhs_flattened, statements_flattened);
|
||||
|
||||
// rhs
|
||||
{
|
||||
// define variables for the bits
|
||||
let rhs_bits_be: Vec<FlatVariable> =
|
||||
(0..safe_width).map(|_| self.use_sym()).collect();
|
||||
|
||||
// add a directive to get the bits
|
||||
statements_flattened.push(FlatStatement::Directive(
|
||||
FlatDirective::new(
|
||||
rhs_bits_be.clone(),
|
||||
Solver::bits(safe_width),
|
||||
vec![rhs_id],
|
||||
),
|
||||
));
|
||||
|
||||
// bitness checks
|
||||
for i in 0..safe_width {
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
FlatExpression::Identifier(rhs_bits_be[i]),
|
||||
FlatExpression::Mult(
|
||||
box FlatExpression::Identifier(rhs_bits_be[i]),
|
||||
box FlatExpression::Identifier(rhs_bits_be[i]),
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
// bit decomposition check
|
||||
let mut rhs_sum = FlatExpression::Number(T::from(0));
|
||||
|
||||
for i in 0..safe_width {
|
||||
rhs_sum = FlatExpression::Add(
|
||||
box rhs_sum,
|
||||
box FlatExpression::Mult(
|
||||
box FlatExpression::Identifier(rhs_bits_be[i]),
|
||||
box FlatExpression::Number(
|
||||
T::from(2).pow(safe_width - i - 1),
|
||||
),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
FlatExpression::Identifier(rhs_id),
|
||||
rhs_sum,
|
||||
));
|
||||
}
|
||||
|
||||
// sym := (lhs * 2) - (rhs * 2)
|
||||
let subtraction_result = FlatExpression::Sub(
|
||||
box FlatExpression::Mult(
|
||||
box FlatExpression::Identifier(rhs_bits_be[i]),
|
||||
box FlatExpression::Number(T::from(2).pow(safe_width - i - 1)),
|
||||
box FlatExpression::Number(T::from(2)),
|
||||
box FlatExpression::Identifier(lhs_id),
|
||||
),
|
||||
box FlatExpression::Mult(
|
||||
box FlatExpression::Number(T::from(2)),
|
||||
box FlatExpression::Identifier(rhs_id),
|
||||
),
|
||||
);
|
||||
|
||||
// define variables for the bits
|
||||
let sub_bits_be: Vec<FlatVariable> =
|
||||
(0..bit_width).map(|_| self.use_sym()).collect();
|
||||
|
||||
// add a directive to get the bits
|
||||
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
|
||||
sub_bits_be.clone(),
|
||||
Solver::bits(bit_width),
|
||||
vec![subtraction_result.clone()],
|
||||
)));
|
||||
|
||||
// bitness checks
|
||||
for i in 0..bit_width {
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
FlatExpression::Identifier(sub_bits_be[i]),
|
||||
FlatExpression::Mult(
|
||||
box FlatExpression::Identifier(sub_bits_be[i]),
|
||||
box FlatExpression::Identifier(sub_bits_be[i]),
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
// check that the decomposition is in the field with a strict `< p` checks
|
||||
self.enforce_strict_le_check(
|
||||
statements_flattened,
|
||||
&sub_bits_be,
|
||||
&T::max_value().bit_vector_be(),
|
||||
);
|
||||
|
||||
// sum(sym_b{i} * 2**i)
|
||||
let mut expr = FlatExpression::Number(T::from(0));
|
||||
|
||||
for i in 0..bit_width {
|
||||
expr = FlatExpression::Add(
|
||||
box expr,
|
||||
box FlatExpression::Mult(
|
||||
box FlatExpression::Identifier(sub_bits_be[i]),
|
||||
box FlatExpression::Number(T::from(2).pow(bit_width - i - 1)),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
statements_flattened
|
||||
.push(FlatStatement::Condition(subtraction_result, expr));
|
||||
|
||||
FlatExpression::Identifier(sub_bits_be[bit_width - 1])
|
||||
}
|
||||
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
FlatExpression::Identifier(rhs_id),
|
||||
rhs_sum,
|
||||
));
|
||||
}
|
||||
|
||||
// sym := (lhs * 2) - (rhs * 2)
|
||||
let subtraction_result = FlatExpression::Sub(
|
||||
box FlatExpression::Mult(
|
||||
box FlatExpression::Number(T::from(2)),
|
||||
box FlatExpression::Identifier(lhs_id),
|
||||
),
|
||||
box FlatExpression::Mult(
|
||||
box FlatExpression::Number(T::from(2)),
|
||||
box FlatExpression::Identifier(rhs_id),
|
||||
),
|
||||
);
|
||||
|
||||
// define variables for the bits
|
||||
let sub_bits_be: Vec<FlatVariable> =
|
||||
(0..bit_width).map(|_| self.use_sym()).collect();
|
||||
|
||||
// add a directive to get the bits
|
||||
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
|
||||
sub_bits_be.clone(),
|
||||
Solver::bits(bit_width),
|
||||
vec![subtraction_result.clone()],
|
||||
)));
|
||||
|
||||
// bitness checks
|
||||
for i in 0..bit_width {
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
FlatExpression::Identifier(sub_bits_be[i]),
|
||||
FlatExpression::Mult(
|
||||
box FlatExpression::Identifier(sub_bits_be[i]),
|
||||
box FlatExpression::Identifier(sub_bits_be[i]),
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
// check that the decomposition is in the field with a strict `< p` checks
|
||||
self.strict_le_check(
|
||||
statements_flattened,
|
||||
&T::max_value_bit_vector_be(),
|
||||
sub_bits_be.clone(),
|
||||
);
|
||||
|
||||
// sum(sym_b{i} * 2**i)
|
||||
let mut expr = FlatExpression::Number(T::from(0));
|
||||
|
||||
for i in 0..bit_width {
|
||||
expr = FlatExpression::Add(
|
||||
box expr,
|
||||
box FlatExpression::Mult(
|
||||
box FlatExpression::Identifier(sub_bits_be[i]),
|
||||
box FlatExpression::Number(T::from(2).pow(bit_width - i - 1)),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
statements_flattened.push(FlatStatement::Condition(subtraction_result, expr));
|
||||
|
||||
FlatExpression::Identifier(sub_bits_be[bit_width - 1])
|
||||
}
|
||||
BooleanExpression::BoolEq(box lhs, box rhs) => {
|
||||
// lhs and rhs are booleans, they flatten to 0 or 1
|
||||
|
|
|
@ -106,7 +106,8 @@ pub trait Field:
|
|||
/// Gets the number of bits
|
||||
fn bits(&self) -> u32;
|
||||
/// Returns this `Field`'s largest value as a big-endian bit vector
|
||||
fn max_value_bit_vector_be() -> Vec<bool> {
|
||||
/// Always returns `Self::get_required_bits()` elements
|
||||
fn bit_vector_be(&self) -> Vec<bool> {
|
||||
fn bytes_to_bits(bytes: &[u8]) -> Vec<bool> {
|
||||
bytes
|
||||
.iter()
|
||||
|
@ -114,13 +115,27 @@ pub trait Field:
|
|||
.collect()
|
||||
}
|
||||
|
||||
let field_bytes_le = Self::into_byte_vector(&Self::max_value());
|
||||
let field_bytes_le = Self::into_byte_vector(&self);
|
||||
// reverse for big-endianess
|
||||
let field_bytes_be = field_bytes_le.into_iter().rev().collect::<Vec<u8>>();
|
||||
let field_bits_be = bytes_to_bits(&field_bytes_be);
|
||||
|
||||
let field_bits_be = &field_bits_be[field_bits_be.len() - Self::get_required_bits()..];
|
||||
field_bits_be.to_vec()
|
||||
let field_bits_be: Vec<_> = (0..Self::get_required_bits()
|
||||
.checked_sub(field_bits_be.len())
|
||||
.unwrap_or(0))
|
||||
.map(|_| &false)
|
||||
.chain(
|
||||
&field_bits_be[field_bits_be
|
||||
.len()
|
||||
.checked_sub(Self::get_required_bits())
|
||||
.unwrap_or(0)..],
|
||||
)
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
assert_eq!(field_bits_be.len(), Self::get_required_bits());
|
||||
|
||||
field_bits_be
|
||||
}
|
||||
/// Returns the value as a BigUint
|
||||
fn to_biguint(&self) -> BigUint;
|
||||
|
|
Loading…
Reference in a new issue