Decision Transformer代码深度剖析:理解关键模块的实现原理
Decision Transformer是一个基于序列建模的强化学习框架,它将强化学习问题转化为序列预测任务,通过Transformer架构实现策略决策。本文将深入剖析其核心代码结构,帮助读者理解关键模块的实现原理。
核心架构概览
Decision Transformer的核心创新在于将强化学习中的状态、动作和回报信号转化为序列数据,通过Transformer模型进行建模。其整体架构如图所示:
这个架构包含以下关键组件:
- 嵌入层(embedding layer):将状态、动作和回报信号转化为向量表示
- 因果Transformer(causal transformer):建模序列数据间的依赖关系
- 线性解码器(linear decoder):预测下一个动作
核心模块实现分析
DecisionTransformer类定义
核心模型定义在gym/decision_transformer/models/decision_transformer.py文件中,该类继承自TrajectoryModel,是整个框架的核心实现。
class DecisionTransformer(TrajectoryModel):
"""
This model uses GPT to model (Return_1, state_1, action_1, Return_2, state_2, ...)
"""
初始化方法
在__init__方法中,模型进行了关键组件的初始化:
def __init__(
self,
state_dim,
act_dim,
hidden_size,
max_length=None,
max_ep_len=4096,
action_tanh=True,
**kwargs
):
super().__init__(state_dim, act_dim, max_length=max_length)
self.hidden_size = hidden_size
config = transformers.GPT2Config(
vocab_size=1, # 不使用词汇表
n_embd=hidden_size,
**kwargs
)
self.transformer = GPT2Model(config) # 使用自定义的GPT2模型
# 定义各种嵌入层
self.embed_timestep = nn.Embedding(max_ep_len, hidden_size)
self.embed_return = torch.nn.Linear(1, hidden_size)
self.embed_state = torch.nn.Linear(self.state_dim, hidden_size)
self.embed_action = torch.nn.Linear(self.act_dim, hidden_size)
self.embed_ln = nn.LayerNorm(hidden_size)
# 预测头
self.predict_state = torch.nn.Linear(hidden_size, self.state_dim)
self.predict_action = nn.Sequential(
*([nn.Linear(hidden_size, self.act_dim)] + ([nn.Tanh()] if action_tanh else []))
)
self.predict_return = torch.nn.Linear(hidden_size, 1)
前向传播方法
forward方法实现了模型的核心逻辑,将状态、动作和回报信号转化为序列输入并进行预测:
def forward(self, states, actions, rewards, returns_to_go, timesteps, attention_mask=None):
batch_size, seq_length = states.shape[0], states.shape[1]
# 处理注意力掩码
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long)
# 对不同模态数据进行嵌入
state_embeddings = self.embed_state(states)
action_embeddings = self.embed_action(actions)
returns_embeddings = self.embed_return(returns_to_go)
time_embeddings = self.embed_timestep(timesteps)
# 添加时间嵌入
state_embeddings = state_embeddings + time_embeddings
action_embeddings = action_embeddings + time_embeddings
returns_embeddings = returns_embeddings + time_embeddings
# 构建序列输入 (R_1, s_1, a_1, R_2, s_2, a_2, ...)
stacked_inputs = torch.stack(
(returns_embeddings, state_embeddings, action_embeddings), dim=1
).permute(0, 2, 1, 3).reshape(batch_size, 3*seq_length, self.hidden_size)
stacked_inputs = self.embed_ln(stacked_inputs)
# 处理注意力掩码
stacked_attention_mask = torch.stack(
(attention_mask, attention_mask, attention_mask), dim=1
).permute(0, 2, 1).reshape(batch_size, 3*seq_length)
# Transformer前向传播
transformer_outputs = self.transformer(
inputs_embeds=stacked_inputs,
attention_mask=stacked_attention_mask,
)
x = transformer_outputs['last_hidden_state']
# 重塑输出并进行预测
x = x.reshape(batch_size, seq_length, 3, self.hidden_size).permute(0, 2, 1, 3)
return_preds = self.predict_return(x[:,2]) # 预测下一个回报
state_preds = self.predict_state(x[:,2]) # 预测下一个状态
action_preds = self.predict_action(x[:,1]) # 预测下一个动作
return state_preds, action_preds, return_preds
动作生成方法
get_action方法实现了基于当前状态生成动作的逻辑:
def get_action(self, states, actions, rewards, returns_to_go, timesteps, **kwargs):
# 处理输入形状
states = states.reshape(1, -1, self.state_dim)
actions = actions.reshape(1, -1, self.act_dim)
returns_to_go = returns_to_go.reshape(1, -1, 1)
timesteps = timesteps.reshape(1, -1)
# 处理序列长度限制
if self.max_length is not None:
# 截断或填充序列到最大长度
states = states[:,-self.max_length:]
actions = actions[:,-self.max_length:]
returns_to_go = returns_to_go[:,-self.max_length:]
timesteps = timesteps[:,-self.max_length:]
# 构建注意力掩码
attention_mask = torch.cat([torch.zeros(self.max_length-states.shape[1]), torch.ones(states.shape[1])])
attention_mask = attention_mask.to(dtype=torch.long, device=states.device).reshape(1, -1)
# 填充序列
states = torch.cat(
[torch.zeros((states.shape[0], self.max_length-states.shape[1], self.state_dim), device=states.device), states],
dim=1).to(dtype=torch.float32)
# 对actions, returns_to_go和timesteps进行类似填充...
else:
attention_mask = None
# 前向传播获取动作预测
_, action_preds, return_preds = self.forward(
states, actions, None, returns_to_go, timesteps, attention_mask=attention_mask, **kwargs)
return action_preds[0,-1]
项目结构解析
Decision Transformer项目包含两个主要应用场景的实现:
-
Atari游戏环境:
- 代码路径:atari/
- 主要文件:atari/run_dt_atari.py(运行脚本)、atari/mingpt/model_atari.py(Atari专用模型)
-
Gym环境:
- 代码路径:gym/
- 决策Transformer核心实现:gym/decision_transformer/
- 模型定义:gym/decision_transformer/models/
- 训练代码:gym/decision_transformer/training/
- 评估代码:gym/decision_transformer/evaluation/
总结
Decision Transformer通过将强化学习问题转化为序列预测任务,创新性地将Transformer架构应用于决策制定。其核心思想是将状态、动作和回报信号组成序列,通过因果Transformer模型进行建模和预测。
核心实现亮点包括:
- 多模态嵌入:分别对状态、动作和回报进行嵌入
- 时间嵌入:为序列添加时间信息
- 序列重组:将(R, s, a)三元组重组为序列输入
- 自回归预测:基于历史信息预测下一个动作
通过这种设计,Decision Transformer能够利用Transformer强大的序列建模能力,在强化学习任务中取得优异性能。
要开始使用Decision Transformer,可通过以下命令克隆仓库:
git clone https://gitcode.com/gh_mirrors/de/decision-transformer
项目提供了Atari游戏和Gym环境的实现,可根据需求选择相应的运行脚本进行实验。
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考




