future_moves: use atomics for ensuring arena size

This commit is contained in:
Simon Gardling 2025-04-08 16:03:01 -04:00
parent ec3fc382d4
commit 788f88a509
Signed by: titaniumtown
GPG Key ID: 9AB28AC10ECE533D

View File

@ -6,7 +6,14 @@ use allocative::Allocative;
use indicatif::{ParallelProgressIterator, ProgressStyle};
use rayon::iter::IntoParallelIterator;
use rayon::prelude::*;
use std::{collections::HashMap, hash::BuildHasherDefault};
use std::{
collections::HashMap,
hash::BuildHasherDefault,
sync::{
atomic::{AtomicUsize, Ordering},
Arc, Mutex,
},
};
#[derive(Allocative)]
pub struct FutureMoves {
@ -161,6 +168,7 @@ impl FutureMoves {
};
let allowed_size = self.config.max_arena_size - self.arena.len();
let curr_size = Arc::new(AtomicUsize::new(0));
let got = self
.leaf_moves()
.into_iter()
@ -169,9 +177,18 @@ impl FutureMoves {
.into_par_iter()
.progress_with_style(ProgressStyle::with_template(pstyle_inner).unwrap())
.map(|parent_idx| (parent_idx, self.generate_children_raw(parent_idx)))
.take_any(allowed_size)
.take_any_while(|(_, x)| {
if curr_size.load(Ordering::Relaxed) + x.len() > allowed_size {
false
} else {
curr_size.fetch_add(x.len(), Ordering::Relaxed);
true
}
})
.collect::<Vec<(usize, Vec<Move>)>>();
let got_len = got.len();
// get total # of generated boards
let got_len = curr_size.load(Ordering::Acquire);
got.into_iter().for_each(|(parent_idx, moves)| {
let start_idx = self.arena.len();
@ -184,7 +201,7 @@ impl FutureMoves {
self.prune_bad_children();
self.current_depth += 1;
if got_len == allowed_size {
// got.len() has hit the upper limit of size permitted
// arena has hit the upper limit of size permitted
break;
}
}