@@ -68,14 +68,16 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
6868                        proto  ==  "https"  and  port  !=  HTTPS_PORT 
6969                    ):
7070                        port_suffix  =  f":{ port }  
71+ 
7172                scope ["headers" ] =  self ._replace_header_value_by_name (
7273                    scope ,
7374                    "host" ,
7475                    f"{ domain } { port_suffix }  ,
7576                )
77+ 
7678        await  self .app (scope , receive , send )
7779
78-     def  _get_forwarded_url_parts (self , scope : Scope ) ->  Tuple [str ]:   # noqa: C901 
80+     def  _get_forwarded_url_parts (self , scope : Scope ) ->  Tuple [str ]:
7981        proto  =  scope .get ("scheme" , "http" )
8082        header_host  =  self ._get_header_value_by_name (scope , "host" )
8183        if  header_host  is  None :
@@ -87,35 +89,28 @@ def _get_forwarded_url_parts(self, scope: Scope) -> Tuple[str]:  # noqa: C901
8789            else :
8890                domain  =  header_host_parts [0 ]
8991                port  =  None 
90-         forwarded  =  self ._get_header_value_by_name (scope , "forwarded" )
91-         if  forwarded  is  not None :
92-             proxy_servers  =  forwarded .split ("," )  # values from the last server are used 
93-             for  proxy_server  in  proxy_servers :
94-                 parts  =  proxy_server .split (";" )
95-                 for  part  in  parts :
96-                     if  len (part ) >  0  and  re .search ("=" , part ):
97-                         key , value  =  part .split ("=" )
98-                         if  key  ==  "proto" :
99-                             proto  =  value 
100-                         elif  key  ==  "host" :
101-                             host_parts  =  value .split (":" )
102-                             domain  =  host_parts [0 ]
103-                             try :
104-                                 port  =  (
105-                                     int (host_parts [1 ]) if  len (host_parts ) ==  2  else  None 
106-                                 )
107-                             except  ValueError :
108-                                 # ignore ports that are not valid integers 
109-                                 pass 
92+ 
93+         if  forwarded  :=  self ._get_header_value_by_name (scope , "forwarded" ):
94+             for  proxy  in  forwarded .split ("," ):
95+                 if  (proto_expr  :=  re .search (r"proto=(?P<proto>http(s)?)" , proxy )) and  (
96+                     host_expr  :=  re .search (
97+                         r"host=(?P<host>[\w.-]+)(:(?P<port>\w+))?" , proxy 
98+                     )
99+                 ):
100+                     proto  =  proto_expr .groupdict ()["proto" ]
101+                     domain  =  host_expr .groupdict ()["host" ]
102+                     port_str  =  host_expr .groupdict ().get ("port" , None )
103+ 
110104        else :
111105            domain  =  self ._get_header_value_by_name (scope , "x-forwarded-host" , domain )
112106            proto  =  self ._get_header_value_by_name (scope , "x-forwarded-proto" , proto )
113107            port_str  =  self ._get_header_value_by_name (scope , "x-forwarded-port" , port )
114-             try :
115-                 port  =  int (port_str ) if  port_str  is  not None  else  None 
116-             except  ValueError :
117-                 # ignore ports that are not valid integers 
118-                 pass 
108+ 
109+         try :
110+             port  =  int (port_str ) if  port_str  is  not None  else  None 
111+         except  ValueError :
112+             # ignore ports that are not valid integers 
113+             pass 
119114
120115        return  (proto , domain , port )
121116
0 commit comments