1
0
Fork 0
mirror of synced 2025-09-24 04:40:05 +00:00

allow variable shifts in witness assignment

This commit is contained in:
dark64 2022-11-17 19:08:56 +01:00
parent c61a481e0b
commit 063a815308
17 changed files with 281 additions and 86 deletions

View file

@ -54,9 +54,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for AssemblyTransformer {
false => {
let sub = FieldElementExpression::Sub(box lhs.clone(), box rhs.clone());
let mut lqc = LinQuadComb::try_from(sub.clone()).map_err(|_| {
Error(
"Found forbidden operation in user-defined constraint".to_string(),
)
Error("Non-quadratic constraints are not allowed".to_string())
})?;
let linear = lqc
@ -127,10 +125,9 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for AssemblyTransformer {
}),
box FieldElementExpression::identifier(id),
)),
_ => Err(Error(format!(
"Non-quadratic constraints are not allowed `{} == {}`",
lhs, rhs
))),
_ => Err(Error(
"Non-quadratic constraints are not allowed".to_string(),
)),
}?
} else {
lqc.quadratic

View file

@ -90,36 +90,6 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantArgumentChecker {
))),
}
}
FieldElementExpression::LeftShift(box e, box by) => {
let e = self.fold_field_expression(e)?;
let by = self.fold_uint_expression(by)?;
match by.as_inner() {
UExpressionInner::Value(_) => {
Ok(FieldElementExpression::LeftShift(box e, box by))
}
by => Err(Error(format!(
"Cannot shift by a variable value, found `{} << {}`",
e,
by.clone().annotate(UBitwidth::B32)
))),
}
}
FieldElementExpression::RightShift(box e, box by) => {
let e = self.fold_field_expression(e)?;
let by = self.fold_uint_expression(by)?;
match by.as_inner() {
UExpressionInner::Value(_) => {
Ok(FieldElementExpression::RightShift(box e, box by))
}
by => Err(Error(format!(
"Cannot shift by a variable value, found `{} << {}`",
e,
by.clone().annotate(UBitwidth::B32)
))),
}
}
e => fold_field_expression(self, e),
}
}

View file

@ -1,12 +1,12 @@
use std::collections::HashSet;
use std::collections::HashMap;
use std::convert::{TryFrom, TryInto};
use std::marker::PhantomData;
use zokrates_ast::typed::types::{ConcreteArrayType, IntoType, UBitwidth};
use zokrates_ast::typed::{self, Expr, Typed};
use zokrates_ast::zir::IntoType as ZirIntoType;
use zokrates_ast::zir::{self, Folder, Id, Select};
use zokrates_field::Field;
use std::convert::{TryFrom, TryInto};
#[derive(Default)]
pub struct Flattener<T: Field> {
phantom: PhantomData<T>,
@ -460,15 +460,11 @@ impl<'ast, T: Field> Flattener<T> {
#[derive(Default)]
pub struct ArgumentFinder<'ast, T> {
pub identifiers: HashSet<zir::Identifier<'ast>>,
pub identifiers: HashMap<zir::Identifier<'ast>, zir::Type>,
_phantom: PhantomData<T>,
}
impl<'ast, T: Field> Folder<'ast, T> for ArgumentFinder<'ast, T> {
fn fold_name(&mut self, n: zir::Identifier<'ast>) -> zir::Identifier<'ast> {
self.identifiers.insert(n.clone());
n
}
fn fold_statement(&mut self, s: zir::ZirStatement<'ast, T>) -> Vec<zir::ZirStatement<'ast, T>> {
match s {
zir::ZirStatement::Definition(assignee, expr) => {
@ -491,6 +487,16 @@ impl<'ast, T: Field> Folder<'ast, T> for ArgumentFinder<'ast, T> {
s => zir::folder::fold_statement(self, s),
}
}
fn fold_identifier_expression<E: zir::Expr<'ast, T> + Id<'ast, T>>(
&mut self,
ty: &E::Ty,
e: zir::IdentifierExpression<'ast, E>,
) -> zir::IdentifierOrExpression<'ast, T, E> {
self.identifiers
.insert(e.id.clone(), ty.clone().into_type());
zir::IdentifierOrExpression::Identifier(e)
}
}
fn fold_assembly_statement<'ast, T: Field>(
@ -515,15 +521,17 @@ fn fold_assembly_statement<'ast, T: Field>(
.collect();
statements_buffer.reverse();
let _ = dbg!(&finder.identifiers);
let function = zir::ZirFunction {
signature: zir::types::Signature::default()
.inputs(vec![zir::Type::FieldElement; finder.identifiers.len()])
.inputs(finder.identifiers.values().cloned().collect())
.outputs(a.iter().map(|a| a.get_type()).collect()),
arguments: finder
.identifiers
.into_iter()
.map(|id| zir::Parameter {
id: zir::Variable::field_element(id),
.map(|(id, ty)| zir::Parameter {
id: zir::Variable::with_id_and_type(id, ty),
private: false,
})
.collect(),
@ -1014,23 +1022,15 @@ fn fold_field_expression<'ast, T: Field>(
}
typed::FieldElementExpression::LeftShift(box e, box by) => {
let e = f.fold_field_expression(statements_buffer, e);
let by = f.fold_uint_expression(statements_buffer, by);
let by = match by.as_inner() {
typed::UExpressionInner::Value(by) => by,
_ => unreachable!("static analysis should have made sure that this is constant"),
};
zir::FieldElementExpression::LeftShift(box e, *by as u32)
zir::FieldElementExpression::LeftShift(box e, box by)
}
typed::FieldElementExpression::RightShift(box e, box by) => {
let e = f.fold_field_expression(statements_buffer, e);
let by = f.fold_uint_expression(statements_buffer, by);
let by = match by.as_inner() {
typed::UExpressionInner::Value(by) => by,
_ => unreachable!("static analysis should have made sure that this is constant"),
};
zir::FieldElementExpression::RightShift(box e, *by as u32)
zir::FieldElementExpression::RightShift(box e, box by)
}
typed::FieldElementExpression::Conditional(c) => f
.fold_conditional_expression(statements_buffer, c)

View file

@ -939,7 +939,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
}
FieldElementExpression::RightShift(box e, box by) => {
let e = self.fold_field_expression(e)?;
let by = dbg!(self.fold_uint_expression(by)?);
let by = self.fold_uint_expression(by)?;
match (e, by) {
(
e,

View file

@ -301,24 +301,50 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> {
(e1, e2) => Ok(FieldElementExpression::Or(box e1, box e2)),
}
}
FieldElementExpression::LeftShift(box e, by) => {
FieldElementExpression::LeftShift(box e, box by) => {
let e = self.fold_field_expression(e)?;
let by = self.fold_uint_expression(by)?;
match (e, by) {
(e, by) if by == 0u32 => Ok(e),
(FieldElementExpression::Number(n), by) => Ok(FieldElementExpression::Number(
(
e,
UExpression {
inner: UExpressionInner::Value(by),
..
},
) if by == 0 => Ok(e),
(
FieldElementExpression::Number(n),
UExpression {
inner: UExpressionInner::Value(by),
..
},
) => Ok(FieldElementExpression::Number(
T::try_from(n.to_biguint().shl(by as usize)).unwrap(),
)),
(e, by) => Ok(FieldElementExpression::LeftShift(box e, by)),
(e, by) => Ok(FieldElementExpression::LeftShift(box e, box by)),
}
}
FieldElementExpression::RightShift(box e, by) => {
FieldElementExpression::RightShift(box e, box by) => {
let e = self.fold_field_expression(e)?;
let by = self.fold_uint_expression(by)?;
match (e, by) {
(e, by) if by == 0u32 => Ok(e),
(FieldElementExpression::Number(n), by) => Ok(FieldElementExpression::Number(
(
e,
UExpression {
inner: UExpressionInner::Value(by),
..
},
) if by == 0 => Ok(e),
(
FieldElementExpression::Number(n),
UExpression {
inner: UExpressionInner::Value(by),
..
},
) => Ok(FieldElementExpression::Number(
T::try_from(n.to_biguint().shr(by as usize)).unwrap(),
)),
(e, by) => Ok(FieldElementExpression::RightShift(box e, by)),
(e, by) => Ok(FieldElementExpression::RightShift(box e, box by)),
}
}
e => fold_field_expression(self, e),

View file

@ -282,15 +282,17 @@ pub fn fold_field_expression<'ast, T: Field, F: Folder<'ast, T>>(
FieldElementExpression::Xor(box left, box right)
}
FieldElementExpression::LeftShift(box e, by) => {
FieldElementExpression::LeftShift(box e, box by) => {
let e = f.fold_field_expression(e);
let by = f.fold_uint_expression(by);
FieldElementExpression::LeftShift(box e, by)
FieldElementExpression::LeftShift(box e, box by)
}
FieldElementExpression::RightShift(box e, by) => {
FieldElementExpression::RightShift(box e, box by) => {
let e = f.fold_field_expression(e);
let by = f.fold_uint_expression(by);
FieldElementExpression::RightShift(box e, by)
FieldElementExpression::RightShift(box e, box by)
}
FieldElementExpression::Conditional(c) => {
match f.fold_conditional_expression(&Type::FieldElement, c) {

View file

@ -71,7 +71,7 @@ impl<'ast, T: fmt::Display> fmt::Display for ZirFunction<'ast, T> {
writeln!(f)?;
}
writeln!(f, "}}")
write!(f, "}}")
}
}
@ -464,8 +464,14 @@ pub enum FieldElementExpression<'ast, T> {
Box<FieldElementExpression<'ast, T>>,
Box<FieldElementExpression<'ast, T>>,
),
LeftShift(Box<FieldElementExpression<'ast, T>>, u32),
RightShift(Box<FieldElementExpression<'ast, T>>, u32),
LeftShift(
Box<FieldElementExpression<'ast, T>>,
Box<UExpression<'ast, T>>,
),
RightShift(
Box<FieldElementExpression<'ast, T>>,
Box<UExpression<'ast, T>>,
),
Conditional(ConditionalExpression<'ast, T, FieldElementExpression<'ast, T>>),
}

View file

@ -306,15 +306,17 @@ pub fn fold_field_expression<'ast, T: Field, F: ResultFolder<'ast, T>>(
FieldElementExpression::Or(box left, box right)
}
FieldElementExpression::LeftShift(box e, by) => {
FieldElementExpression::LeftShift(box e, box by) => {
let e = f.fold_field_expression(e)?;
let by = f.fold_uint_expression(by)?;
FieldElementExpression::LeftShift(box e, by)
FieldElementExpression::LeftShift(box e, box by)
}
FieldElementExpression::RightShift(box e, by) => {
FieldElementExpression::RightShift(box e, box by) => {
let e = f.fold_field_expression(e)?;
let by = f.fold_uint_expression(by)?;
FieldElementExpression::RightShift(box e, by)
FieldElementExpression::RightShift(box e, box by)
}
FieldElementExpression::Conditional(c) => {
match f.fold_conditional_expression(&Type::FieldElement, c)? {

View file

@ -0,0 +1,46 @@
{
"curves": ["Bn128"],
"max_constraint_count": 2,
"tests": [
{
"input": {
"values": ["0", "0"]
},
"output": {
"Ok": {
"value": "0"
}
}
},
{
"input": {
"values": ["1", "0"]
},
"output": {
"Ok": {
"value": "0"
}
}
},
{
"input": {
"values": ["0", "1"]
},
"output": {
"Ok": {
"value": "0"
}
}
},
{
"input": {
"values": ["1", "1"]
},
"output": {
"Ok": {
"value": "1"
}
}
}
]
}

View file

@ -0,0 +1,7 @@
def main(field a, field b) -> field {
field mut c = 0;
asm {
c <== a * b;
}
return c;
}

View file

@ -0,0 +1,26 @@
{
"curves": ["Bn128"],
"max_constraint_count": 2,
"tests": [
{
"input": {
"values": ["0"]
},
"output": {
"Ok": {
"value": "1"
}
}
},
{
"input": {
"values": ["1"]
},
"output": {
"Ok": {
"value": "0"
}
}
}
]
}

View file

@ -0,0 +1,7 @@
def main(field inp) -> field {
field mut out = 0;
asm {
out <== 1 + inp - 2*inp;
}
return out;
}

View file

@ -0,0 +1,46 @@
{
"curves": ["Bn128"],
"max_constraint_count": 2,
"tests": [
{
"input": {
"values": ["0", "0"]
},
"output": {
"Ok": {
"value": "0"
}
}
},
{
"input": {
"values": ["1", "0"]
},
"output": {
"Ok": {
"value": "1"
}
}
},
{
"input": {
"values": ["0", "1"]
},
"output": {
"Ok": {
"value": "1"
}
}
},
{
"input": {
"values": ["1", "1"]
},
"output": {
"Ok": {
"value": "1"
}
}
}
]
}

View file

@ -0,0 +1,7 @@
def main(field a, field b) -> field {
field mut c = 0;
asm {
c <== a + b - a*b;
}
return c;
}

View file

@ -0,0 +1,36 @@
{
"curves": ["Bn128"],
"max_constraint_count": 3,
"tests": [
{
"input": {
"values": ["1", "1"]
},
"output": {
"Ok": {
"value": "1"
}
}
},
{
"input": {
"values": ["2", "4"]
},
"output": {
"Ok": {
"value": "0"
}
}
},
{
"input": {
"values": ["4", "2"]
},
"output": {
"Ok": {
"value": "0"
}
}
}
]
}

View file

@ -0,0 +1,5 @@
import "./is_zero.zok";
def main(field a, field b) -> field {
return is_zero(b - a);
}

View file

@ -4,6 +4,7 @@ use zokrates_abi::{Decode, Value};
use zokrates_ast::ir::{
LinComb, ProgIterator, QuadComb, RuntimeError, Solver, Statement, Variable, Witness,
};
use zokrates_ast::zir;
use zokrates_field::Field;
pub type ExecutionResult<T> = Result<Witness<T>, Error>;
@ -168,11 +169,6 @@ impl Interpreter {
let res = match solver {
Solver::Zir(func) => {
use zokrates_ast::zir::result_folder::ResultFolder;
assert!(func
.arguments
.iter()
.all(|a| a.id._type == zokrates_ast::zir::Type::FieldElement));
assert_eq!(func.arguments.len(), inputs.len());
let constants = func
@ -182,7 +178,23 @@ impl Interpreter {
.map(|(a, v)| {
(
a.id.id.clone(),
zokrates_ast::zir::FieldElementExpression::Number(v.clone()).into(),
match &a.id._type {
zir::Type::FieldElement => {
zokrates_ast::zir::FieldElementExpression::Number(v.clone())
.into()
}
zir::Type::Boolean => {
zokrates_ast::zir::BooleanExpression::Value(*v == T::from(1))
.into()
}
zir::Type::Uint(bitwidth) => {
zokrates_ast::zir::UExpressionInner::Value(
v.to_dec_string().parse::<u128>().unwrap(),
)
.annotate(*bitwidth)
.into()
}
},
)
})
.collect();