@@ -580,14 +580,34 @@ def check_op(op):
580580
581581@singledispatch
582582def axis_mul_or_truediv (
583- X : sparse .spmatrix ,
583+ X : np .ndarray ,
584+ scaling_array : np .ndarray ,
585+ axis : Literal [0 , 1 ],
586+ op : Callable [[Any , Any ], Any ],
587+ * ,
588+ allow_divide_by_zero : bool = True ,
589+ out : np .ndarray | None = None ,
590+ ) -> np .ndarray :
591+ check_op (op )
592+ scaling_array = broadcast_axis (scaling_array , axis )
593+ if op is mul :
594+ return np .multiply (X , scaling_array , out = out )
595+ if not allow_divide_by_zero :
596+ scaling_array = scaling_array .copy () + (scaling_array == 0 )
597+ return np .true_divide (X , scaling_array , out = out )
598+
599+
600+ @axis_mul_or_truediv .register (sparse .csr_matrix )
601+ @axis_mul_or_truediv .register (sparse .csc_matrix )
602+ def _ (
603+ X : sparse .csr_matrix | sparse .csc_matrix ,
584604 scaling_array ,
585605 axis : Literal [0 , 1 ],
586606 op : Callable [[Any , Any ], Any ],
587607 * ,
588608 allow_divide_by_zero : bool = True ,
589- out : sparse .spmatrix | None = None ,
590- ) -> sparse .spmatrix :
609+ out : sparse .csr_matrix | sparse . csc_matrix | None = None ,
610+ ) -> sparse .csr_matrix | sparse . csc_matrix :
591611 check_op (op )
592612 if out is not None :
593613 if X .data is not out .data :
@@ -629,25 +649,6 @@ def new_data_op(x):
629649 ).T
630650
631651
632- @axis_mul_or_truediv .register (np .ndarray )
633- def _ (
634- X : np .ndarray ,
635- scaling_array : np .ndarray ,
636- axis : Literal [0 , 1 ],
637- op : Callable [[Any , Any ], Any ],
638- * ,
639- allow_divide_by_zero : bool = True ,
640- out : np .ndarray | None = None ,
641- ) -> np .ndarray :
642- check_op (op )
643- scaling_array = broadcast_axis (scaling_array , axis )
644- if op is mul :
645- return np .multiply (X , scaling_array , out = out )
646- if not allow_divide_by_zero :
647- scaling_array = scaling_array .copy () + (scaling_array == 0 )
648- return np .true_divide (X , scaling_array , out = out )
649-
650-
651652def make_axis_chunks (
652653 X : DaskArray , axis : Literal [0 , 1 ], pad = True
653654) -> tuple [tuple [int ], tuple [int ]]:
0 commit comments