@@ -69,11 +69,11 @@ class AuthHttpServer:
6969    def  __init__ (
7070        self ,
7171        uri : str ,
72-         redirect_uri : str ,
7372        buf_size : int  =  16384 ,
73+         redirect_uri : str  |  None  =  None ,
7474    ) ->  None :
7575        parsed_uri  =  urllib .parse .urlparse (uri )
76-         parsed_redirect  =  urllib .parse .urlparse (redirect_uri )
76+         parsed_redirect  =  urllib .parse .urlparse (redirect_uri )  if   redirect_uri   else   None 
7777        self ._socket  =  socket .socket (socket .AF_INET , socket .SOCK_STREAM )
7878        self .buf_size  =  buf_size 
7979        if  os .getenv ("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT" , "False" ).lower () ==  "true" :
@@ -84,10 +84,11 @@ def __init__(
8484            else :
8585                self ._socket .setsockopt (socket .SOL_SOCKET , socket .SO_REUSEPORT , 1 )
8686
87-         if  parsed_redirect . hostname   in  ( "localhost" ,  "127.0.0.1" ):
87+         if  parsed_redirect   and   self . _is_local_uri ( parsed_redirect ):
8888            port  =  parsed_redirect .port  or  0 
8989        else :
90-             port  =  parsed_uri .port  or  0 
90+             port  =  parsed_uri .port  if  parsed_uri  and  parsed_uri .port  else  0 
91+ 
9192        for  attempt  in  range (1 , self .DEFAULT_MAX_ATTEMPTS  +  1 ):
9293            try :
9394                self ._socket .bind (
@@ -128,27 +129,30 @@ def __init__(
128129            query = parsed_uri .query ,
129130            fragment = parsed_uri .fragment ,
130131        )
131-         if  (
132-             parsed_redirect .hostname  in  ("localhost" , "127.0.0.1" )
133-             and  port  !=  parsed_redirect .port 
134-         ):
135-             logger .debug (
136-                 f"Updating redirect port { parsed_redirect .port } { port }  
137-             )
138-             self ._redirect_uri  =  urllib .parse .ParseResult (
139-                 scheme = parsed_redirect .scheme ,
140-                 netloc = parsed_redirect .hostname  +  ":"  +  str (port ),
141-                 path = parsed_redirect .path ,
142-                 params = parsed_redirect .params ,
143-                 query = parsed_redirect .query ,
144-                 fragment = parsed_redirect .fragment ,
145-             )
146-         else :
147-             self ._redirect_uri  =  parsed_redirect 
132+         if  parsed_redirect :
133+             if  self ._is_local_uri (parsed_redirect ) and  port  !=  parsed_redirect .port :
134+                 logger .debug (
135+                     f"Updating redirect port { parsed_redirect .port } { port }  
136+                 )
137+                 self ._redirect_uri  =  urllib .parse .ParseResult (
138+                     scheme = parsed_redirect .scheme ,
139+                     netloc = parsed_redirect .hostname  +  ":"  +  str (port ),
140+                     path = parsed_redirect .path ,
141+                     params = parsed_redirect .params ,
142+                     query = parsed_redirect .query ,
143+                     fragment = parsed_redirect .fragment ,
144+                 )
145+             else :
146+                 self ._redirect_uri  =  parsed_redirect 
147+ 
148+     def  _is_local_uri (self , parsed_redirect ):
149+         return  parsed_redirect .hostname  in  ("localhost" , "127.0.0.1" )
148150
149151    @property  
150-     def  redirect_uri (self ) ->  str :
151-         return  self ._redirect_uri .geturl ()
152+     def  redirect_uri (self ) ->  str  |  None :
153+         if  self ._redirect_uri :
154+             return  self ._redirect_uri .geturl ()
155+         return  self .url 
152156
153157    @property  
154158    def  url (self ) ->  str :
0 commit comments