Skip to content

Commit f9d9709

Browse files
rm typer dep
1 parent f5de87f commit f9d9709

File tree

1 file changed

+1
-9
lines changed

1 file changed

+1
-9
lines changed

src/demo_hf.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -120,20 +120,12 @@ def generate(
120120

121121

122122
def main(prompt: str = "my name is", model_name: str = "blt-1b"):
123-
# distributed_args = DistributedArgs()
124-
# distributed_args.configure_world()
125-
# if not torch.distributed.is_initialized():
126-
# setup_torch_distributed(distributed_args)
127-
128123
# Set device and ensure CUDA is available
129124
if not torch.cuda.is_available():
130125
raise RuntimeError("CUDA is required but not available")
131126
device = torch.device("cuda")
132127
torch.cuda.empty_cache() # Clear any existing CUDA memory
133128

134-
assert model_name in ["blt-1b", "blt-7b"]
135-
model_name = model_name.replace("-", "_")
136-
137129
#HF
138130
blt_repo = "facebook/blt-1b"
139131

@@ -197,4 +189,4 @@ def main(prompt: str = "my name is", model_name: str = "blt-1b"):
197189

198190

199191
if __name__ == "__main__":
200-
typer.run(main)
192+
main()

0 commit comments

Comments
 (0)