diff --git a/Makefile b/Makefile index d58046d..ee7b0aa 100644 --- a/Makefile +++ b/Makefile @@ -45,4 +45,4 @@ check: grpc: @buf format -w api/proto # @buf lint api/proto - @buf generate api/proto + @buf generate api/proto \ No newline at end of file diff --git a/api/gen/ai/v1/ai.pb.go b/api/gen/ai/v1/ai.pb.go new file mode 100644 index 0000000..9d70781 --- /dev/null +++ b/api/gen/ai/v1/ai.pb.go @@ -0,0 +1,743 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.6 +// protoc (unknown) +// source: ai.proto + +package aiv1 + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type Role int32 + +const ( + Role_UNKNOWN Role = 0 + Role_USER Role = 1 + Role_ASSISTANT Role = 2 + Role_SYSTEM Role = 3 + Role_TOOL Role = 4 +) + +// Enum value maps for Role. +var ( + Role_name = map[int32]string{ + 0: "UNKNOWN", + 1: "USER", + 2: "ASSISTANT", + 3: "SYSTEM", + 4: "TOOL", + } + Role_value = map[string]int32{ + "UNKNOWN": 0, + "USER": 1, + "ASSISTANT": 2, + "SYSTEM": 3, + "TOOL": 4, + } +) + +func (x Role) Enum() *Role { + p := new(Role) + *p = x + return p +} + +func (x Role) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (Role) Descriptor() protoreflect.EnumDescriptor { + return file_ai_proto_enumTypes[0].Descriptor() +} + +func (Role) Type() protoreflect.EnumType { + return &file_ai_proto_enumTypes[0] +} + +func (x Role) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use Role.Descriptor instead. +func (Role) EnumDescriptor() ([]byte, []int) { + return file_ai_proto_rawDescGZIP(), []int{0} +} + +type StreamEvent struct { + state protoimpl.MessageState `protogen:"open.v1"` + Final bool `protobuf:"varint,1,opt,name=final,proto3" json:"final,omitempty"` + ReasoningContent string `protobuf:"bytes,2,opt,name=reasoningContent,proto3" json:"reasoningContent,omitempty"` + Content string `protobuf:"bytes,3,opt,name=content,proto3" json:"content,omitempty"` + Err string `protobuf:"bytes,4,opt,name=err,proto3" json:"err,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *StreamEvent) Reset() { + *x = StreamEvent{} + mi := &file_ai_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *StreamEvent) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*StreamEvent) ProtoMessage() {} + +func (x *StreamEvent) ProtoReflect() protoreflect.Message { + mi := &file_ai_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use StreamEvent.ProtoReflect.Descriptor instead. +func (*StreamEvent) Descriptor() ([]byte, []int) { + return file_ai_proto_rawDescGZIP(), []int{0} +} + +func (x *StreamEvent) GetFinal() bool { + if x != nil { + return x.Final + } + return false +} + +func (x *StreamEvent) GetReasoningContent() string { + if x != nil { + return x.ReasoningContent + } + return "" +} + +func (x *StreamEvent) GetContent() string { + if x != nil { + return x.Content + } + return "" +} + +func (x *StreamEvent) GetErr() string { + if x != nil { + return x.Err + } + return "" +} + +type Conversation struct { + state protoimpl.MessageState `protogen:"open.v1"` + Sn string `protobuf:"bytes,1,opt,name=sn,proto3" json:"sn,omitempty"` + Uid string `protobuf:"bytes,2,opt,name=uid,proto3" json:"uid,omitempty"` + Title string `protobuf:"bytes,3,opt,name=title,proto3" json:"title,omitempty"` + Message []*Message `protobuf:"bytes,4,rep,name=message,proto3" json:"message,omitempty"` + Ctime string `protobuf:"bytes,5,opt,name=ctime,proto3" json:"ctime,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Conversation) Reset() { + *x = Conversation{} + mi := &file_ai_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Conversation) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Conversation) ProtoMessage() {} + +func (x *Conversation) ProtoReflect() protoreflect.Message { + mi := &file_ai_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Conversation.ProtoReflect.Descriptor instead. +func (*Conversation) Descriptor() ([]byte, []int) { + return file_ai_proto_rawDescGZIP(), []int{1} +} + +func (x *Conversation) GetSn() string { + if x != nil { + return x.Sn + } + return "" +} + +func (x *Conversation) GetUid() string { + if x != nil { + return x.Uid + } + return "" +} + +func (x *Conversation) GetTitle() string { + if x != nil { + return x.Title + } + return "" +} + +func (x *Conversation) GetMessage() []*Message { + if x != nil { + return x.Message + } + return nil +} + +func (x *Conversation) GetCtime() string { + if x != nil { + return x.Ctime + } + return "" +} + +type ListReq struct { + state protoimpl.MessageState `protogen:"open.v1"` + Uid string `protobuf:"bytes,1,opt,name=uid,proto3" json:"uid,omitempty"` + Offset int64 `protobuf:"varint,2,opt,name=offset,proto3" json:"offset,omitempty"` + Limit int64 `protobuf:"varint,3,opt,name=limit,proto3" json:"limit,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ListReq) Reset() { + *x = ListReq{} + mi := &file_ai_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ListReq) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ListReq) ProtoMessage() {} + +func (x *ListReq) ProtoReflect() protoreflect.Message { + mi := &file_ai_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ListReq.ProtoReflect.Descriptor instead. +func (*ListReq) Descriptor() ([]byte, []int) { + return file_ai_proto_rawDescGZIP(), []int{2} +} + +func (x *ListReq) GetUid() string { + if x != nil { + return x.Uid + } + return "" +} + +func (x *ListReq) GetOffset() int64 { + if x != nil { + return x.Offset + } + return 0 +} + +func (x *ListReq) GetLimit() int64 { + if x != nil { + return x.Limit + } + return 0 +} + +type ListResp struct { + state protoimpl.MessageState `protogen:"open.v1"` + Conversations []*Conversation `protobuf:"bytes,1,rep,name=conversations,proto3" json:"conversations,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ListResp) Reset() { + *x = ListResp{} + mi := &file_ai_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ListResp) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ListResp) ProtoMessage() {} + +func (x *ListResp) ProtoReflect() protoreflect.Message { + mi := &file_ai_proto_msgTypes[3] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ListResp.ProtoReflect.Descriptor instead. +func (*ListResp) Descriptor() ([]byte, []int) { + return file_ai_proto_rawDescGZIP(), []int{3} +} + +func (x *ListResp) GetConversations() []*Conversation { + if x != nil { + return x.Conversations + } + return nil +} + +type LLMRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + Sn string `protobuf:"bytes,1,opt,name=sn,proto3" json:"sn,omitempty"` + Message []*Message `protobuf:"bytes,2,rep,name=message,proto3" json:"message,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *LLMRequest) Reset() { + *x = LLMRequest{} + mi := &file_ai_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *LLMRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*LLMRequest) ProtoMessage() {} + +func (x *LLMRequest) ProtoReflect() protoreflect.Message { + mi := &file_ai_proto_msgTypes[4] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use LLMRequest.ProtoReflect.Descriptor instead. +func (*LLMRequest) Descriptor() ([]byte, []int) { + return file_ai_proto_rawDescGZIP(), []int{4} +} + +func (x *LLMRequest) GetSn() string { + if x != nil { + return x.Sn + } + return "" +} + +func (x *LLMRequest) GetMessage() []*Message { + if x != nil { + return x.Message + } + return nil +} + +type DetailRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + Sn string `protobuf:"bytes,1,opt,name=sn,proto3" json:"sn,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *DetailRequest) Reset() { + *x = DetailRequest{} + mi := &file_ai_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *DetailRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DetailRequest) ProtoMessage() {} + +func (x *DetailRequest) ProtoReflect() protoreflect.Message { + mi := &file_ai_proto_msgTypes[5] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DetailRequest.ProtoReflect.Descriptor instead. +func (*DetailRequest) Descriptor() ([]byte, []int) { + return file_ai_proto_rawDescGZIP(), []int{5} +} + +func (x *DetailRequest) GetSn() string { + if x != nil { + return x.Sn + } + return "" +} + +type DetailResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Message []*Message `protobuf:"bytes,2,rep,name=message,proto3" json:"message,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *DetailResponse) Reset() { + *x = DetailResponse{} + mi := &file_ai_proto_msgTypes[6] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *DetailResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DetailResponse) ProtoMessage() {} + +func (x *DetailResponse) ProtoReflect() protoreflect.Message { + mi := &file_ai_proto_msgTypes[6] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DetailResponse.ProtoReflect.Descriptor instead. +func (*DetailResponse) Descriptor() ([]byte, []int) { + return file_ai_proto_rawDescGZIP(), []int{6} +} + +func (x *DetailResponse) GetMessage() []*Message { + if x != nil { + return x.Message + } + return nil +} + +type Message struct { + state protoimpl.MessageState `protogen:"open.v1"` + Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"` + Role Role `protobuf:"varint,2,opt,name=role,proto3,enum=ai.v1.Role" json:"role,omitempty"` + Content string `protobuf:"bytes,3,opt,name=content,proto3" json:"content,omitempty"` + ReasoningContent string `protobuf:"bytes,4,opt,name=reasoningContent,proto3" json:"reasoningContent,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Message) Reset() { + *x = Message{} + mi := &file_ai_proto_msgTypes[7] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Message) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Message) ProtoMessage() {} + +func (x *Message) ProtoReflect() protoreflect.Message { + mi := &file_ai_proto_msgTypes[7] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Message.ProtoReflect.Descriptor instead. +func (*Message) Descriptor() ([]byte, []int) { + return file_ai_proto_rawDescGZIP(), []int{7} +} + +func (x *Message) GetId() string { + if x != nil { + return x.Id + } + return "" +} + +func (x *Message) GetRole() Role { + if x != nil { + return x.Role + } + return Role_UNKNOWN +} + +func (x *Message) GetContent() string { + if x != nil { + return x.Content + } + return "" +} + +func (x *Message) GetReasoningContent() string { + if x != nil { + return x.ReasoningContent + } + return "" +} + +type ChatResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Sn string `protobuf:"bytes,1,opt,name=sn,proto3" json:"sn,omitempty"` + Response *Message `protobuf:"bytes,2,opt,name=response,proto3" json:"response,omitempty"` + Metadata string `protobuf:"bytes,3,opt,name=metadata,proto3" json:"metadata,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ChatResponse) Reset() { + *x = ChatResponse{} + mi := &file_ai_proto_msgTypes[8] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ChatResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ChatResponse) ProtoMessage() {} + +func (x *ChatResponse) ProtoReflect() protoreflect.Message { + mi := &file_ai_proto_msgTypes[8] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ChatResponse.ProtoReflect.Descriptor instead. +func (*ChatResponse) Descriptor() ([]byte, []int) { + return file_ai_proto_rawDescGZIP(), []int{8} +} + +func (x *ChatResponse) GetSn() string { + if x != nil { + return x.Sn + } + return "" +} + +func (x *ChatResponse) GetResponse() *Message { + if x != nil { + return x.Response + } + return nil +} + +func (x *ChatResponse) GetMetadata() string { + if x != nil { + return x.Metadata + } + return "" +} + +var File_ai_proto protoreflect.FileDescriptor + +const file_ai_proto_rawDesc = "" + + "\n" + + "\bai.proto\x12\x05ai.v1\"{\n" + + "\vStreamEvent\x12\x14\n" + + "\x05final\x18\x01 \x01(\bR\x05final\x12*\n" + + "\x10reasoningContent\x18\x02 \x01(\tR\x10reasoningContent\x12\x18\n" + + "\acontent\x18\x03 \x01(\tR\acontent\x12\x10\n" + + "\x03err\x18\x04 \x01(\tR\x03err\"\x86\x01\n" + + "\fConversation\x12\x0e\n" + + "\x02sn\x18\x01 \x01(\tR\x02sn\x12\x10\n" + + "\x03uid\x18\x02 \x01(\tR\x03uid\x12\x14\n" + + "\x05title\x18\x03 \x01(\tR\x05title\x12(\n" + + "\amessage\x18\x04 \x03(\v2\x0e.ai.v1.MessageR\amessage\x12\x14\n" + + "\x05ctime\x18\x05 \x01(\tR\x05ctime\"I\n" + + "\aListReq\x12\x10\n" + + "\x03uid\x18\x01 \x01(\tR\x03uid\x12\x16\n" + + "\x06offset\x18\x02 \x01(\x03R\x06offset\x12\x14\n" + + "\x05limit\x18\x03 \x01(\x03R\x05limit\"E\n" + + "\bListResp\x129\n" + + "\rconversations\x18\x01 \x03(\v2\x13.ai.v1.ConversationR\rconversations\"F\n" + + "\n" + + "LLMRequest\x12\x0e\n" + + "\x02sn\x18\x01 \x01(\tR\x02sn\x12(\n" + + "\amessage\x18\x02 \x03(\v2\x0e.ai.v1.MessageR\amessage\"\x1f\n" + + "\rDetailRequest\x12\x0e\n" + + "\x02sn\x18\x01 \x01(\tR\x02sn\":\n" + + "\x0eDetailResponse\x12(\n" + + "\amessage\x18\x02 \x03(\v2\x0e.ai.v1.MessageR\amessage\"\x80\x01\n" + + "\aMessage\x12\x0e\n" + + "\x02id\x18\x01 \x01(\tR\x02id\x12\x1f\n" + + "\x04role\x18\x02 \x01(\x0e2\v.ai.v1.RoleR\x04role\x12\x18\n" + + "\acontent\x18\x03 \x01(\tR\acontent\x12*\n" + + "\x10reasoningContent\x18\x04 \x01(\tR\x10reasoningContent\"f\n" + + "\fChatResponse\x12\x0e\n" + + "\x02sn\x18\x01 \x01(\tR\x02sn\x12*\n" + + "\bresponse\x18\x02 \x01(\v2\x0e.ai.v1.MessageR\bresponse\x12\x1a\n" + + "\bmetadata\x18\x03 \x01(\tR\bmetadata*B\n" + + "\x04Role\x12\v\n" + + "\aUNKNOWN\x10\x00\x12\b\n" + + "\x04USER\x10\x01\x12\r\n" + + "\tASSISTANT\x10\x02\x12\n" + + "\n" + + "\x06SYSTEM\x10\x03\x12\b\n" + + "\x04TOOL\x10\x042h\n" + + "\tAIService\x12+\n" + + "\x04Chat\x12\x0e.ai.v1.Message\x1a\x13.ai.v1.ChatResponse\x12.\n" + + "\x06Stream\x12\x0e.ai.v1.Message\x1a\x12.ai.v1.StreamEvent0\x012\x8c\x02\n" + + "\x13ConversationService\x122\n" + + "\x06Create\x12\x13.ai.v1.Conversation\x1a\x13.ai.v1.Conversation\x12'\n" + + "\x04List\x12\x0e.ai.v1.ListReq\x1a\x0f.ai.v1.ListResp\x12.\n" + + "\x04Chat\x12\x11.ai.v1.LLMRequest\x1a\x13.ai.v1.ChatResponse\x125\n" + + "\x06Detail\x12\x14.ai.v1.DetailRequest\x1a\x15.ai.v1.DetailResponse\x121\n" + + "\x06Stream\x12\x11.ai.v1.LLMRequest\x1a\x12.ai.v1.StreamEvent0\x01Bz\n" + + "\tcom.ai.v1B\aAiProtoP\x01Z/github.com/ecodeclub/ai-gateway-go/api/gen;aiv1\xa2\x02\x03AXX\xaa\x02\x05Ai.V1\xca\x02\x05Ai\\V1\xe2\x02\x11Ai\\V1\\GPBMetadata\xea\x02\x06Ai::V1b\x06proto3" + +var ( + file_ai_proto_rawDescOnce sync.Once + file_ai_proto_rawDescData []byte +) + +func file_ai_proto_rawDescGZIP() []byte { + file_ai_proto_rawDescOnce.Do(func() { + file_ai_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_ai_proto_rawDesc), len(file_ai_proto_rawDesc))) + }) + return file_ai_proto_rawDescData +} + +var file_ai_proto_enumTypes = make([]protoimpl.EnumInfo, 1) +var file_ai_proto_msgTypes = make([]protoimpl.MessageInfo, 9) +var file_ai_proto_goTypes = []any{ + (Role)(0), // 0: ai.v1.Role + (*StreamEvent)(nil), // 1: ai.v1.StreamEvent + (*Conversation)(nil), // 2: ai.v1.Conversation + (*ListReq)(nil), // 3: ai.v1.ListReq + (*ListResp)(nil), // 4: ai.v1.ListResp + (*LLMRequest)(nil), // 5: ai.v1.LLMRequest + (*DetailRequest)(nil), // 6: ai.v1.DetailRequest + (*DetailResponse)(nil), // 7: ai.v1.DetailResponse + (*Message)(nil), // 8: ai.v1.Message + (*ChatResponse)(nil), // 9: ai.v1.ChatResponse +} +var file_ai_proto_depIdxs = []int32{ + 8, // 0: ai.v1.Conversation.message:type_name -> ai.v1.Message + 2, // 1: ai.v1.ListResp.conversations:type_name -> ai.v1.Conversation + 8, // 2: ai.v1.LLMRequest.message:type_name -> ai.v1.Message + 8, // 3: ai.v1.DetailResponse.message:type_name -> ai.v1.Message + 0, // 4: ai.v1.Message.role:type_name -> ai.v1.Role + 8, // 5: ai.v1.ChatResponse.response:type_name -> ai.v1.Message + 8, // 6: ai.v1.AIService.Chat:input_type -> ai.v1.Message + 8, // 7: ai.v1.AIService.Stream:input_type -> ai.v1.Message + 2, // 8: ai.v1.ConversationService.Create:input_type -> ai.v1.Conversation + 3, // 9: ai.v1.ConversationService.List:input_type -> ai.v1.ListReq + 5, // 10: ai.v1.ConversationService.Chat:input_type -> ai.v1.LLMRequest + 6, // 11: ai.v1.ConversationService.Detail:input_type -> ai.v1.DetailRequest + 5, // 12: ai.v1.ConversationService.Stream:input_type -> ai.v1.LLMRequest + 9, // 13: ai.v1.AIService.Chat:output_type -> ai.v1.ChatResponse + 1, // 14: ai.v1.AIService.Stream:output_type -> ai.v1.StreamEvent + 2, // 15: ai.v1.ConversationService.Create:output_type -> ai.v1.Conversation + 4, // 16: ai.v1.ConversationService.List:output_type -> ai.v1.ListResp + 9, // 17: ai.v1.ConversationService.Chat:output_type -> ai.v1.ChatResponse + 7, // 18: ai.v1.ConversationService.Detail:output_type -> ai.v1.DetailResponse + 1, // 19: ai.v1.ConversationService.Stream:output_type -> ai.v1.StreamEvent + 13, // [13:20] is the sub-list for method output_type + 6, // [6:13] is the sub-list for method input_type + 6, // [6:6] is the sub-list for extension type_name + 6, // [6:6] is the sub-list for extension extendee + 0, // [0:6] is the sub-list for field type_name +} + +func init() { file_ai_proto_init() } +func file_ai_proto_init() { + if File_ai_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_ai_proto_rawDesc), len(file_ai_proto_rawDesc)), + NumEnums: 1, + NumMessages: 9, + NumExtensions: 0, + NumServices: 2, + }, + GoTypes: file_ai_proto_goTypes, + DependencyIndexes: file_ai_proto_depIdxs, + EnumInfos: file_ai_proto_enumTypes, + MessageInfos: file_ai_proto_msgTypes, + }.Build() + File_ai_proto = out.File + file_ai_proto_goTypes = nil + file_ai_proto_depIdxs = nil +} diff --git a/api/gen/ai/v1/ai_grpc.pb.go b/api/gen/ai/v1/ai_grpc.pb.go new file mode 100644 index 0000000..359a52f --- /dev/null +++ b/api/gen/ai/v1/ai_grpc.pb.go @@ -0,0 +1,435 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.5.1 +// - protoc (unknown) +// source: ai.proto + +package aiv1 + +import ( + context "context" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +// Requires gRPC-Go v1.64.0 or later. +const _ = grpc.SupportPackageIsVersion9 + +const ( + AIService_Chat_FullMethodName = "/ai.v1.AIService/Chat" + AIService_Stream_FullMethodName = "/ai.v1.AIService/Stream" +) + +// AIServiceClient is the client API for AIService service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +type AIServiceClient interface { + Chat(ctx context.Context, in *Message, opts ...grpc.CallOption) (*ChatResponse, error) + Stream(ctx context.Context, in *Message, opts ...grpc.CallOption) (grpc.ServerStreamingClient[StreamEvent], error) +} + +type aIServiceClient struct { + cc grpc.ClientConnInterface +} + +func NewAIServiceClient(cc grpc.ClientConnInterface) AIServiceClient { + return &aIServiceClient{cc} +} + +func (c *aIServiceClient) Chat(ctx context.Context, in *Message, opts ...grpc.CallOption) (*ChatResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(ChatResponse) + err := c.cc.Invoke(ctx, AIService_Chat_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *aIServiceClient) Stream(ctx context.Context, in *Message, opts ...grpc.CallOption) (grpc.ServerStreamingClient[StreamEvent], error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + stream, err := c.cc.NewStream(ctx, &AIService_ServiceDesc.Streams[0], AIService_Stream_FullMethodName, cOpts...) + if err != nil { + return nil, err + } + x := &grpc.GenericClientStream[Message, StreamEvent]{ClientStream: stream} + if err := x.ClientStream.SendMsg(in); err != nil { + return nil, err + } + if err := x.ClientStream.CloseSend(); err != nil { + return nil, err + } + return x, nil +} + +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type AIService_StreamClient = grpc.ServerStreamingClient[StreamEvent] + +// AIServiceServer is the server API for AIService service. +// All implementations must embed UnimplementedAIServiceServer +// for forward compatibility. +type AIServiceServer interface { + Chat(context.Context, *Message) (*ChatResponse, error) + Stream(*Message, grpc.ServerStreamingServer[StreamEvent]) error + mustEmbedUnimplementedAIServiceServer() +} + +// UnimplementedAIServiceServer must be embedded to have +// forward compatible implementations. +// +// NOTE: this should be embedded by value instead of pointer to avoid a nil +// pointer dereference when methods are called. +type UnimplementedAIServiceServer struct{} + +func (UnimplementedAIServiceServer) Chat(context.Context, *Message) (*ChatResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method Chat not implemented") +} +func (UnimplementedAIServiceServer) Stream(*Message, grpc.ServerStreamingServer[StreamEvent]) error { + return status.Errorf(codes.Unimplemented, "method Stream not implemented") +} +func (UnimplementedAIServiceServer) mustEmbedUnimplementedAIServiceServer() {} +func (UnimplementedAIServiceServer) testEmbeddedByValue() {} + +// UnsafeAIServiceServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to AIServiceServer will +// result in compilation errors. +type UnsafeAIServiceServer interface { + mustEmbedUnimplementedAIServiceServer() +} + +func RegisterAIServiceServer(s grpc.ServiceRegistrar, srv AIServiceServer) { + // If the following call pancis, it indicates UnimplementedAIServiceServer was + // embedded by pointer and is nil. This will cause panics if an + // unimplemented method is ever invoked, so we test this at initialization + // time to prevent it from happening at runtime later due to I/O. + if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { + t.testEmbeddedByValue() + } + s.RegisterService(&AIService_ServiceDesc, srv) +} + +func _AIService_Chat_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(Message) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(AIServiceServer).Chat(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: AIService_Chat_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(AIServiceServer).Chat(ctx, req.(*Message)) + } + return interceptor(ctx, in, info, handler) +} + +func _AIService_Stream_Handler(srv interface{}, stream grpc.ServerStream) error { + m := new(Message) + if err := stream.RecvMsg(m); err != nil { + return err + } + return srv.(AIServiceServer).Stream(m, &grpc.GenericServerStream[Message, StreamEvent]{ServerStream: stream}) +} + +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type AIService_StreamServer = grpc.ServerStreamingServer[StreamEvent] + +// AIService_ServiceDesc is the grpc.ServiceDesc for AIService service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var AIService_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "ai.v1.AIService", + HandlerType: (*AIServiceServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "Chat", + Handler: _AIService_Chat_Handler, + }, + }, + Streams: []grpc.StreamDesc{ + { + StreamName: "Stream", + Handler: _AIService_Stream_Handler, + ServerStreams: true, + }, + }, + Metadata: "ai.proto", +} + +const ( + ConversationService_Create_FullMethodName = "/ai.v1.ConversationService/Create" + ConversationService_List_FullMethodName = "/ai.v1.ConversationService/List" + ConversationService_Chat_FullMethodName = "/ai.v1.ConversationService/Chat" + ConversationService_Detail_FullMethodName = "/ai.v1.ConversationService/Detail" + ConversationService_Stream_FullMethodName = "/ai.v1.ConversationService/Stream" +) + +// ConversationServiceClient is the client API for ConversationService service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +type ConversationServiceClient interface { + Create(ctx context.Context, in *Conversation, opts ...grpc.CallOption) (*Conversation, error) + List(ctx context.Context, in *ListReq, opts ...grpc.CallOption) (*ListResp, error) + Chat(ctx context.Context, in *LLMRequest, opts ...grpc.CallOption) (*ChatResponse, error) + Detail(ctx context.Context, in *DetailRequest, opts ...grpc.CallOption) (*DetailResponse, error) + Stream(ctx context.Context, in *LLMRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[StreamEvent], error) +} + +type conversationServiceClient struct { + cc grpc.ClientConnInterface +} + +func NewConversationServiceClient(cc grpc.ClientConnInterface) ConversationServiceClient { + return &conversationServiceClient{cc} +} + +func (c *conversationServiceClient) Create(ctx context.Context, in *Conversation, opts ...grpc.CallOption) (*Conversation, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(Conversation) + err := c.cc.Invoke(ctx, ConversationService_Create_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *conversationServiceClient) List(ctx context.Context, in *ListReq, opts ...grpc.CallOption) (*ListResp, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(ListResp) + err := c.cc.Invoke(ctx, ConversationService_List_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *conversationServiceClient) Chat(ctx context.Context, in *LLMRequest, opts ...grpc.CallOption) (*ChatResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(ChatResponse) + err := c.cc.Invoke(ctx, ConversationService_Chat_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *conversationServiceClient) Detail(ctx context.Context, in *DetailRequest, opts ...grpc.CallOption) (*DetailResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(DetailResponse) + err := c.cc.Invoke(ctx, ConversationService_Detail_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *conversationServiceClient) Stream(ctx context.Context, in *LLMRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[StreamEvent], error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + stream, err := c.cc.NewStream(ctx, &ConversationService_ServiceDesc.Streams[0], ConversationService_Stream_FullMethodName, cOpts...) + if err != nil { + return nil, err + } + x := &grpc.GenericClientStream[LLMRequest, StreamEvent]{ClientStream: stream} + if err := x.ClientStream.SendMsg(in); err != nil { + return nil, err + } + if err := x.ClientStream.CloseSend(); err != nil { + return nil, err + } + return x, nil +} + +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type ConversationService_StreamClient = grpc.ServerStreamingClient[StreamEvent] + +// ConversationServiceServer is the server API for ConversationService service. +// All implementations must embed UnimplementedConversationServiceServer +// for forward compatibility. +type ConversationServiceServer interface { + Create(context.Context, *Conversation) (*Conversation, error) + List(context.Context, *ListReq) (*ListResp, error) + Chat(context.Context, *LLMRequest) (*ChatResponse, error) + Detail(context.Context, *DetailRequest) (*DetailResponse, error) + Stream(*LLMRequest, grpc.ServerStreamingServer[StreamEvent]) error + mustEmbedUnimplementedConversationServiceServer() +} + +// UnimplementedConversationServiceServer must be embedded to have +// forward compatible implementations. +// +// NOTE: this should be embedded by value instead of pointer to avoid a nil +// pointer dereference when methods are called. +type UnimplementedConversationServiceServer struct{} + +func (UnimplementedConversationServiceServer) Create(context.Context, *Conversation) (*Conversation, error) { + return nil, status.Errorf(codes.Unimplemented, "method Create not implemented") +} +func (UnimplementedConversationServiceServer) List(context.Context, *ListReq) (*ListResp, error) { + return nil, status.Errorf(codes.Unimplemented, "method List not implemented") +} +func (UnimplementedConversationServiceServer) Chat(context.Context, *LLMRequest) (*ChatResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method Chat not implemented") +} +func (UnimplementedConversationServiceServer) Detail(context.Context, *DetailRequest) (*DetailResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method Detail not implemented") +} +func (UnimplementedConversationServiceServer) Stream(*LLMRequest, grpc.ServerStreamingServer[StreamEvent]) error { + return status.Errorf(codes.Unimplemented, "method Stream not implemented") +} +func (UnimplementedConversationServiceServer) mustEmbedUnimplementedConversationServiceServer() {} +func (UnimplementedConversationServiceServer) testEmbeddedByValue() {} + +// UnsafeConversationServiceServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to ConversationServiceServer will +// result in compilation errors. +type UnsafeConversationServiceServer interface { + mustEmbedUnimplementedConversationServiceServer() +} + +func RegisterConversationServiceServer(s grpc.ServiceRegistrar, srv ConversationServiceServer) { + // If the following call pancis, it indicates UnimplementedConversationServiceServer was + // embedded by pointer and is nil. This will cause panics if an + // unimplemented method is ever invoked, so we test this at initialization + // time to prevent it from happening at runtime later due to I/O. + if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { + t.testEmbeddedByValue() + } + s.RegisterService(&ConversationService_ServiceDesc, srv) +} + +func _ConversationService_Create_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(Conversation) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ConversationServiceServer).Create(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: ConversationService_Create_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ConversationServiceServer).Create(ctx, req.(*Conversation)) + } + return interceptor(ctx, in, info, handler) +} + +func _ConversationService_List_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(ListReq) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ConversationServiceServer).List(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: ConversationService_List_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ConversationServiceServer).List(ctx, req.(*ListReq)) + } + return interceptor(ctx, in, info, handler) +} + +func _ConversationService_Chat_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(LLMRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ConversationServiceServer).Chat(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: ConversationService_Chat_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ConversationServiceServer).Chat(ctx, req.(*LLMRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _ConversationService_Detail_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(DetailRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ConversationServiceServer).Detail(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: ConversationService_Detail_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ConversationServiceServer).Detail(ctx, req.(*DetailRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _ConversationService_Stream_Handler(srv interface{}, stream grpc.ServerStream) error { + m := new(LLMRequest) + if err := stream.RecvMsg(m); err != nil { + return err + } + return srv.(ConversationServiceServer).Stream(m, &grpc.GenericServerStream[LLMRequest, StreamEvent]{ServerStream: stream}) +} + +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type ConversationService_StreamServer = grpc.ServerStreamingServer[StreamEvent] + +// ConversationService_ServiceDesc is the grpc.ServiceDesc for ConversationService service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var ConversationService_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "ai.v1.ConversationService", + HandlerType: (*ConversationServiceServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "Create", + Handler: _ConversationService_Create_Handler, + }, + { + MethodName: "List", + Handler: _ConversationService_List_Handler, + }, + { + MethodName: "Chat", + Handler: _ConversationService_Chat_Handler, + }, + { + MethodName: "Detail", + Handler: _ConversationService_Detail_Handler, + }, + }, + Streams: []grpc.StreamDesc{ + { + StreamName: "Stream", + Handler: _ConversationService_Stream_Handler, + ServerStreams: true, + }, + }, + Metadata: "ai.proto", +} diff --git a/api/proto/ai/v1/ai.proto b/api/proto/ai/v1/ai.proto index 6e540ec..2f8da2c 100644 --- a/api/proto/ai/v1/ai.proto +++ b/api/proto/ai/v1/ai.proto @@ -14,8 +14,7 @@ syntax = "proto3"; package ai.v1; - -option go_package = "v1/ai;aiv1"; +option go_package = "ai/v1;aiv1"; service AIService { rpc Chat(Message) returns (ChatResponse); diff --git a/api/proto/buf.yaml b/api/proto/buf.yaml deleted file mode 100644 index ee12108..0000000 --- a/api/proto/buf.yaml +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright 2021 ecodeclub -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -version: v2 -deps: - - buf.build/googleapis/googleapis -breaking: - use: - - FILE -lint: - use: - - STANDARD # Omit all Buf categories if you don't want to use Buf's built-in rules - except: - - ENUM_VALUE_PREFIX - - ENUM_ZERO_VALUE_SUFFIX - ignore: - - google/type/datetime.proto - - google/protobuf/empty.proto - - google/protobuf/timestamp.proto - disallow_comment_ignores: false # The default behavior of this key has changed from v1 - enum_zero_value_suffix: _UNSPECIFIED - rpc_allow_same_request_response: false - rpc_allow_google_protobuf_empty_requests: false - rpc_allow_google_protobuf_empty_responses: false - service_suffix: Service diff --git a/cmd/platform/server.go b/cmd/platform/server.go index b238a1c..1d3d4a8 100644 --- a/cmd/platform/server.go +++ b/cmd/platform/server.go @@ -16,7 +16,7 @@ package main import ( ds "github.com/cohesion-org/deepseek-go" - ai "github.com/ecodeclub/ai-gateway-go/api/proto/gen/ai/v1" + ai "github.com/ecodeclub/ai-gateway-go/api/gen/ai/v1" igrpc "github.com/ecodeclub/ai-gateway-go/internal/grpc" "github.com/ecodeclub/ai-gateway-go/internal/service" "github.com/ecodeclub/ai-gateway-go/internal/service/llm/platform/deepseek" diff --git a/errs/errs.go b/errs/errs.go index f4b0fea..80bf227 100644 --- a/errs/errs.go +++ b/errs/errs.go @@ -18,4 +18,8 @@ import ( "errors" ) -var ErrBizConfigNotFound = errors.New("biz config not found") +var ( + ErrBizConfigNotFound = errors.New("查询业务配置失败") + ErrInvalidParam = errors.New("参数错误") + ErrInsufficientBalance = errors.New("余额不足") +) diff --git a/internal/domain/quota.go b/internal/domain/quota.go new file mode 100644 index 0000000..1953893 --- /dev/null +++ b/internal/domain/quota.go @@ -0,0 +1,32 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package domain + +type Quota struct { + Amount int64 + Key string + Uid int64 + LastClearTime int64 +} + +type TempQuota struct { + Amount int64 + Key string + StartTime int64 + EndTime int64 + Uid int64 +} + +type Record struct{} diff --git a/internal/errs/code.go b/internal/errs/code.go index 615b5e3..41509d1 100644 --- a/internal/errs/code.go +++ b/internal/errs/code.go @@ -14,7 +14,11 @@ package errs -var SystemError = ErrorCode{Code: 501001, Msg: "系统错误"} +var ( + SystemError = ErrorCode{Code: 501001, Msg: "系统错误"} + InvalidParamError = ErrorCode{Code: 400001, Msg: "参数错误"} + InsufficientBalanceError = ErrorCode{Code: 400002, Msg: "余额不足"} +) type ErrorCode struct { Code int diff --git a/internal/grpc/conversation.go b/internal/grpc/conversation.go index 43b111e..f9aa755 100644 --- a/internal/grpc/conversation.go +++ b/internal/grpc/conversation.go @@ -17,7 +17,7 @@ package grpc import ( "context" - ai "github.com/ecodeclub/ai-gateway-go/api/proto/gen/ai/v1" + ai "github.com/ecodeclub/ai-gateway-go/api/gen/ai/v1" "github.com/ecodeclub/ai-gateway-go/internal/domain" "github.com/ecodeclub/ai-gateway-go/internal/service" "github.com/ecodeclub/ekit/slice" diff --git a/internal/grpc/server.go b/internal/grpc/server.go index d4eb458..856176e 100644 --- a/internal/grpc/server.go +++ b/internal/grpc/server.go @@ -18,7 +18,7 @@ import ( "context" "strconv" - ai "github.com/ecodeclub/ai-gateway-go/api/proto/gen/ai/v1" + ai "github.com/ecodeclub/ai-gateway-go/api/gen/ai/v1" "github.com/ecodeclub/ai-gateway-go/internal/domain" "github.com/ecodeclub/ai-gateway-go/internal/service" ) diff --git a/internal/repository/dao/quota.go b/internal/repository/dao/quota.go new file mode 100644 index 0000000..c88a7c1 --- /dev/null +++ b/internal/repository/dao/quota.go @@ -0,0 +1,202 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dao + +import ( + "context" + "errors" + "time" + + "github.com/ecodeclub/ai-gateway-go/errs" + "gorm.io/gorm" + "gorm.io/gorm/clause" +) + +type TempQuota struct { + ID int64 `gorm:"primaryKey;autoIncrement;column:id"` + UID int64 `gorm:"column:uid"` + Key string `gorm:"column:key;uniqueIndex;type:varchar(256)"` + Amount int64 `gorm:"column:amount"` + StartTime int64 `gorm:"column:start_time"` + EndTime int64 `gorm:"column:end_time"` + Ctime int64 `gorm:"column:ctime"` + Utime int64 `gorm:"column:utime"` +} + +func (TempQuota) TableName() string { + return "temp_quotas" +} + +type QuotaRecord struct { + ID int64 `gorm:"primaryKey;autoIncrement;column:id"` + Uid int64 `gorm:"column:uid;index"` + Key string `gorm:"column:key;uniqueIndex;type:varchar(256)"` + Amount int64 `gorm:"column:amount"` + Ctime int64 `gorm:"column:ctime"` + Utime int64 `gorm:"column:utime"` +} + +func (QuotaRecord) TableName() string { + return "quota_records" +} + +type Quota struct { + ID int64 `gorm:"primaryKey;autoIncrement;column:id"` + UID int64 `gorm:"column:uid"` + Key string `gorm:"column:key;uniqueIndex;type:varchar(256)"` + Amount int64 `gorm:"column:amount"` + LastClearTime int64 `gorm:"column:last_clear_time"` + Ctime int64 `gorm:"column:ctime"` + Utime int64 `gorm:"column:utime"` +} + +func (Quota) TableName() string { + return "quotas" +} + +type QuotaDao struct { + db *gorm.DB +} + +func NewQuotaDao(db *gorm.DB) *QuotaDao { + return &QuotaDao{db: db} +} + +func (dao *QuotaDao) CreateTempQuota(ctx context.Context, quota TempQuota) error { + now := time.Now().Unix() + quota.Ctime = now + quota.Utime = now + return dao.db.WithContext(ctx).Create("a).Error +} + +func (dao *QuotaDao) AddQuota(ctx context.Context, quota Quota) error { + now := time.Now().Unix() + quota.Utime = now + + return dao.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + now := time.Now().Unix() + record := QuotaRecord{ + Key: quota.Key, + Uid: quota.UID, + Amount: quota.Amount, + Ctime: now, + Utime: now, + } + err := tx.Create(&record).Error + if err != nil { + return err + } + + return tx.Clauses(clause.OnConflict{ + Columns: []clause.Column{{Name: "key"}}, + DoUpdates: clause.Assignments(map[string]any{ + "amount": gorm.Expr("amount + ?", quota.Amount), + "utime": now, + }), + }).Create("a).Error + }) +} + +func (dao *QuotaDao) GetQuotaByUid(ctx context.Context, uid int64) (Quota, error) { + var quota Quota + err := dao.db.WithContext(ctx). + Where("uid = ? and end_time >= ?", uid). + First("a).Error + if err != nil { + return Quota{}, err + } + return quota, nil +} + +func (dao *QuotaDao) GetTempQuotaByUidAndTime(ctx context.Context, uid int64) ([]TempQuota, error) { + now := time.Now().Unix() + var quota []TempQuota + err := dao.db.WithContext(ctx). + Where("uid = ? and end_time >= ?", uid, now). + Order("end_time ASC"). + Find("a).Error + if err != nil { + return nil, err + } + return quota, nil +} + +func (dao *QuotaDao) Deduct(ctx context.Context, uid int64, amount int64, key string) error { + return dao.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + now := time.Now().Unix() + record := QuotaRecord{ + Key: key, + Uid: uid, + Amount: amount, + Ctime: now, + Utime: now, + } + err := tx.Create(&record).Error + if err != nil { + return err + } + return dao.deduct(tx, uid, amount, now) + }) +} + +func (dao *QuotaDao) deduct(tx *gorm.DB, uid int64, amount int64, now int64) error { + deductAmount := amount + for { + var quota TempQuota + err := tx.Where("amount > ? and uid = ?", 0, uid).First("a).Error + if err != nil { + // 表示找不到可以扣减的temp + if errors.Is(err, gorm.ErrRecordNotFound) { + break + } + return err + } + deductAmount = min(deductAmount, quota.Amount) + result := tx.Model(&TempQuota{}).Where("amount > ? and uid = ?", 0, uid).Updates(map[string]any{ + "amount": gorm.Expr("amount - ?", deductAmount), + "utime": now, + }) + if result.Error != nil { + return result.Error + } + // 并发问题, 直接下一个 + if result.RowsAffected == 0 { + continue + } + // 表示扣减完毕 + amount -= deductAmount + if amount <= 0 { + return nil + } + } + // 从主额度扣 + result := tx.Model(&Quota{}). + Where("uid = ? AND amount >= ?", uid, deductAmount). + Updates(map[string]any{ + "amount": gorm.Expr("amount - ?", deductAmount), + "utime": now, + }) + if result.Error != nil { + return result.Error + } + if result.RowsAffected == 0 { + return errs.ErrInsufficientBalance + } + return nil +} + +func InitQuotaTable(db *gorm.DB) error { + return db.AutoMigrate(&Quota{}, &TempQuota{}, &QuotaRecord{}) +} diff --git a/internal/repository/quota.go b/internal/repository/quota.go new file mode 100644 index 0000000..c55ba38 --- /dev/null +++ b/internal/repository/quota.go @@ -0,0 +1,69 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package repository + +import ( + "context" + + "github.com/ecodeclub/ai-gateway-go/internal/domain" + "github.com/ecodeclub/ai-gateway-go/internal/repository/dao" + "github.com/ecodeclub/ekit/slice" +) + +type QuotaRepo struct { + dao *dao.QuotaDao +} + +func NewQuotaRepo(dao *dao.QuotaDao) *QuotaRepo { + return &QuotaRepo{dao: dao} +} + +func (q *QuotaRepo) AddQuota(ctx context.Context, quota domain.Quota) error { + return q.dao.AddQuota(ctx, dao.Quota{UID: quota.Uid, Amount: quota.Amount, Key: quota.Key}) +} + +func (q *QuotaRepo) CreateTempQuota(ctx context.Context, quota domain.TempQuota) error { + return q.dao.CreateTempQuota(ctx, dao.TempQuota{Amount: quota.Amount, StartTime: quota.StartTime, EndTime: quota.EndTime, Key: quota.Key, UID: quota.Uid}) +} + +func (q *QuotaRepo) GetQuota(ctx context.Context, uid int64) (domain.Quota, error) { + quota, err := q.dao.GetQuotaByUid(ctx, uid) + if err != nil { + return domain.Quota{}, err + } + return domain.Quota{Amount: quota.Amount, Uid: uid}, nil +} + +func (q *QuotaRepo) GetTempQuota(ctx context.Context, uid int64) ([]domain.TempQuota, error) { + tempQuotaList, err := q.dao.GetTempQuotaByUidAndTime(ctx, uid) + if err != nil { + return nil, err + } + return q.toDomainTempQuota(tempQuotaList), nil +} + +func (q *QuotaRepo) Deduct(ctx context.Context, uid int64, amount int64, key string) error { + return q.dao.Deduct(ctx, uid, amount, key) +} + +func (q *QuotaRepo) toDomainTempQuota(tmpQuotaList []dao.TempQuota) []domain.TempQuota { + return slice.Map[dao.TempQuota, domain.TempQuota](tmpQuotaList, func(idx int, src dao.TempQuota) domain.TempQuota { + return domain.TempQuota{ + Amount: src.Amount, + StartTime: src.StartTime, + EndTime: src.EndTime, + } + }) +} diff --git a/internal/service/quota.go b/internal/service/quota.go new file mode 100644 index 0000000..4ec1cf4 --- /dev/null +++ b/internal/service/quota.go @@ -0,0 +1,50 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package service + +import ( + "context" + + "github.com/ecodeclub/ai-gateway-go/internal/domain" + "github.com/ecodeclub/ai-gateway-go/internal/repository" +) + +type QuotaService struct { + repo *repository.QuotaRepo +} + +func NewQuotaService(repo *repository.QuotaRepo) *QuotaService { + return &QuotaService{repo: repo} +} + +func (q *QuotaService) AddQuota(ctx context.Context, quota domain.Quota) error { + return q.repo.AddQuota(ctx, quota) +} + +func (q *QuotaService) CreateTempQuota(ctx context.Context, quota domain.TempQuota) error { + return q.repo.CreateTempQuota(ctx, quota) +} + +func (q *QuotaService) GetTempQuota(ctx context.Context, uid int64) ([]domain.TempQuota, error) { + return q.repo.GetTempQuota(ctx, uid) +} + +func (q *QuotaService) GetQuota(ctx context.Context, uid int64) (domain.Quota, error) { + return q.repo.GetQuota(ctx, uid) +} + +func (q *QuotaService) Deduct(ctx context.Context, uid int64, amount int64, key string) error { + return q.repo.Deduct(ctx, uid, amount, key) +} diff --git a/internal/test/conversation_test.go b/internal/test/conversation_test.go index eca8946..cdc3789 100644 --- a/internal/test/conversation_test.go +++ b/internal/test/conversation_test.go @@ -18,7 +18,7 @@ import ( "context" "testing" - aiv1 "github.com/ecodeclub/ai-gateway-go/api/proto/gen/ai/v1" + aiv1 "github.com/ecodeclub/ai-gateway-go/api/gen/ai/v1" "github.com/ecodeclub/ai-gateway-go/internal/domain" "github.com/ecodeclub/ai-gateway-go/internal/grpc" "github.com/ecodeclub/ai-gateway-go/internal/repository" @@ -288,7 +288,7 @@ func (c *ConversationSuite) TestDetail() { conversationService := service.NewConversationService(repo, handler) server := grpc.NewConversationServer(conversationService) tc.before() - detail, err := server.Detail(context.Background(), &aiv1.MsgListReq{Sn: "1"}) + detail, err := server.Detail(context.Background(), &aiv1.DetailRequest{Sn: "1"}) require.NoError(t, err) assert.ElementsMatch(t, detail.Message, []*aiv1.Message{ {Role: aiv1.Role_USER, Content: "user1"}, diff --git a/internal/test/mocks/mock_resp.go b/internal/test/mocks/mock_resp.go index 44344ab..5f1f31a 100644 --- a/internal/test/mocks/mock_resp.go +++ b/internal/test/mocks/mock_resp.go @@ -3,7 +3,7 @@ package mocks import ( "context" - ai "github.com/ecodeclub/ai-gateway-go/api/proto/gen/ai/v1" + ai "github.com/ecodeclub/ai-gateway-go/api/gen/ai/v1" ) type MockStreamServer struct { diff --git a/internal/test/quota_test.go b/internal/test/quota_test.go new file mode 100644 index 0000000..fbee4f2 --- /dev/null +++ b/internal/test/quota_test.go @@ -0,0 +1,398 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/ecodeclub/ai-gateway-go/internal/repository" + "github.com/ecodeclub/ai-gateway-go/internal/repository/dao" + "github.com/ecodeclub/ai-gateway-go/internal/service" + "github.com/ecodeclub/ai-gateway-go/internal/test/mocks" + "github.com/ecodeclub/ai-gateway-go/internal/web" + "github.com/ecodeclub/ginx/session" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/yumosx/got/pkg/config" + "go.uber.org/mock/gomock" + "gorm.io/gorm" +) + +type QuotaSuite struct { + suite.Suite + db *gorm.DB + server *gin.Engine +} + +func TestQuota(t *testing.T) { + suite.Run(t, &QuotaSuite{}) +} + +func (q *QuotaSuite) SetupSuite() { + dbConfig := config.NewConfig( + config.WithDBName("ai_gateway_platform"), + config.WithUserName("root"), + config.WithPassword("root"), + config.WithHost("127.0.0.1"), + config.WithPort("13306"), + ) + db, err := config.NewDB(dbConfig) + require.NoError(q.T(), err) + err = dao.InitQuotaTable(db) + require.NoError(q.T(), err) + q.db = db + + d := dao.NewQuotaDao(db) + repo := repository.NewQuotaRepo(d) + svc := service.NewQuotaService(repo) + handler := web.NewQuotaHandler(svc) + server := gin.Default() + handler.PrivateRoutes(server) + q.server = server +} + +func (q *QuotaSuite) TearDownTest() { + err := q.db.Exec("TRUNCATE TABLE quotas").Error + require.NoError(q.T(), err) + err = q.db.Exec("TRUNCATE TABLE quota_records").Error + require.NoError(q.T(), err) + err = q.db.Exec("TRUNCATE TABLE temp_quotas").Error + require.NoError(q.T(), err) +} + +func (q *QuotaSuite) TestQuotaSave() { + t := q.T() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + testcases := []struct { + name string + before func() + after func() + reqBody string + }{ + { + name: "创建一个 quota", + before: func() { + sess := mocks.NewMockSession(ctrl) + sess.EXPECT().Claims().Return(session.Claims{ + Uid: 1, + }).AnyTimes() + provider := mocks.NewMockProvider(ctrl) + session.SetDefaultProvider(provider) + provider.EXPECT().Get(gomock.Any()).Return(sess, nil) + }, + after: func() { + var quota dao.Quota + err := q.db.Where("id = ?", 1).First("a).Error + require.NoError(t, err) + assert.Equal(t, int64(100000), quota.Amount) + + var record dao.QuotaRecord + err = q.db.Where("id = ?", 1).First(&record).Error + require.NoError(t, err) + assert.Equal(t, "23911", record.Key) + }, + reqBody: `{"amount": 100000, "key": "23911"}`, + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + // 每个测试用例开始时清理数据 + err := q.db.Exec("TRUNCATE TABLE quotas").Error + require.NoError(t, err) + err = q.db.Exec("TRUNCATE TABLE quota_records").Error + require.NoError(t, err) + err = q.db.Exec("TRUNCATE TABLE temp_quotas").Error + require.NoError(t, err) + + tc.before() + req, err := http.NewRequest(http.MethodPost, "/quota/save", bytes.NewBuffer([]byte(tc.reqBody))) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + resp := httptest.NewRecorder() + q.server.ServeHTTP(resp, req) + + assert.Equal(t, http.StatusOK, resp.Code) + + tc.after() + }) + } +} + +func (q *QuotaSuite) TestSaveTempQuota() { + t := q.T() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + testcases := []struct { + name string + before func() + after func() + reqBody string + }{ + { + name: "创建临时额度", + before: func() { + sess := mocks.NewMockSession(ctrl) + sess.EXPECT().Claims().Return(session.Claims{ + Uid: 1, + }).AnyTimes() + provider := mocks.NewMockProvider(ctrl) + session.SetDefaultProvider(provider) + provider.EXPECT().Get(gomock.Any()).Return(sess, nil) + }, + after: func() { + var quota dao.TempQuota + err := q.db.Where("uid = ? AND `key` = ?", 1, "temp_key_1").First("a).Error + require.NoError(t, err) + assert.Equal(t, int64(100000), quota.Amount) + assert.Equal(t, int64(123), quota.StartTime) + assert.Equal(t, int64(456), quota.EndTime) + }, + reqBody: `{"amount": 100000, "key": "temp_key_1", "start_time": 123, "end_time": 456}`, + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + // 每个测试用例开始时清理数据 + err := q.db.Exec("TRUNCATE TABLE quotas").Error + require.NoError(t, err) + err = q.db.Exec("TRUNCATE TABLE quota_records").Error + require.NoError(t, err) + err = q.db.Exec("TRUNCATE TABLE temp_quotas").Error + require.NoError(t, err) + + tc.before() + req, err := http.NewRequest(http.MethodPost, "/tmp/save", bytes.NewBuffer([]byte(tc.reqBody))) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + resp := httptest.NewRecorder() + q.server.ServeHTTP(resp, req) + + assert.Equal(t, http.StatusOK, resp.Code) + + tc.after() + }) + } +} + +func (q *QuotaSuite) TestDeduct() { + t := q.T() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + testcases := []struct { + name string + before func() + after func() + reqBody string + }{ + { + name: "从主额度扣减", + before: func() { + sess := mocks.NewMockSession(ctrl) + sess.EXPECT().Claims().Return(session.Claims{ + Uid: 1, + }).AnyTimes() + provider := mocks.NewMockProvider(ctrl) + session.SetDefaultProvider(provider) + provider.EXPECT().Get(gomock.Any()).Return(sess, nil) + + // 创建主额度 + quota := dao.Quota{Amount: 100, Key: "main_quota", UID: 1} + err := q.db.Create("a).Error + require.NoError(t, err) + }, + after: func() { + // 验证主额度被扣减 + var quota dao.Quota + err := q.db.Where("uid = ? AND `key` = ?", 1, "main_quota").First("a).Error + require.NoError(t, err) + assert.Equal(t, int64(80), quota.Amount) + + // 验证扣减记录被创建 + var record dao.QuotaRecord + err = q.db.Where("uid = ? AND `key` = ?", 1, "deduct_key_1").First(&record).Error + require.NoError(t, err) + assert.Equal(t, int64(20), record.Amount) + }, + reqBody: `{"amount": 20, "key": "deduct_key_1"}`, + }, + { + name: "从临时额度扣减", + before: func() { + sess := mocks.NewMockSession(ctrl) + sess.EXPECT().Claims().Return(session.Claims{ + Uid: 1, + }).AnyTimes() + provider := mocks.NewMockProvider(ctrl) + session.SetDefaultProvider(provider) + provider.EXPECT().Get(gomock.Any()).Return(sess, nil) + + // 创建临时额度 + now := time.Now().Unix() + tempQuota := dao.TempQuota{ + Amount: 50, + Key: "temp_quota_1", + UID: 1, + StartTime: now, + EndTime: now + 24*3600, + } + err := q.db.Create(&tempQuota).Error + require.NoError(t, err) + }, + after: func() { + // 验证临时额度被扣减 + var tempQuota dao.TempQuota + err := q.db.Where("uid = ? AND `key` = ?", 1, "temp_quota_1").First(&tempQuota).Error + require.NoError(t, err) + assert.Equal(t, int64(30), tempQuota.Amount) + + // 验证扣减记录被创建 + var record dao.QuotaRecord + err = q.db.Where("uid = ? AND `key` = ?", 1, "deduct_key_2").First(&record).Error + require.NoError(t, err) + assert.Equal(t, int64(20), record.Amount) + }, + reqBody: `{"amount": 20, "key": "deduct_key_2"}`, + }, + { + name: "优先从临时额度扣减,不足再从主额度扣减", + before: func() { + sess := mocks.NewMockSession(ctrl) + sess.EXPECT().Claims().Return(session.Claims{ + Uid: 1, + }).AnyTimes() + provider := mocks.NewMockProvider(ctrl) + session.SetDefaultProvider(provider) + provider.EXPECT().Get(gomock.Any()).Return(sess, nil) + + // 创建主额度 + quota := dao.Quota{Amount: 100, Key: "main_quota_2", UID: 1} + err := q.db.Create("a).Error + require.NoError(t, err) + + // 创建临时额度(金额不足) + now := time.Now().Unix() + tempQuota := dao.TempQuota{ + Amount: 10, + Key: "temp_quota_2", + UID: 1, + StartTime: now, + EndTime: now + 24*3600, + } + err = q.db.Create(&tempQuota).Error + require.NoError(t, err) + }, + after: func() { + // 验证临时额度被完全扣减 + var tempQuota dao.TempQuota + err := q.db.Where("uid = ? AND `key` = ?", 1, "temp_quota_2").First(&tempQuota).Error + require.NoError(t, err) + assert.Equal(t, int64(0), tempQuota.Amount) + + // 验证主额度被扣减 + var quota dao.Quota + err = q.db.Where("uid = ? AND `key` = ?", 1, "main_quota_2").First("a).Error + require.NoError(t, err) + assert.Equal(t, int64(90), quota.Amount) + + // 验证扣减记录被创建 + var record dao.QuotaRecord + err = q.db.Where("uid = ? AND `key` = ?", 1, "deduct_key_3").First(&record).Error + require.NoError(t, err) + assert.Equal(t, int64(30), record.Amount) + }, + reqBody: `{"amount": 30, "key": "deduct_key_3"}`, + }, + { + name: "扣减失败 - 余额不足", + before: func() { + sess := mocks.NewMockSession(ctrl) + sess.EXPECT().Claims().Return(session.Claims{ + Uid: 1, + }).AnyTimes() + provider := mocks.NewMockProvider(ctrl) + session.SetDefaultProvider(provider) + provider.EXPECT().Get(gomock.Any()).Return(sess, nil) + + // 创建少量主额度 + quota := dao.Quota{Amount: 10, Key: "main_quota_3", UID: 1} + err := q.db.Create("a).Error + require.NoError(t, err) + }, + after: func() { + // 验证主额度没有被扣减 + var quota dao.Quota + err := q.db.Where("uid = ? AND `key` = ?", 1, "main_quota_3").First("a).Error + require.NoError(t, err) + assert.Equal(t, int64(10), quota.Amount) // 额度应该保持不变 + + // 验证扣减记录没有被创建(因为事务回滚) + var record dao.QuotaRecord + err = q.db.Where("uid = ? AND `key` = ?", 1, "deduct_key_4").First(&record).Error + assert.Error(t, err) // 应该找不到记录,因为事务回滚了 + }, + reqBody: `{"amount": 50, "key": "deduct_key_4"}`, + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + // 每个测试用例开始时清理数据 + err := q.db.Exec("TRUNCATE TABLE quotas").Error + require.NoError(t, err) + err = q.db.Exec("TRUNCATE TABLE quota_records").Error + require.NoError(t, err) + err = q.db.Exec("TRUNCATE TABLE temp_quotas").Error + require.NoError(t, err) + + tc.before() + req, err := http.NewRequest(http.MethodPost, "/deduct", bytes.NewBuffer([]byte(tc.reqBody))) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + resp := httptest.NewRecorder() + q.server.ServeHTTP(resp, req) + + if tc.name == "扣减失败 - 余额不足" { + assert.Equal(t, http.StatusOK, resp.Code) // HTTP状态码仍然是200 + + var response map[string]interface{} + err = json.Unmarshal(resp.Body.Bytes(), &response) + require.NoError(t, err) + + assert.Equal(t, float64(400002), response["code"]) + assert.Equal(t, "余额不足", response["msg"]) + } else { + assert.Equal(t, http.StatusOK, resp.Code) + } + + tc.after() + }) + } +} diff --git a/internal/test/server_test.go b/internal/test/server_test.go index 5f91492..44923e9 100644 --- a/internal/test/server_test.go +++ b/internal/test/server_test.go @@ -18,7 +18,7 @@ import ( "context" "testing" - ai "github.com/ecodeclub/ai-gateway-go/api/proto/gen/ai/v1" + ai "github.com/ecodeclub/ai-gateway-go/api/gen/ai/v1" "github.com/ecodeclub/ai-gateway-go/internal/domain" igrpc "github.com/ecodeclub/ai-gateway-go/internal/grpc" "github.com/ecodeclub/ai-gateway-go/internal/service" diff --git a/internal/web/quota.go b/internal/web/quota.go new file mode 100644 index 0000000..0be1054 --- /dev/null +++ b/internal/web/quota.go @@ -0,0 +1,143 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package web + +import ( + "errors" + + "github.com/ecodeclub/ai-gateway-go/errs" + "github.com/ecodeclub/ai-gateway-go/internal/domain" + "github.com/ecodeclub/ai-gateway-go/internal/service" + "github.com/ecodeclub/ekit/slice" + "github.com/ecodeclub/ginx" + "github.com/ecodeclub/ginx/session" + "github.com/gin-gonic/gin" +) + +type QuotaHandler struct { + svc *service.QuotaService +} + +func NewQuotaHandler(svc *service.QuotaService) *QuotaHandler { + return &QuotaHandler{svc: svc} +} + +func (q *QuotaHandler) PrivateRoutes(server *gin.Engine) { + group := server.Group("/quota") + group.POST("/save", ginx.BS(q.AddQuota)) + group.POST("/get", ginx.S(q.GetQuota)) + + tmp := server.Group("/tmp") + tmp.POST("/save", ginx.BS(q.CreateTempQuota)) + tmp.POST("/get", ginx.S(q.GetTempQuota)) + + server.POST("/deduct", ginx.BS(q.Deduct)) +} + +func (q *QuotaHandler) AddQuota(ctx *ginx.Context, req QuotaRequest, sess session.Session) (ginx.Result, error) { + uid := sess.Claims().Uid + err := q.svc.AddQuota(ctx, domain.Quota{Amount: req.Amount, Uid: uid, Key: req.Key}) + if err != nil { + return systemErrorResult, err + } + return ginx.Result{ + Msg: "OK", + }, nil +} + +func (q *QuotaHandler) CreateTempQuota(ctx *ginx.Context, req QuotaRequest, sess session.Session) (ginx.Result, error) { + uid := sess.Claims().Uid + + if req.StartTime == 0 || req.EndTime == 0 { + return invalidParamResult, errs.ErrInvalidParam + } + + err := q.svc.CreateTempQuota(ctx, domain.TempQuota{ + Amount: req.Amount, + Key: req.Key, + Uid: uid, + StartTime: req.StartTime, + EndTime: req.EndTime, + }) + if err != nil { + return systemErrorResult, err + } + return ginx.Result{ + Msg: "OK", + }, nil +} + +func (q *QuotaHandler) GetQuota(ctx *ginx.Context, sees session.Session) (ginx.Result, error) { + uid := sees.Claims().Uid + + quota, err := q.svc.GetQuota(ctx, uid) + if err != nil { + return systemErrorResult, err + } + return ginx.Result{ + Msg: "ok", + Data: QuotaResponse{Amount: quota.Amount}, + }, nil +} + +func (q *QuotaHandler) GetTempQuota(ctx *ginx.Context, sees session.Session) (ginx.Result, error) { + uid := sees.Claims().Uid + quotaList, err := q.svc.GetTempQuota(ctx, uid) + if err != nil { + return systemErrorResult, err + } + return ginx.Result{ + Msg: "ok", + Data: q.toQuotaResponse(quotaList), + }, nil +} + +func (q *QuotaHandler) Deduct(ctx *ginx.Context, req QuotaRequest, sees session.Session) (ginx.Result, error) { + uid := sees.Claims().Uid + err := q.svc.Deduct(ctx, uid, req.Amount, req.Key) + if err != nil { + // 检查是否是余额不足错误 + if errors.Is(err, errs.ErrInsufficientBalance) { + return insufficientBalanceResult, nil + } + // 其他系统错误 + return systemErrorResult, nil + } + return ginx.Result{Msg: "OK"}, nil +} + +func (q *QuotaHandler) toQuotaResponse(tempQuotaList []domain.TempQuota) []QuotaResponse { + return slice.Map[domain.TempQuota, QuotaResponse](tempQuotaList, func(idx int, src domain.TempQuota) QuotaResponse { + return QuotaResponse{ + Amount: src.Amount, + StartTime: src.StartTime, + EndTime: src.EndTime, + } + }) +} + +type QuotaRequest struct { + Amount int64 `json:"amount,omitempty"` + Key string `json:"key,omitempty"` + StartTime int64 `json:"start_time,omitempty"` + EndTime int64 `json:"end_time,omitempty"` +} + +type QuotaResponse struct { + Amount int64 `json:"amount,omitempty"` + Key string `json:"key"` + StartTime int64 `json:"start_time,omitempty"` + EndTime int64 `json:"end_time,omitempty"` +} diff --git a/internal/web/result.go b/internal/web/result.go index fad3de7..411a466 100644 --- a/internal/web/result.go +++ b/internal/web/result.go @@ -23,3 +23,13 @@ var systemErrorResult = ginx.Result{ Code: errs.SystemError.Code, Msg: errs.SystemError.Msg, } + +var invalidParamResult = ginx.Result{ + Code: errs.InvalidParamError.Code, + Msg: errs.InvalidParamError.Msg, +} + +var insufficientBalanceResult = ginx.Result{ + Code: errs.InsufficientBalanceError.Code, + Msg: errs.InsufficientBalanceError.Msg, +}