Skip to content

Commit 25e134d

Browse files
committed
Optimized opening db lookup
dded transposition table
1 parent 12285b2 commit 25e134d

File tree

10 files changed

+123
-64
lines changed

10 files changed

+123
-64
lines changed

Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,5 @@ clap = { version = "*", features = ["derive"] }
2626
flate2 = { version = "1.0", features = ["rust_backend"] }
2727
once_cell = "1.21.3"
2828
tar = "0.4.44"
29-
rand = "*"
29+
rand = "*"
30+
lru = "*"

notebooks/benchmark.ipynb

Lines changed: 51 additions & 27 deletions
Large diffs are not rendered by default.

src/api/get/get_eval.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
use axum::{ extract::{ State, Json }, response::IntoResponse, http::StatusCode };
22
use serde::{ Deserialize, Serialize };
3-
use std::{ str::FromStr, collections::HashMap, sync::Arc };
3+
use std::{ collections::HashMap, str::FromStr, sync::{ Arc, Mutex } };
44
use chess::Board;
5-
use crate::bot::{ algorithm::root::search, include::types::{ EngineState, ServerState } };
5+
use crate::bot::{
6+
algorithm::root::search,
7+
include::types::{ EngineState, ServerState, TT_TABLE_SIZE },
8+
};
9+
use lru::LruCache;
10+
use std::num::NonZeroUsize;
611

712
#[derive(Debug, Deserialize)]
813
pub struct EvalRequest {
@@ -49,12 +54,15 @@ pub async fn eval_position_handler(
4954
}
5055
}
5156

57+
let capacity = NonZeroUsize::new(TT_TABLE_SIZE).unwrap();
58+
let transposition_table = Arc::new(Mutex::new(LruCache::new(capacity)));
5259
let mut engine = EngineState {
5360
game_id: "eval_temp".to_string(),
5461
current_board,
5562
history,
5663
statistics: HashMap::new(),
5764
global_map: Arc::clone(&state.global_map),
65+
transposition_table,
5866
};
5967

6068
let board = engine.current_board.clone();

src/api/post/add_game.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
use axum::{ extract::State, response::IntoResponse, Json, http::StatusCode };
2+
use lru::LruCache;
23
use serde::{ Deserialize, Serialize };
3-
use std::{ collections::HashMap, str::FromStr, sync::Arc };
4-
use crate::bot::include::types::{ EngineState, ServerState };
4+
use std::{ collections::HashMap, num::NonZeroUsize, str::FromStr, sync::{ Arc, Mutex } };
5+
use crate::bot::include::types::{ EngineState, ServerState, TT_TABLE_SIZE };
56
use chess::Board;
67

78
#[derive(Debug, Deserialize)]
@@ -50,12 +51,15 @@ pub async fn new_game_handler(
5051
}
5152
}
5253

54+
let capacity = NonZeroUsize::new(TT_TABLE_SIZE).unwrap();
55+
let transposition_table = Arc::new(Mutex::new(LruCache::new(capacity)));
5356
let engine = EngineState {
5457
game_id: payload.game_id.clone(),
5558
current_board,
5659
history,
5760
statistics: HashMap::new(),
5861
global_map: Arc::clone(&state.global_map),
62+
transposition_table,
5963
};
6064

6165
state.engines.insert(payload.game_id.clone(), engine);

src/bot/algorithm/ab.rs

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,25 +52,38 @@ pub fn alpha_beta(
5252
}
5353

5454
*nodes += 1;
55-
56-
// Track maximum depth reached
5755
*max_depth_reached = (*max_depth_reached).max(current_depth);
5856

5957
let board_hash = board.get_hash();
6058
let repetition_count = engine_state.history.get(&board_hash).copied().unwrap_or(0);
6159
let prioritized_moves = get_prioritized_moves(board, is_noisy);
6260

61+
{
62+
let mut tt = engine_state.transposition_table.lock().unwrap();
63+
if let Some(cached_eval) = tt.get(&(board_hash, depth)) {
64+
return (None, *cached_eval);
65+
}
66+
}
67+
6368
match board.status() {
6469
chess::BoardStatus::Checkmate => {
6570
let eval = evaluate_board(board);
6671
let depth_weight = 20 - current_depth.min(20);
6772
let score = (eval * ((depth_weight as i32) + 1)).clamp(-500_000, 500_000);
73+
74+
// ✅ Save to TT
75+
engine_state.transposition_table.lock().unwrap().put((board_hash, depth), score);
76+
6877
return (None, score);
6978
}
7079
chess::BoardStatus::Stalemate => {
80+
engine_state.transposition_table.lock().unwrap().put((board_hash, depth), 0);
81+
7182
return (None, 0);
7283
}
7384
_ if repetition_count >= 3 => {
85+
engine_state.transposition_table.lock().unwrap().put((board_hash, depth), 0);
86+
7487
return (None, 0);
7588
}
7689
_ if !is_noisy && depth == 0 => {
@@ -82,14 +95,16 @@ pub fn alpha_beta(
8295
nodes,
8396
deadline,
8497
engine_state,
85-
current_depth-1,
98+
current_depth - 1,
8699
true,
87100
current_depth + 1,
88101
max_depth_reached
89102
);
90103
}
91104
_ if is_noisy && (prioritized_moves.is_empty() || depth == 0) => {
92105
let eval = evaluate_board(board);
106+
engine_state.transposition_table.lock().unwrap().put((board_hash, depth), eval);
107+
93108
return (None, eval);
94109
}
95110
_ => {}
@@ -148,5 +163,7 @@ pub fn alpha_beta(
148163
}
149164
}
150165

166+
engine_state.transposition_table.lock().unwrap().put((board_hash, depth), best_eval);
167+
151168
(best_move, best_eval)
152169
}

src/bot/algorithm/root.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ pub fn search(
4848
engine_state,
4949
depth,
5050
false,
51-
1,
51+
0,
5252
&mut max_depth
5353
);
5454

src/bot/include/map.rs

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,56 +2,49 @@ use std::{ fs, io::{ self, Write }, path::Path, sync::Arc };
22

33
use flate2::read::GzDecoder;
44
use once_cell::sync::Lazy;
5-
use serde_json::Value;
65
use tar::Archive;
76

8-
use crate::bot::include::types::GlobalMap;
7+
use crate::bot::include::types::{ GlobalMap, OpeningBook };
98

109
const COMPRESSED_OPENING_DB: &[u8] = include_bytes!("../../data/openings.tar.gz");
1110

12-
fn read_opening_db() -> Result<Value, io::Error> {
11+
pub fn read_opening_db() -> Result<OpeningBook, io::Error> {
1312
let output_dir = Path::new("./db");
1413
let compressed_path = output_dir.join("openings.tar.gz");
1514
let file_path = output_dir.join("openingDB.json");
1615

17-
// If file doesn't exist, extract from embedded archive
1816
if !file_path.exists() {
1917
println!("OpeningDB not found, extracting...");
2018

2119
fs::create_dir_all(output_dir)?;
2220

23-
// Write embedded tar.gz to disk
2421
{
2522
let mut file = fs::File::create(&compressed_path)?;
2623
file.write_all(COMPRESSED_OPENING_DB)?;
2724
}
2825

29-
// Extract tar.gz
3026
let tar_file = fs::File::open(&compressed_path)?;
3127
let tar = GzDecoder::new(tar_file);
3228
let mut archive = Archive::new(tar);
3329
archive.unpack(output_dir)?;
3430

35-
// Clean up
3631
fs::remove_file(&compressed_path)?;
3732
println!("OpeningDB extracted to {:?}", file_path);
3833
}
3934

40-
// Read and parse JSON
4135
let file_content = fs::read_to_string(&file_path)?;
42-
let json_data: Value = serde_json
36+
let db: OpeningBook = serde_json
4337
::from_str(&file_content)
4438
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
4539

46-
Ok(json_data)
40+
Ok(db)
4741
}
48-
49-
pub const OPENING_DB: Lazy<Arc<Value>> = Lazy::new(|| {
42+
pub static OPENING_DB: Lazy<Arc<OpeningBook>> = Lazy::new(|| {
5043
Arc::new(read_opening_db().expect("Failed to load opening DB"))
5144
});
5245

5346
impl GlobalMap {
54-
pub fn opening_db() -> Arc<Value> {
47+
pub fn opening_db() -> Arc<OpeningBook> {
5548
Arc::clone(&OPENING_DB)
5649
}
5750

src/bot/include/types.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
use chess::Board;
2+
use serde::Deserialize;
23
use std::collections::HashMap;
3-
use std::sync::Arc;
4+
use std::sync::{ Arc, Mutex };
45
use dashmap::DashMap;
6+
use lru::LruCache;
57

68
#[derive(Debug, Clone)]
79
pub struct Statistics {
@@ -16,6 +18,7 @@ pub struct EngineState {
1618
pub history: HashMap<u64, u32>,
1719
pub statistics: HashMap<u64, Statistics>,
1820
pub global_map: Arc<GlobalMap>,
21+
pub transposition_table: Arc<Mutex<LruCache<(u64, u8), i32>>>,
1922
}
2023

2124
#[derive(Debug)]
@@ -35,3 +38,9 @@ pub enum SpecialMove {
3538
Promotion,
3639
EnPassant,
3740
}
41+
42+
pub const TT_TABLE_SIZE: usize = 100_1000;
43+
44+
#[derive(Debug, Clone, Deserialize)]
45+
pub struct OpeningEntry(pub String, pub u32);
46+
pub type OpeningBook = HashMap<u64, Vec<OpeningEntry>>;

src/bot/util/board.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,13 @@ impl BoardExt for Board {
124124
tactical_bonus += 40;
125125
}
126126

127-
return tactical_bonus + capture_value_sum / 10;
127+
if let Some(piece) = self.piece_on(mv.get_source()) {
128+
if piece == Piece::Pawn {
129+
tactical_bonus += 10;
130+
}
131+
}
132+
133+
tactical_bonus + capture_value_sum / 10
128134
}
129135

130136
fn halfmove_clock(&self) -> u32 {

src/bot/util/lookup.rs

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,20 @@ use chess::{ Board, ChessMove };
22
use rand::seq::SliceRandom;
33
use rand::thread_rng;
44

5-
use crate::bot::include::types::GlobalMap;
5+
use crate::bot::include::map::OPENING_DB;
6+
use crate::bot::include::types::OpeningEntry;
67
use crate::bot::util::moves::parse_uci_move;
78

89
pub fn lookup_opening_db(board: &Board) -> Option<ChessMove> {
9-
let board_hash_str = board.get_hash().to_string();
10-
if let Some(opening_db) = GlobalMap::opening_db().as_object() {
11-
if let Some(entry_array) = opening_db.get(&board_hash_str).and_then(|v| v.as_array()) {
12-
let mut rng = thread_rng();
13-
if let Some(random_entry) = entry_array.choose(&mut rng) {
14-
if let Some(uci_str) = random_entry.get(0).and_then(|v| v.as_str()) {
15-
if let Some(chess_move) = parse_uci_move(uci_str, board) {
16-
println!("{}", chess_move);
17-
return Some(chess_move);
18-
}
19-
}
20-
}
10+
let board_hash = board.get_hash();
11+
let db = &OPENING_DB;
12+
13+
if let Some(entries) = db.get(&board_hash) {
14+
let mut rng = thread_rng();
15+
if let Some(OpeningEntry(uci_str, _weight)) = entries.choose(&mut rng) {
16+
return parse_uci_move(uci_str, board);
2117
}
2218
}
19+
2320
None
2421
}

0 commit comments

Comments
 (0)