Skip to content

Commit 1f918f5

Browse files
author
Yichen Gu
committed
Added normalization functions (for future use)
1 parent 2bec96b commit 1f918f5

File tree

1 file changed

+57
-0
lines changed

1 file changed

+57
-0
lines changed

velovae/model/model_util.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,62 @@ def pred_su(tau, u0, s0, alpha, beta, gamma):
123123
Generalizing RNA velocity to transient cell states through dynamical modeling.
124124
Nature biotechnology, 38(12), 1408-1414.
125125
"""
126+
def scale_by_gene(U,S,train_idx=None,mode='scale_u'):
127+
#mode
128+
# 'auto' means to scale the one with a smaller range
129+
# 'scale_u' means to match std(u) with std(s)
130+
# 'scale_s' means to match std(s) with std(u)
131+
G = U.shape[1]
132+
scaling_u = np.ones((G))
133+
scaling_s = np.ones((G))
134+
std_u, std_s = np.ones((G)),np.ones((G))
135+
for i in range(G):
136+
if(train_idx is None):
137+
si, ui = S[:,i], U[:,i]
138+
else:
139+
si, ui = S[train_idx,i], U[train_idx,i]
140+
sfilt, ufilt = si[(si>0) & (ui>0)], ui[(si>0) & (ui>0)] #Use only nonzero data points
141+
if(len(sfilt)>3 and len(ufilt)>3):
142+
std_u[i] = np.std(ufilt)
143+
std_s[i] = np.std(sfilt)
144+
mask_u, mask_s = (std_u==0), (std_s==0)
145+
std_u = std_u + (mask_u & (~mask_s))*std_s + (mask_u & mask_s)*1
146+
std_s = std_s + ((~mask_u) & mask_s)*std_u + (mask_u & mask_s)*1
147+
if(mode=='auto'):
148+
scaling_u = np.max(np.stack([scaling_u,(std_u/std_s)]),0)
149+
scaling_s = np.max(np.stack([scaling_s,(std_s/std_u)]),0)
150+
elif(mode=='scale_u'):
151+
scaling_u = std_u/std_s
152+
elif(mode=='scale_s'):
153+
scaling_s = std_s/std_u
154+
return U/scaling_u, S/scaling_s, scaling_u, scaling_s
155+
156+
def scale_by_cell(U,S,train_idx=None,separate_us_scale=True):
157+
N = U.shape[0]
158+
nu, ns = U.sum(1, keepdims=True), S.sum(1, keepdims=True)
159+
if(separate_us_scale):
160+
norm_count = (np.median(nu), np.median(ns)) if train_idx is None else (np.median(nu[train_idx]), np.median(ns[train_idx]))
161+
lu = nu/norm_count[0]
162+
ls = ns/norm_count[1]
163+
else:
164+
norm_count = np.median(nu+ns) if train_idx is None else np.median(nu[train_idx]+ns[train_idx])
165+
lu = (nu+ns)/norm_count
166+
ls = lu
167+
return U/lu, S/ls, lu, ls
168+
169+
def get_cell_scale(U,S,train_idx=None,separate_us_scale=True):
170+
N = U.shape[0]
171+
nu, ns = U.sum(1, keepdims=True), S.sum(1, keepdims=True)
172+
if(separate_us_scale):
173+
norm_count = (np.median(nu), np.median(ns)) if train_idx is None else (np.median(nu[train_idx]), np.median(ns[train_idx]))
174+
lu = nu/norm_count[0]
175+
ls = ns/norm_count[1]
176+
else:
177+
norm_count = np.median(nu+ns) if train_idx is None else np.median(nu[train_idx]+ns[train_idx])
178+
lu = (nu+ns)/norm_count
179+
ls = lu
180+
return lu, ls
181+
126182
def linreg(u, s):
127183
q = np.sum(s*s)
128184
r = np.sum(u*s)
@@ -131,6 +187,7 @@ def linreg(u, s):
131187
k = 1.0+np.random.rand()
132188
return k
133189

190+
134191
def init_gene(s,u,percent,fit_scaling=False,Ntype=None):
135192
#Adopted from scvelo
136193

0 commit comments

Comments
 (0)