Skip to content

Commit cd314cc

Browse files
committed
first commit
0 parents  commit cd314cc

16 files changed

+2063
-0
lines changed

LICENSE.txt

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
Copyright (c) Quectel Wireless Solution, Co., Ltd.All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.

code/ark_lib.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import ujson as json
2+
import request
3+
from usr.logging import getLogger
4+
from usr.logging import getLogger
5+
logger = getLogger(__name__)
6+
7+
8+
9+
class ARKConfig(object):
10+
CHAT_COMPLETIONS_POST_URL = 'https://ark.cn-beijing.volces.com/api/v3/chat/completions'
11+
MODEL_ID = 'xx'
12+
API_KEY = 'xx'
13+
14+
class ChatCompletionsError(Exception):
15+
pass
16+
17+
18+
class ChatCompletions(object):
19+
20+
def __init__(self, question):
21+
if not (isinstance(question, str) and question):
22+
raise ChatCompletionsError("question must be str type and not blank.")
23+
self.question = question
24+
self.resp = None
25+
26+
def __enter__(self):
27+
self.__post()
28+
return self
29+
30+
def __exit__(self, *args, **kwargs):
31+
if self.resp:
32+
self.resp.close()
33+
34+
def __post(self):
35+
resp = request.post(
36+
ARKConfig.CHAT_COMPLETIONS_POST_URL,
37+
headers={
38+
"Content-Type": "application/json",
39+
"Authorization": "Bearer {}".format(ARKConfig.API_KEY)
40+
},
41+
json={
42+
"model": ARKConfig.MODEL_ID,
43+
"messages": [
44+
{"role": "user", "content": self.question},
45+
],
46+
"stream": True
47+
}
48+
)
49+
50+
if resp.status_code != 200:
51+
raise ChatCompletionsError("query_chat_completions exc: {}".format("".join((_ for _ in resp.text))))
52+
53+
return resp
54+
55+
@property
56+
def answer(self):
57+
self.resp = self.__post()
58+
raw = ""
59+
for temp in self.resp.text:
60+
raw += temp
61+
while True:
62+
data_index = raw.find("data: ")
63+
line_index = raw.find("\n\n")
64+
if data_index == -1 or line_index == -1:
65+
break
66+
json_string = raw[data_index + 6 : line_index]
67+
if json_string == "[DONE]":
68+
break
69+
data = json.loads(json_string)
70+
yield data["choices"][0]["delta"]["content"]
71+
raw = raw[line_index + 2:]
72+
73+
74+
if __name__ == "__main__":
75+
with ChatCompletions("你好") as cc:
76+
for text in cc.answer:
77+
print(text, end="")

code/asr_lib.py

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
import modem
2+
import ujson as json
3+
from usr import uwebsocket as ws
4+
from usr import uuid
5+
from usr.logging import getLogger
6+
from usr.message import *
7+
8+
9+
logger = getLogger(__name__)
10+
11+
12+
13+
class ASRConfig(object):
14+
ASR_APP_ID = 'xx'
15+
ASR_AUTH_TOKEN = 'xxx'
16+
ASR_HOST = 'wss://openspeech.bytedance.com/api/v2/asr'
17+
ASR_CLUSTER = 'xx'
18+
19+
20+
class ASRConnectError(Exception):
21+
pass
22+
23+
24+
class ASRQueryError(Exception):
25+
pass
26+
27+
28+
class ASRWebSocket(object):
29+
30+
def __init__(self, host=ASRConfig.ASR_HOST, debug=False):
31+
# self.ASR_APP_ID = ASR_APP_ID
32+
# self.ASR_AUTH_TOKEN = ASR_AUTH_TOKEN
33+
# self.ASR_HOST = ASR_HOST
34+
# self.ASR_CLUSTER = ASR_CLUSTER
35+
self.debug = debug
36+
self.host = host
37+
38+
def __enter__(self):
39+
self.open()
40+
return self
41+
42+
def __exit__(self, *args, **kwargs):
43+
return self.close()
44+
45+
@property
46+
def client(self):
47+
__client__ = getattr(self, "__client__", None)
48+
if __client__ is not None:
49+
return __client__
50+
51+
try:
52+
__client__ = ws.Client.connect(
53+
self.host,
54+
headers={"Authorization": "Bearer; {}".format(ASRConfig.ASR_AUTH_TOKEN)},
55+
debug=self.debug
56+
)
57+
except Exception as e:
58+
raise ASRConnectError("ASR websocket connect failed, pls checkout your network! Exception details: {}, {}".format(type(e).__name__, str(e)))
59+
else:
60+
setattr(self, "__client__", __client__)
61+
return __client__
62+
63+
def send(self, data):
64+
"""send data to server"""
65+
return self.client.send(data)
66+
67+
def recv(self):
68+
"""receive data from server, return None or "" means disconnection"""
69+
return self.client.recv()
70+
71+
def open(self):
72+
return self.client
73+
74+
def close(self):
75+
"""close websocket"""
76+
self.client.close()
77+
del self.__client__
78+
79+
def full_client_request(self):
80+
"""发送 full client request"""
81+
payload = json.dumps(
82+
{
83+
"app": {
84+
"appid": ASRConfig.ASR_APP_ID,
85+
"token": ASRConfig.ASR_AUTH_TOKEN,
86+
"cluster": ASRConfig.ASR_CLUSTER
87+
},
88+
"user": {
89+
"uid": modem.getDevImei()
90+
},
91+
"audio": {
92+
"format": "mp3",
93+
"rate": 8000,
94+
"bits": 16,
95+
"channel": 1,
96+
"language": "zh-CN"
97+
},
98+
"request": {
99+
"reqid": str(uuid.uuid4()),
100+
# "workflow": "audio_in,resample,partition,vad,fe,decode",
101+
"sequence": 1,
102+
"nbest": 1,
103+
"show_utterances": False
104+
}
105+
}
106+
)
107+
108+
full_client_request_msg = Message(
109+
proto_version=ProtoVersion.V1,
110+
header_size=1,
111+
message_type=MessageType.FULL_CLIENT_REQUEST,
112+
message_type_specific_flags=MessageTypeSpecificFlags.NONE,
113+
message_serialization_method=MessageSerializationMethod.JSON,
114+
message_compression=MessageCompression.NONE,
115+
payload=len(payload).to_bytes(4, "big") + payload
116+
)
117+
118+
# logger.debug("============ full_client_request_msg =============:\n{}".format(full_client_request_msg))
119+
# logger.debug("full_client_request_msg hex string: ", full_client_request_msg.to_hex())
120+
# logger.debug("payload: ", json.loads(full_client_request_msg.payload[4:]))
121+
122+
try:
123+
self.send(full_client_request_msg.to_bytes())
124+
except Exception as e:
125+
raise ASRQueryError("send error: {}, {}".format(type(e).__name__, str(e)))
126+
127+
try:
128+
resp_data = self.recv()
129+
except Exception as e:
130+
raise ASRQueryError("recv error: {}, {}".format(type(e).__name__, str(e)))
131+
else:
132+
resp = MessageWrapper(Message.from_bytes(resp_data.encode()))
133+
if resp and resp.message_type == MessageType.ERROR_RESPONSE:
134+
logger.error("err resp: {}".format(resp.payload))
135+
return None
136+
# logger.debug("resp: {}".format(resp.payload))
137+
return resp
138+
139+
def audio_only_request(self, payload, is_last=False):
140+
"""发送 audio only request"""
141+
audio_only_request_msg = Message(
142+
proto_version=ProtoVersion.V1,
143+
header_size=1,
144+
message_type=MessageType.AUDIO_ONLY_REQUEST,
145+
message_type_specific_flags=MessageTypeSpecificFlags.AUDIO_ONLY_REQUEST_LAST_PACKAGE if is_last else MessageTypeSpecificFlags.NONE,
146+
message_serialization_method=MessageSerializationMethod.NONE,
147+
message_compression=MessageCompression.NONE,
148+
payload=len(payload).to_bytes(4, "big") + payload
149+
)
150+
# logger.debug("============ audio_only_request_msg =============:\n{}".format(audio_only_request_msg))
151+
# logger.debug("audio_only_request hex string: ", audio_only_request_msg.to_hex())
152+
# logger.debug("payload: ", audio_only_request_msg.payload[4:])
153+
154+
try:
155+
self.send(audio_only_request_msg.to_bytes())
156+
except Exception as e:
157+
raise ASRQueryError("audio_only_request error: {}, {}".format(type(e).__name__, str(e)))
158+
159+
try:
160+
resp_data = self.recv()
161+
except Exception as e:
162+
raise ASRQueryError("audio_only_request recv error: {}, {}".format(type(e).__name__, str(e)))
163+
else:
164+
resp = MessageWrapper(Message.from_bytes(resp_data.encode()))
165+
if resp and resp.message_type == MessageType.ERROR_RESPONSE:
166+
raise ASRQueryError("audio_only_request err resp: {}".format(resp.payload))
167+
# logger.debug("audio_only_request resp: {}".format(resp.payload))
168+
return resp
169+
170+
def query_asr(self, input_file_path="/usr/input.mp3", read_buffer_size=8192):
171+
172+
resp = self.full_client_request()
173+
if resp is None:
174+
raise ASRQueryError("query_asr error resp.")
175+
176+
with open(input_file_path, "rb") as f:
177+
total_size = f.seek(0, 2)
178+
f.seek(0, 0)
179+
180+
chunk_nums = total_size // read_buffer_size
181+
if total_size % read_buffer_size != 0:
182+
chunk_nums += 1
183+
184+
logger.debug("input file chunk nums: {}".format(chunk_nums))
185+
for _ in range(chunk_nums - 1):
186+
logger.debug("send {} chunk".format(_ + 1))
187+
chunk = f.read(read_buffer_size)
188+
resp = self.audio_only_request(chunk)
189+
if resp is None:
190+
break
191+
else:
192+
logger.debug("send last chunk")
193+
chunk = f.read(read_buffer_size)
194+
resp = self.audio_only_request(chunk, is_last=True)
195+
if resp and resp.payload["code"] == 1000:
196+
return resp.payload["result"][0]["text"]

0 commit comments

Comments
 (0)