@@ -466,29 +466,54 @@ def _clamp_y_intercept(
466466    then applies various constraints to ensure the clamping conditions are respected. 
467467    """ 
468468
469+     # Calculate slopes and y-intercepts for bounding boxes 
469470    a , b  =  _get_slope_and_intercept (bounding_boxes )
470471    a1 , a2 , a3 , a4  =  a .unbind (- 1 )
471472    b1 , b2 , b3 , b4  =  b .unbind (- 1 )
472473
473-     # Clamp y-intercepts (soft clamping) 
474+     # Get y-intercepts from original bounding boxes 
475+     _ , bm  =  _get_slope_and_intercept (original_bounding_boxes )
476+     b1m , b2m , b3m , b4m  =  bm .unbind (- 1 )
477+ 
478+     # Soft clamping: Clamp y-intercepts within canvas boundaries 
474479    b1  =  b2 .clamp (b1 , b3 ).clamp (0 , canvas_size [0 ])
475480    b4  =  b3 .clamp (b2 , b4 ).clamp (0 , canvas_size [0 ])
476481
477-     if  clamping_mode  is  not   None  and  clamping_mode  ==  "hard" :
478-         # Get y-intercepts from original bounding boxes 
479-         _ , b  =  _get_slope_and_intercept (original_bounding_boxes )
480-         _ , b2 , b3 , _  =  b .unbind (- 1 )
481- 
482-         # Set b1 and b4 to the average of their clamped values 
483-         b1  =  b4  =  (b1 .clamp (0 , canvas_size [0 ]) +  b4 .clamp (0 , canvas_size [0 ])) /  2 
482+     if  clamping_mode  ==  "hard" :
483+         # Hard clamping: Average b1 and b4, and adjust b2 and b3 for maximum area 
484+         b1  =  b4  =  (b1  +  b4 ) /  2 
485+ 
486+         # Calculate candidate values for b2 based on geometric constraints 
487+         b2_candidates  =  torch .stack (
488+             [
489+                 b1  *  a2  /  a1 ,  # Constraint at y=0 
490+                 b3  *  a2  /  a3 ,  # Constraint at y=0 
491+                 (a1  -  a2 ) *  canvas_size [1 ] +  b1 ,  # Constraint at x=canvas_width 
492+                 (a3  -  a2 ) *  canvas_size [1 ] +  b3 ,  # Constraint at x=canvas_width 
493+             ],
494+             dim = 1 ,
495+         )
496+         # Take maximum value that doesn't exceed original b2 
497+         b2  =  torch .max (b2_candidates , dim = 1 )[0 ].clamp (max = b2 )
498+ 
499+         # Calculate candidate values for b3 based on geometric constraints 
500+         b3_candidates  =  torch .stack (
501+             [
502+                 canvas_size [0 ] *  (1  -  a3  /  a4 ) +  b4  *  a3  /  a4 ,  # Constraint at y=canvas_height 
503+                 canvas_size [0 ] *  (1  -  a3  /  a2 ) +  b2  *  a3  /  a2 ,  # Constraint at y=canvas_height 
504+                 (a2  -  a3 ) *  canvas_size [1 ] +  b2 ,  # Constraint at x=canvas_width 
505+                 (a4  -  a3 ) *  canvas_size [1 ] +  b4 ,  # Constraint at x=canvas_width 
506+             ],
507+             dim = 1 ,
508+         )
509+         # Take minimum value that doesn't go below original b3 
510+         b3  =  torch .min (b3_candidates , dim = 1 )[0 ].clamp (min = b3 )
484511
485-         # Ensure b2 and b3 defined the box of maximum area after clamping b1 and b4 
486-         b2 .clamp_ (b1  *  a2  /  a1 , b4 ).clamp_ ((a1  -  a2 ) *  canvas_size [1 ] +  b1 )
487-         b2 .clamp_ (b3  *  a2  /  a3 , b4 ).clamp_ ((a3  -  a2 ) *  canvas_size [1 ] +  b3 )
488-         b3 .clamp_ (max = canvas_size [0 ] *  (1  -  a3  /  a4 ) +  b4  *  a3  /  a4 )
489-         b3 .clamp_ (max = canvas_size [0 ] *  (1  -  a3  /  a2 ) +  b2  *  a3  /  a2 )
490-         b3 .clamp_ (b1 , (a2  -  a3 ) *  canvas_size [1 ] +  b2 )
491-         b3 .clamp_ (b1 , (a4  -  a3 ) *  canvas_size [1 ] +  b4 )
512+     # Final clamping to ensure y-intercepts are within original box bounds 
513+     b1 .clamp_ (b1m , b3m )
514+     b3 .clamp_ (b1m , b3m )
515+     b2 .clamp_ (b2m , b4m )
516+     b4 .clamp_ (b2m , b4m )
492517
493518    return  torch .stack ([b1 , b2 , b3 , b4 ], dim = - 1 )
494519
@@ -549,7 +574,7 @@ def _clamp_along_y_axis(
549574        [case_a , case_b , case_c ],
550575    ):
551576        bounding_boxes  =  torch .where (cond .unsqueeze (1 ).repeat (1 , 8 ), case .reshape (- 1 , 8 ), bounding_boxes )
552-     if  clamping_mode  is   not   None   and   clamping_mode   ==  "hard" :
577+     if  clamping_mode  ==  "hard" :
553578        bounding_boxes [..., 0 ].clamp_ (0 )  # Clamp x1 to 0 
554579
555580    if  need_cast :
0 commit comments