Skip to content

Commit

Permalink
Improve common package test coverage (databricks#1344)
Browse files Browse the repository at this point in the history
* Make client error messages friendlier
* Allow passing `io.Reader` as request body
* Increase test coverage for `common/http.go`
* Moved `TrimLeadingWhitespace` from `internal` to `commands`
  • Loading branch information
nfx committed May 30, 2022
1 parent f215220 commit 84146de
Show file tree
Hide file tree
Showing 29 changed files with 468 additions and 229 deletions.
3 changes: 2 additions & 1 deletion aws/resource_service_principal_role_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
package aws

import (
"github.com/databrickslabs/terraform-provider-databricks/common"
"testing"

"github.com/databrickslabs/terraform-provider-databricks/common"

"github.com/databrickslabs/terraform-provider-databricks/scim"

"github.com/databrickslabs/terraform-provider-databricks/qa"
Expand Down
14 changes: 7 additions & 7 deletions catalog/resource_external_location.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ func NewExternalLocationsAPI(ctx context.Context, m interface{}) ExternalLocatio
}

type ExternalLocationInfo struct {
Name string `json:"name" tf:"force_new"`
URL string `json:"url"`
CredentialName string `json:"credential_name"`
Comment string `json:"comment,omitempty"`
SkipValidation bool `json:"skip_validation,omitempty"`
Owner string `json:"owner,omitempty" tf:"computed"`
MetastoreID string `json:"metastore_id,omitempty" tf:"computed"`
Name string `json:"name" tf:"force_new"`
URL string `json:"url"`
CredentialName string `json:"credential_name"`
Comment string `json:"comment,omitempty"`
SkipValidation bool `json:"skip_validation,omitempty"`
Owner string `json:"owner,omitempty" tf:"computed"`
MetastoreID string `json:"metastore_id,omitempty" tf:"computed"`
}

func (a ExternalLocationsAPI) create(el *ExternalLocationInfo) error {
Expand Down
3 changes: 1 addition & 2 deletions commands/commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (

"github.com/databrickslabs/terraform-provider-databricks/clusters"
"github.com/databrickslabs/terraform-provider-databricks/common"
"github.com/databrickslabs/terraform-provider-databricks/internal"

"github.com/hashicorp/terraform-plugin-sdk/v2/helper/resource"
)
Expand Down Expand Up @@ -52,7 +51,7 @@ func (a CommandsAPI) Execute(clusterID, language, commandStr string) common.Comm
Summary: fmt.Sprintf("Cluster %s has to be running or resizing, but is %s", clusterID, cluster.State),
}
}
commandStr = internal.TrimLeadingWhitespace(commandStr)
commandStr = TrimLeadingWhitespace(commandStr)
log.Printf("[INFO] Executing %s command on %s:\n%s", language, clusterID, commandStr)
context, err := a.createContext(language, clusterID)
if err != nil {
Expand Down
5 changes: 3 additions & 2 deletions internal/utils.go → commands/leading_whitespace.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
package internal
package commands

import (
"strings"
)

// TrimLeadingWhitespace removes leading whitespace
// TrimLeadingWhitespace removes leading whitespace, so that Python code blocks
// that are embedded into Go code still could be interpreted properly.
func TrimLeadingWhitespace(commandStr string) (newCommand string) {
lines := strings.Split(strings.ReplaceAll(commandStr, "\t", " "), "\n")
leadingWhitespace := 1<<31 - 1
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package internal
package commands

import (
"testing"
Expand Down
2 changes: 2 additions & 0 deletions common/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,8 @@ func (c *DatabricksClient) niceAuthError(message string) error {
}
info = ". " + strings.Join(infos, ". ")
}
info = strings.TrimSuffix(info, ".")
message = strings.TrimSuffix(message, ".")
docUrl := "https://registry.terraform.io/providers/databrickslabs/databricks/latest/docs#authentication"
return fmt.Errorf("%s%s. Please check %s for details", message, info, docUrl)
}
Expand Down
142 changes: 82 additions & 60 deletions common/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"io"
"log"
"net/http"
"net/url"
Expand Down Expand Up @@ -179,7 +179,7 @@ func (c *DatabricksClient) commonErrorClarity(resp *http.Response) *APIError {
}

func (c *DatabricksClient) parseError(resp *http.Response) APIError {
body, err := ioutil.ReadAll(resp.Body)
body, err := io.ReadAll(resp.Body)
if err != nil {
return APIError{
Message: err.Error(),
Expand Down Expand Up @@ -345,16 +345,21 @@ func (c *DatabricksClient) completeUrl(r *http.Request) error {
return nil
}

// scimPathVisitorFactory is a separate method for the sake of unit tests
func (c *DatabricksClient) scimVisitor(r *http.Request) error {
r.Header.Set("Content-Type", "application/scim+json; charset=utf-8")
if c.isAccountsClient() && c.AccountID != "" {
// until `/preview` is there for workspace scim,
// `/api/2.0` is added by completeUrl visitor
r.URL.Path = strings.ReplaceAll(r.URL.Path, "/api/2.0/preview",
fmt.Sprintf("/api/2.0/accounts/%s", c.AccountID))
}
return nil
}

// Scim sets SCIM headers
func (c *DatabricksClient) Scim(ctx context.Context, method, path string, request interface{}, response interface{}) error {
body, err := c.authenticatedQuery(ctx, method, path, request, c.completeUrl, func(r *http.Request) error {
r.Header.Set("Content-Type", "application/scim+json; charset=utf-8")
if c.isAccountsClient() && c.AccountID != "" {
// until `/preview` is there for workspace scim
r.URL.Path = strings.ReplaceAll(path, "/preview", fmt.Sprintf("/api/2.0/accounts/%s", c.AccountID))
}
return nil
})
body, err := c.authenticatedQuery(ctx, method, path, request, c.completeUrl, c.scimVisitor)
if err != nil {
return err
}
Expand Down Expand Up @@ -402,7 +407,9 @@ func (c *DatabricksClient) redactedDump(body []byte) (res string) {
if len(body) == 0 {
return
}

if body[0] != '{' {
return fmt.Sprintf("[non-JSON document of %d bytes]", len(body))
}
var requestMap map[string]interface{}
err := json.Unmarshal(body, &requestMap)
if err != nil {
Expand Down Expand Up @@ -465,21 +472,21 @@ func (c *DatabricksClient) genericQuery(ctx context.Context, method, requestURL
return nil, fmt.Errorf("DatabricksClient is not configured")
}
if err = c.rateLimiter.Wait(ctx); err != nil {
return nil, err
return nil, fmt.Errorf("rate limited: %w", err)
}
requestBody, err := makeRequestBody(method, &requestURL, data, true)
requestBody, err := makeRequestBody(method, &requestURL, data)
if err != nil {
return nil, err
return nil, fmt.Errorf("request marshal: %w", err)
}
request, err := http.NewRequestWithContext(ctx, method, requestURL, bytes.NewBuffer(requestBody))
if err != nil {
return nil, err
return nil, fmt.Errorf("new request: %w", err)
}
request.Header.Set("User-Agent", c.userAgent(ctx))
for _, requestVisitor := range visitors {
err = requestVisitor(request)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed visitor: %w", err)
}
}
headers := c.createDebugHeaders(request.Header, c.Host)
Expand All @@ -488,78 +495,93 @@ func (c *DatabricksClient) genericQuery(ctx context.Context, method, requestURL

r, err := retryablehttp.FromRequest(request)
if err != nil {
return nil, err
return nil, err // no error invariants possible because of `makeRequestBody`
}
resp, err := c.httpClient.Do(r)
// retryablehttp library now returns only wrapped errors
var ae APIError
if errors.As(err, &ae) {
// don't re-wrap, as upper layers may depend on handling common.APIError
return nil, ae
}
if err != nil {
return nil, err
// i don't even know which errors in the real world would end up here.
// `retryablehttp` package nicely wraps _everything_ to `url.Error`.
return nil, fmt.Errorf("failed request: %w", err)
}
defer func() {
if ferr := resp.Body.Close(); ferr != nil {
err = ferr
err = fmt.Errorf("failed to close: %w", ferr)
}
}()
body, err = ioutil.ReadAll(resp.Body)
body, err = io.ReadAll(resp.Body)
if err != nil {
return nil, err
return nil, fmt.Errorf("response body: %w", err)
}
headers = c.createDebugHeaders(resp.Header, "")
log.Printf("[DEBUG] %s %s %s <- %s %s", resp.Status, headers, c.redactedDump(body), method, strings.ReplaceAll(request.URL.Path, "\n", ""))
return body, nil
}

func makeRequestBody(method string, requestURL *string, data interface{}, marshalJSON bool) ([]byte, error) {
func makeQueryString(data interface{}) (string, error) {
inputVal := reflect.ValueOf(data)
inputType := reflect.TypeOf(data)
if inputType.Kind() == reflect.Map {
s := []string{}
keys := inputVal.MapKeys()
// sort map keys by their string repr, so that tests can be deterministic
sort.Slice(keys, func(i, j int) bool {
return keys[i].String() < keys[j].String()
})
for _, k := range keys {
v := inputVal.MapIndex(k)
if v.IsZero() {
continue
}
s = append(s, fmt.Sprintf("%s=%s",
strings.Replace(url.QueryEscape(fmt.Sprintf("%v", k.Interface())), "+", "%20", -1),
strings.Replace(url.QueryEscape(fmt.Sprintf("%v", v.Interface())), "+", "%20", -1)))
}
return "?" + strings.Join(s, "&"), nil
}
if inputType.Kind() == reflect.Struct {
params, err := query.Values(data)
if err != nil {
return "", fmt.Errorf("cannot create query string: %w", err)
}
return "?" + params.Encode(), nil
}
return "", fmt.Errorf("unsupported query string data: %#v", data)
}

func makeRequestBody(method string, requestURL *string, data interface{}) ([]byte, error) {
var requestBody []byte
if data == nil && (method == "DELETE" || method == "GET") {
return requestBody, nil
}
if method == "GET" {
inputVal := reflect.ValueOf(data)
inputType := reflect.TypeOf(data)
switch inputType.Kind() {
case reflect.Map:
s := []string{}
keys := inputVal.MapKeys()
// sort map keys by their string repr, so that tests can be deterministic
sort.Slice(keys, func(i, j int) bool {
return keys[i].String() < keys[j].String()
})
for _, k := range keys {
v := inputVal.MapIndex(k)
if v.IsZero() {
continue
}
s = append(s, fmt.Sprintf("%s=%s",
strings.Replace(url.QueryEscape(fmt.Sprintf("%v", k.Interface())), "+", "%20", -1),
strings.Replace(url.QueryEscape(fmt.Sprintf("%v", v.Interface())), "+", "%20", -1)))
}
*requestURL += "?" + strings.Join(s, "&")
case reflect.Struct:
params, err := query.Values(data)
if err != nil {
return nil, err
}
*requestURL += "?" + params.Encode()
default:
return requestBody, fmt.Errorf("unsupported request data: %#v", data)
qs, err := makeQueryString(data)
if err != nil {
return nil, err
}
} else {
if marshalJSON {
bodyBytes, err := json.MarshalIndent(data, "", " ")
if err != nil {
return nil, err
}
requestBody = bodyBytes
} else {
requestBody = []byte(data.(string))
*requestURL += qs
return requestBody, nil
}
if reader, ok := data.(io.Reader); ok {
raw, err := io.ReadAll(reader)
if err != nil {
return nil, fmt.Errorf("failed to read from reader: %w", err)
}
return raw, nil
}
if str, ok := data.(string); ok {
return []byte(str), nil
}
bodyBytes, err := json.MarshalIndent(data, "", " ")
if err != nil {
return nil, fmt.Errorf("request marshal failure: %w", err)
}
return requestBody, nil
return bodyBytes, nil
}

func onlyNBytes(j string, numBytes int) string {
Expand Down
Loading

0 comments on commit 84146de

Please sign in to comment.