Decision Transformer代码深度剖析:理解关键模块的实现原理

Decision Transformer代码深度剖析:理解关键模块的实现原理

【免费下载链接】decision-transformer Official codebase for Decision Transformer: Reinforcement Learning via Sequence Modeling. 【免费下载链接】decision-transformer 项目地址: https://gitcode.com/gh_mirrors/de/decision-transformer

Decision Transformer是一个基于序列建模的强化学习框架,它将强化学习问题转化为序列预测任务,通过Transformer架构实现策略决策。本文将深入剖析其核心代码结构,帮助读者理解关键模块的实现原理。

核心架构概览

Decision Transformer的核心创新在于将强化学习中的状态、动作和回报信号转化为序列数据,通过Transformer模型进行建模。其整体架构如图所示:

Decision Transformer架构图

这个架构包含以下关键组件:

  • 嵌入层(embedding layer):将状态、动作和回报信号转化为向量表示
  • 因果Transformer(causal transformer):建模序列数据间的依赖关系
  • 线性解码器(linear decoder):预测下一个动作

核心模块实现分析

DecisionTransformer类定义

核心模型定义在gym/decision_transformer/models/decision_transformer.py文件中,该类继承自TrajectoryModel,是整个框架的核心实现。

class DecisionTransformer(TrajectoryModel):
    """
    This model uses GPT to model (Return_1, state_1, action_1, Return_2, state_2, ...)
    """

初始化方法

__init__方法中,模型进行了关键组件的初始化:

def __init__(
        self,
        state_dim,
        act_dim,
        hidden_size,
        max_length=None,
        max_ep_len=4096,
        action_tanh=True,
        **kwargs
):
    super().__init__(state_dim, act_dim, max_length=max_length)
    
    self.hidden_size = hidden_size
    config = transformers.GPT2Config(
        vocab_size=1,  # 不使用词汇表
        n_embd=hidden_size,
        **kwargs
    )
    
    self.transformer = GPT2Model(config)  # 使用自定义的GPT2模型
    
    # 定义各种嵌入层
    self.embed_timestep = nn.Embedding(max_ep_len, hidden_size)
    self.embed_return = torch.nn.Linear(1, hidden_size)
    self.embed_state = torch.nn.Linear(self.state_dim, hidden_size)
    self.embed_action = torch.nn.Linear(self.act_dim, hidden_size)
    
    self.embed_ln = nn.LayerNorm(hidden_size)
    
    # 预测头
    self.predict_state = torch.nn.Linear(hidden_size, self.state_dim)
    self.predict_action = nn.Sequential(
        *([nn.Linear(hidden_size, self.act_dim)] + ([nn.Tanh()] if action_tanh else []))
    )
    self.predict_return = torch.nn.Linear(hidden_size, 1)

前向传播方法

forward方法实现了模型的核心逻辑,将状态、动作和回报信号转化为序列输入并进行预测:

def forward(self, states, actions, rewards, returns_to_go, timesteps, attention_mask=None):
    batch_size, seq_length = states.shape[0], states.shape[1]
    
    # 处理注意力掩码
    if attention_mask is None:
        attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long)
    
    # 对不同模态数据进行嵌入
    state_embeddings = self.embed_state(states)
    action_embeddings = self.embed_action(actions)
    returns_embeddings = self.embed_return(returns_to_go)
    time_embeddings = self.embed_timestep(timesteps)
    
    # 添加时间嵌入
    state_embeddings = state_embeddings + time_embeddings
    action_embeddings = action_embeddings + time_embeddings
    returns_embeddings = returns_embeddings + time_embeddings
    
    # 构建序列输入 (R_1, s_1, a_1, R_2, s_2, a_2, ...)
    stacked_inputs = torch.stack(
        (returns_embeddings, state_embeddings, action_embeddings), dim=1
    ).permute(0, 2, 1, 3).reshape(batch_size, 3*seq_length, self.hidden_size)
    stacked_inputs = self.embed_ln(stacked_inputs)
    
    # 处理注意力掩码
    stacked_attention_mask = torch.stack(
        (attention_mask, attention_mask, attention_mask), dim=1
    ).permute(0, 2, 1).reshape(batch_size, 3*seq_length)
    
    # Transformer前向传播
    transformer_outputs = self.transformer(
        inputs_embeds=stacked_inputs,
        attention_mask=stacked_attention_mask,
    )
    x = transformer_outputs['last_hidden_state']
    
    # 重塑输出并进行预测
    x = x.reshape(batch_size, seq_length, 3, self.hidden_size).permute(0, 2, 1, 3)
    
    return_preds = self.predict_return(x[:,2])  # 预测下一个回报
    state_preds = self.predict_state(x[:,2])    # 预测下一个状态
    action_preds = self.predict_action(x[:,1])  # 预测下一个动作
    
    return state_preds, action_preds, return_preds

动作生成方法

get_action方法实现了基于当前状态生成动作的逻辑:

def get_action(self, states, actions, rewards, returns_to_go, timesteps, **kwargs):
    # 处理输入形状
    states = states.reshape(1, -1, self.state_dim)
    actions = actions.reshape(1, -1, self.act_dim)
    returns_to_go = returns_to_go.reshape(1, -1, 1)
    timesteps = timesteps.reshape(1, -1)
    
    # 处理序列长度限制
    if self.max_length is not None:
        # 截断或填充序列到最大长度
        states = states[:,-self.max_length:]
        actions = actions[:,-self.max_length:]
        returns_to_go = returns_to_go[:,-self.max_length:]
        timesteps = timesteps[:,-self.max_length:]
        
        # 构建注意力掩码
        attention_mask = torch.cat([torch.zeros(self.max_length-states.shape[1]), torch.ones(states.shape[1])])
        attention_mask = attention_mask.to(dtype=torch.long, device=states.device).reshape(1, -1)
        
        # 填充序列
        states = torch.cat(
            [torch.zeros((states.shape[0], self.max_length-states.shape[1], self.state_dim), device=states.device), states],
            dim=1).to(dtype=torch.float32)
        # 对actions, returns_to_go和timesteps进行类似填充...
    else:
        attention_mask = None
    
    # 前向传播获取动作预测
    _, action_preds, return_preds = self.forward(
        states, actions, None, returns_to_go, timesteps, attention_mask=attention_mask, **kwargs)
    
    return action_preds[0,-1]

项目结构解析

Decision Transformer项目包含两个主要应用场景的实现:

  1. Atari游戏环境

  2. Gym环境

总结

Decision Transformer通过将强化学习问题转化为序列预测任务,创新性地将Transformer架构应用于决策制定。其核心思想是将状态、动作和回报信号组成序列,通过因果Transformer模型进行建模和预测。

核心实现亮点包括:

  • 多模态嵌入:分别对状态、动作和回报进行嵌入
  • 时间嵌入:为序列添加时间信息
  • 序列重组:将(R, s, a)三元组重组为序列输入
  • 自回归预测:基于历史信息预测下一个动作

通过这种设计,Decision Transformer能够利用Transformer强大的序列建模能力,在强化学习任务中取得优异性能。

要开始使用Decision Transformer,可通过以下命令克隆仓库:

git clone https://gitcode.com/gh_mirrors/de/decision-transformer

项目提供了Atari游戏和Gym环境的实现,可根据需求选择相应的运行脚本进行实验。

【免费下载链接】decision-transformer Official codebase for Decision Transformer: Reinforcement Learning via Sequence Modeling. 【免费下载链接】decision-transformer 项目地址: https://gitcode.com/gh_mirrors/de/decision-transformer

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值