66 "fmt"
77 "net/http"
88 "net/http/httptest"
9+ "strings"
910 "testing"
1011
1112 "github.com/stretchr/testify/assert"
@@ -976,6 +977,151 @@ func TestThatAntiCSRFCheckIsSkippedIfSessionRequiredIsFalseAndNoAccessTokenIsPas
976977 assert .Equal (t , res .StatusCode , 200 )
977978}
978979
980+ func TestThatResponseHeadersAreCorrectWhenUsingCookies (t * testing.T ) {
981+ configValue := supertokens.TypeInput {
982+ Supertokens : & supertokens.ConnectionInfo {
983+ ConnectionURI : "http://localhost:8080" ,
984+ },
985+ AppInfo : supertokens.AppInfo {
986+ AppName : "SuperTokens" ,
987+ WebsiteDomain : "supertokens.io" ,
988+ APIDomain : "api.supertokens.io" ,
989+ },
990+ RecipeList : []supertokens.Recipe {
991+ Init (& sessmodels.TypeInput {}),
992+ },
993+ }
994+ BeforeEach ()
995+ unittesting .StartUpST ("localhost" , "8080" )
996+ defer AfterEach ()
997+ err := supertokens .Init (configValue )
998+ if err != nil {
999+ t .Error (err .Error ())
1000+ }
1001+
1002+ app := getTestApp ([]typeTestEndpoint {})
1003+ defer app .Close ()
1004+
1005+ sessionCookies := createSessionWithCookies (app , map [string ]interface {}{})
1006+ print (sessionCookies )
1007+
1008+ var accessToken string
1009+
1010+ for _ , cookie := range sessionCookies {
1011+ if cookie .Name == "sAccessToken" {
1012+ accessToken = cookie .Value
1013+ }
1014+ }
1015+
1016+ assert .NotNil (t , accessToken )
1017+
1018+ req , err := http .NewRequest (http .MethodGet , app .URL + "/merge-payload" , nil )
1019+
1020+ if err != nil {
1021+ t .Error (err .Error ())
1022+ }
1023+
1024+ req .Header .Add ("Cookie" , "sAccessToken=" + accessToken )
1025+
1026+ res , err := http .DefaultClient .Do (req )
1027+
1028+ if err != nil {
1029+ t .Error (err .Error ())
1030+ }
1031+
1032+ cookiesHeaderValues := res .Header .Values ("Set-Cookie" )
1033+ accessTokenCount := 0
1034+
1035+ for _ , cookieValue := range cookiesHeaderValues {
1036+ if strings .Contains (cookieValue , "sAccessToken" ) {
1037+ accessTokenCount += 1
1038+ }
1039+ }
1040+
1041+ assert .Equal (t , accessTokenCount , 1 )
1042+
1043+ accessAllowHeaderValues := strings .Split (res .Header .Get ("Access-Control-Expose-Headers" ), "," )
1044+ frontTokenCount := 0
1045+
1046+ for _ , value := range accessAllowHeaderValues {
1047+ if strings .Contains (value , "front-token" ) {
1048+ frontTokenCount += 1
1049+ }
1050+ }
1051+
1052+ assert .Equal (t , frontTokenCount , 1 )
1053+ /**
1054+ Goland does not realise that the test passed because the start and end prints happen in the same line
1055+
1056+ This extra print adds a break line with the "--- PASS:" line being added in a new line making it clear
1057+ to the IDE that the test passed.
1058+
1059+ Leaving this in to avoid future confusion. Weirdly this happens intermittently for all tests.
1060+ */
1061+ fmt .Println ("" )
1062+ }
1063+
1064+ func TestThatResponseHeadersAreCorrectWhenUsingHeaders (t * testing.T ) {
1065+ configValue := supertokens.TypeInput {
1066+ Supertokens : & supertokens.ConnectionInfo {
1067+ ConnectionURI : "http://localhost:8080" ,
1068+ },
1069+ AppInfo : supertokens.AppInfo {
1070+ AppName : "SuperTokens" ,
1071+ WebsiteDomain : "supertokens.io" ,
1072+ APIDomain : "api.supertokens.io" ,
1073+ },
1074+ RecipeList : []supertokens.Recipe {
1075+ Init (& sessmodels.TypeInput {}),
1076+ },
1077+ }
1078+ BeforeEach ()
1079+ unittesting .StartUpST ("localhost" , "8080" )
1080+ defer AfterEach ()
1081+ err := supertokens .Init (configValue )
1082+ if err != nil {
1083+ t .Error (err .Error ())
1084+ }
1085+
1086+ app := getTestApp ([]typeTestEndpoint {})
1087+ defer app .Close ()
1088+
1089+ headers := createSessionWithHeaders (app , map [string ]interface {}{})
1090+ accessToken := headers .Get ("st-access-token" )
1091+
1092+ assert .NotNil (t , accessToken )
1093+
1094+ req , err := http .NewRequest (http .MethodGet , app .URL + "/merge-payload" , nil )
1095+
1096+ if err != nil {
1097+ t .Error (err .Error ())
1098+ }
1099+
1100+ req .Header .Add ("Authorization" , "Bearer " + accessToken )
1101+
1102+ res , err := http .DefaultClient .Do (req )
1103+
1104+ if err != nil {
1105+ t .Error (err .Error ())
1106+ }
1107+
1108+ accessTokenHeaderValues := res .Header .Values ("st-access-token" )
1109+ accessTokenCount := len (accessTokenHeaderValues )
1110+
1111+ assert .Equal (t , accessTokenCount , 1 )
1112+
1113+ accessAllowHeaderValues := strings .Split (res .Header .Get ("Access-Control-Expose-Headers" ), "," )
1114+ frontTokenCount := 0
1115+
1116+ for _ , value := range accessAllowHeaderValues {
1117+ if strings .Contains (value , "front-token" ) {
1118+ frontTokenCount += 1
1119+ }
1120+ }
1121+
1122+ assert .Equal (t , frontTokenCount , 1 )
1123+ }
1124+
9791125type typeTestEndpoint struct {
9801126 path string
9811127 overrideGlobalClaimValidators func (globalClaimValidators []claims.SessionClaimValidator , sessionContainer sessmodels.SessionContainer , userContext supertokens.UserContext ) ([]claims.SessionClaimValidator , error )
@@ -993,6 +1139,46 @@ func createSession(app *httptest.Server, body map[string]interface{}) []*http.Co
9931139 return res .Cookies ()
9941140}
9951141
1142+ func createSessionWithCookies (app * httptest.Server , body map [string ]interface {}) []* http.Cookie {
1143+ bodyBytes := []byte ("{}" )
1144+ if body != nil {
1145+ bodyBytes , _ = json .Marshal (body )
1146+ }
1147+ req , err := http .NewRequest (http .MethodPost , app .URL + "/create" , bytes .NewBuffer (bodyBytes ))
1148+ if err != nil {
1149+ return nil
1150+ }
1151+
1152+ req .Header .Set ("st-auth-mode" , "cookie" )
1153+
1154+ res , err := http .DefaultClient .Do (req )
1155+ if err != nil {
1156+ return nil
1157+ }
1158+
1159+ return res .Cookies ()
1160+ }
1161+
1162+ func createSessionWithHeaders (app * httptest.Server , body map [string ]interface {}) http.Header {
1163+ bodyBytes := []byte ("{}" )
1164+ if body != nil {
1165+ bodyBytes , _ = json .Marshal (body )
1166+ }
1167+ req , err := http .NewRequest (http .MethodPost , app .URL + "/create" , bytes .NewBuffer (bodyBytes ))
1168+ if err != nil {
1169+ return nil
1170+ }
1171+
1172+ req .Header .Set ("st-auth-mode" , "header" )
1173+
1174+ res , err := http .DefaultClient .Do (req )
1175+ if err != nil {
1176+ return nil
1177+ }
1178+
1179+ return res .Header
1180+ }
1181+
9961182func getTestApp (endpoints []typeTestEndpoint ) * httptest.Server {
9971183 mux := http .NewServeMux ()
9981184
@@ -1035,6 +1221,37 @@ func getTestApp(endpoints []typeTestEndpoint) *httptest.Server {
10351221 GetSession (r , rw , & sessmodels.VerifySessionOptions {})
10361222 }))
10371223
1224+ mux .HandleFunc ("/merge-payload" , VerifySession (& sessmodels.VerifySessionOptions {}, func (rw http.ResponseWriter , r * http.Request ) {
1225+ session , err := GetSession (r , rw , & sessmodels.VerifySessionOptions {})
1226+
1227+ if err != nil {
1228+ rw .WriteHeader (500 )
1229+ return
1230+ }
1231+
1232+ session .MergeIntoAccessTokenPayload (map [string ]interface {}{
1233+ "lastUpdate" : "123" ,
1234+ })
1235+ session .MergeIntoAccessTokenPayload (map [string ]interface {}{
1236+ "lastUpdate" : "456" ,
1237+ })
1238+ session .MergeIntoAccessTokenPayload (map [string ]interface {}{
1239+ "lastUpdate" : "789" ,
1240+ })
1241+
1242+ resp := map [string ]interface {}{
1243+ "status" : "OK" ,
1244+ }
1245+ respBytes , err := json .Marshal (resp )
1246+ if err != nil {
1247+ return
1248+ }
1249+ rw .Header ().Set ("Content-Type" , "application/json" )
1250+ rw .Header ().Set ("Content-Length" , fmt .Sprintf ("%d" , (len (respBytes ))))
1251+ rw .WriteHeader (http .StatusOK )
1252+ rw .Write (respBytes )
1253+ }))
1254+
10381255 mux .HandleFunc ("/default-claims" , VerifySession (nil , func (w http.ResponseWriter , r * http.Request ) {
10391256 sessionContainer := GetSessionFromRequestContext (r .Context ())
10401257 resp := map [string ]interface {}{
0 commit comments