Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chore/improve client #1344

Merged
merged 4 commits into from
May 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion aws/resource_service_principal_role_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
package aws

import (
"github.com/databrickslabs/terraform-provider-databricks/common"
"testing"

"github.com/databrickslabs/terraform-provider-databricks/common"

"github.com/databrickslabs/terraform-provider-databricks/scim"

"github.com/databrickslabs/terraform-provider-databricks/qa"
Expand Down
14 changes: 7 additions & 7 deletions catalog/resource_external_location.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ func NewExternalLocationsAPI(ctx context.Context, m interface{}) ExternalLocatio
}

type ExternalLocationInfo struct {
Name string `json:"name" tf:"force_new"`
URL string `json:"url"`
CredentialName string `json:"credential_name"`
Comment string `json:"comment,omitempty"`
SkipValidation bool `json:"skip_validation,omitempty"`
Owner string `json:"owner,omitempty" tf:"computed"`
MetastoreID string `json:"metastore_id,omitempty" tf:"computed"`
Name string `json:"name" tf:"force_new"`
URL string `json:"url"`
CredentialName string `json:"credential_name"`
Comment string `json:"comment,omitempty"`
SkipValidation bool `json:"skip_validation,omitempty"`
Owner string `json:"owner,omitempty" tf:"computed"`
MetastoreID string `json:"metastore_id,omitempty" tf:"computed"`
}

func (a ExternalLocationsAPI) create(el *ExternalLocationInfo) error {
Expand Down
3 changes: 1 addition & 2 deletions commands/commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (

"github.com/databrickslabs/terraform-provider-databricks/clusters"
"github.com/databrickslabs/terraform-provider-databricks/common"
"github.com/databrickslabs/terraform-provider-databricks/internal"

"github.com/hashicorp/terraform-plugin-sdk/v2/helper/resource"
)
Expand Down Expand Up @@ -52,7 +51,7 @@ func (a CommandsAPI) Execute(clusterID, language, commandStr string) common.Comm
Summary: fmt.Sprintf("Cluster %s has to be running or resizing, but is %s", clusterID, cluster.State),
}
}
commandStr = internal.TrimLeadingWhitespace(commandStr)
commandStr = TrimLeadingWhitespace(commandStr)
log.Printf("[INFO] Executing %s command on %s:\n%s", language, clusterID, commandStr)
context, err := a.createContext(language, clusterID)
if err != nil {
Expand Down
5 changes: 3 additions & 2 deletions internal/utils.go → commands/leading_whitespace.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
package internal
package commands

import (
"strings"
)

// TrimLeadingWhitespace removes leading whitespace
// TrimLeadingWhitespace removes leading whitespace, so that Python code blocks
// that are embedded into Go code still could be interpreted properly.
func TrimLeadingWhitespace(commandStr string) (newCommand string) {
lines := strings.Split(strings.ReplaceAll(commandStr, "\t", " "), "\n")
leadingWhitespace := 1<<31 - 1
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package internal
package commands

import (
"testing"
Expand Down
2 changes: 2 additions & 0 deletions common/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,8 @@ func (c *DatabricksClient) niceAuthError(message string) error {
}
info = ". " + strings.Join(infos, ". ")
}
info = strings.TrimSuffix(info, ".")
message = strings.TrimSuffix(message, ".")
docUrl := "https://registry.terraform.io/providers/databrickslabs/databricks/latest/docs#authentication"
return fmt.Errorf("%s%s. Please check %s for details", message, info, docUrl)
}
Expand Down
142 changes: 82 additions & 60 deletions common/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"io"
"log"
"net/http"
"net/url"
Expand Down Expand Up @@ -179,7 +179,7 @@ func (c *DatabricksClient) commonErrorClarity(resp *http.Response) *APIError {
}

func (c *DatabricksClient) parseError(resp *http.Response) APIError {
body, err := ioutil.ReadAll(resp.Body)
body, err := io.ReadAll(resp.Body)
if err != nil {
return APIError{
Message: err.Error(),
Expand Down Expand Up @@ -345,16 +345,21 @@ func (c *DatabricksClient) completeUrl(r *http.Request) error {
return nil
}

// scimPathVisitorFactory is a separate method for the sake of unit tests
func (c *DatabricksClient) scimVisitor(r *http.Request) error {
r.Header.Set("Content-Type", "application/scim+json; charset=utf-8")
if c.isAccountsClient() && c.AccountID != "" {
// until `/preview` is there for workspace scim,
// `/api/2.0` is added by completeUrl visitor
r.URL.Path = strings.ReplaceAll(r.URL.Path, "/api/2.0/preview",
fmt.Sprintf("/api/2.0/accounts/%s", c.AccountID))
}
return nil
}

// Scim sets SCIM headers
func (c *DatabricksClient) Scim(ctx context.Context, method, path string, request interface{}, response interface{}) error {
body, err := c.authenticatedQuery(ctx, method, path, request, c.completeUrl, func(r *http.Request) error {
r.Header.Set("Content-Type", "application/scim+json; charset=utf-8")
if c.isAccountsClient() && c.AccountID != "" {
// until `/preview` is there for workspace scim
r.URL.Path = strings.ReplaceAll(path, "/preview", fmt.Sprintf("/api/2.0/accounts/%s", c.AccountID))
}
return nil
})
body, err := c.authenticatedQuery(ctx, method, path, request, c.completeUrl, c.scimVisitor)
if err != nil {
return err
}
Expand Down Expand Up @@ -402,7 +407,9 @@ func (c *DatabricksClient) redactedDump(body []byte) (res string) {
if len(body) == 0 {
return
}

if body[0] != '{' {
return fmt.Sprintf("[non-JSON document of %d bytes]", len(body))
}
var requestMap map[string]interface{}
err := json.Unmarshal(body, &requestMap)
if err != nil {
Expand Down Expand Up @@ -465,21 +472,21 @@ func (c *DatabricksClient) genericQuery(ctx context.Context, method, requestURL
return nil, fmt.Errorf("DatabricksClient is not configured")
}
if err = c.rateLimiter.Wait(ctx); err != nil {
return nil, err
return nil, fmt.Errorf("rate limited: %w", err)
}
requestBody, err := makeRequestBody(method, &requestURL, data, true)
requestBody, err := makeRequestBody(method, &requestURL, data)
if err != nil {
return nil, err
return nil, fmt.Errorf("request marshal: %w", err)
}
request, err := http.NewRequestWithContext(ctx, method, requestURL, bytes.NewBuffer(requestBody))
if err != nil {
return nil, err
return nil, fmt.Errorf("new request: %w", err)
}
request.Header.Set("User-Agent", c.userAgent(ctx))
for _, requestVisitor := range visitors {
err = requestVisitor(request)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed visitor: %w", err)
}
}
headers := c.createDebugHeaders(request.Header, c.Host)
Expand All @@ -488,78 +495,93 @@ func (c *DatabricksClient) genericQuery(ctx context.Context, method, requestURL

r, err := retryablehttp.FromRequest(request)
if err != nil {
return nil, err
return nil, err // no error invariants possible because of `makeRequestBody`
}
resp, err := c.httpClient.Do(r)
// retryablehttp library now returns only wrapped errors
var ae APIError
if errors.As(err, &ae) {
// don't re-wrap, as upper layers may depend on handling common.APIError
return nil, ae
}
if err != nil {
return nil, err
// i don't even know which errors in the real world would end up here.
// `retryablehttp` package nicely wraps _everything_ to `url.Error`.
return nil, fmt.Errorf("failed request: %w", err)
}
defer func() {
if ferr := resp.Body.Close(); ferr != nil {
err = ferr
err = fmt.Errorf("failed to close: %w", ferr)
}
}()
body, err = ioutil.ReadAll(resp.Body)
body, err = io.ReadAll(resp.Body)
if err != nil {
return nil, err
return nil, fmt.Errorf("response body: %w", err)
}
headers = c.createDebugHeaders(resp.Header, "")
log.Printf("[DEBUG] %s %s %s <- %s %s", resp.Status, headers, c.redactedDump(body), method, strings.ReplaceAll(request.URL.Path, "\n", ""))
return body, nil
}

func makeRequestBody(method string, requestURL *string, data interface{}, marshalJSON bool) ([]byte, error) {
func makeQueryString(data interface{}) (string, error) {
inputVal := reflect.ValueOf(data)
inputType := reflect.TypeOf(data)
if inputType.Kind() == reflect.Map {
s := []string{}
keys := inputVal.MapKeys()
// sort map keys by their string repr, so that tests can be deterministic
sort.Slice(keys, func(i, j int) bool {
return keys[i].String() < keys[j].String()
})
for _, k := range keys {
v := inputVal.MapIndex(k)
if v.IsZero() {
continue
}
s = append(s, fmt.Sprintf("%s=%s",
strings.Replace(url.QueryEscape(fmt.Sprintf("%v", k.Interface())), "+", "%20", -1),
strings.Replace(url.QueryEscape(fmt.Sprintf("%v", v.Interface())), "+", "%20", -1)))
}
return "?" + strings.Join(s, "&"), nil
}
if inputType.Kind() == reflect.Struct {
params, err := query.Values(data)
if err != nil {
return "", fmt.Errorf("cannot create query string: %w", err)
}
return "?" + params.Encode(), nil
}
return "", fmt.Errorf("unsupported query string data: %#v", data)
}

func makeRequestBody(method string, requestURL *string, data interface{}) ([]byte, error) {
var requestBody []byte
if data == nil && (method == "DELETE" || method == "GET") {
return requestBody, nil
}
if method == "GET" {
inputVal := reflect.ValueOf(data)
inputType := reflect.TypeOf(data)
switch inputType.Kind() {
case reflect.Map:
s := []string{}
keys := inputVal.MapKeys()
// sort map keys by their string repr, so that tests can be deterministic
sort.Slice(keys, func(i, j int) bool {
return keys[i].String() < keys[j].String()
})
for _, k := range keys {
v := inputVal.MapIndex(k)
if v.IsZero() {
continue
}
s = append(s, fmt.Sprintf("%s=%s",
strings.Replace(url.QueryEscape(fmt.Sprintf("%v", k.Interface())), "+", "%20", -1),
strings.Replace(url.QueryEscape(fmt.Sprintf("%v", v.Interface())), "+", "%20", -1)))
}
*requestURL += "?" + strings.Join(s, "&")
case reflect.Struct:
params, err := query.Values(data)
if err != nil {
return nil, err
}
*requestURL += "?" + params.Encode()
default:
return requestBody, fmt.Errorf("unsupported request data: %#v", data)
qs, err := makeQueryString(data)
if err != nil {
return nil, err
}
} else {
if marshalJSON {
bodyBytes, err := json.MarshalIndent(data, "", " ")
if err != nil {
return nil, err
}
requestBody = bodyBytes
} else {
requestBody = []byte(data.(string))
*requestURL += qs
return requestBody, nil
}
if reader, ok := data.(io.Reader); ok {
raw, err := io.ReadAll(reader)
if err != nil {
return nil, fmt.Errorf("failed to read from reader: %w", err)
}
return raw, nil
}
if str, ok := data.(string); ok {
return []byte(str), nil
}
bodyBytes, err := json.MarshalIndent(data, "", " ")
if err != nil {
return nil, fmt.Errorf("request marshal failure: %w", err)
}
return requestBody, nil
return bodyBytes, nil
}

func onlyNBytes(j string, numBytes int) string {
Expand Down
Loading