From 0faa8cdba9487923ed78a2375b1a20cdf87f9778 Mon Sep 17 00:00:00 2001 From: Sotetsu KOYAMADA Date: Tue, 22 Oct 2024 17:49:19 +0900 Subject: [PATCH 1/8] use int8 --- pgx/_src/games/chess.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/pgx/_src/games/chess.py b/pgx/_src/games/chess.py index e16df4bf3..2d6c891db 100644 --- a/pgx/_src/games/chess.py +++ b/pgx/_src/games/chess.py @@ -25,7 +25,7 @@ # prepare precomputed values here (e.g., available moves, map to label, etc.) # index: a1: 0, a2: 1, ..., h8: 63 -INIT_BOARD = jnp.int32([4, 1, 0, 0, 0, 0, -1, -4, 2, 1, 0, 0, 0, 0, -1, -2, 3, 1, 0, 0, 0, 0, -1, -3, 5, 1, 0, 0, 0, 0, -1, -5, 6, 1, 0, 0, 0, 0, -1, -6, 3, 1, 0, 0, 0, 0, -1, -3, 2, 1, 0, 0, 0, 0, -1, -2, 4, 1, 0, 0, 0, 0, -1, -4]) # fmt: skip +INIT_BOARD = jnp.int8([4, 1, 0, 0, 0, 0, -1, -4, 2, 1, 0, 0, 0, 0, -1, -2, 3, 1, 0, 0, 0, 0, -1, -3, 5, 1, 0, 0, 0, 0, -1, -5, 6, 1, 0, 0, 0, 0, -1, -6, 3, 1, 0, 0, 0, 0, -1, -3, 2, 1, 0, 0, 0, 0, -1, -2, 4, 1, 0, 0, 0, 0, -1, -4]) # fmt: skip # 8 7 15 23 31 39 47 55 63 # 7 6 14 22 30 38 46 54 62 # 6 5 13 21 29 37 45 53 61 @@ -135,6 +135,7 @@ ZOBRIST_CASTLING = jax.random.randint(keys[2], shape=(4, 2), minval=0, maxval=2**31 - 1, dtype=jnp.uint32) ZOBRIST_EN_PASSANT = jax.random.randint(keys[3], shape=(65, 2), minval=0, maxval=2**31 - 1, dtype=jnp.uint32) INIT_ZOBRIST_HASH = jnp.uint32([1455170221, 1478960862]) +EMPTY, PAWN, KNIGHT, BISHOP, ROOK, QUEEN, KING = [jnp.int8(i) for i in range(7)] # opponent: -1 * piece class GameState(NamedTuple): @@ -264,7 +265,7 @@ def _apply_move(state: GameState, a: Action) -> GameState: is_en_passant = (state.en_passant >= 0) & (piece == PAWN) & (state.en_passant == a.to) removed_pawn_pos = a.to - 1 state = state._replace( - board=state.board.at[removed_pawn_pos].set(lax.select(is_en_passant, EMPTY, state.board[removed_pawn_pos])) + board=state.board.at[removed_pawn_pos].set(lax.select(is_en_passant, EMPTY, jnp.int8(state.board[removed_pawn_pos]))) ) is_en_passant = (piece == PAWN) & (jnp.abs(a.to - a.from_) == 2) state = state._replace(en_passant=lax.select(is_en_passant, (a.to + a.from_) // 2, -1)) @@ -285,9 +286,9 @@ def _apply_move(state: GameState, a: Action) -> GameState: cond = jnp.bool_([[(a.from_ != 32) & (a.from_ != 0), (a.from_ != 32) & (a.from_ != 56)], [a.to != 7, a.to != 63]]) state = state._replace(castling_rights=state.castling_rights & cond) # promotion to queen - piece = lax.select((piece == PAWN) & (a.from_ % 8 == 6) & (a.underpromotion < 0), QUEEN, piece) + piece = lax.select((piece == PAWN) & (a.from_ % 8 == 6) & (a.underpromotion < 0), QUEEN, jnp.int8(piece)) # underpromotion - piece = lax.select(a.underpromotion < 0, piece, jnp.int32([ROOK, BISHOP, KNIGHT])[a.underpromotion]) + piece = lax.select(a.underpromotion < 0, piece, jnp.int8([ROOK, BISHOP, KNIGHT])[a.underpromotion]) # actually move state = state._replace(board=state.board.at[a.from_].set(EMPTY).at[a.to].set(piece)) # type: ignore return state From d75a100c64e77bceaab51a61c87cb6d4f379d632 Mon Sep 17 00:00:00 2001 From: Sotetsu KOYAMADA Date: Tue, 22 Oct 2024 17:50:27 +0900 Subject: [PATCH 2/8] . --- pgx/_src/games/chess.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgx/_src/games/chess.py b/pgx/_src/games/chess.py index 2d6c891db..e2f4e44f1 100644 --- a/pgx/_src/games/chess.py +++ b/pgx/_src/games/chess.py @@ -146,7 +146,7 @@ class GameState(NamedTuple): halfmove_count: Array = jnp.int32(0) # number of moves since the last piece capture or pawn move fullmove_count: Array = jnp.int32(1) # increase every black move hash_history: Array = jnp.zeros((MAX_TERMINATION_STEPS + 1, 2), dtype=jnp.uint32).at[0].set(INIT_ZOBRIST_HASH) - board_history: Array = jnp.zeros((8, 64), dtype=jnp.int32).at[0, :].set(INIT_BOARD) + board_history: Array = jnp.zeros((8, 64), dtype=jnp.int8).at[0, :].set(INIT_BOARD) legal_action_mask: Array = INIT_LEGAL_ACTION_MASK step_count: Array = jnp.int32(0) From 8942f04a9701e5174b19ab3303c92949cc7de9eb Mon Sep 17 00:00:00 2001 From: Sotetsu KOYAMADA Date: Tue, 22 Oct 2024 17:51:39 +0900 Subject: [PATCH 3/8] . --- pgx/_src/games/chess.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgx/_src/games/chess.py b/pgx/_src/games/chess.py index e2f4e44f1..8ba682f56 100644 --- a/pgx/_src/games/chess.py +++ b/pgx/_src/games/chess.py @@ -139,7 +139,7 @@ class GameState(NamedTuple): - color: Array = jnp.int32(0) # w: 0, b: 1 + color: Array = jnp.int8(0) # w: 0, b: 1 board: Array = INIT_BOARD # (64,) castling_rights: Array = jnp.ones([2, 2], dtype=jnp.bool_) # my queen, my king, opp queen, opp king en_passant: Array = jnp.int32(-1) From 5ed21cc8507d3d4543ef8069a99eab9a5d571489 Mon Sep 17 00:00:00 2001 From: Sotetsu KOYAMADA Date: Tue, 22 Oct 2024 18:34:20 +0900 Subject: [PATCH 4/8] enpassant --- pgx/_src/games/chess.py | 4 ++-- pgx/experimental/chess.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pgx/_src/games/chess.py b/pgx/_src/games/chess.py index 8ba682f56..9b1aeaef9 100644 --- a/pgx/_src/games/chess.py +++ b/pgx/_src/games/chess.py @@ -142,7 +142,7 @@ class GameState(NamedTuple): color: Array = jnp.int8(0) # w: 0, b: 1 board: Array = INIT_BOARD # (64,) castling_rights: Array = jnp.ones([2, 2], dtype=jnp.bool_) # my queen, my king, opp queen, opp king - en_passant: Array = jnp.int32(-1) + en_passant: Array = jnp.int8(-1) halfmove_count: Array = jnp.int32(0) # number of moves since the last piece capture or pawn move fullmove_count: Array = jnp.int32(1) # increase every black move hash_history: Array = jnp.zeros((MAX_TERMINATION_STEPS + 1, 2), dtype=jnp.uint32).at[0].set(INIT_ZOBRIST_HASH) @@ -268,7 +268,7 @@ def _apply_move(state: GameState, a: Action) -> GameState: board=state.board.at[removed_pawn_pos].set(lax.select(is_en_passant, EMPTY, jnp.int8(state.board[removed_pawn_pos]))) ) is_en_passant = (piece == PAWN) & (jnp.abs(a.to - a.from_) == 2) - state = state._replace(en_passant=lax.select(is_en_passant, (a.to + a.from_) // 2, -1)) + state = state._replace(en_passant=lax.select(is_en_passant, jnp.int8((a.to + a.from_) // 2), jnp.int8(-1))) # update counters captured = (state.board[a.to] < 0) | is_en_passant state = state._replace( diff --git a/pgx/experimental/chess.py b/pgx/experimental/chess.py index 36313d069..36cfef760 100644 --- a/pgx/experimental/chess.py +++ b/pgx/experimental/chess.py @@ -55,7 +55,7 @@ def from_fen(fen: str): mat = jnp.int32(arr).reshape(8, 8) if color == "b": mat = -jnp.flip(mat, axis=0) - ep = jnp.int32(-1) if en_passant == "-" else jnp.int32("abcdefgh".index(en_passant[0]) * 8 + int(en_passant[1]) - 1) + ep = jnp.int8(-1) if en_passant == "-" else jnp.int8("abcdefgh".index(en_passant[0]) * 8 + int(en_passant[1]) - 1) if color == "b" and ep >= 0: ep = _flip_pos(ep) x = GameState( From fd884b09791eeddd7c58b6198bc58d29bf4783a8 Mon Sep 17 00:00:00 2001 From: Sotetsu KOYAMADA Date: Tue, 22 Oct 2024 18:40:49 +0900 Subject: [PATCH 5/8] . --- pgx/_src/games/chess.py | 24 ++++++++++++------------ pgx/experimental/chess.py | 6 +++--- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/pgx/_src/games/chess.py b/pgx/_src/games/chess.py index 9b1aeaef9..0debc6f54 100644 --- a/pgx/_src/games/chess.py +++ b/pgx/_src/games/chess.py @@ -56,7 +56,7 @@ # 39 11 62 # 38 10 64 # 37 9 64 -FROM_PLANE = -np.ones((64, 73), dtype=np.int32) +FROM_PLANE = -np.ones((64, 73), dtype=np.int8) TO_PLANE = -np.ones((64, 64), dtype=np.int32) # ignores underpromotion zeros, seq, rseq = [0] * 7, list(range(1, 8)), list(range(-7, 0)) # down, up, left, right, down-left, down-right, up-right, up-left, knight, and knight @@ -82,9 +82,9 @@ ixs = [89, 90, 652, 656, 673, 674, 1257, 1258, 1841, 1842, 2425, 2426, 3009, 3010, 3572, 3576, 3593, 3594, 4177, 4178] INIT_LEGAL_ACTION_MASK[ixs] = True -LEGAL_DEST = -np.ones((7, 64, 27), np.int32) # LEGAL_DEST[0, :, :] == -1 -LEGAL_DEST_NEAR = -np.ones((64, 16), np.int32) -LEGAL_DEST_FAR = -np.ones((64, 19), np.int32) +LEGAL_DEST = -np.ones((7, 64, 27), np.int8) # LEGAL_DEST[0, :, :] == -1 +LEGAL_DEST_NEAR = -np.ones((64, 16), np.int8) +LEGAL_DEST_FAR = -np.ones((64, 19), np.int8) CAN_MOVE = np.zeros((7, 64, 64), dtype=np.bool_) for from_ in range(64): legal_dest = {p: [] for p in range(7)} @@ -112,7 +112,7 @@ dests = list(set(legal_dest[QUEEN]).difference(set(legal_dest[KING]))) LEGAL_DEST_FAR[from_, : len(dests)] = dests -BETWEEN = -np.ones((64, 64, 6), dtype=np.int32) +BETWEEN = -np.ones((64, 64, 6), dtype=np.int8) for from_ in range(64): for to in range(64): r0, c0, r1, c1 = from_ % 8, from_ // 8, to % 8, to // 8 @@ -152,18 +152,18 @@ class GameState(NamedTuple): class Action(NamedTuple): - from_: Array = jnp.int32(-1) - to: Array = jnp.int32(-1) - underpromotion: Array = jnp.int32(-1) # 0: rook, 1: bishop, 2: knight + from_: Array = jnp.int8(-1) + to: Array = jnp.int8(-1) + underpromotion: Array = jnp.int8(-1) # 0: rook, 1: bishop, 2: knight @staticmethod def _from_label(label: Array): - from_, plane = label // 73, label % 73 - underpromotion = lax.select(plane >= 9, -1, plane // 3) + from_, plane = jnp.int8(label // 73), jnp.int8(label % 73) + underpromotion = lax.select(plane >= 9, jnp.int8(-1), plane // 3) return Action(from_=from_, to=FROM_PLANE[from_, plane], underpromotion=underpromotion) def _to_label(self): - return self.from_ * 73 + TO_PLANE[self.from_, self.to] + return jnp.int32(self.from_) * 73 + TO_PLANE[self.from_, self.to] class Game: @@ -362,7 +362,7 @@ def legal_labels(label): can_castle_queen_side &= (b[0] == ROOK) & (b[8] == EMPTY) & (b[16] == EMPTY) & (b[24] == EMPTY) & (b[32] == KING) can_castle_king_side = state.castling_rights[0, 1] can_castle_king_side &= (b[32] == KING) & (b[40] == EMPTY) & (b[48] == EMPTY) & (b[56] == ROOK) - not_checked = ~jax.vmap(_is_attacked, in_axes=(None, 0))(state, jnp.int32([16, 24, 32, 40, 48])) + not_checked = ~jax.vmap(_is_attacked, in_axes=(None, 0))(state, jnp.int8([16, 24, 32, 40, 48])) mask = mask.at[2364].set(mask[2364] | (can_castle_queen_side & not_checked[:3].all())) mask = mask.at[2367].set(mask[2367] | (can_castle_king_side & not_checked[2:].all())) diff --git a/pgx/experimental/chess.py b/pgx/experimental/chess.py index 36cfef760..55b01fdc6 100644 --- a/pgx/experimental/chess.py +++ b/pgx/experimental/chess.py @@ -52,7 +52,7 @@ def from_fen(fen: str): castling_rights = jnp.bool_([["Q" in castling, "K" in castling], ["q" in castling, "k" in castling]]) if color == "b": castling_rights = castling_rights[::-1] - mat = jnp.int32(arr).reshape(8, 8) + mat = jnp.int8(arr).reshape(8, 8) if color == "b": mat = -jnp.flip(mat, axis=0) ep = jnp.int8(-1) if en_passant == "-" else jnp.int8("abcdefgh".index(en_passant[0]) * 8 + int(en_passant[1]) - 1) @@ -60,7 +60,7 @@ def from_fen(fen: str): ep = _flip_pos(ep) x = GameState( board=jnp.rot90(mat, k=3).flatten(), - color=jnp.int32(0) if color == "w" else jnp.int32(1), + color=jnp.int8(0) if color == "w" else jnp.int8(1), castling_rights=castling_rights, en_passant=ep, halfmove_count=jnp.int32(halfmove_cnt), @@ -70,7 +70,7 @@ def from_fen(fen: str): x = x._replace(legal_action_mask=legal_action_mask) x = _update_history(x) - player_order = jnp.int32([0, 1]) + player_order = jnp.int8([0, 1]) state = State( _player_order=player_order, _x=x, From b9706bc285d4515675ae65e6999a49eabb5cc059 Mon Sep 17 00:00:00 2001 From: Sotetsu KOYAMADA Date: Tue, 22 Oct 2024 18:54:21 +0900 Subject: [PATCH 6/8] . --- pgx/_src/games/chess.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgx/_src/games/chess.py b/pgx/_src/games/chess.py index 0debc6f54..c4e87a2b6 100644 --- a/pgx/_src/games/chess.py +++ b/pgx/_src/games/chess.py @@ -268,7 +268,7 @@ def _apply_move(state: GameState, a: Action) -> GameState: board=state.board.at[removed_pawn_pos].set(lax.select(is_en_passant, EMPTY, jnp.int8(state.board[removed_pawn_pos]))) ) is_en_passant = (piece == PAWN) & (jnp.abs(a.to - a.from_) == 2) - state = state._replace(en_passant=lax.select(is_en_passant, jnp.int8((a.to + a.from_) // 2), jnp.int8(-1))) + state = state._replace(en_passant=lax.select(is_en_passant, (a.to + a.from_) // 2, jnp.int8(-1))) # update counters captured = (state.board[a.to] < 0) | is_en_passant state = state._replace( From 15096bd0fdf78e6d08e6f948aa274cb1db3aa594 Mon Sep 17 00:00:00 2001 From: Sotetsu KOYAMADA Date: Tue, 22 Oct 2024 21:00:59 +0900 Subject: [PATCH 7/8] tidy --- pgx/_src/games/chess.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/pgx/_src/games/chess.py b/pgx/_src/games/chess.py index c4e87a2b6..a07f3a00b 100644 --- a/pgx/_src/games/chess.py +++ b/pgx/_src/games/chess.py @@ -19,7 +19,7 @@ import numpy as np from jax import Array, lax -EMPTY, PAWN, KNIGHT, BISHOP, ROOK, QUEEN, KING = tuple(range(7)) # opponent: -1 * piece +EMPTY, PAWN, KNIGHT, BISHOP, ROOK, QUEEN, KING = [jnp.int8(i) for i in range(7)] # opponent: -1 * piece MAX_TERMINATION_STEPS = 512 # from AlphaZero paper # prepare precomputed values here (e.g., available moves, map to label, etc.) @@ -93,23 +93,23 @@ continue r0, c0, r1, c1 = from_ % 8, from_ // 8, to % 8, to // 8 if (r1 - r0 == 1 and abs(c1 - c0) <= 1) or ((r0, r1) == (1, 3) and abs(c1 - c0) == 0): - legal_dest[PAWN].append(to) + legal_dest[PAWN.item()].append(to) if (abs(r1 - r0) == 1 and abs(c1 - c0) == 2) or (abs(r1 - r0) == 2 and abs(c1 - c0) == 1): - legal_dest[KNIGHT].append(to) + legal_dest[KNIGHT.item()].append(to) if abs(r1 - r0) == abs(c1 - c0): - legal_dest[BISHOP].append(to) + legal_dest[BISHOP.item()].append(to) if abs(r1 - r0) == 0 or abs(c1 - c0) == 0: - legal_dest[ROOK].append(to) + legal_dest[ROOK.item()].append(to) if (abs(r1 - r0) == 0 or abs(c1 - c0) == 0) or (abs(r1 - r0) == abs(c1 - c0)): - legal_dest[QUEEN].append(to) + legal_dest[QUEEN.item()].append(to) if from_ != to and abs(r1 - r0) <= 1 and abs(c1 - c0) <= 1: - legal_dest[KING].append(to) + legal_dest[KING.item()].append(to) for p in range(1, 7): LEGAL_DEST[p, from_, : len(legal_dest[p])] = legal_dest[p] CAN_MOVE[p, from_, legal_dest[p]] = True - dests = list(set(legal_dest[KING]) | set(legal_dest[KNIGHT])) + dests = list(set(legal_dest[KING.item()]) | set(legal_dest[KNIGHT.item()])) LEGAL_DEST_NEAR[from_, : len(dests)] = dests - dests = list(set(legal_dest[QUEEN]).difference(set(legal_dest[KING]))) + dests = list(set(legal_dest[QUEEN.item()]).difference(set(legal_dest[KING.item()]))) LEGAL_DEST_FAR[from_, : len(dests)] = dests BETWEEN = -np.ones((64, 64, 6), dtype=np.int8) @@ -135,7 +135,6 @@ ZOBRIST_CASTLING = jax.random.randint(keys[2], shape=(4, 2), minval=0, maxval=2**31 - 1, dtype=jnp.uint32) ZOBRIST_EN_PASSANT = jax.random.randint(keys[3], shape=(65, 2), minval=0, maxval=2**31 - 1, dtype=jnp.uint32) INIT_ZOBRIST_HASH = jnp.uint32([1455170221, 1478960862]) -EMPTY, PAWN, KNIGHT, BISHOP, ROOK, QUEEN, KING = [jnp.int8(i) for i in range(7)] # opponent: -1 * piece class GameState(NamedTuple): From a1bdfbda7e83bd924f9e0cc2e91d94ffe239979e Mon Sep 17 00:00:00 2001 From: Sotetsu KOYAMADA Date: Tue, 22 Oct 2024 21:09:40 +0900 Subject: [PATCH 8/8] halfmovecnt --- pgx/_src/games/chess.py | 4 ++-- pgx/experimental/chess.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pgx/_src/games/chess.py b/pgx/_src/games/chess.py index a07f3a00b..1c17ab61b 100644 --- a/pgx/_src/games/chess.py +++ b/pgx/_src/games/chess.py @@ -142,7 +142,7 @@ class GameState(NamedTuple): board: Array = INIT_BOARD # (64,) castling_rights: Array = jnp.ones([2, 2], dtype=jnp.bool_) # my queen, my king, opp queen, opp king en_passant: Array = jnp.int8(-1) - halfmove_count: Array = jnp.int32(0) # number of moves since the last piece capture or pawn move + halfmove_count: Array = jnp.int8(0) # number of moves since the last piece capture or pawn move fullmove_count: Array = jnp.int32(1) # increase every black move hash_history: Array = jnp.zeros((MAX_TERMINATION_STEPS + 1, 2), dtype=jnp.uint32).at[0].set(INIT_ZOBRIST_HASH) board_history: Array = jnp.zeros((8, 64), dtype=jnp.int8).at[0, :].set(INIT_BOARD) @@ -271,7 +271,7 @@ def _apply_move(state: GameState, a: Action) -> GameState: # update counters captured = (state.board[a.to] < 0) | is_en_passant state = state._replace( - halfmove_count=lax.select(captured | (piece == PAWN), 0, state.halfmove_count + 1), + halfmove_count=lax.select(captured | (piece == PAWN), jnp.int8(0), state.halfmove_count + 1), fullmove_count=state.fullmove_count + jnp.int32(state.color == 1), ) # castling diff --git a/pgx/experimental/chess.py b/pgx/experimental/chess.py index 55b01fdc6..1a2d3e1c7 100644 --- a/pgx/experimental/chess.py +++ b/pgx/experimental/chess.py @@ -63,7 +63,7 @@ def from_fen(fen: str): color=jnp.int8(0) if color == "w" else jnp.int8(1), castling_rights=castling_rights, en_passant=ep, - halfmove_count=jnp.int32(halfmove_cnt), + halfmove_count=jnp.int8(halfmove_cnt), fullmove_count=jnp.int32(fullmove_cnt), ) legal_action_mask = jax.jit(_legal_action_mask)(x)