Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ tilt_config.json
# Test reports
coverage.html
coverage.out
costmodel
61 changes: 61 additions & 0 deletions pkg/cloud/azure/pricesheetclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package azure

import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"time"
Expand Down Expand Up @@ -122,3 +124,62 @@ func (client *PriceSheetClient) downloadByBillingPeriodCreateRequest(ctx context
req.Raw().Header["Accept"] = []string{"*/*"}
return req, nil
}

// billingPeriodsListTemplate is the URL template for listing billing periods.
const billingPeriodsListTemplate = "/providers/Microsoft.Billing/billingAccounts/%s/billingPeriods"

// billingPeriodsListResponse represents the response from the billing periods list API.
type billingPeriodsListResponse struct {
Value []billingPeriodEntry `json:"value"`
}

// billingPeriodEntry represents a single billing period entry.
type billingPeriodEntry struct {
Name string `json:"name"`
}

// GetCurrentBillingPeriod fetches the most recent billing period name from the
// Azure Billing Periods List API. This handles both EA accounts (which use
// "yyyyMM" format) and MCA accounts (which use "yyyyMM-1" format).
// See: https://learn.microsoft.com/en-us/rest/api/billing/billing-periods/list
func (client *PriceSheetClient) GetCurrentBillingPeriod(ctx context.Context) (string, error) {
if client.billingAccountID == "" {
return "", errors.New("parameter client.billingAccountID cannot be empty")
}
urlPath := fmt.Sprintf(billingPeriodsListTemplate, url.PathEscape(client.billingAccountID))
req, err := runtime.NewRequest(ctx, http.MethodGet, runtime.JoinPaths(client.host, urlPath))
if err != nil {
return "", fmt.Errorf("creating billing periods list request: %w", err)
}
reqQP := req.Raw().URL.Query()
reqQP.Set("api-version", "2020-05-01")
// The API returns billing periods in descending order by default, so $top=1 gives the most recent period.
reqQP.Set("$top", "1")
req.Raw().URL.RawQuery = reqQP.Encode()
req.Raw().Header["Accept"] = []string{"application/json"}

resp, err := client.pl.Do(req)
if err != nil {
return "", fmt.Errorf("executing billing periods list request: %w", err)
}
defer resp.Body.Close()
if !runtime.HasStatusCode(resp, http.StatusOK) {
return "", runtime.NewResponseError(resp)
}

body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("reading billing periods response body: %w", err)
}

var result billingPeriodsListResponse
if err := json.Unmarshal(body, &result); err != nil {
return "", fmt.Errorf("parsing billing periods response: %w", err)
}

if len(result.Value) == 0 {
return "", errors.New("no billing periods returned from API")
}

return result.Value[0].Name, nil
}
10 changes: 10 additions & 0 deletions pkg/cloud/azure/pricesheetclient_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,16 @@ func TestPriceSheetClient_URLConstruction(t *testing.T) {
}
}

func TestPriceSheetClient_GetCurrentBillingPeriod_Validation(t *testing.T) {
cred := &mockCredential{}
client, err := NewPriceSheetClient("", cred, nil)
require.NoError(t, err)

_, err = client.GetCurrentBillingPeriod(context.Background())
assert.Error(t, err)
assert.Contains(t, err.Error(), "parameter client.billingAccountID cannot be empty")
}

// TestPriceSheetClient_MethodRegression ensures the HTTP method fix doesn't regress
// This test would fail if someone accidentally changed POST back to GET
func TestPriceSheetClient_MethodRegression(t *testing.T) {
Expand Down
13 changes: 12 additions & 1 deletion pkg/cloud/azure/pricesheetdownloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,18 @@ func (d *PriceSheetDownloader) getDownloadURL(ctx context.Context) (string, erro
if err != nil {
return "", fmt.Errorf("creating pricesheet client: %w", err)
}
poller, err := client.BeginDownloadByBillingPeriod(ctx, currentBillingPeriod())

// Dynamically fetch the current billing period name from the API.
// This handles both EA accounts (yyyyMM) and MCA accounts (yyyyMM-1).
billingPeriod, err := client.GetCurrentBillingPeriod(ctx)
if err != nil {
// Fall back to the hardcoded format for backwards compatibility.
log.Warnf("failed to fetch billing period from API, falling back to default format: %s", err)
billingPeriod = currentBillingPeriod()
}
log.Infof("using billing period %q", billingPeriod)

poller, err := client.BeginDownloadByBillingPeriod(ctx, billingPeriod)
if err != nil {
return "", fmt.Errorf("beginning pricesheet download: %w", err)
}
Expand Down