From 788f88a509a8791f277beb97baaa282ceb825db9 Mon Sep 17 00:00:00 2001 From: Simon Gardling Date: Tue, 8 Apr 2025 16:03:01 -0400 Subject: [PATCH] future_moves: use atomics for ensuring arena size --- src/logic/future_moves.rs | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/src/logic/future_moves.rs b/src/logic/future_moves.rs index 3fec03a..ec9b96b 100644 --- a/src/logic/future_moves.rs +++ b/src/logic/future_moves.rs @@ -6,7 +6,14 @@ use allocative::Allocative; use indicatif::{ParallelProgressIterator, ProgressStyle}; use rayon::iter::IntoParallelIterator; use rayon::prelude::*; -use std::{collections::HashMap, hash::BuildHasherDefault}; +use std::{ + collections::HashMap, + hash::BuildHasherDefault, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, Mutex, + }, +}; #[derive(Allocative)] pub struct FutureMoves { @@ -161,6 +168,7 @@ impl FutureMoves { }; let allowed_size = self.config.max_arena_size - self.arena.len(); + let curr_size = Arc::new(AtomicUsize::new(0)); let got = self .leaf_moves() .into_iter() @@ -169,9 +177,18 @@ impl FutureMoves { .into_par_iter() .progress_with_style(ProgressStyle::with_template(pstyle_inner).unwrap()) .map(|parent_idx| (parent_idx, self.generate_children_raw(parent_idx))) - .take_any(allowed_size) + .take_any_while(|(_, x)| { + if curr_size.load(Ordering::Relaxed) + x.len() > allowed_size { + false + } else { + curr_size.fetch_add(x.len(), Ordering::Relaxed); + true + } + }) .collect::)>>(); - let got_len = got.len(); + + // get total # of generated boards + let got_len = curr_size.load(Ordering::Acquire); got.into_iter().for_each(|(parent_idx, moves)| { let start_idx = self.arena.len(); @@ -184,7 +201,7 @@ impl FutureMoves { self.prune_bad_children(); self.current_depth += 1; if got_len == allowed_size { - // got.len() has hit the upper limit of size permitted + // arena has hit the upper limit of size permitted break; } }