Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,23 @@ There are also `m` loss computations instead of the usual 1.

For more information see Cui et al. (https://arxiv.org/abs/2112.09331) or Pham et al. (https://arxiv.org/abs/2111.10050).

### Support for remote loading/training

It is always possible to resume directly from a remote file, e.g., a file in an s3 bucket. Just set `--resume s3://<path-to-checkpoint> `.
This will work with any filesystem supported by `fsspec`.

It is also possible to train `open_clip` models while continuously backing up to s3. This can help to avoid slow local file systems.

Say that your node has a local ssd `/scratch`, an s3 bucket `s3://<path-to-bucket>`.

In that case, set `--logs /scratch` and `--remote-sync s3://<path-to-bucket>`. Then, a background process will sync `/scratch/<run-name>` to `s3://<path-to-bucket>/<run-name>`. After syncing, the background process will sleep for `--remote-sync-frequency` seconds, which defaults to 5 minutes.

There is also experimental support for syncing to other remote file systems, not just s3. To do so, specify `--remote-sync-protocol fsspec`. However, this is currently very slow and not recommended.

Also, to optionally avoid saving too many checkpoints locally when using these features, you can use `--delete-previous-checkpoint` which deletes the previous checkpoint after saving a new one.

Note: if you are using this feature with `--resume latest`, there are a few warnings. First, use with `--save-most-recent` is not supported. Second, only `s3` is supported. Finally, since the sync happens in the background, it is possible that the most recent checkpoint may not be finished syncing to the remote.

## Scaling trends

The plot below shows how zero-shot performance of CLIP models varies as we scale the number of samples used for training. Zero-shot performance increases steadily for both ImageNet and [ImageNetV2](https://arxiv.org/abs/1902.10811), and is far from saturated at ~15M samples.
Expand Down
1 change: 1 addition & 0 deletions requirements-training.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ pandas
braceexpand
huggingface_hub
transformers
fsspec
83 changes: 83 additions & 0 deletions src/training/file_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import logging
import os
import multiprocessing
import subprocess
import time
import fsspec
import torch
from tqdm import tqdm

def remote_sync_s3(local_dir, remote_dir):
# skip epoch_latest which can change during sync.
result = subprocess.run(["aws", "s3", "sync", local_dir, remote_dir, '--exclude', '*epoch_latest.pt'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
if result.returncode != 0:
logging.error(f"Error: Failed to sync with S3 bucket {result.stderr.decode('utf-8')}")
return False

logging.info(f"Successfully synced with S3 bucket")
return True

def remote_sync_fsspec(local_dir, remote_dir):
# FIXME currently this is slow and not recommended. Look into speeding up.
a = fsspec.get_mapper(local_dir)
b = fsspec.get_mapper(remote_dir)

for k in a:
# skip epoch_latest which can change during sync.
if 'epoch_latest.pt' in k:
continue

logging.info(f'Attempting to sync {k}')
if k in b and len(a[k]) == len(b[k]):
logging.debug(f'Skipping remote sync for {k}.')
continue

try:
logging.info(f'Successful sync for {k}.')
b[k] = a[k]
except Exception as e:
logging.info(f'Error during remote sync for {k}: {e}')
return False

return True

def remote_sync(local_dir, remote_dir, protocol):
logging.info('Starting remote sync.')
if protocol == 's3':
return remote_sync_s3(local_dir, remote_dir)
elif protocol == 'fsspec':
return remote_sync_fsspec(local_dir, remote_dir)
else:
logging.error('Remote protocol not known')
return False

def keep_running_remote_sync(sync_every, local_dir, remote_dir, protocol):
while True:
time.sleep(sync_every)
remote_sync(local_dir, remote_dir, protocol)

def start_sync_process(sync_every, local_dir, remote_dir, protocol):
p = multiprocessing.Process(target=keep_running_remote_sync, args=(sync_every, local_dir, remote_dir, protocol))
return p

# Note: we are not currently using this save function.
def pt_save(pt_obj, file_path):
of = fsspec.open(file_path, "wb")
with of as f:
torch.save(pt_obj, file_path)

def pt_load(file_path, map_location=None):
if not file_path.startswith('/'):
logging.info('Loading remote checkpoint, which may take a bit.')
of = fsspec.open(file_path, "rb")
with of as f:
out = torch.load(f)
return out

def check_exists(file_path):
try:
with fsspec.open(file_path):
pass
except FileNotFoundError:
return False
return True
73 changes: 67 additions & 6 deletions src/training/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import os
import re
import subprocess
import sys
import random
from datetime import datetime
Expand Down Expand Up @@ -33,6 +34,7 @@
from training.params import parse_args
from training.scheduler import cosine_lr
from training.train import train_one_epoch, evaluate
from training.file_utils import pt_load, check_exists, start_sync_process, remote_sync


LATEST_CHECKPOINT_NAME = "epoch_latest.pt"
Expand All @@ -49,9 +51,16 @@ def natural_key(string_):
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]


def get_latest_checkpoint(path: str):
def get_latest_checkpoint(path: str, remote : bool):
# as writen, this glob recurses, so can pick up checkpoints across multiple sub-folders
checkpoints = glob.glob(path + '**/*.pt', recursive=True)
if remote:
result = subprocess.run(["aws", "s3", "ls", path + "/"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be much better to use fsspec to do that kind of listing

print(result)
if result.returncode == 1:
return None
checkpoints = [os.path.join(path, x.split(' ')[-1]) for x in result.stdout.decode().split('\n')[:-1]]
else:
checkpoints = glob.glob(path + '**/*.pt', recursive=True)
if checkpoints:
checkpoints = sorted(checkpoints, key=natural_key)
return checkpoints[-1]
Expand Down Expand Up @@ -121,23 +130,33 @@ def main(args):

if resume_latest:
resume_from = None
checkpoint_path = args.checkpoint_path
# If using remote_sync, need to check the remote instead of the local checkpoints folder.
if args.remote_sync is not None:
checkpoint_path = os.path.join(args.remote_sync, args.name, "checkpoints")
if args.save_most_recent:
print('Error. Cannot use save-most-recent with remote_sync and resume latest.')
return -1
if args.remote_sync_protocol != 's3':
print('Error. Sync protocol not supported when using resume latest.')
return -1
if is_master(args):
# Checking for existing checkpoint via master rank only. It is possible for
# different rank processes to see different files if a shared file-system is under
# stress, however it's very difficult to fully work around such situations.
if args.save_most_recent:
# if --save-most-recent flag is set, look for latest at a fixed filename
resume_from = os.path.join(args.checkpoint_path, LATEST_CHECKPOINT_NAME)
resume_from = os.path.join(checkpoint_path, LATEST_CHECKPOINT_NAME)
if not os.path.exists(resume_from):
# If no latest checkpoint has been saved yet, don't try to resume
resume_from = None
else:
# otherwise, list checkpoint dir contents and pick the newest checkpoint
resume_from = get_latest_checkpoint(args.checkpoint_path)
resume_from = get_latest_checkpoint(checkpoint_path, remote=args.remote_sync is not None)
if resume_from:
logging.info(f'Found latest resume checkpoint at {resume_from}.')
else:
logging.info(f'No latest resume checkpoint found in {args.checkpoint_path}.')
logging.info(f'No latest resume checkpoint found in {checkpoint_path}.')
if args.distributed:
# sync found checkpoint path to all ranks
resume_from = broadcast_object(args, resume_from)
Expand All @@ -146,6 +165,29 @@ def main(args):
if args.copy_codebase:
copy_codebase(args)

# start the sync proces if remote-sync is not None
remote_sync_process = None
if is_master(args) and args.remote_sync is not None:
# first make sure it works
result = remote_sync(
os.path.join(args.logs, args.name),
os.path.join(args.remote_sync, args.name),
args.remote_sync_protocol
)
if result:
logging.info('remote sync successful.')
else:
logging.info('Error: remote sync failed. Exiting.')
return -1
# if all looks good, start a process to do this every args.remote_sync_frequency seconds
remote_sync_process = start_sync_process(
args.remote_sync_frequency,
os.path.join(args.logs, args.name),
os.path.join(args.remote_sync, args.name),
args.remote_sync_protocol
)
remote_sync_process.start()

if args.precision == 'fp16':
logging.warning(
'It is recommended to use AMP mixed-precision instead of FP16. '
Expand Down Expand Up @@ -247,7 +289,7 @@ def main(args):
# optionally resume from a checkpoint
start_epoch = 0
if args.resume is not None:
checkpoint = torch.load(args.resume, map_location='cpu')
checkpoint = pt_load(args.resume, map_location='cpu')
if 'epoch' in checkpoint:
# resuming a train checkpoint w/ epoch and optimizer state
start_epoch = checkpoint["epoch"]
Expand Down Expand Up @@ -335,6 +377,11 @@ def main(args):
checkpoint_dict,
os.path.join(args.checkpoint_path, f"epoch_{completed_epoch}.pt"),
)
if args.delete_previous_checkpoint:
previous_checkpoint = os.path.join(args.checkpoint_path, f"epoch_{completed_epoch - 1}.pt")
if os.path.exists(previous_checkpoint):
os.remove(previous_checkpoint)

if args.save_most_recent:
# try not to corrupt the latest checkpoint if save fails
tmp_save_path = os.path.join(args.checkpoint_path, "tmp.pt")
Expand All @@ -345,6 +392,20 @@ def main(args):
if args.wandb and is_master(args):
wandb.finish()

# run a final sync.
if remote_sync_process is not None:
logging.info('Final remote sync.')
remote_sync_process.terminate()
result = remote_sync(
os.path.join(args.logs, args.name),
os.path.join(args.remote_sync, args.name),
args.remote_sync_protocol
)
if result:
logging.info('Final remote sync successful.')
else:
logging.info('Final remote sync failed.')


def copy_codebase(args):
from shutil import copytree, ignore_patterns
Expand Down
26 changes: 24 additions & 2 deletions src/training/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,30 @@ def parse_args(args):
default=100,
help="Log every n steps to tensorboard/console/wandb.",
)


parser.add_argument(
"--remote-sync",
type=str,
default=None,
help="Optinoally sync with a remote path specified by this arg",
)
parser.add_argument(
"--remote-sync-frequency",
type=int,
default=300,
help="How frequently to sync to a remote directly if --remote-sync is not None.",
)
parser.add_argument(
"--remote-sync-protocol",
choices=["s3", "fsspec"],
default="s3",
help="How to do the remote sync backup if --remote-sync is not None.",
)
parser.add_argument(
"--delete-previous-checkpoint",
default=False,
action="store_true",
help="If true, delete previous checkpoint after storing a new one."
)
args = parser.parse_args(args)

# If some params are not passed, we use the default values based on model name.
Expand Down