@@ -108,11 +108,21 @@ def _msd_gaps(traj, mpp, fps, max_lagtime=100, detail=False, pos_columns=None):
108108
109109 result = pd .DataFrame (_msd_iter (pos .values , lagtimes ),
110110 columns = result_columns , index = lagtimes )
111- result ['msd' ] = result [result_columns [- len (pos_columns ):]].sum (1 )
111+ result ['msd' ] = result [result_columns [- len (pos_columns ):]].sum (1 , skipna = False )
112112 if detail :
113113 # effective number of measurements
114114 # approximately corrected with number of gaps
115115 result ['N' ] = _msd_N (len (pos ), lagtimes ) * len (traj ) / len (pos )
116+ desired_total_N = result ['N' ].sum ()
117+
118+ # If MSD is nan that's because there were zero datapoints. Reset N to 0.
119+ result ['N' ] = np .where (result ['msd' ].isna (), 0 , result ['N' ])
120+ current_total_N = result ['N' ].sum ()
121+
122+ if current_total_N != desired_total_N :
123+ # scale up N for the rest of the column
124+ result ['N' ] = result ['N' ] * desired_total_N / current_total_N
125+
116126 result ['lagt' ] = result .index .values / float (fps )
117127 result .index .name = 'lagt'
118128 return result
@@ -232,6 +242,10 @@ def emsd(traj, mpp, fps, max_lagtime=100, detail=False, pos_columns=None):
232242 msds .append (msd (ptraj , mpp , fps , max_lagtime , True , pos_columns ))
233243 ids .append (pid )
234244 msds = pandas_concat (msds , keys = ids , names = ['particle' , 'frame' ])
245+
246+ # remove np.nan because it would make the rest of the calculation break
247+ msds ['msd' ] = np .where (msds ['msd' ].isna (), 0 , msds ['msd' ])
248+
235249 results = msds .mul (msds ['N' ], axis = 0 ).groupby (level = 1 ).mean () # weighted average
236250 results = results .div (msds ['N' ].groupby (level = 1 ).mean (), axis = 0 ) # weights normalized
237251 # Above, lagt is lumped in with the rest for simplicity and speed.
0 commit comments