Skip to content

Commit afbeced

Browse files
committed
init commit attempt for multihot embeddings, and other basic modeling tools that people would typically use to start an ML project
1 parent a75de8c commit afbeced

File tree

7 files changed

+674
-2
lines changed

7 files changed

+674
-2
lines changed

pyhealth/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from .deepr import Deepr, DeeprLayer
88
from .embedding import EmbeddingModel
99
from .gamenet import GAMENet, GAMENetLayer
10+
from .logistic_regression import LogisticRegression
1011
from .gan import GAN
1112
from .gnn import GAT, GCN
1213
from .graph_torchvision_model import Graph_TorchvisionModel

pyhealth/models/embedding.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,46 @@
44
import torch.nn as nn
55

66
from ..datasets import SampleDataset
7-
from ..processors import SequenceProcessor, TimeseriesProcessor, TensorProcessor
7+
from ..processors import (
8+
MultiHotProcessor,
9+
SequenceProcessor,
10+
TensorProcessor,
11+
TimeseriesProcessor,
12+
)
813
from .base_model import BaseModel
914

1015

1116
class EmbeddingModel(BaseModel):
1217
"""
1318
EmbeddingModel is responsible for creating embedding layers for different types of input data.
1419
20+
This model automatically creates appropriate embedding transformations based on the processor type:
21+
22+
- SequenceProcessor: Creates nn.Embedding for categorical sequences (e.g., diagnosis codes)
23+
Input: (batch, seq_len) with integer indices
24+
Output: (batch, seq_len, embedding_dim)
25+
26+
- TimeseriesProcessor: Creates nn.Linear for time series features
27+
Input: (batch, seq_len, num_features)
28+
Output: (batch, seq_len, embedding_dim)
29+
30+
- TensorProcessor: Creates nn.Linear for fixed-size numerical features
31+
Input: (batch, feature_size)
32+
Output: (batch, embedding_dim)
33+
34+
- MultiHotProcessor: Creates nn.Linear for multi-hot encoded categorical features
35+
Input: (batch, num_categories) binary tensor
36+
Output: (batch, embedding_dim)
37+
Note: Converts sparse categorical representations to dense embeddings
38+
39+
- Other processors with size(): Creates nn.Linear if processor reports a positive size
40+
Input: (batch, size)
41+
Output: (batch, embedding_dim)
42+
1543
Attributes:
1644
dataset (SampleDataset): The dataset containing input processors.
1745
embedding_layers (nn.ModuleDict): A dictionary of embedding layers for each input field.
46+
embedding_dim (int): The target embedding dimension for all features.
1847
"""
1948

2049
def __init__(self, dataset: SampleDataset, embedding_dim: int = 128):
@@ -26,6 +55,7 @@ def __init__(self, dataset: SampleDataset, embedding_dim: int = 128):
2655
embedding_dim (int): The dimension of the embedding space. Default is 128.
2756
"""
2857
super().__init__(dataset)
58+
self.embedding_dim = embedding_dim
2959
self.embedding_layers = nn.ModuleDict()
3060
for field_name, processor in self.dataset.input_processors.items():
3161
if isinstance(processor, SequenceProcessor):
@@ -54,6 +84,36 @@ def __init__(self, dataset: SampleDataset, embedding_dim: int = 128):
5484
self.embedding_layers[field_name] = nn.Linear(
5585
in_features=input_size, out_features=embedding_dim
5686
)
87+
elif isinstance(processor, MultiHotProcessor):
88+
# MultiHotProcessor produces fixed-size binary vectors
89+
# Use processor.size() to get the vocabulary size (num_categories)
90+
num_categories = processor.size()
91+
self.embedding_layers[field_name] = nn.Linear(
92+
in_features=num_categories, out_features=embedding_dim
93+
)
94+
else:
95+
# Handle other processors with a size() method
96+
size_attr = getattr(processor, "size", None)
97+
if callable(size_attr):
98+
size_value = size_attr()
99+
else:
100+
size_value = size_attr
101+
102+
if isinstance(size_value, int) and size_value > 0:
103+
self.embedding_layers[field_name] = nn.Linear(
104+
in_features=size_value, out_features=embedding_dim
105+
)
106+
else:
107+
# No valid size() method found - raise an error
108+
raise ValueError(
109+
f"Processor for field '{field_name}' (type: {type(processor).__name__}) "
110+
f"does not have a valid size() method or it returned an invalid value. "
111+
f"To use this processor with EmbeddingModel, it must either:\n"
112+
f" 1. Be a recognized processor type (SequenceProcessor, TimeseriesProcessor, "
113+
f"TensorProcessor, MultiHotProcessor), or\n"
114+
f" 2. Implement a size() method that returns a positive integer representing "
115+
f"the feature dimension."
116+
)
57117

58118
def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
59119
"""
@@ -67,8 +127,8 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
67127
"""
68128
embedded = {}
69129
for field_name, tensor in inputs.items():
130+
tensor = tensor.to(self.device)
70131
if field_name in self.embedding_layers:
71-
tensor = tensor.to(self.device)
72132
embedded[field_name] = self.embedding_layers[field_name](tensor)
73133
else:
74134
embedded[field_name] = tensor # passthrough for continuous features
Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
from typing import Dict
2+
3+
import torch
4+
import torch.nn as nn
5+
6+
from pyhealth.datasets import SampleDataset
7+
from pyhealth.models import BaseModel
8+
9+
from .embedding import EmbeddingModel
10+
11+
12+
class LogisticRegression(BaseModel):
13+
"""Logistic/Linear regression baseline model.
14+
15+
This model uses embeddings from different input features and applies a single
16+
linear transformation (no hidden layers or non-linearity) to produce predictions.
17+
18+
- For classification tasks: acts as logistic regression
19+
- For regression tasks: acts as linear regression
20+
21+
The model automatically handles different input types through the EmbeddingModel,
22+
pools sequence dimensions, concatenates all feature embeddings, and applies a
23+
final linear layer.
24+
25+
Args:
26+
dataset: the dataset to train the model. It is used to query certain
27+
information such as the set of all tokens.
28+
embedding_dim: the embedding dimension. Default is 128.
29+
**kwargs: other parameters (for compatibility).
30+
31+
Examples:
32+
>>> from pyhealth.datasets import SampleDataset
33+
>>> samples = [
34+
... {
35+
... "patient_id": "patient-0",
36+
... "visit_id": "visit-0",
37+
... "conditions": ["cond-33", "cond-86", "cond-80"],
38+
... "procedures": [1.0, 2.0, 3.5, 4],
39+
... "label": 0,
40+
... },
41+
... {
42+
... "patient_id": "patient-1",
43+
... "visit_id": "visit-1",
44+
... "conditions": ["cond-33", "cond-86", "cond-80"],
45+
... "procedures": [5.0, 2.0, 3.5, 4],
46+
... "label": 1,
47+
... },
48+
... ]
49+
>>> input_schema = {"conditions": "sequence",
50+
... "procedures": "tensor"}
51+
>>> output_schema = {"label": "binary"}
52+
>>> dataset = SampleDataset(samples=samples,
53+
... input_schema=input_schema,
54+
... output_schema=output_schema,
55+
... dataset_name="test")
56+
>>>
57+
>>> from pyhealth.models import LogisticRegression
58+
>>> model = LogisticRegression(dataset=dataset)
59+
>>>
60+
>>> from pyhealth.datasets import get_dataloader
61+
>>> train_loader = get_dataloader(dataset, batch_size=2, shuffle=True)
62+
>>> data_batch = next(iter(train_loader))
63+
>>>
64+
>>> ret = model(**data_batch)
65+
>>> print(ret)
66+
{
67+
'loss': tensor(0.6931, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>),
68+
'y_prob': tensor([[0.5123],
69+
[0.4987]], grad_fn=<SigmoidBackward0>),
70+
'y_true': tensor([[1.],
71+
[0.]]),
72+
'logit': tensor([[0.0492],
73+
[-0.0052]], grad_fn=<AddmmBackward0>)
74+
}
75+
>>>
76+
77+
"""
78+
79+
def __init__(
80+
self,
81+
dataset: SampleDataset,
82+
embedding_dim: int = 128,
83+
**kwargs,
84+
):
85+
super(LogisticRegression, self).__init__(dataset)
86+
self.embedding_dim = embedding_dim
87+
88+
assert len(self.label_keys) == 1, "Only one label key is supported"
89+
self.label_key = self.label_keys[0]
90+
91+
# Use the EmbeddingModel to handle embedding logic
92+
self.embedding_model = EmbeddingModel(dataset, embedding_dim)
93+
94+
# Single linear layer (no hidden layers, no activation)
95+
output_size = self.get_output_size()
96+
self.fc = nn.Linear(len(self.feature_keys) * self.embedding_dim, output_size)
97+
98+
@staticmethod
99+
def mean_pooling(x, mask):
100+
"""Mean pooling over the middle dimension of the tensor.
101+
102+
Args:
103+
x: tensor of shape (batch_size, seq_len, embedding_dim)
104+
mask: tensor of shape (batch_size, seq_len)
105+
106+
Returns:
107+
x: tensor of shape (batch_size, embedding_dim)
108+
109+
Examples:
110+
>>> x.shape
111+
[128, 5, 32]
112+
>>> mean_pooling(x, mask).shape
113+
[128, 32]
114+
"""
115+
return x.sum(dim=1) / mask.sum(dim=1, keepdim=True)
116+
117+
def forward(self, **kwargs) -> Dict[str, torch.Tensor]:
118+
"""Forward propagation.
119+
120+
Args:
121+
**kwargs: keyword arguments for the model. The keys must contain
122+
all the feature keys and the label key.
123+
124+
Returns:
125+
Dict[str, torch.Tensor]: A dictionary with the following keys:
126+
- loss: a scalar tensor representing the loss.
127+
- y_prob: a tensor representing the predicted probabilities.
128+
- y_true: a tensor representing the true labels.
129+
- logit: a tensor representing the logits.
130+
- embed (optional): a tensor representing the patient
131+
embeddings if requested.
132+
"""
133+
patient_emb = []
134+
135+
# Preprocess inputs for EmbeddingModel
136+
processed_inputs = {}
137+
reshape_info = {} # Track which inputs were reshaped
138+
139+
for feature_key in self.feature_keys:
140+
x = kwargs[feature_key]
141+
142+
# Convert to tensor if not already
143+
if not isinstance(x, torch.Tensor):
144+
x = torch.tensor(x, device=self.device)
145+
else:
146+
x = x.to(self.device)
147+
148+
# Handle 3D input: (patient, event, # of codes) -> flatten to 2D
149+
if x.dim() == 3:
150+
batch_size, seq_len, inner_len = x.shape
151+
x = x.view(batch_size, seq_len * inner_len)
152+
reshape_info[feature_key] = {
153+
"original_shape": (batch_size, seq_len, inner_len),
154+
"was_3d": True,
155+
"expanded": False,
156+
}
157+
elif x.dim() == 1:
158+
x = x.unsqueeze(0)
159+
reshape_info[feature_key] = {"was_3d": False, "expanded": True}
160+
else:
161+
reshape_info[feature_key] = {"was_3d": False, "expanded": False}
162+
163+
processed_inputs[feature_key] = x
164+
165+
# Pass through EmbeddingModel
166+
embedded = self.embedding_model(processed_inputs)
167+
168+
for feature_key in self.feature_keys:
169+
x = embedded[feature_key]
170+
171+
info = reshape_info[feature_key]
172+
if info.get("expanded") and x.dim() > 1:
173+
x = x.squeeze(0)
174+
175+
# Handle different tensor dimensions for pooling
176+
if x.dim() == 3:
177+
# Case: (batch, seq_len, embedding_dim) - apply mean pooling
178+
mask = (x.sum(dim=-1) != 0).float()
179+
if mask.sum(dim=-1, keepdim=True).any():
180+
x = self.mean_pooling(x, mask)
181+
else:
182+
x = x.mean(dim=1)
183+
elif x.dim() == 2:
184+
# Case: (batch, embedding_dim) - already pooled, use as is
185+
pass
186+
else:
187+
raise ValueError(f"Unsupported tensor dimension: {x.dim()}")
188+
189+
patient_emb.append(x)
190+
191+
# Concatenate all feature embeddings
192+
patient_emb = torch.cat(patient_emb, dim=1)
193+
194+
# Apply single linear layer (no activation)
195+
logits = self.fc(patient_emb)
196+
197+
# Obtain y_true, loss, y_prob
198+
y_true = kwargs[self.label_key].to(self.device)
199+
loss = self.get_loss_function()(logits, y_true)
200+
y_prob = self.prepare_y_prob(logits)
201+
202+
results = {
203+
"loss": loss,
204+
"y_prob": y_prob,
205+
"y_true": y_true,
206+
"logit": logits,
207+
}
208+
if kwargs.get("embed", False):
209+
results["embed"] = patient_emb
210+
return results
211+
212+
213+
if __name__ == "__main__":
214+
from pyhealth.datasets import SampleDataset
215+
216+
samples = [
217+
{
218+
"patient_id": "patient-0",
219+
"visit_id": "visit-0",
220+
"conditions": ["cond-33", "cond-86", "cond-80"],
221+
"procedures": [1.0, 2.0, 3.5, 4],
222+
"label": 0,
223+
},
224+
{
225+
"patient_id": "patient-1",
226+
"visit_id": "visit-1",
227+
"conditions": ["cond-33", "cond-86", "cond-80"],
228+
"procedures": [5.0, 2.0, 3.5, 4],
229+
"label": 1,
230+
},
231+
]
232+
233+
# Define input and output schemas
234+
input_schema = {
235+
"conditions": "sequence", # sequence of condition codes
236+
"procedures": "tensor", # tensor of procedure values
237+
}
238+
output_schema = {"label": "binary"} # binary classification
239+
240+
# dataset
241+
dataset = SampleDataset(
242+
samples=samples,
243+
input_schema=input_schema,
244+
output_schema=output_schema,
245+
dataset_name="test",
246+
)
247+
248+
# data loader
249+
from pyhealth.datasets import get_dataloader
250+
251+
train_loader = get_dataloader(dataset, batch_size=2, shuffle=True)
252+
253+
# model
254+
model = LogisticRegression(dataset=dataset)
255+
256+
# data batch
257+
data_batch = next(iter(train_loader))
258+
259+
# try the model
260+
ret = model(**data_batch)
261+
print(ret)
262+
263+
# try loss backward
264+
ret["loss"].backward()

pyhealth/processors/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def get_processor(name: str):
2525
MultiLabelProcessor,
2626
RegressionLabelProcessor,
2727
)
28+
from .multi_hot_processor import MultiHotProcessor
2829
from .raw_processor import RawProcessor
2930
from .sequence_processor import SequenceProcessor
3031
from .signal_processor import SignalProcessor
@@ -40,6 +41,7 @@ def get_processor(name: str):
4041
"TensorProcessor",
4142
"TimeseriesProcessor",
4243
"SignalProcessor",
44+
"MultiHotProcessor",
4345
"BinaryLabelProcessor",
4446
"MultiClassLabelProcessor",
4547
"MultiLabelProcessor",

0 commit comments

Comments
 (0)