强化学习PPO算法解决:限制取出物品数量的01背包问题

150 阅读1分钟
import gymnasium as gym  # 导入gym
from gymnasium import Env
from gymnasium.spaces import Discrete, Box, Dict, Tuple, MultiBinary, MultiDiscrete
import numpy as np
import random
import os
from stable_baselines3 import PPO, DQN ,A2C
from stable_baselines3.common.vec_env import VecFrameStack  # 堆叠操作,提高训练效率
from stable_baselines3.common.evaluation import evaluate_policy

model_name = "PPO3"
# 定义环境
class KnapsackEnv(Env):
    def __init__(self, weights, values, capacity, limit):
        super(KnapsackEnv, self).__init__()
        self.weights = weights
        self.values = values
        self.capacity = capacity
        self.n_items = len(weights)
        self.limit_num = limit

        self.current_weight = 0
        self.current_value = 0
        self.current_index = 0

        self.action_space = Discrete(2)  # 0: 不选, 1: 选
        self.observation_space = gym.spaces.Dict({"select_id": gym.spaces.Box(0, self.n_items, (1,), np.int64),
                                                  "selected": gym.spaces.Box(0, 1, (self.n_items,), np.int64),})

    def reset(self, seed="", options=""):
        self.current_weight = 0
        self.current_value = 0
        self.current_index = 0
        self.state = {"select_id": np.zeros(1 ,dtype=int), "selected": np.zeros(self.n_items, dtype=int)}
        return self.state, {}

    def step(self, action):
        # self.state = self.observation_space.sample()
        reward = 0
        done = False
        truncated = False
        select_id = self.state["select_id"][0]

        if action==1:
            self.state["selected"][select_id] = 1
            if self.current_weight + self.weights[select_id] > self.capacity:
                reward = -100
            else:
                self.current_weight += self.weights[select_id]
                self.current_value += self.values[select_id]
                reward = 0
        else:
            reward = 0
            self.state["selected"][select_id] = 0

        select_id += 1
        if select_id == self.n_items or np.sum(self.state["selected"]) == self.limit_num:
            reward = self.current_value * 100
            done = True

        self.state['select_id'][0] = select_id
        return self.state, reward, done, truncated, {}

    def render(self, mode='human'):
        pass

env = KnapsackEnv([1,4,6,4],[3,4,10,10],9, 2)

model = PPO("MultiInputPolicy", env, verbose=1)
model.learn(total_timesteps=200000)
model.save(model_name)

model = PPO.load(model_name)
episodes = 20
for episode in range(1, episodes + 1):
    obs, _ = env.reset()
    done = False
    score = 0
    truncated = False
    step = 0
    while not done:
        env.render()
        action, _ = model.predict(obs)
        obs, reward, done, truncated, info = env.step(action)
        step += 1
    print('Episode:{} Score:{} Step:{}'.format(episode, env.current_value, step))