563 lines
17 KiB
Rust
563 lines
17 KiB
Rust
use crate::{
|
|
logic::r#move::Move,
|
|
repr::{Board, Piece, Winner},
|
|
};
|
|
use indicatif::{ProgressIterator, ProgressStyle};
|
|
use std::collections::HashMap;
|
|
|
|
pub struct FutureMoves {
|
|
/// Arena containing all [`Move`]
|
|
arena: Vec<Move>,
|
|
|
|
/// Index of the [`Move`] tree's root node
|
|
current_root: Option<usize>,
|
|
|
|
/// Current generated depth of the Arena
|
|
current_depth: usize,
|
|
|
|
/// Target depth of children to generate
|
|
max_depth: usize,
|
|
|
|
/// How many deep should the lazy children status expire?
|
|
lazy_expire: usize,
|
|
|
|
/// Color w.r.t
|
|
agent_color: Piece,
|
|
}
|
|
|
|
impl FutureMoves {
|
|
pub const fn new(agent_color: Piece, max_depth: usize, lazy_expire: usize) -> Self {
|
|
Self {
|
|
arena: Vec::new(),
|
|
current_root: None,
|
|
current_depth: 0,
|
|
max_depth,
|
|
agent_color,
|
|
lazy_expire,
|
|
}
|
|
}
|
|
|
|
/// Return the length of the Arena
|
|
pub fn arena_len(&self) -> usize {
|
|
self.arena.len()
|
|
}
|
|
|
|
/// Generate children for all children of `nodes`
|
|
/// only `pub` for the sake of benchmarking
|
|
pub fn extend_layers(&mut self) {
|
|
let mut next_nodes: Vec<usize> = (0..self.arena.len())
|
|
// we want to select all nodes that don't have children, or are lazy (need to maybe be regenerated)
|
|
.filter(|&idx| {
|
|
let got = &self.arena[idx];
|
|
got.lazy_children || got.children.is_empty()
|
|
})
|
|
.filter(|&idx| self.is_connected_to_root(idx)) // put here so this will not extend needlessly before prunes
|
|
.collect();
|
|
|
|
for i in (self.current_depth + 1)..=self.max_depth {
|
|
next_nodes = next_nodes
|
|
.into_iter()
|
|
.progress_with_style(
|
|
ProgressStyle::with_template(&format!(
|
|
"Generating children (depth: {}/{}): ({{pos}}/{{len}}) {{per_sec}}",
|
|
i, self.max_depth
|
|
))
|
|
.unwrap(),
|
|
)
|
|
.flat_map(|node_idx| {
|
|
self.generate_children(
|
|
node_idx,
|
|
if self.arena[node_idx].lazy_children {
|
|
self.depth_of(node_idx)
|
|
} else {
|
|
i
|
|
} > self.lazy_expire,
|
|
)
|
|
})
|
|
.flatten()
|
|
.collect();
|
|
}
|
|
self.current_depth = self.max_depth;
|
|
}
|
|
|
|
/// Determines if a [`Move`] at index `idx` is connected to `self.current_root`
|
|
/// Returns `false` if `self.current_root` is None
|
|
fn is_connected_to_root(&self, idx: usize) -> bool {
|
|
if let Some(root) = self.current_root {
|
|
let mut current = Some(idx);
|
|
while let Some(parent_idx) = current {
|
|
if parent_idx == root {
|
|
return true;
|
|
}
|
|
current = self.arena[parent_idx].parent;
|
|
}
|
|
}
|
|
|
|
false
|
|
}
|
|
|
|
/// Creates children for a parent (`parent`), returns an iterator it's children's indexes
|
|
fn generate_children(
|
|
&mut self,
|
|
parent_idx: usize,
|
|
lazy_children: bool,
|
|
) -> Option<impl Iterator<Item = usize>> {
|
|
let parent = &self.arena[parent_idx];
|
|
|
|
// early-exit if a winner for the parent already exists
|
|
if parent.winner != Winner::None {
|
|
return None;
|
|
}
|
|
|
|
let new_color = !parent.color;
|
|
|
|
// use [`Board::all_positions`] here instead of [`Board::possible_moves`]
|
|
// because we use [`Board::what_if`] later and we want to reduce calls to [`Board::propegate_from_dry`]
|
|
let new: Vec<Move> = Board::all_positions()
|
|
.flat_map(|(i, j)| {
|
|
parent
|
|
.board
|
|
.what_if(i, j, new_color)
|
|
.map(move |x| (i, j, x))
|
|
})
|
|
.map(|(i, j, new_board)| {
|
|
Move::new(
|
|
i,
|
|
j,
|
|
new_board,
|
|
new_color,
|
|
lazy_children,
|
|
self.agent_color,
|
|
Some(parent_idx),
|
|
)
|
|
})
|
|
.collect();
|
|
|
|
// keep the TOP_K children of their magnitude
|
|
const TOP_K_CHILDREN: usize = 10;
|
|
|
|
let start_idx = self.arena.len();
|
|
self.arena.extend(new);
|
|
|
|
let new_indices = start_idx..self.arena.len();
|
|
|
|
self.arena[parent_idx].children.extend(new_indices.clone());
|
|
let mut parent_copy = self.arena[parent_idx].clone();
|
|
parent_copy.sort_children(self.arena.as_mut_slice());
|
|
self.arena[parent_idx] = parent_copy;
|
|
|
|
if lazy_children && new_indices.clone().count() > TOP_K_CHILDREN {
|
|
for i in new_indices.clone().skip(TOP_K_CHILDREN) {
|
|
self.arena[i].lazy_children = true;
|
|
}
|
|
}
|
|
|
|
Some(new_indices)
|
|
}
|
|
|
|
/// Given an index from `self.arena`, what depth is it at? 0-indexed
|
|
fn depth_of(&self, node_idx: usize) -> usize {
|
|
let mut depth = 0;
|
|
let mut current = Some(node_idx);
|
|
while let Some(parent_idx) = current {
|
|
depth += 1;
|
|
current = self.arena[parent_idx].parent;
|
|
}
|
|
|
|
depth - 1
|
|
}
|
|
|
|
/// Compute `Move.value`, propegating upwards from the furthest out Moves
|
|
/// in the Arena.
|
|
fn compute_values(&mut self, indexes: impl Iterator<Item = usize>) {
|
|
// PERF! pre-organize all indexes based on what depth they're at
|
|
// previously, I did a lookup map based on if a node was visited, still resulted in a full
|
|
// O(n) iteration each depth
|
|
let mut by_depth: HashMap<usize, Vec<usize>> = HashMap::new();
|
|
for idx in indexes {
|
|
let depth = self.depth_of(idx);
|
|
if let Some(got) = by_depth.get_mut(&depth) {
|
|
got.push(idx);
|
|
} else {
|
|
by_depth.insert(depth, vec![idx]);
|
|
}
|
|
}
|
|
|
|
let mut by_depth_vec: Vec<(usize, Vec<usize>)> = by_depth.into_iter().collect();
|
|
by_depth_vec.sort_by_key(|x| x.0);
|
|
|
|
for (depth, nodes) in by_depth_vec {
|
|
for idx in nodes {
|
|
// TODO! impl dynamic sorting based on children's states, maybe it propegates
|
|
// upwards using the `parent` field
|
|
// SAFETY! the sort_by_key function should not modify anything
|
|
unsafe { &mut (*(self as *mut Self)) }.arena[idx]
|
|
.children
|
|
// negative because we want the largest value in the first index
|
|
// abs so we get the most extreme solutions
|
|
// but base on `.value` for recursive behavior
|
|
.sort_by_key(|&x| -self.arena[x].value.abs());
|
|
|
|
let children_value = self.arena[idx]
|
|
.children
|
|
.iter()
|
|
.rev() // rev then reverse so we get an index starting from the back
|
|
.enumerate()
|
|
// since children are sorted by value, we should weight the first one more
|
|
.map(|(i, &child)| self.arena[child].value * ((i + 1) as i128))
|
|
.sum::<i128>();
|
|
|
|
// previously we used `depth` and divided `self_value` by it, idk if this is worth it
|
|
// we should really setup some sort of ELO rating for each commit, playing them against
|
|
// each other or something, could be cool to benchmark these more subjective things, not
|
|
// just performance
|
|
self.arena[idx].value = self.arena[idx].self_value as i128 + children_value;
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Return the best move which is a child of `self.current_root`
|
|
pub fn best_move(&self) -> Option<(usize, usize)> {
|
|
self.current_root
|
|
.and_then(|x| {
|
|
self.arena[x]
|
|
.children
|
|
.iter()
|
|
.max_by_key(|&&idx| self.arena[idx].value)
|
|
})
|
|
.inspect(|&&x| {
|
|
assert_eq!(
|
|
self.arena[x].color, self.agent_color,
|
|
"selected move color should be the same as the color of the agent"
|
|
);
|
|
})
|
|
.map(|&x| self.arena[x].coords())
|
|
}
|
|
|
|
/// Updates `FutureMoves` based on the current state of the board
|
|
/// The board is supposed to be after the opposing move
|
|
pub fn update_from_board(&mut self, board: &Board) {
|
|
let curr_board = self
|
|
.arena
|
|
.iter()
|
|
.enumerate()
|
|
.find(|(_, m)| &m.board == board && (m.parent == self.current_root))
|
|
.map(|(idx, _)| idx)
|
|
.filter(|_| self.current_root.is_some());
|
|
|
|
if let Some(curr_board_idx) = curr_board {
|
|
self.set_root_idx_raw(curr_board_idx);
|
|
} else {
|
|
self.set_root_from_board(*board);
|
|
}
|
|
}
|
|
|
|
/// Clear the arena and create and set a root which contains a Board
|
|
pub fn set_root_from_board(&mut self, board: Board) {
|
|
self.arena.clear();
|
|
self.arena.push(Move::new(
|
|
0, // dummy
|
|
0, // dummy
|
|
board,
|
|
!self.agent_color,
|
|
false,
|
|
self.agent_color,
|
|
None,
|
|
));
|
|
self.set_root_idx_raw(0);
|
|
}
|
|
|
|
/// Update the root based on the coordinate of the move
|
|
/// Returns a boolean, `true` if the operation was successful, false if not
|
|
#[must_use = "You must check if the root was properly set"]
|
|
pub fn update_root_coord(&mut self, i: usize, j: usize) -> bool {
|
|
self.arena
|
|
.iter()
|
|
.enumerate()
|
|
.find(|(_, node)| {
|
|
node.parent == self.current_root
|
|
&& self.current_root.is_some()
|
|
&& node.coords() == (i, j)
|
|
})
|
|
.map(|x| x.0)
|
|
// do raw set so we can prune it on the next move (in `update`)
|
|
.inspect(|&root| self.update_root_idx_raw(root))
|
|
.is_some()
|
|
}
|
|
|
|
/// Update current root without modifying or pruning the Arena
|
|
fn update_root_idx_raw(&mut self, idx: usize) {
|
|
self.current_root = Some(idx);
|
|
self.current_depth -= self.depth_of(idx);
|
|
}
|
|
|
|
/// Update current root index while pruning and extending the tree (also recalculate values)
|
|
fn set_root_idx_raw(&mut self, idx: usize) {
|
|
self.update_root_idx_raw(idx);
|
|
|
|
// self.prune_bad_children();
|
|
self.refocus_tree();
|
|
self.extend_layers();
|
|
self.compute_values(0..self.arena.len());
|
|
}
|
|
|
|
fn prune_bad_children(&mut self) {
|
|
const BOTTOM_PERC: f32 = 20.0;
|
|
let Some(root) = self.current_root else {
|
|
return;
|
|
};
|
|
|
|
let mut children = self.arena[root].children.clone();
|
|
|
|
children.sort_by_key(|&i| -self.arena[i].value);
|
|
let start_len = ((children.len()) as f32 * (1.0 - BOTTOM_PERC)) as usize;
|
|
let drained = children.drain(start_len..);
|
|
println!("{}", drained.len());
|
|
|
|
for i in drained {
|
|
self.arena[i].parent = None;
|
|
}
|
|
|
|
self.arena[root].children = children;
|
|
}
|
|
|
|
/// Rebuilds the Arena based on `self.current_root`, prunes unrelated nodes
|
|
fn refocus_tree(&mut self) {
|
|
let Some(root) = self.current_root else {
|
|
return;
|
|
};
|
|
|
|
// make sure `root` doesn't reference another node
|
|
self.arena[root].parent = None;
|
|
|
|
let mut retain = vec![false; self.arena.len()];
|
|
|
|
// stack is going to be AT MAXIMUM, the size of the array,
|
|
// so lets just pre-allocate it
|
|
let mut stack: Vec<usize> = Vec::with_capacity(self.arena.len());
|
|
stack.push(root);
|
|
|
|
// traverse children of the current root
|
|
while let Some(idx) = stack.pop() {
|
|
retain[idx] = true;
|
|
stack.extend(self.arena[idx].children.iter());
|
|
}
|
|
|
|
let mut index_map = vec![None; self.arena.len()];
|
|
|
|
let new_start: Vec<(usize, usize, Move)> = retain
|
|
.into_iter()
|
|
.enumerate() // old_idx
|
|
.zip(self.arena.drain(..))
|
|
.filter(|&((_, keep), _)| keep) // filter out un-related nodes
|
|
.map(|((old_idx, _), node)| (old_idx, node))
|
|
.enumerate() // new_idx
|
|
.map(|(a, (b, c))| (a, b, c))
|
|
.collect();
|
|
|
|
for &(new_idx, old_idx, _) in &new_start {
|
|
index_map[old_idx] = Some(new_idx);
|
|
}
|
|
|
|
self.arena = new_start
|
|
.into_iter()
|
|
.map(|(_, _, mut node)| {
|
|
if let Some(parent) = node.parent.as_mut() {
|
|
if let Some(new_parent) = index_map[*parent] {
|
|
*parent = new_parent;
|
|
} else {
|
|
// make sure we don't have dangling parents
|
|
node.parent = None;
|
|
}
|
|
}
|
|
|
|
for c in node.children.as_mut_slice() {
|
|
debug_assert!(
|
|
index_map.get(*c).unwrap().is_some(),
|
|
"index_map should contain the child's index"
|
|
);
|
|
*c = unsafe { index_map.get_unchecked(*c).unwrap_unchecked() };
|
|
}
|
|
|
|
node
|
|
})
|
|
.collect();
|
|
|
|
self.current_root = index_map[root];
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn prune_tree_test() {
|
|
let mut futm = FutureMoves::new(Piece::Black, 0, 0);
|
|
futm.arena.push(Move {
|
|
i: 0,
|
|
j: 0,
|
|
board: Board::new(),
|
|
winner: Winner::None,
|
|
parent: None,
|
|
children: vec![1, 3, 4],
|
|
value: 0,
|
|
self_value: 0,
|
|
color: Piece::Black,
|
|
lazy_children: false,
|
|
});
|
|
|
|
futm.update_root_idx_raw(0);
|
|
|
|
// child 1
|
|
futm.arena.push(Move::new(
|
|
0,
|
|
0,
|
|
Board::new(),
|
|
Piece::White,
|
|
false,
|
|
Piece::Black,
|
|
Some(0),
|
|
));
|
|
|
|
// dummy
|
|
futm.arena.push(Move::new(
|
|
1234,
|
|
1234,
|
|
Board::new(),
|
|
Piece::White,
|
|
false,
|
|
Piece::Black,
|
|
None,
|
|
));
|
|
|
|
futm.arena.push(Move::new(
|
|
0,
|
|
0,
|
|
Board::new(),
|
|
Piece::White,
|
|
false,
|
|
Piece::Black,
|
|
Some(0),
|
|
));
|
|
|
|
futm.arena.push(Move::new(
|
|
0,
|
|
0,
|
|
Board::new(),
|
|
Piece::White,
|
|
false,
|
|
Piece::Black,
|
|
Some(0),
|
|
));
|
|
assert_eq!(futm.arena_len(), 5);
|
|
|
|
futm.refocus_tree();
|
|
assert_eq!(futm.arena_len(), 4);
|
|
assert_eq!(futm.arena[0].children.len(), 3);
|
|
|
|
assert_ne!(futm.arena[2].i, 1234, "dummy value still exists");
|
|
}
|
|
|
|
#[test]
|
|
fn expand_layer_test() {
|
|
let mut futm = FutureMoves::new(Piece::Black, 1, 1);
|
|
futm.arena.push(Move::new(
|
|
0,
|
|
0,
|
|
Board::new().starting_pos(),
|
|
Piece::Black,
|
|
false,
|
|
Piece::Black,
|
|
None,
|
|
));
|
|
|
|
futm.update_root_idx_raw(0);
|
|
|
|
futm.extend_layers();
|
|
assert_eq!(futm.arena_len(), 5);
|
|
|
|
// move to a child
|
|
futm.update_root_idx_raw(1);
|
|
futm.refocus_tree();
|
|
assert_eq!(futm.arena_len(), 1);
|
|
|
|
// make sure current_root is properly updated
|
|
assert_eq!(futm.current_root, Some(0));
|
|
|
|
futm.extend_layers();
|
|
assert!(
|
|
futm.arena_len() > 1,
|
|
"extend_layer didn't grow arena after refocus"
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn depth_of_test() {
|
|
let mut futm = FutureMoves::new(Piece::Black, 0, 0);
|
|
|
|
futm.arena.push(Move {
|
|
i: 0,
|
|
j: 0,
|
|
board: Board::new(),
|
|
winner: Winner::None,
|
|
parent: None,
|
|
children: vec![1],
|
|
value: 0,
|
|
self_value: 0,
|
|
color: Piece::Black,
|
|
lazy_children: false,
|
|
});
|
|
|
|
futm.update_root_idx_raw(0);
|
|
|
|
// child 1
|
|
futm.arena.push(Move::new(
|
|
0,
|
|
0,
|
|
Board::new(),
|
|
Piece::White,
|
|
false,
|
|
Piece::Black,
|
|
Some(0),
|
|
));
|
|
futm.arena[1].parent = Some(0);
|
|
futm.arena[1].children = vec![3];
|
|
|
|
// dummy
|
|
futm.arena.push(Move::new(
|
|
1234,
|
|
1234,
|
|
Board::new(),
|
|
Piece::White,
|
|
false,
|
|
Piece::Black,
|
|
None,
|
|
));
|
|
|
|
futm.arena.push(Move::new(
|
|
0,
|
|
0,
|
|
Board::new(),
|
|
Piece::White,
|
|
false,
|
|
Piece::Black,
|
|
Some(0),
|
|
));
|
|
|
|
futm.arena[3].parent = Some(1);
|
|
|
|
futm.arena.push(Move::new(
|
|
0,
|
|
0,
|
|
Board::new(),
|
|
Piece::White,
|
|
false,
|
|
Piece::Black,
|
|
Some(0),
|
|
));
|
|
|
|
assert_eq!(futm.depth_of(3), 2);
|
|
}
|
|
}
|