-
Notifications
You must be signed in to change notification settings - Fork 3k
Feature: Automated Creation Based on Example for PyTorch Linear Modules with ReLU Activations #126
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 2 commits
cd2f012
419439a
898129f
5c12449
58a67b6
5526745
ec2d32f
f74311e
d1bf4cb
2e4c312
18a5c18
695bef6
01eda11
6b359d1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,94 @@ | ||
|
|
||
| \documentclass[border=8pt, multi, tikz]{standalone} | ||
| \usepackage{import} | ||
| \subimport{../layers/}{init} | ||
| \usetikzlibrary{positioning} | ||
| \usetikzlibrary{3d} %for including external image | ||
|
|
||
| \def\ConvColor{rgb:yellow,5;red,2.5;white,5} | ||
| \def\ConvReluColor{rgb:yellow,5;red,5;white,5} | ||
| \def\PoolColor{rgb:red,1;black,0.3} | ||
| \def\UnpoolColor{rgb:blue,2;green,1;black,0.3} | ||
| \def\FcColor{rgb:blue,5;red,2.5;white,5} | ||
| \def\FcReluColor{rgb:blue,5;red,5;white,4} | ||
| \def\SoftmaxColor{rgb:magenta,5;black,7} | ||
| \def\SumColor{rgb:blue,5;green,15} | ||
|
|
||
| \newcommand{\copymidarrow}{\tikz \draw[-Stealth,line width=0.8mm,draw={rgb:blue,4;red,1;green,1;black,3}] (-0.3,0) -- ++(0.3,0);} | ||
|
|
||
| \begin{document} | ||
| \begin{tikzpicture} | ||
| \tikzstyle{connection}=[ultra thick,every node/.style={sloped,allow upside down},draw=\edgecolor,opacity=0.7] | ||
| \tikzstyle{copyconnection}=[ultra thick,every node/.style={sloped,allow upside down},draw={rgb:blue,4;red,1;green,1;black,3},opacity=0.7] | ||
|
|
||
| \pic[shift={(1, 0, 0)}] at (0, 0, 0) | ||
| {Box={ | ||
| name=module1, | ||
| caption=$\mathrm{{FC}}$, | ||
| xlabel={{16, }}, | ||
| zlabel=, | ||
| fill=\FcColor, | ||
| height=16, | ||
| width=1, | ||
| depth=1 | ||
| } | ||
| }; | ||
|
|
||
| \pic[shift={(0.5, 0, 0)}] at (module1-east) | ||
| {Box={ | ||
| name=module2, | ||
| caption=$\varphi_\mathrm{{ReLU}}$, | ||
| xlabel={{, }}, | ||
| zlabel=, | ||
| fill=\ConvColor, | ||
| height=16, | ||
| width=0.5, | ||
| depth=1 | ||
| } | ||
| }; | ||
|
|
||
| \pic[shift={(1, 0, 0)}] at (module2-east) | ||
| {Box={ | ||
| name=module3, | ||
| caption=$\mathrm{{FC}}$, | ||
| xlabel={{16, }}, | ||
| zlabel=, | ||
| fill=\FcColor, | ||
| height=16, | ||
| width=1, | ||
| depth=1 | ||
| } | ||
| }; | ||
|
|
||
| \draw [connection] (module2-east) -- node {\midarrow} (module3-west); | ||
|
|
||
| \pic[shift={(0.5, 0, 0)}] at (module3-east) | ||
| {Box={ | ||
| name=module4, | ||
| caption=$\varphi_\mathrm{{ReLU}}$, | ||
| xlabel={{, }}, | ||
| zlabel=, | ||
| fill=\ConvColor, | ||
| height=16, | ||
| width=0.5, | ||
| depth=1 | ||
| } | ||
| }; | ||
|
|
||
| \pic[shift={(1, 0, 0)}] at (module4-east) | ||
| {Box={ | ||
| name=module5, | ||
| caption=$\mathrm{{FC}}$, | ||
| xlabel={{1, }}, | ||
| zlabel=, | ||
| fill=\FcColor, | ||
| height=1, | ||
| width=1, | ||
| depth=1 | ||
| } | ||
| }; | ||
|
|
||
| \draw [connection] (module4-east) -- node {\midarrow} (module5-west); | ||
|
|
||
| \end{tikzpicture} | ||
| \end{document} |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,70 @@ | ||
| from torchinfo import summary | ||
|
|
||
| import pycore.tikzeng as pnn | ||
|
|
||
|
|
||
| class TorchArchParser: | ||
|
|
||
| text_mapping = { | ||
| "Linear": "\\mathrm{{FC}}", | ||
| "ReLU": "\\varphi_\\mathrm{{ReLU}}" | ||
| } | ||
|
|
||
| def __init__(self, torch_module, input_size): | ||
|
|
||
git-thor marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| self.torch_module = torch_module | ||
| self.summary_list = summary(self.torch_module, input_size=input_size).summary_list | ||
|
|
||
| self.arch = self.parse(self.summary_list) | ||
|
|
||
| def get_arch(self): | ||
|
|
||
| return self.arch | ||
|
|
||
| @staticmethod | ||
| def parse(summary_list): | ||
|
|
||
| arch = list() | ||
| arch.append(pnn.to_head("..")) | ||
| arch.append(pnn.to_cor()) | ||
| arch.append(pnn.to_begin()) | ||
| for idx, layer in enumerate(summary_list[2:], start=1): | ||
|
|
||
| if layer.class_name == "Linear": | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If more module types would be parsed in the future, having helper functions or a builder class corresponding to the layer.class_name would omit cluttering the
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree and this becomes even more important with increased types of supported PyTorch modules (layers, activations). Though, I have oriented myself on the coding style in Further, I propose to abstract parsing and constructing the tikz code toward a collection of layers that map to a similar tikz representation and a general activation representation that maps to all possible PyTorch-implemented activation functions. Though, all of this could - and imo should - be an improvement and extension on-top of this basic functionaility. |
||
| text = TorchArchParser.text_mapping.get(layer.class_name, "\\mathrm{{FC}}") | ||
| arch_layer = pnn.to_Conv( | ||
| name=f"module{idx}", | ||
| s_filer="", | ||
| n_filer=layer.module.out_features, | ||
| offset=str((1, 0, 0)), | ||
| width=1, | ||
| height=layer.module.out_features, | ||
| depth=1, | ||
| fill_color="\\FcColor", | ||
| caption=f"${text}$", | ||
| to=f"(module{idx-1}-east)" if idx > 1 else str((0, 0, 0)), | ||
| ) | ||
| arch.append(arch_layer) | ||
|
|
||
| if idx > 1: | ||
| arch_layer = pnn.to_connection(f"module{idx-1}", f"module{idx}") | ||
| arch.append(arch_layer) | ||
|
|
||
| if layer.class_name in {"ReLU"}: | ||
git-thor marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| text = TorchArchParser.text_mapping.get(layer.class_name, "\\varphi") | ||
| arch_layer = pnn.to_Conv( | ||
| name=f"module{idx}", | ||
| s_filer="", | ||
| n_filer="", | ||
| offset=str((0.5, 0, 0)), | ||
| width=0.5, | ||
| height=layer.input_size[1], | ||
| depth=layer.input_size[0], | ||
| caption=f"${text}$", | ||
| to=f"(module{idx-1}-east)" if idx > 1 else str((0, 0, 0)), | ||
| ) | ||
| arch.append(arch_layer) | ||
|
|
||
| arch.append(pnn.to_end()) | ||
|
|
||
| return arch | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,45 @@ | ||
| import sys | ||
| sys.path.append('../') | ||
|
|
||
| import torch as th | ||
|
|
||
| from pycore.torchparse import TorchArchParser | ||
| from pycore.tikzeng import to_generate | ||
|
|
||
|
|
||
| DEVICE = th.device('cuda' if th.cuda.is_available() else 'cpu') | ||
git-thor marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| class MLP(th.nn.Module): | ||
|
|
||
| def __init__(self): | ||
|
|
||
| super(MLP, self).__init__() | ||
|
|
||
| self.net = th.nn.Sequential( | ||
| th.nn.Linear(2, 16), | ||
| th.nn.ReLU(), | ||
| th.nn.Linear(16, 16), | ||
| th.nn.ReLU(), | ||
| th.nn.Linear(16, 1) | ||
| ) | ||
|
|
||
| def forward(self, x): | ||
|
|
||
| x = x.view(-1, 2) | ||
| y_hat = self.net(x) | ||
|
|
||
| return y_hat.view(-1, 1) | ||
|
|
||
|
|
||
| def main(): | ||
|
|
||
| mlp = MLP() | ||
| parser = TorchArchParser(torch_module=mlp, input_size=(1, 2)) | ||
| arch = parser.get_arch() | ||
| to_generate(arch, pathname="./test_torch_mlp.tex") | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
|
|
||
| main() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,2 @@ | ||
| ## The following requirements were added by pip freeze: | ||
| torchinfo[pytorch]==1.6.2 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,6 @@ | ||
| #!/bin/bash | ||
|
|
||
|
|
||
| python $1.py | ||
| python $1.py | ||
| pdflatex $1.tex | ||
|
|
||
| rm *.aux *.log *.vscodeLog | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.