diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3ac28af..df61297 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -49,3 +49,26 @@ jobs: gofmt -s -l . | grep -v vendor exit 1 fi + + - name: Run tests with coverage + run: go test -short -race -coverprofile=coverage.out ./... + + - name: Extract coverage percentage + if: github.ref == 'refs/heads/main' && matrix.go-version == '1.24' + id: coverage + run: | + COVERAGE=$(go tool cover -func=coverage.out | grep total | awk '{print $3}' | tr -d '%') + echo "percentage=$COVERAGE" >> "$GITHUB_OUTPUT" + + - name: Update coverage badge + if: github.ref == 'refs/heads/main' && matrix.go-version == '1.24' + uses: schneegans/dynamic-badges-action@v1.7.0 + with: + auth: ${{ secrets.GIST_TOKEN }} + gistID: 2c608589294aed9aa900256daeec0fd4 + filename: coverage.json + label: coverage + message: ${{ steps.coverage.outputs.percentage }}% + valColorRange: ${{ steps.coverage.outputs.percentage }} + minColorRange: 40 + maxColorRange: 90 diff --git a/README.md b/README.md index c787501..c683ca8 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,7 @@ [![Go Reference](https://pkg.go.dev/badge/github.com/tirthpatell/threads-go.svg)](https://pkg.go.dev/github.com/tirthpatell/threads-go) [![Go Report Card](https://goreportcard.com/badge/github.com/tirthpatell/threads-go)](https://goreportcard.com/report/github.com/tirthpatell/threads-go) [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) +[![Coverage](https://img.shields.io/endpoint?url=https://gist.githubusercontent.com/tirthpatell/2c608589294aed9aa900256daeec0fd4/raw/coverage.json)](https://github.com/tirthpatell/threads-go/actions) Production-ready Go client for the Threads API with complete endpoint coverage, OAuth 2.0 authentication, rate limiting, and comprehensive error handling. diff --git a/auth_test.go b/auth_test.go new file mode 100644 index 0000000..fea3c77 --- /dev/null +++ b/auth_test.go @@ -0,0 +1,192 @@ +package threads + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" +) + +func TestExchangeCodeForToken_Success(t *testing.T) { + handler := func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + http.NotFound(w, r) + return + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(200) + _, _ = w.Write([]byte(`{ + "access_token": "new_token_123", + "token_type": "bearer", + "expires_in": 3600, + "user_id": 99999 + }`)) + } + + server := httptest.NewServer(http.HandlerFunc(handler)) + t.Cleanup(server.Close) + + config := &Config{ + ClientID: "test-id", + ClientSecret: "test-secret", + RedirectURI: "https://example.com/callback", + } + config.SetDefaults() + config.BaseURL = server.URL + + client, err := NewClient(config) + if err != nil { + t.Fatal(err) + } + + err = client.ExchangeCodeForToken(context.Background(), "auth_code_123") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !client.IsAuthenticated() { + t.Error("expected client to be authenticated") + } + tokenInfo := client.GetTokenInfo() + if tokenInfo.AccessToken != "new_token_123" { + t.Errorf("expected new_token_123, got %s", tokenInfo.AccessToken) + } + if tokenInfo.UserID != "99999" { + t.Errorf("expected user ID 99999, got %s", tokenInfo.UserID) + } +} + +func TestExchangeCodeForToken_EmptyCode(t *testing.T) { + config := &Config{ + ClientID: "test-id", + ClientSecret: "test-secret", + RedirectURI: "https://example.com/callback", + } + config.SetDefaults() + client, _ := NewClient(config) + + err := client.ExchangeCodeForToken(context.Background(), "") + if err == nil { + t.Fatal("expected error for empty code") + } +} + +func TestGetLongLivedToken_Success(t *testing.T) { + client := testClient(t, jsonHandler(200, `{ + "access_token": "long_lived_token", + "token_type": "bearer", + "expires_in": 5184000 + }`)) + + err := client.GetLongLivedToken(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + tokenInfo := client.GetTokenInfo() + if tokenInfo.AccessToken != "long_lived_token" { + t.Errorf("expected long_lived_token, got %s", tokenInfo.AccessToken) + } +} + +func TestRefreshToken_Success(t *testing.T) { + client := testClient(t, jsonHandler(200, `{ + "access_token": "refreshed_token", + "token_type": "bearer", + "expires_in": 5184000 + }`)) + + err := client.RefreshToken(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + tokenInfo := client.GetTokenInfo() + if tokenInfo.AccessToken != "refreshed_token" { + t.Errorf("expected refreshed_token, got %s", tokenInfo.AccessToken) + } +} + +func TestRefreshToken_NoToken(t *testing.T) { + config := &Config{ + ClientID: "test-id", + ClientSecret: "test-secret", + RedirectURI: "https://example.com/callback", + } + config.SetDefaults() + client, _ := NewClient(config) + + err := client.RefreshToken(context.Background()) + if err == nil { + t.Fatal("expected error when no token") + } +} + +func TestDebugToken_Success(t *testing.T) { + client := testClient(t, jsonHandler(200, `{ + "data": { + "type": "USER", + "application": "Test App", + "is_valid": true, + "expires_at": 1735689600, + "issued_at": 1735603200, + "user_id": "12345", + "scopes": ["threads_basic"] + } + }`)) + + resp, err := client.DebugToken(context.Background(), "test-token") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !resp.Data.IsValid { + t.Error("expected valid token") + } + if resp.Data.UserID != "12345" { + t.Errorf("expected user ID 12345, got %s", resp.Data.UserID) + } +} + +func TestGetAuthURL_ContainsRequiredParams(t *testing.T) { + config := &Config{ + ClientID: "my-app-id", + ClientSecret: "secret", + RedirectURI: "https://example.com/callback", + } + config.SetDefaults() + client, _ := NewClient(config) + + authURL := client.GetAuthURL([]string{"threads_basic"}) + if authURL == "" { + t.Fatal("expected non-empty auth URL") + } + for _, param := range []string{"client_id=my-app-id", "response_type=code", "scope=threads_basic"} { + if !strings.Contains(authURL, param) { + t.Errorf("expected auth URL to contain %q, got %s", param, authURL) + } + } +} + +func TestTokenExpiration(t *testing.T) { + config := &Config{ + ClientID: "test-id", + ClientSecret: "test-secret", + RedirectURI: "https://example.com/callback", + } + config.SetDefaults() + client, _ := NewClient(config) + + _ = client.SetTokenInfo(&TokenInfo{ + AccessToken: "expired", + TokenType: "Bearer", + ExpiresAt: time.Now().Add(-time.Hour), + UserID: "12345", + CreatedAt: time.Now().Add(-2 * time.Hour), + }) + + if !client.IsTokenExpired() { + t.Error("expected token to be expired") + } + if !client.IsTokenExpiringSoon(time.Hour) { + t.Error("expected token to be expiring soon") + } +} diff --git a/client_test.go b/client_test.go index 9a31399..55974bb 100644 --- a/client_test.go +++ b/client_test.go @@ -682,13 +682,6 @@ func TestCreateErrorFromResponseParsesIsTransient(t *testing.T) { } } -type noopLogger struct{} - -func (n *noopLogger) Debug(msg string, fields ...any) {} -func (n *noopLogger) Info(msg string, fields ...any) {} -func (n *noopLogger) Warn(msg string, fields ...any) {} -func (n *noopLogger) Error(msg string, fields ...any) {} - func TestIsRetryableErrorWithTransientAPIError(t *testing.T) { h := &HTTPClient{ logger: &noopLogger{}, diff --git a/client_utils.go b/client_utils.go index 3741740..09dc989 100644 --- a/client_utils.go +++ b/client_utils.go @@ -8,12 +8,11 @@ import ( // getUserID extracts user ID from token info func (c *Client) getUserID() string { + c.mu.RLock() + defer c.mu.RUnlock() if c.tokenInfo != nil && c.tokenInfo.UserID != "" { return c.tokenInfo.UserID } - - // If user ID is not in token info, we might need to call /me endpoint - // For now, return empty string to trigger an error return "" } diff --git a/http_client.go b/http_client.go index db37128..09da6e2 100644 --- a/http_client.go +++ b/http_client.go @@ -58,13 +58,23 @@ func NewHTTPClient(config *Config, rateLimiter *RateLimiter) *HTTPClient { Timeout: config.HTTPTimeout, } + baseURL := config.BaseURL + if baseURL == "" { + baseURL = "https://graph.threads.net" + } + + userAgent := config.UserAgent + if userAgent == "" { + userAgent = DefaultUserAgent + } + return &HTTPClient{ client: httpClient, logger: config.Logger, retryConfig: config.RetryConfig, rateLimiter: rateLimiter, - baseURL: "https://graph.threads.net", - userAgent: DefaultUserAgent, + baseURL: baseURL, + userAgent: userAgent, } } diff --git a/http_client_test.go b/http_client_test.go new file mode 100644 index 0000000..5715cc1 --- /dev/null +++ b/http_client_test.go @@ -0,0 +1,122 @@ +package threads + +import ( + "context" + "net/http" + "sync/atomic" + "testing" + "time" +) + +func TestHTTPClient_RetryOnServerError(t *testing.T) { + var attempts int32 + handler := func(w http.ResponseWriter, r *http.Request) { + count := atomic.AddInt32(&attempts, 1) + if count < 3 { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(500) + _, _ = w.Write([]byte(`{"error":{"message":"Internal error","type":"OAuthException","code":2,"is_transient":true}}`)) + return + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(200) + _, _ = w.Write([]byte(`{"ok":true}`)) + } + + httpClient := newTestHTTPClient(t, http.HandlerFunc(handler), &RetryConfig{ + MaxRetries: 3, + InitialDelay: 10 * time.Millisecond, + MaxDelay: 50 * time.Millisecond, + BackoffFactor: 2.0, + }) + + resp, err := httpClient.Do(&RequestOptions{Method: "GET", Path: "/test"}, "token") + if err != nil { + t.Fatalf("expected success after retries, got: %v", err) + } + if resp.StatusCode != 200 { + t.Errorf("expected 200, got %d", resp.StatusCode) + } + if atomic.LoadInt32(&attempts) != 3 { + t.Errorf("expected 3 attempts, got %d", atomic.LoadInt32(&attempts)) + } +} + +func TestHTTPClient_NoRetryOnValidationError(t *testing.T) { + var attempts int32 + handler := func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&attempts, 1) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(400) + _, _ = w.Write([]byte(`{"error":{"message":"Bad request","type":"OAuthException","code":100}}`)) + } + + httpClient := newTestHTTPClient(t, http.HandlerFunc(handler), &RetryConfig{ + MaxRetries: 3, + InitialDelay: 10 * time.Millisecond, + MaxDelay: 50 * time.Millisecond, + BackoffFactor: 2.0, + }) + + _, err := httpClient.Do(&RequestOptions{Method: "GET", Path: "/test"}, "token") + if err == nil { + t.Fatal("expected error for 400") + } + if atomic.LoadInt32(&attempts) != 1 { + t.Errorf("expected 1 attempt (no retry for 400), got %d", atomic.LoadInt32(&attempts)) + } +} + +func TestHTTPClient_ContextCancellation(t *testing.T) { + handler := func(w http.ResponseWriter, r *http.Request) { + select { + case <-r.Context().Done(): + case <-time.After(5 * time.Second): + } + w.WriteHeader(200) + } + + httpClient := newTestHTTPClient(t, http.HandlerFunc(handler), &RetryConfig{ + MaxRetries: 0, + InitialDelay: time.Second, + MaxDelay: time.Second, + BackoffFactor: 1.0, + }) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + _, err := httpClient.Do(&RequestOptions{Method: "GET", Path: "/slow", Context: ctx}, "token") + if err == nil { + t.Fatal("expected error from context cancellation") + } +} + +func TestHTTPClient_ParseRateLimitHeaders(t *testing.T) { + handler := func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-RateLimit-Limit", "100") + w.Header().Set("X-RateLimit-Remaining", "42") + w.Header().Set("X-RateLimit-Reset", "1735689600") + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(200) + _, _ = w.Write([]byte(`{}`)) + } + + httpClient := newTestHTTPClient(t, http.HandlerFunc(handler), &RetryConfig{ + MaxRetries: 0, InitialDelay: time.Second, MaxDelay: time.Second, BackoffFactor: 1.0, + }) + + resp, err := httpClient.Do(&RequestOptions{Method: "GET", Path: "/test"}, "token") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.RateLimit == nil { + t.Fatal("expected rate limit info") + } + if resp.RateLimit.Limit != 100 { + t.Errorf("expected limit 100, got %d", resp.RateLimit.Limit) + } + if resp.RateLimit.Remaining != 42 { + t.Errorf("expected remaining 42, got %d", resp.RateLimit.Remaining) + } +} diff --git a/insights_test.go b/insights_test.go new file mode 100644 index 0000000..02e6008 --- /dev/null +++ b/insights_test.go @@ -0,0 +1,48 @@ +package threads + +import ( + "context" + "testing" +) + +func TestGetPostInsights_Success(t *testing.T) { + client := testClient(t, jsonHandler(200, `{ + "data": [ + {"name": "views", "period": "lifetime", "values": [{"value": 100}]}, + {"name": "likes", "period": "lifetime", "values": [{"value": 25}]} + ] + }`)) + + resp, err := client.GetPostInsights(context.Background(), ConvertToPostID("post_1"), []string{"views", "likes"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(resp.Data) != 2 { + t.Errorf("expected 2 metrics, got %d", len(resp.Data)) + } + if resp.Data[0].Name != "views" { + t.Errorf("expected 'views', got %s", resp.Data[0].Name) + } +} + +func TestGetPostInsights_InvalidPostID(t *testing.T) { + client := testClient(t, jsonHandler(200, `{}`)) + _, err := client.GetPostInsights(context.Background(), PostID(""), []string{"views"}) + if err == nil { + t.Fatal("expected error for empty post ID") + } +} + +func TestGetAccountInsights_Success(t *testing.T) { + client := testClient(t, jsonHandler(200, `{ + "data": [{"name": "followers_count", "period": "day", "values": [{"value": 500}]}] + }`)) + + resp, err := client.GetAccountInsights(context.Background(), ConvertToUserID("12345"), []string{"followers_count"}, "day") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(resp.Data) != 1 { + t.Errorf("expected 1 metric, got %d", len(resp.Data)) + } +} diff --git a/location_test.go b/location_test.go new file mode 100644 index 0000000..c36b7e4 --- /dev/null +++ b/location_test.go @@ -0,0 +1,57 @@ +package threads + +import ( + "context" + "testing" +) + +func TestSearchLocations_Success(t *testing.T) { + client := testClient(t, jsonHandler(200, `{ + "data": [ + {"id": "loc1", "name": "Coffee Shop", "city": "San Francisco"}, + {"id": "loc2", "name": "Coffee House", "city": "San Francisco"} + ] + }`)) + + resp, err := client.SearchLocations(context.Background(), "coffee", nil, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(resp.Data) != 2 { + t.Errorf("expected 2 locations, got %d", len(resp.Data)) + } +} + +func TestSearchLocations_EmptyQuery(t *testing.T) { + client := testClient(t, jsonHandler(200, `{}`)) + _, err := client.SearchLocations(context.Background(), "", nil, nil) + if err == nil { + t.Fatal("expected error for empty query") + } +} + +func TestGetLocation_Success(t *testing.T) { + client := testClient(t, jsonHandler(200, `{ + "id": "loc1", + "name": "Golden Gate Park", + "city": "San Francisco", + "latitude": 37.7694, + "longitude": -122.4862 + }`)) + + loc, err := client.GetLocation(context.Background(), ConvertToLocationID("loc1")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if loc.Name != "Golden Gate Park" { + t.Errorf("expected Golden Gate Park, got %s", loc.Name) + } +} + +func TestGetLocation_InvalidID(t *testing.T) { + client := testClient(t, jsonHandler(200, `{}`)) + _, err := client.GetLocation(context.Background(), LocationID("")) + if err == nil { + t.Fatal("expected error for empty location ID") + } +} diff --git a/pagination_test.go b/pagination_test.go new file mode 100644 index 0000000..95a9d9a --- /dev/null +++ b/pagination_test.go @@ -0,0 +1,86 @@ +package threads + +import ( + "context" + "net/http" + "sync/atomic" + "testing" +) + +func TestPostIterator_MultiplePages(t *testing.T) { + var callCount int32 + handler := func(w http.ResponseWriter, r *http.Request) { + count := atomic.AddInt32(&callCount, 1) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(200) + switch count { + case 1: + _, _ = w.Write([]byte(`{"data":[{"id":"1"},{"id":"2"}],"paging":{"cursors":{"after":"page2"}}}`)) + case 2: + _, _ = w.Write([]byte(`{"data":[{"id":"3"}],"paging":{}}`)) + default: + _, _ = w.Write([]byte(`{"data":[],"paging":{}}`)) + } + } + + client := testClient(t, http.HandlerFunc(handler)) + iter := NewPostIterator(client, ConvertToUserID("12345"), &PostsOptions{Limit: 2}) + posts, err := iter.Collect(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(posts) != 3 { + t.Errorf("expected 3 posts, got %d", len(posts)) + } +} + +func TestPostIterator_EmptyResult(t *testing.T) { + client := testClient(t, jsonHandler(200, `{"data":[],"paging":{}}`)) + iter := NewPostIterator(client, ConvertToUserID("12345"), nil) + posts, err := iter.Collect(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(posts) != 0 { + t.Errorf("expected 0 posts, got %d", len(posts)) + } +} + +func TestPostIterator_Reset(t *testing.T) { + client := testClient(t, jsonHandler(200, `{"data":[{"id":"1"}],"paging":{}}`)) + iter := NewPostIterator(client, ConvertToUserID("12345"), nil) + + posts1, _ := iter.Collect(context.Background()) + if len(posts1) != 1 { + t.Errorf("expected 1 post, got %d", len(posts1)) + } + if iter.HasNext() { + t.Error("expected iterator to be done") + } + + iter.Reset() + if !iter.HasNext() { + t.Error("expected iterator to have next after reset") + } +} + +func TestSearchIterator_Keyword(t *testing.T) { + client := testClient(t, jsonHandler(200, `{"data":[{"id":"1","text":"match"}],"paging":{}}`)) + iter := NewSearchIterator(client, "test", "keyword", nil) + posts, err := iter.Collect(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(posts) != 1 { + t.Errorf("expected 1 result, got %d", len(posts)) + } +} + +func TestSearchIterator_InvalidType(t *testing.T) { + client := testClient(t, jsonHandler(200, `{}`)) + iter := NewSearchIterator(client, "test", "invalid", nil) + _, err := iter.Next(context.Background()) + if err == nil { + t.Fatal("expected error for invalid search type") + } +} diff --git a/posts_create_test.go b/posts_create_test.go new file mode 100644 index 0000000..d96645c --- /dev/null +++ b/posts_create_test.go @@ -0,0 +1,256 @@ +package threads + +import ( + "context" + "net/http" + "strings" + "sync/atomic" + "testing" +) + +func TestCreateTextPost_Success(t *testing.T) { + var callCount int32 + handler := func(w http.ResponseWriter, r *http.Request) { + count := atomic.AddInt32(&callCount, 1) + w.Header().Set("Content-Type", "application/json") + + switch { + case r.Method == "POST" && strings.HasPrefix(r.URL.Path, "/12345/threads_publish"): + w.WriteHeader(200) + _, _ = w.Write([]byte(`{"id":"post_1"}`)) + case r.Method == "POST" && strings.HasPrefix(r.URL.Path, "/12345/threads"): + w.WriteHeader(200) + _, _ = w.Write([]byte(`{"id":"container_1"}`)) + case r.Method == "GET" && strings.HasPrefix(r.URL.Path, "/container_1"): + w.WriteHeader(200) + _, _ = w.Write([]byte(`{"id":"container_1","status":"FINISHED"}`)) + case r.Method == "GET" && strings.HasPrefix(r.URL.Path, "/post_1"): + w.WriteHeader(200) + _, _ = w.Write([]byte(`{"id":"post_1","text":"Hello","media_type":"TEXT","permalink":"https://threads.net/p/1"}`)) + default: + t.Logf("call %d: unexpected request: %s %s", count, r.Method, r.URL.Path) + http.NotFound(w, r) + } + } + + client := testClient(t, http.HandlerFunc(handler)) + + post, err := client.CreateTextPost(context.Background(), &TextPostContent{ + Text: "Hello", + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if post.ID != "post_1" { + t.Errorf("expected post ID post_1, got %s", post.ID) + } +} + +func TestCreateTextPost_EmptyText(t *testing.T) { + client := testClient(t, jsonHandler(200, `{}`)) + + _, err := client.CreateTextPost(context.Background(), &TextPostContent{ + Text: "", + }) + if err == nil { + t.Fatal("expected error for empty text") + } +} + +func TestCreateTextPost_TextTooLong(t *testing.T) { + client := testClient(t, jsonHandler(200, `{}`)) + + longText := make([]byte, MaxTextLength+1) + for i := range longText { + longText[i] = 'a' + } + + _, err := client.CreateTextPost(context.Background(), &TextPostContent{ + Text: string(longText), + }) + if err == nil { + t.Fatal("expected error for text too long") + } + if !IsValidationError(err) { + t.Errorf("expected ValidationError, got %T", err) + } +} + +func TestCreateTextPost_AutoPublish(t *testing.T) { + handler := func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + + switch { + case r.Method == "POST" && strings.HasPrefix(r.URL.Path, "/12345/threads"): + if err := r.ParseForm(); err != nil { + t.Errorf("failed to parse form: %v", err) + } + if r.PostForm.Get("auto_publish_text") != "true" { + t.Error("expected auto_publish_text=true") + } + w.WriteHeader(200) + _, _ = w.Write([]byte(`{"id":"post_auto"}`)) + case r.Method == "GET" && strings.HasPrefix(r.URL.Path, "/post_auto"): + w.WriteHeader(200) + _, _ = w.Write([]byte(`{"id":"post_auto","text":"Auto","media_type":"TEXT"}`)) + default: + http.NotFound(w, r) + } + } + + client := testClient(t, http.HandlerFunc(handler)) + + post, err := client.CreateTextPost(context.Background(), &TextPostContent{ + Text: "Auto", + AutoPublishText: true, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if post.ID != "post_auto" { + t.Errorf("expected post ID post_auto, got %s", post.ID) + } +} + +func TestCreateImagePost_Success(t *testing.T) { + handler := func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + switch { + case r.Method == "POST" && strings.HasPrefix(r.URL.Path, "/12345/threads_publish"): + w.WriteHeader(200) + _, _ = w.Write([]byte(`{"id":"img_post"}`)) + case r.Method == "POST" && strings.HasPrefix(r.URL.Path, "/12345/threads"): + w.WriteHeader(200) + _, _ = w.Write([]byte(`{"id":"img_container"}`)) + case r.Method == "GET" && strings.HasPrefix(r.URL.Path, "/img_container"): + w.WriteHeader(200) + _, _ = w.Write([]byte(`{"id":"img_container","status":"FINISHED"}`)) + case r.Method == "GET" && strings.HasPrefix(r.URL.Path, "/img_post"): + w.WriteHeader(200) + _, _ = w.Write([]byte(`{"id":"img_post","media_type":"IMAGE","media_url":"https://example.com/img.jpg"}`)) + default: + http.NotFound(w, r) + } + } + + client := testClient(t, http.HandlerFunc(handler)) + + post, err := client.CreateImagePost(context.Background(), &ImagePostContent{ + ImageURL: "https://example.com/img.jpg", + Text: "Check this out", + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if post.ID != "img_post" { + t.Errorf("expected img_post, got %s", post.ID) + } +} + +func TestCreateImagePost_MissingURL(t *testing.T) { + client := testClient(t, jsonHandler(200, `{}`)) + + _, err := client.CreateImagePost(context.Background(), &ImagePostContent{ImageURL: ""}) + if err == nil { + t.Fatal("expected error for missing image URL") + } +} + +func TestCreateVideoPost_Success(t *testing.T) { + handler := func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + switch { + case r.Method == "POST" && strings.HasPrefix(r.URL.Path, "/12345/threads_publish"): + w.WriteHeader(200) + _, _ = w.Write([]byte(`{"id":"vid_post"}`)) + case r.Method == "POST" && strings.HasPrefix(r.URL.Path, "/12345/threads"): + w.WriteHeader(200) + _, _ = w.Write([]byte(`{"id":"vid_container"}`)) + case r.Method == "GET" && strings.HasPrefix(r.URL.Path, "/vid_container"): + w.WriteHeader(200) + _, _ = w.Write([]byte(`{"id":"vid_container","status":"FINISHED"}`)) + case r.Method == "GET" && strings.HasPrefix(r.URL.Path, "/vid_post"): + w.WriteHeader(200) + _, _ = w.Write([]byte(`{"id":"vid_post","media_type":"VIDEO"}`)) + default: + http.NotFound(w, r) + } + } + + client := testClient(t, http.HandlerFunc(handler)) + + post, err := client.CreateVideoPost(context.Background(), &VideoPostContent{ + VideoURL: "https://example.com/vid.mp4", + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if post.ID != "vid_post" { + t.Errorf("expected vid_post, got %s", post.ID) + } +} + +func TestCreateVideoPost_MissingURL(t *testing.T) { + client := testClient(t, jsonHandler(200, `{}`)) + + _, err := client.CreateVideoPost(context.Background(), &VideoPostContent{VideoURL: ""}) + if err == nil { + t.Fatal("expected error for missing video URL") + } +} + +func TestRepostPost_Success(t *testing.T) { + handler := func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + switch { + case r.Method == "POST" && strings.HasPrefix(r.URL.Path, "/original_post/repost"): + w.WriteHeader(200) + _, _ = w.Write([]byte(`{"id":"repost_1"}`)) + case r.Method == "GET" && strings.HasPrefix(r.URL.Path, "/repost_1"): + w.WriteHeader(200) + _, _ = w.Write([]byte(`{"id":"repost_1","media_type":"TEXT"}`)) + default: + http.NotFound(w, r) + } + } + + client := testClient(t, http.HandlerFunc(handler)) + + post, err := client.RepostPost(context.Background(), ConvertToPostID("original_post")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if post.ID != "repost_1" { + t.Errorf("expected repost_1, got %s", post.ID) + } +} + +func TestRepostPost_InvalidID(t *testing.T) { + client := testClient(t, jsonHandler(200, `{}`)) + + _, err := client.RepostPost(context.Background(), PostID("")) + if err == nil { + t.Fatal("expected error for empty post ID") + } +} + +func TestGetContainerStatus_Success(t *testing.T) { + client := testClient(t, jsonHandler(200, `{"id":"container_1","status":"FINISHED"}`)) + + status, err := client.GetContainerStatus(context.Background(), ConvertToContainerID("container_1")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if status.Status != "FINISHED" { + t.Errorf("expected FINISHED, got %s", status.Status) + } +} + +func TestGetContainerStatus_InvalidID(t *testing.T) { + client := testClient(t, jsonHandler(200, `{}`)) + + _, err := client.GetContainerStatus(context.Background(), ContainerID("")) + if err == nil { + t.Fatal("expected error for empty container ID") + } +} diff --git a/posts_delete_test.go b/posts_delete_test.go new file mode 100644 index 0000000..fbe4b1f --- /dev/null +++ b/posts_delete_test.go @@ -0,0 +1,50 @@ +package threads + +import ( + "context" + "net/http" + "testing" +) + +func TestDeletePost_Success(t *testing.T) { + handler := func(w http.ResponseWriter, r *http.Request) { + if r.Method != "DELETE" { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(200) + _, _ = w.Write([]byte(`{"id":"post_1","owner":{"id":"12345"}}`)) + return + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(200) + _, _ = w.Write([]byte(`{"success":true}`)) + } + + client := testClient(t, http.HandlerFunc(handler)) + + err := client.DeletePost(context.Background(), ConvertToPostID("post_1")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestDeletePost_InvalidID(t *testing.T) { + client := testClient(t, jsonHandler(200, `{}`)) + + err := client.DeletePost(context.Background(), PostID("")) + if err == nil { + t.Fatal("expected error for empty post ID") + } +} + +func TestDeletePost_NotFound(t *testing.T) { + client := testClient(t, jsonHandler(404, `{"error":{"message":"not found","type":"OAuthException","code":100}}`)) + client.config.RetryConfig.MaxRetries = 0 + + err := client.DeletePost(context.Background(), ConvertToPostID("nonexistent")) + if err == nil { + t.Fatal("expected error for 404") + } + if !IsAPIError(err) { + t.Errorf("expected APIError, got %T", err) + } +} diff --git a/posts_read_test.go b/posts_read_test.go new file mode 100644 index 0000000..bcfd980 --- /dev/null +++ b/posts_read_test.go @@ -0,0 +1,177 @@ +package threads + +import ( + "context" + "net/http" + "testing" +) + +func TestGetPost_Success(t *testing.T) { + client := testClient(t, jsonHandler(200, `{ + "id": "123456", + "text": "Hello world", + "media_type": "TEXT", + "permalink": "https://threads.net/@user/post/123456", + "username": "testuser", + "timestamp": "2026-01-15T10:30:00+0000" + }`)) + + post, err := client.GetPost(context.Background(), ConvertToPostID("123456")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if post.ID != "123456" { + t.Errorf("expected ID 123456, got %s", post.ID) + } + if post.Text != "Hello world" { + t.Errorf("expected text 'Hello world', got %s", post.Text) + } + if post.Username != "testuser" { + t.Errorf("expected username 'testuser', got %s", post.Username) + } +} + +func TestGetPost_InvalidID(t *testing.T) { + client := testClient(t, jsonHandler(200, `{}`)) + + _, err := client.GetPost(context.Background(), PostID("")) + if err == nil { + t.Fatal("expected error for empty post ID") + } + if !IsValidationError(err) { + t.Errorf("expected ValidationError, got %T", err) + } +} + +func TestGetPost_NotFound(t *testing.T) { + client := testClient(t, jsonHandler(404, `{"error":{"message":"Object does not exist","type":"OAuthException","code":100}}`)) + + _, err := client.GetPost(context.Background(), ConvertToPostID("nonexistent")) + if err == nil { + t.Fatal("expected error for 404") + } + if !IsAPIError(err) { + t.Errorf("expected APIError, got %T", err) + } +} + +func TestGetPost_ServerError(t *testing.T) { + client := testClient(t, jsonHandler(500, `{"error":{"message":"Internal error","type":"OAuthException","code":2}}`)) + client.config.RetryConfig.MaxRetries = 0 + + _, err := client.GetPost(context.Background(), ConvertToPostID("123")) + if err == nil { + t.Fatal("expected error for 500") + } + if !IsAPIError(err) { + t.Errorf("expected APIError, got %T", err) + } +} + +func TestGetPost_AuthenticationRequired(t *testing.T) { + client := testClient(t, jsonHandler(200, `{}`)) + _ = client.ClearToken() + + _, err := client.GetPost(context.Background(), ConvertToPostID("123")) + if err == nil { + t.Fatal("expected error when not authenticated") + } + if !IsAuthenticationError(err) { + t.Errorf("expected AuthenticationError, got %T", err) + } +} + +func TestGetUserPosts_Success(t *testing.T) { + client := testClient(t, jsonHandler(200, `{ + "data": [ + {"id": "1", "text": "Post 1"}, + {"id": "2", "text": "Post 2"} + ], + "paging": {"cursors": {"after": "cursor123"}} + }`)) + + resp, err := client.GetUserPosts(context.Background(), ConvertToUserID("12345"), nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(resp.Data) != 2 { + t.Errorf("expected 2 posts, got %d", len(resp.Data)) + } + if resp.Paging.Cursors == nil || resp.Paging.Cursors.After != "cursor123" { + t.Error("expected paging cursor") + } +} + +func TestGetUserPosts_InvalidUserID(t *testing.T) { + client := testClient(t, jsonHandler(200, `{}`)) + + _, err := client.GetUserPosts(context.Background(), UserID(""), nil) + if err == nil { + t.Fatal("expected error for empty user ID") + } + if !IsValidationError(err) { + t.Errorf("expected ValidationError, got %T", err) + } +} + +func TestGetPublishingLimits_Success(t *testing.T) { + client := testClient(t, jsonHandler(200, `{ + "data": [{ + "quota_usage": 5, + "config": {"quota_total": 250, "quota_duration": 86400}, + "reply_quota_usage": 10, + "reply_config": {"quota_total": 1000, "quota_duration": 86400} + }] + }`)) + + limits, err := client.GetPublishingLimits(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if limits.QuotaUsage != 5 { + t.Errorf("expected quota_usage 5, got %d", limits.QuotaUsage) + } + if limits.Config.QuotaTotal != 250 { + t.Errorf("expected quota_total 250, got %d", limits.Config.QuotaTotal) + } +} + +func TestGetUserMentions_Success(t *testing.T) { + client := testClient(t, jsonHandler(200, `{ + "data": [{"id": "1", "text": "@user mentioned you"}], + "paging": {} + }`)) + + resp, err := client.GetUserMentions(context.Background(), ConvertToUserID("12345"), nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(resp.Data) != 1 { + t.Errorf("expected 1 mention, got %d", len(resp.Data)) + } +} + +func TestGetUserGhostPosts_Success(t *testing.T) { + handler := func(w http.ResponseWriter, r *http.Request) { + fields := r.URL.Query().Get("fields") + if fields != GhostPostFields { + t.Errorf("expected ghost post fields, got %s", fields) + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(200) + _, _ = w.Write([]byte(`{ + "data": [{"id": "1", "text": "Ghost!", "ghost_post_status": "active"}], + "paging": {} + }`)) + } + + client := testClient(t, http.HandlerFunc(handler)) + + resp, err := client.GetUserGhostPosts(context.Background(), ConvertToUserID("12345"), nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(resp.Data) != 1 { + t.Errorf("expected 1 ghost post, got %d", len(resp.Data)) + } +} diff --git a/ratelimit.go b/ratelimit.go index 9a7d513..67cc0f7 100644 --- a/ratelimit.go +++ b/ratelimit.go @@ -60,13 +60,8 @@ func NewRateLimiter(config *RateLimiterConfig) *RateLimiter { // ShouldWait returns true if we should wait before making a request // Only returns true if we've been explicitly rate limited by the API func (rl *RateLimiter) ShouldWait() bool { - rl.mu.Lock() - defer rl.mu.Unlock() - - // Clear rate limited flag if the window has reset - if time.Now().After(rl.resetTime) { - rl.rateLimited = false - } + rl.mu.RLock() + defer rl.mu.RUnlock() // Only wait if we've been rate limited and the rate limit hasn't reset yet return rl.rateLimited && time.Now().Before(rl.resetTime) @@ -75,7 +70,6 @@ func (rl *RateLimiter) ShouldWait() bool { // Wait blocks until it's safe to make a request, only when actually rate limited func (rl *RateLimiter) Wait(ctx context.Context) error { rl.mu.Lock() - defer rl.mu.Unlock() // Check if rate limit window has reset if time.Now().After(rl.resetTime) { @@ -83,12 +77,14 @@ func (rl *RateLimiter) Wait(ctx context.Context) error { rl.resetTime = time.Now().Add(time.Hour) // Reset to 1 hour from now rl.rateLimited = false // Clear rate limited flag rl.logRateLimitReset() + rl.mu.Unlock() return nil // No need to wait if window has reset } // Only wait if we've been explicitly rate limited if !rl.rateLimited { rl.lastRequestTime = time.Now() + rl.mu.Unlock() return nil } @@ -105,14 +101,31 @@ func (rl *RateLimiter) Wait(ctx context.Context) error { rl.logRateLimitWait(waitTime) - // Wait for either the context to be cancelled or the wait time to elapse - select { - case <-ctx.Done(): - return ctx.Err() - case <-time.After(waitTime): - // After waiting, clear the rate limited flag + // Capture resetTime before releasing lock to detect if MarkRateLimited() + // was called with a later reset time while we were sleeping + originalResetTime := rl.resetTime + + // Release lock before sleeping so other goroutines aren't blocked + rl.mu.Unlock() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(waitTime): + } + + rl.mu.Lock() + if rl.resetTime.After(originalResetTime) { + // Rate limit was extended while we slept; recalculate and loop + originalResetTime = rl.resetTime + waitTime = time.Until(rl.resetTime) + rl.mu.Unlock() + continue + } rl.rateLimited = false rl.lastRequestTime = time.Now() + rl.mu.Unlock() return nil } } @@ -270,6 +283,7 @@ func (rl *RateLimiter) Reset() { rl.remaining = rl.limit rl.resetTime = time.Now().Add(time.Hour) rl.lastRequestTime = time.Time{} + rl.rateLimited = false // Drain the queue for len(rl.requestQueue) > 0 { diff --git a/ratelimit_test.go b/ratelimit_test.go new file mode 100644 index 0000000..5d14fe5 --- /dev/null +++ b/ratelimit_test.go @@ -0,0 +1,90 @@ +package threads + +import ( + "context" + "testing" + "time" +) + +func TestRateLimiter_NotRateLimitedByDefault(t *testing.T) { + rl := NewRateLimiter(&RateLimiterConfig{InitialLimit: 100}) + if rl.ShouldWait() { + t.Error("should not wait by default") + } + if rl.IsRateLimited() { + t.Error("should not be rate limited by default") + } +} + +func TestRateLimiter_MarkRateLimited(t *testing.T) { + rl := NewRateLimiter(&RateLimiterConfig{InitialLimit: 100}) + rl.MarkRateLimited(time.Now().Add(30 * time.Second)) + if !rl.IsRateLimited() { + t.Error("expected to be rate limited after marking") + } + if !rl.ShouldWait() { + t.Error("expected to should wait after being rate limited") + } +} + +func TestRateLimiter_WaitRespectsContext(t *testing.T) { + rl := NewRateLimiter(&RateLimiterConfig{InitialLimit: 100}) + rl.MarkRateLimited(time.Now().Add(10 * time.Second)) + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + err := rl.Wait(ctx) + if err == nil { + t.Fatal("expected context timeout error") + } +} + +func TestRateLimiter_UpdateFromHeaders(t *testing.T) { + rl := NewRateLimiter(&RateLimiterConfig{InitialLimit: 100}) + rl.UpdateFromHeaders(&RateLimitInfo{ + Limit: 200, + Remaining: 150, + Reset: time.Now().Add(time.Hour), + }) + + status := rl.GetStatus() + if status.Limit != 200 { + t.Errorf("expected limit 200, got %d", status.Limit) + } + if status.Remaining != 150 { + t.Errorf("expected remaining 150, got %d", status.Remaining) + } +} + +func TestRateLimiter_IsNearLimit(t *testing.T) { + rl := NewRateLimiter(&RateLimiterConfig{InitialLimit: 100}) + rl.UpdateFromHeaders(&RateLimitInfo{Limit: 100, Remaining: 10}) + + if !rl.IsNearLimit(0.8) { + t.Error("expected near limit at 80% threshold") + } + if rl.IsNearLimit(0.95) { + t.Error("expected not near limit at 95% threshold") + } +} + +func TestRateLimiter_Reset(t *testing.T) { + rl := NewRateLimiter(&RateLimiterConfig{InitialLimit: 100}) + rl.MarkRateLimited(time.Now().Add(time.Hour)) + rl.Reset() + if rl.IsRateLimited() { + t.Error("expected not rate limited after reset") + } +} + +func TestRateLimiter_QueueRequest(t *testing.T) { + rl := NewRateLimiter(&RateLimiterConfig{InitialLimit: 100, QueueSize: 2}) + _ = rl.QueueRequest(context.Background()) + _ = rl.QueueRequest(context.Background()) + + err := rl.QueueRequest(context.Background()) + if err == nil { + t.Fatal("expected error when queue is full") + } +} diff --git a/replies_test.go b/replies_test.go new file mode 100644 index 0000000..abaf573 --- /dev/null +++ b/replies_test.go @@ -0,0 +1,102 @@ +package threads + +import ( + "context" + "testing" +) + +func TestGetReplies_Success(t *testing.T) { + client := testClient(t, jsonHandler(200, `{ + "data": [ + {"id": "reply_1", "text": "Great post!", "is_reply": true}, + {"id": "reply_2", "text": "Thanks!", "is_reply": true} + ], + "paging": {"cursors": {"after": "next_cursor"}} + }`)) + + resp, err := client.GetReplies(context.Background(), ConvertToPostID("post_1"), nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(resp.Data) != 2 { + t.Errorf("expected 2 replies, got %d", len(resp.Data)) + } +} + +func TestGetReplies_InvalidPostID(t *testing.T) { + client := testClient(t, jsonHandler(200, `{}`)) + _, err := client.GetReplies(context.Background(), PostID(""), nil) + if err == nil { + t.Fatal("expected error for empty post ID") + } +} + +func TestGetConversation_Success(t *testing.T) { + client := testClient(t, jsonHandler(200, `{ + "data": [{"id": "msg_1", "text": "Thread message"}], + "paging": {} + }`)) + + resp, err := client.GetConversation(context.Background(), ConvertToPostID("post_1"), nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(resp.Data) != 1 { + t.Errorf("expected 1 message, got %d", len(resp.Data)) + } +} + +func TestHideReply_Success(t *testing.T) { + client := testClient(t, jsonHandler(200, `{"success":true}`)) + err := client.HideReply(context.Background(), ConvertToPostID("reply_1")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestHideReply_InvalidID(t *testing.T) { + client := testClient(t, jsonHandler(200, `{}`)) + err := client.HideReply(context.Background(), PostID("")) + if err == nil { + t.Fatal("expected error for empty reply ID") + } +} + +func TestUnhideReply_Success(t *testing.T) { + client := testClient(t, jsonHandler(200, `{"success":true}`)) + err := client.UnhideReply(context.Background(), ConvertToPostID("reply_1")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestGetPendingReplies_Success(t *testing.T) { + client := testClient(t, jsonHandler(200, `{ + "data": [{"id": "pending_1", "text": "Awaiting approval", "reply_approval_status": "pending"}], + "paging": {} + }`)) + + resp, err := client.GetPendingReplies(context.Background(), ConvertToPostID("post_1"), nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(resp.Data) != 1 { + t.Errorf("expected 1 pending reply, got %d", len(resp.Data)) + } +} + +func TestApprovePendingReply_Success(t *testing.T) { + client := testClient(t, jsonHandler(200, `{"success":true}`)) + err := client.ApprovePendingReply(context.Background(), ConvertToPostID("pending_1")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestIgnorePendingReply_Success(t *testing.T) { + client := testClient(t, jsonHandler(200, `{"success":true}`)) + err := client.IgnorePendingReply(context.Background(), ConvertToPostID("pending_1")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} diff --git a/search_test.go b/search_test.go new file mode 100644 index 0000000..06ef9b3 --- /dev/null +++ b/search_test.go @@ -0,0 +1,53 @@ +package threads + +import ( + "context" + "net/http" + "testing" +) + +func TestKeywordSearch_Success(t *testing.T) { + client := testClient(t, jsonHandler(200, `{ + "data": [{"id": "1", "text": "Go programming"}, {"id": "2", "text": "Golang tips"}], + "paging": {} + }`)) + + resp, err := client.KeywordSearch(context.Background(), "golang", nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(resp.Data) != 2 { + t.Errorf("expected 2 results, got %d", len(resp.Data)) + } +} + +func TestKeywordSearch_EmptyQuery(t *testing.T) { + client := testClient(t, jsonHandler(200, `{}`)) + _, err := client.KeywordSearch(context.Background(), "", nil) + if err == nil { + t.Fatal("expected error for empty query") + } +} + +func TestKeywordSearch_WithOptions(t *testing.T) { + handler := func(w http.ResponseWriter, r *http.Request) { + q := r.URL.Query() + if q.Get("q") != "test" { + t.Errorf("expected q=test, got q=%s", q.Get("q")) + } + if q.Get("search_type") != "TOP" { + t.Errorf("expected search_type=TOP, got %s", q.Get("search_type")) + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(200) + _, _ = w.Write([]byte(`{"data":[],"paging":{}}`)) + } + + client := testClient(t, http.HandlerFunc(handler)) + _, err := client.KeywordSearch(context.Background(), "test", &SearchOptions{ + SearchType: SearchTypeTop, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} diff --git a/test_helpers_test.go b/test_helpers_test.go new file mode 100644 index 0000000..bc8b9b4 --- /dev/null +++ b/test_helpers_test.go @@ -0,0 +1,74 @@ +package threads + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" +) + +// testClient creates a *Client whose HTTP requests go to the given handler. +func testClient(t *testing.T, handler http.Handler) *Client { + t.Helper() + server := httptest.NewServer(handler) + t.Cleanup(server.Close) + + config := &Config{ + ClientID: "test-client-id", + ClientSecret: "test-client-secret", + RedirectURI: "https://example.com/callback", + } + config.SetDefaults() + config.BaseURL = server.URL + + client, err := NewClient(config) + if err != nil { + t.Fatalf("testClient: %v", err) + } + + // Set a valid token so methods that require auth work + err = client.SetTokenInfo(&TokenInfo{ + AccessToken: "test-access-token", + TokenType: "Bearer", + ExpiresAt: time.Now().Add(24 * time.Hour), + UserID: "12345", + CreatedAt: time.Now(), + }) + if err != nil { + t.Fatalf("testClient SetTokenInfo: %v", err) + } + + return client +} + +// jsonHandler returns an http.HandlerFunc that responds with the given status and JSON body. +func jsonHandler(status int, body string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + _, _ = w.Write([]byte(body)) + } +} + +// newTestHTTPClient creates an HTTPClient pointed at a test server with the given handler and retry config. +func newTestHTTPClient(t *testing.T, handler http.Handler, retryConfig *RetryConfig) *HTTPClient { + t.Helper() + server := httptest.NewServer(handler) + t.Cleanup(server.Close) + + config := &Config{ + HTTPTimeout: 5 * time.Second, + Logger: &noopLogger{}, + RetryConfig: retryConfig, + BaseURL: server.URL, + } + return NewHTTPClient(config, nil) +} + +// noopLogger is a no-op Logger implementation for tests. +type noopLogger struct{} + +func (n *noopLogger) Debug(msg string, fields ...any) {} +func (n *noopLogger) Info(msg string, fields ...any) {} +func (n *noopLogger) Warn(msg string, fields ...any) {} +func (n *noopLogger) Error(msg string, fields ...any) {} diff --git a/users_test.go b/users_test.go new file mode 100644 index 0000000..f0349ac --- /dev/null +++ b/users_test.go @@ -0,0 +1,93 @@ +package threads + +import ( + "context" + "net/http" + "testing" +) + +func TestGetUser_Success(t *testing.T) { + client := testClient(t, jsonHandler(200, `{ + "id": "12345", + "username": "testuser", + "name": "Test User", + "followers_count": 100 + }`)) + + user, err := client.GetUser(context.Background(), ConvertToUserID("12345")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if user.Username != "testuser" { + t.Errorf("expected testuser, got %s", user.Username) + } +} + +func TestGetUser_InvalidID(t *testing.T) { + client := testClient(t, jsonHandler(200, `{}`)) + _, err := client.GetUser(context.Background(), UserID("")) + if err == nil { + t.Fatal("expected error for empty user ID") + } +} + +func TestGetMe_Success(t *testing.T) { + client := testClient(t, jsonHandler(200, `{"id":"12345","username":"me","name":"My Name"}`)) + + user, err := client.GetMe(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if user.Username != "me" { + t.Errorf("expected 'me', got %s", user.Username) + } +} + +func TestLookupPublicProfile_Success(t *testing.T) { + client := testClient(t, jsonHandler(200, `{ + "username": "publicuser", + "name": "Public User", + "is_verified": true, + "follower_count": 5000 + }`)) + + user, err := client.LookupPublicProfile(context.Background(), "publicuser") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if user.Username != "publicuser" { + t.Errorf("expected publicuser, got %s", user.Username) + } + if !user.IsVerified { + t.Error("expected verified user") + } +} + +func TestLookupPublicProfile_EmptyUsername(t *testing.T) { + client := testClient(t, jsonHandler(200, `{}`)) + _, err := client.LookupPublicProfile(context.Background(), "") + if err == nil { + t.Fatal("expected error for empty username") + } +} + +func TestGetUserFields_Success(t *testing.T) { + handler := func(w http.ResponseWriter, r *http.Request) { + fields := r.URL.Query().Get("fields") + if fields == "" { + t.Error("expected fields parameter") + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(200) + _, _ = w.Write([]byte(`{"id":"12345","username":"testuser"}`)) + } + + client := testClient(t, http.HandlerFunc(handler)) + user, err := client.GetUserFields(context.Background(), ConvertToUserID("12345"), []string{"id", "username"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if user.ID != "12345" { + t.Errorf("expected 12345, got %s", user.ID) + } +}