Skip to content

Commit 6d35bbc

Browse files
authored
added the freeze automation to the client object (#120)
* added the freeze automation to the client object, pulled from CLI * added test cases for freeze * fixed imports * fixed directory collision
1 parent bf2a033 commit 6d35bbc

File tree

5 files changed

+94
-53
lines changed

5 files changed

+94
-53
lines changed

Algorithmia/CLI.py

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import os
33
from Algorithmia.errors import DataApiError
44
from Algorithmia.algo_response import AlgoResponse
5-
from Algorithmia.util import md5_for_file, md5_for_str
65
import json, re, requests, six
76
import toml
87
import shutil
@@ -244,28 +243,7 @@ def cat(self, path, client):
244243

245244
# algo freeze
246245
def freezeAlgo(self, client, manifest_path="model_manifest.json"):
247-
if os.path.exists(manifest_path):
248-
with open(manifest_path, 'r') as f:
249-
manifest_file = json.load(f)
250-
manifest_file['timestamp'] = str(time())
251-
required_files = manifest_file['required_files']
252-
optional_files = manifest_file['optional_files']
253-
for i in range(len(required_files)):
254-
uri = required_files[i]['source_uri']
255-
local_file = client.file(uri).getFile(as_path=True)
256-
md5_checksum = md5_for_file(local_file)
257-
required_files[i]['md5_checksum'] = md5_checksum
258-
for i in range(len(optional_files)):
259-
uri = required_files[i]['source_uri']
260-
local_file = client.file(uri).getFile(as_path=True)
261-
md5_checksum = md5_for_file(local_file)
262-
required_files[i]['md5_checksum'] = md5_checksum
263-
lock_md5_checksum = md5_for_str(str(manifest_file))
264-
manifest_file['lock_checksum'] = lock_md5_checksum
265-
with open('model_manifest.json.freeze', 'w') as f:
266-
json.dump(manifest_file, f)
267-
else:
268-
print("Expected to find a model_manifest.json file, none was discovered in working directory")
246+
client.freeze(manifest_path)
269247

270248
# algo cp <src> <dest>
271249
def cp(self, src, dest, client):

Algorithmia/algorithm.py

Lines changed: 34 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@
88
from Algorithmia.errors import ApiError, ApiInternalError, raiseAlgoApiError
99
from enum import Enum
1010
from algorithmia_api_client.rest import ApiException
11-
from algorithmia_api_client import CreateRequest, UpdateRequest, VersionRequest, Details, Settings, SettingsMandatory, SettingsPublish, \
11+
from algorithmia_api_client import CreateRequest, UpdateRequest, VersionRequest, Details, Settings, SettingsMandatory, \
12+
SettingsPublish, \
1213
CreateRequestVersionInfo, VersionInfo, VersionInfoPublish
1314

14-
OutputType = Enum('OutputType','default raw void')
15+
OutputType = Enum('OutputType', 'default raw void')
16+
1517

1618
class Algorithm(object):
1719
def __init__(self, client, algoRef):
@@ -32,7 +34,7 @@ def __init__(self, client, algoRef):
3234
raise ValueError('Invalid algorithm URI: ' + algoRef)
3335

3436
def set_options(self, timeout=300, stdout=False, output=OutputType.default, **query_parameters):
35-
self.query_parameters = {'timeout':timeout, 'stdout':stdout}
37+
self.query_parameters = {'timeout': timeout, 'stdout': stdout}
3638
self.output_type = output
3739
self.query_parameters.update(query_parameters)
3840
return self
@@ -42,7 +44,8 @@ def create(self, details={}, settings={}, version_info={}):
4244
detailsObj = Details(**details)
4345
settingsObj = SettingsMandatory(**settings)
4446
createRequestVersionInfoObj = CreateRequestVersionInfo(**version_info)
45-
create_parameters = {"name": self.algoname, "details": detailsObj, "settings": settingsObj, "version_info": createRequestVersionInfoObj}
47+
create_parameters = {"name": self.algoname, "details": detailsObj, "settings": settingsObj,
48+
"version_info": createRequestVersionInfoObj}
4649
create_request = CreateRequest(**create_parameters)
4750
try:
4851
# Create Algorithm
@@ -57,7 +60,8 @@ def update(self, details={}, settings={}, version_info={}):
5760
detailsObj = Details(**details)
5861
settingsObj = Settings(**settings)
5962
createRequestVersionInfoObj = CreateRequestVersionInfo(**version_info)
60-
update_parameters = {"details": detailsObj, "settings": settingsObj, "version_info": createRequestVersionInfoObj}
63+
update_parameters = {"details": detailsObj, "settings": settingsObj,
64+
"version_info": createRequestVersionInfoObj}
6165
update_request = UpdateRequest(**update_parameters)
6266
try:
6367
# Update Algorithm
@@ -70,9 +74,10 @@ def update(self, details={}, settings={}, version_info={}):
7074
# Publish an algorithm
7175
def publish(self, details={}, settings={}, version_info={}):
7276
publish_parameters = {"details": details, "settings": settings, "version_info": version_info}
73-
url = "/v1/algorithms/"+self.username+"/"+self.algoname + "/versions"
77+
url = "/v1/algorithms/" + self.username + "/" + self.algoname + "/versions"
7478
print(publish_parameters)
75-
api_response = self.client.postJsonHelper(url, publish_parameters, parse_response_as_json=True, **self.query_parameters)
79+
api_response = self.client.postJsonHelper(url, publish_parameters, parse_response_as_json=True,
80+
**self.query_parameters)
7681
return api_response
7782
# except ApiException as e:
7883
# error_message = json.loads(e.body)
@@ -81,7 +86,8 @@ def publish(self, details={}, settings={}, version_info={}):
8186
def builds(self, limit=56, marker=None):
8287
try:
8388
if marker is not None:
84-
api_response = self.client.manageApi.get_algorithm_builds(self.username, self.algoname, limit=limit, marker=marker)
89+
api_response = self.client.manageApi.get_algorithm_builds(self.username, self.algoname, limit=limit,
90+
marker=marker)
8591
else:
8692
api_response = self.client.manageApi.get_algorithm_builds(self.username, self.algoname, limit=limit)
8793
return api_response
@@ -109,11 +115,10 @@ def get_build_logs(self, build_id):
109115
raise raiseAlgoApiError(error_message)
110116

111117
def build_logs(self):
112-
url = '/v1/algorithms/'+self.username+'/'+self.algoname+'/builds'
118+
url = '/v1/algorithms/' + self.username + '/' + self.algoname + '/builds'
113119
response = json.loads(self.client.getHelper(url).content.decode('utf-8'))
114120
return response
115121

116-
117122
def get_scm_status(self):
118123
try:
119124
api_response = self.client.manageApi.get_algorithm_scm_connection_status(self.username, self.algoname)
@@ -157,7 +162,6 @@ def versions(self, limit=None, marker=None, published=None, callable=None):
157162
error_message = json.loads(e.body)
158163
raise raiseAlgoApiError(error_message)
159164

160-
161165
# Compile an algorithm
162166
def compile(self):
163167
try:
@@ -176,25 +180,26 @@ def pipe(self, input1):
176180
elif self.output_type == OutputType.void:
177181
return self._postVoidOutput(input1)
178182
else:
179-
return AlgoResponse.create_algo_response(self.client.postJsonHelper(self.url, input1, **self.query_parameters))
183+
return AlgoResponse.create_algo_response(
184+
self.client.postJsonHelper(self.url, input1, **self.query_parameters))
180185

181186
def _postRawOutput(self, input1):
182-
# Don't parse response as json
183-
self.query_parameters['output'] = 'raw'
184-
response = self.client.postJsonHelper(self.url, input1, parse_response_as_json=False, **self.query_parameters)
185-
# Check HTTP code and throw error as needed
186-
if response.status_code == 400:
187-
# Bad request
188-
raise ApiError(response.text)
189-
elif response.status_code == 500:
190-
raise ApiInternalError(response.text)
191-
else:
192-
return response.text
187+
# Don't parse response as json
188+
self.query_parameters['output'] = 'raw'
189+
response = self.client.postJsonHelper(self.url, input1, parse_response_as_json=False, **self.query_parameters)
190+
# Check HTTP code and throw error as needed
191+
if response.status_code == 400:
192+
# Bad request
193+
raise ApiError(response.text)
194+
elif response.status_code == 500:
195+
raise ApiInternalError(response.text)
196+
else:
197+
return response.text
193198

194199
def _postVoidOutput(self, input1):
195-
self.query_parameters['output'] = 'void'
196-
responseJson = self.client.postJsonHelper(self.url, input1, **self.query_parameters)
197-
if 'error' in responseJson:
198-
raise ApiError(responseJson['error']['message'])
199-
else:
200-
return AsyncResponse(responseJson)
200+
self.query_parameters['output'] = 'void'
201+
responseJson = self.client.postJsonHelper(self.url, input1, **self.query_parameters)
202+
if 'error' in responseJson:
203+
raise ApiError(responseJson['error']['message'])
204+
else:
205+
return AsyncResponse(responseJson)

Algorithmia/client.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@
66
from Algorithmia.datafile import DataFile, LocalDataFile, AdvancedDataFile
77
from Algorithmia.datadirectory import DataDirectory, LocalDataDirectory, AdvancedDataDirectory
88
from algorithmia_api_client import Configuration, DefaultApi, ApiClient
9-
9+
from Algorithmia.util import md5_for_file, md5_for_str
1010
from tempfile import mkstemp
1111
import atexit
1212
import json, re, requests, six, certifi
1313
import tarfile
1414
import os
15+
from time import time
1516

1617

1718
class Client(object):
@@ -343,6 +344,30 @@ def exit_handler(self):
343344
except OSError as e:
344345
print(e)
345346

347+
# Used by CI/CD automation for freezing model manifest files, and by the CLI for manual freezing
348+
def freeze(self, manifest_path, manifest_output_dir="."):
349+
if os.path.exists(manifest_path):
350+
with open(manifest_path, 'r') as f:
351+
manifest_file = json.load(f)
352+
manifest_file['timestamp'] = str(time())
353+
required_files = manifest_file['required_files']
354+
optional_files = manifest_file['optional_files']
355+
for i in range(len(required_files)):
356+
uri = required_files[i]['source_uri']
357+
local_file = self.file(uri).getFile(as_path=True)
358+
md5_checksum = md5_for_file(local_file)
359+
required_files[i]['md5_checksum'] = md5_checksum
360+
for i in range(len(optional_files)):
361+
uri = required_files[i]['source_uri']
362+
local_file = self.file(uri).getFile(as_path=True)
363+
md5_checksum = md5_for_file(local_file)
364+
required_files[i]['md5_checksum'] = md5_checksum
365+
lock_md5_checksum = md5_for_str(str(manifest_file))
366+
manifest_file['lock_checksum'] = lock_md5_checksum
367+
with open(manifest_output_dir+'/'+'model_manifest.json.freeze', 'w') as f:
368+
json.dump(manifest_file, f)
369+
else:
370+
print("Expected to find a model_manifest.json file, none was discovered in working directory")
346371

347372
def isJson(myjson):
348373
try:

Test/client_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,10 @@ def test_algorithm_programmatic_create_process(self):
397397

398398
self.assertEqual(response.version_info.semantic_version, "0.1.0", "information is incorrect")
399399

400+
def test_algo_freeze(self):
401+
self.regular_client.freeze("Test/resources/manifests/example_manifest.json", "Test/resources/manifests")
402+
403+
400404

401405
if __name__ == '__main__':
402406
unittest.main()
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
{
2+
"required_files" : [
3+
{ "name": "squeezenet",
4+
"source_uri": "data://AlgorithmiaSE/image_cassification_demo/squeezenet1_1-f364aa15.pth",
5+
"fail_on_tamper": true,
6+
"metadata": {
7+
"dataset_md5_checksum": "46a44d32d2c5c07f7f66324bef4c7266"
8+
}
9+
},
10+
{
11+
"name": "labels",
12+
"source_uri": "data://AlgorithmiaSE/image_cassification_demo/imagenet_class_index.json",
13+
"fail_on_tamper": true,
14+
"metadata": {
15+
"dataset_md5_checksum": "46a44d32d2c5c07f7f66324bef4c7266"
16+
}
17+
}
18+
],
19+
"optional_files": [
20+
{
21+
"name": "mobilenet",
22+
"source_uri": "data://AlgorithmiaSE/image_cassification_demo/mobilenet_v2-b0353104.pth",
23+
"fail_on_tamper": false,
24+
"metadata": {
25+
"dataset_md5_checksum": "46a44d32d2c5c07f7f66324bef4c7266"
26+
}
27+
}
28+
]
29+
}

0 commit comments

Comments
 (0)