极大极小搜索+αβ剪枝+算杀的五子棋AI代码
import java.util.*;
public class Solution {
private static final int N = 15; // 棋盘大小
private static final int INF = 2000000000;
private static final int BLANK = 3; // g[i][j] == BLANK,棋盘在该位置为空
private final int D = 6; // 递归的最大深度
private final int L = 3; // 在有子的周围扩展L个长度作为候选点,太远的就不考虑
private int[][] g; // 棋盘
private int me, he; // 我和对方的棋子颜色:0或是1
private final int[][] scoreMatrix = new int[N][N]; // 得分数组
private final int[][][] pos = new int[N][N][4]; // 构建boardStr的位置映射数组
public char[] boardStr = new char[1000]; // 棋盘转化的字符串
private int len = 0; // 棋盘转化的字符串的长度
// 以下是 AC自动机的数组,210 是打表得出的
private final int[][] tr = new int[210][4];
private final int[] cnt = new int[210];
private final int[] w = new int[210]; // 节点权值
private final int[] ne = new int[210];
// 构建 trie 时使用的下标
private int idx = 0;
// 常量
private final static String[] WIN_AI = {"11111"};
private final static String[] WIN_PLAYER = {"22222"};
private final static String[] HUO4_AI = {"011110"};
private final static String[] HUO4_PLAYER = {"022220"};
private final static String[] CHONG4_AI = {"011112", "211110", "10111", "11011", "11101"};
private final static String[] CHONG4_PLAYER = {"022221", "122220", "20222", "22022", "22202"};
private final static String[] HUO3_AI = {"001110", "011100", "010110", "011010"};
private final static String[] HUO3_PLAYER = {"002220", "022200", "020220", "022020"};
private final static String[] MIAN3_AI = {"001112", "010112", "011012", "011102", "211100", "211010", "210110", "201110", "00111", "10011", "10101", "10110", "01011", "10011", "11001", "11010", "01101", "10101", "11001", "11100",};
private final static String[] MIAN3_PLAYER = {"002221", "020221", "022021", "022201", "122200", "122020", "120220", "102220", "00222", "20022", "20202", "20220", "02022", "20022", "22002", "22020", "02202", "20202", "22002", "22200",};
private final static String[] HUO2_AI = {"000110", "001010", "001100", "001100", "010100", "011000", "000110", "010010", "010100", "001010", "010010", "011000",};
private final static String[] HUO2_PLAYER = {"000220", "002020", "002200", "002200", "020200", "022000", "000220", "020020", "020200", "002020", "020020", "022000",};
private final static String[] MIAN2_AI = {"000112", "001012", "010012", "10001", "2010102", "2011002", "211000", "210100", "210010", "2001102"};
private final static String[] MIAN2_PLAYER = {"000221", "002021", "020021", "20002", "1020201", "1022001", "122000", "120200", "120020", "1002201"};
private final static String[] OL1_AI = {"1"};
private final static String[] OL1_PLAYER = {"2"};
private List<Integer> ans = new ArrayList<>(2);
// 估价函数:永远估计当前局面对于我方的价值
// 如果这一步是我在下,那么我会选择这个值最大的局面
// 如果这一步轮到对方,那么对方会选择这个值最小的局面(因为要对我方不利,对对方就是有利的)
public int evaluate() {
// AC自动机计算得分
int score = 0;
for (int i = 0, j = 0; i < len; i++) {
int t = boardStr[i] - '0';
j = tr[j][t];
int p = j;
while (p != 0) {
score += cnt[p] * w[p];
// 如果已经有一方获胜了,就不用再往下找了
if (w[p] == INF)
return INF;
if (w[p] == -INF)
return -INF;
p = ne[p];
}
}
return score;
}
// 启发式搜索:计算候选点
private List<Point> getCandidates(int who) {
List<Point> points = new ArrayList<>();
for (int i = 0; i < N; i++) {
for (int j = 0; j < N; j++) {
if (!isValid(i, j))
continue;
// 设置这里的棋子
if (who == me) {
for (int k = 0; k < 4; k++)
boardStr[pos[i][j][k]] = '1';
} else {
for (int k = 0; k < 4; k++)
boardStr[pos[i][j][k]] = '2';
}
int res = evaluate();
// 恢复这里的位置为空
for (int k = 0; k < 4; k++)
boardStr[pos[i][j][k]] = '0';
points.add(new Point(i, j, res));
}
}
Collections.sort(points);
if (who == me) { // 如果轮到我落子,那么选得分最高的点
Collections.reverse(points);
return points.subList(0, Math.min(13, points.size()));
}
return points.subList(0, Math.min(13, points.size()));
}
// 递归求解:返回当前局面的得分
public int dfs(int depth, int alpha, int beta) {
// 如果已经成杀,直接返回
int t = evaluate();
if (t == -INF || t == INF)
return t;
if (depth == 0)
return evaluate();
if (depth % 2 == 0) { // 极大层,更新alpha,返回alpha
List<Point> candidates = getCandidates(me);
for (Point p : candidates) {
int x = p.x, y = p.y;
// 设置这里的棋子:我永远是'1',对面永远是'2'
for (int i = 0; i < 4; i++) {
boardStr[pos[x][y][i]] = '1';
}
// 递归
int res = dfs(depth - 1, alpha, beta);
// 回溯时恢复现场
for (int i = 0; i < 4; i++) {
boardStr[pos[x][y][i]] = '0';
}
if (res >= alpha) {
alpha = res;
if (depth == D)
scoreMatrix[x][y] = alpha;
}
if (alpha >= beta)
return alpha; // 剪枝
}
return alpha;
} else { // 极小层,更新beta,返回beta
List<Point> candidates = getCandidates(he);
for (Point p : candidates) {
int x = p.x, y = p.y;
// 设置这里的棋子:我还是对方(也就是极大还是极小)
for (int i = 0; i < 4; i++) {
boardStr[pos[x][y][i]] = '2';
}
// 递归
int res = dfs(depth - 1, alpha, beta);
// 回溯时恢复现场
for (int i = 0; i < 4; i++) {
boardStr[pos[x][y][i]] = '0';
}
beta = Math.min(beta, res);
if (alpha >= beta)
return beta; // 剪枝
}
return beta;
}
}
private Point getBestNextStep(int score) {
List<Point> points = new ArrayList<>();
for (int i = 0; i < N; i++) {
for (int j = 0; j < N; j++)
if (scoreMatrix[i][j] == score)
points.add(new Point(i, j));
}
int mx = Integer.MIN_VALUE;
Point ans = null;
for (Point point : points) {
for (int i = 0; i < 4; i++)
boardStr[pos[point.x][point.y][i]] = '1';
int res = evaluate();
if (res > mx) {
mx = res;
ans = point;
}
for (int i = 0; i < 4; i++) // 还原棋盘
boardStr[pos[point.x][point.y][i]] = '0';
}
return ans;
}
// 寻找杀棋
private boolean getKill() {
List<Point> candidates = getCandidates(me);
if (candidates.get(0).val == INF) {
ans.addAll(Arrays.asList(candidates.get(0).x, candidates.get(0).y));
return true;
}
candidates = getCandidates(he);
if (candidates.get(0).val == -INF) {
ans.addAll(Arrays.asList(candidates.get(0).x, candidates.get(0).y));
return true;
}
return false;
}
/*
* 入口方法
* */
public List<Integer> getNextPoint(int[][] g, int me, int he) {
this.g = g;
this.me = me;
this.he = he;
// 判断棋盘是否为空,也就是是否是第一步走棋
int cnt = 0;
for (int i = 0; i < N; i++) {
for (int j = 0; j < N; j++)
if (g[i][j] == BLANK)
cnt++;
}
// 如果AI是第一步走棋
if (cnt == N * N) {
if (me == 0) { // 黑棋,直接下在天元
ans.addAll(Arrays.asList(7, 7));
return ans;
} else { // 白棋,在最中间的9个位置随机选择一个位置
Random random = new Random();
for (int i = 0; i < 10; i++) {
int x = random.nextInt(3) + 6;
int y = random.nextInt(3) + 6;
if (g[x][y] == BLANK) {
ans.addAll(Arrays.asList(x, y));
return ans;
}
}
}
}
// 初始化trie树
init();
// 只需要在最开始将棋盘转化为字符串就可以
chessBoard2String();
// System.out.println(boardStr);
// System.out.println("start score: " + evaluate());
// 寻找杀棋
if (getKill())
return ans;
// 递归寻找最好的棋
int res = dfs(D, Integer.MIN_VALUE, Integer.MAX_VALUE);
// System.out.println("final score: " + res);
Point nextPoint = getBestNextStep(res);
System.out.println(nextPoint);
ans.addAll(Arrays.asList(nextPoint.x, nextPoint.y));
return ans;
}
// 控制落子点在已经落子的位置周围(2个单位)
private boolean isValid(int x, int y) {
if (g[x][y] != BLANK)
return false;
// 横向:——
for (int j = Math.max(0, y - L); j < Math.min(N, y + L + 1); j++) {
if (g[x][j] != BLANK)
return true;
}
// 纵向:|
for (int i = Math.max(0, x - L); i < Math.min(N, x + L + 1); i++) {
if (g[i][y] != BLANK)
return true;
}
// 正斜线:\
for (int i = Math.max(0, x - L), j = Math.max(0, y - L); i < Math.min(N, x + L + 1) && j < Math.min(N, y + L + 1); i++, j++) {
if (g[i][j] != BLANK)
return true;
}
// 反斜线:/
for (int i = Math.max(0, x - L), j = Math.min(N, y + L); i < Math.min(N, x + L + 1) && j < Math.max(0, y - L + 1); i++, j--) {
if (g[i][j] != BLANK)
return true;
}
return false;
}
// 将棋盘转化成一个字符串,不同行列之间用'3'相连,避免对字符串匹配做出影响
private void chessBoard2String() {
// 所有横向:——
for (int i = 0; i < N; i++) {
boardStr[len++] = '3';
for (int j = 0; j < N; j++) {
if (g[i][j] == me)
boardStr[len++] = '1';
else if (g[i][j] == he)
boardStr[len++] = '2';
else boardStr[len++] = '0';
pos[i][j][0] = len - 1;
}
}
// 所有纵向:|
for (int j = 0; j < N; j++) {
boardStr[len++] = '3';
for (int i = 0; i < N; i++) {
if (g[i][j] == me)
boardStr[len++] = '1';
else if (g[i][j] == he)
boardStr[len++] = '2';
else boardStr[len++] = '0';
pos[i][j][1] = len - 1;
}
}
// 所有正斜线:\
for (int k = 0; k < N; k++) { // 枚举起点
boardStr[len++] = '3';
for (int i = 0, j = k; i < N && j < N; i++, j++) {
if (g[i][j] == me)
boardStr[len++] = '1';
else if (g[i][j] == he)
boardStr[len++] = '2';
else boardStr[len++] = '0';
pos[i][j][2] = len - 1;
}
}
for (int k = 1; k < N; k++) {
boardStr[len++] = '3';
for (int i = k, j = 0; i < N && j < N; i++, j++) {
if (g[i][j] == me)
boardStr[len++] = '1';
else if (g[i][j] == he)
boardStr[len++] = '2';
else boardStr[len++] = '0';
pos[i][j][2] = len - 1;
}
}
// 所有反斜线:/
for (int k = 0; k < N; k++) {
boardStr[len++] = '3';
for (int i = 0, j = k; i < N && j >= 0; i++, j--) {
if (g[i][j] == me)
boardStr[len++] = '1';
else if (g[i][j] == he)
boardStr[len++] = '2';
else boardStr[len++] = '0';
pos[i][j][3] = len - 1;
}
}
for (int k = 1; k < N; k++) {
boardStr[len++] = '3';
for (int i = k, j = N - 1; i < N && j >= 0; i++, j--) {
if (g[i][j] == me)
boardStr[len++] = '1';
else if (g[i][j] == he)
boardStr[len++] = '2';
else boardStr[len++] = '0';
pos[i][j][3] = len - 1;
}
}
}
// 将字符串 s 插入到 trie 中
private void insert(String s, int score) {
int p = 0;
for (int i = 0; i < s.length(); i++) {
int t = s.charAt(i) - '0';
if (tr[p][t] == 0)
tr[p][t] = ++idx;
p = tr[p][t];
}
cnt[p]++;
w[p] = score;
}
// 构建 ne[] 数组
private void build() {
int[] q = new int[10000];
int hh = 0, tt = 0;
// 将根节点所有存在的子节点加入到队列中
q[tt++] = tr[0][0];
q[tt++] = tr[0][1];
q[tt++] = tr[0][2];
// 宽搜
while (hh < tt) {
int t = q[hh++];
for (int i = 0; i < 4; i++) {
int c = tr[t][i];
if (c == 0)
tr[t][i] = tr[ne[t]][i];
else {
ne[c] = tr[ne[t]][i];
q[tt++] = c;
}
}
}
}
// 初始化trie树和调用build()
private void init() {
// 构建 trie 树
// 我方积分
for (String s : WIN_AI) {
insert(s, INF);
}
for (String s : HUO4_AI) {
insert(s, 50000);
}
for (String s : CHONG4_AI) {
insert(s, 400);
}
for (String s : HUO3_AI) {
insert(s, 400);
}
for (String s : MIAN3_AI) {
insert(s, 20);
}
for (String s : HUO2_AI) {
insert(s, 20);
}
for (String s : MIAN2_AI) {
insert(s, 1);
}
for (String s : OL1_AI) {
insert(s, 1);
}
// 对方积分
for (String s : WIN_PLAYER) {
insert(s, -INF);
}
for (String s : HUO4_PLAYER) {
insert(s, -100000);
}
for (String s : CHONG4_PLAYER) {
insert(s, -100000);
}
for (String s : HUO3_PLAYER) {
insert(s, -8000);
}
for (String s : MIAN3_PLAYER) {
insert(s, -50);
}
for (String s : HUO2_PLAYER) {
insert(s, -50);
}
for (String s : MIAN2_PLAYER) {
insert(s, -50);
}
for (String s : OL1_PLAYER) {
insert(s, -3);
}
// 构建 ne[] 数组
build();
}
public class Point implements Comparable<Point> {
public int x;
public int y;
public int val;
public Point(int x, int y) {
this.x = x;
this.y = y;
}
public Point(int x, int y, int val) {
this(x, y);
this.val = val;
}
@Override
public int compareTo(Point o) { // 默认升序排序
if (this.val > o.val)
return 1;
else if (this.val < o.val)
return -1;
return 0;
}
}
}
代码在 / 斜线情况下有时候会判断不出来