1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681
| from dataclasses import dataclass from typing import Dict
import numpy as np import torch import torch.nn as nn from torch.distributions import Categorical from torch.utils.data import BatchSampler, SubsetRandomSampler
# ============================================================ # 1. 超参数配置 # ============================================================
@dataclass class PPOConfig: # 环境维度 state_dim: int action_dim: int
# 强化学习超参数 gamma: float = 0.99 # 折扣因子 γ gae_lambda: float = 0.95 # GAE 中的 λ clip_eps: float = 0.2 # PPO clipping 范围 ε
# 损失权重 value_coef: float = 0.5 # Critic loss 权重 entropy_coef: float = 0.01 # 熵奖励权重,鼓励探索
# 优化相关 learning_rate: float = 3e-4 update_epochs: int = 10 # 一批 rollout 重复更新多少轮 mini_batch_size: int = 64 max_grad_norm: float = 0.5 # 梯度裁剪
# 数据收集相关 rollout_steps: int = 2048 # 收集多少步后执行一次 update
device: str = "cuda" if torch.cuda.is_available() else "cpu"
# ============================================================ # 2. Rollout Buffer # ============================================================
class RolloutBuffer: """ PPO 是 on-policy 算法。
每轮训练流程为: 1. 使用当前策略与环境交互; 2. 将交互数据保存到 buffer; 3. 使用这一批数据更新策略若干轮; 4. 清空 buffer; 5. 使用更新后的策略重新采样。
注意: - rewards 保存的是环境即时返回的 r_t; - returns 与 advantages 在 rollout 收集完成后再计算; - old_log_probs 保存"采样时策略"对动作的概率, 后续 PPO 计算新旧策略概率比时会使用它。 """
def __init__(self): self.states = [] self.actions = [] self.rewards = [] self.dones = []
self.values = [] self.old_log_probs = []
self.advantages = [] self.returns = []
def add( self, state: np.ndarray, action: int, reward: float, done: bool, value: float, log_prob: float, ) -> None: """ 每执行一次 env.step(action),就调用一次 add()。
参数对应: state = s_t action = a_t reward = r_t,由环境返回 done = 当前 episode 是否结束 value = V_old(s_t) log_prob = log π_old(a_t | s_t) """ self.states.append(state) self.actions.append(action) self.rewards.append(reward) self.dones.append(done)
self.values.append(value) self.old_log_probs.append(log_prob)
def compute_returns_and_advantages( self, last_value: float, gamma: float, gae_lambda: float, ) -> None: """ 根据已经采样好的 rollout,计算: 1. TD Error: δ_t = r_t + γ V(s_{t+1}) - V(s_t)
2. GAE Advantage: A_t = δ_t + γλ δ_{t+1} + (γλ)^2 δ_{t+2} + ...
3. Return / Value Target: R_t = A_t + V(s_t)
这里的 last_value 表示: 如果 rollout 最后一步还没有真正结束, 就需要用 Critic 预测 V(s_{T+1}) 来 bootstrap。
如果最后一步刚好 terminal, 则 last_value = 0。 """
num_steps = len(self.rewards)
self.advantages = [0.0 for _ in range(num_steps)] self.returns = [0.0 for _ in range(num_steps)]
gae = 0.0
# 从后向前计算,因为当前优势依赖未来的 TD Error for t in reversed(range(num_steps)):
if t == num_steps - 1: # rollout 的最后一个位置 next_value = last_value else: # 中间位置的下一状态价值,已在采样时保存 next_value = self.values[t + 1]
# 如果当前 transition 已经到达终止状态, # 那么之后没有未来收益,不应该 bootstrap。 non_terminal = 1.0 - float(self.dones[t])
# TD Error: # δ_t = r_t + γV(s_{t+1}) - V(s_t) td_delta = ( self.rewards[t] + gamma * next_value * non_terminal - self.values[t] )
# GAE: # A_t = δ_t + γλ A_{t+1} gae = ( td_delta + gamma * gae_lambda * non_terminal * gae )
self.advantages[t] = gae
# Critic 的训练目标: # R_t = A_t + V(s_t) self.returns[t] = gae + self.values[t]
def clear(self) -> None: """一次 PPO update 完成后,清空旧 rollout。""" self.states.clear() self.actions.clear() self.rewards.clear() self.dones.clear()
self.values.clear() self.old_log_probs.clear()
self.advantages.clear() self.returns.clear()
def __len__(self) -> int: return len(self.rewards)
# ============================================================ # 3. Actor-Critic Network # ============================================================
class ActorCritic(nn.Module): """ Actor: 输入状态 s_t; 输出动作分布 π(a | s)。
Critic: 输入状态 s_t; 输出状态价值 V(s_t)。
这里使用两个独立 MLP,便于理解 Actor 与 Critic 的分工。 """
def __init__(self, state_dim: int, action_dim: int): super().__init__()
# ---------------------------------------------------- # Actor:输出每个离散动作的 logits # ---------------------------------------------------- self.actor = nn.Sequential( nn.Linear(state_dim, 64), nn.Tanh(), nn.Linear(64, 64), nn.Tanh(), nn.Linear(64, action_dim), )
# ---------------------------------------------------- # Critic:输出一个标量 V(s) # ---------------------------------------------------- self.critic = nn.Sequential( nn.Linear(state_dim, 64), nn.Tanh(), nn.Linear(64, 64), nn.Tanh(), nn.Linear(64, 1), )
def get_value(self, state: torch.Tensor) -> torch.Tensor: """ 仅计算 Critic 价值: V(s) """ return self.critic(state).squeeze(-1)
def get_action_and_value( self, state: torch.Tensor, action: torch.Tensor | None = None, ): """ 统一完成: 1. 根据 Actor 构造动作分布; 2. 采样动作,或者评价给定动作; 3. 计算 log_prob; 4. 计算 entropy; 5. 计算 Critic value。
两种使用场景:
场景 A:与环境交互时 action=None 网络会采样动作 a_t。
场景 B:更新 PPO 时 action 为 buffer 中保存的旧动作 网络会计算"新策略下这些旧动作的概率"。 """
logits = self.actor(state)
# 使用 logits 而不是手动 Softmax: # Categorical 内部会稳定地处理 softmax 与 log_prob。 dist = Categorical(logits=logits)
if action is None: action = dist.sample()
log_prob = dist.log_prob(action) entropy = dist.entropy() value = self.get_value(state)
return action, log_prob, entropy, value
# ============================================================ # 4. PPO Agent # ============================================================
class PPOAgent: """ 核心职责: 1. act(): 使用当前策略采样动作; 2. value(): 估计状态价值; 3. update(): 使用 buffer 中的数据执行 PPO 更新。 """
def __init__(self, config: PPOConfig): self.config = config self.device = torch.device(config.device)
self.network = ActorCritic( state_dim=config.state_dim, action_dim=config.action_dim, ).to(self.device)
self.optimizer = torch.optim.Adam( self.network.parameters(), lr=config.learning_rate, )
@torch.no_grad() def act(self, state: np.ndarray): """ 与环境交互时调用。
输入: state = 当前环境状态 s_t
输出: action = 从 π_old(a | s_t) 中采样得到的动作 log_prob = log π_old(a_t | s_t) value = V_old(s_t)
由于该阶段只是采样数据,因此不需要保存计算图。 """
state_tensor = torch.tensor( state, dtype=torch.float32, device=self.device, ).unsqueeze(0)
action, log_prob, _, value = self.network.get_action_and_value( state_tensor )
return ( action.item(), log_prob.item(), value.item(), )
@torch.no_grad() def value(self, state: np.ndarray) -> float: """ 仅预测某个状态的价值 V(s)。
主要用于: rollout 结束但 episode 尚未结束时, 估计最后一个状态之后的未来收益。 """
state_tensor = torch.tensor( state, dtype=torch.float32, device=self.device, ).unsqueeze(0)
value = self.network.get_value(state_tensor)
return value.item()
def update(self, buffer: RolloutBuffer) -> Dict[str, float]: """ 使用一整个 rollout buffer 更新 Actor 与 Critic。
更新步骤: 1. 将 buffer 转成 tensor; 2. 标准化 advantage; 3. 多次 epoch 遍历这批数据; 4. 对每个 mini-batch 计算: - 新旧策略概率比 ratio - clipped actor loss - critic value loss - entropy bonus 5. 反向传播更新参数。 """
data = buffer.to_tensors(self.device)
states = data["states"] actions = data["actions"] old_log_probs = data["old_log_probs"] old_values = data["old_values"] advantages = data["advantages"] returns = data["returns"]
# ---------------------------------------------------- # 标准化 Advantage # ---------------------------------------------------- # PPO 中通常会标准化 advantage,以降低训练数值波动。 advantages = ( advantages - advantages.mean() ) / ( advantages.std() + 1e-8 )
num_samples = len(buffer)
actor_losses = [] critic_losses = [] entropy_values = [] approx_kls = [] clip_fractions = []
# ---------------------------------------------------- # 同一批 rollout 数据更新多个 epoch # ---------------------------------------------------- for epoch in range(self.config.update_epochs):
sampler = BatchSampler( SubsetRandomSampler(range(num_samples)), batch_size=self.config.mini_batch_size, drop_last=False, )
for batch_indices in sampler:
batch_states = states[batch_indices] batch_actions = actions[batch_indices] batch_old_log_probs = old_log_probs[batch_indices] batch_advantages = advantages[batch_indices] batch_returns = returns[batch_indices] batch_old_values = old_values[batch_indices]
# ------------------------------------------------ # 用当前正在更新的新策略,重新评价旧动作 # ------------------------------------------------ _, new_log_probs, entropy, new_values = ( self.network.get_action_and_value( batch_states, batch_actions, ) )
# ------------------------------------------------ # PPO 概率比值 # # ratio = # π_new(a_t | s_t) / π_old(a_t | s_t) # # 使用 log_prob 计算更稳定: # exp(log π_new - log π_old) # ------------------------------------------------ log_ratio = new_log_probs - batch_old_log_probs ratio = torch.exp(log_ratio)
# ------------------------------------------------ # PPO Actor Loss # # unclipped: # ratio * advantage # # clipped: # clip(ratio, 1-ε, 1+ε) * advantage # # 目标是最大化二者较小值。 # 由于优化器执行最小化,因此前面添加负号。 # ------------------------------------------------ surrogate_1 = ratio * batch_advantages
surrogate_2 = torch.clamp( ratio, 1.0 - self.config.clip_eps, 1.0 + self.config.clip_eps, ) * batch_advantages
actor_loss = -torch.min( surrogate_1, surrogate_2, ).mean()
# ------------------------------------------------ # Critic Loss # # Critic 希望满足: # V(s_t) ≈ Return_t # # Return_t = Advantage_t + V_old(s_t) # ------------------------------------------------ critic_loss = 0.5 * ( new_values - batch_returns ).pow(2).mean()
# ------------------------------------------------ # Entropy Bonus # # entropy 越大,策略越不确定,探索越充分。 # 因为总体目标是最小化 loss, # 所以写成 - entropy_coef * entropy。 # ------------------------------------------------ entropy_bonus = entropy.mean()
total_loss = ( actor_loss + self.config.value_coef * critic_loss - self.config.entropy_coef * entropy_bonus )
# ------------------------------------------------ # 梯度更新 # ------------------------------------------------ self.optimizer.zero_grad() total_loss.backward()
nn.utils.clip_grad_norm_( self.network.parameters(), self.config.max_grad_norm, )
self.optimizer.step()
# ------------------------------------------------ # 可选:记录一些训练指标 # ------------------------------------------------ with torch.no_grad():
# 近似 KL,用于观察新旧策略是否偏离过大 approx_kl = ( (ratio - 1.0) - log_ratio ).mean()
# 有多少比例的数据触发了 clipping clip_fraction = ( (torch.abs(ratio - 1.0) > self.config.clip_eps) .float() .mean() )
actor_losses.append(actor_loss.item()) critic_losses.append(critic_loss.item()) entropy_values.append(entropy_bonus.item()) approx_kls.append(approx_kl.item()) clip_fractions.append(clip_fraction.item())
return { "actor_loss": float(np.mean(actor_losses)), "critic_loss": float(np.mean(critic_losses)), "entropy": float(np.mean(entropy_values)), "approx_kl": float(np.mean(approx_kls)), "clip_fraction": float(np.mean(clip_fractions)), }
# ============================================================ # 5. Training Loop # ============================================================
def train_ppo(env, config: PPOConfig, total_steps: int): """ env 可以理解为类似 Gymnasium 环境:
state, info = env.reset()
next_state, reward, terminated, truncated, info = env.step(action)
对强化学习而言: - Actor 只负责输出 action; - reward 由 env.step(action) 返回; - buffer 在拿到 reward 之后写入 transition。 """
agent = PPOAgent(config) buffer = RolloutBuffer()
state, _ = env.reset()
episode_return = 0.0 episode_length = 0
for global_step in range(1, total_steps + 1):
# ---------------------------------------------------- # 第一步:Actor 根据当前状态选择动作 # # 得到: # a_t # log π_old(a_t | s_t) # V_old(s_t) # ---------------------------------------------------- action, old_log_prob, old_value = agent.act(state)
# ---------------------------------------------------- # 第二步:将动作交给环境执行 # # 环境返回: # s_{t+1} # r_t # 是否终止 # # 关键点: # reward 不是 Actor 或 Critic 输出的, # 而是环境执行动作后产生的反馈。 # ---------------------------------------------------- next_state, reward, terminated, truncated, info = env.step(action)
done = terminated or truncated
# ---------------------------------------------------- # 第三步:将当前 transition 存入 buffer # # 此时一个 transition 的数据已经完整: # # s_t # a_t # r_t # done_t # V_old(s_t) # log π_old(a_t | s_t) # ---------------------------------------------------- buffer.add( state=state, action=action, reward=reward, done=done, value=old_value, log_prob=old_log_prob, )
episode_return += reward episode_length += 1
# 环境推进到下一状态 state = next_state
# ---------------------------------------------------- # 第四步:如果 episode 结束,重置环境 # ---------------------------------------------------- if done: print( f"step={global_step:7d} | " f"episode_return={episode_return:8.2f} | " f"episode_length={episode_length:4d}" )
state, _ = env.reset() episode_return = 0.0 episode_length = 0
# ---------------------------------------------------- # 第五步:收集满一批 rollout 后,执行 PPO 更新 # ---------------------------------------------------- if len(buffer) >= config.rollout_steps:
# ------------------------------------------------ # rollout 最后一步的 bootstrap value # # 情况 A:最后一个 transition 到达 terminal # 后面没有收益,所以 last_value = 0。 # # 情况 B:只是因为 rollout_steps 满了而暂停采样, # 当前 episode 仍在继续, # 则用 Critic 估计 V(s_{T+1})。 # ------------------------------------------------ if done: last_value = 0.0 else: last_value = agent.value(state)
# ------------------------------------------------ # 第六步:根据已有 reward 与 value 计算: # - TD Error # - Advantage # - Return # ------------------------------------------------ buffer.compute_returns_and_advantages( last_value=last_value, gamma=config.gamma, gae_lambda=config.gae_lambda, )
# ------------------------------------------------ # 第七步:PPO 更新 Actor 与 Critic # ------------------------------------------------ metrics = agent.update(buffer)
print( f"[update] step={global_step:7d} | " f"actor_loss={metrics['actor_loss']:+.4f} | " f"critic_loss={metrics['critic_loss']:.4f} | " f"entropy={metrics['entropy']:.4f} | " f"approx_kl={metrics['approx_kl']:.6f} | " f"clip_frac={metrics['clip_fraction']:.3f}" )
# ------------------------------------------------ # 第八步:清空旧数据 # # PPO 是 on-policy 算法。 # 当前策略已经被更新,因此旧 rollout 不应长期重复使用。 # ------------------------------------------------ buffer.clear()
return agent
|