@@ -20,10 +20,11 @@ limitations under the License.
2020namespace xllm ::hf {
2121static std::atomic<bool > g_executeOk (true );
2222
23- ATBBase::ATBBase (const Context & context)
23+ ATBBase::ATBBase (const ModelContext & context)
2424 : device_(context.get_tensor_options().device()),
2525 name_ (" " ),
2626 parallel_args_(context.get_parallel_args()) {
27+ context_ = const_cast <atb::Context*>(context.get_atb_context ());
2728 auto quant_args = context.get_quant_args ();
2829 if (!quant_args.quantize_type ().empty ()) {
2930 quantize_type_ = quant_args.quantize_type ();
@@ -39,6 +40,8 @@ ATBBase::ATBBase(const Context& context)
3940 CHECK_EQ (parallel_args_.world_size (), dp_size_ * dp_local_tp_size_);
4041 dp_local_tp_rank_ = parallel_args_.rank () % dp_local_tp_size_;
4142
43+ work_space_ = AtbWorkspace (device_);
44+
4245 runTaskFunc_ = std::bind (
4346 &ATBBase::run_task, this , std::placeholders::_1, std::placeholders::_2);
4447}
@@ -195,8 +198,6 @@ void ATBBase::run_task(std::string taskName, std::function<int()> task) const {
195198}
196199
197200atb::Status ATBBase::execute_node (atb_speed::Model::Node& node,
198- atb::Context* context,
199- AtbWorkspace& workspace,
200201 int nodeId,
201202 aclrtEvent* event,
202203 std::atomic<bool >* event_flag) {
@@ -208,7 +209,7 @@ atb::Status ATBBase::execute_node(atb_speed::Model::Node& node,
208209 << std::endl;
209210 throw std::runtime_error (ss.str ());
210211 }
211- context_ = context;
212+
212213 atb::Status st =
213214 node.operation ->Setup (node.variantPack , node.workspaceSize , context_);
214215 if (st != 0 ) {
@@ -217,7 +218,7 @@ atb::Status ATBBase::execute_node(atb_speed::Model::Node& node,
217218 }
218219
219220 if (node.workspaceSize > 0 ) {
220- node.workspace = workspace .GetWorkspaceBuffer (node.workspaceSize );
221+ node.workspace = work_space_ .GetWorkspaceBuffer (node.workspaceSize );
221222 }
222223
223224 runTaskFunc_ (name_ + std::to_string (nodeId), [=]() {
0 commit comments