Merge pull request #1322 from Zokrates/fix-nested-assembly
Fix propagation in case of nested assembly
This commit is contained in:
commit
f900437d36
7 changed files with 132 additions and 13 deletions
1
changelogs/unreleased/1322-dark64
Normal file
1
changelogs/unreleased/1322-dark64
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Fix panic in case of nested assembly blocks
|
|
@ -510,6 +510,38 @@ pub struct ArgumentFinder<'ast, T> {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'ast, T: Field> Folder<'ast, T> for ArgumentFinder<'ast, T> {
|
impl<'ast, T: Field> Folder<'ast, T> for ArgumentFinder<'ast, T> {
|
||||||
|
fn fold_assembly_block(
|
||||||
|
&mut self,
|
||||||
|
s: zir::AssemblyBlockStatement<'ast, T>,
|
||||||
|
) -> Vec<zir::ZirStatement<'ast, T>> {
|
||||||
|
let mut statements: Vec<_> = s
|
||||||
|
.inner
|
||||||
|
.into_iter()
|
||||||
|
.rev()
|
||||||
|
.flat_map(|s| self.fold_assembly_statement(s))
|
||||||
|
.collect();
|
||||||
|
statements.reverse();
|
||||||
|
vec![zir::ZirStatement::Assembly(
|
||||||
|
zir::AssemblyBlockStatement::new(statements),
|
||||||
|
)]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn fold_assembly_assignment(
|
||||||
|
&mut self,
|
||||||
|
s: zir::AssemblyAssignment<'ast, T>,
|
||||||
|
) -> Vec<zir::ZirAssemblyStatement<'ast, T>> {
|
||||||
|
let assignees: Vec<zir::ZirAssignee<'ast>> = s
|
||||||
|
.assignee
|
||||||
|
.into_iter()
|
||||||
|
.map(|v| self.fold_assignee(v))
|
||||||
|
.collect();
|
||||||
|
let expr = self.fold_function(s.expression);
|
||||||
|
for a in &assignees {
|
||||||
|
self.identifiers.remove(&a.id);
|
||||||
|
}
|
||||||
|
vec![zir::ZirAssemblyStatement::assignment(assignees, expr)]
|
||||||
|
}
|
||||||
|
|
||||||
fn fold_statement(&mut self, s: zir::ZirStatement<'ast, T>) -> Vec<zir::ZirStatement<'ast, T>> {
|
fn fold_statement(&mut self, s: zir::ZirStatement<'ast, T>) -> Vec<zir::ZirStatement<'ast, T>> {
|
||||||
match s {
|
match s {
|
||||||
zir::ZirStatement::Definition(s) => {
|
zir::ZirStatement::Definition(s) => {
|
||||||
|
|
|
@ -84,6 +84,27 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn fold_assembly_block(
|
||||||
|
&mut self,
|
||||||
|
s: zokrates_ast::zir::AssemblyBlockStatement<'ast, T>,
|
||||||
|
) -> Result<Vec<ZirStatement<'ast, T>>, Self::Error> {
|
||||||
|
let block: Vec<_> = s
|
||||||
|
.inner
|
||||||
|
.into_iter()
|
||||||
|
.map(|s| self.fold_assembly_statement(s))
|
||||||
|
.collect::<Result<Vec<_>, _>>()?
|
||||||
|
.into_iter()
|
||||||
|
.flatten()
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Ok(match block.is_empty() {
|
||||||
|
true => vec![],
|
||||||
|
false => vec![ZirStatement::Assembly(
|
||||||
|
zokrates_ast::zir::AssemblyBlockStatement::new(block),
|
||||||
|
)],
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
fn fold_assembly_assignment(
|
fn fold_assembly_assignment(
|
||||||
&mut self,
|
&mut self,
|
||||||
s: zokrates_ast::zir::AssemblyAssignment<'ast, T>,
|
s: zokrates_ast::zir::AssemblyAssignment<'ast, T>,
|
||||||
|
|
|
@ -9,11 +9,16 @@ pub struct ZirCanonicalizer<'ast> {
|
||||||
|
|
||||||
impl<'ast, T: Field> Folder<'ast, T> for ZirCanonicalizer<'ast> {
|
impl<'ast, T: Field> Folder<'ast, T> for ZirCanonicalizer<'ast> {
|
||||||
fn fold_parameter(&mut self, p: Parameter<'ast>) -> Parameter<'ast> {
|
fn fold_parameter(&mut self, p: Parameter<'ast>) -> Parameter<'ast> {
|
||||||
let new_id = self.identifier_map.len();
|
let id = match self.identifier_map.get(&p.id.id) {
|
||||||
self.identifier_map.insert(p.id.id.clone(), new_id);
|
Some(v) => Identifier::internal(*v),
|
||||||
|
None => {
|
||||||
|
let new_id = self.identifier_map.len();
|
||||||
|
self.identifier_map.insert(p.id.id.clone(), new_id);
|
||||||
|
Identifier::internal(new_id)
|
||||||
|
}
|
||||||
|
};
|
||||||
Parameter {
|
Parameter {
|
||||||
id: Variable::with_id_and_type(Identifier::internal(new_id), p.id.ty),
|
id: Variable::with_id_and_type(id, p.id.ty),
|
||||||
..p
|
..p
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -63,8 +63,8 @@ pub struct ZirFunction<'ast, T> {
|
||||||
pub type IdentifierOrExpression<'ast, T, E> =
|
pub type IdentifierOrExpression<'ast, T, E> =
|
||||||
expressions::IdentifierOrExpression<Identifier<'ast>, E, <E as Expr<'ast, T>>::Inner>;
|
expressions::IdentifierOrExpression<Identifier<'ast>, E, <E as Expr<'ast, T>>::Inner>;
|
||||||
|
|
||||||
impl<'ast, T: fmt::Display> fmt::Display for ZirFunction<'ast, T> {
|
impl<'ast, T: fmt::Display> ZirFunction<'ast, T> {
|
||||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
fn fmt_indented(&self, f: &mut fmt::Formatter, depth: usize) -> fmt::Result {
|
||||||
writeln!(
|
writeln!(
|
||||||
f,
|
f,
|
||||||
"({}) -> ({}) {{",
|
"({}) -> ({}) {{",
|
||||||
|
@ -82,11 +82,17 @@ impl<'ast, T: fmt::Display> fmt::Display for ZirFunction<'ast, T> {
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
for s in &self.statements {
|
for s in &self.statements {
|
||||||
s.fmt_indented(f, 1)?;
|
s.fmt_indented(f, depth + 1)?;
|
||||||
writeln!(f)?;
|
writeln!(f)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
write!(f, "}}")
|
write!(f, "{}}}", "\t".repeat(depth))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ast, T: fmt::Display> fmt::Display for ZirFunction<'ast, T> {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||||
|
self.fmt_indented(f, 0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -174,19 +180,30 @@ impl<'ast, T> WithSpan for ZirAssemblyStatement<'ast, T> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'ast, T: fmt::Display> fmt::Display for ZirAssemblyStatement<'ast, T> {
|
impl<'ast, T: fmt::Display> ZirAssemblyStatement<'ast, T> {
|
||||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
fn fmt_indented(&self, f: &mut fmt::Formatter, depth: usize) -> fmt::Result {
|
||||||
|
write!(f, "{}", "\t".repeat(depth))?;
|
||||||
match *self {
|
match *self {
|
||||||
ZirAssemblyStatement::Assignment(ref s) => {
|
ZirAssemblyStatement::Assignment(ref s) => {
|
||||||
write!(
|
write!(
|
||||||
f,
|
f,
|
||||||
"{} <-- {};",
|
"{} <-- (",
|
||||||
s.assignee
|
s.assignee
|
||||||
.iter()
|
.iter()
|
||||||
.map(|a| a.to_string())
|
.map(|a| a.to_string())
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
.join(", "),
|
.join(", ")
|
||||||
|
)?;
|
||||||
|
s.expression.fmt_indented(f, depth)?;
|
||||||
|
write!(
|
||||||
|
f,
|
||||||
|
")({});",
|
||||||
s.expression
|
s.expression
|
||||||
|
.arguments
|
||||||
|
.iter()
|
||||||
|
.map(|a| a.id.id.to_string())
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join(", ")
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
ZirAssemblyStatement::Constraint(ref s) => {
|
ZirAssemblyStatement::Constraint(ref s) => {
|
||||||
|
@ -196,6 +213,12 @@ impl<'ast, T: fmt::Display> fmt::Display for ZirAssemblyStatement<'ast, T> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<'ast, T: fmt::Display> fmt::Display for ZirAssemblyStatement<'ast, T> {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||||
|
self.fmt_indented(f, 0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub type DefinitionStatement<'ast, T> =
|
pub type DefinitionStatement<'ast, T> =
|
||||||
common::expressions::DefinitionStatement<ZirAssignee<'ast>, ZirExpression<'ast, T>>;
|
common::expressions::DefinitionStatement<ZirAssignee<'ast>, ZirExpression<'ast, T>>;
|
||||||
pub type AssertionStatement<'ast, T> =
|
pub type AssertionStatement<'ast, T> =
|
||||||
|
@ -462,7 +485,8 @@ impl<'ast, T: fmt::Display> ZirStatement<'ast, T> {
|
||||||
ZirStatement::Assembly(s) => {
|
ZirStatement::Assembly(s) => {
|
||||||
writeln!(f, "asm {{")?;
|
writeln!(f, "asm {{")?;
|
||||||
for s in &s.inner {
|
for s in &s.inner {
|
||||||
writeln!(f, "{}{}", "\t".repeat(depth + 1), s)?;
|
s.fmt_indented(f, depth + 1)?;
|
||||||
|
writeln!(f)?;
|
||||||
}
|
}
|
||||||
write!(f, "{}}}", "\t".repeat(depth))
|
write!(f, "{}}}", "\t".repeat(depth))
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,16 @@
|
||||||
|
{
|
||||||
|
"curves": ["Bn128"],
|
||||||
|
"max_constraint_count": 1,
|
||||||
|
"tests": [
|
||||||
|
{
|
||||||
|
"input": {
|
||||||
|
"values": ["4"]
|
||||||
|
},
|
||||||
|
"output": {
|
||||||
|
"Ok": {
|
||||||
|
"value": "32"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
|
@ -0,0 +1,20 @@
|
||||||
|
def bar(field a) -> field {
|
||||||
|
field tmp = a * a;
|
||||||
|
field mut b = 0;
|
||||||
|
asm {
|
||||||
|
b <-- tmp * 2;
|
||||||
|
}
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
def foo(field a) -> field {
|
||||||
|
field mut b = 0;
|
||||||
|
asm {
|
||||||
|
b <-- bar(a);
|
||||||
|
}
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
def main(field a) -> field {
|
||||||
|
return foo(a);
|
||||||
|
}
|
Loading…
Reference in a new issue