Skip to content

Actor-Critic 算法

📅 发表于 2025/08/31
🔄 更新于 2025/08/31
👁️ -- 次访问
📝 0 字
0 分钟
rl-theory
#Actor
#Critic
#Q Actor-Critic
#A2C
#A3C
#Advantage
#Actor loss
#Critic Loss
#GAE
#λ-return
#单步TD
#MC估计

Actor-Critic 算法

核心思想

Actor-Critic 核心思想

核心思想

  • 结合策略梯度时序差分的强化学习算法。
  • 演员:策略函数πθ(as)
    • 学习一个策略,得到尽可能高的回报
    • 和环境交互采样
  • 评论员:价值函数Vπ(s)
    • 估计当前策略的价值,即评估演员的好坏
    • 输入轨迹,评估当前(s,a)的价值

优点

  • 兼顾策略梯度和时序差分的优点
  • 能缓解二者难以解决的高方差问题
    • 策略梯度算法:利用策略和环境交互采样估计策略梯度
    • 基于价值的算法:需要和环境采样来估计价值函数
  • 为何能缓解
    • Actor 只负责策略梯度采样
    • Critic 只负责策略的价值估计,带来了更稳定的估计

QAC:最简单的AC算法 Q Actor-Critic

Q Actor-Critic

核心思想

  • CriticQ函数Qϕ(st,at),作为策略梯度里的权重,代替MC策略梯度里的累计回报Gt
J(θ)=1Nn=1Nt=0TnQϕ(st,at)logpθ(atnstn)a
  • MC策略梯度使用累计回报Gt来作为权重
J(θ)=1Nn=1Nt=0TnGtnlogpθ(atnstn)a

A2C: 优势AC算法 / Advantage Actor Crictic

核心思想

A2C

核心思想

  • 引入优势函数=Q-V作为权重,动作是Q,基线是V
    • 优势表示当前(s,a)相比其他动作平均水平的优势
    • Vπ(st)同一状态下的基线, 而非所有状态的均值
Aπ(st,at)=Qπ(st,at)Vπ(st)
  • 同策略算法

TD误差近似优势函数

  • Q函数可用V函数做TD估计,直接去掉期望值是因为论文实验这个方法效果好。
Qπ(st,at)=E[rt+1+γVπ(st+1)]rt+1+γVπ(st+1)Aπ(st,at)=rt+1+γVπ(st+1)Vπ(st)TD
  • 优点
    • 只需估计V,无需估计Q
    • r的方差相比G的方差小很多

策略梯度

J(θ)=1Nn=1Nt=0TnAπ(stn,atn),logpθ(atnstn)aJ(θ)=1Nn=1Nt=0Tn(Qπ(stn,atn)Vπ(stn)),logpθ(atnstn)aJ(θ)=1Nn=1Nt=0Tn(rtn+γVπ(st+1n)Vπ(stn)),logpθ(atnstn)a

loss

  • Actor Loss:策略梯度
    • actor_loss= -(log_probs * advantages.detach()).mean()
loss=Aπ(st,at)logpθ(st,at)
  • Critic Loss:对优势使用TD error 均方误差,即使二者更加接近即可
    • critic_loss = advantages.pow(2).mean()

优点

  • 减去基线,降低了方差,详见降低方差数学推导过程

算法流程

A2C 算法流程
  • 初始Actor(策略π),Actor和环境交互,采样一些资料
  • 用采样资料结合TD方法,估计V函数
  • 基于V函数计算策略梯度,再更新参数
J(θ)=1Nn=1Nt=0Tn(rtn+γVπ(st+1n)Vπ(stn)),logpθ(atnstn)a

关键技巧

A2C 关键技巧

技巧1:估计V和Actor2个网络

  • Critic:Vπ(s),输入状态,输出标量
  • Actor:π(s),输入状态,输出动作分布,

技巧2:探索机制

  • π(s)输出设置约束,使分布的熵不要太小,希望不同动作的采样概率平均一些

技巧3:优势函数值域固定到[-1,1]

  • 做回报归一化,让优势函数更稳定,从而减小反差

伪代码

Actor、Critic 定义

python
# 分开定义
class Critic(nn.Module):
  ''' 估计状态价值,输出标量
  '''
    def __init__(self,state_dim):
        self.fc1 = nn.Linear(state_dim, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 1)
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        value = self.fc3(x)
        return value

class Actor(nn.Module):
  ''' 采样动作,输出logits_p,动作概率
  '''
    def __init__(self, state_dim, action_dim):
        self.fc1 = nn.Linear(state_dim, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, action_dim)
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        logits_p = F.softmax(self.fc3(x), dim=1)
        return logits_p

# 合在一起定义
class ActorCritic(nn.Module):
  	''' 输入状态,输出动作概率和状态价值
  	'''
    def __init__(self, state_dim, action_dim):
        self.fc1 = nn.Linear(state_dim, 256)
        self.fc2 = nn.Linear(256, 256)
        self.action_layer = nn.Linear(256, action_dim)
        self.value_layer = nn.Linear(256, 1)
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        logits_p = F.softmax(self.action_layer(x), dim=1)
        value = self.value_layer(x)
        return logits_p, value

Agent:采样动作、计算优势函数、计算loss、策略更新

python
from torch.distributions import Categorical
class Agent:
    def __init__(self):
        self.model = ActorCritic(state_dim, action_dim)
    
    def sample_action(self,state):
        ''' 动作采样
        '''
        state = torch.tensor(state, device=self.device, dtype=torch.float32)
        # 策略网络输出动作概率分布
        logits_p, value = self.model(state)
        dist = Categorical(logits_p) 
        # 依概率采样一个分布
        action = dist.sample() 
        return action
    
    def _compute_returns(self, rewards, dones):
        ''' 计算回报,做归一化
        '''
        returns = []
        discounted_sum = 0
        for reward, done in zip(reversed(rewards), reversed(dones)):
            if done:
                discounted_sum = 0
            # 折扣回报
            discounted_sum = reward + (self.gamma * discounted_sum)
            returns.insert(0, discounted_sum)
        # 回报归一化
        returns = torch.tensor(returns, device=self.device, dtype=torch.float32).unsqueeze(dim=1)
        returns = (returns - returns.mean()) / (returns.std() + 1e-5) # 1e-5 to avoid division by zero
        return returns
    
    def compute_advantage(self):
        ''' 计算优势
        '''
        # 从经验池中采样数据:动作概率、状态、回报、结束
        logits_p, states, rewards, dones = self.memory.sample()
        # 根据rewards 计算回报
        returns = self._compute_returns(rewards, dones)
        states = torch.tensor(states, device=self.device, dtype=torch.float32)
        # 当前模型去估计状态价值 
        logits_p, values = self.model(states)
        # 实际回报 - 状态价值 作为优势
        advantages = returns - values
        return advantages
      
  	def compute_loss(self):
        '''计算损失函数
        '''
        # 采样数据
        logits_p, states, rewards, dones = self.memory.sample()
        returns = self._compute_returns(rewards, dones)
        states = torch.tensor(states, device=self.device, dtype=torch.float32)
        # 估计价值V、计算策略logits_p
        logits_p, values = self.model(states)
        # 计算advantages
        advantages = returns - values
        dist = Categorical(logits_p)
        # 计算log_prob
        log_probs = dist.log_prob(actions)
        # 注意这里策略损失反向传播时不需要优化优势函数,因此需要将其 detach 掉
        actor_loss = -(log_probs * advantages.detach()).mean() 
        # critic loss
        critic_loss = advantages.pow(2).mean()
        return actor_loss, critic_loss

A3C 异步优势AC算法/ Asynchronous Advantage Actor Critic

A3C:A2C算法的异步扩展版

背景

  • 强化学习很慢,需要加快速度
  • 可以像鸣人一样使用多个分身进行修行。

核心思想

  • 使用1个全局网络+多个worker并行探索训练
  • 每个进程单独训练,计算出梯度后回传中央控制中心,去更新原来的参数。
    • 原来的参数被别的进程覆盖掉怎么办?没关系,还是正常更新就好了
  • 是一种同策略算法,虽然看起来像异策略
    • 每个worker内,只用当前策略采样的数据计算梯度

优点

  • 不存储历史数据,通过平行探索保持训练稳定性

广义优势估计

背景

GAE 背景

背景

  • A2C引入优势函数来缓解了方差,但TD本质还是使用蒙特卡洛估计,还是会产生高方差。
  • MC:无偏估计,带来高方差。TD:有偏估计,缓解高方差。

λ-return

  • n个k步估计量(k从1到n)进行移动加权平均平衡了TD和MC方法,也平衡了偏差和方差
Gt:Tλ=(1λ)n=1Tt1λn1Gt:t+n+λTt1Gt

GAE 推导过程

GAE 推导过程

广义优势估计核心思想

  • 广义优势估计,类似于λ-return,使用n个估计量
At1=δt=V(st)+rt+γV(st+1)At2=δt+γδt+1=V(st)+rt+γV(st+1)+γ2V(st+2)At3=δt+γδt+1+γ2δt+2=V(st)+rt+γV(st+1)+γ2V(st+2)+γ3V(st+3)Atk=l=0k1γlδt+l=V(st)+rt+γV(st+1)+γ2V(st+l)+γkV(st+k)At=l=0γlδt+l=V(st)+l=0γlrt+l
  • GAE对这n个k步估计量,k从1到n,进行加权平均
AtGAE(γ,λ)(st,at)=(1λ)(At(1)+λAt(2)+λ2At(3)+)=(1λ)(δt+λ(δt+γδt+1)+λ2(δt+γδt+1+γ2δt+2)+)=(1λ)(δt(1+λ+λ2+)+γδt+1(λ+λ2+)+γ2δt+2(λ2+λ3+)+)=(1λ)(δt11λ+γδt+1λ1λ+γ2δt+2λ21λ+)=l=0(γλ)lδt+l=l=0(γλ)l(rt+l+γV(st+l+1)V(st+l))

GAE 总结

GAE 总结

总结

  • δt+l:时步t+lTD误差
δt+l=rt+l+γV(st+l+1)V(st+l)
  • 广义优势估计
AtGAE(γ,λ)(st,at)=l=0(γλ)lδt+lAtGAE(γ,λ)(st,at)=l=0(γλ)l(rt+l+γV(st+l+1)V(st+l))
  • λ=0,GAE退化为单步TD
AtGAE(γ,λ)(st,at)=δt=rt+l+γV(st+1)V(st)
  • λ=1,GAE退化为MC估计
AtGAE(γ,λ)(st,at)=l=0(γ)l(rt+l+γV(st+l+1)V(st+l))

优点

  • 平衡了MC和TD,平衡了方差和偏差。
总访客数:   ·   总访问量:
PLM's Blog @ 2016 - 2025