1
0
Fork 0
mirror of synced 2025-09-24 04:40:05 +00:00

small tweaks, simplify and improve propagation

This commit is contained in:
schaeff 2019-09-09 13:37:47 +02:00
parent 5b4f581a74
commit c8080e9656
4 changed files with 106 additions and 82 deletions

View file

@ -153,8 +153,6 @@ pub fn compile<T: Field, R: BufRead, S: BufRead, E: Into<imports::Error>>(
// analyse (unroll and constant propagation)
let typed_ast = typed_ast.analyse();
println!("{}", typed_ast);
// flatten input program
let program_flattened = Flattener::flatten(typed_ast);

View file

@ -601,8 +601,6 @@ impl<'ast> Checker<'ast> {
&mut self,
assignee: AssigneeNode<'ast, T>,
) -> Result<TypedAssignee<'ast, T>, Error> {
println!("{:?}", assignee.value);
let pos = assignee.pos();
// check that the assignee is declared
match assignee.value {
@ -889,7 +887,7 @@ impl<'ast> Checker<'ast> {
let f = &candidates[0];
// the return count has to be 1
match f.signature.outputs.len() {
1 => match f.signature.outputs[0].clone() {
1 => match &f.signature.outputs[0] {
Type::FieldElement => Ok(FieldElementExpression::FunctionCall(
FunctionKey {
id: f.id.clone(),
@ -898,15 +896,17 @@ impl<'ast> Checker<'ast> {
arguments_checked,
)
.into()),
Type::Array(ty, size) => Ok(ArrayExpressionInner::FunctionCall(
FunctionKey {
id: f.id.clone(),
signature: f.signature.clone(),
},
arguments_checked,
)
.annotate(*ty, size)
.into()),
Type::Array(box ty, size) => {
Ok(ArrayExpressionInner::FunctionCall(
FunctionKey {
id: f.id.clone(),
signature: f.signature.clone(),
},
arguments_checked,
)
.annotate(ty.clone(), size.clone())
.into())
}
_ => unimplemented!(),
},
n => Err(Error {
@ -1085,7 +1085,14 @@ impl<'ast> Checker<'ast> {
.into()),
}
}
_ => unreachable!(""),
e => Err(Error {
pos: Some(pos),
message: format!(
"Cannot access slice of expression {} of type {}",
e,
e.get_type(),
),
}),
},
RangeOrExpression::Expression(e) => match (array, self.check_expression(e)?) {
(TypedExpression::Array(a), TypedExpression::FieldElement(i)) => {

View file

@ -27,6 +27,18 @@ impl<'ast, T: Field> Propagator<'ast, T> {
}
}
fn is_constant<'ast, T: Field>(e: &TypedExpression<'ast, T>) -> bool {
match e {
TypedExpression::FieldElement(FieldElementExpression::Number(..)) => true,
TypedExpression::Boolean(BooleanExpression::Value(..)) => true,
TypedExpression::Array(a) => match a.as_inner() {
ArrayExpressionInner::Value(v) => v.iter().all(|e| is_constant(e)),
_ => false,
},
_ => false,
}
}
impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
fn fold_function(&mut self, f: TypedFunction<'ast, T>) -> TypedFunction<'ast, T> {
self.constants = HashMap::new();
@ -44,56 +56,16 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
)),
// propagation to the defined variable if rhs is a constant
TypedStatement::Definition(TypedAssignee::Identifier(var), expr) => {
match self.fold_expression(expr) {
e @ TypedExpression::Boolean(BooleanExpression::Value(..))
| e @ TypedExpression::FieldElement(FieldElementExpression::Number(..)) => {
self.constants.insert(TypedAssignee::Identifier(var), e);
None
}
TypedExpression::Array(e) => {
let ty = e.inner_type().clone();
let size = e.size();
let expr = self.fold_expression(expr);
match e.into_inner() {
ArrayExpressionInner::Value(array) => {
let array: Vec<_> =
array.into_iter().map(|e| self.fold_expression(e)).collect();
match array.iter().all(|e| match e {
TypedExpression::FieldElement(
FieldElementExpression::Number(..),
) => true,
TypedExpression::Boolean(BooleanExpression::Value(..)) => true,
_ => false,
}) {
true => {
// all elements of the array are constants
self.constants.insert(
TypedAssignee::Identifier(var),
ArrayExpressionInner::Value(array)
.annotate(ty, size)
.into(),
);
None
}
false => Some(TypedStatement::Definition(
TypedAssignee::Identifier(var),
ArrayExpressionInner::Value(array)
.annotate(ty, size)
.into(),
)),
}
}
e => Some(TypedStatement::Definition(
TypedAssignee::Identifier(var),
TypedExpression::Array(e.annotate(ty, size)),
)),
}
}
e => Some(TypedStatement::Definition(
if is_constant(&expr) {
self.constants.insert(TypedAssignee::Identifier(var), expr);
None
} else {
Some(TypedStatement::Definition(
TypedAssignee::Identifier(var),
e,
)),
expr,
))
}
}
TypedStatement::Definition(TypedAssignee::Select(..), _) => {
@ -109,7 +81,7 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
}
// we unrolled for loops in the previous step
TypedStatement::For(..) => {
panic!("for loop is unexpected, it should have been unrolled")
unreachable!("for loop is unexpected, it should have been unrolled")
}
TypedStatement::MultipleDefinition(variables, expression_list) => {
let expression_list = self.fold_expression_list(expression_list);
@ -138,9 +110,9 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
))) {
Some(e) => match e {
TypedExpression::FieldElement(e) => e.clone(),
_ => {
panic!("constant stored for a field element should be a field element")
}
_ => unreachable!(
"constant stored for a field element should be a field element"
),
},
None => FieldElementExpression::Identifier(id),
}
@ -194,7 +166,7 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
(e1, FieldElementExpression::Number(n2)) => {
FieldElementExpression::Pow(box e1, box FieldElementExpression::Number(n2))
}
(_, e2) => panic!(format!(
(_, e2) => unreachable!(format!(
"non-constant exponent {} detected during static analysis",
e2
)),
@ -222,10 +194,10 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
if n_as_usize < size {
FieldElementExpression::try_from(v[n_as_usize].clone()).unwrap()
} else {
panic!(format!(
unreachable!(
"out of bounds index ({} >= {}) found during static analysis",
n_as_usize, size
));
);
}
}
(ArrayExpressionInner::Identifier(id), FieldElementExpression::Number(n)) => {
@ -239,7 +211,7 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
)) {
Some(e) => match e {
TypedExpression::FieldElement(e) => e.clone(),
_ => panic!(""),
_ => unreachable!(""),
},
None => FieldElementExpression::Select(
box ArrayExpressionInner::Identifier(id).annotate(inner_type, size),
@ -278,6 +250,49 @@ impl<'ast, T: Field> Folder<'ast, T> for Propagator<'ast, T> {
None => ArrayExpressionInner::Identifier(id),
}
}
ArrayExpressionInner::Select(box array, box index) => {
let array = self.fold_array_expression(array);
let index = self.fold_field_expression(index);
let inner_type = array.inner_type().clone();
let size = array.size();
match (array.into_inner(), index) {
(ArrayExpressionInner::Value(v), FieldElementExpression::Number(n)) => {
let n_as_usize = n.to_dec_string().parse::<usize>().unwrap();
if n_as_usize < size {
ArrayExpression::try_from(v[n_as_usize].clone())
.unwrap()
.into_inner()
} else {
unreachable!(
"out of bounds index ({} >= {}) found during static analysis",
n_as_usize, size
);
}
}
(ArrayExpressionInner::Identifier(id), FieldElementExpression::Number(n)) => {
match self.constants.get(&TypedAssignee::Select(
box TypedAssignee::Identifier(Variable::array(
id.clone(),
inner_type.clone(),
size,
)),
box FieldElementExpression::Number(n.clone()).into(),
)) {
Some(e) => match e {
TypedExpression::Array(e) => e.clone().into_inner(),
_ => unreachable!(""),
},
None => ArrayExpressionInner::Select(
box ArrayExpressionInner::Identifier(id).annotate(inner_type, size),
box FieldElementExpression::Number(n),
),
}
}
(a, i) => ArrayExpressionInner::Select(box a.annotate(inner_type, size), box i),
}
}
ArrayExpressionInner::IfElse(box condition, box consequence, box alternative) => {
let consequence = self.fold_array_expression(consequence);
let alternative = self.fold_array_expression(alternative);

View file

@ -81,17 +81,21 @@ pub struct TypedModule<'ast, T: Field> {
impl<'ast> fmt::Display for Identifier<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"{}_{}_{}",
self.stack
.iter()
.map(|(name, sig, count)| format!("{}_{}_{}", name, sig.to_slug(), count))
.collect::<Vec<_>>()
.join("_"),
self.id,
self.version
)
if self.stack.len() == 0 && self.version == 0 {
write!(f, "{}", self.id)
} else {
write!(
f,
"{}_{}_{}",
self.stack
.iter()
.map(|(name, sig, count)| format!("{}_{}_{}", name, sig.to_slug(), count))
.collect::<Vec<_>>()
.join("_"),
self.id,
self.version
)
}
}
}