diff --git a/src/logic/future_moves.rs b/src/logic/future_moves.rs index c5350a2..0ba82f4 100644 --- a/src/logic/future_moves.rs +++ b/src/logic/future_moves.rs @@ -3,8 +3,10 @@ use crate::{ repr::{Board, CoordPair, Piece, Winner}, }; use allocative::Allocative; -use indicatif::{ProgressIterator, ProgressStyle}; -use std::{collections::HashMap, hash::BuildHasherDefault, ops::ControlFlow}; +use indicatif::{ParallelProgressIterator, ProgressStyle}; +use rayon::iter::IntoParallelIterator; +use rayon::prelude::*; +use std::{collections::HashMap, hash::BuildHasherDefault}; #[derive(Allocative)] pub struct FutureMoves { @@ -158,29 +160,29 @@ impl FutureMoves { ) }; - let cf = self - .leaf_moves() + self.leaf_moves() .into_iter() .filter(|&i| self.depth_of(i) == self.current_depth) .collect::>() - .into_iter() + .into_par_iter() .progress_with_style(ProgressStyle::with_template(pstyle_inner).unwrap()) - .try_for_each(|node_idx| { - self.generate_children(node_idx); + .map(|parent_idx| (parent_idx, self.generate_children_raw(parent_idx))) + .collect::)>>() + .into_iter() + .for_each(|(parent_idx, moves)| { + let start_idx = self.arena.len(); + self.arena.extend(moves); - if self.arena_len() >= self.config.max_arena_size { - ControlFlow::Break(()) - } else { - ControlFlow::Continue(()) - } + let new_indices = start_idx..self.arena.len(); + self.arena[parent_idx].children.extend(new_indices); }); self.prune_bad_children(); - - if cf.is_break() { - return; - } self.current_depth += 1; + + if self.arena.len() >= self.config.max_arena_size { + break; + } } } @@ -200,10 +202,7 @@ impl FutureMoves { false } - /// Creates children for a parent (`parent_idx`) - /// Completely unchecked, the caller should be the one who tests to make sure child generation - /// hasn't already been tried on a parent - fn generate_children(&mut self, parent_idx: usize) { + fn generate_children_raw(&self, parent_idx: usize) -> Vec { let parent = &self.arena[parent_idx]; let new_color = !parent.color; @@ -230,14 +229,11 @@ impl FutureMoves { new.push(Move::new(None, parent_board, new_color, self.agent_color)); } - let start_idx = self.arena.len(); - self.arena.extend(new); - - let new_indices = start_idx..self.arena.len(); - - for child_idx in new_indices { - self.set_parent_child(parent_idx, child_idx); + for m in new.iter_mut() { + m.parent = Some(parent_idx); } + + new } /// Given an index from `self.arena`, what depth is it at? 0-indexed diff --git a/src/main.rs b/src/main.rs index 469475a..631bf0a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -32,7 +32,7 @@ fn main() { min_arena_depth: 14, top_k_children: 2, up_to_minus: 10, - max_arena_size: 100_000_000, + max_arena_size: 400_000_000, do_prune: true, print: true, children_eval_method: Default::default(),