Skip to content

Commit 886ea8d

Browse files
Malay Bagfacebook-github-bot
authored andcommitted
Replace int(..) with torch.sym_int(...) to make it torch.export compatible (#3270)
Summary: Pull Request resolved: #3270 As title Reviewed By: angelayi Differential Revision: D79912672 fbshipit-source-id: b2586764082fdcb3665f486720a0e9ee97724d09
1 parent fc56fb0 commit 886ea8d

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

torchrec/sparse/jagged_tensor.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1135,11 +1135,16 @@ def _use_segment_sum_csr(stride_per_key: List[int]) -> bool:
11351135
return False
11361136

11371137
elements_per_segment = sum(stride_per_key) / len(stride_per_key)
1138-
segment_threshold = int(
1138+
segment_threshold_float = (
11391139
1.39771
11401140
+ 0.0000312222 * elements_per_segment
11411141
+ 1.63949e-10 * elements_per_segment**2
11421142
)
1143+
if not torch.jit.is_scripting() and is_non_strict_exporting():
1144+
segment_threshold = torch.sym_int(segment_threshold_float)
1145+
else:
1146+
segment_threshold = int(segment_threshold_float)
1147+
11431148
return len(stride_per_key) >= segment_threshold
11441149

11451150

0 commit comments

Comments
 (0)