improvements
This commit is contained in:
parent
1e07a90ed6
commit
8d7e5804df
24 changed files with 154 additions and 151 deletions
|
@ -1,3 +1,7 @@
|
||||||
|
// A static analyser pass to transform user-defined constraints to the form `lin_comb === quad_comb`
|
||||||
|
// This pass can fail if a non-quadratic constraint is found which cannot be transformed to the expected form
|
||||||
|
|
||||||
|
use crate::ZirPropagator;
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
use zokrates_ast::zir::lqc::LinQuadComb;
|
use zokrates_ast::zir::lqc::LinQuadComb;
|
||||||
use zokrates_ast::zir::result_folder::{fold_field_expression, ResultFolder};
|
use zokrates_ast::zir::result_folder::{fold_field_expression, ResultFolder};
|
||||||
|
@ -17,8 +21,7 @@ pub struct AssemblyTransformer;
|
||||||
|
|
||||||
impl AssemblyTransformer {
|
impl AssemblyTransformer {
|
||||||
pub fn transform<T: Field>(p: ZirProgram<T>) -> Result<ZirProgram<T>, Error> {
|
pub fn transform<T: Field>(p: ZirProgram<T>) -> Result<ZirProgram<T>, Error> {
|
||||||
let mut f = AssemblyTransformer;
|
AssemblyTransformer.fold_program(p)
|
||||||
f.fold_program(p)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -52,7 +55,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for AssemblyTransformer {
|
||||||
match is_quadratic {
|
match is_quadratic {
|
||||||
true => Ok(ZirAssemblyStatement::Constraint(lhs, rhs)),
|
true => Ok(ZirAssemblyStatement::Constraint(lhs, rhs)),
|
||||||
false => {
|
false => {
|
||||||
let sub = FieldElementExpression::Sub(box lhs.clone(), box rhs.clone());
|
let sub = FieldElementExpression::Sub(box lhs, box rhs);
|
||||||
let mut lqc = LinQuadComb::try_from(sub.clone()).map_err(|_| {
|
let mut lqc = LinQuadComb::try_from(sub.clone()).map_err(|_| {
|
||||||
Error("Non-quadratic constraints are not allowed".to_string())
|
Error("Non-quadratic constraints are not allowed".to_string())
|
||||||
})?;
|
})?;
|
||||||
|
@ -60,72 +63,72 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for AssemblyTransformer {
|
||||||
let linear = lqc
|
let linear = lqc
|
||||||
.linear
|
.linear
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.filter_map(|(c, i)| match c {
|
.map(|(c, i)| {
|
||||||
c if c == T::from(0) => None,
|
FieldElementExpression::Mult(
|
||||||
c if c == T::from(1) => Some(FieldElementExpression::identifier(i)),
|
|
||||||
_ => Some(FieldElementExpression::Mult(
|
|
||||||
box FieldElementExpression::Number(c),
|
box FieldElementExpression::Number(c),
|
||||||
box FieldElementExpression::identifier(i),
|
box FieldElementExpression::identifier(i),
|
||||||
)),
|
)
|
||||||
})
|
})
|
||||||
.reduce(|p, n| FieldElementExpression::Add(box p, box n))
|
.fold(FieldElementExpression::Number(T::from(0)), |acc, e| {
|
||||||
.unwrap_or_else(|| FieldElementExpression::Number(T::from(0)));
|
FieldElementExpression::Add(box acc, box e)
|
||||||
|
});
|
||||||
|
|
||||||
let lhs = match lqc.constant {
|
let lhs = FieldElementExpression::Add(
|
||||||
c if c == T::from(0) => linear,
|
box FieldElementExpression::Number(lqc.constant),
|
||||||
c => FieldElementExpression::Add(
|
box linear,
|
||||||
box FieldElementExpression::Number(c),
|
);
|
||||||
box linear,
|
|
||||||
),
|
|
||||||
};
|
|
||||||
|
|
||||||
let rhs: FieldElementExpression<'ast, T> = if lqc.quadratic.len() > 1 {
|
let rhs: FieldElementExpression<'ast, T> = if lqc.quadratic.len() > 1 {
|
||||||
let is_common_factor = |id: &Identifier<'ast>,
|
let common_factors = lqc.quadratic.iter().fold(
|
||||||
q: &Vec<(
|
None,
|
||||||
T,
|
|acc: Option<Vec<Identifier>>, (_, a, b)| {
|
||||||
Identifier<'ast>,
|
Some(
|
||||||
Identifier<'ast>,
|
acc.map(|factors| {
|
||||||
)>| {
|
factors
|
||||||
q.iter().all(|(_, i0, i1)| i0.eq(id) || i1.eq(id))
|
.into_iter()
|
||||||
};
|
.filter(|f| f == a || f == b)
|
||||||
|
.collect()
|
||||||
|
})
|
||||||
|
.unwrap_or_else(|| vec![a.clone(), b.clone()]),
|
||||||
|
)
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
let common_factor: Option<Identifier<'ast>> =
|
match common_factors {
|
||||||
lqc.quadratic.iter().find_map(|(_, i0, i1)| {
|
Some(factors) => Ok(FieldElementExpression::Mult(
|
||||||
if is_common_factor(i0, &lqc.quadratic) {
|
|
||||||
Some(i0.clone())
|
|
||||||
} else if is_common_factor(i1, &lqc.quadratic) {
|
|
||||||
Some(i1.clone())
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
match common_factor {
|
|
||||||
Some(id) => Ok(FieldElementExpression::Mult(
|
|
||||||
box lqc
|
box lqc
|
||||||
.quadratic
|
.quadratic
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.filter_map(|(c, i0, i1)| {
|
.map(|(c, i0, i1)| {
|
||||||
let c = T::zero() - c;
|
let c = T::zero() - c;
|
||||||
let id = if id.eq(&i0) { i1 } else { i0 };
|
let i0 = match factors.contains(&i0) {
|
||||||
match c {
|
true => FieldElementExpression::Number(T::from(1)),
|
||||||
c if c == T::from(0) => None,
|
false => FieldElementExpression::identifier(i0),
|
||||||
c if c == T::from(1) => {
|
};
|
||||||
Some(FieldElementExpression::identifier(id))
|
let i1 = match factors.contains(&i1) {
|
||||||
}
|
true => FieldElementExpression::Number(T::from(1)),
|
||||||
_ => Some(FieldElementExpression::Mult(
|
false => FieldElementExpression::identifier(i1),
|
||||||
box FieldElementExpression::Number(c),
|
};
|
||||||
box FieldElementExpression::identifier(id),
|
FieldElementExpression::Mult(
|
||||||
)),
|
box FieldElementExpression::Number(c),
|
||||||
}
|
box FieldElementExpression::Mult(box i0, box i1),
|
||||||
|
)
|
||||||
})
|
})
|
||||||
.reduce(|p, n| FieldElementExpression::Add(box p, box n))
|
.fold(
|
||||||
.unwrap_or_else(|| {
|
FieldElementExpression::Number(T::from(0)),
|
||||||
FieldElementExpression::Number(T::from(0))
|
|acc, e| FieldElementExpression::Add(box acc, box e),
|
||||||
}),
|
),
|
||||||
box FieldElementExpression::identifier(id),
|
box factors.into_iter().fold(
|
||||||
|
FieldElementExpression::Number(T::from(1)),
|
||||||
|
|acc, id| {
|
||||||
|
FieldElementExpression::Mult(
|
||||||
|
box acc,
|
||||||
|
box FieldElementExpression::identifier(id),
|
||||||
|
)
|
||||||
|
},
|
||||||
|
),
|
||||||
)),
|
)),
|
||||||
_ => Err(Error(
|
None => Err(Error(
|
||||||
"Non-quadratic constraints are not allowed".to_string(),
|
"Non-quadratic constraints are not allowed".to_string(),
|
||||||
)),
|
)),
|
||||||
}?
|
}?
|
||||||
|
@ -144,6 +147,15 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for AssemblyTransformer {
|
||||||
.unwrap_or_else(|| FieldElementExpression::Number(T::from(0)))
|
.unwrap_or_else(|| FieldElementExpression::Number(T::from(0)))
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let mut propagator = ZirPropagator::default();
|
||||||
|
let lhs = propagator
|
||||||
|
.fold_field_expression(lhs)
|
||||||
|
.map_err(|e| Error(e.to_string()))?;
|
||||||
|
|
||||||
|
let rhs = propagator
|
||||||
|
.fold_field_expression(rhs)
|
||||||
|
.map_err(|e| Error(e.to_string()))?;
|
||||||
|
|
||||||
Ok(ZirAssemblyStatement::Constraint(lhs, rhs))
|
Ok(ZirAssemblyStatement::Constraint(lhs, rhs))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -458,6 +458,27 @@ impl<'ast, T: Field> Flattener<T> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This finder looks for identifiers that were not defined in some block of statements
|
||||||
|
// These identifiers are used as function arguments when moving witness assignment expression
|
||||||
|
// to a zir function.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
// def main(field a, field mut b) -> field {
|
||||||
|
// asm {
|
||||||
|
// b <== a * a;
|
||||||
|
// }
|
||||||
|
// return b;
|
||||||
|
// }
|
||||||
|
// is turned into
|
||||||
|
// def main(field a, field mut b) -> field {
|
||||||
|
// asm {
|
||||||
|
// b <-- (field a) -> field {
|
||||||
|
// return a * a;
|
||||||
|
// }
|
||||||
|
// b == a * a;
|
||||||
|
// }
|
||||||
|
// return b;
|
||||||
|
// }
|
||||||
#[derive(Default)]
|
#[derive(Default)]
|
||||||
pub struct ArgumentFinder<'ast, T> {
|
pub struct ArgumentFinder<'ast, T> {
|
||||||
pub identifiers: HashMap<zir::Identifier<'ast>, zir::Type>,
|
pub identifiers: HashMap<zir::Identifier<'ast>, zir::Type>,
|
||||||
|
|
|
@ -12,7 +12,7 @@ use num_bigint::BigUint;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::convert::{TryFrom, TryInto};
|
use std::convert::{TryFrom, TryInto};
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
use std::ops::{BitAnd, BitOr, BitXor, Mul, Shr, Sub};
|
use std::ops::{BitAnd, BitOr, BitXor, Shl, Shr, Sub};
|
||||||
use zokrates_ast::common::FlatEmbed;
|
use zokrates_ast::common::FlatEmbed;
|
||||||
use zokrates_ast::typed::result_folder::*;
|
use zokrates_ast::typed::result_folder::*;
|
||||||
use zokrates_ast::typed::types::Type;
|
use zokrates_ast::typed::types::Type;
|
||||||
|
@ -402,7 +402,6 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
|
||||||
true => {
|
true => {
|
||||||
let r: Option<TypedExpression<'ast, T>> = match embed_call.embed {
|
let r: Option<TypedExpression<'ast, T>> = match embed_call.embed {
|
||||||
FlatEmbed::FieldToBoolUnsafe => Ok(None), // todo
|
FlatEmbed::FieldToBoolUnsafe => Ok(None), // todo
|
||||||
FlatEmbed::BoolToField => Ok(None), // todo
|
|
||||||
FlatEmbed::BitArrayLe => Ok(None), // todo
|
FlatEmbed::BitArrayLe => Ok(None), // todo
|
||||||
FlatEmbed::U64FromBits => Ok(Some(process_u_from_bits(
|
FlatEmbed::U64FromBits => Ok(Some(process_u_from_bits(
|
||||||
&embed_call.arguments,
|
&embed_call.arguments,
|
||||||
|
@ -874,6 +873,12 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
|
||||||
T::try_from(n1.to_biguint().bitxor(n2.to_biguint())).unwrap(),
|
T::try_from(n1.to_biguint().bitxor(n2.to_biguint())).unwrap(),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
(FieldElementExpression::Number(n), e)
|
||||||
|
| (e, FieldElementExpression::Number(n))
|
||||||
|
if n == T::from(0) =>
|
||||||
|
{
|
||||||
|
Ok(e)
|
||||||
|
}
|
||||||
(e1, e2) if e1.eq(&e2) => Ok(FieldElementExpression::Number(T::from(0))),
|
(e1, e2) if e1.eq(&e2) => Ok(FieldElementExpression::Number(T::from(0))),
|
||||||
(e1, e2) => Ok(FieldElementExpression::Xor(box e1, box e2)),
|
(e1, e2) => Ok(FieldElementExpression::Xor(box e1, box e2)),
|
||||||
}
|
}
|
||||||
|
@ -948,8 +953,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
|
||||||
let mask: BigUint = two.pow(T::get_required_bits()).sub(1usize);
|
let mask: BigUint = two.pow(T::get_required_bits()).sub(1usize);
|
||||||
|
|
||||||
Ok(FieldElementExpression::Number(
|
Ok(FieldElementExpression::Number(
|
||||||
T::try_from(n.to_biguint().mul(two.pow(by as usize)).bitand(mask))
|
T::try_from(n.to_biguint().shl(by as usize).bitand(mask)).unwrap(),
|
||||||
.unwrap(),
|
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
(e, by) => Ok(FieldElementExpression::LeftShift(box e, box by)),
|
(e, by) => Ok(FieldElementExpression::LeftShift(box e, box by)),
|
||||||
|
|
|
@ -2,7 +2,7 @@ use num::traits::Pow;
|
||||||
use num_bigint::BigUint;
|
use num_bigint::BigUint;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
use std::ops::{BitAnd, BitOr, BitXor, Mul, Shr, Sub};
|
use std::ops::{BitAnd, BitOr, BitXor, Shl, Shr, Sub};
|
||||||
use zokrates_ast::zir::types::UBitwidth;
|
use zokrates_ast::zir::types::UBitwidth;
|
||||||
use zokrates_ast::zir::{
|
use zokrates_ast::zir::{
|
||||||
result_folder::*, Conditional, ConditionalExpression, ConditionalOrExpression, Expr, Id,
|
result_folder::*, Conditional, ConditionalExpression, ConditionalOrExpression, Expr, Id,
|
||||||
|
@ -334,8 +334,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> {
|
||||||
let mask: BigUint = two.pow(T::get_required_bits()).sub(1usize);
|
let mask: BigUint = two.pow(T::get_required_bits()).sub(1usize);
|
||||||
|
|
||||||
Ok(FieldElementExpression::Number(
|
Ok(FieldElementExpression::Number(
|
||||||
T::try_from(n.to_biguint().mul(two.pow(by as usize)).bitand(mask))
|
T::try_from(n.to_biguint().shl(by as usize).bitand(mask)).unwrap(),
|
||||||
.unwrap(),
|
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
(e, by) => Ok(FieldElementExpression::LeftShift(box e, box by)),
|
(e, by) => Ok(FieldElementExpression::LeftShift(box e, box by)),
|
||||||
|
|
|
@ -32,7 +32,6 @@ cfg_if::cfg_if! {
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Copy, PartialOrd, Ord, Serialize, Deserialize)]
|
#[derive(Debug, Clone, PartialEq, Eq, Hash, Copy, PartialOrd, Ord, Serialize, Deserialize)]
|
||||||
pub enum FlatEmbed {
|
pub enum FlatEmbed {
|
||||||
FieldToBoolUnsafe,
|
FieldToBoolUnsafe,
|
||||||
BoolToField,
|
|
||||||
BitArrayLe,
|
BitArrayLe,
|
||||||
Unpack,
|
Unpack,
|
||||||
U8ToBits,
|
U8ToBits,
|
||||||
|
@ -55,9 +54,6 @@ impl FlatEmbed {
|
||||||
FlatEmbed::FieldToBoolUnsafe => UnresolvedSignature::new()
|
FlatEmbed::FieldToBoolUnsafe => UnresolvedSignature::new()
|
||||||
.inputs(vec![UnresolvedType::FieldElement.into()])
|
.inputs(vec![UnresolvedType::FieldElement.into()])
|
||||||
.output(UnresolvedType::Boolean.into()),
|
.output(UnresolvedType::Boolean.into()),
|
||||||
FlatEmbed::BoolToField => UnresolvedSignature::new()
|
|
||||||
.inputs(vec![UnresolvedType::Boolean.into()])
|
|
||||||
.output(UnresolvedType::FieldElement.into()),
|
|
||||||
FlatEmbed::BitArrayLe => UnresolvedSignature::new()
|
FlatEmbed::BitArrayLe => UnresolvedSignature::new()
|
||||||
.generics(vec![ConstantGenericNode::mock("N")])
|
.generics(vec![ConstantGenericNode::mock("N")])
|
||||||
.inputs(vec![
|
.inputs(vec![
|
||||||
|
@ -197,9 +193,6 @@ impl FlatEmbed {
|
||||||
FlatEmbed::FieldToBoolUnsafe => DeclarationSignature::new()
|
FlatEmbed::FieldToBoolUnsafe => DeclarationSignature::new()
|
||||||
.inputs(vec![DeclarationType::FieldElement])
|
.inputs(vec![DeclarationType::FieldElement])
|
||||||
.output(DeclarationType::Boolean),
|
.output(DeclarationType::Boolean),
|
||||||
FlatEmbed::BoolToField => DeclarationSignature::new()
|
|
||||||
.inputs(vec![DeclarationType::Boolean])
|
|
||||||
.output(DeclarationType::FieldElement),
|
|
||||||
FlatEmbed::BitArrayLe => DeclarationSignature::new()
|
FlatEmbed::BitArrayLe => DeclarationSignature::new()
|
||||||
.generics(vec![Some(DeclarationConstant::Generic(
|
.generics(vec![Some(DeclarationConstant::Generic(
|
||||||
GenericIdentifier::with_name("N").with_index(0),
|
GenericIdentifier::with_name("N").with_index(0),
|
||||||
|
@ -307,7 +300,6 @@ impl FlatEmbed {
|
||||||
pub fn id(&self) -> &'static str {
|
pub fn id(&self) -> &'static str {
|
||||||
match self {
|
match self {
|
||||||
FlatEmbed::FieldToBoolUnsafe => "_FIELD_TO_BOOL_UNSAFE",
|
FlatEmbed::FieldToBoolUnsafe => "_FIELD_TO_BOOL_UNSAFE",
|
||||||
FlatEmbed::BoolToField => "_BOOL_TO_FIELD",
|
|
||||||
FlatEmbed::BitArrayLe => "_BIT_ARRAY_LT",
|
FlatEmbed::BitArrayLe => "_BIT_ARRAY_LT",
|
||||||
FlatEmbed::Unpack => "_UNPACK",
|
FlatEmbed::Unpack => "_UNPACK",
|
||||||
FlatEmbed::U8ToBits => "_U8_TO_BITS",
|
FlatEmbed::U8ToBits => "_U8_TO_BITS",
|
||||||
|
|
|
@ -3,7 +3,7 @@ use std::fmt;
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize, Hash, PartialEq, Eq)]
|
#[derive(Debug, Clone, Serialize, Deserialize, Hash, PartialEq, Eq)]
|
||||||
pub enum RuntimeError {
|
pub enum RuntimeError {
|
||||||
UserConstraint,
|
SourceAssemblyConstraint,
|
||||||
BellmanConstraint,
|
BellmanConstraint,
|
||||||
BellmanOneBinding,
|
BellmanOneBinding,
|
||||||
BellmanInputBinding,
|
BellmanInputBinding,
|
||||||
|
@ -50,7 +50,7 @@ impl RuntimeError {
|
||||||
|
|
||||||
!matches!(
|
!matches!(
|
||||||
self,
|
self,
|
||||||
UserConstraint
|
SourceAssemblyConstraint
|
||||||
| SourceAssertion(_)
|
| SourceAssertion(_)
|
||||||
| Inverse
|
| Inverse
|
||||||
| SelectRangeCheck
|
| SelectRangeCheck
|
||||||
|
@ -65,7 +65,7 @@ impl fmt::Display for RuntimeError {
|
||||||
use RuntimeError::*;
|
use RuntimeError::*;
|
||||||
|
|
||||||
let msg = match self {
|
let msg = match self {
|
||||||
UserConstraint => "User constraint is unsatisfied",
|
SourceAssemblyConstraint => "Source constraint is unsatisfied",
|
||||||
BellmanConstraint => "Bellman constraint is unsatisfied",
|
BellmanConstraint => "Bellman constraint is unsatisfied",
|
||||||
BellmanOneBinding => "Bellman ~one binding is unsatisfied",
|
BellmanOneBinding => "Bellman ~one binding is unsatisfied",
|
||||||
BellmanInputBinding => "Bellman input binding is unsatisfied",
|
BellmanInputBinding => "Bellman input binding is unsatisfied",
|
||||||
|
|
|
@ -20,7 +20,7 @@ pub enum Solver<'ast, T> {
|
||||||
SnarkVerifyBls12377(usize),
|
SnarkVerifyBls12377(usize),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'ast, T: fmt::Debug + fmt::Display> fmt::Display for Solver<'ast, T> {
|
impl<'ast, T: fmt::Debug> fmt::Display for Solver<'ast, T> {
|
||||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||||
match self {
|
match self {
|
||||||
Solver::Zir(_) => write!(f, "Zir(..)"),
|
Solver::Zir(_) => write!(f, "Zir(..)"),
|
||||||
|
|
|
@ -1,14 +1,17 @@
|
||||||
use crate::zir::{FieldElementExpression, Identifier};
|
use crate::zir::{FieldElementExpression, Identifier};
|
||||||
use zokrates_field::Field;
|
use zokrates_field::Field;
|
||||||
|
|
||||||
|
pub type LinearTerm<'ast, T> = (T, Identifier<'ast>);
|
||||||
|
pub type QuadraticTerm<'ast, T> = (T, Identifier<'ast>, Identifier<'ast>);
|
||||||
|
|
||||||
#[derive(Clone, PartialEq, Hash, Eq, Debug, Default)]
|
#[derive(Clone, PartialEq, Hash, Eq, Debug, Default)]
|
||||||
pub struct LinQuadComb<'ast, T> {
|
pub struct LinQuadComb<'ast, T> {
|
||||||
// the constant terms
|
// the constant terms
|
||||||
pub constant: T,
|
pub constant: T,
|
||||||
// the linear terms
|
// the linear terms
|
||||||
pub linear: Vec<(T, Identifier<'ast>)>,
|
pub linear: Vec<LinearTerm<'ast, T>>,
|
||||||
// the quadratic terms
|
// the quadratic terms
|
||||||
pub quadratic: Vec<(T, Identifier<'ast>, Identifier<'ast>)>,
|
pub quadratic: Vec<QuadraticTerm<'ast, T>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'ast, T: Field> std::ops::Add for LinQuadComb<'ast, T> {
|
impl<'ast, T: Field> std::ops::Add for LinQuadComb<'ast, T> {
|
||||||
|
|
|
@ -166,7 +166,8 @@ pub enum ZirStatement<'ast, T> {
|
||||||
FormatString,
|
FormatString,
|
||||||
Vec<(ConcreteType, Vec<ZirExpression<'ast, T>>)>,
|
Vec<(ConcreteType, Vec<ZirExpression<'ast, T>>)>,
|
||||||
),
|
),
|
||||||
Assembly(#[serde(borrow)] Vec<ZirAssemblyStatement<'ast, T>>),
|
#[serde(borrow)]
|
||||||
|
Assembly(Vec<ZirAssemblyStatement<'ast, T>>),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'ast, T: fmt::Display> fmt::Display for ZirStatement<'ast, T> {
|
impl<'ast, T: fmt::Display> fmt::Display for ZirStatement<'ast, T> {
|
||||||
|
|
|
@ -1,8 +1,6 @@
|
||||||
from "EMBED" import bool_to_field;
|
|
||||||
|
|
||||||
def main(bool x) -> field {
|
def main(bool x) -> field {
|
||||||
// `x` is constrained by the compiler automatically so we can safely
|
// `x` is constrained by the compiler automatically so we can safely
|
||||||
// treat it as `field` with no extra cost
|
// convert to `field` with no extra cost
|
||||||
field out = bool_to_field(x);
|
field out = x ? 1 : 0;
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
|
@ -1,7 +1,11 @@
|
||||||
def main(field a, field b) {
|
def main(field a, field b) -> field {
|
||||||
field mut c = 0;
|
field mut c = 0;
|
||||||
|
field mut invb = 0;
|
||||||
asm {
|
asm {
|
||||||
c <-- a / b;
|
invb <-- b == 0 ? 0 : 1 / b;
|
||||||
|
invb * b === 1;
|
||||||
|
c <-- invb * a;
|
||||||
a === b * c;
|
a === b * c;
|
||||||
}
|
}
|
||||||
|
return c;
|
||||||
}
|
}
|
|
@ -5,7 +5,8 @@ def main(field x) -> bool {
|
||||||
asm {
|
asm {
|
||||||
x * (x - 1) === 0;
|
x * (x - 1) === 0;
|
||||||
}
|
}
|
||||||
// we can treat `x` as `bool` afterwards, as we constrained it properly
|
// we can convert `x` to `bool` afterwards, as we constrained it properly
|
||||||
|
// if we failed to constrain `x` to `0` or `1`, the call to `field_to_bool_unsafe` introduces undefined behavior
|
||||||
// `field_to_bool_unsafe` call does not produce any extra constraints
|
// `field_to_bool_unsafe` call does not produce any extra constraints
|
||||||
bool out = field_to_bool_unsafe(x);
|
bool out = field_to_bool_unsafe(x);
|
||||||
return out;
|
return out;
|
||||||
|
|
|
@ -259,7 +259,7 @@ mod tests {
|
||||||
|
|
||||||
let interpreter = zokrates_interpreter::Interpreter::default();
|
let interpreter = zokrates_interpreter::Interpreter::default();
|
||||||
|
|
||||||
let res = interpreter.execute(artifacts.prog(), &[Bn128Field::from(0u32)]);
|
let res = interpreter.execute(artifacts.prog(), &[Bn128Field::from(0)]);
|
||||||
|
|
||||||
assert!(res.is_err());
|
assert!(res.is_err());
|
||||||
}
|
}
|
||||||
|
|
|
@ -1050,7 +1050,6 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
|
|
||||||
match embed {
|
match embed {
|
||||||
FlatEmbed::FieldToBoolUnsafe => vec![params.pop().unwrap()],
|
FlatEmbed::FieldToBoolUnsafe => vec![params.pop().unwrap()],
|
||||||
FlatEmbed::BoolToField => vec![params.pop().unwrap()],
|
|
||||||
FlatEmbed::U8ToBits => self.u_to_bits(params.pop().unwrap(), 8.into()),
|
FlatEmbed::U8ToBits => self.u_to_bits(params.pop().unwrap(), 8.into()),
|
||||||
FlatEmbed::U16ToBits => self.u_to_bits(params.pop().unwrap(), 16.into()),
|
FlatEmbed::U16ToBits => self.u_to_bits(params.pop().unwrap(), 16.into()),
|
||||||
FlatEmbed::U32ToBits => self.u_to_bits(params.pop().unwrap(), 32.into()),
|
FlatEmbed::U32ToBits => self.u_to_bits(params.pop().unwrap(), 32.into()),
|
||||||
|
@ -2255,7 +2254,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
||||||
statements_flattened.push_back(FlatStatement::Condition(
|
statements_flattened.push_back(FlatStatement::Condition(
|
||||||
lhs,
|
lhs,
|
||||||
rhs,
|
rhs,
|
||||||
RuntimeError::UserConstraint,
|
RuntimeError::SourceAssemblyConstraint,
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -151,10 +151,6 @@ impl Importer {
|
||||||
id: symbol.get_alias(),
|
id: symbol.get_alias(),
|
||||||
symbol: Symbol::Flat(FlatEmbed::FieldToBoolUnsafe),
|
symbol: Symbol::Flat(FlatEmbed::FieldToBoolUnsafe),
|
||||||
},
|
},
|
||||||
"bool_to_field" => SymbolDeclaration {
|
|
||||||
id: symbol.get_alias(),
|
|
||||||
symbol: Symbol::Flat(FlatEmbed::BoolToField),
|
|
||||||
},
|
|
||||||
"bit_array_le" => SymbolDeclaration {
|
"bit_array_le" => SymbolDeclaration {
|
||||||
id: symbol.get_alias(),
|
id: symbol.get_alias(),
|
||||||
symbol: Symbol::Flat(FlatEmbed::BitArrayLe),
|
symbol: Symbol::Flat(FlatEmbed::BitArrayLe),
|
||||||
|
|
|
@ -1804,10 +1804,9 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
||||||
let e = match checked_e {
|
let e = match checked_e {
|
||||||
TypedExpression::FieldElement(e) => Ok(e),
|
TypedExpression::FieldElement(e) => Ok(e),
|
||||||
TypedExpression::Int(e) => Ok(FieldElementExpression::try_from_int(e).unwrap()),
|
TypedExpression::Int(e) => Ok(FieldElementExpression::try_from_int(e).unwrap()),
|
||||||
_ => Err(ErrorInner {
|
e => Err(ErrorInner {
|
||||||
pos: Some(pos),
|
pos: Some(pos),
|
||||||
message: "Only field element expressions are allowed in the assembly block"
|
message: format!("The right hand side of an assembly assignment must be of type field, found {}", e.get_type())
|
||||||
.to_string(),
|
|
||||||
}),
|
}),
|
||||||
}?;
|
}?;
|
||||||
|
|
||||||
|
@ -1819,9 +1818,9 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
||||||
TypedAssemblyStatement::Assignment(assignee.clone(), e.clone()),
|
TypedAssemblyStatement::Assignment(assignee.clone(), e.clone()),
|
||||||
TypedAssemblyStatement::Constraint(assignee.into(), e),
|
TypedAssemblyStatement::Constraint(assignee.into(), e),
|
||||||
]),
|
]),
|
||||||
_ => Err(ErrorInner {
|
ty => Err(ErrorInner {
|
||||||
pos: Some(pos),
|
pos: Some(pos),
|
||||||
message: "Assignee must be of type `field`".to_string(),
|
message: format!("Assignee must be of type field, found {}", ty),
|
||||||
}),
|
}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1849,10 +1848,13 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
||||||
FieldElementExpression::try_from_int(lhs).unwrap(),
|
FieldElementExpression::try_from_int(lhs).unwrap(),
|
||||||
FieldElementExpression::try_from_int(rhs).unwrap(),
|
FieldElementExpression::try_from_int(rhs).unwrap(),
|
||||||
)),
|
)),
|
||||||
_ => Err(ErrorInner {
|
(e1, e2) => Err(ErrorInner {
|
||||||
pos: Some(pos),
|
pos: Some(pos),
|
||||||
message: "Only field element expressions are allowed in the assembly block"
|
message: format!(
|
||||||
.to_string(),
|
"Assembly constraint expected expressions of type field, found {}, {}",
|
||||||
|
e1.get_type(),
|
||||||
|
e2.get_type()
|
||||||
|
),
|
||||||
}),
|
}),
|
||||||
}?;
|
}?;
|
||||||
|
|
||||||
|
@ -1873,14 +1875,12 @@ impl<'ast, T: Field> Checker<'ast, T> {
|
||||||
Statement::Assembly(statements) => {
|
Statement::Assembly(statements) => {
|
||||||
let mut checked_statements = vec![];
|
let mut checked_statements = vec![];
|
||||||
for s in statements {
|
for s in statements {
|
||||||
checked_statements.push(
|
checked_statements.extend(
|
||||||
self.check_assembly_statement(s, module_id, types)
|
self.check_assembly_statement(s, module_id, types)
|
||||||
.map_err(|e| vec![e])?,
|
.map_err(|e| vec![e])?,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
Ok(TypedStatement::Assembly(
|
Ok(TypedStatement::Assembly(checked_statements))
|
||||||
checked_statements.into_iter().flatten().collect(),
|
|
||||||
))
|
|
||||||
}
|
}
|
||||||
Statement::Log(l, expressions) => {
|
Statement::Log(l, expressions) => {
|
||||||
let l = FormatString::from(l);
|
let l = FormatString::from(l);
|
||||||
|
|
|
@ -29,7 +29,7 @@
|
||||||
"output": {
|
"output": {
|
||||||
"Err": {
|
"Err": {
|
||||||
"UnsatisfiedConstraint": {
|
"UnsatisfiedConstraint": {
|
||||||
"error": "UserConstraint"
|
"error": "SourceAssemblyConstraint"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -48,7 +48,7 @@
|
||||||
"output": {
|
"output": {
|
||||||
"Err": {
|
"Err": {
|
||||||
"UnsatisfiedConstraint": {
|
"UnsatisfiedConstraint": {
|
||||||
"error": "UserConstraint"
|
"error": "SourceAssemblyConstraint"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
{
|
{
|
||||||
"curves": ["Bn128"],
|
"curves": ["Bn128"],
|
||||||
"max_constraint_count": 2,
|
"max_constraint_count": 3,
|
||||||
"tests": [
|
"tests": [
|
||||||
{
|
{
|
||||||
"input": {
|
"input": {
|
||||||
|
@ -14,11 +14,13 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"input": {
|
"input": {
|
||||||
"values": ["4", "0"]
|
"values": ["0", "0"]
|
||||||
},
|
},
|
||||||
"output": {
|
"output": {
|
||||||
"Err": {
|
"Err": {
|
||||||
"Solver": "Assertion failed: `Division by zero`"
|
"UnsatisfiedConstraint": {
|
||||||
|
"error": "SourceAssemblyConstraint"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,8 +1,11 @@
|
||||||
def main(field x, field y) -> field {
|
def main(field a, field b) -> field {
|
||||||
field mut z = 0;
|
field mut c = 0;
|
||||||
|
field mut invb = 0;
|
||||||
asm {
|
asm {
|
||||||
z <-- x / y;
|
invb <-- b == 0 ? 0 : 1 / b;
|
||||||
z * y === x;
|
invb * b === 1;
|
||||||
|
c <-- invb * a;
|
||||||
|
a === b * c;
|
||||||
}
|
}
|
||||||
return z;
|
return c;
|
||||||
}
|
}
|
|
@ -1,26 +0,0 @@
|
||||||
{
|
|
||||||
"curves": ["Bn128"],
|
|
||||||
"max_constraint_count": 2,
|
|
||||||
"tests": [
|
|
||||||
{
|
|
||||||
"input": {
|
|
||||||
"values": [false]
|
|
||||||
},
|
|
||||||
"output": {
|
|
||||||
"Ok": {
|
|
||||||
"value": "0"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"input": {
|
|
||||||
"values": [true]
|
|
||||||
},
|
|
||||||
"output": {
|
|
||||||
"Ok": {
|
|
||||||
"value": "1"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
|
@ -1,6 +0,0 @@
|
||||||
from "EMBED" import bool_to_field;
|
|
||||||
|
|
||||||
def main(bool x) -> field {
|
|
||||||
field out = bool_to_field(x);
|
|
||||||
return out;
|
|
||||||
}
|
|
|
@ -29,7 +29,7 @@
|
||||||
"output": {
|
"output": {
|
||||||
"Err": {
|
"Err": {
|
||||||
"UnsatisfiedConstraint": {
|
"UnsatisfiedConstraint": {
|
||||||
"error": "UserConstraint"
|
"error": "SourceAssemblyConstraint"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -42,7 +42,7 @@ fn lt_field() {
|
||||||
assert!(interpreter
|
assert!(interpreter
|
||||||
.execute(
|
.execute(
|
||||||
res.prog(),
|
res.prog(),
|
||||||
&[Bn128Field::from(10000u32), Bn128Field::from(5555u32)]
|
&[Bn128Field::from(10000), Bn128Field::from(5555)]
|
||||||
)
|
)
|
||||||
.is_err());
|
.is_err());
|
||||||
}
|
}
|
||||||
|
@ -78,7 +78,7 @@ fn lt_uint() {
|
||||||
assert!(interpreter
|
assert!(interpreter
|
||||||
.execute(
|
.execute(
|
||||||
res.prog(),
|
res.prog(),
|
||||||
&[Bn128Field::from(10000u32), Bn128Field::from(5555u32)]
|
&[Bn128Field::from(10000), Bn128Field::from(5555)]
|
||||||
)
|
)
|
||||||
.is_err());
|
.is_err());
|
||||||
}
|
}
|
||||||
|
@ -123,7 +123,7 @@ fn unpack256() {
|
||||||
let interpreter = Interpreter::try_out_of_range();
|
let interpreter = Interpreter::try_out_of_range();
|
||||||
|
|
||||||
assert!(interpreter
|
assert!(interpreter
|
||||||
.execute(res.prog(), &[Bn128Field::from(0u32)])
|
.execute(res.prog(), &[Bn128Field::from(0)])
|
||||||
.is_err());
|
.is_err());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -167,6 +167,6 @@ fn unpack256_unchecked() {
|
||||||
let interpreter = Interpreter::try_out_of_range();
|
let interpreter = Interpreter::try_out_of_range();
|
||||||
|
|
||||||
assert!(interpreter
|
assert!(interpreter
|
||||||
.execute(res.prog(), &[Bn128Field::from(0u32)])
|
.execute(res.prog(), &[Bn128Field::from(0)])
|
||||||
.is_ok());
|
.is_ok());
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue