allow variable shifts in witness assignment
This commit is contained in:
parent
c61a481e0b
commit
063a815308
17 changed files with 281 additions and 86 deletions
|
@ -54,9 +54,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for AssemblyTransformer {
|
|||
false => {
|
||||
let sub = FieldElementExpression::Sub(box lhs.clone(), box rhs.clone());
|
||||
let mut lqc = LinQuadComb::try_from(sub.clone()).map_err(|_| {
|
||||
Error(
|
||||
"Found forbidden operation in user-defined constraint".to_string(),
|
||||
)
|
||||
Error("Non-quadratic constraints are not allowed".to_string())
|
||||
})?;
|
||||
|
||||
let linear = lqc
|
||||
|
@ -127,10 +125,9 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for AssemblyTransformer {
|
|||
}),
|
||||
box FieldElementExpression::identifier(id),
|
||||
)),
|
||||
_ => Err(Error(format!(
|
||||
"Non-quadratic constraints are not allowed `{} == {}`",
|
||||
lhs, rhs
|
||||
))),
|
||||
_ => Err(Error(
|
||||
"Non-quadratic constraints are not allowed".to_string(),
|
||||
)),
|
||||
}?
|
||||
} else {
|
||||
lqc.quadratic
|
||||
|
|
|
@ -90,36 +90,6 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantArgumentChecker {
|
|||
))),
|
||||
}
|
||||
}
|
||||
FieldElementExpression::LeftShift(box e, box by) => {
|
||||
let e = self.fold_field_expression(e)?;
|
||||
let by = self.fold_uint_expression(by)?;
|
||||
|
||||
match by.as_inner() {
|
||||
UExpressionInner::Value(_) => {
|
||||
Ok(FieldElementExpression::LeftShift(box e, box by))
|
||||
}
|
||||
by => Err(Error(format!(
|
||||
"Cannot shift by a variable value, found `{} << {}`",
|
||||
e,
|
||||
by.clone().annotate(UBitwidth::B32)
|
||||
))),
|
||||
}
|
||||
}
|
||||
FieldElementExpression::RightShift(box e, box by) => {
|
||||
let e = self.fold_field_expression(e)?;
|
||||
let by = self.fold_uint_expression(by)?;
|
||||
|
||||
match by.as_inner() {
|
||||
UExpressionInner::Value(_) => {
|
||||
Ok(FieldElementExpression::RightShift(box e, box by))
|
||||
}
|
||||
by => Err(Error(format!(
|
||||
"Cannot shift by a variable value, found `{} << {}`",
|
||||
e,
|
||||
by.clone().annotate(UBitwidth::B32)
|
||||
))),
|
||||
}
|
||||
}
|
||||
e => fold_field_expression(self, e),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
use std::collections::HashSet;
|
||||
use std::collections::HashMap;
|
||||
use std::convert::{TryFrom, TryInto};
|
||||
use std::marker::PhantomData;
|
||||
use zokrates_ast::typed::types::{ConcreteArrayType, IntoType, UBitwidth};
|
||||
use zokrates_ast::typed::{self, Expr, Typed};
|
||||
use zokrates_ast::zir::IntoType as ZirIntoType;
|
||||
use zokrates_ast::zir::{self, Folder, Id, Select};
|
||||
use zokrates_field::Field;
|
||||
|
||||
use std::convert::{TryFrom, TryInto};
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct Flattener<T: Field> {
|
||||
phantom: PhantomData<T>,
|
||||
|
@ -460,15 +460,11 @@ impl<'ast, T: Field> Flattener<T> {
|
|||
|
||||
#[derive(Default)]
|
||||
pub struct ArgumentFinder<'ast, T> {
|
||||
pub identifiers: HashSet<zir::Identifier<'ast>>,
|
||||
pub identifiers: HashMap<zir::Identifier<'ast>, zir::Type>,
|
||||
_phantom: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> Folder<'ast, T> for ArgumentFinder<'ast, T> {
|
||||
fn fold_name(&mut self, n: zir::Identifier<'ast>) -> zir::Identifier<'ast> {
|
||||
self.identifiers.insert(n.clone());
|
||||
n
|
||||
}
|
||||
fn fold_statement(&mut self, s: zir::ZirStatement<'ast, T>) -> Vec<zir::ZirStatement<'ast, T>> {
|
||||
match s {
|
||||
zir::ZirStatement::Definition(assignee, expr) => {
|
||||
|
@ -491,6 +487,16 @@ impl<'ast, T: Field> Folder<'ast, T> for ArgumentFinder<'ast, T> {
|
|||
s => zir::folder::fold_statement(self, s),
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_identifier_expression<E: zir::Expr<'ast, T> + Id<'ast, T>>(
|
||||
&mut self,
|
||||
ty: &E::Ty,
|
||||
e: zir::IdentifierExpression<'ast, E>,
|
||||
) -> zir::IdentifierOrExpression<'ast, T, E> {
|
||||
self.identifiers
|
||||
.insert(e.id.clone(), ty.clone().into_type());
|
||||
zir::IdentifierOrExpression::Identifier(e)
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_assembly_statement<'ast, T: Field>(
|
||||
|
@ -515,15 +521,17 @@ fn fold_assembly_statement<'ast, T: Field>(
|
|||
.collect();
|
||||
statements_buffer.reverse();
|
||||
|
||||
let _ = dbg!(&finder.identifiers);
|
||||
|
||||
let function = zir::ZirFunction {
|
||||
signature: zir::types::Signature::default()
|
||||
.inputs(vec![zir::Type::FieldElement; finder.identifiers.len()])
|
||||
.inputs(finder.identifiers.values().cloned().collect())
|
||||
.outputs(a.iter().map(|a| a.get_type()).collect()),
|
||||
arguments: finder
|
||||
.identifiers
|
||||
.into_iter()
|
||||
.map(|id| zir::Parameter {
|
||||
id: zir::Variable::field_element(id),
|
||||
.map(|(id, ty)| zir::Parameter {
|
||||
id: zir::Variable::with_id_and_type(id, ty),
|
||||
private: false,
|
||||
})
|
||||
.collect(),
|
||||
|
@ -1014,23 +1022,15 @@ fn fold_field_expression<'ast, T: Field>(
|
|||
}
|
||||
typed::FieldElementExpression::LeftShift(box e, box by) => {
|
||||
let e = f.fold_field_expression(statements_buffer, e);
|
||||
let by = f.fold_uint_expression(statements_buffer, by);
|
||||
|
||||
let by = match by.as_inner() {
|
||||
typed::UExpressionInner::Value(by) => by,
|
||||
_ => unreachable!("static analysis should have made sure that this is constant"),
|
||||
};
|
||||
|
||||
zir::FieldElementExpression::LeftShift(box e, *by as u32)
|
||||
zir::FieldElementExpression::LeftShift(box e, box by)
|
||||
}
|
||||
typed::FieldElementExpression::RightShift(box e, box by) => {
|
||||
let e = f.fold_field_expression(statements_buffer, e);
|
||||
let by = f.fold_uint_expression(statements_buffer, by);
|
||||
|
||||
let by = match by.as_inner() {
|
||||
typed::UExpressionInner::Value(by) => by,
|
||||
_ => unreachable!("static analysis should have made sure that this is constant"),
|
||||
};
|
||||
|
||||
zir::FieldElementExpression::RightShift(box e, *by as u32)
|
||||
zir::FieldElementExpression::RightShift(box e, box by)
|
||||
}
|
||||
typed::FieldElementExpression::Conditional(c) => f
|
||||
.fold_conditional_expression(statements_buffer, c)
|
||||
|
|
|
@ -939,7 +939,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
|
|||
}
|
||||
FieldElementExpression::RightShift(box e, box by) => {
|
||||
let e = self.fold_field_expression(e)?;
|
||||
let by = dbg!(self.fold_uint_expression(by)?);
|
||||
let by = self.fold_uint_expression(by)?;
|
||||
match (e, by) {
|
||||
(
|
||||
e,
|
||||
|
|
|
@ -301,24 +301,50 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> {
|
|||
(e1, e2) => Ok(FieldElementExpression::Or(box e1, box e2)),
|
||||
}
|
||||
}
|
||||
FieldElementExpression::LeftShift(box e, by) => {
|
||||
FieldElementExpression::LeftShift(box e, box by) => {
|
||||
let e = self.fold_field_expression(e)?;
|
||||
let by = self.fold_uint_expression(by)?;
|
||||
match (e, by) {
|
||||
(e, by) if by == 0u32 => Ok(e),
|
||||
(FieldElementExpression::Number(n), by) => Ok(FieldElementExpression::Number(
|
||||
(
|
||||
e,
|
||||
UExpression {
|
||||
inner: UExpressionInner::Value(by),
|
||||
..
|
||||
},
|
||||
) if by == 0 => Ok(e),
|
||||
(
|
||||
FieldElementExpression::Number(n),
|
||||
UExpression {
|
||||
inner: UExpressionInner::Value(by),
|
||||
..
|
||||
},
|
||||
) => Ok(FieldElementExpression::Number(
|
||||
T::try_from(n.to_biguint().shl(by as usize)).unwrap(),
|
||||
)),
|
||||
(e, by) => Ok(FieldElementExpression::LeftShift(box e, by)),
|
||||
(e, by) => Ok(FieldElementExpression::LeftShift(box e, box by)),
|
||||
}
|
||||
}
|
||||
FieldElementExpression::RightShift(box e, by) => {
|
||||
FieldElementExpression::RightShift(box e, box by) => {
|
||||
let e = self.fold_field_expression(e)?;
|
||||
let by = self.fold_uint_expression(by)?;
|
||||
match (e, by) {
|
||||
(e, by) if by == 0u32 => Ok(e),
|
||||
(FieldElementExpression::Number(n), by) => Ok(FieldElementExpression::Number(
|
||||
(
|
||||
e,
|
||||
UExpression {
|
||||
inner: UExpressionInner::Value(by),
|
||||
..
|
||||
},
|
||||
) if by == 0 => Ok(e),
|
||||
(
|
||||
FieldElementExpression::Number(n),
|
||||
UExpression {
|
||||
inner: UExpressionInner::Value(by),
|
||||
..
|
||||
},
|
||||
) => Ok(FieldElementExpression::Number(
|
||||
T::try_from(n.to_biguint().shr(by as usize)).unwrap(),
|
||||
)),
|
||||
(e, by) => Ok(FieldElementExpression::RightShift(box e, by)),
|
||||
(e, by) => Ok(FieldElementExpression::RightShift(box e, box by)),
|
||||
}
|
||||
}
|
||||
e => fold_field_expression(self, e),
|
||||
|
|
|
@ -282,15 +282,17 @@ pub fn fold_field_expression<'ast, T: Field, F: Folder<'ast, T>>(
|
|||
|
||||
FieldElementExpression::Xor(box left, box right)
|
||||
}
|
||||
FieldElementExpression::LeftShift(box e, by) => {
|
||||
FieldElementExpression::LeftShift(box e, box by) => {
|
||||
let e = f.fold_field_expression(e);
|
||||
let by = f.fold_uint_expression(by);
|
||||
|
||||
FieldElementExpression::LeftShift(box e, by)
|
||||
FieldElementExpression::LeftShift(box e, box by)
|
||||
}
|
||||
FieldElementExpression::RightShift(box e, by) => {
|
||||
FieldElementExpression::RightShift(box e, box by) => {
|
||||
let e = f.fold_field_expression(e);
|
||||
let by = f.fold_uint_expression(by);
|
||||
|
||||
FieldElementExpression::RightShift(box e, by)
|
||||
FieldElementExpression::RightShift(box e, box by)
|
||||
}
|
||||
FieldElementExpression::Conditional(c) => {
|
||||
match f.fold_conditional_expression(&Type::FieldElement, c) {
|
||||
|
|
|
@ -71,7 +71,7 @@ impl<'ast, T: fmt::Display> fmt::Display for ZirFunction<'ast, T> {
|
|||
writeln!(f)?;
|
||||
}
|
||||
|
||||
writeln!(f, "}}")
|
||||
write!(f, "}}")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -464,8 +464,14 @@ pub enum FieldElementExpression<'ast, T> {
|
|||
Box<FieldElementExpression<'ast, T>>,
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
),
|
||||
LeftShift(Box<FieldElementExpression<'ast, T>>, u32),
|
||||
RightShift(Box<FieldElementExpression<'ast, T>>, u32),
|
||||
LeftShift(
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
Box<UExpression<'ast, T>>,
|
||||
),
|
||||
RightShift(
|
||||
Box<FieldElementExpression<'ast, T>>,
|
||||
Box<UExpression<'ast, T>>,
|
||||
),
|
||||
Conditional(ConditionalExpression<'ast, T, FieldElementExpression<'ast, T>>),
|
||||
}
|
||||
|
||||
|
|
|
@ -306,15 +306,17 @@ pub fn fold_field_expression<'ast, T: Field, F: ResultFolder<'ast, T>>(
|
|||
|
||||
FieldElementExpression::Or(box left, box right)
|
||||
}
|
||||
FieldElementExpression::LeftShift(box e, by) => {
|
||||
FieldElementExpression::LeftShift(box e, box by) => {
|
||||
let e = f.fold_field_expression(e)?;
|
||||
let by = f.fold_uint_expression(by)?;
|
||||
|
||||
FieldElementExpression::LeftShift(box e, by)
|
||||
FieldElementExpression::LeftShift(box e, box by)
|
||||
}
|
||||
FieldElementExpression::RightShift(box e, by) => {
|
||||
FieldElementExpression::RightShift(box e, box by) => {
|
||||
let e = f.fold_field_expression(e)?;
|
||||
let by = f.fold_uint_expression(by)?;
|
||||
|
||||
FieldElementExpression::RightShift(box e, by)
|
||||
FieldElementExpression::RightShift(box e, box by)
|
||||
}
|
||||
FieldElementExpression::Conditional(c) => {
|
||||
match f.fold_conditional_expression(&Type::FieldElement, c)? {
|
||||
|
|
46
zokrates_core_test/tests/tests/assembly/gates/and.json
Normal file
46
zokrates_core_test/tests/tests/assembly/gates/and.json
Normal file
|
@ -0,0 +1,46 @@
|
|||
{
|
||||
"curves": ["Bn128"],
|
||||
"max_constraint_count": 2,
|
||||
"tests": [
|
||||
{
|
||||
"input": {
|
||||
"values": ["0", "0"]
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"value": "0"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"values": ["1", "0"]
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"value": "0"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"values": ["0", "1"]
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"value": "0"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"values": ["1", "1"]
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"value": "1"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
7
zokrates_core_test/tests/tests/assembly/gates/and.zok
Normal file
7
zokrates_core_test/tests/tests/assembly/gates/and.zok
Normal file
|
@ -0,0 +1,7 @@
|
|||
def main(field a, field b) -> field {
|
||||
field mut c = 0;
|
||||
asm {
|
||||
c <== a * b;
|
||||
}
|
||||
return c;
|
||||
}
|
26
zokrates_core_test/tests/tests/assembly/gates/not.json
Normal file
26
zokrates_core_test/tests/tests/assembly/gates/not.json
Normal file
|
@ -0,0 +1,26 @@
|
|||
{
|
||||
"curves": ["Bn128"],
|
||||
"max_constraint_count": 2,
|
||||
"tests": [
|
||||
{
|
||||
"input": {
|
||||
"values": ["0"]
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"value": "1"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"values": ["1"]
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"value": "0"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
7
zokrates_core_test/tests/tests/assembly/gates/not.zok
Normal file
7
zokrates_core_test/tests/tests/assembly/gates/not.zok
Normal file
|
@ -0,0 +1,7 @@
|
|||
def main(field inp) -> field {
|
||||
field mut out = 0;
|
||||
asm {
|
||||
out <== 1 + inp - 2*inp;
|
||||
}
|
||||
return out;
|
||||
}
|
46
zokrates_core_test/tests/tests/assembly/gates/or.json
Normal file
46
zokrates_core_test/tests/tests/assembly/gates/or.json
Normal file
|
@ -0,0 +1,46 @@
|
|||
{
|
||||
"curves": ["Bn128"],
|
||||
"max_constraint_count": 2,
|
||||
"tests": [
|
||||
{
|
||||
"input": {
|
||||
"values": ["0", "0"]
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"value": "0"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"values": ["1", "0"]
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"value": "1"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"values": ["0", "1"]
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"value": "1"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"values": ["1", "1"]
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"value": "1"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
7
zokrates_core_test/tests/tests/assembly/gates/or.zok
Normal file
7
zokrates_core_test/tests/tests/assembly/gates/or.zok
Normal file
|
@ -0,0 +1,7 @@
|
|||
def main(field a, field b) -> field {
|
||||
field mut c = 0;
|
||||
asm {
|
||||
c <== a + b - a*b;
|
||||
}
|
||||
return c;
|
||||
}
|
36
zokrates_core_test/tests/tests/assembly/is_equal.json
Normal file
36
zokrates_core_test/tests/tests/assembly/is_equal.json
Normal file
|
@ -0,0 +1,36 @@
|
|||
{
|
||||
"curves": ["Bn128"],
|
||||
"max_constraint_count": 3,
|
||||
"tests": [
|
||||
{
|
||||
"input": {
|
||||
"values": ["1", "1"]
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"value": "1"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"values": ["2", "4"]
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"value": "0"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"input": {
|
||||
"values": ["4", "2"]
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"value": "0"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
5
zokrates_core_test/tests/tests/assembly/is_equal.zok
Normal file
5
zokrates_core_test/tests/tests/assembly/is_equal.zok
Normal file
|
@ -0,0 +1,5 @@
|
|||
import "./is_zero.zok";
|
||||
|
||||
def main(field a, field b) -> field {
|
||||
return is_zero(b - a);
|
||||
}
|
|
@ -4,6 +4,7 @@ use zokrates_abi::{Decode, Value};
|
|||
use zokrates_ast::ir::{
|
||||
LinComb, ProgIterator, QuadComb, RuntimeError, Solver, Statement, Variable, Witness,
|
||||
};
|
||||
use zokrates_ast::zir;
|
||||
use zokrates_field::Field;
|
||||
|
||||
pub type ExecutionResult<T> = Result<Witness<T>, Error>;
|
||||
|
@ -168,11 +169,6 @@ impl Interpreter {
|
|||
let res = match solver {
|
||||
Solver::Zir(func) => {
|
||||
use zokrates_ast::zir::result_folder::ResultFolder;
|
||||
|
||||
assert!(func
|
||||
.arguments
|
||||
.iter()
|
||||
.all(|a| a.id._type == zokrates_ast::zir::Type::FieldElement));
|
||||
assert_eq!(func.arguments.len(), inputs.len());
|
||||
|
||||
let constants = func
|
||||
|
@ -182,7 +178,23 @@ impl Interpreter {
|
|||
.map(|(a, v)| {
|
||||
(
|
||||
a.id.id.clone(),
|
||||
zokrates_ast::zir::FieldElementExpression::Number(v.clone()).into(),
|
||||
match &a.id._type {
|
||||
zir::Type::FieldElement => {
|
||||
zokrates_ast::zir::FieldElementExpression::Number(v.clone())
|
||||
.into()
|
||||
}
|
||||
zir::Type::Boolean => {
|
||||
zokrates_ast::zir::BooleanExpression::Value(*v == T::from(1))
|
||||
.into()
|
||||
}
|
||||
zir::Type::Uint(bitwidth) => {
|
||||
zokrates_ast::zir::UExpressionInner::Value(
|
||||
v.to_dec_string().parse::<u128>().unwrap(),
|
||||
)
|
||||
.annotate(*bitwidth)
|
||||
.into()
|
||||
}
|
||||
},
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
|
Loading…
Reference in a new issue