1
0
Fork 0
mirror of synced 2025-09-23 12:18:44 +00:00
This commit is contained in:
schaeff 2021-03-12 12:23:03 +01:00
parent 676f6b844f
commit 0ffcb31392
3 changed files with 289 additions and 141 deletions

3
test.zok Normal file
View file

@ -0,0 +1,3 @@
def main(field x):
assert(x < 3)
return

View file

@ -191,7 +191,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
// `sizeUnknown` accordingly:
// `sizeUnknown = sizeUnknown && a[0]`
// * b[1] = 1
// `sizrUnknown = sizeUnknown && a[1]`
// `sizeUnknown = sizeUnknown && a[1]`
// * b[2] = 1
// `sizeUnknown = sizeUnknown && a[2]`
// * b[3] = 0
@ -203,12 +203,16 @@ impl<'ast, T: Field> Flattener<'ast, T> {
// sizeUnkown *
// **false => a -> {0,1}
// ```
//
// # Returns
//
// * a vector of FlatExpression which all evaluate to 1 if a <= b and 0 otherwise
fn strict_le_check(
&mut self,
statements_flattened: &mut FlatStatements<T>,
a: &[FlatVariable],
b: &[bool],
a: Vec<FlatVariable>,
) {
) -> Vec<FlatExpression<T>> {
let len = b.len();
assert_eq!(a.len(), T::get_required_bits());
assert_eq!(a.len(), b.len());
@ -227,6 +231,8 @@ impl<'ast, T: Field> Flattener<'ast, T> {
FlatExpression::Number(T::from(1)),
));
let mut res = vec![];
for (i, b) in b.iter().enumerate() {
if *b {
statements_flattened.push(FlatStatement::Definition(
@ -275,12 +281,25 @@ impl<'ast, T: Field> Flattener<'ast, T> {
box and_name.into(),
);
statements_flattened.push(FlatStatement::Condition(
FlatExpression::Number(T::from(1)),
or,
));
res.push(or);
}
}
res
}
fn enforce_strict_le_check(
&mut self,
statements_flattened: &mut FlatStatements<T>,
a: &[FlatVariable],
b: &[bool],
) {
let statements: Vec<_> = self
.strict_le_check(statements_flattened, a, b)
.into_iter()
.map(|c| FlatStatement::Condition(FlatExpression::Number(T::from(1)), c))
.collect();
statements_flattened.extend(statements);
}
/// Flatten an if/else expression
@ -392,6 +411,106 @@ impl<'ast, T: Field> Flattener<'ast, T> {
let rhs_flattened =
self.flatten_field_expression(symbols, statements_flattened, rhs);
match (lhs_flattened, rhs_flattened) {
(lhs_flattened, FlatExpression::Number(constant)) => {
// decompose lhs to bits
let lhs_id = self.define(lhs_flattened, statements_flattened);
// define variables for the bits
let lhs_bits_be: Vec<FlatVariable> =
(0..bit_width).map(|_| self.use_sym()).collect();
// add a directive to get the bits
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
lhs_bits_be.clone(),
Solver::bits(bit_width),
vec![lhs_id],
)));
// bitness checks
for i in 0..bit_width {
statements_flattened.push(FlatStatement::Condition(
FlatExpression::Identifier(lhs_bits_be[i]),
FlatExpression::Mult(
box FlatExpression::Identifier(lhs_bits_be[i]),
box FlatExpression::Identifier(lhs_bits_be[i]),
),
));
}
// bit decomposition check
let mut lhs_sum = FlatExpression::Number(T::from(0));
for i in 0..bit_width {
lhs_sum = FlatExpression::Add(
box lhs_sum,
box FlatExpression::Mult(
box FlatExpression::Identifier(lhs_bits_be[i]),
box FlatExpression::Number(T::from(2).pow(bit_width - i - 1)),
),
);
}
statements_flattened.push(FlatStatement::Condition(
FlatExpression::Identifier(lhs_id),
lhs_sum,
));
// check that this decomposition does not overflow the field
self.enforce_strict_le_check(
statements_flattened,
&lhs_bits_be,
&T::max_value().bit_vector_be(),
);
let conditions = self.strict_le_check(
statements_flattened,
&lhs_bits_be,
&constant.bit_vector_be(),
);
// define the condition as `conditions.len() - sum(conditions)`
let condition = FlatExpression::Sub(
box FlatExpression::Number(T::from(conditions.len())),
box conditions
.into_iter()
.fold(FlatExpression::Number(T::zero()), |acc, e| {
FlatExpression::Add(box acc, box e)
}),
);
// return `condition == 0`
// copy pasted the fieldeq flattening code
let name_y = self.use_sym();
let name_m = self.use_sym();
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
vec![name_y, name_m],
Solver::ConditionEq,
vec![condition.clone()],
)));
statements_flattened.push(FlatStatement::Condition(
FlatExpression::Identifier(name_y),
FlatExpression::Mult(
box condition.clone(),
box FlatExpression::Identifier(name_m),
),
));
let res = FlatExpression::Sub(
box FlatExpression::Number(T::one()),
box FlatExpression::Identifier(name_y),
);
statements_flattened.push(FlatStatement::Condition(
FlatExpression::Number(T::zero()),
FlatExpression::Mult(box res.clone(), box condition),
));
res
}
(lhs_flattened, rhs_flattened) => {
// lhs
let lhs_id = self.define(lhs_flattened, statements_flattened);
@ -404,11 +523,13 @@ impl<'ast, T: Field> Flattener<'ast, T> {
(0..safe_width).map(|_| self.use_sym()).collect();
// add a directive to get the bits
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
statements_flattened.push(FlatStatement::Directive(
FlatDirective::new(
lhs_bits_be.clone(),
Solver::bits(safe_width),
vec![lhs_id],
)));
),
));
// bitness checks
for i in 0..safe_width {
@ -429,7 +550,9 @@ impl<'ast, T: Field> Flattener<'ast, T> {
box lhs_sum,
box FlatExpression::Mult(
box FlatExpression::Identifier(lhs_bits_be[i]),
box FlatExpression::Number(T::from(2).pow(safe_width - i - 1)),
box FlatExpression::Number(
T::from(2).pow(safe_width - i - 1),
),
),
);
}
@ -450,11 +573,13 @@ impl<'ast, T: Field> Flattener<'ast, T> {
(0..safe_width).map(|_| self.use_sym()).collect();
// add a directive to get the bits
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
statements_flattened.push(FlatStatement::Directive(
FlatDirective::new(
rhs_bits_be.clone(),
Solver::bits(safe_width),
vec![rhs_id],
)));
),
));
// bitness checks
for i in 0..safe_width {
@ -475,7 +600,9 @@ impl<'ast, T: Field> Flattener<'ast, T> {
box rhs_sum,
box FlatExpression::Mult(
box FlatExpression::Identifier(rhs_bits_be[i]),
box FlatExpression::Number(T::from(2).pow(safe_width - i - 1)),
box FlatExpression::Number(
T::from(2).pow(safe_width - i - 1),
),
),
);
}
@ -521,10 +648,10 @@ impl<'ast, T: Field> Flattener<'ast, T> {
}
// check that the decomposition is in the field with a strict `< p` checks
self.strict_le_check(
self.enforce_strict_le_check(
statements_flattened,
&T::max_value_bit_vector_be(),
sub_bits_be.clone(),
&sub_bits_be,
&T::max_value().bit_vector_be(),
);
// sum(sym_b{i} * 2**i)
@ -540,10 +667,13 @@ impl<'ast, T: Field> Flattener<'ast, T> {
);
}
statements_flattened.push(FlatStatement::Condition(subtraction_result, expr));
statements_flattened
.push(FlatStatement::Condition(subtraction_result, expr));
FlatExpression::Identifier(sub_bits_be[bit_width - 1])
}
}
}
BooleanExpression::BoolEq(box lhs, box rhs) => {
// lhs and rhs are booleans, they flatten to 0 or 1
let x = self.flatten_boolean_expression(symbols, statements_flattened, lhs);

View file

@ -106,7 +106,8 @@ pub trait Field:
/// Gets the number of bits
fn bits(&self) -> u32;
/// Returns this `Field`'s largest value as a big-endian bit vector
fn max_value_bit_vector_be() -> Vec<bool> {
/// Always returns `Self::get_required_bits()` elements
fn bit_vector_be(&self) -> Vec<bool> {
fn bytes_to_bits(bytes: &[u8]) -> Vec<bool> {
bytes
.iter()
@ -114,13 +115,27 @@ pub trait Field:
.collect()
}
let field_bytes_le = Self::into_byte_vector(&Self::max_value());
let field_bytes_le = Self::into_byte_vector(&self);
// reverse for big-endianess
let field_bytes_be = field_bytes_le.into_iter().rev().collect::<Vec<u8>>();
let field_bits_be = bytes_to_bits(&field_bytes_be);
let field_bits_be = &field_bits_be[field_bits_be.len() - Self::get_required_bits()..];
field_bits_be.to_vec()
let field_bits_be: Vec<_> = (0..Self::get_required_bits()
.checked_sub(field_bits_be.len())
.unwrap_or(0))
.map(|_| &false)
.chain(
&field_bits_be[field_bits_be
.len()
.checked_sub(Self::get_required_bits())
.unwrap_or(0)..],
)
.cloned()
.collect();
assert_eq!(field_bits_be.len(), Self::get_required_bits());
field_bits_be
}
/// Returns the value as a BigUint
fn to_biguint(&self) -> BigUint;