small tweaks, simplify and improve propagation
This commit is contained in:
parent
5b4f581a74
commit
c8080e9656
4 changed files with 106 additions and 82 deletions
|
@ -153,8 +153,6 @@ pub fn compile<T: Field, R: BufRead, S: BufRead, E: Into<imports::Error>>(
|
|||
// analyse (unroll and constant propagation)
|
||||
let typed_ast = typed_ast.analyse();
|
||||
|
||||
println!("{}", typed_ast);
|
||||
|
||||
// flatten input program
|
||||
let program_flattened = Flattener::flatten(typed_ast);
|
||||
|
||||
|
|
|
@ -601,8 +601,6 @@ impl<'ast> Checker<'ast> {
|
|||
&mut self,
|
||||
assignee: AssigneeNode<'ast, T>,
|
||||
) -> Result<TypedAssignee<'ast, T>, Error> {
|
||||
println!("{:?}", assignee.value);
|
||||
|
||||
let pos = assignee.pos();
|
||||
// check that the assignee is declared
|
||||
match assignee.value {
|
||||
|
@ -889,7 +887,7 @@ impl<'ast> Checker<'ast> {
|
|||
let f = &candidates[0];
|
||||
// the return count has to be 1
|
||||
match f.signature.outputs.len() {
|
||||
1 => match f.signature.outputs[0].clone() {
|
||||
1 => match &f.signature.outputs[0] {
|
||||
Type::FieldElement => Ok(FieldElementExpression::FunctionCall(
|
||||
FunctionKey {
|
||||
id: f.id.clone(),
|
||||
|
@ -898,15 +896,17 @@ impl<'ast> Checker<'ast> {
|
|||
arguments_checked,
|
||||
)
|
||||
.into()),
|
||||
Type::Array(ty, size) => Ok(ArrayExpressionInner::FunctionCall(
|
||||
FunctionKey {
|
||||
id: f.id.clone(),
|
||||
signature: f.signature.clone(),
|
||||
},
|
||||
arguments_checked,
|
||||
)
|
||||
.annotate(*ty, size)
|
||||
.into()),
|
||||
Type::Array(box ty, size) => {
|
||||
Ok(ArrayExpressionInner::FunctionCall(
|
||||
FunctionKey {
|
||||
id: f.id.clone(),
|
||||
signature: f.signature.clone(),
|
||||
},
|
||||
arguments_checked,
|
||||
)
|
||||
.annotate(ty.clone(), size.clone())
|
||||
.into())
|
||||
}
|
||||
_ => unimplemented!(),
|
||||
},
|
||||
n => Err(Error {
|
||||
|
@ -1085,7 +1085,14 @@ impl<'ast> Checker<'ast> {
|
|||
.into()),
|
||||
}
|
||||
}
|
||||
_ => unreachable!(""),
|
||||
e => Err(Error {
|
||||
pos: Some(pos),
|
||||
message: format!(
|
||||
"Cannot access slice of expression {} of type {}",
|
||||
e,
|
||||
e.get_type(),
|
||||
),
|
||||
}),
|
||||
},
|
||||
RangeOrExpression::Expression(e) => match (array, self.check_expression(e)?) {
|
||||
(TypedExpression::Array(a), TypedExpression::FieldElement(i)) => {
|
||||
|
|
|
@ -27,6 +27,18 @@ impl<'ast, T: Field> Propagator<'ast, T> {
|
|||
}
|
||||
}
|
||||
|
||||
fn is_constant<'ast, T: Field>(e: &TypedExpression<'ast, T>) -> bool {
|
||||
match e {
|
||||
TypedExpression::FieldElement(FieldElementExpression::Number(..)) => true,
|
||||
TypedExpression::Boolean(BooleanExpression::Value(..)) => true,
|
||||
TypedExpression::Array(a) => match a.as_inner() {
|
||||
ArrayExpressionInner::Value(v) => v.iter().all(|e| is_constant(e)),
|
||||
_ => false,
|
||||
},
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
|
||||
fn fold_function(&mut self, f: TypedFunction<'ast, T>) -> TypedFunction<'ast, T> {
|
||||
self.constants = HashMap::new();
|
||||
|
@ -44,56 +56,16 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
|
|||
)),
|
||||
// propagation to the defined variable if rhs is a constant
|
||||
TypedStatement::Definition(TypedAssignee::Identifier(var), expr) => {
|
||||
match self.fold_expression(expr) {
|
||||
e @ TypedExpression::Boolean(BooleanExpression::Value(..))
|
||||
| e @ TypedExpression::FieldElement(FieldElementExpression::Number(..)) => {
|
||||
self.constants.insert(TypedAssignee::Identifier(var), e);
|
||||
None
|
||||
}
|
||||
TypedExpression::Array(e) => {
|
||||
let ty = e.inner_type().clone();
|
||||
let size = e.size();
|
||||
let expr = self.fold_expression(expr);
|
||||
|
||||
match e.into_inner() {
|
||||
ArrayExpressionInner::Value(array) => {
|
||||
let array: Vec<_> =
|
||||
array.into_iter().map(|e| self.fold_expression(e)).collect();
|
||||
|
||||
match array.iter().all(|e| match e {
|
||||
TypedExpression::FieldElement(
|
||||
FieldElementExpression::Number(..),
|
||||
) => true,
|
||||
TypedExpression::Boolean(BooleanExpression::Value(..)) => true,
|
||||
_ => false,
|
||||
}) {
|
||||
true => {
|
||||
// all elements of the array are constants
|
||||
self.constants.insert(
|
||||
TypedAssignee::Identifier(var),
|
||||
ArrayExpressionInner::Value(array)
|
||||
.annotate(ty, size)
|
||||
.into(),
|
||||
);
|
||||
None
|
||||
}
|
||||
false => Some(TypedStatement::Definition(
|
||||
TypedAssignee::Identifier(var),
|
||||
ArrayExpressionInner::Value(array)
|
||||
.annotate(ty, size)
|
||||
.into(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
e => Some(TypedStatement::Definition(
|
||||
TypedAssignee::Identifier(var),
|
||||
TypedExpression::Array(e.annotate(ty, size)),
|
||||
)),
|
||||
}
|
||||
}
|
||||
e => Some(TypedStatement::Definition(
|
||||
if is_constant(&expr) {
|
||||
self.constants.insert(TypedAssignee::Identifier(var), expr);
|
||||
None
|
||||
} else {
|
||||
Some(TypedStatement::Definition(
|
||||
TypedAssignee::Identifier(var),
|
||||
e,
|
||||
)),
|
||||
expr,
|
||||
))
|
||||
}
|
||||
}
|
||||
TypedStatement::Definition(TypedAssignee::Select(..), _) => {
|
||||
|
@ -109,7 +81,7 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
|
|||
}
|
||||
// we unrolled for loops in the previous step
|
||||
TypedStatement::For(..) => {
|
||||
panic!("for loop is unexpected, it should have been unrolled")
|
||||
unreachable!("for loop is unexpected, it should have been unrolled")
|
||||
}
|
||||
TypedStatement::MultipleDefinition(variables, expression_list) => {
|
||||
let expression_list = self.fold_expression_list(expression_list);
|
||||
|
@ -138,9 +110,9 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
|
|||
))) {
|
||||
Some(e) => match e {
|
||||
TypedExpression::FieldElement(e) => e.clone(),
|
||||
_ => {
|
||||
panic!("constant stored for a field element should be a field element")
|
||||
}
|
||||
_ => unreachable!(
|
||||
"constant stored for a field element should be a field element"
|
||||
),
|
||||
},
|
||||
None => FieldElementExpression::Identifier(id),
|
||||
}
|
||||
|
@ -194,7 +166,7 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
|
|||
(e1, FieldElementExpression::Number(n2)) => {
|
||||
FieldElementExpression::Pow(box e1, box FieldElementExpression::Number(n2))
|
||||
}
|
||||
(_, e2) => panic!(format!(
|
||||
(_, e2) => unreachable!(format!(
|
||||
"non-constant exponent {} detected during static analysis",
|
||||
e2
|
||||
)),
|
||||
|
@ -222,10 +194,10 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
|
|||
if n_as_usize < size {
|
||||
FieldElementExpression::try_from(v[n_as_usize].clone()).unwrap()
|
||||
} else {
|
||||
panic!(format!(
|
||||
unreachable!(
|
||||
"out of bounds index ({} >= {}) found during static analysis",
|
||||
n_as_usize, size
|
||||
));
|
||||
);
|
||||
}
|
||||
}
|
||||
(ArrayExpressionInner::Identifier(id), FieldElementExpression::Number(n)) => {
|
||||
|
@ -239,7 +211,7 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
|
|||
)) {
|
||||
Some(e) => match e {
|
||||
TypedExpression::FieldElement(e) => e.clone(),
|
||||
_ => panic!(""),
|
||||
_ => unreachable!(""),
|
||||
},
|
||||
None => FieldElementExpression::Select(
|
||||
box ArrayExpressionInner::Identifier(id).annotate(inner_type, size),
|
||||
|
@ -278,6 +250,49 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
|
|||
None => ArrayExpressionInner::Identifier(id),
|
||||
}
|
||||
}
|
||||
ArrayExpressionInner::Select(box array, box index) => {
|
||||
let array = self.fold_array_expression(array);
|
||||
let index = self.fold_field_expression(index);
|
||||
|
||||
let inner_type = array.inner_type().clone();
|
||||
let size = array.size();
|
||||
|
||||
match (array.into_inner(), index) {
|
||||
(ArrayExpressionInner::Value(v), FieldElementExpression::Number(n)) => {
|
||||
let n_as_usize = n.to_dec_string().parse::<usize>().unwrap();
|
||||
if n_as_usize < size {
|
||||
ArrayExpression::try_from(v[n_as_usize].clone())
|
||||
.unwrap()
|
||||
.into_inner()
|
||||
} else {
|
||||
unreachable!(
|
||||
"out of bounds index ({} >= {}) found during static analysis",
|
||||
n_as_usize, size
|
||||
);
|
||||
}
|
||||
}
|
||||
(ArrayExpressionInner::Identifier(id), FieldElementExpression::Number(n)) => {
|
||||
match self.constants.get(&TypedAssignee::Select(
|
||||
box TypedAssignee::Identifier(Variable::array(
|
||||
id.clone(),
|
||||
inner_type.clone(),
|
||||
size,
|
||||
)),
|
||||
box FieldElementExpression::Number(n.clone()).into(),
|
||||
)) {
|
||||
Some(e) => match e {
|
||||
TypedExpression::Array(e) => e.clone().into_inner(),
|
||||
_ => unreachable!(""),
|
||||
},
|
||||
None => ArrayExpressionInner::Select(
|
||||
box ArrayExpressionInner::Identifier(id).annotate(inner_type, size),
|
||||
box FieldElementExpression::Number(n),
|
||||
),
|
||||
}
|
||||
}
|
||||
(a, i) => ArrayExpressionInner::Select(box a.annotate(inner_type, size), box i),
|
||||
}
|
||||
}
|
||||
ArrayExpressionInner::IfElse(box condition, box consequence, box alternative) => {
|
||||
let consequence = self.fold_array_expression(consequence);
|
||||
let alternative = self.fold_array_expression(alternative);
|
||||
|
|
|
@ -81,17 +81,21 @@ pub struct TypedModule<'ast, T: Field> {
|
|||
|
||||
impl<'ast> fmt::Display for Identifier<'ast> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"{}_{}_{}",
|
||||
self.stack
|
||||
.iter()
|
||||
.map(|(name, sig, count)| format!("{}_{}_{}", name, sig.to_slug(), count))
|
||||
.collect::<Vec<_>>()
|
||||
.join("_"),
|
||||
self.id,
|
||||
self.version
|
||||
)
|
||||
if self.stack.len() == 0 && self.version == 0 {
|
||||
write!(f, "{}", self.id)
|
||||
} else {
|
||||
write!(
|
||||
f,
|
||||
"{}_{}_{}",
|
||||
self.stack
|
||||
.iter()
|
||||
.map(|(name, sig, count)| format!("{}_{}_{}", name, sig.to_slug(), count))
|
||||
.collect::<Vec<_>>()
|
||||
.join("_"),
|
||||
self.id,
|
||||
self.version
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue