五子棋AI代码

397 阅读4分钟

极大极小搜索+αβ剪枝+算杀的五子棋AI代码

import java.util.*;

public class Solution {
    private static final int N = 15;    // 棋盘大小
    private static final int INF = 2000000000;
    private static final int BLANK = 3;    // g[i][j] == BLANK,棋盘在该位置为空
    private final int D = 6;    // 递归的最大深度
    private final int L = 3;    // 在有子的周围扩展L个长度作为候选点,太远的就不考虑
    private int[][] g;  // 棋盘
    private int me, he;     // 我和对方的棋子颜色:0或是1
    private final int[][] scoreMatrix = new int[N][N];  // 得分数组
    private final int[][][] pos = new int[N][N][4];     // 构建boardStr的位置映射数组
    public char[] boardStr = new char[1000];     // 棋盘转化的字符串
    private int len = 0;        // 棋盘转化的字符串的长度
    // 以下是 AC自动机的数组,210 是打表得出的
    private final int[][] tr = new int[210][4];
    private final int[] cnt = new int[210];
    private final int[] w = new int[210];   // 节点权值
    private final int[] ne = new int[210];
    // 构建 trie 时使用的下标
    private int idx = 0;
    // 常量
    private final static String[] WIN_AI = {"11111"};
    private final static String[] WIN_PLAYER = {"22222"};
    private final static String[] HUO4_AI = {"011110"};
    private final static String[] HUO4_PLAYER = {"022220"};
    private final static String[] CHONG4_AI = {"011112", "211110", "10111", "11011", "11101"};
    private final static String[] CHONG4_PLAYER = {"022221", "122220", "20222", "22022", "22202"};
    private final static String[] HUO3_AI = {"001110", "011100", "010110", "011010"};
    private final static String[] HUO3_PLAYER = {"002220", "022200", "020220", "022020"};
    private final static String[] MIAN3_AI = {"001112", "010112", "011012", "011102", "211100", "211010", "210110", "201110", "00111", "10011", "10101", "10110", "01011", "10011", "11001", "11010", "01101", "10101", "11001", "11100",};
    private final static String[] MIAN3_PLAYER = {"002221", "020221", "022021", "022201", "122200", "122020", "120220", "102220", "00222", "20022", "20202", "20220", "02022", "20022", "22002", "22020", "02202", "20202", "22002", "22200",};
    private final static String[] HUO2_AI = {"000110", "001010", "001100", "001100", "010100", "011000", "000110", "010010", "010100", "001010", "010010", "011000",};
    private final static String[] HUO2_PLAYER = {"000220", "002020", "002200", "002200", "020200", "022000", "000220", "020020", "020200", "002020", "020020", "022000",};
    private final static String[] MIAN2_AI = {"000112", "001012", "010012", "10001", "2010102", "2011002", "211000", "210100", "210010", "2001102"};
    private final static String[] MIAN2_PLAYER = {"000221", "002021", "020021", "20002", "1020201", "1022001", "122000", "120200", "120020", "1002201"};
    private final static String[] OL1_AI = {"1"};
    private final static String[] OL1_PLAYER = {"2"};

    private List<Integer> ans = new ArrayList<>(2);

    // 估价函数:永远估计当前局面对于我方的价值
    // 如果这一步是我在下,那么我会选择这个值最大的局面
    // 如果这一步轮到对方,那么对方会选择这个值最小的局面(因为要对我方不利,对对方就是有利的)
    public int evaluate() {
        // AC自动机计算得分
        int score = 0;
        for (int i = 0, j = 0; i < len; i++) {
            int t = boardStr[i] - '0';
            j = tr[j][t];
            int p = j;
            while (p != 0) {
                score += cnt[p] * w[p];
                // 如果已经有一方获胜了,就不用再往下找了
                if (w[p] == INF)
                    return INF;
                if (w[p] == -INF)
                    return -INF;

                p = ne[p];
            }
        }
        return score;
    }

    // 启发式搜索:计算候选点
    private List<Point> getCandidates(int who) {
        List<Point> points = new ArrayList<>();
        for (int i = 0; i < N; i++) {
            for (int j = 0; j < N; j++) {
                if (!isValid(i, j))
                    continue;
                // 设置这里的棋子
                if (who == me) {
                    for (int k = 0; k < 4; k++)
                        boardStr[pos[i][j][k]] = '1';
                } else {
                    for (int k = 0; k < 4; k++)
                        boardStr[pos[i][j][k]] = '2';
                }
                int res = evaluate();
                // 恢复这里的位置为空
                for (int k = 0; k < 4; k++)
                    boardStr[pos[i][j][k]] = '0';
                points.add(new Point(i, j, res));
            }
        }
        Collections.sort(points);
        if (who == me) {    // 如果轮到我落子,那么选得分最高的点
            Collections.reverse(points);
            return points.subList(0, Math.min(13, points.size()));
        }
        return points.subList(0, Math.min(13, points.size()));
    }

    // 递归求解:返回当前局面的得分
    public int dfs(int depth, int alpha, int beta) {
        // 如果已经成杀,直接返回
        int t = evaluate();
        if (t == -INF || t == INF)
            return t;
        if (depth == 0)
            return evaluate();

        if (depth % 2 == 0) {    // 极大层,更新alpha,返回alpha
            List<Point> candidates = getCandidates(me);
            for (Point p : candidates) {
                int x = p.x, y = p.y;
                // 设置这里的棋子:我永远是'1',对面永远是'2'
                for (int i = 0; i < 4; i++) {
                    boardStr[pos[x][y][i]] = '1';
                }
                // 递归
                int res = dfs(depth - 1, alpha, beta);
                // 回溯时恢复现场
                for (int i = 0; i < 4; i++) {
                    boardStr[pos[x][y][i]] = '0';
                }
                if (res >= alpha) {
                    alpha = res;
                    if (depth == D)
                        scoreMatrix[x][y] = alpha;
                }
                if (alpha >= beta)
                    return alpha;   // 剪枝
            }
            return alpha;
        } else {  // 极小层,更新beta,返回beta
            List<Point> candidates = getCandidates(he);
            for (Point p : candidates) {
                int x = p.x, y = p.y;
                // 设置这里的棋子:我还是对方(也就是极大还是极小)
                for (int i = 0; i < 4; i++) {
                    boardStr[pos[x][y][i]] = '2';
                }
                // 递归
                int res = dfs(depth - 1, alpha, beta);
                // 回溯时恢复现场
                for (int i = 0; i < 4; i++) {
                    boardStr[pos[x][y][i]] = '0';
                }
                beta = Math.min(beta, res);
                if (alpha >= beta)
                    return beta;   // 剪枝
            }
            return beta;
        }
    }

    private Point getBestNextStep(int score) {
        List<Point> points = new ArrayList<>();
        for (int i = 0; i < N; i++) {
            for (int j = 0; j < N; j++)
                if (scoreMatrix[i][j] == score)
                    points.add(new Point(i, j));
        }
        int mx = Integer.MIN_VALUE;
        Point ans = null;
        for (Point point : points) {
            for (int i = 0; i < 4; i++)
                boardStr[pos[point.x][point.y][i]] = '1';
            int res = evaluate();
            if (res > mx) {
                mx = res;
                ans = point;
            }
            for (int i = 0; i < 4; i++)    // 还原棋盘
                boardStr[pos[point.x][point.y][i]] = '0';
        }
        return ans;
    }

    // 寻找杀棋
    private boolean getKill() {
        List<Point> candidates = getCandidates(me);
        if (candidates.get(0).val == INF) {
            ans.addAll(Arrays.asList(candidates.get(0).x, candidates.get(0).y));
            return true;
        }
        candidates = getCandidates(he);
        if (candidates.get(0).val == -INF) {
            ans.addAll(Arrays.asList(candidates.get(0).x, candidates.get(0).y));
            return true;
        }
        return false;
    }

    /*
     * 入口方法
     * */
    public List<Integer> getNextPoint(int[][] g, int me, int he) {
        this.g = g;
        this.me = me;
        this.he = he;

        // 判断棋盘是否为空,也就是是否是第一步走棋
        int cnt = 0;
        for (int i = 0; i < N; i++) {
            for (int j = 0; j < N; j++)
                if (g[i][j] == BLANK)
                    cnt++;
        }
        // 如果AI是第一步走棋
        if (cnt == N * N) {
            if (me == 0) { // 黑棋,直接下在天元
                ans.addAll(Arrays.asList(7, 7));
                return ans;
            } else {  // 白棋,在最中间的9个位置随机选择一个位置
                Random random = new Random();
                for (int i = 0; i < 10; i++) {
                    int x = random.nextInt(3) + 6;
                    int y = random.nextInt(3) + 6;
                    if (g[x][y] == BLANK) {
                        ans.addAll(Arrays.asList(x, y));
                        return ans;
                    }
                }
            }
        }

        // 初始化trie树
        init();
        // 只需要在最开始将棋盘转化为字符串就可以
        chessBoard2String();
//        System.out.println(boardStr);
//        System.out.println("start score: " + evaluate());
        // 寻找杀棋
        if (getKill())
            return ans;
        // 递归寻找最好的棋
        int res = dfs(D, Integer.MIN_VALUE, Integer.MAX_VALUE);
//        System.out.println("final score: " + res);
        Point nextPoint = getBestNextStep(res);
        System.out.println(nextPoint);
        ans.addAll(Arrays.asList(nextPoint.x, nextPoint.y));
        return ans;
    }

    // 控制落子点在已经落子的位置周围(2个单位)
    private boolean isValid(int x, int y) {
        if (g[x][y] != BLANK)
            return false;
        // 横向:——
        for (int j = Math.max(0, y - L); j < Math.min(N, y + L + 1); j++) {
            if (g[x][j] != BLANK)
                return true;
        }
        // 纵向:|
        for (int i = Math.max(0, x - L); i < Math.min(N, x + L + 1); i++) {
            if (g[i][y] != BLANK)
                return true;
        }
        // 正斜线:\
        for (int i = Math.max(0, x - L), j = Math.max(0, y - L); i < Math.min(N, x + L + 1) && j < Math.min(N, y + L + 1); i++, j++) {
            if (g[i][j] != BLANK)
                return true;
        }
        // 反斜线:/
        for (int i = Math.max(0, x - L), j = Math.min(N, y + L); i < Math.min(N, x + L + 1) && j < Math.max(0, y - L + 1); i++, j--) {
            if (g[i][j] != BLANK)
                return true;
        }
        return false;
    }

    // 将棋盘转化成一个字符串,不同行列之间用'3'相连,避免对字符串匹配做出影响
    private void chessBoard2String() {
        // 所有横向:——
        for (int i = 0; i < N; i++) {
            boardStr[len++] = '3';
            for (int j = 0; j < N; j++) {
                if (g[i][j] == me)
                    boardStr[len++] = '1';
                else if (g[i][j] == he)
                    boardStr[len++] = '2';
                else boardStr[len++] = '0';
                pos[i][j][0] = len - 1;
            }
        }
        // 所有纵向:|
        for (int j = 0; j < N; j++) {
            boardStr[len++] = '3';
            for (int i = 0; i < N; i++) {
                if (g[i][j] == me)
                    boardStr[len++] = '1';
                else if (g[i][j] == he)
                    boardStr[len++] = '2';
                else boardStr[len++] = '0';
                pos[i][j][1] = len - 1;
            }
        }
        // 所有正斜线:\
        for (int k = 0; k < N; k++) {    // 枚举起点
            boardStr[len++] = '3';
            for (int i = 0, j = k; i < N && j < N; i++, j++) {
                if (g[i][j] == me)
                    boardStr[len++] = '1';
                else if (g[i][j] == he)
                    boardStr[len++] = '2';
                else boardStr[len++] = '0';
                pos[i][j][2] = len - 1;
            }
        }
        for (int k = 1; k < N; k++) {
            boardStr[len++] = '3';
            for (int i = k, j = 0; i < N && j < N; i++, j++) {
                if (g[i][j] == me)
                    boardStr[len++] = '1';
                else if (g[i][j] == he)
                    boardStr[len++] = '2';
                else boardStr[len++] = '0';
                pos[i][j][2] = len - 1;
            }
        }
        // 所有反斜线:/
        for (int k = 0; k < N; k++) {
            boardStr[len++] = '3';
            for (int i = 0, j = k; i < N && j >= 0; i++, j--) {
                if (g[i][j] == me)
                    boardStr[len++] = '1';
                else if (g[i][j] == he)
                    boardStr[len++] = '2';
                else boardStr[len++] = '0';
                pos[i][j][3] = len - 1;
            }
        }
        for (int k = 1; k < N; k++) {
            boardStr[len++] = '3';
            for (int i = k, j = N - 1; i < N && j >= 0; i++, j--) {
                if (g[i][j] == me)
                    boardStr[len++] = '1';
                else if (g[i][j] == he)
                    boardStr[len++] = '2';
                else boardStr[len++] = '0';
                pos[i][j][3] = len - 1;
            }
        }
    }

    // 将字符串 s 插入到 trie 中
    private void insert(String s, int score) {
        int p = 0;
        for (int i = 0; i < s.length(); i++) {
            int t = s.charAt(i) - '0';
            if (tr[p][t] == 0)
                tr[p][t] = ++idx;
            p = tr[p][t];
        }
        cnt[p]++;
        w[p] = score;
    }

    // 构建 ne[] 数组
    private void build() {
        int[] q = new int[10000];
        int hh = 0, tt = 0;
        // 将根节点所有存在的子节点加入到队列中
        q[tt++] = tr[0][0];
        q[tt++] = tr[0][1];
        q[tt++] = tr[0][2];
        // 宽搜
        while (hh < tt) {
            int t = q[hh++];
            for (int i = 0; i < 4; i++) {
                int c = tr[t][i];
                if (c == 0)
                    tr[t][i] = tr[ne[t]][i];
                else {
                    ne[c] = tr[ne[t]][i];
                    q[tt++] = c;
                }
            }
        }
    }

    // 初始化trie树和调用build()
    private void init() {
        // 构建 trie 树
        // 我方积分
        for (String s : WIN_AI) {
            insert(s, INF);
        }
        for (String s : HUO4_AI) {
            insert(s, 50000);
        }
        for (String s : CHONG4_AI) {
            insert(s, 400);
        }
        for (String s : HUO3_AI) {
            insert(s, 400);
        }
        for (String s : MIAN3_AI) {
            insert(s, 20);
        }
        for (String s : HUO2_AI) {
            insert(s, 20);
        }
        for (String s : MIAN2_AI) {
            insert(s, 1);
        }
        for (String s : OL1_AI) {
            insert(s, 1);
        }
        // 对方积分
        for (String s : WIN_PLAYER) {
            insert(s, -INF);
        }
        for (String s : HUO4_PLAYER) {
            insert(s, -100000);
        }
        for (String s : CHONG4_PLAYER) {
            insert(s, -100000);
        }
        for (String s : HUO3_PLAYER) {
            insert(s, -8000);
        }
        for (String s : MIAN3_PLAYER) {
            insert(s, -50);
        }
        for (String s : HUO2_PLAYER) {
            insert(s, -50);
        }
        for (String s : MIAN2_PLAYER) {
            insert(s, -50);
        }
        for (String s : OL1_PLAYER) {
            insert(s, -3);
        }
        // 构建 ne[] 数组
        build();
    }

    public class Point implements Comparable<Point> {
        public int x;
        public int y;
        public int val;

        public Point(int x, int y) {
            this.x = x;
            this.y = y;
        }

        public Point(int x, int y, int val) {
            this(x, y);
            this.val = val;
        }

        @Override
        public int compareTo(Point o) {    // 默认升序排序
            if (this.val > o.val)
                return 1;
            else if (this.val < o.val)
                return -1;
            return 0;
        }
    }
}

代码在 / 斜线情况下有时候会判断不出来