11from abc import ABC , abstractmethod
2- from typing import Any , Iterable
2+ from dataclasses import dataclass
3+ from typing import Any , Iterable , Optional
34
45import pandas as pd
56from pydantic import BaseModel , ConfigDict
67
78from siapy .core .exceptions import InvalidInputError
9+ from siapy .entities import Signatures
810
911from .helpers import generate_classification_target , generate_regression_target
1012
@@ -88,7 +90,7 @@ def reset_index(self) -> "ClassificationTarget":
8890
8991class RegressionTarget (Target ):
9092 value : pd .Series
91- name : str
93+ name : str = "value"
9294
9395 def __getitem__ (self , indices : Any ) -> "RegressionTarget" :
9496 value = self .value .iloc [indices ]
@@ -105,7 +107,7 @@ def from_iterable(cls, data: Iterable[Any]) -> "RegressionTarget":
105107 @classmethod
106108 def from_dict (cls , data : dict [str , Any ]) -> "RegressionTarget" :
107109 value = pd .Series (data ["value" ], name = "value" )
108- name = data ["name" ]
110+ name = data ["name" ] if "name" in data else "value"
109111 return cls (value = value , name = name )
110112
111113 def to_dict (self ) -> dict [str , Any ]:
@@ -121,42 +123,38 @@ def reset_index(self) -> "RegressionTarget":
121123 return RegressionTarget (value = self .value .reset_index (drop = True ), name = self .name )
122124
123125
124- class TabularDatasetData (BaseModel ):
125- model_config = ConfigDict (arbitrary_types_allowed = True )
126- pixels : pd .DataFrame
127- signals : pd .DataFrame
126+ @dataclass
127+ class TabularDatasetData :
128+ signatures : Signatures
128129 metadata : pd .DataFrame
129130 target : Target | None = None
130131
131- def __init__ (self , * args : Any , ** kwargs : Any ):
132- super ().__init__ (* args , ** kwargs )
133- self ._validate_lengths ()
132+ def __len__ (self ) -> int :
133+ return len (self .signatures )
134134
135- def __setattr__ (self , name : str , value : Any ) -> None :
136- super ().__setattr__ (name , value )
137- if name in self .model_fields .keys ():
138- self ._validate_lengths ()
135+ def __repr__ (self ) -> str :
136+ return f"TabularDatasetData(signatures={ self .signatures } , metadata={ self .metadata } , target={ self .target } )"
139137
140138 def __getitem__ (self , indices : Any ) -> "TabularDatasetData" :
141- pixels = self .pixels .iloc [indices ]
142- signals = self .signals .iloc [indices ]
139+ signatures = self .signatures [indices ]
143140 metadata = self .metadata .iloc [indices ]
141+ if isinstance (metadata , pd .Series ):
142+ metadata = pd .DataFrame (metadata ).T
144143 target = None if self .target is None else self .target .__getitem__ (indices )
145- return TabularDatasetData (pixels = pixels , signals = signals , metadata = metadata , target = target )
144+ return TabularDatasetData (signatures = signatures , metadata = metadata , target = target )
146145
147- def __len__ (self ) -> int :
148- return len ( self .pixels )
146+ def __post_init__ (self ) -> None :
147+ self ._validate_lengths ( )
149148
150149 @classmethod
151150 def from_dict (cls , data : dict [str , Any ]) -> "TabularDatasetData" :
152- pixels = pd .DataFrame (data ["pixels" ])
153- signals = pd .DataFrame (data ["signals" ])
151+ signatures = Signatures .from_dict ({"pixels" : data ["pixels" ], "signals" : data ["signals" ]})
154152 metadata = pd .DataFrame (data ["metadata" ])
155153 target = TabularDatasetData .target_from_dict (data .get ("target" , None ))
156- return cls (pixels = pixels , signals = signals , metadata = metadata , target = target )
154+ return cls (signatures = signatures , metadata = metadata , target = target )
157155
158156 @staticmethod
159- def target_from_dict (data : dict [str , Any ] | None ) -> Target | None :
157+ def target_from_dict (data : dict [str , Any ] | None = None ) -> Optional [ Target ] :
160158 if data is None :
161159 return None
162160
@@ -172,14 +170,13 @@ def target_from_dict(data: dict[str, Any] | None) -> Target | None:
172170 raise InvalidInputError (data , "Invalid target dict." )
173171
174172 def _validate_lengths (self ) -> None :
175- if not ( len (self .pixels ) == len (self .signals ) == len ( self . metadata ) ):
173+ if len (self .signatures ) != len (self .metadata ):
176174 raise InvalidInputError (
177175 {
178- "pixels_length" : len (self .pixels ),
179- "signals_length" : len (self .signals ),
176+ "signatures_length" : len (self .signatures ),
180177 "metadata_length" : len (self .metadata ),
181178 },
182- "Lengths of pixels, signals, and metadata must be equal" ,
179+ "Lengths of signatures and metadata must be equal" ,
183180 )
184181 if self .target is not None and len (self .target ) != len (self ):
185182 raise InvalidInputError (
@@ -190,25 +187,77 @@ def _validate_lengths(self) -> None:
190187 "Target length must be equal to the length of the dataset." ,
191188 )
192189
190+ def set_attributes (
191+ self ,
192+ * ,
193+ signatures : Signatures | None = None ,
194+ metadata : pd .DataFrame | None = None ,
195+ target : Target | None = None ,
196+ ) -> "TabularDatasetData" :
197+ current_data = self .copy ()
198+ signatures = signatures if signatures is not None else current_data .signatures
199+ metadata = metadata if metadata is not None else current_data .metadata
200+ target = target if target is not None else current_data .target
201+ return TabularDatasetData (signatures = signatures , metadata = metadata , target = target )
202+
193203 def to_dict (self ) -> dict [str , Any ]:
204+ signatures_dict = self .signatures .to_dict ()
194205 return {
195- "pixels" : self . pixels . to_dict () ,
196- "signals" : self . signals . to_dict () ,
206+ "pixels" : signatures_dict [ " pixels" ] ,
207+ "signals" : signatures_dict [ " signals" ] ,
197208 "metadata" : self .metadata .to_dict (),
198209 "target" : self .target .to_dict () if self .target is not None else None ,
199210 }
200211
201212 def to_dataframe (self ) -> pd .DataFrame :
202- combined_df = pd .concat ([self .pixels , self . signals , self .metadata ], axis = 1 )
213+ combined_df = pd .concat ([self .signatures . to_dataframe () , self .metadata ], axis = 1 )
203214 if self .target is not None :
204215 target_series = self .target .to_dataframe ()
205216 combined_df = pd .concat ([combined_df , target_series ], axis = 1 )
206217 return combined_df
207218
219+ def to_dataframe_multiindex (self ) -> pd .DataFrame :
220+ signatures_df = self .signatures .to_dataframe_multiindex ()
221+
222+ metadata_columns = pd .MultiIndex .from_tuples (
223+ [("metadata" , col ) for col in self .metadata .columns ], names = ["category" , "field" ]
224+ )
225+ metadata_df = pd .DataFrame (self .metadata .values , columns = metadata_columns )
226+
227+ combined_df = pd .concat ([signatures_df , metadata_df ], axis = 1 )
228+
229+ if self .target is not None :
230+ target_df = self .target .to_dataframe ()
231+ if isinstance (self .target , ClassificationTarget ):
232+ target_columns = pd .MultiIndex .from_tuples (
233+ [("target" , col ) for col in target_df .columns ],
234+ names = ["category" , "field" ],
235+ )
236+ elif isinstance (self .target , RegressionTarget ):
237+ target_columns = pd .MultiIndex .from_tuples (
238+ [("target" , self .target .name )],
239+ names = ["category" , "field" ],
240+ )
241+ else :
242+ raise InvalidInputError (
243+ self .target ,
244+ "Invalid target type. Expected ClassificationTarget or RegressionTarget." ,
245+ )
246+ target_df = pd .DataFrame (target_df .values , columns = target_columns )
247+ combined_df = pd .concat ([combined_df , target_df ], axis = 1 )
248+
249+ return combined_df
250+
208251 def reset_index (self ) -> "TabularDatasetData" :
209252 return TabularDatasetData (
210- pixels = self .pixels .reset_index (drop = True ),
211- signals = self .signals .reset_index (drop = True ),
253+ signatures = self .signatures .reset_index (),
212254 metadata = self .metadata .reset_index (drop = True ),
213255 target = self .target .reset_index () if self .target is not None else None ,
214256 )
257+
258+ def copy (self ) -> "TabularDatasetData" :
259+ return TabularDatasetData (
260+ signatures = self .signatures .copy (),
261+ metadata = self .metadata .copy (),
262+ target = self .target .model_copy () if self .target is not None else None ,
263+ )
0 commit comments