Skip to content

Commit e686a10

Browse files
authored
refactor: simplify atb context and workspace call chain. (jd-opensource#141)
1 parent 11cd05a commit e686a10

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

67 files changed

+386
-565
lines changed

xllm/core/framework/CMakeLists.txt

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,15 @@ cc_library(
6565
:tokenizer
6666
torch
6767
)
68+
69+
cc_library(
70+
NAME
71+
model_context
72+
HDRS
73+
model_context.h
74+
SRCS
75+
model_context.cpp
76+
DEPS
77+
torch
78+
$<$<BOOL:${USE_NPU}>:torch_npu>
79+
)
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "core/framework/model_context.h"
17+
18+
#include <torch/torch.h>
19+
#if defined(USE_NPU)
20+
#ifdef TORCH_HIGHER_THAN_PTA6
21+
// #include <torch_npu/csrc/core/npu/NPUFormat.h>
22+
#include <torch_npu/csrc/framework/OpCommand.h>
23+
#else
24+
#include <torch_npu/csrc/aten/NPUNativeFunctions.h>
25+
#include <torch_npu/csrc/framework/utils/OpPreparation.h>
26+
#endif
27+
#include <torch_npu/csrc/libs/init_npu.h>
28+
#endif
29+
30+
namespace xllm {
31+
ModelContext::ModelContext(const ParallelArgs& input_parallel_args,
32+
const ModelArgs& model_args,
33+
const QuantArgs& quant_args,
34+
const torch::TensorOptions& tensor_options)
35+
: parallel_args_(input_parallel_args),
36+
model_args_(model_args),
37+
quant_args_(quant_args),
38+
tensor_options_(tensor_options) {
39+
#if defined(USE_NPU)
40+
int32_t device_id = tensor_options.device().index();
41+
void* stream = c10_npu::getCurrentNPUStream(device_id).stream();
42+
atb::CreateContext(&context_);
43+
context_->SetExecuteStream(stream);
44+
context_->SetAsyncTilingCopyStatus(true);
45+
#endif
46+
}
47+
} // namespace xllm
Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ limitations under the License.
1515

1616
#pragma once
1717

18+
#if defined(USE_NPU)
19+
#include <acl/acl.h>
20+
#endif
21+
1822
#include <memory>
1923

2024
#include "core/framework/model/model_args.h"
@@ -23,38 +27,41 @@ limitations under the License.
2327

2428
namespace xllm {
2529

26-
class Context {
30+
class ModelContext {
2731
public:
28-
Context(const ParallelArgs& input_parallel_args)
29-
: parallel_args(input_parallel_args) {}
32+
ModelContext() : parallel_args_(1, 1, nullptr) {};
3033

31-
const ModelArgs& get_model_args() const { return model_args; }
32-
void set_model_args(const ModelArgs& model_args) {
33-
this->model_args = model_args;
34-
}
34+
ModelContext(const ParallelArgs& input_parallel_args,
35+
const ModelArgs& model_args,
36+
const QuantArgs& quant_args,
37+
const torch::TensorOptions& tensor_options);
3538

36-
const QuantArgs& get_quant_args() const { return quant_args; }
37-
void set_quant_args(const QuantArgs& quant_args) {
38-
this->quant_args = quant_args;
39-
}
39+
const ModelArgs& get_model_args() const { return model_args_; }
40+
41+
const QuantArgs& get_quant_args() const { return quant_args_; }
4042

41-
const ParallelArgs& get_parallel_args() const { return parallel_args; }
42-
// void set_paralle_args(const ParallelArgs& parallel_args) {
43-
// this->parallel_args = parallel_args;
44-
// }
43+
const ParallelArgs& get_parallel_args() const { return parallel_args_; }
4544

4645
const torch::TensorOptions& get_tensor_options() const {
47-
return tensor_options;
46+
return tensor_options_;
4847
}
49-
void set_tensor_options(const torch::TensorOptions& tensor_options) {
50-
this->tensor_options = tensor_options;
48+
49+
const atb::Context* get_atb_context() const { return context_; }
50+
51+
void set_image_embedding_mode(bool image_embedding_mode) {
52+
model_args_.image_embedding_mode() = image_embedding_mode;
5153
}
5254

5355
private:
54-
ModelArgs model_args;
55-
QuantArgs quant_args;
56-
ParallelArgs parallel_args;
57-
torch::TensorOptions tensor_options;
56+
ModelArgs model_args_;
57+
QuantArgs quant_args_;
58+
ParallelArgs parallel_args_;
59+
torch::TensorOptions tensor_options_;
60+
61+
#if defined(USE_NPU)
62+
// used for npu atb
63+
atb::Context* context_;
64+
#endif
5865
};
5966

6067
} // namespace xllm

xllm/core/layers/npu/atb_base.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,11 @@ limitations under the License.
2020
namespace xllm::hf {
2121
static std::atomic<bool> g_executeOk(true);
2222

23-
ATBBase::ATBBase(const Context& context)
23+
ATBBase::ATBBase(const ModelContext& context)
2424
: device_(context.get_tensor_options().device()),
2525
name_(""),
2626
parallel_args_(context.get_parallel_args()) {
27+
context_ = const_cast<atb::Context*>(context.get_atb_context());
2728
auto quant_args = context.get_quant_args();
2829
if (!quant_args.quantize_type().empty()) {
2930
quantize_type_ = quant_args.quantize_type();
@@ -39,6 +40,8 @@ ATBBase::ATBBase(const Context& context)
3940
CHECK_EQ(parallel_args_.world_size(), dp_size_ * dp_local_tp_size_);
4041
dp_local_tp_rank_ = parallel_args_.rank() % dp_local_tp_size_;
4142

43+
work_space_ = AtbWorkspace(device_);
44+
4245
runTaskFunc_ = std::bind(
4346
&ATBBase::run_task, this, std::placeholders::_1, std::placeholders::_2);
4447
}
@@ -195,8 +198,6 @@ void ATBBase::run_task(std::string taskName, std::function<int()> task) const {
195198
}
196199

197200
atb::Status ATBBase::execute_node(atb_speed::Model::Node& node,
198-
atb::Context* context,
199-
AtbWorkspace& workspace,
200201
int nodeId,
201202
aclrtEvent* event,
202203
std::atomic<bool>* event_flag) {
@@ -208,7 +209,7 @@ atb::Status ATBBase::execute_node(atb_speed::Model::Node& node,
208209
<< std::endl;
209210
throw std::runtime_error(ss.str());
210211
}
211-
context_ = context;
212+
212213
atb::Status st =
213214
node.operation->Setup(node.variantPack, node.workspaceSize, context_);
214215
if (st != 0) {
@@ -217,7 +218,7 @@ atb::Status ATBBase::execute_node(atb_speed::Model::Node& node,
217218
}
218219

219220
if (node.workspaceSize > 0) {
220-
node.workspace = workspace.GetWorkspaceBuffer(node.workspaceSize);
221+
node.workspace = work_space_.GetWorkspaceBuffer(node.workspaceSize);
221222
}
222223

223224
runTaskFunc_(name_ + std::to_string(nodeId), [=]() {

xllm/core/layers/npu/atb_base.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ limitations under the License.
2929

3030
#include "atb/atb_infer.h"
3131
#include "buffer/atb_workspace.h"
32-
#include "framework/context.h"
3332
#include "framework/kv_cache/kv_cache.h"
3433
#include "framework/model/model_input_params.h"
34+
#include "framework/model_context.h"
3535
#include "framework/state_dict/state_dict.h"
3636
#include "pytorch/adapter/utils/utils.h"
3737
#include "pytorch/adapter/workspace/workspace.h"
@@ -97,7 +97,7 @@ enum class LinearTypeV2 : int {
9797

9898
class ATBBase {
9999
public:
100-
ATBBase(const Context& context);
100+
ATBBase(const ModelContext& context);
101101
virtual ~ATBBase() {};
102102

103103
using Task = std::function<int()>;
@@ -132,8 +132,6 @@ class ATBBase {
132132
// void get_sharded(at::Tensor weight_tensor,int dim);
133133

134134
atb::Status execute_node(atb_speed::Model::Node& node,
135-
atb::Context* context,
136-
AtbWorkspace& workspace,
137135
int nodeId = 0,
138136
aclrtEvent* event = nullptr,
139137
std::atomic<bool>* event_flag = nullptr);
@@ -152,6 +150,7 @@ class ATBBase {
152150

153151
protected:
154152
atb::Context* context_;
153+
AtbWorkspace work_space_;
155154
std::vector<at::Tensor> at_weight_tensors_;
156155
std::vector<atb::Tensor> atb_weight_tensors_;
157156
at::Device device_;

xllm/core/layers/npu/atb_head_impl.cpp

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ void AtbLmHeadImpl::param_from_args(atb_speed::common::LmHeadParam& param,
6767
}
6868
}
6969

70-
AtbLmHeadImpl::AtbLmHeadImpl(const Context& context) : ATBBase(context) {
70+
AtbLmHeadImpl::AtbLmHeadImpl(const ModelContext& context) : ATBBase(context) {
7171
param_from_args(llm_head_param_prefill_,
7272
context.get_model_args(),
7373
context.get_parallel_args(),
@@ -161,22 +161,10 @@ int64_t AtbLmHeadImpl::init_node(atb_speed::Model::Node& node,
161161

162162
torch::Tensor AtbLmHeadImpl::forward(const torch::Tensor& hidden_states,
163163
const torch::Tensor& seleted_idxes,
164-
atb::Context* context,
165-
AtbWorkspace& workspace,
166164
int nodeId) {
167165
atb::Status st;
168166
build_node_variant_pack(llm_head_node_prefill_, hidden_states, seleted_idxes);
169-
st = execute_node(llm_head_node_prefill_, context, workspace, nodeId);
170-
// if (is_prefill) {
171-
// build_node_variant_pack(llm_head_node_prefill_,
172-
// hidden_states,seleted_idxes); st = execute_node(llm_head_node_prefill_,
173-
// context, workspace ,nodeId);
174-
// } else {
175-
// build_node_variant_pack(llm_head_node_decode_,
176-
// hidden_states,seleted_idxes); st = execute_node(llm_head_node_decode_,
177-
// context, workspace ,nodeId);
178-
// }
179-
// c10_npu::NPUCachingAllocator::emptyCache();
167+
st = execute_node(llm_head_node_prefill_, nodeId);
180168
LOG_IF(FATAL, st != 0) << model_name_
181169
<< "execute llmhead node fail, error code: " << st;
182170
return atOutTensors_[0];

xllm/core/layers/npu/atb_head_impl.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ limitations under the License.
2828

2929
#include "atb/atb_infer.h"
3030
#include "atb_base.h"
31-
#include "framework/context.h"
3231
#include "framework/model/model_input_params.h"
32+
#include "framework/model_context.h"
3333
#include "layers/npu/llm_head.h"
3434
#include "nlohmann/json.hpp"
3535
#include "pytorch/adapter/utils/utils.h"
@@ -47,7 +47,7 @@ class AtbLmHeadImpl : public LlmHeadImpl, public ATBBase {
4747
using RunTaskFunc =
4848
std::function<void(const std::string& taskName, Task task)>;
4949

50-
explicit AtbLmHeadImpl(const Context& context);
50+
explicit AtbLmHeadImpl(const ModelContext& context);
5151

5252
~AtbLmHeadImpl() {};
5353

@@ -66,8 +66,6 @@ class AtbLmHeadImpl : public LlmHeadImpl, public ATBBase {
6666

6767
torch::Tensor forward(const torch::Tensor& hidden_states,
6868
const torch::Tensor& seleted_idxes,
69-
atb::Context* context,
70-
AtbWorkspace& workspace,
7169
int nodeId) override;
7270

7371
// void build_node_variant_pack(atb_speed::Model::Node& node, torch::Tensor&

xllm/core/layers/npu/atb_linear.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,12 @@ limitations under the License.
2323
#include "xllm_kernels/operations/fusion/utils.h"
2424

2525
namespace xllm::hf {
26-
std::shared_ptr<AtbLinearImpl> create_atb_linear_layer(const Context& context) {
26+
std::shared_ptr<AtbLinearImpl> create_atb_linear_layer(
27+
const ModelContext& context) {
2728
return std::make_shared<AtbLinearImpl>(context);
2829
}
2930

30-
AtbLinearImpl::AtbLinearImpl(const Context& context) : ATBBase(context) {
31+
AtbLinearImpl::AtbLinearImpl(const ModelContext& context) : ATBBase(context) {
3132
at_weight_tensors_.resize(1);
3233
atb_weight_tensors_.resize(1);
3334
at_out_tensors_.resize(1);
@@ -103,14 +104,11 @@ int64_t AtbLinearImpl::init_node(atb_speed::Model::Node& node) {
103104
return atb::NO_ERROR;
104105
}
105106

106-
torch::Tensor AtbLinearImpl::forward(const torch::Tensor& input,
107-
atb::Context* context,
108-
AtbWorkspace& workspace,
109-
int nodeId) {
107+
torch::Tensor AtbLinearImpl::forward(const torch::Tensor& input, int nodeId) {
110108
atb::Status st;
111109

112110
build_node_variant_pack(linear_node_, input);
113-
st = execute_node(linear_node_, context, workspace, nodeId);
111+
st = execute_node(linear_node_, nodeId);
114112
LOG_IF(FATAL, st != 0) << model_name_
115113
<< "infer shape fail, error code: " << st;
116114

@@ -156,7 +154,7 @@ void AtbLinearImpl::build_node_variant_pack(atb_speed::Model::Node& node,
156154
atb_speed::Utils::AtTensor2Tensor(at_out_tensors_.at(0));
157155
}
158156

159-
AtbLinear::AtbLinear(const Context& context)
157+
AtbLinear::AtbLinear(const ModelContext& context)
160158
: ModuleHolder(create_atb_linear_layer(context)) {}
161159

162160
} // namespace xllm::hf

xllm/core/layers/npu/atb_linear.h

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ limitations under the License.
2929

3030
#include "atb/atb_infer.h"
3131
#include "atb_base.h"
32-
#include "framework/context.h"
3332
#include "framework/model/model_input_params.h"
33+
#include "framework/model_context.h"
3434
#include "framework/state_dict/state_dict.h"
3535
#include "nlohmann/json.hpp"
3636
#include "pytorch/adapter/utils/utils.h"
@@ -47,7 +47,7 @@ class AtbLinearImpl : public torch::nn::Module, public ATBBase {
4747
using RunTaskFunc =
4848
std::function<void(const std::string& taskName, Task task)>;
4949

50-
explicit AtbLinearImpl(const Context& context);
50+
explicit AtbLinearImpl(const ModelContext& context);
5151

5252
~AtbLinearImpl() {};
5353

@@ -59,10 +59,7 @@ class AtbLinearImpl : public torch::nn::Module, public ATBBase {
5959

6060
int64_t init_layer();
6161

62-
torch::Tensor forward(const torch::Tensor& input,
63-
atb::Context* context,
64-
AtbWorkspace& workspace,
65-
int nodeId);
62+
torch::Tensor forward(const torch::Tensor& input, int nodeId);
6663

6764
void build_node_variant_pack(atb_speed::Model::Node& node,
6865
const torch::Tensor& input);
@@ -83,9 +80,10 @@ class AtbLinear : public torch::nn::ModuleHolder<AtbLinearImpl> {
8380
using torch::nn::ModuleHolder<AtbLinearImpl>::ModuleHolder;
8481
using Impl __attribute__((__unused__)) = AtbLinearImpl;
8582

86-
AtbLinear(const Context& context);
83+
AtbLinear(const ModelContext& context);
8784
};
8885

89-
std::shared_ptr<AtbLinearImpl> create_atb_linear_layer(const Context& context);
86+
std::shared_ptr<AtbLinearImpl> create_atb_linear_layer(
87+
const ModelContext& context);
9088

9189
} // namespace xllm::hf

0 commit comments

Comments
 (0)