From 5af047c2326ed010793c0a6f201645bb8e57916a Mon Sep 17 00:00:00 2001 From: schaeff Date: Fri, 21 Jan 2022 14:24:12 +0100 Subject: [PATCH] wip --- zokrates_core/src/flatten/mod.rs | 31 ++++++++++++++++++++----------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/zokrates_core/src/flatten/mod.rs b/zokrates_core/src/flatten/mod.rs index c92d15a5..ddfa8321 100644 --- a/zokrates_core/src/flatten/mod.rs +++ b/zokrates_core/src/flatten/mod.rs @@ -16,8 +16,10 @@ use crate::flat_absy::{RuntimeError, *}; use crate::solvers::Solver; use crate::zir::types::{Type, UBitwidth}; use crate::zir::*; -use std::collections::hash_map::Entry; -use std::collections::{hash_map::HashMap, VecDeque}; +use std::collections::{ + hash_map::{Entry, HashMap}, + VecDeque, +}; use std::convert::TryFrom; use zokrates_field::Field; @@ -440,6 +442,9 @@ impl<'ast, T: Field> Flattener<'ast, T> { let c_bit_width = c.bits() as usize; let c_bits_be = c.to_bits_be(); + // we reduce e `n` bits with `n` the bitwidth of `c` + // if `e` does not fit in `n` bits, this will fail + // but as we are asserting `e < c`, `e` not fitting in `n` bits should indeed lead to an unsatisfied constraint let e_bits_be = self.get_bits( &FlatUExpression::with_field(e), c_bit_width, @@ -667,6 +672,8 @@ impl<'ast, T: Field> Flattener<'ast, T> { let e: FlatExpression = self.define(e, statements_flattened).into(); let bitwidth = T::get_required_bits(); + // we want to reduce `e <= c` to 0 or 1, without ever throwing. We know the bitwidth of `c` and want to minimize the bitwidth we reduce `e` to. + // we must use the maximum bitwidth, as otherwise, large enough values of `e` would lead `get_bits` to throw. let e_bits_be = self.get_bits( &FlatUExpression::with_field(e), bitwidth, @@ -675,14 +682,6 @@ impl<'ast, T: Field> Flattener<'ast, T> { RuntimeError::ConstantLtSum, ); - // check that this decomposition does not overflow the field - self.enforce_constant_le_check_bits( - statements_flattened, - &e_bits_be, - &T::max_value().to_bits_be(), - RuntimeError::Le, - ); - let c_bits_be: Vec = c.to_bits_be(); let conditions = self.constant_le_check(statements_flattened, &e_bits_be, &c_bits_be); @@ -1960,9 +1959,19 @@ impl<'ast, T: Field> Flattener<'ast, T> { statements_flattened.push_back(FlatStatement::Condition( e.field.clone().unwrap(), sum.clone(), - error, + error.clone(), )); + // if the result is not unique, check that this decomposition does not overflow the field + if to == T::get_required_bits() { + self.enforce_constant_le_check_bits( + statements_flattened, + &bits, + &T::max_value().to_bits_be(), + error, + ); + } + // truncate to the `to` lowest bits let bits = bits[from - to..].to_vec();