@@ -394,6 +394,46 @@ TEST(torchapi, service_predict_object_detection)
394
394
ASSERT_EQ (preds_best.Size (), 3 );
395
395
}
396
396
397
+ TEST (torchapi, service_predict_object_detection_any_size)
398
+ {
399
+ JsonAPI japi;
400
+ std::string sname = " detectserv" ;
401
+ std::string jstr
402
+ = " {\" mllib\" :\" torch\" ,\" description\" :\" fasterrcnn\" ,\" type\" :"
403
+ " \" supervised\" ,\" model\" :{\" repository\" :\" "
404
+ + detect_repo
405
+ + " \" },\" parameters\" :{\" input\" :{\" connector\" :\" image\" ,\" height\" :"
406
+ " -1,\" width\" :-1,\" rgb\" :true,\" scale\" :0.0039},\" mllib\" :{"
407
+ " \" template\" :\" fasterrcnn\" }}}" ;
408
+
409
+ std::string joutstr = japi.jrender (japi.service_create (sname, jstr));
410
+ ASSERT_EQ (created_str, joutstr);
411
+ std::string jpredictstr
412
+ = " {\" service\" :\" detectserv\" ,\" parameters\" :{"
413
+ " \" input\" :{\" height\" :-1,"
414
+ " \" width\" :-1},\" output\" :{\" bbox\" :true, "
415
+ " \" best_bbox\" :1,\" confidence_threshold\" :0.8}},\" data\" :[\" "
416
+ + detect_train_repo_fasterrcnn + " /imgs/000550-L.jpg\" ]}" ;
417
+
418
+ joutstr = japi.jrender (japi.service_predict (jpredictstr));
419
+ JDoc jd;
420
+ std::cout << " joutstr=" << joutstr << std::endl;
421
+ jd.Parse <rapidjson::kParseNanAndInfFlag >(joutstr.c_str ());
422
+ ASSERT_TRUE (!jd.HasParseError ());
423
+ ASSERT_EQ (200 , jd[" status" ][" code" ]);
424
+ ASSERT_TRUE (jd[" body" ][" predictions" ].IsArray ());
425
+
426
+ auto &preds = jd[" body" ][" predictions" ][0 ][" classes" ];
427
+ std::string cl1 = preds[0 ][" cat" ].GetString ();
428
+ ASSERT_TRUE (cl1 == " car" );
429
+ ASSERT_TRUE (preds[0 ][" prob" ].GetDouble () > 0.9 );
430
+ auto &bbox = preds[0 ][" bbox" ];
431
+ ASSERT_NEAR (bbox[" xmin" ].GetDouble (), 258.0 , 5.0 );
432
+ ASSERT_NEAR (bbox[" ymin" ].GetDouble (), 333.0 , 5.0 );
433
+ ASSERT_NEAR (bbox[" xmax" ].GetDouble (), 401.0 , 5.0 );
434
+ ASSERT_NEAR (bbox[" ymax" ].GetDouble (), 448.0 , 5.0 );
435
+ }
436
+
397
437
TEST (torchapi, service_predict_segmentation)
398
438
{
399
439
JsonAPI japi;
@@ -2748,6 +2788,107 @@ TEST(torchapi, service_train_object_detection_translation)
2748
2788
fileops::remove_dir (detect_train_repo_yolox + " test_0.lmdb" );
2749
2789
}
2750
2790
2791
+ TEST (torchapi, service_train_object_detection_yolox_any_size)
2792
+ {
2793
+ // Test with arbitrary image size: width = -1, height = -1
2794
+ setenv (" CUBLAS_WORKSPACE_CONFIG" , " :4096:8" , true );
2795
+ torch::manual_seed (torch_seed);
2796
+ at::globalContext ().setDeterministicCuDNN (true );
2797
+
2798
+ JsonAPI japi;
2799
+ std::string sname = " detectserv" ;
2800
+ std::string jstr
2801
+ = " {\" mllib\" :\" torch\" ,\" description\" :\" yolox\" ,\" type\" :"
2802
+ " \" supervised\" ,\" model\" :{\" repository\" :\" "
2803
+ + detect_train_repo_yolox
2804
+ + " \" },\" parameters\" :{\" input\" :{\" connector\" :\" image\" ,\" height\" :"
2805
+ " -1,\" width\" :-1,\" rgb\" :true,\" bbox\" :true,\" db\" :true},"
2806
+ " \" mllib\" :{\" template\" :\" yolox\" ,\" gpu\" :true,"
2807
+ " \" nclasses\" :2}}}" ;
2808
+
2809
+ std::string joutstr = japi.jrender (japi.service_create (sname, jstr));
2810
+ ASSERT_EQ (created_str, joutstr);
2811
+
2812
+ // Train
2813
+ std::string jtrainstr
2814
+ = " {\" service\" :\" detectserv\" ,\" async\" :false,\" parameters\" :{"
2815
+ " \" mllib\" :{\" solver\" :{\" iterations\" :3"
2816
+ + std::string (" " )
2817
+ // + iterations_detection + ",\"base_lr\":" + torch_lr
2818
+ + " ,\" iter_size\" :2,\" solver_"
2819
+ " type\" :\" ADAM\" ,\" test_interval\" :200},\" net\" :{\" batch_size\" :2,"
2820
+ " \" test_batch_size\" :1,\" reg_weight\" :0.5},\" resume\" :false,"
2821
+ " \" mirror\" :true,\" rotate\" :true,\" crop_size\" :512,"
2822
+ " \" test_crop_samples\" :10,"
2823
+ " \" cutout\" :0.1,\" geometry\" :{\" prob\" :0.1,\" persp_horizontal\" :"
2824
+ " true,\" persp_vertical\" :true,\" zoom_in\" :true,\" zoom_out\" :true,"
2825
+ " \" pad_mode\" :\" constant\" },\" noise\" :{\" prob\" :0.01},\" distort\" :{"
2826
+ " \" prob\" :0.01}},\" input\" :{\" seed\" :12347,\" db\" :true,"
2827
+ " \" shuffle\" :true},\" output\" :{\" measure\" :[\" map-05\" ,\" map-50\" ,"
2828
+ " \" map-90\" ]}},\" data\" :[\" "
2829
+ + fasterrcnn_train_data + " \" ,\" " + fasterrcnn_test_data + " \" ]}" ;
2830
+
2831
+ joutstr = japi.jrender (japi.service_train (jtrainstr));
2832
+ JDoc jd;
2833
+ std::cout << " joutstr=" << joutstr << std::endl;
2834
+ jd.Parse <rapidjson::kParseNanAndInfFlag >(joutstr.c_str ());
2835
+ ASSERT_TRUE (!jd.HasParseError ());
2836
+ ASSERT_EQ (201 , jd[" status" ][" code" ]);
2837
+
2838
+ // ASSERT_EQ(jd["body"]["measure"]["iteration"], 200) << "iterations";
2839
+ ASSERT_TRUE (jd[" body" ][" measure" ][" map" ].GetDouble () <= 1.0 ) << " map" ;
2840
+ ASSERT_TRUE (jd[" body" ][" measure" ][" map-05" ].GetDouble () <= 1.0 ) << " map-05" ;
2841
+ ASSERT_TRUE (jd[" body" ][" measure" ][" map-50" ].GetDouble () <= 1.0 ) << " map-50" ;
2842
+ ASSERT_TRUE (jd[" body" ][" measure" ][" map-90" ].GetDouble () <= 1.0 ) << " map-90" ;
2843
+ // ASSERT_TRUE(jd["body"]["measure"]["map"].GetDouble() > 0.0) << "map";
2844
+
2845
+ // check metrics
2846
+ auto &meas = jd[" body" ][" measure" ];
2847
+ ASSERT_TRUE (meas.HasMember (" iou_loss" ));
2848
+ ASSERT_TRUE (meas.HasMember (" conf_loss" ));
2849
+ ASSERT_TRUE (meas.HasMember (" cls_loss" ));
2850
+ ASSERT_TRUE (meas.HasMember (" l1_loss" ));
2851
+ ASSERT_TRUE (meas.HasMember (" train_loss" ));
2852
+ ASSERT_TRUE (
2853
+ std::abs (meas[" train_loss" ].GetDouble ()
2854
+ - (meas[" iou_loss" ].GetDouble () * 0.5
2855
+ + meas[" cls_loss" ].GetDouble () + meas[" l1_loss" ].GetDouble ()
2856
+ + meas[" conf_loss" ].GetDouble ()))
2857
+ < 0.0001 );
2858
+
2859
+ // check that predict works fine
2860
+ std::string jpredictstr = " {\" service\" :\" detectserv\" ,\" parameters\" :{"
2861
+ " \" input\" :{\" height\" :-1,"
2862
+ " \" width\" :-1},\" output\" :{\" bbox\" :true, "
2863
+ " \" confidence_threshold\" :0.8}},\" data\" :[\" "
2864
+ + detect_train_repo_fasterrcnn
2865
+ + " /imgs/000550-L.jpg\" ]}" ;
2866
+ joutstr = japi.jrender (japi.service_predict (jpredictstr));
2867
+ jd = JDoc ();
2868
+ std::cout << " joutstr=" << joutstr << std::endl;
2869
+ jd.Parse <rapidjson::kParseNanAndInfFlag >(joutstr.c_str ());
2870
+ ASSERT_TRUE (!jd.HasParseError ());
2871
+ ASSERT_EQ (200 , jd[" status" ][" code" ]);
2872
+
2873
+ std::unordered_set<std::string> lfiles;
2874
+ fileops::list_directory (detect_train_repo_yolox, true , false , false , lfiles);
2875
+ for (std::string ff : lfiles)
2876
+ {
2877
+ if (ff.find (" checkpoint" ) != std::string::npos
2878
+ || ff.find (" solver" ) != std::string::npos)
2879
+ remove (ff.c_str ());
2880
+ }
2881
+ ASSERT_TRUE (!fileops::file_exists (detect_train_repo_yolox + " checkpoint-"
2882
+ + iterations_detection + " .ptw" ));
2883
+ ASSERT_TRUE (!fileops::file_exists (detect_train_repo_yolox + " checkpoint-"
2884
+ + iterations_detection + " .pt" ));
2885
+
2886
+ fileops::clear_directory (detect_train_repo_yolox + " train.lmdb" );
2887
+ fileops::clear_directory (detect_train_repo_yolox + " test_0.lmdb" );
2888
+ fileops::remove_dir (detect_train_repo_yolox + " train.lmdb" );
2889
+ fileops::remove_dir (detect_train_repo_yolox + " test_0.lmdb" );
2890
+ }
2891
+
2751
2892
TEST (torchapi, service_train_images_native)
2752
2893
{
2753
2894
setenv (" CUBLAS_WORKSPACE_CONFIG" , " :4096:8" , true );
0 commit comments