@@ -254,7 +254,12 @@ def network_from_df(self, df):
254254 """
255255 self._check_input(df, 'df')
256256 # Ensure order of columns
257- df = df[['i', 'j', 't']]
257+ if len(df.columns)==4 :
258+ df = df[['i', 'j', 't', 'weight']]
259+ elif len(df.columns)==3 :
260+ df = df[['i', 'j', 't']]
261+ else :
262+ print("Wrong number of columns in df")
258263 self.network = df
259264 self._update_network()
260265
@@ -319,7 +324,7 @@ def _calc_netshape(self):
319324 n_timepoints = int(self.network.shape[-1])
320325 self.netshape = (n_nodes, n_timepoints)
321326 else:
322- n_nodes = self.network[['i', 'j']].max(axis=1).max()+1
327+ n_nodes = len(np.unique( self.network[['i', 'j']].values))
323328 n_timepoints = self.network['t'].max() - self.network['t'].min() + 1
324329 if self.N > n_nodes:
325330 n_nodes = self.N
@@ -573,4 +578,4 @@ def _check_input(self, datain, datatype):
573578 if ('i' and 'j' and 't') not in datain:
574579 raise ValueError('Columns must be \'i\' \'j\' and \'t\'')
575580 else:
576- raise ValueError('Unknown datatype')
581+ raise ValueError('Unknown datatype')
0 commit comments