16
16
import numpy as np
17
17
import tabulate
18
18
import torch
19
-
20
19
from fbgemm_gpu .split_embedding_configs import SparseType
21
20
from fbgemm_gpu .split_table_batched_embeddings_ops_common import (
22
21
BoundsCheckMode ,
@@ -99,46 +98,72 @@ def generate_requests(
99
98
return rs
100
99
101
100
102
- # pyre-fixme[3]: Return type must be annotated.
103
101
def _get_random_tensor (
104
102
num_ads : int ,
105
103
embedding_dimension : int ,
106
104
ads_tables : int ,
107
105
data_type : str ,
108
106
gpu_idx : int ,
109
107
include_quantization : bool ,
110
- ):
108
+ use_pitched : bool ,
109
+ alignment : int = 256 , # alignment in bytes
110
+ ) -> torch .Tensor :
111
+ device = torch .device (f"cuda:{ gpu_idx } " )
112
+
111
113
if data_type == "FP16" or include_quantization :
112
- result_tensor = torch .randn (
113
- num_ads ,
114
- embedding_dimension * ads_tables ,
115
- dtype = torch .float16 ,
116
- device = torch .device (f"cuda:{ gpu_idx } " ),
117
- )
114
+ dtype = torch .float16
115
+ width_elems = embedding_dimension * ads_tables
116
+ elem_size = torch .finfo (dtype ).bits // 8
117
+
118
+ if use_pitched :
119
+ width_bytes = width_elems * elem_size
120
+ pitch_bytes = int (np .ceil (width_bytes / alignment ) * alignment )
121
+ pitch_elems = pitch_bytes // elem_size
122
+ storage = torch .empty ((num_ads , pitch_elems ), dtype = dtype , device = device )
123
+ result_tensor = storage [:, :width_elems ] # logical view
124
+ else :
125
+ result_tensor = torch .randn (
126
+ num_ads , width_elems , dtype = dtype , device = device
127
+ )
128
+
118
129
elif data_type == "INT8" :
119
- assert (
120
- embedding_dimension % 2
121
- ) == 0 , "needs to align to 2 bytes (half type size) for INT8"
122
- result_tensor = torch .randint (
123
- 0 ,
124
- 255 ,
125
- # 2 FP16 numbers for scale and bias, total of 4 bytes overhead
126
- size = (num_ads , (embedding_dimension + 4 ) * ads_tables ),
127
- dtype = torch .uint8 ,
128
- device = torch .device (f"cuda:{ gpu_idx } " ),
129
- )
130
+ assert embedding_dimension % 2 == 0 , "needs to align to 2 bytes for INT8"
131
+ dtype = torch .uint8
132
+ width_elems = (embedding_dimension + 4 ) * ads_tables
133
+ elem_size = 1
134
+
135
+ if use_pitched :
136
+ width_bytes = width_elems * elem_size
137
+ pitch_bytes = int (np .ceil (width_bytes / alignment ) * alignment )
138
+ pitch_elems = pitch_bytes // elem_size
139
+ storage = torch .randint (
140
+ 0 , 255 , (num_ads , pitch_elems ), dtype = dtype , device = device
141
+ )
142
+ result_tensor = storage [:, :width_elems ]
143
+ else :
144
+ result_tensor = torch .randint (
145
+ 0 , 255 , (num_ads , width_elems ), dtype = dtype , device = device
146
+ )
147
+
130
148
elif data_type == "INT4" :
131
- assert (
132
- embedding_dimension % 4
133
- ) == 0 , "needs to align to 2 bytes (half type size) for INT4"
134
- result_tensor = torch .randint (
135
- 0 ,
136
- 255 ,
137
- # Using torch.uint8 for int4 storage
138
- size = (num_ads , (embedding_dimension // 2 + 4 ) * ads_tables ),
139
- dtype = torch .uint8 ,
140
- device = torch .device (f"cuda:{ gpu_idx } " ),
141
- )
149
+ assert embedding_dimension % 4 == 0 , "needs to align to 2 bytes for INT4"
150
+ dtype = torch .uint8
151
+ width_elems = (embedding_dimension // 2 + 4 ) * ads_tables
152
+ elem_size = 1
153
+
154
+ if use_pitched :
155
+ width_bytes = width_elems * elem_size
156
+ pitch_bytes = int (np .ceil (width_bytes / alignment ) * alignment )
157
+ pitch_elems = pitch_bytes // elem_size
158
+ storage = torch .randint (
159
+ 0 , 255 , (num_ads , pitch_elems ), dtype = dtype , device = device
160
+ )
161
+ result_tensor = storage [:, :width_elems ]
162
+ else :
163
+ result_tensor = torch .randint (
164
+ 0 , 255 , (num_ads , width_elems ), dtype = dtype , device = device
165
+ )
166
+
142
167
else :
143
168
raise ValueError
144
169
@@ -253,6 +278,7 @@ def benchmark( # noqa C901
253
278
num_ads : int ,
254
279
embedding_dimension : int ,
255
280
ads_tables : int ,
281
+ use_pitched : bool ,
256
282
iters : int = 10 ,
257
283
p2p_bw : bool = False ,
258
284
dst_device : int = 0 ,
@@ -298,6 +324,7 @@ def benchmark( # noqa C901
298
324
data_type ,
299
325
gpu_idx ,
300
326
include_quantization ,
327
+ use_pitched ,
301
328
)
302
329
for gpu_idx in range (num_gpus )
303
330
]
@@ -485,6 +512,7 @@ def pool_func_with_quantization(
485
512
@click .option ("--num_of_embeddings" , default = 100000 , type = int )
486
513
@click .option ("--pooling_factor" , default = 25 , type = int )
487
514
@click .option ("--sweep" , is_flag = True , default = False )
515
+ @click .option ("--use_pitched" , is_flag = True , default = False )
488
516
def cli (
489
517
all_to_one_only : bool ,
490
518
sum_reduce_to_one_only : bool ,
@@ -500,6 +528,7 @@ def cli(
500
528
num_of_embeddings : int ,
501
529
pooling_factor : int ,
502
530
sweep : bool ,
531
+ use_pitched : bool ,
503
532
) -> None :
504
533
csv_header = (
505
534
"mode, data_type, num_ads, embedding_dimension, ads_tables, num_gpus, dst_device, all_to_one_only, "
@@ -534,6 +563,7 @@ def handler(signum, frame):
534
563
num_ads ,
535
564
embedding_dimension ,
536
565
ads_tables ,
566
+ use_pitched ,
537
567
iters ,
538
568
p2p_bw ,
539
569
dst_device ,
@@ -558,6 +588,7 @@ def handler(signum, frame):
558
588
num_ads ,
559
589
embedding_dimension ,
560
590
ads_tables ,
591
+ use_pitched ,
561
592
iters ,
562
593
p2p_bw ,
563
594
dst_device ,
0 commit comments