Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 22 additions & 45 deletions internal/provider/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,27 +149,27 @@ func (c *Client) doRequest(ctx context.Context, method, path string, body interf

// Read the response body once and store it in a variable
defer resp.Body.Close()
resp_body, _ := io.ReadAll(resp.Body)
respBody, _ := io.ReadAll(resp.Body)

// Print the response body as JSON if available
if len(resp_body) > 0 {
if len(respBody) > 0 {
var prettyBody map[string]interface{}
if err := json.Unmarshal(resp_body, &prettyBody); err == nil {
if err := json.Unmarshal(respBody, &prettyBody); err == nil {
prettyJSON, _ := json.MarshalIndent(prettyBody, "", " ")
tflog.Debug(ctx, fmt.Sprintf("[ZENML] Response body (JSON):\n%s", prettyJSON))
} else {
tflog.Debug(ctx, fmt.Sprintf("[ZENML] Response body:\n%s", string(resp_body)))
tflog.Debug(ctx, fmt.Sprintf("[ZENML] Response body:\n%s", string(respBody)))
}
}

tflog.Info(ctx, fmt.Sprintf("[ZENML] Response status: %d", resp.StatusCode))

if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, resp.StatusCode, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(resp_body))
return nil, resp.StatusCode, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(respBody))
}

// Re-wrap the body so that the caller can still read it
resp.Body = io.NopCloser(bytes.NewReader(resp_body))
resp.Body = io.NopCloser(bytes.NewReader(respBody))

return resp, resp.StatusCode, nil
}
Expand Down Expand Up @@ -251,19 +251,7 @@ func (c *Client) DeleteStack(ctx context.Context, id string) error {
}

func (c *Client) ListStacks(ctx context.Context, params *ListParams) (*Page[StackResponse], error) {
if params == nil {
params = &ListParams{
Page: 1,
PageSize: 100,
}
} else {
if params.Page <= 0 {
params.Page = 1
}
if params.PageSize <= 0 {
params.PageSize = 100
}
}
params = initializeListParams(params)

query := url.Values{}
query.Add("page", fmt.Sprintf("%d", params.Page))
Expand Down Expand Up @@ -350,19 +338,7 @@ func (c *Client) DeleteComponent(ctx context.Context, id string) error {
}

func (c *Client) ListStackComponents(ctx context.Context, workspace string, params *ListParams) (*Page[ComponentResponse], error) {
if params == nil {
params = &ListParams{
Page: 1,
PageSize: 100,
}
} else {
if params.Page <= 0 {
params.Page = 1
}
if params.PageSize <= 0 {
params.PageSize = 100
}
}
params = initializeListParams(params)

query := url.Values{}
query.Add("page", fmt.Sprintf("%d", params.Page))
Expand Down Expand Up @@ -463,19 +439,7 @@ func (c *Client) DeleteServiceConnector(ctx context.Context, id string) error {
}

func (c *Client) ListServiceConnectors(ctx context.Context, params *ListParams) (*Page[ServiceConnectorResponse], error) {
if params == nil {
params = &ListParams{
Page: 1,
PageSize: 100,
}
} else {
if params.Page <= 0 {
params.Page = 1
}
if params.PageSize <= 0 {
params.PageSize = 100
}
}
params = initializeListParams(params)

query := url.Values{}
query.Add("page", fmt.Sprintf("%d", params.Page))
Expand All @@ -499,6 +463,19 @@ func (c *Client) ListServiceConnectors(ctx context.Context, params *ListParams)
return &result, nil
}

func initializeListParams(params *ListParams) *ListParams {
if params == nil {
params = &ListParams{}
}
if params.Page <= 0 {
params.Page = 1
}
if params.PageSize <= 0 {
params.PageSize = 100
}
return params
}

// Add this new method to the Client
func (c *Client) GetServiceConnectorByName(ctx context.Context, workspace, name string) (*ServiceConnectorResponse, error) {
params := &ListParams{
Expand Down
9 changes: 7 additions & 2 deletions internal/provider/data_source_server.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
// Package provider contains the implementation of the ZenML Terraform provider.
package provider

import (
Expand All @@ -8,6 +9,7 @@ import (
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema"
)

// dataSourceServer returns a Terraform resource schema for the ZenML server information.
func dataSourceServer() *schema.Resource {
return &schema.Resource{
Description: "Data source for global ZenML server information",
Expand Down Expand Up @@ -55,9 +57,12 @@ func dataSourceServer() *schema.Resource {
}
}

// dataSourceServerRead reads the server information from the ZenML server and sets the corresponding fields in the Terraform state.
func dataSourceServerRead(ctx context.Context, d *schema.ResourceData, m interface{}) diag.Diagnostics {
c := m.(*Client)

c, ok := m.(*Client)
if !ok {
return diag.FromErr(fmt.Errorf("unexpected type for client: %T", m))
}
server, err := c.GetServerInfo(ctx)
if err != nil {
return diag.FromErr(fmt.Errorf("error fetching server info: %v", err))
Expand Down
20 changes: 11 additions & 9 deletions internal/provider/data_source_service_connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ func dataSourceServiceConnectorRead(ctx context.Context, d *schema.ResourceData,
name := d.Get("name").(string)
id := d.Get("id").(string)

var err error = nil
var connector *ServiceConnectorResponse = nil
var err error
var connector *ServiceConnectorResponse

if id != "" {
connector, err = c.GetServiceConnector(ctx, id)
Expand All @@ -120,20 +120,20 @@ func dataSourceServiceConnectorRead(ctx context.Context, d *schema.ResourceData,

if connector.Body != nil {

connector_type := ""
connectorType := ""

// Unmarshal the connector type, which can be either a string or a struct
// Try to unmarshal as string
err = json.Unmarshal(connector.Body.ConnectorType, &connector_type)
err = json.Unmarshal(connector.Body.ConnectorType, &connectorType)
if err != nil {
var type_struct ServiceConnectorType
var typeStruct ServiceConnectorType
// Try to unmarshal as struct
if err = json.Unmarshal(connector.Body.ConnectorType, &type_struct); err == nil {
connector_type = type_struct.ConnectorType
if err = json.Unmarshal(connector.Body.ConnectorType, &typeStruct); err == nil {
connectorType = typeStruct.ConnectorType
}
}

if err := d.Set("type", connector_type); err != nil {
if err := d.Set("type", connectorType); err != nil {
return diag.FromErr(err)
}

Expand All @@ -145,7 +145,9 @@ func dataSourceServiceConnectorRead(ctx context.Context, d *schema.ResourceData,
if len(connector.Body.ResourceTypes) == 1 {
d.Set("resource_type", connector.Body.ResourceTypes[0])
} else {
d.Set("resource_type", "")
if err := d.Set("resource_type", ""); err != nil {
return diag.FromErr(err)
}
}

if connector.Body.ResourceID != nil {
Expand Down
10 changes: 5 additions & 5 deletions internal/provider/data_source_stack.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,16 +84,16 @@ func dataSourceStackRead(ctx context.Context, d *schema.ResourceData, m interfac
workspace := d.Get("workspace").(string)
name := d.Get("name").(string)

var stack *StackResponse = nil
var err error = nil
var stack *StackResponse
var err error

if id != "" {
// Get stack by ID
stack, err = c.GetStack(ctx, id)
if err != nil {
return diag.FromErr(fmt.Errorf("error getting stack: %v", err))
}
} else if name == "" {
} else if name != "" {
// List stacks with filter to find by name
params := &ListParams{
Filter: map[string]string{
Expand Down Expand Up @@ -142,7 +142,7 @@ func dataSourceStackRead(ctx context.Context, d *schema.ResourceData, m interfac

// Extract keys
keys := make([]string, 0, len(stack.Metadata.Components))
for key, _ := range stack.Metadata.Components {
for key := range stack.Metadata.Components {
keys = append(keys, key)
}

Expand All @@ -154,13 +154,13 @@ func dataSourceStackRead(ctx context.Context, d *schema.ResourceData, m interfac
componentList := stack.Metadata.Components[key]
var componentData map[string]string
for _, component := range componentList {
// Only take the first component of each type
componentData = map[string]string{
"id": component.ID,
"name": component.Name,
"type": component.Body.Type,
"flavor": component.Body.Flavor,
}
// Only take the first component of each type
break
}
components = append(components, componentData)
Expand Down
3 changes: 2 additions & 1 deletion internal/provider/provider.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// provider.go
// This file defines the ZenML Terraform provider, including its schema, resources, and data sources.
package provider

import (
Expand Down Expand Up @@ -51,7 +52,7 @@ func providerConfigure(ctx context.Context, d *schema.ResourceData) (interface{}
apiKey := d.Get("api_key").(string)
apiToken := d.Get("api_token").(string)

// Should be handled by the schema
// The schema should ensure that the server_url is not empty
if serverURL == "" {
return nil, diag.Errorf("server_url must be configured")
}
Expand Down
5 changes: 4 additions & 1 deletion internal/provider/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,12 @@ func testAccPreCheck(t *testing.T) {
t.Fatal("ZENML_SERVER_URL must be set for acceptance tests")
}
creds := []string{"ZENML_API_KEY", "ZENML_API_TOKEN"}
v:=""
v := ""
for _, cred := range creds {
v = os.Getenv(cred)
if v != "" {
break
}
}
if v == "" {
t.Fatal(
Expand Down
11 changes: 5 additions & 6 deletions internal/provider/resource_stack.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// resource_stack.go
// This file contains the implementation of the Terraform resource for managing ZenML stacks.
package provider

import (
Expand Down Expand Up @@ -35,9 +36,10 @@ func resourceStack() *schema.Resource {
Type: schema.TypeString,
},
Description: "Map of component types to component IDs",
// We cannot delete components while they are still in use
// by a stack, so we need to force new stacks when components
// are changed.
// Components cannot be deleted while they are still in use by a stack
// because the stack relies on these components to function correctly.
// Therefore, any change to the components requires creating a new stack
// to ensure that the existing stack remains consistent and operational.
ForceNew: true,
},
"labels": {
Expand Down Expand Up @@ -80,7 +82,6 @@ func resourceStack() *schema.Resource {
func resourceStackCreate(ctx context.Context, d *schema.ResourceData, m interface{}) diag.Diagnostics {
client := m.(*Client)

// Get the workspace from schema instead of hardcoding
workspace := d.Get("workspace").(string)

stack := StackRequest{
Expand All @@ -96,7 +97,6 @@ func resourceStackCreate(ctx context.Context, d *schema.ResourceData, m interfac
}
stack.Components = components
}

// Handle labels
if v, ok := d.GetOk("labels"); ok {
labels := make(map[string]string)
Expand All @@ -105,7 +105,6 @@ func resourceStackCreate(ctx context.Context, d *schema.ResourceData, m interfac
}
stack.Labels = labels
}

resp, err := client.CreateStack(ctx, workspace, stack)
if err != nil {
return diag.FromErr(fmt.Errorf("error creating stack: %w", err))
Expand Down
17 changes: 12 additions & 5 deletions internal/provider/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,22 @@ package provider

import (
"fmt"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema"
"strings"

"github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema"
)

// Constants and variables used for validating connector types, resource types,
// authentication methods, and component types in the ZenML Terraform provider.
// All validation constants and variables
var (
validConnectorTypes = []string{
"aws", "gcp", "azure", "kubernetes",
"docker", "hyperai",
"aws",
"gcp",
"azure",
"kubernetes",
"docker",
"hyperai",
}

validResourceTypes = map[string][]string{
Expand Down Expand Up @@ -83,7 +90,7 @@ var (

validComponentTypes = []string{
"alerter",
"annotator",
"annotator",
"artifact_store",
"container_registry",
"data_validator",
Expand All @@ -99,7 +106,7 @@ var (

func validateServiceConnector(d *schema.ResourceDiff) error {
connectorType := d.Get("type").(string)

// Validate connector type first
validType := false
for _, t := range validConnectorTypes {
Expand Down
Binary file added terraform-provider-zenml
Binary file not shown.