1
0
Fork 0
mirror of synced 2025-09-23 12:18:44 +00:00
This commit is contained in:
dark64 2022-11-09 18:57:30 +01:00
parent 33c6bfce9b
commit bccb08c836
8 changed files with 299 additions and 109 deletions

View file

@ -0,0 +1,117 @@
use std::fmt;
use zokrates_ast::zir::lqc::LinQuadComb;
use zokrates_ast::zir::result_folder::{fold_field_expression, ResultFolder};
use zokrates_ast::zir::{FieldElementExpression, ZirAssemblyStatement, ZirProgram};
use zokrates_field::Field;
#[derive(Debug)]
pub struct Error(String);
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.0)
}
}
pub struct AssemblyTransformer;
impl AssemblyTransformer {
pub fn transform<T: Field>(p: ZirProgram<T>) -> Result<ZirProgram<T>, Error> {
let mut f = AssemblyTransformer;
f.fold_program(p)
}
}
impl<'ast, T: Field> ResultFolder<'ast, T> for AssemblyTransformer {
type Error = Error;
fn fold_assembly_statement(
&mut self,
s: ZirAssemblyStatement<'ast, T>,
) -> Result<ZirAssemblyStatement<'ast, T>, Self::Error> {
match s {
ZirAssemblyStatement::Assignment(_, _) => Ok(s),
ZirAssemblyStatement::Constraint(lhs, rhs) => {
let lhs = self.fold_field_expression(lhs)?;
let rhs = self.fold_field_expression(rhs)?;
let sub = FieldElementExpression::Sub(box lhs, box rhs);
// let sub = match (lhs, rhs) {
// (FieldElementExpression::Number(n), e)
// | (e, FieldElementExpression::Number(n)) => {
// FieldElementExpression::Sub(box FieldElementExpression::Number(n), box e)
// }
// (lhs, rhs) => FieldElementExpression::Sub(box lhs, box rhs),
// };
let mut lqc = LinQuadComb::try_from(sub.clone()).map_err(|_| {
Error("Found forbidden operation in user-defined constraint".to_string())
})?;
println!("{:#?}", lqc);
if lqc.quadratic.len() > 1 {
return Err(Error(
"Non-quadratic constraints are not allowed".to_string(),
));
}
let linear = lqc
.linear
.into_iter()
.filter_map(|(c, i)| match c {
c if c == T::from(0) => None,
c if c == T::from(1) => Some(FieldElementExpression::Identifier(i)),
_ => Some(FieldElementExpression::Mult(
box FieldElementExpression::Number(c),
box FieldElementExpression::Identifier(i),
)),
})
.reduce(|p, n| FieldElementExpression::Add(box p, box n))
.unwrap();
let lhs = match lqc.constant {
c if c == T::from(0) => linear,
c => FieldElementExpression::Add(
box FieldElementExpression::Number(c),
box linear,
),
};
let rhs: FieldElementExpression<'ast, T> = lqc
.quadratic
.pop()
.map(|(c, i0, i1)| {
FieldElementExpression::Mult(
box FieldElementExpression::Mult(
box FieldElementExpression::Number(T::zero() - c),
box FieldElementExpression::Identifier(i0),
),
box FieldElementExpression::Identifier(i1),
)
})
.unwrap_or_else(|| FieldElementExpression::Number(T::from(0)));
println!("{} == {}", lhs, rhs);
Ok(ZirAssemblyStatement::Constraint(lhs, rhs))
}
}
}
fn fold_field_expression(
&mut self,
e: FieldElementExpression<'ast, T>,
) -> Result<FieldElementExpression<'ast, T>, Self::Error> {
match e {
FieldElementExpression::And(_, _)
| FieldElementExpression::Or(_, _)
| FieldElementExpression::Xor(_, _)
| FieldElementExpression::LeftShift(_, _)
| FieldElementExpression::RightShift(_, _) => Err(Error(
format!("Found bitwise operation in expression `{}` of type `field` (only allowed in assembly assignment statement)", e)
)),
e => fold_field_expression(self, e),
}
}
}

View file

@ -6,6 +6,7 @@
//! @author Thibaut Schaeffer <thibaut@schaeff.fr>
//! @date 2018
mod assembly_transformer;
mod branch_isolator;
mod condition_redefiner;
mod constant_argument_checker;
@ -22,7 +23,6 @@ mod struct_concretizer;
mod uint_optimizer;
mod variable_write_remover;
mod zir_propagation;
mod zir_validator;
use self::branch_isolator::Isolator;
use self::condition_redefiner::ConditionRedefiner;
@ -35,11 +35,11 @@ use self::reducer::reduce_program;
use self::struct_concretizer::StructConcretizer;
use self::uint_optimizer::UintOptimizer;
use self::variable_write_remover::VariableWriteRemover;
use crate::assembly_transformer::AssemblyTransformer;
use crate::constant_resolver::ConstantResolver;
use crate::dead_code::DeadCodeEliminator;
use crate::panic_extractor::PanicExtractor;
pub use crate::zir_propagation::ZirPropagator;
use crate::zir_validator::ZirValidator;
use std::fmt;
use zokrates_ast::typed::{abi::Abi, TypedProgram};
use zokrates_ast::zir::ZirProgram;
@ -53,7 +53,7 @@ pub enum Error {
ZirPropagation(self::zir_propagation::Error),
NonConstantArgument(self::constant_argument_checker::Error),
OutOfBounds(self::out_of_bounds::Error),
Assembly(self::zir_validator::Error),
Assembly(self::assembly_transformer::Error),
}
impl From<reducer::Error> for Error {
@ -86,8 +86,8 @@ impl From<constant_argument_checker::Error> for Error {
}
}
impl From<zir_validator::Error> for Error {
fn from(e: zir_validator::Error) -> Self {
impl From<assembly_transformer::Error> for Error {
fn from(e: assembly_transformer::Error) -> Self {
Error::Assembly(e)
}
}
@ -202,8 +202,8 @@ pub fn analyse<'ast, T: Field>(
log::trace!("\n{}", zir);
// validate zir
log::debug!("Static analyser: Validate zir");
let zir = ZirValidator::validate(zir).map_err(Error::from)?;
log::debug!("Static analyser: Apply constraint transformations in assembly");
let zir = AssemblyTransformer::transform(zir).map_err(Error::from)?;
Ok((zir, abi))
}

View file

@ -1,54 +0,0 @@
use std::fmt;
use zokrates_ast::zir::result_folder::{
fold_assembly_statement, fold_field_expression, ResultFolder,
};
use zokrates_ast::zir::{FieldElementExpression, ZirAssemblyStatement, ZirProgram};
use zokrates_field::Field;
#[derive(Debug)]
pub struct Error(String);
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.0)
}
}
pub struct ZirValidator;
impl ZirValidator {
pub fn validate<T: Field>(p: ZirProgram<T>) -> Result<ZirProgram<T>, Error> {
let mut checker = ZirValidator;
checker.fold_program(p)
}
}
impl<'ast, T: Field> ResultFolder<'ast, T> for ZirValidator {
type Error = Error;
fn fold_assembly_statement(
&mut self,
s: ZirAssemblyStatement<'ast, T>,
) -> Result<ZirAssemblyStatement<'ast, T>, Self::Error> {
match s {
ZirAssemblyStatement::Assignment(_, _) => Ok(s),
s => fold_assembly_statement(self, s),
}
}
fn fold_field_expression(
&mut self,
e: FieldElementExpression<'ast, T>,
) -> Result<FieldElementExpression<'ast, T>, Self::Error> {
match e {
FieldElementExpression::And(_, _)
| FieldElementExpression::Or(_, _)
| FieldElementExpression::Xor(_, _)
| FieldElementExpression::LeftShift(_, _)
| FieldElementExpression::RightShift(_, _) => Err(Error(
format!("Found bitwise operation in expression `{}` of type `field` (only allowed in assembly assignment statement)", e)
)),
e => fold_field_expression(self, e),
}
}
}

View file

@ -1334,44 +1334,28 @@ impl<'ast, T> FieldElementExpression<'ast, T> {
pub fn pow(self, other: UExpression<'ast, T>) -> Self {
FieldElementExpression::Pow(box self, box other)
}
pub fn is_quadratic(&self) -> bool {
match self {
FieldElementExpression::Mult(box left, box right) => {
left.is_linear() && right.is_linear()
}
_ => false,
}
}
fn is_linear(&self) -> bool {
// This is used for early detection in semantics but it is not completely accurate
// Deeper analysis is done in a separate step after semantic checks
pub fn is_non_quadratic(&self) -> bool {
match self {
FieldElementExpression::Block(_) => false,
FieldElementExpression::Number(_) => true,
FieldElementExpression::Identifier(_) => true,
FieldElementExpression::Number(_) => false,
FieldElementExpression::Identifier(_) => false,
FieldElementExpression::Add(box left, box right) => {
left.is_linear() && right.is_linear()
left.is_non_quadratic() || right.is_non_quadratic()
}
FieldElementExpression::Sub(box left, box right) => {
left.is_linear() && right.is_linear()
left.is_non_quadratic() || right.is_non_quadratic()
}
FieldElementExpression::Mult(box left, box right) => matches!(
(left, right),
(FieldElementExpression::Number(_), _) | (_, FieldElementExpression::Number(_))
),
FieldElementExpression::Div(_, _) => false,
FieldElementExpression::Pow(_, _) => false,
FieldElementExpression::And(_, _) => false,
FieldElementExpression::Or(_, _) => false,
FieldElementExpression::Xor(_, _) => false,
FieldElementExpression::LeftShift(_, _) => false,
FieldElementExpression::RightShift(_, _) => false,
FieldElementExpression::Conditional(_) => false,
FieldElementExpression::Neg(_) => true,
FieldElementExpression::Pos(_) => true,
FieldElementExpression::FunctionCall(_) => false,
FieldElementExpression::Member(_) => true,
FieldElementExpression::Select(_) => true,
FieldElementExpression::Element(_) => true,
FieldElementExpression::Mult(box left, box right) => {
left.is_non_quadratic() || right.is_non_quadratic()
}
FieldElementExpression::Neg(_) => false,
FieldElementExpression::Pos(_) => false,
FieldElementExpression::Member(_) => false,
FieldElementExpression::Select(_) => false,
FieldElementExpression::Element(_) => false,
_ => true,
}
}
}

132
zokrates_ast/src/zir/lqc.rs Normal file
View file

@ -0,0 +1,132 @@
use crate::zir::{FieldElementExpression, Identifier};
use zokrates_field::Field;
#[derive(Clone, PartialEq, Hash, Eq, Debug, Default)]
pub struct LinQuadComb<'ast, T> {
// the constant terms
pub constant: T,
// the linear terms
pub linear: Vec<(T, Identifier<'ast>)>,
// the quadratic terms
pub quadratic: Vec<(T, Identifier<'ast>, Identifier<'ast>)>,
}
impl<'ast, T: Field> std::ops::Add for LinQuadComb<'ast, T> {
type Output = Self;
fn add(self, mut other: Self) -> Self::Output {
Self {
constant: self.constant + other.constant,
linear: {
let mut l = self.linear;
l.append(&mut other.linear);
l
},
quadratic: {
let mut q = self.quadratic;
q.append(&mut other.quadratic);
q
},
}
}
}
impl<'ast, T: Field> std::ops::Sub for LinQuadComb<'ast, T> {
type Output = Self;
fn sub(self, mut other: Self) -> Self::Output {
Self {
constant: self.constant - other.constant,
linear: {
let mut l = self.linear;
other.linear.iter_mut().for_each(|(c, _)| {
*c = T::zero() - &*c;
});
l.append(&mut other.linear);
l
},
quadratic: {
let mut q = self.quadratic;
other.quadratic.iter_mut().for_each(|(c, _, _)| {
*c = T::zero() - &*c;
});
q.append(&mut other.quadratic);
q
},
}
}
}
impl<'ast, T: Field> LinQuadComb<'ast, T> {
fn try_mul(self, rhs: Self) -> Result<Self, ()> {
// fail if the result has degree higher than 2
if !(self.quadratic.is_empty() || rhs.quadratic.is_empty()) {
return Err(());
}
Ok(Self {
constant: self.constant.clone() * rhs.constant.clone(),
linear: {
// lin0 * const1 + lin1 * const0
self.linear
.clone()
.into_iter()
.map(|(c, i)| (c * rhs.constant.clone(), i))
.chain(
rhs.linear
.clone()
.into_iter()
.map(|(c, i)| (c * self.constant.clone(), i)),
)
.collect()
},
quadratic: {
// quad0 * const1 + quad1 * const0 + lin0 * lin1
self.quadratic
.into_iter()
.map(|(c, i0, i1)| (c * rhs.constant.clone(), i0, i1))
.chain(
rhs.quadratic
.into_iter()
.map(|(c, i0, i1)| (c * self.constant.clone(), i0, i1)),
)
.chain(self.linear.iter().flat_map(|(cl, l)| {
rhs.linear
.iter()
.map(|(cr, r)| (cl.clone() * cr.clone(), l.clone(), r.clone()))
}))
.collect()
},
})
}
}
impl<'ast, T: Field> TryFrom<FieldElementExpression<'ast, T>> for LinQuadComb<'ast, T> {
type Error = ();
fn try_from(e: FieldElementExpression<'ast, T>) -> Result<Self, Self::Error> {
match e {
FieldElementExpression::Number(v) => Ok(Self {
constant: v,
..Self::default()
}),
FieldElementExpression::Identifier(id) => Ok(Self {
linear: vec![(T::one(), id)],
..Self::default()
}),
FieldElementExpression::Add(box left, box right) => {
Ok(Self::try_from(left)? + Self::try_from(right)?)
}
FieldElementExpression::Sub(box left, box right) => {
Ok(Self::try_from(left)? - Self::try_from(right)?)
}
FieldElementExpression::Mult(box left, box right) => {
let left = Self::try_from(left)?;
let right = Self::try_from(right)?;
left.try_mul(right)
}
_ => Err(()),
}
}
}

View file

@ -1,6 +1,7 @@
pub mod folder;
mod from_typed;
mod identifier;
pub mod lqc;
mod parameter;
pub mod result_folder;
pub mod types;

View file

@ -1801,6 +1801,13 @@ impl<'ast, T: Field> Checker<'ast, T> {
match constrained {
true => {
// early non-quadratic detection
if e.is_non_quadratic() {
return Err(ErrorInner {
pos: Some(pos),
message: "Non-quadratic constraints are not allowed".to_string(),
});
}
let e = FieldElementExpression::block(vec![], e);
match assignee.get_type() {
Type::FieldElement => Ok(vec![
@ -1822,34 +1829,36 @@ impl<'ast, T: Field> Checker<'ast, T> {
AssemblyStatement::Constraint(lhs, rhs) => {
let lhs = self.check_expression(lhs, module_id, types)?;
let rhs = self.check_expression(rhs, module_id, types)?;
match (lhs, rhs) {
let (lhs, rhs) = match (lhs, rhs) {
(TypedExpression::FieldElement(lhs), TypedExpression::FieldElement(rhs)) => {
Ok(vec![TypedAssemblyStatement::Constraint(lhs, rhs)])
Ok((lhs, rhs))
}
(TypedExpression::FieldElement(lhs), TypedExpression::Int(rhs)) => {
Ok(vec![TypedAssemblyStatement::Constraint(
lhs,
FieldElementExpression::try_from_int(rhs).unwrap(),
)])
Ok((lhs, FieldElementExpression::try_from_int(rhs).unwrap()))
}
(TypedExpression::Int(lhs), TypedExpression::FieldElement(rhs)) => {
Ok(vec![TypedAssemblyStatement::Constraint(
FieldElementExpression::try_from_int(lhs).unwrap(),
rhs,
)])
}
(TypedExpression::Int(lhs), TypedExpression::Int(rhs)) => {
Ok(vec![TypedAssemblyStatement::Constraint(
FieldElementExpression::try_from_int(lhs).unwrap(),
FieldElementExpression::try_from_int(rhs).unwrap(),
)])
Ok((FieldElementExpression::try_from_int(lhs).unwrap(), rhs))
}
(TypedExpression::Int(lhs), TypedExpression::Int(rhs)) => Ok((
FieldElementExpression::try_from_int(lhs).unwrap(),
FieldElementExpression::try_from_int(rhs).unwrap(),
)),
_ => Err(ErrorInner {
pos: Some(pos),
message: "Only field element expressions are allowed in the assembly block"
.to_string(),
}),
}?;
if lhs.is_non_quadratic() || rhs.is_non_quadratic() {
return Err(ErrorInner {
pos: Some(pos),
message: "Non-quadratic constraints are not allowed".to_string(),
});
}
Ok(vec![TypedAssemblyStatement::Constraint(lhs, rhs)])
}
}
}

View file

@ -1,5 +1,6 @@
{
"curves": ["Bn128"],
"max_constraint_count": 2,
"tests": [
{
"input": {