From bccb08c836f8f61a083ae5101d8b88a8ab9525d1 Mon Sep 17 00:00:00 2001 From: dark64 Date: Wed, 9 Nov 2022 18:57:30 +0100 Subject: [PATCH] wip --- zokrates_analysis/src/assembly_transformer.rs | 117 ++++++++++++++++ zokrates_analysis/src/lib.rs | 14 +- zokrates_analysis/src/zir_validator.rs | 54 ------- zokrates_ast/src/typed/mod.rs | 48 +++---- zokrates_ast/src/zir/lqc.rs | 132 ++++++++++++++++++ zokrates_ast/src/zir/mod.rs | 1 + zokrates_core/src/semantics.rs | 41 +++--- .../tests/tests/assembly/division.json | 1 + 8 files changed, 299 insertions(+), 109 deletions(-) create mode 100644 zokrates_analysis/src/assembly_transformer.rs delete mode 100644 zokrates_analysis/src/zir_validator.rs create mode 100644 zokrates_ast/src/zir/lqc.rs diff --git a/zokrates_analysis/src/assembly_transformer.rs b/zokrates_analysis/src/assembly_transformer.rs new file mode 100644 index 00000000..aa2f2595 --- /dev/null +++ b/zokrates_analysis/src/assembly_transformer.rs @@ -0,0 +1,117 @@ +use std::fmt; +use zokrates_ast::zir::lqc::LinQuadComb; +use zokrates_ast::zir::result_folder::{fold_field_expression, ResultFolder}; +use zokrates_ast::zir::{FieldElementExpression, ZirAssemblyStatement, ZirProgram}; +use zokrates_field::Field; + +#[derive(Debug)] +pub struct Error(String); + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +pub struct AssemblyTransformer; + +impl AssemblyTransformer { + pub fn transform(p: ZirProgram) -> Result, Error> { + let mut f = AssemblyTransformer; + f.fold_program(p) + } +} + +impl<'ast, T: Field> ResultFolder<'ast, T> for AssemblyTransformer { + type Error = Error; + + fn fold_assembly_statement( + &mut self, + s: ZirAssemblyStatement<'ast, T>, + ) -> Result, Self::Error> { + match s { + ZirAssemblyStatement::Assignment(_, _) => Ok(s), + ZirAssemblyStatement::Constraint(lhs, rhs) => { + let lhs = self.fold_field_expression(lhs)?; + let rhs = self.fold_field_expression(rhs)?; + let sub = FieldElementExpression::Sub(box lhs, box rhs); + + // let sub = match (lhs, rhs) { + // (FieldElementExpression::Number(n), e) + // | (e, FieldElementExpression::Number(n)) => { + // FieldElementExpression::Sub(box FieldElementExpression::Number(n), box e) + // } + // (lhs, rhs) => FieldElementExpression::Sub(box lhs, box rhs), + // }; + + let mut lqc = LinQuadComb::try_from(sub.clone()).map_err(|_| { + Error("Found forbidden operation in user-defined constraint".to_string()) + })?; + + println!("{:#?}", lqc); + + if lqc.quadratic.len() > 1 { + return Err(Error( + "Non-quadratic constraints are not allowed".to_string(), + )); + } + + let linear = lqc + .linear + .into_iter() + .filter_map(|(c, i)| match c { + c if c == T::from(0) => None, + c if c == T::from(1) => Some(FieldElementExpression::Identifier(i)), + _ => Some(FieldElementExpression::Mult( + box FieldElementExpression::Number(c), + box FieldElementExpression::Identifier(i), + )), + }) + .reduce(|p, n| FieldElementExpression::Add(box p, box n)) + .unwrap(); + + let lhs = match lqc.constant { + c if c == T::from(0) => linear, + c => FieldElementExpression::Add( + box FieldElementExpression::Number(c), + box linear, + ), + }; + + let rhs: FieldElementExpression<'ast, T> = lqc + .quadratic + .pop() + .map(|(c, i0, i1)| { + FieldElementExpression::Mult( + box FieldElementExpression::Mult( + box FieldElementExpression::Number(T::zero() - c), + box FieldElementExpression::Identifier(i0), + ), + box FieldElementExpression::Identifier(i1), + ) + }) + .unwrap_or_else(|| FieldElementExpression::Number(T::from(0))); + + println!("{} == {}", lhs, rhs); + + Ok(ZirAssemblyStatement::Constraint(lhs, rhs)) + } + } + } + + fn fold_field_expression( + &mut self, + e: FieldElementExpression<'ast, T>, + ) -> Result, Self::Error> { + match e { + FieldElementExpression::And(_, _) + | FieldElementExpression::Or(_, _) + | FieldElementExpression::Xor(_, _) + | FieldElementExpression::LeftShift(_, _) + | FieldElementExpression::RightShift(_, _) => Err(Error( + format!("Found bitwise operation in expression `{}` of type `field` (only allowed in assembly assignment statement)", e) + )), + e => fold_field_expression(self, e), + } + } +} diff --git a/zokrates_analysis/src/lib.rs b/zokrates_analysis/src/lib.rs index 79efe9f8..552b5db8 100644 --- a/zokrates_analysis/src/lib.rs +++ b/zokrates_analysis/src/lib.rs @@ -6,6 +6,7 @@ //! @author Thibaut Schaeffer //! @date 2018 +mod assembly_transformer; mod branch_isolator; mod condition_redefiner; mod constant_argument_checker; @@ -22,7 +23,6 @@ mod struct_concretizer; mod uint_optimizer; mod variable_write_remover; mod zir_propagation; -mod zir_validator; use self::branch_isolator::Isolator; use self::condition_redefiner::ConditionRedefiner; @@ -35,11 +35,11 @@ use self::reducer::reduce_program; use self::struct_concretizer::StructConcretizer; use self::uint_optimizer::UintOptimizer; use self::variable_write_remover::VariableWriteRemover; +use crate::assembly_transformer::AssemblyTransformer; use crate::constant_resolver::ConstantResolver; use crate::dead_code::DeadCodeEliminator; use crate::panic_extractor::PanicExtractor; pub use crate::zir_propagation::ZirPropagator; -use crate::zir_validator::ZirValidator; use std::fmt; use zokrates_ast::typed::{abi::Abi, TypedProgram}; use zokrates_ast::zir::ZirProgram; @@ -53,7 +53,7 @@ pub enum Error { ZirPropagation(self::zir_propagation::Error), NonConstantArgument(self::constant_argument_checker::Error), OutOfBounds(self::out_of_bounds::Error), - Assembly(self::zir_validator::Error), + Assembly(self::assembly_transformer::Error), } impl From for Error { @@ -86,8 +86,8 @@ impl From for Error { } } -impl From for Error { - fn from(e: zir_validator::Error) -> Self { +impl From for Error { + fn from(e: assembly_transformer::Error) -> Self { Error::Assembly(e) } } @@ -202,8 +202,8 @@ pub fn analyse<'ast, T: Field>( log::trace!("\n{}", zir); // validate zir - log::debug!("Static analyser: Validate zir"); - let zir = ZirValidator::validate(zir).map_err(Error::from)?; + log::debug!("Static analyser: Apply constraint transformations in assembly"); + let zir = AssemblyTransformer::transform(zir).map_err(Error::from)?; Ok((zir, abi)) } diff --git a/zokrates_analysis/src/zir_validator.rs b/zokrates_analysis/src/zir_validator.rs deleted file mode 100644 index ea9e6150..00000000 --- a/zokrates_analysis/src/zir_validator.rs +++ /dev/null @@ -1,54 +0,0 @@ -use std::fmt; -use zokrates_ast::zir::result_folder::{ - fold_assembly_statement, fold_field_expression, ResultFolder, -}; -use zokrates_ast::zir::{FieldElementExpression, ZirAssemblyStatement, ZirProgram}; -use zokrates_field::Field; - -#[derive(Debug)] -pub struct Error(String); - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self.0) - } -} - -pub struct ZirValidator; - -impl ZirValidator { - pub fn validate(p: ZirProgram) -> Result, Error> { - let mut checker = ZirValidator; - checker.fold_program(p) - } -} - -impl<'ast, T: Field> ResultFolder<'ast, T> for ZirValidator { - type Error = Error; - - fn fold_assembly_statement( - &mut self, - s: ZirAssemblyStatement<'ast, T>, - ) -> Result, Self::Error> { - match s { - ZirAssemblyStatement::Assignment(_, _) => Ok(s), - s => fold_assembly_statement(self, s), - } - } - - fn fold_field_expression( - &mut self, - e: FieldElementExpression<'ast, T>, - ) -> Result, Self::Error> { - match e { - FieldElementExpression::And(_, _) - | FieldElementExpression::Or(_, _) - | FieldElementExpression::Xor(_, _) - | FieldElementExpression::LeftShift(_, _) - | FieldElementExpression::RightShift(_, _) => Err(Error( - format!("Found bitwise operation in expression `{}` of type `field` (only allowed in assembly assignment statement)", e) - )), - e => fold_field_expression(self, e), - } - } -} diff --git a/zokrates_ast/src/typed/mod.rs b/zokrates_ast/src/typed/mod.rs index 47eb190f..fbde3813 100644 --- a/zokrates_ast/src/typed/mod.rs +++ b/zokrates_ast/src/typed/mod.rs @@ -1334,44 +1334,28 @@ impl<'ast, T> FieldElementExpression<'ast, T> { pub fn pow(self, other: UExpression<'ast, T>) -> Self { FieldElementExpression::Pow(box self, box other) } - pub fn is_quadratic(&self) -> bool { - match self { - FieldElementExpression::Mult(box left, box right) => { - left.is_linear() && right.is_linear() - } - _ => false, - } - } - fn is_linear(&self) -> bool { + // This is used for early detection in semantics but it is not completely accurate + // Deeper analysis is done in a separate step after semantic checks + pub fn is_non_quadratic(&self) -> bool { match self { - FieldElementExpression::Block(_) => false, - FieldElementExpression::Number(_) => true, - FieldElementExpression::Identifier(_) => true, + FieldElementExpression::Number(_) => false, + FieldElementExpression::Identifier(_) => false, FieldElementExpression::Add(box left, box right) => { - left.is_linear() && right.is_linear() + left.is_non_quadratic() || right.is_non_quadratic() } FieldElementExpression::Sub(box left, box right) => { - left.is_linear() && right.is_linear() + left.is_non_quadratic() || right.is_non_quadratic() } - FieldElementExpression::Mult(box left, box right) => matches!( - (left, right), - (FieldElementExpression::Number(_), _) | (_, FieldElementExpression::Number(_)) - ), - FieldElementExpression::Div(_, _) => false, - FieldElementExpression::Pow(_, _) => false, - FieldElementExpression::And(_, _) => false, - FieldElementExpression::Or(_, _) => false, - FieldElementExpression::Xor(_, _) => false, - FieldElementExpression::LeftShift(_, _) => false, - FieldElementExpression::RightShift(_, _) => false, - FieldElementExpression::Conditional(_) => false, - FieldElementExpression::Neg(_) => true, - FieldElementExpression::Pos(_) => true, - FieldElementExpression::FunctionCall(_) => false, - FieldElementExpression::Member(_) => true, - FieldElementExpression::Select(_) => true, - FieldElementExpression::Element(_) => true, + FieldElementExpression::Mult(box left, box right) => { + left.is_non_quadratic() || right.is_non_quadratic() + } + FieldElementExpression::Neg(_) => false, + FieldElementExpression::Pos(_) => false, + FieldElementExpression::Member(_) => false, + FieldElementExpression::Select(_) => false, + FieldElementExpression::Element(_) => false, + _ => true, } } } diff --git a/zokrates_ast/src/zir/lqc.rs b/zokrates_ast/src/zir/lqc.rs new file mode 100644 index 00000000..54f16548 --- /dev/null +++ b/zokrates_ast/src/zir/lqc.rs @@ -0,0 +1,132 @@ +use crate::zir::{FieldElementExpression, Identifier}; +use zokrates_field::Field; + +#[derive(Clone, PartialEq, Hash, Eq, Debug, Default)] +pub struct LinQuadComb<'ast, T> { + // the constant terms + pub constant: T, + // the linear terms + pub linear: Vec<(T, Identifier<'ast>)>, + // the quadratic terms + pub quadratic: Vec<(T, Identifier<'ast>, Identifier<'ast>)>, +} + +impl<'ast, T: Field> std::ops::Add for LinQuadComb<'ast, T> { + type Output = Self; + + fn add(self, mut other: Self) -> Self::Output { + Self { + constant: self.constant + other.constant, + linear: { + let mut l = self.linear; + l.append(&mut other.linear); + l + }, + quadratic: { + let mut q = self.quadratic; + q.append(&mut other.quadratic); + q + }, + } + } +} + +impl<'ast, T: Field> std::ops::Sub for LinQuadComb<'ast, T> { + type Output = Self; + + fn sub(self, mut other: Self) -> Self::Output { + Self { + constant: self.constant - other.constant, + linear: { + let mut l = self.linear; + other.linear.iter_mut().for_each(|(c, _)| { + *c = T::zero() - &*c; + }); + l.append(&mut other.linear); + l + }, + quadratic: { + let mut q = self.quadratic; + other.quadratic.iter_mut().for_each(|(c, _, _)| { + *c = T::zero() - &*c; + }); + q.append(&mut other.quadratic); + q + }, + } + } +} + +impl<'ast, T: Field> LinQuadComb<'ast, T> { + fn try_mul(self, rhs: Self) -> Result { + // fail if the result has degree higher than 2 + if !(self.quadratic.is_empty() || rhs.quadratic.is_empty()) { + return Err(()); + } + + Ok(Self { + constant: self.constant.clone() * rhs.constant.clone(), + linear: { + // lin0 * const1 + lin1 * const0 + self.linear + .clone() + .into_iter() + .map(|(c, i)| (c * rhs.constant.clone(), i)) + .chain( + rhs.linear + .clone() + .into_iter() + .map(|(c, i)| (c * self.constant.clone(), i)), + ) + .collect() + }, + quadratic: { + // quad0 * const1 + quad1 * const0 + lin0 * lin1 + self.quadratic + .into_iter() + .map(|(c, i0, i1)| (c * rhs.constant.clone(), i0, i1)) + .chain( + rhs.quadratic + .into_iter() + .map(|(c, i0, i1)| (c * self.constant.clone(), i0, i1)), + ) + .chain(self.linear.iter().flat_map(|(cl, l)| { + rhs.linear + .iter() + .map(|(cr, r)| (cl.clone() * cr.clone(), l.clone(), r.clone())) + })) + .collect() + }, + }) + } +} + +impl<'ast, T: Field> TryFrom> for LinQuadComb<'ast, T> { + type Error = (); + + fn try_from(e: FieldElementExpression<'ast, T>) -> Result { + match e { + FieldElementExpression::Number(v) => Ok(Self { + constant: v, + ..Self::default() + }), + FieldElementExpression::Identifier(id) => Ok(Self { + linear: vec![(T::one(), id)], + ..Self::default() + }), + FieldElementExpression::Add(box left, box right) => { + Ok(Self::try_from(left)? + Self::try_from(right)?) + } + FieldElementExpression::Sub(box left, box right) => { + Ok(Self::try_from(left)? - Self::try_from(right)?) + } + FieldElementExpression::Mult(box left, box right) => { + let left = Self::try_from(left)?; + let right = Self::try_from(right)?; + + left.try_mul(right) + } + _ => Err(()), + } + } +} diff --git a/zokrates_ast/src/zir/mod.rs b/zokrates_ast/src/zir/mod.rs index 6f3bb81f..28408346 100644 --- a/zokrates_ast/src/zir/mod.rs +++ b/zokrates_ast/src/zir/mod.rs @@ -1,6 +1,7 @@ pub mod folder; mod from_typed; mod identifier; +pub mod lqc; mod parameter; pub mod result_folder; pub mod types; diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index d70c71dd..d66ffb7b 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -1801,6 +1801,13 @@ impl<'ast, T: Field> Checker<'ast, T> { match constrained { true => { + // early non-quadratic detection + if e.is_non_quadratic() { + return Err(ErrorInner { + pos: Some(pos), + message: "Non-quadratic constraints are not allowed".to_string(), + }); + } let e = FieldElementExpression::block(vec![], e); match assignee.get_type() { Type::FieldElement => Ok(vec![ @@ -1822,34 +1829,36 @@ impl<'ast, T: Field> Checker<'ast, T> { AssemblyStatement::Constraint(lhs, rhs) => { let lhs = self.check_expression(lhs, module_id, types)?; let rhs = self.check_expression(rhs, module_id, types)?; - match (lhs, rhs) { + + let (lhs, rhs) = match (lhs, rhs) { (TypedExpression::FieldElement(lhs), TypedExpression::FieldElement(rhs)) => { - Ok(vec![TypedAssemblyStatement::Constraint(lhs, rhs)]) + Ok((lhs, rhs)) } (TypedExpression::FieldElement(lhs), TypedExpression::Int(rhs)) => { - Ok(vec![TypedAssemblyStatement::Constraint( - lhs, - FieldElementExpression::try_from_int(rhs).unwrap(), - )]) + Ok((lhs, FieldElementExpression::try_from_int(rhs).unwrap())) } (TypedExpression::Int(lhs), TypedExpression::FieldElement(rhs)) => { - Ok(vec![TypedAssemblyStatement::Constraint( - FieldElementExpression::try_from_int(lhs).unwrap(), - rhs, - )]) - } - (TypedExpression::Int(lhs), TypedExpression::Int(rhs)) => { - Ok(vec![TypedAssemblyStatement::Constraint( - FieldElementExpression::try_from_int(lhs).unwrap(), - FieldElementExpression::try_from_int(rhs).unwrap(), - )]) + Ok((FieldElementExpression::try_from_int(lhs).unwrap(), rhs)) } + (TypedExpression::Int(lhs), TypedExpression::Int(rhs)) => Ok(( + FieldElementExpression::try_from_int(lhs).unwrap(), + FieldElementExpression::try_from_int(rhs).unwrap(), + )), _ => Err(ErrorInner { pos: Some(pos), message: "Only field element expressions are allowed in the assembly block" .to_string(), }), + }?; + + if lhs.is_non_quadratic() || rhs.is_non_quadratic() { + return Err(ErrorInner { + pos: Some(pos), + message: "Non-quadratic constraints are not allowed".to_string(), + }); } + + Ok(vec![TypedAssemblyStatement::Constraint(lhs, rhs)]) } } } diff --git a/zokrates_core_test/tests/tests/assembly/division.json b/zokrates_core_test/tests/tests/assembly/division.json index f0da6cc0..7478326f 100644 --- a/zokrates_core_test/tests/tests/assembly/division.json +++ b/zokrates_core_test/tests/tests/assembly/division.json @@ -1,5 +1,6 @@ { "curves": ["Bn128"], + "max_constraint_count": 2, "tests": [ { "input": {