Skip to content

Commit

Permalink
Fix (examples/ptq): fix arguments typing and names
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Feb 19, 2024
1 parent 2369645 commit 8b3bbc5
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: BSD-3-Clause

import argparse
from functools import partial
import os
import random
import warnings
Expand Down Expand Up @@ -38,6 +39,14 @@
# Ignore warnings about __torch_function__
warnings.filterwarnings("ignore")


def parse_type(v, default_type):
if v == 'None':
return None
else:
return default_type(v)


model_names = sorted(
name for name in torchvision.models.__dict__ if name.islower() and not name.startswith("__") and
callable(torchvision.models.__dict__[name]) and not name.startswith("get_"))
Expand Down Expand Up @@ -96,7 +105,7 @@
parser.add_argument(
'--bias-bit-width',
default=32,
type=int,
type=partial(parse_type, default_type=int),
choices=[32, 16, None],
help='Bias bit width (default: 32)')
parser.add_argument(
Expand Down Expand Up @@ -234,7 +243,7 @@
add_bool_arg(parser, 'calibrate-bn', default=False, help='Calibrate BN (default: disabled)')
add_bool_arg(
parser,
'split-input',
'channel-splitting-split-input',
default=False,
help='Input Channels Splitting for channel splitting (default: disabled)')
add_bool_arg(
Expand Down Expand Up @@ -311,7 +320,7 @@ def main():
f"Weight quant calibration type: {args.weight_quant_calibration_type} - "
f"Calibrate BN: {args.calibrate_bn} - "
f"Channel Splitting Ratio: {args.channel_splitting_ratio} - "
f"Split Input: {args.split_input} - "
f"Split Input: {args.channel_splitting_split_input} - "
f"Merge BN: {args.merge_bn}")

# Get model-specific configurations about input shapes and normalization
Expand Down Expand Up @@ -358,7 +367,7 @@ def main():
equalize_merge_bias=args.graph_eq_merge_bias,
merge_bn=args.merge_bn,
channel_splitting_ratio=args.channel_splitting_ratio,
channel_splitting_split_input=args.split_input)
channel_splitting_split_input=args.channel_splitting_split_input)
else:
raise RuntimeError(f"{args.target_backend} backend not supported.")

Expand Down

0 comments on commit 8b3bbc5

Please # to comment.