diff --git a/README.md b/README.md index a752569..1a4a852 100644 --- a/README.md +++ b/README.md @@ -2,11 +2,11 @@ [![build-dot](https://github.com/sensity-ai/dot/actions/workflows/build_dot.yaml/badge.svg)](https://github.com/sensity-ai/dot/actions/workflows/build_dot.yaml) [![code-check](https://github.com/sensity-ai/dot/actions/workflows/code_check.yaml/badge.svg)](https://github.com/sensity-ai/dot/actions/workflows/code_check.yaml) -dot (aka Deepfake Offensive Toolkit) makes real-time, controllable deepfakes ready for virtual cameras injection. dot is created for performing penetration testing against e.g. identity verification and video conferencing systems, for the use by security analysts, Red Team members, and biometrics researchers. +*dot* (aka Deepfake Offensive Toolkit) makes real-time, controllable deepfakes ready for virtual cameras injection. *dot* is created for performing penetration testing against e.g. identity verification and video conferencing systems, for the use by security analysts, Red Team members, and biometrics researchers. -If you want to learn more about dot is used for penetration tests with deepfakes in the industry, read [this article by The Verge](https://www.theverge.com/2022/5/18/23092964/deepfake-attack-facial-recognition-liveness-test-banks-sensity-report) +If you want to learn more about *dot* is used for penetration tests with deepfakes in the industry, read [this article by The Verge](https://www.theverge.com/2022/5/18/23092964/deepfake-attack-facial-recognition-liveness-test-banks-sensity-report) -*dot is developed for research and demonstration purposes. As an end user, you have the responsibility to obey all applicable laws when using this program. Authors and contributing developers assume no liability and are not responsible for any misuse or damage caused by the use of this program.* +dot *is developed for research and demonstration purposes. As an end user, you have the responsibility to obey all applicable laws when using this program. Authors and contributing developers assume no liability and are not responsible for any misuse or damage caused by the use of this program.*

@@ -14,7 +14,7 @@ If you want to learn more about dot is used for penetration tests with deepfakes ## How it works -In a nutshell, dot works like this +In a nutshell, *dot* works like this ```text __________________ _____________________________ __________________________ @@ -22,14 +22,14 @@ In a nutshell, dot works like this ------------------ ----------------------------- -------------------------- ``` -All deepfakes supported by dot do not require additional training. They can be used +All deepfakes supported by *dot* do not require additional training. They can be used in real-time on the fly on a photo that becomes the target of face impersonation. Supported methods: - face swap (via [SimSwap](https://github.com/neuralchen/SimSwap)), at resolutions `224` and `512` - with the option of face superresolution (via [GPen](https://github.com/yangxy/GPEN)) at resolutions `256` and `512` - lower quality face swap (via OpenCV) -- [first order motion model](https://github.com/AliaksandrSiarohin/first-order-model) +- [FOMM](https://github.com/AliaksandrSiarohin/first-order-model), First Order Motion Model for image animation ## Installation @@ -92,7 +92,7 @@ There are 2 options for downloading the model weights: ## Usage -### Running `dot`: +### Running dot Run `dot --help` to get a full list of available options. @@ -127,14 +127,14 @@ Run `dot --help` to get a full list of available options. Additionally, to enable face superresolution, use the flag `--gpen_type gpen_256` or `--gpen_type gpen_512`. -3. Avatarify +3. FOMM ```bash dot \ - --swap_type avatarify \ + --swap_type fomm \ --target 0 \ --source "./data" \ - --model_path ./saved_models/avatarify/vox-adv-cpk.pth.tar \ + --model_path ./saved_models/fomm/vox-adv-cpk.pth.tar \ --show_fps \ --use_gpu ``` @@ -151,19 +151,19 @@ Run `dot --help` to get a full list of available options. --use_gpu ``` -**Note**: To use dot on CPU (not recommended), do not pass the `--use_gpu` flag. +**Note**: To use *dot* on CPU (not recommended), do not pass the `--use_gpu` flag. -### Controlling dot: +### Controlling dot > **Disclaimer**: We use the `SimSwap` technique for the following demonstration -Running `dot` via any of the above methods generates real-time Deepfake on the input video feed using source images from the `./data` folder. +Running *dot* via any of the above methods generates real-time Deepfake on the input video feed using source images from the `./data` folder.

-When running `dot` a list of available control options appear on the terminal window as shown above. You can toggle through and select different source images by pressing the associated control key. +When running *dot* a list of available control options appear on the terminal window as shown above. You can toggle through and select different source images by pressing the associated control key. Watch the following demo video for better understanding of the control options: @@ -177,7 +177,7 @@ Instructions vary depending on your operating system. ### Windows -- Install [OBS Studio](https://obsproject.com/) for capturing Avatarify output. +- Install [OBS Studio](https://obsproject.com/). - Install [VirtualCam plugin](https://obsproject.com/forum/resources/obs-virtualcam.539/). @@ -188,7 +188,7 @@ Choose `Install and register only 1 virtual camera`. - In the Sources section, press on Add button ("+" sign), select Windows Capture and press OK. In the appeared window, - choose "[python.exe]: avatarify" in Window drop-down menu and press OK. + choose "[python.exe]: fomm" in Window drop-down menu and press OK. Then select Edit -> Transform -> Fit to screen. - In OBS Studio, go to Tools -> VirtualCam. Check AutoStart, @@ -230,7 +230,7 @@ Use the virtual camera with `OBS Studio`: - Download and install OBS Studio for MacOS from [here](https://obsproject.com/) - Open OBS and follow the first-time setup (you might be required to enable certain permissions in *System Preferences*) -- Run dot with `--use_cam` flag to enable camera feed +- Run *dot* with `--use_cam` flag to enable camera feed - Click the "+" button in the sources section → select "Windows Capture", create a new source and enter "OK" → select window with "python" included in the name and enter OK - Click "Start Virtual Camera" button in the controls section - Select "OBS Cam" as default camera in the video settings of the application target of the injection @@ -240,7 +240,7 @@ Use the virtual camera with `OBS Studio`: *This is not a commercial Sensity product, and it is distributed freely with no warranties* The software is distributed under [BSD 3-Clause](LICENSE). -dot utilizes several open source libraries. If you use dot, make sure you agree with their +*dot* utilizes several open source libraries. If you use *dot*, make sure you agree with their licenses too. In particular, this codebase is built on top of the following research projects: - @@ -252,9 +252,9 @@ licenses too. In particular, this codebase is built on top of the following rese This repository follows the [Google Python Style Guide](https://google.github.io/styleguide/pyguide.html) for code formatting. -If you have ideas for improving dot, feel free to open relevant Issues and PRs. Please read [CONTRIBUTING.md](./CONTRIBUTING.md) before contributing to the repository. +If you have ideas for improving *dot*, feel free to open relevant Issues and PRs. Please read [CONTRIBUTING.md](./CONTRIBUTING.md) before contributing to the repository. -If you are working on improving the speed of dot, please read first our guide on [code profiling](docs/profiling.md). +If you are working on improving the speed of *dot*, please read first our guide on [code profiling](docs/profiling.md). ### Setup Dev-Tools @@ -286,4 +286,4 @@ If you are working on improving the speed of dot, please read first our guide on ## Research -- [Run dot on image and video files instead of camera feed](docs/run_without_camera.md) +- [Run *dot* on image and video files instead of camera feed](docs/run_without_camera.md) diff --git a/dot/__main__.py b/dot/__main__.py index e2333d1..c9fea9b 100644 --- a/dot/__main__.py +++ b/dot/__main__.py @@ -15,7 +15,7 @@ @click.option( "--swap_type", "swap_type", - type=click.Choice(["avatarify", "faceswap_cv2", "simswap"], case_sensitive=False), + type=click.Choice(["fomm", "faceswap_cv2", "simswap"], case_sensitive=False), required=True, ) @click.option( diff --git a/dot/avatarify/__init__.py b/dot/avatarify/__init__.py deleted file mode 100644 index 6aba85f..0000000 --- a/dot/avatarify/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -#!/usr/bin/env python3 - -from .option import AvatarifyOption - -__all__ = ["AvatarifyOption"] diff --git a/dot/dot.py b/dot/dot.py index ce66752..ff71a8f 100644 --- a/dot/dot.py +++ b/dot/dot.py @@ -7,12 +7,12 @@ from pathlib import Path from typing import List, Optional, Union -from .avatarify import AvatarifyOption from .commons import ModelOption from .faceswap_cv2 import FaceswapCVOption +from .fomm import FOMMOption from .simswap import SimswapOption -AVAILABLE_SWAP_TYPES = ["simswap", "avatarify", "faceswap_cv2"] +AVAILABLE_SWAP_TYPES = ["simswap", "fomm", "faceswap_cv2"] class DOT: @@ -20,7 +20,7 @@ class DOT: Supported Engines: - `simswap` - - `avatarify` + - `fomm` - `faceswap_cv2` Attributes: @@ -90,8 +90,8 @@ def build_option( gpen_path=gpen_path, crop_size=crop_size, ) - elif swap_type == "avatarify": - option = self.avatarify( + elif swap_type == "fomm": + option = self.fomm( use_gpu=use_gpu, gpen_type=gpen_type, gpen_path=gpen_path ) elif swap_type == "faceswap_cv2": @@ -197,10 +197,10 @@ def faceswap_cv2( crop_size=crop_size, ) - def avatarify( + def fomm( self, use_gpu: bool, gpen_type: str, gpen_path: str, crop_size: int = 256 - ) -> AvatarifyOption: - """Build Avatarify Option. + ) -> FOMMOption: + """Build FOMM Option. Args: use_gpu (bool): If True, use GPU. @@ -209,9 +209,9 @@ def avatarify( crop_size (int, optional): crop size. Defaults to 256. Returns: - AvatarifyOption: Avatarify Option. + FOMMOption: FOMM Option. """ - return AvatarifyOption( + return FOMMOption( use_gpu=use_gpu, gpen_type=gpen_type, gpen_path=gpen_path, diff --git a/dot/fomm/__init__.py b/dot/fomm/__init__.py new file mode 100644 index 0000000..52b41ba --- /dev/null +++ b/dot/fomm/__init__.py @@ -0,0 +1,5 @@ +#!/usr/bin/env python3 + +from .option import FOMMOption + +__all__ = ["FOMMOption"] diff --git a/dot/avatarify/config/vox-adv-256.yaml b/dot/fomm/config/vox-adv-256.yaml similarity index 100% rename from dot/avatarify/config/vox-adv-256.yaml rename to dot/fomm/config/vox-adv-256.yaml diff --git a/dot/avatarify/modules/__init__.py b/dot/fomm/modules/__init__.py similarity index 100% rename from dot/avatarify/modules/__init__.py rename to dot/fomm/modules/__init__.py diff --git a/dot/avatarify/modules/dense_motion.py b/dot/fomm/modules/dense_motion.py similarity index 97% rename from dot/avatarify/modules/dense_motion.py rename to dot/fomm/modules/dense_motion.py index 9fb1723..1e1e538 100644 --- a/dot/avatarify/modules/dense_motion.py +++ b/dot/fomm/modules/dense_motion.py @@ -1,155 +1,155 @@ -#!/usr/bin/env python3 - -import torch -import torch.nn.functional as F -from torch import nn - -from .util import AntiAliasInterpolation2d, Hourglass, kp2gaussian, make_coordinate_grid - - -class DenseMotionNetwork(nn.Module): - """ - Module that predicting a dense motion - from sparse motion representation given - by kp_source and kp_driving - """ - - def __init__( - self, - block_expansion, - num_blocks, - max_features, - num_kp, - num_channels, - estimate_occlusion_map=False, - scale_factor=1, - kp_variance=0.01, - ): - - super(DenseMotionNetwork, self).__init__() - self.hourglass = Hourglass( - block_expansion=block_expansion, - in_features=(num_kp + 1) * (num_channels + 1), - max_features=max_features, - num_blocks=num_blocks, - ) - - self.mask = nn.Conv2d( - self.hourglass.out_filters, num_kp + 1, kernel_size=(7, 7), padding=(3, 3) - ) - - if estimate_occlusion_map: - self.occlusion = nn.Conv2d( - self.hourglass.out_filters, 1, kernel_size=(7, 7), padding=(3, 3) - ) - else: - self.occlusion = None - - self.num_kp = num_kp - self.scale_factor = scale_factor - self.kp_variance = kp_variance - - if self.scale_factor != 1: - self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor) - - def create_heatmap_representations(self, source_image, kp_driving, kp_source): - """ - Eq 6. in the paper H_k(z) - """ - spatial_size = source_image.shape[2:] - gaussian_driving = kp2gaussian( - kp_driving, spatial_size=spatial_size, kp_variance=self.kp_variance - ) - - gaussian_source = kp2gaussian( - kp_source, spatial_size=spatial_size, kp_variance=self.kp_variance - ) - heatmap = gaussian_driving - gaussian_source - - # adding background feature - zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1]).type( - heatmap.type() - ) - heatmap = torch.cat([zeros, heatmap], dim=1) - heatmap = heatmap.unsqueeze(2) - return heatmap - - def create_sparse_motions(self, source_image, kp_driving, kp_source): - """ - Eq 4. in the paper T_{s<-d}(z) - """ - bs, _, h, w = source_image.shape - identity_grid = make_coordinate_grid((h, w), type=kp_source["value"].type()) - identity_grid = identity_grid.view(1, 1, h, w, 2) - coordinate_grid = identity_grid - kp_driving["value"].view( - bs, self.num_kp, 1, 1, 2 - ) - if "jacobian" in kp_driving: - jacobian = torch.matmul( - kp_source["jacobian"], torch.inverse(kp_driving["jacobian"]) - ) - jacobian = jacobian.unsqueeze(-3).unsqueeze(-3) - jacobian = jacobian.repeat(1, 1, h, w, 1, 1) - coordinate_grid = torch.matmul(jacobian, coordinate_grid.unsqueeze(-1)) - coordinate_grid = coordinate_grid.squeeze(-1) - - driving_to_source = coordinate_grid + kp_source["value"].view( - bs, self.num_kp, 1, 1, 2 - ) - - # adding background feature - identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1) - sparse_motions = torch.cat([identity_grid, driving_to_source], dim=1) - return sparse_motions - - def create_deformed_source_image(self, source_image, sparse_motions): - """ - Eq 7. in the paper hat{T}_{s<-d}(z) - """ - bs, _, h, w = source_image.shape - source_repeat = ( - source_image.unsqueeze(1) - .unsqueeze(1) - .repeat(1, self.num_kp + 1, 1, 1, 1, 1) - ) - source_repeat = source_repeat.view(bs * (self.num_kp + 1), -1, h, w) - sparse_motions = sparse_motions.view((bs * (self.num_kp + 1), h, w, -1)) - sparse_deformed = F.grid_sample(source_repeat, sparse_motions) - sparse_deformed = sparse_deformed.view((bs, self.num_kp + 1, -1, h, w)) - return sparse_deformed - - def forward(self, source_image, kp_driving, kp_source): - if self.scale_factor != 1: - source_image = self.down(source_image) - - bs, _, h, w = source_image.shape - - out_dict = dict() - heatmap_representation = self.create_heatmap_representations( - source_image, kp_driving, kp_source - ) - sparse_motion = self.create_sparse_motions(source_image, kp_driving, kp_source) - deformed_source = self.create_deformed_source_image(source_image, sparse_motion) - out_dict["sparse_deformed"] = deformed_source - - input = torch.cat([heatmap_representation, deformed_source], dim=2) - input = input.view(bs, -1, h, w) - - prediction = self.hourglass(input) - - mask = self.mask(prediction) - mask = F.softmax(mask, dim=1) - out_dict["mask"] = mask - mask = mask.unsqueeze(2) - sparse_motion = sparse_motion.permute(0, 1, 4, 2, 3) - deformation = (sparse_motion * mask).sum(dim=1) - deformation = deformation.permute(0, 2, 3, 1) - - out_dict["deformation"] = deformation - - # Sec. 3.2 in the paper - if self.occlusion: - occlusion_map = torch.sigmoid(self.occlusion(prediction)) - out_dict["occlusion_map"] = occlusion_map - - return out_dict +#!/usr/bin/env python3 + +import torch +import torch.nn.functional as F +from torch import nn + +from .util import AntiAliasInterpolation2d, Hourglass, kp2gaussian, make_coordinate_grid + + +class DenseMotionNetwork(nn.Module): + """ + Module that predicting a dense motion + from sparse motion representation given + by kp_source and kp_driving + """ + + def __init__( + self, + block_expansion, + num_blocks, + max_features, + num_kp, + num_channels, + estimate_occlusion_map=False, + scale_factor=1, + kp_variance=0.01, + ): + + super(DenseMotionNetwork, self).__init__() + self.hourglass = Hourglass( + block_expansion=block_expansion, + in_features=(num_kp + 1) * (num_channels + 1), + max_features=max_features, + num_blocks=num_blocks, + ) + + self.mask = nn.Conv2d( + self.hourglass.out_filters, num_kp + 1, kernel_size=(7, 7), padding=(3, 3) + ) + + if estimate_occlusion_map: + self.occlusion = nn.Conv2d( + self.hourglass.out_filters, 1, kernel_size=(7, 7), padding=(3, 3) + ) + else: + self.occlusion = None + + self.num_kp = num_kp + self.scale_factor = scale_factor + self.kp_variance = kp_variance + + if self.scale_factor != 1: + self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor) + + def create_heatmap_representations(self, source_image, kp_driving, kp_source): + """ + Eq 6. in the paper H_k(z) + """ + spatial_size = source_image.shape[2:] + gaussian_driving = kp2gaussian( + kp_driving, spatial_size=spatial_size, kp_variance=self.kp_variance + ) + + gaussian_source = kp2gaussian( + kp_source, spatial_size=spatial_size, kp_variance=self.kp_variance + ) + heatmap = gaussian_driving - gaussian_source + + # adding background feature + zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1]).type( + heatmap.type() + ) + heatmap = torch.cat([zeros, heatmap], dim=1) + heatmap = heatmap.unsqueeze(2) + return heatmap + + def create_sparse_motions(self, source_image, kp_driving, kp_source): + """ + Eq 4. in the paper T_{s<-d}(z) + """ + bs, _, h, w = source_image.shape + identity_grid = make_coordinate_grid((h, w), type=kp_source["value"].type()) + identity_grid = identity_grid.view(1, 1, h, w, 2) + coordinate_grid = identity_grid - kp_driving["value"].view( + bs, self.num_kp, 1, 1, 2 + ) + if "jacobian" in kp_driving: + jacobian = torch.matmul( + kp_source["jacobian"], torch.inverse(kp_driving["jacobian"]) + ) + jacobian = jacobian.unsqueeze(-3).unsqueeze(-3) + jacobian = jacobian.repeat(1, 1, h, w, 1, 1) + coordinate_grid = torch.matmul(jacobian, coordinate_grid.unsqueeze(-1)) + coordinate_grid = coordinate_grid.squeeze(-1) + + driving_to_source = coordinate_grid + kp_source["value"].view( + bs, self.num_kp, 1, 1, 2 + ) + + # adding background feature + identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1) + sparse_motions = torch.cat([identity_grid, driving_to_source], dim=1) + return sparse_motions + + def create_deformed_source_image(self, source_image, sparse_motions): + """ + Eq 7. in the paper hat{T}_{s<-d}(z) + """ + bs, _, h, w = source_image.shape + source_repeat = ( + source_image.unsqueeze(1) + .unsqueeze(1) + .repeat(1, self.num_kp + 1, 1, 1, 1, 1) + ) + source_repeat = source_repeat.view(bs * (self.num_kp + 1), -1, h, w) + sparse_motions = sparse_motions.view((bs * (self.num_kp + 1), h, w, -1)) + sparse_deformed = F.grid_sample(source_repeat, sparse_motions) + sparse_deformed = sparse_deformed.view((bs, self.num_kp + 1, -1, h, w)) + return sparse_deformed + + def forward(self, source_image, kp_driving, kp_source): + if self.scale_factor != 1: + source_image = self.down(source_image) + + bs, _, h, w = source_image.shape + + out_dict = dict() + heatmap_representation = self.create_heatmap_representations( + source_image, kp_driving, kp_source + ) + sparse_motion = self.create_sparse_motions(source_image, kp_driving, kp_source) + deformed_source = self.create_deformed_source_image(source_image, sparse_motion) + out_dict["sparse_deformed"] = deformed_source + + input = torch.cat([heatmap_representation, deformed_source], dim=2) + input = input.view(bs, -1, h, w) + + prediction = self.hourglass(input) + + mask = self.mask(prediction) + mask = F.softmax(mask, dim=1) + out_dict["mask"] = mask + mask = mask.unsqueeze(2) + sparse_motion = sparse_motion.permute(0, 1, 4, 2, 3) + deformation = (sparse_motion * mask).sum(dim=1) + deformation = deformation.permute(0, 2, 3, 1) + + out_dict["deformation"] = deformation + + # Sec. 3.2 in the paper + if self.occlusion: + occlusion_map = torch.sigmoid(self.occlusion(prediction)) + out_dict["occlusion_map"] = occlusion_map + + return out_dict diff --git a/dot/avatarify/modules/generator_optim.py b/dot/fomm/modules/generator_optim.py similarity index 97% rename from dot/avatarify/modules/generator_optim.py rename to dot/fomm/modules/generator_optim.py index f2b3d14..8f3dd87 100644 --- a/dot/avatarify/modules/generator_optim.py +++ b/dot/fomm/modules/generator_optim.py @@ -1,146 +1,146 @@ -#!/usr/bin/env python3 - -import torch -import torch.nn.functional as F -from torch import nn - -from .dense_motion import DenseMotionNetwork -from .util import DownBlock2d, ResBlock2d, SameBlock2d, UpBlock2d - - -class OcclusionAwareGenerator(nn.Module): - """ - Generator that given source image and keypoints - try to transform image according to movement trajectories - induced by keypoints. Generator follows Johnson architecture. - """ - - def __init__( - self, - num_channels, - num_kp, - block_expansion, - max_features, - num_down_blocks, - num_bottleneck_blocks, - estimate_occlusion_map=False, - dense_motion_params=None, - estimate_jacobian=False, - ): - super(OcclusionAwareGenerator, self).__init__() - - if dense_motion_params is not None: - self.dense_motion_network = DenseMotionNetwork( - num_kp=num_kp, - num_channels=num_channels, - estimate_occlusion_map=estimate_occlusion_map, - **dense_motion_params - ) - else: - self.dense_motion_network = None - - self.first = SameBlock2d( - num_channels, block_expansion, kernel_size=(7, 7), padding=(3, 3) - ) - - down_blocks = [] - for i in range(num_down_blocks): - in_features = min(max_features, block_expansion * (2**i)) - out_features = min(max_features, block_expansion * (2 ** (i + 1))) - down_blocks.append( - DownBlock2d( - in_features, out_features, kernel_size=(3, 3), padding=(1, 1) - ) - ) - self.down_blocks = nn.ModuleList(down_blocks) - - up_blocks = [] - for i in range(num_down_blocks): - in_features = min( - max_features, block_expansion * (2 ** (num_down_blocks - i)) - ) - out_features = min( - max_features, block_expansion * (2 ** (num_down_blocks - i - 1)) - ) - up_blocks.append( - UpBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)) - ) - self.up_blocks = nn.ModuleList(up_blocks) - - self.bottleneck = torch.nn.Sequential() - in_features = min(max_features, block_expansion * (2**num_down_blocks)) - for i in range(num_bottleneck_blocks): - self.bottleneck.add_module( - "r" + str(i), - ResBlock2d(in_features, kernel_size=(3, 3), padding=(1, 1)), - ) - - self.final = nn.Conv2d( - block_expansion, num_channels, kernel_size=(7, 7), padding=(3, 3) - ) - self.estimate_occlusion_map = estimate_occlusion_map - self.num_channels = num_channels - - self.enc_features = None - - def deform_input(self, inp, deformation): - _, h_old, w_old, _ = deformation.shape - _, _, h, w = inp.shape - if h_old != h or w_old != w: - deformation = deformation.permute(0, 3, 1, 2) - deformation = F.interpolate(deformation, size=(h, w), mode="bilinear") - deformation = deformation.permute(0, 2, 3, 1) - return F.grid_sample(inp, deformation) - - def encode_source(self, source_image): - # Encoding (downsampling) part - out = self.first(source_image) - for i in range(len(self.down_blocks)): - out = self.down_blocks[i](out) - - self.enc_features = out - - def forward(self, source_image, kp_driving, kp_source, optim_ret=True): - assert self.enc_features is not None, "Call encode_source()" - out = self.enc_features - - # Transforming feature representation - # according to deformation and occlusion - output_dict = {} - if self.dense_motion_network is not None: - dense_motion = self.dense_motion_network( - source_image=source_image, kp_driving=kp_driving, kp_source=kp_source - ) - output_dict["mask"] = dense_motion["mask"] - output_dict["sparse_deformed"] = dense_motion["sparse_deformed"] - - if "occlusion_map" in dense_motion: - occlusion_map = dense_motion["occlusion_map"] - output_dict["occlusion_map"] = occlusion_map - else: - occlusion_map = None - deformation = dense_motion["deformation"] - out = self.deform_input(out, deformation) - - if occlusion_map is not None: - if (out.shape[2] != occlusion_map.shape[2]) or ( - out.shape[3] != occlusion_map.shape[3] - ): - occlusion_map = F.interpolate( - occlusion_map, size=out.shape[2:], mode="bilinear" - ) - out = out * occlusion_map - - if not optim_ret: - output_dict["deformed"] = self.deform_input(source_image, deformation) - - # Decoding part - out = self.bottleneck(out) - for i in range(len(self.up_blocks)): - out = self.up_blocks[i](out) - out = self.final(out) - out = F.sigmoid(out) - - output_dict["prediction"] = out - - return output_dict +#!/usr/bin/env python3 + +import torch +import torch.nn.functional as F +from torch import nn + +from .dense_motion import DenseMotionNetwork +from .util import DownBlock2d, ResBlock2d, SameBlock2d, UpBlock2d + + +class OcclusionAwareGenerator(nn.Module): + """ + Generator that given source image and keypoints + try to transform image according to movement trajectories + induced by keypoints. Generator follows Johnson architecture. + """ + + def __init__( + self, + num_channels, + num_kp, + block_expansion, + max_features, + num_down_blocks, + num_bottleneck_blocks, + estimate_occlusion_map=False, + dense_motion_params=None, + estimate_jacobian=False, + ): + super(OcclusionAwareGenerator, self).__init__() + + if dense_motion_params is not None: + self.dense_motion_network = DenseMotionNetwork( + num_kp=num_kp, + num_channels=num_channels, + estimate_occlusion_map=estimate_occlusion_map, + **dense_motion_params + ) + else: + self.dense_motion_network = None + + self.first = SameBlock2d( + num_channels, block_expansion, kernel_size=(7, 7), padding=(3, 3) + ) + + down_blocks = [] + for i in range(num_down_blocks): + in_features = min(max_features, block_expansion * (2**i)) + out_features = min(max_features, block_expansion * (2 ** (i + 1))) + down_blocks.append( + DownBlock2d( + in_features, out_features, kernel_size=(3, 3), padding=(1, 1) + ) + ) + self.down_blocks = nn.ModuleList(down_blocks) + + up_blocks = [] + for i in range(num_down_blocks): + in_features = min( + max_features, block_expansion * (2 ** (num_down_blocks - i)) + ) + out_features = min( + max_features, block_expansion * (2 ** (num_down_blocks - i - 1)) + ) + up_blocks.append( + UpBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)) + ) + self.up_blocks = nn.ModuleList(up_blocks) + + self.bottleneck = torch.nn.Sequential() + in_features = min(max_features, block_expansion * (2**num_down_blocks)) + for i in range(num_bottleneck_blocks): + self.bottleneck.add_module( + "r" + str(i), + ResBlock2d(in_features, kernel_size=(3, 3), padding=(1, 1)), + ) + + self.final = nn.Conv2d( + block_expansion, num_channels, kernel_size=(7, 7), padding=(3, 3) + ) + self.estimate_occlusion_map = estimate_occlusion_map + self.num_channels = num_channels + + self.enc_features = None + + def deform_input(self, inp, deformation): + _, h_old, w_old, _ = deformation.shape + _, _, h, w = inp.shape + if h_old != h or w_old != w: + deformation = deformation.permute(0, 3, 1, 2) + deformation = F.interpolate(deformation, size=(h, w), mode="bilinear") + deformation = deformation.permute(0, 2, 3, 1) + return F.grid_sample(inp, deformation) + + def encode_source(self, source_image): + # Encoding (downsampling) part + out = self.first(source_image) + for i in range(len(self.down_blocks)): + out = self.down_blocks[i](out) + + self.enc_features = out + + def forward(self, source_image, kp_driving, kp_source, optim_ret=True): + assert self.enc_features is not None, "Call encode_source()" + out = self.enc_features + + # Transforming feature representation + # according to deformation and occlusion + output_dict = {} + if self.dense_motion_network is not None: + dense_motion = self.dense_motion_network( + source_image=source_image, kp_driving=kp_driving, kp_source=kp_source + ) + output_dict["mask"] = dense_motion["mask"] + output_dict["sparse_deformed"] = dense_motion["sparse_deformed"] + + if "occlusion_map" in dense_motion: + occlusion_map = dense_motion["occlusion_map"] + output_dict["occlusion_map"] = occlusion_map + else: + occlusion_map = None + deformation = dense_motion["deformation"] + out = self.deform_input(out, deformation) + + if occlusion_map is not None: + if (out.shape[2] != occlusion_map.shape[2]) or ( + out.shape[3] != occlusion_map.shape[3] + ): + occlusion_map = F.interpolate( + occlusion_map, size=out.shape[2:], mode="bilinear" + ) + out = out * occlusion_map + + if not optim_ret: + output_dict["deformed"] = self.deform_input(source_image, deformation) + + # Decoding part + out = self.bottleneck(out) + for i in range(len(self.up_blocks)): + out = self.up_blocks[i](out) + out = self.final(out) + out = F.sigmoid(out) + + output_dict["prediction"] = out + + return output_dict diff --git a/dot/avatarify/modules/keypoint_detector.py b/dot/fomm/modules/keypoint_detector.py similarity index 96% rename from dot/avatarify/modules/keypoint_detector.py rename to dot/fomm/modules/keypoint_detector.py index 8f9be4e..8ebbde3 100644 --- a/dot/avatarify/modules/keypoint_detector.py +++ b/dot/fomm/modules/keypoint_detector.py @@ -1,111 +1,111 @@ -#!/usr/bin/env python3 - -import torch -import torch.nn.functional as F -from torch import nn - -from .util import AntiAliasInterpolation2d, Hourglass, make_coordinate_grid - - -class KPDetector(nn.Module): - """ - Detecting a keypoints. Return keypoint position - and jacobian near each keypoint. - """ - - def __init__( - self, - block_expansion, - num_kp, - num_channels, - max_features, - num_blocks, - temperature, - estimate_jacobian=False, - scale_factor=1, - single_jacobian_map=False, - pad=0, - ): - - super(KPDetector, self).__init__() - - self.predictor = Hourglass( - block_expansion, - in_features=num_channels, - max_features=max_features, - num_blocks=num_blocks, - ) - - self.kp = nn.Conv2d( - in_channels=self.predictor.out_filters, - out_channels=num_kp, - kernel_size=(7, 7), - padding=pad, - ) - - if estimate_jacobian: - self.num_jacobian_maps = 1 if single_jacobian_map else num_kp - self.jacobian = nn.Conv2d( - in_channels=self.predictor.out_filters, - out_channels=4 * self.num_jacobian_maps, - kernel_size=(7, 7), - padding=pad, - ) - self.jacobian.weight.data.zero_() - self.jacobian.bias.data.copy_( - torch.tensor([1, 0, 0, 1] * self.num_jacobian_maps, dtype=torch.float) - ) - else: - self.jacobian = None - - self.temperature = temperature - self.scale_factor = scale_factor - if self.scale_factor != 1: - self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor) - - def gaussian2kp(self, heatmap): - """ - Extract the mean and from a heatmap - """ - shape = heatmap.shape - heatmap = heatmap.unsqueeze(-1) - grid = ( - make_coordinate_grid(shape[2:], heatmap.type()).unsqueeze_(0).unsqueeze_(0) - ) - value = (heatmap * grid).sum(dim=(2, 3)) - kp = {"value": value} - - return kp - - def forward(self, x): - if self.scale_factor != 1: - x = self.down(x) - - feature_map = self.predictor(x) - prediction = self.kp(feature_map) - - final_shape = prediction.shape - heatmap = prediction.view(final_shape[0], final_shape[1], -1) - heatmap = F.softmax(heatmap / self.temperature, dim=2) - heatmap = heatmap.view(*final_shape) - - out = self.gaussian2kp(heatmap) - - if self.jacobian is not None: - jacobian_map = self.jacobian(feature_map) - jacobian_map = jacobian_map.reshape( - final_shape[0], - self.num_jacobian_maps, - 4, - final_shape[2], - final_shape[3], - ) - heatmap = heatmap.unsqueeze(2) - - jacobian = heatmap * jacobian_map - jacobian = jacobian.view(final_shape[0], final_shape[1], 4, -1) - jacobian = jacobian.sum(dim=-1) - jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 2, 2) - out["jacobian"] = jacobian - - return out +#!/usr/bin/env python3 + +import torch +import torch.nn.functional as F +from torch import nn + +from .util import AntiAliasInterpolation2d, Hourglass, make_coordinate_grid + + +class KPDetector(nn.Module): + """ + Detecting a keypoints. Return keypoint position + and jacobian near each keypoint. + """ + + def __init__( + self, + block_expansion, + num_kp, + num_channels, + max_features, + num_blocks, + temperature, + estimate_jacobian=False, + scale_factor=1, + single_jacobian_map=False, + pad=0, + ): + + super(KPDetector, self).__init__() + + self.predictor = Hourglass( + block_expansion, + in_features=num_channels, + max_features=max_features, + num_blocks=num_blocks, + ) + + self.kp = nn.Conv2d( + in_channels=self.predictor.out_filters, + out_channels=num_kp, + kernel_size=(7, 7), + padding=pad, + ) + + if estimate_jacobian: + self.num_jacobian_maps = 1 if single_jacobian_map else num_kp + self.jacobian = nn.Conv2d( + in_channels=self.predictor.out_filters, + out_channels=4 * self.num_jacobian_maps, + kernel_size=(7, 7), + padding=pad, + ) + self.jacobian.weight.data.zero_() + self.jacobian.bias.data.copy_( + torch.tensor([1, 0, 0, 1] * self.num_jacobian_maps, dtype=torch.float) + ) + else: + self.jacobian = None + + self.temperature = temperature + self.scale_factor = scale_factor + if self.scale_factor != 1: + self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor) + + def gaussian2kp(self, heatmap): + """ + Extract the mean and from a heatmap + """ + shape = heatmap.shape + heatmap = heatmap.unsqueeze(-1) + grid = ( + make_coordinate_grid(shape[2:], heatmap.type()).unsqueeze_(0).unsqueeze_(0) + ) + value = (heatmap * grid).sum(dim=(2, 3)) + kp = {"value": value} + + return kp + + def forward(self, x): + if self.scale_factor != 1: + x = self.down(x) + + feature_map = self.predictor(x) + prediction = self.kp(feature_map) + + final_shape = prediction.shape + heatmap = prediction.view(final_shape[0], final_shape[1], -1) + heatmap = F.softmax(heatmap / self.temperature, dim=2) + heatmap = heatmap.view(*final_shape) + + out = self.gaussian2kp(heatmap) + + if self.jacobian is not None: + jacobian_map = self.jacobian(feature_map) + jacobian_map = jacobian_map.reshape( + final_shape[0], + self.num_jacobian_maps, + 4, + final_shape[2], + final_shape[3], + ) + heatmap = heatmap.unsqueeze(2) + + jacobian = heatmap * jacobian_map + jacobian = jacobian.view(final_shape[0], final_shape[1], 4, -1) + jacobian = jacobian.sum(dim=-1) + jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 2, 2) + out["jacobian"] = jacobian + + return out diff --git a/dot/avatarify/modules/util.py b/dot/fomm/modules/util.py similarity index 96% rename from dot/avatarify/modules/util.py rename to dot/fomm/modules/util.py index cc4f8a5..86ac63e 100644 --- a/dot/avatarify/modules/util.py +++ b/dot/fomm/modules/util.py @@ -1,291 +1,291 @@ -#!/usr/bin/env python3 - -import torch -import torch.nn.functional as F -from torch import nn - -from ..sync_batchnorm.batchnorm import SynchronizedBatchNorm2d as BatchNorm2d - - -def kp2gaussian(kp, spatial_size, kp_variance): - """ - Transform a keypoint into gaussian like representation - """ - mean = kp["value"] - - coordinate_grid = make_coordinate_grid(spatial_size, mean.type()) - number_of_leading_dimensions = len(mean.shape) - 1 - shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape - coordinate_grid = coordinate_grid.view(*shape) - repeats = mean.shape[:number_of_leading_dimensions] + (1, 1, 1) - coordinate_grid = coordinate_grid.repeat(*repeats) - - # Preprocess kp shape - shape = mean.shape[:number_of_leading_dimensions] + (1, 1, 2) - mean = mean.view(*shape) - - mean_sub = coordinate_grid - mean - - out = torch.exp(-0.5 * (mean_sub**2).sum(-1) / kp_variance) - - return out - - -def make_coordinate_grid(spatial_size, type): - """ - Create a meshgrid [-1,1] x [-1,1] of given spatial_size. - """ - h, w = spatial_size - x = torch.arange(w).type(type) - y = torch.arange(h).type(type) - - x = 2 * (x / (w - 1)) - 1 - y = 2 * (y / (h - 1)) - 1 - - yy = y.view(-1, 1).repeat(1, w) - xx = x.view(1, -1).repeat(h, 1) - - meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2) - - return meshed - - -class ResBlock2d(nn.Module): - """ - Res block, preserve spatial resolution. - """ - - def __init__(self, in_features, kernel_size, padding): - super(ResBlock2d, self).__init__() - self.conv1 = nn.Conv2d( - in_channels=in_features, - out_channels=in_features, - kernel_size=kernel_size, - padding=padding, - ) - - self.conv2 = nn.Conv2d( - in_channels=in_features, - out_channels=in_features, - kernel_size=kernel_size, - padding=padding, - ) - - self.norm1 = BatchNorm2d(in_features, affine=True) - self.norm2 = BatchNorm2d(in_features, affine=True) - - def forward(self, x): - out = self.norm1(x) - out = F.relu(out) - out = self.conv1(out) - out = self.norm2(out) - out = F.relu(out) - out = self.conv2(out) - out += x - return out - - -class UpBlock2d(nn.Module): - """ - Upsampling block for use in decoder. - """ - - def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): - - super(UpBlock2d, self).__init__() - - self.conv = nn.Conv2d( - in_channels=in_features, - out_channels=out_features, - kernel_size=kernel_size, - padding=padding, - groups=groups, - ) - - self.norm = BatchNorm2d(out_features, affine=True) - - def forward(self, x): - out = F.interpolate(x, scale_factor=2) - out = self.conv(out) - out = self.norm(out) - out = F.relu(out) - return out - - -class DownBlock2d(nn.Module): - """ - Downsampling block for use in encoder. - """ - - def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): - - super(DownBlock2d, self).__init__() - self.conv = nn.Conv2d( - in_channels=in_features, - out_channels=out_features, - kernel_size=kernel_size, - padding=padding, - groups=groups, - ) - - self.norm = BatchNorm2d(out_features, affine=True) - self.pool = nn.AvgPool2d(kernel_size=(2, 2)) - - def forward(self, x): - out = self.conv(x) - out = self.norm(out) - out = F.relu(out) - out = self.pool(out) - return out - - -class SameBlock2d(nn.Module): - """ - Simple block, preserve spatial resolution. - """ - - def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1): - - super(SameBlock2d, self).__init__() - self.conv = nn.Conv2d( - in_channels=in_features, - out_channels=out_features, - kernel_size=kernel_size, - padding=padding, - groups=groups, - ) - - self.norm = BatchNorm2d(out_features, affine=True) - - def forward(self, x): - out = self.conv(x) - out = self.norm(out) - out = F.relu(out) - return out - - -class Encoder(nn.Module): - """ - Hourglass Encoder - """ - - def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): - - super(Encoder, self).__init__() - - down_blocks = [] - for i in range(num_blocks): - down_blocks.append( - DownBlock2d( - in_features - if i == 0 - else min(max_features, block_expansion * (2**i)), - min(max_features, block_expansion * (2 ** (i + 1))), - kernel_size=3, - padding=1, - ) - ) - - self.down_blocks = nn.ModuleList(down_blocks) - - def forward(self, x): - outs = [x] - for down_block in self.down_blocks: - outs.append(down_block(outs[-1])) - return outs - - -class Decoder(nn.Module): - """ - Hourglass Decoder - """ - - def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): - - super(Decoder, self).__init__() - - up_blocks = [] - - for i in range(num_blocks)[::-1]: - in_filters = (1 if i == num_blocks - 1 else 2) * min( - max_features, block_expansion * (2 ** (i + 1)) - ) - out_filters = min(max_features, block_expansion * (2**i)) - up_blocks.append( - UpBlock2d(in_filters, out_filters, kernel_size=3, padding=1) - ) - - self.up_blocks = nn.ModuleList(up_blocks) - self.out_filters = block_expansion + in_features - - def forward(self, x): - out = x.pop() - for up_block in self.up_blocks: - out = up_block(out) - skip = x.pop() - out = torch.cat([out, skip], dim=1) - return out - - -class Hourglass(nn.Module): - """ - Hourglass architecture. - """ - - def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): - - super(Hourglass, self).__init__() - self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features) - - self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features) - - self.out_filters = self.decoder.out_filters - - def forward(self, x): - return self.decoder(self.encoder(x)) - - -class AntiAliasInterpolation2d(nn.Module): - """ - Band-limited downsampling, - for better preservation of the input signal. - """ - - def __init__(self, channels, scale): - super(AntiAliasInterpolation2d, self).__init__() - sigma = (1 / scale - 1) / 2 - kernel_size = 2 * round(sigma * 4) + 1 - self.ka = kernel_size // 2 - self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka - - kernel_size = [kernel_size, kernel_size] - sigma = [sigma, sigma] - # The gaussian kernel is the product of the - # gaussian function of each dimension. - kernel = 1 - meshgrids = torch.meshgrid( - [torch.arange(size, dtype=torch.float32) for size in kernel_size] - ) - for size, std, mgrid in zip(kernel_size, sigma, meshgrids): - mean = (size - 1) / 2 - kernel *= torch.exp(-((mgrid - mean) ** 2) / (2 * std**2)) - - # Make sure sum of values in gaussian kernel equals 1. - kernel = kernel / torch.sum(kernel) - # Reshape to depthwise convolutional weight - kernel = kernel.view(1, 1, *kernel.size()) - kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) - - self.register_buffer("weight", kernel) - self.groups = channels - self.scale = scale - - def forward(self, input): - if self.scale == 1.0: - return input - - out = F.pad(input, (self.ka, self.kb, self.ka, self.kb)) - out = F.conv2d(out, weight=self.weight, groups=self.groups) - out = F.interpolate(out, scale_factor=(self.scale, self.scale)) - - return out +#!/usr/bin/env python3 + +import torch +import torch.nn.functional as F +from torch import nn + +from ..sync_batchnorm.batchnorm import SynchronizedBatchNorm2d as BatchNorm2d + + +def kp2gaussian(kp, spatial_size, kp_variance): + """ + Transform a keypoint into gaussian like representation + """ + mean = kp["value"] + + coordinate_grid = make_coordinate_grid(spatial_size, mean.type()) + number_of_leading_dimensions = len(mean.shape) - 1 + shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape + coordinate_grid = coordinate_grid.view(*shape) + repeats = mean.shape[:number_of_leading_dimensions] + (1, 1, 1) + coordinate_grid = coordinate_grid.repeat(*repeats) + + # Preprocess kp shape + shape = mean.shape[:number_of_leading_dimensions] + (1, 1, 2) + mean = mean.view(*shape) + + mean_sub = coordinate_grid - mean + + out = torch.exp(-0.5 * (mean_sub**2).sum(-1) / kp_variance) + + return out + + +def make_coordinate_grid(spatial_size, type): + """ + Create a meshgrid [-1,1] x [-1,1] of given spatial_size. + """ + h, w = spatial_size + x = torch.arange(w).type(type) + y = torch.arange(h).type(type) + + x = 2 * (x / (w - 1)) - 1 + y = 2 * (y / (h - 1)) - 1 + + yy = y.view(-1, 1).repeat(1, w) + xx = x.view(1, -1).repeat(h, 1) + + meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2) + + return meshed + + +class ResBlock2d(nn.Module): + """ + Res block, preserve spatial resolution. + """ + + def __init__(self, in_features, kernel_size, padding): + super(ResBlock2d, self).__init__() + self.conv1 = nn.Conv2d( + in_channels=in_features, + out_channels=in_features, + kernel_size=kernel_size, + padding=padding, + ) + + self.conv2 = nn.Conv2d( + in_channels=in_features, + out_channels=in_features, + kernel_size=kernel_size, + padding=padding, + ) + + self.norm1 = BatchNorm2d(in_features, affine=True) + self.norm2 = BatchNorm2d(in_features, affine=True) + + def forward(self, x): + out = self.norm1(x) + out = F.relu(out) + out = self.conv1(out) + out = self.norm2(out) + out = F.relu(out) + out = self.conv2(out) + out += x + return out + + +class UpBlock2d(nn.Module): + """ + Upsampling block for use in decoder. + """ + + def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): + + super(UpBlock2d, self).__init__() + + self.conv = nn.Conv2d( + in_channels=in_features, + out_channels=out_features, + kernel_size=kernel_size, + padding=padding, + groups=groups, + ) + + self.norm = BatchNorm2d(out_features, affine=True) + + def forward(self, x): + out = F.interpolate(x, scale_factor=2) + out = self.conv(out) + out = self.norm(out) + out = F.relu(out) + return out + + +class DownBlock2d(nn.Module): + """ + Downsampling block for use in encoder. + """ + + def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): + + super(DownBlock2d, self).__init__() + self.conv = nn.Conv2d( + in_channels=in_features, + out_channels=out_features, + kernel_size=kernel_size, + padding=padding, + groups=groups, + ) + + self.norm = BatchNorm2d(out_features, affine=True) + self.pool = nn.AvgPool2d(kernel_size=(2, 2)) + + def forward(self, x): + out = self.conv(x) + out = self.norm(out) + out = F.relu(out) + out = self.pool(out) + return out + + +class SameBlock2d(nn.Module): + """ + Simple block, preserve spatial resolution. + """ + + def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1): + + super(SameBlock2d, self).__init__() + self.conv = nn.Conv2d( + in_channels=in_features, + out_channels=out_features, + kernel_size=kernel_size, + padding=padding, + groups=groups, + ) + + self.norm = BatchNorm2d(out_features, affine=True) + + def forward(self, x): + out = self.conv(x) + out = self.norm(out) + out = F.relu(out) + return out + + +class Encoder(nn.Module): + """ + Hourglass Encoder + """ + + def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): + + super(Encoder, self).__init__() + + down_blocks = [] + for i in range(num_blocks): + down_blocks.append( + DownBlock2d( + in_features + if i == 0 + else min(max_features, block_expansion * (2**i)), + min(max_features, block_expansion * (2 ** (i + 1))), + kernel_size=3, + padding=1, + ) + ) + + self.down_blocks = nn.ModuleList(down_blocks) + + def forward(self, x): + outs = [x] + for down_block in self.down_blocks: + outs.append(down_block(outs[-1])) + return outs + + +class Decoder(nn.Module): + """ + Hourglass Decoder + """ + + def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): + + super(Decoder, self).__init__() + + up_blocks = [] + + for i in range(num_blocks)[::-1]: + in_filters = (1 if i == num_blocks - 1 else 2) * min( + max_features, block_expansion * (2 ** (i + 1)) + ) + out_filters = min(max_features, block_expansion * (2**i)) + up_blocks.append( + UpBlock2d(in_filters, out_filters, kernel_size=3, padding=1) + ) + + self.up_blocks = nn.ModuleList(up_blocks) + self.out_filters = block_expansion + in_features + + def forward(self, x): + out = x.pop() + for up_block in self.up_blocks: + out = up_block(out) + skip = x.pop() + out = torch.cat([out, skip], dim=1) + return out + + +class Hourglass(nn.Module): + """ + Hourglass architecture. + """ + + def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): + + super(Hourglass, self).__init__() + self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features) + + self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features) + + self.out_filters = self.decoder.out_filters + + def forward(self, x): + return self.decoder(self.encoder(x)) + + +class AntiAliasInterpolation2d(nn.Module): + """ + Band-limited downsampling, + for better preservation of the input signal. + """ + + def __init__(self, channels, scale): + super(AntiAliasInterpolation2d, self).__init__() + sigma = (1 / scale - 1) / 2 + kernel_size = 2 * round(sigma * 4) + 1 + self.ka = kernel_size // 2 + self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka + + kernel_size = [kernel_size, kernel_size] + sigma = [sigma, sigma] + # The gaussian kernel is the product of the + # gaussian function of each dimension. + kernel = 1 + meshgrids = torch.meshgrid( + [torch.arange(size, dtype=torch.float32) for size in kernel_size] + ) + for size, std, mgrid in zip(kernel_size, sigma, meshgrids): + mean = (size - 1) / 2 + kernel *= torch.exp(-((mgrid - mean) ** 2) / (2 * std**2)) + + # Make sure sum of values in gaussian kernel equals 1. + kernel = kernel / torch.sum(kernel) + # Reshape to depthwise convolutional weight + kernel = kernel.view(1, 1, *kernel.size()) + kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) + + self.register_buffer("weight", kernel) + self.groups = channels + self.scale = scale + + def forward(self, input): + if self.scale == 1.0: + return input + + out = F.pad(input, (self.ka, self.kb, self.ka, self.kb)) + out = F.conv2d(out, weight=self.weight, groups=self.groups) + out = F.interpolate(out, scale_factor=(self.scale, self.scale)) + + return out diff --git a/dot/avatarify/option.py b/dot/fomm/option.py similarity index 96% rename from dot/avatarify/option.py rename to dot/fomm/option.py index 444301c..af6c851 100644 --- a/dot/avatarify/option.py +++ b/dot/fomm/option.py @@ -35,7 +35,7 @@ def determine_path(): sys.exit() -class AvatarifyOption(ModelOption): +class FOMMOption(ModelOption): def __init__( self, use_gpu=True, @@ -44,7 +44,7 @@ def __init__( gpen_type=None, gpen_path=None, ): - super(AvatarifyOption, self).__init__( + super(FOMMOption, self).__init__( gpen_type=gpen_type, use_gpu=use_gpu, crop_size=crop_size, @@ -134,8 +134,8 @@ def handle_keyboard_input(self): self.predictor.reset_frames() if not self.is_calibrated: - cv2.namedWindow("FOM", cv2.WINDOW_GUI_NORMAL) - cv2.moveWindow("FOM", 600, 250) + cv2.namedWindow("FOMM", cv2.WINDOW_GUI_NORMAL) + cv2.moveWindow("FOMM", 600, 250) self.is_calibrated = True self.show_landmarks = False @@ -239,7 +239,7 @@ def process_image(self, image, use_gpu=True, **kwargs) -> np.array: if not self.opt_hide_rect: draw_rect(preview_frame) - cv2.imshow("FOM", preview_frame[..., ::-1]) + cv2.imshow("FOMM", preview_frame[..., ::-1]) if out is not None: if not self.opt_no_pad: diff --git a/dot/avatarify/predictor_local.py b/dot/fomm/predictor_local.py similarity index 96% rename from dot/avatarify/predictor_local.py rename to dot/fomm/predictor_local.py index e10e2cb..c7803d8 100644 --- a/dot/avatarify/predictor_local.py +++ b/dot/fomm/predictor_local.py @@ -1,172 +1,172 @@ -#!/usr/bin/env python3 - -import face_alignment -import numpy as np -import torch -import yaml -from scipy.spatial import ConvexHull - -from .modules.generator_optim import OcclusionAwareGenerator -from .modules.keypoint_detector import KPDetector - - -def normalize_kp( - kp_source, - kp_driving, - kp_driving_initial, - adapt_movement_scale=False, - use_relative_movement=False, - use_relative_jacobian=False, -): - - if adapt_movement_scale: - source_area = ConvexHull(kp_source["value"][0].data.cpu().numpy()).volume - driving_area = ConvexHull( - kp_driving_initial["value"][0].data.cpu().numpy() - ).volume - adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area) - else: - adapt_movement_scale = 1 - - kp_new = {k: v for k, v in kp_driving.items()} - - if use_relative_movement: - kp_value_diff = kp_driving["value"] - kp_driving_initial["value"] - kp_value_diff *= adapt_movement_scale - kp_new["value"] = kp_value_diff + kp_source["value"] - - if use_relative_jacobian: - jacobian_diff = torch.matmul( - kp_driving["jacobian"], torch.inverse(kp_driving_initial["jacobian"]) - ) - kp_new["jacobian"] = torch.matmul(jacobian_diff, kp_source["jacobian"]) - - return kp_new - - -def to_tensor(a): - return torch.tensor(a[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2) / 255 - - -class PredictorLocal: - def __init__( - self, - config_path, - checkpoint_path, - relative=False, - adapt_movement_scale=False, - device=None, - enc_downscale=1, - ): - - self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") - self.relative = relative - self.adapt_movement_scale = adapt_movement_scale - self.start_frame = None - self.start_frame_kp = None - self.kp_driving_initial = None - self.config_path = config_path - self.checkpoint_path = checkpoint_path - self.generator, self.kp_detector = self.load_checkpoints() - self.fa = face_alignment.FaceAlignment( - face_alignment.LandmarksType._2D, flip_input=True, device=self.device - ) - self.source = None - self.kp_source = None - self.enc_downscale = enc_downscale - - def load_checkpoints(self): - with open(self.config_path) as f: - config = yaml.load(f, Loader=yaml.FullLoader) - - generator = OcclusionAwareGenerator( - **config["model_params"]["generator_params"], - **config["model_params"]["common_params"] - ) - generator.to(self.device) - - kp_detector = KPDetector( - **config["model_params"]["kp_detector_params"], - **config["model_params"]["common_params"] - ) - kp_detector.to(self.device) - - checkpoint = torch.load(self.checkpoint_path, map_location=self.device) - generator.load_state_dict(checkpoint["generator"]) - kp_detector.load_state_dict(checkpoint["kp_detector"]) - - generator.eval() - kp_detector.eval() - - return generator, kp_detector - - def reset_frames(self): - self.kp_driving_initial = None - - def set_source_image(self, source_image): - self.source = to_tensor(source_image).to(self.device) - self.kp_source = self.kp_detector(self.source) - - if self.enc_downscale > 1: - h = int(self.source.shape[2] / self.enc_downscale) - w = int(self.source.shape[3] / self.enc_downscale) - source_enc = torch.nn.functional.interpolate( - self.source, size=(h, w), mode="bilinear" - ) - else: - source_enc = self.source - - self.generator.encode_source(source_enc) - - def predict(self, driving_frame): - assert self.kp_source is not None, "call set_source_image()" - - with torch.no_grad(): - driving = to_tensor(driving_frame).to(self.device) - - if self.kp_driving_initial is None: - self.kp_driving_initial = self.kp_detector(driving) - self.start_frame = driving_frame.copy() - self.start_frame_kp = self.get_frame_kp(driving_frame) - - kp_driving = self.kp_detector(driving) - kp_norm = normalize_kp( - kp_source=self.kp_source, - kp_driving=kp_driving, - kp_driving_initial=self.kp_driving_initial, - use_relative_movement=self.relative, - use_relative_jacobian=self.relative, - adapt_movement_scale=self.adapt_movement_scale, - ) - - out = self.generator( - self.source, kp_source=self.kp_source, kp_driving=kp_norm - ) - - out = np.transpose(out["prediction"].data.cpu().numpy(), [0, 2, 3, 1])[0] - out = (np.clip(out, 0, 1) * 255).astype(np.uint8) - - return out - - def get_frame_kp(self, image): - kp_landmarks = self.fa.get_landmarks(image) - if kp_landmarks: - kp_image = kp_landmarks[0] - kp_image = self.normalize_alignment_kp(kp_image) - return kp_image - else: - return None - - @staticmethod - def normalize_alignment_kp(kp): - kp = kp - kp.mean(axis=0, keepdims=True) - area = ConvexHull(kp[:, :2]).volume - area = np.sqrt(area) - kp[:, :2] = kp[:, :2] / area - return kp - - def get_start_frame(self): - return self.start_frame - - def get_start_frame_kp(self): - return self.start_frame_kp +#!/usr/bin/env python3 + +import face_alignment +import numpy as np +import torch +import yaml +from scipy.spatial import ConvexHull + +from .modules.generator_optim import OcclusionAwareGenerator +from .modules.keypoint_detector import KPDetector + + +def normalize_kp( + kp_source, + kp_driving, + kp_driving_initial, + adapt_movement_scale=False, + use_relative_movement=False, + use_relative_jacobian=False, +): + + if adapt_movement_scale: + source_area = ConvexHull(kp_source["value"][0].data.cpu().numpy()).volume + driving_area = ConvexHull( + kp_driving_initial["value"][0].data.cpu().numpy() + ).volume + adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area) + else: + adapt_movement_scale = 1 + + kp_new = {k: v for k, v in kp_driving.items()} + + if use_relative_movement: + kp_value_diff = kp_driving["value"] - kp_driving_initial["value"] + kp_value_diff *= adapt_movement_scale + kp_new["value"] = kp_value_diff + kp_source["value"] + + if use_relative_jacobian: + jacobian_diff = torch.matmul( + kp_driving["jacobian"], torch.inverse(kp_driving_initial["jacobian"]) + ) + kp_new["jacobian"] = torch.matmul(jacobian_diff, kp_source["jacobian"]) + + return kp_new + + +def to_tensor(a): + return torch.tensor(a[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2) / 255 + + +class PredictorLocal: + def __init__( + self, + config_path, + checkpoint_path, + relative=False, + adapt_movement_scale=False, + device=None, + enc_downscale=1, + ): + + self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") + self.relative = relative + self.adapt_movement_scale = adapt_movement_scale + self.start_frame = None + self.start_frame_kp = None + self.kp_driving_initial = None + self.config_path = config_path + self.checkpoint_path = checkpoint_path + self.generator, self.kp_detector = self.load_checkpoints() + self.fa = face_alignment.FaceAlignment( + face_alignment.LandmarksType._2D, flip_input=True, device=self.device + ) + self.source = None + self.kp_source = None + self.enc_downscale = enc_downscale + + def load_checkpoints(self): + with open(self.config_path) as f: + config = yaml.load(f, Loader=yaml.FullLoader) + + generator = OcclusionAwareGenerator( + **config["model_params"]["generator_params"], + **config["model_params"]["common_params"] + ) + generator.to(self.device) + + kp_detector = KPDetector( + **config["model_params"]["kp_detector_params"], + **config["model_params"]["common_params"] + ) + kp_detector.to(self.device) + + checkpoint = torch.load(self.checkpoint_path, map_location=self.device) + generator.load_state_dict(checkpoint["generator"]) + kp_detector.load_state_dict(checkpoint["kp_detector"]) + + generator.eval() + kp_detector.eval() + + return generator, kp_detector + + def reset_frames(self): + self.kp_driving_initial = None + + def set_source_image(self, source_image): + self.source = to_tensor(source_image).to(self.device) + self.kp_source = self.kp_detector(self.source) + + if self.enc_downscale > 1: + h = int(self.source.shape[2] / self.enc_downscale) + w = int(self.source.shape[3] / self.enc_downscale) + source_enc = torch.nn.functional.interpolate( + self.source, size=(h, w), mode="bilinear" + ) + else: + source_enc = self.source + + self.generator.encode_source(source_enc) + + def predict(self, driving_frame): + assert self.kp_source is not None, "call set_source_image()" + + with torch.no_grad(): + driving = to_tensor(driving_frame).to(self.device) + + if self.kp_driving_initial is None: + self.kp_driving_initial = self.kp_detector(driving) + self.start_frame = driving_frame.copy() + self.start_frame_kp = self.get_frame_kp(driving_frame) + + kp_driving = self.kp_detector(driving) + kp_norm = normalize_kp( + kp_source=self.kp_source, + kp_driving=kp_driving, + kp_driving_initial=self.kp_driving_initial, + use_relative_movement=self.relative, + use_relative_jacobian=self.relative, + adapt_movement_scale=self.adapt_movement_scale, + ) + + out = self.generator( + self.source, kp_source=self.kp_source, kp_driving=kp_norm + ) + + out = np.transpose(out["prediction"].data.cpu().numpy(), [0, 2, 3, 1])[0] + out = (np.clip(out, 0, 1) * 255).astype(np.uint8) + + return out + + def get_frame_kp(self, image): + kp_landmarks = self.fa.get_landmarks(image) + if kp_landmarks: + kp_image = kp_landmarks[0] + kp_image = self.normalize_alignment_kp(kp_image) + return kp_image + else: + return None + + @staticmethod + def normalize_alignment_kp(kp): + kp = kp - kp.mean(axis=0, keepdims=True) + area = ConvexHull(kp[:, :2]).volume + area = np.sqrt(area) + kp[:, :2] = kp[:, :2] / area + return kp + + def get_start_frame(self): + return self.start_frame + + def get_start_frame_kp(self): + return self.start_frame_kp diff --git a/dot/avatarify/sync_batchnorm/__init__.py b/dot/fomm/sync_batchnorm/__init__.py similarity index 96% rename from dot/avatarify/sync_batchnorm/__init__.py rename to dot/fomm/sync_batchnorm/__init__.py index f7d529c..fe0627a 100644 --- a/dot/avatarify/sync_batchnorm/__init__.py +++ b/dot/fomm/sync_batchnorm/__init__.py @@ -1,10 +1,10 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# File : __init__.py -# Author : Jiayuan Mao -# Email : maojiayuan@gmail.com -# Date : 27/01/2018 -# -# This file is part of Synchronized-BatchNorm-PyTorch. -# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch -# Distributed under MIT License. +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# File : __init__.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. diff --git a/dot/avatarify/sync_batchnorm/batchnorm.py b/dot/fomm/sync_batchnorm/batchnorm.py similarity index 97% rename from dot/avatarify/sync_batchnorm/batchnorm.py rename to dot/fomm/sync_batchnorm/batchnorm.py index b1c2d97..9ae06db 100644 --- a/dot/avatarify/sync_batchnorm/batchnorm.py +++ b/dot/fomm/sync_batchnorm/batchnorm.py @@ -1,216 +1,216 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# File : batchnorm.py -# Author : Jiayuan Mao -# Email : maojiayuan@gmail.com -# Date : 27/01/2018 -# -# This file is part of Synchronized-BatchNorm-PyTorch. -# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch -# Distributed under MIT License. - -import collections - -import torch.nn.functional as F -from torch.nn.modules.batchnorm import _BatchNorm -from torch.nn.parallel._functions import Broadcast, ReduceAddCoalesced - -from .comm import SyncMaster - -__all__ = ["SynchronizedBatchNorm2d"] - - -def _sum_ft(tensor): - """sum over the first and last dimention""" - return tensor.sum(dim=0).sum(dim=-1) - - -def _unsqueeze_ft(tensor): - """add new dementions at the front and the tail""" - return tensor.unsqueeze(0).unsqueeze(-1) - - -_ChildMessage = collections.namedtuple("_ChildMessage", ["sum", "ssum", "sum_size"]) - -_MasterMessage = collections.namedtuple("_MasterMessage", ["sum", "inv_std"]) - - -class _SynchronizedBatchNorm(_BatchNorm): - def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): - - super(_SynchronizedBatchNorm, self).__init__( - num_features, eps=eps, momentum=momentum, affine=affine - ) - - self._sync_master = SyncMaster(self._data_parallel_master) - - self._is_parallel = False - self._parallel_id = None - self._slave_pipe = None - - def forward(self, input): - # If it is not parallel computation or is in - # evaluation mode, use PyTorch's implementation. - if not (self._is_parallel and self.training): - return F.batch_norm( - input, - self.running_mean, - self.running_var, - self.weight, - self.bias, - self.training, - self.momentum, - self.eps, - ) - - # Resize the input to (B, C, -1). - input_shape = input.size() - input = input.view(input.size(0), self.num_features, -1) - - # Compute the sum and square-sum. - sum_size = input.size(0) * input.size(2) - input_sum = _sum_ft(input) - input_ssum = _sum_ft(input**2) - - # Reduce-and-broadcast the statistics. - if self._parallel_id == 0: - mean, inv_std = self._sync_master.run_master( - _ChildMessage(input_sum, input_ssum, sum_size) - ) - else: - mean, inv_std = self._slave_pipe.run_slave( - _ChildMessage(input_sum, input_ssum, sum_size) - ) - - # Compute the output. - if self.affine: - # MJY:: Fuse the multiplication for speed. - output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft( - inv_std * self.weight - ) + _unsqueeze_ft(self.bias) - else: - output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) - - # Reshape it. - return output.view(input_shape) - - def __data_parallel_replicate__(self, ctx, copy_id): - self._is_parallel = True - self._parallel_id = copy_id - - # parallel_id == 0 means master device. - if self._parallel_id == 0: - ctx.sync_master = self._sync_master - else: - self._slave_pipe = ctx.sync_master.register_slave(copy_id) - - def _data_parallel_master(self, intermediates): - """Reduce the sum and square-sum, - compute the statistics, and broadcast it.""" - # Always using same "device order" makes the - # ReduceAdd operation faster. - # Thanks to:: Tete Xiao (http://tetexiao.com/) - intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) - - to_reduce = [i[1][:2] for i in intermediates] - to_reduce = [j for i in to_reduce for j in i] # flatten - target_gpus = [i[1].sum.get_device() for i in intermediates] - - sum_size = sum([i[1].sum_size for i in intermediates]) - sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) - mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) - - broadcasted = Broadcast.apply(target_gpus, mean, inv_std) - - outputs = [] - for i, rec in enumerate(intermediates): - outputs.append((rec[0], _MasterMessage(*broadcasted[i * 2 : i * 2 + 2]))) - - return outputs - - def _compute_mean_std(self, sum_, ssum, size): - """Compute the mean and standard-deviation with - sum and square-sum. This method also maintains - the moving average on the master device.""" - assert size > 1, ( - "BatchNorm computes unbiased " - "standard-deviation, which requires size > 1." - ) - mean = sum_ / size - sumvar = ssum - sum_ * mean - unbias_var = sumvar / (size - 1) - bias_var = sumvar / size - - self.running_mean = ( - 1 - self.momentum - ) * self.running_mean + self.momentum * mean.data - self.running_var = ( - 1 - self.momentum - ) * self.running_var + self.momentum * unbias_var.data - - return mean, bias_var.clamp(self.eps) ** -0.5 - - -class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): - r"""Applies Batch Normalization over a 4d input that is seen as a - mini-batch of 3d inputs - - .. math:: - - y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta - - This module differs from the built-in PyTorch BatchNorm2d as the mean and - standard-deviation are reduced across all devices during training. - - For example, when one uses `nn.DataParallel` to wrap the network during - training, PyTorch's implementation normalize the tensor on each device - using the statistics only on that device, which accelerated the - computation and is also easy to implement, but the statistics might - be inaccurate. - Instead, in this synchronized version, the statistics will be computed - over all training samples distributed on multiple devices. - - Note that, for one-GPU or CPU-only case, this module behaves exactly same - as the built-in PyTorch implementation. - - The mean and standard-deviation are calculated per-dimension over - the mini-batches and gamma and beta are learnable parameter vectors - of size C (where C is the input size). - - During training, this layer keeps a running estimate of its computed mean - and variance. The running sum is kept with a default momentum of 0.1. - - During evaluation, this running mean/variance is used for normalization. - - Because the BatchNorm is done over the `C` dimension, computing statistics - on `(N, H, W)` slices, it's common terminology to call this Spatial - BatchNorm - - Args: - num_features: num_features from an expected input of - size batch_size x num_features x height x width - eps: a value added to the denominator for numerical stability. - Default: 1e-5 - momentum: the value used for the running_mean and running_var - computation. Default: 0.1 - affine: a boolean value that when set to ``True``, - gives the layer learnable - affine parameters. Default: ``True`` - - Shape: - - Input: :math:`(N, C, H, W)` - - Output: :math:`(N, C, H, W)` (same shape as input) - - Examples: - >>> # With Learnable Parameters - >>> m = SynchronizedBatchNorm2d(100) - >>> # Without Learnable Parameters - >>> m = SynchronizedBatchNorm2d(100, affine=False) - >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) - >>> output = m(input) - """ - - def _check_input_dim(self, input): - if input.dim() != 4: - raise ValueError("expected 4D input (got {}D input)".format(input.dim())) - super(SynchronizedBatchNorm2d, self)._check_input_dim(input) +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# File : batchnorm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import collections + +import torch.nn.functional as F +from torch.nn.modules.batchnorm import _BatchNorm +from torch.nn.parallel._functions import Broadcast, ReduceAddCoalesced + +from .comm import SyncMaster + +__all__ = ["SynchronizedBatchNorm2d"] + + +def _sum_ft(tensor): + """sum over the first and last dimention""" + return tensor.sum(dim=0).sum(dim=-1) + + +def _unsqueeze_ft(tensor): + """add new dementions at the front and the tail""" + return tensor.unsqueeze(0).unsqueeze(-1) + + +_ChildMessage = collections.namedtuple("_ChildMessage", ["sum", "ssum", "sum_size"]) + +_MasterMessage = collections.namedtuple("_MasterMessage", ["sum", "inv_std"]) + + +class _SynchronizedBatchNorm(_BatchNorm): + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): + + super(_SynchronizedBatchNorm, self).__init__( + num_features, eps=eps, momentum=momentum, affine=affine + ) + + self._sync_master = SyncMaster(self._data_parallel_master) + + self._is_parallel = False + self._parallel_id = None + self._slave_pipe = None + + def forward(self, input): + # If it is not parallel computation or is in + # evaluation mode, use PyTorch's implementation. + if not (self._is_parallel and self.training): + return F.batch_norm( + input, + self.running_mean, + self.running_var, + self.weight, + self.bias, + self.training, + self.momentum, + self.eps, + ) + + # Resize the input to (B, C, -1). + input_shape = input.size() + input = input.view(input.size(0), self.num_features, -1) + + # Compute the sum and square-sum. + sum_size = input.size(0) * input.size(2) + input_sum = _sum_ft(input) + input_ssum = _sum_ft(input**2) + + # Reduce-and-broadcast the statistics. + if self._parallel_id == 0: + mean, inv_std = self._sync_master.run_master( + _ChildMessage(input_sum, input_ssum, sum_size) + ) + else: + mean, inv_std = self._slave_pipe.run_slave( + _ChildMessage(input_sum, input_ssum, sum_size) + ) + + # Compute the output. + if self.affine: + # MJY:: Fuse the multiplication for speed. + output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft( + inv_std * self.weight + ) + _unsqueeze_ft(self.bias) + else: + output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) + + # Reshape it. + return output.view(input_shape) + + def __data_parallel_replicate__(self, ctx, copy_id): + self._is_parallel = True + self._parallel_id = copy_id + + # parallel_id == 0 means master device. + if self._parallel_id == 0: + ctx.sync_master = self._sync_master + else: + self._slave_pipe = ctx.sync_master.register_slave(copy_id) + + def _data_parallel_master(self, intermediates): + """Reduce the sum and square-sum, + compute the statistics, and broadcast it.""" + # Always using same "device order" makes the + # ReduceAdd operation faster. + # Thanks to:: Tete Xiao (http://tetexiao.com/) + intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) + + to_reduce = [i[1][:2] for i in intermediates] + to_reduce = [j for i in to_reduce for j in i] # flatten + target_gpus = [i[1].sum.get_device() for i in intermediates] + + sum_size = sum([i[1].sum_size for i in intermediates]) + sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) + mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) + + broadcasted = Broadcast.apply(target_gpus, mean, inv_std) + + outputs = [] + for i, rec in enumerate(intermediates): + outputs.append((rec[0], _MasterMessage(*broadcasted[i * 2 : i * 2 + 2]))) + + return outputs + + def _compute_mean_std(self, sum_, ssum, size): + """Compute the mean and standard-deviation with + sum and square-sum. This method also maintains + the moving average on the master device.""" + assert size > 1, ( + "BatchNorm computes unbiased " + "standard-deviation, which requires size > 1." + ) + mean = sum_ / size + sumvar = ssum - sum_ * mean + unbias_var = sumvar / (size - 1) + bias_var = sumvar / size + + self.running_mean = ( + 1 - self.momentum + ) * self.running_mean + self.momentum * mean.data + self.running_var = ( + 1 - self.momentum + ) * self.running_var + self.momentum * unbias_var.data + + return mean, bias_var.clamp(self.eps) ** -0.5 + + +class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): + r"""Applies Batch Normalization over a 4d input that is seen as a + mini-batch of 3d inputs + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm2d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device + using the statistics only on that device, which accelerated the + computation and is also easy to implement, but the statistics might + be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, H, W)` slices, it's common terminology to call this Spatial + BatchNorm + + Args: + num_features: num_features from an expected input of + size batch_size x num_features x height x width + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, + gives the layer learnable + affine parameters. Default: ``True`` + + Shape: + - Input: :math:`(N, C, H, W)` + - Output: :math:`(N, C, H, W)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm2d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm2d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 4: + raise ValueError("expected 4D input (got {}D input)".format(input.dim())) + super(SynchronizedBatchNorm2d, self)._check_input_dim(input) diff --git a/dot/avatarify/sync_batchnorm/comm.py b/dot/fomm/sync_batchnorm/comm.py similarity index 96% rename from dot/avatarify/sync_batchnorm/comm.py rename to dot/fomm/sync_batchnorm/comm.py index deb79c3..352dc0f 100644 --- a/dot/avatarify/sync_batchnorm/comm.py +++ b/dot/fomm/sync_batchnorm/comm.py @@ -1,155 +1,155 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# File : comm.py -# Author : Jiayuan Mao -# Email : maojiayuan@gmail.com -# Date : 27/01/2018 -# -# This file is part of Synchronized-BatchNorm-PyTorch. -# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch -# Distributed under MIT License. - -import collections -import queue -import threading - -__all__ = ["FutureResult", "SlavePipe", "SyncMaster"] - - -class FutureResult(object): - """ - A thread-safe future implementation. - Used only as one-to-one pipe. - """ - - def __init__(self): - self._result = None - self._lock = threading.Lock() - self._cond = threading.Condition(self._lock) - - def put(self, result): - with self._lock: - assert self._result is None, "Previous result has't been fetched." - self._result = result - self._cond.notify() - - def get(self): - with self._lock: - if self._result is None: - self._cond.wait() - - res = self._result - self._result = None - return res - - -_MasterRegistry = collections.namedtuple("_MasterRegistry", ["result"]) -_SlavePipeBase = collections.namedtuple( - "_SlavePipeBase", ["identifier", "queue", "result"] -) - - -class SlavePipe(_SlavePipeBase): - """ - Pipe for master-slave communication. - """ - - def run_slave(self, msg): - self.queue.put((self.identifier, msg)) - ret = self.result.get() - self.queue.put(True) - return ret - - -class SyncMaster(object): - """ - An abstract `SyncMaster` object. - - - During the replication, as the data parallel will - trigger an callback of each module, all slave devices should - call `register(id)` and obtain an `SlavePipe` - to communicate with the master. - - During the forward pass, master device invokes - `run_master`, all messages from slave devices - will be collected, and passed to a registered callback. - - After receiving the messages, the master device - should gather the information and determine - to message passed back to each slave devices. - """ - - def __init__(self, master_callback): - """ - Args: - master_callback: a callback to be invoked - after having collected messages from slave devices. - """ - self._master_callback = master_callback - self._queue = queue.Queue() - self._registry = collections.OrderedDict() - self._activated = False - - def __getstate__(self): - return {"master_callback": self._master_callback} - - def __setstate__(self, state): - self.__init__(state["master_callback"]) - - def register_slave(self, identifier): - """ - Register an slave device. - - Args: - identifier: an identifier, usually is the device id. - - Returns: a `SlavePipe` object which can be used - to communicate with the master device. - """ - if self._activated: - assert self._queue.empty(), ( - "Queue is not clean " "before next initialization." - ) - self._activated = False - self._registry.clear() - future = FutureResult() - self._registry[identifier] = _MasterRegistry(future) - return SlavePipe(identifier, self._queue, future) - - def run_master(self, master_msg): - """ - Main entry for the master device in each forward pass. - The messages were first collected from each devices - (including the master device), and then an callback - will be invoked to compute the message to be sent - back to each devices (including the master device). - - Args: - master_msg: the message that the master want to send - to itself. This will be placed as the first message - when calling `master_callback`. - For detailed usage, see `_SynchronizedBatchNorm` - for an example. - - Returns: the message to be sent back to the master device. - """ - self._activated = True - - intermediates = [(0, master_msg)] - for i in range(self.nr_slaves): - intermediates.append(self._queue.get()) - - results = self._master_callback(intermediates) - assert results[0][0] == 0, "The first result " "should belongs to the master." - - for i, res in results: - if i == 0: - continue - self._registry[i].result.put(res) - - for i in range(self.nr_slaves): - assert self._queue.get() is True - - return results[0][1] - - @property - def nr_slaves(self): - return len(self._registry) +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# File : comm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import collections +import queue +import threading + +__all__ = ["FutureResult", "SlavePipe", "SyncMaster"] + + +class FutureResult(object): + """ + A thread-safe future implementation. + Used only as one-to-one pipe. + """ + + def __init__(self): + self._result = None + self._lock = threading.Lock() + self._cond = threading.Condition(self._lock) + + def put(self, result): + with self._lock: + assert self._result is None, "Previous result has't been fetched." + self._result = result + self._cond.notify() + + def get(self): + with self._lock: + if self._result is None: + self._cond.wait() + + res = self._result + self._result = None + return res + + +_MasterRegistry = collections.namedtuple("_MasterRegistry", ["result"]) +_SlavePipeBase = collections.namedtuple( + "_SlavePipeBase", ["identifier", "queue", "result"] +) + + +class SlavePipe(_SlavePipeBase): + """ + Pipe for master-slave communication. + """ + + def run_slave(self, msg): + self.queue.put((self.identifier, msg)) + ret = self.result.get() + self.queue.put(True) + return ret + + +class SyncMaster(object): + """ + An abstract `SyncMaster` object. + + - During the replication, as the data parallel will + trigger an callback of each module, all slave devices should + call `register(id)` and obtain an `SlavePipe` + to communicate with the master. + - During the forward pass, master device invokes + `run_master`, all messages from slave devices + will be collected, and passed to a registered callback. + - After receiving the messages, the master device + should gather the information and determine + to message passed back to each slave devices. + """ + + def __init__(self, master_callback): + """ + Args: + master_callback: a callback to be invoked + after having collected messages from slave devices. + """ + self._master_callback = master_callback + self._queue = queue.Queue() + self._registry = collections.OrderedDict() + self._activated = False + + def __getstate__(self): + return {"master_callback": self._master_callback} + + def __setstate__(self, state): + self.__init__(state["master_callback"]) + + def register_slave(self, identifier): + """ + Register an slave device. + + Args: + identifier: an identifier, usually is the device id. + + Returns: a `SlavePipe` object which can be used + to communicate with the master device. + """ + if self._activated: + assert self._queue.empty(), ( + "Queue is not clean " "before next initialization." + ) + self._activated = False + self._registry.clear() + future = FutureResult() + self._registry[identifier] = _MasterRegistry(future) + return SlavePipe(identifier, self._queue, future) + + def run_master(self, master_msg): + """ + Main entry for the master device in each forward pass. + The messages were first collected from each devices + (including the master device), and then an callback + will be invoked to compute the message to be sent + back to each devices (including the master device). + + Args: + master_msg: the message that the master want to send + to itself. This will be placed as the first message + when calling `master_callback`. + For detailed usage, see `_SynchronizedBatchNorm` + for an example. + + Returns: the message to be sent back to the master device. + """ + self._activated = True + + intermediates = [(0, master_msg)] + for i in range(self.nr_slaves): + intermediates.append(self._queue.get()) + + results = self._master_callback(intermediates) + assert results[0][0] == 0, "The first result " "should belongs to the master." + + for i, res in results: + if i == 0: + continue + self._registry[i].result.put(res) + + for i in range(self.nr_slaves): + assert self._queue.get() is True + + return results[0][1] + + @property + def nr_slaves(self): + return len(self._registry) diff --git a/tests/pipeline_test.py b/tests/pipeline_test.py index f6f00da..a8695d2 100644 --- a/tests/pipeline_test.py +++ b/tests/pipeline_test.py @@ -21,7 +21,7 @@ def setUp(self): self.faceswap_cv2_option = self._dot.faceswap_cv2(False, False, None) - self.avatarify_option = self._dot.avatarify(False, False, None) + self.fomm_option = self._dot.fomm(False, False, None) self.simswap_option = self._dot.simswap(False, False, None) @@ -39,7 +39,7 @@ def test_option_creation(self): assert len(rejected) == 1 success, rejected = self._dot.generate( - self.avatarify_option, + self.fomm_option, "./tests", "./tests", show_fps=False,