use crate::{ logic::r#move::Move, repr::{Board, Coord, Piece, Winner}, }; use indicatif::{ProgressIterator, ProgressStyle}; use std::{collections::HashMap, hash::BuildHasherDefault}; pub struct FutureMoves { /// Arena containing all [`Move`] arena: Vec, /// Index of the [`Move`] tree's root node current_root: Option, /// Current generated depth of the Arena current_depth: usize, /// Color w.r.t agent_color: Piece, config: FutureMoveConfig, } #[derive(Default)] pub struct FutureMoveConfig { /// Max depth of that we should try and traverse pub max_depth: usize, /// subtract this value from FutureMove.max_depth /// and that would be the min depth an arena should fill for /// pruning to happen pub min_arena_depth_sub: usize, /// when pruning, keep the top_k # of children pub top_k_children: usize, // the lower the value, the more conservative the pruning is, what level to stop pruning at? // a lower value allows more possible paths pub up_to_minus: usize, /// Max size of the arena, will not generate more if /// the arena is of that size or bigger pub max_arena_size: usize, } impl FutureMoves { pub const fn new(agent_color: Piece, config: FutureMoveConfig) -> Self { Self { arena: Vec::new(), current_root: None, current_depth: 0, agent_color, config, } } /// Return the length of the Arena pub fn arena_len(&self) -> usize { self.arena.len() } /// Generate children for all children of `nodes` /// only `pub` for the sake of benchmarking pub fn extend_layers(&mut self) { for i in (self.current_depth + 1)..=self.config.max_depth { if self.arena_len() >= self.config.max_arena_size { dbg!("extend_layers: early break ({})", self.arena_len()); break; } (0..self.arena.len()) // we want to select all nodes that don't have children, or are lazy (need to maybe be regenerated) .filter(|&idx| { let got = &self.arena[idx]; !got.is_trimmed && got.winner == Winner::None && !got.tried_children }) .filter(|&idx| self.is_connected_to_root(idx)) .collect::>() .into_iter() .progress_with_style( ProgressStyle::with_template(&format!( "Generating children (depth: {}/{}): ({{pos}}/{{len}}) {{per_sec}}", i, self.config.max_depth )) .unwrap(), ) .for_each(|node_idx| { self.generate_children(node_idx).last(); self.arena[node_idx].tried_children = true; }); self.prune_bad_children(); self.current_depth += 1; } } /// Determines if a [`Move`] at index `idx` is connected to `self.current_root` /// Returns `false` if `self.current_root` is None fn is_connected_to_root(&self, idx: usize) -> bool { if let Some(root) = self.current_root { let mut current = Some(idx); while let Some(parent_idx) = current { if parent_idx == root { return true; } current = self.arena[parent_idx].parent; } } false } /// Creates children for a parent (`parent`), returns an iterator it's children's indexes /// Completely unchecked, the caller should be the one who tests to make sure child generation /// hasn't already been tried on a parent fn generate_children(&mut self, parent_idx: usize) -> impl Iterator { let parent = &self.arena[parent_idx]; let new_color = !parent.color; // use [`Board::all_positions`] here instead of [`Board::possible_moves`] // because we use [`Board::what_if`] later and we want to reduce calls to [`Board::propegate_from_dry`] let mut new: Vec = Board::all_positions() .flat_map(|(i, j)| { parent .board .what_if(i, j, new_color) .map(move |x| (i, j, x)) }) .map(|(i, j, new_board)| { Move::new(Some((i, j)), new_board, new_color, self.agent_color) }) .collect(); if new.is_empty() { new.push(Move::new(None, parent.board, new_color, self.agent_color)); } let start_idx = self.arena.len(); self.arena.extend(new); let new_indices = start_idx..self.arena.len(); for child_idx in new_indices.clone() { self.set_parent_child(parent_idx, child_idx); } new_indices } /// Given an index from `self.arena`, what depth is it at? 0-indexed fn depth_of(&self, node_idx: usize) -> usize { let mut depth = 0; let mut current = Some(node_idx); while let Some(parent_idx) = current { depth += 1; current = self.arena[parent_idx].parent; } depth - 1 } //// PERF! pre-organize all indexes based on what depth they're at /// previously, I did a lookup map based on if a node was visited, still resulted in a full /// O(n) iteration each depth fn by_depth(&self, indexes: impl Iterator) -> Vec<(usize, Vec)> { let mut by_depth: HashMap< usize, Vec, BuildHasherDefault>, > = HashMap::with_hasher(BuildHasherDefault::default()); for idx in indexes { let depth = self.depth_of(idx); if let Some(got) = by_depth.get_mut(&depth) { got.push(idx); } else { by_depth.insert(depth, vec![idx]); } } let mut by_depth_vec: Vec<(usize, Vec)> = by_depth.into_iter().collect(); by_depth_vec.sort_by_key(|x| x.0); by_depth_vec } /// Compute `Move.value`, propegating upwards from the furthest out Moves /// in the Arena. fn compute_values(&mut self, indexes: impl Iterator) { let by_depth_vec = self.by_depth(indexes); // reversed so we build up the value of the closest (in time) moves from the future for (depth, nodes) in by_depth_vec.into_iter().rev() { for idx in nodes { // TODO! impl dynamic sorting based on children's states, maybe it propegates // upwards using the `parent` field // let mut parent_copy = self.arena[idx].clone(); // parent_copy.sort_children(self.arena.as_mut_slice()); // self.arena[idx] = parent_copy; let children_value = self.arena[idx] .children .iter() .map(|&child| self.arena[child].value.expect("child has no value??")) .sum::(); // we use `depth` and divided `self_value` by it, idk if this is worth it // we should really setup some sort of ELO rating for each commit, playing them against // each other or something, could be cool to benchmark these more subjective things, not // just performance (cycles/time wise) self.arena[idx].value = Some( (self.arena[idx].self_value as i128 + children_value) / (depth + 1) as i128, ); } } } /// Return the best move which is a child of `self.current_root` pub fn best_move(&self) -> Option<(Coord, Coord)> { self.current_root .and_then(|x| { self.arena[x] .children .iter() .max_by_key(|&&idx| self.arena[idx].value) }) .inspect(|&&x| { assert_eq!( self.arena[x].color, self.agent_color, "selected move color should be the same as the color of the agent" ); }) .and_then(|&x| self.arena[x].coord) } /// Updates `FutureMoves` based on the current state of the board /// The board is supposed to be after the opposing move pub fn update_from_board(&mut self, board: &Board) { let curr_board = self .arena .iter() .enumerate() .find(|(_, m)| &m.board == board && (m.parent == self.current_root)) .map(|(idx, _)| idx) .filter(|_| self.current_root.is_some()); if let Some(curr_board_idx) = curr_board { self.set_root_idx_raw(curr_board_idx); } else { dbg!("regenerating arena from board"); self.set_root_from_board(*board); } } /// Clear the arena and create and set a root which contains a Board pub fn set_root_from_board(&mut self, board: Board) { self.arena.clear(); self.arena .push(Move::new(None, board, !self.agent_color, self.agent_color)); // because we have to regenerate root from a [`Board`] // we need to reset the current_depth (fixes `skip_move_recovery`) self.current_depth = 0; self.set_root_idx_raw(0); } /// Update the root based on the coordinate of the move /// Returns a boolean, `true` if the operation was successful, false if not #[must_use = "You must check if the root was properly set"] pub fn update_root_coord(&mut self, i: Coord, j: Coord) -> bool { // check to make sure current_root is some so we dont // have to do that in the iterator if self.current_root.is_none() { return false; } self.arena .iter() .enumerate() .find(|(_, node)| node.parent == self.current_root && node.coord == Some((i, j))) .map(|x| x.0) // do raw set so we can prune it on the next move (in `update`) .inspect(|&root| self.update_root_idx_raw(root)) .is_some() } /// Update current root without modifying or pruning the Arena fn update_root_idx_raw(&mut self, idx: usize) { self.current_root = Some(idx); self.current_depth -= self.depth_of(idx); } /// Update current root index while pruning and extending the tree (also recalculate values) fn set_root_idx_raw(&mut self, idx: usize) { self.update_root_idx_raw(idx); self.refocus_tree(); self.extend_layers(); self.compute_values(0..self.arena.len()); // check arena's consistancy assert_eq!(self.check_arena().join("\n"), ""); } pub fn set_parent_child(&mut self, parent: usize, child: usize) { self.arena[parent].children.push(child); self.arena[child].parent = Some(parent); } /// Checks the consistancy of the Arena (parents and children) /// returns a vector of errors ([`String`]) pub fn check_arena(&self) -> Vec { let mut errors = vec![]; for idx in 0..self.arena.len() { let m = &self.arena[idx]; if let Some(parent) = m.parent { if !(0..self.arena.len()).contains(&parent) { errors.push(format!("{}: parent is out of range ({})", idx, parent)); } if !self.arena[parent].children.contains(&idx) { errors.push(format!( "{}: parent ({}) doesn't list {} as child", idx, parent, idx )); } } for &child_idx in &m.children { if !(0..self.arena.len()).contains(&child_idx) { errors.push(format!("{}: parent is out of range ({})", idx, child_idx)); } if self.arena[child_idx].parent != Some(idx) { errors.push(format!( "{}: child ({}) does not list self as parent", idx, child_idx )); } } } errors } fn prune_bad_children(&mut self) { // values are needed in order to prune and see what's best self.compute_values(0..self.arena_len()); let by_depth = self.by_depth(0..self.arena.len()); if self .config .max_depth .saturating_sub(self.config.min_arena_depth_sub) > self.current_depth { return; } for (depth, indexes) in by_depth { // TODO! maybe update by_depth every iteration or something? if depth > self.current_depth.saturating_sub(self.config.up_to_minus) { return; } // only prune moves of the agent if indexes.first().map(|&i| self.arena[i].color) != Some(self.agent_color) { continue; } for idx in indexes { let mut m = self.arena[idx].clone(); if m.is_trimmed { continue; } m.is_trimmed = true; m.sort_children(&self.arena); if m.children.len() > self.config.top_k_children { let drained = m.children.drain(self.config.top_k_children..); for idx in drained { self.arena[idx].parent = None; } } self.arena[idx] = m; } } // rebuild tree to exclude the things that were pruned self.refocus_tree(); } /// Rebuilds the Arena based on `self.current_root`, prunes unrelated nodes fn refocus_tree(&mut self) { let Some(root) = self.current_root else { return; }; // make sure `root` doesn't reference another node self.arena[root].parent = None; let mut retain = vec![false; self.arena.len()]; // stack is going to be AT MAXIMUM, the size of the array, // so lets just pre-allocate it let mut stack: Vec = Vec::with_capacity(self.arena.len()); stack.push(root); // traverse children of the current root while let Some(idx) = stack.pop() { retain[idx] = true; stack.extend(self.arena[idx].children.iter()); } let mut index_map = vec![None; self.arena.len()]; let new_start: Vec<(usize, usize, Move)> = retain .into_iter() .enumerate() // old_idx .zip(self.arena.drain(..)) .filter(|&((_, keep), _)| keep) // filter out un-related nodes .map(|((old_idx, _), node)| (old_idx, node)) .enumerate() // new_idx .map(|(a, (b, c))| (a, b, c)) .collect(); for &(new_idx, old_idx, _) in &new_start { index_map[old_idx] = Some(new_idx); } self.arena = new_start .into_iter() .map(|(_, _, mut node)| { if let Some(parent) = node.parent.as_mut() { if let Some(new_parent) = index_map[*parent] { *parent = new_parent; } else { // make sure we don't have dangling parents node.parent = None; } } for c in node.children.as_mut_slice() { debug_assert!( index_map.get(*c).unwrap().is_some(), "index_map should contain the child's index" ); *c = unsafe { index_map.get_unchecked(*c).unwrap_unchecked() }; } node }) .collect(); self.current_root = index_map[root]; } } #[cfg(test)] mod tests { use super::*; const FUTURE_MOVES_CONFIG: FutureMoveConfig = FutureMoveConfig { max_depth: 1, min_arena_depth_sub: 2, top_k_children: 2, up_to_minus: 0, max_arena_size: 100, }; #[test] fn prune_tree_test() { let mut futm = FutureMoves::new(Piece::Black, FUTURE_MOVES_CONFIG); futm.arena.push(Move { coord: None, board: Board::new(), winner: Winner::None, parent: None, children: Vec::new(), value: None, self_value: 0, color: Piece::Black, is_trimmed: false, tried_children: false, }); futm.update_root_idx_raw(0); // child 1 futm.arena .push(Move::new(None, Board::new(), Piece::White, Piece::Black)); futm.set_parent_child(0, 1); // dummy (2) futm.arena.push(Move::new( Some((123, 123)), Board::new(), Piece::White, Piece::Black, )); // 3 futm.arena .push(Move::new(None, Board::new(), Piece::White, Piece::Black)); futm.set_parent_child(0, 3); // 4 futm.arena .push(Move::new(None, Board::new(), Piece::White, Piece::Black)); futm.set_parent_child(0, 4); assert_eq!(futm.arena_len(), 5); futm.refocus_tree(); assert_eq!(futm.arena_len(), 4); assert_eq!(futm.arena[0].children.len(), 3); assert_ne!( futm.arena[2].coord, Some((123, 123)), "dummy value still exists" ); } #[test] fn expand_layer_test() { let mut futm = FutureMoves::new(Piece::Black, FUTURE_MOVES_CONFIG); futm.config.max_depth = 1; futm.arena.push(Move::new( None, Board::new().starting_pos(), Piece::Black, Piece::Black, )); futm.update_root_idx_raw(0); futm.extend_layers(); assert_eq!(futm.arena_len(), 5); // move to a child futm.update_root_idx_raw(1); futm.refocus_tree(); assert_eq!(futm.arena_len(), 1); // make sure current_root is properly updated assert_eq!(futm.current_root, Some(0)); futm.extend_layers(); assert!( futm.arena_len() > 1, "extend_layer didn't grow arena after refocus" ); } #[test] fn depth_of_test() { let mut futm = FutureMoves::new(Piece::Black, FUTURE_MOVES_CONFIG); futm.arena.push(Move { coord: None, board: Board::new(), winner: Winner::None, parent: None, children: vec![], value: None, self_value: 0, color: Piece::Black, is_trimmed: false, tried_children: false, }); futm.update_root_idx_raw(0); // child 1 futm.arena .push(Move::new(None, Board::new(), Piece::White, Piece::Black)); futm.set_parent_child(0, 1); // dummy futm.arena .push(Move::new(None, Board::new(), Piece::White, Piece::Black)); futm.arena .push(Move::new(None, Board::new(), Piece::White, Piece::Black)); futm.set_parent_child(1, 3); futm.arena .push(Move::new(None, Board::new(), Piece::White, Piece::Black)); futm.set_parent_child(0, 4); assert_eq!(futm.depth_of(3), 2); } #[test] fn by_depth_test() { let mut futm = FutureMoves::new(Piece::Black, FUTURE_MOVES_CONFIG); futm.arena.push(Move { coord: None, board: Board::new(), winner: Winner::None, parent: None, children: vec![1], value: None, self_value: 0, color: Piece::Black, is_trimmed: false, tried_children: false, }); futm.update_root_idx_raw(0); // child 1 futm.arena .push(Move::new(None, Board::new(), Piece::White, Piece::Black)); futm.set_parent_child(0, 1); // dummy futm.arena .push(Move::new(None, Board::new(), Piece::White, Piece::Black)); futm.arena .push(Move::new(None, Board::new(), Piece::White, Piece::Black)); futm.set_parent_child(1, 3); assert_eq!( futm.by_depth(0..futm.arena.len()), vec![(0, vec![0, 2]), (1, vec![1]), (2, vec![3])] ); } /// tests whether or not FutureMoves can recover from multiple skips and then manually regenerating the arena #[test] fn skip_move_recovery() { let mut futm = FutureMoves::new(Piece::Black, FUTURE_MOVES_CONFIG); let mut board = Board::new().starting_pos(); // replay of a test I did // TODO! make this as small of a test as possible let moves = vec![ (Some((5, 4)), Piece::Black), (Some((5, 5)), Piece::White), (Some((5, 6)), Piece::Black), (Some((6, 4)), Piece::White), (Some((7, 3)), Piece::Black), (Some((7, 4)), Piece::White), (Some((7, 5)), Piece::Black), (Some((2, 4)), Piece::White), (Some((1, 4)), Piece::Black), (Some((1, 5)), Piece::White), (Some((1, 6)), Piece::Black), (Some((0, 6)), Piece::White), (Some((3, 2)), Piece::Black), (Some((1, 7)), Piece::White), (None, Piece::Black), // black skips a move (Some((0, 4)), Piece::White), (None, Piece::Black), // black skips a move (Some((4, 2)), Piece::White), ]; for (coords, color) in moves { if color == futm.agent_color { // my turn futm.update_from_board(&board); let best_move = futm.best_move(); if coords.is_none() { assert_eq!(best_move, None); } else { assert_ne!(best_move, None); } } if let Some((i, j)) = coords { board.place(i, j, color).unwrap(); } } } }