Skip to content

Commit

Permalink
Add option to disable plotting the linear scaling fit
Browse files Browse the repository at this point in the history
  • Loading branch information
mimischi committed Oct 22, 2018
1 parent f3bc35a commit 000d23a
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 15 deletions.
21 changes: 15 additions & 6 deletions mdbenchmark/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,20 +74,20 @@ def plot_projection(df, selection, color, ax=None):
return ax


def plot_line(df, selection, label, ax=None):
def plot_line(df, selection, label, fit, ax=None):
if ax is None:
ax = plt.gca()

p = ax.plot(selection, "ns/day", ".-", data=df, ms="10", label=label)
color = p[0].get_color()

if len(df[selection]) > 1:
if fit and (len(df[selection]) > 1):
plot_projection(df=df, selection=selection, color=color, ax=ax)

return ax


def plot_over_group(df, plot_cores, ax=None):
def plot_over_group(df, plot_cores, fit, ax=None):
# plot all lines
selection = "ncores" if plot_cores else "nodes"

Expand All @@ -101,7 +101,7 @@ def plot_over_group(df, plot_cores, ax=None):
label = "{template} - {module} on {pu}s".format(
template=template, module=module, pu=pu
)
plot_line(df=df, selection=selection, ax=ax, label=label)
plot_line(df=df, selection=selection, ax=ax, fit=fit, label=label)

# style axes
xlabel = "cores" if plot_cores else "nodes"
Expand Down Expand Up @@ -221,6 +221,12 @@ def filter_dataframe_for_plotting(df, host_name, module_name, gpu, cpu):
show_default=True,
is_flag=True,
)
@click.option(
"--fit/--no-fit",
help="Fit a line through the first two data points, indicating linear scaling.",
show_default=True,
default=True,
)
@click.option(
"--font-size", help="Font size for generated plot.", default=16, show_default=True
)
Expand Down Expand Up @@ -249,6 +255,7 @@ def plot(
gpu,
cpu,
plot_cores,
fit,
font_size,
dpi,
xtick_step,
Expand All @@ -261,7 +268,9 @@ def plot(
command.
You can customize the filename and file format of the generated plot with
the `--output-name` and `--format` option, respectively.
the `--output-name` and `--format` option, respectively. Per default, a fit
will be plotted through the first data points of each benchmark group. To
disable the fit, use the `--no-fit` option.
To only plot specific benchmarks, make use of the `--module`, `--template`,
`--cpu/--no-cpu` and `--gpu/--no-gpu` options.
Expand All @@ -284,7 +293,7 @@ def plot(
fig = Figure()
FigureCanvas(fig)
ax = fig.add_subplot(111)
ax = plot_over_group(df, plot_cores, ax=ax)
ax = plot_over_group(df=df, plot_cores=plot_cores, fit=fit, ax=ax)

# Update xticks
selection = "ncores" if plot_cores else "nodes"
Expand Down
17 changes: 8 additions & 9 deletions mdbenchmark/tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,18 @@
import os

import click

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pytest
from numpy.testing import assert_equal
from pandas.testing import assert_frame_equal

from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from matplotlib.figure import Figure
from mdbenchmark import cli, plot, utils
from mdbenchmark.ext.click_test import cli_runner
from mdbenchmark.testing import data

import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from matplotlib.figure import Figure
from numpy.testing import assert_equal
from pandas.testing import assert_frame_equal


@pytest.mark.parametrize(
Expand Down Expand Up @@ -321,7 +320,7 @@ def test_plot_plot_line(capsys, cli_runner, tmpdir, data):
fig = Figure()
FigureCanvas(fig)
ax = fig.add_subplot(111)
plot.plot_line(df=df, selection=selection, label=label, ax=ax)
plot.plot_line(df=df, selection=selection, label=label, fit=True, ax=ax)


def test_plot_plot_line_singlepoint(capsys, cli_runner, tmpdir, data):
Expand All @@ -334,4 +333,4 @@ def test_plot_plot_line_singlepoint(capsys, cli_runner, tmpdir, data):
fig = Figure()
FigureCanvas(fig)
ax = fig.add_subplot(111)
plot.plot_line(df=df, selection=selection, label=label, ax=ax)
plot.plot_line(df=df, selection=selection, label=label, fit=True, ax=ax)

0 comments on commit 000d23a

Please # to comment.