Skip to content

Commit e7bfc47

Browse files
authored
Merge pull request #79 from maximelucas/small_fixes
fixed #73 and #78
2 parents ec62732 + 5b52f88 commit e7bfc47

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

teneto/classes/network.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,12 @@ def network_from_df(self, df):
254254
"""
255255
self._check_input(df, 'df')
256256
# Ensure order of columns
257-
df = df[['i', 'j', 't']]
257+
if len(df.columns)==4 :
258+
df = df[['i', 'j', 't', 'weight']]
259+
elif len(df.columns)==3 :
260+
df = df[['i', 'j', 't']]
261+
else :
262+
print("Wrong number of columns in df")
258263
self.network = df
259264
self._update_network()
260265

@@ -319,7 +324,7 @@ def _calc_netshape(self):
319324
n_timepoints = int(self.network.shape[-1])
320325
self.netshape = (n_nodes, n_timepoints)
321326
else:
322-
n_nodes = self.network[['i', 'j']].max(axis=1).max()+1
327+
n_nodes = len(np.unique(self.network[['i', 'j']].values))
323328
n_timepoints = self.network['t'].max() - self.network['t'].min() + 1
324329
if self.N > n_nodes:
325330
n_nodes = self.N
@@ -573,4 +578,4 @@ def _check_input(self, datain, datatype):
573578
if ('i' and 'j' and 't') not in datain:
574579
raise ValueError('Columns must be \'i\' \'j\' and \'t\'')
575580
else:
576-
raise ValueError('Unknown datatype')
581+
raise ValueError('Unknown datatype')

0 commit comments

Comments
 (0)