Skip to content

Commit e84c616

Browse files
Bycobmergify[bot]
authored andcommittedMay 15, 2023
fix: no resize when training with images
1 parent 20d8ebe commit e84c616

File tree

6 files changed

+188
-71
lines changed

6 files changed

+188
-71
lines changed
 

‎src/backends/torch/torchdataaug.cc

+19-10
Original file line numberDiff line numberDiff line change
@@ -395,17 +395,24 @@ namespace dd
395395

396396
if (sample)
397397
{
398+
int img_width = src.cols;
399+
int img_height = src.rows;
400+
std::uniform_int_distribution<int> uniform_int_crop_x(
401+
0, img_width - cp._crop_size);
402+
std::uniform_int_distribution<int> uniform_int_crop_y(
403+
0, img_height - cp._crop_size);
404+
398405
#pragma omp critical
399406
{
400407
if (test)
401408
{
402-
crop_x = cp._uniform_int_crop_x(_rnd_test_gen);
403-
crop_y = cp._uniform_int_crop_y(_rnd_test_gen);
409+
crop_x = uniform_int_crop_x(_rnd_test_gen);
410+
crop_y = uniform_int_crop_y(_rnd_test_gen);
404411
}
405412
else
406413
{
407-
crop_x = cp._uniform_int_crop_x(_rnd_gen);
408-
crop_y = cp._uniform_int_crop_y(_rnd_gen);
414+
crop_x = uniform_int_crop_x(_rnd_gen);
415+
crop_y = uniform_int_crop_y(_rnd_gen);
409416
}
410417
}
411418
}
@@ -464,20 +471,22 @@ namespace dd
464471

465472
#pragma omp critical
466473
{
474+
int img_width = src.cols;
475+
int img_height = src.rows;
467476
// get shape and area to erase
468477
int w = 0, h = 0, rect_x = 0, rect_y = 0;
469478
if (cp._w == 0 && cp._h == 0)
470479
{
471-
float s = cp._uniform_real_cutout_s(_rnd_gen) * cp._img_width
472-
* cp._img_height; // area
480+
float s = cp._uniform_real_cutout_s(_rnd_gen) * img_width
481+
* img_height; // area
473482
float r = cp._uniform_real_cutout_r(_rnd_gen); // aspect ratio
474483

475-
w = std::min(cp._img_width,
484+
w = std::min(img_width,
476485
static_cast<int>(std::floor(std::sqrt(s / r))));
477-
h = std::min(cp._img_height,
486+
h = std::min(img_height,
478487
static_cast<int>(std::floor(std::sqrt(s * r))));
479-
std::uniform_int_distribution<int> distx(0, cp._img_width - w);
480-
std::uniform_int_distribution<int> disty(0, cp._img_height - h);
488+
std::uniform_int_distribution<int> distx(0, img_width - w);
489+
std::uniform_int_distribution<int> disty(0, img_height - h);
481490
rect_x = distx(_rnd_gen);
482491
rect_y = disty(_rnd_gen);
483492
}

‎src/backends/torch/torchdataaug.h

+6-44
Original file line numberDiff line numberDiff line change
@@ -33,67 +33,34 @@
3333

3434
namespace dd
3535
{
36-
class ImgAugParams
36+
class CropParams
3737
{
3838
public:
39-
ImgAugParams() : _img_width(224), _img_height(224)
39+
CropParams()
4040
{
4141
}
4242

43-
ImgAugParams(const int &img_width, const int &img_height)
44-
: _img_width(img_width), _img_height(img_height)
43+
CropParams(const int &crop_size) : _crop_size(crop_size)
4544
{
4645
}
4746

48-
~ImgAugParams()
49-
{
50-
}
51-
52-
int _img_width;
53-
int _img_height;
54-
};
55-
56-
class CropParams : public ImgAugParams
57-
{
58-
public:
59-
CropParams() : ImgAugParams()
60-
{
61-
}
62-
63-
CropParams(const int &crop_size, const int &img_width,
64-
const int &img_height)
65-
: ImgAugParams(img_width, img_height), _crop_size(crop_size)
66-
{
67-
if (_crop_size > 0)
68-
{
69-
_uniform_int_crop_x
70-
= std::uniform_int_distribution<int>(0, _img_width - _crop_size);
71-
_uniform_int_crop_y = std::uniform_int_distribution<int>(
72-
0, _img_height - _crop_size);
73-
}
74-
}
75-
7647
~CropParams()
7748
{
7849
}
7950

8051
// default params
8152
int _crop_size = -1;
82-
std::uniform_int_distribution<int> _uniform_int_crop_x;
83-
std::uniform_int_distribution<int> _uniform_int_crop_y;
8453
int _test_crop_samples = 1; /**< number of sampled crops (at test time). */
8554
};
8655

87-
class CutoutParams : public ImgAugParams
56+
class CutoutParams
8857
{
8958
public:
90-
CutoutParams() : ImgAugParams()
59+
CutoutParams()
9160
{
9261
}
9362

94-
CutoutParams(const float &prob, const int &img_width,
95-
const int &img_height)
96-
: ImgAugParams(img_width, img_height), _prob(prob)
63+
CutoutParams(const float &prob) : _prob(prob)
9764
{
9865
_uniform_real_cutout_s
9966
= std::uniform_real_distribution<float>(_cutout_sl, _cutout_sh);
@@ -287,11 +254,6 @@ namespace dd
287254
_uniform_real_1(0.0, 1.0), _bernouilli(0.5),
288255
_uniform_int_rotate(0, 3)
289256
{
290-
if (_crop_params._crop_size > 0)
291-
{
292-
_cutout_params._img_width = _crop_params._crop_size;
293-
_cutout_params._img_height = _crop_params._crop_size;
294-
}
295257
reset_rnd_test_gen();
296258
}
297259

‎src/backends/torch/torchdataset.cc

+8-5
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ namespace dd
227227
torch::load(targett, targetstream);
228228
}
229229

230-
if (bgr.cols != width || bgr.rows != height)
230+
if (width > 0 && height > 0 && (bgr.cols != width || bgr.rows != height))
231231
{
232232
cv::resize(bgr, bgr, cv::Size(width, height), 0, 0, cv::INTER_CUBIC);
233233

@@ -860,10 +860,13 @@ namespace dd
860860

861861
std::ifstream infile(bboxfname);
862862
std::string line;
863-
double wfactor = static_cast<double>(inputc->_width)
864-
/ static_cast<double>(orig_width);
865-
double hfactor = static_cast<double>(inputc->_height)
866-
/ static_cast<double>(orig_height);
863+
double wfactor = inputc->_width > 0 ? static_cast<double>(inputc->_width)
864+
/ static_cast<double>(orig_width)
865+
: 1;
866+
double hfactor = inputc->_height > 0
867+
? static_cast<double>(inputc->_height)
868+
/ static_cast<double>(orig_height)
869+
: 1;
867870

868871
while (std::getline(infile, line))
869872
{

‎src/backends/torch/torchlib.cc

+10-8
Original file line numberDiff line numberDiff line change
@@ -701,8 +701,7 @@ namespace dd
701701
if (ad_mllib.has("crop_size"))
702702
{
703703
int crop_size = ad_mllib.get("crop_size").get<int>();
704-
crop_params
705-
= CropParams(crop_size, inputc.width(), inputc.height());
704+
crop_params = CropParams(crop_size);
706705
if (ad_mllib.has("test_crop_samples"))
707706
crop_params._test_crop_samples
708707
= ad_mllib.get("test_crop_samples").get<int>();
@@ -712,8 +711,7 @@ namespace dd
712711
if (ad_mllib.has("cutout"))
713712
{
714713
float cutout = ad_mllib.get("cutout").get<double>();
715-
cutout_params
716-
= CutoutParams(cutout, inputc.width(), inputc.height());
714+
cutout_params = CutoutParams(cutout);
717715
this->_logger->info("cutout: {}", cutout);
718716
}
719717
GeometryParams geometry_params;
@@ -1640,6 +1638,10 @@ namespace dd
16401638
throw MLLibInternalException(
16411639
"Couldn't find original image size for " + uri);
16421640
}
1641+
int src_width
1642+
= inputc.width() > 0 ? inputc.width() : cols - 1;
1643+
int src_height
1644+
= inputc.height() > 0 ? inputc.height() : rows - 1;
16431645

16441646
APIData results_ad;
16451647
std::vector<double> probs;
@@ -1676,10 +1678,10 @@ namespace dd
16761678
this->_mlmodel.get_hcorresp(labels_acc[j]));
16771679

16781680
double bbox[] = {
1679-
bboxes_acc[j][0] / inputc.width() * (cols - 1),
1680-
bboxes_acc[j][1] / inputc.height() * (rows - 1),
1681-
bboxes_acc[j][2] / inputc.width() * (cols - 1),
1682-
bboxes_acc[j][3] / inputc.height() * (rows - 1),
1681+
bboxes_acc[j][0] / src_width * (cols - 1),
1682+
bboxes_acc[j][1] / src_height * (rows - 1),
1683+
bboxes_acc[j][2] / src_width * (cols - 1),
1684+
bboxes_acc[j][3] / src_height * (rows - 1),
16831685
};
16841686

16851687
// clamp bbox

‎src/imginputfileconn.h

+4-4
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,9 @@ namespace dd
100100
{
101101
if (_scaled)
102102
scale(src, dst);
103-
else if (_width == 0 || _height == 0)
103+
else if (_width < 0 || _height < 0)
104104
{
105-
if (_width == 0 && _height == 0)
105+
if (_width < 0 && _height < 0)
106106
{
107107
// Do nothing and keep native resolution. May cause issues if
108108
// batched images are different resolutions
@@ -199,9 +199,9 @@ namespace dd
199199
{
200200
if (_scaled)
201201
scale_cuda(src, dst);
202-
else if (_width == 0 || _height == 0)
202+
else if (_width < 0 || _height < 0)
203203
{
204-
if (_width == 0 && _height == 0)
204+
if (_width < 0 && _height < 0)
205205
{
206206
// Do nothing and keep native resolution. May cause issues if
207207
// batched images are different resolutions

‎tests/ut-torchapi.cc

+141
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,46 @@ TEST(torchapi, service_predict_object_detection)
394394
ASSERT_EQ(preds_best.Size(), 3);
395395
}
396396

397+
TEST(torchapi, service_predict_object_detection_any_size)
398+
{
399+
JsonAPI japi;
400+
std::string sname = "detectserv";
401+
std::string jstr
402+
= "{\"mllib\":\"torch\",\"description\":\"fasterrcnn\",\"type\":"
403+
"\"supervised\",\"model\":{\"repository\":\""
404+
+ detect_repo
405+
+ "\"},\"parameters\":{\"input\":{\"connector\":\"image\",\"height\":"
406+
"-1,\"width\":-1,\"rgb\":true,\"scale\":0.0039},\"mllib\":{"
407+
"\"template\":\"fasterrcnn\"}}}";
408+
409+
std::string joutstr = japi.jrender(japi.service_create(sname, jstr));
410+
ASSERT_EQ(created_str, joutstr);
411+
std::string jpredictstr
412+
= "{\"service\":\"detectserv\",\"parameters\":{"
413+
"\"input\":{\"height\":-1,"
414+
"\"width\":-1},\"output\":{\"bbox\":true, "
415+
"\"best_bbox\":1,\"confidence_threshold\":0.8}},\"data\":[\""
416+
+ detect_train_repo_fasterrcnn + "/imgs/000550-L.jpg\"]}";
417+
418+
joutstr = japi.jrender(japi.service_predict(jpredictstr));
419+
JDoc jd;
420+
std::cout << "joutstr=" << joutstr << std::endl;
421+
jd.Parse<rapidjson::kParseNanAndInfFlag>(joutstr.c_str());
422+
ASSERT_TRUE(!jd.HasParseError());
423+
ASSERT_EQ(200, jd["status"]["code"]);
424+
ASSERT_TRUE(jd["body"]["predictions"].IsArray());
425+
426+
auto &preds = jd["body"]["predictions"][0]["classes"];
427+
std::string cl1 = preds[0]["cat"].GetString();
428+
ASSERT_TRUE(cl1 == "car");
429+
ASSERT_TRUE(preds[0]["prob"].GetDouble() > 0.9);
430+
auto &bbox = preds[0]["bbox"];
431+
ASSERT_NEAR(bbox["xmin"].GetDouble(), 258.0, 5.0);
432+
ASSERT_NEAR(bbox["ymin"].GetDouble(), 333.0, 5.0);
433+
ASSERT_NEAR(bbox["xmax"].GetDouble(), 401.0, 5.0);
434+
ASSERT_NEAR(bbox["ymax"].GetDouble(), 448.0, 5.0);
435+
}
436+
397437
TEST(torchapi, service_predict_segmentation)
398438
{
399439
JsonAPI japi;
@@ -2748,6 +2788,107 @@ TEST(torchapi, service_train_object_detection_translation)
27482788
fileops::remove_dir(detect_train_repo_yolox + "test_0.lmdb");
27492789
}
27502790

2791+
TEST(torchapi, service_train_object_detection_yolox_any_size)
2792+
{
2793+
// Test with arbitrary image size: width = -1, height = -1
2794+
setenv("CUBLAS_WORKSPACE_CONFIG", ":4096:8", true);
2795+
torch::manual_seed(torch_seed);
2796+
at::globalContext().setDeterministicCuDNN(true);
2797+
2798+
JsonAPI japi;
2799+
std::string sname = "detectserv";
2800+
std::string jstr
2801+
= "{\"mllib\":\"torch\",\"description\":\"yolox\",\"type\":"
2802+
"\"supervised\",\"model\":{\"repository\":\""
2803+
+ detect_train_repo_yolox
2804+
+ "\"},\"parameters\":{\"input\":{\"connector\":\"image\",\"height\":"
2805+
"-1,\"width\":-1,\"rgb\":true,\"bbox\":true,\"db\":true},"
2806+
"\"mllib\":{\"template\":\"yolox\",\"gpu\":true,"
2807+
"\"nclasses\":2}}}";
2808+
2809+
std::string joutstr = japi.jrender(japi.service_create(sname, jstr));
2810+
ASSERT_EQ(created_str, joutstr);
2811+
2812+
// Train
2813+
std::string jtrainstr
2814+
= "{\"service\":\"detectserv\",\"async\":false,\"parameters\":{"
2815+
"\"mllib\":{\"solver\":{\"iterations\":3"
2816+
+ std::string("")
2817+
//+ iterations_detection + ",\"base_lr\":" + torch_lr
2818+
+ ",\"iter_size\":2,\"solver_"
2819+
"type\":\"ADAM\",\"test_interval\":200},\"net\":{\"batch_size\":2,"
2820+
"\"test_batch_size\":1,\"reg_weight\":0.5},\"resume\":false,"
2821+
"\"mirror\":true,\"rotate\":true,\"crop_size\":512,"
2822+
"\"test_crop_samples\":10,"
2823+
"\"cutout\":0.1,\"geometry\":{\"prob\":0.1,\"persp_horizontal\":"
2824+
"true,\"persp_vertical\":true,\"zoom_in\":true,\"zoom_out\":true,"
2825+
"\"pad_mode\":\"constant\"},\"noise\":{\"prob\":0.01},\"distort\":{"
2826+
"\"prob\":0.01}},\"input\":{\"seed\":12347,\"db\":true,"
2827+
"\"shuffle\":true},\"output\":{\"measure\":[\"map-05\",\"map-50\","
2828+
"\"map-90\"]}},\"data\":[\""
2829+
+ fasterrcnn_train_data + "\",\"" + fasterrcnn_test_data + "\"]}";
2830+
2831+
joutstr = japi.jrender(japi.service_train(jtrainstr));
2832+
JDoc jd;
2833+
std::cout << "joutstr=" << joutstr << std::endl;
2834+
jd.Parse<rapidjson::kParseNanAndInfFlag>(joutstr.c_str());
2835+
ASSERT_TRUE(!jd.HasParseError());
2836+
ASSERT_EQ(201, jd["status"]["code"]);
2837+
2838+
// ASSERT_EQ(jd["body"]["measure"]["iteration"], 200) << "iterations";
2839+
ASSERT_TRUE(jd["body"]["measure"]["map"].GetDouble() <= 1.0) << "map";
2840+
ASSERT_TRUE(jd["body"]["measure"]["map-05"].GetDouble() <= 1.0) << "map-05";
2841+
ASSERT_TRUE(jd["body"]["measure"]["map-50"].GetDouble() <= 1.0) << "map-50";
2842+
ASSERT_TRUE(jd["body"]["measure"]["map-90"].GetDouble() <= 1.0) << "map-90";
2843+
// ASSERT_TRUE(jd["body"]["measure"]["map"].GetDouble() > 0.0) << "map";
2844+
2845+
// check metrics
2846+
auto &meas = jd["body"]["measure"];
2847+
ASSERT_TRUE(meas.HasMember("iou_loss"));
2848+
ASSERT_TRUE(meas.HasMember("conf_loss"));
2849+
ASSERT_TRUE(meas.HasMember("cls_loss"));
2850+
ASSERT_TRUE(meas.HasMember("l1_loss"));
2851+
ASSERT_TRUE(meas.HasMember("train_loss"));
2852+
ASSERT_TRUE(
2853+
std::abs(meas["train_loss"].GetDouble()
2854+
- (meas["iou_loss"].GetDouble() * 0.5
2855+
+ meas["cls_loss"].GetDouble() + meas["l1_loss"].GetDouble()
2856+
+ meas["conf_loss"].GetDouble()))
2857+
< 0.0001);
2858+
2859+
// check that predict works fine
2860+
std::string jpredictstr = "{\"service\":\"detectserv\",\"parameters\":{"
2861+
"\"input\":{\"height\":-1,"
2862+
"\"width\":-1},\"output\":{\"bbox\":true, "
2863+
"\"confidence_threshold\":0.8}},\"data\":[\""
2864+
+ detect_train_repo_fasterrcnn
2865+
+ "/imgs/000550-L.jpg\"]}";
2866+
joutstr = japi.jrender(japi.service_predict(jpredictstr));
2867+
jd = JDoc();
2868+
std::cout << "joutstr=" << joutstr << std::endl;
2869+
jd.Parse<rapidjson::kParseNanAndInfFlag>(joutstr.c_str());
2870+
ASSERT_TRUE(!jd.HasParseError());
2871+
ASSERT_EQ(200, jd["status"]["code"]);
2872+
2873+
std::unordered_set<std::string> lfiles;
2874+
fileops::list_directory(detect_train_repo_yolox, true, false, false, lfiles);
2875+
for (std::string ff : lfiles)
2876+
{
2877+
if (ff.find("checkpoint") != std::string::npos
2878+
|| ff.find("solver") != std::string::npos)
2879+
remove(ff.c_str());
2880+
}
2881+
ASSERT_TRUE(!fileops::file_exists(detect_train_repo_yolox + "checkpoint-"
2882+
+ iterations_detection + ".ptw"));
2883+
ASSERT_TRUE(!fileops::file_exists(detect_train_repo_yolox + "checkpoint-"
2884+
+ iterations_detection + ".pt"));
2885+
2886+
fileops::clear_directory(detect_train_repo_yolox + "train.lmdb");
2887+
fileops::clear_directory(detect_train_repo_yolox + "test_0.lmdb");
2888+
fileops::remove_dir(detect_train_repo_yolox + "train.lmdb");
2889+
fileops::remove_dir(detect_train_repo_yolox + "test_0.lmdb");
2890+
}
2891+
27512892
TEST(torchapi, service_train_images_native)
27522893
{
27532894
setenv("CUBLAS_WORKSPACE_CONFIG", ":4096:8", true);

0 commit comments

Comments
 (0)