1
0
Fork 0
mirror of synced 2025-09-23 12:18:44 +00:00

Merge pull request #855 from Zokrates/fix-constant-struct-member-mutation

Fix constant assignment to constant struct member panic
This commit is contained in:
Thibaut Schaeffer 2021-05-10 15:49:49 +02:00 committed by GitHub
commit dab975b7f7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 38 additions and 9 deletions

View file

@ -0,0 +1 @@
Fix crash when updating a constant struct member to another constant

View file

@ -152,13 +152,16 @@ fn is_constant<T: Field>(e: &TypedExpression<T>) -> bool {
} }
} }
fn remove_spreads<T: Field>(e: TypedExpression<T>) -> TypedExpression<T> { // in the constant map, we only want canonical constants: [0; 3] -> [0, 0, 0], [...[1], 2] -> [1, 2], etc
fn remove_spreads_aux<T: Field>(e: TypedExpressionOrSpread<T>) -> Vec<TypedExpression<T>> { fn to_canonical_constant<T: Field>(e: TypedExpression<T>) -> TypedExpression<T> {
fn to_canonical_constant_aux<T: Field>(
e: TypedExpressionOrSpread<T>,
) -> Vec<TypedExpression<T>> {
match e { match e {
TypedExpressionOrSpread::Expression(e) => vec![e], TypedExpressionOrSpread::Expression(e) => vec![e],
TypedExpressionOrSpread::Spread(s) => match s.array.into_inner() { TypedExpressionOrSpread::Spread(s) => match s.array.into_inner() {
ArrayExpressionInner::Value(v) => { ArrayExpressionInner::Value(v) => {
v.into_iter().flat_map(remove_spreads_aux).collect() v.into_iter().flat_map(to_canonical_constant_aux).collect()
} }
_ => unimplemented!(), _ => unimplemented!(),
}, },
@ -172,7 +175,7 @@ fn remove_spreads<T: Field>(e: TypedExpression<T>) -> TypedExpression<T> {
match a.into_inner() { match a.into_inner() {
ArrayExpressionInner::Value(v) => ArrayExpressionInner::Value( ArrayExpressionInner::Value(v) => ArrayExpressionInner::Value(
v.into_iter() v.into_iter()
.flat_map(remove_spreads_aux) .flat_map(to_canonical_constant_aux)
.map(|e| e.into()) .map(|e| e.into())
.collect::<Vec<_>>() .collect::<Vec<_>>()
.into(), .into(),
@ -197,7 +200,7 @@ fn remove_spreads<T: Field>(e: TypedExpression<T>) -> TypedExpression<T> {
ArrayExpressionInner::Value( ArrayExpressionInner::Value(
v.into_iter() v.into_iter()
.flat_map(remove_spreads_aux) .flat_map(to_canonical_constant_aux)
.map(|e| e.into()) .map(|e| e.into())
.enumerate() .enumerate()
.filter(|(index, _)| index >= &from && index < &to) .filter(|(index, _)| index >= &from && index < &to)
@ -214,7 +217,7 @@ fn remove_spreads<T: Field>(e: TypedExpression<T>) -> TypedExpression<T> {
_ => unreachable!("should be a uint value"), _ => unreachable!("should be a uint value"),
}; };
let e = remove_spreads(e); let e = to_canonical_constant(e);
ArrayExpressionInner::Value( ArrayExpressionInner::Value(
vec![TypedExpressionOrSpread::Expression(e); count].into(), vec![TypedExpressionOrSpread::Expression(e); count].into(),
@ -225,6 +228,18 @@ fn remove_spreads<T: Field>(e: TypedExpression<T>) -> TypedExpression<T> {
_ => unreachable!(), _ => unreachable!(),
} }
} }
TypedExpression::Struct(s) => {
let struct_ty = s.ty().clone();
match s.into_inner() {
StructExpressionInner::Value(expressions) => StructExpressionInner::Value(
expressions.into_iter().map(to_canonical_constant).collect(),
)
.annotate(struct_ty)
.into(),
_ => unreachable!(),
}
}
e => e, e => e,
} }
} }
@ -300,7 +315,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
if is_constant(&expr) { if is_constant(&expr) {
match assignee { match assignee {
TypedAssignee::Identifier(var) => { TypedAssignee::Identifier(var) => {
let expr = remove_spreads(expr); let expr = to_canonical_constant(expr);
assert!(self.constants.insert(var.id, expr).is_none()); assert!(self.constants.insert(var.id, expr).is_none());
@ -308,7 +323,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
} }
assignee => match self.try_get_constant_mut(&assignee) { assignee => match self.try_get_constant_mut(&assignee) {
Ok((_, c)) => { Ok((_, c)) => {
*c = remove_spreads(expr); *c = to_canonical_constant(expr);
Ok(vec![]) Ok(vec![])
} }
Err(v) => match self.constants.remove(&v.id) { Err(v) => match self.constants.remove(&v.id) {
@ -374,7 +389,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> {
let argument = arguments.pop().unwrap(); let argument = arguments.pop().unwrap();
let argument = remove_spreads(argument); let argument = to_canonical_constant(argument);
match ArrayExpression::try_from(argument) match ArrayExpression::try_from(argument)
.unwrap() .unwrap()

View file

@ -0,0 +1,5 @@
{
"entry_point": "./tests/tests/structs/constant.zok",
"curves": ["Bn128"],
"tests": []
}

View file

@ -0,0 +1,8 @@
struct State {
u32[16] memory
}
def main():
State s = State { memory: [0; 16] }
s.memory[0] = 0x00000001
return