@@ -57,26 +57,43 @@ def get_value(self) -> list[str | int | None]:
5757class MutableDataFrame :
5858 """A DataFrame that can change values in-place."""
5959
60- def __init__ (self , dataframe : pl .DataFrame , hierarchical : bool ) -> None :
60+ def __init__ (
61+ self , dataframe : pl .DataFrame | pl .LazyFrame , hierarchical : bool
62+ ) -> None :
6163 """Initialize the class."""
62- self .dataset : pl .DataFrame | dict [str , Any ] = dataframe
64+ self .dataset : pl .DataFrame | dict [str , Any ] | pl . LazyFrame = dataframe
6365 self .matched_fields : dict [str , FieldMatch ] = {}
6466 self .matched_fields_metrics : dict [str , int ] | None = None
6567 self .hierarchical : bool = hierarchical
66- self .schema = dataframe .schema
68+ self .schema = (
69+ dataframe .schema
70+ if isinstance (dataframe , pl .DataFrame )
71+ else dataframe .collect_schema ()
72+ )
6773
6874 def match_rules (
6975 self , rules : list [PseudoRule ], target_rules : list [PseudoRule ] | None
7076 ) -> None :
7177 """Create references to all the columns that matches the given pseudo rules."""
7278 if self .hierarchical is False :
73- assert isinstance (self .dataset , pl .DataFrame )
79+ assert isinstance (self .dataset , pl .DataFrame ) or isinstance (
80+ self .dataset , pl .LazyFrame
81+ )
82+
83+ def extract_column_data (
84+ pattern : str , dataset : pl .DataFrame | pl .LazyFrame
85+ ) -> list [Any ]:
86+ if isinstance (dataset , pl .DataFrame ):
87+ return list (dataset .get_column (pattern ))
88+ elif isinstance (dataset , pl .LazyFrame ):
89+ return list (dataset .select (pattern ).collect ().to_series ())
90+
7491 self .matched_fields = {
7592 str (i ): FieldMatch (
7693 path = rule .pattern ,
7794 pattern = rule .pattern ,
7895 indexer = [],
79- col = list ( self . dataset . get_column ( rule .pattern ) ),
96+ col = extract_column_data ( rule .pattern , self . dataset ),
8097 wrapped_list = False ,
8198 func = rule .func ,
8299 target_func = target_rule .func if target_rule else None ,
@@ -109,7 +126,9 @@ def get_matched_fields(self) -> dict[str, FieldMatch]:
109126 def update (self , path : str , data : list [str | None ]) -> None :
110127 """Update a column with the given data."""
111128 if self .hierarchical is False :
112- assert isinstance (self .dataset , pl .DataFrame )
129+ assert isinstance (self .dataset , pl .DataFrame ) or isinstance (
130+ self .dataset , pl .LazyFrame
131+ )
113132 self .dataset = self .dataset .with_columns (pl .Series (data ).alias (path ))
114133 elif (field_match := self .matched_fields .get (path )) is not None :
115134 assert isinstance (self .dataset , dict )
@@ -122,10 +141,12 @@ def update(self, path: str, data: list[str | None]) -> None:
122141 data if field_match .wrapped_list is False else data [0 ]
123142 )
124143
125- def to_polars (self ) -> pl .DataFrame :
144+ def to_polars (self ) -> pl .DataFrame | pl . LazyFrame :
126145 """Convert to Polars DataFrame."""
127146 if self .hierarchical is False :
128- assert isinstance (self .dataset , pl .DataFrame )
147+ assert isinstance (self .dataset , pl .DataFrame ) or isinstance (
148+ self .dataset , pl .LazyFrame
149+ )
129150 return self .dataset
130151 else :
131152 assert isinstance (self .dataset , dict )
0 commit comments