diff --git a/crypto/stark/src/fri/fri_functions.rs b/crypto/stark/src/fri/fri_functions.rs index 8bd355ec4..e3940aa1a 100644 --- a/crypto/stark/src/fri/fri_functions.rs +++ b/crypto/stark/src/fri/fri_functions.rs @@ -16,16 +16,42 @@ pub fn fold_evaluations_in_place, E: IsField>( evals: &mut Vec>, zeta: &FieldElement, inv_twiddles: &[FieldElement], -) { +) where + FieldElement: Send + Sync, + FieldElement: Sync, +{ let half = evals.len() / 2; - for j in 0..half { - let lo = &evals[2 * j]; - let hi = &evals[2 * j + 1]; - let sum = lo + hi; - let diff = lo - hi; - evals[j] = &sum + &(&inv_twiddles[j] * &(zeta * &diff)); + + #[cfg(feature = "parallel")] + { + use rayon::prelude::*; + // Parallel fold: split evals into pairs, compute folded value for each. + // Write results into a new Vec to avoid aliasing (evals[j] overlaps evals[2*j]). + let folded: Vec> = (0..half) + .into_par_iter() + .map(|j| { + let lo = &evals[2 * j]; + let hi = &evals[2 * j + 1]; + let sum = lo + hi; + let diff = lo - hi; + &sum + &(&inv_twiddles[j] * &(zeta * &diff)) + }) + .collect(); + evals.truncate(half); + evals[..half].clone_from_slice(&folded); + } + + #[cfg(not(feature = "parallel"))] + { + for j in 0..half { + let lo = &evals[2 * j]; + let hi = &evals[2 * j + 1]; + let sum = lo + hi; + let diff = lo - hi; + evals[j] = &sum + &(&inv_twiddles[j] * &(zeta * &diff)); + } + evals.truncate(half); } - evals.truncate(half); } /// Compute inverse twiddle factors for evaluation-form FRI folding.