reduce propagation memory usage
This commit is contained in:
parent
e46ad845d0
commit
b397e8a993
5 changed files with 117 additions and 58 deletions
1
changelogs/unreleased/1296-schaeff
Normal file
1
changelogs/unreleased/1296-schaeff
Normal file
|
@ -0,0 +1 @@
|
|||
Reduce memory usage of compilation
|
|
@ -32,6 +32,7 @@ pub enum Error {
|
|||
AssertionFailed(RuntimeError),
|
||||
InvalidValue(String),
|
||||
OutOfBounds(u128, u128),
|
||||
VariableLength(String),
|
||||
}
|
||||
|
||||
impl fmt::Display for Error {
|
||||
|
@ -45,6 +46,7 @@ impl fmt::Display for Error {
|
|||
"Out of bounds index ({} >= {}) found during static analysis",
|
||||
index, size
|
||||
),
|
||||
Error::VariableLength(message) => write!(f, "{}", message),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -61,6 +63,10 @@ impl<'ast, T: Field> Propagator<'ast, T> {
|
|||
Propagator::default().fold_program(p)
|
||||
}
|
||||
|
||||
pub fn clear_call_frame(&mut self, frame: usize) {
|
||||
self.constants.retain(|id, _| id.id.frame != frame);
|
||||
}
|
||||
|
||||
// get a mutable reference to the constant corresponding to a given assignee if any, otherwise
|
||||
// return the identifier at the root of this assignee
|
||||
fn try_get_constant_mut<'b>(
|
||||
|
@ -1153,59 +1159,93 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> {
|
|||
_: &E::Ty,
|
||||
e: SelectExpression<'ast, T, E>,
|
||||
) -> Result<SelectOrExpression<'ast, T, E>, Self::Error> {
|
||||
let array = self.fold_array_expression(*e.array)?;
|
||||
let index = self.fold_uint_expression(*e.index)?;
|
||||
let array = *e.array;
|
||||
|
||||
let inner_type = array.inner_type().clone();
|
||||
let size = array.size();
|
||||
let ty = self.fold_array_type(*array.ty)?;
|
||||
let size = match ty.size.as_inner() {
|
||||
UExpressionInner::Value(v) => Ok(v),
|
||||
_ => unreachable!("array size was checked when folding array type"),
|
||||
}?;
|
||||
|
||||
match size.into_inner() {
|
||||
UExpressionInner::Value(size) => match (array.into_inner(), index.into_inner()) {
|
||||
(ArrayExpressionInner::Value(v), UExpressionInner::Value(n)) => {
|
||||
if n < size {
|
||||
Ok(SelectOrExpression::Expression(
|
||||
v.expression_at::<E>(n.value as usize).unwrap().into_inner(),
|
||||
))
|
||||
} else {
|
||||
Err(Error::OutOfBounds(n.value, size.value))
|
||||
}
|
||||
}
|
||||
(ArrayExpressionInner::Identifier(id), UExpressionInner::Value(n)) => {
|
||||
match self.constants.get(&id.id) {
|
||||
Some(a) => match a {
|
||||
match (array.inner, index.into_inner()) {
|
||||
// special case if the array is an identifier: check the cache and only clone the element, not the whole array
|
||||
(ArrayExpressionInner::Identifier(id), UExpressionInner::Value(n)) => {
|
||||
match self.constants.get(&id.id) {
|
||||
Some(v) => {
|
||||
// get the constant array. it was guaranteed to be a value when it was inserted
|
||||
let v = match v {
|
||||
TypedExpression::Array(a) => match a.as_inner() {
|
||||
ArrayExpressionInner::Value(v) => {
|
||||
Ok(SelectOrExpression::Expression(
|
||||
v.expression_at::<E>(n.value as usize)
|
||||
.unwrap()
|
||||
.into_inner(),
|
||||
))
|
||||
}
|
||||
_ => unreachable!("should be an array value"),
|
||||
ArrayExpressionInner::Value(v) => v,
|
||||
_ => unreachable!(),
|
||||
},
|
||||
_ => unreachable!("should be an array expression"),
|
||||
},
|
||||
None => Ok(SelectOrExpression::Expression(
|
||||
E::select(
|
||||
ArrayExpressionInner::Identifier(id)
|
||||
.annotate(ArrayType::new(inner_type, size.value as u32)),
|
||||
UExpressionInner::Value(n).annotate(UBitwidth::B32),
|
||||
)
|
||||
.into_inner(),
|
||||
)),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
// sanity check that the value does not contain spreads
|
||||
assert!(v
|
||||
.value
|
||||
.iter()
|
||||
.all(|e| matches!(e, TypedExpressionOrSpread::Expression(_))));
|
||||
|
||||
if n.value < size.value {
|
||||
Ok(SelectOrExpression::Expression(
|
||||
// clone only the element
|
||||
match v.value[n.value as usize].clone() {
|
||||
TypedExpressionOrSpread::Expression(e) => {
|
||||
E::try_from(e).unwrap().into_inner()
|
||||
}
|
||||
_ => unreachable!(),
|
||||
},
|
||||
))
|
||||
} else {
|
||||
Err(Error::OutOfBounds(n.value, size.value))
|
||||
}
|
||||
}
|
||||
_ => Ok(SelectOrExpression::Select(SelectExpression::new(
|
||||
ArrayExpressionInner::Identifier(id).annotate(ty),
|
||||
UExpressionInner::Value(n).annotate(UBitwidth::B32),
|
||||
))),
|
||||
}
|
||||
(a, i) => Ok(SelectOrExpression::Select(SelectExpression::new(
|
||||
a.annotate(ArrayType::new(inner_type, size.value as u32)),
|
||||
i.annotate(UBitwidth::B32),
|
||||
))),
|
||||
},
|
||||
_ => Ok(SelectOrExpression::Select(SelectExpression::new(
|
||||
array, index,
|
||||
))),
|
||||
}
|
||||
(array, index) => {
|
||||
let array = self.fold_array_expression_inner(&ty, array)?;
|
||||
|
||||
match (array, index) {
|
||||
(ArrayExpressionInner::Value(v), UExpressionInner::Value(n)) => {
|
||||
if n.value < size.value {
|
||||
Ok(SelectOrExpression::Expression(
|
||||
v.expression_at::<E>(n.value as usize).into_inner(),
|
||||
))
|
||||
} else {
|
||||
Err(Error::OutOfBounds(n.value, size.value))
|
||||
}
|
||||
}
|
||||
(a, i) => Ok(SelectOrExpression::Select(SelectExpression::new(
|
||||
a.annotate(ty),
|
||||
i.annotate(UBitwidth::B32),
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn fold_array_type(
|
||||
&mut self,
|
||||
t: ArrayType<'ast, T>,
|
||||
) -> Result<ArrayType<'ast, T>, Self::Error> {
|
||||
let size = self.fold_uint_expression(*t.size)?;
|
||||
|
||||
if !size.is_constant() {
|
||||
return Err(Error::VariableLength(format!(
|
||||
"Array length should be fixed, found {}",
|
||||
size
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(ArrayType::new(self.fold_type(*t.ty)?, size))
|
||||
}
|
||||
|
||||
fn fold_array_expression_cases(
|
||||
&mut self,
|
||||
ty: &ArrayType<'ast, T>,
|
||||
|
|
|
@ -106,6 +106,15 @@ impl<'ast, 'a, T: Field> Reducer<'ast, 'a, T> {
|
|||
program,
|
||||
}
|
||||
}
|
||||
|
||||
fn push_call_frame(&mut self) {
|
||||
self.ssa.push_call_frame();
|
||||
}
|
||||
|
||||
fn pop_call_frame(&mut self) {
|
||||
self.propagator.clear_call_frame(self.ssa.latest_frame);
|
||||
self.ssa.pop_call_frame();
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
|
||||
|
@ -159,7 +168,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
|
|||
})
|
||||
.collect::<Result<_, _>>()?;
|
||||
|
||||
self.ssa.push_call_frame();
|
||||
self.push_call_frame();
|
||||
|
||||
let res = inline_call::<_, E>(&e.function_key, &generics, &arguments, ty, self.program);
|
||||
|
||||
|
@ -236,7 +245,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> {
|
|||
}
|
||||
};
|
||||
|
||||
self.ssa.pop_call_frame();
|
||||
self.pop_call_frame();
|
||||
|
||||
res
|
||||
}
|
||||
|
|
|
@ -1731,28 +1731,31 @@ impl<'ast, T> std::iter::FromIterator<TypedExpressionOrSpread<'ast, T>>
|
|||
|
||||
impl<'ast, T: Field> ArrayValueExpression<'ast, T> {
|
||||
fn expression_at_aux<
|
||||
'a,
|
||||
U: Select<'ast, T> + From<TypedExpression<'ast, T>> + Into<TypedExpression<'ast, T>>,
|
||||
>(
|
||||
v: TypedExpressionOrSpread<'ast, T>,
|
||||
) -> Vec<Option<U>> {
|
||||
) -> Vec<TypedExpression<'ast, T>> {
|
||||
match v {
|
||||
TypedExpressionOrSpread::Expression(e) => vec![Some(e.into())],
|
||||
TypedExpressionOrSpread::Expression(e) => vec![e],
|
||||
TypedExpressionOrSpread::Spread(s) => match s.array.size().into_inner() {
|
||||
UExpressionInner::Value(size) => {
|
||||
let array_ty = s.array.ty().clone();
|
||||
|
||||
match s.array.into_inner() {
|
||||
ArrayExpressionInner::Value(v) => {
|
||||
v.into_iter().flat_map(Self::expression_at_aux).collect()
|
||||
}
|
||||
ArrayExpressionInner::Value(v) => v
|
||||
.value
|
||||
.into_iter()
|
||||
.flat_map(Self::expression_at_aux::<U>)
|
||||
.collect(),
|
||||
a => (0..size.value)
|
||||
.map(|i| {
|
||||
Some(U::select(a.clone().annotate(array_ty.clone()), i as u32))
|
||||
U::select(a.clone().annotate(array_ty.clone()), i as u32).into()
|
||||
})
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
_ => vec![None],
|
||||
_ => unreachable!(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
@ -1760,14 +1763,16 @@ impl<'ast, T: Field> ArrayValueExpression<'ast, T> {
|
|||
pub fn expression_at<
|
||||
U: Select<'ast, T> + From<TypedExpression<'ast, T>> + Into<TypedExpression<'ast, T>>,
|
||||
>(
|
||||
&self,
|
||||
self,
|
||||
index: usize,
|
||||
) -> Option<U> {
|
||||
self.iter()
|
||||
.flat_map(|v| Self::expression_at_aux(v.clone()))
|
||||
.take_while(|e| e.is_some())
|
||||
.map(|e| e.unwrap())
|
||||
) -> U {
|
||||
self.into_iter()
|
||||
.flat_map(|v| Self::expression_at_aux::<U>(v))
|
||||
.nth(index)
|
||||
.unwrap()
|
||||
.clone()
|
||||
.try_into()
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
def main(u32 N) {
|
||||
field[N] a = [0; N];
|
||||
return;
|
||||
}
|
Loading…
Reference in a new issue