66 "fmt"
77 "net/http"
88 "net/http/httptest"
9+ "strings"
910 "testing"
1011
1112 "github.com/stretchr/testify/assert"
@@ -976,6 +977,157 @@ func TestThatAntiCSRFCheckIsSkippedIfSessionRequiredIsFalseAndNoAccessTokenIsPas
976977 assert .Equal (t , res .StatusCode , 200 )
977978}
978979
980+ func TestThatResponseHeadersAreCorrectWhenUsingCookies (t * testing.T ) {
981+ configValue := supertokens.TypeInput {
982+ Supertokens : & supertokens.ConnectionInfo {
983+ ConnectionURI : "http://localhost:8080" ,
984+ },
985+ AppInfo : supertokens.AppInfo {
986+ AppName : "SuperTokens" ,
987+ WebsiteDomain : "supertokens.io" ,
988+ APIDomain : "api.supertokens.io" ,
989+ },
990+ RecipeList : []supertokens.Recipe {
991+ Init (& sessmodels.TypeInput {}),
992+ },
993+ }
994+ BeforeEach ()
995+ unittesting .StartUpST ("localhost" , "8080" )
996+ defer AfterEach ()
997+ err := supertokens .Init (configValue )
998+ if err != nil {
999+ t .Error (err .Error ())
1000+ }
1001+
1002+ app := getTestApp ([]typeTestEndpoint {})
1003+ defer app .Close ()
1004+
1005+ sessionCookies := createSessionWithCookies (app , map [string ]interface {}{})
1006+ print (sessionCookies )
1007+
1008+ var accessToken string
1009+
1010+ for _ , cookie := range sessionCookies {
1011+ if cookie .Name == "sAccessToken" {
1012+ accessToken = cookie .Value
1013+ }
1014+ }
1015+
1016+ assert .NotNil (t , accessToken )
1017+
1018+ req , err := http .NewRequest (http .MethodGet , app .URL + "/merge-payload" , nil )
1019+
1020+ if err != nil {
1021+ t .Error (err .Error ())
1022+ }
1023+
1024+ req .Header .Add ("Cookie" , "sAccessToken=" + accessToken )
1025+
1026+ res , err := http .DefaultClient .Do (req )
1027+
1028+ if err != nil {
1029+ t .Error (err .Error ())
1030+ }
1031+
1032+ cookiesHeaderValues := res .Header .Values ("Set-Cookie" )
1033+ accessTokenCount := 0
1034+
1035+ for _ , cookieValue := range cookiesHeaderValues {
1036+ if strings .Contains (cookieValue , "sAccessToken" ) {
1037+ accessTokenCount += 1
1038+ }
1039+ }
1040+
1041+ assert .Equal (t , accessTokenCount , 1 )
1042+
1043+ accessAllowHeaderValues := strings .Split (res .Header .Get ("Access-Control-Expose-Headers" ), "," )
1044+ frontTokenCount := 0
1045+
1046+ for _ , value := range accessAllowHeaderValues {
1047+ if strings .Contains (value , "front-token" ) {
1048+ frontTokenCount += 1
1049+ }
1050+ }
1051+
1052+ assert .Equal (t , frontTokenCount , 1 )
1053+ /**
1054+ Goland does not realise that the test passed because the start and end prints happen in the same line
1055+
1056+ This extra print adds a break line with the "--- PASS:" line being added in a new line making it clear
1057+ to the IDE that the test passed.
1058+
1059+ Leaving this in to avoid future confusion. Weirdly this happens intermittently for all tests.
1060+ */
1061+ fmt .Println ("" )
1062+ }
1063+
1064+ func TestThatResponseHeadersAreCorrectWhenUsingHeaders (t * testing.T ) {
1065+ configValue := supertokens.TypeInput {
1066+ Supertokens : & supertokens.ConnectionInfo {
1067+ ConnectionURI : "http://localhost:8080" ,
1068+ },
1069+ AppInfo : supertokens.AppInfo {
1070+ AppName : "SuperTokens" ,
1071+ WebsiteDomain : "supertokens.io" ,
1072+ APIDomain : "api.supertokens.io" ,
1073+ },
1074+ RecipeList : []supertokens.Recipe {
1075+ Init (& sessmodels.TypeInput {}),
1076+ },
1077+ }
1078+ BeforeEach ()
1079+ unittesting .StartUpST ("localhost" , "8080" )
1080+ defer AfterEach ()
1081+ err := supertokens .Init (configValue )
1082+ if err != nil {
1083+ t .Error (err .Error ())
1084+ }
1085+
1086+ app := getTestApp ([]typeTestEndpoint {})
1087+ defer app .Close ()
1088+
1089+ headers := createSessionWithHeaders (app , map [string ]interface {}{})
1090+ accessToken := headers .Get ("st-access-token" )
1091+
1092+ assert .NotNil (t , accessToken )
1093+
1094+ req , err := http .NewRequest (http .MethodGet , app .URL + "/merge-payload" , nil )
1095+
1096+ if err != nil {
1097+ t .Error (err .Error ())
1098+ }
1099+
1100+ req .Header .Add ("Authorization" , "Bearer " + accessToken )
1101+
1102+ res , err := http .DefaultClient .Do (req )
1103+
1104+ if err != nil {
1105+ t .Error (err .Error ())
1106+ }
1107+
1108+ accessTokenHeaderValues := res .Header .Values ("st-access-token" )
1109+ accessTokenCount := len (accessTokenHeaderValues )
1110+
1111+ assert .Equal (t , accessTokenCount , 1 )
1112+
1113+ accessAllowHeaderValues := strings .Split (res .Header .Get ("Access-Control-Expose-Headers" ), "," )
1114+ frontTokenCount := 0
1115+ stAccessTokenCount := 0
1116+
1117+ for _ , value := range accessAllowHeaderValues {
1118+ if strings .Contains (value , "front-token" ) {
1119+ frontTokenCount += 1
1120+ }
1121+
1122+ if strings .Contains (value , "st-access-token" ) {
1123+ stAccessTokenCount += 1
1124+ }
1125+ }
1126+
1127+ assert .Equal (t , frontTokenCount , 1 )
1128+ assert .Equal (t , stAccessTokenCount , 1 )
1129+ }
1130+
9791131type typeTestEndpoint struct {
9801132 path string
9811133 overrideGlobalClaimValidators func (globalClaimValidators []claims.SessionClaimValidator , sessionContainer sessmodels.SessionContainer , userContext supertokens.UserContext ) ([]claims.SessionClaimValidator , error )
@@ -993,6 +1145,46 @@ func createSession(app *httptest.Server, body map[string]interface{}) []*http.Co
9931145 return res .Cookies ()
9941146}
9951147
1148+ func createSessionWithCookies (app * httptest.Server , body map [string ]interface {}) []* http.Cookie {
1149+ bodyBytes := []byte ("{}" )
1150+ if body != nil {
1151+ bodyBytes , _ = json .Marshal (body )
1152+ }
1153+ req , err := http .NewRequest (http .MethodPost , app .URL + "/create" , bytes .NewBuffer (bodyBytes ))
1154+ if err != nil {
1155+ return nil
1156+ }
1157+
1158+ req .Header .Set ("st-auth-mode" , "cookie" )
1159+
1160+ res , err := http .DefaultClient .Do (req )
1161+ if err != nil {
1162+ return nil
1163+ }
1164+
1165+ return res .Cookies ()
1166+ }
1167+
1168+ func createSessionWithHeaders (app * httptest.Server , body map [string ]interface {}) http.Header {
1169+ bodyBytes := []byte ("{}" )
1170+ if body != nil {
1171+ bodyBytes , _ = json .Marshal (body )
1172+ }
1173+ req , err := http .NewRequest (http .MethodPost , app .URL + "/create" , bytes .NewBuffer (bodyBytes ))
1174+ if err != nil {
1175+ return nil
1176+ }
1177+
1178+ req .Header .Set ("st-auth-mode" , "header" )
1179+
1180+ res , err := http .DefaultClient .Do (req )
1181+ if err != nil {
1182+ return nil
1183+ }
1184+
1185+ return res .Header
1186+ }
1187+
9961188func getTestApp (endpoints []typeTestEndpoint ) * httptest.Server {
9971189 mux := http .NewServeMux ()
9981190
@@ -1035,6 +1227,37 @@ func getTestApp(endpoints []typeTestEndpoint) *httptest.Server {
10351227 GetSession (r , rw , & sessmodels.VerifySessionOptions {})
10361228 }))
10371229
1230+ mux .HandleFunc ("/merge-payload" , VerifySession (& sessmodels.VerifySessionOptions {}, func (rw http.ResponseWriter , r * http.Request ) {
1231+ session , err := GetSession (r , rw , & sessmodels.VerifySessionOptions {})
1232+
1233+ if err != nil {
1234+ rw .WriteHeader (500 )
1235+ return
1236+ }
1237+
1238+ session .MergeIntoAccessTokenPayload (map [string ]interface {}{
1239+ "lastUpdate" : "123" ,
1240+ })
1241+ session .MergeIntoAccessTokenPayload (map [string ]interface {}{
1242+ "lastUpdate" : "456" ,
1243+ })
1244+ session .MergeIntoAccessTokenPayload (map [string ]interface {}{
1245+ "lastUpdate" : "789" ,
1246+ })
1247+
1248+ resp := map [string ]interface {}{
1249+ "status" : "OK" ,
1250+ }
1251+ respBytes , err := json .Marshal (resp )
1252+ if err != nil {
1253+ return
1254+ }
1255+ rw .Header ().Set ("Content-Type" , "application/json" )
1256+ rw .Header ().Set ("Content-Length" , fmt .Sprintf ("%d" , (len (respBytes ))))
1257+ rw .WriteHeader (http .StatusOK )
1258+ rw .Write (respBytes )
1259+ }))
1260+
10381261 mux .HandleFunc ("/default-claims" , VerifySession (nil , func (w http.ResponseWriter , r * http.Request ) {
10391262 sessionContainer := GetSessionFromRequestContext (r .Context ())
10401263 resp := map [string ]interface {}{
0 commit comments