|
1 | 1 | import platform |
2 | 2 | from argparse import ArgumentParser |
3 | 3 |
|
| 4 | +import pyarrow |
| 5 | + |
4 | 6 | from datasets import __version__ as version |
5 | | -from datasets import config |
6 | | -from datasets.commands import BaseTransformersCLICommand |
| 7 | +from datasets.commands import BaseDatasetsCLICommand |
7 | 8 |
|
8 | 9 |
|
9 | 10 | def info_command_factory(_): |
10 | 11 | return EnvironmentCommand() |
11 | 12 |
|
12 | 13 |
|
13 | | -class EnvironmentCommand(BaseTransformersCLICommand): |
| 14 | +class EnvironmentCommand(BaseDatasetsCLICommand): |
14 | 15 | @staticmethod |
15 | 16 | def register_subcommand(parser: ArgumentParser): |
16 | | - download_parser = parser.add_parser("env") |
| 17 | + download_parser = parser.add_parser("env", help="Print relevant system environment info.") |
17 | 18 | download_parser.set_defaults(func=info_command_factory) |
18 | 19 |
|
19 | 20 | def run(self): |
20 | | - pt_version = "not installed" |
21 | | - pt_cuda_available = "NA" |
22 | | - if config.TORCH_AVAILABLE: |
23 | | - import torch |
24 | | - |
25 | | - pt_version = torch.__version__ |
26 | | - pt_cuda_available = torch.cuda.is_available() |
27 | | - |
28 | | - tf_version = "not installed" |
29 | | - tf_cuda_available = "NA" |
30 | | - if config.TF_AVAILABLE: |
31 | | - import tensorflow as tf |
32 | | - |
33 | | - tf_version = tf.__version__ |
34 | | - try: |
35 | | - # deprecated in v2.1 |
36 | | - tf_cuda_available = tf.test.is_gpu_available() |
37 | | - except AttributeError: |
38 | | - # returns list of devices, convert to bool |
39 | | - tf_cuda_available = bool(tf.config.list_physical_devices("GPU")) |
40 | | - |
41 | 21 | info = { |
42 | 22 | "`datasets` version": version, |
43 | 23 | "Platform": platform.platform(), |
44 | 24 | "Python version": platform.python_version(), |
45 | | - "PyTorch version (GPU?)": "{} ({})".format(pt_version, pt_cuda_available), |
46 | | - "Tensorflow version (GPU?)": "{} ({})".format(tf_version, tf_cuda_available), |
47 | | - "Using GPU in script?": "<fill in>", |
48 | | - "Using distributed or parallel set-up in script?": "<fill in>", |
| 25 | + "PyArrow version": pyarrow.__version__, |
49 | 26 | } |
50 | 27 |
|
51 | | - print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n") |
| 28 | + print("\nCopy-and-paste the text below in your GitHub issue.\n") |
52 | 29 | print(self.format_dict(info)) |
53 | 30 |
|
54 | 31 | return info |
|
0 commit comments