translated the algorithm code from cpp to rust

This commit is contained in:
cqql 2024-10-19 01:19:38 +02:00
commit da18a7e6a7
5 changed files with 251 additions and 0 deletions

1
.gitignore vendored Normal file
View file

@ -0,0 +1 @@
/target

7
Cargo.lock generated Normal file
View file

@ -0,0 +1,7 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 3
[[package]]
name = "connect4"
version = "0.1.0"

6
Cargo.toml Normal file
View file

@ -0,0 +1,6 @@
[package]
name = "connect4"
version = "0.1.0"
edition = "2021"
[dependencies]

205
src/connect4.rs Normal file
View file

@ -0,0 +1,205 @@
// code based fully on https://github.com/balkarjun/ConnectFourAI/blob/main/main.cpp
// rewritten from cpp to rust
/*
6 13 20 27 34 41 48 55 62
---------------------
| 5 12 19 26 33 40 47 | 54 61
| 4 11 18 25 32 39 46 | 53 60
| 3 10 17 24 31 38 45 | 52 59
| 2 9 16 23 30 37 44 | 51 58
| 1 8 15 22 29 36 43 | 50 57
| 0 7 14 21 28 35 42 | 49 56 63
---------------------
Bitboard Representation
*/
const NROWS: usize = 6;
const NCOLS: usize = 7;
// bit indices for the top padding row
const PAD_TOP: [i64; NCOLS] = [6, 13, 20, 27, 34, 41, 48];
// simple heuristic, where a position's value is the number of possible 4-in-a-rows from that position
const SCOREMAP: [i64; NCOLS * (NROWS + 1)] = [
3, 4, 5, 5, 4, 3, 0, 4, 6, 8, 8, 6, 4, 0, 5, 8, 11, 11, 8, 5, 0, 7, 9, 13, 13, 9, 7, 0, 5, 8,
11, 11, 8, 5, 0, 4, 6, 8, 8, 6, 4, 0, 3, 4, 5, 5, 4, 3, 0,
];
pub struct GameState {
pub bitboards: [i64; 2],
pub counter: i64,
pub heights: [i64; NCOLS],
}
impl GameState {
pub fn new() -> Self {
let bitboards: [i64; 2] = [0, 0];
// number of moves made
let counter: i64 = 0;
// bit indices where next move should be made
let heights: [i64; NCOLS] = [0, 7, 14, 21, 28, 35, 42];
GameState {
bitboards,
counter,
heights,
}
}
pub fn move_is_valid(&self, icol: usize) -> bool {
return self.heights[icol] < PAD_TOP[icol];
}
/// assumes that the move is valid
pub fn make_move(&mut self, icol: usize) {
// flip the appropriate bit (0 -> 1)
self.bitboards[(self.counter & 1) as usize] ^= 1 << self.heights[icol];
self.counter += 1;
self.heights[icol] += 1;
}
/// assumes it is called right after a move is made
pub fn undo_move(&mut self, icol: usize) {
self.counter -= 1;
self.heights[icol] -= 1;
// flip the appropriate bit (1 -> 0)
self.bitboards[(self.counter & 1) as usize] ^= 1 << self.heights[icol];
}
pub fn check_win(&self) -> i64 {
for idx in 0..=1 {
let board: i64 = self.bitboards[idx];
if (board & (board >> 7) & (board >> 14) & (board >> 21)) != 0 {
return idx as i64;
} // horizontal
if (board & (board >> 1) & (board >> 2) & (board >> 3)) != 0 {
return idx as i64;
} // vertical
if (board & (board >> 8) & (board >> 16) & (board >> 24)) != 0 {
return idx as i64;
} // positive diagonal
if (board & (board >> 6) & (board >> 12) & (board >> 18)) != 0 {
return idx as i64;
} // negative diagonal
}
return -1;
}
/// returns -1 if not in an end state, or an appropriate score
pub fn end_state_reached(&self) -> i64 {
let winner = self.check_win();
if winner == 0 {
return 100000;
} // white won
if winner == 1 {
return -100000;
} // black won
if self.counter >= 42 {
return 0;
} // tie
return -1;
}
/// outputs a score based on the current state of the board
/// assumes that the board is not in an end state
pub fn evaluate(&self) -> i64 {
// for each cell, add score for white and subtract score for black
let mut score = 0;
for idx in 0..48 {
score += SCOREMAP[idx]
* (((self.bitboards[0] >> idx) & 1) - ((self.bitboards[1] >> idx) & 1));
}
return score;
}
pub fn display(&self) {
for irow in (0..NROWS).rev() {
print!("\n ");
for icol in 0..NCOLS {
let idx = irow + 7 * icol;
if (self.bitboards[0] >> idx) & 1 != 0 {
print!("");
} else if (self.bitboards[1] >> idx) & 1 != 0 {
print!("");
} else {
print!("· ");
}
}
}
print!("\n 1 2 3 4 5 6 7\n\n");
}
}
pub fn minimax(board: &mut GameState, mut alpha: i64, beta: i64, depth: i64) -> i64 {
let score = board.end_state_reached();
if score != -1 {
return score;
}
if depth == 0 {
return board.evaluate();
}
let sign: i64 = 1 - 2 * (board.counter & 1); // 1 if white, -1 if black
for icol in 0..NCOLS {
// skip filled columns
if !board.move_is_valid(icol) {
continue;
};
board.make_move(icol);
let score = minimax(board, beta, alpha, depth - 1);
board.undo_move(icol);
if sign * score > sign * alpha {
alpha = score
};
if sign * alpha >= sign * beta {
break;
};
}
return alpha;
}
pub fn get_minimax_move(board: &mut GameState, depth: i64) -> usize {
// +inf for black (1), -inf for white (0)
let mut alpha: i64 = if board.counter & 1 != 0 {
1000000
} else {
-1000000
};
let beta: i64 = -alpha;
let mut best_move: usize = 0;
let sign = 1 - 2 * (board.counter & 1); // 1 if white, -1 if black
for icol in 0..NCOLS {
// skip filled columns
if !board.move_is_valid(icol) {
continue;
};
board.make_move(icol);
let score = minimax(board, beta, alpha, depth - 1);
board.undo_move(icol);
if sign * score > sign * alpha {
alpha = score;
best_move = icol;
}
if sign * alpha >= sign * beta {
break;
};
}
return best_move;
}

32
src/main.rs Normal file
View file

@ -0,0 +1,32 @@
mod connect4;
fn main() {
let mut board = connect4::GameState::new();
let depth = 7;
let mut end_score: i64 = -1;
while end_score == -1 {
println!("Move {}:", board.counter);
board.display();
let best_move =
connect4::get_minimax_move(
&mut board,
depth
);
board.make_move(best_move);
end_score = board.end_state_reached();
}
println!("Move {}:", board.counter);
board.display();
if end_score == 0 {
println!("It's a Tie");
} else if end_score > 0 {
println!("White (●) Won!");
} else {
println!("Black (○) Won!");
}
}