@@ -38,6 +38,35 @@ limitations under the License.
3838#include " util/utils.h"
3939
4040namespace xllm {
41+
42+ namespace {
43+ DiTCacheConfig parse_dit_cache_from_flags () {
44+ DiTCacheConfig cache_config;
45+ if (FLAGS_dit_cache_policy == " FBCache" ) {
46+ cache_config.selected_policy = PolicyType::FBCache;
47+ cache_config.fbcache .warmup_steps = FLAGS_dit_cache_warmup_steps;
48+ cache_config.fbcache .residual_diff_threshold =
49+ FLAGS_dit_cache_residual_diff_threshold;
50+ } else if (FLAGS_dit_cache_policy == " TaylorSeer" ) {
51+ cache_config.selected_policy = PolicyType::TaylorSeer;
52+ cache_config.taylorseer .n_derivatives = FLAGS_dit_cache_n_derivatives;
53+ cache_config.taylorseer .skip_interval_steps =
54+ FLAGS_dit_cache_skip_interval_steps;
55+ cache_config.taylorseer .warmup_steps = FLAGS_dit_cache_warmup_steps;
56+ } else if (FLAGS_dit_cache_policy == " FBCacheTaylorSeer" ) {
57+ cache_config.selected_policy = PolicyType::FBCacheTaylorSeer;
58+ cache_config.fbcachetaylorseer .n_derivatives =
59+ FLAGS_dit_cache_n_derivatives;
60+ cache_config.fbcachetaylorseer .warmup_steps = FLAGS_dit_cache_warmup_steps;
61+ cache_config.fbcachetaylorseer .residual_diff_threshold =
62+ FLAGS_dit_cache_residual_diff_threshold;
63+ } else if (FLAGS_dit_cache_policy == " None" ) {
64+ cache_config.selected_policy = PolicyType::TaylorSeer;
65+ }
66+ return cache_config;
67+ }
68+ } // namespace
69+
4170DiTWorker::DiTWorker (const ParallelArgs& parallel_args,
4271 const torch::Device& device,
4372 const runtime::Options& options)
@@ -65,17 +94,8 @@ bool DiTWorker::init_model(const std::string& model_weights_path) {
6594 dit_model_executor_ =
6695 std::make_unique<DiTExecutor>(dit_model_.get (), options_);
6796
68- DiTCacheConfig cache_config_;
69-
70- // TODO: Optimize ditcache configuration initialization.
71- cache_config_.selected_policy = PolicyType::TaylorSeer;
72- cache_config_.taylorseer .n_derivatives = 3 ;
73- cache_config_.taylorseer .skip_interval_steps = 3 ;
74- cache_config_.taylorseer .num_inference_steps = 25 ;
75- cache_config_.taylorseer .warmup_steps = 0 ;
76-
77- bool success = DiTCache::get_instance ().init (cache_config_);
78- CHECK (success) << " DiTCache init failed" ;
97+ DiTCacheConfig cache_config = parse_dit_cache_from_flags ();
98+ DiTCache::get_instance ().init (cache_config);
7999
80100 return true ;
81101}
0 commit comments