dimanche 18 février 2018

L'apprentissage (de "cart pole") par renforcement

D'abord, installer python et keras dans votre ordinateur.
Installer Keras dans Linux

Installer paquets nécessaires:
$ python3.6-m pip install keras-rl
$ python3.6 -m pip install gym
$ python3.6 -m pip install gym[all]
$ sudo apt-get install libav-tools

Créer un fichier texte et le nommer "test.py". (Si l'extension n'est pas visible, changez les paramètres pour le faire apparaître.) Copier-coller le suivant dans le fichier texte:
import numpy as np
import gym

from keras.models import Sequential
from keras.layers import Dense, Activation, Flatten
from keras.optimizers import Adam
import os.path
from rl.agents.dqn import DQNAgent
from rl.policy import BoltzmannQPolicy
from rl.memory import SequentialMemory


ENV_NAME = 'CartPole-v0'


# Get the environment and extract the number of actions.
env = gym.make(ENV_NAME)
np.random.seed(123)
env.seed(123)
nb_actions = env.action_space.n
weights_filename = 'dqn_{}_weights.h5f'.format(ENV_NAME)
# Next, we build a very simple model.
model = Sequential()
model.add(Flatten(input_shape=(1,) + env.observation_space.shape))
model.add(Dense(16))
model.add(Activation('relu'))
model.add(Dense(16))
model.add(Activation('relu'))
model.add(Dense(16))
model.add(Activation('relu'))
model.add(Dense(nb_actions))
model.add(Activation('linear'))
if os.path.isfile(weights_filename):
    print('\n\n\n\nSaved parameters found. I will use this file...\n'+ weights_filename +'\n\n\n\n')
    model.load_weights(weights_filename)
else:
    print('\n\n\n\nSaved parameters Not found. Creating new one...\n\n\n\n')
print(model.summary())

# Finally, we configure and compile our agent. You can use every built-in Keras optimizer and
# even the metrics!
memory = SequentialMemory(limit=50000, window_length=1)
policy = BoltzmannQPolicy()
dqn = DQNAgent(model=model, nb_actions=nb_actions, memory=memory, nb_steps_warmup=10,
               target_model_update=1e-2, policy=policy)
dqn.compile(Adam(lr=1e-3), metrics=['mae'])

# Okay, now it's time to learn something! We visualize the training here for show, but this
# slows down training quite a lot. You can always safely abort the training prematurely using
# Ctrl + C.
dqn.fit(env, nb_steps=50000, visualize=True, verbose=2)

# After training is done, we save the final weights.
dqn.save_weights('dqn_{}_weights.h5f'.format(ENV_NAME), overwrite=True)

# Finally, evaluate our algorithm for 5 episodes.
dqn.test(env, nb_episodes=5, visualize=True)

Enregistrer le fichier et l'exécuter :
$ python3.6 test.py

L'apprentissage (de "cart pole") par renforcement va commencer:

Other examples:
https://github.com/matthiasplappert/keras-rl/tree/master/examples


Refereces

matthiasplappert/keras-rl examples (cartpole)
https://github.com/matthiasplappert/keras-rl/tree/master/examples

Aucun commentaire:

Enregistrer un commentaire