Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fixed HTTP Error 308 when running `examples/super_resolution.py`

### Security

### Dependencies
Expand Down
4 changes: 2 additions & 2 deletions docs/user_guide/intermediate/turbulence_super_resolution.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ In this example you will learn the following:

.. warning::

The Python package `pyJHTDB <https://github.com/idies/pyJHTDB>`_ is required for this example to download and process the training and validation datasets.
Install using ``pip install pyJHTDB``.
The Python package `giverny <https://github.com/sciserver/giverny>`_ is required for this example to download and process the training and validation datasets.
Install using ``pip install giverny``.

Problem Description
-------------------
Expand Down
2 changes: 1 addition & 1 deletion examples/super_resolution/conf/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ custom:
n_train: 512
n_valid: 16
domain_size: 128
access_token: "edu.jhu.pha.turbulence.testing-201311" #Replace with your own token here
access_token: "edu.jhu.pha.turbulence.testing-201406" #Replace with your own token here

loss_weights:
U: 1.0
Expand Down
56 changes: 27 additions & 29 deletions examples/super_resolution/jhtdb_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
import torch

try:
import pyJHTDB
import pyJHTDB.dbinfo
from givernylocal.turbulence_dataset import turb_dataset
from givernylocal.turbulence_toolkit import getCutout
except:
raise ModuleNotFoundError(
"This example requires the pyJHTDB python package for access to the JHT database.\n"
+ "Find out information here: https://github.com/idies/pyJHTDB"
"This example requires the givernylocal python package for access to the JHTDB database.\n"
+ "Find out information here: https://github.com/sciserver/giverny"
)
from tqdm import *
from typing import List
Expand All @@ -33,7 +33,7 @@
from physicsnemo.sym.distributed.manager import DistributedManager


def _pos_to_name(dataset, field, time_step, start, end, step, filter_width):
def _pos_to_name(dataset, field, time_step, start, end, step):
return (
"jhtdb_field_"
+ str(field)
Expand All @@ -57,8 +57,6 @@ def _pos_to_name(dataset, field, time_step, start, end, step, filter_width):
+ str(step[1])
+ "_"
+ str(step[2])
+ "_filter_width_"
+ str(filter_width)
)


Expand All @@ -74,30 +72,31 @@ def _name_to_pos(name):


def get_jhtdb(
loader, data_dir: Path, dataset, field, time_step, start, end, step, filter_width
loader, data_dir: Path, dataset, field, time_step, start, end, step
):
# get filename
file_name = (
_pos_to_name(dataset, field, time_step, start, end, step, filter_width) + ".npy"
_pos_to_name(dataset, field, time_step, start, end, step) + ".npy"
)
file_dir = data_dir / Path(file_name)

axes_ranges = np.array([[start[0], end[0]], [start[1], end[1]], [start[2], end[2]]])

# check if file exists and if not download it
try:
results = np.load(file_dir)
except:
# Only MPI process 0 can download data
if DistributedManager().rank == 0:
results = loader.getCutout(
data_set=dataset,
field=field,
time_step=time_step,
start=start,
end=end,
step=step,
filter_width=filter_width,
results = getCutout(
loader,
field,
time_step,
axes_ranges,
strides=step
)
np.save(file_dir, results)

np.save(file_dir, results.to_array()[0,:,:,:,:])
# Wait for all processes to get here
if DistributedManager().distributed:
torch.distributed.barrier()
Expand All @@ -110,7 +109,7 @@ def make_jhtdb_dataset(
nr_samples: int = 128,
domain_size: int = 64,
lr_factor: int = 4,
token: str = "edu.jhu.pha.turbulence.testing-201311",
token: str = "edu.jhu.pha.turbulence.testing-201406", # Request your own token at https://turbulence.idies.jhu.edu/home
data_dir: str = to_absolute_path("datasets/jhtdb_training"),
time_range: List[int] = [1, 1024],
dataset_seed: int = 123,
Expand All @@ -120,50 +119,49 @@ def make_jhtdb_dataset(
data_dir = Path(data_dir)
data_dir.mkdir(parents=True, exist_ok=True)

dataset_title = "isotropic1024coarse"

# initialize runner
lJHTDB = pyJHTDB.libJHTDB()
lJHTDB.initialize()
lJHTDB.add_token(token)
dataset = turb_dataset(dataset_title = dataset_title, output_path = str(data_dir), auth_token = token)

# loop to get dataset
np.random.seed(dataset_seed)
list_low_res_u = []
list_high_res_u = []
for i in tqdm(range(nr_samples)):
# set download params
dataset = "isotropic1024coarse"
field = "u"
field = "velocity"
# subfield = 0 # Velocity has 3 components: u-v-w. Specify 0 for u
time_step = int(np.random.randint(time_range[0], time_range[1]))
start = np.array(
[np.random.randint(1, 1024 - domain_size) for _ in range(3)], dtype=int
)
end = np.array([x + domain_size - 1 for x in start], dtype=int)
np.array(3 * [1], dtype=int)

# get high res data
high_res_u = get_jhtdb(
lJHTDB,
dataset,
data_dir,
dataset,
field,
time_step,
start,
end,
np.array(3 * [1], dtype=int),
1,
# 1, # JHTDB no longer supports filtering operations
)

# get low res data
low_res_u = get_jhtdb(
lJHTDB,
dataset,
data_dir,
dataset,
field,
time_step,
start,
end,
np.array(3 * [lr_factor], dtype=int),
lr_factor,
# lr_factor, # JHTDB no longer supports filtering operations
)

# plot
Expand Down