Skip to content

Commit

Permalink
fix: require at least one auth method (#24)
Browse files Browse the repository at this point in the history
  • Loading branch information
everpcpc authored Dec 7, 2022
1 parent a95f19e commit 70fa12d
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 11 deletions.
29 changes: 20 additions & 9 deletions restful.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func (c *APIClient) makeURL(path string, args ...interface{}) string {
return fmt.Sprintf(format, args...)
}

func (c *APIClient) makeHeaders() http.Header {
func (c *APIClient) makeHeaders() (http.Header, error) {
headers := http.Header{}
if c.Tenant != "" {
headers.Set(DatabendTenantHeader, c.Tenant)
Expand All @@ -106,9 +106,11 @@ func (c *APIClient) makeHeaders() http.Header {
headers.Set(Authorization, fmt.Sprintf("Basic %s", encode(c.User, c.Password)))
} else if c.AccessToken != "" {
headers.Set(Authorization, fmt.Sprintf("Bearer %s", c.AccessToken))
} else {
return nil, errors.New("no user or access token")
}

return headers
return headers, nil
}

func encode(name string, key string) string {
Expand All @@ -127,7 +129,10 @@ var databendInsecureTransport = &http.Transport{
}

func (c *APIClient) DoQuery(ctx context.Context, query string, args []driver.Value) (*QueryResponse, error) {
headers := c.makeHeaders()
headers, err := c.makeHeaders()
if err != nil {
return nil, err
}
q, err := buildQuery(query, args)
if err != nil {
return nil, err
Expand Down Expand Up @@ -206,10 +211,13 @@ func (c *APIClient) QuerySync(ctx context.Context, query string, args []driver.V
}

func (c *APIClient) QueryPage(queryId, path string) (*QueryResponse, error) {
headers := c.makeHeaders()
headers, err := c.makeHeaders()
if err != nil {
return nil, err
}
headers.Set("queryID", queryId)
var result QueryResponse
err := retry.Do(
err = retry.Do(
func() error {
err := c.DoRequest("GET", path, headers, nil, &result)
if err != nil {
Expand Down Expand Up @@ -311,10 +319,13 @@ func (c *APIClient) uploadToStageByAPI(stage, fileName string) error {
url := c.makeURL(path)
httpReq, err := http.NewRequest("PUT", url, body)
if err != nil {
return err
return errors.Wrap(err, "failed to create http request")
}

httpReq.Header = c.makeHeaders()
httpReq.Header, err = c.makeHeaders()
if err != nil {
return errors.Wrap(err, "failed to make headers")
}
if len(c.Host) > 0 {
httpReq.Host = c.Host
}
Expand All @@ -326,13 +337,13 @@ func (c *APIClient) uploadToStageByAPI(stage, fileName string) error {
}
httpResp, err := httpClient.Do(httpReq)
if err != nil {
return fmt.Errorf("failed http do request: %w", err)
return errors.Wrap(err, "failed http do request")
}
defer httpResp.Body.Close()

httpRespBody, err := io.ReadAll(httpResp.Body)
if err != nil {
return fmt.Errorf("io read error: %w", err)
return errors.Wrap(err, "failed to read http response body")
}

if httpResp.StatusCode == http.StatusUnauthorized {
Expand Down
6 changes: 4 additions & 2 deletions restful_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ func TestMakeHeadersUserPassword(t *testing.T) {
Host: "localhost:8000",
Tenant: "default",
}
headers := c.makeHeaders()
headers, err := c.makeHeaders()
assert.Nil(t, err)
assert.Equal(t, headers["Authorization"], []string{"Basic cm9vdDpyb290"})
assert.Equal(t, headers["X-Databend-Tenant"], []string{"default"})
}
Expand All @@ -25,7 +26,8 @@ func TestMakeHeadersAccessToken(t *testing.T) {
AccessToken: "abc123",
Warehouse: "small-abc",
}
headers := c.makeHeaders()
headers, err := c.makeHeaders()
assert.Nil(t, err)
assert.Equal(t, headers["Authorization"], []string{"Bearer abc123"})
assert.Equal(t, headers["X-Databend-Tenant"], []string{"tn3ftqihs"})
assert.Equal(t, headers["X-Databend-Warehouse"], []string{"small-abc"})
Expand Down

0 comments on commit 70fa12d

Please sign in to comment.