diff --git a/.gitignore b/.gitignore index ff82b33a9..39cafdafd 100644 --- a/.gitignore +++ b/.gitignore @@ -23,3 +23,4 @@ tilt_config.json # Test reports coverage.html coverage.out +costmodel diff --git a/pkg/cloud/azure/pricesheetclient.go b/pkg/cloud/azure/pricesheetclient.go index 2f7569b13..633b43cc7 100644 --- a/pkg/cloud/azure/pricesheetclient.go +++ b/pkg/cloud/azure/pricesheetclient.go @@ -2,8 +2,10 @@ package azure import ( "context" + "encoding/json" "errors" "fmt" + "io" "net/http" "net/url" "time" @@ -122,3 +124,62 @@ func (client *PriceSheetClient) downloadByBillingPeriodCreateRequest(ctx context req.Raw().Header["Accept"] = []string{"*/*"} return req, nil } + +// billingPeriodsListTemplate is the URL template for listing billing periods. +const billingPeriodsListTemplate = "/providers/Microsoft.Billing/billingAccounts/%s/billingPeriods" + +// billingPeriodsListResponse represents the response from the billing periods list API. +type billingPeriodsListResponse struct { + Value []billingPeriodEntry `json:"value"` +} + +// billingPeriodEntry represents a single billing period entry. +type billingPeriodEntry struct { + Name string `json:"name"` +} + +// GetCurrentBillingPeriod fetches the most recent billing period name from the +// Azure Billing Periods List API. This handles both EA accounts (which use +// "yyyyMM" format) and MCA accounts (which use "yyyyMM-1" format). +// See: https://learn.microsoft.com/en-us/rest/api/billing/billing-periods/list +func (client *PriceSheetClient) GetCurrentBillingPeriod(ctx context.Context) (string, error) { + if client.billingAccountID == "" { + return "", errors.New("parameter client.billingAccountID cannot be empty") + } + urlPath := fmt.Sprintf(billingPeriodsListTemplate, url.PathEscape(client.billingAccountID)) + req, err := runtime.NewRequest(ctx, http.MethodGet, runtime.JoinPaths(client.host, urlPath)) + if err != nil { + return "", fmt.Errorf("creating billing periods list request: %w", err) + } + reqQP := req.Raw().URL.Query() + reqQP.Set("api-version", "2020-05-01") + // The API returns billing periods in descending order by default, so $top=1 gives the most recent period. + reqQP.Set("$top", "1") + req.Raw().URL.RawQuery = reqQP.Encode() + req.Raw().Header["Accept"] = []string{"application/json"} + + resp, err := client.pl.Do(req) + if err != nil { + return "", fmt.Errorf("executing billing periods list request: %w", err) + } + defer resp.Body.Close() + if !runtime.HasStatusCode(resp, http.StatusOK) { + return "", runtime.NewResponseError(resp) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("reading billing periods response body: %w", err) + } + + var result billingPeriodsListResponse + if err := json.Unmarshal(body, &result); err != nil { + return "", fmt.Errorf("parsing billing periods response: %w", err) + } + + if len(result.Value) == 0 { + return "", errors.New("no billing periods returned from API") + } + + return result.Value[0].Name, nil +} diff --git a/pkg/cloud/azure/pricesheetclient_test.go b/pkg/cloud/azure/pricesheetclient_test.go index 49ebb8d21..99b42726b 100644 --- a/pkg/cloud/azure/pricesheetclient_test.go +++ b/pkg/cloud/azure/pricesheetclient_test.go @@ -161,6 +161,16 @@ func TestPriceSheetClient_URLConstruction(t *testing.T) { } } +func TestPriceSheetClient_GetCurrentBillingPeriod_Validation(t *testing.T) { + cred := &mockCredential{} + client, err := NewPriceSheetClient("", cred, nil) + require.NoError(t, err) + + _, err = client.GetCurrentBillingPeriod(context.Background()) + assert.Error(t, err) + assert.Contains(t, err.Error(), "parameter client.billingAccountID cannot be empty") +} + // TestPriceSheetClient_MethodRegression ensures the HTTP method fix doesn't regress // This test would fail if someone accidentally changed POST back to GET func TestPriceSheetClient_MethodRegression(t *testing.T) { diff --git a/pkg/cloud/azure/pricesheetdownloader.go b/pkg/cloud/azure/pricesheetdownloader.go index 622095496..c30fa459a 100644 --- a/pkg/cloud/azure/pricesheetdownloader.go +++ b/pkg/cloud/azure/pricesheetdownloader.go @@ -59,7 +59,18 @@ func (d *PriceSheetDownloader) getDownloadURL(ctx context.Context) (string, erro if err != nil { return "", fmt.Errorf("creating pricesheet client: %w", err) } - poller, err := client.BeginDownloadByBillingPeriod(ctx, currentBillingPeriod()) + + // Dynamically fetch the current billing period name from the API. + // This handles both EA accounts (yyyyMM) and MCA accounts (yyyyMM-1). + billingPeriod, err := client.GetCurrentBillingPeriod(ctx) + if err != nil { + // Fall back to the hardcoded format for backwards compatibility. + log.Warnf("failed to fetch billing period from API, falling back to default format: %s", err) + billingPeriod = currentBillingPeriod() + } + log.Infof("using billing period %q", billingPeriod) + + poller, err := client.BeginDownloadByBillingPeriod(ctx, billingPeriod) if err != nil { return "", fmt.Errorf("beginning pricesheet download: %w", err) }