diff --git a/Algorithm/OCR/digits/dataLoader.py b/Algorithm/OCR/digits/dataLoader.py index e3406af..3a566b4 100644 --- a/Algorithm/OCR/digits/dataLoader.py +++ b/Algorithm/OCR/digits/dataLoader.py @@ -8,25 +8,37 @@ class dataLoader(): # todo 调整batch,使每个batch顺序都不一样 - def __init__(self, path, bs, ifUpdate): + def __init__(self, type, path, bs, ifUpdate): + self.meanPixel = 80 self.bs = bs self.pointer = 0 - - if not os.path.exists(path + "/train.pkl") or not os.path.exists(path + "/test.pkl"): + self.type = type + if type == 'rgb': + self.train_path = os.path.join(path, "rgb_augment_train.pkl") + self.test_path = os.path.join(path, "rgb_test.pkl") + elif type == 'bit': + self.train_path = os.path.join(path, "bit_augment_train.pkl") + self.test_path = os.path.join(path, "bit_test.pkl") + + if not os.path.exists(self.train_path) or not os.path.exists(self.test_path): ifUpdate = True if ifUpdate: - os.system("rm -rf {}".format(path + "/train.pkl")) - os.system("rm -rf {}".format(path + "/test.pkl")) + os.system("rm -rf {}".format(self.train_path)) + os.system("rm -rf {}".format(self.test_path)) self.readImagesFromMultiFils(path) - self.readDataFromPkl(path) + self.readDataFromPkl() self.shuffle() def readImagesFromMultiFils(self, path): - for t in ["train", "test"]: - data = torch.Tensor(np.zeros((1, 1, 28, 28))) + for t in ["rgb_augmentation", "rgb_test"]: + if self.type == 'bit': + data = torch.Tensor(np.zeros((1, 1, 28, 28))) + elif self.type == 'rgb': + data = torch.Tensor(np.zeros((1, 3, 28, 28))) label = [] + names = [] for i in range(11): root = path + "/" + t + "/" + str(i) + "/" @@ -34,28 +46,46 @@ def readImagesFromMultiFils(self, path): for im in images: if im.split(".")[-1] != "bmp": continue - img = cv2.imread(root + im)[:, :, 0] - _, thresh = cv2.threshold(img, 127, 255, cv2.THRESH_BINARY) - temp = torch.Tensor(thresh).view(1, 1, 28, 28) + # print(img.shape) + names.append(root+im) + if self.type == "bit": + img = cv2.imread(root + im)[:, :, 0] + img = cv2.resize(img, (28, 28), interpolation=cv2.INTER_CUBIC) + img = cv2.adaptiveThreshold(img, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY_INV, 55, 11) + # 增强 + kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 2)) + img = cv2.morphologyEx(img, cv2.MORPH_OPEN, kernel) + # _, img = cv2.threshold(img, 127, 255, cv2.THRESH_BINARY) + elif self.type == 'rgb': + img = cv2.imread(root + im) + img = cv2.resize(img, (28, 28), interpolation=cv2.INTER_CUBIC) + if len(img.shape) == 2: + img = np.array(img, img, img) + print("convert to rgb: ", img.shape) + + temp = torch.Tensor(img).view(1, 3, 28, 28) - self.meanPixel data = torch.cat((data, temp), 0) label.append(i) - - fp = open(path + "/" + t + ".pkl", "wb") - pickle.dump([data[1:], torch.Tensor(np.array(label)).long()], fp) + if t.endswith("test"): + fp = open(self.test_path, "wb") + pickle.dump([data[1:], torch.Tensor(np.array(label)).long(), names], fp) + else: + fp = open(self.train_path, "wb") + pickle.dump([data[1:], torch.Tensor(np.array(label)).long()], fp) fp.close() - def readDataFromPkl(self, path): - with open(path+"/train.pkl", "rb") as fp: + def readDataFromPkl(self): + with open(self.train_path, "rb") as fp: self.trainData, self.trainLabel = pickle.load(fp) - with open(path+"/test.pkl", "rb") as fp: - self.testData, self.testLabel = pickle.load(fp) + with open(self.test_path, "rb") as fp: + self.testData, self.testLabel, self.names = pickle.load(fp) def getTrainData(self): return self.trainData, self.trainLabel def getTestData(self): - return self.testData, self.testLabel + return self.testData, self.testLabel, self.names def shuffle(self): li = list(range(self.trainData.shape[0])) @@ -83,16 +113,16 @@ def get_rounds(self): return int(self.trainData.shape[0] / self.bs) + 1 -if __name__ == "__main__": - dl = dataLoader("images/LCD_enhanced", 64, True) - dl.shuffle() - train, trl = dl.getTrainData() - test, tel = dl.getTestData() - # - print(train.shape, trl) - print(test.shape, tel) - # - - print(dl.trainLabel) - dl.shuffle() - print(dl.trainLabel) +# if __name__ == "__main__": +# dl = dataLoader("dataset/", 64, True) +# dl.shuffle() +# train, trl = dl.getTrainData() +# test, tel = dl.getTestData() +# # +# print(train.shape, trl) +# print(test.shape, tel) +# # +# +# print(dl.trainLabel) +# dl.shuffle() +# print(dl.trainLabel) diff --git a/Algorithm/OCR/digits/data_augmentation.py b/Algorithm/OCR/digits/data_augmentation.py new file mode 100644 index 0000000..b68b23f --- /dev/null +++ b/Algorithm/OCR/digits/data_augmentation.py @@ -0,0 +1,132 @@ +import os +import cv2 +import numpy as np +import random +from PIL import Image, ImageEnhance + +''' +定义hsv变换函数: +hue_delta是色调变化比例 +sat_delta是饱和度变化比例 +val_delta是明度变化比例 +''' +def hsv_transform(img, hue_delta, sat_mult, val_mult): + img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV).astype(np.float) + img_hsv[:, :, 0] = (img_hsv[:, :, 0] + hue_delta) % 180 + img_hsv[:, :, 1] *= sat_mult + img_hsv[:, :, 2] *= val_mult + img_hsv[img_hsv > 255] = 255 + return cv2.cvtColor(np.round(img_hsv).astype(np.uint8), cv2.COLOR_HSV2BGR) + +''' +随机hsv变换 +hue_vari是色调变化比例的范围 +sat_vari是饱和度变化比例的范围 +val_vari是明度变化比例的范围 +''' +def random_hsv_transform(img, hue_vari=10, sat_vari=0.1, val_vari=0.1): + hue_delta = np.random.randint(-hue_vari, hue_vari) + sat_mult = 1 + np.random.uniform(-sat_vari, sat_vari) + val_mult = 1 + np.random.uniform(-val_vari, val_vari) + return hsv_transform(img, hue_delta, sat_mult, val_mult) + +''' +定义gamma变换函数: +gamma就是Gamma +''' +def gamma_transform(img, gamma=1.0): + gamma_table = [np.power(x / 255.0, gamma) * 255.0 for x in range(256)] + gamma_table = np.round(np.array(gamma_table)).astype(np.uint8) + return cv2.LUT(img, gamma_table) + +''' +随机gamma变换 +gamma_vari是Gamma变化的范围[1/gamma_vari, gamma_vari) +''' +def random_gamma_transform(img, gamma_vari=2.0): + log_gamma_vari = np.log(gamma_vari) + alpha = np.random.uniform(-log_gamma_vari, log_gamma_vari) + gamma = np.exp(alpha) + return gamma_transform(img, gamma) + +def randomGaussian(image, mean=0.2, sigma=0.3): + """ + 对图像进行高斯噪声处理 + :param image: + :return: + """ + def gaussianNoisy(im, mean=0.2, sigma=0.3): + """ + 对图像做高斯噪音处理 + :param im: 单通道图像 + :param mean: 偏移量 + :param sigma: 标准差 + :return: + """ + for _i in range(len(im)): + im[_i] += random.gauss(mean, sigma) + return im + + # 将图像转化成数组 + img = np.asarray(image) + img.flags.writeable = True # 将数组改为读写模式 + width, height = img.shape[:2] + img_r = gaussianNoisy(img[:, :, 0].flatten(), mean, sigma) + img_g = gaussianNoisy(img[:, :, 1].flatten(), mean, sigma) + img_b = gaussianNoisy(img[:, :, 2].flatten(), mean, sigma) + img[:, :, 0] = img_r.reshape([width, height]) + img[:, :, 1] = img_g.reshape([width, height]) + img[:, :, 2] = img_b.reshape([width, height]) + return np.uint8(img) + +def randomColor(image): + """ + 对图像进行颜色抖动 + :param image: PIL的图像image + :return: 有颜色色差的图像image + """ + image = Image.fromarray(image) + random_factor = np.random.randint(0, 31) / 10. # 随机因子 + color_image = ImageEnhance.Color(image).enhance(random_factor) # 调整图像的饱和度 + random_factor = np.random.randint(10, 21) / 10. # 随机因子 + brightness_image = ImageEnhance.Brightness(color_image).enhance(random_factor) # 调整图像的亮度 + random_factor = np.random.randint(10, 21) / 10. # 随机因1子 + contrast_image = ImageEnhance.Contrast(brightness_image).enhance(random_factor) # 调整图像对比度 + random_factor = np.random.randint(0, 31) / 10. # 随机因子 + img = ImageEnhance.Sharpness(contrast_image).enhance(random_factor) # 调整图像锐度 + return np.array(img) + + +def augmentation(origin, dest): + for sub in os.listdir(origin): + subpath = os.path.join(origin, sub) + destpath = os.path.join(dest, sub) + if not os.path.exists(destpath): + os.makedirs(destpath) + for file in os.listdir(subpath): + filename = os.path.join(subpath, file) + img = cv2.imread(filename) + cv2.imwrite(os.path.join(destpath, file[:-4]+"_origin.bmp"), img) + # 随机hsv变换 + img_hsv = random_hsv_transform(img.copy()) + destname = os.path.join(destpath, file[:-4]+"_hsv.bmp") + cv2.imwrite(destname, img_hsv) + # 随机gamma变换 + img_gamma = random_gamma_transform(img.copy()) + destname = os.path.join(destpath, file[:-4] + "_gamma.bmp") + cv2.imwrite(destname, img_gamma) + # 对图像进行颜色抖动 + img_color = randomColor(img.copy()) + destname = os.path.join(destpath, file[:-4]+"_color.bmp") + cv2.imwrite(destname, img_color) + # 对图像进行高斯噪声处理 + img_gaussian = randomGaussian(img.copy()) + destname = os.path.join(destpath, file[:-4] + "_gaussian.bmp") + cv2.imwrite(destname, img_gaussian) + + + + +origin = "dataset/rgb_train" +dest = "dataset/rgb_augmentation" +augmentation(origin, dest) \ No newline at end of file diff --git a/Algorithm/OCR/digits/train.py b/Algorithm/OCR/digits/train.py index 0b79be0..c506a91 100644 --- a/Algorithm/OCR/digits/train.py +++ b/Algorithm/OCR/digits/train.py @@ -1,85 +1,120 @@ -from Algorithm.OCR.newNet.dataLoader import * -from Algorithm.OCR.newNet.LeNet import * - import sys -sys.path.append("../LeNet") - +import os +import cv2 import numpy as np import torch import torch.optim as optim import torch.nn.init as init - -torch.manual_seed(10) - -bs = 128 -lr = 0.001 -epoch = 40 -stepLength = 20 -classes = 11 - -training = "LCD_enhanced" -data = dataLoader("images/"+training, bs, ifUpdate=True) -testInputs, testLabels = data.getTestData() -print("数据集加载完毕") - - -weight = np.zeros(classes) - -for i in range(classes): - images = os.listdir("images/LCD_enhanced/train/"+str(i)) - weight[i] = len(images) -weight = weight/np.sum(weight) -print(weight) - - -def adjust_learning_rate(optimizer, epoch, t=10): - """Sets the learning rate to the initial LR decayed by 10 every t epochs,default=10""" - new_lr = lr * (0.1 ** (epoch // t)) - for param_group in optimizer.param_groups: - param_group['lr'] = new_lr - -net = myNet() -net.train() - -criterion = nn.NLLLoss(weight=torch.Tensor(weight)) -optimizer = optim.Adam(net.parameters(), weight_decay=1e-5, lr=lr) - -steps = data.get_rounds() - -for e in range(epoch): - sum_loss = 0.0 - for step in range(steps): - - inputs, labels = data.next_batch() - inputs = inputs/255 - optimizer.zero_grad() - - # forward + backward - outputs = net.forward(inputs) - loss = criterion(outputs, labels.long()) - loss.backward() - optimizer.step() - - sum_loss += loss.item() - if step % stepLength == stepLength-1: - print('epoch %d /step %d: loss:%.03f' - % (e + 1, step + 1, sum_loss / stepLength)) - sum_loss = 0.0 - adjust_learning_rate(optimizer, e) +import matplotlib.pyplot as plt +# from tensorboardX import SummaryWriter + +from dataLoader import dataLoader +from net import rgbNet + +def train(type): + + torch.manual_seed(10) + + bs = 256 + lr = 0.001 + epoch = 40 + stepLength = 20 + classes = 11 + + data = dataLoader(type, "dataset/", bs, ifUpdate=True) + testInputs, testLabels, _ = data.getTestData() + print("数据集加载完毕") + + + weight = np.zeros(classes) + + for i in range(classes): + images = os.listdir("dataset/rgb_augmentation/"+str(i)) + weight[i] = len(images) + weight = weight/np.sum(weight) + print(weight) + + + def adjust_learning_rate(optimizer, epoch, t=10): + """Sets the learning rate to the initial LR decayed by 10 every t epochs,default=10""" + new_lr = lr * (0.1 ** (epoch // t)) + for param_group in optimizer.param_groups: + param_group['lr'] = new_lr + + net = rgbNet(type) + net.train() + + criterion = torch.nn.NLLLoss(weight=torch.Tensor(weight)) + optimizer = optim.Adam(net.parameters(), weight_decay=1e-5, lr=lr) + + steps = data.get_rounds() + train_loss = [] + + for e in range(epoch): + sum_loss = 0.0 + for step in range(steps): + + inputs, labels = data.next_batch() + inputs = inputs/255 + optimizer.zero_grad() + + # forward + backward + outputs = net.forward(inputs) + loss = criterion(outputs, labels.long()) + loss.backward() + optimizer.step() + + sum_loss += loss.item() + if step % stepLength == stepLength-1: + print('epoch %d /step %d: loss:%.03f' + % (e + 1, step + 1, sum_loss / stepLength)) + # sum_loss = 0.0 + train_loss.append(sum_loss) + adjust_learning_rate(optimizer, e) + outputs = net.forward(testInputs) + _, predicted = torch.max(outputs.data, 1) + correct = (predicted == testLabels.long()).sum() + print('第%d个epoch的识别准确率为:%d%%' % (e + 1, (100 * correct / testLabels.shape[0]))) + torch.save(net.state_dict(), type + "_" + str(net.__class__.__name__) + "_net.pkl") + plt.figure(0) + x = [i for i in range(len(train_loss))] + plt.plot(x, train_loss) + plt.savefig("train_loss.jpg") + +def test(type): + import shutil + result = type + "result/" + if os.path.exists(result): + shutil.rmtree(result) + if not os.path.exists(result): + os.makedirs(result) + data = dataLoader(type, "dataset/", bs=256, ifUpdate=False) + testInputs, testLabels, names = data.getTestData() + net = rgbNet(type) + modelname = type + "_" + str(net.__class__.__name__) + "_net.pkl" + net.load_state_dict(torch.load(modelname)) + net.eval() + print("model load") outputs = net.forward(testInputs) _, predicted = torch.max(outputs.data, 1) correct = (predicted == testLabels.long()).sum() - print('第%d个epoch的识别准确率为:%d%%' % (e + 1, (100 * correct / testLabels.shape[0]))) -torch.save(net.state_dict(), "net.pkl") - -# for i in range(testInputs.shape[0]): -# test = np.array(testInputs[i].view(28, 28)) -# res = net.forward(testInputs[i].view(1,1,28,28)) -# cv2.imshow("test", test) -# _, predicted = torch.max(res.data, 1) -# print(predicted) -# cv2.waitKey(0) - + print('识别准确率为:%d%%' % ((100 * correct / testLabels.shape[0]))) + + # show result picture + for i in range(testInputs.shape[0]): + testimg = cv2.imread(names[i]) + res = net.forward(testInputs[i].view(1, 3, 28, 28)) + _, predicted = torch.max(res.data, 1) + # cv2.imshow("test", test) + name = os.path.join(result, str(i) + "__" + str(predicted.numpy()) + ".bmp") + cv2.imwrite(name, testimg) + # print(predicted) + # cv2.waitKey(0) + print("done!") + +if __name__ == "__main__": + train('rgb') + test('rgb') diff --git a/Algorithm/pressure/digitPressure.py b/Algorithm/pressure/digitPressure.py index 9811d1e..b5e649a 100644 --- a/Algorithm/pressure/digitPressure.py +++ b/Algorithm/pressure/digitPressure.py @@ -12,6 +12,87 @@ def digitPressure(image, info): template = meterFinderBySIFT(image, info) + + # 存储图片 + if not os.path.exists("storeDigitData"): + os.mkdir("storeDigitData") + + try: + os.mkdir("storeDigitData/thresh") + os.mkdir("storeDigitData/rgb") + except IOError: + pass + + for i in range(11): + try: + os.mkdir("storeDigitData/thresh/" + str(i)) + os.mkdir("storeDigitData/rgb/" + str(i)) + except IOError: + continue + + myRes = [] + if 'rgb' in info and info['rgb']: # rgb as input + myRes = rgbRecognize(template, info) + else: + myRes = bitRecognize(template, info) + + if info["digitType"] == "KWH": + myRes[0] = myRes[0][:4] + myRes.pop(1) + + # 去除头部的非数字字符,同时将非头部的字符转为数字 + for i in range(len(myRes)): + temp = "" + for j, c in enumerate(myRes[i]): + if c != "?": + temp += c + elif j != 0: + temp += str(random.randint(0, 9)) + myRes[i] = float(temp) + + return myRes + +def rgbRecognize(template, info): + # 由标定点得到液晶区域 + dst = boxRectifier(template, info) + # 读取标定信息 + widthSplit = info["widthSplit"] + heightSplit = info["heightSplit"] + # 网络初始化 + MyNet = newNet() + myRes = [] + imgNum = int((len(os.listdir("storeDigitData/")) - 1) / 3) + for i in range(len(heightSplit)): + split = widthSplit[i] + myNum = "" + for j in range(len(split) - 1): + if "decimal" in info.keys() and j == info["decimal"][i]: + myNum += "." + continue + # 得到分割的图片区域 + img = dst[heightSplit[i][0]:heightSplit[i][1], split[j]:split[j + 1]] + num = MyNet.recognizeNet(img, 'rgb') + myNum = myNum + num + + # 存储图片 + cv2.imwrite("storeDigitData/rgb/{}/{}_{}{}_p{}.bmp".format( + num, + imgNum, + i, + j, + num + ), img) + + myRes.append(myNum) + if ifShow: + cv2.imshow("rec", dst) + cv2.imshow("template", template) + print(myRes) + cv2.waitKey(0) + cv2.destroyAllWindows() + return myRes + + +def bitRecognize(template, info): template = cv2.GaussianBlur(template, (3, 3), 0) # 读取标定信息 @@ -37,31 +118,16 @@ def digitPressure(image, info): thresh = cv2.adaptiveThreshold(Hist, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY_INV, 15, 11) else: thresh = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, block, param) - if ifOpen=="close": + if ifOpen == "close": p = cv2.getStructuringElement(cv2.MORPH_RECT, (2, 2)) res1 = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, p) - # 存储图片 - if not os.path.exists("storeDigitData"): - os.mkdir("storeDigitData") - - try: - os.mkdir("storeDigitData/thresh") - os.mkdir("storeDigitData/rgb") - except IOError: - pass - - for i in range(11): - try: - os.mkdir("storeDigitData/thresh/"+str(i)) - os.mkdir("storeDigitData/rgb/"+str(i)) - except IOError: - continue + if os.path.exists("storeDigitData/"): + imgNum = int((len(os.listdir("storeDigitData/"))-1)/3) + cv2.imwrite("storeDigitData/" + str(imgNum) + "_dst.bmp", dst) + cv2.imwrite("storeDigitData/" + str(imgNum) + "_gray.bmp", gray) + cv2.imwrite("storeDigitData/" + str(imgNum) + "_thresh.bmp", thresh) - imgNum = int((len(os.listdir("storeDigitData/"))-1)/3) - cv2.imwrite("storeDigitData/" + str(imgNum) + "_dst.bmp", dst) - cv2.imwrite("storeDigitData/" + str(imgNum) + "_gray.bmp", gray) - cv2.imwrite("storeDigitData/" + str(imgNum) + "_thresh.bmp", thresh) # 网络初始化 MyNet = newNet() @@ -77,50 +143,26 @@ def digitPressure(image, info): # 得到分割的图片区域 img = thresh[heightSplit[i][0]:heightSplit[i][1], split[j]:split[j + 1]] rgb_ = dst[heightSplit[i][0]:heightSplit[i][1], split[j]:split[j + 1]] - # 增强 - # kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 2)) - # img = cv2.morphologyEx(img, cv2.MORPH_OPEN, kernel) - num = MyNet.recognizeNet(img) + num = MyNet.recognizeNet(img, 'bit') myNum = myNum + num # 存储图片 - # cv2.imwrite("storeDigitData/thresh/{}/{}_{}{}_p{}.bmp".format( - # num, - # imgNum, - # i, - # j, - # num - # ), img) - # - # cv2.imwrite("storeDigitData/rgb/{}/{}_{}{}_p{}.bmp".format( - # num, - # imgNum, - # i, - # j, - # num - # ), rgb_) + cv2.imwrite("storeDigitData/thresh/{}/{}_{}{}_p{}.bmp".format( + num, + imgNum, + i, + j, + num + ), img) myRes.append(myNum) + if ifShow: + cv2.imshow("rec", dst) + cv2.imshow("template", template) + print(myRes) + cv2.waitKey(0) + cv2.destroyAllWindows() + return myRes - if info["digitType"] == "KWH": - myRes[0] = myRes[0][:4]+myRes.pop(1) - # 去除头部的非数字字符,同时将非头部的字符转为数字 - for i in range(len(myRes)): - temp = "" - for j, c in enumerate(myRes[i]): - if c != "?": - temp += c - elif j != 0: - temp += str(random.randint(0, 9)) - myRes[i] = float(temp) - - if ifShow: - cv2.imshow("rec", dst) - cv2.imshow("image", image) - print(myRes) - cv2.waitKey(0) - cv2.destroyAllWindows() - - return myRes diff --git a/Algorithm/utils/Finder.py b/Algorithm/utils/Finder.py index 20fe5cd..ee087d7 100644 --- a/Algorithm/utils/Finder.py +++ b/Algorithm/utils/Finder.py @@ -246,3 +246,145 @@ def isOverflow(point, width, height): return src_correct +def meterReginAndLocationBySIFT(image, info): + """ + locate meter's bbox + :param image: image + :param info: info + :return: bbox image and bbox + """ + template = info["template"] + + # cv2.imshow("template", template) + # cv2.imshow("image", image) + # cv2.waitKey(0) + + startPoint = (info["startPoint"]["x"], info["startPoint"]["y"]) + centerPoint = (info["centerPoint"]["x"], info["centerPoint"]["y"]) + endPoint = (info["endPoint"]["x"], info["endPoint"]["y"]) + # startPointUp = (info["startPointUp"]["x"], info["startPointUp"]["y"]) + # endPointUp = (info["endPointUp"]["x"], info["endPointUp"]["y"]) + # centerPointUp = (info["centerPointUp"]["x"], info["centerPointUp"]["y"]) + + templateBlurred = cv2.GaussianBlur(template, (3, 3), 0) + imageBlurred = cv2.GaussianBlur(image, (3, 3), 0) + + sift = cv2.xfeatures2d.SIFT_create() + + # shape of descriptor n * 128, n is the num of key points. + # a row of descriptor is the feature of related key point. + templateKeyPoint, templateDescriptor = sift.detectAndCompute(templateBlurred, None) + imageKeyPoint, imageDescriptor = sift.detectAndCompute(imageBlurred, None) + + # for debug + # templateBlurred = cv2.drawKeypoints(templateBlurred, templateKeyPoint, templateBlurred) + # imageBlurred = cv2.drawKeypoints(imageBlurred, imageKeyPoint, imageBlurred) + # cv2.imshow("template", templateBlurred) + # cv2.imshow("image", imageBlurred) + + # match + bf = cv2.BFMatcher() + # k = 2, so each match has 2 points. 2 points are sorted by distance. + matches = bf.knnMatch(templateDescriptor, imageDescriptor, k=2) + + # The first one is better than the second one + good = [[m] for m, n in matches if m.distance < 0.8 * n.distance] + + # distance matrix + templatePointMatrix = np.array([list(templateKeyPoint[p[0].queryIdx].pt) for p in good]) + imagePointMatrix = np.array([list(imageKeyPoint[p[0].trainIdx].pt) for p in good]) + templatePointDistanceMatrix = pairwise_distances(templatePointMatrix, metric="euclidean") + imagePointDistanceMatrix = pairwise_distances(imagePointMatrix, metric="euclidean") + + # del bad match + distances = [] + maxAbnormalNum = 15 + for i in range(len(good)): + diff = abs(templatePointDistanceMatrix[i] - imagePointDistanceMatrix[i]) + # distance between distance features + diff.sort() + distances.append(np.sqrt(np.sum(np.square(diff[:-maxAbnormalNum])))) + + averageDistance = np.average(distances) + good2 = [good[i] for i in range(len(good)) if distances[i] < 2 * averageDistance] + + # for debug + # matchImage = cv2.drawMatchesKnn(template, templateKeyPoint, image, imageKeyPoint, good2, None, flags=2) + # cv2.imshow("matchImage", matchImage) + # cv2.waitKey(0) + + # not match + if len(good2) < 3: + print("not found!") + return template, None, None + + # 寻找转换矩阵 M + src_pts = np.float32([templateKeyPoint[m[0].queryIdx].pt for m in good2]).reshape(-1, 1, 2) + dst_pts = np.float32([imageKeyPoint[m[0].trainIdx].pt for m in good2]).reshape(-1, 1, 2) + M, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0) + matchesMask = mask.ravel().tolist() + h, w, _ = template.shape + + # 找出匹配到的图形的四个点和标定信息里的所有点 + pts = np.float32( + [[0, 0], [0, h - 1], [w - 1, h - 1], [w - 1, 0], [startPoint[0], startPoint[1]], [endPoint[0], endPoint[1]], + [centerPoint[0], centerPoint[1]], + # [startPointUp[0], startPointUp[1]], + # [endPointUp[0], endPointUp[1]], + # [centerPointUp[0], centerPointUp[1]] + ]).reshape(-1, 1, 2) + dst = cv2.perspectiveTransform(pts, M) + + # 校正图像 + angle = 0.0 + vector = (dst[3][0][0] - dst[0][0][0], dst[3][0][1] - dst[0][0][1]) + cos = (vector[0] * (200.0)) / (200.0 * math.sqrt(vector[0] ** 2 + vector[1] ** 2)) + if (vector[1] > 0): + angle = math.acos(cos) * 180.0 / math.pi + else: + angle = -math.acos(cos) * 180.0 / math.pi + # print(angle) + + change = cv2.getRotationMatrix2D((dst[0][0][0], dst[0][0][1]), angle, 1) + src_correct = cv2.warpAffine(image, change, (image.shape[1], image.shape[0])) + array = np.array([[0, 0, 1]]) + newchange = np.vstack((change, array)) + # 获得校正后的所需要的点 + newpoints = [] + for i in range(len(pts)): + point = newchange.dot(np.array([dst[i][0][0], dst[i][0][1], 1])) + point = list(point) + point.pop() + newpoints.append(point) + src_correct = src_correct[int(round(newpoints[0][1])):int(round(newpoints[1][1])), + int(round(newpoints[0][0])):int(round(newpoints[3][0]))] + bbox = [int(round(newpoints[0][1])), int(round(newpoints[1][1])), + int(round(newpoints[0][0])), int(round(newpoints[3][0]))] + + + width = src_correct.shape[1] + height = src_correct.shape[0] + if width == 0 or height == 0: + return template, None, None + + startPoint = (int(round(newpoints[4][0]) - newpoints[0][0]), int(round(newpoints[4][1]) - newpoints[0][1])) + endPoint = (int(round(newpoints[5][0]) - newpoints[0][0]), int(round(newpoints[5][1]) - newpoints[0][1])) + centerPoint = (int(round(newpoints[6][0]) - newpoints[0][0]), int(round(newpoints[6][1]) - newpoints[0][1])) + + def isOverflow(point, width, height): + if point[0] < 0 or point[1] < 0 or point[0] > width - 1 or point[1] > height - 1: + return True + return False + + if isOverflow(startPoint, width, height) or isOverflow(endPoint, width, height) or isOverflow(centerPoint, width, + height): + print("overflow!") + return template, None, None + + info["startPoint"]["x"] = startPoint[0] + info["startPoint"]["y"] = startPoint[1] + info["centerPoint"]["x"] = centerPoint[0] + info["centerPoint"]["y"] = centerPoint[1] + info["endPoint"]["x"] = endPoint[0] + info["endPoint"]["y"] = endPoint[1] + return src_correct, change, bbox diff --git a/Algorithm/videoDigit.py b/Algorithm/videoDigit.py index bc37ad5..639636c 100644 --- a/Algorithm/videoDigit.py +++ b/Algorithm/videoDigit.py @@ -1,19 +1,24 @@ -from collections import defaultdict - import cv2 +import sys +import os +import json +import shutil import numpy as np import torch +from collections import defaultdict, Counter +import copy -from Algorithm.OCR.character.characterNet import characterNet from Algorithm.pressure.digitPressure import digitPressure -from Algorithm.utils.Finder import meterFinderBySIFT +from Algorithm.utils.Finder import meterFinderBySIFT, meterReginAndLocationBySIFT +# from Algorithm.utils.Finder import meterFinderBySIFT, meterReginAndLocationBySIFT +from Algorithm.OCR.character.characterNet import characterNet def videoDigit(video, info): """ :param video: VideoCapture Input :param info: info - :return: + :return: """ net = characterNet() net.load_state_dict(torch.load("Algorithm/OCR/character/NewNet_minLoss_model_0.965.pkl")) @@ -23,19 +28,69 @@ def emptyLsit(): return [] imagesDict = defaultdict(emptyLsit) - + box = None + newinfo = None + newchange = None + saveframe = [] for i, frame in enumerate(pictures): - res = digitPressure(frame, info) - index = checkFrame(net, frame, info) - imagesDict[chr(index+ord('A'))] += [res] - # debug + res = digitPressure(frame, copy.deepcopy(info)) + eachinfo = copy.deepcopy(info) + template, change, bbox = meterReginAndLocationBySIFT(frame, eachinfo) + index, charimg = checkFrame(i, net, template, eachinfo) + if index < 4 and box is None: + box = bbox.copy() + newchange = change.copy() + newinfo = copy.deepcopy(eachinfo) + elif index == 4 and box is None: + saveframe.append(i) + if index < 4: + imagesDict[chr(index+ord('A'))] += [res] + if box is not None: + for index in saveframe: + res = digitPressure(frame, copy.deepcopy(info)) + frame = pictures[index] + src_correct = cv2.warpAffine(frame, newchange, (frame.shape[1], frame.shape[0])) + newtemplate = src_correct[box[0]:box[1], box[2]:box[3]] + index, charimg = checkFrame(i, net, newtemplate, newinfo) + if index < 4: + imagesDict[chr(index + ord('A'))] += [res] + # # debug # title = str(chr(index+ord('A'))) + str(res) # frame = cv2.putText(frame, title, (50, 100), cv2.FONT_HERSHEY_COMPLEX, 1, (0, 0, 255), 3) - # cv2.imwrite(os.path.join("video_result", info['name']+str(i)+'.jpg'), frame) + # cv2.imwrite(os.path.join(resultdir, info['name']+str(i)+'.jpg'), frame) # cv2.imshow(title, frame) # cv2.waitKey(0) + imagesDict = getResult(imagesDict) return imagesDict +def getResult(dicts): + newdicts = {'A': [], 'B': [], 'C': [], 'D': []} + for ctype, res in dicts.items(): + print("res ", res) + firsts = [[c for c in str(x[0])] for x in res] + seconds = [[c for c in str(x[1])] for x in res] + print("firsts ", firsts) + print("seconds ", seconds) + for x in firsts: + if len(x) < 6: + x.insert(0, ' ') + for x in seconds: + if len(x) < 6: + x.insert(0, ' ') + if len(firsts[0]) > 0 and len(seconds[0]) > 0: + number = "" + for j in range(6): + words = [a[j] for a in firsts] + num = Counter(words).most_common(1) + number = number + num[0][0] + newdicts[ctype].append(number) + number = "" + for j in range(6): + words = [a[j] for a in seconds] + num = Counter(words).most_common(1) + number = number + num[0][0] + newdicts[ctype].append(number) + return newdicts def getPictures(videoCapture): """ @@ -45,7 +100,7 @@ def getPictures(videoCapture): """ pictures = [] cnt = 0 - skipFrameNum = 30 + skipFrameNum = 15 while True: ret, frame = videoCapture.read() # print(cnt, np.shape(frame)) @@ -59,21 +114,18 @@ def getPictures(videoCapture): videoCapture.release() return pictures - -def checkFrame(net, image, info): +def checkFrame(count, net, template, info): """ 判断图片的类型A,AB等 :param image: image :param info: info :return: 出现次序0.1.2.3 """ - start = ([info["startPoint"]["x"], info["startPoint"]["y"]]) - end = ([info["endPoint"]["x"], info["endPoint"]["y"]]) - center = ([info["centerPoint"]["x"], info["centerPoint"]["y"]]) + start = [info["startPoint"]["x"], info["startPoint"]["y"]] + end = [info["endPoint"]["x"], info["endPoint"]["y"]] + center = [info["centerPoint"]["x"], info["centerPoint"]["y"]] width = info["rectangle"]["width"] height = info["rectangle"]["height"] - widthSplit = info["widthSplit"] - heightSplit = info["heightSplit"] characSplit = info["characSplit"] # 计算数字表的矩形外框,并且拉直矫正 @@ -82,7 +134,7 @@ def checkFrame(net, image, info): pts2 = np.float32([[0, 0], [width, 0], [width, height], [0, height]]) M = cv2.getPerspectiveTransform(pts1, pts2) - template = meterFinderBySIFT(image, info) + # oritemplate = template.copy() template = cv2.cvtColor(template, cv2.COLOR_BGR2GRAY) template = cv2.equalizeHist(template) dst = cv2.warpPerspective(template, M, (width, height)) @@ -91,7 +143,7 @@ def checkFrame(net, image, info): imgType = cv2.adaptiveThreshold(imgType, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 17, 11) imgType = cv2.bitwise_not(imgType) - # orimage = imgType.copy() + orimage = imgType.copy() imgType = torch.Tensor(np.array(imgType, dtype=np.float32)) imgType = torch.unsqueeze(imgType, 0) @@ -101,9 +153,47 @@ def checkFrame(net, image, info): maxIndex = np.argmax(type_probe) # debug + # oritemplate = drawTemplatePoints(template, info) + # dst = drawDstPoints(dst, info) + # cv2.imshow("M", M) + # cv2.imshow("template ", oritemplate) + # cv2.imshow("des ", dst) # cv2.imshow(str(chr(maxIndex+ord('A'))), orimage) - # type = str(chr(maxIndex+ord('A'))) - # cv2.imwrite(os.path.join("video_result", type+info['name']+'_'+str(type_probe)+'.jpg'), orimage) + # kind = str(chr(maxIndex+ord('A'))) + # cv2.imwrite(os.path.join(resultdir, info['name'] + '_' + str(count) + '_' + kind + '.bmp'), orimage) # cv2.waitKey(0) - return maxIndex + return maxIndex, orimage + +def drawTemplatePoints(template, info): + start = (info["startPoint"]["x"], info["startPoint"]["y"]) + end = (info["endPoint"]["x"], info["endPoint"]["y"]) + center = (info["centerPoint"]["x"], info["centerPoint"]["y"]) + cv2.circle(template, start, 2, (255, 0, 0), 3) + cv2.circle(template, end, 2, (255, 0, 0), 3) + cv2.circle(template, center, 2, (255, 0, 0), 3) + return template + +def drawDstPoints(dst, info): + widthSplit = info["widthSplit"] + heightSplit = info["heightSplit"] + characSplit = info["characSplit"] + pts = [] + pts.append((characSplit[0][0], characSplit[1][0])) + pts.append((characSplit[0][0], characSplit[1][1])) + pts.append((characSplit[0][1], characSplit[1][0])) + pts.append((characSplit[0][1], characSplit[1][1])) + for x in widthSplit[0]: + for j in range(len(heightSplit)): + for y in heightSplit[j]: + pts.append((x, y)) + for point in pts: + cv2.circle(dst, point, 1, (255, 0, 0), 3) + return dst + + +# resultdir = "/Users/yuanyuan/Documents/GitHub/meterReader/result_character" +# if os.path.exists(resultdir): +# shutil.rmtree(resultdir) +# if not os.path.exists(resultdir): +# os.makedirs(resultdir) diff --git a/TestServiceSample.py b/TestServiceSample.py index 2cbf7f7..1d5967a 100644 --- a/TestServiceSample.py +++ b/TestServiceSample.py @@ -88,4 +88,4 @@ def testVideo(): # # codecov("info/20190416/IMAGES/image") - # testVideo() + # testVideo() \ No newline at end of file diff --git a/configuration.py b/configuration.py index c168d77..fa3ba47 100644 --- a/configuration.py +++ b/configuration.py @@ -11,6 +11,4 @@ # templatePath = "info/20190416/template" configPath = "info/20190514/config" -templatePath = "info/20190514/template" - - +templatePath = "info/20190514/template" \ No newline at end of file