Skip to content

Commit a0cebf2

Browse files
author
Wish
committed
update template
1 parent 5604911 commit a0cebf2

File tree

2 files changed

+26
-44
lines changed

2 files changed

+26
-44
lines changed

src/tensorRT/common/trt_tensor.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -454,21 +454,25 @@ namespace TRT{
454454
return *this;
455455
}
456456

457-
int Tensor::offset(const std::vector<int>& index){
458-
459-
Assert(index.size() <= shape_.size());
457+
int Tensor::offset(size_t size, const int* index_array){
458+
459+
Assert(size <= shape_.size());
460460
int value = 0;
461461
for(int i = 0; i < shape_.size(); ++i){
462462

463-
if(i < index.size())
464-
value += index[i];
463+
if(i < size)
464+
value += index_array[i];
465465

466466
if(i + 1 < shape_.size())
467467
value *= shape_[i+1];
468468
}
469469
return value;
470470
}
471471

472+
int Tensor::offset(const std::vector<int>& index_array){
473+
return offset(index_array.size(), index_array.data());
474+
}
475+
472476
Tensor& Tensor::set_norm_mat(int n, const cv::Mat& image, float mean[3], float std[3]) {
473477

474478
Assert(image.channels() == 3 && !image.empty() && type() == DataType::Float);

src/tensorRT/common/trt_tensor.hpp

Lines changed: 17 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,9 @@ namespace TRT {
102102
bool empty();
103103

104104
template<typename ... _Args>
105-
Tensor& resize(int t, _Args&& ... args){
106-
resized_dim_.clear();
107-
return resize_impl(t, args...);
105+
Tensor& resize(int dim_size, _Args ... dim_size_args){
106+
const int dim_size_array[] = {dim_size, dim_size_args...};
107+
return resize(sizeof...(dim_size_args) + 1, dim_size_array);
108108
}
109109

110110
Tensor& resize(int ndims, const int* dims);
@@ -121,29 +121,30 @@ namespace TRT {
121121
inline void* gpu() const { ((Tensor*)this)->to_gpu(); return data_->gpu(); }
122122

123123
template<typename ... _Args>
124-
int offset(int t, _Args&& ... args){
125-
offset_index_.clear();
126-
return offset_impl(t, args...);
124+
int offset(int index, _Args ... index_args){
125+
const int index_array[] = {index, index_args...};
126+
return offset(sizeof...(index_args) + 1, index_array);
127127
}
128128

129129
int offset(const std::vector<int>& index);
130+
int offset(size_t size, const int* index_array);
130131

131-
template<typename DataT> inline const DataT* cpu() const { return (DataT*)cpu(); }
132-
template<typename DataT> inline DataT* cpu() { return (DataT*)cpu(); }
132+
template<typename DType> inline const DType* cpu() const { return (DType*)cpu(); }
133+
template<typename DType> inline DType* cpu() { return (DType*)cpu(); }
133134

134-
template<typename DataT, typename ... _Args>
135-
inline DataT* cpu(int t, _Args&& ... args) { return cpu<DataT>() + offset(t, args...); }
135+
template<typename DType, typename ... _Args>
136+
inline DType* cpu(int t, _Args&& ... args) { return cpu<DType>() + offset(t, args...); }
136137

137138

138-
template<typename DataT> inline const DataT* gpu() const { return (DataT*)gpu(); }
139-
template<typename DataT> inline DataT* gpu() { return (DataT*)gpu(); }
139+
template<typename DType> inline const DType* gpu() const { return (DType*)gpu(); }
140+
template<typename DType> inline DType* gpu() { return (DType*)gpu(); }
140141

141-
template<typename DataT, typename ... _Args>
142-
inline DataT* gpu(int t, _Args&& ... args) { return gpu<DataT>() + offset(t, args...); }
142+
template<typename DType, typename ... _Args>
143+
inline DType* gpu(int t, _Args&& ... args) { return gpu<DType>() + offset(t, args...); }
143144

144145

145-
template<typename DataT, typename ... _Args>
146-
inline DataT& at(int t, _Args&& ... args) { return *(cpu<DataT>() + offset(t, args...)); }
146+
template<typename DType, typename ... _Args>
147+
inline DType& at(int t, _Args&& ... args) { return *(cpu<DType>() + offset(t, args...)); }
147148

148149
std::shared_ptr<MixMemory> get_data() {return data_;}
149150
std::shared_ptr<MixMemory> get_workspace() {return workspace_;}
@@ -192,34 +193,11 @@ namespace TRT {
192193
bool save_to_file(const std::string& file);
193194

194195
private:
195-
Tensor& resize_impl(int value){
196-
resized_dim_.push_back(value);
197-
return resize(resized_dim_);
198-
}
199-
200-
template<typename ... _Args>
201-
Tensor& resize_impl(int t, _Args&& ... args){
202-
resized_dim_.push_back(t);
203-
return resize_impl(args...);
204-
}
205-
206-
int offset_impl(int value){
207-
offset_index_.push_back(value);
208-
return offset(offset_index_);
209-
}
210-
211-
template<typename ... _Args>
212-
int offset_impl(int t, _Args&& ... args){
213-
offset_index_.push_back(t);
214-
return offset_impl(args...);
215-
}
216-
217196
Tensor& compute_shape_string();
218197
Tensor& adajust_memory_by_update_dims_or_type();
219198
void setup_data(std::shared_ptr<MixMemory> data);
220199

221200
private:
222-
std::vector<int> resized_dim_, offset_index_;
223201
std::vector<int> shape_;
224202
std::vector<size_t> strides_;
225203
size_t bytes_ = 0;

0 commit comments

Comments
 (0)