diff --git a/Algorithm/OCR/digits/LeNet.py b/Algorithm/OCR/digits/LeNet.py new file mode 100644 index 0000000..b962abe --- /dev/null +++ b/Algorithm/OCR/digits/LeNet.py @@ -0,0 +1,133 @@ +import torch.nn as nn +from torch.nn import init + +class myNet(nn.Module): + def __init__(self): + super(myNet, self).__init__() + self.conv1_1 = nn.Sequential( # input_size=(1*28*28) + nn.Conv2d( + in_channels=1, + out_channels=16, + kernel_size=3, + padding=1 + ), # padding=2保证输入输出尺寸相同 + ) + self.BN = nn.BatchNorm2d(1, momentum=0.5) + self.conv1_2 = nn.Sequential( # input_size=(16*28*28) + nn.Conv2d( + in_channels=16, + out_channels=16, + kernel_size=3, + padding=1 + ), # padding=2保证输入输出尺寸相同 + nn.ReLU(), # input_size=(16*28*28) + nn.MaxPool2d(kernel_size=2, stride=2) # output_size=(16*14*14) + ) + self.conv2_1 = nn.Sequential( + nn.Conv2d( + in_channels=16, + out_channels=32, + kernel_size=3), + ) # output 32*12*12 + self.conv2_2 = nn.Sequential( + nn.Conv2d( + in_channels=32, + out_channels=32, + kernel_size=3), + nn.ReLU(), # input_size=(32*10*10) + nn.MaxPool2d(2, 2) # output_size=(8*5*5) + ) + self.fc1 = nn.Linear(32 * 5 * 5, 128) + self.relu1 = nn.ReLU() + self.dropout = nn.Dropout(0.2) + self._set_init(self.fc1) + self.fc2 = nn.Linear(128, 11) + self._set_init(self.fc2) + self.softmax = nn.LogSoftmax(dim=1) + + def _set_init(self, layer): # 参数初始化 + init.normal_(layer.weight, mean=0., std=.1) + + # 定义前向传播过程,输入为x + def forward(self, x): + # x = x.view(-1, 28, 28) + # x = self.BN(x) + # x = x.view(-1, 1, 28, 28) + x = self.conv1_1(x) + x = self.conv1_2(x) + x = self.conv2_1(x) + x = self.conv2_2(x) + x = x.view(x.size()[0], -1) + x = self.fc1(x) + x = self.relu1(x) + x = self.dropout(x) + x = self.fc2(x) + return self.softmax(x) + +class rgbNet(nn.Module): + def __init__(self, imtype): + super(rgbNet, self).__init__() + if imtype == 'bit': + n = 1 + elif imtype == 'rgb': + n = 3 + + self.conv1_1 = nn.Sequential( # input_size=(1*28*28) + nn.Conv2d( + in_channels=n, + out_channels=16, + kernel_size=3, + padding=1 + ), # padding=2保证输入输出尺寸相同 + ) + self.BN = nn.BatchNorm2d(1, momentum=0.5) + self.conv1_2 = nn.Sequential( # input_size=(16*28*28) + nn.Conv2d( + in_channels=16, + out_channels=16, + kernel_size=3, + padding=1 + ), # padding=2保证输入输出尺寸相同 + nn.ReLU(), # input_size=(16*28*28) + nn.MaxPool2d(kernel_size=2, stride=2) # output_size=(16*14*14) + ) + self.conv2_1 = nn.Sequential( + nn.Conv2d( + in_channels=16, + out_channels=32, + kernel_size=3), + ) # output 32*12*12 + self.conv2_2 = nn.Sequential( + nn.Conv2d( + in_channels=32, + out_channels=32, + kernel_size=3), + nn.ReLU(), # input_size=(32*10*10) + nn.MaxPool2d(2, 2) # output_size=(8*5*5) + ) + self.fc1 = nn.Linear(32 * 5 * 5, 128) + self.relu1 = nn.ReLU() + self.dropout = nn.Dropout(0.2) + self._set_init(self.fc1) + self.fc2 = nn.Linear(128, 11) + self._set_init(self.fc2) + self.softmax = nn.LogSoftmax(dim=1) + + def _set_init(self, layer): # 参数初始化 + init.normal_(layer.weight, mean=0., std=.1) + + # 定义前向传播过程,输入为x + def forward(self, x): + # x = x.view(-1, 28, 28) + # x = self.BN(x) + # x = x.view(-1, 1, 28, 28) + x = self.conv1_1(x) + x = self.conv1_2(x) + x = self.conv2_1(x) + x = self.conv2_2(x) + x = x.view(x.size()[0], -1) + x = self.fc1(x) + x = self.relu1(x) + x = self.dropout(x) + x = self.fc2(x) + return self.softmax(x) \ No newline at end of file diff --git a/Algorithm/OCR/newNet/__init__.py b/Algorithm/OCR/digits/__init__.py similarity index 100% rename from Algorithm/OCR/newNet/__init__.py rename to Algorithm/OCR/digits/__init__.py diff --git a/Algorithm/OCR/digits/dataLoader.py b/Algorithm/OCR/digits/dataLoader.py new file mode 100644 index 0000000..3a566b4 --- /dev/null +++ b/Algorithm/OCR/digits/dataLoader.py @@ -0,0 +1,128 @@ +import numpy as np +import torch +import os +import cv2 +import pickle +import random + + +class dataLoader(): + # todo 调整batch,使每个batch顺序都不一样 + def __init__(self, type, path, bs, ifUpdate): + self.meanPixel = 80 + self.bs = bs + self.pointer = 0 + self.type = type + if type == 'rgb': + self.train_path = os.path.join(path, "rgb_augment_train.pkl") + self.test_path = os.path.join(path, "rgb_test.pkl") + elif type == 'bit': + self.train_path = os.path.join(path, "bit_augment_train.pkl") + self.test_path = os.path.join(path, "bit_test.pkl") + + if not os.path.exists(self.train_path) or not os.path.exists(self.test_path): + ifUpdate = True + + if ifUpdate: + os.system("rm -rf {}".format(self.train_path)) + os.system("rm -rf {}".format(self.test_path)) + self.readImagesFromMultiFils(path) + + self.readDataFromPkl() + self.shuffle() + + def readImagesFromMultiFils(self, path): + for t in ["rgb_augmentation", "rgb_test"]: + if self.type == 'bit': + data = torch.Tensor(np.zeros((1, 1, 28, 28))) + elif self.type == 'rgb': + data = torch.Tensor(np.zeros((1, 3, 28, 28))) + label = [] + names = [] + + for i in range(11): + root = path + "/" + t + "/" + str(i) + "/" + images = os.listdir(root) + for im in images: + if im.split(".")[-1] != "bmp": + continue + # print(img.shape) + names.append(root+im) + if self.type == "bit": + img = cv2.imread(root + im)[:, :, 0] + img = cv2.resize(img, (28, 28), interpolation=cv2.INTER_CUBIC) + img = cv2.adaptiveThreshold(img, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY_INV, 55, 11) + # 增强 + kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 2)) + img = cv2.morphologyEx(img, cv2.MORPH_OPEN, kernel) + # _, img = cv2.threshold(img, 127, 255, cv2.THRESH_BINARY) + elif self.type == 'rgb': + img = cv2.imread(root + im) + img = cv2.resize(img, (28, 28), interpolation=cv2.INTER_CUBIC) + if len(img.shape) == 2: + img = np.array(img, img, img) + print("convert to rgb: ", img.shape) + + temp = torch.Tensor(img).view(1, 3, 28, 28) - self.meanPixel + data = torch.cat((data, temp), 0) + + label.append(i) + if t.endswith("test"): + fp = open(self.test_path, "wb") + pickle.dump([data[1:], torch.Tensor(np.array(label)).long(), names], fp) + else: + fp = open(self.train_path, "wb") + pickle.dump([data[1:], torch.Tensor(np.array(label)).long()], fp) + fp.close() + + def readDataFromPkl(self): + with open(self.train_path, "rb") as fp: + self.trainData, self.trainLabel = pickle.load(fp) + with open(self.test_path, "rb") as fp: + self.testData, self.testLabel, self.names = pickle.load(fp) + + def getTrainData(self): + return self.trainData, self.trainLabel + + def getTestData(self): + return self.testData, self.testLabel, self.names + + def shuffle(self): + li = list(range(self.trainData.shape[0])) + random.shuffle(li) + self.trainData = self.trainData[li] + self.trainLabel = self.trainLabel[li] + + def next_batch(self): + if self.pointer * self.bs == self.trainData.shape[0]: + self.pointer = 0 + + if (self.pointer + 1) * self.bs > self.trainData.shape[0]: + temp = self.pointer + self.pointer = 0 + return self.trainData[temp * self.bs:], \ + self.trainLabel[temp * self.bs:] + + temp = self.pointer + self.pointer += 1 + + return self.trainData[temp * self.bs:self.pointer * self.bs], \ + self.trainLabel[temp * self.bs:self.pointer * self.bs] + + def get_rounds(self): + return int(self.trainData.shape[0] / self.bs) + 1 + + +# if __name__ == "__main__": +# dl = dataLoader("dataset/", 64, True) +# dl.shuffle() +# train, trl = dl.getTrainData() +# test, tel = dl.getTestData() +# # +# print(train.shape, trl) +# print(test.shape, tel) +# # +# +# print(dl.trainLabel) +# dl.shuffle() +# print(dl.trainLabel) diff --git a/Algorithm/OCR/digits/data_augmentation.py b/Algorithm/OCR/digits/data_augmentation.py new file mode 100644 index 0000000..b68b23f --- /dev/null +++ b/Algorithm/OCR/digits/data_augmentation.py @@ -0,0 +1,132 @@ +import os +import cv2 +import numpy as np +import random +from PIL import Image, ImageEnhance + +''' +定义hsv变换函数: +hue_delta是色调变化比例 +sat_delta是饱和度变化比例 +val_delta是明度变化比例 +''' +def hsv_transform(img, hue_delta, sat_mult, val_mult): + img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV).astype(np.float) + img_hsv[:, :, 0] = (img_hsv[:, :, 0] + hue_delta) % 180 + img_hsv[:, :, 1] *= sat_mult + img_hsv[:, :, 2] *= val_mult + img_hsv[img_hsv > 255] = 255 + return cv2.cvtColor(np.round(img_hsv).astype(np.uint8), cv2.COLOR_HSV2BGR) + +''' +随机hsv变换 +hue_vari是色调变化比例的范围 +sat_vari是饱和度变化比例的范围 +val_vari是明度变化比例的范围 +''' +def random_hsv_transform(img, hue_vari=10, sat_vari=0.1, val_vari=0.1): + hue_delta = np.random.randint(-hue_vari, hue_vari) + sat_mult = 1 + np.random.uniform(-sat_vari, sat_vari) + val_mult = 1 + np.random.uniform(-val_vari, val_vari) + return hsv_transform(img, hue_delta, sat_mult, val_mult) + +''' +定义gamma变换函数: +gamma就是Gamma +''' +def gamma_transform(img, gamma=1.0): + gamma_table = [np.power(x / 255.0, gamma) * 255.0 for x in range(256)] + gamma_table = np.round(np.array(gamma_table)).astype(np.uint8) + return cv2.LUT(img, gamma_table) + +''' +随机gamma变换 +gamma_vari是Gamma变化的范围[1/gamma_vari, gamma_vari) +''' +def random_gamma_transform(img, gamma_vari=2.0): + log_gamma_vari = np.log(gamma_vari) + alpha = np.random.uniform(-log_gamma_vari, log_gamma_vari) + gamma = np.exp(alpha) + return gamma_transform(img, gamma) + +def randomGaussian(image, mean=0.2, sigma=0.3): + """ + 对图像进行高斯噪声处理 + :param image: + :return: + """ + def gaussianNoisy(im, mean=0.2, sigma=0.3): + """ + 对图像做高斯噪音处理 + :param im: 单通道图像 + :param mean: 偏移量 + :param sigma: 标准差 + :return: + """ + for _i in range(len(im)): + im[_i] += random.gauss(mean, sigma) + return im + + # 将图像转化成数组 + img = np.asarray(image) + img.flags.writeable = True # 将数组改为读写模式 + width, height = img.shape[:2] + img_r = gaussianNoisy(img[:, :, 0].flatten(), mean, sigma) + img_g = gaussianNoisy(img[:, :, 1].flatten(), mean, sigma) + img_b = gaussianNoisy(img[:, :, 2].flatten(), mean, sigma) + img[:, :, 0] = img_r.reshape([width, height]) + img[:, :, 1] = img_g.reshape([width, height]) + img[:, :, 2] = img_b.reshape([width, height]) + return np.uint8(img) + +def randomColor(image): + """ + 对图像进行颜色抖动 + :param image: PIL的图像image + :return: 有颜色色差的图像image + """ + image = Image.fromarray(image) + random_factor = np.random.randint(0, 31) / 10. # 随机因子 + color_image = ImageEnhance.Color(image).enhance(random_factor) # 调整图像的饱和度 + random_factor = np.random.randint(10, 21) / 10. # 随机因子 + brightness_image = ImageEnhance.Brightness(color_image).enhance(random_factor) # 调整图像的亮度 + random_factor = np.random.randint(10, 21) / 10. # 随机因1子 + contrast_image = ImageEnhance.Contrast(brightness_image).enhance(random_factor) # 调整图像对比度 + random_factor = np.random.randint(0, 31) / 10. # 随机因子 + img = ImageEnhance.Sharpness(contrast_image).enhance(random_factor) # 调整图像锐度 + return np.array(img) + + +def augmentation(origin, dest): + for sub in os.listdir(origin): + subpath = os.path.join(origin, sub) + destpath = os.path.join(dest, sub) + if not os.path.exists(destpath): + os.makedirs(destpath) + for file in os.listdir(subpath): + filename = os.path.join(subpath, file) + img = cv2.imread(filename) + cv2.imwrite(os.path.join(destpath, file[:-4]+"_origin.bmp"), img) + # 随机hsv变换 + img_hsv = random_hsv_transform(img.copy()) + destname = os.path.join(destpath, file[:-4]+"_hsv.bmp") + cv2.imwrite(destname, img_hsv) + # 随机gamma变换 + img_gamma = random_gamma_transform(img.copy()) + destname = os.path.join(destpath, file[:-4] + "_gamma.bmp") + cv2.imwrite(destname, img_gamma) + # 对图像进行颜色抖动 + img_color = randomColor(img.copy()) + destname = os.path.join(destpath, file[:-4]+"_color.bmp") + cv2.imwrite(destname, img_color) + # 对图像进行高斯噪声处理 + img_gaussian = randomGaussian(img.copy()) + destname = os.path.join(destpath, file[:-4] + "_gaussian.bmp") + cv2.imwrite(destname, img_gaussian) + + + + +origin = "dataset/rgb_train" +dest = "dataset/rgb_augmentation" +augmentation(origin, dest) \ No newline at end of file diff --git a/Algorithm/OCR/newNet/net.pkl b/Algorithm/OCR/digits/model/net.pkl similarity index 100% rename from Algorithm/OCR/newNet/net.pkl rename to Algorithm/OCR/digits/model/net.pkl diff --git a/Algorithm/OCR/digits/model/rgb_myNet_net.pkl b/Algorithm/OCR/digits/model/rgb_myNet_net.pkl new file mode 100644 index 0000000..72c8283 Binary files /dev/null and b/Algorithm/OCR/digits/model/rgb_myNet_net.pkl differ diff --git a/Algorithm/OCR/digits/train.py b/Algorithm/OCR/digits/train.py new file mode 100644 index 0000000..c506a91 --- /dev/null +++ b/Algorithm/OCR/digits/train.py @@ -0,0 +1,120 @@ +import sys +import os +import cv2 +import numpy as np +import torch +import torch.optim as optim +import torch.nn.init as init +import matplotlib.pyplot as plt +# from tensorboardX import SummaryWriter + +from dataLoader import dataLoader +from net import rgbNet + +def train(type): + + torch.manual_seed(10) + + bs = 256 + lr = 0.001 + epoch = 40 + stepLength = 20 + classes = 11 + + data = dataLoader(type, "dataset/", bs, ifUpdate=True) + testInputs, testLabels, _ = data.getTestData() + print("数据集加载完毕") + + + weight = np.zeros(classes) + + for i in range(classes): + images = os.listdir("dataset/rgb_augmentation/"+str(i)) + weight[i] = len(images) + weight = weight/np.sum(weight) + print(weight) + + + def adjust_learning_rate(optimizer, epoch, t=10): + """Sets the learning rate to the initial LR decayed by 10 every t epochs,default=10""" + new_lr = lr * (0.1 ** (epoch // t)) + for param_group in optimizer.param_groups: + param_group['lr'] = new_lr + + net = rgbNet(type) + net.train() + + criterion = torch.nn.NLLLoss(weight=torch.Tensor(weight)) + optimizer = optim.Adam(net.parameters(), weight_decay=1e-5, lr=lr) + + steps = data.get_rounds() + train_loss = [] + + for e in range(epoch): + sum_loss = 0.0 + for step in range(steps): + + inputs, labels = data.next_batch() + inputs = inputs/255 + optimizer.zero_grad() + + # forward + backward + outputs = net.forward(inputs) + loss = criterion(outputs, labels.long()) + loss.backward() + optimizer.step() + + sum_loss += loss.item() + if step % stepLength == stepLength-1: + print('epoch %d /step %d: loss:%.03f' + % (e + 1, step + 1, sum_loss / stepLength)) + # sum_loss = 0.0 + train_loss.append(sum_loss) + adjust_learning_rate(optimizer, e) + outputs = net.forward(testInputs) + _, predicted = torch.max(outputs.data, 1) + correct = (predicted == testLabels.long()).sum() + print('第%d个epoch的识别准确率为:%d%%' % (e + 1, (100 * correct / testLabels.shape[0]))) + torch.save(net.state_dict(), type + "_" + str(net.__class__.__name__) + "_net.pkl") + plt.figure(0) + x = [i for i in range(len(train_loss))] + plt.plot(x, train_loss) + plt.savefig("train_loss.jpg") + +def test(type): + import shutil + result = type + "result/" + if os.path.exists(result): + shutil.rmtree(result) + if not os.path.exists(result): + os.makedirs(result) + data = dataLoader(type, "dataset/", bs=256, ifUpdate=False) + testInputs, testLabels, names = data.getTestData() + net = rgbNet(type) + modelname = type + "_" + str(net.__class__.__name__) + "_net.pkl" + net.load_state_dict(torch.load(modelname)) + net.eval() + print("model load") + outputs = net.forward(testInputs) + _, predicted = torch.max(outputs.data, 1) + correct = (predicted == testLabels.long()).sum() + print('识别准确率为:%d%%' % ((100 * correct / testLabels.shape[0]))) + + # show result picture + for i in range(testInputs.shape[0]): + testimg = cv2.imread(names[i]) + res = net.forward(testInputs[i].view(1, 3, 28, 28)) + _, predicted = torch.max(res.data, 1) + # cv2.imshow("test", test) + name = os.path.join(result, str(i) + "__" + str(predicted.numpy()) + ".bmp") + cv2.imwrite(name, testimg) + # print(predicted) + # cv2.waitKey(0) + print("done!") + +if __name__ == "__main__": + train('rgb') + test('rgb') + + + diff --git a/Algorithm/OCR/newNet/LeNet.py b/Algorithm/OCR/newNet/LeNet.py deleted file mode 100644 index c7ff0e7..0000000 --- a/Algorithm/OCR/newNet/LeNet.py +++ /dev/null @@ -1,65 +0,0 @@ -import torch.nn as nn -from torch.nn import init - -class myNet(nn.Module): - def __init__(self): - super(myNet, self).__init__() - self.conv1_1 = nn.Sequential( # input_size=(1*28*28) - nn.Conv2d( - in_channels=1, - out_channels=16, - kernel_size=3, - padding=1 - ), # padding=2保证输入输出尺寸相同 - ) - self.BN = nn.BatchNorm2d(1, momentum=0.5) - self.conv1_2 = nn.Sequential( # input_size=(16*28*28) - nn.Conv2d( - in_channels=16, - out_channels=16, - kernel_size=3, - padding=1 - ), # padding=2保证输入输出尺寸相同 - nn.ReLU(), # input_size=(16*28*28) - nn.MaxPool2d(kernel_size=2, stride=2) # output_size=(16*14*14) - ) - self.conv2_1 = nn.Sequential( - nn.Conv2d( - in_channels=16, - out_channels=32, - kernel_size=3), - ) # output 32*12*12 - self.conv2_2 = nn.Sequential( - nn.Conv2d( - in_channels=32, - out_channels=32, - kernel_size=3), - nn.ReLU(), # input_size=(32*10*10) - nn.MaxPool2d(2, 2) # output_size=(8*5*5) - ) - self.fc1 = nn.Linear(32 * 5 * 5, 128) - self.relu1 = nn.ReLU() - self.dropout = nn.Dropout(0.2) - self._set_init(self.fc1) - self.fc2 = nn.Linear(128, 11) - self._set_init(self.fc2) - self.softmax = nn.LogSoftmax(dim=1) - - def _set_init(self, layer): # 参数初始化 - init.normal_(layer.weight, mean=0., std=.1) - - # 定义前向传播过程,输入为x - def forward(self, x): - # x = x.view(-1, 28, 28) - # x = self.BN(x) - # x = x.view(-1, 1, 28, 28) - x = self.conv1_1(x) - x = self.conv1_2(x) - x = self.conv2_1(x) - x = self.conv2_2(x) - x = x.view(x.size()[0], -1) - x = self.fc1(x) - x = self.relu1(x) - x = self.dropout(x) - x = self.fc2(x) - return self.softmax(x) diff --git a/Algorithm/OCR/newNet/dataLoader.py b/Algorithm/OCR/newNet/dataLoader.py deleted file mode 100644 index e3406af..0000000 --- a/Algorithm/OCR/newNet/dataLoader.py +++ /dev/null @@ -1,98 +0,0 @@ -import numpy as np -import torch -import os -import cv2 -import pickle -import random - - -class dataLoader(): - # todo 调整batch,使每个batch顺序都不一样 - def __init__(self, path, bs, ifUpdate): - self.bs = bs - self.pointer = 0 - - if not os.path.exists(path + "/train.pkl") or not os.path.exists(path + "/test.pkl"): - ifUpdate = True - - if ifUpdate: - os.system("rm -rf {}".format(path + "/train.pkl")) - os.system("rm -rf {}".format(path + "/test.pkl")) - self.readImagesFromMultiFils(path) - - self.readDataFromPkl(path) - self.shuffle() - - def readImagesFromMultiFils(self, path): - for t in ["train", "test"]: - data = torch.Tensor(np.zeros((1, 1, 28, 28))) - label = [] - - for i in range(11): - root = path + "/" + t + "/" + str(i) + "/" - images = os.listdir(root) - for im in images: - if im.split(".")[-1] != "bmp": - continue - img = cv2.imread(root + im)[:, :, 0] - _, thresh = cv2.threshold(img, 127, 255, cv2.THRESH_BINARY) - temp = torch.Tensor(thresh).view(1, 1, 28, 28) - data = torch.cat((data, temp), 0) - - label.append(i) - - fp = open(path + "/" + t + ".pkl", "wb") - pickle.dump([data[1:], torch.Tensor(np.array(label)).long()], fp) - fp.close() - - def readDataFromPkl(self, path): - with open(path+"/train.pkl", "rb") as fp: - self.trainData, self.trainLabel = pickle.load(fp) - with open(path+"/test.pkl", "rb") as fp: - self.testData, self.testLabel = pickle.load(fp) - - def getTrainData(self): - return self.trainData, self.trainLabel - - def getTestData(self): - return self.testData, self.testLabel - - def shuffle(self): - li = list(range(self.trainData.shape[0])) - random.shuffle(li) - self.trainData = self.trainData[li] - self.trainLabel = self.trainLabel[li] - - def next_batch(self): - if self.pointer * self.bs == self.trainData.shape[0]: - self.pointer = 0 - - if (self.pointer + 1) * self.bs > self.trainData.shape[0]: - temp = self.pointer - self.pointer = 0 - return self.trainData[temp * self.bs:], \ - self.trainLabel[temp * self.bs:] - - temp = self.pointer - self.pointer += 1 - - return self.trainData[temp * self.bs:self.pointer * self.bs], \ - self.trainLabel[temp * self.bs:self.pointer * self.bs] - - def get_rounds(self): - return int(self.trainData.shape[0] / self.bs) + 1 - - -if __name__ == "__main__": - dl = dataLoader("images/LCD_enhanced", 64, True) - dl.shuffle() - train, trl = dl.getTrainData() - test, tel = dl.getTestData() - # - print(train.shape, trl) - print(test.shape, tel) - # - - print(dl.trainLabel) - dl.shuffle() - print(dl.trainLabel) diff --git a/Algorithm/OCR/newNet/train.py b/Algorithm/OCR/newNet/train.py deleted file mode 100644 index 0b79be0..0000000 --- a/Algorithm/OCR/newNet/train.py +++ /dev/null @@ -1,85 +0,0 @@ -from Algorithm.OCR.newNet.dataLoader import * -from Algorithm.OCR.newNet.LeNet import * - -import sys -sys.path.append("../LeNet") - -import numpy as np -import torch -import torch.optim as optim -import torch.nn.init as init - -torch.manual_seed(10) - -bs = 128 -lr = 0.001 -epoch = 40 -stepLength = 20 -classes = 11 - -training = "LCD_enhanced" -data = dataLoader("images/"+training, bs, ifUpdate=True) -testInputs, testLabels = data.getTestData() -print("数据集加载完毕") - - -weight = np.zeros(classes) - -for i in range(classes): - images = os.listdir("images/LCD_enhanced/train/"+str(i)) - weight[i] = len(images) -weight = weight/np.sum(weight) -print(weight) - - -def adjust_learning_rate(optimizer, epoch, t=10): - """Sets the learning rate to the initial LR decayed by 10 every t epochs,default=10""" - new_lr = lr * (0.1 ** (epoch // t)) - for param_group in optimizer.param_groups: - param_group['lr'] = new_lr - -net = myNet() -net.train() - -criterion = nn.NLLLoss(weight=torch.Tensor(weight)) -optimizer = optim.Adam(net.parameters(), weight_decay=1e-5, lr=lr) - -steps = data.get_rounds() - -for e in range(epoch): - sum_loss = 0.0 - for step in range(steps): - - inputs, labels = data.next_batch() - inputs = inputs/255 - optimizer.zero_grad() - - # forward + backward - outputs = net.forward(inputs) - loss = criterion(outputs, labels.long()) - loss.backward() - optimizer.step() - - sum_loss += loss.item() - if step % stepLength == stepLength-1: - print('epoch %d /step %d: loss:%.03f' - % (e + 1, step + 1, sum_loss / stepLength)) - sum_loss = 0.0 - adjust_learning_rate(optimizer, e) - outputs = net.forward(testInputs) - _, predicted = torch.max(outputs.data, 1) - correct = (predicted == testLabels.long()).sum() - print('第%d个epoch的识别准确率为:%d%%' % (e + 1, (100 * correct / testLabels.shape[0]))) -torch.save(net.state_dict(), "net.pkl") - -# for i in range(testInputs.shape[0]): -# test = np.array(testInputs[i].view(28, 28)) -# res = net.forward(testInputs[i].view(1,1,28,28)) -# cv2.imshow("test", test) -# _, predicted = torch.max(res.data, 1) -# print(predicted) -# cv2.waitKey(0) - - - - diff --git a/Algorithm/OCR/utils.py b/Algorithm/OCR/utils.py index ff08282..c795a8e 100644 --- a/Algorithm/OCR/utils.py +++ b/Algorithm/OCR/utils.py @@ -31,23 +31,36 @@ def __init__(self): :return: """ sys.path.append("newNet") - from Algorithm.OCR.newNet.LeNet import myNet + from Algorithm.OCR.digits.LeNet import myNet, rgbNet - self.net = myNet() - self.net.eval() - self.net.load_state_dict(torch.load("Algorithm/OCR/newNet/net.pkl")) + self.bitNet = myNet() + self.bitNet.eval() + self.bitNet.load_state_dict(torch.load("Algorithm/OCR/digits/model/net.pkl")) + self.rgbNet = rgbNet('rgb') + self.rgbNet.eval() + self.rgbNet.load_state_dict(torch.load("Algorithm/OCR/digits/model/rgb_myNet_net.pkl")) - def recognizeNet(self, image): + + + def recognizeNet(self, image, type): """ LeNet识别图像中的数字 :param image: 输入图像 :return: 识别的数字值 """ - image = fillAndResize(image) - tensor = torch.Tensor(image).view((1, 1, 28, 28))/255 + if type == 'bit': + image = fillAndResize(image) + tensor = torch.Tensor(image).view((1, 1, 28, 28))/255 + tensor = tensor.to("cpu") + result = self.bitNet.forward(tensor) + if type == 'rgb': + img = cv2.resize(image, (28, 28), interpolation=cv2.INTER_CUBIC) + if len(img.shape) == 2: + img = np.array(img, img, img) + tensor = torch.Tensor(img).view((1, 3, 28, 28)) + tensor = tensor.to("cpu") + result = self.rgbNet.forward(tensor) - tensor = tensor.to("cpu") - result = self.net.forward(tensor) _, predicted = torch.max(result.data, 1) num = int(np.array(predicted[0]).astype(np.uint32)) diff --git a/Algorithm/pressure/digitPressure.py b/Algorithm/pressure/digitPressure.py index 9811d1e..b5e649a 100644 --- a/Algorithm/pressure/digitPressure.py +++ b/Algorithm/pressure/digitPressure.py @@ -12,6 +12,87 @@ def digitPressure(image, info): template = meterFinderBySIFT(image, info) + + # 存储图片 + if not os.path.exists("storeDigitData"): + os.mkdir("storeDigitData") + + try: + os.mkdir("storeDigitData/thresh") + os.mkdir("storeDigitData/rgb") + except IOError: + pass + + for i in range(11): + try: + os.mkdir("storeDigitData/thresh/" + str(i)) + os.mkdir("storeDigitData/rgb/" + str(i)) + except IOError: + continue + + myRes = [] + if 'rgb' in info and info['rgb']: # rgb as input + myRes = rgbRecognize(template, info) + else: + myRes = bitRecognize(template, info) + + if info["digitType"] == "KWH": + myRes[0] = myRes[0][:4] + myRes.pop(1) + + # 去除头部的非数字字符,同时将非头部的字符转为数字 + for i in range(len(myRes)): + temp = "" + for j, c in enumerate(myRes[i]): + if c != "?": + temp += c + elif j != 0: + temp += str(random.randint(0, 9)) + myRes[i] = float(temp) + + return myRes + +def rgbRecognize(template, info): + # 由标定点得到液晶区域 + dst = boxRectifier(template, info) + # 读取标定信息 + widthSplit = info["widthSplit"] + heightSplit = info["heightSplit"] + # 网络初始化 + MyNet = newNet() + myRes = [] + imgNum = int((len(os.listdir("storeDigitData/")) - 1) / 3) + for i in range(len(heightSplit)): + split = widthSplit[i] + myNum = "" + for j in range(len(split) - 1): + if "decimal" in info.keys() and j == info["decimal"][i]: + myNum += "." + continue + # 得到分割的图片区域 + img = dst[heightSplit[i][0]:heightSplit[i][1], split[j]:split[j + 1]] + num = MyNet.recognizeNet(img, 'rgb') + myNum = myNum + num + + # 存储图片 + cv2.imwrite("storeDigitData/rgb/{}/{}_{}{}_p{}.bmp".format( + num, + imgNum, + i, + j, + num + ), img) + + myRes.append(myNum) + if ifShow: + cv2.imshow("rec", dst) + cv2.imshow("template", template) + print(myRes) + cv2.waitKey(0) + cv2.destroyAllWindows() + return myRes + + +def bitRecognize(template, info): template = cv2.GaussianBlur(template, (3, 3), 0) # 读取标定信息 @@ -37,31 +118,16 @@ def digitPressure(image, info): thresh = cv2.adaptiveThreshold(Hist, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY_INV, 15, 11) else: thresh = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, block, param) - if ifOpen=="close": + if ifOpen == "close": p = cv2.getStructuringElement(cv2.MORPH_RECT, (2, 2)) res1 = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, p) - # 存储图片 - if not os.path.exists("storeDigitData"): - os.mkdir("storeDigitData") - - try: - os.mkdir("storeDigitData/thresh") - os.mkdir("storeDigitData/rgb") - except IOError: - pass - - for i in range(11): - try: - os.mkdir("storeDigitData/thresh/"+str(i)) - os.mkdir("storeDigitData/rgb/"+str(i)) - except IOError: - continue + if os.path.exists("storeDigitData/"): + imgNum = int((len(os.listdir("storeDigitData/"))-1)/3) + cv2.imwrite("storeDigitData/" + str(imgNum) + "_dst.bmp", dst) + cv2.imwrite("storeDigitData/" + str(imgNum) + "_gray.bmp", gray) + cv2.imwrite("storeDigitData/" + str(imgNum) + "_thresh.bmp", thresh) - imgNum = int((len(os.listdir("storeDigitData/"))-1)/3) - cv2.imwrite("storeDigitData/" + str(imgNum) + "_dst.bmp", dst) - cv2.imwrite("storeDigitData/" + str(imgNum) + "_gray.bmp", gray) - cv2.imwrite("storeDigitData/" + str(imgNum) + "_thresh.bmp", thresh) # 网络初始化 MyNet = newNet() @@ -77,50 +143,26 @@ def digitPressure(image, info): # 得到分割的图片区域 img = thresh[heightSplit[i][0]:heightSplit[i][1], split[j]:split[j + 1]] rgb_ = dst[heightSplit[i][0]:heightSplit[i][1], split[j]:split[j + 1]] - # 增强 - # kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 2)) - # img = cv2.morphologyEx(img, cv2.MORPH_OPEN, kernel) - num = MyNet.recognizeNet(img) + num = MyNet.recognizeNet(img, 'bit') myNum = myNum + num # 存储图片 - # cv2.imwrite("storeDigitData/thresh/{}/{}_{}{}_p{}.bmp".format( - # num, - # imgNum, - # i, - # j, - # num - # ), img) - # - # cv2.imwrite("storeDigitData/rgb/{}/{}_{}{}_p{}.bmp".format( - # num, - # imgNum, - # i, - # j, - # num - # ), rgb_) + cv2.imwrite("storeDigitData/thresh/{}/{}_{}{}_p{}.bmp".format( + num, + imgNum, + i, + j, + num + ), img) myRes.append(myNum) + if ifShow: + cv2.imshow("rec", dst) + cv2.imshow("template", template) + print(myRes) + cv2.waitKey(0) + cv2.destroyAllWindows() + return myRes - if info["digitType"] == "KWH": - myRes[0] = myRes[0][:4]+myRes.pop(1) - # 去除头部的非数字字符,同时将非头部的字符转为数字 - for i in range(len(myRes)): - temp = "" - for j, c in enumerate(myRes[i]): - if c != "?": - temp += c - elif j != 0: - temp += str(random.randint(0, 9)) - myRes[i] = float(temp) - - if ifShow: - cv2.imshow("rec", dst) - cv2.imshow("image", image) - print(myRes) - cv2.waitKey(0) - cv2.destroyAllWindows() - - return myRes diff --git a/Algorithm/utils/Finder.py b/Algorithm/utils/Finder.py index 20fe5cd..ee087d7 100644 --- a/Algorithm/utils/Finder.py +++ b/Algorithm/utils/Finder.py @@ -246,3 +246,145 @@ def isOverflow(point, width, height): return src_correct +def meterReginAndLocationBySIFT(image, info): + """ + locate meter's bbox + :param image: image + :param info: info + :return: bbox image and bbox + """ + template = info["template"] + + # cv2.imshow("template", template) + # cv2.imshow("image", image) + # cv2.waitKey(0) + + startPoint = (info["startPoint"]["x"], info["startPoint"]["y"]) + centerPoint = (info["centerPoint"]["x"], info["centerPoint"]["y"]) + endPoint = (info["endPoint"]["x"], info["endPoint"]["y"]) + # startPointUp = (info["startPointUp"]["x"], info["startPointUp"]["y"]) + # endPointUp = (info["endPointUp"]["x"], info["endPointUp"]["y"]) + # centerPointUp = (info["centerPointUp"]["x"], info["centerPointUp"]["y"]) + + templateBlurred = cv2.GaussianBlur(template, (3, 3), 0) + imageBlurred = cv2.GaussianBlur(image, (3, 3), 0) + + sift = cv2.xfeatures2d.SIFT_create() + + # shape of descriptor n * 128, n is the num of key points. + # a row of descriptor is the feature of related key point. + templateKeyPoint, templateDescriptor = sift.detectAndCompute(templateBlurred, None) + imageKeyPoint, imageDescriptor = sift.detectAndCompute(imageBlurred, None) + + # for debug + # templateBlurred = cv2.drawKeypoints(templateBlurred, templateKeyPoint, templateBlurred) + # imageBlurred = cv2.drawKeypoints(imageBlurred, imageKeyPoint, imageBlurred) + # cv2.imshow("template", templateBlurred) + # cv2.imshow("image", imageBlurred) + + # match + bf = cv2.BFMatcher() + # k = 2, so each match has 2 points. 2 points are sorted by distance. + matches = bf.knnMatch(templateDescriptor, imageDescriptor, k=2) + + # The first one is better than the second one + good = [[m] for m, n in matches if m.distance < 0.8 * n.distance] + + # distance matrix + templatePointMatrix = np.array([list(templateKeyPoint[p[0].queryIdx].pt) for p in good]) + imagePointMatrix = np.array([list(imageKeyPoint[p[0].trainIdx].pt) for p in good]) + templatePointDistanceMatrix = pairwise_distances(templatePointMatrix, metric="euclidean") + imagePointDistanceMatrix = pairwise_distances(imagePointMatrix, metric="euclidean") + + # del bad match + distances = [] + maxAbnormalNum = 15 + for i in range(len(good)): + diff = abs(templatePointDistanceMatrix[i] - imagePointDistanceMatrix[i]) + # distance between distance features + diff.sort() + distances.append(np.sqrt(np.sum(np.square(diff[:-maxAbnormalNum])))) + + averageDistance = np.average(distances) + good2 = [good[i] for i in range(len(good)) if distances[i] < 2 * averageDistance] + + # for debug + # matchImage = cv2.drawMatchesKnn(template, templateKeyPoint, image, imageKeyPoint, good2, None, flags=2) + # cv2.imshow("matchImage", matchImage) + # cv2.waitKey(0) + + # not match + if len(good2) < 3: + print("not found!") + return template, None, None + + # 寻找转换矩阵 M + src_pts = np.float32([templateKeyPoint[m[0].queryIdx].pt for m in good2]).reshape(-1, 1, 2) + dst_pts = np.float32([imageKeyPoint[m[0].trainIdx].pt for m in good2]).reshape(-1, 1, 2) + M, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0) + matchesMask = mask.ravel().tolist() + h, w, _ = template.shape + + # 找出匹配到的图形的四个点和标定信息里的所有点 + pts = np.float32( + [[0, 0], [0, h - 1], [w - 1, h - 1], [w - 1, 0], [startPoint[0], startPoint[1]], [endPoint[0], endPoint[1]], + [centerPoint[0], centerPoint[1]], + # [startPointUp[0], startPointUp[1]], + # [endPointUp[0], endPointUp[1]], + # [centerPointUp[0], centerPointUp[1]] + ]).reshape(-1, 1, 2) + dst = cv2.perspectiveTransform(pts, M) + + # 校正图像 + angle = 0.0 + vector = (dst[3][0][0] - dst[0][0][0], dst[3][0][1] - dst[0][0][1]) + cos = (vector[0] * (200.0)) / (200.0 * math.sqrt(vector[0] ** 2 + vector[1] ** 2)) + if (vector[1] > 0): + angle = math.acos(cos) * 180.0 / math.pi + else: + angle = -math.acos(cos) * 180.0 / math.pi + # print(angle) + + change = cv2.getRotationMatrix2D((dst[0][0][0], dst[0][0][1]), angle, 1) + src_correct = cv2.warpAffine(image, change, (image.shape[1], image.shape[0])) + array = np.array([[0, 0, 1]]) + newchange = np.vstack((change, array)) + # 获得校正后的所需要的点 + newpoints = [] + for i in range(len(pts)): + point = newchange.dot(np.array([dst[i][0][0], dst[i][0][1], 1])) + point = list(point) + point.pop() + newpoints.append(point) + src_correct = src_correct[int(round(newpoints[0][1])):int(round(newpoints[1][1])), + int(round(newpoints[0][0])):int(round(newpoints[3][0]))] + bbox = [int(round(newpoints[0][1])), int(round(newpoints[1][1])), + int(round(newpoints[0][0])), int(round(newpoints[3][0]))] + + + width = src_correct.shape[1] + height = src_correct.shape[0] + if width == 0 or height == 0: + return template, None, None + + startPoint = (int(round(newpoints[4][0]) - newpoints[0][0]), int(round(newpoints[4][1]) - newpoints[0][1])) + endPoint = (int(round(newpoints[5][0]) - newpoints[0][0]), int(round(newpoints[5][1]) - newpoints[0][1])) + centerPoint = (int(round(newpoints[6][0]) - newpoints[0][0]), int(round(newpoints[6][1]) - newpoints[0][1])) + + def isOverflow(point, width, height): + if point[0] < 0 or point[1] < 0 or point[0] > width - 1 or point[1] > height - 1: + return True + return False + + if isOverflow(startPoint, width, height) or isOverflow(endPoint, width, height) or isOverflow(centerPoint, width, + height): + print("overflow!") + return template, None, None + + info["startPoint"]["x"] = startPoint[0] + info["startPoint"]["y"] = startPoint[1] + info["centerPoint"]["x"] = centerPoint[0] + info["centerPoint"]["y"] = centerPoint[1] + info["endPoint"]["x"] = endPoint[0] + info["endPoint"]["y"] = endPoint[1] + return src_correct, change, bbox diff --git a/Algorithm/videoDigit.py b/Algorithm/videoDigit.py index bc37ad5..639636c 100644 --- a/Algorithm/videoDigit.py +++ b/Algorithm/videoDigit.py @@ -1,19 +1,24 @@ -from collections import defaultdict - import cv2 +import sys +import os +import json +import shutil import numpy as np import torch +from collections import defaultdict, Counter +import copy -from Algorithm.OCR.character.characterNet import characterNet from Algorithm.pressure.digitPressure import digitPressure -from Algorithm.utils.Finder import meterFinderBySIFT +from Algorithm.utils.Finder import meterFinderBySIFT, meterReginAndLocationBySIFT +# from Algorithm.utils.Finder import meterFinderBySIFT, meterReginAndLocationBySIFT +from Algorithm.OCR.character.characterNet import characterNet def videoDigit(video, info): """ :param video: VideoCapture Input :param info: info - :return: + :return: """ net = characterNet() net.load_state_dict(torch.load("Algorithm/OCR/character/NewNet_minLoss_model_0.965.pkl")) @@ -23,19 +28,69 @@ def emptyLsit(): return [] imagesDict = defaultdict(emptyLsit) - + box = None + newinfo = None + newchange = None + saveframe = [] for i, frame in enumerate(pictures): - res = digitPressure(frame, info) - index = checkFrame(net, frame, info) - imagesDict[chr(index+ord('A'))] += [res] - # debug + res = digitPressure(frame, copy.deepcopy(info)) + eachinfo = copy.deepcopy(info) + template, change, bbox = meterReginAndLocationBySIFT(frame, eachinfo) + index, charimg = checkFrame(i, net, template, eachinfo) + if index < 4 and box is None: + box = bbox.copy() + newchange = change.copy() + newinfo = copy.deepcopy(eachinfo) + elif index == 4 and box is None: + saveframe.append(i) + if index < 4: + imagesDict[chr(index+ord('A'))] += [res] + if box is not None: + for index in saveframe: + res = digitPressure(frame, copy.deepcopy(info)) + frame = pictures[index] + src_correct = cv2.warpAffine(frame, newchange, (frame.shape[1], frame.shape[0])) + newtemplate = src_correct[box[0]:box[1], box[2]:box[3]] + index, charimg = checkFrame(i, net, newtemplate, newinfo) + if index < 4: + imagesDict[chr(index + ord('A'))] += [res] + # # debug # title = str(chr(index+ord('A'))) + str(res) # frame = cv2.putText(frame, title, (50, 100), cv2.FONT_HERSHEY_COMPLEX, 1, (0, 0, 255), 3) - # cv2.imwrite(os.path.join("video_result", info['name']+str(i)+'.jpg'), frame) + # cv2.imwrite(os.path.join(resultdir, info['name']+str(i)+'.jpg'), frame) # cv2.imshow(title, frame) # cv2.waitKey(0) + imagesDict = getResult(imagesDict) return imagesDict +def getResult(dicts): + newdicts = {'A': [], 'B': [], 'C': [], 'D': []} + for ctype, res in dicts.items(): + print("res ", res) + firsts = [[c for c in str(x[0])] for x in res] + seconds = [[c for c in str(x[1])] for x in res] + print("firsts ", firsts) + print("seconds ", seconds) + for x in firsts: + if len(x) < 6: + x.insert(0, ' ') + for x in seconds: + if len(x) < 6: + x.insert(0, ' ') + if len(firsts[0]) > 0 and len(seconds[0]) > 0: + number = "" + for j in range(6): + words = [a[j] for a in firsts] + num = Counter(words).most_common(1) + number = number + num[0][0] + newdicts[ctype].append(number) + number = "" + for j in range(6): + words = [a[j] for a in seconds] + num = Counter(words).most_common(1) + number = number + num[0][0] + newdicts[ctype].append(number) + return newdicts def getPictures(videoCapture): """ @@ -45,7 +100,7 @@ def getPictures(videoCapture): """ pictures = [] cnt = 0 - skipFrameNum = 30 + skipFrameNum = 15 while True: ret, frame = videoCapture.read() # print(cnt, np.shape(frame)) @@ -59,21 +114,18 @@ def getPictures(videoCapture): videoCapture.release() return pictures - -def checkFrame(net, image, info): +def checkFrame(count, net, template, info): """ 判断图片的类型A,AB等 :param image: image :param info: info :return: 出现次序0.1.2.3 """ - start = ([info["startPoint"]["x"], info["startPoint"]["y"]]) - end = ([info["endPoint"]["x"], info["endPoint"]["y"]]) - center = ([info["centerPoint"]["x"], info["centerPoint"]["y"]]) + start = [info["startPoint"]["x"], info["startPoint"]["y"]] + end = [info["endPoint"]["x"], info["endPoint"]["y"]] + center = [info["centerPoint"]["x"], info["centerPoint"]["y"]] width = info["rectangle"]["width"] height = info["rectangle"]["height"] - widthSplit = info["widthSplit"] - heightSplit = info["heightSplit"] characSplit = info["characSplit"] # 计算数字表的矩形外框,并且拉直矫正 @@ -82,7 +134,7 @@ def checkFrame(net, image, info): pts2 = np.float32([[0, 0], [width, 0], [width, height], [0, height]]) M = cv2.getPerspectiveTransform(pts1, pts2) - template = meterFinderBySIFT(image, info) + # oritemplate = template.copy() template = cv2.cvtColor(template, cv2.COLOR_BGR2GRAY) template = cv2.equalizeHist(template) dst = cv2.warpPerspective(template, M, (width, height)) @@ -91,7 +143,7 @@ def checkFrame(net, image, info): imgType = cv2.adaptiveThreshold(imgType, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 17, 11) imgType = cv2.bitwise_not(imgType) - # orimage = imgType.copy() + orimage = imgType.copy() imgType = torch.Tensor(np.array(imgType, dtype=np.float32)) imgType = torch.unsqueeze(imgType, 0) @@ -101,9 +153,47 @@ def checkFrame(net, image, info): maxIndex = np.argmax(type_probe) # debug + # oritemplate = drawTemplatePoints(template, info) + # dst = drawDstPoints(dst, info) + # cv2.imshow("M", M) + # cv2.imshow("template ", oritemplate) + # cv2.imshow("des ", dst) # cv2.imshow(str(chr(maxIndex+ord('A'))), orimage) - # type = str(chr(maxIndex+ord('A'))) - # cv2.imwrite(os.path.join("video_result", type+info['name']+'_'+str(type_probe)+'.jpg'), orimage) + # kind = str(chr(maxIndex+ord('A'))) + # cv2.imwrite(os.path.join(resultdir, info['name'] + '_' + str(count) + '_' + kind + '.bmp'), orimage) # cv2.waitKey(0) - return maxIndex + return maxIndex, orimage + +def drawTemplatePoints(template, info): + start = (info["startPoint"]["x"], info["startPoint"]["y"]) + end = (info["endPoint"]["x"], info["endPoint"]["y"]) + center = (info["centerPoint"]["x"], info["centerPoint"]["y"]) + cv2.circle(template, start, 2, (255, 0, 0), 3) + cv2.circle(template, end, 2, (255, 0, 0), 3) + cv2.circle(template, center, 2, (255, 0, 0), 3) + return template + +def drawDstPoints(dst, info): + widthSplit = info["widthSplit"] + heightSplit = info["heightSplit"] + characSplit = info["characSplit"] + pts = [] + pts.append((characSplit[0][0], characSplit[1][0])) + pts.append((characSplit[0][0], characSplit[1][1])) + pts.append((characSplit[0][1], characSplit[1][0])) + pts.append((characSplit[0][1], characSplit[1][1])) + for x in widthSplit[0]: + for j in range(len(heightSplit)): + for y in heightSplit[j]: + pts.append((x, y)) + for point in pts: + cv2.circle(dst, point, 1, (255, 0, 0), 3) + return dst + + +# resultdir = "/Users/yuanyuan/Documents/GitHub/meterReader/result_character" +# if os.path.exists(resultdir): +# shutil.rmtree(resultdir) +# if not os.path.exists(resultdir): +# os.makedirs(resultdir) diff --git a/TestServiceSample.py b/TestServiceSample.py index 2cbf7f7..1d5967a 100644 --- a/TestServiceSample.py +++ b/TestServiceSample.py @@ -88,4 +88,4 @@ def testVideo(): # # codecov("info/20190416/IMAGES/image") - # testVideo() + # testVideo() \ No newline at end of file diff --git a/configuration.py b/configuration.py index c168d77..fa3ba47 100644 --- a/configuration.py +++ b/configuration.py @@ -11,6 +11,4 @@ # templatePath = "info/20190416/template" configPath = "info/20190514/config" -templatePath = "info/20190514/template" - - +templatePath = "info/20190514/template" \ No newline at end of file diff --git a/ocr_config/PMW2800.json b/ocr_config/PMW2800.json new file mode 100644 index 0000000..0f4f394 --- /dev/null +++ b/ocr_config/PMW2800.json @@ -0,0 +1,21 @@ +{ + "rectangle": { + "width": 200, + "height": 200 + }, + "widthSplit": [ + [8, 26, 44, 62, 80, 81, 98, 116], + [8, 26, 44, 62, 80, 81, 98, 116], + [8, 26, 44, 62, 80, 81, 98, 116], + [8, 26, 44, 62, 80, 81, 98, 116] + ], + "heightSplit": [ + [17, 48], + [50, 80], + [83, 114], + [119, 147] + ], + "decimal":[ + 4, 4, 4, 4 + ] +} \ No newline at end of file diff --git a/ocr_config/PMW2810.json b/ocr_config/PMW2810.json new file mode 100644 index 0000000..2c1ac67 --- /dev/null +++ b/ocr_config/PMW2810.json @@ -0,0 +1,20 @@ +{ + "rectangle": { + "width": 200, + "height": 200 + }, + "widthSplit": [ + [19, 40, 61, 82, 106, 107, 127, 150], + [19, 40, 61, 82, 106, 107, 127, 150], + [19, 40, 61, 82, 106, 107, 127, 150] + ], + "heightSplit": [ + [21, 58], + [62, 103], + [105, 146] + ], + "decimal":[ + + 4, 4, 4 + ] +} \ No newline at end of file