fix parser, fmt
This commit is contained in:
parent
f8521fd33c
commit
6d65801b9b
9 changed files with 192 additions and 76 deletions
|
@ -5,6 +5,7 @@ use zokrates_pest_ast as pest;
|
|||
|
||||
impl<'ast, T: Field> From<pest::File<'ast>> for absy::Module<'ast, T> {
|
||||
fn from(prog: pest::File<'ast>) -> absy::Module<T> {
|
||||
println!("{:#?}", prog);
|
||||
absy::Module::with_symbols(
|
||||
prog.structs
|
||||
.into_iter()
|
||||
|
|
|
@ -752,8 +752,11 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
statements_flattened: &mut Vec<FlatStatement<T>>,
|
||||
expr: UExpression<'ast, T>,
|
||||
) -> FlatUExpression<T> {
|
||||
|
||||
statements_flattened.push(FlatStatement::Log(format!(" {}{}", " ".repeat(self.depth), expr)));
|
||||
statements_flattened.push(FlatStatement::Log(format!(
|
||||
" {}{}",
|
||||
" ".repeat(self.depth),
|
||||
expr
|
||||
)));
|
||||
|
||||
self.depth += 1;
|
||||
|
||||
|
@ -802,7 +805,9 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
},
|
||||
));
|
||||
|
||||
FlatUExpression::with_bits(name_not.into_iter().map(|v| v.into()).collect::<Vec<_>>())
|
||||
FlatUExpression::with_bits(
|
||||
name_not.into_iter().map(|v| v.into()).collect::<Vec<_>>(),
|
||||
)
|
||||
}
|
||||
UExpressionInner::Add(box left, box right) => {
|
||||
let left_flattened = self
|
||||
|
@ -1114,10 +1119,12 @@ impl<'ast, T: Field> Flattener<'ast, T> {
|
|||
|
||||
self.depth -= 1;
|
||||
|
||||
statements_flattened.push(FlatStatement::Log(format!(" {} DONE", " ".repeat(self.depth))));
|
||||
statements_flattened.push(FlatStatement::Log(format!(
|
||||
" {} DONE",
|
||||
" ".repeat(self.depth)
|
||||
)));
|
||||
|
||||
res
|
||||
|
||||
}
|
||||
|
||||
fn get_bits(
|
||||
|
|
|
@ -2070,7 +2070,11 @@ impl<'ast> Checker<'ast> {
|
|||
(e1, e2) => Err(ErrorInner {
|
||||
pos: Some(pos),
|
||||
|
||||
message: format!("cannot left-shift {} by {}", e1.get_type(), e2.get_type()),
|
||||
message: format!(
|
||||
"cannot left-shift {} by {}",
|
||||
e1.get_type(),
|
||||
e2.get_type()
|
||||
),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
@ -2084,7 +2088,11 @@ impl<'ast> Checker<'ast> {
|
|||
(e1, e2) => Err(ErrorInner {
|
||||
pos: Some(pos),
|
||||
|
||||
message: format!("cannot right-shift {} by {}", e1.get_type(), e2.get_type()),
|
||||
message: format!(
|
||||
"cannot right-shift {} by {}",
|
||||
e1.get_type(),
|
||||
e2.get_type()
|
||||
),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -87,7 +87,6 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
|
|||
)),
|
||||
// propagation to the defined variable if rhs is a constant
|
||||
TypedStatement::Definition(TypedAssignee::Identifier(var), expr) => {
|
||||
|
||||
let expr = self.fold_expression(expr);
|
||||
|
||||
if is_constant(&expr) {
|
||||
|
@ -137,11 +136,13 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
|
|||
let expression_list = self.fold_expression_list(expression_list);
|
||||
match expression_list {
|
||||
TypedExpressionList::FunctionCall(key, arguments, types) => {
|
||||
let arguments: Vec<_> = arguments.into_iter().map(|a| self.fold_expression(a)).collect();
|
||||
let arguments: Vec<_> = arguments
|
||||
.into_iter()
|
||||
.map(|a| self.fold_expression(a))
|
||||
.collect();
|
||||
|
||||
if arguments.iter().all(|a| is_constant(a)) {
|
||||
let expr: TypedExpression<'ast, T> = if key.id == "_U32_FROM_BITS" {
|
||||
|
||||
assert_eq!(variables.len(), 1);
|
||||
assert_eq!(arguments.len(), 1);
|
||||
|
||||
|
@ -152,25 +153,27 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
|
|||
ArrayExpressionInner::Value(v) => {
|
||||
assert_eq!(v.len(), 32);
|
||||
UExpressionInner::Value(
|
||||
v.into_iter()
|
||||
.map(|v| match v {
|
||||
TypedExpression::Boolean(
|
||||
BooleanExpression::Value(v),
|
||||
) => v,
|
||||
_ => unreachable!(),
|
||||
})
|
||||
.enumerate()
|
||||
.fold(0, |acc, (i, v)| {
|
||||
if v {
|
||||
acc + 2u128
|
||||
.pow((32 - i - 1).try_into().unwrap())
|
||||
} else {
|
||||
acc
|
||||
}
|
||||
}),
|
||||
)
|
||||
.annotate(32)
|
||||
.into()},
|
||||
v.into_iter()
|
||||
.map(|v| match v {
|
||||
TypedExpression::Boolean(
|
||||
BooleanExpression::Value(v),
|
||||
) => v,
|
||||
_ => unreachable!(),
|
||||
})
|
||||
.enumerate()
|
||||
.fold(0, |acc, (i, v)| {
|
||||
if v {
|
||||
acc + 2u128.pow(
|
||||
(32 - i - 1).try_into().unwrap(),
|
||||
)
|
||||
} else {
|
||||
acc
|
||||
}
|
||||
}),
|
||||
)
|
||||
.annotate(32)
|
||||
.into()
|
||||
}
|
||||
_ => unreachable!(),
|
||||
},
|
||||
_ => unreachable!(),
|
||||
|
@ -195,7 +198,11 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
|
|||
}
|
||||
assert_eq!(num, 0);
|
||||
|
||||
ArrayExpressionInner::Value(res.into_iter().map(|v| BooleanExpression::Value(v).into()).collect())
|
||||
ArrayExpressionInner::Value(
|
||||
res.into_iter()
|
||||
.map(|v| BooleanExpression::Value(v).into())
|
||||
.collect(),
|
||||
)
|
||||
.annotate(Type::Boolean, 32)
|
||||
.into()
|
||||
}
|
||||
|
@ -219,7 +226,9 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
|
|||
false => None,
|
||||
}
|
||||
} else {
|
||||
if self.verbose {println!("not constant!")};
|
||||
if self.verbose {
|
||||
println!("not constant!")
|
||||
};
|
||||
let l = TypedExpressionList::FunctionCall(key, arguments, types);
|
||||
Some(TypedStatement::MultipleDefinition(variables, l))
|
||||
}
|
||||
|
@ -282,7 +291,9 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
|
|||
) {
|
||||
(UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => {
|
||||
use std::convert::TryInto;
|
||||
UExpressionInner::Value((v1.wrapping_sub(v2)) % 2_u128.pow(bitwidth.try_into().unwrap()))
|
||||
UExpressionInner::Value(
|
||||
(v1.wrapping_sub(v2)) % 2_u128.pow(bitwidth.try_into().unwrap()),
|
||||
)
|
||||
}
|
||||
(e, UExpressionInner::Value(v)) | (UExpressionInner::Value(v), e) => match v {
|
||||
0 => e,
|
||||
|
@ -323,17 +334,17 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
|
|||
let by_as_usize = by.to_dec_string().parse::<usize>().unwrap();
|
||||
UExpressionInner::Value(v >> by_as_usize)
|
||||
}
|
||||
(e, FieldElementExpression::Number(by)) => {
|
||||
UExpressionInner::RightShift(box e.annotate(bitwidth), box FieldElementExpression::Number(by))
|
||||
}
|
||||
(e, FieldElementExpression::Number(by)) => UExpressionInner::RightShift(
|
||||
box e.annotate(bitwidth),
|
||||
box FieldElementExpression::Number(by),
|
||||
),
|
||||
(_, e2) => unreachable!(format!(
|
||||
"non-constant shift {} detected during static analysis",
|
||||
e2
|
||||
)),
|
||||
}
|
||||
},
|
||||
}
|
||||
UExpressionInner::LeftShift(box e, box by) => {
|
||||
|
||||
println!("LEFT?");
|
||||
|
||||
let e = self.fold_uint_expression(e);
|
||||
|
@ -343,15 +354,16 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
|
|||
let by_as_usize = by.to_dec_string().parse::<usize>().unwrap();
|
||||
UExpressionInner::Value((v << by_as_usize) & 0xffffffff)
|
||||
}
|
||||
(e, FieldElementExpression::Number(by)) => {
|
||||
UExpressionInner::LeftShift(box e.annotate(bitwidth), box FieldElementExpression::Number(by))
|
||||
}
|
||||
(e, FieldElementExpression::Number(by)) => UExpressionInner::LeftShift(
|
||||
box e.annotate(bitwidth),
|
||||
box FieldElementExpression::Number(by),
|
||||
),
|
||||
(_, e2) => unreachable!(format!(
|
||||
"non-constant shift {} detected during static analysis",
|
||||
e2
|
||||
)),
|
||||
}
|
||||
},
|
||||
}
|
||||
UExpressionInner::Xor(box e1, box e2) => match (
|
||||
self.fold_uint_expression(e1).into_inner(),
|
||||
self.fold_uint_expression(e2).into_inner(),
|
||||
|
@ -621,7 +633,7 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
|
|||
},
|
||||
None => ArrayExpressionInner::Identifier(id),
|
||||
}
|
||||
},
|
||||
}
|
||||
ArrayExpressionInner::Select(box array, box index) => {
|
||||
let array = self.fold_array_expression(array);
|
||||
let index = self.fold_field_expression(index);
|
||||
|
@ -978,9 +990,7 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
|
|||
),
|
||||
}
|
||||
}
|
||||
(a, i) => {
|
||||
BooleanExpression::Select(box a.annotate(inner_type, size), box i)
|
||||
}
|
||||
(a, i) => BooleanExpression::Select(box a.annotate(inner_type, size), box i),
|
||||
}
|
||||
}
|
||||
BooleanExpression::Member(box s, m) => {
|
||||
|
|
|
@ -517,39 +517,68 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
|
|||
})
|
||||
.collect(),
|
||||
)],
|
||||
ZirStatement::MultipleDefinition(lhs, rhs) => {
|
||||
match rhs {
|
||||
ZirExpressionList::FunctionCall(key, arguments, ty) => {
|
||||
match key.clone().id {
|
||||
"_U32_FROM_BITS" => {
|
||||
assert_eq!(lhs.len(), 1);
|
||||
let expr = UExpressionInner::FunctionCall(key.clone(), arguments.clone().into_iter().map(|a| self.fold_expression(a)).collect()).annotate(32).metadata(UMetadata {
|
||||
bitwidth: Some(32),
|
||||
should_reduce: Some(true)
|
||||
});
|
||||
self.register(lhs[0].clone(), ZirExpression::Uint(expr));
|
||||
vec![ZirStatement::MultipleDefinition(lhs, ZirExpressionList::FunctionCall(key, arguments, ty))]
|
||||
},
|
||||
"_U32_TO_BITS" => {
|
||||
assert_eq!(lhs.len(), 32);
|
||||
vec![ZirStatement::MultipleDefinition(lhs, ZirExpressionList::FunctionCall(key, arguments.into_iter().map(|e| self.fold_expression(e)).collect(), ty))]
|
||||
},
|
||||
_ => {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
ZirStatement::MultipleDefinition(lhs, rhs) => match rhs {
|
||||
ZirExpressionList::FunctionCall(key, arguments, ty) => match key.clone().id {
|
||||
"_U32_FROM_BITS" => {
|
||||
assert_eq!(lhs.len(), 1);
|
||||
let expr = UExpressionInner::FunctionCall(
|
||||
key.clone(),
|
||||
arguments
|
||||
.clone()
|
||||
.into_iter()
|
||||
.map(|a| self.fold_expression(a))
|
||||
.collect(),
|
||||
)
|
||||
.annotate(32)
|
||||
.metadata(UMetadata {
|
||||
bitwidth: Some(32),
|
||||
should_reduce: Some(true),
|
||||
});
|
||||
self.register(lhs[0].clone(), ZirExpression::Uint(expr));
|
||||
vec![ZirStatement::MultipleDefinition(
|
||||
lhs,
|
||||
ZirExpressionList::FunctionCall(key, arguments, ty),
|
||||
)]
|
||||
}
|
||||
}
|
||||
"_U32_TO_BITS" => {
|
||||
assert_eq!(lhs.len(), 32);
|
||||
vec![ZirStatement::MultipleDefinition(
|
||||
lhs,
|
||||
ZirExpressionList::FunctionCall(
|
||||
key,
|
||||
arguments
|
||||
.into_iter()
|
||||
.map(|e| self.fold_expression(e))
|
||||
.collect(),
|
||||
ty,
|
||||
),
|
||||
)]
|
||||
}
|
||||
_ => unimplemented!(),
|
||||
},
|
||||
},
|
||||
// we need to put back in range to assert
|
||||
ZirStatement::Condition(lhs, rhs) => match (self.fold_expression(lhs), self.fold_expression(rhs)) {
|
||||
(ZirExpression::Uint(lhs), ZirExpression::Uint(rhs)) => {
|
||||
let lhs_metadata = lhs.metadata.clone().unwrap();
|
||||
let rhs_metadata = rhs.metadata.clone().unwrap();
|
||||
vec![ZirStatement::Condition(lhs.metadata(UMetadata { should_reduce: Some(true), ..lhs_metadata}).into(), rhs.metadata(UMetadata { should_reduce: Some(true), ..rhs_metadata}).into())]
|
||||
},
|
||||
(lhs, rhs) => vec![ZirStatement::Condition(lhs, rhs)]
|
||||
},
|
||||
ZirStatement::Condition(lhs, rhs) => {
|
||||
match (self.fold_expression(lhs), self.fold_expression(rhs)) {
|
||||
(ZirExpression::Uint(lhs), ZirExpression::Uint(rhs)) => {
|
||||
let lhs_metadata = lhs.metadata.clone().unwrap();
|
||||
let rhs_metadata = rhs.metadata.clone().unwrap();
|
||||
vec![ZirStatement::Condition(
|
||||
lhs.metadata(UMetadata {
|
||||
should_reduce: Some(true),
|
||||
..lhs_metadata
|
||||
})
|
||||
.into(),
|
||||
rhs.metadata(UMetadata {
|
||||
should_reduce: Some(true),
|
||||
..rhs_metadata
|
||||
})
|
||||
.into(),
|
||||
)]
|
||||
}
|
||||
(lhs, rhs) => vec![ZirStatement::Condition(lhs, rhs)],
|
||||
}
|
||||
}
|
||||
s => fold_statement(self, s),
|
||||
}
|
||||
}
|
||||
|
|
5
zokrates_core_test/tests/tests/single_return.json
Normal file
5
zokrates_core_test/tests/tests/single_return.json
Normal file
|
@ -0,0 +1,5 @@
|
|||
{
|
||||
"entry_point": "./tests/tests/single_return.zok",
|
||||
"tests": [
|
||||
]
|
||||
}
|
6
zokrates_core_test/tests/tests/single_return.zok
Normal file
6
zokrates_core_test/tests/tests/single_return.zok
Normal file
|
@ -0,0 +1,6 @@
|
|||
def foo() -> (field):
|
||||
return 42
|
||||
|
||||
def main() -> ():
|
||||
field a = foo()
|
||||
return
|
|
@ -133,6 +133,56 @@ mod tests {
|
|||
};
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_field_def_to_multi() {
|
||||
parses_to! {
|
||||
parser: ZoKratesParser,
|
||||
input: r#"field a = foo()
|
||||
"#,
|
||||
rule: Rule::statement,
|
||||
tokens: [
|
||||
statement(0, 28, [
|
||||
multi_assignment_statement(0, 15, [
|
||||
optionally_typed_identifier(0, 7, [
|
||||
ty(0, 5, [
|
||||
ty_basic(0, 5, [
|
||||
ty_field(0, 5)
|
||||
])
|
||||
]),
|
||||
identifier(6, 7)
|
||||
]),
|
||||
identifier(10, 13),
|
||||
])
|
||||
])
|
||||
]
|
||||
};
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_u8_def_to_multi() {
|
||||
parses_to! {
|
||||
parser: ZoKratesParser,
|
||||
input: r#"u32 a = foo()
|
||||
"#,
|
||||
rule: Rule::statement,
|
||||
tokens: [
|
||||
statement(0, 26, [
|
||||
multi_assignment_statement(0, 13, [
|
||||
optionally_typed_identifier(0, 5, [
|
||||
ty(0, 3, [
|
||||
ty_basic(0, 3, [
|
||||
ty_u32(0, 3)
|
||||
])
|
||||
]),
|
||||
identifier(4, 5)
|
||||
]),
|
||||
identifier(8, 11),
|
||||
])
|
||||
])
|
||||
]
|
||||
};
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_invalid_identifier() {
|
||||
fails_with! {
|
||||
|
|
|
@ -131,5 +131,5 @@ COMMENT = _{ ("/*" ~ (!"*/" ~ ANY)* ~ "*/") | ("//" ~ (!NEWLINE ~ ANY)*) }
|
|||
|
||||
// TODO: Order by alphabet
|
||||
keyword = @{"as"|"bool"|"byte"|"def"|"do"|"else"|"endfor"|"export"|"false"|"field"|"for"|"if"|"then"|"fi"|"import"|
|
||||
"in"|"private"|"public"|"return"|"struct"|"true"|"uint"
|
||||
"in"|"private"|"public"|"return"|"struct"|"true"|"u8"|"u16"|"u32"
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue