AI玩游戏的一点尝试(4)—— 数字识别

50 阅读4分钟

前言

AI玩游戏的一点尝试(1)—— 架构设计与初步状态识别

AI玩游戏的一点尝试(2)—— 初探无监督学习与特征可视化

AI玩游戏的一点尝试(3)—— 图片去重

数据预处理

首先要从游戏截图上裁剪出需要识别的数字区域,这里写一个脚本可视化创建模板区域并记录:

template_path = Path("data/template.json")
template_data = {}
if template_path.exists():
    with open(template_path, 'r', encoding='utf-8') as f:
        template_data = json.load(f)

image = cv2.imread("data/source/1/20250527_161056.png")

window_name = "Choose Area"
cv2.namedWindow(window_name, cv2.WINDOW_NORMAL)
cv2.resizeWindow(window_name, 540, 960)
r = cv2.selectROI(window_name, image)
cv2.destroyWindow(window_name)

name = input("请输入区域名称: ").strip()

if name in template_data:
    overwrite = input(f"区域 '{name}' 已存在,是否覆盖?(y/n): ").strip().lower()
    if overwrite != 'y':
        log.info("已取消操作")
        exit()

template_data[name] = [int(r[0]), int(r[1]), int(r[2]), int(r[3])]
with open(template_path, 'w', encoding='utf-8') as f:
    json.dump(template_data, f, ensure_ascii=False, indent=2)
log.info(f"已保存区域 '{name}' 的模板信息")

image.png

创建模板信息后,需要对已经在养成状态界面的图片进行处理,批量生成截取后的数字图片:

def extract_and_preprocess_digits(img_array, template_data, target_width, target_height):
    digit_images = []
    attributes = ['速度', '耐力', '力量', '毅力', '智力', '技能点数']
    img = Image.fromarray(img_array) 
    for attribute in attributes:
        x, y, width, height = template_data[attribute]
        y -= 65
        cropped_img = img.crop((x, y, x + width, y + height))
        resized_img = cropped_img.resize((target_width, target_height), Image.Resampling.LANCZOS)
        processed_img_array = np.array(resized_img)
        digit_images.append((attribute, processed_img_array))
    return digit_images

但是只是生成图片还不够,我们还需要对数据进行标注,这里先使用ocr对图片进行识别,如果ocr识别失败再手动输入:

try:
    ocr_result = int(ocr.ocr.ocr_for_single_line(np.array(cropped_img))['text'].replace("o", "0"))
    if ocr_result <= 2000:
        digit = ocr_result
    else:
        cv2.imshow('Image', cv2.cvtColor(np.array(cropped_img), cv2.COLOR_RGB2BGR))
        cv2.waitKey(0)
        cv2.destroyAllWindows()
        digit = int(input(f"识别到的数字 {ocr_result} 异常,请手动输入: "))
except Exception as e:
    logger.error(f"识别过程中出现异常: {e}")
    cv2.imshow('Image', cv2.cvtColor(np.array(cropped_img), cv2.COLOR_RGB2BGR))
    cv2.waitKey(0)
    cv2.destroyAllWindows()
    digit = int(input("识别过程中出现异常,请手动输入: "))

output_path = os.path.join(output_dir, f"{digit}_{base_name}_{attribute}.png")
Image.fromarray(resized_img_array).save(output_path, 'PNG')

识别结果直接放在文件名前缀可以很清晰的看到哪个数字识别有问题,手工验证一遍后数据集就准备完成了。

image.png

模型准备

这个模型只计划用来识别养成过程中的数字部分,而养成数字都不会超过4位数,经过AI介绍选择定长识别+空位填充的方式:数字定长4位,每一位有11个输出节点,分别代表0-9和空。

for filename in os.listdir(image_dir):
    if not filename.endswith('.png'):
        continue
    label_str = filename.split('_')[0]
    label_digits = [int(c) for c in label_str if c.isdigit()]
    if len(label_digits) > max_digits:
        label_digits = label_digits[:max_digits]
    label_digits += [10] * (max_digits - len(label_digits))
    self.images.append(os.path.join(image_dir, filename))
    self.labels.append(label_digits)

模型使用卷积+全连接:

class DigitClassifier(nn.Module):
    def __init__(self, max_digits=4):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 64, 3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
        self.conv3 = nn.Conv2d(128, 256, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.relu = nn.ReLU()
        self.max_digits = max_digits
        # 计算池化后特征图尺寸
        # 输入尺寸: (H, W) = (DIGIT_TARGET_HEIGHT, DIGIT_TARGET_WIDTH)
        # 经过3次2x2池化,高宽各除以8
        feature_h = DIGIT_TARGET_HEIGHT // 8
        feature_w = DIGIT_TARGET_WIDTH // 8
        self.fc1 = nn.Linear(256 * feature_h * feature_w, 256)
        self.fc2 = nn.Linear(256, max_digits * 11)
    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = self.pool(self.relu(self.conv3(x)))
        x = x.view(x.size(0), -1)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        x = x.view(-1, self.max_digits, 11)
        return x

为了增强模型的泛化能力,对数据集transform进行处理(一开始没有灰度和标准化操作,后来加上效果更好):

# 定义训练集的数据增强变换
train_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize((DIGIT_TARGET_HEIGHT, DIGIT_TARGET_WIDTH)),
    transforms.RandomAffine(degrees=5, translate=(0.1, 0.1), scale=(0.9, 1.1)),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=DIGIT_NORMALIZE_MEAN, std=DIGIT_NORMALIZE_STD),
])
# 定义验证集的变换(只进行Resize和ToTensor)
val_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize((DIGIT_TARGET_HEIGHT, DIGIT_TARGET_WIDTH)),
    transforms.ToTensor(),
    transforms.Normalize(mean=DIGIT_NORMALIZE_MEAN, std=DIGIT_NORMALIZE_STD),
])

一开始使用40epoch进行训练,全部训练完成后发现过拟合:训练集和验证集正确率都很高,但提供数据集之外的图片识别经常出错。这里又写了一个随机识别单张图的脚本:

while True:
    image_path = random.choice(all_digit_images)
    logger.info(f"随机选择数字图片进行处理: {image_path}")

    img = Image.open(image_path).convert('RGB')
    img_array = np.array(img)

    resized_img_array = img_array

    processed_img_pil = Image.fromarray(resized_img_array).convert('RGB')
    image_tensor = predict_transform(processed_img_pil).unsqueeze(0).to(device)
    with torch.no_grad():
        outputs = model(image_tensor)
        predicted_indices = outputs.argmax(dim=2).squeeze(0)

    predicted_digits_str = ""
    for digit_index in predicted_indices:
        if digit_index.item() != 10:
            predicted_digits_str += str(digit_index.item())
    cv2.imshow('Predicted Digit', cv2.cvtColor(resized_img_array, cv2.COLOR_RGB2BGR))
    logger.info(f"预测的数字: {predicted_digits_str}")
    cv2.waitKey(0)
    cv2.destroyAllWindows()

改为在训练集loss没有明显减少的时候停止训练,准确度基本符合预期:

image.png

接入游戏

在之前的数据采集循环中增加状态的判断并进行识别:

if state in [1, 3] and not saved_in_current_state:
    saved_in_current_state = True
    digit_predictions, processed_images = process_and_predict_digits(image, template_data, digit_model, digit_transform, device, max_digits)

    logger.info("识别到数字:")
    for attribute, result in digit_predictions.items():
        logger.info(f"  {attribute}: {result['value']}")

image.png

下一步

目前训练的数字都是比较易读的类型,下一步计划在此基础上继续训练另一种显示的数字识别。

image.png