Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/fastdeploy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,5 @@
from . import vision
from . import pipeline
from . import text
from .download import download, download_and_decompress, download_model
from .download import download, download_and_decompress, download_model, get_model_list
from . import serving
26 changes: 26 additions & 0 deletions python/fastdeploy/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,30 @@ def download_and_decompress(url, path='.', rename=None):
return


def get_model_list(category: str=None):
'''
Get all pre-trained models information supported by fd.download_model.
Args:
category(str): model category, if None, list all models in all categories.
Returns:
results(dict): a dictionary, key is category, value is a list which contains models information.
'''
result = model_server.get_model_list()
if result['status'] != 0:
raise ValueError(
'Failed to get pretrained models information from hub model server.'
)
result = result['data']
if category is None:
return result
elif category in result:
return {category: result[category]}
else:
raise ValueError(
'No pretrained model in category {} can be downloaded now.'.format(
category))


def download_model(name: str,
path: str=None,
format: str=None,
Expand All @@ -237,11 +261,13 @@ def download_model(name: str,
if format == 'paddle':
if url.count(".tgz") > 0 or url.count(".tar") > 0 or url.count(
"zip") > 0:
archive_path = fullpath
fullpath = decompress(fullpath)
try:
os.rename(fullpath,
os.path.join(os.path.dirname(fullpath), name))
fullpath = os.path.join(os.path.dirname(fullpath), name)
os.remove(archive_path)
except FileExistsError:
pass
print('Successfully download model at path: {}'.format(fullpath))
Expand Down
14 changes: 14 additions & 0 deletions python/fastdeploy/utils/hub_model_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,20 @@ def request(self, path: str, params: dict) -> dict:
except requests.exceptions.ConnectionError as e:
raise ServerConnectionError(self._url)

def get_model_list(self):
'''
Get all pre-trained models information in dataset.
Return:
result(dict): key is category name, value is a list which contains models \
information such as name, format and version.
'''
api = '{}/{}'.format(self._url, 'fastdeploy_listmodels')
try:
result = requests.get(api, timeout=self._timeout)
return result.json()
except requests.exceptions.ConnectionError as e:
raise ServerConnectionError(self._url)

def is_connected(self):
return self.check(self._url)

Expand Down