diff --git a/internal/provider/client.go b/internal/provider/client.go index 14184bb..f148276 100644 --- a/internal/provider/client.go +++ b/internal/provider/client.go @@ -149,27 +149,27 @@ func (c *Client) doRequest(ctx context.Context, method, path string, body interf // Read the response body once and store it in a variable defer resp.Body.Close() - resp_body, _ := io.ReadAll(resp.Body) + respBody, _ := io.ReadAll(resp.Body) // Print the response body as JSON if available - if len(resp_body) > 0 { + if len(respBody) > 0 { var prettyBody map[string]interface{} - if err := json.Unmarshal(resp_body, &prettyBody); err == nil { + if err := json.Unmarshal(respBody, &prettyBody); err == nil { prettyJSON, _ := json.MarshalIndent(prettyBody, "", " ") tflog.Debug(ctx, fmt.Sprintf("[ZENML] Response body (JSON):\n%s", prettyJSON)) } else { - tflog.Debug(ctx, fmt.Sprintf("[ZENML] Response body:\n%s", string(resp_body))) + tflog.Debug(ctx, fmt.Sprintf("[ZENML] Response body:\n%s", string(respBody))) } } tflog.Info(ctx, fmt.Sprintf("[ZENML] Response status: %d", resp.StatusCode)) if resp.StatusCode < 200 || resp.StatusCode >= 300 { - return nil, resp.StatusCode, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(resp_body)) + return nil, resp.StatusCode, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(respBody)) } // Re-wrap the body so that the caller can still read it - resp.Body = io.NopCloser(bytes.NewReader(resp_body)) + resp.Body = io.NopCloser(bytes.NewReader(respBody)) return resp, resp.StatusCode, nil } @@ -251,19 +251,7 @@ func (c *Client) DeleteStack(ctx context.Context, id string) error { } func (c *Client) ListStacks(ctx context.Context, params *ListParams) (*Page[StackResponse], error) { - if params == nil { - params = &ListParams{ - Page: 1, - PageSize: 100, - } - } else { - if params.Page <= 0 { - params.Page = 1 - } - if params.PageSize <= 0 { - params.PageSize = 100 - } - } + params = initializeListParams(params) query := url.Values{} query.Add("page", fmt.Sprintf("%d", params.Page)) @@ -350,19 +338,7 @@ func (c *Client) DeleteComponent(ctx context.Context, id string) error { } func (c *Client) ListStackComponents(ctx context.Context, workspace string, params *ListParams) (*Page[ComponentResponse], error) { - if params == nil { - params = &ListParams{ - Page: 1, - PageSize: 100, - } - } else { - if params.Page <= 0 { - params.Page = 1 - } - if params.PageSize <= 0 { - params.PageSize = 100 - } - } + params = initializeListParams(params) query := url.Values{} query.Add("page", fmt.Sprintf("%d", params.Page)) @@ -463,19 +439,7 @@ func (c *Client) DeleteServiceConnector(ctx context.Context, id string) error { } func (c *Client) ListServiceConnectors(ctx context.Context, params *ListParams) (*Page[ServiceConnectorResponse], error) { - if params == nil { - params = &ListParams{ - Page: 1, - PageSize: 100, - } - } else { - if params.Page <= 0 { - params.Page = 1 - } - if params.PageSize <= 0 { - params.PageSize = 100 - } - } + params = initializeListParams(params) query := url.Values{} query.Add("page", fmt.Sprintf("%d", params.Page)) @@ -499,6 +463,19 @@ func (c *Client) ListServiceConnectors(ctx context.Context, params *ListParams) return &result, nil } +func initializeListParams(params *ListParams) *ListParams { + if params == nil { + params = &ListParams{} + } + if params.Page <= 0 { + params.Page = 1 + } + if params.PageSize <= 0 { + params.PageSize = 100 + } + return params +} + // Add this new method to the Client func (c *Client) GetServiceConnectorByName(ctx context.Context, workspace, name string) (*ServiceConnectorResponse, error) { params := &ListParams{ diff --git a/internal/provider/data_source_server.go b/internal/provider/data_source_server.go index a4eeee4..e609945 100644 --- a/internal/provider/data_source_server.go +++ b/internal/provider/data_source_server.go @@ -1,3 +1,4 @@ +// Package provider contains the implementation of the ZenML Terraform provider. package provider import ( @@ -8,6 +9,7 @@ import ( "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" ) +// dataSourceServer returns a Terraform resource schema for the ZenML server information. func dataSourceServer() *schema.Resource { return &schema.Resource{ Description: "Data source for global ZenML server information", @@ -55,9 +57,12 @@ func dataSourceServer() *schema.Resource { } } +// dataSourceServerRead reads the server information from the ZenML server and sets the corresponding fields in the Terraform state. func dataSourceServerRead(ctx context.Context, d *schema.ResourceData, m interface{}) diag.Diagnostics { - c := m.(*Client) - + c, ok := m.(*Client) + if !ok { + return diag.FromErr(fmt.Errorf("unexpected type for client: %T", m)) + } server, err := c.GetServerInfo(ctx) if err != nil { return diag.FromErr(fmt.Errorf("error fetching server info: %v", err)) diff --git a/internal/provider/data_source_service_connector.go b/internal/provider/data_source_service_connector.go index 61513dc..b3f8532 100644 --- a/internal/provider/data_source_service_connector.go +++ b/internal/provider/data_source_service_connector.go @@ -93,8 +93,8 @@ func dataSourceServiceConnectorRead(ctx context.Context, d *schema.ResourceData, name := d.Get("name").(string) id := d.Get("id").(string) - var err error = nil - var connector *ServiceConnectorResponse = nil + var err error + var connector *ServiceConnectorResponse if id != "" { connector, err = c.GetServiceConnector(ctx, id) @@ -120,20 +120,20 @@ func dataSourceServiceConnectorRead(ctx context.Context, d *schema.ResourceData, if connector.Body != nil { - connector_type := "" + connectorType := "" // Unmarshal the connector type, which can be either a string or a struct // Try to unmarshal as string - err = json.Unmarshal(connector.Body.ConnectorType, &connector_type) + err = json.Unmarshal(connector.Body.ConnectorType, &connectorType) if err != nil { - var type_struct ServiceConnectorType + var typeStruct ServiceConnectorType // Try to unmarshal as struct - if err = json.Unmarshal(connector.Body.ConnectorType, &type_struct); err == nil { - connector_type = type_struct.ConnectorType + if err = json.Unmarshal(connector.Body.ConnectorType, &typeStruct); err == nil { + connectorType = typeStruct.ConnectorType } } - if err := d.Set("type", connector_type); err != nil { + if err := d.Set("type", connectorType); err != nil { return diag.FromErr(err) } @@ -145,7 +145,9 @@ func dataSourceServiceConnectorRead(ctx context.Context, d *schema.ResourceData, if len(connector.Body.ResourceTypes) == 1 { d.Set("resource_type", connector.Body.ResourceTypes[0]) } else { - d.Set("resource_type", "") + if err := d.Set("resource_type", ""); err != nil { + return diag.FromErr(err) + } } if connector.Body.ResourceID != nil { diff --git a/internal/provider/data_source_stack.go b/internal/provider/data_source_stack.go index f814255..2e35c0f 100644 --- a/internal/provider/data_source_stack.go +++ b/internal/provider/data_source_stack.go @@ -84,8 +84,8 @@ func dataSourceStackRead(ctx context.Context, d *schema.ResourceData, m interfac workspace := d.Get("workspace").(string) name := d.Get("name").(string) - var stack *StackResponse = nil - var err error = nil + var stack *StackResponse + var err error if id != "" { // Get stack by ID @@ -93,7 +93,7 @@ func dataSourceStackRead(ctx context.Context, d *schema.ResourceData, m interfac if err != nil { return diag.FromErr(fmt.Errorf("error getting stack: %v", err)) } - } else if name == "" { + } else if name != "" { // List stacks with filter to find by name params := &ListParams{ Filter: map[string]string{ @@ -142,7 +142,7 @@ func dataSourceStackRead(ctx context.Context, d *schema.ResourceData, m interfac // Extract keys keys := make([]string, 0, len(stack.Metadata.Components)) - for key, _ := range stack.Metadata.Components { + for key := range stack.Metadata.Components { keys = append(keys, key) } @@ -154,13 +154,13 @@ func dataSourceStackRead(ctx context.Context, d *schema.ResourceData, m interfac componentList := stack.Metadata.Components[key] var componentData map[string]string for _, component := range componentList { + // Only take the first component of each type componentData = map[string]string{ "id": component.ID, "name": component.Name, "type": component.Body.Type, "flavor": component.Body.Flavor, } - // Only take the first component of each type break } components = append(components, componentData) diff --git a/internal/provider/provider.go b/internal/provider/provider.go index e7e1ff7..464821d 100644 --- a/internal/provider/provider.go +++ b/internal/provider/provider.go @@ -1,4 +1,5 @@ // provider.go +// This file defines the ZenML Terraform provider, including its schema, resources, and data sources. package provider import ( @@ -51,7 +52,7 @@ func providerConfigure(ctx context.Context, d *schema.ResourceData) (interface{} apiKey := d.Get("api_key").(string) apiToken := d.Get("api_token").(string) - // Should be handled by the schema + // The schema should ensure that the server_url is not empty if serverURL == "" { return nil, diag.Errorf("server_url must be configured") } diff --git a/internal/provider/provider_test.go b/internal/provider/provider_test.go index 428d77d..20b4d7a 100644 --- a/internal/provider/provider_test.go +++ b/internal/provider/provider_test.go @@ -28,9 +28,12 @@ func testAccPreCheck(t *testing.T) { t.Fatal("ZENML_SERVER_URL must be set for acceptance tests") } creds := []string{"ZENML_API_KEY", "ZENML_API_TOKEN"} - v:="" + v := "" for _, cred := range creds { v = os.Getenv(cred) + if v != "" { + break + } } if v == "" { t.Fatal( diff --git a/internal/provider/resource_stack.go b/internal/provider/resource_stack.go index b7b4a55..65c871f 100644 --- a/internal/provider/resource_stack.go +++ b/internal/provider/resource_stack.go @@ -1,4 +1,5 @@ // resource_stack.go +// This file contains the implementation of the Terraform resource for managing ZenML stacks. package provider import ( @@ -35,9 +36,10 @@ func resourceStack() *schema.Resource { Type: schema.TypeString, }, Description: "Map of component types to component IDs", - // We cannot delete components while they are still in use - // by a stack, so we need to force new stacks when components - // are changed. + // Components cannot be deleted while they are still in use by a stack + // because the stack relies on these components to function correctly. + // Therefore, any change to the components requires creating a new stack + // to ensure that the existing stack remains consistent and operational. ForceNew: true, }, "labels": { @@ -80,7 +82,6 @@ func resourceStack() *schema.Resource { func resourceStackCreate(ctx context.Context, d *schema.ResourceData, m interface{}) diag.Diagnostics { client := m.(*Client) - // Get the workspace from schema instead of hardcoding workspace := d.Get("workspace").(string) stack := StackRequest{ @@ -96,7 +97,6 @@ func resourceStackCreate(ctx context.Context, d *schema.ResourceData, m interfac } stack.Components = components } - // Handle labels if v, ok := d.GetOk("labels"); ok { labels := make(map[string]string) @@ -105,7 +105,6 @@ func resourceStackCreate(ctx context.Context, d *schema.ResourceData, m interfac } stack.Labels = labels } - resp, err := client.CreateStack(ctx, workspace, stack) if err != nil { return diag.FromErr(fmt.Errorf("error creating stack: %w", err)) diff --git a/internal/provider/validation.go b/internal/provider/validation.go index e582d4d..7782f6d 100644 --- a/internal/provider/validation.go +++ b/internal/provider/validation.go @@ -3,15 +3,22 @@ package provider import ( "fmt" - "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" "strings" + + "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" ) +// Constants and variables used for validating connector types, resource types, +// authentication methods, and component types in the ZenML Terraform provider. // All validation constants and variables var ( validConnectorTypes = []string{ - "aws", "gcp", "azure", "kubernetes", - "docker", "hyperai", + "aws", + "gcp", + "azure", + "kubernetes", + "docker", + "hyperai", } validResourceTypes = map[string][]string{ @@ -83,7 +90,7 @@ var ( validComponentTypes = []string{ "alerter", - "annotator", + "annotator", "artifact_store", "container_registry", "data_validator", @@ -99,7 +106,7 @@ var ( func validateServiceConnector(d *schema.ResourceDiff) error { connectorType := d.Get("type").(string) - + // Validate connector type first validType := false for _, t := range validConnectorTypes { diff --git a/terraform-provider-zenml b/terraform-provider-zenml new file mode 100755 index 0000000..c980703 Binary files /dev/null and b/terraform-provider-zenml differ