@@ -702,56 +702,90 @@ pub fn extract_table_name_from_sql(sql: &str) -> Option<String> {
702702/// Extract table name and optional alias from SQL FROM clause
703703/// Returns (table_name, Option<alias>)
704704pub fn extract_table_and_alias_from_sql ( sql : & str ) -> Option < ( String , Option < String > ) > {
705- // Normalize whitespace to handle newlines/tabs in SQL
706- let normalized: String = sql
707- . chars ( )
708- . map ( |c| if c. is_whitespace ( ) { ' ' } else { c } )
709- . collect ( ) ;
710- let from_pos = find_top_level_keyword ( & normalized, "FROM" , 0 ) ?;
711- let after_from = & normalized[ from_pos..] ;
705+ fn skip_ws_and_comments ( sql : & str , mut idx : usize ) -> usize {
706+ let bytes = sql. as_bytes ( ) ;
707+ while idx < bytes. len ( ) {
708+ let c = bytes[ idx] as char ;
709+ if c. is_whitespace ( ) {
710+ idx += 1 ;
711+ continue ;
712+ }
713+ if c == '-' && idx + 1 < bytes. len ( ) && bytes[ idx + 1 ] as char == '-' {
714+ idx += 2 ;
715+ while idx < bytes. len ( ) {
716+ let ch = bytes[ idx] as char ;
717+ idx += 1 ;
718+ if ch == '\n' || ch == '\r' {
719+ break ;
720+ }
721+ }
722+ continue ;
723+ }
724+ if c == '/' && idx + 1 < bytes. len ( ) && bytes[ idx + 1 ] as char == '*' {
725+ idx += 2 ;
726+ while idx + 1 < bytes. len ( ) {
727+ let ch = bytes[ idx] as char ;
728+ if ch == '*' && bytes[ idx + 1 ] as char == '/' {
729+ idx += 2 ;
730+ break ;
731+ }
732+ idx += 1 ;
733+ }
734+ continue ;
735+ }
736+ break ;
737+ }
738+ idx
739+ }
712740
713- // Parse: FROM table_name [AS] [alias]
714- let ( rest, _) = tag_no_case :: < _ , _ , nom:: error:: Error < & str > > ( "FROM" ) ( after_from) . ok ( ) ?;
715- let ( rest, _) = multispace1 :: < _ , nom:: error:: Error < & str > > ( rest) . ok ( ) ?;
716- let ( rest, table) = identifier ( rest) . ok ( ) ?;
741+ let from_pos = find_top_level_keyword ( sql, "FROM" , 0 ) ?;
742+ let mut idx = from_pos + 4 ;
743+ idx = skip_ws_and_comments ( sql, idx) ;
744+ if idx >= sql. len ( ) {
745+ return None ;
746+ }
717747
718- // Check for alias (optional AS keyword followed by identifier)
719- let rest_trimmed = rest. trim_start ( ) ;
748+ let table_start = idx;
749+ while idx < sql. len ( ) && is_table_ident_char ( sql. as_bytes ( ) [ idx] as char ) {
750+ idx += 1 ;
751+ }
752+ if table_start == idx {
753+ return None ;
754+ }
755+ let table = sql[ table_start..idx] . to_string ( ) ;
720756
721- // Check for end of FROM clause (WHERE, GROUP, ORDER, etc.)
722- let rest_upper = rest_trimmed. to_uppercase ( ) ;
723- if rest_trimmed. is_empty ( )
724- || rest_upper. starts_with ( "WHERE" )
725- || rest_upper. starts_with ( "GROUP" )
726- || rest_upper. starts_with ( "ORDER" )
727- || rest_upper. starts_with ( "LIMIT" )
728- || rest_upper. starts_with ( "HAVING" )
729- || rest_upper. starts_with ( "JOIN" )
730- || rest_trimmed. starts_with ( ';' )
731- {
732- return Some ( ( table. to_string ( ) , None ) ) ;
757+ idx = skip_ws_and_comments ( sql, idx) ;
758+ if idx >= sql. len ( ) || sql. as_bytes ( ) [ idx] as char == ';' {
759+ return Some ( ( table, None ) ) ;
733760 }
734761
735- // Try to parse optional AS keyword
736- let after_as = if rest_upper. starts_with ( "AS " ) {
737- & rest_trimmed[ 3 ..]
738- } else {
739- rest_trimmed
740- } ;
762+ let rest = & sql[ idx..] ;
763+ let rest_upper = rest. to_uppercase ( ) ;
764+ let mut rest_after_as = rest;
765+ if rest_upper. starts_with ( "AS" ) {
766+ let after_as = & rest[ 2 ..] ;
767+ if after_as
768+ . chars ( )
769+ . next ( )
770+ . map_or ( false , |ch| ch. is_whitespace ( ) )
771+ {
772+ let mut as_idx = idx + 2 ;
773+ as_idx = skip_ws_and_comments ( sql, as_idx) ;
774+ rest_after_as = & sql[ as_idx..] ;
775+ }
776+ }
741777
742- // Parse the alias identifier
743- if let Ok ( ( _, alias) ) = identifier ( after_as. trim_start ( ) ) {
744- // Make sure alias isn't a keyword
778+ if let Ok ( ( _, alias) ) = identifier ( rest_after_as. trim_start ( ) ) {
745779 let alias_upper = alias. to_uppercase ( ) ;
746780 if matches ! (
747781 alias_upper. as_str( ) ,
748- "WHERE" | "GROUP" | "ORDER" | "LIMIT" | "HAVING" | "JOIN"
782+ "FROM" | " WHERE" | "GROUP" | "ORDER" | "LIMIT" | "HAVING" | "JOIN"
749783 ) {
750- return Some ( ( table. to_string ( ) , None ) ) ;
784+ return Some ( ( table, None ) ) ;
751785 }
752- Some ( ( table. to_string ( ) , Some ( alias. to_string ( ) ) ) )
786+ Some ( ( table, Some ( alias. to_string ( ) ) ) )
753787 } else {
754- Some ( ( table. to_string ( ) , None ) )
788+ Some ( ( table, None ) )
755789 }
756790}
757791
@@ -793,6 +827,10 @@ fn is_boundary_char(ch: Option<char>) -> bool {
793827 ch. map_or ( true , |c| !c. is_alphanumeric ( ) && c != '_' )
794828}
795829
830+ fn is_table_ident_char ( ch : char ) -> bool {
831+ ch. is_alphanumeric ( ) || ch == '_' || ch == '.'
832+ }
833+
796834fn find_top_level_keyword ( sql : & str , keyword : & str , start : usize ) -> Option < usize > {
797835 let upper = sql. to_uppercase ( ) ;
798836 let upper_bytes = upper. as_bytes ( ) ;
@@ -810,11 +848,31 @@ fn find_top_level_keyword(sql: &str, keyword: &str, start: usize) -> Option<usiz
810848 let mut in_double = false ;
811849 let mut in_backtick = false ;
812850 let mut in_bracket = false ;
851+ let mut in_line_comment = false ;
852+ let mut in_block_comment = false ;
813853
814854 let mut i = start;
815855 while i < bytes. len ( ) {
816856 let c = bytes[ i] as char ;
817857
858+ if in_line_comment {
859+ if c == '\n' || c == '\r' {
860+ in_line_comment = false ;
861+ }
862+ i += 1 ;
863+ continue ;
864+ }
865+
866+ if in_block_comment {
867+ if c == '*' && i + 1 < bytes. len ( ) && bytes[ i + 1 ] as char == '/' {
868+ in_block_comment = false ;
869+ i += 2 ;
870+ continue ;
871+ }
872+ i += 1 ;
873+ continue ;
874+ }
875+
818876 if in_single {
819877 if c == '\'' {
820878 if i + 1 < bytes. len ( ) && bytes[ i + 1 ] as char == '\'' {
@@ -851,6 +909,18 @@ fn find_top_level_keyword(sql: &str, keyword: &str, start: usize) -> Option<usiz
851909 continue ;
852910 }
853911
912+ if c == '-' && i + 1 < bytes. len ( ) && bytes[ i + 1 ] as char == '-' {
913+ in_line_comment = true ;
914+ i += 2 ;
915+ continue ;
916+ }
917+
918+ if c == '/' && i + 1 < bytes. len ( ) && bytes[ i + 1 ] as char == '*' {
919+ in_block_comment = true ;
920+ i += 2 ;
921+ continue ;
922+ }
923+
854924 match c {
855925 '\'' => {
856926 in_single = true ;
@@ -951,6 +1021,42 @@ fn find_top_level_keyword(sql: &str, keyword: &str, start: usize) -> Option<usiz
9511021 None
9521022}
9531023
1024+ fn insert_table_alias ( sql : & str , table_name : & str , alias : & str ) -> Option < String > {
1025+ let from_pos = find_top_level_keyword ( sql, "FROM" , 0 ) ?;
1026+ let bytes = sql. as_bytes ( ) ;
1027+ let mut idx = from_pos + 4 ;
1028+ while idx < bytes. len ( ) && bytes[ idx] . is_ascii_whitespace ( ) {
1029+ idx += 1 ;
1030+ }
1031+ if idx >= bytes. len ( ) {
1032+ return None ;
1033+ }
1034+
1035+ let table_start = idx;
1036+ while idx < bytes. len ( ) && is_table_ident_char ( bytes[ idx] as char ) {
1037+ idx += 1 ;
1038+ }
1039+ if table_start == idx {
1040+ return None ;
1041+ }
1042+
1043+ let table_token = & sql[ table_start..idx] ;
1044+ let table_simple = table_token
1045+ . split ( '.' )
1046+ . next_back ( )
1047+ . unwrap_or ( table_token) ;
1048+ if !table_simple. eq_ignore_ascii_case ( table_name) {
1049+ return None ;
1050+ }
1051+
1052+ let mut updated = String :: with_capacity ( sql. len ( ) + alias. len ( ) + 1 ) ;
1053+ updated. push_str ( & sql[ ..idx] ) ;
1054+ updated. push ( ' ' ) ;
1055+ updated. push_str ( alias) ;
1056+ updated. push_str ( & sql[ idx..] ) ;
1057+ Some ( updated)
1058+ }
1059+
9541060fn find_first_top_level_keyword ( sql : & str , start : usize , keywords : & [ & str ] ) -> Option < usize > {
9551061 keywords
9561062 . iter ( )
@@ -3826,10 +3932,12 @@ pub fn expand_aggregate_with_at(sql: &str) -> AggregateExpandResult {
38263932 Some ( pt. effective_name . clone ( ) )
38273933 } else {
38283934 // No alias on primary table, add _outer
3829- let from_pattern = format ! ( "FROM {}" , pt. name) ;
3830- let from_replacement = format ! ( "FROM {} _outer" , pt. name) ;
3831- result_sql = result_sql. replace ( & from_pattern, & from_replacement) ;
3832- Some ( "_outer" . to_string ( ) )
3935+ if let Some ( updated_sql) = insert_table_alias ( & result_sql, & pt. name , "_outer" ) {
3936+ result_sql = updated_sql;
3937+ Some ( "_outer" . to_string ( ) )
3938+ } else {
3939+ None
3940+ }
38333941 }
38343942 } else {
38353943 None
0 commit comments