from fastapi import FastAPI, Request from fastapi.responses import HTMLResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates import chess from datasets import load_dataset import pickle import time from pyroaring import BitMap def board_to_tokens(board): return [(board.piece_at(sq).symbol(), chess.square_name(sq)) for sq in chess.SQUARES if board.piece_at(sq)] def get_puzzle_positions(fen, moves_uci): positions = [] board = chess.Board(fen) board.push_uci(moves_uci.split()[0]) positions.append(board.copy()) for move_uci in moves_uci.split()[1:]: board.push_uci(move_uci) positions.append(board.copy()) return positions def load_index(path='chess_index.pkl'): with open(path, 'rb') as f: data = pickle.load(f) return data['index'], data['metadata'] def query_positions(index, metadata, query_tokens): result = index[query_tokens[0]].copy() if query_tokens[0] in index else BitMap() for token in query_tokens[1:]: if token in index: result &= index[token] else: return BitMap() return [(pos_id, metadata[pos_id]) for pos_id in result] dset = load_dataset("Lichess/chess-puzzles", split="train") index, metadata = load_index() app = FastAPI() app.mount("/static", StaticFiles(directory="static"), name="static") templates = Jinja2Templates(directory="templates") @app.get("/") def read_root(request: Request): return templates.TemplateResponse("index.html", {"request": request}) @app.post("/search") async def search(data: dict): start = time.time() board = chess.Board(data['fen']) query_tokens = board_to_tokens(board) matches = query_positions(index, metadata, query_tokens) seen_puzzles = {} for pos_id, (puzzle_row, move_idx) in matches: if puzzle_row not in seen_puzzles: seen_puzzles[puzzle_row] = (pos_id, move_idx) results = [] for puzzle_row, (pos_id, move_idx) in seen_puzzles.items(): row = dset[puzzle_row] positions = get_puzzle_positions(row['FEN'], row['Moves']) matched_board = positions[move_idx] results.append({ "PuzzleId": row['PuzzleId'], "FEN": matched_board.fen(), "Moves": row['Moves'], "Rating": row['Rating'], "Popularity": row['Popularity'], "Themes": row['Themes'], "MatchedMove": move_idx }) elapsed_ms = (time.time() - start) * 1000 return {"count": len(results), "results": results, "time_ms": elapsed_ms}