//! Module containing the `Flattener` to process a program that is R1CS-able. //! //! @file flatten.rs //! @author Dennis Kuhnert //! @author Jacob Eberhardt //! @date 2017 mod utils; use self::utils::flat_expression_from_bits; use zokrates_ast::{ common::{expressions::ValueExpression, Span}, zir::{ canonicalizer::ZirCanonicalizer, ConditionalExpression, Expr, Folder, SelectExpression, ShouldReduce, UMetadata, ZirAssemblyStatement, ZirExpressionList, ZirProgram, }, }; use zokrates_interpreter::Interpreter; use std::collections::{ hash_map::{Entry, HashMap}, VecDeque, }; use std::ops::*; use zokrates_ast::common::embed::*; use zokrates_ast::common::FlatEmbed; use zokrates_ast::common::WithSpan; use zokrates_ast::common::{flat::Variable, RuntimeError}; use zokrates_ast::flat::*; use zokrates_ast::ir::Solver; use zokrates_ast::zir::types::{Type, UBitwidth}; use zokrates_ast::zir::{ BooleanExpression, Conditional, FieldElementExpression, Identifier, Parameter as ZirParameter, UExpression, UExpressionInner, Variable as ZirVariable, ZirExpression, ZirStatement, }; use zokrates_common::CompileConfig; use zokrates_field::Field; /// A container for statements produced during code generation /// New statements are registered with the span set in the container #[derive(Default)] pub struct FlatStatements<'ast, T> { span: Option, buffer: VecDeque>, } impl<'ast, T> FlatStatements<'ast, T> { fn push_back(&mut self, s: FlatStatement<'ast, T>) { self.buffer.push_back(s.span(self.span)) } fn pop_front(&mut self) -> Option> { self.buffer.pop_front() } fn extend(&mut self, i: impl IntoIterator>) { self.buffer.extend(i.into_iter().map(|s| s.span(self.span))) } fn set_span(&mut self, span: Option) { self.span = span; } fn is_empty(&self) -> bool { self.buffer.is_empty() } } impl<'ast, T> IntoIterator for FlatStatements<'ast, T> { type Item = FlatStatement<'ast, T>; type IntoIter = std::collections::vec_deque::IntoIter; fn into_iter(self) -> Self::IntoIter { self.buffer.into_iter() } } /// Flattens a function /// /// # Arguments /// * `funct` - `ZirFunction` that will be flattened pub fn from_program_and_config( prog: ZirProgram, config: CompileConfig, ) -> FlattenerIterator { let funct = prog.main; let mut flattener = Flattener::new(config); let mut statements_flattened = FlatStatements::default(); // push parameters let arguments_flattened = funct .arguments .into_iter() .map(|p| flattener.use_parameter(&p, &mut statements_flattened)) .collect(); FlattenerIterator { arguments: arguments_flattened, statements: FlattenerIteratorInner { statements: funct.statements.into(), statements_flattened, flattener, }, return_count: funct.signature.outputs.len(), module_map: prog.module_map, } } pub struct FlattenerIteratorInner<'ast, T> { pub statements: VecDeque>, pub statements_flattened: FlatStatements<'ast, T>, pub flattener: Flattener<'ast, T>, } pub type FlattenerIterator<'ast, T> = FlatProgIterator<'ast, T, FlattenerIteratorInner<'ast, T>>; impl<'ast, T: Field> Iterator for FlattenerIteratorInner<'ast, T> { type Item = FlatStatement<'ast, T>; fn next(&mut self) -> Option { while self.statements_flattened.is_empty() { match self.statements.pop_front() { Some(s) => { self.flattener .flatten_statement(&mut self.statements_flattened, s); } None => { break; } } } self.statements_flattened.pop_front() } } /// Flattener, computes flattened program. #[derive(Debug)] pub struct Flattener<'ast, T> { config: CompileConfig, /// Index of the next introduced variable while processing the program. next_var_idx: usize, /// `Variable`s corresponding to each `Identifier` layout: HashMap, Variable>, /// Cached bit decompositions to avoid re-generating them bits_cache: HashMap, Vec>>, } trait FlattenOutput: Sized { fn flat(self) -> FlatExpression; } impl FlattenOutput for FlatExpression { fn flat(self) -> FlatExpression { self } } impl FlattenOutput for FlatUExpression { fn flat(self) -> FlatExpression { self.get_field_unchecked() } } // We introduce a trait in order to make it possible to make flattening `e` generic over the type of `e` trait Flatten<'ast, T: Field>: From> + Conditional<'ast, T> { type Output: FlattenOutput; fn flatten( self, flattener: &mut Flattener<'ast, T>, statements_flattened: &mut FlatStatements<'ast, T>, ) -> Self::Output; } impl<'ast, T: Field> Flatten<'ast, T> for FieldElementExpression<'ast, T> { type Output = FlatExpression; fn flatten( self, flattener: &mut Flattener<'ast, T>, statements_flattened: &mut FlatStatements<'ast, T>, ) -> Self::Output { flattener.flatten_field_expression(statements_flattened, self) } } impl<'ast, T: Field> Flatten<'ast, T> for UExpression<'ast, T> { type Output = FlatUExpression; fn flatten( self, flattener: &mut Flattener<'ast, T>, statements_flattened: &mut FlatStatements<'ast, T>, ) -> Self::Output { flattener.flatten_uint_expression(statements_flattened, self) } } impl<'ast, T: Field> Flatten<'ast, T> for BooleanExpression<'ast, T> { type Output = FlatExpression; fn flatten( self, flattener: &mut Flattener<'ast, T>, statements_flattened: &mut FlatStatements<'ast, T>, ) -> Self::Output { flattener.flatten_boolean_expression(statements_flattened, self) } } #[derive(Clone, Debug)] struct FlatUExpression { field: Option>, bits: Option>>, } impl WithSpan for FlatUExpression { fn span(self, span: Option) -> Self { Self { field: self.field.map(|e| e.span(span)), bits: self .bits .map(|bits| bits.into_iter().map(|b| b.span(span)).collect()), } } fn get_span(&self) -> Option { let field_span = self.field.as_ref().map(|f| f.get_span()); field_span.unwrap_or_else(|| unimplemented!()) } } impl FlatUExpression { fn default() -> Self { FlatUExpression { field: None, bits: None, } } } impl FlatUExpression { fn field>>>(mut self, e: U) -> Self { self.field = e.into(); self } fn bits>>>>(mut self, e: U) -> Self { self.bits = e.into(); self } fn with_field>>>(e: U) -> Self { Self::default().field(e) } fn with_bits>>>>(e: U) -> Self { Self::default().bits(e) } } impl FlatUExpression { fn get_field_unchecked(self) -> FlatExpression { match self.field { Some(f) => f, None => match self.bits { Some(bits) => flat_expression_from_bits(bits), None => unreachable!(), }, } } } impl<'ast, T: Field> Flattener<'ast, T> { /// Returns a `Flattener` with fresh `layout`. fn new(config: CompileConfig) -> Flattener<'ast, T> { Flattener { config, next_var_idx: 0, layout: HashMap::new(), bits_cache: HashMap::new(), } } /// Flattens a definition, trying to avoid creating redundant variables fn define( &mut self, e: FlatExpression, statements_flattened: &mut FlatStatements<'ast, T>, ) -> Variable { match e { FlatExpression::Identifier(id) => id.id, e => { let res = self.use_sym(); statements_flattened.push_back(FlatStatement::definition(res, e)); res } } } /// Compute a range check between the bid endian decomposition of an expression and the /// big endian decomposition of a constant /// /// # Arguments /// * `a` - the big-endian bit decomposition of the expression to check against the range /// * `b` - the big-endian bit decomposition of the upper bound we're checking against /// /// # Returns /// * a vector of FlatExpression which all evaluate to `1` if `a <= b` and `0` otherwise /// /// # Notes /// /// Algorithm from [the sapling spec](https://github.com/zcash/zips/blob/master/protocol/sapling.pdf) A.3.2.2 /// /// Let's assume b = [1, 1, 1, 0] /// /// 1. Init `sizeUnknown = true` /// As long as `sizeUnknown` is `true` we don't yet know if a is <= than b. /// 2. Loop over `b`: /// * b[0] = 1 /// when `b` is 1 we check whether `a` is 0 in that particular run and update /// `sizeUnknown` accordingly: /// `sizeUnknown = sizeUnknown && a[0]` /// * b[1] = 1 /// `sizeUnknown = sizeUnknown && a[1]` /// * b[2] = 1 /// `sizeUnknown = sizeUnknown && a[2]` /// * b[3] = 0 /// we need to enforce that `a` is 0 in case `sizeUnknown`is still `true`, /// otherwise `a` can be {0,1}: /// `true == (!sizeUnknown || !a[3])` /// **true => a -> 0 /// sizeUnkown * /// **false => a -> {0,1} #[must_use] fn constant_le_check( &mut self, statements_flattened: &mut FlatStatements<'ast, T>, a: &[FlatExpression], b: &[bool], ) -> Vec> { assert_eq!(a.len(), b.len()); let is_power_of_two_minus_one = b.iter().all(|b| *b); // if `b` is all ones, then the check is always verified because that's the maximum possible value if is_power_of_two_minus_one { let statements: Vec<_> = a .iter() .map(|e| { let e_id = self.define(e.clone(), statements_flattened); FlatStatement::condition( e_id.into(), FlatExpression::mul(e_id.into(), e_id.into()), RuntimeError::Bitness, ) }) .collect(); statements_flattened.extend(statements); return vec![]; } let len = b.len(); let mut is_not_smaller_run = vec![]; let mut size_unknown = vec![]; for _ in 0..len { is_not_smaller_run.push(self.use_sym()); size_unknown.push(self.use_sym()); } // init size_unknown = true statements_flattened.push_back(FlatStatement::definition( size_unknown[0], FlatExpression::value(T::from(1)), )); let mut res = vec![]; for (i, b) in b.iter().enumerate() { if *b { statements_flattened.push_back(FlatStatement::definition( is_not_smaller_run[i], a[i].clone(), )); // don't need to update size_unknown in the last round if i < len - 1 { statements_flattened.push_back(FlatStatement::definition( size_unknown[i + 1], FlatExpression::mul(size_unknown[i].into(), is_not_smaller_run[i].into()), )); } } else { // don't need to update size_unknown in the last round if i < len - 1 { statements_flattened.push_back( // sizeUnknown is not changing in this case // We sill have to assign the old value to the variable of the current run // This trivial definition will later be removed by the optimiser FlatStatement::definition(size_unknown[i + 1], size_unknown[i].into()), ); } let or_left = FlatExpression::sub(FlatExpression::value(T::from(1)), size_unknown[i].into()); let or_right: FlatExpression<_> = FlatExpression::sub(FlatExpression::value(T::from(1)), a[i].clone()); let and_name = self.use_sym(); let and = FlatExpression::mul(or_left.clone(), or_right.clone()); statements_flattened.push_back(FlatStatement::definition(and_name, and)); let or = FlatExpression::sub(FlatExpression::add(or_left, or_right), and_name.into()); res.push(or); } } res } /// Compute an equality check between two expressions /// /// # Arguments /// /// * `statements_flattened` - Vector where new flattened statements can be added. /// * `left - the first `FlatExpression` /// * `right` - the second `FlatExpression` /// /// # Returns /// * A FlatExpression which evaluates to `1` if `left == right`, `0` otherwise fn eq_check( &mut self, statements_flattened: &mut FlatStatements<'ast, T>, left: FlatExpression, right: FlatExpression, ) -> FlatExpression { let left = self.define(left, statements_flattened); let right = self.define(right, statements_flattened); // Wanted: (Y = (X != 0) ? 1 : 0) // X = a - b // # Y = if X == 0 then 0 else 1 fi // # M = if X == 0 then 1 else 1/X fi // Y == X * M // 0 == (1-Y) * X let x = FlatExpression::sub(left.into(), right.into()); let name_y = self.use_sym(); let name_m = self.use_sym(); statements_flattened.push_back(FlatStatement::directive( vec![name_y, name_m], Solver::ConditionEq, vec![x.clone()], )); statements_flattened.push_back(FlatStatement::condition( FlatExpression::identifier(name_y), FlatExpression::mul(x.clone(), FlatExpression::identifier(name_m)), RuntimeError::Equal, )); let res = FlatExpression::sub( FlatExpression::value(T::one()), FlatExpression::identifier(name_y), ); statements_flattened.push_back(FlatStatement::condition( FlatExpression::value(T::zero()), FlatExpression::mul(res.clone(), x), RuntimeError::Equal, )); res } /// Enforce a range check against a constant: the range check isn't verified iff a constraint will fail /// /// # Arguments /// /// * `statements_flattened` - Vector where new flattened statements can be added. /// * `a` - the big-endian bit decomposition of the expression we enforce to be in range /// * `b` - the big-endian bit decomposition of the upper bound of the range fn enforce_constant_le_check_bits( &mut self, statements_flattened: &mut FlatStatements<'ast, T>, a: &[FlatExpression], c: &[bool], error: RuntimeError, ) { let conditions = self.constant_le_check(statements_flattened, a, c); let conditions_count = conditions.len(); let conditions_sum = conditions .into_iter() .fold(FlatExpression::from(T::zero()), |acc, e| { FlatExpression::add(acc, e) }); statements_flattened.push_back(FlatStatement::condition( FlatExpression::value(T::from(0)), FlatExpression::sub(conditions_sum, T::from(conditions_count).into()), error, )); } /// Enforce a range check against a constant: the range check isn't verified iff a constraint will fail /// /// # Arguments /// /// * `statements_flattened` - Vector where new flattened statements can be added. /// * `e` - the expression we enforce to be in range /// * `c` - the constant upper bound of the range fn enforce_constant_le_check( &mut self, statements_flattened: &mut FlatStatements<'ast, T>, e: FlatExpression, c: T, error: RuntimeError, ) { let c_bit_width = c.bits() as usize; let c_bits_be = c.to_bits_be(); // we reduce e `n` bits with `n` the bitwidth of `c` // if `e` does not fit in `n` bits, this will fail // but as we are asserting `e < c`, `e` not fitting in `n` bits should indeed lead to an unsatisfied constraint let e_bits_be = self.get_bits_unchecked( &FlatUExpression::with_field(e), c_bit_width, c_bit_width, statements_flattened, error.clone(), ); self.enforce_constant_le_check_bits( statements_flattened, &e_bits_be, &c_bits_be[T::get_required_bits() - c_bit_width..], error, ); } /// Enforce a range check against a constant: the range check isn't verified iff a constraint will fail /// /// # Arguments /// /// * `statements_flattened` - Vector where new flattened statements can be added. /// * `e` - the expression we enforce to be in range /// * `c` - the constant upper bound of the range fn enforce_constant_lt_check( &mut self, statements_flattened: &mut FlatStatements<'ast, T>, e: FlatExpression, c: T, error: RuntimeError, ) { // `e < 0` will always result in false value, so we constrain `0 == 1` if c == T::zero() { statements_flattened.push_back(FlatStatement::condition( T::zero().into(), T::one().into(), error, )); } else { self.enforce_constant_le_check(statements_flattened, e, c - T::one(), error) } } fn make_conditional( &mut self, statements: FlatStatements<'ast, T>, condition: FlatExpression, ) -> FlatStatements<'ast, T> { statements .into_iter() .flat_map(|s| match s { FlatStatement::Condition(s) => { let span = s.get_span(); let mut output = FlatStatements::default(); output.set_span(span); // we transform (a == b) into (c => (a == b)) which is (!c || (a == b)) // let's introduce new variables to make sure everything is linear let name_lin = self.define(s.lin, &mut output); let name_quad = self.define(s.quad, &mut output); // let's introduce an expression which is 1 iff `a == b` let y = FlatExpression::add( FlatExpression::sub(name_lin.into(), name_quad.into()), T::one().into(), ); // let's introduce !c let x = FlatExpression::sub(T::one().into(), condition.clone()); assert!(x.is_linear() && y.is_linear()); let name_x_or_y = self.use_sym(); output.push_back(FlatStatement::directive( vec![name_x_or_y], Solver::Or, vec![x.clone(), y.clone()], )); output.push_back(FlatStatement::condition( FlatExpression::add( x.clone(), FlatExpression::sub(y.clone(), name_x_or_y.into()), ), FlatExpression::mul(x, y), RuntimeError::BranchIsolation, )); output.push_back(FlatStatement::condition( name_x_or_y.into(), T::one().into(), s.error, )); output } s => { let mut v = FlatStatements::default(); v.push_back(s); v } }) .fold(FlatStatements::default(), |mut acc, s| { acc.push_back(s); acc }) } /// Flatten an if/else expression /// /// # Arguments /// /// * `statements_flattened` - Vector where new flattened statements can be added. /// * `condition` - the condition as a `BooleanExpression`. /// * `consequence` - the consequence of type U. /// * `alternative` - the alternative of type U. /// # Remarks /// * U is the type of the expression fn flatten_conditional_expression>( &mut self, statements_flattened: &mut FlatStatements<'ast, T>, e: ConditionalExpression<'ast, T, U>, ) -> FlatUExpression { let condition = *e.condition; let consequence = *e.consequence; let alternative = *e.alternative; let condition_flat = self.flatten_boolean_expression(statements_flattened, condition.clone()); let condition_id = self.use_sym(); statements_flattened.push_back(FlatStatement::definition(condition_id, condition_flat)); let (consequence, alternative) = if self.config.isolate_branches { let mut consequence_statements = FlatStatements::default(); let consequence = consequence.flatten(self, &mut consequence_statements); let mut alternative_statements = FlatStatements::default(); let alternative = alternative.flatten(self, &mut alternative_statements); let consequence_statements = self.make_conditional(consequence_statements, condition_id.into()); let alternative_statements = self.make_conditional( alternative_statements, FlatExpression::sub(FlatExpression::value(T::one()), condition_id.into()), ); statements_flattened.extend(consequence_statements); statements_flattened.extend(alternative_statements); (consequence, alternative) } else { ( consequence.flatten(self, statements_flattened), alternative.flatten(self, statements_flattened), ) }; let consequence = consequence.flat(); let alternative = alternative.flat(); let consequence_id = self.use_sym(); statements_flattened.push_back(FlatStatement::definition(consequence_id, consequence)); let alternative_id = self.use_sym(); statements_flattened.push_back(FlatStatement::definition(alternative_id, alternative)); let term0_id = self.use_sym(); statements_flattened.push_back(FlatStatement::definition( term0_id, FlatExpression::mul(condition_id.into(), FlatExpression::from(consequence_id)), )); let term1_id = self.use_sym(); statements_flattened.push_back(FlatStatement::definition( term1_id, FlatExpression::mul( FlatExpression::sub(FlatExpression::value(T::one()), condition_id.into()), FlatExpression::from(alternative_id), ), )); let res = self.use_sym(); statements_flattened.push_back(FlatStatement::definition( res, FlatExpression::add( FlatExpression::from(term0_id), FlatExpression::from(term1_id), ), )); FlatUExpression { field: Some(FlatExpression::identifier(res)), bits: None, } } /// Compute a strict check against a constant /// # Arguments /// * `statements_flattened` - Vector where new flattened statements can be added. /// * `e` - the `FlatExpression` that's being checked against the range. /// * `c` - the constant strict upper bound of the range /// /// # Returns /// * a `FlatExpression` which evaluates to `1` if `0 <= e < c`, and to `0` otherwise fn constant_lt_check( &mut self, statements_flattened: &mut FlatStatements<'ast, T>, e: FlatExpression, c: T, ) -> FlatExpression { if c == T::zero() { // this is the case c == 0, we return 0, aka false return T::zero().into(); } self.constant_field_le_check(statements_flattened, e, c - T::one()) } /// Compute a range check against a constant /// /// # Arguments /// /// * `statements_flattened` - Vector where new flattened statements can be added. /// * `e` - the `FlatExpression` that's being checked against the range. /// * `c` - the constant upper bound of the range /// /// # Returns /// * a `FlatExpression` which evaluates to `1` if `0 <= e <= c`, and to `0` otherwise fn constant_field_le_check( &mut self, statements_flattened: &mut FlatStatements<'ast, T>, e: FlatExpression, c: T, ) -> FlatExpression { let bitwidth = T::get_required_bits(); // we want to reduce `e <= c` to 0 or 1, without ever throwing. We know the bitwidth of `c` and want to minimize the bitwidth we reduce `e` to. // we must use the maximum bitwidth, as otherwise, large enough values of `e` would lead `get_bits` to throw. let e_bits_be = self.get_bits_unchecked( &FlatUExpression::with_field(e), bitwidth, bitwidth, statements_flattened, RuntimeError::ConstantLtSum, ); // check that this decomposition does not overflow the field self.enforce_constant_le_check_bits( statements_flattened, &e_bits_be, &T::max_value().to_bits_be(), RuntimeError::Le, ); let conditions = self.constant_le_check(statements_flattened, &e_bits_be, &c.to_bits_be()); // return `len(conditions) == sum(conditions)` self.eq_check( statements_flattened, T::from(conditions.len()).into(), conditions .into_iter() .fold(FlatExpression::value(T::zero()), |acc, e| { FlatExpression::add(acc, e) }), ) } #[must_use] fn le_check( &mut self, statements_flattened: &mut FlatStatements<'ast, T>, lhs_flattened: FlatExpression, rhs_flattened: FlatExpression, bit_width: usize, ) -> FlatExpression { FlatExpression::add( self.eq_check( statements_flattened, lhs_flattened.clone(), rhs_flattened.clone(), ), self.lt_check( statements_flattened, lhs_flattened, rhs_flattened, bit_width, ), ) } #[must_use] fn lt_check( &mut self, statements_flattened: &mut FlatStatements<'ast, T>, lhs_flattened: FlatExpression, rhs_flattened: FlatExpression, bit_width: usize, ) -> FlatExpression { match (lhs_flattened, rhs_flattened) { (x, FlatExpression::Value(constant)) => { self.constant_lt_check(statements_flattened, x, constant.value) } // (c < x <= p - 1) <=> (0 <= p - 1 - x < p - 1 - c) (FlatExpression::Value(constant), x) => self.constant_lt_check( statements_flattened, FlatExpression::sub(T::max_value().into(), x), T::max_value() - constant.value, ), (lhs_flattened, rhs_flattened) => { let lhs_id = self.define(lhs_flattened, statements_flattened); let rhs_id = self.define(rhs_flattened, statements_flattened); // shifted_sub := 2**safe_width + lhs - rhs let shifted_sub = FlatExpression::add( FlatExpression::value(T::from(2).pow(bit_width)), FlatExpression::sub( FlatExpression::identifier(lhs_id), FlatExpression::identifier(rhs_id), ), ); let sub_width = bit_width + 1; let shifted_sub_bits_be = self.get_bits_unchecked( &FlatUExpression::with_field(shifted_sub), sub_width, sub_width, statements_flattened, RuntimeError::IncompleteDynamicRange, ); FlatExpression::sub( FlatExpression::value(T::one()), shifted_sub_bits_be[0].clone(), ) } } } /// Flattens a boolean expression /// /// # Arguments /// /// * `statements_flattened` - Vector where new flattened statements can be added. /// * `expression` - `BooleanExpression` that will be flattened. /// /// # Postconditions /// /// * `flatten_boolean_expressions` always returns a linear expression, /// * in order to preserve composability. fn flatten_boolean_expression( &mut self, statements_flattened: &mut FlatStatements<'ast, T>, expression: BooleanExpression<'ast, T>, ) -> FlatExpression { let span = expression.get_span(); let span_backup = statements_flattened.span; statements_flattened.set_span(span); let res = match expression { BooleanExpression::Identifier(x) => { FlatExpression::identifier(*self.layout.get(&x.id).unwrap()) } BooleanExpression::Select(e) => self .flatten_select_expression(statements_flattened, e) .get_field_unchecked(), BooleanExpression::FieldLt(e) => { // Get the bit width to know the size of the binary decompositions for this Field let bit_width = T::get_required_bits(); let safe_width = bit_width - 2; // dynamic comparison is not complete, it only applies to field elements whose difference is strictly smaller than 2**(bitwidth - 2) let lhs_flattened = self.flatten_field_expression(statements_flattened, *e.left); let rhs_flattened = self.flatten_field_expression(statements_flattened, *e.right); self.lt_check( statements_flattened, lhs_flattened, rhs_flattened, safe_width, ) } BooleanExpression::BoolEq(e) => { // lhs and rhs are booleans, they flatten to 0 or 1 let x = self.flatten_boolean_expression(statements_flattened, *e.left); let y = self.flatten_boolean_expression(statements_flattened, *e.right); // Wanted: Not(X - Y)**2 which is an XNOR // We know that X and Y are [0, 1] // (X - Y) can become a negative values, which is why squaring the result is needed // Negating this returns correct result // Non-binary Truth table for logic of operation // +---+---+-------+---------------+ // | X | Y | X - Y | Not(X - Y)**2 | // +---+---+-------+---------------+ // | 1 | 1 | 0 | 1 | // | 1 | 0 | 1 | 0 | // | 0 | 1 | -1 | 0 | // | 0 | 0 | 0 | 1 | // +---+---+-------+---------------+ let x_sub_y = FlatExpression::sub(x, y); let name_x_mult_x = self.use_sym(); statements_flattened.push_back(FlatStatement::definition( name_x_mult_x, FlatExpression::mul(x_sub_y.clone(), x_sub_y), )); FlatExpression::sub( FlatExpression::value(T::one()), FlatExpression::identifier(name_x_mult_x), ) } BooleanExpression::FieldEq(e) => { let lhs = self.flatten_field_expression(statements_flattened, *e.left); let rhs = self.flatten_field_expression(statements_flattened, *e.right); self.eq_check(statements_flattened, lhs, rhs) } BooleanExpression::UintEq(e) => { // We reduce each side into range and apply the same approach as for field elements // Wanted: (Y = (X != 0) ? 1 : 0) // X = a - b // # Y = if X == 0 then 0 else 1 fi // # M = if X == 0 then 1 else 1/X fi // Y == X * M // 0 == (1-Y) * X assert!(e.left.metadata.as_ref().unwrap().should_reduce.to_bool()); assert!(e.right.metadata.as_ref().unwrap().should_reduce.to_bool()); let lhs = self .flatten_uint_expression(statements_flattened, *e.left) .get_field_unchecked(); let rhs = self .flatten_uint_expression(statements_flattened, *e.right) .get_field_unchecked(); self.eq_check(statements_flattened, lhs, rhs) } BooleanExpression::FieldLe(e) => { let lt = self.flatten_boolean_expression( statements_flattened, BooleanExpression::field_lt(*e.left.clone(), *e.right.clone()).span(span), ); let eq = self.flatten_boolean_expression( statements_flattened, BooleanExpression::field_eq(*e.left, *e.right).span(span), ); FlatExpression::add(eq, lt) } BooleanExpression::UintLt(e) => { let bit_width = e.left.bitwidth.to_usize(); assert!(e.left.metadata.as_ref().unwrap().should_reduce.to_bool()); assert!(e.right.metadata.as_ref().unwrap().should_reduce.to_bool()); let lhs_flattened = self .flatten_uint_expression(statements_flattened, *e.left) .get_field_unchecked(); let rhs_flattened = self .flatten_uint_expression(statements_flattened, *e.right) .get_field_unchecked(); self.lt_check( statements_flattened, lhs_flattened, rhs_flattened, bit_width, ) } BooleanExpression::UintLe(e) => { let lt = self.flatten_boolean_expression( statements_flattened, BooleanExpression::uint_lt(*e.left.clone(), *e.right.clone()).span(span), ); let eq = self.flatten_boolean_expression( statements_flattened, BooleanExpression::uint_eq(*e.left, *e.right).span(span), ); FlatExpression::add(eq, lt) } BooleanExpression::Or(e) => { let x = self.flatten_boolean_expression(statements_flattened, *e.left); let y = self.flatten_boolean_expression(statements_flattened, *e.right); assert!(x.is_linear() && y.is_linear()); let name_x_or_y = self.use_sym(); statements_flattened.push_back(FlatStatement::directive( vec![name_x_or_y], Solver::Or, vec![x.clone(), y.clone()], )); statements_flattened.push_back(FlatStatement::condition( FlatExpression::add( x.clone(), FlatExpression::sub(y.clone(), name_x_or_y.into()), ), FlatExpression::mul(x, y), RuntimeError::Or, )); name_x_or_y.into() } BooleanExpression::And(e) => { let x = self.flatten_boolean_expression(statements_flattened, *e.left); let y = self.flatten_boolean_expression(statements_flattened, *e.right); let name_x_and_y = self.use_sym(); assert!(x.is_linear() && y.is_linear()); statements_flattened.push_back(FlatStatement::definition( name_x_and_y, FlatExpression::mul(x, y), )); FlatExpression::identifier(name_x_and_y) } BooleanExpression::Not(e) => { let x = self.flatten_boolean_expression(statements_flattened, *e.inner); FlatExpression::sub(FlatExpression::value(T::one()), x) } BooleanExpression::Value(b) => FlatExpression::value(match b.value { true => T::from(1), false => T::from(0), }), BooleanExpression::Conditional(e) => self .flatten_conditional_expression(statements_flattened, e) .get_field_unchecked(), } .span(span); statements_flattened.set_span(span_backup); res } fn u_to_bits( &mut self, expression: FlatUExpression, bitwidth: UBitwidth, ) -> Vec> { let bits = expression.bits.unwrap(); assert_eq!(bits.len(), bitwidth.to_usize()); bits.into_iter().map(FlatUExpression::with_field).collect() } fn bits_to_u( &mut self, bits: Vec>, bitwidth: UBitwidth, ) -> FlatUExpression { let bits: Vec<_> = bits.into_iter().map(|e| e.get_field_unchecked()).collect(); assert_eq!(bits.len(), bitwidth.to_usize()); FlatUExpression::with_bits(bits) } /// Flattens a function call /// /// # Arguments /// /// * `statements_flattened` - Vector where new flattened statements can be added. /// * `id` - `Identifier of the function. /// * `return_types` - Types of the return values of the function /// * `param_expressions` - Arguments of this call fn flatten_embed_call( &mut self, statements_flattened: &mut FlatStatements<'ast, T>, embed: FlatEmbed, generics: Vec, param_expressions: Vec>, ) -> Vec> { let mut params: Vec<_> = param_expressions .into_iter() .map(|p| { if let ZirExpression::Uint(e) = &p { assert!(e.metadata.as_ref().unwrap().should_reduce.is_true()); } self.flatten_expression(statements_flattened, p) }) .collect(); match embed { FlatEmbed::FieldToBoolUnsafe => vec![params.pop().unwrap()], FlatEmbed::U8ToBits => self.u_to_bits(params.pop().unwrap(), 8.into()), FlatEmbed::U16ToBits => self.u_to_bits(params.pop().unwrap(), 16.into()), FlatEmbed::U32ToBits => self.u_to_bits(params.pop().unwrap(), 32.into()), FlatEmbed::U64ToBits => self.u_to_bits(params.pop().unwrap(), 64.into()), FlatEmbed::U8FromBits => { vec![self.bits_to_u(params, 8.into())] } FlatEmbed::U16FromBits => { vec![self.bits_to_u(params, 16.into())] } FlatEmbed::U32FromBits => { vec![self.bits_to_u(params, 32.into())] } FlatEmbed::U64FromBits => { vec![self.bits_to_u(params, 64.into())] } FlatEmbed::BitArrayLe => { // get the length of the bit arrays let len = generics[0]; // split the arguments into the two bit arrays of size `len` let (expressions, constants) = ( params[..len as usize].to_vec(), params[len as usize..].to_vec(), ); // define variables for the variable bits let variables: Vec<_> = expressions .into_iter() .map(|e| { self.define(e.get_field_unchecked(), statements_flattened) .into() }) .collect(); // get constants for the constant bits let constants: Vec<_> = constants .into_iter() .map(|e| match e.get_field_unchecked() { FlatExpression::Value(n) if n.value == T::one() => true, FlatExpression::Value(n) if n.value == T::zero() => false, _ => unreachable!(), }) .collect(); // get the list of conditions which must hold iff the `<=` relation holds let conditions = self.constant_le_check(statements_flattened, &variables, &constants); // return `len(conditions) == sum(conditions)` vec![FlatUExpression::with_field( self.eq_check( statements_flattened, T::from(conditions.len()).into(), conditions .into_iter() .fold(FlatExpression::value(T::zero()), |acc, e| { FlatExpression::add(acc, e) }), ), )] } funct => match funct { FlatEmbed::Unpack => self.flatten_embed_call_aux( statements_flattened, params, unpack_to_bitwidth(generics[0] as usize), ), #[cfg(feature = "bellman")] FlatEmbed::Sha256Round => { self.flatten_embed_call_aux(statements_flattened, params, sha256_round()) } #[cfg(feature = "ark")] FlatEmbed::SnarkVerifyBls12377 => self.flatten_embed_call_aux( statements_flattened, params, snark_verify_bls12_377::(generics[0] as usize), ), _ => unreachable!(), }, } } fn flatten_embed_call_aux( &mut self, statements_flattened: &mut FlatStatements<'ast, T>, params: Vec>, funct: FlatFunctionIterator<'ast, T, impl IntoIterator>>, ) -> Vec> { let mut replacement_map = HashMap::new(); // Handle complex parameters and assign values: // Rename Parameters, assign them to values in call. Resolve complex expressions with definitions let params_flattened = params.into_iter().map(|e| e.get_field_unchecked()); let return_values = (0..funct.return_count).map(Variable::public); for (concrete_argument, formal_argument) in params_flattened.zip(funct.arguments) { let new_var = self.define(concrete_argument, statements_flattened); replacement_map.insert(formal_argument.id, new_var); } // Ensure renaming and correct returns: // add all flattened statements, adapt return statements let statements = funct.statements.into_iter().map(|stat| match stat { FlatStatement::Block(..) => unreachable!(), FlatStatement::Definition(s) => { let new_var = self.use_sym(); replacement_map.insert(s.assignee, new_var); let new_rhs = s.rhs.apply_substitution(&replacement_map); FlatStatement::definition(new_var, new_rhs) } FlatStatement::Condition(s) => { let new_quad = s.quad.apply_substitution(&replacement_map); let new_lin = s.lin.apply_substitution(&replacement_map); FlatStatement::condition(new_lin, new_quad, s.error) } FlatStatement::Directive(d) => { let new_outputs = d .outputs .into_iter() .map(|o| { let new_o = self.use_sym(); replacement_map.insert(o, new_o); new_o }) .collect(); let new_inputs = d .inputs .into_iter() .map(|i| i.apply_substitution(&replacement_map)) .collect(); FlatStatement::directive(new_outputs, d.solver, new_inputs) } FlatStatement::Log(s) => FlatStatement::Log(LogStatement::new( s.format_string, s.expressions .into_iter() .map(|(t, e)| { ( t, e.into_iter() .map(|e| e.apply_substitution(&replacement_map)) .collect(), ) }) .collect(), )), }); statements_flattened.extend(statements); return_values .map(|x| FlatExpression::from(x).apply_substitution(&replacement_map)) .map(FlatUExpression::with_field) .collect() } /// Flattens an expression /// /// # Arguments /// /// * `statements_flattened` - Vector where new flattened statements can be added. /// * `expr` - `ZirExpression` that will be flattened. fn flatten_expression( &mut self, statements_flattened: &mut FlatStatements<'ast, T>, expr: ZirExpression<'ast, T>, ) -> FlatUExpression { match expr { ZirExpression::FieldElement(e) => { FlatUExpression::with_field(self.flatten_field_expression(statements_flattened, e)) } ZirExpression::Boolean(e) => FlatUExpression::with_field( self.flatten_boolean_expression(statements_flattened, e), ), ZirExpression::Uint(e) => self.flatten_uint_expression(statements_flattened, e), } } fn default_xor( &mut self, statements_flattened: &mut FlatStatements<'ast, T>, left: UExpression<'ast, T>, right: UExpression<'ast, T>, ) -> FlatUExpression { let left_flattened = self.flatten_uint_expression(statements_flattened, left); let right_flattened = self.flatten_uint_expression(statements_flattened, right); // `left` and `right` were reduced to the target bitwidth, hence their bits are available let left_bits = left_flattened.bits.unwrap(); let right_bits = right_flattened.bits.unwrap(); let xor: Vec> = left_bits .into_iter() .zip(right_bits.into_iter()) .map(|(x, y)| match (x, y) { (FlatExpression::Value(n), e) | (e, FlatExpression::Value(n)) => { if n.value == T::from(0) { self.define(e, statements_flattened).into() } else if n.value == T::from(1) { self.define( FlatExpression::sub(FlatExpression::value(T::from(1)), e), statements_flattened, ) .into() } else { unreachable!() } } (x, y) => { let name = self.use_sym(); statements_flattened.extend(vec![ FlatStatement::directive( vec![name], Solver::Xor, vec![x.clone(), y.clone()], ), FlatStatement::condition( FlatExpression::add( x.clone(), FlatExpression::sub(y.clone(), name.into()), ), FlatExpression::mul(FlatExpression::add(x.clone(), x), y), RuntimeError::Xor, ), ]); name.into() } }) .collect(); FlatUExpression::with_bits(xor) } fn euclidean_division( &mut self, statements_flattened: &mut FlatStatements<'ast, T>, target_bitwidth: UBitwidth, left: UExpression<'ast, T>, right: UExpression<'ast, T>, ) -> (FlatExpression, FlatExpression) { let left_flattened = self .flatten_uint_expression(statements_flattened, left) .get_field_unchecked(); let right_flattened = self .flatten_uint_expression(statements_flattened, right) .get_field_unchecked(); let n = if left_flattened.is_linear() { left_flattened } else { let id = self.use_sym(); statements_flattened.push_back(FlatStatement::definition(id, left_flattened)); FlatExpression::identifier(id) }; let d = if right_flattened.is_linear() { right_flattened } else { let id = self.use_sym(); statements_flattened.push_back(FlatStatement::definition(id, right_flattened)); FlatExpression::identifier(id) }; // introduce the quotient and remainder let q = self.use_sym(); let r = self.use_sym(); statements_flattened.push_back(FlatStatement::directive( vec![q, r], Solver::EuclideanDiv, vec![n.clone(), d.clone()], )); let target_bitwidth = target_bitwidth.to_usize(); // q in range let _ = self.get_bits_unchecked( &FlatUExpression::with_field(FlatExpression::from(q)), target_bitwidth, target_bitwidth, statements_flattened, RuntimeError::Sum, ); // r in range let _ = self.get_bits_unchecked( &FlatUExpression::with_field(FlatExpression::from(r)), target_bitwidth, target_bitwidth, statements_flattened, RuntimeError::Sum, ); // r < d <=> r - d + 2**w < 2**w let _ = self.get_bits_unchecked( &FlatUExpression::with_field(FlatExpression::add( FlatExpression::sub(r.into(), d.clone()), FlatExpression::value(T::from(2_u128.pow(target_bitwidth as u32))), )), target_bitwidth, target_bitwidth, statements_flattened, RuntimeError::Sum, ); // q*d == n - r statements_flattened.push_back(FlatStatement::condition( FlatExpression::sub(n, r.into()), FlatExpression::mul(q.into(), d), RuntimeError::Euclidean, )); (q.into(), r.into()) } /// Flattens a uint expression /// /// # Arguments /// /// * `statements_flattened` - Vector where new flattened statements can be added. /// * `expr` - `UExpression` that will be flattened. fn flatten_uint_expression( &mut self, statements_flattened: &mut FlatStatements<'ast, T>, expr: UExpression<'ast, T>, ) -> FlatUExpression { let span = expr.as_inner().get_span(); let span_backup = statements_flattened.span; statements_flattened.set_span(span); // the bitwidth for this type of uint (8, 16 or 32) let target_bitwidth = expr.bitwidth; let metadata = expr.metadata.clone().unwrap(); // the bitwidth on which this value is currently represented let actual_bitwidth = metadata.bitwidth() as usize; // whether this value should be reduced, for example if it is then used in a bitwidth operation let should_reduce = metadata.should_reduce; let should_reduce = should_reduce.to_bool(); let res = match expr.into_inner() { UExpressionInner::Value(x) => { FlatUExpression::with_field(FlatExpression::value(T::from(x.value))) } // force to be a field element UExpressionInner::Identifier(x) => { let field = FlatExpression::identifier(*self.layout.get(&x.id).unwrap()); let bits = self.bits_cache.get(&field).map(|bits| { assert_eq!(bits.len(), target_bitwidth.to_usize()); bits.clone() }); FlatUExpression::with_field(field).bits(bits) } UExpressionInner::Select(e) => self.flatten_select_expression(statements_flattened, e), UExpressionInner::Not(e) => { let e = self.flatten_uint_expression(statements_flattened, *e.inner); let e_bits = e.bits.unwrap(); assert_eq!(e_bits.len(), target_bitwidth.to_usize()); let name_not: Vec<_> = e_bits .into_iter() .map(|bit| { self.define( FlatExpression::sub(FlatExpression::value(T::from(1)), bit), statements_flattened, ) .into() }) .collect(); FlatUExpression::with_bits(name_not) } UExpressionInner::Add(e) => { let left_flattened = self .flatten_uint_expression(statements_flattened, *e.left) .get_field_unchecked(); let right_flattened = self .flatten_uint_expression(statements_flattened, *e.right) .get_field_unchecked(); let new_left = if left_flattened.is_linear() { left_flattened } else { let id = self.use_sym(); statements_flattened.push_back(FlatStatement::definition(id, left_flattened)); FlatExpression::identifier(id) }; let new_right = if right_flattened.is_linear() { right_flattened } else { let id = self.use_sym(); statements_flattened.push_back(FlatStatement::definition(id, right_flattened)); FlatExpression::identifier(id) }; FlatUExpression::with_field(FlatExpression::add(new_left, new_right)) } UExpressionInner::Sub(e) => { // see uint optimizer for the reasoning here let offset = FlatExpression::value(T::from(2).pow(std::cmp::max( e.right.metadata.as_ref().unwrap().bitwidth() as usize, target_bitwidth as usize, ))); let left_flattened = self .flatten_uint_expression(statements_flattened, *e.left) .get_field_unchecked(); let right_flattened = self .flatten_uint_expression(statements_flattened, *e.right) .get_field_unchecked(); let new_left = if left_flattened.is_linear() { left_flattened } else { let id = self.use_sym(); statements_flattened.push_back(FlatStatement::definition(id, left_flattened)); FlatExpression::identifier(id) }; let new_right = if right_flattened.is_linear() { right_flattened } else { let id = self.use_sym(); statements_flattened.push_back(FlatStatement::definition(id, right_flattened)); FlatExpression::identifier(id) }; FlatUExpression::with_field(FlatExpression::add( offset, FlatExpression::sub(new_left, new_right), )) } UExpressionInner::LeftShift(e) => { let by = match e.right.into_inner() { UExpressionInner::Value(v) => v.value as u32, _ => unreachable!(), }; let e = *e.left; let e = self.flatten_uint_expression(statements_flattened, e); let e_bits = e.bits.unwrap(); assert_eq!(e_bits.len(), target_bitwidth.to_usize()); FlatUExpression::with_bits( e_bits .into_iter() .skip(by as usize) .chain( (0..std::cmp::min(by as usize, target_bitwidth.to_usize())) .map(|_| FlatExpression::value(T::from(0))), ) .collect::>(), ) } UExpressionInner::RightShift(e) => { let by = match e.right.into_inner() { UExpressionInner::Value(v) => v.value as u32, _ => unreachable!(), }; let e = *e.left; let e = self.flatten_uint_expression(statements_flattened, e); let e_bits = e.bits.unwrap(); assert_eq!(e_bits.len(), target_bitwidth.to_usize()); FlatUExpression::with_bits( (0..std::cmp::min(by as usize, target_bitwidth.to_usize())) .map(|_| FlatExpression::value(T::from(0))) .chain(e_bits.into_iter().take( target_bitwidth.to_usize() - std::cmp::min(by as usize, target_bitwidth.to_usize()), )) .collect::>(), ) } UExpressionInner::Mult(e) => { let left_flattened = self .flatten_uint_expression(statements_flattened, *e.left) .get_field_unchecked(); let right_flattened = self .flatten_uint_expression(statements_flattened, *e.right) .get_field_unchecked(); let new_left = if left_flattened.is_linear() { left_flattened } else { let id = self.use_sym(); statements_flattened.push_back(FlatStatement::definition(id, left_flattened)); FlatExpression::identifier(id) }; let new_right = if right_flattened.is_linear() { right_flattened } else { let id = self.use_sym(); statements_flattened.push_back(FlatStatement::definition(id, right_flattened)); FlatExpression::identifier(id) }; let res = self.use_sym(); statements_flattened.push_back(FlatStatement::definition( res, FlatExpression::mul(new_left, new_right), )); FlatUExpression::with_field(FlatExpression::identifier(res)) } UExpressionInner::Div(e) => { let (q, _) = self.euclidean_division( statements_flattened, target_bitwidth, *e.left, *e.right, ); FlatUExpression::with_field(q) } UExpressionInner::Rem(e) => { let (_, r) = self.euclidean_division( statements_flattened, target_bitwidth, *e.left, *e.right, ); FlatUExpression::with_field(r) } UExpressionInner::Conditional(e) => { self.flatten_conditional_expression(statements_flattened, e) } UExpressionInner::Xor(e) => { let left_metadata = e.left.metadata.clone().unwrap(); let right_metadata = e.right.metadata.clone().unwrap(); let left_span = e.left.get_span(); let right_span = e.right.get_span(); match (e.left.into_inner(), e.right.into_inner()) { (UExpressionInner::And(e), UExpressionInner::And(ee)) => { let a = *e.left; let b = *e.right; let aa = *ee.left; let c = *ee.right; if aa.as_inner() == UExpression::not(a.clone()).as_inner() { let a_flattened = self.flatten_uint_expression(statements_flattened, a); let b_flattened = self.flatten_uint_expression(statements_flattened, b); let c_flattened = self.flatten_uint_expression(statements_flattened, c); let a_bits = a_flattened.bits.unwrap(); let b_bits = b_flattened.bits.unwrap(); let c_bits = c_flattened.bits.unwrap(); let res: Vec> = a_bits .into_iter() .zip(b_bits.into_iter()) .zip(c_bits.into_iter()) .map(|((a, b), c)| { // a(b - c) = ch - c let ch = self.use_sym(); statements_flattened.extend(vec![ FlatStatement::directive( vec![ch], Solver::ShaCh, vec![a.clone(), b.clone(), c.clone()], ), FlatStatement::condition( FlatExpression::sub(ch.into(), c.clone()), a * (b - c), RuntimeError::ShaXor, ), ]); ch.into() }) .collect(); FlatUExpression::with_bits(res) } else { self.default_xor( statements_flattened, UExpression::and(a, b).metadata(left_metadata), UExpression::and(aa, c).metadata(right_metadata), ) } } (UExpressionInner::Xor(e), c) => { let a_metadata = e.left.metadata.clone().unwrap(); let b_metadata = e.right.metadata.clone().unwrap(); let a_span = e.left.get_span(); let b_span = e.right.get_span(); let c_span = right_span; match (e.left.into_inner(), e.right.into_inner(), c) { ( UExpressionInner::And(e0), UExpressionInner::And(e1), UExpressionInner::And(e2), ) => { let a = *e0.left; let b = *e0.right; let aa = *e1.left; let c = *e1.right; let bb = *e2.left; let cc = *e2.right; if (aa == a) && (bb == b) && (cc == c) { let a_flattened = self.flatten_uint_expression(statements_flattened, a); let b_flattened = self.flatten_uint_expression(statements_flattened, b); let c_flattened = self.flatten_uint_expression(statements_flattened, c); let a_bits = a_flattened.bits.unwrap(); let b_bits = b_flattened.bits.unwrap(); let c_bits = c_flattened.bits.unwrap(); let res: Vec> = a_bits .into_iter() .zip(b_bits.into_iter()) .zip(c_bits.into_iter()) .map(|((a, b), c)| { // (b) * (c) = (bc) // (2bc - b - c) * (a) = bc - maj let maj = self.use_sym(); let bc = self.use_sym(); statements_flattened.extend(vec![ FlatStatement::directive( vec![maj], Solver::ShaAndXorAndXorAnd, vec![a.clone(), b.clone(), c.clone()], ), FlatStatement::condition( bc.into(), FlatExpression::mul(b.clone(), c.clone()), RuntimeError::ShaXor, ), FlatStatement::condition( FlatExpression::sub(bc.into(), maj.into()), FlatExpression::mul( FlatExpression::sub( FlatExpression::add( bc.into(), bc.into(), ), FlatExpression::add(b, c), ), a, ), RuntimeError::ShaXor, ), ]); maj.into() }) .collect(); FlatUExpression::with_bits(res) } else { self.default_xor( statements_flattened, UExpression::xor( UExpression::and(a, b) .metadata(a_metadata) .span(a_span), UExpression::and(aa, c) .metadata(b_metadata) .span(b_span), ) .metadata(left_metadata) .span(left_span), UExpression::and(bb, cc) .metadata(right_metadata) .span(c_span), ) } } (a, b, c) => self.default_xor( statements_flattened, UExpression::xor( a.annotate(target_bitwidth) .metadata(a_metadata) .span(a_span), b.annotate(target_bitwidth) .metadata(b_metadata) .span(b_span), ) .metadata(left_metadata) .span(left_span), c.annotate(target_bitwidth) .metadata(right_metadata) .span(c_span), ), } } (left_i, right_i) => self.default_xor( statements_flattened, left_i .annotate(target_bitwidth) .metadata(left_metadata) .span(left_span), right_i .annotate(target_bitwidth) .metadata(right_metadata) .span(right_span), ), } } UExpressionInner::And(e) => { let left_flattened = self.flatten_uint_expression(statements_flattened, *e.left); let right_flattened = self.flatten_uint_expression(statements_flattened, *e.right); let left_bits = left_flattened.bits.unwrap(); let right_bits = right_flattened.bits.unwrap(); assert_eq!(left_bits.len(), target_bitwidth.to_usize()); assert_eq!(right_bits.len(), target_bitwidth.to_usize()); let and: Vec<_> = left_bits .into_iter() .zip(right_bits.into_iter()) .map(|(x, y)| match (x, y) { (FlatExpression::Value(n), e) | (e, FlatExpression::Value(n)) => { if n.value == T::from(0) { FlatExpression::value(T::from(0)) } else if n.value == T::from(1) { e } else { unreachable!(); } } (x, y) => self .define(FlatExpression::mul(x, y), statements_flattened) .into(), }) .collect(); FlatUExpression::with_bits(and) } UExpressionInner::Or(e) => { let left_flattened = self.flatten_uint_expression(statements_flattened, *e.left); let right_flattened = self.flatten_uint_expression(statements_flattened, *e.right); let left_bits = left_flattened.bits.unwrap(); let right_bits = right_flattened.bits.unwrap(); assert_eq!(left_bits.len(), target_bitwidth.to_usize()); assert_eq!(right_bits.len(), target_bitwidth.to_usize()); assert_eq!(left_bits.len(), target_bitwidth.to_usize()); assert_eq!(right_bits.len(), target_bitwidth.to_usize()); let or: Vec> = left_bits .into_iter() .zip(right_bits.into_iter()) .map(|(x, y)| match (x, y) { (FlatExpression::Value(n), e) | (e, FlatExpression::Value(n)) => { if n.value == T::from(0) { self.define(e, statements_flattened).into() } else if n.value == T::from(1) { FlatExpression::value(T::from(1)) } else { unreachable!() } } (x, y) => { let name = self.use_sym(); statements_flattened.extend(vec![ FlatStatement::directive( vec![name], Solver::Or, vec![x.clone(), y.clone()], ), FlatStatement::condition( FlatExpression::add( x.clone(), FlatExpression::sub(y.clone(), name.into()), ), FlatExpression::mul(x, y), RuntimeError::Or, ), ]); name.into() } }) .collect(); FlatUExpression::with_bits(or) } } .span(span); let res = match should_reduce { true => { let bits = self.get_bits_unchecked( &res, actual_bitwidth, target_bitwidth.to_usize(), statements_flattened, RuntimeError::Sum, ); let field = if actual_bitwidth > target_bitwidth.to_usize() { bits.iter().enumerate().fold( FlatExpression::value(T::from(0)), |acc, (index, bit)| { FlatExpression::add( acc, FlatExpression::mul( FlatExpression::value( T::from(2).pow(target_bitwidth.to_usize() - index - 1), ), bit.clone(), ), ) }, ) } else { res.get_field_unchecked() }; FlatUExpression::with_bits(bits).field(field) } false => res, }; statements_flattened.set_span(span_backup); res } /// Get the bits for a FlatUExpression /// /// # Arguments /// /// * `e` - the FlatUExpression to get the bits from /// * `from` - the original bitwidth of `e` /// * `to` - the bitwidth of the output representation /// * `statement_buffer` - a buffer to add statements to /// * `error` - the error to throw at runtime in case anything fails /// /// # Notes /// * `from` and `to` must be smaller or equal to `T::get_required_bits()`, the bitwidth of the prime field /// * the result is not checked to be in range. This is fine for `to < T::get_required_bits()`, but otherwise it is the caller's responsibility to add that check fn get_bits_unchecked( &mut self, e: &FlatUExpression, from: usize, to: usize, statements_flattened: &mut FlatStatements<'ast, T>, error: RuntimeError, ) -> Vec> { assert!(from <= T::get_required_bits()); assert!(to <= T::get_required_bits()); // constants do not require directives if let Some(FlatExpression::Value(ref x)) = e.field { let bits: Vec<_> = Interpreter::execute_solver(&Solver::bits(to), &[x.value], &[]) .unwrap() .into_iter() .map(FlatExpression::value) .collect(); assert_eq!(bits.len(), to); self.bits_cache .insert(e.field.clone().unwrap(), bits.clone()); return bits; }; e.bits.clone().unwrap_or_else(|| { // we are not reducing a constant, therefore the result should always have a smaller bitwidth: // `to` is the target bitwidth, and `from` cannot be smaller than that unless we're looking at a // constant let from = std::cmp::max(from, to); let res = match self.bits_cache.entry(e.field.clone().unwrap()) { Entry::Occupied(entry) => { let res: Vec<_> = entry.get().clone(); // if we already know a decomposition, its number of elements has to be smaller or equal to `to` assert!(res.len() <= to); // we then pad it with zeroes on the left (big endian) to return `to` bits if res.len() == to { res } else { (0..to - res.len()) .map(|_| FlatExpression::value(T::zero())) .chain(res) .collect() } } Entry::Vacant(_) => { let bits = (0..from).map(|_| self.use_sym()).collect::>(); statements_flattened.push_back(FlatStatement::directive( bits.clone(), Solver::Bits(from), vec![e.field.clone().unwrap()], )); let bits: Vec<_> = bits.into_iter().map(FlatExpression::identifier).collect(); // decompose to the actual bitwidth // bit checks statements_flattened.extend(bits.iter().take(from).map(|bit| { FlatStatement::condition( bit.clone(), FlatExpression::mul(bit.clone(), bit.clone()), RuntimeError::Bitness, ) })); let sum = flat_expression_from_bits(bits.clone()); // sum check statements_flattened.push_back(FlatStatement::condition( e.field.clone().unwrap(), sum.clone(), error, )); // truncate to the `to` lowest bits let bits = bits[from - to..].to_vec(); assert_eq!(bits.len(), to); self.bits_cache .insert(e.field.clone().unwrap(), bits.clone()); self.bits_cache.insert(sum, bits.clone()); bits } }; assert_eq!(res.len(), to); res }) } fn flatten_select_expression>( &mut self, statements_flattened: &mut FlatStatements<'ast, T>, e: SelectExpression<'ast, T, U>, ) -> FlatUExpression { let span = e.get_span(); let array = e.array; let index = *e.index; let (range_check, result) = array .into_iter() .enumerate() .map(|(i, e)| { let condition = self.flatten_boolean_expression( statements_flattened, BooleanExpression::uint_eq( UExpression::value(i as u128) .annotate(UBitwidth::B32) .metadata(UMetadata { should_reduce: ShouldReduce::True, max: T::from(i), }) .span(span), index.clone(), ) .span(span), ); let element = e.flatten(self, statements_flattened); (condition, element) }) .collect::>() .into_iter() .fold( ( FlatExpression::value(T::zero()), FlatExpression::value(T::zero()), ), |(mut range_check, mut result), (condition, element)| { range_check = FlatExpression::add(range_check, condition.clone()).span(span); let conditional_element_id = self.use_sym(); statements_flattened.push_back( FlatStatement::definition( conditional_element_id, FlatExpression::mul(condition, element.flat()).span(span), ) .span(span), ); result = FlatExpression::add(result, conditional_element_id.into()).span(span); (range_check, result) }, ); statements_flattened.push_back(FlatStatement::condition( range_check, FlatExpression::value(T::one()), RuntimeError::SelectRangeCheck, )); FlatUExpression::with_field(result) } /// Flattens a field expression /// /// # Arguments /// /// * `statements_flattened` - Vector where new flattened statements can be added. /// * `expr` - `FieldElementExpression` that will be flattened. fn flatten_field_expression( &mut self, statements_flattened: &mut FlatStatements<'ast, T>, expr: FieldElementExpression<'ast, T>, ) -> FlatExpression { let span = expr.get_span(); let span_backup = statements_flattened.span; statements_flattened.set_span(span); let res = match expr { FieldElementExpression::Value(x) => FlatExpression::Value(x), // force to be a field element FieldElementExpression::Identifier(x) => FlatExpression::identifier( *self.layout.get(&x.id).unwrap_or_else(|| panic!("{}", x)), ), FieldElementExpression::Select(e) => self .flatten_select_expression(statements_flattened, e) .get_field_unchecked(), FieldElementExpression::Add(e) => { let left_flattened = self.flatten_field_expression(statements_flattened, *e.left); let right_flattened = self.flatten_field_expression(statements_flattened, *e.right); let new_left = if left_flattened.is_linear() { left_flattened } else { let id = self.use_sym(); statements_flattened.push_back(FlatStatement::definition(id, left_flattened)); FlatExpression::identifier(id) }; let new_right = if right_flattened.is_linear() { right_flattened } else { let id = self.use_sym(); statements_flattened.push_back(FlatStatement::definition(id, right_flattened)); FlatExpression::identifier(id) }; FlatExpression::add(new_left, new_right) } FieldElementExpression::Sub(e) => { let left_flattened = self.flatten_field_expression(statements_flattened, *e.left); let right_flattened = self.flatten_field_expression(statements_flattened, *e.right); let new_left = if left_flattened.is_linear() { left_flattened } else { let id = self.use_sym(); statements_flattened.push_back(FlatStatement::definition(id, left_flattened)); FlatExpression::identifier(id) }; let new_right = if right_flattened.is_linear() { right_flattened } else { let id = self.use_sym(); statements_flattened.push_back(FlatStatement::definition(id, right_flattened)); FlatExpression::identifier(id) }; FlatExpression::sub(new_left, new_right) } FieldElementExpression::Mult(e) => { let left_flattened = self.flatten_field_expression(statements_flattened, *e.left); let right_flattened = self.flatten_field_expression(statements_flattened, *e.right); let new_left = if left_flattened.is_linear() { left_flattened } else { let id = self.use_sym(); statements_flattened.push_back(FlatStatement::definition(id, left_flattened)); FlatExpression::identifier(id) }; let new_right = if right_flattened.is_linear() { right_flattened } else { let id = self.use_sym(); statements_flattened.push_back(FlatStatement::definition(id, right_flattened)); FlatExpression::identifier(id) }; FlatExpression::mul(new_left, new_right) } FieldElementExpression::Div(e) => { let left_flattened = self.flatten_field_expression(statements_flattened, *e.left); let right_flattened = self.flatten_field_expression(statements_flattened, *e.right); let new_left: FlatExpression = { let id = self.use_sym(); statements_flattened.push_back(FlatStatement::definition(id, left_flattened)); id.into() }; let new_right: FlatExpression = { let id = self.use_sym(); statements_flattened.push_back(FlatStatement::definition(id, right_flattened)); id.into() }; // `right` is assumed to already be non-zero so this is an unchecked division // TODO: we could save one constraint here by reusing the inverse of `right` computed earlier let inverse = self.use_sym(); // # c = a/b statements_flattened.push_back(FlatStatement::directive( vec![inverse], Solver::Div, vec![new_left.clone(), new_right.clone()], )); // assert(c * b == a) statements_flattened.push_back(FlatStatement::condition( new_left, FlatExpression::mul(new_right, inverse.into()), RuntimeError::Division, )); inverse.into() } FieldElementExpression::Pow(e) => { match e.right.into_inner() { UExpressionInner::Value(ref exp) => { // flatten the base expression let base_flattened = self.flatten_field_expression(statements_flattened, *e.left.clone()); // we require from the base to be linear // TODO change that assert!(base_flattened.is_linear()); // convert the exponent to bytes, big endian let ebytes_be = exp.value.to_be_bytes(); // convert the bytes to bits, remove leading zeroes (we only need powers up to the highest non-zero bit) #[allow(clippy::needless_collect)] // collecting is required as we then reverse let ebits_be: Vec<_> = ebytes_be .iter() .flat_map(|byte| (0..8).rev().map(move |i| byte & (1 << i) != 0)) // byte to bit, big endian .skip_while(|b| !b) // skip trailing false bits .collect(); // reverse to start with the lowest bit first let ebits_le: Vec<_> = ebits_be.into_iter().rev().collect(); // calculate all powers e**(2**i) by squaring let powers: Vec> = ebits_le .iter() .scan(None, |state, _| { match state { // the first element is the base None => { *state = Some(base_flattened.clone()); Some(base_flattened.clone()) } // any subsequent element is the square of the previous one Some(previous) => { // introduce a new variable let id = self.use_sym(); // set it to the square of the previous one, stored in state statements_flattened.push_back(FlatStatement::definition( id, FlatExpression::mul(previous.clone(), previous.clone()), )); // store it in the state for later squaring *state = Some(FlatExpression::identifier(id)); // return it for later use constructing the result Some(FlatExpression::identifier(id)) } } }) .collect(); // construct the result iterating through the bits, multiplying by the associated power iff the bit is true ebits_le.into_iter().zip(powers).fold( FlatExpression::value(T::from(1)), // initialise the result at 1. If we have no bits to iterate through, we're computing x**0 == 1 |acc, (bit, power)| match bit { true => { // update the result by introducing a new variable let id = self.use_sym(); statements_flattened.push_back(FlatStatement::definition( id, FlatExpression::mul(acc, power), // set the new result to the current result times the current power )); FlatExpression::identifier(id) } false => acc, // this bit is false, keep the previous result }, ) } _ => panic!("Expected number as pow exponent"), } } FieldElementExpression::Conditional(e) => self .flatten_conditional_expression(statements_flattened, e) .get_field_unchecked(), _ => unreachable!(), }; statements_flattened.set_span(span_backup); res } fn flatten_assembly_statement( &mut self, statements_flattened: &mut FlatStatements<'ast, T>, stat: ZirAssemblyStatement<'ast, T>, ) { let span = stat.get_span(); let span_backup = statements_flattened.span; statements_flattened.set_span(span); match stat { ZirAssemblyStatement::Assignment(s) => { let inputs: Vec> = s .expression .arguments .iter() .cloned() .map(|p| self.layout.get(&p.id.id).cloned().unwrap().into()) .collect(); let outputs: Vec = s .assignee .into_iter() .map(|assignee| self.use_variable(&assignee)) .collect(); let mut canonicalizer = ZirCanonicalizer::default(); let function = canonicalizer.fold_function(s.expression); let directive = FlatDirective::new(outputs, Solver::Zir(function), inputs); statements_flattened.push_back(FlatStatement::Directive(directive)); } ZirAssemblyStatement::Constraint(s) => { let lhs = self.flatten_field_expression(statements_flattened, s.left); let rhs = self.flatten_field_expression(statements_flattened, s.right); statements_flattened.push_back(FlatStatement::condition( lhs, rhs, RuntimeError::SourceAssemblyConstraint(s.metadata), )); } }; statements_flattened.set_span(span_backup); } /// Flattens a statement /// /// # Arguments /// /// * `statements_flattened` - Vector where new flattened statements can be added. /// * `stat` - `ZirStatement` that will be flattened. fn flatten_statement( &mut self, statements_flattened: &mut FlatStatements<'ast, T>, s: ZirStatement<'ast, T>, ) { let span = s.get_span(); let span_backup = statements_flattened.span; statements_flattened.set_span(span); match s { ZirStatement::Return(s) => { #[allow(clippy::needless_collect)] // clippy suggests to not collect here, but `statements_flattened` is borrowed in the iterator, // so we cannot borrow again when extending let flat_expressions: Vec<_> = s .inner .into_iter() .map(|expr| self.flatten_expression(statements_flattened, expr)) .map(|x| x.get_field_unchecked()) .collect(); statements_flattened.extend( flat_expressions .into_iter() .enumerate() .map(|(index, e)| FlatStatement::definition(Variable::public(index), e)), ); } ZirStatement::Assembly(s) => { let mut block_statements = FlatStatements::default(); block_statements.set_span(s.get_span()); for s in s.inner { self.flatten_assembly_statement(&mut block_statements, s); } statements_flattened .push_back(FlatStatement::block(block_statements.buffer.into())); } ZirStatement::IfElse(s) => { let condition_flat = self.flatten_boolean_expression(statements_flattened, s.condition.clone()); let condition_id = self.use_sym(); statements_flattened .push_back(FlatStatement::definition(condition_id, condition_flat)); if self.config.isolate_branches { let mut consequence_statements = FlatStatements::default(); let mut alternative_statements = FlatStatements::default(); s.consequence .into_iter() .for_each(|s| self.flatten_statement(&mut consequence_statements, s)); s.alternative .into_iter() .for_each(|s| self.flatten_statement(&mut alternative_statements, s)); let consequence_statements = self.make_conditional(consequence_statements, condition_id.into()); let alternative_statements = self.make_conditional( alternative_statements, FlatExpression::sub(FlatExpression::value(T::one()), condition_id.into()), ); statements_flattened.extend(consequence_statements); statements_flattened.extend(alternative_statements); } else { s.consequence .into_iter() .for_each(|s| self.flatten_statement(statements_flattened, s)); s.alternative .into_iter() .for_each(|s| self.flatten_statement(statements_flattened, s)); } } ZirStatement::Definition(s) => { // define n variables with n the number of primitive types for v_type // assign them to the n primitive types for expr let assignee = s.assignee; let expr = s.rhs; let rhs = self.flatten_expression(statements_flattened, expr); let bits = rhs.bits.clone(); let var = match rhs.get_field_unchecked() { FlatExpression::Identifier(id) => { self.use_variable_with_existing(&assignee, id.id); id.id } e => { let var = self.use_variable(&assignee); // handle return of function call statements_flattened.push_back(FlatStatement::definition(var, e)); var } }; // register bits if let Some(bits) = bits { self.bits_cache .insert(FlatExpression::identifier(var), bits); } } ZirStatement::Assertion(s) => { let e = s.expression; let error = s.error; match e { BooleanExpression::And(..) => { for boolean in e.into_conjunction_iterator() { self.flatten_statement( statements_flattened, ZirStatement::assertion(boolean, error.clone()).span(span), ) } } BooleanExpression::FieldEq(e) => { let lhs = self.flatten_field_expression(statements_flattened, *e.left); let rhs = self.flatten_field_expression(statements_flattened, *e.right); self.flatten_equality_assertion( statements_flattened, lhs, rhs, error.into(), ) } BooleanExpression::FieldLt(e) => { let lhs = self.flatten_field_expression(statements_flattened, *e.left); let rhs = self.flatten_field_expression(statements_flattened, *e.right); match (lhs, rhs) { (e, FlatExpression::Value(c)) => self.enforce_constant_lt_check( statements_flattened, e, c.value, error.into(), ), // c < e <=> p - 1 - e < p - 1 - c (FlatExpression::Value(c), e) => self.enforce_constant_lt_check( statements_flattened, FlatExpression::sub(T::max_value().into(), e), T::max_value() - c.value, error.into(), ), (lhs, rhs) => { let bit_width = T::get_required_bits(); let safe_width = bit_width - 2; // dynamic comparison is not complete let e = self.lt_check(statements_flattened, lhs, rhs, safe_width); statements_flattened.push_back(FlatStatement::condition( e, FlatExpression::value(T::one()), error.into(), )); } } } BooleanExpression::FieldLe(e) => { let lhs = self.flatten_field_expression(statements_flattened, *e.left); let rhs = self.flatten_field_expression(statements_flattened, *e.right); match (lhs, rhs) { (e, FlatExpression::Value(c)) => self.enforce_constant_le_check( statements_flattened, e, c.value, error.into(), ), // c <= e <=> p - 1 - e <= p - 1 - c (FlatExpression::Value(c), e) => self.enforce_constant_le_check( statements_flattened, FlatExpression::sub(T::max_value().into(), e), T::max_value() - c.value, error.into(), ), (lhs, rhs) => { let bit_width = T::get_required_bits(); let safe_width = bit_width - 2; // dynamic comparison is not complete let e = self.le_check(statements_flattened, lhs, rhs, safe_width); statements_flattened.push_back(FlatStatement::condition( e, FlatExpression::value(T::one()), error.into(), )); } } } BooleanExpression::UintLe(e) => { let lhs = self .flatten_uint_expression(statements_flattened, *e.left) .get_field_unchecked(); let rhs = self .flatten_uint_expression(statements_flattened, *e.right) .get_field_unchecked(); match (lhs, rhs) { (e, FlatExpression::Value(c)) => self.enforce_constant_le_check( statements_flattened, e, c.value, error.into(), ), // c <= e <=> p - 1 - e <= p - 1 - c (FlatExpression::Value(c), e) => self.enforce_constant_le_check( statements_flattened, FlatExpression::sub(T::max_value().into(), e), T::max_value() - c.value, error.into(), ), (lhs, rhs) => { let bit_width = T::get_required_bits(); let safe_width = bit_width - 2; // dynamic comparison is not complete let e = self.le_check(statements_flattened, lhs, rhs, safe_width); statements_flattened.push_back(FlatStatement::condition( e, FlatExpression::value(T::one()), error.into(), )); } } } BooleanExpression::UintEq(e) => { let lhs = self .flatten_uint_expression(statements_flattened, *e.left) .get_field_unchecked(); let rhs = self .flatten_uint_expression(statements_flattened, *e.right) .get_field_unchecked(); self.flatten_equality_assertion( statements_flattened, lhs, rhs, error.into(), ) } BooleanExpression::BoolEq(e) => { let lhs = self.flatten_boolean_expression(statements_flattened, *e.left); let rhs = self.flatten_boolean_expression(statements_flattened, *e.right); self.flatten_equality_assertion( statements_flattened, lhs, rhs, error.into(), ) } BooleanExpression::Not(u) => { let inner_span = u.get_span(); match *u.inner { BooleanExpression::UintEq(b) => { if let UExpressionInner::Value(ValueExpression { value: 0, .. }) = b.left.inner { let x = self .flatten_uint_expression(statements_flattened, *b.right) .get_field_unchecked(); self.enforce_not_zero_assertion(statements_flattened, x) } else if let UExpressionInner::Value(ValueExpression { value: 0, .. }) = b.right.inner { let x = self .flatten_uint_expression(statements_flattened, *b.left) .get_field_unchecked(); self.enforce_not_zero_assertion(statements_flattened, x) } else { self.enforce_naive_assertion( statements_flattened, BooleanExpression::not(BooleanExpression::UintEq(b)), error, ); } } BooleanExpression::FieldEq(b) => match (*b.left, *b.right) { ( FieldElementExpression::Value(ValueExpression { value: zero, .. }), x, ) | ( x, FieldElementExpression::Value(ValueExpression { value: zero, .. }), ) if zero == T::from(0) => { let x = self.flatten_field_expression(statements_flattened, x); self.enforce_not_zero_assertion(statements_flattened, x) } (left, right) => self.enforce_naive_assertion( statements_flattened, BooleanExpression::not( BooleanExpression::field_eq(left, right).span(inner_span), ) .span(span), error, ), }, e => self.enforce_naive_assertion( statements_flattened, BooleanExpression::not(e.span(inner_span)).span(span), error, ), } } e => self.enforce_naive_assertion(statements_flattened, e, error), } } ZirStatement::MultipleDefinition(s) => { // flatten the right side to p = sum(var_i.type.primitive_count) expressions // define p new variables to the right side expressions match s.rhs { ZirExpressionList::EmbedCall(embed, generics, exprs) => { let rhs_flattened = self.flatten_embed_call( statements_flattened, embed, generics, exprs.clone(), ); let rhs = rhs_flattened.into_iter(); assert_eq!(s.assignees.len(), rhs.len()); let assignees: Vec<_> = s .assignees .into_iter() .zip(rhs) .map(|(v, r)| match r.get_field_unchecked() { FlatExpression::Identifier(id) => { self.use_variable_with_existing(&v, id.id); id.id } e => { let id = self.use_variable(&v); statements_flattened .push_back(FlatStatement::definition(id, e)); id } }) .collect(); match embed { FlatEmbed::U64FromBits | FlatEmbed::U32FromBits | FlatEmbed::U16FromBits | FlatEmbed::U8FromBits => { let bits = exprs .into_iter() .map(|e| { self.flatten_expression(statements_flattened, e) .get_field_unchecked() }) .collect(); self.bits_cache.insert(assignees[0].into(), bits); } _ => {} } } } } ZirStatement::Log(s) => { let expressions = s .expressions .into_iter() .map(|(t, e)| { ( t, e.into_iter() .map(|e| { self.flatten_expression(statements_flattened, e) .get_field_unchecked() }) .collect(), ) }) .collect(); statements_flattened.push_back(FlatStatement::Log(LogStatement::new( s.format_string, expressions, ))); } }; statements_flattened.set_span(span_backup); } fn enforce_naive_assertion( &mut self, statements_flattened: &mut FlatStatements<'ast, T>, e: BooleanExpression<'ast, T>, error: zokrates_ast::zir::RuntimeError, ) { // naive approach: flatten the boolean to a single field element and constrain it to 1 let e = self.flatten_boolean_expression(statements_flattened, e); if e.is_linear() { statements_flattened.push_back(FlatStatement::condition( e, FlatExpression::value(T::from(1)), error.into(), )); } else { // swap so that left side is linear statements_flattened.push_back(FlatStatement::condition( FlatExpression::value(T::from(1)), e, error.into(), )); } } /// Enforce that x is not zero /// /// # Arguments /// /// * `statements_flattened` - `FlatStatements<'ast, T>` Vector where new flattened statements can be added. /// * `x` - `FlatExpression` The expression to be constrained to not be zero. fn enforce_not_zero_assertion( &mut self, statements_flattened: &mut FlatStatements<'ast, T>, x: FlatExpression, ) { // introduce intermediate variable let x_id = self.define(x, statements_flattened); // check that `x` is not 0 by giving its inverse let invx = self.use_sym(); // # invx = 1/x statements_flattened.push_back(FlatStatement::Directive(FlatDirective::new( vec![invx], Solver::Div, vec![FlatExpression::value(T::one()), x_id.into()], ))); // assert(invx * x == 1) statements_flattened.push_back(FlatStatement::condition( FlatExpression::value(T::one()), FlatExpression::mul(invx.into(), x_id.into()), RuntimeError::Inverse, )); } /// Flattens an equality assertion, enforcing it in the circuit. /// /// # Arguments /// /// * `statements_flattened` - `FlatStatements<'ast, T>` Vector where new flattened statements can be added. /// * `lhs` - `FlatExpression` Left-hand side of the equality expression. /// * `rhs` - `FlatExpression` Right-hand side of the equality expression. fn flatten_equality_assertion( &mut self, statements_flattened: &mut FlatStatements<'ast, T>, lhs: FlatExpression, rhs: FlatExpression, error: RuntimeError, ) { let (lhs, rhs) = match (lhs, rhs) { (FlatExpression::Mult(e), z) | (z, FlatExpression::Mult(e)) => ( self.identify_expression(z, statements_flattened), FlatExpression::mul( self.identify_expression(*e.left, statements_flattened), self.identify_expression(*e.right, statements_flattened), ), ), (x, z) => ( self.identify_expression(z, statements_flattened), FlatExpression::mul( self.identify_expression(x, statements_flattened), FlatExpression::value(T::from(1)), ), ), }; statements_flattened.push_back(FlatStatement::condition(lhs, rhs, error)); } /// Identifies a non-linear expression by assigning it to a new identifier. /// /// # Arguments /// /// * `e` - `FlatExpression` Expression to be assigned to an identifier. /// * `statements_flattened` - `FlatStatements<'ast, T>` Vector where new flattened statements can be added. fn identify_expression( &mut self, e: FlatExpression, statements_flattened: &mut FlatStatements<'ast, T>, ) -> FlatExpression { match e.is_linear() { true => e, false => { let sym = self.use_sym(); statements_flattened.push_back(FlatStatement::definition(sym, e)); FlatExpression::identifier(sym) } } } /// Returns a fresh Variable for a given Variable /// # Arguments /// /// * `variable` - a variable in the program being flattened fn use_variable(&mut self, variable: &ZirVariable<'ast>) -> Variable { let var = self.issue_new_variable(); self.layout.insert(variable.id.clone(), var); var } /// Reuse an existing variable for a new name /// # Arguments /// /// * `variable` - a variable in the program being flattened /// * `flat_variable` - an existing flat variable fn use_variable_with_existing( &mut self, variable: &ZirVariable<'ast>, flat_variable: Variable, ) { self.layout.insert(variable.id.clone(), flat_variable); } fn use_parameter( &mut self, parameter: &ZirParameter<'ast>, statements_flattened: &mut FlatStatements<'ast, T>, ) -> Parameter { let span = parameter.get_span(); let backup_span = statements_flattened.span; statements_flattened.set_span(span); let variable = self.use_variable(¶meter.id); match parameter.id.get_type() { Type::Uint(bitwidth) => { // to constrain unsigned integer inputs to be in range, we get their bit decomposition. // it will be cached self.get_bits_unchecked( &FlatUExpression::with_field(FlatExpression::identifier(variable)), bitwidth.to_usize(), bitwidth.to_usize(), statements_flattened, RuntimeError::Sum, ); } Type::Boolean => { statements_flattened.push_back(FlatStatement::condition( variable.into(), FlatExpression::mul(variable.into(), variable.into()), RuntimeError::ArgumentBitness, )); } Type::FieldElement => {} } statements_flattened.set_span(backup_span); Parameter::new(variable, parameter.private).span(span) } fn issue_new_variable(&mut self) -> Variable { let var = Variable::new(self.next_var_idx); self.next_var_idx += 1; var } // create an internal variable. We do not register it in the layout fn use_sym(&mut self) -> Variable { self.issue_new_variable() } } #[cfg(test)] mod tests { use super::*; use zokrates_ast::zir; use zokrates_ast::zir::types::Signature; use zokrates_ast::zir::types::Type; use zokrates_ast::zir::Id; use zokrates_ast::zir::ZirFunction; use zokrates_field::Bn128Field; fn flatten_function(f: ZirFunction) -> FlatProg { from_program_and_config( ZirProgram { main: f, module_map: Default::default(), }, CompileConfig::default(), ) .collect() } #[test] fn assertion_bool_eq() { // def main() { // bool x = true; // bool y = true; // assert(x == y); // return; // } // def main() { // _0 = 1; // _1 = 1; // _1 == (_0 * 1); // } let function = ZirFunction:: { arguments: vec![], statements: vec![ ZirStatement::definition( zir::Variable::boolean("x".into()), BooleanExpression::value(true).into(), ), ZirStatement::definition( zir::Variable::boolean("y".into()), BooleanExpression::value(true).into(), ), ZirStatement::assertion( BooleanExpression::bool_eq( BooleanExpression::identifier("x".into()), BooleanExpression::identifier("y".into()), ), zir::RuntimeError::mock(), ), ZirStatement::ret(vec![]), ], signature: Signature { inputs: vec![], outputs: vec![], }, }; let flat = flatten_function(function); let expected = FlatFunction { module_map: Default::default(), arguments: vec![], return_count: 0, statements: vec![ FlatStatement::definition( Variable::new(0), FlatExpression::value(Bn128Field::from(1)), ), FlatStatement::definition( Variable::new(1), FlatExpression::value(Bn128Field::from(1)), ), FlatStatement::condition( FlatExpression::identifier(Variable::new(1)), FlatExpression::mul( FlatExpression::identifier(Variable::new(0)), FlatExpression::value(Bn128Field::from(1)), ), zir::RuntimeError::mock().into(), ), ], }; assert_eq!(flat.collect(), expected); } #[test] fn assertion_field_eq() { // def main() { // field x = 1; // field y = 2; // assert(x + 1 == y); // return; // } // def main() { // _0 = 42; // _1 = 42; // _1 == ((_0 + 1) * 1); // } let function = ZirFunction { arguments: vec![], statements: vec![ ZirStatement::definition( zir::Variable::field_element("x"), FieldElementExpression::value(Bn128Field::from(1)).into(), ), ZirStatement::definition( zir::Variable::field_element("y"), FieldElementExpression::value(Bn128Field::from(2)).into(), ), ZirStatement::assertion( BooleanExpression::field_eq( FieldElementExpression::add( FieldElementExpression::identifier("x".into()), FieldElementExpression::value(Bn128Field::from(1)), ), FieldElementExpression::identifier("y".into()), ), zir::RuntimeError::mock(), ), ZirStatement::ret(vec![]), ], signature: Signature { inputs: vec![], outputs: vec![], }, }; let flat = flatten_function(function); let expected = FlatFunction { module_map: Default::default(), arguments: vec![], return_count: 0, statements: vec![ FlatStatement::definition( Variable::new(0), FlatExpression::value(Bn128Field::from(1)), ), FlatStatement::definition( Variable::new(1), FlatExpression::value(Bn128Field::from(2)), ), FlatStatement::condition( FlatExpression::identifier(Variable::new(1)), FlatExpression::mul( FlatExpression::add( FlatExpression::identifier(Variable::new(0)), FlatExpression::value(Bn128Field::from(1)), ), FlatExpression::value(Bn128Field::from(1)), ), zir::RuntimeError::mock().into(), ), ], }; assert_eq!(flat.collect(), expected); } #[test] fn assertion_uint_eq() { // def main() { // u32 x = 42; // assert(x == 42); // return; // } // def main() { // _0 = 42; // 42 == (_0 * 1); // } let metadata = UMetadata { max: 0xffffffff_u32.into(), should_reduce: ShouldReduce::True, }; let function = ZirFunction:: { arguments: vec![], statements: vec![ ZirStatement::definition( zir::Variable::uint("x".into(), 32), ZirExpression::Uint( UExpression::value(42) .annotate(32) .metadata(metadata.clone()), ), ), ZirStatement::assertion( BooleanExpression::uint_eq( UExpression::identifier("x".into()) .annotate(32) .metadata(metadata.clone()), UExpression::value(42).annotate(32).metadata(metadata), ), zir::RuntimeError::mock(), ), ZirStatement::ret(vec![]), ], signature: Signature { inputs: vec![], outputs: vec![], }, }; let flat = flatten_function(function); let expected = FlatFunction { module_map: Default::default(), arguments: vec![], return_count: 0, statements: vec![ FlatStatement::definition( Variable::new(0), FlatExpression::value(Bn128Field::from(42)), ), FlatStatement::condition( FlatExpression::value(Bn128Field::from(42)), FlatExpression::mul( FlatExpression::identifier(Variable::new(0)), FlatExpression::value(Bn128Field::from(1)), ), zir::RuntimeError::mock().into(), ), ], }; assert_eq!(flat.collect(), expected); } #[test] fn assertion_ident_eq_ident() { // def main() { // field x = 2; // field y = 2; // assert(x == y); // return; // } // def main() { // _0 = 2; // _1 = 2; // _1 == (_0 * 1); // } let function = ZirFunction { arguments: vec![], statements: vec![ ZirStatement::definition( zir::Variable::field_element("x"), FieldElementExpression::value(Bn128Field::from(2)).into(), ), ZirStatement::definition( zir::Variable::field_element("y"), FieldElementExpression::value(Bn128Field::from(2)).into(), ), ZirStatement::assertion( BooleanExpression::field_eq( FieldElementExpression::identifier("x".into()), FieldElementExpression::identifier("y".into()), ), zir::RuntimeError::mock(), ), ZirStatement::ret(vec![]), ], signature: Signature { inputs: vec![], outputs: vec![], }, }; let flat = flatten_function(function); let expected = FlatFunction { module_map: Default::default(), arguments: vec![], return_count: 0, statements: vec![ FlatStatement::definition( Variable::new(0), FlatExpression::value(Bn128Field::from(2)), ), FlatStatement::definition( Variable::new(1), FlatExpression::value(Bn128Field::from(2)), ), FlatStatement::condition( FlatExpression::identifier(Variable::new(1)), FlatExpression::mul( FlatExpression::identifier(Variable::new(0)), FlatExpression::value(Bn128Field::from(1)), ), zir::RuntimeError::mock().into(), ), ], }; assert_eq!(flat.collect(), expected); } #[test] fn assertion_mult_eq_ident() { // def main() { // field x = 2; // field y = 2; // field z = 4; // assert(x * y == z); // return; // } // def main() { // _0 = 2; // _1 = 2; // _2 = 4; // _2 == (_0 * _1); // } let function = ZirFunction { arguments: vec![], statements: vec![ ZirStatement::definition( zir::Variable::field_element("x"), FieldElementExpression::value(Bn128Field::from(2)).into(), ), ZirStatement::definition( zir::Variable::field_element("y"), FieldElementExpression::value(Bn128Field::from(2)).into(), ), ZirStatement::definition( zir::Variable::field_element("z"), FieldElementExpression::value(Bn128Field::from(4)).into(), ), ZirStatement::assertion( BooleanExpression::field_eq( FieldElementExpression::mul( FieldElementExpression::identifier("x".into()), FieldElementExpression::identifier("y".into()), ), FieldElementExpression::identifier("z".into()), ), zir::RuntimeError::mock(), ), ZirStatement::ret(vec![]), ], signature: Signature { inputs: vec![], outputs: vec![], }, }; let flat = flatten_function(function); let expected = FlatFunction { module_map: Default::default(), arguments: vec![], return_count: 0, statements: vec![ FlatStatement::definition( Variable::new(0), FlatExpression::value(Bn128Field::from(2)), ), FlatStatement::definition( Variable::new(1), FlatExpression::value(Bn128Field::from(2)), ), FlatStatement::definition( Variable::new(2), FlatExpression::value(Bn128Field::from(4)), ), FlatStatement::condition( FlatExpression::identifier(Variable::new(2)), FlatExpression::mul( FlatExpression::identifier(Variable::new(0)), FlatExpression::identifier(Variable::new(1)), ), zir::RuntimeError::mock().into(), ), ], }; assert_eq!(flat.collect(), expected); } #[test] fn assertion_ident_eq_mult() { // def main() { // field x = 2; // field y = 2; // field z = 4; // assert(z == x * y); // return; // } // def main() { // _0 = 2; // _1 = 2; // _2 = 4; // _2 == (_0 * _1); // } let function = ZirFunction { arguments: vec![], statements: vec![ ZirStatement::definition( zir::Variable::field_element("x"), FieldElementExpression::value(Bn128Field::from(2)).into(), ), ZirStatement::definition( zir::Variable::field_element("y"), FieldElementExpression::value(Bn128Field::from(2)).into(), ), ZirStatement::definition( zir::Variable::field_element("z"), FieldElementExpression::value(Bn128Field::from(4)).into(), ), ZirStatement::assertion( BooleanExpression::field_eq( FieldElementExpression::identifier("z".into()), FieldElementExpression::mul( FieldElementExpression::identifier("x".into()), FieldElementExpression::identifier("y".into()), ), ), zir::RuntimeError::mock(), ), ZirStatement::ret(vec![]), ], signature: Signature { inputs: vec![], outputs: vec![], }, }; let flat = flatten_function(function); let expected = FlatFunction { module_map: Default::default(), arguments: vec![], return_count: 0, statements: vec![ FlatStatement::definition( Variable::new(0), FlatExpression::value(Bn128Field::from(2)), ), FlatStatement::definition( Variable::new(1), FlatExpression::value(Bn128Field::from(2)), ), FlatStatement::definition( Variable::new(2), FlatExpression::value(Bn128Field::from(4)), ), FlatStatement::condition( FlatExpression::identifier(Variable::new(2)), FlatExpression::mul( FlatExpression::identifier(Variable::new(0)), FlatExpression::identifier(Variable::new(1)), ), zir::RuntimeError::mock().into(), ), ], }; assert_eq!(flat.collect(), expected); } #[test] fn assertion_mult_eq_mult() { // def main() { // field x = 4; // field y = 4; // field z = 8; // field t = 2; // assert(x * y == z * t); // return; // } // def main() { // _0 = 4; // _1 = 4; // _2 = 8; // _3 = 2; // _4 = (_2 * _3); // _4 == (_0 * _1); // } let function = ZirFunction { arguments: vec![], statements: vec![ ZirStatement::definition( zir::Variable::field_element("x"), FieldElementExpression::value(Bn128Field::from(4)).into(), ), ZirStatement::definition( zir::Variable::field_element("y"), FieldElementExpression::value(Bn128Field::from(4)).into(), ), ZirStatement::definition( zir::Variable::field_element("z"), FieldElementExpression::value(Bn128Field::from(8)).into(), ), ZirStatement::definition( zir::Variable::field_element("t"), FieldElementExpression::value(Bn128Field::from(2)).into(), ), ZirStatement::assertion( BooleanExpression::field_eq( FieldElementExpression::mul( FieldElementExpression::identifier("x".into()), FieldElementExpression::identifier("y".into()), ), FieldElementExpression::mul( FieldElementExpression::identifier("z".into()), FieldElementExpression::identifier("t".into()), ), ), zir::RuntimeError::mock(), ), ], signature: Signature { inputs: vec![], outputs: vec![], }, }; let flat = flatten_function(function); let expected = FlatFunction { module_map: Default::default(), arguments: vec![], return_count: 0, statements: vec![ FlatStatement::definition( Variable::new(0), FlatExpression::value(Bn128Field::from(4)), ), FlatStatement::definition( Variable::new(1), FlatExpression::value(Bn128Field::from(4)), ), FlatStatement::definition( Variable::new(2), FlatExpression::value(Bn128Field::from(8)), ), FlatStatement::definition( Variable::new(3), FlatExpression::value(Bn128Field::from(2)), ), FlatStatement::definition( Variable::new(4), FlatExpression::mul( FlatExpression::identifier(Variable::new(2)), FlatExpression::identifier(Variable::new(3)), ), ), FlatStatement::condition( FlatExpression::identifier(Variable::new(4)), FlatExpression::mul( FlatExpression::identifier(Variable::new(0)), FlatExpression::identifier(Variable::new(1)), ), zir::RuntimeError::mock().into(), ), ], }; assert_eq!(flat.collect(), expected); } #[test] fn powers_zero() { // def main() { // field a = 7; // field b = a**0; // return b; // } // def main() { // _0 = 7; // _1 = 1; // power flattening returns 1, definition introduces _7 // return _1; // } let function = ZirFunction { arguments: vec![], statements: vec![ ZirStatement::definition( zir::Variable::field_element("a"), FieldElementExpression::value(Bn128Field::from(7)).into(), ), ZirStatement::definition( zir::Variable::field_element("b"), FieldElementExpression::pow( FieldElementExpression::identifier("a".into()), 0u32.into(), ) .into(), ), ZirStatement::ret(vec![FieldElementExpression::identifier("b".into()).into()]), ], signature: Signature { inputs: vec![], outputs: vec![Type::FieldElement], }, }; let expected = FlatFunction { module_map: Default::default(), arguments: vec![], return_count: 1, statements: vec![ FlatStatement::definition( Variable::new(0), FlatExpression::value(Bn128Field::from(7)), ), FlatStatement::definition( Variable::new(1), FlatExpression::value(Bn128Field::from(1)), ), FlatStatement::definition( Variable::public(0), FlatExpression::identifier(Variable::new(1)), ), ], }; let flattened = flatten_function(function); assert_eq!(flattened, expected); } #[test] fn power_one() { // def main() { // field a = 7; // field b = a**1; // return b; // } // def main() { // _0 = 7; // _1 = 1 * _0; // x**1 // _2 = _1; // power flattening returns _1, definition introduces _2 // return _2; // } let function = ZirFunction { arguments: vec![], statements: vec![ ZirStatement::definition( zir::Variable::field_element("a"), FieldElementExpression::value(Bn128Field::from(7)).into(), ), ZirStatement::definition( zir::Variable::field_element("b"), FieldElementExpression::pow( FieldElementExpression::identifier("a".into()), 1u32.into(), ) .into(), ), ZirStatement::ret(vec![FieldElementExpression::identifier("b".into()).into()]), ], signature: Signature { inputs: vec![], outputs: vec![Type::FieldElement], }, }; let expected = FlatFunction { module_map: Default::default(), arguments: vec![], return_count: 1, statements: vec![ FlatStatement::definition( Variable::new(0), FlatExpression::value(Bn128Field::from(7)), ), FlatStatement::definition( Variable::new(1), FlatExpression::mul( FlatExpression::value(Bn128Field::from(1)), FlatExpression::identifier(Variable::new(0)), ), ), FlatStatement::definition( Variable::public(0), FlatExpression::identifier(Variable::new(1)), ), ], }; let flattened = flatten_function(function); assert_eq!(flattened, expected); } #[test] fn power_13() { // def main() { // field a = 7; // field b = a**13; // return b; // } // we apply double and add // 13 == 0b1101 // a ** 13 == a**(2**0 + 2**2 + 2**3) == a**1 * a**4 * a**8 // a_0 = a * a // a**2 // a_1 = a_0 * a_0 // a**4 // a_2 = a_1 * a_1 // a**8 // a_3 = a * a_1 // a * a**4 == a**5 // a_4 = a_3 * a_2 // a**5 * a**8 == a**13 // def main() { // _0 = 7; // _1 = (_0 * _0); // a**2 // _2 = (_1 * _1); // a**4 // _3 = (_2 * _2); // a**8 // // _4 = 1 * _0; // a // _5 = _4 * _2; // a**5 // _6 = _5 * _3; // a**13 // return _6; // } let function = ZirFunction { arguments: vec![], statements: vec![ ZirStatement::definition( zir::Variable::field_element("a"), FieldElementExpression::value(Bn128Field::from(7)).into(), ), ZirStatement::definition( zir::Variable::field_element("b"), FieldElementExpression::pow( FieldElementExpression::identifier("a".into()), 13u32.into(), ) .into(), ), ZirStatement::ret(vec![FieldElementExpression::identifier("b".into()).into()]), ], signature: Signature { inputs: vec![], outputs: vec![Type::FieldElement], }, }; let expected = FlatFunction { module_map: Default::default(), arguments: vec![], return_count: 1, statements: vec![ FlatStatement::definition( Variable::new(0), FlatExpression::value(Bn128Field::from(7)), ), FlatStatement::definition( Variable::new(1), FlatExpression::mul( FlatExpression::identifier(Variable::new(0)), FlatExpression::identifier(Variable::new(0)), ), ), FlatStatement::definition( Variable::new(2), FlatExpression::mul( FlatExpression::identifier(Variable::new(1)), FlatExpression::identifier(Variable::new(1)), ), ), FlatStatement::definition( Variable::new(3), FlatExpression::mul( FlatExpression::identifier(Variable::new(2)), FlatExpression::identifier(Variable::new(2)), ), ), FlatStatement::definition( Variable::new(4), FlatExpression::mul( FlatExpression::value(Bn128Field::from(1)), FlatExpression::identifier(Variable::new(0)), ), ), FlatStatement::definition( Variable::new(5), FlatExpression::mul( FlatExpression::identifier(Variable::new(4)), FlatExpression::identifier(Variable::new(2)), ), ), FlatStatement::definition( Variable::new(6), FlatExpression::mul( FlatExpression::identifier(Variable::new(5)), FlatExpression::identifier(Variable::new(3)), ), ), FlatStatement::definition( Variable::public(0), FlatExpression::identifier(Variable::new(6)), ), ], }; let flattened = flatten_function(function); assert_eq!(flattened, expected); } #[test] fn if_else() { let config = CompileConfig::default(); let expression = FieldElementExpression::conditional( BooleanExpression::field_eq( FieldElementExpression::value(Bn128Field::from(32)), FieldElementExpression::value(Bn128Field::from(4)), ), FieldElementExpression::value(Bn128Field::from(12)), FieldElementExpression::value(Bn128Field::from(51)), ); let mut flattener = Flattener::new(config); flattener.flatten_field_expression(&mut FlatStatements::default(), expression); } #[test] fn geq_leq() { let config = CompileConfig::default(); let mut flattener = Flattener::new(config); let expression_le = BooleanExpression::field_le( FieldElementExpression::value(Bn128Field::from(32)), FieldElementExpression::value(Bn128Field::from(4)), ); flattener.flatten_boolean_expression(&mut FlatStatements::default(), expression_le); } #[test] fn bool_and() { let config = CompileConfig::default(); let mut flattener = Flattener::new(config); let expression = FieldElementExpression::conditional( BooleanExpression::bitand( BooleanExpression::field_eq( FieldElementExpression::value(Bn128Field::from(4)), FieldElementExpression::value(Bn128Field::from(4)), ), BooleanExpression::field_lt( FieldElementExpression::value(Bn128Field::from(4)), FieldElementExpression::value(Bn128Field::from(20)), ), ), FieldElementExpression::value(Bn128Field::from(12)), FieldElementExpression::value(Bn128Field::from(51)), ); flattener.flatten_field_expression(&mut FlatStatements::default(), expression); } #[test] fn div() { // a = 5 / b / b let config = CompileConfig::default(); let mut flattener = Flattener::new(config); let mut statements_flattened = FlatStatements::default(); let definition = ZirStatement::definition( zir::Variable::field_element("b"), FieldElementExpression::value(Bn128Field::from(42)).into(), ); let statement = ZirStatement::definition( zir::Variable::field_element("a"), FieldElementExpression::div( FieldElementExpression::div( FieldElementExpression::value(Bn128Field::from(5)), FieldElementExpression::identifier("b".into()), ), FieldElementExpression::identifier("b".into()), ) .into(), ); flattener.flatten_statement(&mut statements_flattened, definition); flattener.flatten_statement(&mut statements_flattened, statement); // define b let b = Variable::new(0); // define new wires for members of Div let five = Variable::new(1); let b0 = Variable::new(2); // Define inverse let sym_0 = Variable::new(3); // Define result, which is first member to next Div let sym_1 = Variable::new(4); // Define second member let b1 = Variable::new(5); // Define inverse let sym_2 = Variable::new(6); assert_eq!( statements_flattened.buffer, vec![ FlatStatement::definition(b, FlatExpression::value(Bn128Field::from(42))), // inputs to first div (5/b) FlatStatement::definition(five, FlatExpression::value(Bn128Field::from(5))), FlatStatement::definition(b0, b.into()), // execute div FlatStatement::Directive(FlatDirective::new( vec![sym_0], Solver::Div, vec![five.into(), b0.into()] )), FlatStatement::condition( five.into(), FlatExpression::mul(b0.into(), sym_0.into()), RuntimeError::Division ), // inputs to second div (res/b) FlatStatement::definition(sym_1, sym_0.into()), FlatStatement::definition(b1, b.into()), // execute div FlatStatement::Directive(FlatDirective::new( vec![sym_2], Solver::Div, vec![sym_1.into(), b1.into()] )), FlatStatement::condition( sym_1.into(), FlatExpression::mul(b1.into(), sym_2.into()), RuntimeError::Division ), ] ); } }