Gumbel-MCTS-CLI: 10x Faster Tree Search in Pure Python

The Problem
Vanilla MCTS with UCB1 wastes simulation budget exploring actions that are clearly sub-optimal. As the tree grows, most simulations go to branches that will never be selected. The algorithm corrects itself eventually, but by then the time budget is gone.
NEO built Gumbel-MCTS-CLI to implement the fix from the ICLR 2022 paper by Danihelka, Guez, Schrittwieser, and Silver. The result runs approximately 10x faster than vanilla UCT at the same simulation budget, with equivalent or better search quality, using NumPy only with no GPU required.
Three Changes from Vanilla MCTS
Gumbel-MCTS makes three targeted changes to the standard algorithm. Together they eliminate the wasted-simulation problem.
Gumbel top-k sampling replaces UCB exploration at the root. Instead of UCB scores, the algorithm adds Gumbel(0,1) noise to the log-prior probability of each action and selects the top-k candidates. This is statistically unbiased: under the right prior, the top-k actions are exactly the actions worth exploring. Budget goes to promising candidates from round one.
Sequential halving splits the simulation budget across log2(k) rounds. In each round, the bottom half of candidates is eliminated based on their current value estimates. By the final round, all remaining simulations go to the single most promising action. This is the technique from Successive Rejects adapted to tree search.
Completed-Q values fill in the value estimate for unvisited actions without running a rollout. The algorithm computes a weighted mean of visited siblings' Q-values, weighted by their priors. This makes early rounds cheap: you get a reasonable value estimate for every action before any simulation touches it.
The benchmark confirms the speedup is real. On the MaxTree environment with 200 simulations and 20 trials, vanilla MCTS takes 12.84ms per call. Gumbel-MCTS takes 1.31ms, a 9.80x speedup with a quality ratio of 1.13 (Gumbel finds slightly better solutions on average).
The Three APIs
The tool exposes three search functions with different tradeoffs.
search(env, state) is the simple API. Call it, get the best action. No extra data returned.
search_with_stats(env, state) returns three values: best action, root node, and elapsed time in seconds. The root node has value (Q-value at root), visit_count (total simulations run), and a full child tree. Use this when you need to inspect the search tree after the fact.
search_anytime(env, state) is a generator that yields a snapshot after every sequential halving round. Each snapshot contains the current best action, round number, remaining candidates, elapsed time, and the full root node. Stop the generator at any point to get the best action found so far.
# Budget-limited search: stop after 50ms
snapshot = None
for snapshot in agent.search_anytime(env, state):
if snapshot['elapsed_sec'] > 0.050:
break
best_action = snapshot['action']
The anytime API is important for real-time applications where you have a fixed time budget rather than a fixed simulation count.
Tree Visualization
The tool ships an ASCII tree visualizer and an action statistics table. These are the primary debugging tools when the search is not finding the expected action.
from gumbel_mcts.visualize import print_tree, format_action_table
best_action, root, elapsed = agent.search_with_stats(env, state)
print(print_tree(root, max_depth=3, max_children=5))
print(format_action_table(root, action_names={0: "UP", 1: "DOWN", 2: "LEFT", 3: "RIGHT"}))
The tree view shows visit count and Q-value for each node with a bar chart indicator. The action table shows visits, share percentage, Q-value, standard error, and prior for each root action. If the model's prior assigns 0.25 to each action (uniform) but visits are concentrated on action 1, that is Gumbel noise working correctly. If visits are still uniform after 200 simulations, the prior or the value function is flat and the search has no signal to exploit.
Three Benchmark Environments
All three environments implement the same four-method protocol: actions(), step(action), clone(), and random_rollout(). Any class implementing these four methods works as a drop-in environment for both GumbelMCTS and VanillaMCTS.
GridWorld is an 8x8 grid navigation task. The agent starts at (0,0) and must reach (N-1, N-1). Actions are UP, DOWN, LEFT, RIGHT. This tests search in a space with a clear optimal path.
MaxTree is a tree with depth 6 and branching factor 8. The agent selects actions to reach a leaf with the highest value. Since most leaves have low value, this tests the algorithm's ability to find the rare high-value path efficiently.
SequenceEnv builds a token sequence of length 5 from a vocabulary of 8 tokens. The score is a function of the sequence. This tests search in a combinatorial space without spatial structure.
Custom Environments
Any environment that implements the four required methods plugs in directly.
class MyEnv:
def actions(self): # list of valid actions
...
def step(self, action): # (next_state, reward, terminal)
...
def clone(self): # deep copy for tree simulation
...
def random_rollout(self, env, depth): # float value estimate
...
from gumbel_mcts.gumbel_mcts import GumbelMCTS
agent = GumbelMCTS(n_simulations=400, max_considered_actions=8)
best = agent.search(MyEnv(), initial_state)
The clone() method is required because MCTS simulates forward from each node, modifying state, then needs to restore the original state for the next simulation. The implementation must return a fully independent copy.
How to Build This with NEO
Open NEO in VS Code or Cursor and describe what you want to build. A good starting prompt for this project:
"Implement Gumbel-MCTS from the ICLR 2022 paper by Danihelka, Guez, Schrittwieser, and Silver in pure Python using NumPy only. The three key changes from vanilla UCT are: Gumbel top-k sampling at the root (add Gumbel(0,1) noise to log-prior probabilities and select top-k candidates), sequential halving (split simulation budget across log2(k) rounds, eliminating bottom half each round), and completed-Q values (fill unvisited action value estimates using prior-weighted sibling Q-values). Expose three APIs: search() returning the best action, search_with_stats() returning action/root node/elapsed time, and search_anytime() as a generator yielding snapshots after each sequential halving round for budget-limited stopping. Include three benchmark environments: 8x8 GridWorld, MaxTree (depth 6, branching factor 8), and SequenceEnv (5-token sequence, vocabulary 8). Any class implementing actions()/step()/clone()/random_rollout() works as a drop-in environment."
NEO generates the project structure and core implementation from that. From there you iterate ask it to add the ASCII tree visualizer and action statistics table showing visit counts, Q-values, and priors for debugging search behavior, add the CLI benchmark script that runs both Gumbel-MCTS and vanilla UCT side by side and prints a speedup table, or add the 88-test pytest suite covering all three APIs and environments. Each request builds on what's already there.
To run the finished project:
git clone https://github.com/dakshjain-1616/gumbel-mcts-cli
cd gumbel-mcts-cli
pip install -r requirements.txt
python benchmark_mcts.py --dry-run
The benchmark prints the full speedup table comparing Gumbel-MCTS and vanilla UCT side by side the default MaxTree run shows approximately 10x faster search at equivalent or better solution quality.
NEO built a pure-Python Gumbel-MCTS implementation from the ICLR 2022 paper that runs 10x faster than vanilla UCT with no GPU required and a simple four-method protocol for custom environments. See what else NEO ships at heyneo.com.
Try NEO in Your IDE
Install the NEO extension to bring AI-powered development directly into your workflow:
- VS Code: NEO in VS Code
- Cursor: Install NEO for Cursor →