Skip to content

Commit 0c53b19

Browse files
authored
Merge pull request #5 from codingpot/refactor-1
Implement PaperGet() method
2 parents baeaf63 + d687e3b commit 0c53b19

File tree

11 files changed

+182
-89
lines changed

11 files changed

+182
-89
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@
1414
# Dependency directories (remove the comment below to include it)
1515
# vendor/
1616
/.idea/
17+
/.env

client.go

Lines changed: 2 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import (
55
"encoding/json"
66
"errors"
77
"fmt"
8-
"github.com/codingpot/paperswithcode-go/models"
8+
"github.com/codingpot/paperswithcode-go/internal/transport"
99
"net/http"
1010
"net/url"
1111
"time"
@@ -22,6 +22,7 @@ type ClientOption func(*Client)
2222
func WithAPIToken(apiToken string) ClientOption {
2323
return func(client *Client) {
2424
client.apiToken = apiToken
25+
client.HTTPClient.Transport = transport.NewTransportWithAuthHeader(apiToken)
2526
}
2627
}
2728

@@ -52,37 +53,6 @@ type errorResponse struct {
5253
Message string `json:"message"`
5354
}
5455

55-
// Paper represents a specific Paper's Information by the title
56-
type Paper struct {
57-
ID string `json:"id"`
58-
ArxivID string `json:"arxiv_id,omitempty"`
59-
URLAbs string `json:"url_abs"`
60-
URLPDF string `json:"url_pdf"`
61-
Title string `json:"title"`
62-
Abstract string `json:"abstract"`
63-
Authors []string `json:"authors"`
64-
Published string `json:"published"`
65-
}
66-
67-
func (c *Client) GetPaper(ctx context.Context, id string) (*Paper, error) {
68-
fmt.Println(id)
69-
url := fmt.Sprintf("%s/papers/%s/", c.BaseURL, url.QueryEscape(id))
70-
req, err := http.NewRequest("GET", url, nil)
71-
72-
if err != nil {
73-
return nil, err
74-
}
75-
76-
req = req.WithContext(ctx)
77-
78-
res := Paper{}
79-
if err := c.sendRequest(req, &res); err != nil {
80-
return nil, err
81-
}
82-
83-
return &res, nil
84-
}
85-
8656
// Method list used by Paper's ID
8757
type MethodList struct {
8858
Count int `json:"count"`
@@ -101,7 +71,6 @@ func (c *Client) GetMethodList(ctx context.Context, id string) (*MethodList, err
10171
fmt.Println(id)
10272
url := fmt.Sprintf("%s/papers/%s/methods", c.BaseURL, url.QueryEscape(id))
10373
req, err := http.NewRequest("GET", url, nil)
104-
10574
if err != nil {
10675
return nil, err
10776
}
@@ -177,48 +146,3 @@ func (c *Client) sendRequest(req *http.Request, v interface{}) error {
177146

178147
return nil
179148
}
180-
181-
type PaperListParams struct {
182-
Query string
183-
Page int
184-
Limit int
185-
}
186-
187-
func (p PaperListParams) Build() string {
188-
if p.Query == "" {
189-
return fmt.Sprintf("items_per_page=%d&page=%d", p.Limit, p.Page)
190-
}
191-
return fmt.Sprintf("q=%s&items_per_page=%d&page=%d", url.QueryEscape(p.Query), p.Limit, p.Page)
192-
}
193-
194-
func (c *Client) PaperList(params PaperListParams) (*models.PaperListResult, error) {
195-
papersListURL := c.BaseURL + "/papers?" + params.Build()
196-
197-
request, err := http.NewRequest(http.MethodGet, papersListURL, nil /*body*/)
198-
if err != nil {
199-
return nil, err
200-
}
201-
request.Header.Set("Authorization", "Token "+c.apiToken)
202-
203-
response, err := c.HTTPClient.Get(papersListURL)
204-
if err != nil {
205-
return nil, err
206-
}
207-
208-
var paperListResult models.PaperListResult
209-
210-
err = json.NewDecoder(response.Body).Decode(&paperListResult)
211-
if err != nil {
212-
return nil, err
213-
}
214-
215-
return &paperListResult, nil
216-
}
217-
218-
func PaperListParamsDefault() PaperListParams {
219-
return PaperListParams{
220-
Query: "",
221-
Page: 1,
222-
Limit: 50,
223-
}
224-
}

client_test.go

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
package paperswithcode_go
22

33
import (
4-
"os"
4+
"github.com/codingpot/paperswithcode-go/internal/testutils"
5+
"net/http"
6+
"net/http/httptest"
57
"testing"
68

79
"github.com/stretchr/testify/assert"
@@ -10,16 +12,27 @@ import (
1012
var apiToken string
1113

1214
func init() {
13-
var ok bool
14-
apiToken, ok = os.LookupEnv("PAPERSWITHCODE_API_TOKEN")
15-
16-
if !ok {
17-
panic("expected PAPERSWITHCODE_API_TOKEN environment variable")
18-
}
15+
apiToken = testutils.MustExtractAPITokenFromEnv()
1916
}
2017

21-
func TestClient_PaperList(t *testing.T) {
22-
client := NewClient(WithAPIToken(apiToken))
23-
_, err := client.PaperList(PaperListParamsDefault())
18+
func TestWithAPIToken(t *testing.T) {
19+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
20+
w.WriteHeader(http.StatusOK)
21+
_, err := w.Write([]byte("ok"))
22+
assert.NoError(t, err)
23+
}))
24+
defer server.Close()
25+
26+
c := NewClient(WithAPIToken("MY_TOKEN"))
27+
emptyRequest, err := http.NewRequest(http.MethodGet, server.URL, nil)
28+
2429
assert.NoError(t, err)
30+
_, err = c.HTTPClient.Transport.RoundTrip(emptyRequest)
31+
assert.NoError(t, err)
32+
assert.Equal(t, "Token MY_TOKEN", emptyRequest.Header.Get("Authorization"))
33+
}
34+
35+
func TestTransportIsNotProvidedWhenNoAPIIsProvided(t *testing.T) {
36+
c := NewClient()
37+
assert.Nil(t, c.HTTPClient.Transport)
2538
}

cmd/client/main.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ func main() {
1919
paper := "generative adversarial networks"
2020
paper = strings.ReplaceAll(paper, " ", "-")
2121

22-
paperList, _ := c.GetPaper(ctx, paper)
22+
paperList, _ := c.PaperGet(paper)
2323
fmt.Println(paperList)
2424
fmt.Println()
2525
fmt.Println()

dummy/paper_get_response.json

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
{
2+
"id": "generative-adversarial-networks",
3+
"arxiv_id": "1406.2661",
4+
"nips_id": null,
5+
"url_abs": "https://arxiv.org/abs/1406.2661v1",
6+
"url_pdf": "https://arxiv.org/pdf/1406.2661v1.pdf",
7+
"title": "Generative Adversarial Networks",
8+
"abstract": "We propose a new framework for estimating generative models via an adversarial process, in which we simultaneously train two models: a generative model G that captures the data distribution, and a discriminative model D that estimates the probability that a sample came from the training data rather than G. The training procedure for G is to maximize the probability of D making a mistake. This framework corresponds to a minimax two-player game. In the space of arbitrary functions G and D, a unique solution exists, with G recovering the training data distribution and D equal to 1/2 everywhere. In the case where G and D are defined by multilayer perceptrons, the entire system can be trained with backpropagation. There is no need for any Markov chains or unrolled approximate inference networks during either training or generation of samples. Experiments demonstrate the potential of the framework through qualitative and quantitative evaluation of the generated samples.",
9+
"authors": [
10+
"Ian J. Goodfellow",
11+
"Jean Pouget-Abadie",
12+
"Mehdi Mirza",
13+
"Bing Xu",
14+
"David Warde-Farley",
15+
"Sherjil Ozair",
16+
"Aaron Courville",
17+
"Yoshua Bengio"
18+
],
19+
"published": "2014-06-10",
20+
"conference": "generative-adversarial-networks-1",
21+
"conference_url_abs": null,
22+
"conference_url_pdf": null,
23+
"proceeding": "proceedings-of-the-27th-international"
24+
}

internal/testutils/testutils.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
package testutils
2+
3+
import "os"
4+
5+
func MustExtractAPITokenFromEnv() string {
6+
apiToken, ok := os.LookupEnv("PAPERSWITHCODE_API_TOKEN")
7+
8+
if !ok {
9+
panic("expected PAPERSWITHCODE_API_TOKEN environment variable")
10+
}
11+
return apiToken
12+
}

internal/transport/transport.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
package transport
2+
3+
import "net/http"
4+
5+
type transportWithAuthHeader struct {
6+
r http.RoundTripper
7+
token string
8+
}
9+
10+
var _ http.RoundTripper = transportWithAuthHeader{}
11+
12+
func (t transportWithAuthHeader) RoundTrip(request *http.Request) (*http.Response, error) {
13+
request.Header.Set("Authorization", "Token "+t.token)
14+
return t.r.RoundTrip(request)
15+
}
16+
17+
func NewTransportWithAuthHeader(token string) http.RoundTripper {
18+
return transportWithAuthHeader{
19+
r: http.DefaultTransport,
20+
token: token,
21+
}
22+
}

paperget.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
package paperswithcode_go
2+
3+
import (
4+
"encoding/json"
5+
"fmt"
6+
"github.com/codingpot/paperswithcode-go/models"
7+
"net/url"
8+
)
9+
10+
// PaperGet returns a single paper. Note that paperID is hyphen cased (e.g., generative-adversarial-networks).
11+
func (c *Client) PaperGet(paperID string) (*models.PaperListResultItem, error) {
12+
paperGetURL := fmt.Sprintf("%s/papers/%s/", c.BaseURL, url.QueryEscape(paperID))
13+
response, err := c.HTTPClient.Get(paperGetURL)
14+
if err != nil {
15+
return nil, err
16+
}
17+
18+
var paperGetResult models.PaperListResultItem
19+
err = json.NewDecoder(response.Body).Decode(&paperGetResult)
20+
if err != nil {
21+
return nil, err
22+
}
23+
24+
return &paperGetResult, nil
25+
}

paperget_test.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
package paperswithcode_go
2+
3+
import (
4+
"github.com/stretchr/testify/assert"
5+
"testing"
6+
)
7+
8+
func TestClient_PaperGet(t *testing.T) {
9+
c := NewClient(WithAPIToken(apiToken))
10+
got, err := c.PaperGet("generative-adversarial-networks")
11+
assert.NoError(t, err)
12+
assert.Equal(t, "generative-adversarial-networks", got.ID)
13+
}

paperlist.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
package paperswithcode_go
2+
3+
import (
4+
"encoding/json"
5+
"fmt"
6+
"github.com/codingpot/paperswithcode-go/models"
7+
"net/url"
8+
)
9+
10+
func (c *Client) PaperList(params PaperListParams) (*models.PaperListResult, error) {
11+
papersListURL := c.BaseURL + "/papers?" + params.Build()
12+
13+
response, err := c.HTTPClient.Get(papersListURL)
14+
if err != nil {
15+
return nil, err
16+
}
17+
18+
var paperListResult models.PaperListResult
19+
20+
err = json.NewDecoder(response.Body).Decode(&paperListResult)
21+
if err != nil {
22+
return nil, err
23+
}
24+
25+
return &paperListResult, nil
26+
}
27+
28+
type PaperListParams struct {
29+
Query string
30+
Page int
31+
Limit int
32+
}
33+
34+
func (p PaperListParams) Build() string {
35+
if p.Query == "" {
36+
return fmt.Sprintf("items_per_page=%d&page=%d", p.Limit, p.Page)
37+
}
38+
return fmt.Sprintf("q=%s&items_per_page=%d&page=%d", url.QueryEscape(p.Query), p.Limit, p.Page)
39+
}
40+
41+
func PaperListParamsDefault() PaperListParams {
42+
return PaperListParams{
43+
Query: "",
44+
Page: 1,
45+
Limit: 50,
46+
}
47+
}

0 commit comments

Comments
 (0)