-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathrl_llm_intro.html
More file actions
614 lines (570 loc) · 39.4 KB
/
rl_llm_intro.html
File metadata and controls
614 lines (570 loc) · 39.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>强化学习在LLM训练中的应用</title>
<style>
body {
/* 使用更现代、跨平台的字体栈 */
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif, "Apple Color Emoji", "Segoe UI Emoji", "Segoe UI Symbol";
line-height: 1.7; /* 略微增加行高以提高可读性 */
margin: 20px;
background-color: #f4f4f4;
color: #333;
}
.container {
max-width: 900px;
margin: auto;
background: #fff;
padding: 30px;
border-radius: 8px;
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
}
h1, h2, h3 {
color: #0056b3;
}
h1 {
text-align: center;
border-bottom: 2px solid #0056b3;
padding-bottom: 10px;
}
h2 {
margin-top: 30px;
border-left: 4px solid #0056b3;
padding-left: 10px;
}
.concept, .algorithm {
margin-bottom: 20px;
padding: 15px;
background-color: #e9f5ff;
border-radius: 5px;
border: 1px solid #b3d7ff;
}
.example {
margin-top: 10px;
padding: 10px;
background-color: #fff9e6;
border: 1px dashed #ffe58f;
border-radius: 4px;
}
.example strong {
color: #d46b08;
}
code {
background-color: #f0f0f0; /* 稍微调整背景色 */
padding: 3px 5px; /* 略微增加内边距 */
border-radius: 4px;
/* 使用更常见的等宽字体栈 */
font-family: Consolas, Monaco, 'Andale Mono', 'Ubuntu Mono', monospace;
font-size: 0.9em; /* 调整代码字体大小 */
margin: 0 2px; /* 为英文术语添加间距 */
word-spacing: 0.2em; /* 增加单词间距 */
}
/* 为英文术语添加特殊样式 */
.english-term {
font-family: inherit; /* 继承正文字体 */
font-weight: normal; /* 取消加粗 */
margin: 0 2px; /* 添加间距 */
}
.diagram {
text-align: center;
margin: 20px 0;
padding: 15px;
border: 1px solid #ccc;
background-color: #f9f9f9;
border-radius: 5px;
}
.diagram img, .diagram svg {
max-width: 100%;
height: auto;
}
table {
width: 100%;
border-collapse: collapse;
margin-top: 15px;
}
th, td {
border: 1px solid #ddd;
padding: 8px;
text-align: left;
}
th {
background-color: #f2f2f2;
color: #333;
}
.comparison-table th, .comparison-table td {
text-align: center;
}
</style>
<script>
MathJax = {
tex: {
inlineMath: [['$', '$'], ['\(', '\)']],
displayMath: [['$$', '$$'], ['\[', '\]']]
},
svg: {
fontCache: 'global'
}
};
</script>
<script type="text/javascript" id="MathJax-script" async
src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-svg.js">
</script>
</head>
<body>
<div class="container">
<h1>强化学习 (RL) 在大语言模型 (LLM) 训练中的应用</h1>
<section id="rl-basics">
<h2>1. 强化学习范式与基本要素</h2>
<p>强化学习是一种机器学习范式,智能体 (Agent) 通过与环境 (Environment) 交互来学习如何做出决策以最大化累积奖励 (Reward)。</p>
<div class="diagram">
<p><strong>图示:智能体与环境交互循环</strong></p>
<svg width="450" height="220" xmlns="http://www.w3.org/2000/svg" style="display: block; margin: auto;">
<!-- Agent Box -->
<rect x="50" y="70" width="120" height="80" rx="5" ry="5" fill="#e9f5ff" stroke="#0056b3" stroke-width="1.5"/>
<text x="110" y="115" text-anchor="middle" font-size="14" fill="#333">智能体 (Agent)</text>
<text x="110" y="135" text-anchor="middle" font-size="12" fill="#555">(LLM)</text>
<!-- Environment Box -->
<rect x="280" y="70" width="120" height="80" rx="5" ry="5" fill="#fff9e6" stroke="#d46b08" stroke-width="1.5"/>
<text x="340" y="115" text-anchor="middle" font-size="14" fill="#333">环境 (Environment)</text>
<text x="340" y="135" text-anchor="middle" font-size="12" fill="#555">(用户/任务)</text>
<!-- Arrow Definitions -->
<defs>
<marker id="arrowhead-interaction" markerWidth="10" markerHeight="7" refX="9" refY="3.5" orient="auto">
<polygon points="0 0, 10 3.5, 0 7" fill="#333" />
</marker>
</defs>
<!-- Action Arrow -->
<path d="M 170 90 Q 225 60 280 90" stroke="#333" stroke-width="1.5" fill="none" marker-end="url(#arrowhead-interaction)"/>
<text x="225" y="70" text-anchor="middle" font-size="13">动作 a (生成 token)</text>
<!-- State/Reward Arrow -->
<path d="M 280 130 Q 225 160 170 130" stroke="#333" stroke-width="1.5" fill="none" marker-end="url(#arrowhead-interaction)"/>
<text x="225" y="170" text-anchor="middle" font-size="13">新状态 s', 奖励 r</text>
</svg>
</div>
<div class="concept">
<h3>基本要素:</h3>
<ul>
<li><strong>智能体 (Agent):</strong> 学习者和决策者。在LLM场景下,Agent 就是正在被训练的语言模型本身。</li>
<li><strong>环境 (Environment):</strong> Agent 交互的对象。对于LLM文本生成,环境可以理解为当前的对话上下文、用户的输入、或者需要完成的任务(如摘要、翻译)。</li>
<li><strong>状态 (State s):</strong> 对环境当前情况的描述。在LLM中,状态通常是到目前为止生成的文本序列。</li>
<li><strong>动作 (Action a):</strong> Agent 可以采取的操作。对于LLM,动作通常是选择下一个要生成的词 (token)。</li>
<li><strong>奖励 (Reward r):</strong> Agent 在某个状态下采取某个动作后,环境反馈的标量信号,表示该动作的好坏。奖励的设计是RL应用于LLM的关键,例如可以是生成文本的流畅度、相关性、安全性评分,或者是否符合人类偏好。</li>
</ul>
</div>
<div class="example">
<strong>LLM 文本生成示例:</strong>
<ul>
<li><strong>Agent:</strong> GPT-4 模型</li>
<li><strong>环境:</strong> 用户请求写一首关于猫的诗</li>
<li><strong>状态 (s):</strong> 当前已生成的诗句 "The cat sat on the"</li>
<li><strong>动作 (a):</strong> 从词汇表中选择下一个词,例如 "mat", "chair", "roof", "table"</li>
<li><strong>奖励 (r):</strong> 如果选择 "mat",生成的句子更连贯、符合主题,可能获得 +1 的奖励;如果选择 "banana",则可能获得 -1 的奖励。奖励可以由另一个评价模型或者人类反馈给出。</li>
</ul>
</div>
</section>
<section id="rl-concepts">
<h2>2. 强化学习核心概念</h2>
<div class="concept">
<h3>状态价值函数 V(s)</h3>
<p><strong>定义:</strong> 从状态 <code>s</code> 出发,遵循当前策略 <code>π</code>,预期未来能获得的累积奖励的总和(通常带折扣)。它衡量了当前状态 <code>s</code> 有多好。</p>
<p><strong>公式:</strong> $V^{\pi}(s) = \mathbb{E}_{\pi}[\sum_{k=0}^{\infty} \gamma^k r_{t+k+1} | S_t=s]$ (其中 $\gamma$ 是折扣因子)</p>
<p>或者,它可以表示为策略下所有动作价值的期望:$V^{\pi}(s) = \sum_{a \in A} \pi(a|s) Q^{\pi}(s, a)$</p>
<div class="example">
<strong>LLM 示例:</strong> 对于状态 <code>s</code> = "The cat sat on the",<code>V(s)</code> 表示从这个句子片段开始,模型继续生成文本,预期能获得的平均总分数(例如,基于流畅度、相关性等)。一个高的 <code>V(s)</code> 意味着这个开头很有潜力生成高质量的后续文本。
</div>
</div>
<div class="concept">
<h3>动作价值函数 Q(s, a)</h3>
<p><strong>定义:</strong> 在状态 <code>s</code> 下,采取动作 <code>a</code>,然后遵循当前策略 <code>π</code>,预期未来能获得的累积奖励的总和。它衡量了在状态 <code>s</code> 下采取动作 <code>a</code> 有多好。</p>
<p><strong>公式:</strong> $Q^{\pi}(s, a) = \mathbb{E}_{\pi}[\sum_{k=0}^{\infty} \gamma^k r_{t+k+1} | S_t=s, A_t=a]$</p>
<div class="example">
<strong>LLM 示例:</strong> 对于状态 <code>s</code> = "The cat sat on the",<code>Q(s, a="mat")</code> 表示如果模型选择生成 "mat" 这个词,然后继续生成,预期的平均总分数。比较 <code>Q(s, a="mat")</code> 和 <code>Q(s, a="chair")</code> 可以帮助模型决定哪个词更好。
</div>
</div>
<div class="concept">
<h3>优势函数 A(s, a)</h3>
<p><strong>定义:</strong> 在状态 <code>s</code> 下,采取动作 <code>a</code> 相对于遵循当前策略的平均表现(即 V(s))有多大的优势。</p>
<p><strong>公式:</strong> $A^{\pi}(s, a) = Q^{\pi}(s, a) - V^{\pi}(s)$</p>
<p>优势函数衡量了采取特定动作 <code>a</code> 比随机选择一个遵循当前策略的动作要好多少。正优势意味着该动作优于平均水平,负优势则劣于平均水平。这在策略梯度等算法中非常重要。</p>
<div class="example">
<strong>LLM 示例:</strong> 如果 <code>V(s) = 7.0</code>,<code>Q(s, a="mat") = 8.5</code>,<code>Q(s, a="roof") = 6.0</code>。那么:
<ul>
<li><code>A(s, a="mat") = 8.5 - 7.0 = +1.5</code> (选择 "mat" 比平均动作好)</li>
<li><code>A(s, a="roof") = 6.0 - 7.0 = -1.0</code> (选择 "roof" 比平均动作差)</li>
</ul>
模型会倾向于增加选择 "mat" 的概率,减少选择 "roof" 的概率。
</div>
<div class="diagram">
<p><strong>图示:Q值、V值与优势函数 (LLM 示例)</strong></p>
<p>状态 s = "The cat sat on the"</p>
<svg width="600" height="350" xmlns="http://www.w3.org/2000/svg" style="display: block; margin: auto;">
<style>
.bar { fill: skyblue; stroke: navy; stroke-width: 0.5; }
.bar.pos-adv { fill: lightgreen; }
.bar.neg-adv { fill: salmon; }
.axis { stroke: black; stroke-width: 1.5; }
.grid-line { stroke: #ccc; stroke-width: 0.5; stroke-dasharray: 2,2; }
.label { font-size: 12px; text-anchor: middle; }
.value-label { font-size: 10px; text-anchor: middle; }
.title { font-size: 14px; text-anchor: middle; font-weight: bold; }
.v-line { stroke: red; stroke-width: 1.5; stroke-dasharray: 4,2; }
</style>
<!-- Data (from python example) -->
<defs>
<data id="q_data">
<item token="mat" q="8.5" adv="1.5"/>
<item token="chair" q="7.5" adv="0.5"/>
<item token="roof" q="6.0" adv="-1.0"/>
<item token="table" q="7.8" adv="0.8"/>
</data>
<variable id="v_s" value="7.0"/>
</defs>
<!-- Q-Value Chart -->
<g transform="translate(50, 40)">
<text x="125" y="-10" class="title">Q(s, a)</text>
<!-- Y Axis -->
<line x1="0" y1="0" x2="0" y2="200" class="axis"/>
<text x="-25" y="100" transform="rotate(-90, -25, 100)" class="label">预期回报</text>
<line x1="0" y1="0" x2="250" y2="0" class="axis"/> <!-- Top line -->
<line x1="0" y1="200" x2="250" y2="200" class="axis"/> <!-- X Axis -->
<!-- Y Scale -->
<text x="-10" y="5" class="label" text-anchor="end">10</text>
<line x1="-5" y1="0" x2="0" y2="0" class="axis"/>
<text x="-10" y="105" class="label" text-anchor="end">5</text>
<line x1="-5" y1="100" x2="250" y2="100" class="grid-line"/>
<line x1="-5" y1="100" x2="0" y2="100" class="axis"/>
<text x="-10" y="205" class="label" text-anchor="end">0</text>
<line x1="-5" y1="200" x2="0" y2="200" class="axis"/>
<!-- Bars -->
<rect x="20" y="30" width="40" height="170" class="bar" data-value="8.5"/>
<text x="40" y="215" class="label">mat</text>
<text x="40" y="25" class="value-label">8.5</text>
<rect x="70" y="50" width="40" height="150" class="bar" data-value="7.5"/>
<text x="90" y="215" class="label">chair</text>
<text x="90" y="45" class="value-label">7.5</text>
<rect x="120" y="80" width="40" height="120" class="bar" data-value="6.0"/>
<text x="140" y="215" class="label">roof</text>
<text x="140" y="75" class="value-label">6.0</text>
<rect x="170" y="44" width="40" height="156" class="bar" data-value="7.8"/>
<text x="190" y="215" class="label">table</text>
<text x="190" y="39" class="value-label">7.8</text>
<!-- V(s) Line -->
<line x1="0" y1="60" x2="250" y2="60" class="v-line"/>
<text x="260" y="65" class="label" fill="red">V(s) = 7.0</text>
</g>
<!-- Advantage Chart -->
<g transform="translate(350, 40)">
<text x="125" y="-10" class="title">A(s, a) = Q(s, a) - V(s)</text>
<!-- Y Axis -->
<line x1="0" y1="0" x2="0" y2="200" class="axis"/>
<text x="-25" y="100" transform="rotate(-90, -25, 100)" class="label">优势</text>
<line x1="0" y1="100" x2="250" y2="100" class="axis"/> <!-- X Axis (at y=0 for advantage) -->
<!-- Y Scale -->
<text x="-10" y="5" class="label" text-anchor="end">+2</text>
<line x1="-5" y1="0" x2="0" y2="0" class="axis"/>
<text x="-10" y="105" class="label" text-anchor="end">0</text>
<line x1="-5" y1="100" x2="0" y2="100" class="axis"/>
<text x="-10" y="205" class="label" text-anchor="end">-2</text>
<line x1="-5" y1="200" x2="0" y2="200" class="axis"/>
<!-- Bars -->
<rect x="20" y="70" width="40" height="30" class="bar pos-adv" data-value="1.5"/>
<text x="40" y="215" class="label">mat</text>
<text x="40" y="65" class="value-label">+1.5</text>
<rect x="70" y="90" width="40" height="10" class="bar pos-adv" data-value="0.5"/>
<text x="90" y="215" class="label">chair</text>
<text x="90" y="85" class="value-label">+0.5</text>
<rect x="120" y="100" width="40" height="20" class="bar neg-adv" data-value="-1.0"/>
<text x="140" y="215" class="label">roof</text>
<text x="140" y="125" class="value-label">-1.0</text>
<rect x="170" y="84" width="40" height="16" class="bar pos-adv" data-value="0.8"/>
<text x="190" y="215" class="label">table</text>
<text x="190" y="79" class="value-label">+0.8</text>
</g>
<!-- Legend -->
<g transform="translate(150, 300)">
<rect x="0" y="0" width="15" height="10" fill="lightgreen"/>
<text x="20" y="10" class="label">正优势 (优于平均)</text>
<rect x="150" y="0" width="15" height="10" fill="salmon"/>
<text x="170" y="10" class="label">负优势 (劣于平均)</text>
</g>
</svg>
</div>
</div>
</section>
<section id="policy-estimation-improvement">
<h2>3. 策略评估与策略改进</h2>
<p>强化学习的核心目标是找到一个最优策略 $\pi^*$,使得累积奖励最大化。这通常通过一个迭代过程实现:评估当前策略的好坏(策略评估),然后根据评估结果改进策略(策略改进)。</p>
<div class="concept">
<h3>策略评估 (Policy Evaluation)</h3>
<p><strong>目标:</strong> 给定一个策略 $\pi$,计算该策略下的状态价值函数 $V^{\pi}(s)$ 或动作价值函数 $Q^{\pi}(s, a)$。</p>
<p><strong>方法:</strong>
<ul>
<li><strong>蒙特卡洛 (Monte Carlo) 方法:</strong> 通过运行完整的 episode 来收集回报 (return),然后对每个状态或状态-动作对的回报进行平均,以估计其价值。适用于 episodic 任务。</li>
<li><strong>时序差分 (Temporal Difference, TD) 学习:</strong> 不需要等待 episode 结束,而是使用当前的回报和下一个状态的估计价值来更新当前状态的价值估计(自举 Bootstrapping)。例如 TD(0), Q-learning, SARSA。</li>
</ul>
</p>
<div class="example">
<strong>LLM 示例:</strong> 假设我们有一个初步的文本生成策略 $\pi$。策略评估的目标是估计出,遵循这个策略,从某个部分生成的文本(状态 s)开始,最终能得到多高的平均分数(例如,流畅度+相关性)。或者,在状态 s 时选择生成特定词 a(动作 a),最终能得到多高的平均分数。
</div>
</div>
<div class="concept">
<h3>策略改进 (Policy Improvement)</h3>
<p><strong>目标:</strong> 基于当前策略 $\pi$ 的价值函数 ($V^{\pi}$ 或 $Q^{\pi}$),找到一个更好的策略 $\pi'$,使得 $V^{\pi'}(s) \ge V^{\pi}(s)$ 对所有状态 $s$ 成立。</p>
<p><strong>方法:</strong>
<ul>
<li><strong>贪心策略:</strong> 对于每个状态 $s$,选择能够最大化 $Q^{\pi}(s, a)$ 的动作 $a$。即 $\pi'(s) = \arg\max_a Q^{\pi}(s, a)$。</li>
<li><strong>$\epsilon$-贪心策略:</strong> 大部分时间选择贪心动作,但以小概率 $\epsilon$ 随机选择一个动作,以保证探索。</li>
<li><strong>策略梯度方法:</strong> 直接参数化策略 $\pi_\theta(a|s)$,然后沿着能使预期回报增加的梯度方向更新参数 $\theta$。</li>
</ul>
</p>
<p>策略评估和策略改进通常交替进行,这个过程称为<strong>策略迭代 (Policy Iteration)</strong> 或 <strong>价值迭代 (Value Iteration)</strong> 的变种,直至策略收敛到最优。 </p>
<div class="example">
<strong>LLM 示例:</strong> 通过策略评估,我们知道了在状态 "The cat sat on the" 下,选择 "mat" 的 Q 值 (8.5) 高于选择 "roof" 的 Q 值 (6.0)。策略改进会调整模型参数,使得模型在遇到这个状态时,选择 "mat" 的概率增加,选择 "roof" 的概率减少。
</div>
</div>
</section>
<section id="rl-algorithms">
<h2>3. 强化学习在 LLM 中的常用算法</h2>
<p>为了让 LLM 的输出更符合人类偏好、更安全或在特定任务上表现更好,研究者们将强化学习算法应用于 LLM 的微调阶段。以下是一些常用算法:</p>
<div class="algorithm">
<h3>PPO (Proximal Policy Optimization)</h3>
<p>PPO 是一种策略梯度方法,旨在在更新策略时限制更新步长,以避免策略崩溃,提高训练稳定性。它通过优化一个带截断 (clipping) 的目标函数来实现。</p>
<p><strong>在 LLM 中的应用 (RLHF - Reinforcement Learning from Human Feedback):</strong></p>
<ol>
<li><strong>收集偏好数据:</strong> 人类标注者对模型生成的多个回答进行排序或评分。</li>
<li><strong>训练奖励模型 (Reward Model):</strong> 使用偏好数据训练一个模型,该模型能预测给定文本序列的“好坏”程度(即奖励)。</li>
<li><strong>使用 PPO 微调 LLM:</strong> 将 LLM 作为策略网络,奖励模型提供奖励信号,使用 PPO 算法优化 LLM,使其生成能获得更高奖励(即更符合人类偏好)的文本。同时会加入一个 KL 散度惩罚项,防止微调后的模型偏离原始模型太远。</li>
</ol>
<h4>详细训练步骤 (结合 LLM 示例):</h4>
<ol>
<li><strong>初始化:</strong> 从预训练好的 SFT (Supervised Fine-Tuning) 模型 $\pi_{\text{SFT}}$ 开始,将其作为初始策略 $\pi_0$。同时,加载训练好的奖励模型 (RM) $r_\phi$。</li>
<li><strong>采样阶段 (Rollout):</strong>
<ul>
<li>从数据集中抽取一批 prompt $x$。</li>
<li>使用当前策略 $\pi_k$ (第 k 次迭代的策略) 对每个 prompt $x$ 生成回答 $y$。得到 $(x, y)$ 对。</li>
<li>对于每个生成的 $(x, y)$,使用奖励模型 $r_\phi$ 计算奖励 $r = r_\phi(x, y)$。</li>
<li>计算 KL 散度惩罚项:$r_{\text{KL}} = -\beta \log(\pi_k(y|x) / \pi_{\text{SFT}}(y|x))$。这个惩罚项防止 $\pi_k$ 偏离初始 SFT 模型太远。</li>
<li>总奖励为 $R(x, y) = r + r_{\text{KL}}$。</li>
</ul>
</li>
<li><strong>优势估计 (Advantage Estimation):</strong>
<ul>
<li>使用价值函数 $V_\psi(x)$ (通常也需要训练一个价值模型) 来估计在状态 $x$ 下的预期回报。</li>
<li>计算优势函数 $\hat{A}_k(x, y) = R(x, y) - V_\psi(x)$。优势表示当前动作(生成 $y$)相对于平均水平的好坏程度。可以使用 GAE (Generalized Advantage Estimation) 等更高级的方法来减少方差。</li>
</ul>
</li>
<li><strong>策略更新 (Policy Update):</strong>
<ul>
<li>使用收集到的 $(x, y, \hat{A}_k)$ 数据,通过优化 PPO 的目标函数来更新策略参数 $\theta$ (从 $\theta_k$ 到 $\theta_{k+1}$)。PPO 的目标函数通常包含一个截断项 (clipping),限制新旧策略的比率,防止更新步长过大:
$$L^{\text{CLIP}}(\theta) = \mathbb{E}_{x,y \sim \pi_k} [\min(r_t(\theta) \hat{A}_k, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) \hat{A}_k)]$$
其中 $r_t(\theta) = \frac{\pi_\theta(y|x)}{\pi_{\theta_k}(y|x)}$ 是新旧策略的概率比率,$\epsilon$ 是截断超参数。</li>
<li>通常会进行多次小批量的梯度更新。</li>
</ul>
</li>
<li><strong>价值函数更新 (Value Function Update):</strong> (如果使用价值模型)
<ul>
<li>使用收集到的数据 $(x, R(x,y))$ 更新价值函数 $V_\psi(x)$,通常通过最小化均方误差 $(R(x,y) - V_\psi(x))^2$。</li>
</ul>
</li>
<li><strong>重复:</strong> 重复步骤 2-5,直到策略收敛或达到最大迭代次数。</li>
</ol>
<p><strong>图示:RLHF (PPO) 流程</strong></p>
<svg width="650" height="480" xmlns="http://www.w3.org/2000/svg" style="display: block; margin: auto; font-family: sans-serif;">
<!-- Add white background -->
<rect width="100%" height="100%" fill="white"/>
<style>
.box {
fill: #e9f5ff;
stroke: #000; /* Changed from #0056b3 */
stroke-width: 1.5;
rx: 8;
ry: 8;
transition: fill 0.2s ease;
}
.box:hover { fill: #cce7ff; }
.box-value {
fill: #fff9e6;
stroke: #000; /* Changed from #d46b08 */
}
.box-value:hover { fill: #ffeccc; }
.label {
font-size: 13px;
text-anchor: middle;
fill: #333;
pointer-events: none; /* Allow hover on box */
}
.label-title {
font-weight: bold;
font-size: 14px;
}
.arrow {
stroke: #000; /* Changed from #555 */
stroke-width: 2;
fill: none;
marker-end: url(#arrowhead-ppo);
}
.loop-arrow {
stroke: #000; /* Changed from #007bff */
stroke-width: 2;
fill: none;
marker-end: url(#arrowhead-ppo-loop);
stroke-dasharray: 5, 5;
}
.step-label {
font-size: 16px;
text-anchor: middle;
font-weight: bold;
fill: #004085;
}
</style>
<defs>
<marker id="arrowhead-ppo" markerWidth="10" markerHeight="7" refX="9" refY="3.5" orient="auto">
<polygon points="0 0, 10 3.5, 0 7" fill="#000" />
</marker>
<marker id="arrowhead-ppo-loop" markerWidth="10" markerHeight="7" refX="9" refY="3.5" orient="auto">
<polygon points="0 0, 10 3.5, 0 7" fill="#000" />
</marker>
</defs>
<text x="325" y="40" class="step-label">PPO 训练循环 (PPO Training Loop)</text>
<!-- Step 1: Collect Data -->
<rect x="225" y="70" width="200" height="75" class="box"/>
<text x="325" y="95" class="label label-title">1. 收集数据 (Collect Data)</text>
<text x="325" y="115" class="label">使用当前策略 $\pi_{\theta_{old}}$</text>
<text x="325" y="130" class="label">与环境交互,采样轨迹 $\tau$</text>
<!-- Arrow 1 -> 2 -->
<path d="M 325 145 V 175" class="arrow"/>
<!-- Step 2: Compute Advantages -->
<rect x="225" y="175" width="200" height="65" class="box"/>
<text x="325" y="200" class="label label-title">2. 计算优势估计</text>
<text x="325" y="220" class="label">$\hat{A}_t = Q(s_t, a_t) - V(s_t)$ (e.g., GAE)</text>
<!-- Arrow 2 -> 3 -->
<path d="M 325 240 V 270" class="arrow"/>
<!-- Step 3: Optimize Policy -->
<rect x="225" y="270" width="200" height="85" class="box"/>
<text x="325" y="295" class="label label-title">3. 优化策略 (Optimize Policy)</text>
<text x="325" y="315" class="label">最大化 PPO 截断目标:</text>
<text x="325" y="335" class="label">$L^{CLIP}(\theta) = \hat{\mathbb{E}}_t [\min(r_t(\theta)\hat{A}_t, \dots)] $</text>
<!-- Arrow 2 -> 4 -->
<path d="M 425 207.5 H 460" class="arrow"/>
<!-- Step 4: Update Value Function (Optional) -->
<rect x="460" y="175" width="160" height="65" class="box box-value"/>
<text x="540" y="200" class="label label-title">4. 更新价值函数 (Opt.)</text>
<text x="540" y="220" class="label">最小化 $L^{VF} = (V_\phi(s_t) - R_t)^2$</text>
<!-- Loop Back Arrow: 4 -> 1 (Removed) -->
<!-- Loop Back Arrow: 3 -> 1 -->
<path d="M 225 312.5 H 150 Q 100 200 150 87.5 H 225" class="loop-arrow"/>
</svg>
</div>
<div class="algorithm">
<h3>DPO (Direct Preference Optimization)</h3>
<p>DPO 是一种更新颖的方法,它绕过了显式训练奖励模型的步骤,直接从偏好数据中优化语言模型。它推导出了一种直接从偏好数据计算最优策略的方法,简化了 RLHF 流程。</p>
<p><strong>核心思想:</strong> DPO 表明,可以通过一个简单的目标函数,直接使用偏好数据(哪个回答更好)来优化策略(LLM),而不需要先拟合一个奖励模型。它隐式地定义和优化了奖励函数。</p>
<p><strong>优点:</strong> 相比 PPO,DPO 通常更简单、更稳定,且效果相当甚至更好。</p>
<h4>详细训练步骤:</h4>
<ol>
<li><strong>初始化:</strong> 从预训练好的 SFT 模型 $\pi_{\text{ref}}$ (参考策略,通常是 SFT 模型本身) 开始。</li>
<li><strong>准备偏好数据:</strong> 收集人类偏好数据,形式为 $(x, y_w, y_l)$,其中 $x$ 是 prompt,$y_w$ 是更受偏好的回答 (winner),$y_l$ 是不太受偏好的回答 (loser)。</li>
<li><strong>计算隐式奖励:</strong> DPO 的关键在于它推导出可以直接从策略概率计算隐式奖励差。对于一对偏好 $(y_w, y_l)$,模型 $\pi_\theta$ 和参考模型 $\pi_{\text{ref}}$ 的对数概率差定义为:
$$\hat{r}_\theta(x, y) = \beta \log \frac{\pi_\theta(y|x)}{\pi_{\text{ref}}(y|x)}$$
其中 $\beta$ 是一个超参数,控制与参考策略的偏差程度。</li>
<li><strong>定义损失函数:</strong> DPO 的目标是最大化模型 $\pi_\theta$ 对偏好回答 $y_w$ 的概率,同时最小化对 $y_l$ 的概率。损失函数通常采用类似二元交叉熵的形式,基于隐式奖励差:
$$L_{\text{DPO}}(\pi_\theta; \pi_{\text{ref}}) = -\mathbb{E}_{(x, y_w, y_l)} \left[ \log \sigma(\hat{r}_\theta(x, y_w) - \hat{r}_\theta(x, y_l)) \right]$$
$$= -\mathbb{E}_{(x, y_w, y_l)} \left[ \log \sigma\left(\beta \log \frac{\pi_\theta(y_w|x)}{\pi_{\text{ref}}(y_w|x)} - \beta \log \frac{\pi_\theta(y_l|x)}{\pi_{\text{ref}}(y_l|x)}\right) \right]$$
其中 $\sigma$ 是 sigmoid 函数。这个损失函数鼓励模型 $\pi_\theta$ 使得 $y_w$ 的隐式奖励高于 $y_l$ 的隐式奖励。</li>
<li><strong>模型更新:</strong> 使用梯度下降法最小化 $L_{\text{DPO}}$,更新模型 $\pi_\theta$ 的参数。</li>
<li><strong>重复:</strong> 使用新的数据或重复使用现有数据进行多轮更新,直到模型收敛。</li>
</ol>
<h4>与 PPO 的主要不同:</h4>
<ul>
<li><strong>无需显式奖励模型 (RM):</strong> DPO 最显著的特点是不需要预先训练一个独立的奖励模型。它直接从偏好对 $(y_w, y_l)$ 中学习,隐式地定义了奖励。</li>
<li><strong>优化目标不同:</strong> PPO 优化的是期望累积奖励(加上 KL 惩罚),而 DPO 直接优化一个基于偏好对概率的损失函数。</li>
<li><strong>实现复杂度:</strong> DPO 通常比 PPO 实现起来更简单,因为它省去了训练和维护奖励模型以及价值模型的步骤,也避免了复杂的采样和优势估计过程。</li>
<li><strong>稳定性:</strong> DPO 通常被认为比 PPO 更稳定,不易出现模式崩溃等问题。</li>
</ul>
<div class="diagram">
<p><strong>图示:DPO 流程</strong></p>
<svg width="500" height="250" xmlns="http://www.w3.org/2000/svg" style="display: block; margin: auto;">
<style>
.box { fill: #fffbe6; stroke: #faad14; stroke-width: 1.5; rx: 5; ry: 5; }
.arrow { stroke: #333; stroke-width: 1.5; fill: none; marker-end: url(#arrowhead-dpo); }
.label { font-size: 13px; text-anchor: middle; }
.step-label { font-size: 14px; text-anchor: middle; font-weight: bold; fill: #d46b08; }
</style>
<defs>
<marker id="arrowhead-dpo" markerWidth="10" markerHeight="7" refX="9" refY="3.5" orient="auto">
<polygon points="0 0, 10 3.5, 0 7" fill="#333" />
</marker>
</defs>
<!-- Input: SFT Model -->
<rect x="40" y="80" width="120" height="60" class="box"/>
<text x="100" y="110" class="label">SFT 模型</text>
<text x="100" y="125" class="label">(参考策略 π_ref)</text>
<!-- Input: Preference Data -->
<rect x="190" y="20" width="120" height="60" class="box"/>
<text x="250" y="50" class="label">人类偏好数据</text>
<text x="250" y="65" class="label">(y_w, y_l)</text>
<!-- DPO Optimization -->
<circle cx="250" cy="150" r="50" fill="#fff0f6" stroke="#eb2f96" stroke-width="1.5"/>
<text x="250" y="155" class="label">DPO 优化</text>
<text x="250" y="170" class="label">(直接优化策略)</text>
<!-- Output: Optimized LLM -->
<rect x="340" y="80" width="120" height="60" class="box" fill="#d9f7be" stroke="#52c41a"/>
<text x="400" y="110" class="label">优化后的 LLM</text>
<text x="400" y="125" class="label">(策略 π_θ)</text>
<!-- Arrows -->
<path d="M 160 110 H 195 C 200 110 200 140 200 150" class="arrow"/> <!-- SFT to DPO -->
<path d="M 250 80 V 100" class="arrow"/> <!-- Preference Data to DPO -->
<path d="M 300 150 C 300 140 305 110 340 110" class="arrow"/> <!-- DPO to Optimized LLM -->
</svg>
<p style="text-align:center; font-size: 12px; margin-top: 5px;">DPO 使用偏好对 (赢家 y_w, 输家 y_l) 直接更新 SFT 模型,无需显式奖励模型。</p>
</div>
</div>
<h3>算法比较</h3>
<table class="comparison-table">
<thead>
<tr>
<th>特性</th>
<th>PPO (RLHF)</th>
<th>DPO</th>
<th>GRPO</th>
</tr>
</thead>
<tbody>
<tr>
<td><strong>需要奖励模型</strong></td>
<td>是 (显式训练)</td>
<td>否 (隐式)</td>
<td>否 (隐式,但更广义)</td>
</tr>
<tr>
<td><strong>优化目标</strong></td>
<td>最大化奖励 + KL 惩罚</td>
<td>直接优化偏好概率</td>
<td>直接优化广义偏好概率</td>
</tr>
<tr>
<td><strong>实现复杂度</strong></td>
<td>高 (需要训练 RM 和 RL)</td>
<td>中等</td>
<td>中等到高 (取决于具体形式)</td>
</tr>
<tr>
<td><strong>稳定性</strong></td>
<td>中等 (对超参敏感)</td>
<td>较高</td>
<td>较高 (理论上)</td>
</tr>
<tr>
<td><strong>灵活性</strong></td>
<td>中等</td>
<td>较低 (主要处理 pairwise 偏好)</td>
<td>高 (可处理复杂偏好)</td>
</tr>
</tbody>
</table>
</section>
</div>
</body>
</html>