We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent fc56fb0 commit 886ea8dCopy full SHA for 886ea8d
torchrec/sparse/jagged_tensor.py
@@ -1135,11 +1135,16 @@ def _use_segment_sum_csr(stride_per_key: List[int]) -> bool:
1135
return False
1136
1137
elements_per_segment = sum(stride_per_key) / len(stride_per_key)
1138
- segment_threshold = int(
+ segment_threshold_float = (
1139
1.39771
1140
+ 0.0000312222 * elements_per_segment
1141
+ 1.63949e-10 * elements_per_segment**2
1142
)
1143
+ if not torch.jit.is_scripting() and is_non_strict_exporting():
1144
+ segment_threshold = torch.sym_int(segment_threshold_float)
1145
+ else:
1146
+ segment_threshold = int(segment_threshold_float)
1147
+
1148
return len(stride_per_key) >= segment_threshold
1149
1150
0 commit comments