Merge branch 'develop' into constraint-analyser
This commit is contained in:
commit
49297bf1db
3 changed files with 51 additions and 46 deletions
|
@ -203,7 +203,7 @@ fn use_variable(
|
|||
/// * the return value of the `FlatFunction` is not deterministic: as we decompose over log_2(p) + 1 bits, some
|
||||
/// elements can have multiple representations: For example, `unpack(0)` is `[0, ..., 0]` but also `unpack(p)`
|
||||
pub fn unpack<T: Field>() -> FlatFunction<T> {
|
||||
let nbits = T::get_required_bits();
|
||||
let bit_width = T::get_required_bits();
|
||||
|
||||
let mut counter = 0;
|
||||
|
||||
|
@ -221,23 +221,23 @@ pub fn unpack<T: Field>() -> FlatFunction<T> {
|
|||
format!("i0"),
|
||||
&mut counter,
|
||||
))];
|
||||
let directive_outputs: Vec<FlatVariable> = (0..T::get_required_bits())
|
||||
let directive_outputs: Vec<FlatVariable> = (0..bit_width)
|
||||
.map(|index| use_variable(&mut layout, format!("o{}", index), &mut counter))
|
||||
.collect();
|
||||
|
||||
let solver = Solver::bits();
|
||||
let solver = Solver::bits(bit_width);
|
||||
|
||||
let outputs = directive_outputs
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(index, _)| *index >= T::get_required_bits() - nbits)
|
||||
.filter(|(index, _)| *index >= T::get_required_bits() - bit_width)
|
||||
.map(|(_, o)| FlatExpression::Identifier(o.clone()))
|
||||
.collect();
|
||||
|
||||
// o253, o252, ... o{253 - (nbits - 1)} are bits
|
||||
let mut statements: Vec<FlatStatement<T>> = (0..nbits)
|
||||
// o253, o252, ... o{253 - (bit_width - 1)} are bits
|
||||
let mut statements: Vec<FlatStatement<T>> = (0..bit_width)
|
||||
.map(|index| {
|
||||
let bit = FlatExpression::Identifier(FlatVariable::new(T::get_required_bits() - index));
|
||||
let bit = FlatExpression::Identifier(FlatVariable::new(bit_width - index));
|
||||
FlatStatement::Condition(
|
||||
bit.clone(),
|
||||
FlatExpression::Mult(box bit.clone(), box bit.clone()),
|
||||
|
@ -245,14 +245,14 @@ pub fn unpack<T: Field>() -> FlatFunction<T> {
|
|||
})
|
||||
.collect();
|
||||
|
||||
// sum check: o253 + o252 * 2 + ... + o{253 - (nbits - 1)} * 2**(nbits - 1)
|
||||
// sum check: o253 + o252 * 2 + ... + o{253 - (bit_width - 1)} * 2**(bit_width - 1)
|
||||
let mut lhs_sum = FlatExpression::Number(T::from(0));
|
||||
|
||||
for i in 0..nbits {
|
||||
for i in 0..bit_width {
|
||||
lhs_sum = FlatExpression::Add(
|
||||
box lhs_sum,
|
||||
box FlatExpression::Mult(
|
||||
box FlatExpression::Identifier(FlatVariable::new(T::get_required_bits() - i)),
|
||||
box FlatExpression::Identifier(FlatVariable::new(bit_width - i)),
|
||||
box FlatExpression::Number(T::from(2).pow(i)),
|
||||
),
|
||||
);
|
||||
|
@ -312,7 +312,7 @@ mod tests {
|
|||
(0..FieldPrime::get_required_bits())
|
||||
.map(|i| FlatVariable::new(i + 1))
|
||||
.collect(),
|
||||
Solver::bits(),
|
||||
Solver::bits(FieldPrime::get_required_bits()),
|
||||
vec![FlatVariable::new(0)]
|
||||
))
|
||||
);
|
||||
|
|
|
@ -566,8 +566,9 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
FlatExpression::Identifier(self.layout.get(&x).unwrap().clone()[0])
|
||||
}
|
||||
BooleanExpression::Lt(box lhs, box rhs) => {
|
||||
// Get the bitwidth to know the size of the binary decompsitions for this Field
|
||||
let bitwidth = T::get_required_bits();
|
||||
// Get the bit width to know the size of the binary decompositions for this Field
|
||||
let bit_width = T::get_required_bits();
|
||||
let safe_width = bit_width - 2; // making sure we don't overflow, assert here?
|
||||
|
||||
// We know from semantic checking that lhs and rhs have the same type
|
||||
// What the expression will flatten to depends on that type
|
||||
|
@ -587,22 +588,22 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
{
|
||||
// define variables for the bits
|
||||
let lhs_bits_be: Vec<FlatVariable> =
|
||||
(0..bitwidth).map(|_| self.use_sym()).collect();
|
||||
(0..safe_width).map(|_| self.use_sym()).collect();
|
||||
|
||||
// add a directive to get the bits
|
||||
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
|
||||
lhs_bits_be.clone(),
|
||||
Solver::bits(),
|
||||
Solver::bits(safe_width),
|
||||
vec![lhs_id],
|
||||
)));
|
||||
|
||||
// bitness checks
|
||||
for i in 0..bitwidth - 2 {
|
||||
for i in 0..safe_width {
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
FlatExpression::Identifier(lhs_bits_be[i + 2]),
|
||||
FlatExpression::Identifier(lhs_bits_be[i]),
|
||||
FlatExpression::Mult(
|
||||
box FlatExpression::Identifier(lhs_bits_be[i + 2]),
|
||||
box FlatExpression::Identifier(lhs_bits_be[i + 2]),
|
||||
box FlatExpression::Identifier(lhs_bits_be[i]),
|
||||
box FlatExpression::Identifier(lhs_bits_be[i]),
|
||||
),
|
||||
));
|
||||
}
|
||||
|
@ -610,12 +611,12 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
// bit decomposition check
|
||||
let mut lhs_sum = FlatExpression::Number(T::from(0));
|
||||
|
||||
for i in 0..bitwidth - 2 {
|
||||
for i in 0..safe_width {
|
||||
lhs_sum = FlatExpression::Add(
|
||||
box lhs_sum,
|
||||
box FlatExpression::Mult(
|
||||
box FlatExpression::Identifier(lhs_bits_be[i + 2]),
|
||||
box FlatExpression::Number(T::from(2).pow(bitwidth - 2 - i - 1)),
|
||||
box FlatExpression::Identifier(lhs_bits_be[i]),
|
||||
box FlatExpression::Number(T::from(2).pow(safe_width - i - 1)),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
@ -634,22 +635,22 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
{
|
||||
// define variables for the bits
|
||||
let rhs_bits_be: Vec<FlatVariable> =
|
||||
(0..bitwidth).map(|_| self.use_sym()).collect();
|
||||
(0..safe_width).map(|_| self.use_sym()).collect();
|
||||
|
||||
// add a directive to get the bits
|
||||
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
|
||||
rhs_bits_be.clone(),
|
||||
Solver::bits(),
|
||||
Solver::bits(safe_width),
|
||||
vec![rhs_id],
|
||||
)));
|
||||
|
||||
// bitness checks
|
||||
for i in 0..bitwidth - 2 {
|
||||
for i in 0..safe_width {
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
FlatExpression::Identifier(rhs_bits_be[i + 2]),
|
||||
FlatExpression::Identifier(rhs_bits_be[i]),
|
||||
FlatExpression::Mult(
|
||||
box FlatExpression::Identifier(rhs_bits_be[i + 2]),
|
||||
box FlatExpression::Identifier(rhs_bits_be[i + 2]),
|
||||
box FlatExpression::Identifier(rhs_bits_be[i]),
|
||||
box FlatExpression::Identifier(rhs_bits_be[i]),
|
||||
),
|
||||
));
|
||||
}
|
||||
|
@ -657,12 +658,12 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
// bit decomposition check
|
||||
let mut rhs_sum = FlatExpression::Number(T::from(0));
|
||||
|
||||
for i in 0..bitwidth - 2 {
|
||||
for i in 0..safe_width {
|
||||
rhs_sum = FlatExpression::Add(
|
||||
box rhs_sum,
|
||||
box FlatExpression::Mult(
|
||||
box FlatExpression::Identifier(rhs_bits_be[i + 2]),
|
||||
box FlatExpression::Number(T::from(2).pow(bitwidth - 2 - i - 1)),
|
||||
box FlatExpression::Identifier(rhs_bits_be[i]),
|
||||
box FlatExpression::Number(T::from(2).pow(safe_width - i - 1)),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
@ -687,17 +688,17 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
|
||||
// define variables for the bits
|
||||
let sub_bits_be: Vec<FlatVariable> =
|
||||
(0..bitwidth).map(|_| self.use_sym()).collect();
|
||||
(0..bit_width).map(|_| self.use_sym()).collect();
|
||||
|
||||
// add a directive to get the bits
|
||||
statements_flattened.push(FlatStatement::Directive(FlatDirective::new(
|
||||
sub_bits_be.clone(),
|
||||
Solver::bits(),
|
||||
Solver::bits(bit_width),
|
||||
vec![subtraction_result.clone()],
|
||||
)));
|
||||
|
||||
// bitness checks
|
||||
for i in 0..bitwidth {
|
||||
for i in 0..bit_width {
|
||||
statements_flattened.push(FlatStatement::Condition(
|
||||
FlatExpression::Identifier(sub_bits_be[i]),
|
||||
FlatExpression::Mult(
|
||||
|
@ -710,19 +711,19 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
// sum(sym_b{i} * 2**i)
|
||||
let mut expr = FlatExpression::Number(T::from(0));
|
||||
|
||||
for i in 0..bitwidth {
|
||||
for i in 0..bit_width {
|
||||
expr = FlatExpression::Add(
|
||||
box expr,
|
||||
box FlatExpression::Mult(
|
||||
box FlatExpression::Identifier(sub_bits_be[i]),
|
||||
box FlatExpression::Number(T::from(2).pow(bitwidth - i - 1)),
|
||||
box FlatExpression::Number(T::from(2).pow(bit_width - i - 1)),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
statements_flattened.push(FlatStatement::Condition(subtraction_result, expr));
|
||||
|
||||
FlatExpression::Identifier(sub_bits_be[bitwidth - 1])
|
||||
FlatExpression::Identifier(sub_bits_be[bit_width - 1])
|
||||
}
|
||||
BooleanExpression::BoolEq(box lhs, box rhs) => {
|
||||
// lhs and rhs are booleans, they flatten to 0 or 1
|
||||
|
|
|
@ -5,7 +5,7 @@ use zokrates_field::field::Field;
|
|||
#[derive(Clone, PartialEq, Debug, Serialize, Deserialize, Hash, Eq)]
|
||||
pub enum Solver {
|
||||
ConditionEq,
|
||||
Bits,
|
||||
Bits(usize),
|
||||
Div,
|
||||
Sha256Round,
|
||||
}
|
||||
|
@ -20,7 +20,7 @@ impl Signed for Solver {
|
|||
fn get_signature(&self) -> (usize, usize) {
|
||||
match self {
|
||||
Solver::ConditionEq => (1, 2),
|
||||
Solver::Bits => (1, 254),
|
||||
Solver::Bits(bit_width) => (1, *bit_width),
|
||||
Solver::Div => (2, 1),
|
||||
Solver::Sha256Round => (768, 26935),
|
||||
}
|
||||
|
@ -37,11 +37,11 @@ impl<T: Field> Executable<T> for Solver {
|
|||
true => vec![T::zero(), T::one()],
|
||||
false => vec![T::one(), T::one() / inputs[0].clone()],
|
||||
},
|
||||
Solver::Bits => {
|
||||
Solver::Bits(bit_width) => {
|
||||
let mut num = inputs[0].clone();
|
||||
let mut res = vec![];
|
||||
let bits = 254;
|
||||
for i in (0..bits).rev() {
|
||||
|
||||
for i in (0..*bit_width).rev() {
|
||||
if T::from(2).pow(i) <= num {
|
||||
num = num - T::from(2).pow(i);
|
||||
res.push(T::one());
|
||||
|
@ -73,8 +73,8 @@ impl<T: Field> Executable<T> for Solver {
|
|||
}
|
||||
|
||||
impl Solver {
|
||||
pub fn bits() -> Self {
|
||||
Solver::Bits
|
||||
pub fn bits(width: usize) -> Self {
|
||||
Solver::Bits(width)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -125,7 +125,9 @@ mod tests {
|
|||
#[test]
|
||||
fn bits_of_one() {
|
||||
let inputs = vec![FieldPrime::from(1)];
|
||||
let res = Solver::Bits.execute(&inputs).unwrap();
|
||||
let res = Solver::Bits(FieldPrime::get_required_bits())
|
||||
.execute(&inputs)
|
||||
.unwrap();
|
||||
assert_eq!(res[253], FieldPrime::from(1));
|
||||
for i in 0..252 {
|
||||
assert_eq!(res[i], FieldPrime::from(0));
|
||||
|
@ -135,7 +137,9 @@ mod tests {
|
|||
#[test]
|
||||
fn bits_of_42() {
|
||||
let inputs = vec![FieldPrime::from(42)];
|
||||
let res = Solver::Bits.execute(&inputs).unwrap();
|
||||
let res = Solver::Bits(FieldPrime::get_required_bits())
|
||||
.execute(&inputs)
|
||||
.unwrap();
|
||||
assert_eq!(res[253], FieldPrime::from(0));
|
||||
assert_eq!(res[252], FieldPrime::from(1));
|
||||
assert_eq!(res[251], FieldPrime::from(0));
|
||||
|
|
Loading…
Reference in a new issue