Merge pull request #1322 from Zokrates/fix-nested-assembly
Fix propagation in case of nested assembly
This commit is contained in:
commit
f900437d36
7 changed files with 132 additions and 13 deletions
1
changelogs/unreleased/1322-dark64
Normal file
1
changelogs/unreleased/1322-dark64
Normal file
|
@ -0,0 +1 @@
|
|||
Fix panic in case of nested assembly blocks
|
|
@ -510,6 +510,38 @@ pub struct ArgumentFinder<'ast, T> {
|
|||
}
|
||||
|
||||
impl<'ast, T: Field> Folder<'ast, T> for ArgumentFinder<'ast, T> {
|
||||
fn fold_assembly_block(
|
||||
&mut self,
|
||||
s: zir::AssemblyBlockStatement<'ast, T>,
|
||||
) -> Vec<zir::ZirStatement<'ast, T>> {
|
||||
let mut statements: Vec<_> = s
|
||||
.inner
|
||||
.into_iter()
|
||||
.rev()
|
||||
.flat_map(|s| self.fold_assembly_statement(s))
|
||||
.collect();
|
||||
statements.reverse();
|
||||
vec![zir::ZirStatement::Assembly(
|
||||
zir::AssemblyBlockStatement::new(statements),
|
||||
)]
|
||||
}
|
||||
|
||||
fn fold_assembly_assignment(
|
||||
&mut self,
|
||||
s: zir::AssemblyAssignment<'ast, T>,
|
||||
) -> Vec<zir::ZirAssemblyStatement<'ast, T>> {
|
||||
let assignees: Vec<zir::ZirAssignee<'ast>> = s
|
||||
.assignee
|
||||
.into_iter()
|
||||
.map(|v| self.fold_assignee(v))
|
||||
.collect();
|
||||
let expr = self.fold_function(s.expression);
|
||||
for a in &assignees {
|
||||
self.identifiers.remove(&a.id);
|
||||
}
|
||||
vec![zir::ZirAssemblyStatement::assignment(assignees, expr)]
|
||||
}
|
||||
|
||||
fn fold_statement(&mut self, s: zir::ZirStatement<'ast, T>) -> Vec<zir::ZirStatement<'ast, T>> {
|
||||
match s {
|
||||
zir::ZirStatement::Definition(s) => {
|
||||
|
|
|
@ -84,6 +84,27 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> {
|
|||
})
|
||||
}
|
||||
|
||||
fn fold_assembly_block(
|
||||
&mut self,
|
||||
s: zokrates_ast::zir::AssemblyBlockStatement<'ast, T>,
|
||||
) -> Result<Vec<ZirStatement<'ast, T>>, Self::Error> {
|
||||
let block: Vec<_> = s
|
||||
.inner
|
||||
.into_iter()
|
||||
.map(|s| self.fold_assembly_statement(s))
|
||||
.collect::<Result<Vec<_>, _>>()?
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect();
|
||||
|
||||
Ok(match block.is_empty() {
|
||||
true => vec![],
|
||||
false => vec![ZirStatement::Assembly(
|
||||
zokrates_ast::zir::AssemblyBlockStatement::new(block),
|
||||
)],
|
||||
})
|
||||
}
|
||||
|
||||
fn fold_assembly_assignment(
|
||||
&mut self,
|
||||
s: zokrates_ast::zir::AssemblyAssignment<'ast, T>,
|
||||
|
|
|
@ -9,11 +9,16 @@ pub struct ZirCanonicalizer<'ast> {
|
|||
|
||||
impl<'ast, T: Field> Folder<'ast, T> for ZirCanonicalizer<'ast> {
|
||||
fn fold_parameter(&mut self, p: Parameter<'ast>) -> Parameter<'ast> {
|
||||
let id = match self.identifier_map.get(&p.id.id) {
|
||||
Some(v) => Identifier::internal(*v),
|
||||
None => {
|
||||
let new_id = self.identifier_map.len();
|
||||
self.identifier_map.insert(p.id.id.clone(), new_id);
|
||||
|
||||
Identifier::internal(new_id)
|
||||
}
|
||||
};
|
||||
Parameter {
|
||||
id: Variable::with_id_and_type(Identifier::internal(new_id), p.id.ty),
|
||||
id: Variable::with_id_and_type(id, p.id.ty),
|
||||
..p
|
||||
}
|
||||
}
|
||||
|
|
|
@ -63,8 +63,8 @@ pub struct ZirFunction<'ast, T> {
|
|||
pub type IdentifierOrExpression<'ast, T, E> =
|
||||
expressions::IdentifierOrExpression<Identifier<'ast>, E, <E as Expr<'ast, T>>::Inner>;
|
||||
|
||||
impl<'ast, T: fmt::Display> fmt::Display for ZirFunction<'ast, T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
impl<'ast, T: fmt::Display> ZirFunction<'ast, T> {
|
||||
fn fmt_indented(&self, f: &mut fmt::Formatter, depth: usize) -> fmt::Result {
|
||||
writeln!(
|
||||
f,
|
||||
"({}) -> ({}) {{",
|
||||
|
@ -82,11 +82,17 @@ impl<'ast, T: fmt::Display> fmt::Display for ZirFunction<'ast, T> {
|
|||
)?;
|
||||
|
||||
for s in &self.statements {
|
||||
s.fmt_indented(f, 1)?;
|
||||
s.fmt_indented(f, depth + 1)?;
|
||||
writeln!(f)?;
|
||||
}
|
||||
|
||||
write!(f, "}}")
|
||||
write!(f, "{}}}", "\t".repeat(depth))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: fmt::Display> fmt::Display for ZirFunction<'ast, T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
self.fmt_indented(f, 0)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -174,19 +180,30 @@ impl<'ast, T> WithSpan for ZirAssemblyStatement<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: fmt::Display> fmt::Display for ZirAssemblyStatement<'ast, T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
impl<'ast, T: fmt::Display> ZirAssemblyStatement<'ast, T> {
|
||||
fn fmt_indented(&self, f: &mut fmt::Formatter, depth: usize) -> fmt::Result {
|
||||
write!(f, "{}", "\t".repeat(depth))?;
|
||||
match *self {
|
||||
ZirAssemblyStatement::Assignment(ref s) => {
|
||||
write!(
|
||||
f,
|
||||
"{} <-- {};",
|
||||
"{} <-- (",
|
||||
s.assignee
|
||||
.iter()
|
||||
.map(|a| a.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(", "),
|
||||
.join(", ")
|
||||
)?;
|
||||
s.expression.fmt_indented(f, depth)?;
|
||||
write!(
|
||||
f,
|
||||
")({});",
|
||||
s.expression
|
||||
.arguments
|
||||
.iter()
|
||||
.map(|a| a.id.id.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
)
|
||||
}
|
||||
ZirAssemblyStatement::Constraint(ref s) => {
|
||||
|
@ -196,6 +213,12 @@ impl<'ast, T: fmt::Display> fmt::Display for ZirAssemblyStatement<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: fmt::Display> fmt::Display for ZirAssemblyStatement<'ast, T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
self.fmt_indented(f, 0)
|
||||
}
|
||||
}
|
||||
|
||||
pub type DefinitionStatement<'ast, T> =
|
||||
common::expressions::DefinitionStatement<ZirAssignee<'ast>, ZirExpression<'ast, T>>;
|
||||
pub type AssertionStatement<'ast, T> =
|
||||
|
@ -462,7 +485,8 @@ impl<'ast, T: fmt::Display> ZirStatement<'ast, T> {
|
|||
ZirStatement::Assembly(s) => {
|
||||
writeln!(f, "asm {{")?;
|
||||
for s in &s.inner {
|
||||
writeln!(f, "{}{}", "\t".repeat(depth + 1), s)?;
|
||||
s.fmt_indented(f, depth + 1)?;
|
||||
writeln!(f)?;
|
||||
}
|
||||
write!(f, "{}}}", "\t".repeat(depth))
|
||||
}
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
{
|
||||
"curves": ["Bn128"],
|
||||
"max_constraint_count": 1,
|
||||
"tests": [
|
||||
{
|
||||
"input": {
|
||||
"values": ["4"]
|
||||
},
|
||||
"output": {
|
||||
"Ok": {
|
||||
"value": "32"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
|
@ -0,0 +1,20 @@
|
|||
def bar(field a) -> field {
|
||||
field tmp = a * a;
|
||||
field mut b = 0;
|
||||
asm {
|
||||
b <-- tmp * 2;
|
||||
}
|
||||
return b;
|
||||
}
|
||||
|
||||
def foo(field a) -> field {
|
||||
field mut b = 0;
|
||||
asm {
|
||||
b <-- bar(a);
|
||||
}
|
||||
return b;
|
||||
}
|
||||
|
||||
def main(field a) -> field {
|
||||
return foo(a);
|
||||
}
|
Loading…
Reference in a new issue