File tree Expand file tree Collapse file tree 3 files changed +11
-8
lines changed
Expand file tree Collapse file tree 3 files changed +11
-8
lines changed Original file line number Diff line number Diff line change @@ -116,6 +116,9 @@ def plan_spec(
116116 Greedily place the spec in the first memory that can fit it.
117117 """
118118 for spec .mem_id in range (1 , self .get_num_memories ()):
119+ if placement_constraints .is_mem_id_in_blocklist (spec , spec .mem_id ):
120+ # Skip placement for blocked memory id.
121+ continue
119122 prev_offset , smallest_gap = 0 , float ("inf" )
120123 for allocated_spec in state .allocated_buffers [spec .mem_id ]:
121124 if not Verifier .lifetime_overlap (spec , allocated_spec ):
@@ -141,11 +144,11 @@ def plan_spec(
141144 )
142145 if spec .mem_offset is None :
143146 spec .mem_offset = prev_offset
144- if not self . is_valid_placement ( spec , placement_constraints ):
145- spec . mem_offset = None
146- continue
147- else :
148- spec . mem_offset = prev_offset
147+
148+ if not self . is_valid_placement ( spec , placement_constraints ):
149+ # Skip placement for invalid memory id.
150+ spec . mem_offset = None
151+ continue
149152
150153 state .place_spec (spec )
151154 # A data structure used for maintaining the tensor order
Original file line number Diff line number Diff line change @@ -204,7 +204,7 @@ def _place_memory_id_pinned_specs(
204204 for spec , c in spec_with_abs_constraint .items ()
205205 if c is not None and c .pinned_memory_id == mem_id and c .offset is None
206206 }
207- logging .error (f"Placing specs { mem_id_pinned_specs } for { mem_id = } " )
207+ logging .debug (f"Placing specs { mem_id_pinned_specs } for { mem_id = } " )
208208
209209 with self .block_memories_except (mem_id ):
210210 self .plan (
@@ -220,7 +220,7 @@ def _place_memory_id_pinned_specs(
220220 if constraint is None :
221221 continue
222222
223- logging .error (f"Placing spec { spec } with { constraint } " )
223+ logging .debug (f"Placing spec { spec } with { constraint } " )
224224
225225 if not state .is_placed (spec ):
226226 raise MemoryError (
Original file line number Diff line number Diff line change @@ -1044,7 +1044,7 @@ class DummyMemIdBlockConstraintGen(PassBase):
10441044 mul: blocks 1, 3
10451045 """
10461046
1047- def __init__ (self , memory_constraints : MemoryConfig ):
1047+ def __init__ (self , memory_constraints : MemConstraints ):
10481048 self .memory_constraints = memory_constraints
10491049
10501050 def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
You can’t perform that action at this time.
0 commit comments