@@ -9,12 +9,12 @@ class AutoEmulateDataset(Dataset):
99
1010 def __init__ (
1111 self ,
12- data_path : str ,
13- n_steps_input : int ,
14- n_steps_output : int ,
12+ data_path : str | None ,
13+ data : dict | None = None ,
14+ n_steps_input : int = 1 ,
15+ n_steps_output : int = 1 ,
1516 stride : int = 1 ,
1617 # TODO: support for passing data from dict
17- # data: dict | None = None,
1818 input_channel_idxs : tuple [int , ...] | None = None ,
1919 output_channel_idxs : tuple [int , ...] | None = None ,
2020 ):
@@ -45,20 +45,8 @@ def __init__(
4545 self .input_channel_idxs = input_channel_idxs
4646 self .output_channel_idxs = output_channel_idxs
4747
48- # TODO: support passing as dict
49- # Load data
50- with h5py .File (data_path , "r" ) as f :
51- assert "data" in f , "HDF5 file must contain 'data' dataset"
52- self .data : TensorLike = torch .Tensor (f ["data" ][:]) # type: ignore # [N, T, W, H, C] # noqa: PGH003
53- print (f"Loaded data shape: { self .data .shape } " )
54- # TODO: add the constant scalars
55- self .constant_scalars = (
56- torch .Tensor (f ["constant_scalars" ][:]) # type: ignore # noqa: PGH003
57- if "constant_scalars" in f
58- else None
59- ) # [N, C]
60- # TODO: add the constant fields
61- # self.constant_fields = torch.Tensor(f['data'][:]) # [N, W, H, C]
48+ # Read or parse data
49+ self .read_data (data_path ) if data_path is not None else self .parse_data (data )
6250
6351 # Destructured here
6452 (
@@ -107,14 +95,45 @@ def __init__(
10795 print (f"Each input sample shape: { self .all_input_fields [0 ].shape } " )
10896 print (f"Each output sample shape: { self .all_output_fields [0 ].shape } " )
10997
98+ def read_data (self , data_path : str ):
99+ """Read data.
100+
101+ By default assumes HDF5 format in `data_path` with correct shape and fields.
102+ """
103+ # TODO: support passing as dict
104+ # Load data
105+ self .data_path = data_path
106+ with h5py .File (self .data_path , "r" ) as f :
107+ assert "data" in f , "HDF5 file must contain 'data' dataset"
108+ self .data : TensorLike = torch .Tensor (f ["data" ][:]) # type: ignore # [N, T, W, H, C] # noqa: PGH003
109+ print (f"Loaded data shape: { self .data .shape } " )
110+ # TODO: add the constant scalars
111+ self .constant_scalars = (
112+ torch .Tensor (f ["constant_scalars" ][:]) # type: ignore # noqa: PGH003
113+ if "constant_scalars" in f
114+ else None
115+ ) # [N, C]
116+ # TODO: add the constant fields
117+ # self.constant_fields = torch.Tensor(f['data'][:]) # [N, W, H, C]
118+
119+ def parse_data (self , data : dict | None ):
120+ """Parse data from a dictionary."""
121+ if data is not None :
122+ self .data = data ["data" ]
123+ self .constant_scalars = data .get ("constant_scalars" , None )
124+ self .constant_fields = data .get ("constant_fields" , None )
125+ return
126+ msg = "No data provided to parse."
127+ raise ValueError (msg )
128+
110129 def __len__ (self ): # noqa: D105
111130 return len (self .all_input_fields )
112131
113132 def __getitem__ (self , idx ): # noqa: D105
114133 return {
115134 "input_fields" : self .all_input_fields [idx ],
116135 "output_fields" : self .all_output_fields [idx ],
117- # "constant_scalars": self.all_constant_scalars[idx],
136+ "constant_scalars" : self .all_constant_scalars [idx ],
118137 # TODO: add this
119138 # "constant_fields": self.all_constant_fields[idx],
120139 }
0 commit comments