Skip to content

Commit 5455880

Browse files
authored
Merge pull request #67 from cocoide/feature/github-login
GithubのSSO機能
2 parents 0d40b91 + c0ac1e5 commit 5455880

14 files changed

+1376
-154
lines changed

cmd/login.go

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/*
2+
Copyright © 2023 NAME HERE <EMAIL ADDRESS>
3+
*/
4+
package cmd
5+
6+
import (
7+
"fmt"
8+
"sync"
9+
10+
"github.com/cocoide/commitify/internal/gateway"
11+
"github.com/cocoide/commitify/internal/usecase"
12+
"github.com/fatih/color"
13+
"github.com/spf13/cobra"
14+
)
15+
16+
const (
17+
DeviceActivateURL = "https://github.com/login/device"
18+
)
19+
20+
var loginCmd = &cobra.Command{
21+
Use: "login",
22+
Short: "login by github",
23+
Long: `by login you can use auto pull request feature`,
24+
Run: func(cmd *cobra.Command, args []string) {
25+
httpClient := gateway.NewHttpClient()
26+
u := usecase.NewLoginCmdUsecase(httpClient)
27+
res, err := u.BeginGithubSSO()
28+
if err != nil {
29+
fmt.Printf("ログイン中にエラーが発生: %v", err)
30+
}
31+
32+
var wg sync.WaitGroup
33+
wg.Add(1)
34+
35+
errChan := make(chan error, 1)
36+
37+
go func() {
38+
defer wg.Done()
39+
40+
req := &usecase.ScheduleVerifyAuthRequest{
41+
DeviceCode: res.DeviceCode, Interval: res.Interval, ExpiresIn: res.ExpiresIn}
42+
err := u.ScheduleVerifyAuth(req)
43+
errChan <- err
44+
}()
45+
fmt.Printf("以下のページで認証コード『%s』を入力して下さい。\n", res.UserCode)
46+
fmt.Printf(color.HiCyanString("➡️ %s\n"), DeviceActivateURL)
47+
wg.Wait()
48+
err = <-errChan
49+
if err != nil {
50+
fmt.Printf("🚨認証エラーが発生: %v", err)
51+
} else {
52+
fmt.Printf("**🎉認証が正常に完了**\n")
53+
}
54+
},
55+
}
56+
57+
func init() {
58+
rootCmd.AddCommand(loginCmd)
59+
}

cmd/suggest.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ func NewSuggestModel() *suggestModel {
129129
if err != nil {
130130
log.Fatalf("設定ファイルの読み込みができませんでした")
131131
}
132-
switch config.WithGptRequestLocation() {
132+
switch config.GptRequestLocation() {
133133
case entity.Client:
134134
nlp := gateway.NewOpenAIGateway(context.Background())
135135
commitMessageService = gateway.NewClientCommitMessageGateway(nlp)

internal/entity/config.go

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ package entity
33
import (
44
"encoding/json"
55
"fmt"
6-
"github.com/cocoide/commitify-grpc-server/pkg/pb"
6+
pb "github.com/cocoide/commitify/proto/gen"
77
"github.com/spf13/viper"
88
"os"
99
)
@@ -38,6 +38,7 @@ type Config struct {
3838
UseLanguage int `json:"UseLanguage"`
3939
CommitFormat int `json:"CommitFormat"`
4040
AISource int `json:"AISource"`
41+
GithubToken string `json:"GithubToken"`
4142
}
4243

4344
func (c *Config) Config2PbVars() (pb.CodeFormatType, pb.LanguageType) {
@@ -81,7 +82,7 @@ func ReadConfig() (Config, error) {
8182
return result, nil
8283
}
8384

84-
func WriteConfig(config Config) error {
85+
func (c Config) WriteConfig() error {
8586
homePath, err := os.UserHomeDir()
8687
if err != nil {
8788
return err
@@ -91,7 +92,7 @@ func WriteConfig(config Config) error {
9192
viper.SetConfigName("config")
9293
viper.SetConfigType("yaml")
9394
configMap := make(map[string]interface{})
94-
configBytes, err := json.Marshal(config)
95+
configBytes, err := json.Marshal(c)
9596
if err != nil {
9697
return fmt.Errorf("error marshalling config: %s", err.Error())
9798
}
@@ -108,6 +109,11 @@ func WriteConfig(config Config) error {
108109
return nil
109110
}
110111

112+
func (c *Config) WithGithubToken(token string) *Config {
113+
c.GithubToken = token
114+
return c
115+
}
116+
111117
func SaveConfig(configIndex, updateConfigParamInt int, updateConfigParamStr string) error {
112118
currentConfig, err := ReadConfig()
113119
if err != nil {
@@ -125,15 +131,15 @@ func SaveConfig(configIndex, updateConfigParamInt int, updateConfigParamStr stri
125131
currentConfig.AISource = updateConfigParamInt
126132
}
127133

128-
err = WriteConfig(currentConfig)
134+
err = currentConfig.WriteConfig()
129135
if err != nil {
130136
return err
131137
}
132138

133139
return nil
134140
}
135141

136-
func (c *Config) WithGptRequestLocation() GptRequestLocation {
142+
func (c *Config) GptRequestLocation() GptRequestLocation {
137143
switch c.AISource {
138144
case 0:
139145
return Server

internal/gateway/grpc.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@ package gateway
22

33
import (
44
"crypto/tls"
5-
"github.com/cocoide/commitify-grpc-server/pkg/pb"
65
"github.com/cocoide/commitify/internal/entity"
76
"github.com/cocoide/commitify/internal/service"
7+
pb "github.com/cocoide/commitify/proto/gen"
88
"golang.org/x/net/context"
99
"google.golang.org/grpc"
1010
"google.golang.org/grpc/credentials"

internal/gateway/http_client.go

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
package gateway
2+
3+
import (
4+
"fmt"
5+
"io"
6+
"net/http"
7+
"strconv"
8+
)
9+
10+
type HttpClient struct {
11+
client *http.Client
12+
endpoint string
13+
headers map[string]string
14+
params map[string]interface{}
15+
body io.Reader
16+
}
17+
18+
func NewHttpClient() *HttpClient {
19+
return &HttpClient{
20+
client: &http.Client{},
21+
headers: make(map[string]string),
22+
params: make(map[string]interface{}),
23+
}
24+
}
25+
26+
func (h *HttpClient) WithBaseURL(baseURL string) *HttpClient {
27+
h.endpoint = baseURL
28+
return h
29+
}
30+
31+
func (h *HttpClient) WithBearerToken(token string) *HttpClient {
32+
h.headers["Authorization"] = fmt.Sprintf("Bearer %s", token)
33+
return h
34+
}
35+
36+
func (h *HttpClient) WithPath(path string) *HttpClient {
37+
h.endpoint = h.endpoint + "/" + path
38+
return h
39+
}
40+
41+
func (h *HttpClient) WithParam(key string, value interface{}) *HttpClient {
42+
h.params[key] = value
43+
return h
44+
}
45+
46+
type HttpMethod int
47+
48+
const (
49+
GET HttpMethod = iota + 1
50+
POST
51+
DELTE
52+
PUT
53+
)
54+
55+
func (h *HttpClient) Execute(method HttpMethod) ([]byte, error) {
56+
var methodName string
57+
switch method {
58+
case GET:
59+
methodName = "GET"
60+
case POST:
61+
methodName = "POST"
62+
case DELTE:
63+
methodName = "DELETE"
64+
case PUT:
65+
methodName = "PUT"
66+
}
67+
client := h.client
68+
69+
req, err := http.NewRequest(methodName, h.endpoint, h.body)
70+
if err != nil {
71+
return nil, err
72+
}
73+
74+
for k, v := range h.headers {
75+
req.Header.Add(k, v)
76+
}
77+
78+
query := req.URL.Query()
79+
for key, value := range h.params {
80+
switch v := value.(type) {
81+
case string:
82+
query.Add(key, v)
83+
case int:
84+
query.Add(key, strconv.Itoa(v))
85+
case bool:
86+
query.Add(key, strconv.FormatBool(v))
87+
default:
88+
return nil, fmt.Errorf("Failed to parse param value: %v", value)
89+
}
90+
}
91+
req.URL.RawQuery = query.Encode()
92+
resp, err := client.Do(req)
93+
if err != nil {
94+
return nil, err
95+
}
96+
defer resp.Body.Close()
97+
98+
body, err := io.ReadAll(resp.Body)
99+
if err != nil {
100+
return nil, err
101+
}
102+
return body, nil
103+
}

internal/usecase/login_cmd.go

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
package usecase
2+
3+
import (
4+
"fmt"
5+
"github.com/cocoide/commitify/internal/entity"
6+
"github.com/cocoide/commitify/internal/gateway"
7+
"net/url"
8+
"strconv"
9+
"time"
10+
)
11+
12+
const (
13+
GithubClientID = "b27d87c28752d2363922"
14+
GithubScope = "repo"
15+
GrantType = "urn:ietf:params:oauth:grant-type:device_code"
16+
)
17+
18+
type LoginCmdUsecase struct {
19+
http *gateway.HttpClient
20+
}
21+
22+
func NewLoginCmdUsecase(http *gateway.HttpClient) *LoginCmdUsecase {
23+
http.WithBaseURL("https://github.com/login")
24+
return &LoginCmdUsecase{http: http}
25+
}
26+
27+
type BeginGithubSSOResponse struct {
28+
DeviceCode string
29+
UserCode string
30+
Interval int
31+
ExpiresIn int
32+
}
33+
34+
func (u *LoginCmdUsecase) BeginGithubSSO() (*BeginGithubSSOResponse, error) {
35+
b, err := u.http.WithPath("device/code").
36+
WithParam("client_id", GithubClientID).
37+
WithParam("scope", GithubScope).
38+
Execute(gateway.POST)
39+
if err != nil {
40+
return nil, err
41+
}
42+
values, err := url.ParseQuery(string(b))
43+
if err != nil {
44+
return nil, err
45+
}
46+
deviceCode := values.Get("device_code")
47+
userCode := values.Get("user_code")
48+
expiresIn, err := strconv.Atoi(values.Get("expires_in"))
49+
if err != nil {
50+
return nil, err
51+
}
52+
interval, err := strconv.Atoi(values.Get("interval"))
53+
if err != nil {
54+
return nil, err
55+
}
56+
if deviceCode == "" || userCode == "" {
57+
return nil, fmt.Errorf("failed to parse code")
58+
}
59+
return &BeginGithubSSOResponse{
60+
DeviceCode: deviceCode,
61+
UserCode: userCode,
62+
ExpiresIn: expiresIn,
63+
Interval: interval,
64+
}, nil
65+
}
66+
67+
type ScheduleVerifyAuthRequest struct {
68+
DeviceCode string
69+
Interval int
70+
ExpiresIn int
71+
}
72+
73+
func (u *LoginCmdUsecase) ScheduleVerifyAuth(req *ScheduleVerifyAuthRequest) error {
74+
u.http = gateway.NewHttpClient().
75+
WithBaseURL("https://github.com/login").
76+
WithPath("oauth/access_token").
77+
WithParam("client_id", GithubClientID).
78+
WithParam("device_code", req.DeviceCode).
79+
WithParam("grant_type", GrantType)
80+
81+
timeout := time.After(time.Duration(req.ExpiresIn) * time.Second)
82+
ticker := time.NewTicker(time.Duration(req.Interval) * time.Second)
83+
defer ticker.Stop()
84+
85+
for {
86+
select {
87+
case <-timeout:
88+
return fmt.Errorf("認証プロセスがタイムアウトしました")
89+
case <-ticker.C:
90+
b, err := u.http.Execute(gateway.POST)
91+
if err != nil {
92+
return err
93+
}
94+
values, err := url.ParseQuery(string(b))
95+
if err != nil {
96+
return err
97+
}
98+
accessToken := values.Get("access_token")
99+
if accessToken != "" {
100+
config, err := entity.ReadConfig()
101+
if err != nil {
102+
return err
103+
}
104+
config.WithGithubToken(accessToken)
105+
if err := config.WriteConfig(); err != nil {
106+
return err
107+
}
108+
return nil
109+
}
110+
if newIntervalStr := values.Get("interval"); newIntervalStr != "" {
111+
newInterval, err := strconv.Atoi(newIntervalStr)
112+
if err != nil {
113+
return err
114+
}
115+
ticker.Stop()
116+
ticker = time.NewTicker(time.Duration(newInterval) * time.Second)
117+
}
118+
}
119+
}
120+
}

0 commit comments

Comments
 (0)