1
0
Fork 0
mirror of synced 2025-09-23 12:18:44 +00:00

implement propagation for assembly blocks. wip

This commit is contained in:
schaeff 2022-11-29 17:26:27 +01:00
parent fbfc20c4e6
commit 33a3043fe4
23 changed files with 302 additions and 103 deletions

View file

@ -31,9 +31,9 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for AssemblyTransformer {
fn fold_assembly_statement(
&mut self,
s: ZirAssemblyStatement<'ast, T>,
) -> Result<ZirAssemblyStatement<'ast, T>, Self::Error> {
) -> Result<Vec<ZirAssemblyStatement<'ast, T>>, Self::Error> {
match s {
ZirAssemblyStatement::Assignment(_, _) => Ok(s),
ZirAssemblyStatement::Assignment(_, _) => Ok(vec![s]),
ZirAssemblyStatement::Constraint(lhs, rhs) => {
let lhs = self.fold_field_expression(lhs)?;
let rhs = self.fold_field_expression(rhs)?;
@ -53,7 +53,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for AssemblyTransformer {
};
match is_quadratic {
true => Ok(ZirAssemblyStatement::Constraint(lhs, rhs)),
true => Ok(vec![ZirAssemblyStatement::Constraint(lhs, rhs)]),
false => {
let sub = FieldElementExpression::Sub(box lhs, box rhs);
let mut lqc = LinQuadComb::try_from(sub.clone()).map_err(|_| {
@ -156,7 +156,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for AssemblyTransformer {
.fold_field_expression(rhs)
.map_err(|e| Error(e.to_string()))?;
Ok(ZirAssemblyStatement::Constraint(lhs, rhs))
Ok(vec![ZirAssemblyStatement::Constraint(lhs, rhs)])
}
}
}

View file

@ -528,7 +528,10 @@ fn fold_assembly_statement<'ast, T: Field>(
match s {
typed::TypedAssemblyStatement::Assignment(a, e) => {
let mut statements_buffer: Vec<zir::ZirStatement<'ast, T>> = vec![];
let a = f.fold_assignee(a);
let mut a = f.fold_assignee(a);
assert_eq!(a.len(), 1);
let a = a.pop().unwrap();
assert_eq!(a.get_type(), zir::Type::FieldElement);
let e = f.fold_field_expression(&mut statements_buffer, e);
statements_buffer.push(zir::ZirStatement::Return(vec![
zir::ZirExpression::FieldElement(e),
@ -545,7 +548,7 @@ fn fold_assembly_statement<'ast, T: Field>(
let function = zir::ZirFunction {
signature: zir::types::Signature::default()
.inputs(finder.identifiers.values().cloned().collect())
.outputs(a.iter().map(|a| a.get_type()).collect()),
.outputs(vec![a.get_type()]),
arguments: finder
.identifiers
.into_iter()

View file

@ -212,42 +212,102 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
)
}
fn fold_assembly_statement(
&mut self,
s: TypedAssemblyStatement<'ast, T>,
) -> Result<Vec<TypedAssemblyStatement<'ast, T>>, Self::Error> {
match s {
TypedAssemblyStatement::Assignment(assignee, expr) => {
let assignee = self.fold_assignee(assignee)?;
let expr = self.fold_field_expression(expr)?;
let expr = TypedExpression::from(expr);
if expr.is_constant() {
match assignee {
TypedAssignee::Identifier(var) => {
let expr = expr.into_canonical_constant();
assert!(self.constants.insert(var.id, expr).is_none());
Ok(vec![])
}
assignee => match self.try_get_constant_mut(&assignee) {
Ok((_, c)) => {
*c = expr.into_canonical_constant();
Ok(vec![])
}
Err(v) => match self.constants.remove(&v.id) {
// invalidate the cache for this identifier, and define the latest
// version of the constant in the program, if any
Some(c) => Ok(vec![
TypedAssemblyStatement::Assignment(v.clone().into(), c.into()),
TypedAssemblyStatement::Assignment(assignee, expr.into()),
]),
None => Ok(vec![TypedAssemblyStatement::Assignment(
assignee,
expr.into(),
)]),
},
},
}
} else {
// the expression being assigned is not constant, invalidate the cache
let v = self
.try_get_constant_mut(&assignee)
.map(|(v, _)| v)
.unwrap_or_else(|v| v);
match self.constants.remove(&v.id) {
Some(c) => Ok(vec![
TypedAssemblyStatement::Assignment(v.clone().into(), c.into()),
TypedAssemblyStatement::Assignment(assignee, expr.into()),
]),
None => Ok(vec![TypedAssemblyStatement::Assignment(
assignee,
expr.into(),
)]),
}
}
}
TypedAssemblyStatement::Constraint(left, right) => {
let left = self.fold_field_expression(left)?;
let right = self.fold_field_expression(right)?;
// a bit hacky, but we use a fake boolean expression to check this
let is_equal =
BooleanExpression::FieldEq(EqExpression::new(left.clone(), right.clone()));
let is_equal = self.fold_boolean_expression(is_equal)?;
match is_equal {
BooleanExpression::Value(true) => Ok(vec![]),
BooleanExpression::Value(false) => Err(Error::AssertionFailed(format!(
"In asm block: `{} !== {}`",
left, right
))),
_ => Ok(vec![TypedAssemblyStatement::Constraint(left, right)]),
}
}
}
}
fn fold_statement(
&mut self,
s: TypedStatement<'ast, T>,
) -> Result<Vec<TypedStatement<'ast, T>>, Error> {
match s {
TypedStatement::Assembly(statements) => {
let mut assembly_statement_buffer = vec![];
let mut statement_buffer = vec![];
for s in statements {
match self.fold_assembly_statement(s)? {
TypedAssemblyStatement::Assignment(assignee, expr) => {
// invalidate the cache
let v = self
.try_get_constant_mut(&assignee)
.map(|(v, _)| v)
.unwrap_or_else(|v| v);
match self.constants.remove(&v.id) {
Some(c) => {
statement_buffer.push(TypedStatement::Definition(
v.clone().into(),
c.into(),
));
}
None => {}
}
assembly_statement_buffer
.push(TypedAssemblyStatement::Assignment(assignee, expr));
}
s => assembly_statement_buffer.push(s),
}
let statements: Vec<_> = statements
.into_iter()
.map(|s| self.fold_assembly_statement(s))
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.flatten()
.collect();
match statements.len() {
0 => Ok(vec![]),
_ => Ok(vec![TypedStatement::Assembly(statements)]),
}
statement_buffer.push(TypedStatement::Assembly(assembly_statement_buffer));
Ok(statement_buffer)
}
// propagation to the defined variable if rhs is a constant
TypedStatement::Definition(assignee, DefinitionRhs::Expression(expr)) => {

View file

@ -127,7 +127,7 @@ impl<'ast, 'a, T: Field> Folder<'ast, T> for ShallowTransformer<'ast, 'a> {
fn fold_assembly_statement(
&mut self,
s: TypedAssemblyStatement<'ast, T>,
) -> TypedAssemblyStatement<'ast, T> {
) -> Vec<TypedAssemblyStatement<'ast, T>> {
match s {
TypedAssemblyStatement::Assignment(a, e) => {
let e = self.fold_field_expression(e);
@ -138,7 +138,7 @@ impl<'ast, 'a, T: Field> Folder<'ast, T> for ShallowTransformer<'ast, 'a> {
}
a => fold_assignee(self, a),
};
TypedAssemblyStatement::Assignment(a, e)
vec![TypedAssemblyStatement::Assignment(a, e)]
}
s => fold_assembly_statement(self, s),
}

View file

@ -60,18 +60,55 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> {
fn fold_assembly_statement(
&mut self,
s: ZirAssemblyStatement<'ast, T>,
) -> Result<ZirAssemblyStatement<'ast, T>, Self::Error> {
) -> Result<Vec<ZirAssemblyStatement<'ast, T>>, Self::Error> {
match s {
ZirAssemblyStatement::Assignment(assignees, function) => {
for a in &assignees {
self.constants.remove(&a.id);
ZirAssemblyStatement::Assignment(assignee, function) => {
let function = self.fold_function(function)?;
if function.statements.len() == 1 {
let value = match &function.statements.last().unwrap() {
ZirStatement::Return(values) => {
assert_eq!(values.len(), 1);
match values[0].clone() {
ZirExpression::FieldElement(FieldElementExpression::Number(v)) => {
Some(v)
}
_ => None,
}
}
_ => None,
};
match value {
Some(v) => {
self.constants
.insert(assignee.id, FieldElementExpression::Number(v).into());
Ok(vec![])
}
None => Ok(vec![ZirAssemblyStatement::Assignment(assignee, function)]),
}
} else {
Ok(vec![ZirAssemblyStatement::Assignment(assignee, function)])
}
}
ZirAssemblyStatement::Constraint(left, right) => {
let left = self.fold_field_expression(left)?;
let right = self.fold_field_expression(right)?;
// a bit hacky, but we use a fake boolean expression to check this
let is_equal = BooleanExpression::FieldEq(box left.clone(), box right.clone());
let is_equal = self.fold_boolean_expression(is_equal)?;
match is_equal {
BooleanExpression::Value(true) => Ok(vec![]),
BooleanExpression::Value(false) => {
Err(Error::AssertionFailed(RuntimeError::SourceAssertion(
format!("In asm block: `{} !== {}`", left, right),
)))
}
_ => Ok(vec![ZirAssemblyStatement::Constraint(left, right)]),
}
Ok(ZirAssemblyStatement::Assignment(
assignees,
self.fold_function(function)?,
))
}
s => fold_assembly_statement(self, s),
}
}
@ -147,6 +184,19 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> {
self.fold_expression_list(list)?,
)])
}
ZirStatement::Assembly(statements) => {
let statements: Vec<_> = statements
.into_iter()
.map(|s| self.fold_assembly_statement(s))
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.flatten()
.collect();
match statements.len() {
0 => Ok(vec![]),
_ => Ok(vec![ZirStatement::Assembly(statements)]),
}
}
_ => fold_statement(self, s),
}
}

View file

@ -263,7 +263,7 @@ pub trait Folder<'ast, T: Field>: Sized {
fn fold_assembly_statement(
&mut self,
s: TypedAssemblyStatement<'ast, T>,
) -> TypedAssemblyStatement<'ast, T> {
) -> Vec<TypedAssemblyStatement<'ast, T>> {
fold_assembly_statement(self, s)
}
@ -525,15 +525,18 @@ pub fn fold_definition_rhs<'ast, T: Field, F: Folder<'ast, T>>(
pub fn fold_assembly_statement<'ast, T: Field, F: Folder<'ast, T>>(
f: &mut F,
s: TypedAssemblyStatement<'ast, T>,
) -> TypedAssemblyStatement<'ast, T> {
) -> Vec<TypedAssemblyStatement<'ast, T>> {
match s {
TypedAssemblyStatement::Assignment(a, e) => {
TypedAssemblyStatement::Assignment(f.fold_assignee(a), f.fold_field_expression(e))
vec![TypedAssemblyStatement::Assignment(
f.fold_assignee(a),
f.fold_field_expression(e),
)]
}
TypedAssemblyStatement::Constraint(lhs, rhs) => TypedAssemblyStatement::Constraint(
TypedAssemblyStatement::Constraint(lhs, rhs) => vec![TypedAssemblyStatement::Constraint(
f.fold_field_expression(lhs),
f.fold_field_expression(rhs),
),
)],
}
}
@ -564,7 +567,7 @@ pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>(
TypedStatement::Assembly(statements) => TypedStatement::Assembly(
statements
.into_iter()
.map(|s| f.fold_assembly_statement(s))
.flat_map(|s| f.fold_assembly_statement(s))
.collect(),
),
s => s,

View file

@ -389,7 +389,7 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
fn fold_assembly_statement(
&mut self,
s: TypedAssemblyStatement<'ast, T>,
) -> Result<TypedAssemblyStatement<'ast, T>, Self::Error> {
) -> Result<Vec<TypedAssemblyStatement<'ast, T>>, Self::Error> {
fold_assembly_statement(self, s)
}
@ -526,15 +526,18 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
pub fn fold_assembly_statement<'ast, T: Field, F: ResultFolder<'ast, T>>(
f: &mut F,
s: TypedAssemblyStatement<'ast, T>,
) -> Result<TypedAssemblyStatement<'ast, T>, F::Error> {
) -> Result<Vec<TypedAssemblyStatement<'ast, T>>, F::Error> {
Ok(match s {
TypedAssemblyStatement::Assignment(a, e) => {
TypedAssemblyStatement::Assignment(f.fold_assignee(a)?, f.fold_field_expression(e)?)
vec![TypedAssemblyStatement::Assignment(
f.fold_assignee(a)?,
f.fold_field_expression(e)?,
)]
}
TypedAssemblyStatement::Constraint(lhs, rhs) => TypedAssemblyStatement::Constraint(
TypedAssemblyStatement::Constraint(lhs, rhs) => vec![TypedAssemblyStatement::Constraint(
f.fold_field_expression(lhs)?,
f.fold_field_expression(rhs)?,
),
)],
})
}
@ -572,7 +575,10 @@ pub fn fold_statement<'ast, T: Field, F: ResultFolder<'ast, T>>(
statements
.into_iter()
.map(|s| f.fold_assembly_statement(s))
.collect::<Result<Vec<_>, _>>()?,
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.flatten()
.collect(),
),
s => s,
};

View file

@ -59,7 +59,7 @@ pub trait Folder<'ast, T: Field>: Sized {
fn fold_assembly_statement(
&mut self,
s: ZirAssemblyStatement<'ast, T>,
) -> ZirAssemblyStatement<'ast, T> {
) -> Vec<ZirAssemblyStatement<'ast, T>> {
fold_assembly_statement(self, s)
}
@ -145,17 +145,17 @@ pub trait Folder<'ast, T: Field>: Sized {
pub fn fold_assembly_statement<'ast, T: Field, F: Folder<'ast, T>>(
f: &mut F,
s: ZirAssemblyStatement<'ast, T>,
) -> ZirAssemblyStatement<'ast, T> {
) -> Vec<ZirAssemblyStatement<'ast, T>> {
match s {
ZirAssemblyStatement::Assignment(assignees, function) => {
let assignees = assignees.into_iter().map(|a| f.fold_assignee(a)).collect();
ZirAssemblyStatement::Assignment(assignee, function) => {
let assignees = f.fold_assignee(assignee);
let function = f.fold_function(function);
ZirAssemblyStatement::Assignment(assignees, function)
vec![ZirAssemblyStatement::Assignment(assignees, function)]
}
ZirAssemblyStatement::Constraint(lhs, rhs) => {
let lhs = f.fold_field_expression(lhs);
let rhs = f.fold_field_expression(rhs);
ZirAssemblyStatement::Constraint(lhs, rhs)
vec![ZirAssemblyStatement::Constraint(lhs, rhs)]
}
}
}
@ -201,7 +201,7 @@ pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>(
ZirStatement::Assembly(statements) => ZirStatement::Assembly(
statements
.into_iter()
.map(|s| f.fold_assembly_statement(s))
.flat_map(|s| f.fold_assembly_statement(s))
.collect(),
),
};

View file

@ -119,10 +119,7 @@ impl RuntimeError {
#[derive(Clone, PartialEq, Hash, Eq, Debug, Serialize, Deserialize)]
pub enum ZirAssemblyStatement<'ast, T> {
Assignment(
#[serde(borrow)] Vec<ZirAssignee<'ast>>,
ZirFunction<'ast, T>,
),
Assignment(#[serde(borrow)] ZirAssignee<'ast>, ZirFunction<'ast, T>),
Constraint(
FieldElementExpression<'ast, T>,
FieldElementExpression<'ast, T>,
@ -133,15 +130,7 @@ impl<'ast, T: fmt::Display> fmt::Display for ZirAssemblyStatement<'ast, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
ZirAssemblyStatement::Assignment(ref lhs, ref rhs) => {
write!(
f,
"{} <-- {}",
lhs.iter()
.map(|a| a.id.to_string())
.collect::<Vec<String>>()
.join(", "),
rhs
)
write!(f, "{} <-- {}", lhs, rhs)
}
ZirAssemblyStatement::Constraint(ref lhs, ref rhs) => {
write!(f, "{} === {}", lhs, rhs)

View file

@ -64,7 +64,7 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
fn fold_assembly_statement(
&mut self,
s: ZirAssemblyStatement<'ast, T>,
) -> Result<ZirAssemblyStatement<'ast, T>, Self::Error> {
) -> Result<Vec<ZirAssemblyStatement<'ast, T>>, Self::Error> {
fold_assembly_statement(self, s)
}
@ -162,20 +162,17 @@ pub trait ResultFolder<'ast, T: Field>: Sized {
pub fn fold_assembly_statement<'ast, T: Field, F: ResultFolder<'ast, T>>(
f: &mut F,
s: ZirAssemblyStatement<'ast, T>,
) -> Result<ZirAssemblyStatement<'ast, T>, F::Error> {
) -> Result<Vec<ZirAssemblyStatement<'ast, T>>, F::Error> {
Ok(match s {
ZirAssemblyStatement::Assignment(assignees, function) => {
let assignees = assignees
.into_iter()
.map(|a| f.fold_assignee(a))
.collect::<Result<Vec<_>, _>>()?;
ZirAssemblyStatement::Assignment(assignee, function) => {
let assignee = f.fold_assignee(assignee)?;
let function = f.fold_function(function)?;
ZirAssemblyStatement::Assignment(assignees, function)
vec![ZirAssemblyStatement::Assignment(assignee, function)]
}
ZirAssemblyStatement::Constraint(lhs, rhs) => {
let lhs = f.fold_field_expression(lhs)?;
let rhs = f.fold_field_expression(rhs)?;
ZirAssemblyStatement::Constraint(lhs, rhs)
vec![ZirAssemblyStatement::Constraint(lhs, rhs)]
}
})
}
@ -238,7 +235,10 @@ pub fn fold_statement<'ast, T: Field, F: ResultFolder<'ast, T>>(
let statements = statements
.into_iter()
.map(|s| f.fold_assembly_statement(s))
.collect::<Result<Vec<_>, _>>()?;
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.flatten()
.collect();
ZirStatement::Assembly(statements)
}
};

View file

@ -2229,16 +2229,12 @@ impl<'ast, T: Field> Flattener<'ast, T> {
stat: ZirAssemblyStatement<'ast, T>,
) {
match stat {
ZirAssemblyStatement::Assignment(assignees, function) => {
let outputs: Vec<Variable> = assignees
.iter()
.map(|a| {
self.layout
.get(&a.id)
.cloned()
.unwrap_or_else(|| self.use_variable(a))
})
.collect();
ZirAssemblyStatement::Assignment(assignee, function) => {
let outputs: Vec<Variable> = vec![self
.layout
.get(&assignee.id)
.cloned()
.unwrap_or_else(|| self.use_variable(&assignee))];
let inputs: Vec<FlatExpression<T>> = function
.arguments
.iter()

View file

@ -2135,7 +2135,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
let pos = assignee.pos();
// check that the assignee is declared
match assignee.value {
Assignee::Identifier(variable_name) => match self.scope.get(&*variable_name) {
Assignee::Identifier(variable_name) => match self.scope.get(variable_name) {
Some(info) => match info.is_mutable {
false => Err(ErrorInner {
pos: Some(assignee.pos()),
@ -2444,7 +2444,7 @@ impl<'ast, T: Field> Checker<'ast, T> {
Expression::BooleanConstant(b) => Ok(BooleanExpression::Value(b).into()),
Expression::Identifier(name) => {
// check that `id` is defined in the scope
match self.scope.get(&*name) {
match self.scope.get(name) {
Some(info) => {
let id = info.id;
match info.ty.clone() {

View file

@ -0,0 +1,17 @@
{
"curves": ["Bn128"],
"max_constraint_count": 0,
"tests": [
{
"input": {
"values": [
]
},
"output": {
"Ok": {
"value": ["0", "2"]
}
}
}
]
}

View file

@ -0,0 +1,9 @@
def main() -> field[2] {
field[2] mut a = [1, 2];
u32 i = 0;
asm {
a[i] <-- 0;
a[i] === 0;
}
return a;
}

View file

@ -0,0 +1,18 @@
{
"curves": ["Bn128"],
"max_constraint_count": 1,
"tests": [
{
"input": {
"values": [
"42"
]
},
"output": {
"Ok": {
"value": ["42", "2"]
}
}
}
]
}

View file

@ -0,0 +1,9 @@
def main(field v) -> field[2] {
field[2] mut a = [1, 2];
u32 i = 0;
asm {
a[i] <-- v;
a[i] === v;
}
return a;
}

View file

@ -0,0 +1,5 @@
{
"curves": ["Bn128"],
"max_constraint_count": 0,
"tests": []
}

View file

@ -0,0 +1,8 @@
def main() {
field mut a = 0;
asm {
a <-- 1;
a === 1;
}
return;
}

View file

@ -0,0 +1,5 @@
{
"curves": ["Bn128"],
"max_constraint_count": 0,
"tests": []
}

View file

@ -0,0 +1,5 @@
def main() {
asm {
}
return;
}

View file

@ -0,0 +1,6 @@
{
"curves": ["Bn128"],
"max_constraint_count": 0,
"tests": []
}

View file

@ -0,0 +1,10 @@
def main() {
field mut a = 0;
field mut b = 0;
asm {
a <-- 1;
b <-- a;
b === 1;
}
return;
}

View file

@ -433,18 +433,18 @@ mod ast {
pub span: Span<'ast>,
}
#[derive(Debug, FromPest, PartialEq, Clone)]
#[derive(Debug, FromPest, PartialEq, Eq, Clone)]
#[pest_ast(rule(Rule::op_asm))]
pub enum AssignmentOperator {
Assign(AssignOperator),
AssignConstrain(AssignConstrainOperator),
}
#[derive(Debug, FromPest, PartialEq, Clone)]
#[derive(Debug, FromPest, PartialEq, Eq, Clone)]
#[pest_ast(rule(Rule::op_asm_assign))]
pub struct AssignOperator;
#[derive(Debug, FromPest, PartialEq, Clone)]
#[derive(Debug, FromPest, PartialEq, Eq, Clone)]
#[pest_ast(rule(Rule::op_asm_assign_constrain))]
pub struct AssignConstrainOperator;