@@ -76,7 +76,7 @@ func (el *ExternalLogins) login() http.Handler {
76
76
}
77
77
78
78
next := http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
79
- sess , err := p .BeginAuth (reqID )
79
+ sess , err := p .BeginAuth (el . toState ( provider , reqID , conf . ID ) )
80
80
if err != nil {
81
81
http .Error (w , err .Error (), http .StatusInternalServerError )
82
82
return
@@ -102,13 +102,15 @@ func (el *ExternalLogins) login() http.Handler {
102
102
103
103
func (el * ExternalLogins ) callback () http.Handler {
104
104
return http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
105
- provider := r .URL .Query ().Get ("provider" )
106
- reqID := r .URL .Query ().Get ("reqid" )
105
+ provider , reqID , baseID := el .fromState (el .getState (r ))
107
106
108
107
var conf internal.BaseConfig
109
108
if err := volatile .GetTyped ("oauth_" + reqID , & conf ); err != nil {
110
109
http .Error (w , err .Error (), http .StatusBadRequest )
111
110
return
111
+ } else if conf .ID != baseID {
112
+ http .Error (w , "invalid request" , http .StatusBadRequest )
113
+ return
112
114
}
113
115
114
116
customer , err := datastore .FindAccount (conf .CustomerID )
@@ -237,7 +239,7 @@ func (el *ExternalLogins) signIn(dbName, email string) (sessionToken string, err
237
239
func (el * ExternalLogins ) signUp (dbName , provider , email , accessToken string ) (sessionToken string , err error ) {
238
240
pw := fmt .Sprintf ("%s:%s" , provider , accessToken )
239
241
240
- b , _ , err := el .membership .createAccountAndUser (dbName , email , pw , 100 )
242
+ b , _ , err := el .membership .createAccountAndUser (dbName , email , pw , 0 )
241
243
if err != nil {
242
244
return
243
245
}
@@ -248,11 +250,8 @@ func (el *ExternalLogins) signUp(dbName, provider, email, accessToken string) (s
248
250
249
251
func (el * ExternalLogins ) getProvider (dbID , provider , reqID string , info internal.OAuthConfig ) (p goth.Provider , err error ) {
250
252
callbackURL := fmt .Sprintf (
251
- "%s/oauth/callback?provider=%s&reqid=%s&sbpk=%s " ,
253
+ "%s/oauth/callback" ,
252
254
config .Current .AppURL ,
253
- provider ,
254
- reqID ,
255
- dbID ,
256
255
)
257
256
258
257
if provider == OAuthProviderTwitter {
@@ -272,3 +271,19 @@ func (*ExternalLogins) getState(r *http.Request) string {
272
271
}
273
272
return params .Get ("state" )
274
273
}
274
+
275
+ func (* ExternalLogins ) toState (provider , reqID , baseID string ) string {
276
+ return fmt .Sprintf ("%s_%s_%s" , provider , reqID , baseID )
277
+ }
278
+
279
+ func (* ExternalLogins ) fromState (state string ) (provider , reqID , baseID string ) {
280
+ parts := strings .Split (state , "_" )
281
+ if len (parts ) != 3 {
282
+ return
283
+ }
284
+
285
+ provider = parts [0 ]
286
+ reqID = parts [1 ]
287
+ baseID = parts [2 ]
288
+ return
289
+ }
0 commit comments