-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathwidget_2dgraph.py
146 lines (129 loc) · 5.64 KB
/
widget_2dgraph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
#2D Graph Widget
# Imports
from PyQt5 import QtWidgets
from matplotlib.backend_bases import MouseEvent
from matplotlib.figure import Figure
import matplotlib.pyplot as plt
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as Canvas
import matplotlib
import numpy as np
import random
from scipy.interpolate import griddata
# Ensure using PyQt5 backend
matplotlib.use('QT5Agg')
# Matplotlib canvas class to create figure
class MplCanvas2dGraph(Canvas):
def __init__(self):
# Stores the x, y, output values for the neural network
self.inputPoints = []
# red = 0, blue = 1 for the outputs
self.outputValues = []
# Stores the scatter plot points so they can visibly be removed
self.scatterpoints = []
# Heat map graphic
self.heatmap = None
self.fig = Figure()
self.axes = self.fig.add_subplot(111)
self.axes.set_xlabel('Width Of Flower Petal')
self.axes.set_ylabel('Length Of Flower Petal')
self.axes.set_xlim(0,10)
self.axes.set_ylim(0,10)
self.fig.canvas.mpl_connect('button_press_event', self.on_mouse_click)
Canvas.__init__(self, self.fig)
Canvas.setSizePolicy(self, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding)
Canvas.updateGeometry(self)
#for debugging only
#def onclick(self, event):
# print('%s click: button=%d, x=%d, y=%d, xdata=%f, ydata=%f' %('double' if event.dblclick else 'single', event.button, event.x, event.y, event.xdata, event.ydata))
def on_mouse_click(self, event):
# left click (red, 0)
if event.button == 1:
self.inputPoints.append([event.xdata,event.ydata])
self.outputValues.append([0])
self.scatterpoints.append(self.axes.scatter(event.xdata, event.ydata, color='red'))
# middle click
elif event.button == 2:
for i in range(0,10):
# Generate random x, y values
x = random.uniform(0, 10)
y = random.uniform(0, 10)
self.inputPoints.append([x,y])
self.outputValues.append([0])
self.scatterpoints.append(self.axes.scatter(x, y, c='red'))
# Generate second set of random x, y values
x = random.uniform(0, 10)
y = random.uniform(0, 10)
self.inputPoints.append([x,y])
self.outputValues.append([1])
self.scatterpoints.append(self.axes.scatter(x, y, c='blue'))
# Right click (blue, 1)
elif event.button == 3:
self.inputPoints.append([event.xdata,event.ydata])
self.outputValues.append([1])
self.scatterpoints.append(self.axes.scatter(event.xdata, event.ydata, color='blue'))
# Draw the image. must be redrawn with every change. it does not happen automatically
self.figure.canvas.draw()
def clearGraph(self):
# Clear the heatmap
if self.heatmap:
for element in self.heatmap.collections:
element.remove()
# Just so the boolean conditional above works
self.heatmap = None
# Remove all the points one by one
for point in self.scatterpoints:
point.remove()
# Clear the input/output lists
self.inputPoints.clear()
self.outputValues.clear()
# Clear the point list
self.scatterpoints.clear()
# Redraw the image.
self.figure.canvas.draw()
def getGraphValues(self):
#print(self.inputPoints)
#print(self.outputValues)
return self.inputPoints, self.outputValues
def updateHeatMap(self, points):
if self.heatmap:
for element in self.heatmap.collections:
element.remove()
# Just so the boolean conditional above works
self.heatmap = None
x = points[:, 0]
y = points[:, 1]
z = points[:, 2]
resolution = 50
contour_method = 'linear'
resolution = str(resolution)+'j'
X,Y = np.mgrid[min(x):max(x):complex(resolution), min(y):max(y):complex(resolution)]
points = [[a,b] for a,b in zip(x,y)]
Z = griddata(points, z, (X, Y), method=contour_method)
self.heatmap = self.axes.contourf(X,Y,Z, vmin=0, vmax=1, cmap=matplotlib.cm.get_cmap('seismic_r'), alpha=0.65, zorder=0)
self.figure.canvas.draw()
def toggleHeatMap(self, points):
if self.heatmap:
for element in self.heatmap.collections:
element.remove()
# Just so the boolean conditional above works
self.heatmap = None
else:
x = points[:, 0]
y = points[:, 1]
z = points[:, 2]
resolution = 50
contour_method = 'linear'
resolution = str(resolution)+'j'
X,Y = np.mgrid[min(x):max(x):complex(resolution), min(y):max(y):complex(resolution)]
points = [[a,b] for a,b in zip(x,y)]
Z = griddata(points, z, (X, Y), method=contour_method)
self.heatmap = self.axes.contourf(X,Y,Z, vmin=0, vmax=1, cmap=matplotlib.cm.get_cmap('seismic_r'), alpha=0.65, zorder=0)
self.figure.canvas.draw()
# Matplotlib widget
class Widget2dGraph(QtWidgets.QWidget):
def __init__(self, parent=None):
QtWidgets.QWidget.__init__(self, parent) # Inherit from QWidget
self.canvas = MplCanvas2dGraph() # Create canvas object
self.vbl = QtWidgets.QVBoxLayout() # Set box for plotting
self.vbl.addWidget(self.canvas)
self.setLayout(self.vbl)