1
0
Fork 0
mirror of synced 2025-09-23 12:18:44 +00:00

improvements

This commit is contained in:
dark64 2022-11-25 17:54:21 +01:00
parent 1e07a90ed6
commit 8d7e5804df
24 changed files with 154 additions and 151 deletions

View file

@ -1,3 +1,7 @@
// A static analyser pass to transform user-defined constraints to the form `lin_comb === quad_comb`
// This pass can fail if a non-quadratic constraint is found which cannot be transformed to the expected form
use crate::ZirPropagator;
use std::fmt; use std::fmt;
use zokrates_ast::zir::lqc::LinQuadComb; use zokrates_ast::zir::lqc::LinQuadComb;
use zokrates_ast::zir::result_folder::{fold_field_expression, ResultFolder}; use zokrates_ast::zir::result_folder::{fold_field_expression, ResultFolder};
@ -17,8 +21,7 @@ pub struct AssemblyTransformer;
impl AssemblyTransformer { impl AssemblyTransformer {
pub fn transform<T: Field>(p: ZirProgram<T>) -> Result<ZirProgram<T>, Error> { pub fn transform<T: Field>(p: ZirProgram<T>) -> Result<ZirProgram<T>, Error> {
let mut f = AssemblyTransformer; AssemblyTransformer.fold_program(p)
f.fold_program(p)
} }
} }
@ -52,7 +55,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for AssemblyTransformer {
match is_quadratic { match is_quadratic {
true => Ok(ZirAssemblyStatement::Constraint(lhs, rhs)), true => Ok(ZirAssemblyStatement::Constraint(lhs, rhs)),
false => { false => {
let sub = FieldElementExpression::Sub(box lhs.clone(), box rhs.clone()); let sub = FieldElementExpression::Sub(box lhs, box rhs);
let mut lqc = LinQuadComb::try_from(sub.clone()).map_err(|_| { let mut lqc = LinQuadComb::try_from(sub.clone()).map_err(|_| {
Error("Non-quadratic constraints are not allowed".to_string()) Error("Non-quadratic constraints are not allowed".to_string())
})?; })?;
@ -60,72 +63,72 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for AssemblyTransformer {
let linear = lqc let linear = lqc
.linear .linear
.into_iter() .into_iter()
.filter_map(|(c, i)| match c { .map(|(c, i)| {
c if c == T::from(0) => None, FieldElementExpression::Mult(
c if c == T::from(1) => Some(FieldElementExpression::identifier(i)),
_ => Some(FieldElementExpression::Mult(
box FieldElementExpression::Number(c), box FieldElementExpression::Number(c),
box FieldElementExpression::identifier(i), box FieldElementExpression::identifier(i),
)), )
}) })
.reduce(|p, n| FieldElementExpression::Add(box p, box n)) .fold(FieldElementExpression::Number(T::from(0)), |acc, e| {
.unwrap_or_else(|| FieldElementExpression::Number(T::from(0))); FieldElementExpression::Add(box acc, box e)
});
let lhs = match lqc.constant { let lhs = FieldElementExpression::Add(
c if c == T::from(0) => linear, box FieldElementExpression::Number(lqc.constant),
c => FieldElementExpression::Add( box linear,
box FieldElementExpression::Number(c), );
box linear,
),
};
let rhs: FieldElementExpression<'ast, T> = if lqc.quadratic.len() > 1 { let rhs: FieldElementExpression<'ast, T> = if lqc.quadratic.len() > 1 {
let is_common_factor = |id: &Identifier<'ast>, let common_factors = lqc.quadratic.iter().fold(
q: &Vec<( None,
T, |acc: Option<Vec<Identifier>>, (_, a, b)| {
Identifier<'ast>, Some(
Identifier<'ast>, acc.map(|factors| {
)>| { factors
q.iter().all(|(_, i0, i1)| i0.eq(id) || i1.eq(id)) .into_iter()
}; .filter(|f| f == a || f == b)
.collect()
})
.unwrap_or_else(|| vec![a.clone(), b.clone()]),
)
},
);
let common_factor: Option<Identifier<'ast>> = match common_factors {
lqc.quadratic.iter().find_map(|(_, i0, i1)| { Some(factors) => Ok(FieldElementExpression::Mult(
if is_common_factor(i0, &lqc.quadratic) {
Some(i0.clone())
} else if is_common_factor(i1, &lqc.quadratic) {
Some(i1.clone())
} else {
None
}
});
match common_factor {
Some(id) => Ok(FieldElementExpression::Mult(
box lqc box lqc
.quadratic .quadratic
.into_iter() .into_iter()
.filter_map(|(c, i0, i1)| { .map(|(c, i0, i1)| {
let c = T::zero() - c; let c = T::zero() - c;
let id = if id.eq(&i0) { i1 } else { i0 }; let i0 = match factors.contains(&i0) {
match c { true => FieldElementExpression::Number(T::from(1)),
c if c == T::from(0) => None, false => FieldElementExpression::identifier(i0),
c if c == T::from(1) => { };
Some(FieldElementExpression::identifier(id)) let i1 = match factors.contains(&i1) {
} true => FieldElementExpression::Number(T::from(1)),
_ => Some(FieldElementExpression::Mult( false => FieldElementExpression::identifier(i1),
box FieldElementExpression::Number(c), };
box FieldElementExpression::identifier(id), FieldElementExpression::Mult(
)), box FieldElementExpression::Number(c),
} box FieldElementExpression::Mult(box i0, box i1),
)
}) })
.reduce(|p, n| FieldElementExpression::Add(box p, box n)) .fold(
.unwrap_or_else(|| { FieldElementExpression::Number(T::from(0)),
FieldElementExpression::Number(T::from(0)) |acc, e| FieldElementExpression::Add(box acc, box e),
}), ),
box FieldElementExpression::identifier(id), box factors.into_iter().fold(
FieldElementExpression::Number(T::from(1)),
|acc, id| {
FieldElementExpression::Mult(
box acc,
box FieldElementExpression::identifier(id),
)
},
),
)), )),
_ => Err(Error( None => Err(Error(
"Non-quadratic constraints are not allowed".to_string(), "Non-quadratic constraints are not allowed".to_string(),
)), )),
}? }?
@ -144,6 +147,15 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for AssemblyTransformer {
.unwrap_or_else(|| FieldElementExpression::Number(T::from(0))) .unwrap_or_else(|| FieldElementExpression::Number(T::from(0)))
}; };
let mut propagator = ZirPropagator::default();
let lhs = propagator
.fold_field_expression(lhs)
.map_err(|e| Error(e.to_string()))?;
let rhs = propagator
.fold_field_expression(rhs)
.map_err(|e| Error(e.to_string()))?;
Ok(ZirAssemblyStatement::Constraint(lhs, rhs)) Ok(ZirAssemblyStatement::Constraint(lhs, rhs))
} }
} }

View file

@ -458,6 +458,27 @@ impl<'ast, T: Field> Flattener<T> {
} }
} }
// This finder looks for identifiers that were not defined in some block of statements
// These identifiers are used as function arguments when moving witness assignment expression
// to a zir function.
//
// Example:
// def main(field a, field mut b) -> field {
// asm {
// b <== a * a;
// }
// return b;
// }
// is turned into
// def main(field a, field mut b) -> field {
// asm {
// b <-- (field a) -> field {
// return a * a;
// }
// b == a * a;
// }
// return b;
// }
#[derive(Default)] #[derive(Default)]
pub struct ArgumentFinder<'ast, T> { pub struct ArgumentFinder<'ast, T> {
pub identifiers: HashMap<zir::Identifier<'ast>, zir::Type>, pub identifiers: HashMap<zir::Identifier<'ast>, zir::Type>,

View file

@ -12,7 +12,7 @@ use num_bigint::BigUint;
use std::collections::HashMap; use std::collections::HashMap;
use std::convert::{TryFrom, TryInto}; use std::convert::{TryFrom, TryInto};
use std::fmt; use std::fmt;
use std::ops::{BitAnd, BitOr, BitXor, Mul, Shr, Sub}; use std::ops::{BitAnd, BitOr, BitXor, Shl, Shr, Sub};
use zokrates_ast::common::FlatEmbed; use zokrates_ast::common::FlatEmbed;
use zokrates_ast::typed::result_folder::*; use zokrates_ast::typed::result_folder::*;
use zokrates_ast::typed::types::Type; use zokrates_ast::typed::types::Type;
@ -402,7 +402,6 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
true => { true => {
let r: Option<TypedExpression<'ast, T>> = match embed_call.embed { let r: Option<TypedExpression<'ast, T>> = match embed_call.embed {
FlatEmbed::FieldToBoolUnsafe => Ok(None), // todo FlatEmbed::FieldToBoolUnsafe => Ok(None), // todo
FlatEmbed::BoolToField => Ok(None), // todo
FlatEmbed::BitArrayLe => Ok(None), // todo FlatEmbed::BitArrayLe => Ok(None), // todo
FlatEmbed::U64FromBits => Ok(Some(process_u_from_bits( FlatEmbed::U64FromBits => Ok(Some(process_u_from_bits(
&embed_call.arguments, &embed_call.arguments,
@ -874,6 +873,12 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
T::try_from(n1.to_biguint().bitxor(n2.to_biguint())).unwrap(), T::try_from(n1.to_biguint().bitxor(n2.to_biguint())).unwrap(),
)) ))
} }
(FieldElementExpression::Number(n), e)
| (e, FieldElementExpression::Number(n))
if n == T::from(0) =>
{
Ok(e)
}
(e1, e2) if e1.eq(&e2) => Ok(FieldElementExpression::Number(T::from(0))), (e1, e2) if e1.eq(&e2) => Ok(FieldElementExpression::Number(T::from(0))),
(e1, e2) => Ok(FieldElementExpression::Xor(box e1, box e2)), (e1, e2) => Ok(FieldElementExpression::Xor(box e1, box e2)),
} }
@ -948,8 +953,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
let mask: BigUint = two.pow(T::get_required_bits()).sub(1usize); let mask: BigUint = two.pow(T::get_required_bits()).sub(1usize);
Ok(FieldElementExpression::Number( Ok(FieldElementExpression::Number(
T::try_from(n.to_biguint().mul(two.pow(by as usize)).bitand(mask)) T::try_from(n.to_biguint().shl(by as usize).bitand(mask)).unwrap(),
.unwrap(),
)) ))
} }
(e, by) => Ok(FieldElementExpression::LeftShift(box e, box by)), (e, by) => Ok(FieldElementExpression::LeftShift(box e, box by)),

View file

@ -2,7 +2,7 @@ use num::traits::Pow;
use num_bigint::BigUint; use num_bigint::BigUint;
use std::collections::HashMap; use std::collections::HashMap;
use std::fmt; use std::fmt;
use std::ops::{BitAnd, BitOr, BitXor, Mul, Shr, Sub}; use std::ops::{BitAnd, BitOr, BitXor, Shl, Shr, Sub};
use zokrates_ast::zir::types::UBitwidth; use zokrates_ast::zir::types::UBitwidth;
use zokrates_ast::zir::{ use zokrates_ast::zir::{
result_folder::*, Conditional, ConditionalExpression, ConditionalOrExpression, Expr, Id, result_folder::*, Conditional, ConditionalExpression, ConditionalOrExpression, Expr, Id,
@ -334,8 +334,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> {
let mask: BigUint = two.pow(T::get_required_bits()).sub(1usize); let mask: BigUint = two.pow(T::get_required_bits()).sub(1usize);
Ok(FieldElementExpression::Number( Ok(FieldElementExpression::Number(
T::try_from(n.to_biguint().mul(two.pow(by as usize)).bitand(mask)) T::try_from(n.to_biguint().shl(by as usize).bitand(mask)).unwrap(),
.unwrap(),
)) ))
} }
(e, by) => Ok(FieldElementExpression::LeftShift(box e, box by)), (e, by) => Ok(FieldElementExpression::LeftShift(box e, box by)),

View file

@ -32,7 +32,6 @@ cfg_if::cfg_if! {
#[derive(Debug, Clone, PartialEq, Eq, Hash, Copy, PartialOrd, Ord, Serialize, Deserialize)] #[derive(Debug, Clone, PartialEq, Eq, Hash, Copy, PartialOrd, Ord, Serialize, Deserialize)]
pub enum FlatEmbed { pub enum FlatEmbed {
FieldToBoolUnsafe, FieldToBoolUnsafe,
BoolToField,
BitArrayLe, BitArrayLe,
Unpack, Unpack,
U8ToBits, U8ToBits,
@ -55,9 +54,6 @@ impl FlatEmbed {
FlatEmbed::FieldToBoolUnsafe => UnresolvedSignature::new() FlatEmbed::FieldToBoolUnsafe => UnresolvedSignature::new()
.inputs(vec![UnresolvedType::FieldElement.into()]) .inputs(vec![UnresolvedType::FieldElement.into()])
.output(UnresolvedType::Boolean.into()), .output(UnresolvedType::Boolean.into()),
FlatEmbed::BoolToField => UnresolvedSignature::new()
.inputs(vec![UnresolvedType::Boolean.into()])
.output(UnresolvedType::FieldElement.into()),
FlatEmbed::BitArrayLe => UnresolvedSignature::new() FlatEmbed::BitArrayLe => UnresolvedSignature::new()
.generics(vec![ConstantGenericNode::mock("N")]) .generics(vec![ConstantGenericNode::mock("N")])
.inputs(vec![ .inputs(vec![
@ -197,9 +193,6 @@ impl FlatEmbed {
FlatEmbed::FieldToBoolUnsafe => DeclarationSignature::new() FlatEmbed::FieldToBoolUnsafe => DeclarationSignature::new()
.inputs(vec![DeclarationType::FieldElement]) .inputs(vec![DeclarationType::FieldElement])
.output(DeclarationType::Boolean), .output(DeclarationType::Boolean),
FlatEmbed::BoolToField => DeclarationSignature::new()
.inputs(vec![DeclarationType::Boolean])
.output(DeclarationType::FieldElement),
FlatEmbed::BitArrayLe => DeclarationSignature::new() FlatEmbed::BitArrayLe => DeclarationSignature::new()
.generics(vec![Some(DeclarationConstant::Generic( .generics(vec![Some(DeclarationConstant::Generic(
GenericIdentifier::with_name("N").with_index(0), GenericIdentifier::with_name("N").with_index(0),
@ -307,7 +300,6 @@ impl FlatEmbed {
pub fn id(&self) -> &'static str { pub fn id(&self) -> &'static str {
match self { match self {
FlatEmbed::FieldToBoolUnsafe => "_FIELD_TO_BOOL_UNSAFE", FlatEmbed::FieldToBoolUnsafe => "_FIELD_TO_BOOL_UNSAFE",
FlatEmbed::BoolToField => "_BOOL_TO_FIELD",
FlatEmbed::BitArrayLe => "_BIT_ARRAY_LT", FlatEmbed::BitArrayLe => "_BIT_ARRAY_LT",
FlatEmbed::Unpack => "_UNPACK", FlatEmbed::Unpack => "_UNPACK",
FlatEmbed::U8ToBits => "_U8_TO_BITS", FlatEmbed::U8ToBits => "_U8_TO_BITS",

View file

@ -3,7 +3,7 @@ use std::fmt;
#[derive(Debug, Clone, Serialize, Deserialize, Hash, PartialEq, Eq)] #[derive(Debug, Clone, Serialize, Deserialize, Hash, PartialEq, Eq)]
pub enum RuntimeError { pub enum RuntimeError {
UserConstraint, SourceAssemblyConstraint,
BellmanConstraint, BellmanConstraint,
BellmanOneBinding, BellmanOneBinding,
BellmanInputBinding, BellmanInputBinding,
@ -50,7 +50,7 @@ impl RuntimeError {
!matches!( !matches!(
self, self,
UserConstraint SourceAssemblyConstraint
| SourceAssertion(_) | SourceAssertion(_)
| Inverse | Inverse
| SelectRangeCheck | SelectRangeCheck
@ -65,7 +65,7 @@ impl fmt::Display for RuntimeError {
use RuntimeError::*; use RuntimeError::*;
let msg = match self { let msg = match self {
UserConstraint => "User constraint is unsatisfied", SourceAssemblyConstraint => "Source constraint is unsatisfied",
BellmanConstraint => "Bellman constraint is unsatisfied", BellmanConstraint => "Bellman constraint is unsatisfied",
BellmanOneBinding => "Bellman ~one binding is unsatisfied", BellmanOneBinding => "Bellman ~one binding is unsatisfied",
BellmanInputBinding => "Bellman input binding is unsatisfied", BellmanInputBinding => "Bellman input binding is unsatisfied",

View file

@ -20,7 +20,7 @@ pub enum Solver<'ast, T> {
SnarkVerifyBls12377(usize), SnarkVerifyBls12377(usize),
} }
impl<'ast, T: fmt::Debug + fmt::Display> fmt::Display for Solver<'ast, T> { impl<'ast, T: fmt::Debug> fmt::Display for Solver<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self { match self {
Solver::Zir(_) => write!(f, "Zir(..)"), Solver::Zir(_) => write!(f, "Zir(..)"),

View file

@ -1,14 +1,17 @@
use crate::zir::{FieldElementExpression, Identifier}; use crate::zir::{FieldElementExpression, Identifier};
use zokrates_field::Field; use zokrates_field::Field;
pub type LinearTerm<'ast, T> = (T, Identifier<'ast>);
pub type QuadraticTerm<'ast, T> = (T, Identifier<'ast>, Identifier<'ast>);
#[derive(Clone, PartialEq, Hash, Eq, Debug, Default)] #[derive(Clone, PartialEq, Hash, Eq, Debug, Default)]
pub struct LinQuadComb<'ast, T> { pub struct LinQuadComb<'ast, T> {
// the constant terms // the constant terms
pub constant: T, pub constant: T,
// the linear terms // the linear terms
pub linear: Vec<(T, Identifier<'ast>)>, pub linear: Vec<LinearTerm<'ast, T>>,
// the quadratic terms // the quadratic terms
pub quadratic: Vec<(T, Identifier<'ast>, Identifier<'ast>)>, pub quadratic: Vec<QuadraticTerm<'ast, T>>,
} }
impl<'ast, T: Field> std::ops::Add for LinQuadComb<'ast, T> { impl<'ast, T: Field> std::ops::Add for LinQuadComb<'ast, T> {

View file

@ -166,7 +166,8 @@ pub enum ZirStatement<'ast, T> {
FormatString, FormatString,
Vec<(ConcreteType, Vec<ZirExpression<'ast, T>>)>, Vec<(ConcreteType, Vec<ZirExpression<'ast, T>>)>,
), ),
Assembly(#[serde(borrow)] Vec<ZirAssemblyStatement<'ast, T>>), #[serde(borrow)]
Assembly(Vec<ZirAssemblyStatement<'ast, T>>),
} }
impl<'ast, T: fmt::Display> fmt::Display for ZirStatement<'ast, T> { impl<'ast, T: fmt::Display> fmt::Display for ZirStatement<'ast, T> {

View file

@ -1,8 +1,6 @@
from "EMBED" import bool_to_field;
def main(bool x) -> field { def main(bool x) -> field {
// `x` is constrained by the compiler automatically so we can safely // `x` is constrained by the compiler automatically so we can safely
// treat it as `field` with no extra cost // convert to `field` with no extra cost
field out = bool_to_field(x); field out = x ? 1 : 0;
return out; return out;
} }

View file

@ -1,7 +1,11 @@
def main(field a, field b) { def main(field a, field b) -> field {
field mut c = 0; field mut c = 0;
field mut invb = 0;
asm { asm {
c <-- a / b; invb <-- b == 0 ? 0 : 1 / b;
invb * b === 1;
c <-- invb * a;
a === b * c; a === b * c;
} }
return c;
} }

View file

@ -5,7 +5,8 @@ def main(field x) -> bool {
asm { asm {
x * (x - 1) === 0; x * (x - 1) === 0;
} }
// we can treat `x` as `bool` afterwards, as we constrained it properly // we can convert `x` to `bool` afterwards, as we constrained it properly
// if we failed to constrain `x` to `0` or `1`, the call to `field_to_bool_unsafe` introduces undefined behavior
// `field_to_bool_unsafe` call does not produce any extra constraints // `field_to_bool_unsafe` call does not produce any extra constraints
bool out = field_to_bool_unsafe(x); bool out = field_to_bool_unsafe(x);
return out; return out;

View file

@ -259,7 +259,7 @@ mod tests {
let interpreter = zokrates_interpreter::Interpreter::default(); let interpreter = zokrates_interpreter::Interpreter::default();
let res = interpreter.execute(artifacts.prog(), &[Bn128Field::from(0u32)]); let res = interpreter.execute(artifacts.prog(), &[Bn128Field::from(0)]);
assert!(res.is_err()); assert!(res.is_err());
} }

View file

@ -1050,7 +1050,6 @@ impl<'ast, T: Field> Flattener<'ast, T> {
match embed { match embed {
FlatEmbed::FieldToBoolUnsafe => vec![params.pop().unwrap()], FlatEmbed::FieldToBoolUnsafe => vec![params.pop().unwrap()],
FlatEmbed::BoolToField => vec![params.pop().unwrap()],
FlatEmbed::U8ToBits => self.u_to_bits(params.pop().unwrap(), 8.into()), FlatEmbed::U8ToBits => self.u_to_bits(params.pop().unwrap(), 8.into()),
FlatEmbed::U16ToBits => self.u_to_bits(params.pop().unwrap(), 16.into()), FlatEmbed::U16ToBits => self.u_to_bits(params.pop().unwrap(), 16.into()),
FlatEmbed::U32ToBits => self.u_to_bits(params.pop().unwrap(), 32.into()), FlatEmbed::U32ToBits => self.u_to_bits(params.pop().unwrap(), 32.into()),
@ -2255,7 +2254,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
statements_flattened.push_back(FlatStatement::Condition( statements_flattened.push_back(FlatStatement::Condition(
lhs, lhs,
rhs, rhs,
RuntimeError::UserConstraint, RuntimeError::SourceAssemblyConstraint,
)); ));
} }
} }

View file

@ -151,10 +151,6 @@ impl Importer {
id: symbol.get_alias(), id: symbol.get_alias(),
symbol: Symbol::Flat(FlatEmbed::FieldToBoolUnsafe), symbol: Symbol::Flat(FlatEmbed::FieldToBoolUnsafe),
}, },
"bool_to_field" => SymbolDeclaration {
id: symbol.get_alias(),
symbol: Symbol::Flat(FlatEmbed::BoolToField),
},
"bit_array_le" => SymbolDeclaration { "bit_array_le" => SymbolDeclaration {
id: symbol.get_alias(), id: symbol.get_alias(),
symbol: Symbol::Flat(FlatEmbed::BitArrayLe), symbol: Symbol::Flat(FlatEmbed::BitArrayLe),

View file

@ -1804,10 +1804,9 @@ impl<'ast, T: Field> Checker<'ast, T> {
let e = match checked_e { let e = match checked_e {
TypedExpression::FieldElement(e) => Ok(e), TypedExpression::FieldElement(e) => Ok(e),
TypedExpression::Int(e) => Ok(FieldElementExpression::try_from_int(e).unwrap()), TypedExpression::Int(e) => Ok(FieldElementExpression::try_from_int(e).unwrap()),
_ => Err(ErrorInner { e => Err(ErrorInner {
pos: Some(pos), pos: Some(pos),
message: "Only field element expressions are allowed in the assembly block" message: format!("The right hand side of an assembly assignment must be of type field, found {}", e.get_type())
.to_string(),
}), }),
}?; }?;
@ -1819,9 +1818,9 @@ impl<'ast, T: Field> Checker<'ast, T> {
TypedAssemblyStatement::Assignment(assignee.clone(), e.clone()), TypedAssemblyStatement::Assignment(assignee.clone(), e.clone()),
TypedAssemblyStatement::Constraint(assignee.into(), e), TypedAssemblyStatement::Constraint(assignee.into(), e),
]), ]),
_ => Err(ErrorInner { ty => Err(ErrorInner {
pos: Some(pos), pos: Some(pos),
message: "Assignee must be of type `field`".to_string(), message: format!("Assignee must be of type field, found {}", ty),
}), }),
} }
} }
@ -1849,10 +1848,13 @@ impl<'ast, T: Field> Checker<'ast, T> {
FieldElementExpression::try_from_int(lhs).unwrap(), FieldElementExpression::try_from_int(lhs).unwrap(),
FieldElementExpression::try_from_int(rhs).unwrap(), FieldElementExpression::try_from_int(rhs).unwrap(),
)), )),
_ => Err(ErrorInner { (e1, e2) => Err(ErrorInner {
pos: Some(pos), pos: Some(pos),
message: "Only field element expressions are allowed in the assembly block" message: format!(
.to_string(), "Assembly constraint expected expressions of type field, found {}, {}",
e1.get_type(),
e2.get_type()
),
}), }),
}?; }?;
@ -1873,14 +1875,12 @@ impl<'ast, T: Field> Checker<'ast, T> {
Statement::Assembly(statements) => { Statement::Assembly(statements) => {
let mut checked_statements = vec![]; let mut checked_statements = vec![];
for s in statements { for s in statements {
checked_statements.push( checked_statements.extend(
self.check_assembly_statement(s, module_id, types) self.check_assembly_statement(s, module_id, types)
.map_err(|e| vec![e])?, .map_err(|e| vec![e])?,
); );
} }
Ok(TypedStatement::Assembly( Ok(TypedStatement::Assembly(checked_statements))
checked_statements.into_iter().flatten().collect(),
))
} }
Statement::Log(l, expressions) => { Statement::Log(l, expressions) => {
let l = FormatString::from(l); let l = FormatString::from(l);

View file

@ -29,7 +29,7 @@
"output": { "output": {
"Err": { "Err": {
"UnsatisfiedConstraint": { "UnsatisfiedConstraint": {
"error": "UserConstraint" "error": "SourceAssemblyConstraint"
} }
} }
} }

View file

@ -48,7 +48,7 @@
"output": { "output": {
"Err": { "Err": {
"UnsatisfiedConstraint": { "UnsatisfiedConstraint": {
"error": "UserConstraint" "error": "SourceAssemblyConstraint"
} }
} }
} }

View file

@ -1,6 +1,6 @@
{ {
"curves": ["Bn128"], "curves": ["Bn128"],
"max_constraint_count": 2, "max_constraint_count": 3,
"tests": [ "tests": [
{ {
"input": { "input": {
@ -14,11 +14,13 @@
}, },
{ {
"input": { "input": {
"values": ["4", "0"] "values": ["0", "0"]
}, },
"output": { "output": {
"Err": { "Err": {
"Solver": "Assertion failed: `Division by zero`" "UnsatisfiedConstraint": {
"error": "SourceAssemblyConstraint"
}
} }
} }
} }

View file

@ -1,8 +1,11 @@
def main(field x, field y) -> field { def main(field a, field b) -> field {
field mut z = 0; field mut c = 0;
field mut invb = 0;
asm { asm {
z <-- x / y; invb <-- b == 0 ? 0 : 1 / b;
z * y === x; invb * b === 1;
c <-- invb * a;
a === b * c;
} }
return z; return c;
} }

View file

@ -1,26 +0,0 @@
{
"curves": ["Bn128"],
"max_constraint_count": 2,
"tests": [
{
"input": {
"values": [false]
},
"output": {
"Ok": {
"value": "0"
}
}
},
{
"input": {
"values": [true]
},
"output": {
"Ok": {
"value": "1"
}
}
}
]
}

View file

@ -1,6 +0,0 @@
from "EMBED" import bool_to_field;
def main(bool x) -> field {
field out = bool_to_field(x);
return out;
}

View file

@ -29,7 +29,7 @@
"output": { "output": {
"Err": { "Err": {
"UnsatisfiedConstraint": { "UnsatisfiedConstraint": {
"error": "UserConstraint" "error": "SourceAssemblyConstraint"
} }
} }
} }

View file

@ -42,7 +42,7 @@ fn lt_field() {
assert!(interpreter assert!(interpreter
.execute( .execute(
res.prog(), res.prog(),
&[Bn128Field::from(10000u32), Bn128Field::from(5555u32)] &[Bn128Field::from(10000), Bn128Field::from(5555)]
) )
.is_err()); .is_err());
} }
@ -78,7 +78,7 @@ fn lt_uint() {
assert!(interpreter assert!(interpreter
.execute( .execute(
res.prog(), res.prog(),
&[Bn128Field::from(10000u32), Bn128Field::from(5555u32)] &[Bn128Field::from(10000), Bn128Field::from(5555)]
) )
.is_err()); .is_err());
} }
@ -123,7 +123,7 @@ fn unpack256() {
let interpreter = Interpreter::try_out_of_range(); let interpreter = Interpreter::try_out_of_range();
assert!(interpreter assert!(interpreter
.execute(res.prog(), &[Bn128Field::from(0u32)]) .execute(res.prog(), &[Bn128Field::from(0)])
.is_err()); .is_err());
} }
@ -167,6 +167,6 @@ fn unpack256_unchecked() {
let interpreter = Interpreter::try_out_of_range(); let interpreter = Interpreter::try_out_of_range();
assert!(interpreter assert!(interpreter
.execute(res.prog(), &[Bn128Field::from(0u32)]) .execute(res.prog(), &[Bn128Field::from(0)])
.is_ok()); .is_ok());
} }