From 88645277c7c72b158c769e72771da970e32fc736 Mon Sep 17 00:00:00 2001 From: yumosx Date: Sun, 15 Jun 2025 14:46:26 +0800 Subject: [PATCH 01/24] add quota --- internal/repository/dao/quota.go | 33 ++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 internal/repository/dao/quota.go diff --git a/internal/repository/dao/quota.go b/internal/repository/dao/quota.go new file mode 100644 index 0000000..6f50751 --- /dev/null +++ b/internal/repository/dao/quota.go @@ -0,0 +1,33 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dao + +type TempQuota struct { + ID int64 `gorm:"primaryKey;autoIncrement;column:id"` + UID string `gorm:"column:uid"` + Amount int64 `gorm:"column:amount"` + StartTime int64 `gorm:"column:start_time"` + EndTime int64 `gorm:"column:end_time"` + Ctime int64 `gorm:"column:ctime"` + Utime int64 `gorm:"column:utime"` +} + +type Quota struct { + ID int64 `gorm:"primaryKey;autoIncrement;colum:id"` + UID string `gorm:"column:uid"` + Amount int64 `gorm:"colum:amount"` + Ctime int64 `gorm:"column:ctime"` + Utime int64 `gorm:"column:utime"` +} From f9b38d0edf6205b5fcad6a8b0afcc1ce2b7781d0 Mon Sep 17 00:00:00 2001 From: yumosx Date: Sun, 15 Jun 2025 16:24:57 +0800 Subject: [PATCH 02/24] fix typo --- internal/repository/dao/quota.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/repository/dao/quota.go b/internal/repository/dao/quota.go index 6f50751..362add0 100644 --- a/internal/repository/dao/quota.go +++ b/internal/repository/dao/quota.go @@ -25,7 +25,7 @@ type TempQuota struct { } type Quota struct { - ID int64 `gorm:"primaryKey;autoIncrement;colum:id"` + ID int64 `gorm:"primaryKey;autoIncrement;column:id"` UID string `gorm:"column:uid"` Amount int64 `gorm:"colum:amount"` Ctime int64 `gorm:"column:ctime"` From 33f17f8f3297256497e19a268646a142fb8e8970 Mon Sep 17 00:00:00 2001 From: yumosx Date: Fri, 27 Jun 2025 19:49:45 +0800 Subject: [PATCH 03/24] add deduct --- internal/domain/quota.go | 6 ++ internal/repository/dao/quota.go | 128 +++++++++++++++++++++++++++++-- internal/repository/quota.go | 19 +++++ 3 files changed, 148 insertions(+), 5 deletions(-) create mode 100644 internal/domain/quota.go create mode 100644 internal/repository/quota.go diff --git a/internal/domain/quota.go b/internal/domain/quota.go new file mode 100644 index 0000000..b9f7391 --- /dev/null +++ b/internal/domain/quota.go @@ -0,0 +1,6 @@ +package domain + +type Quota struct { + Amount int64 + Uid string +} diff --git a/internal/repository/dao/quota.go b/internal/repository/dao/quota.go index 362add0..22e68e4 100644 --- a/internal/repository/dao/quota.go +++ b/internal/repository/dao/quota.go @@ -14,6 +14,18 @@ package dao +import ( + "context" + "errors" + "time" + + "gorm.io/gorm" +) + +var ( + ErrNoAmount error = errors.New("余额不足") +) + type TempQuota struct { ID int64 `gorm:"primaryKey;autoIncrement;column:id"` UID string `gorm:"column:uid"` @@ -25,9 +37,115 @@ type TempQuota struct { } type Quota struct { - ID int64 `gorm:"primaryKey;autoIncrement;column:id"` - UID string `gorm:"column:uid"` - Amount int64 `gorm:"colum:amount"` - Ctime int64 `gorm:"column:ctime"` - Utime int64 `gorm:"column:utime"` + ID int64 `gorm:"primaryKey;autoIncrement;column:id"` + UID string `gorm:"column:uid"` + Amount int64 `gorm:"column:amount"` + LastClearTime int64 `gorm:"column:last_clear_time"` + Ctime int64 `gorm:"column:ctime"` + Utime int64 `gorm:"column:utime"` +} + +type QuotaDao struct { + db *gorm.DB +} + +func NewQuotaDao(db *gorm.DB) *QuotaDao { + return &QuotaDao{db: db} +} + +// CreateTempQuota 用来创建临时额度 +func (dao *QuotaDao) CreateTempQuota(ctx context.Context, quota TempQuota) error { + now := time.Now().Unix() + quota.Ctime = now + quota.Utime = now + return dao.db.WithContext(ctx).Create("a).Error +} + +// Create 用来创建对应的永久的额度 +func (dao *QuotaDao) Create(ctx context.Context, quota Quota) error { + now := time.Now().Unix() + quota.Ctime = now + quota.Ctime = now + return dao.db.WithContext(ctx).Create("a).Error +} + +func (dao *QuotaDao) GetQuotaByUid(ctx context.Context, uid string) error { + return dao.db.WithContext(ctx).Where("uid = ?", uid).Error +} + +func (dao *QuotaDao) GetTempQuotaByUid(ctx context.Context, uid string) error { + return dao.db.WithContext(ctx).Where("uid = ?", uid).Error +} + +// Deduct 扣减 +func (dao *QuotaDao) Deduct(ctx context.Context, uid string, amount int64) error { + now := time.Now().Unix() + // 首先扣除temp 的, 然后扣除 quota的 + err := dao.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + // 可能存在多个时间段 + var tempQuotas []TempQuota + err := tx.Where("uid = ? AND end_time >= ?", uid, now). + Order("end_time ASC"). + Find(&tempQuotas).Error + + if err != nil { + return err + } + + for i := range tempQuotas { + tq := tempQuotas[i] + if amount <= 0 { + break + } + deduct := int64(0) + // 如果大于需要直接扣, 小于就直接扣完 + if tq.Amount >= amount { + deduct = amount + amount = 0 + } else { + deduct = tq.Amount + amount -= deduct + } + tq.Amount -= deduct + tq.Utime = now + // 然后更新 + err = tx.Model(tq).Select("amount", "utime").Updates(tq).Error + if err != nil { + return err + } + } + + var quota Quota + err = tx.Where("uid = ?", uid).First("a).Error + if err != nil { + return err + } + + // 临时额度扣减完毕 + if amount <= 0 { + return nil + } + + // 扣完了发现还不够扣的, 从 quota 中扣 + if quota.Amount < amount { + return ErrNoAmount + } + quota.Amount -= amount + quota.Utime = now + quota.LastClearTime = now + + //更新 + err = tx.Model(&Quota{}).Updates(map[string]any{ + "amount": quota.Amount, + "utime": quota.Utime, + "last_clear_time": quota.LastClearTime, + }).Error + + if err != nil { + return ErrNoAmount + } + + return nil + }) + return err } diff --git a/internal/repository/quota.go b/internal/repository/quota.go new file mode 100644 index 0000000..05f47cf --- /dev/null +++ b/internal/repository/quota.go @@ -0,0 +1,19 @@ +package repository + +import ( + "context" + + "github.com/ecodeclub/ai-gateway-go/internal/repository/dao" +) + +type QuotaRepo struct { + quota *dao.Quota +} + +func NewQuotaRepo(quota *dao.Quota) *QuotaRepo { + return &QuotaRepo{quota: quota} +} + +func (q *QuotaRepo) Get(ctx context.Context) { + +} From 65cdc59e29f0662864d6452a948f60487a130827 Mon Sep 17 00:00:00 2001 From: yumosx Date: Fri, 27 Jun 2025 20:15:47 +0800 Subject: [PATCH 04/24] add the quota --- internal/repository/quota.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/internal/repository/quota.go b/internal/repository/quota.go index 05f47cf..1142ce7 100644 --- a/internal/repository/quota.go +++ b/internal/repository/quota.go @@ -14,6 +14,9 @@ func NewQuotaRepo(quota *dao.Quota) *QuotaRepo { return &QuotaRepo{quota: quota} } -func (q *QuotaRepo) Get(ctx context.Context) { +func (q *QuotaRepo) Get(ctx context.Context) error { + return nil +} +func (q *QuotaRepo) Deduct(ctx context.Context) error { } From 6c93dcc3897e45ac40478c4d8809f1c9a0d33d4b Mon Sep 17 00:00:00 2001 From: yumosx Date: Sat, 28 Jun 2025 16:04:56 +0800 Subject: [PATCH 05/24] handler: add quota --- Makefile | 10 +- errs/errs.go | 5 +- internal/domain/quota.go | 26 +++++- internal/repository/dao/quota.go | 81 +++++++++------- internal/repository/quota.go | 63 +++++++++++-- internal/service/quota.go | 54 +++++++++++ internal/web/quota.go | 153 +++++++++++++++++++++++++++++++ 7 files changed, 350 insertions(+), 42 deletions(-) create mode 100644 internal/service/quota.go create mode 100644 internal/web/quota.go diff --git a/Makefile b/Makefile index 6751fc1..f9ad393 100644 --- a/Makefile +++ b/Makefile @@ -1,3 +1,10 @@ +GOFILES=$(shell find . -type f -name '*.go' \ + -not -path "./vendor/*" \ + -not -path "./third_party/*" \ + -not -path "./.idea/*" \ + -not -name '*.pb.go' \ + -not -name '*mock*.go') + .PHONY: bench bench: @go test -bench=. -benchmem ./... @@ -24,7 +31,8 @@ e2e: .PHONY: fmt fmt: - @goimports -l -w $$(find . -type f -name '*.go' -not -path "./.idea/*" -not -name '*.pb.go' -not -name '*mock*.go') + @goimports -l -w $(GOFILES) + @gofumpt -l -w $(GOFILES) .PHONY: lint lint: diff --git a/errs/errs.go b/errs/errs.go index f4b0fea..7d925d2 100644 --- a/errs/errs.go +++ b/errs/errs.go @@ -18,4 +18,7 @@ import ( "errors" ) -var ErrBizConfigNotFound = errors.New("biz config not found") +var ( + ErrNoAmount = errors.New("余额不足") + ErrBizConfigNotFound = errors.New("biz config not found") +) diff --git a/internal/domain/quota.go b/internal/domain/quota.go index b9f7391..31c5d76 100644 --- a/internal/domain/quota.go +++ b/internal/domain/quota.go @@ -1,6 +1,28 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package domain type Quota struct { - Amount int64 - Uid string + Amount int64 + Uid int64 + LastClearTime int64 +} + +type TempQuota struct { + Amount int64 + StartTime int64 + EndTime int64 + Uid int64 } diff --git a/internal/repository/dao/quota.go b/internal/repository/dao/quota.go index 22e68e4..3c9fc31 100644 --- a/internal/repository/dao/quota.go +++ b/internal/repository/dao/quota.go @@ -16,33 +16,29 @@ package dao import ( "context" - "errors" "time" + "github.com/ecodeclub/ai-gateway-go/errs" "gorm.io/gorm" ) -var ( - ErrNoAmount error = errors.New("余额不足") -) - type TempQuota struct { - ID int64 `gorm:"primaryKey;autoIncrement;column:id"` - UID string `gorm:"column:uid"` - Amount int64 `gorm:"column:amount"` - StartTime int64 `gorm:"column:start_time"` - EndTime int64 `gorm:"column:end_time"` - Ctime int64 `gorm:"column:ctime"` - Utime int64 `gorm:"column:utime"` + ID int64 `gorm:"primaryKey;autoIncrement;column:id"` + UID int64 `gorm:"column:uid"` + Amount int64 `gorm:"column:amount"` + StartTime int64 `gorm:"column:start_time"` + EndTime int64 `gorm:"column:end_time"` + Ctime int64 `gorm:"column:ctime"` + Utime int64 `gorm:"column:utime"` } type Quota struct { - ID int64 `gorm:"primaryKey;autoIncrement;column:id"` - UID string `gorm:"column:uid"` - Amount int64 `gorm:"column:amount"` - LastClearTime int64 `gorm:"column:last_clear_time"` - Ctime int64 `gorm:"column:ctime"` - Utime int64 `gorm:"column:utime"` + ID int64 `gorm:"primaryKey;autoIncrement;column:id"` + UID int64 `gorm:"column:uid"` + Amount int64 `gorm:"column:amount"` + LastClearTime int64 `gorm:"column:last_clear_time"` + Ctime int64 `gorm:"column:ctime"` + Utime int64 `gorm:"column:utime"` } type QuotaDao struct { @@ -65,20 +61,41 @@ func (dao *QuotaDao) CreateTempQuota(ctx context.Context, quota TempQuota) error func (dao *QuotaDao) Create(ctx context.Context, quota Quota) error { now := time.Now().Unix() quota.Ctime = now - quota.Ctime = now + quota.Utime = now return dao.db.WithContext(ctx).Create("a).Error } -func (dao *QuotaDao) GetQuotaByUid(ctx context.Context, uid string) error { - return dao.db.WithContext(ctx).Where("uid = ?", uid).Error +func (dao *QuotaDao) UpdateQuota(ctx context.Context, quota Quota) error { + now := time.Now().Unix() + quota.Utime = now + + return dao.db.WithContext(ctx).Where("uid = ?", quota.UID).Updates(map[string]any{ + "amount": quota.Amount, + "utime": quota.Utime, + }).Error +} + +func (dao *QuotaDao) GetQuotaByUid(ctx context.Context, uid int64) (Quota, error) { + var quota Quota + err := dao.db.WithContext(ctx).Where("uid = ? and end_time >= ?", uid).First("a).Error + if err != nil { + return Quota{}, err + } + return quota, nil } -func (dao *QuotaDao) GetTempQuotaByUid(ctx context.Context, uid string) error { - return dao.db.WithContext(ctx).Where("uid = ?", uid).Error +func (dao *QuotaDao) GetTempQuotaByUidAndTime(ctx context.Context, uid int64) ([]TempQuota, error) { + now := time.Now().Unix() + var quota []TempQuota + err := dao.db.WithContext(ctx).Where("uid = ? and end_time >= ?", uid, now).Find("a).Error + if err != nil { + return nil, err + } + return quota, nil } // Deduct 扣减 -func (dao *QuotaDao) Deduct(ctx context.Context, uid string, amount int64) error { +func (dao *QuotaDao) Deduct(ctx context.Context, uid int64, amount int64) error { now := time.Now().Unix() // 首先扣除temp 的, 然后扣除 quota的 err := dao.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { @@ -87,7 +104,6 @@ func (dao *QuotaDao) Deduct(ctx context.Context, uid string, amount int64) error err := tx.Where("uid = ? AND end_time >= ?", uid, now). Order("end_time ASC"). Find(&tempQuotas).Error - if err != nil { return err } @@ -107,9 +123,11 @@ func (dao *QuotaDao) Deduct(ctx context.Context, uid string, amount int64) error amount -= deduct } tq.Amount -= deduct - tq.Utime = now // 然后更新 - err = tx.Model(tq).Select("amount", "utime").Updates(tq).Error + err = tx.Model(&TempQuota{}).Where("uid = ?", uid).Updates(map[string]any{ + "amount": tq.Amount, + "utime": now, + }).Error if err != nil { return err } @@ -128,21 +146,20 @@ func (dao *QuotaDao) Deduct(ctx context.Context, uid string, amount int64) error // 扣完了发现还不够扣的, 从 quota 中扣 if quota.Amount < amount { - return ErrNoAmount + return errs.ErrNoAmount } quota.Amount -= amount quota.Utime = now quota.LastClearTime = now - //更新 - err = tx.Model(&Quota{}).Updates(map[string]any{ + // 更新 + err = tx.Model(&Quota{}).Where("uid = ?", uid).Updates(map[string]any{ "amount": quota.Amount, "utime": quota.Utime, "last_clear_time": quota.LastClearTime, }).Error - if err != nil { - return ErrNoAmount + return errs.ErrNoAmount } return nil diff --git a/internal/repository/quota.go b/internal/repository/quota.go index 1142ce7..496c77b 100644 --- a/internal/repository/quota.go +++ b/internal/repository/quota.go @@ -1,22 +1,73 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package repository import ( "context" + "github.com/ecodeclub/ai-gateway-go/internal/domain" "github.com/ecodeclub/ai-gateway-go/internal/repository/dao" + "github.com/ecodeclub/ekit/slice" ) type QuotaRepo struct { - quota *dao.Quota + dao *dao.QuotaDao +} + +func NewQuotaRepo(dao *dao.QuotaDao) *QuotaRepo { + return &QuotaRepo{dao: dao} +} + +func (q *QuotaRepo) CreateQuota(ctx context.Context, quota domain.Quota) error { + return q.dao.Create(ctx, dao.Quota{UID: quota.Uid, Amount: quota.Amount}) +} + +func (q *QuotaRepo) UpdateQuota(ctx context.Context, quota domain.Quota) error { + return q.dao.UpdateQuota(ctx, dao.Quota{UID: quota.Uid, Amount: quota.Amount}) +} + +func (q *QuotaRepo) CreateTempQuota(ctx context.Context, quota domain.TempQuota) error { + return q.dao.CreateTempQuota(ctx, dao.TempQuota{Amount: quota.Amount, StartTime: quota.StartTime, EndTime: quota.EndTime}) +} + +func (q *QuotaRepo) GetQuota(ctx context.Context, uid int64) (domain.Quota, error) { + quota, err := q.dao.GetQuotaByUid(ctx, uid) + if err != nil { + return domain.Quota{}, err + } + return domain.Quota{Amount: quota.Amount, Uid: uid}, nil } -func NewQuotaRepo(quota *dao.Quota) *QuotaRepo { - return &QuotaRepo{quota: quota} +func (q *QuotaRepo) GetTempQuota(ctx context.Context, uid int64) ([]domain.TempQuota, error) { + tempQuotaList, err := q.dao.GetTempQuotaByUidAndTime(ctx, uid) + if err != nil { + return nil, err + } + return q.toDomainTempQuota(tempQuotaList), nil } -func (q *QuotaRepo) Get(ctx context.Context) error { - return nil +func (q *QuotaRepo) Deduct(ctx context.Context, uid int64, amount int64) error { + return q.dao.Deduct(ctx, uid, amount) } -func (q *QuotaRepo) Deduct(ctx context.Context) error { +func (q *QuotaRepo) toDomainTempQuota(tmpQuotaList []dao.TempQuota) []domain.TempQuota { + return slice.Map[dao.TempQuota, domain.TempQuota](tmpQuotaList, func(idx int, src dao.TempQuota) domain.TempQuota { + return domain.TempQuota{ + Amount: src.Amount, + StartTime: src.StartTime, + EndTime: src.EndTime, + } + }) } diff --git a/internal/service/quota.go b/internal/service/quota.go new file mode 100644 index 0000000..b69d21b --- /dev/null +++ b/internal/service/quota.go @@ -0,0 +1,54 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package service + +import ( + "context" + + "github.com/ecodeclub/ai-gateway-go/internal/domain" + "github.com/ecodeclub/ai-gateway-go/internal/repository" +) + +type QuotaService struct { + repo *repository.QuotaRepo +} + +func NewQuotaService(repo *repository.QuotaRepo) *QuotaService { + return &QuotaService{repo: repo} +} + +func (q *QuotaService) CreateQuota(ctx context.Context, quota domain.Quota) error { + return q.repo.CreateQuota(ctx, quota) +} + +func (q *QuotaService) CreateTempQuota(ctx context.Context, quota domain.TempQuota) error { + return q.repo.CreateTempQuota(ctx, quota) +} + +func (q *QuotaService) UpdateQuota(ctx context.Context, quota domain.Quota) error { + return q.repo.UpdateQuota(ctx, quota) +} + +func (q *QuotaService) GetTempQuota(ctx context.Context, uid int64) ([]domain.TempQuota, error) { + return q.repo.GetTempQuota(ctx, uid) +} + +func (q *QuotaService) GetQuota(ctx context.Context, uid int64) (domain.Quota, error) { + return q.repo.GetQuota(ctx, uid) +} + +func (q *QuotaService) Deduct(ctx context.Context, uid int64, amount int64) error { + return q.repo.Deduct(ctx, uid, amount) +} diff --git a/internal/web/quota.go b/internal/web/quota.go new file mode 100644 index 0000000..d8e01ef --- /dev/null +++ b/internal/web/quota.go @@ -0,0 +1,153 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package web + +import ( + "time" + + "github.com/ecodeclub/ai-gateway-go/internal/domain" + "github.com/ecodeclub/ai-gateway-go/internal/service" + "github.com/ecodeclub/ekit/slice" + "github.com/ecodeclub/ginx" + "github.com/ecodeclub/ginx/session" + "github.com/gin-gonic/gin" +) + +type QuotaHandler struct { + svc *service.QuotaService +} + +func NewQuotaHandler(svc *service.QuotaService) *QuotaHandler { + return &QuotaHandler{svc: svc} +} + +func (q *QuotaHandler) PrivateRoutes(_ *gin.Engine) {} + +func (q *QuotaHandler) PublicRoutes(server *gin.Engine) { + group := server.Group("/quota") + group.POST("/create", ginx.BS(q.CreateTempQuota)) + group.POST("/create_tmp", ginx.BS(q.CreateTempQuota)) + group.POST("/deduct", ginx.BS(q.Deduct)) + group.POST("/get", ginx.S(q.GetQuota)) + group.POST("/get_tmp", ginx.S(q.GetTempQuota)) +} + +func (q *QuotaHandler) CreateQuota(ctx *ginx.Context, req QuotaRequest, sess session.Session) (ginx.Result, error) { + uid := sess.Claims().Uid + err := q.svc.CreateQuota(ctx, domain.Quota{Amount: req.Amount, Uid: uid}) + if err != nil { + return systemErrorResult, nil + } + return ginx.Result{ + Msg: "OK", + }, nil +} + +func (q *QuotaHandler) CreateTempQuota(ctx *ginx.Context, req QuotaRequest, sess session.Session) (ginx.Result, error) { + uid := sess.Claims().Uid + + if req.StartTime == "" || req.EndTime == "" { + return systemErrorResult, nil + } + + start, _ := q.toTimestamp(req.StartTime) + end, _ := q.toTimestamp(req.EndTime) + + err := q.svc.CreateTempQuota(ctx, domain.TempQuota{Amount: req.Amount, Uid: uid, StartTime: start, EndTime: end}) + if err != nil { + return systemErrorResult, nil + } + return ginx.Result{ + Msg: "OK", + }, nil +} + +func (q *QuotaHandler) GetQuota(ctx *ginx.Context, sees session.Session) (ginx.Result, error) { + uid := sees.Claims().Uid + + quota, err := q.svc.GetQuota(ctx, uid) + if err != nil { + return systemErrorResult, err + } + return ginx.Result{ + Msg: "ok", + Data: QuotaResponse{Amount: quota.Amount}, + }, nil +} + +func (q *QuotaHandler) GetTempQuota(ctx *ginx.Context, sees session.Session) (ginx.Result, error) { + uid := sees.Claims().Uid + quotaList, err := q.svc.GetTempQuota(ctx, uid) + if err != nil { + return systemErrorResult, err + } + return ginx.Result{ + Msg: "ok", + Data: q.toQuotaResponse(quotaList), + }, nil +} + +func (q *QuotaHandler) UpdateQuota(ctx *ginx.Context, req QuotaRequest, sees session.Session) (ginx.Result, error) { + uid := sees.Claims().Uid + + err := q.svc.UpdateQuota(ctx, domain.Quota{Uid: uid, Amount: req.Amount}) + if err != nil { + return systemErrorResult, nil + } + + return ginx.Result{ + Msg: "OK", + }, nil +} + +func (q *QuotaHandler) Deduct(ctx *ginx.Context, req QuotaRequest, sees session.Session) (ginx.Result, error) { + uid := sees.Claims().Uid + err := q.svc.Deduct(ctx, uid, req.Amount) + if err != nil { + return systemErrorResult, nil + } + return ginx.Result{Msg: "OK"}, nil +} + +func (q *QuotaHandler) toTimestamp(timeStr string) (int64, error) { + const layout = "2006-01-02 15:04:05" + t, err := time.Parse(layout, timeStr) + if err != nil { + return 0, err + } + return t.Unix(), nil +} + +func (q *QuotaHandler) toQuotaResponse(tempQuotaList []domain.TempQuota) []QuotaResponse { + return slice.Map[domain.TempQuota, QuotaResponse](tempQuotaList, func(idx int, src domain.TempQuota) QuotaResponse { + return QuotaResponse{ + Amount: src.Amount, + StartTime: time.Unix(src.StartTime, 0).Format("2006-01-02 15:04:05"), + EndTime: time.Unix(src.EndTime, 0).Format("2006-01-02 15:04:05"), + } + }) +} + +type QuotaRequest struct { + Amount int64 `json:"amount,omitempty"` + StartTime string `json:"start_time,omitempty"` + EndTime string `json:"end_time,omitempty"` +} + +type QuotaResponse struct { + Amount int64 `json:"amount,omitempty"` + StartTime string `json:"start_time,omitempty"` + EndTime string `json:"end_time,omitempty"` +} From f05e5c21905f91d0016983abaff91ea605a409c3 Mon Sep 17 00:00:00 2001 From: yumosx Date: Sat, 28 Jun 2025 16:13:31 +0800 Subject: [PATCH 06/24] fix ci --- .github/workflows/go-fmt.yml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/go-fmt.yml b/.github/workflows/go-fmt.yml index eefc0d5..c6fff3f 100644 --- a/.github/workflows/go-fmt.yml +++ b/.github/workflows/go-fmt.yml @@ -35,8 +35,9 @@ jobs: go-version: "1.24.2" - name: Install goimports - run: go install golang.org/x/tools/cmd/goimports@latest - + run: | + go install golang.org/x/tools/cmd/goimports@latest + go install mvdan.cc/gofumpt@latest - name: Check run: | make check From e190998e976d0fc7ba0571f632830f10137a1954 Mon Sep 17 00:00:00 2001 From: yumosx Date: Sat, 5 Jul 2025 22:28:43 +0800 Subject: [PATCH 07/24] update --- internal/domain/quota.go | 5 + internal/repository/dao/quota.go | 210 ++++++++++++++++--------------- internal/repository/quota.go | 16 +-- internal/service/quota.go | 66 ++++++++-- internal/test/quota_test.go | 130 +++++++++++++++++++ internal/web/quota.go | 40 +++--- 6 files changed, 317 insertions(+), 150 deletions(-) create mode 100644 internal/test/quota_test.go diff --git a/internal/domain/quota.go b/internal/domain/quota.go index 31c5d76..8ccc932 100644 --- a/internal/domain/quota.go +++ b/internal/domain/quota.go @@ -16,13 +16,18 @@ package domain type Quota struct { Amount int64 + Key string Uid int64 LastClearTime int64 } type TempQuota struct { Amount int64 + Key string StartTime int64 EndTime int64 Uid int64 } + +type Record struct { +} diff --git a/internal/repository/dao/quota.go b/internal/repository/dao/quota.go index 3c9fc31..5ed5937 100644 --- a/internal/repository/dao/quota.go +++ b/internal/repository/dao/quota.go @@ -18,27 +18,50 @@ import ( "context" "time" - "github.com/ecodeclub/ai-gateway-go/errs" "gorm.io/gorm" + "gorm.io/gorm/clause" ) type TempQuota struct { - ID int64 `gorm:"primaryKey;autoIncrement;column:id"` - UID int64 `gorm:"column:uid"` - Amount int64 `gorm:"column:amount"` - StartTime int64 `gorm:"column:start_time"` - EndTime int64 `gorm:"column:end_time"` - Ctime int64 `gorm:"column:ctime"` - Utime int64 `gorm:"column:utime"` + ID int64 `gorm:"primaryKey;autoIncrement;column:id"` + UID int64 `gorm:"column:uid"` + Key string `gorm:"column:key;uniqueIndex;type:varchar(256)"` + Amount int64 `gorm:"column:amount"` + StartTime int64 `gorm:"column:start_time"` + EndTime int64 `gorm:"column:end_time"` + Ctime int64 `gorm:"column:ctime"` + Utime int64 `gorm:"column:utime"` +} + +func (TempQuota) TableName() string { + return "temp_quotas" +} + +type QuotaRecord struct { + ID int64 `gorm:"primaryKey;autoIncrement;column:id"` + Uid int64 `gorm:"column:uid;index"` + Key string `gorm:"column:key;uniqueIndex;type:varchar(256)"` + Amount int64 `gorm:"column:amount"` + Ctime int64 `gorm:"column:ctime"` + Utime int64 `gorm:"column:utime"` +} + +func (QuotaRecord) TableName() string { + return "quota_records" } type Quota struct { - ID int64 `gorm:"primaryKey;autoIncrement;column:id"` - UID int64 `gorm:"column:uid"` - Amount int64 `gorm:"column:amount"` - LastClearTime int64 `gorm:"column:last_clear_time"` - Ctime int64 `gorm:"column:ctime"` - Utime int64 `gorm:"column:utime"` + ID int64 `gorm:"primaryKey;autoIncrement;column:id"` + UID int64 `gorm:"column:uid"` + Key string `gorm:"column:key;uniqueIndex;type:varchar(256)"` + Amount int64 `gorm:"column:amount"` + LastClearTime int64 `gorm:"column:last_clear_time"` + Ctime int64 `gorm:"column:ctime"` + Utime int64 `gorm:"column:utime"` +} + +func (Quota) TableName() string { + return "quotas" } type QuotaDao struct { @@ -49,35 +72,80 @@ func NewQuotaDao(db *gorm.DB) *QuotaDao { return &QuotaDao{db: db} } -// CreateTempQuota 用来创建临时额度 -func (dao *QuotaDao) CreateTempQuota(ctx context.Context, quota TempQuota) error { +func (dao *QuotaDao) SaveTempQuota(ctx context.Context, quota TempQuota) error { now := time.Now().Unix() - quota.Ctime = now - quota.Utime = now - return dao.db.WithContext(ctx).Create("a).Error -} + return dao.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + record := QuotaRecord{ + Key: quota.Key, + Uid: quota.UID, + Amount: quota.Amount, + Ctime: now, + Utime: now, + } -// Create 用来创建对应的永久的额度 -func (dao *QuotaDao) Create(ctx context.Context, quota Quota) error { - now := time.Now().Unix() - quota.Ctime = now - quota.Utime = now - return dao.db.WithContext(ctx).Create("a).Error + result := tx.Clauses(clause.OnConflict{ + Columns: []clause.Column{{Name: "key"}}, + DoUpdates: clause.Assignments(map[string]interface{}{ + "amount": quota.Amount, + "utime": now, + }), + }).Create(&record) + if result.Error != nil { + return result.Error + } + quota.Ctime = now + quota.Utime = now + + return tx.Clauses(clause.OnConflict{ + Columns: []clause.Column{{Name: "key"}}, + DoUpdates: clause.Assignments(map[string]interface{}{ + "amount": quota.Amount, + "start_time": quota.StartTime, + "end_time": quota.EndTime, + "utime": now, + }), + }).Create("a).Error + }) } -func (dao *QuotaDao) UpdateQuota(ctx context.Context, quota Quota) error { +func (dao *QuotaDao) SaveQuota(ctx context.Context, quota Quota) error { now := time.Now().Unix() quota.Utime = now - return dao.db.WithContext(ctx).Where("uid = ?", quota.UID).Updates(map[string]any{ - "amount": quota.Amount, - "utime": quota.Utime, - }).Error + return dao.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + record := QuotaRecord{ + Key: quota.Key, + Uid: quota.UID, + Amount: quota.Amount, + Ctime: now, + Utime: now, + } + result := tx.Clauses(clause.OnConflict{ + Columns: []clause.Column{{Name: "key"}}, + DoUpdates: clause.Assignments(map[string]interface{}{ + "amount": quota.Amount, + "utime": now, + }), + }).Create(&record) + if result.Error != nil { + return result.Error + } + + return tx.Clauses(clause.OnConflict{ + Columns: []clause.Column{{Name: "key"}}, + DoUpdates: clause.Assignments(map[string]any{ + "amount": quota.Amount, + "utime": now, + }), + }).Create("a).Error + }) } func (dao *QuotaDao) GetQuotaByUid(ctx context.Context, uid int64) (Quota, error) { var quota Quota - err := dao.db.WithContext(ctx).Where("uid = ? and end_time >= ?", uid).First("a).Error + err := dao.db.WithContext(ctx). + Where("uid = ? and end_time >= ?", uid). + First("a).Error if err != nil { return Quota{}, err } @@ -87,82 +155,16 @@ func (dao *QuotaDao) GetQuotaByUid(ctx context.Context, uid int64) (Quota, error func (dao *QuotaDao) GetTempQuotaByUidAndTime(ctx context.Context, uid int64) ([]TempQuota, error) { now := time.Now().Unix() var quota []TempQuota - err := dao.db.WithContext(ctx).Where("uid = ? and end_time >= ?", uid, now).Find("a).Error + err := dao.db.WithContext(ctx). + Where("uid = ? and end_time >= ?", uid, now). + Order("end_time ASC"). + Find("a).Error if err != nil { return nil, err } return quota, nil } -// Deduct 扣减 -func (dao *QuotaDao) Deduct(ctx context.Context, uid int64, amount int64) error { - now := time.Now().Unix() - // 首先扣除temp 的, 然后扣除 quota的 - err := dao.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { - // 可能存在多个时间段 - var tempQuotas []TempQuota - err := tx.Where("uid = ? AND end_time >= ?", uid, now). - Order("end_time ASC"). - Find(&tempQuotas).Error - if err != nil { - return err - } - - for i := range tempQuotas { - tq := tempQuotas[i] - if amount <= 0 { - break - } - deduct := int64(0) - // 如果大于需要直接扣, 小于就直接扣完 - if tq.Amount >= amount { - deduct = amount - amount = 0 - } else { - deduct = tq.Amount - amount -= deduct - } - tq.Amount -= deduct - // 然后更新 - err = tx.Model(&TempQuota{}).Where("uid = ?", uid).Updates(map[string]any{ - "amount": tq.Amount, - "utime": now, - }).Error - if err != nil { - return err - } - } - - var quota Quota - err = tx.Where("uid = ?", uid).First("a).Error - if err != nil { - return err - } - - // 临时额度扣减完毕 - if amount <= 0 { - return nil - } - - // 扣完了发现还不够扣的, 从 quota 中扣 - if quota.Amount < amount { - return errs.ErrNoAmount - } - quota.Amount -= amount - quota.Utime = now - quota.LastClearTime = now - - // 更新 - err = tx.Model(&Quota{}).Where("uid = ?", uid).Updates(map[string]any{ - "amount": quota.Amount, - "utime": quota.Utime, - "last_clear_time": quota.LastClearTime, - }).Error - if err != nil { - return errs.ErrNoAmount - } - - return nil - }) - return err +func InitQuotaTable(db *gorm.DB) error { + return db.AutoMigrate(&Quota{}, &TempQuota{}, &QuotaRecord{}) } diff --git a/internal/repository/quota.go b/internal/repository/quota.go index 496c77b..b0df5b3 100644 --- a/internal/repository/quota.go +++ b/internal/repository/quota.go @@ -30,16 +30,12 @@ func NewQuotaRepo(dao *dao.QuotaDao) *QuotaRepo { return &QuotaRepo{dao: dao} } -func (q *QuotaRepo) CreateQuota(ctx context.Context, quota domain.Quota) error { - return q.dao.Create(ctx, dao.Quota{UID: quota.Uid, Amount: quota.Amount}) +func (q *QuotaRepo) SaveQuota(ctx context.Context, quota domain.Quota) error { + return q.dao.SaveQuota(ctx, dao.Quota{UID: quota.Uid, Amount: quota.Amount, Key: quota.Key}) } -func (q *QuotaRepo) UpdateQuota(ctx context.Context, quota domain.Quota) error { - return q.dao.UpdateQuota(ctx, dao.Quota{UID: quota.Uid, Amount: quota.Amount}) -} - -func (q *QuotaRepo) CreateTempQuota(ctx context.Context, quota domain.TempQuota) error { - return q.dao.CreateTempQuota(ctx, dao.TempQuota{Amount: quota.Amount, StartTime: quota.StartTime, EndTime: quota.EndTime}) +func (q *QuotaRepo) SaveTempQuota(ctx context.Context, quota domain.TempQuota) error { + return q.dao.SaveTempQuota(ctx, dao.TempQuota{Amount: quota.Amount, StartTime: quota.StartTime, EndTime: quota.EndTime, Key: quota.Key}) } func (q *QuotaRepo) GetQuota(ctx context.Context, uid int64) (domain.Quota, error) { @@ -58,10 +54,6 @@ func (q *QuotaRepo) GetTempQuota(ctx context.Context, uid int64) ([]domain.TempQ return q.toDomainTempQuota(tempQuotaList), nil } -func (q *QuotaRepo) Deduct(ctx context.Context, uid int64, amount int64) error { - return q.dao.Deduct(ctx, uid, amount) -} - func (q *QuotaRepo) toDomainTempQuota(tmpQuotaList []dao.TempQuota) []domain.TempQuota { return slice.Map[dao.TempQuota, domain.TempQuota](tmpQuotaList, func(idx int, src dao.TempQuota) domain.TempQuota { return domain.TempQuota{ diff --git a/internal/service/quota.go b/internal/service/quota.go index b69d21b..e176b7a 100644 --- a/internal/service/quota.go +++ b/internal/service/quota.go @@ -16,7 +16,9 @@ package service import ( "context" + "fmt" + "github.com/ecodeclub/ai-gateway-go/errs" "github.com/ecodeclub/ai-gateway-go/internal/domain" "github.com/ecodeclub/ai-gateway-go/internal/repository" ) @@ -29,16 +31,12 @@ func NewQuotaService(repo *repository.QuotaRepo) *QuotaService { return &QuotaService{repo: repo} } -func (q *QuotaService) CreateQuota(ctx context.Context, quota domain.Quota) error { - return q.repo.CreateQuota(ctx, quota) +func (q *QuotaService) SaveQuota(ctx context.Context, quota domain.Quota) error { + return q.repo.SaveQuota(ctx, quota) } -func (q *QuotaService) CreateTempQuota(ctx context.Context, quota domain.TempQuota) error { - return q.repo.CreateTempQuota(ctx, quota) -} - -func (q *QuotaService) UpdateQuota(ctx context.Context, quota domain.Quota) error { - return q.repo.UpdateQuota(ctx, quota) +func (q *QuotaService) SaveTempQuota(ctx context.Context, quota domain.TempQuota) error { + return q.repo.SaveTempQuota(ctx, quota) } func (q *QuotaService) GetTempQuota(ctx context.Context, uid int64) ([]domain.TempQuota, error) { @@ -49,6 +47,54 @@ func (q *QuotaService) GetQuota(ctx context.Context, uid int64) (domain.Quota, e return q.repo.GetQuota(ctx, uid) } -func (q *QuotaService) Deduct(ctx context.Context, uid int64, amount int64) error { - return q.repo.Deduct(ctx, uid, amount) +func (q *QuotaService) Deduct(ctx context.Context, uid int64, amount int64, key string) error { + key1 := fmt.Sprintf("temp_%s", key) + key2 := fmt.Sprintf("quota_%s", key) + tempQuotaList, err := q.repo.GetTempQuota(ctx, uid) + if err != nil { + return err + } + // 1. 优先扣减临时表 + for i := range tempQuotaList { + tq := tempQuotaList[i] + if amount <= 0 { + break + } + deduct := int64(0) + if tq.Amount >= amount { + deduct = amount + amount = 0 + } else { + deduct = tq.Amount + amount -= deduct + } + tq.Amount -= deduct + err = q.SaveTempQuota(ctx, domain.TempQuota{ + Uid: uid, + Amount: tq.Amount, + Key: key1, + }) + + if err != nil { + return err + } + } + // 扣减完成了 + if amount <= 0 { + return nil + } + + quota, err := q.GetQuota(ctx, uid) + if err != nil { + return err + } + if quota.Amount < amount { + return errs.ErrNoAmount + } + + return q.SaveQuota(ctx, domain.Quota{ + Uid: uid, + Amount: amount, + Key: key2, + }) } diff --git a/internal/test/quota_test.go b/internal/test/quota_test.go new file mode 100644 index 0000000..1fe3000 --- /dev/null +++ b/internal/test/quota_test.go @@ -0,0 +1,130 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "bytes" + "net/http" + "net/http/httptest" + "testing" + + "github.com/ecodeclub/ai-gateway-go/internal/repository" + "github.com/ecodeclub/ai-gateway-go/internal/repository/dao" + "github.com/ecodeclub/ai-gateway-go/internal/service" + "github.com/ecodeclub/ai-gateway-go/internal/test/mocks" + "github.com/ecodeclub/ai-gateway-go/internal/web" + "github.com/ecodeclub/ginx/session" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/yumosx/got/pkg/config" + "go.uber.org/mock/gomock" + "gorm.io/gorm" +) + +type QuotaSuite struct { + suite.Suite + db *gorm.DB + server *gin.Engine +} + +func TestQuota(t *testing.T) { + suite.Run(t, &QuotaSuite{}) +} + +func (q *QuotaSuite) SetupSuite() { + dbConfig := config.NewConfig( + config.WithDBName("ai_gateway_platform"), + config.WithUserName("root"), + config.WithPassword("root"), + config.WithHost("127.0.0.1"), + config.WithPort("13306"), + ) + db, err := config.NewDB(dbConfig) + require.NoError(q.T(), err) + err = dao.InitQuotaTable(db) + require.NoError(q.T(), err) + q.db = db + + d := dao.NewQuotaDao(db) + repo := repository.NewQuotaRepo(d) + svc := service.NewQuotaService(repo) + handler := web.NewQuotaHandler(svc) + server := gin.Default() + handler.PrivateRoutes(server) + q.server = server +} + +func (q *QuotaSuite) TearDownTest() { + err := q.db.Exec("TRUNCATE TABLE quotas").Error + require.NoError(q.T(), err) + err = q.db.Exec("TRUNCATE TABLE quota_records").Error + require.NoError(q.T(), err) +} + +func (q *QuotaSuite) TestQuotaSave() { + t := q.T() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + testcases := []struct { + name string + before func() + after func() + reqBody string + }{ + { + name: "创建一个 quota", + before: func() { + sess := mocks.NewMockSession(ctrl) + sess.EXPECT().Claims().Return(session.Claims{ + Uid: 1, + }).AnyTimes() + provider := mocks.NewMockProvider(ctrl) + session.SetDefaultProvider(provider) + provider.EXPECT().Get(gomock.Any()).Return(sess, nil) + }, + after: func() { + var quota dao.Quota + err := q.db.Where("id = ?", 1).First("a).Error + require.NoError(t, err) + assert.Equal(t, int64(100000), quota.Amount) + + var record dao.QuotaRecord + err = q.db.Where("id = ?", 1).First(&record).Error + require.NoError(t, err) + assert.Equal(t, "23911", record.Key) + }, + reqBody: `{"amount": 100000, "key": "23911"}`, + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + tc.before() + req, err := http.NewRequest(http.MethodPost, "/quota/save", bytes.NewBuffer([]byte(tc.reqBody))) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + resp := httptest.NewRecorder() + q.server.ServeHTTP(resp, req) + + assert.Equal(t, http.StatusOK, resp.Code) + + tc.after() + }) + } +} diff --git a/internal/web/quota.go b/internal/web/quota.go index d8e01ef..5add0fc 100644 --- a/internal/web/quota.go +++ b/internal/web/quota.go @@ -33,20 +33,23 @@ func NewQuotaHandler(svc *service.QuotaService) *QuotaHandler { return &QuotaHandler{svc: svc} } -func (q *QuotaHandler) PrivateRoutes(_ *gin.Engine) {} +func (q *QuotaHandler) PublicRoutes(_ *gin.Engine) {} -func (q *QuotaHandler) PublicRoutes(server *gin.Engine) { +func (q *QuotaHandler) PrivateRoutes(server *gin.Engine) { group := server.Group("/quota") - group.POST("/create", ginx.BS(q.CreateTempQuota)) - group.POST("/create_tmp", ginx.BS(q.CreateTempQuota)) - group.POST("/deduct", ginx.BS(q.Deduct)) + group.POST("/save", ginx.BS(q.SaveQuota)) group.POST("/get", ginx.S(q.GetQuota)) - group.POST("/get_tmp", ginx.S(q.GetTempQuota)) + + tmp := server.Group("/tmp") + tmp.POST("/save", ginx.BS(q.SaveTempQuota)) + tmp.POST("/get", ginx.S(q.GetTempQuota)) + + server.POST("/deduct", ginx.BS(q.Deduct)) } -func (q *QuotaHandler) CreateQuota(ctx *ginx.Context, req QuotaRequest, sess session.Session) (ginx.Result, error) { +func (q *QuotaHandler) SaveQuota(ctx *ginx.Context, req QuotaRequest, sess session.Session) (ginx.Result, error) { uid := sess.Claims().Uid - err := q.svc.CreateQuota(ctx, domain.Quota{Amount: req.Amount, Uid: uid}) + err := q.svc.SaveQuota(ctx, domain.Quota{Amount: req.Amount, Uid: uid, Key: req.Key}) if err != nil { return systemErrorResult, nil } @@ -55,7 +58,7 @@ func (q *QuotaHandler) CreateQuota(ctx *ginx.Context, req QuotaRequest, sess ses }, nil } -func (q *QuotaHandler) CreateTempQuota(ctx *ginx.Context, req QuotaRequest, sess session.Session) (ginx.Result, error) { +func (q *QuotaHandler) SaveTempQuota(ctx *ginx.Context, req QuotaRequest, sess session.Session) (ginx.Result, error) { uid := sess.Claims().Uid if req.StartTime == "" || req.EndTime == "" { @@ -65,7 +68,7 @@ func (q *QuotaHandler) CreateTempQuota(ctx *ginx.Context, req QuotaRequest, sess start, _ := q.toTimestamp(req.StartTime) end, _ := q.toTimestamp(req.EndTime) - err := q.svc.CreateTempQuota(ctx, domain.TempQuota{Amount: req.Amount, Uid: uid, StartTime: start, EndTime: end}) + err := q.svc.SaveTempQuota(ctx, domain.TempQuota{Amount: req.Amount, Uid: uid, StartTime: start, EndTime: end}) if err != nil { return systemErrorResult, nil } @@ -99,22 +102,9 @@ func (q *QuotaHandler) GetTempQuota(ctx *ginx.Context, sees session.Session) (gi }, nil } -func (q *QuotaHandler) UpdateQuota(ctx *ginx.Context, req QuotaRequest, sees session.Session) (ginx.Result, error) { - uid := sees.Claims().Uid - - err := q.svc.UpdateQuota(ctx, domain.Quota{Uid: uid, Amount: req.Amount}) - if err != nil { - return systemErrorResult, nil - } - - return ginx.Result{ - Msg: "OK", - }, nil -} - func (q *QuotaHandler) Deduct(ctx *ginx.Context, req QuotaRequest, sees session.Session) (ginx.Result, error) { uid := sees.Claims().Uid - err := q.svc.Deduct(ctx, uid, req.Amount) + err := q.svc.Deduct(ctx, uid, req.Amount, req.Key) if err != nil { return systemErrorResult, nil } @@ -142,12 +132,14 @@ func (q *QuotaHandler) toQuotaResponse(tempQuotaList []domain.TempQuota) []Quota type QuotaRequest struct { Amount int64 `json:"amount,omitempty"` + Key string `json:"key,omitempty"` StartTime string `json:"start_time,omitempty"` EndTime string `json:"end_time,omitempty"` } type QuotaResponse struct { Amount int64 `json:"amount,omitempty"` + Key string `json:"key"` StartTime string `json:"start_time,omitempty"` EndTime string `json:"end_time,omitempty"` } From 77370966c4009d3f23cdab22c9975f89576396e7 Mon Sep 17 00:00:00 2001 From: yumosx Date: Tue, 8 Jul 2025 23:00:36 +0800 Subject: [PATCH 08/24] update deduct method --- errs/errs.go | 2 +- internal/domain/quota.go | 3 +- internal/repository/dao/quota.go | 130 +++++++++++++++++++++++-------- internal/repository/quota.go | 4 + internal/service/quota.go | 52 +------------ 5 files changed, 104 insertions(+), 87 deletions(-) diff --git a/errs/errs.go b/errs/errs.go index 7d925d2..dd02860 100644 --- a/errs/errs.go +++ b/errs/errs.go @@ -19,6 +19,6 @@ import ( ) var ( - ErrNoAmount = errors.New("余额不足") + DeductAmount = errors.New("扣减失败") ErrBizConfigNotFound = errors.New("biz config not found") ) diff --git a/internal/domain/quota.go b/internal/domain/quota.go index 8ccc932..1953893 100644 --- a/internal/domain/quota.go +++ b/internal/domain/quota.go @@ -29,5 +29,4 @@ type TempQuota struct { Uid int64 } -type Record struct { -} +type Record struct{} diff --git a/internal/repository/dao/quota.go b/internal/repository/dao/quota.go index 5ed5937..d201383 100644 --- a/internal/repository/dao/quota.go +++ b/internal/repository/dao/quota.go @@ -18,6 +18,7 @@ import ( "context" "time" + "github.com/ecodeclub/ai-gateway-go/errs" "gorm.io/gorm" "gorm.io/gorm/clause" ) @@ -74,38 +75,17 @@ func NewQuotaDao(db *gorm.DB) *QuotaDao { func (dao *QuotaDao) SaveTempQuota(ctx context.Context, quota TempQuota) error { now := time.Now().Unix() - return dao.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { - record := QuotaRecord{ - Key: quota.Key, - Uid: quota.UID, - Amount: quota.Amount, - Ctime: now, - Utime: now, - } - - result := tx.Clauses(clause.OnConflict{ - Columns: []clause.Column{{Name: "key"}}, - DoUpdates: clause.Assignments(map[string]interface{}{ - "amount": quota.Amount, - "utime": now, - }), - }).Create(&record) - if result.Error != nil { - return result.Error - } - quota.Ctime = now - quota.Utime = now - - return tx.Clauses(clause.OnConflict{ - Columns: []clause.Column{{Name: "key"}}, - DoUpdates: clause.Assignments(map[string]interface{}{ - "amount": quota.Amount, - "start_time": quota.StartTime, - "end_time": quota.EndTime, - "utime": now, - }), - }).Create("a).Error - }) + quota.Ctime = now + quota.Utime = now + return dao.db.WithContext(ctx).Clauses(clause.OnConflict{ + Columns: []clause.Column{{Name: "key"}}, + DoUpdates: clause.Assignments(map[string]any{ + "amount": quota.Amount, + "start_time": quota.StartTime, + "end_time": quota.EndTime, + "utime": now, + }), + }).Create("a).Error } func (dao *QuotaDao) SaveQuota(ctx context.Context, quota Quota) error { @@ -122,11 +102,12 @@ func (dao *QuotaDao) SaveQuota(ctx context.Context, quota Quota) error { } result := tx.Clauses(clause.OnConflict{ Columns: []clause.Column{{Name: "key"}}, - DoUpdates: clause.Assignments(map[string]interface{}{ + DoUpdates: clause.Assignments(map[string]any{ "amount": quota.Amount, "utime": now, }), }).Create(&record) + if result.Error != nil { return result.Error } @@ -165,6 +146,89 @@ func (dao *QuotaDao) GetTempQuotaByUidAndTime(ctx context.Context, uid int64) ([ return quota, nil } +func (dao *QuotaDao) Deduct(ctx context.Context, uid int64, amount int64, key string) error { + return dao.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + now := time.Now().Unix() + record := QuotaRecord{ + Key: key, + Uid: uid, + Amount: amount, + Ctime: now, + Utime: now, + } + result := tx.Clauses(clause.OnConflict{ + Columns: []clause.Column{{Name: "key"}}, + DoNothing: true, + }).Create(&record) + + if result.Error != nil { + return result.Error + } + if result.RowsAffected == 0 { + return nil + } + // 执行扣减程序 + return dao.deduct(tx, uid, amount, now) + }) +} + +func (dao *QuotaDao) deduct(tx *gorm.DB, uid int64, amount int64, now int64) error { + var tempQuotas []TempQuota + err := tx.Where("uid = ? AND end_time >= ? AND amount > 0", uid, now). + Order("end_time ASC"). + Find(&tempQuotas).Error + if err != nil { + return err + } + + remain := amount + + // 先扣临时额度 + for i := range tempQuotas { + if remain <= 0 { + break + } + tq := &tempQuotas[i] + deduct := tq.Amount + if deduct > remain { + deduct = remain + } + // 原子扣减,防止并发下超扣 + update := tx.Model(&TempQuota{}). + Where("id = ? AND amount >= ?", tq.ID, deduct). + Updates(map[string]any{ + "amount": gorm.Expr("amount - ?", deduct), + "utime": now, + }) + if update.Error != nil { + return update.Error + } + if update.RowsAffected == 0 { + continue // 这条被其他并发扣完,跳过 + } + remain -= deduct + } + + // 如果还有剩余,从主额度扣 + if remain > 0 { + result := tx.Model(&Quota{}). + Where("uid = ? AND amount >= ?", uid, remain). + Updates(map[string]any{ + "amount": gorm.Expr("amount - ?", remain), + "utime": now, + "last_clear_time": now, + }) + if result.Error != nil { + return result.Error + } + if result.RowsAffected == 0 { + return errs.DeductAmount + } + } + + return nil +} + func InitQuotaTable(db *gorm.DB) error { return db.AutoMigrate(&Quota{}, &TempQuota{}, &QuotaRecord{}) } diff --git a/internal/repository/quota.go b/internal/repository/quota.go index b0df5b3..51800b6 100644 --- a/internal/repository/quota.go +++ b/internal/repository/quota.go @@ -54,6 +54,10 @@ func (q *QuotaRepo) GetTempQuota(ctx context.Context, uid int64) ([]domain.TempQ return q.toDomainTempQuota(tempQuotaList), nil } +func (q *QuotaRepo) Deduct(ctx context.Context, uid int64, amount int64, key string) error { + return q.dao.Deduct(ctx, uid, amount, key) +} + func (q *QuotaRepo) toDomainTempQuota(tmpQuotaList []dao.TempQuota) []domain.TempQuota { return slice.Map[dao.TempQuota, domain.TempQuota](tmpQuotaList, func(idx int, src dao.TempQuota) domain.TempQuota { return domain.TempQuota{ diff --git a/internal/service/quota.go b/internal/service/quota.go index e176b7a..e429551 100644 --- a/internal/service/quota.go +++ b/internal/service/quota.go @@ -16,9 +16,7 @@ package service import ( "context" - "fmt" - "github.com/ecodeclub/ai-gateway-go/errs" "github.com/ecodeclub/ai-gateway-go/internal/domain" "github.com/ecodeclub/ai-gateway-go/internal/repository" ) @@ -48,53 +46,5 @@ func (q *QuotaService) GetQuota(ctx context.Context, uid int64) (domain.Quota, e } func (q *QuotaService) Deduct(ctx context.Context, uid int64, amount int64, key string) error { - key1 := fmt.Sprintf("temp_%s", key) - key2 := fmt.Sprintf("quota_%s", key) - tempQuotaList, err := q.repo.GetTempQuota(ctx, uid) - if err != nil { - return err - } - // 1. 优先扣减临时表 - for i := range tempQuotaList { - tq := tempQuotaList[i] - if amount <= 0 { - break - } - deduct := int64(0) - if tq.Amount >= amount { - deduct = amount - amount = 0 - } else { - deduct = tq.Amount - amount -= deduct - } - tq.Amount -= deduct - err = q.SaveTempQuota(ctx, domain.TempQuota{ - Uid: uid, - Amount: tq.Amount, - Key: key1, - }) - - if err != nil { - return err - } - } - // 扣减完成了 - if amount <= 0 { - return nil - } - - quota, err := q.GetQuota(ctx, uid) - if err != nil { - return err - } - if quota.Amount < amount { - return errs.ErrNoAmount - } - - return q.SaveQuota(ctx, domain.Quota{ - Uid: uid, - Amount: amount, - Key: key2, - }) + return q.repo.Deduct(ctx, uid, amount, key) } From ef5c009eb1a76dff90ec5f06fa3ecfd4c66209f3 Mon Sep 17 00:00:00 2001 From: yumosx Date: Wed, 9 Jul 2025 23:57:38 +0800 Subject: [PATCH 09/24] update and add tests --- errs/errs.go | 5 +- internal/repository/dao/quota.go | 48 ++++++------ internal/test/quota_test.go | 125 +++++++++++++++++++++++++++++++ 3 files changed, 151 insertions(+), 27 deletions(-) diff --git a/errs/errs.go b/errs/errs.go index dd02860..926e1f1 100644 --- a/errs/errs.go +++ b/errs/errs.go @@ -19,6 +19,7 @@ import ( ) var ( - DeductAmount = errors.New("扣减失败") - ErrBizConfigNotFound = errors.New("biz config not found") + ErrDeductAmountFailed = errors.New("deduct amount failed") + ErrInsufficientQuota = errors.New("insufficient quota") + ErrBizConfigNotFound = errors.New("biz config not found") ) diff --git a/internal/repository/dao/quota.go b/internal/repository/dao/quota.go index d201383..b90ac17 100644 --- a/internal/repository/dao/quota.go +++ b/internal/repository/dao/quota.go @@ -112,6 +112,10 @@ func (dao *QuotaDao) SaveQuota(ctx context.Context, quota Quota) error { return result.Error } + if result.RowsAffected == 0 { + return nil + } + return tx.Clauses(clause.OnConflict{ Columns: []clause.Column{{Name: "key"}}, DoUpdates: clause.Assignments(map[string]any{ @@ -184,16 +188,9 @@ func (dao *QuotaDao) deduct(tx *gorm.DB, uid int64, amount int64, now int64) err remain := amount // 先扣临时额度 - for i := range tempQuotas { - if remain <= 0 { - break - } - tq := &tempQuotas[i] - deduct := tq.Amount - if deduct > remain { - deduct = remain - } - // 原子扣减,防止并发下超扣 + for _, tq := range tempQuotas { + deduct := min(tq.Amount, remain) + update := tx.Model(&TempQuota{}). Where("id = ? AND amount >= ?", tq.ID, deduct). Updates(map[string]any{ @@ -204,26 +201,27 @@ func (dao *QuotaDao) deduct(tx *gorm.DB, uid int64, amount int64, now int64) err return update.Error } if update.RowsAffected == 0 { - continue // 这条被其他并发扣完,跳过 + continue } + remain -= deduct + if remain <= 0 { + return nil + } } // 如果还有剩余,从主额度扣 - if remain > 0 { - result := tx.Model(&Quota{}). - Where("uid = ? AND amount >= ?", uid, remain). - Updates(map[string]any{ - "amount": gorm.Expr("amount - ?", remain), - "utime": now, - "last_clear_time": now, - }) - if result.Error != nil { - return result.Error - } - if result.RowsAffected == 0 { - return errs.DeductAmount - } + result := tx.Model(&Quota{}). + Where("uid = ? AND amount >= ?", uid, remain). + Updates(map[string]any{ + "amount": gorm.Expr("amount - ?", remain), + "utime": now, + }) + if result.Error != nil { + return result.Error + } + if result.RowsAffected == 0 { + return errs.ErrDeductAmountFailed } return nil diff --git a/internal/test/quota_test.go b/internal/test/quota_test.go index 1fe3000..db584f1 100644 --- a/internal/test/quota_test.go +++ b/internal/test/quota_test.go @@ -19,6 +19,7 @@ import ( "net/http" "net/http/httptest" "testing" + "time" "github.com/ecodeclub/ai-gateway-go/internal/repository" "github.com/ecodeclub/ai-gateway-go/internal/repository/dao" @@ -73,6 +74,8 @@ func (q *QuotaSuite) TearDownTest() { require.NoError(q.T(), err) err = q.db.Exec("TRUNCATE TABLE quota_records").Error require.NoError(q.T(), err) + err = q.db.Exec("TRUNCATE TABLE temp_quotas").Error + require.NoError(q.T(), err) } func (q *QuotaSuite) TestQuotaSave() { @@ -128,3 +131,125 @@ func (q *QuotaSuite) TestQuotaSave() { }) } } + +func (q *QuotaSuite) TestSaveTempQuota() { + t := q.T() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + testcases := []struct { + name string + reqBody string + before func() + after func() + }{ + { + name: "save temp", + reqBody: `{"amount": 100000, "key": "23911", "start_time": "123", "end_time": "456"}`, + before: func() { + sess := mocks.NewMockSession(ctrl) + sess.EXPECT().Claims().Return(session.Claims{ + Uid: 1, + }).AnyTimes() + provider := mocks.NewMockProvider(ctrl) + session.SetDefaultProvider(provider) + provider.EXPECT().Get(gomock.Any()).Return(sess, nil) + }, + after: func() { + var quota dao.TempQuota + err := q.db.Where("id = ?", 1).First("a).Error + require.NoError(t, err) + assert.Equal(t, int64(100000), quota.Amount) + }, + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + tc.before() + req, err := http.NewRequest(http.MethodPost, "/tmp/save", bytes.NewBuffer([]byte(tc.reqBody))) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + resp := httptest.NewRecorder() + q.server.ServeHTTP(resp, req) + + assert.Equal(t, http.StatusOK, resp.Code) + + tc.after() + }) + } +} + +func (q *QuotaSuite) TestDeduct() { + t := q.T() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + testcases := []struct { + name string + reqBody string + before func() + after func() + }{ + { + name: "deduct quota", + reqBody: `{"amount": 10, "key": "23911"}`, + before: func() { + sess := mocks.NewMockSession(ctrl) + sess.EXPECT().Claims().Return(session.Claims{ + Uid: 1, + }).AnyTimes() + provider := mocks.NewMockProvider(ctrl) + session.SetDefaultProvider(provider) + provider.EXPECT().Get(gomock.Any()).Return(sess, nil) + + quota := dao.Quota{Amount: 20, Key: "23911", UID: 1} + q.db.Create("a) + }, + after: func() { + var quota dao.Quota + err := q.db.Where("id = ?", 1).First("a).Error + require.NoError(t, err) + assert.Equal(t, int64(10), quota.Amount) + }, + }, + { + name: "deduct temp quota", + reqBody: `{"amount": 10, "key": "23921"}`, + before: func() { + sess := mocks.NewMockSession(ctrl) + sess.EXPECT().Claims().Return(session.Claims{ + Uid: 1, + }).AnyTimes() + provider := mocks.NewMockProvider(ctrl) + session.SetDefaultProvider(provider) + provider.EXPECT().Get(gomock.Any()).Return(sess, nil) + + quota := dao.TempQuota{Amount: 20, Key: "23921", UID: 1, StartTime: time.Now().Unix(), EndTime: time.Now().Add(24 * time.Hour).Unix()} + err := q.db.Create("a).Error + require.NoError(t, err) + }, + after: func() { + var quota dao.TempQuota + err := q.db.Where("id = ?", 1).First("a).Error + require.NoError(t, err) + assert.Equal(t, int64(10), quota.Amount) + }, + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + tc.before() + req, err := http.NewRequest(http.MethodPost, "/deduct", bytes.NewBuffer([]byte(tc.reqBody))) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + resp := httptest.NewRecorder() + q.server.ServeHTTP(resp, req) + + assert.Equal(t, http.StatusOK, resp.Code) + tc.after() + }) + } +} From f5f967522355239288ffd98af2591505cdb5c6e0 Mon Sep 17 00:00:00 2001 From: yumosx Date: Thu, 10 Jul 2025 00:05:24 +0800 Subject: [PATCH 10/24] fix lint --- .github/workflows/golangci-lint.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index 23b5dcb..7bbe491 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -37,4 +37,6 @@ jobs: uses: golangci/golangci-lint-action@v7 with: version: v2.0 - only-new-issues: true \ No newline at end of file + only-new-issues: true + - name: Generate Protobuf + run: buf generate \ No newline at end of file From de2286b9bf963a3ad4536a708ad0d3c8153e2eaa Mon Sep 17 00:00:00 2001 From: yumosx Date: Thu, 10 Jul 2025 00:07:14 +0800 Subject: [PATCH 11/24] fix package --- api/proto/ai/v1/ai.proto | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/proto/ai/v1/ai.proto b/api/proto/ai/v1/ai.proto index 2117af1..7028425 100644 --- a/api/proto/ai/v1/ai.proto +++ b/api/proto/ai/v1/ai.proto @@ -14,7 +14,7 @@ syntax = "proto3"; package ai.v1; -option go_package = "ai/v1;aiv1"; +option go_package = "github.com/ecodeclub/ai-gateway-go/api/gen/ai/v1;ai_v1"; service AIService { rpc Chat(Message) returns (ChatResponse); From a96b0fb8f977716a28bae42a921f88e684d52f60 Mon Sep 17 00:00:00 2001 From: yumosx Date: Thu, 10 Jul 2025 00:36:13 +0800 Subject: [PATCH 12/24] add the pb.go --- api/gen/ai/v1/ai.pb.go | 743 ++++++++++++++++++++++++++++++++++++ api/gen/ai/v1/ai_grpc.pb.go | 435 +++++++++++++++++++++ api/proto/ai/v1/ai.proto | 2 +- api/proto/buf.yaml | 36 -- 4 files changed, 1179 insertions(+), 37 deletions(-) create mode 100644 api/gen/ai/v1/ai.pb.go create mode 100644 api/gen/ai/v1/ai_grpc.pb.go delete mode 100644 api/proto/buf.yaml diff --git a/api/gen/ai/v1/ai.pb.go b/api/gen/ai/v1/ai.pb.go new file mode 100644 index 0000000..9d70781 --- /dev/null +++ b/api/gen/ai/v1/ai.pb.go @@ -0,0 +1,743 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.6 +// protoc (unknown) +// source: ai.proto + +package aiv1 + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type Role int32 + +const ( + Role_UNKNOWN Role = 0 + Role_USER Role = 1 + Role_ASSISTANT Role = 2 + Role_SYSTEM Role = 3 + Role_TOOL Role = 4 +) + +// Enum value maps for Role. +var ( + Role_name = map[int32]string{ + 0: "UNKNOWN", + 1: "USER", + 2: "ASSISTANT", + 3: "SYSTEM", + 4: "TOOL", + } + Role_value = map[string]int32{ + "UNKNOWN": 0, + "USER": 1, + "ASSISTANT": 2, + "SYSTEM": 3, + "TOOL": 4, + } +) + +func (x Role) Enum() *Role { + p := new(Role) + *p = x + return p +} + +func (x Role) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (Role) Descriptor() protoreflect.EnumDescriptor { + return file_ai_proto_enumTypes[0].Descriptor() +} + +func (Role) Type() protoreflect.EnumType { + return &file_ai_proto_enumTypes[0] +} + +func (x Role) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use Role.Descriptor instead. +func (Role) EnumDescriptor() ([]byte, []int) { + return file_ai_proto_rawDescGZIP(), []int{0} +} + +type StreamEvent struct { + state protoimpl.MessageState `protogen:"open.v1"` + Final bool `protobuf:"varint,1,opt,name=final,proto3" json:"final,omitempty"` + ReasoningContent string `protobuf:"bytes,2,opt,name=reasoningContent,proto3" json:"reasoningContent,omitempty"` + Content string `protobuf:"bytes,3,opt,name=content,proto3" json:"content,omitempty"` + Err string `protobuf:"bytes,4,opt,name=err,proto3" json:"err,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *StreamEvent) Reset() { + *x = StreamEvent{} + mi := &file_ai_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *StreamEvent) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*StreamEvent) ProtoMessage() {} + +func (x *StreamEvent) ProtoReflect() protoreflect.Message { + mi := &file_ai_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use StreamEvent.ProtoReflect.Descriptor instead. +func (*StreamEvent) Descriptor() ([]byte, []int) { + return file_ai_proto_rawDescGZIP(), []int{0} +} + +func (x *StreamEvent) GetFinal() bool { + if x != nil { + return x.Final + } + return false +} + +func (x *StreamEvent) GetReasoningContent() string { + if x != nil { + return x.ReasoningContent + } + return "" +} + +func (x *StreamEvent) GetContent() string { + if x != nil { + return x.Content + } + return "" +} + +func (x *StreamEvent) GetErr() string { + if x != nil { + return x.Err + } + return "" +} + +type Conversation struct { + state protoimpl.MessageState `protogen:"open.v1"` + Sn string `protobuf:"bytes,1,opt,name=sn,proto3" json:"sn,omitempty"` + Uid string `protobuf:"bytes,2,opt,name=uid,proto3" json:"uid,omitempty"` + Title string `protobuf:"bytes,3,opt,name=title,proto3" json:"title,omitempty"` + Message []*Message `protobuf:"bytes,4,rep,name=message,proto3" json:"message,omitempty"` + Ctime string `protobuf:"bytes,5,opt,name=ctime,proto3" json:"ctime,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Conversation) Reset() { + *x = Conversation{} + mi := &file_ai_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Conversation) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Conversation) ProtoMessage() {} + +func (x *Conversation) ProtoReflect() protoreflect.Message { + mi := &file_ai_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Conversation.ProtoReflect.Descriptor instead. +func (*Conversation) Descriptor() ([]byte, []int) { + return file_ai_proto_rawDescGZIP(), []int{1} +} + +func (x *Conversation) GetSn() string { + if x != nil { + return x.Sn + } + return "" +} + +func (x *Conversation) GetUid() string { + if x != nil { + return x.Uid + } + return "" +} + +func (x *Conversation) GetTitle() string { + if x != nil { + return x.Title + } + return "" +} + +func (x *Conversation) GetMessage() []*Message { + if x != nil { + return x.Message + } + return nil +} + +func (x *Conversation) GetCtime() string { + if x != nil { + return x.Ctime + } + return "" +} + +type ListReq struct { + state protoimpl.MessageState `protogen:"open.v1"` + Uid string `protobuf:"bytes,1,opt,name=uid,proto3" json:"uid,omitempty"` + Offset int64 `protobuf:"varint,2,opt,name=offset,proto3" json:"offset,omitempty"` + Limit int64 `protobuf:"varint,3,opt,name=limit,proto3" json:"limit,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ListReq) Reset() { + *x = ListReq{} + mi := &file_ai_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ListReq) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ListReq) ProtoMessage() {} + +func (x *ListReq) ProtoReflect() protoreflect.Message { + mi := &file_ai_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ListReq.ProtoReflect.Descriptor instead. +func (*ListReq) Descriptor() ([]byte, []int) { + return file_ai_proto_rawDescGZIP(), []int{2} +} + +func (x *ListReq) GetUid() string { + if x != nil { + return x.Uid + } + return "" +} + +func (x *ListReq) GetOffset() int64 { + if x != nil { + return x.Offset + } + return 0 +} + +func (x *ListReq) GetLimit() int64 { + if x != nil { + return x.Limit + } + return 0 +} + +type ListResp struct { + state protoimpl.MessageState `protogen:"open.v1"` + Conversations []*Conversation `protobuf:"bytes,1,rep,name=conversations,proto3" json:"conversations,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ListResp) Reset() { + *x = ListResp{} + mi := &file_ai_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ListResp) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ListResp) ProtoMessage() {} + +func (x *ListResp) ProtoReflect() protoreflect.Message { + mi := &file_ai_proto_msgTypes[3] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ListResp.ProtoReflect.Descriptor instead. +func (*ListResp) Descriptor() ([]byte, []int) { + return file_ai_proto_rawDescGZIP(), []int{3} +} + +func (x *ListResp) GetConversations() []*Conversation { + if x != nil { + return x.Conversations + } + return nil +} + +type LLMRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + Sn string `protobuf:"bytes,1,opt,name=sn,proto3" json:"sn,omitempty"` + Message []*Message `protobuf:"bytes,2,rep,name=message,proto3" json:"message,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *LLMRequest) Reset() { + *x = LLMRequest{} + mi := &file_ai_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *LLMRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*LLMRequest) ProtoMessage() {} + +func (x *LLMRequest) ProtoReflect() protoreflect.Message { + mi := &file_ai_proto_msgTypes[4] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use LLMRequest.ProtoReflect.Descriptor instead. +func (*LLMRequest) Descriptor() ([]byte, []int) { + return file_ai_proto_rawDescGZIP(), []int{4} +} + +func (x *LLMRequest) GetSn() string { + if x != nil { + return x.Sn + } + return "" +} + +func (x *LLMRequest) GetMessage() []*Message { + if x != nil { + return x.Message + } + return nil +} + +type DetailRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + Sn string `protobuf:"bytes,1,opt,name=sn,proto3" json:"sn,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *DetailRequest) Reset() { + *x = DetailRequest{} + mi := &file_ai_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *DetailRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DetailRequest) ProtoMessage() {} + +func (x *DetailRequest) ProtoReflect() protoreflect.Message { + mi := &file_ai_proto_msgTypes[5] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DetailRequest.ProtoReflect.Descriptor instead. +func (*DetailRequest) Descriptor() ([]byte, []int) { + return file_ai_proto_rawDescGZIP(), []int{5} +} + +func (x *DetailRequest) GetSn() string { + if x != nil { + return x.Sn + } + return "" +} + +type DetailResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Message []*Message `protobuf:"bytes,2,rep,name=message,proto3" json:"message,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *DetailResponse) Reset() { + *x = DetailResponse{} + mi := &file_ai_proto_msgTypes[6] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *DetailResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DetailResponse) ProtoMessage() {} + +func (x *DetailResponse) ProtoReflect() protoreflect.Message { + mi := &file_ai_proto_msgTypes[6] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DetailResponse.ProtoReflect.Descriptor instead. +func (*DetailResponse) Descriptor() ([]byte, []int) { + return file_ai_proto_rawDescGZIP(), []int{6} +} + +func (x *DetailResponse) GetMessage() []*Message { + if x != nil { + return x.Message + } + return nil +} + +type Message struct { + state protoimpl.MessageState `protogen:"open.v1"` + Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"` + Role Role `protobuf:"varint,2,opt,name=role,proto3,enum=ai.v1.Role" json:"role,omitempty"` + Content string `protobuf:"bytes,3,opt,name=content,proto3" json:"content,omitempty"` + ReasoningContent string `protobuf:"bytes,4,opt,name=reasoningContent,proto3" json:"reasoningContent,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Message) Reset() { + *x = Message{} + mi := &file_ai_proto_msgTypes[7] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Message) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Message) ProtoMessage() {} + +func (x *Message) ProtoReflect() protoreflect.Message { + mi := &file_ai_proto_msgTypes[7] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Message.ProtoReflect.Descriptor instead. +func (*Message) Descriptor() ([]byte, []int) { + return file_ai_proto_rawDescGZIP(), []int{7} +} + +func (x *Message) GetId() string { + if x != nil { + return x.Id + } + return "" +} + +func (x *Message) GetRole() Role { + if x != nil { + return x.Role + } + return Role_UNKNOWN +} + +func (x *Message) GetContent() string { + if x != nil { + return x.Content + } + return "" +} + +func (x *Message) GetReasoningContent() string { + if x != nil { + return x.ReasoningContent + } + return "" +} + +type ChatResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Sn string `protobuf:"bytes,1,opt,name=sn,proto3" json:"sn,omitempty"` + Response *Message `protobuf:"bytes,2,opt,name=response,proto3" json:"response,omitempty"` + Metadata string `protobuf:"bytes,3,opt,name=metadata,proto3" json:"metadata,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ChatResponse) Reset() { + *x = ChatResponse{} + mi := &file_ai_proto_msgTypes[8] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ChatResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ChatResponse) ProtoMessage() {} + +func (x *ChatResponse) ProtoReflect() protoreflect.Message { + mi := &file_ai_proto_msgTypes[8] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ChatResponse.ProtoReflect.Descriptor instead. +func (*ChatResponse) Descriptor() ([]byte, []int) { + return file_ai_proto_rawDescGZIP(), []int{8} +} + +func (x *ChatResponse) GetSn() string { + if x != nil { + return x.Sn + } + return "" +} + +func (x *ChatResponse) GetResponse() *Message { + if x != nil { + return x.Response + } + return nil +} + +func (x *ChatResponse) GetMetadata() string { + if x != nil { + return x.Metadata + } + return "" +} + +var File_ai_proto protoreflect.FileDescriptor + +const file_ai_proto_rawDesc = "" + + "\n" + + "\bai.proto\x12\x05ai.v1\"{\n" + + "\vStreamEvent\x12\x14\n" + + "\x05final\x18\x01 \x01(\bR\x05final\x12*\n" + + "\x10reasoningContent\x18\x02 \x01(\tR\x10reasoningContent\x12\x18\n" + + "\acontent\x18\x03 \x01(\tR\acontent\x12\x10\n" + + "\x03err\x18\x04 \x01(\tR\x03err\"\x86\x01\n" + + "\fConversation\x12\x0e\n" + + "\x02sn\x18\x01 \x01(\tR\x02sn\x12\x10\n" + + "\x03uid\x18\x02 \x01(\tR\x03uid\x12\x14\n" + + "\x05title\x18\x03 \x01(\tR\x05title\x12(\n" + + "\amessage\x18\x04 \x03(\v2\x0e.ai.v1.MessageR\amessage\x12\x14\n" + + "\x05ctime\x18\x05 \x01(\tR\x05ctime\"I\n" + + "\aListReq\x12\x10\n" + + "\x03uid\x18\x01 \x01(\tR\x03uid\x12\x16\n" + + "\x06offset\x18\x02 \x01(\x03R\x06offset\x12\x14\n" + + "\x05limit\x18\x03 \x01(\x03R\x05limit\"E\n" + + "\bListResp\x129\n" + + "\rconversations\x18\x01 \x03(\v2\x13.ai.v1.ConversationR\rconversations\"F\n" + + "\n" + + "LLMRequest\x12\x0e\n" + + "\x02sn\x18\x01 \x01(\tR\x02sn\x12(\n" + + "\amessage\x18\x02 \x03(\v2\x0e.ai.v1.MessageR\amessage\"\x1f\n" + + "\rDetailRequest\x12\x0e\n" + + "\x02sn\x18\x01 \x01(\tR\x02sn\":\n" + + "\x0eDetailResponse\x12(\n" + + "\amessage\x18\x02 \x03(\v2\x0e.ai.v1.MessageR\amessage\"\x80\x01\n" + + "\aMessage\x12\x0e\n" + + "\x02id\x18\x01 \x01(\tR\x02id\x12\x1f\n" + + "\x04role\x18\x02 \x01(\x0e2\v.ai.v1.RoleR\x04role\x12\x18\n" + + "\acontent\x18\x03 \x01(\tR\acontent\x12*\n" + + "\x10reasoningContent\x18\x04 \x01(\tR\x10reasoningContent\"f\n" + + "\fChatResponse\x12\x0e\n" + + "\x02sn\x18\x01 \x01(\tR\x02sn\x12*\n" + + "\bresponse\x18\x02 \x01(\v2\x0e.ai.v1.MessageR\bresponse\x12\x1a\n" + + "\bmetadata\x18\x03 \x01(\tR\bmetadata*B\n" + + "\x04Role\x12\v\n" + + "\aUNKNOWN\x10\x00\x12\b\n" + + "\x04USER\x10\x01\x12\r\n" + + "\tASSISTANT\x10\x02\x12\n" + + "\n" + + "\x06SYSTEM\x10\x03\x12\b\n" + + "\x04TOOL\x10\x042h\n" + + "\tAIService\x12+\n" + + "\x04Chat\x12\x0e.ai.v1.Message\x1a\x13.ai.v1.ChatResponse\x12.\n" + + "\x06Stream\x12\x0e.ai.v1.Message\x1a\x12.ai.v1.StreamEvent0\x012\x8c\x02\n" + + "\x13ConversationService\x122\n" + + "\x06Create\x12\x13.ai.v1.Conversation\x1a\x13.ai.v1.Conversation\x12'\n" + + "\x04List\x12\x0e.ai.v1.ListReq\x1a\x0f.ai.v1.ListResp\x12.\n" + + "\x04Chat\x12\x11.ai.v1.LLMRequest\x1a\x13.ai.v1.ChatResponse\x125\n" + + "\x06Detail\x12\x14.ai.v1.DetailRequest\x1a\x15.ai.v1.DetailResponse\x121\n" + + "\x06Stream\x12\x11.ai.v1.LLMRequest\x1a\x12.ai.v1.StreamEvent0\x01Bz\n" + + "\tcom.ai.v1B\aAiProtoP\x01Z/github.com/ecodeclub/ai-gateway-go/api/gen;aiv1\xa2\x02\x03AXX\xaa\x02\x05Ai.V1\xca\x02\x05Ai\\V1\xe2\x02\x11Ai\\V1\\GPBMetadata\xea\x02\x06Ai::V1b\x06proto3" + +var ( + file_ai_proto_rawDescOnce sync.Once + file_ai_proto_rawDescData []byte +) + +func file_ai_proto_rawDescGZIP() []byte { + file_ai_proto_rawDescOnce.Do(func() { + file_ai_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_ai_proto_rawDesc), len(file_ai_proto_rawDesc))) + }) + return file_ai_proto_rawDescData +} + +var file_ai_proto_enumTypes = make([]protoimpl.EnumInfo, 1) +var file_ai_proto_msgTypes = make([]protoimpl.MessageInfo, 9) +var file_ai_proto_goTypes = []any{ + (Role)(0), // 0: ai.v1.Role + (*StreamEvent)(nil), // 1: ai.v1.StreamEvent + (*Conversation)(nil), // 2: ai.v1.Conversation + (*ListReq)(nil), // 3: ai.v1.ListReq + (*ListResp)(nil), // 4: ai.v1.ListResp + (*LLMRequest)(nil), // 5: ai.v1.LLMRequest + (*DetailRequest)(nil), // 6: ai.v1.DetailRequest + (*DetailResponse)(nil), // 7: ai.v1.DetailResponse + (*Message)(nil), // 8: ai.v1.Message + (*ChatResponse)(nil), // 9: ai.v1.ChatResponse +} +var file_ai_proto_depIdxs = []int32{ + 8, // 0: ai.v1.Conversation.message:type_name -> ai.v1.Message + 2, // 1: ai.v1.ListResp.conversations:type_name -> ai.v1.Conversation + 8, // 2: ai.v1.LLMRequest.message:type_name -> ai.v1.Message + 8, // 3: ai.v1.DetailResponse.message:type_name -> ai.v1.Message + 0, // 4: ai.v1.Message.role:type_name -> ai.v1.Role + 8, // 5: ai.v1.ChatResponse.response:type_name -> ai.v1.Message + 8, // 6: ai.v1.AIService.Chat:input_type -> ai.v1.Message + 8, // 7: ai.v1.AIService.Stream:input_type -> ai.v1.Message + 2, // 8: ai.v1.ConversationService.Create:input_type -> ai.v1.Conversation + 3, // 9: ai.v1.ConversationService.List:input_type -> ai.v1.ListReq + 5, // 10: ai.v1.ConversationService.Chat:input_type -> ai.v1.LLMRequest + 6, // 11: ai.v1.ConversationService.Detail:input_type -> ai.v1.DetailRequest + 5, // 12: ai.v1.ConversationService.Stream:input_type -> ai.v1.LLMRequest + 9, // 13: ai.v1.AIService.Chat:output_type -> ai.v1.ChatResponse + 1, // 14: ai.v1.AIService.Stream:output_type -> ai.v1.StreamEvent + 2, // 15: ai.v1.ConversationService.Create:output_type -> ai.v1.Conversation + 4, // 16: ai.v1.ConversationService.List:output_type -> ai.v1.ListResp + 9, // 17: ai.v1.ConversationService.Chat:output_type -> ai.v1.ChatResponse + 7, // 18: ai.v1.ConversationService.Detail:output_type -> ai.v1.DetailResponse + 1, // 19: ai.v1.ConversationService.Stream:output_type -> ai.v1.StreamEvent + 13, // [13:20] is the sub-list for method output_type + 6, // [6:13] is the sub-list for method input_type + 6, // [6:6] is the sub-list for extension type_name + 6, // [6:6] is the sub-list for extension extendee + 0, // [0:6] is the sub-list for field type_name +} + +func init() { file_ai_proto_init() } +func file_ai_proto_init() { + if File_ai_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_ai_proto_rawDesc), len(file_ai_proto_rawDesc)), + NumEnums: 1, + NumMessages: 9, + NumExtensions: 0, + NumServices: 2, + }, + GoTypes: file_ai_proto_goTypes, + DependencyIndexes: file_ai_proto_depIdxs, + EnumInfos: file_ai_proto_enumTypes, + MessageInfos: file_ai_proto_msgTypes, + }.Build() + File_ai_proto = out.File + file_ai_proto_goTypes = nil + file_ai_proto_depIdxs = nil +} diff --git a/api/gen/ai/v1/ai_grpc.pb.go b/api/gen/ai/v1/ai_grpc.pb.go new file mode 100644 index 0000000..359a52f --- /dev/null +++ b/api/gen/ai/v1/ai_grpc.pb.go @@ -0,0 +1,435 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.5.1 +// - protoc (unknown) +// source: ai.proto + +package aiv1 + +import ( + context "context" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +// Requires gRPC-Go v1.64.0 or later. +const _ = grpc.SupportPackageIsVersion9 + +const ( + AIService_Chat_FullMethodName = "/ai.v1.AIService/Chat" + AIService_Stream_FullMethodName = "/ai.v1.AIService/Stream" +) + +// AIServiceClient is the client API for AIService service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +type AIServiceClient interface { + Chat(ctx context.Context, in *Message, opts ...grpc.CallOption) (*ChatResponse, error) + Stream(ctx context.Context, in *Message, opts ...grpc.CallOption) (grpc.ServerStreamingClient[StreamEvent], error) +} + +type aIServiceClient struct { + cc grpc.ClientConnInterface +} + +func NewAIServiceClient(cc grpc.ClientConnInterface) AIServiceClient { + return &aIServiceClient{cc} +} + +func (c *aIServiceClient) Chat(ctx context.Context, in *Message, opts ...grpc.CallOption) (*ChatResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(ChatResponse) + err := c.cc.Invoke(ctx, AIService_Chat_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *aIServiceClient) Stream(ctx context.Context, in *Message, opts ...grpc.CallOption) (grpc.ServerStreamingClient[StreamEvent], error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + stream, err := c.cc.NewStream(ctx, &AIService_ServiceDesc.Streams[0], AIService_Stream_FullMethodName, cOpts...) + if err != nil { + return nil, err + } + x := &grpc.GenericClientStream[Message, StreamEvent]{ClientStream: stream} + if err := x.ClientStream.SendMsg(in); err != nil { + return nil, err + } + if err := x.ClientStream.CloseSend(); err != nil { + return nil, err + } + return x, nil +} + +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type AIService_StreamClient = grpc.ServerStreamingClient[StreamEvent] + +// AIServiceServer is the server API for AIService service. +// All implementations must embed UnimplementedAIServiceServer +// for forward compatibility. +type AIServiceServer interface { + Chat(context.Context, *Message) (*ChatResponse, error) + Stream(*Message, grpc.ServerStreamingServer[StreamEvent]) error + mustEmbedUnimplementedAIServiceServer() +} + +// UnimplementedAIServiceServer must be embedded to have +// forward compatible implementations. +// +// NOTE: this should be embedded by value instead of pointer to avoid a nil +// pointer dereference when methods are called. +type UnimplementedAIServiceServer struct{} + +func (UnimplementedAIServiceServer) Chat(context.Context, *Message) (*ChatResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method Chat not implemented") +} +func (UnimplementedAIServiceServer) Stream(*Message, grpc.ServerStreamingServer[StreamEvent]) error { + return status.Errorf(codes.Unimplemented, "method Stream not implemented") +} +func (UnimplementedAIServiceServer) mustEmbedUnimplementedAIServiceServer() {} +func (UnimplementedAIServiceServer) testEmbeddedByValue() {} + +// UnsafeAIServiceServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to AIServiceServer will +// result in compilation errors. +type UnsafeAIServiceServer interface { + mustEmbedUnimplementedAIServiceServer() +} + +func RegisterAIServiceServer(s grpc.ServiceRegistrar, srv AIServiceServer) { + // If the following call pancis, it indicates UnimplementedAIServiceServer was + // embedded by pointer and is nil. This will cause panics if an + // unimplemented method is ever invoked, so we test this at initialization + // time to prevent it from happening at runtime later due to I/O. + if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { + t.testEmbeddedByValue() + } + s.RegisterService(&AIService_ServiceDesc, srv) +} + +func _AIService_Chat_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(Message) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(AIServiceServer).Chat(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: AIService_Chat_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(AIServiceServer).Chat(ctx, req.(*Message)) + } + return interceptor(ctx, in, info, handler) +} + +func _AIService_Stream_Handler(srv interface{}, stream grpc.ServerStream) error { + m := new(Message) + if err := stream.RecvMsg(m); err != nil { + return err + } + return srv.(AIServiceServer).Stream(m, &grpc.GenericServerStream[Message, StreamEvent]{ServerStream: stream}) +} + +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type AIService_StreamServer = grpc.ServerStreamingServer[StreamEvent] + +// AIService_ServiceDesc is the grpc.ServiceDesc for AIService service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var AIService_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "ai.v1.AIService", + HandlerType: (*AIServiceServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "Chat", + Handler: _AIService_Chat_Handler, + }, + }, + Streams: []grpc.StreamDesc{ + { + StreamName: "Stream", + Handler: _AIService_Stream_Handler, + ServerStreams: true, + }, + }, + Metadata: "ai.proto", +} + +const ( + ConversationService_Create_FullMethodName = "/ai.v1.ConversationService/Create" + ConversationService_List_FullMethodName = "/ai.v1.ConversationService/List" + ConversationService_Chat_FullMethodName = "/ai.v1.ConversationService/Chat" + ConversationService_Detail_FullMethodName = "/ai.v1.ConversationService/Detail" + ConversationService_Stream_FullMethodName = "/ai.v1.ConversationService/Stream" +) + +// ConversationServiceClient is the client API for ConversationService service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +type ConversationServiceClient interface { + Create(ctx context.Context, in *Conversation, opts ...grpc.CallOption) (*Conversation, error) + List(ctx context.Context, in *ListReq, opts ...grpc.CallOption) (*ListResp, error) + Chat(ctx context.Context, in *LLMRequest, opts ...grpc.CallOption) (*ChatResponse, error) + Detail(ctx context.Context, in *DetailRequest, opts ...grpc.CallOption) (*DetailResponse, error) + Stream(ctx context.Context, in *LLMRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[StreamEvent], error) +} + +type conversationServiceClient struct { + cc grpc.ClientConnInterface +} + +func NewConversationServiceClient(cc grpc.ClientConnInterface) ConversationServiceClient { + return &conversationServiceClient{cc} +} + +func (c *conversationServiceClient) Create(ctx context.Context, in *Conversation, opts ...grpc.CallOption) (*Conversation, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(Conversation) + err := c.cc.Invoke(ctx, ConversationService_Create_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *conversationServiceClient) List(ctx context.Context, in *ListReq, opts ...grpc.CallOption) (*ListResp, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(ListResp) + err := c.cc.Invoke(ctx, ConversationService_List_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *conversationServiceClient) Chat(ctx context.Context, in *LLMRequest, opts ...grpc.CallOption) (*ChatResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(ChatResponse) + err := c.cc.Invoke(ctx, ConversationService_Chat_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *conversationServiceClient) Detail(ctx context.Context, in *DetailRequest, opts ...grpc.CallOption) (*DetailResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(DetailResponse) + err := c.cc.Invoke(ctx, ConversationService_Detail_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *conversationServiceClient) Stream(ctx context.Context, in *LLMRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[StreamEvent], error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + stream, err := c.cc.NewStream(ctx, &ConversationService_ServiceDesc.Streams[0], ConversationService_Stream_FullMethodName, cOpts...) + if err != nil { + return nil, err + } + x := &grpc.GenericClientStream[LLMRequest, StreamEvent]{ClientStream: stream} + if err := x.ClientStream.SendMsg(in); err != nil { + return nil, err + } + if err := x.ClientStream.CloseSend(); err != nil { + return nil, err + } + return x, nil +} + +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type ConversationService_StreamClient = grpc.ServerStreamingClient[StreamEvent] + +// ConversationServiceServer is the server API for ConversationService service. +// All implementations must embed UnimplementedConversationServiceServer +// for forward compatibility. +type ConversationServiceServer interface { + Create(context.Context, *Conversation) (*Conversation, error) + List(context.Context, *ListReq) (*ListResp, error) + Chat(context.Context, *LLMRequest) (*ChatResponse, error) + Detail(context.Context, *DetailRequest) (*DetailResponse, error) + Stream(*LLMRequest, grpc.ServerStreamingServer[StreamEvent]) error + mustEmbedUnimplementedConversationServiceServer() +} + +// UnimplementedConversationServiceServer must be embedded to have +// forward compatible implementations. +// +// NOTE: this should be embedded by value instead of pointer to avoid a nil +// pointer dereference when methods are called. +type UnimplementedConversationServiceServer struct{} + +func (UnimplementedConversationServiceServer) Create(context.Context, *Conversation) (*Conversation, error) { + return nil, status.Errorf(codes.Unimplemented, "method Create not implemented") +} +func (UnimplementedConversationServiceServer) List(context.Context, *ListReq) (*ListResp, error) { + return nil, status.Errorf(codes.Unimplemented, "method List not implemented") +} +func (UnimplementedConversationServiceServer) Chat(context.Context, *LLMRequest) (*ChatResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method Chat not implemented") +} +func (UnimplementedConversationServiceServer) Detail(context.Context, *DetailRequest) (*DetailResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method Detail not implemented") +} +func (UnimplementedConversationServiceServer) Stream(*LLMRequest, grpc.ServerStreamingServer[StreamEvent]) error { + return status.Errorf(codes.Unimplemented, "method Stream not implemented") +} +func (UnimplementedConversationServiceServer) mustEmbedUnimplementedConversationServiceServer() {} +func (UnimplementedConversationServiceServer) testEmbeddedByValue() {} + +// UnsafeConversationServiceServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to ConversationServiceServer will +// result in compilation errors. +type UnsafeConversationServiceServer interface { + mustEmbedUnimplementedConversationServiceServer() +} + +func RegisterConversationServiceServer(s grpc.ServiceRegistrar, srv ConversationServiceServer) { + // If the following call pancis, it indicates UnimplementedConversationServiceServer was + // embedded by pointer and is nil. This will cause panics if an + // unimplemented method is ever invoked, so we test this at initialization + // time to prevent it from happening at runtime later due to I/O. + if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { + t.testEmbeddedByValue() + } + s.RegisterService(&ConversationService_ServiceDesc, srv) +} + +func _ConversationService_Create_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(Conversation) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ConversationServiceServer).Create(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: ConversationService_Create_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ConversationServiceServer).Create(ctx, req.(*Conversation)) + } + return interceptor(ctx, in, info, handler) +} + +func _ConversationService_List_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(ListReq) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ConversationServiceServer).List(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: ConversationService_List_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ConversationServiceServer).List(ctx, req.(*ListReq)) + } + return interceptor(ctx, in, info, handler) +} + +func _ConversationService_Chat_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(LLMRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ConversationServiceServer).Chat(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: ConversationService_Chat_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ConversationServiceServer).Chat(ctx, req.(*LLMRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _ConversationService_Detail_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(DetailRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ConversationServiceServer).Detail(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: ConversationService_Detail_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ConversationServiceServer).Detail(ctx, req.(*DetailRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _ConversationService_Stream_Handler(srv interface{}, stream grpc.ServerStream) error { + m := new(LLMRequest) + if err := stream.RecvMsg(m); err != nil { + return err + } + return srv.(ConversationServiceServer).Stream(m, &grpc.GenericServerStream[LLMRequest, StreamEvent]{ServerStream: stream}) +} + +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type ConversationService_StreamServer = grpc.ServerStreamingServer[StreamEvent] + +// ConversationService_ServiceDesc is the grpc.ServiceDesc for ConversationService service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var ConversationService_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "ai.v1.ConversationService", + HandlerType: (*ConversationServiceServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "Create", + Handler: _ConversationService_Create_Handler, + }, + { + MethodName: "List", + Handler: _ConversationService_List_Handler, + }, + { + MethodName: "Chat", + Handler: _ConversationService_Chat_Handler, + }, + { + MethodName: "Detail", + Handler: _ConversationService_Detail_Handler, + }, + }, + Streams: []grpc.StreamDesc{ + { + StreamName: "Stream", + Handler: _ConversationService_Stream_Handler, + ServerStreams: true, + }, + }, + Metadata: "ai.proto", +} diff --git a/api/proto/ai/v1/ai.proto b/api/proto/ai/v1/ai.proto index c8d8481..2f8da2c 100644 --- a/api/proto/ai/v1/ai.proto +++ b/api/proto/ai/v1/ai.proto @@ -14,7 +14,7 @@ syntax = "proto3"; package ai.v1; -option go_package = "github.com/ecodeclub/ai-gateway-go/api/gen/ai/v1;ai_v1"; +option go_package = "ai/v1;aiv1"; service AIService { rpc Chat(Message) returns (ChatResponse); diff --git a/api/proto/buf.yaml b/api/proto/buf.yaml deleted file mode 100644 index ee12108..0000000 --- a/api/proto/buf.yaml +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright 2021 ecodeclub -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -version: v2 -deps: - - buf.build/googleapis/googleapis -breaking: - use: - - FILE -lint: - use: - - STANDARD # Omit all Buf categories if you don't want to use Buf's built-in rules - except: - - ENUM_VALUE_PREFIX - - ENUM_ZERO_VALUE_SUFFIX - ignore: - - google/type/datetime.proto - - google/protobuf/empty.proto - - google/protobuf/timestamp.proto - disallow_comment_ignores: false # The default behavior of this key has changed from v1 - enum_zero_value_suffix: _UNSPECIFIED - rpc_allow_same_request_response: false - rpc_allow_google_protobuf_empty_requests: false - rpc_allow_google_protobuf_empty_responses: false - service_suffix: Service From fe10e1f6c088fd47bc5a92a49d44ba82237b7ae7 Mon Sep 17 00:00:00 2001 From: yumosx Date: Thu, 10 Jul 2025 00:44:38 +0800 Subject: [PATCH 13/24] fix path --- cmd/platform/server.go | 2 +- internal/grpc/conversation.go | 2 +- internal/grpc/server.go | 2 +- internal/test/conversation_test.go | 4 ++-- internal/test/mocks/mock_resp.go | 2 +- internal/test/server_test.go | 2 +- 6 files changed, 7 insertions(+), 7 deletions(-) diff --git a/cmd/platform/server.go b/cmd/platform/server.go index b238a1c..1d3d4a8 100644 --- a/cmd/platform/server.go +++ b/cmd/platform/server.go @@ -16,7 +16,7 @@ package main import ( ds "github.com/cohesion-org/deepseek-go" - ai "github.com/ecodeclub/ai-gateway-go/api/proto/gen/ai/v1" + ai "github.com/ecodeclub/ai-gateway-go/api/gen/ai/v1" igrpc "github.com/ecodeclub/ai-gateway-go/internal/grpc" "github.com/ecodeclub/ai-gateway-go/internal/service" "github.com/ecodeclub/ai-gateway-go/internal/service/llm/platform/deepseek" diff --git a/internal/grpc/conversation.go b/internal/grpc/conversation.go index 43b111e..f9aa755 100644 --- a/internal/grpc/conversation.go +++ b/internal/grpc/conversation.go @@ -17,7 +17,7 @@ package grpc import ( "context" - ai "github.com/ecodeclub/ai-gateway-go/api/proto/gen/ai/v1" + ai "github.com/ecodeclub/ai-gateway-go/api/gen/ai/v1" "github.com/ecodeclub/ai-gateway-go/internal/domain" "github.com/ecodeclub/ai-gateway-go/internal/service" "github.com/ecodeclub/ekit/slice" diff --git a/internal/grpc/server.go b/internal/grpc/server.go index d4eb458..856176e 100644 --- a/internal/grpc/server.go +++ b/internal/grpc/server.go @@ -18,7 +18,7 @@ import ( "context" "strconv" - ai "github.com/ecodeclub/ai-gateway-go/api/proto/gen/ai/v1" + ai "github.com/ecodeclub/ai-gateway-go/api/gen/ai/v1" "github.com/ecodeclub/ai-gateway-go/internal/domain" "github.com/ecodeclub/ai-gateway-go/internal/service" ) diff --git a/internal/test/conversation_test.go b/internal/test/conversation_test.go index eca8946..cdc3789 100644 --- a/internal/test/conversation_test.go +++ b/internal/test/conversation_test.go @@ -18,7 +18,7 @@ import ( "context" "testing" - aiv1 "github.com/ecodeclub/ai-gateway-go/api/proto/gen/ai/v1" + aiv1 "github.com/ecodeclub/ai-gateway-go/api/gen/ai/v1" "github.com/ecodeclub/ai-gateway-go/internal/domain" "github.com/ecodeclub/ai-gateway-go/internal/grpc" "github.com/ecodeclub/ai-gateway-go/internal/repository" @@ -288,7 +288,7 @@ func (c *ConversationSuite) TestDetail() { conversationService := service.NewConversationService(repo, handler) server := grpc.NewConversationServer(conversationService) tc.before() - detail, err := server.Detail(context.Background(), &aiv1.MsgListReq{Sn: "1"}) + detail, err := server.Detail(context.Background(), &aiv1.DetailRequest{Sn: "1"}) require.NoError(t, err) assert.ElementsMatch(t, detail.Message, []*aiv1.Message{ {Role: aiv1.Role_USER, Content: "user1"}, diff --git a/internal/test/mocks/mock_resp.go b/internal/test/mocks/mock_resp.go index 44344ab..5f1f31a 100644 --- a/internal/test/mocks/mock_resp.go +++ b/internal/test/mocks/mock_resp.go @@ -3,7 +3,7 @@ package mocks import ( "context" - ai "github.com/ecodeclub/ai-gateway-go/api/proto/gen/ai/v1" + ai "github.com/ecodeclub/ai-gateway-go/api/gen/ai/v1" ) type MockStreamServer struct { diff --git a/internal/test/server_test.go b/internal/test/server_test.go index 5f91492..44923e9 100644 --- a/internal/test/server_test.go +++ b/internal/test/server_test.go @@ -18,7 +18,7 @@ import ( "context" "testing" - ai "github.com/ecodeclub/ai-gateway-go/api/proto/gen/ai/v1" + ai "github.com/ecodeclub/ai-gateway-go/api/gen/ai/v1" "github.com/ecodeclub/ai-gateway-go/internal/domain" igrpc "github.com/ecodeclub/ai-gateway-go/internal/grpc" "github.com/ecodeclub/ai-gateway-go/internal/service" From 84129ec74616f3507822863961f884f666ba8c6c Mon Sep 17 00:00:00 2001 From: yumosx Date: Thu, 10 Jul 2025 00:46:33 +0800 Subject: [PATCH 14/24] fix --- .github/workflows/golangci-lint.yml | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index 7bbe491..23b5dcb 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -37,6 +37,4 @@ jobs: uses: golangci/golangci-lint-action@v7 with: version: v2.0 - only-new-issues: true - - name: Generate Protobuf - run: buf generate \ No newline at end of file + only-new-issues: true \ No newline at end of file From 2216361361affaba1c48fd1bf3ed68e0416fc262 Mon Sep 17 00:00:00 2001 From: yumosx Date: Thu, 10 Jul 2025 21:21:34 +0800 Subject: [PATCH 15/24] update test --- internal/test/quota_test.go | 54 +++++++++++++++++++++++++++---------- 1 file changed, 40 insertions(+), 14 deletions(-) diff --git a/internal/test/quota_test.go b/internal/test/quota_test.go index db584f1..48e3c66 100644 --- a/internal/test/quota_test.go +++ b/internal/test/quota_test.go @@ -189,13 +189,13 @@ func (q *QuotaSuite) TestDeduct() { testcases := []struct { name string reqBody string - before func() - after func() + before func(db *gorm.DB, server *gin.Engine) + after func(db *gorm.DB) }{ { name: "deduct quota", reqBody: `{"amount": 10, "key": "23911"}`, - before: func() { + before: func(db *gorm.DB, server *gin.Engine) { sess := mocks.NewMockSession(ctrl) sess.EXPECT().Claims().Return(session.Claims{ Uid: 1, @@ -205,11 +205,11 @@ func (q *QuotaSuite) TestDeduct() { provider.EXPECT().Get(gomock.Any()).Return(sess, nil) quota := dao.Quota{Amount: 20, Key: "23911", UID: 1} - q.db.Create("a) + db.Create("a) }, - after: func() { + after: func(db *gorm.DB) { var quota dao.Quota - err := q.db.Where("id = ?", 1).First("a).Error + err := db.Where("id = ?", 1).First("a).Error require.NoError(t, err) assert.Equal(t, int64(10), quota.Amount) }, @@ -217,7 +217,7 @@ func (q *QuotaSuite) TestDeduct() { { name: "deduct temp quota", reqBody: `{"amount": 10, "key": "23921"}`, - before: func() { + before: func(db *gorm.DB, server *gin.Engine) { sess := mocks.NewMockSession(ctrl) sess.EXPECT().Claims().Return(session.Claims{ Uid: 1, @@ -227,12 +227,12 @@ func (q *QuotaSuite) TestDeduct() { provider.EXPECT().Get(gomock.Any()).Return(sess, nil) quota := dao.TempQuota{Amount: 20, Key: "23921", UID: 1, StartTime: time.Now().Unix(), EndTime: time.Now().Add(24 * time.Hour).Unix()} - err := q.db.Create("a).Error + err := db.Create("a).Error require.NoError(t, err) }, - after: func() { + after: func(db *gorm.DB) { var quota dao.TempQuota - err := q.db.Where("id = ?", 1).First("a).Error + err := db.Where("id = ?", 1).First("a).Error require.NoError(t, err) assert.Equal(t, int64(10), quota.Amount) }, @@ -241,15 +241,41 @@ func (q *QuotaSuite) TestDeduct() { for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - tc.before() + // 独立初始化 DB 和 Gin Engine + dbConfig := config.NewConfig( + config.WithDBName("ai_gateway_platform"), + config.WithUserName("root"), + config.WithPassword("root"), + config.WithHost("127.0.0.1"), + config.WithPort("13306"), + ) + db, err := config.NewDB(dbConfig) + require.NoError(t, err) + err = dao.InitQuotaTable(db) + require.NoError(t, err) + + d := dao.NewQuotaDao(db) + repo := repository.NewQuotaRepo(d) + svc := service.NewQuotaService(repo) + handler := web.NewQuotaHandler(svc) + server := gin.Default() + handler.PrivateRoutes(server) + + // 清理表 + defer func() { + db.Exec("TRUNCATE TABLE quotas") + db.Exec("TRUNCATE TABLE quota_records") + db.Exec("TRUNCATE TABLE temp_quotas") + }() + + tc.before(db, server) req, err := http.NewRequest(http.MethodPost, "/deduct", bytes.NewBuffer([]byte(tc.reqBody))) require.NoError(t, err) req.Header.Set("Content-Type", "application/json") resp := httptest.NewRecorder() - q.server.ServeHTTP(resp, req) - + server.ServeHTTP(resp, req) assert.Equal(t, http.StatusOK, resp.Code) - tc.after() + tc.after(db) }) } } From 7efd4dba7e851363f4fa6c352326a433620107a8 Mon Sep 17 00:00:00 2001 From: yumosx Date: Fri, 11 Jul 2025 00:24:25 +0800 Subject: [PATCH 16/24] update --- internal/repository/dao/quota.go | 43 +++++++++++++------------------- 1 file changed, 18 insertions(+), 25 deletions(-) diff --git a/internal/repository/dao/quota.go b/internal/repository/dao/quota.go index b90ac17..79134e8 100644 --- a/internal/repository/dao/quota.go +++ b/internal/repository/dao/quota.go @@ -16,9 +16,11 @@ package dao import ( "context" + "errors" "time" "github.com/ecodeclub/ai-gateway-go/errs" + "github.com/go-sql-driver/mysql" "gorm.io/gorm" "gorm.io/gorm/clause" ) @@ -93,6 +95,7 @@ func (dao *QuotaDao) SaveQuota(ctx context.Context, quota Quota) error { quota.Utime = now return dao.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + now := time.Now().Unix() record := QuotaRecord{ Key: quota.Key, Uid: quota.UID, @@ -100,20 +103,13 @@ func (dao *QuotaDao) SaveQuota(ctx context.Context, quota Quota) error { Ctime: now, Utime: now, } - result := tx.Clauses(clause.OnConflict{ - Columns: []clause.Column{{Name: "key"}}, - DoUpdates: clause.Assignments(map[string]any{ - "amount": quota.Amount, - "utime": now, - }), - }).Create(&record) - - if result.Error != nil { - return result.Error - } - - if result.RowsAffected == 0 { - return nil + err := tx.Create(&record).Error + if err != nil { + var mysqlErr *mysql.MySQLError + if errors.As(err, &mysqlErr) && mysqlErr.Number == 1062 { + return err + } + return err } return tx.Clauses(clause.OnConflict{ @@ -160,18 +156,15 @@ func (dao *QuotaDao) Deduct(ctx context.Context, uid int64, amount int64, key st Ctime: now, Utime: now, } - result := tx.Clauses(clause.OnConflict{ - Columns: []clause.Column{{Name: "key"}}, - DoNothing: true, - }).Create(&record) - - if result.Error != nil { - return result.Error - } - if result.RowsAffected == 0 { - return nil + err := tx.Create(&record).Error + if err != nil { + // 判断是否唯一索引冲突(MySQL 1062) + var mysqlErr *mysql.MySQLError + if errors.As(err, &mysqlErr) && mysqlErr.Number == 1062 { + return err + } + return err } - // 执行扣减程序 return dao.deduct(tx, uid, amount, now) }) } From 1f420bd92fb02bda8936236318114980cf00bb1b Mon Sep 17 00:00:00 2001 From: yumosx Date: Sat, 12 Jul 2025 15:05:57 +0800 Subject: [PATCH 17/24] update the deduct --- errs/errs.go | 3 +- go.mod | 2 +- internal/repository/dao/quota.go | 63 ++++++++++++++------------------ 3 files changed, 29 insertions(+), 39 deletions(-) diff --git a/errs/errs.go b/errs/errs.go index 926e1f1..1bb8e03 100644 --- a/errs/errs.go +++ b/errs/errs.go @@ -19,7 +19,6 @@ import ( ) var ( - ErrDeductAmountFailed = errors.New("deduct amount failed") - ErrInsufficientQuota = errors.New("insufficient quota") + ErrDeductAmountFailed = errors.New("扣减失败") ErrBizConfigNotFound = errors.New("biz config not found") ) diff --git a/go.mod b/go.mod index 0afdb9f..774e252 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/ecodeclub/ekit v0.0.8-0.20240211141809-d8a351a335b5 github.com/ecodeclub/ginx v0.0.1 github.com/gin-gonic/gin v1.10.0 + github.com/go-sql-driver/mysql v1.7.0 github.com/golang-jwt/jwt/v5 v5.2.2 github.com/google/uuid v1.6.0 github.com/gotomicro/ego v1.2.3 @@ -43,7 +44,6 @@ require ( github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/validator/v10 v10.26.0 // indirect - github.com/go-sql-driver/mysql v1.7.0 // indirect github.com/goccy/go-json v0.10.5 // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/golang/snappy v0.0.4 // indirect diff --git a/internal/repository/dao/quota.go b/internal/repository/dao/quota.go index 79134e8..fa6d1cd 100644 --- a/internal/repository/dao/quota.go +++ b/internal/repository/dao/quota.go @@ -158,11 +158,6 @@ func (dao *QuotaDao) Deduct(ctx context.Context, uid int64, amount int64, key st } err := tx.Create(&record).Error if err != nil { - // 判断是否唯一索引冲突(MySQL 1062) - var mysqlErr *mysql.MySQLError - if errors.As(err, &mysqlErr) && mysqlErr.Number == 1062 { - return err - } return err } return dao.deduct(tx, uid, amount, now) @@ -170,44 +165,41 @@ func (dao *QuotaDao) Deduct(ctx context.Context, uid int64, amount int64, key st } func (dao *QuotaDao) deduct(tx *gorm.DB, uid int64, amount int64, now int64) error { - var tempQuotas []TempQuota - err := tx.Where("uid = ? AND end_time >= ? AND amount > 0", uid, now). - Order("end_time ASC"). - Find(&tempQuotas).Error - if err != nil { - return err - } - - remain := amount - - // 先扣临时额度 - for _, tq := range tempQuotas { - deduct := min(tq.Amount, remain) - - update := tx.Model(&TempQuota{}). - Where("id = ? AND amount >= ?", tq.ID, deduct). - Updates(map[string]any{ - "amount": gorm.Expr("amount - ?", deduct), - "utime": now, - }) - if update.Error != nil { - return update.Error + deductAmount := amount + for { + var quota TempQuota + err := tx.Where("amount > ? and uid = ?", 0, uid).First("a).Error + if err != nil { + // 表示找不到可以扣减的temp + if errors.Is(err, gorm.ErrRecordNotFound) { + break + } + return err } - if update.RowsAffected == 0 { + deductAmount = min(deductAmount, quota.Amount) + result := tx.Where("amount > ? and uid = ?", 0, uid).Updates(map[string]any{ + "amount": gorm.Expr("amount - ?", deductAmount), + "utime": now, + }) + if result.Error != nil { + return result.Error + } + // 并发问题, 直接下一个 + if result.RowsAffected == 0 { continue } - - remain -= deduct - if remain <= 0 { - return nil + // 表示扣减完毕 + amount -= deductAmount + if amount <= 0 { + break } } - // 如果还有剩余,从主额度扣 + // 从主额度扣 result := tx.Model(&Quota{}). - Where("uid = ? AND amount >= ?", uid, remain). + Where("uid = ? AND amount >= ?", uid, deductAmount). Updates(map[string]any{ - "amount": gorm.Expr("amount - ?", remain), + "amount": gorm.Expr("amount - ?", deductAmount), "utime": now, }) if result.Error != nil { @@ -216,7 +208,6 @@ func (dao *QuotaDao) deduct(tx *gorm.DB, uid int64, amount int64, now int64) err if result.RowsAffected == 0 { return errs.ErrDeductAmountFailed } - return nil } From 671269bc134a360fea0add3ab389fcdaf07d15ba Mon Sep 17 00:00:00 2001 From: yumosx Date: Sat, 12 Jul 2025 15:09:44 +0800 Subject: [PATCH 18/24] =?UTF-8?q?=E5=88=A0=E9=99=A4=E4=B8=8D=E5=BF=85?= =?UTF-8?q?=E8=A6=81=E7=9A=84=E6=8C=87=E4=BB=A4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Makefile | 1 - 1 file changed, 1 deletion(-) diff --git a/Makefile b/Makefile index 82789cd..4bc4d4d 100644 --- a/Makefile +++ b/Makefile @@ -46,7 +46,6 @@ tidy: check: @$(MAKE) fmt @$(MAKE) tidy - #@$(MAKE) lint # 生成gRPC相关文件 .PHONY: grpc From 5deee10c7a5e2d711b94ebca452b1cc68c2df045 Mon Sep 17 00:00:00 2001 From: yumosx Date: Sat, 12 Jul 2025 15:12:42 +0800 Subject: [PATCH 19/24] remove the error --- internal/repository/dao/quota.go | 5 ----- 1 file changed, 5 deletions(-) diff --git a/internal/repository/dao/quota.go b/internal/repository/dao/quota.go index fa6d1cd..7cd6fb7 100644 --- a/internal/repository/dao/quota.go +++ b/internal/repository/dao/quota.go @@ -20,7 +20,6 @@ import ( "time" "github.com/ecodeclub/ai-gateway-go/errs" - "github.com/go-sql-driver/mysql" "gorm.io/gorm" "gorm.io/gorm/clause" ) @@ -105,10 +104,6 @@ func (dao *QuotaDao) SaveQuota(ctx context.Context, quota Quota) error { } err := tx.Create(&record).Error if err != nil { - var mysqlErr *mysql.MySQLError - if errors.As(err, &mysqlErr) && mysqlErr.Number == 1062 { - return err - } return err } From 97cb475e183ec1f014ac01aa826ebecf7f0b48e5 Mon Sep 17 00:00:00 2001 From: yumosx Date: Sat, 12 Jul 2025 15:39:18 +0800 Subject: [PATCH 20/24] fix the deduct --- go.mod | 2 +- internal/repository/dao/quota.go | 6 ++++-- internal/test/quota_test.go | 4 ++-- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/go.mod b/go.mod index 774e252..0afdb9f 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,6 @@ require ( github.com/ecodeclub/ekit v0.0.8-0.20240211141809-d8a351a335b5 github.com/ecodeclub/ginx v0.0.1 github.com/gin-gonic/gin v1.10.0 - github.com/go-sql-driver/mysql v1.7.0 github.com/golang-jwt/jwt/v5 v5.2.2 github.com/google/uuid v1.6.0 github.com/gotomicro/ego v1.2.3 @@ -44,6 +43,7 @@ require ( github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/validator/v10 v10.26.0 // indirect + github.com/go-sql-driver/mysql v1.7.0 // indirect github.com/goccy/go-json v0.10.5 // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/golang/snappy v0.0.4 // indirect diff --git a/internal/repository/dao/quota.go b/internal/repository/dao/quota.go index 7cd6fb7..e5f9d28 100644 --- a/internal/repository/dao/quota.go +++ b/internal/repository/dao/quota.go @@ -172,7 +172,7 @@ func (dao *QuotaDao) deduct(tx *gorm.DB, uid int64, amount int64, now int64) err return err } deductAmount = min(deductAmount, quota.Amount) - result := tx.Where("amount > ? and uid = ?", 0, uid).Updates(map[string]any{ + result := tx.Model(&TempQuota{}).Where("amount > ? and uid = ?", 0, uid).Updates(map[string]any{ "amount": gorm.Expr("amount - ?", deductAmount), "utime": now, }) @@ -189,7 +189,9 @@ func (dao *QuotaDao) deduct(tx *gorm.DB, uid int64, amount int64, now int64) err break } } - + if amount == 0 { + return nil + } // 从主额度扣 result := tx.Model(&Quota{}). Where("uid = ? AND amount >= ?", uid, deductAmount). diff --git a/internal/test/quota_test.go b/internal/test/quota_test.go index 48e3c66..2e59306 100644 --- a/internal/test/quota_test.go +++ b/internal/test/quota_test.go @@ -216,7 +216,7 @@ func (q *QuotaSuite) TestDeduct() { }, { name: "deduct temp quota", - reqBody: `{"amount": 10, "key": "23921"}`, + reqBody: `{"amount": 10, "key": "23922"}`, before: func(db *gorm.DB, server *gin.Engine) { sess := mocks.NewMockSession(ctrl) sess.EXPECT().Claims().Return(session.Claims{ @@ -226,7 +226,7 @@ func (q *QuotaSuite) TestDeduct() { session.SetDefaultProvider(provider) provider.EXPECT().Get(gomock.Any()).Return(sess, nil) - quota := dao.TempQuota{Amount: 20, Key: "23921", UID: 1, StartTime: time.Now().Unix(), EndTime: time.Now().Add(24 * time.Hour).Unix()} + quota := dao.TempQuota{Amount: 20, Key: "23922", UID: 1, StartTime: time.Now().Unix(), EndTime: time.Now().Add(24 * time.Hour).Unix()} err := db.Create("a).Error require.NoError(t, err) }, From 34fd8887dffbe0aa3016bbbe47230e3dc87fdc03 Mon Sep 17 00:00:00 2001 From: yumosx Date: Sat, 12 Jul 2025 15:41:42 +0800 Subject: [PATCH 21/24] fix --- internal/repository/dao/quota.go | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/internal/repository/dao/quota.go b/internal/repository/dao/quota.go index e5f9d28..5e62fc1 100644 --- a/internal/repository/dao/quota.go +++ b/internal/repository/dao/quota.go @@ -186,12 +186,9 @@ func (dao *QuotaDao) deduct(tx *gorm.DB, uid int64, amount int64, now int64) err // 表示扣减完毕 amount -= deductAmount if amount <= 0 { - break + return nil } } - if amount == 0 { - return nil - } // 从主额度扣 result := tx.Model(&Quota{}). Where("uid = ? AND amount >= ?", uid, deductAmount). From d0b1832dfa295f5a29aabeb7116e20af48d8e663 Mon Sep 17 00:00:00 2001 From: yumosx Date: Mon, 14 Jul 2025 19:00:09 +0800 Subject: [PATCH 22/24] add tests --- .github/workflows/go-fmt.yml | 4 - errs/errs.go | 5 +- internal/errs/code.go | 6 +- internal/repository/dao/quota.go | 18 +-- internal/repository/quota.go | 8 +- internal/service/quota.go | 8 +- internal/test/quota_test.go | 219 ++++++++++++++++++++++++------- internal/web/quota.go | 60 ++++----- internal/web/result.go | 10 ++ 9 files changed, 228 insertions(+), 110 deletions(-) diff --git a/.github/workflows/go-fmt.yml b/.github/workflows/go-fmt.yml index c6fff3f..e16877c 100644 --- a/.github/workflows/go-fmt.yml +++ b/.github/workflows/go-fmt.yml @@ -34,10 +34,6 @@ jobs: with: go-version: "1.24.2" - - name: Install goimports - run: | - go install golang.org/x/tools/cmd/goimports@latest - go install mvdan.cc/gofumpt@latest - name: Check run: | make check diff --git a/errs/errs.go b/errs/errs.go index 1bb8e03..80bf227 100644 --- a/errs/errs.go +++ b/errs/errs.go @@ -19,6 +19,7 @@ import ( ) var ( - ErrDeductAmountFailed = errors.New("扣减失败") - ErrBizConfigNotFound = errors.New("biz config not found") + ErrBizConfigNotFound = errors.New("查询业务配置失败") + ErrInvalidParam = errors.New("参数错误") + ErrInsufficientBalance = errors.New("余额不足") ) diff --git a/internal/errs/code.go b/internal/errs/code.go index 615b5e3..41509d1 100644 --- a/internal/errs/code.go +++ b/internal/errs/code.go @@ -14,7 +14,11 @@ package errs -var SystemError = ErrorCode{Code: 501001, Msg: "系统错误"} +var ( + SystemError = ErrorCode{Code: 501001, Msg: "系统错误"} + InvalidParamError = ErrorCode{Code: 400001, Msg: "参数错误"} + InsufficientBalanceError = ErrorCode{Code: 400002, Msg: "余额不足"} +) type ErrorCode struct { Code int diff --git a/internal/repository/dao/quota.go b/internal/repository/dao/quota.go index 5e62fc1..c88a7c1 100644 --- a/internal/repository/dao/quota.go +++ b/internal/repository/dao/quota.go @@ -74,22 +74,14 @@ func NewQuotaDao(db *gorm.DB) *QuotaDao { return &QuotaDao{db: db} } -func (dao *QuotaDao) SaveTempQuota(ctx context.Context, quota TempQuota) error { +func (dao *QuotaDao) CreateTempQuota(ctx context.Context, quota TempQuota) error { now := time.Now().Unix() quota.Ctime = now quota.Utime = now - return dao.db.WithContext(ctx).Clauses(clause.OnConflict{ - Columns: []clause.Column{{Name: "key"}}, - DoUpdates: clause.Assignments(map[string]any{ - "amount": quota.Amount, - "start_time": quota.StartTime, - "end_time": quota.EndTime, - "utime": now, - }), - }).Create("a).Error + return dao.db.WithContext(ctx).Create("a).Error } -func (dao *QuotaDao) SaveQuota(ctx context.Context, quota Quota) error { +func (dao *QuotaDao) AddQuota(ctx context.Context, quota Quota) error { now := time.Now().Unix() quota.Utime = now @@ -110,7 +102,7 @@ func (dao *QuotaDao) SaveQuota(ctx context.Context, quota Quota) error { return tx.Clauses(clause.OnConflict{ Columns: []clause.Column{{Name: "key"}}, DoUpdates: clause.Assignments(map[string]any{ - "amount": quota.Amount, + "amount": gorm.Expr("amount + ?", quota.Amount), "utime": now, }), }).Create("a).Error @@ -200,7 +192,7 @@ func (dao *QuotaDao) deduct(tx *gorm.DB, uid int64, amount int64, now int64) err return result.Error } if result.RowsAffected == 0 { - return errs.ErrDeductAmountFailed + return errs.ErrInsufficientBalance } return nil } diff --git a/internal/repository/quota.go b/internal/repository/quota.go index 51800b6..c55ba38 100644 --- a/internal/repository/quota.go +++ b/internal/repository/quota.go @@ -30,12 +30,12 @@ func NewQuotaRepo(dao *dao.QuotaDao) *QuotaRepo { return &QuotaRepo{dao: dao} } -func (q *QuotaRepo) SaveQuota(ctx context.Context, quota domain.Quota) error { - return q.dao.SaveQuota(ctx, dao.Quota{UID: quota.Uid, Amount: quota.Amount, Key: quota.Key}) +func (q *QuotaRepo) AddQuota(ctx context.Context, quota domain.Quota) error { + return q.dao.AddQuota(ctx, dao.Quota{UID: quota.Uid, Amount: quota.Amount, Key: quota.Key}) } -func (q *QuotaRepo) SaveTempQuota(ctx context.Context, quota domain.TempQuota) error { - return q.dao.SaveTempQuota(ctx, dao.TempQuota{Amount: quota.Amount, StartTime: quota.StartTime, EndTime: quota.EndTime, Key: quota.Key}) +func (q *QuotaRepo) CreateTempQuota(ctx context.Context, quota domain.TempQuota) error { + return q.dao.CreateTempQuota(ctx, dao.TempQuota{Amount: quota.Amount, StartTime: quota.StartTime, EndTime: quota.EndTime, Key: quota.Key, UID: quota.Uid}) } func (q *QuotaRepo) GetQuota(ctx context.Context, uid int64) (domain.Quota, error) { diff --git a/internal/service/quota.go b/internal/service/quota.go index e429551..4ec1cf4 100644 --- a/internal/service/quota.go +++ b/internal/service/quota.go @@ -29,12 +29,12 @@ func NewQuotaService(repo *repository.QuotaRepo) *QuotaService { return &QuotaService{repo: repo} } -func (q *QuotaService) SaveQuota(ctx context.Context, quota domain.Quota) error { - return q.repo.SaveQuota(ctx, quota) +func (q *QuotaService) AddQuota(ctx context.Context, quota domain.Quota) error { + return q.repo.AddQuota(ctx, quota) } -func (q *QuotaService) SaveTempQuota(ctx context.Context, quota domain.TempQuota) error { - return q.repo.SaveTempQuota(ctx, quota) +func (q *QuotaService) CreateTempQuota(ctx context.Context, quota domain.TempQuota) error { + return q.repo.CreateTempQuota(ctx, quota) } func (q *QuotaService) GetTempQuota(ctx context.Context, uid int64) ([]domain.TempQuota, error) { diff --git a/internal/test/quota_test.go b/internal/test/quota_test.go index 2e59306..fbee4f2 100644 --- a/internal/test/quota_test.go +++ b/internal/test/quota_test.go @@ -16,6 +16,7 @@ package test import ( "bytes" + "encoding/json" "net/http" "net/http/httptest" "testing" @@ -118,6 +119,14 @@ func (q *QuotaSuite) TestQuotaSave() { for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { + // 每个测试用例开始时清理数据 + err := q.db.Exec("TRUNCATE TABLE quotas").Error + require.NoError(t, err) + err = q.db.Exec("TRUNCATE TABLE quota_records").Error + require.NoError(t, err) + err = q.db.Exec("TRUNCATE TABLE temp_quotas").Error + require.NoError(t, err) + tc.before() req, err := http.NewRequest(http.MethodPost, "/quota/save", bytes.NewBuffer([]byte(tc.reqBody))) require.NoError(t, err) @@ -140,13 +149,12 @@ func (q *QuotaSuite) TestSaveTempQuota() { testcases := []struct { name string - reqBody string before func() after func() + reqBody string }{ { - name: "save temp", - reqBody: `{"amount": 100000, "key": "23911", "start_time": "123", "end_time": "456"}`, + name: "创建临时额度", before: func() { sess := mocks.NewMockSession(ctrl) sess.EXPECT().Claims().Return(session.Claims{ @@ -158,15 +166,26 @@ func (q *QuotaSuite) TestSaveTempQuota() { }, after: func() { var quota dao.TempQuota - err := q.db.Where("id = ?", 1).First("a).Error + err := q.db.Where("uid = ? AND `key` = ?", 1, "temp_key_1").First("a).Error require.NoError(t, err) assert.Equal(t, int64(100000), quota.Amount) + assert.Equal(t, int64(123), quota.StartTime) + assert.Equal(t, int64(456), quota.EndTime) }, + reqBody: `{"amount": 100000, "key": "temp_key_1", "start_time": 123, "end_time": 456}`, }, } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { + // 每个测试用例开始时清理数据 + err := q.db.Exec("TRUNCATE TABLE quotas").Error + require.NoError(t, err) + err = q.db.Exec("TRUNCATE TABLE quota_records").Error + require.NoError(t, err) + err = q.db.Exec("TRUNCATE TABLE temp_quotas").Error + require.NoError(t, err) + tc.before() req, err := http.NewRequest(http.MethodPost, "/tmp/save", bytes.NewBuffer([]byte(tc.reqBody))) require.NoError(t, err) @@ -183,19 +202,19 @@ func (q *QuotaSuite) TestSaveTempQuota() { func (q *QuotaSuite) TestDeduct() { t := q.T() + ctrl := gomock.NewController(t) defer ctrl.Finish() testcases := []struct { name string + before func() + after func() reqBody string - before func(db *gorm.DB, server *gin.Engine) - after func(db *gorm.DB) }{ { - name: "deduct quota", - reqBody: `{"amount": 10, "key": "23911"}`, - before: func(db *gorm.DB, server *gin.Engine) { + name: "从主额度扣减", + before: func() { sess := mocks.NewMockSession(ctrl) sess.EXPECT().Claims().Return(session.Claims{ Uid: 1, @@ -204,20 +223,29 @@ func (q *QuotaSuite) TestDeduct() { session.SetDefaultProvider(provider) provider.EXPECT().Get(gomock.Any()).Return(sess, nil) - quota := dao.Quota{Amount: 20, Key: "23911", UID: 1} - db.Create("a) + // 创建主额度 + quota := dao.Quota{Amount: 100, Key: "main_quota", UID: 1} + err := q.db.Create("a).Error + require.NoError(t, err) }, - after: func(db *gorm.DB) { + after: func() { + // 验证主额度被扣减 var quota dao.Quota - err := db.Where("id = ?", 1).First("a).Error + err := q.db.Where("uid = ? AND `key` = ?", 1, "main_quota").First("a).Error + require.NoError(t, err) + assert.Equal(t, int64(80), quota.Amount) + + // 验证扣减记录被创建 + var record dao.QuotaRecord + err = q.db.Where("uid = ? AND `key` = ?", 1, "deduct_key_1").First(&record).Error require.NoError(t, err) - assert.Equal(t, int64(10), quota.Amount) + assert.Equal(t, int64(20), record.Amount) }, + reqBody: `{"amount": 20, "key": "deduct_key_1"}`, }, { - name: "deduct temp quota", - reqBody: `{"amount": 10, "key": "23922"}`, - before: func(db *gorm.DB, server *gin.Engine) { + name: "从临时额度扣减", + before: func() { sess := mocks.NewMockSession(ctrl) sess.EXPECT().Claims().Return(session.Claims{ Uid: 1, @@ -226,56 +254,145 @@ func (q *QuotaSuite) TestDeduct() { session.SetDefaultProvider(provider) provider.EXPECT().Get(gomock.Any()).Return(sess, nil) - quota := dao.TempQuota{Amount: 20, Key: "23922", UID: 1, StartTime: time.Now().Unix(), EndTime: time.Now().Add(24 * time.Hour).Unix()} - err := db.Create("a).Error + // 创建临时额度 + now := time.Now().Unix() + tempQuota := dao.TempQuota{ + Amount: 50, + Key: "temp_quota_1", + UID: 1, + StartTime: now, + EndTime: now + 24*3600, + } + err := q.db.Create(&tempQuota).Error require.NoError(t, err) }, - after: func(db *gorm.DB) { - var quota dao.TempQuota - err := db.Where("id = ?", 1).First("a).Error + after: func() { + // 验证临时额度被扣减 + var tempQuota dao.TempQuota + err := q.db.Where("uid = ? AND `key` = ?", 1, "temp_quota_1").First(&tempQuota).Error + require.NoError(t, err) + assert.Equal(t, int64(30), tempQuota.Amount) + + // 验证扣减记录被创建 + var record dao.QuotaRecord + err = q.db.Where("uid = ? AND `key` = ?", 1, "deduct_key_2").First(&record).Error require.NoError(t, err) - assert.Equal(t, int64(10), quota.Amount) + assert.Equal(t, int64(20), record.Amount) }, + reqBody: `{"amount": 20, "key": "deduct_key_2"}`, + }, + { + name: "优先从临时额度扣减,不足再从主额度扣减", + before: func() { + sess := mocks.NewMockSession(ctrl) + sess.EXPECT().Claims().Return(session.Claims{ + Uid: 1, + }).AnyTimes() + provider := mocks.NewMockProvider(ctrl) + session.SetDefaultProvider(provider) + provider.EXPECT().Get(gomock.Any()).Return(sess, nil) + + // 创建主额度 + quota := dao.Quota{Amount: 100, Key: "main_quota_2", UID: 1} + err := q.db.Create("a).Error + require.NoError(t, err) + + // 创建临时额度(金额不足) + now := time.Now().Unix() + tempQuota := dao.TempQuota{ + Amount: 10, + Key: "temp_quota_2", + UID: 1, + StartTime: now, + EndTime: now + 24*3600, + } + err = q.db.Create(&tempQuota).Error + require.NoError(t, err) + }, + after: func() { + // 验证临时额度被完全扣减 + var tempQuota dao.TempQuota + err := q.db.Where("uid = ? AND `key` = ?", 1, "temp_quota_2").First(&tempQuota).Error + require.NoError(t, err) + assert.Equal(t, int64(0), tempQuota.Amount) + + // 验证主额度被扣减 + var quota dao.Quota + err = q.db.Where("uid = ? AND `key` = ?", 1, "main_quota_2").First("a).Error + require.NoError(t, err) + assert.Equal(t, int64(90), quota.Amount) + + // 验证扣减记录被创建 + var record dao.QuotaRecord + err = q.db.Where("uid = ? AND `key` = ?", 1, "deduct_key_3").First(&record).Error + require.NoError(t, err) + assert.Equal(t, int64(30), record.Amount) + }, + reqBody: `{"amount": 30, "key": "deduct_key_3"}`, + }, + { + name: "扣减失败 - 余额不足", + before: func() { + sess := mocks.NewMockSession(ctrl) + sess.EXPECT().Claims().Return(session.Claims{ + Uid: 1, + }).AnyTimes() + provider := mocks.NewMockProvider(ctrl) + session.SetDefaultProvider(provider) + provider.EXPECT().Get(gomock.Any()).Return(sess, nil) + + // 创建少量主额度 + quota := dao.Quota{Amount: 10, Key: "main_quota_3", UID: 1} + err := q.db.Create("a).Error + require.NoError(t, err) + }, + after: func() { + // 验证主额度没有被扣减 + var quota dao.Quota + err := q.db.Where("uid = ? AND `key` = ?", 1, "main_quota_3").First("a).Error + require.NoError(t, err) + assert.Equal(t, int64(10), quota.Amount) // 额度应该保持不变 + + // 验证扣减记录没有被创建(因为事务回滚) + var record dao.QuotaRecord + err = q.db.Where("uid = ? AND `key` = ?", 1, "deduct_key_4").First(&record).Error + assert.Error(t, err) // 应该找不到记录,因为事务回滚了 + }, + reqBody: `{"amount": 50, "key": "deduct_key_4"}`, }, } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - // 独立初始化 DB 和 Gin Engine - dbConfig := config.NewConfig( - config.WithDBName("ai_gateway_platform"), - config.WithUserName("root"), - config.WithPassword("root"), - config.WithHost("127.0.0.1"), - config.WithPort("13306"), - ) - db, err := config.NewDB(dbConfig) + // 每个测试用例开始时清理数据 + err := q.db.Exec("TRUNCATE TABLE quotas").Error + require.NoError(t, err) + err = q.db.Exec("TRUNCATE TABLE quota_records").Error require.NoError(t, err) - err = dao.InitQuotaTable(db) + err = q.db.Exec("TRUNCATE TABLE temp_quotas").Error require.NoError(t, err) - d := dao.NewQuotaDao(db) - repo := repository.NewQuotaRepo(d) - svc := service.NewQuotaService(repo) - handler := web.NewQuotaHandler(svc) - server := gin.Default() - handler.PrivateRoutes(server) - - // 清理表 - defer func() { - db.Exec("TRUNCATE TABLE quotas") - db.Exec("TRUNCATE TABLE quota_records") - db.Exec("TRUNCATE TABLE temp_quotas") - }() - - tc.before(db, server) + tc.before() req, err := http.NewRequest(http.MethodPost, "/deduct", bytes.NewBuffer([]byte(tc.reqBody))) require.NoError(t, err) req.Header.Set("Content-Type", "application/json") resp := httptest.NewRecorder() - server.ServeHTTP(resp, req) - assert.Equal(t, http.StatusOK, resp.Code) - tc.after(db) + q.server.ServeHTTP(resp, req) + + if tc.name == "扣减失败 - 余额不足" { + assert.Equal(t, http.StatusOK, resp.Code) // HTTP状态码仍然是200 + + var response map[string]interface{} + err = json.Unmarshal(resp.Body.Bytes(), &response) + require.NoError(t, err) + + assert.Equal(t, float64(400002), response["code"]) + assert.Equal(t, "余额不足", response["msg"]) + } else { + assert.Equal(t, http.StatusOK, resp.Code) + } + + tc.after() }) } } diff --git a/internal/web/quota.go b/internal/web/quota.go index 5add0fc..0be1054 100644 --- a/internal/web/quota.go +++ b/internal/web/quota.go @@ -15,8 +15,9 @@ package web import ( - "time" + "errors" + "github.com/ecodeclub/ai-gateway-go/errs" "github.com/ecodeclub/ai-gateway-go/internal/domain" "github.com/ecodeclub/ai-gateway-go/internal/service" "github.com/ecodeclub/ekit/slice" @@ -33,44 +34,45 @@ func NewQuotaHandler(svc *service.QuotaService) *QuotaHandler { return &QuotaHandler{svc: svc} } -func (q *QuotaHandler) PublicRoutes(_ *gin.Engine) {} - func (q *QuotaHandler) PrivateRoutes(server *gin.Engine) { group := server.Group("/quota") - group.POST("/save", ginx.BS(q.SaveQuota)) + group.POST("/save", ginx.BS(q.AddQuota)) group.POST("/get", ginx.S(q.GetQuota)) tmp := server.Group("/tmp") - tmp.POST("/save", ginx.BS(q.SaveTempQuota)) + tmp.POST("/save", ginx.BS(q.CreateTempQuota)) tmp.POST("/get", ginx.S(q.GetTempQuota)) server.POST("/deduct", ginx.BS(q.Deduct)) } -func (q *QuotaHandler) SaveQuota(ctx *ginx.Context, req QuotaRequest, sess session.Session) (ginx.Result, error) { +func (q *QuotaHandler) AddQuota(ctx *ginx.Context, req QuotaRequest, sess session.Session) (ginx.Result, error) { uid := sess.Claims().Uid - err := q.svc.SaveQuota(ctx, domain.Quota{Amount: req.Amount, Uid: uid, Key: req.Key}) + err := q.svc.AddQuota(ctx, domain.Quota{Amount: req.Amount, Uid: uid, Key: req.Key}) if err != nil { - return systemErrorResult, nil + return systemErrorResult, err } return ginx.Result{ Msg: "OK", }, nil } -func (q *QuotaHandler) SaveTempQuota(ctx *ginx.Context, req QuotaRequest, sess session.Session) (ginx.Result, error) { +func (q *QuotaHandler) CreateTempQuota(ctx *ginx.Context, req QuotaRequest, sess session.Session) (ginx.Result, error) { uid := sess.Claims().Uid - if req.StartTime == "" || req.EndTime == "" { - return systemErrorResult, nil + if req.StartTime == 0 || req.EndTime == 0 { + return invalidParamResult, errs.ErrInvalidParam } - start, _ := q.toTimestamp(req.StartTime) - end, _ := q.toTimestamp(req.EndTime) - - err := q.svc.SaveTempQuota(ctx, domain.TempQuota{Amount: req.Amount, Uid: uid, StartTime: start, EndTime: end}) + err := q.svc.CreateTempQuota(ctx, domain.TempQuota{ + Amount: req.Amount, + Key: req.Key, + Uid: uid, + StartTime: req.StartTime, + EndTime: req.EndTime, + }) if err != nil { - return systemErrorResult, nil + return systemErrorResult, err } return ginx.Result{ Msg: "OK", @@ -106,26 +108,22 @@ func (q *QuotaHandler) Deduct(ctx *ginx.Context, req QuotaRequest, sees session. uid := sees.Claims().Uid err := q.svc.Deduct(ctx, uid, req.Amount, req.Key) if err != nil { + // 检查是否是余额不足错误 + if errors.Is(err, errs.ErrInsufficientBalance) { + return insufficientBalanceResult, nil + } + // 其他系统错误 return systemErrorResult, nil } return ginx.Result{Msg: "OK"}, nil } -func (q *QuotaHandler) toTimestamp(timeStr string) (int64, error) { - const layout = "2006-01-02 15:04:05" - t, err := time.Parse(layout, timeStr) - if err != nil { - return 0, err - } - return t.Unix(), nil -} - func (q *QuotaHandler) toQuotaResponse(tempQuotaList []domain.TempQuota) []QuotaResponse { return slice.Map[domain.TempQuota, QuotaResponse](tempQuotaList, func(idx int, src domain.TempQuota) QuotaResponse { return QuotaResponse{ Amount: src.Amount, - StartTime: time.Unix(src.StartTime, 0).Format("2006-01-02 15:04:05"), - EndTime: time.Unix(src.EndTime, 0).Format("2006-01-02 15:04:05"), + StartTime: src.StartTime, + EndTime: src.EndTime, } }) } @@ -133,13 +131,13 @@ func (q *QuotaHandler) toQuotaResponse(tempQuotaList []domain.TempQuota) []Quota type QuotaRequest struct { Amount int64 `json:"amount,omitempty"` Key string `json:"key,omitempty"` - StartTime string `json:"start_time,omitempty"` - EndTime string `json:"end_time,omitempty"` + StartTime int64 `json:"start_time,omitempty"` + EndTime int64 `json:"end_time,omitempty"` } type QuotaResponse struct { Amount int64 `json:"amount,omitempty"` Key string `json:"key"` - StartTime string `json:"start_time,omitempty"` - EndTime string `json:"end_time,omitempty"` + StartTime int64 `json:"start_time,omitempty"` + EndTime int64 `json:"end_time,omitempty"` } diff --git a/internal/web/result.go b/internal/web/result.go index fad3de7..411a466 100644 --- a/internal/web/result.go +++ b/internal/web/result.go @@ -23,3 +23,13 @@ var systemErrorResult = ginx.Result{ Code: errs.SystemError.Code, Msg: errs.SystemError.Msg, } + +var invalidParamResult = ginx.Result{ + Code: errs.InvalidParamError.Code, + Msg: errs.InvalidParamError.Msg, +} + +var insufficientBalanceResult = ginx.Result{ + Code: errs.InsufficientBalanceError.Code, + Msg: errs.InsufficientBalanceError.Msg, +} From 12cad22ed28fd7cf2eebdeca3f15dff6031710ef Mon Sep 17 00:00:00 2001 From: yumosx Date: Mon, 14 Jul 2025 19:08:26 +0800 Subject: [PATCH 23/24] rm fmt --- Makefile | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/Makefile b/Makefile index 4bc4d4d..277c7d4 100644 --- a/Makefile +++ b/Makefile @@ -28,13 +28,12 @@ e2e: @go test -race -v -failfast ./... @docker compose -f ./.script/docker-compose.yaml down +# .PHONY: fmt +# fmt: +# @goimports -l -w $(GOFILES) +# @gofumpt -l -w $(GOFILES) -.PHONY: fmt -fmt: - @goimports -l -w $(GOFILES) - @gofumpt -l -w $(GOFILES) - -.PHONY: lint +.PHONY: lint lint: @golangci-lint run -c .golangci.yml @@ -44,7 +43,6 @@ tidy: .PHONY: check check: - @$(MAKE) fmt @$(MAKE) tidy # 生成gRPC相关文件 From 8ac1412cc6945a9c9ab2f34dc00c13ed047d1068 Mon Sep 17 00:00:00 2001 From: yumosx Date: Mon, 14 Jul 2025 19:22:35 +0800 Subject: [PATCH 24/24] fix --- .github/workflows/go-fmt.yml | 3 +++ Makefile | 21 ++++++++------------- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/.github/workflows/go-fmt.yml b/.github/workflows/go-fmt.yml index e16877c..eefc0d5 100644 --- a/.github/workflows/go-fmt.yml +++ b/.github/workflows/go-fmt.yml @@ -34,6 +34,9 @@ jobs: with: go-version: "1.24.2" + - name: Install goimports + run: go install golang.org/x/tools/cmd/goimports@latest + - name: Check run: | make check diff --git a/Makefile b/Makefile index 277c7d4..ee7b0aa 100644 --- a/Makefile +++ b/Makefile @@ -1,10 +1,3 @@ -GOFILES=$(shell find . -type f -name '*.go' \ - -not -path "./vendor/*" \ - -not -path "./third_party/*" \ - -not -path "./.idea/*" \ - -not -name '*.pb.go' \ - -not -name '*mock*.go') - .PHONY: bench bench: @go test -bench=. -benchmem ./... @@ -28,12 +21,12 @@ e2e: @go test -race -v -failfast ./... @docker compose -f ./.script/docker-compose.yaml down -# .PHONY: fmt -# fmt: -# @goimports -l -w $(GOFILES) -# @gofumpt -l -w $(GOFILES) -.PHONY: lint +.PHONY: fmt +fmt: + @goimports -l -w $$(find . -type f -name '*.go' -not -path "./.idea/*" -not -name '*.pb.go' -not -name '*mock*.go') + +.PHONY: lint lint: @golangci-lint run -c .golangci.yml @@ -43,11 +36,13 @@ tidy: .PHONY: check check: + @$(MAKE) fmt @$(MAKE) tidy + #@$(MAKE) lint # 生成gRPC相关文件 .PHONY: grpc grpc: @buf format -w api/proto # @buf lint api/proto - @buf generate api/proto + @buf generate api/proto \ No newline at end of file