Skip to content

Commit ed437a6

Browse files
authored
feat: support parse dit cache config from runtime flags. (jd-opensource#473)
1 parent 25e16fa commit ed437a6

File tree

7 files changed

+63
-37
lines changed

7 files changed

+63
-37
lines changed

xllm/core/common/global_flags.cpp

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -436,4 +436,25 @@ DEFINE_bool(
436436
enable_dp_balance,
437437
false,
438438
"Whether to enable dp load balance, if true, sequences within a single "
439-
"dp batch will be shuffled.");
439+
"dp batch will be shuffled.");
440+
441+
// --- dit cache config ---
442+
443+
DEFINE_string(dit_cache_policy,
444+
"TaylorSeer",
445+
"The policy of dit cache(e.g. None, FBCache, TaylorSeer, "
446+
"FBCacheTaylorSeer).");
447+
448+
DEFINE_int64(dit_cache_warmup_steps, 0, "The number of warmup steps.");
449+
450+
DEFINE_int64(dit_cache_n_derivatives,
451+
3,
452+
"The number of derivatives to use in TaylorSeer.");
453+
454+
DEFINE_int64(dit_cache_skip_interval_steps,
455+
3,
456+
"The interval steps to skip for derivative calculation.");
457+
458+
DEFINE_double(dit_cache_residual_diff_threshold,
459+
0.09f,
460+
"The residual difference threshold for cache reuse.");

xllm/core/common/global_flags.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,3 +214,13 @@ DECLARE_bool(enable_prefetch_weight);
214214
DECLARE_int32(flashinfer_workspace_buffer_size);
215215

216216
DECLARE_bool(enable_dp_balance);
217+
218+
DECLARE_string(dit_cache_policy);
219+
220+
DECLARE_int64(dit_cache_warmup_steps);
221+
222+
DECLARE_int64(dit_cache_n_derivatives);
223+
224+
DECLARE_int64(dit_cache_skip_interval_steps);
225+
226+
DECLARE_double(dit_cache_residual_diff_threshold);

xllm/core/framework/dit_cache/dit_cache_config.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,6 @@ enum class PolicyType {
2525
};
2626

2727
struct DiTBaseCacheOptions {
28-
// the number of inference steps.
29-
int num_inference_steps = 25;
30-
3128
// the number of warmup steps.
3229
int warmup_steps = 0;
3330
};

xllm/core/framework/dit_cache/fbcache.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,9 @@ namespace xllm {
2020
void FBCache::init(const DiTCacheConfig& cfg) {
2121
CHECK_GE(cfg.fbcache.residual_diff_threshold, 0.0)
2222
<< "residual_diff_threshold must be >= 0";
23-
CHECK_GT(cfg.fbcache.num_inference_steps, 0)
24-
<< "num_inference_steps must be > 0";
2523
CHECK_GE(cfg.fbcache.warmup_steps, 0) << "warmup_steps must be >= 0";
26-
CHECK_LE(cfg.fbcache.warmup_steps, cfg.fbcache.num_inference_steps)
27-
<< "warmup_steps cannot exceed num_inference_steps";
2824

2925
residual_diff_threshold_ = cfg.fbcache.residual_diff_threshold;
30-
num_inference_steps_ = cfg.fbcache.num_inference_steps;
3126
warmup_steps_ = cfg.fbcache.warmup_steps;
3227
current_step_ = 0;
3328
use_cache_ = false;

xllm/core/framework/dit_cache/fbcache_taylorseer.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,12 @@ namespace xllm {
2020
void FBCacheTaylorSeer::init(const DiTCacheConfig& cfg) {
2121
CHECK_GE(cfg.fbcachetaylorseer.residual_diff_threshold, 0.0)
2222
<< "residual_diff_threshold must be >= 0";
23-
CHECK_GT(cfg.fbcachetaylorseer.num_inference_steps, 0)
24-
<< "num_inference_steps must be > 0";
2523
CHECK_GE(cfg.fbcachetaylorseer.warmup_steps, 0)
2624
<< "warmup_steps must be >= 0";
2725
CHECK_GE(cfg.fbcachetaylorseer.n_derivatives, 0)
2826
<< "n_derivatives must be >= 0";
29-
CHECK_LE(cfg.fbcachetaylorseer.warmup_steps,
30-
cfg.fbcachetaylorseer.num_inference_steps)
31-
<< "warmup_steps cannot exceed num_inference_steps";
3227

3328
residual_diff_threshold_ = cfg.fbcachetaylorseer.residual_diff_threshold;
34-
num_inference_steps_ = cfg.fbcachetaylorseer.num_inference_steps;
3529
warmup_steps_ = cfg.fbcachetaylorseer.warmup_steps;
3630

3731
if (!taylorseer) {
@@ -40,8 +34,6 @@ void FBCacheTaylorSeer::init(const DiTCacheConfig& cfg) {
4034

4135
DiTCacheConfig ts_cfg;
4236
ts_cfg.taylorseer.n_derivatives = cfg.fbcachetaylorseer.n_derivatives;
43-
ts_cfg.taylorseer.num_inference_steps =
44-
cfg.fbcachetaylorseer.num_inference_steps;
4537
taylorseer->init(ts_cfg);
4638
}
4739

xllm/core/framework/dit_cache/taylorseer.cpp

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,11 @@ double factorial(int k) { return std::tgamma(static_cast<double>(k) + 1.0); }
2222
} // namespace
2323

2424
void TaylorSeer::init(const DiTCacheConfig& cfg) {
25-
CHECK_GT(cfg.taylorseer.num_inference_steps, 0)
26-
<< "num_inference_steps must be > 0";
2725
CHECK_GE(cfg.taylorseer.warmup_steps, 0) << "warmup_steps must be >= 0";
2826
CHECK_GE(cfg.taylorseer.skip_interval_steps, 0)
2927
<< "skip_interval_steps must be >= 0";
3028
CHECK_GE(cfg.taylorseer.n_derivatives, 0) << "n_derivatives must be >= 0";
3129

32-
CHECK_LT(cfg.taylorseer.skip_interval_steps,
33-
cfg.taylorseer.num_inference_steps)
34-
<< "skip_interval_steps must be less than num_inference_steps";
35-
CHECK_LT(cfg.taylorseer.warmup_steps, cfg.taylorseer.num_inference_steps)
36-
<< "warmup_steps must be less than num_inference_steps";
37-
38-
num_inference_steps_ = cfg.taylorseer.num_inference_steps;
3930
warmup_steps_ = cfg.taylorseer.warmup_steps;
4031
skip_interval_steps_ = cfg.taylorseer.skip_interval_steps;
4132
n_derivatives_ = cfg.taylorseer.n_derivatives;

xllm/core/runtime/dit_worker.cpp

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,35 @@ limitations under the License.
3838
#include "util/utils.h"
3939

4040
namespace xllm {
41+
42+
namespace {
43+
DiTCacheConfig parse_dit_cache_from_flags() {
44+
DiTCacheConfig cache_config;
45+
if (FLAGS_dit_cache_policy == "FBCache") {
46+
cache_config.selected_policy = PolicyType::FBCache;
47+
cache_config.fbcache.warmup_steps = FLAGS_dit_cache_warmup_steps;
48+
cache_config.fbcache.residual_diff_threshold =
49+
FLAGS_dit_cache_residual_diff_threshold;
50+
} else if (FLAGS_dit_cache_policy == "TaylorSeer") {
51+
cache_config.selected_policy = PolicyType::TaylorSeer;
52+
cache_config.taylorseer.n_derivatives = FLAGS_dit_cache_n_derivatives;
53+
cache_config.taylorseer.skip_interval_steps =
54+
FLAGS_dit_cache_skip_interval_steps;
55+
cache_config.taylorseer.warmup_steps = FLAGS_dit_cache_warmup_steps;
56+
} else if (FLAGS_dit_cache_policy == "FBCacheTaylorSeer") {
57+
cache_config.selected_policy = PolicyType::FBCacheTaylorSeer;
58+
cache_config.fbcachetaylorseer.n_derivatives =
59+
FLAGS_dit_cache_n_derivatives;
60+
cache_config.fbcachetaylorseer.warmup_steps = FLAGS_dit_cache_warmup_steps;
61+
cache_config.fbcachetaylorseer.residual_diff_threshold =
62+
FLAGS_dit_cache_residual_diff_threshold;
63+
} else if (FLAGS_dit_cache_policy == "None") {
64+
cache_config.selected_policy = PolicyType::TaylorSeer;
65+
}
66+
return cache_config;
67+
}
68+
} // namespace
69+
4170
DiTWorker::DiTWorker(const ParallelArgs& parallel_args,
4271
const torch::Device& device,
4372
const runtime::Options& options)
@@ -65,17 +94,8 @@ bool DiTWorker::init_model(const std::string& model_weights_path) {
6594
dit_model_executor_ =
6695
std::make_unique<DiTExecutor>(dit_model_.get(), options_);
6796

68-
DiTCacheConfig cache_config_;
69-
70-
// TODO: Optimize ditcache configuration initialization.
71-
cache_config_.selected_policy = PolicyType::TaylorSeer;
72-
cache_config_.taylorseer.n_derivatives = 3;
73-
cache_config_.taylorseer.skip_interval_steps = 3;
74-
cache_config_.taylorseer.num_inference_steps = 25;
75-
cache_config_.taylorseer.warmup_steps = 0;
76-
77-
bool success = DiTCache::get_instance().init(cache_config_);
78-
CHECK(success) << "DiTCache init failed";
97+
DiTCacheConfig cache_config = parse_dit_cache_from_flags();
98+
DiTCache::get_instance().init(cache_config);
7999

80100
return true;
81101
}

0 commit comments

Comments
 (0)