1
0
Fork 0
mirror of synced 2025-09-23 20:28:36 +00:00

fix parser, fmt

This commit is contained in:
schaeff 2020-03-02 11:42:24 +01:00
parent f8521fd33c
commit 6d65801b9b
9 changed files with 192 additions and 76 deletions

View file

@ -5,6 +5,7 @@ use zokrates_pest_ast as pest;
impl<'ast, T: Field> From<pest::File<'ast>> for absy::Module<'ast, T> {
fn from(prog: pest::File<'ast>) -> absy::Module<T> {
println!("{:#?}", prog);
absy::Module::with_symbols(
prog.structs
.into_iter()

View file

@ -752,8 +752,11 @@ impl<'ast, T: Field> Flattener<'ast, T> {
statements_flattened: &mut Vec<FlatStatement<T>>,
expr: UExpression<'ast, T>,
) -> FlatUExpression<T> {
statements_flattened.push(FlatStatement::Log(format!(" {}{}", " ".repeat(self.depth), expr)));
statements_flattened.push(FlatStatement::Log(format!(
" {}{}",
" ".repeat(self.depth),
expr
)));
self.depth += 1;
@ -802,7 +805,9 @@ impl<'ast, T: Field> Flattener<'ast, T> {
},
));
FlatUExpression::with_bits(name_not.into_iter().map(|v| v.into()).collect::<Vec<_>>())
FlatUExpression::with_bits(
name_not.into_iter().map(|v| v.into()).collect::<Vec<_>>(),
)
}
UExpressionInner::Add(box left, box right) => {
let left_flattened = self
@ -1114,10 +1119,12 @@ impl<'ast, T: Field> Flattener<'ast, T> {
self.depth -= 1;
statements_flattened.push(FlatStatement::Log(format!(" {} DONE", " ".repeat(self.depth))));
statements_flattened.push(FlatStatement::Log(format!(
" {} DONE",
" ".repeat(self.depth)
)));
res
}
fn get_bits(

View file

@ -2070,7 +2070,11 @@ impl<'ast> Checker<'ast> {
(e1, e2) => Err(ErrorInner {
pos: Some(pos),
message: format!("cannot left-shift {} by {}", e1.get_type(), e2.get_type()),
message: format!(
"cannot left-shift {} by {}",
e1.get_type(),
e2.get_type()
),
}),
}
}
@ -2084,7 +2088,11 @@ impl<'ast> Checker<'ast> {
(e1, e2) => Err(ErrorInner {
pos: Some(pos),
message: format!("cannot right-shift {} by {}", e1.get_type(), e2.get_type()),
message: format!(
"cannot right-shift {} by {}",
e1.get_type(),
e2.get_type()
),
}),
}
}

View file

@ -87,7 +87,6 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
)),
// propagation to the defined variable if rhs is a constant
TypedStatement::Definition(TypedAssignee::Identifier(var), expr) => {
let expr = self.fold_expression(expr);
if is_constant(&expr) {
@ -137,11 +136,13 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
let expression_list = self.fold_expression_list(expression_list);
match expression_list {
TypedExpressionList::FunctionCall(key, arguments, types) => {
let arguments: Vec<_> = arguments.into_iter().map(|a| self.fold_expression(a)).collect();
let arguments: Vec<_> = arguments
.into_iter()
.map(|a| self.fold_expression(a))
.collect();
if arguments.iter().all(|a| is_constant(a)) {
let expr: TypedExpression<'ast, T> = if key.id == "_U32_FROM_BITS" {
assert_eq!(variables.len(), 1);
assert_eq!(arguments.len(), 1);
@ -152,25 +153,27 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
ArrayExpressionInner::Value(v) => {
assert_eq!(v.len(), 32);
UExpressionInner::Value(
v.into_iter()
.map(|v| match v {
TypedExpression::Boolean(
BooleanExpression::Value(v),
) => v,
_ => unreachable!(),
})
.enumerate()
.fold(0, |acc, (i, v)| {
if v {
acc + 2u128
.pow((32 - i - 1).try_into().unwrap())
} else {
acc
}
}),
)
.annotate(32)
.into()},
v.into_iter()
.map(|v| match v {
TypedExpression::Boolean(
BooleanExpression::Value(v),
) => v,
_ => unreachable!(),
})
.enumerate()
.fold(0, |acc, (i, v)| {
if v {
acc + 2u128.pow(
(32 - i - 1).try_into().unwrap(),
)
} else {
acc
}
}),
)
.annotate(32)
.into()
}
_ => unreachable!(),
},
_ => unreachable!(),
@ -195,7 +198,11 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
}
assert_eq!(num, 0);
ArrayExpressionInner::Value(res.into_iter().map(|v| BooleanExpression::Value(v).into()).collect())
ArrayExpressionInner::Value(
res.into_iter()
.map(|v| BooleanExpression::Value(v).into())
.collect(),
)
.annotate(Type::Boolean, 32)
.into()
}
@ -219,7 +226,9 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
false => None,
}
} else {
if self.verbose {println!("not constant!")};
if self.verbose {
println!("not constant!")
};
let l = TypedExpressionList::FunctionCall(key, arguments, types);
Some(TypedStatement::MultipleDefinition(variables, l))
}
@ -282,7 +291,9 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
) {
(UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => {
use std::convert::TryInto;
UExpressionInner::Value((v1.wrapping_sub(v2)) % 2_u128.pow(bitwidth.try_into().unwrap()))
UExpressionInner::Value(
(v1.wrapping_sub(v2)) % 2_u128.pow(bitwidth.try_into().unwrap()),
)
}
(e, UExpressionInner::Value(v)) | (UExpressionInner::Value(v), e) => match v {
0 => e,
@ -323,17 +334,17 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
let by_as_usize = by.to_dec_string().parse::<usize>().unwrap();
UExpressionInner::Value(v >> by_as_usize)
}
(e, FieldElementExpression::Number(by)) => {
UExpressionInner::RightShift(box e.annotate(bitwidth), box FieldElementExpression::Number(by))
}
(e, FieldElementExpression::Number(by)) => UExpressionInner::RightShift(
box e.annotate(bitwidth),
box FieldElementExpression::Number(by),
),
(_, e2) => unreachable!(format!(
"non-constant shift {} detected during static analysis",
e2
)),
}
},
}
UExpressionInner::LeftShift(box e, box by) => {
println!("LEFT?");
let e = self.fold_uint_expression(e);
@ -343,15 +354,16 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
let by_as_usize = by.to_dec_string().parse::<usize>().unwrap();
UExpressionInner::Value((v << by_as_usize) & 0xffffffff)
}
(e, FieldElementExpression::Number(by)) => {
UExpressionInner::LeftShift(box e.annotate(bitwidth), box FieldElementExpression::Number(by))
}
(e, FieldElementExpression::Number(by)) => UExpressionInner::LeftShift(
box e.annotate(bitwidth),
box FieldElementExpression::Number(by),
),
(_, e2) => unreachable!(format!(
"non-constant shift {} detected during static analysis",
e2
)),
}
},
}
UExpressionInner::Xor(box e1, box e2) => match (
self.fold_uint_expression(e1).into_inner(),
self.fold_uint_expression(e2).into_inner(),
@ -621,7 +633,7 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
},
None => ArrayExpressionInner::Identifier(id),
}
},
}
ArrayExpressionInner::Select(box array, box index) => {
let array = self.fold_array_expression(array);
let index = self.fold_field_expression(index);
@ -978,9 +990,7 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
),
}
}
(a, i) => {
BooleanExpression::Select(box a.annotate(inner_type, size), box i)
}
(a, i) => BooleanExpression::Select(box a.annotate(inner_type, size), box i),
}
}
BooleanExpression::Member(box s, m) => {

View file

@ -517,39 +517,68 @@ impl<'ast, T: Field> Folder<'ast, T> for UintOptimizer<'ast, T> {
})
.collect(),
)],
ZirStatement::MultipleDefinition(lhs, rhs) => {
match rhs {
ZirExpressionList::FunctionCall(key, arguments, ty) => {
match key.clone().id {
"_U32_FROM_BITS" => {
assert_eq!(lhs.len(), 1);
let expr = UExpressionInner::FunctionCall(key.clone(), arguments.clone().into_iter().map(|a| self.fold_expression(a)).collect()).annotate(32).metadata(UMetadata {
bitwidth: Some(32),
should_reduce: Some(true)
});
self.register(lhs[0].clone(), ZirExpression::Uint(expr));
vec![ZirStatement::MultipleDefinition(lhs, ZirExpressionList::FunctionCall(key, arguments, ty))]
},
"_U32_TO_BITS" => {
assert_eq!(lhs.len(), 32);
vec![ZirStatement::MultipleDefinition(lhs, ZirExpressionList::FunctionCall(key, arguments.into_iter().map(|e| self.fold_expression(e)).collect(), ty))]
},
_ => {
unimplemented!()
}
}
ZirStatement::MultipleDefinition(lhs, rhs) => match rhs {
ZirExpressionList::FunctionCall(key, arguments, ty) => match key.clone().id {
"_U32_FROM_BITS" => {
assert_eq!(lhs.len(), 1);
let expr = UExpressionInner::FunctionCall(
key.clone(),
arguments
.clone()
.into_iter()
.map(|a| self.fold_expression(a))
.collect(),
)
.annotate(32)
.metadata(UMetadata {
bitwidth: Some(32),
should_reduce: Some(true),
});
self.register(lhs[0].clone(), ZirExpression::Uint(expr));
vec![ZirStatement::MultipleDefinition(
lhs,
ZirExpressionList::FunctionCall(key, arguments, ty),
)]
}
}
"_U32_TO_BITS" => {
assert_eq!(lhs.len(), 32);
vec![ZirStatement::MultipleDefinition(
lhs,
ZirExpressionList::FunctionCall(
key,
arguments
.into_iter()
.map(|e| self.fold_expression(e))
.collect(),
ty,
),
)]
}
_ => unimplemented!(),
},
},
// we need to put back in range to assert
ZirStatement::Condition(lhs, rhs) => match (self.fold_expression(lhs), self.fold_expression(rhs)) {
(ZirExpression::Uint(lhs), ZirExpression::Uint(rhs)) => {
let lhs_metadata = lhs.metadata.clone().unwrap();
let rhs_metadata = rhs.metadata.clone().unwrap();
vec![ZirStatement::Condition(lhs.metadata(UMetadata { should_reduce: Some(true), ..lhs_metadata}).into(), rhs.metadata(UMetadata { should_reduce: Some(true), ..rhs_metadata}).into())]
},
(lhs, rhs) => vec![ZirStatement::Condition(lhs, rhs)]
},
ZirStatement::Condition(lhs, rhs) => {
match (self.fold_expression(lhs), self.fold_expression(rhs)) {
(ZirExpression::Uint(lhs), ZirExpression::Uint(rhs)) => {
let lhs_metadata = lhs.metadata.clone().unwrap();
let rhs_metadata = rhs.metadata.clone().unwrap();
vec![ZirStatement::Condition(
lhs.metadata(UMetadata {
should_reduce: Some(true),
..lhs_metadata
})
.into(),
rhs.metadata(UMetadata {
should_reduce: Some(true),
..rhs_metadata
})
.into(),
)]
}
(lhs, rhs) => vec![ZirStatement::Condition(lhs, rhs)],
}
}
s => fold_statement(self, s),
}
}

View file

@ -0,0 +1,5 @@
{
"entry_point": "./tests/tests/single_return.zok",
"tests": [
]
}

View file

@ -0,0 +1,6 @@
def foo() -> (field):
return 42
def main() -> ():
field a = foo()
return

View file

@ -133,6 +133,56 @@ mod tests {
};
}
#[test]
fn parse_field_def_to_multi() {
parses_to! {
parser: ZoKratesParser,
input: r#"field a = foo()
"#,
rule: Rule::statement,
tokens: [
statement(0, 28, [
multi_assignment_statement(0, 15, [
optionally_typed_identifier(0, 7, [
ty(0, 5, [
ty_basic(0, 5, [
ty_field(0, 5)
])
]),
identifier(6, 7)
]),
identifier(10, 13),
])
])
]
};
}
#[test]
fn parse_u8_def_to_multi() {
parses_to! {
parser: ZoKratesParser,
input: r#"u32 a = foo()
"#,
rule: Rule::statement,
tokens: [
statement(0, 26, [
multi_assignment_statement(0, 13, [
optionally_typed_identifier(0, 5, [
ty(0, 3, [
ty_basic(0, 3, [
ty_u32(0, 3)
])
]),
identifier(4, 5)
]),
identifier(8, 11),
])
])
]
};
}
#[test]
fn parse_invalid_identifier() {
fails_with! {

View file

@ -131,5 +131,5 @@ COMMENT = _{ ("/*" ~ (!"*/" ~ ANY)* ~ "*/") | ("//" ~ (!NEWLINE ~ ANY)*) }
// TODO: Order by alphabet
keyword = @{"as"|"bool"|"byte"|"def"|"do"|"else"|"endfor"|"export"|"false"|"field"|"for"|"if"|"then"|"fi"|"import"|
"in"|"private"|"public"|"return"|"struct"|"true"|"uint"
"in"|"private"|"public"|"return"|"struct"|"true"|"u8"|"u16"|"u32"
}