Skip to content
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
0756f39
Make module swap the main QAT flow again
andrewor14 Oct 4, 2024
9e9fdef
Add generic fake quantized linear for QAT
andrewor14 Oct 4, 2024
7f623a5
Update base for Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 8, 2024
c8f9f37
Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 8, 2024
75fcd21
Update base for Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 8, 2024
9185cc4
Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 8, 2024
d671826
Update base for Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 8, 2024
59b6644
Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 8, 2024
d4332cb
Update base for Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 8, 2024
8e5d2ea
Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 8, 2024
dbad878
Update base for Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 8, 2024
ab43744
Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 8, 2024
d6750a9
Update base for Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 9, 2024
15a3d81
Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 9, 2024
0153d66
Update base for Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 11, 2024
8de3ba6
Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 11, 2024
c18c60f
Update base for Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 11, 2024
ef4f062
Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 11, 2024
8f48663
Update base for Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 11, 2024
e442439
Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 11, 2024
4239d47
Update base for Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 11, 2024
b0c6cc7
Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 11, 2024
75c83ef
Update base for Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 11, 2024
39ebc46
Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 11, 2024
e08517c
Update base for Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 11, 2024
f9286c5
Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 11, 2024
d0d9573
Update base for Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 11, 2024
c0ed9ed
Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 11, 2024
f9a2f4c
Update base for Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 14, 2024
83e2f10
Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 14, 2024
5b4feb0
Update base for Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 14, 2024
fbc0259
Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 14, 2024
756cb8d
Update base for Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 14, 2024
5642f44
Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 14, 2024
622b6df
Update base for Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 14, 2024
b5fe5a7
Update on "Add generic fake quantized linear for QAT"
andrewor14 Oct 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions test/dtypes/test_affine_quantized_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,17 @@
float8_weight_only,
quantize_,
)
from torchao.quantization.observer import PerRow, PerTensor
from torchao.quantization.granularity import (
PerRow,
PerTensor,
)
from torchao.quantization.quant_api import (
float8_static_activation_float8_weight,
)
from torchao.quantization.quant_primitives import MappingType, choose_qparams_affine
from torchao.quantization.quant_primitives import (
MappingType,
choose_qparams_affine,
)

random.seed(0)
torch.manual_seed(0)
Expand Down
22 changes: 12 additions & 10 deletions test/quantization/test_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
from torch.testing._internal import common_utils
from torch.testing._internal.common_utils import TestCase

from torchao.quantization.observer import (
AffineQuantizedMinMaxObserver,
from torchao.quantization.granularity import (
PerAxis,
PerTensor,
)
from torchao.quantization.observer import (
AffineQuantizedMinMaxObserver,
)
from torchao.quantization.quant_api import (
insert_observers_,
)
Expand Down Expand Up @@ -42,7 +44,7 @@ def test_min_max_per_tensor_affine(self):
obs = AffineQuantizedMinMaxObserver(
MappingType.ASYMMETRIC,
torch.uint8,
granularity_type=PerTensor(),
granularity=PerTensor(),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
Expand All @@ -54,7 +56,7 @@ def test_min_max_per_channel_affine(self):
obs = AffineQuantizedMinMaxObserver(
MappingType.ASYMMETRIC,
torch.uint8,
granularity_type=PerAxis(axis=0),
granularity=PerAxis(axis=0),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
Expand All @@ -68,7 +70,7 @@ def test_block_size_calc_success(self):
obs = AffineQuantizedMinMaxObserver(
MappingType.SYMMETRIC,
torch.float8_e4m3fn,
granularity_type=PerTensor(),
granularity=PerTensor(),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
Expand All @@ -87,7 +89,7 @@ def test_block_size_calc_success(self):
obs = AffineQuantizedMinMaxObserver(
MappingType.SYMMETRIC,
torch.float8_e4m3fn,
granularity_type=PerAxis(1),
granularity=PerAxis(1),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
Expand All @@ -102,7 +104,7 @@ def test_block_size_row_errors(self):
obs = AffineQuantizedMinMaxObserver(
MappingType.SYMMETRIC,
torch.float8_e4m3fn,
granularity_type=PerAxis(0),
granularity=PerAxis(0),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
Expand All @@ -121,7 +123,7 @@ def test_block_size_row_errors(self):
obs = AffineQuantizedMinMaxObserver(
MappingType.SYMMETRIC,
torch.float8_e4m3fn,
granularity_type=PerAxis(1),
granularity=PerAxis(1),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
Expand Down Expand Up @@ -149,7 +151,7 @@ def test_linear_observer_tensor(self, observe_weight: bool):
input_observer = AffineQuantizedMinMaxObserver(
MappingType.SYMMETRIC,
torch.float8_e4m3fn,
granularity_type=PerTensor(),
granularity=PerTensor(),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
Expand All @@ -159,7 +161,7 @@ def test_linear_observer_tensor(self, observe_weight: bool):
weight_observer = AffineQuantizedMinMaxObserver(
MappingType.SYMMETRIC,
torch.float8_e4m3fn,
granularity_type=PerTensor(),
granularity=PerTensor(),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
Expand Down
Loading