/*
 * Decompiled with CFR 0.152.
 */
package mage.player.ai;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import mage.abilities.Ability;
import mage.abilities.PlayLandAbility;
import mage.abilities.common.PassAbility;
import mage.cards.Card;
import mage.constants.PhaseStep;
import mage.constants.Zone;
import mage.game.Game;
import mage.game.combat.Combat;
import mage.game.turn.Step;
import mage.player.ai.MCTSNextActionFactory;
import mage.player.ai.MCTSPlayer;
import mage.player.ai.SimulatedPlayerMCTS;
import mage.players.Player;
import mage.util.RandomUtil;
import org.apache.log4j.Logger;

public class MCTSNode {
    public static final boolean USE_ACTION_CACHE = false;
    private static final double selectionCoefficient = Math.sqrt(2.0);
    private static final double passRatioTolerance = 0.0;
    private static final Logger logger = Logger.getLogger(MCTSNode.class);
    private int visits = 0;
    private int wins = 0;
    private MCTSNode parent;
    private final List<MCTSNode> children = new ArrayList<MCTSNode>();
    private Ability action;
    private Game game;
    private Combat combat;
    private final String stateValue;
    private final String fullStateValue;
    private UUID playerId;
    private boolean terminal = false;
    private UUID targetPlayer;
    private static int nodeCount;
    private static final ConcurrentHashMap<String, List<Ability>> playablesCache;
    private static final ConcurrentHashMap<String, List<List<UUID>>> attacksCache;
    private static final ConcurrentHashMap<String, List<List<List<UUID>>>> blocksCache;
    private static long playablesHit;
    private static long playablesMiss;
    private static long attacksHit;
    private static long attacksMiss;
    private static long blocksHit;
    private static long blocksMiss;

    public MCTSNode(UUID targetPlayer, Game game) {
        this.targetPlayer = targetPlayer;
        this.game = game;
        this.stateValue = game.getState().getValue(game, targetPlayer);
        this.fullStateValue = game.getState().getValue(true, game);
        this.terminal = game.checkIfGameIsOver();
        this.setPlayer();
        nodeCount = 1;
    }

    protected MCTSNode(MCTSNode parent, Game game, Ability action) {
        this.targetPlayer = parent.targetPlayer;
        this.game = game;
        this.stateValue = game.getState().getValue(game, this.targetPlayer);
        this.fullStateValue = game.getState().getValue(true, game);
        this.terminal = game.checkIfGameIsOver();
        this.parent = parent;
        this.action = action;
        this.setPlayer();
        ++nodeCount;
    }

    protected MCTSNode(MCTSNode parent, Game game, Combat combat) {
        this.targetPlayer = parent.targetPlayer;
        this.game = game;
        this.combat = combat;
        this.stateValue = game.getState().getValue(game, this.targetPlayer);
        this.fullStateValue = game.getState().getValue(true, game);
        this.terminal = game.checkIfGameIsOver();
        this.parent = parent;
        this.setPlayer();
        ++nodeCount;
    }

    private void setPlayer() {
        this.playerId = this.game.getStep().getStepPart() == Step.StepPart.PRIORITY ? this.game.getPriorityPlayerId() : (this.game.getTurnStepType() == PhaseStep.DECLARE_BLOCKERS ? (UUID)this.game.getCombat().getDefenders().iterator().next() : this.game.getActivePlayerId());
    }

    public MCTSNode select(UUID targetPlayerId) {
        double bestValue = Double.NEGATIVE_INFINITY;
        boolean isTarget = this.playerId.equals(targetPlayerId);
        MCTSNode bestChild = null;
        if (this.children.size() == 1) {
            return this.children.get(0);
        }
        for (MCTSNode node : this.children) {
            double uct = node.visits > 0 ? (isTarget ? (double)(node.wins / node.visits) + selectionCoefficient * Math.sqrt(Math.log(this.visits) / (double)node.visits) : (double)((node.visits - node.wins) / node.visits) + selectionCoefficient * Math.sqrt(Math.log(this.visits) / (double)node.visits)) : 10000.0 + 1000.0 * RandomUtil.nextDouble();
            if (!(uct > bestValue)) continue;
            bestChild = node;
            bestValue = uct;
        }
        return bestChild;
    }

    public void expand() {
        MCTSPlayer player = (MCTSPlayer)this.game.getPlayer(this.playerId);
        if (player.getNextAction() == null) {
            logger.fatal((Object)"next action is null");
        }
        this.children.addAll(MCTSNextActionFactory.createNextAction(player.getNextAction()).performNextAction(this, player, this.game, this.fullStateValue));
        this.game = null;
    }

    public int simulate(UUID playerId) {
        Game sim = this.createSimulation(this.game, playerId);
        sim.resume();
        int retVal = -1;
        for (Player simPlayer : sim.getPlayers().values()) {
            if (!simPlayer.getId().equals(playerId) || !simPlayer.hasWon()) continue;
            retVal = 1;
        }
        return retVal;
    }

    public void backpropagate(int result) {
        if (result == 0) {
            return;
        }
        if (result == 1) {
            ++this.wins;
        }
        ++this.visits;
        if (this.parent != null) {
            this.parent.backpropagate(result);
        }
    }

    public boolean isLeaf() {
        return this.children.isEmpty();
    }

    public MCTSNode bestChild() {
        if (this.children.size() == 1) {
            return this.children.get(0);
        }
        double bestCount = -1.0;
        double bestRatio = 0.0;
        boolean bestIsPass = false;
        MCTSNode bestChild = null;
        for (MCTSNode node : this.children) {
            double ratio;
            if ((double)node.visits > bestCount) {
                if (bestIsPass && (ratio = (double)node.wins / ((double)node.visits * 1.0)) < bestRatio + 0.0) continue;
                bestChild = node;
                bestCount = node.visits;
                bestRatio = (double)node.wins / ((double)node.visits * 1.0);
                bestIsPass = false;
                continue;
            }
            if (!(node.action instanceof PassAbility) || node.visits <= 10 || bestChild.action instanceof PlayLandAbility || !((ratio = (double)node.wins / ((double)node.visits * 1.0)) > bestRatio - 0.0)) continue;
            logger.info((Object)("choosing pass over " + bestChild.getAction()));
            bestChild = node;
            bestCount = node.visits;
            bestRatio = ratio;
            bestIsPass = true;
        }
        return bestChild;
    }

    public void emancipate() {
        if (this.parent != null) {
            this.parent.children.remove(this);
            this.parent = null;
        }
    }

    public Ability getAction() {
        return this.action;
    }

    public int getNumChildren() {
        return this.children.size();
    }

    public MCTSNode getParent() {
        return this.parent;
    }

    public Combat getCombat() {
        return this.combat;
    }

    public int getNodeCount() {
        return nodeCount;
    }

    public String getStateValue() {
        return this.stateValue;
    }

    public double getWinRatio() {
        if (this.visits > 0) {
            return (double)this.wins / ((double)this.visits * 1.0);
        }
        return -1.0;
    }

    public int getVisits() {
        return this.visits;
    }

    protected Game createSimulation(Game game, UUID playerId) {
        Game sim = game.createSimulationForAI();
        for (Player oldPlayer : sim.getState().getPlayers().values()) {
            Player origPlayer = ((Player)game.getState().getPlayers().get((Object)oldPlayer.getId())).copy();
            SimulatedPlayerMCTS newPlayer = new SimulatedPlayerMCTS(oldPlayer, true);
            newPlayer.restore(origPlayer);
            sim.getState().getPlayers().put((Object)oldPlayer.getId(), (Object)newPlayer);
        }
        this.randomizePlayers(sim, playerId);
        return sim;
    }

    protected void randomizePlayers(Game game, UUID playerId) {
        for (Player player : game.getState().getPlayers().values()) {
            if (!player.getId().equals(playerId)) {
                int handSize = player.getHand().size();
                player.getLibrary().addAll(player.getHand().getCards(game), game);
                player.getHand().clear();
                player.getLibrary().shuffle();
                for (int i = 0; i < handSize; ++i) {
                    Card card = player.getLibrary().drawFromTop(game);
                    card.setZone(Zone.HAND, game);
                    player.getHand().add(card);
                }
                continue;
            }
            player.getLibrary().shuffle();
        }
    }

    public boolean isTerminal() {
        return this.terminal;
    }

    public boolean isWinner(UUID playerId) {
        Player player;
        return this.game != null && (player = this.game.getPlayer(playerId)) != null && player.hasWon();
    }

    public MCTSNode getMatchingState(String state) {
        ArrayDeque<MCTSNode> queue = new ArrayDeque<MCTSNode>();
        queue.add(this);
        while (!queue.isEmpty()) {
            MCTSNode current = (MCTSNode)queue.remove();
            if (current.stateValue.equals(state)) {
                return current;
            }
            for (MCTSNode child : current.children) {
                queue.add(child);
            }
        }
        return null;
    }

    public void merge(MCTSNode merge) {
        if (!this.stateValue.equals(merge.stateValue)) {
            logger.info((Object)"mismatched merge states at root");
            return;
        }
        this.visits += merge.visits;
        this.wins += merge.wins;
        int mismatchCount = 0;
        ArrayList<MCTSNode> mergeChildren = new ArrayList<MCTSNode>();
        for (MCTSNode child : merge.children) {
            mergeChildren.add(child);
        }
        block1: for (MCTSNode child : this.children) {
            for (MCTSNode mergeChild : mergeChildren) {
                if (mergeChild.action != null && child.action != null) {
                    if (!mergeChild.action.toString().equals(child.action.toString())) continue;
                    if (!mergeChild.stateValue.equals(child.stateValue)) {
                        ++mismatchCount;
                        continue block1;
                    }
                    child.merge(mergeChild);
                    mergeChildren.remove(mergeChild);
                    continue block1;
                }
                if (!mergeChild.combat.getValue().equals(child.combat.getValue())) continue;
                if (!mergeChild.stateValue.equals(child.stateValue)) {
                    ++mismatchCount;
                    continue block1;
                }
                child.merge(mergeChild);
                mergeChildren.remove(mergeChild);
                continue block1;
            }
        }
        if (!mergeChildren.isEmpty()) {
            for (MCTSNode child : mergeChildren) {
                child.parent = this;
                this.children.add(child);
            }
        }
    }

    public int size() {
        int num = 1;
        for (MCTSNode child : this.children) {
            num += child.size();
        }
        return num;
    }

    protected static List<Ability> getPlayables(MCTSPlayer player, String state, Game game) {
        if (playablesCache.containsKey(state)) {
            ++playablesHit;
            return playablesCache.get(state);
        }
        ++playablesMiss;
        List<Ability> abilities = player.getPlayableOptions(game);
        playablesCache.put(state, abilities);
        return abilities;
    }

    protected static List<List<UUID>> getAttacks(MCTSPlayer player, String state, Game game) {
        if (attacksCache.containsKey(state)) {
            ++attacksHit;
            return attacksCache.get(state);
        }
        ++attacksMiss;
        List<List<UUID>> attacks = player.getAttacks(game);
        attacksCache.put(state, attacks);
        return attacks;
    }

    protected static List<List<List<UUID>>> getBlocks(MCTSPlayer player, String state, Game game) {
        if (blocksCache.containsKey(state)) {
            ++blocksHit;
            return blocksCache.get(state);
        }
        ++blocksMiss;
        List<List<List<UUID>>> blocks = player.getBlocks(game);
        blocksCache.put(state, blocks);
        return blocks;
    }

    public static int cleanupCache(int turnNum) {
        Set playablesKeys = playablesCache.keySet();
        Iterator playablesIterator = playablesKeys.iterator();
        int count = 0;
        while (playablesIterator.hasNext()) {
            String next = (String)playablesIterator.next();
            int cacheTurn = Integer.parseInt(next.split(":", 2)[0].substring(1));
            if (cacheTurn >= turnNum) continue;
            playablesIterator.remove();
            ++count;
        }
        Set attacksKeys = attacksCache.keySet();
        Iterator attacksIterator = attacksKeys.iterator();
        while (attacksIterator.hasNext()) {
            int cacheTurn = Integer.parseInt(((String)attacksIterator.next()).split(":", 2)[0].substring(1));
            if (cacheTurn >= turnNum) continue;
            attacksIterator.remove();
            ++count;
        }
        Set blocksKeys = blocksCache.keySet();
        Iterator blocksIterator = blocksKeys.iterator();
        while (blocksIterator.hasNext()) {
            int cacheTurn = Integer.parseInt(((String)blocksIterator.next()).split(":", 2)[0].substring(1));
            if (cacheTurn >= turnNum) continue;
            blocksIterator.remove();
            ++count;
        }
        return count;
    }

    public static void logHitMiss() {
    }

    static {
        playablesCache = new ConcurrentHashMap();
        attacksCache = new ConcurrentHashMap();
        blocksCache = new ConcurrentHashMap();
        playablesHit = 0L;
        playablesMiss = 0L;
        attacksHit = 0L;
        attacksMiss = 0L;
        blocksHit = 0L;
        blocksMiss = 0L;
    }
}

