Introducere la învățarea prin întărire. Partea a 3-a: Q-Learning with Neural Networks, Algoritm DQN

Puteți rula singur codul TensorFlow în acest link (sau o versiune PyTorch în acest link), sau continuați să citiți pentru a vedea codul fără a-l rula. Deoarece codul este un pic mai lung decât în părțile anterioare, voi arăta aici doar cele mai importante bucăți. Întregul cod sursă este disponibil urmând link-ul de mai sus.

Iată mediul CartPole. Folosesc OpenAI Gym pentru a vizualiza și rula acest mediu. Scopul este de a muta căruciorul în stânga și în dreapta, pentru a menține stâlpul în poziție verticală. Dacă înclinarea stâlpului este mai mare de 15 grade față de axa verticală, episodul se va încheia și o vom lua de la capăt. Videoclipul 1 prezintă un exemplu de derulare a mai multor episoade în acest mediu prin efectuarea de acțiuni aleatorii.

Video 1: Joc aleatoriu în mediul CartPole.

Pentru a implementa algoritmul DQN, vom începe prin a crea DNN-urile principal (main_nn) și țintă (target_nn). Rețeaua țintă va fi o copie a celei principale, dar cu propria sa copie a ponderilor. Vom avea nevoie, de asemenea, de un optimizator și de o funcție de pierdere.

Algoritm 1. Arhitectura rețelelor neuronale profunde.

În continuare, vom crea bufferul de reluare a experienței, pentru a adăuga experiența în buffer și a o eșantiona ulterior pentru instruire.

Algoritm 2. Bufferul de reluare a experienței.

De asemenea, vom scrie o funcție ajutătoare pentru a rula politica ε-greedy și pentru a antrena rețeaua principală folosind datele stocate în buffer.

Algoritm 3. Funcții pentru politica ε-greedy și pentru antrenarea rețelei neuronale.

De asemenea, vom defini hiperparametrii necesari și vom antrena rețeaua neuronală. Vom reda un episod folosind politica ε-greedy, vom stoca datele în bufferul de reluare a experienței și vom antrena rețeaua principală după fiecare etapă. O dată la fiecare 2000 de pași, vom copia ponderile din rețeaua principală în rețeaua țintă. De asemenea, vom scădea valoarea lui epsilon (ε) pentru a începe cu o explorare ridicată și pentru a diminua explorarea în timp. Vom vedea cum algoritmul începe să învețe după fiecare episod.

Algoritm 4. Bucla principală.

Acesta este rezultatul care va fi afișat:

Episode 0/1000. Epsilon: 0.99. Reward in last 100 episodes: 14.0 Episode 50/1000. Epsilon: 0.94. Reward in last 100 episodes: 22.2 Episode 100/1000. Epsilon: 0.89. Reward in last 100 episodes: 23.3 Episode 150/1000. Epsilon: 0.84. Reward in last 100 episodes: 23.4 Episode 200/1000. Epsilon: 0.79. Reward in last 100 episodes: 24.9 Episode 250/1000. Epsilon: 0.74. Reward in last 100 episodes: 30.4 Episode 300/1000. Epsilon: 0.69. Reward in last 100 episodes: 38.4 Episode 350/1000. Epsilon: 0.64. Reward in last 100 episodes: 51.4 Episode 400/1000. Epsilon: 0.59. Reward in last 100 episodes: 68.2 Episode 450/1000. Epsilon: 0.54. Reward in last 100 episodes: 82.4 Episode 500/1000. Epsilon: 0.49. Reward in last 100 episodes: 102.1 Episode 550/1000. Epsilon: 0.44. Reward in last 100 episodes: 129.7 Episode 600/1000. Epsilon: 0.39. Reward in last 100 episodes: 151.7 Episode 650/1000. Epsilon: 0.34. Reward in last 100 episodes: 173.0 Episode 700/1000. Epsilon: 0.29. Reward in last 100 episodes: 187.3 Episode 750/1000. Epsilon: 0.24. Reward in last 100 episodes: 190.9 Episode 800/1000. Epsilon: 0.19. Reward in last 100 episodes: 194.6 Episode 850/1000. Epsilon: 0.14. Reward in last 100 episodes: 195.9 Episode 900/1000. Epsilon: 0.09. Reward in last 100 episodes: 197.9 Episode 950/1000. Epsilon: 0.05. Reward in last 100 episodes: 200.0 Episode 1000/1000. Epsilon: 0.05. Reward in last 100 episodes: 200.0

Acum că agentul a învățat să maximizeze recompensa pentru mediul CartPole, vom face agentul să interacționeze cu mediul încă o dată, pentru a vizualiza rezultatul și a vedea că acum este capabil să mențină polul în echilibru timp de 200 de cadre.

Video 2. Rezultatul vizualizării agentului DQN antrenat care interacționează cu mediul CartPole.

Puteți rula singur codul TensorFlow în acest link (sau o versiune PyTorch în acest link).

.

Lasă un răspuns

Adresa ta de email nu va fi publicată.