Skip to content

Commit e9b34e5

Browse files
committed
rename
1 parent 4165e47 commit e9b34e5

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

pyhealth/models/stagenet_mha.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -275,11 +275,12 @@ def forward(
275275
return last_output, output, distance
276276

277277

278-
class StageNet(BaseModel):
278+
class StageAttentionNet(BaseModel):
279279
"""StageNet model.
280280
281281
Paper: Junyi Gao et al. Stagenet: Stage-aware neural networks for health
282-
risk prediction. WWW 2020.
282+
risk prediction. WWW 2020. But with Multi-Head Attention (MHA) between
283+
the SA-LSTM and the SA-CNN.
283284
284285
This model uses the StageNetProcessor which expects inputs in the format:
285286
{"value": [...], "time": [...]}
@@ -376,7 +377,7 @@ def __init__(
376377
levels: int = 3,
377378
**kwargs,
378379
):
379-
super(StageNet, self).__init__(
380+
super(StageAttentionNet, self).__init__(
380381
dataset=dataset,
381382
)
382383
self.embedding_dim = embedding_dim
@@ -689,7 +690,7 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]:
689690
train_loader = get_dataloader(dataset, batch_size=2, shuffle=True)
690691

691692
# model
692-
model = StageNet(dataset=dataset)
693+
model = StageAttentionNet(dataset=dataset)
693694

694695
# data batch
695696
data_batch = next(iter(train_loader))

tests/core/test_stagenet_mha.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch
33

44
from pyhealth.datasets import SampleDataset, get_dataloader
5-
from pyhealth.models.stagenet_mha import StageNet as StageNetMHA
5+
from pyhealth.models.stagenet_mha import StageAttentionNet as StageNetMHA
66

77

88
class TestStageNetMHA(unittest.TestCase):

0 commit comments

Comments
 (0)