Introduzione all’apprendimento per rinforzo. Parte 3: Q-Learning con Reti Neurali, Algoritmo DQN

Puoi eseguire tu stesso il codice TensorFlow in questo link (o una versione PyTorch in questo link), o continuare a leggere per vedere il codice senza eseguirlo. Poiché il codice è un po’ più lungo che nelle parti precedenti, mostrerò solo i pezzi più importanti qui. L’intero codice sorgente è disponibile seguendo il link sopra.

Ecco l’ambiente CartPole. Sto usando OpenAI Gym per visualizzare ed eseguire questo ambiente. L’obiettivo è quello di spostare il carrello a sinistra e a destra, al fine di mantenere il palo in una posizione verticale. Se l’inclinazione del palo è più di 15 gradi dall’asse verticale, l’episodio finirà e ricominceremo da capo. Il video 1 mostra un esempio di esecuzione di diversi episodi in questo ambiente compiendo azioni in modo casuale.

Video 1: Gioco casuale in ambiente CartPole.

Per implementare l’algoritmo DQN, inizieremo creando le DNN principale (main_nn) e target (target_nn). La rete di destinazione sarà una copia di quella principale, ma con una propria copia dei pesi. Avremo anche bisogno di un ottimizzatore e di una funzione di perdita.

Algoritmo 1. L’architettura delle Reti Neurali Profonde.

In seguito, creeremo l’experience replay buffer, per aggiungere l’esperienza al buffer e campionarla successivamente per l’allenamento.

Algoritmo 2. Experience Replay Buffer.

Scriveremo anche una funzione di aiuto per eseguire la politica ε-greedy, e per addestrare la rete principale usando i dati memorizzati nel buffer.

Algoritmo 3. Funzioni per la politica ε-greedy e per addestrare la rete neurale.

Definiremo anche gli iper-parametri necessari e addestreremo la rete neurale. Giocheremo un episodio usando la politica ε-greedy, memorizzeremo i dati nel buffer di riproduzione dell’esperienza, e addestreremo la rete principale dopo ogni passo. Una volta ogni 2000 passi, copieremo i pesi dalla rete principale nella rete target. Diminuiremo anche il valore di epsilon (ε) per iniziare con un’alta esplorazione e diminuire l’esplorazione nel tempo. Vedremo come l’algoritmo inizia ad imparare dopo ogni episodio.

Algoritmo 4. Main loop.

Questo è il risultato che verrà visualizzato:

Episode 0/1000. Epsilon: 0.99. Reward in last 100 episodes: 14.0 Episode 50/1000. Epsilon: 0.94. Reward in last 100 episodes: 22.2 Episode 100/1000. Epsilon: 0.89. Reward in last 100 episodes: 23.3 Episode 150/1000. Epsilon: 0.84. Reward in last 100 episodes: 23.4 Episode 200/1000. Epsilon: 0.79. Reward in last 100 episodes: 24.9 Episode 250/1000. Epsilon: 0.74. Reward in last 100 episodes: 30.4 Episode 300/1000. Epsilon: 0.69. Reward in last 100 episodes: 38.4 Episode 350/1000. Epsilon: 0.64. Reward in last 100 episodes: 51.4 Episode 400/1000. Epsilon: 0.59. Reward in last 100 episodes: 68.2 Episode 450/1000. Epsilon: 0.54. Reward in last 100 episodes: 82.4 Episode 500/1000. Epsilon: 0.49. Reward in last 100 episodes: 102.1 Episode 550/1000. Epsilon: 0.44. Reward in last 100 episodes: 129.7 Episode 600/1000. Epsilon: 0.39. Reward in last 100 episodes: 151.7 Episode 650/1000. Epsilon: 0.34. Reward in last 100 episodes: 173.0 Episode 700/1000. Epsilon: 0.29. Reward in last 100 episodes: 187.3 Episode 750/1000. Epsilon: 0.24. Reward in last 100 episodes: 190.9 Episode 800/1000. Epsilon: 0.19. Reward in last 100 episodes: 194.6 Episode 850/1000. Epsilon: 0.14. Reward in last 100 episodes: 195.9 Episode 900/1000. Epsilon: 0.09. Reward in last 100 episodes: 197.9 Episode 950/1000. Epsilon: 0.05. Reward in last 100 episodes: 200.0 Episode 1000/1000. Epsilon: 0.05. Reward in last 100 episodes: 200.0

Ora che l’agente ha imparato a massimizzare la ricompensa per l’ambiente CartPole, faremo interagire l’agente con l’ambiente un’altra volta, per visualizzare il risultato e vedere che ora è in grado di mantenere l’asta in equilibrio per 200 frame.

Video 2. Risultato della visualizzazione dell’agente DQN addestrato che interagisce con l’ambiente CartPole.

Puoi eseguire tu stesso il codice TensorFlow in questo link (o una versione PyTorch in questo link).

Lascia un commento

Il tuo indirizzo email non sarà pubblicato.