@@ -123,6 +123,62 @@ def pred_su(tau, u0, s0, alpha, beta, gamma):
123123Generalizing RNA velocity to transient cell states through dynamical modeling.
124124Nature biotechnology, 38(12), 1408-1414.
125125"""
126+ def scale_by_gene (U ,S ,train_idx = None ,mode = 'scale_u' ):
127+ #mode
128+ # 'auto' means to scale the one with a smaller range
129+ # 'scale_u' means to match std(u) with std(s)
130+ # 'scale_s' means to match std(s) with std(u)
131+ G = U .shape [1 ]
132+ scaling_u = np .ones ((G ))
133+ scaling_s = np .ones ((G ))
134+ std_u , std_s = np .ones ((G )),np .ones ((G ))
135+ for i in range (G ):
136+ if (train_idx is None ):
137+ si , ui = S [:,i ], U [:,i ]
138+ else :
139+ si , ui = S [train_idx ,i ], U [train_idx ,i ]
140+ sfilt , ufilt = si [(si > 0 ) & (ui > 0 )], ui [(si > 0 ) & (ui > 0 )] #Use only nonzero data points
141+ if (len (sfilt )> 3 and len (ufilt )> 3 ):
142+ std_u [i ] = np .std (ufilt )
143+ std_s [i ] = np .std (sfilt )
144+ mask_u , mask_s = (std_u == 0 ), (std_s == 0 )
145+ std_u = std_u + (mask_u & (~ mask_s ))* std_s + (mask_u & mask_s )* 1
146+ std_s = std_s + ((~ mask_u ) & mask_s )* std_u + (mask_u & mask_s )* 1
147+ if (mode == 'auto' ):
148+ scaling_u = np .max (np .stack ([scaling_u ,(std_u / std_s )]),0 )
149+ scaling_s = np .max (np .stack ([scaling_s ,(std_s / std_u )]),0 )
150+ elif (mode == 'scale_u' ):
151+ scaling_u = std_u / std_s
152+ elif (mode == 'scale_s' ):
153+ scaling_s = std_s / std_u
154+ return U / scaling_u , S / scaling_s , scaling_u , scaling_s
155+
156+ def scale_by_cell (U ,S ,train_idx = None ,separate_us_scale = True ):
157+ N = U .shape [0 ]
158+ nu , ns = U .sum (1 , keepdims = True ), S .sum (1 , keepdims = True )
159+ if (separate_us_scale ):
160+ norm_count = (np .median (nu ), np .median (ns )) if train_idx is None else (np .median (nu [train_idx ]), np .median (ns [train_idx ]))
161+ lu = nu / norm_count [0 ]
162+ ls = ns / norm_count [1 ]
163+ else :
164+ norm_count = np .median (nu + ns ) if train_idx is None else np .median (nu [train_idx ]+ ns [train_idx ])
165+ lu = (nu + ns )/ norm_count
166+ ls = lu
167+ return U / lu , S / ls , lu , ls
168+
169+ def get_cell_scale (U ,S ,train_idx = None ,separate_us_scale = True ):
170+ N = U .shape [0 ]
171+ nu , ns = U .sum (1 , keepdims = True ), S .sum (1 , keepdims = True )
172+ if (separate_us_scale ):
173+ norm_count = (np .median (nu ), np .median (ns )) if train_idx is None else (np .median (nu [train_idx ]), np .median (ns [train_idx ]))
174+ lu = nu / norm_count [0 ]
175+ ls = ns / norm_count [1 ]
176+ else :
177+ norm_count = np .median (nu + ns ) if train_idx is None else np .median (nu [train_idx ]+ ns [train_idx ])
178+ lu = (nu + ns )/ norm_count
179+ ls = lu
180+ return lu , ls
181+
126182def linreg (u , s ):
127183 q = np .sum (s * s )
128184 r = np .sum (u * s )
@@ -131,6 +187,7 @@ def linreg(u, s):
131187 k = 1.0 + np .random .rand ()
132188 return k
133189
190+
134191def init_gene (s ,u ,percent ,fit_scaling = False ,Ntype = None ):
135192 #Adopted from scvelo
136193
0 commit comments