alphabeta_search: replace recursion with explicit stack

This commit is contained in:
2025-12-21 00:12:52 +01:00
parent 97b6934024
commit ef33b23f70

433
engine.h
View File

@@ -7,7 +7,7 @@
#ifndef NDEBUG #ifndef NDEBUG
#define assert(expr) \ #define assert(expr) \
((expr) ? 0 : (__builtin_trap(), 0)) ((expr) ? 0 : (__builtin_trap(), 0))
#else #else
#define assert(...) (void)0 #define assert(...) (void)0
#endif #endif
@@ -1520,7 +1520,7 @@ static inline void tt_insert(struct tt* tt, uint64_t hash, struct search_option
enum move_result { enum move_result {
MR_NORMAL, MR_NORMAL,
MR_CHECK, MR_CHECK,
MR_REPEATS, /* this board state has been observed before */ MR_REPEATS, /* this board state has been observed before */
MR_STALEMATE, MR_STALEMATE,
MR_CHECKMATE, MR_CHECKMATE,
}; };
@@ -1666,7 +1666,7 @@ static enum move_result board_move(struct pos* restrict pos,
pos->halfmoves += 1; pos->halfmoves += 1;
assert(hist->length < 64); assert(hist->length < 64);
int repetitions = 0; int repetitions = 0;
for (size_t i = 0; i < hist->length; ++i) { for (size_t i = 0; i < hist->length; ++i) {
_Static_assert(sizeof *pos == sizeof hist->items[i]); _Static_assert(sizeof *pos == sizeof hist->items[i]);
if (!my_memcmp(&hist->items[i].pieces, &pos->pieces, sizeof pos->pieces) if (!my_memcmp(&hist->items[i].pieces, &pos->pieces, sizeof pos->pieces)
@@ -1674,26 +1674,26 @@ static enum move_result board_move(struct pos* restrict pos,
&& hist->items[i].player == pos->player && hist->items[i].player == pos->player
&& hist->items[i].ep_targets == pos->ep_targets) && hist->items[i].ep_targets == pos->ep_targets)
{ {
repetitions += 1; repetitions += 1;
} }
} }
hist->items[hist->length++] = *pos; hist->items[hist->length++] = *pos;
if (repetitions >= 3 || pos->halfmoves > 50) { if (repetitions >= 3 || pos->halfmoves > 50) {
return MR_STALEMATE;
}
else if (repetitions > 0) {
return MR_REPEATS;
}
else if (pos->halfmoves > 50) {
return MR_STALEMATE; return MR_STALEMATE;
} }
else if (attacks_to(pos, pos->pieces[them][PIECE_KING], 0ULL, 0ULL) else if (repetitions > 0) {
& ~pos->occupied[them]) { return MR_REPEATS;
}
else if (pos->halfmoves > 50) {
return MR_STALEMATE;
}
else if (attacks_to(pos, pos->pieces[them][PIECE_KING], 0ULL, 0ULL)
& ~pos->occupied[them]) {
return MR_CHECK; return MR_CHECK;
} }
else { else {
return MR_NORMAL; return MR_NORMAL;
} }
} }
@@ -1814,7 +1814,7 @@ double quiesce(struct pos const* pos,
size_t move_count = 0; size_t move_count = 0;
struct move moves[MOVE_MAX]; struct move moves[MOVE_MAX];
all_moves(pos, us, &move_count, moves); all_moves(pos, us, &move_count, moves);
if (move_count == 0) { if (move_count == 0) {
@@ -1833,7 +1833,7 @@ double quiesce(struct pos const* pos,
} }
struct pos poscpy = *pos; struct pos poscpy = *pos;
enum piece mailbox_cpy[SQ_INDEX_COUNT]; enum piece mailbox_cpy[SQ_INDEX_COUNT];
my_memcpy(mailbox_cpy, mailbox, sizeof (enum piece[SQ_INDEX_COUNT])); my_memcpy(mailbox_cpy, mailbox, sizeof (enum piece[SQ_INDEX_COUNT]));
@@ -1859,6 +1859,67 @@ double quiesce(struct pos const* pos,
return highscore; return highscore;
} }
struct ab_frame {
enum search_stage {
ST_INIT = 0,
ST_LOOP = 1,
ST_WAIT_CHILD = 2,
} stage;
struct pos pos;
enum piece mailbox[SQ_INDEX_COUNT];
enum player us;
int8_t depth;
uint64_t mattr_filter;
double alpha;
double beta;
double alpha_orig;
struct search_option tte;
struct move moves[MOVE_MAX];
size_t move_count;
double best_score;
struct move best_move;
size_t old_hist_length;
struct move pending_move;
struct search_option result;
};
static inline void ab_update_parent_after_score(struct ab_frame *parent,
struct tt *tt,
struct move m,
double score)
{
if (score > parent->best_score) {
parent->best_score = score;
parent->best_move = m;
}
if (score > parent->alpha) {
parent->alpha = score;
}
if (parent->alpha >= parent->beta) {
struct search_option out = {
.score = parent->alpha,
.move = parent->best_move,
.depth = parent->depth,
.hash = parent->pos.hash,
.init = true,
.flag = TT_LOWER,
};
tt_insert(tt, parent->pos.hash, out);
parent->result = out;
/* mark parent as complete so it will be popped in st_wait_child. */
parent->stage = ST_WAIT_CHILD;
}
}
static static
struct search_option alphabeta_search(struct pos const* pos, struct search_option alphabeta_search(struct pos const* pos,
@@ -1871,139 +1932,231 @@ struct search_option alphabeta_search(struct pos const* pos,
double alpha, double alpha,
double beta) double beta)
{ {
if (depth <= 0) { struct {
return (struct search_option) { struct ab_frame stack[50];
.score = quiesce(pos, mailbox, us, alpha, beta, 0), size_t len;
/*.score = board_score_heuristic(pos),*/ } ab_stack = {0};
.move = (struct move){0},
.depth = 0,
.hash = pos->hash,
.init = true,
.flag = TT_EXACT,
};
}
double const alpha_orig = alpha; #define STACK_PUSH(s, x) \
do { \
assert(s.len < sizeof s.stack / sizeof s.stack[0]); \
(s.stack)[(s.len)++] = (x); \
} while (0)
#define STACK_TOP(s) &(s.stack)[(s.len) - 1]
#define STACK_POP(s) (assert(s.len > 0), s.stack[--(s.len)]);
#define STACK_EMPTY(s) (s.len == 0)
struct search_option tte = tt_get(tt, pos->hash); struct ab_frame root = {
.stage = ST_INIT,
.pos = *pos,
.us = us,
.depth = depth,
.mattr_filter = mattr_filter,
.alpha = alpha,
.beta = beta
};
my_memcpy(root.mailbox, mailbox, sizeof root.mailbox);
if (tte.init && tte.hash == pos->hash && tte.depth >= depth) { STACK_PUSH(ab_stack, root);
if (tte.flag == TT_EXACT) {
return tte;
} else if (tte.flag == TT_LOWER) {
if (tte.score > alpha) alpha = tte.score;
} else if (tte.flag == TT_UPPER) {
if (tte.score < beta) beta = tte.score;
}
if (alpha >= beta) {
return tte;
}
}
struct move moves[MOVE_MAX];
size_t move_count = 0ULL;
all_moves(pos, us, &move_count, moves);
if (move_count == 0) { while (1) {
/* TODO: reusing mate distances correctly needs ply normalization */ struct ab_frame *fr = STACK_TOP(ab_stack);
double score = 0;
if (attacks_to(pos, pos->pieces[us][PIECE_KING], 0ULL, 0ULL) != 0ULL) {
score = -(999.0 + (double)depth);
}
return (struct search_option) {
.score = score,
.move = (struct move){0},
.depth = depth,
.hash = pos->hash,
.init = true,
.flag = TT_EXACT,
};
}
for (size_t i = 0; i < move_count; ++i) { switch (fr->stage) {
move_compute_appeal(&moves[i], pos, us, mailbox);
}
/* if TT had a best move for this position, search it first. */ case ST_INIT: {
if (tte.init && tte.hash == pos->hash) { if (fr->depth <= 0) {
for (size_t i = 0; i < move_count; ++i) { fr->result = (struct search_option) {
if (moves[i].from == tte.move.from && moves[i].to == tte.move.to) { .score = quiesce(&fr->pos, fr->mailbox, fr->us,
moves[i].appeal = APPEAL_MAX; fr->alpha, fr->beta, 0),
.move = (struct move){0},
.depth = 0,
.hash = fr->pos.hash,
.init = true,
.flag = TT_EXACT,
};
fr->stage = ST_WAIT_CHILD;
break; break;
} }
fr->alpha_orig = fr->alpha;
fr->tte = tt_get(tt, fr->pos.hash);
if (fr->tte.init && fr->tte.hash == fr->pos.hash && fr->tte.depth >= fr->depth) {
if (fr->tte.flag == TT_EXACT) {
fr->result = fr->tte;
fr->stage = ST_WAIT_CHILD;
break;
} else if (fr->tte.flag == TT_LOWER) {
if (fr->tte.score > fr->alpha) {
fr->alpha = fr->tte.score;
}
} else if (fr->tte.flag == TT_UPPER) {
if (fr->tte.score < fr->beta) {
fr->beta = fr->tte.score;
}
}
if (fr->alpha >= fr->beta) {
fr->result = fr->tte;
fr->stage = ST_WAIT_CHILD;
break;
}
}
fr->move_count = 0;
all_moves(&fr->pos, fr->us, &fr->move_count, fr->moves);
/* checkmate or stalemate */
if (fr->move_count == 0) {
double score = 0.0;
if (attacks_to(&fr->pos,
fr->pos.pieces[fr->us][PIECE_KING],
0ULL, 0ULL) != 0ULL) {
score = -(999.0 + (double)fr->depth);
}
fr->result = (struct search_option) {
.score = score,
.move = (struct move){0},
.depth = fr->depth,
.hash = fr->pos.hash,
.init = true,
.flag = TT_EXACT,
};
fr->stage = ST_WAIT_CHILD;
break;
}
for (size_t i = 0; i < fr->move_count; ++i) {
move_compute_appeal(&fr->moves[i], &fr->pos, fr->us, fr->mailbox);
}
/* put existing TT entry first */
if (fr->tte.init && fr->tte.hash == fr->pos.hash) {
for (size_t i = 0; i < fr->move_count; ++i) {
if (fr->moves[i].from == fr->tte.move.from &&
fr->moves[i].to == fr->tte.move.to) {
fr->moves[i].appeal = APPEAL_MAX;
break;
}
}
}
fr->best_score = -1e300;
fr->best_move = fr->moves[0];
fr->stage = ST_LOOP;
} break;
case ST_LOOP: {
if (fr->result.init) {
fr->stage = ST_WAIT_CHILD;
break;
}
if (fr->move_count == 0) {
enum tt_flag flag = TT_EXACT;
if (fr->best_score <= fr->alpha_orig) flag = TT_UPPER;
fr->result = (struct search_option) {
.score = fr->best_score,
.move = fr->best_move,
.depth = fr->depth,
.hash = fr->pos.hash,
.init = true,
.flag = flag,
};
tt_insert(tt, fr->pos.hash, fr->result);
fr->stage = ST_WAIT_CHILD;
break;
}
struct move m = moves_linear_search(fr->moves, &fr->move_count);
if (fr->mattr_filter && !(m.attr & fr->mattr_filter)) {
break;
}
fr->old_hist_length = hist->length;
struct pos child_pos = fr->pos;
enum piece child_mailbox[SQ_INDEX_COUNT];
my_memcpy(child_mailbox, fr->mailbox, sizeof child_mailbox);
enum move_result const r = board_move(&child_pos, hist, child_mailbox, m);
if (r == MR_STALEMATE || r == MR_REPEATS) {
hist->length = fr->old_hist_length;
ab_update_parent_after_score(fr, tt, m, 0.0);
break;
}
if (fr->depth - 1 <= 0) {
double score = -quiesce(&child_pos,
child_mailbox,
opposite_player(fr->us),
-fr->beta,
-fr->alpha,
0);
hist->length = fr->old_hist_length;
ab_update_parent_after_score(fr, tt, m, score);
break;
}
fr->pending_move = m;
fr->stage = ST_WAIT_CHILD;
struct ab_frame child = {0};
child.stage = ST_INIT;
child.pos = child_pos;
my_memcpy(child.mailbox, child_mailbox, sizeof child.mailbox);
child.us = opposite_player(fr->us);
child.depth = fr->depth - 1;
child.mattr_filter = fr->mattr_filter;
child.alpha = -fr->beta;
child.beta = -fr->alpha;
STACK_PUSH(ab_stack, child);
} break;
case ST_WAIT_CHILD: {
struct search_option out = fr->result;
STACK_POP(ab_stack);
if (STACK_EMPTY(ab_stack)) {
return out;
}
struct ab_frame *parent = STACK_TOP(ab_stack);
if (parent->stage == ST_WAIT_CHILD) {
/* parent is waiting on a child (this frame), propagate */
double score = -out.score;
struct move m = parent->pending_move;
hist->length = parent->old_hist_length;
/* resume parent move loop. */
parent->stage = ST_LOOP;
ab_update_parent_after_score(parent, tt, m, score);
} else {
/* parent wasn't waiting: treat as a completed frame result being
propagated through a completion chain. */
parent->result = out;
parent->stage = ST_WAIT_CHILD;
}
} break;
default: {
/* unreachable */
assert(0);
__builtin_unreachable();
} break;
} }
} }
double best_score = -1e300;
struct move best_move = moves[0];
while (move_count) {
struct move m = moves_linear_search(moves, &move_count);
if (mattr_filter && !(m.attr & mattr_filter)) {
continue;
}
/* TODO: make lean apply/undo mechanism instead of copying */
struct pos poscpy = *pos;
enum piece mailbox_cpy[SQ_INDEX_COUNT];
my_memcpy(mailbox_cpy, mailbox, sizeof mailbox_cpy);
size_t old_hist_length = hist->length;
enum move_result const r = board_move(&poscpy, hist, mailbox_cpy, m);
double score;
if (r == MR_STALEMATE || r == MR_REPEATS) {
score = 0.0;
} else {
score = -alphabeta_search(&poscpy,
hist,
tt,
mailbox_cpy,
opposite_player(us),
depth - 1,
mattr_filter,
-beta,
-alpha).score;
}
hist->length = old_hist_length;
if (score > best_score) {
best_score = score;
best_move = m;
}
if (score > alpha) {
alpha = score;
}
if (alpha >= beta) {
struct search_option out = {
.score = alpha,
.move = best_move,
.depth = depth,
.hash = pos->hash,
.init = true,
.flag = TT_LOWER,
};
tt_insert(tt, pos->hash, out);
return out;
}
}
enum tt_flag flag = TT_EXACT;
if (best_score <= alpha_orig) flag = TT_UPPER;
struct search_option out = {
.score = best_score,
.move = best_move,
.depth = depth,
.hash = pos->hash,
.init = true,
.flag = flag,
};
tt_insert(tt, pos->hash, out);
return out;
} }
static struct search_result {struct move move; double score;} static struct search_result {struct move move; double score;}
@@ -2015,7 +2168,7 @@ search(struct board* b, enum player us, int8_t max_depth)
SYS_PROT_READ | SYS_PROT_WRITE, SYS_PROT_READ | SYS_PROT_WRITE,
SYS_MADV_RANDOM); SYS_MADV_RANDOM);
if (b->tt.entries == NULL) { if (b->tt.entries == NULL) {
__builtin_trap(); __builtin_trap();
} }
} }
#endif #endif