目录
前言
本文将介绍Actor-Critic算法基本原理和实现步骤,以及基于pytorch框架的代码实现详解。
本文代码实现的编程环境如下:
编辑器:VS Code
编译器:python 3.9.10
依赖:
gym==0.26.2
gym-notices==0.0.8
gymnasium==0.29.1
matplotlib==3.8.3
matplotlib-inline==0.1.6
numpy==1.24.1
opencv-contrib-python==3.4.16.59
opencv-python==3.4.16.59
pygame==2.5.2
torch==2.1.2+cu121
torchaudio==2.1.2+cu121
torchvision==0.16.2+cu121
一、简介
Actor-Critic算法是一种强化学习算法,它结合了策略优化方法(如Policy Gradient方法)和值函数方法(如Q-Learning)的优点。
AC算法主要有两个部分:Actor和Critic
Actor负责选择动作。Actor网络的输入是状态信息,输出是所有动作的选择概率,随后便可以根据输出的动作选择概率进行随机抽取。
Critic负责评估Actor选择的动作的好坏。Critic网络的输入是状态信息,输出是价值。
通过训练这两个部分来让智能体完成工作便是这个算法的目的。
二、基本原理
2.1 核心步骤概述
AC算法的基本工作流程分为以下几步:
1、策略网络根据游戏环境状态信息给出动作概率。
2、根据动作概率随机抽取一个动作作为这一步要做出的动作。
3、游戏环境根据动作做出相应动作,使得游戏环境进入下一个状态。
4、计算策略损失和价值损失,回传梯度进行更新。

注:以上是最基本的也是最核心的步骤,其余细节后续再考虑。
2.2 核心步骤基本原理
2.2.1 决策动作
决策动作有两步,一是要通过网络获得动作概率分布,二是要根据概率分布抽取动作。
第一步需要一个决策用的神经网络,该网络可以自定义,也可以参考其他项目中采用的网络。且要注意输入输出的张量维度大小要与游戏环境要求的一致。记该网络为actor,当前状态为state。那么决策动作的过程就是将state当做输入放入网络actor当中运行一遍,即actor(state)。其返回的结果(输出),就是动作的概率分布。
第二步则是根据动作概率分布来进行随机动作。例如向左走和向右走的概率分别为[0.3, 0.7],这就意味着本次选择向左走的概率为30%,向右走的概率为70%。抽取结束后,便完成了决策动作。
2.2.2 游戏更新
将上一步选出的动作传入游戏当中,在游戏中操作出这一步,那么游戏环境就会从上一个状态迈向下一个状态,此即游戏状态更新。记下一个状态为next_state,后续的计算将会用到它。
2.2.3 损失回传
由于有两个网络需要更新,所以损失的计算分为两个部分。在AC算法中,这两个网络所采用的损失函数并不相同,需要分开计算。在计算完两个损失函数后,便可以调用backward()函数进行回传更新网络权重。
2.2.3.1 Actor网络损失函数
策略网络的损失函数如下:
loss = − ln π ( a ∣ s ) ⋅ A ( s , a ) \text{loss} = -\ln\pi(a|s) \cdot A(s, a) loss=−lnπ(a∣s)⋅A(s,a)
π ( a ∣ s ) π(a∣s) π(a∣s)就是策略函数在s状态下做出a动作的概率。一般来说,这里的策略函数就是指策略网络。而 ln π ( a ∣ s ) \ln\pi(a|s) lnπ(a∣s)就是将这个概率求对数。
A ( s , a ) A(s, a) A(s,a)是优势函数,其公式是一个时序差分,是真实价值和预测价值的差。具体公式如下:
A ( s , a ) = Q ( s , a ) − V ( s ) = R + γ ⋅ V ( s ′ ) − V ( s ) A(s, a) = Q(s, a) - V(s) = R + \gamma \cdot V(s') - V(s) A(s,a)=Q(s,a)−V(s)=R+γ⋅V(s′)−V(s)
Q Q Q表示真实价值,也是 s s s状态下做出动作 a a a的价值,根据上述等式可以轻易发现 Q ( s , a ) = R + γ ⋅ V ( s ′ ) Q(s, a) = R + \gamma \cdot V(s') Q(s,a)=R+γ⋅V(s′)。
公式当中的符号表示的意义:
1、 R R R表示 s s s状态下做出 a a a动作后得到的奖励。
2、 γ \gamma γ表示折扣因子,用于计算累积折扣回报,即未来的预期收益。在无改进的AC算法中,这里的未来回报只计算后一步的。折扣因子越大,表示越重视未来的回报,一般取0.9~0.99。
3、 V ( s ′ ) V(s') V(s′)表示 s s s做出 a a a后的下一个状态 s ′ s' s′的价值。
V V V表示预测价值。预测价值表示的价值函数计算得出的价值。一般来说,这里的价值函数就是指价值网络。 V ( s ) V(s) V(s)就表示 s s s状态时,价值网络输出的价值。
注:价值函数通常被定义为在给定状态下,智能体按照其策略执行动作所能获得的预期回报。因为它的意义在于表示进入该状态后准备去决策动作的价值,而不是进入该状态并选择了特定动作的价值,所以价值函数的输入是不需要将决策的动作也一并放入的。也即状态的价值,而非状态下动作的价值。
将公式合并得最终公式:
loss = − ln π ( a ∣ s ) ⋅ ( R + γ ⋅ V ( s ′ ) − V ( s ) ) \text{loss} = -\ln\pi(a|s) \cdot (R + \gamma \cdot V(s') - V(s)) loss=−lnπ(a∣s)⋅(R+γ⋅V(s′)−V(s))
在使用该公式时,需要额外注意真实价值部分的 V ( s ′ ) V(s') V(s′)。当游戏在 s s s状态时便已经结束,那么 V ( s ′ ) V(s') V(s′)应当为0。
2.3.3.2 Critic网络损失函数
价值网络的损失函数如下:
loss = M S E ( Q ( s , a ) , V ( s ) ) \text{loss} = MSE(Q(s, a), V(s)) loss=MSE(Q(s,a),V(s))
该公式的意思是计算真实价值和预测价值的均方差。均方差公式如下:
M S E ( Y , Y ^ ) = 1 n ∑ i = 1 n ( y i − y ^ i ) 2 MSE(Y, \hat{Y}) = \frac{1}{n} \sum_{i=1}^{n} (y_i - \hat{y}_i)^2 MSE(Y,Y^)=n1i=1∑n(yi−y


4578

被折叠的 条评论
为什么被折叠?



