From e6f7a2205c06e703c6b22bdcc5e1f248823c2a2e Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Sat, 7 Dec 2024 04:12:12 -0500 Subject: [PATCH] test: add arrow datasets and arrow unit tests (#403) Signed-off-by: Will Johnson --- tests/artifacts/testdata/__init__.py | 7 +++ .../twitter_complaints_input_output.arrow | Bin 0 -> 13858 bytes .../testdata/twitter_complaints_small.arrow | Bin 0 -> 3930 bytes ..._tokenized_with_maykeye_tinyllama_v0.arrow | Bin 0 -> 11466 bytes tests/data/test_data_preprocessing_utils.py | 41 ++++++++++++++++++ tuning/utils/utils.py | 2 + 6 files changed, 50 insertions(+) create mode 100644 tests/artifacts/testdata/twitter_complaints_input_output.arrow create mode 100644 tests/artifacts/testdata/twitter_complaints_small.arrow create mode 100644 tests/artifacts/testdata/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.arrow diff --git a/tests/artifacts/testdata/__init__.py b/tests/artifacts/testdata/__init__.py index 8b6a7ea43..39895f6f1 100644 --- a/tests/artifacts/testdata/__init__.py +++ b/tests/artifacts/testdata/__init__.py @@ -22,6 +22,7 @@ PARQUET_DATA_DIR = os.path.join(os.path.dirname(__file__), "parquet") TWITTER_COMPLAINTS_DATA_JSON = os.path.join(DATA_DIR, "twitter_complaints_small.json") TWITTER_COMPLAINTS_DATA_JSONL = os.path.join(DATA_DIR, "twitter_complaints_small.jsonl") +TWITTER_COMPLAINTS_DATA_ARROW = os.path.join(DATA_DIR, "twitter_complaints_small.arrow") TWITTER_COMPLAINTS_DATA_PARQUET = os.path.join( PARQUET_DATA_DIR, "twitter_complaints_small.parquet" ) @@ -31,6 +32,9 @@ TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL = os.path.join( DATA_DIR, "twitter_complaints_input_output.jsonl" ) +TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW = os.path.join( + DATA_DIR, "twitter_complaints_input_output.arrow" +) TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET = os.path.join( PARQUET_DATA_DIR, "twitter_complaints_input_output.parquet" ) @@ -40,6 +44,9 @@ TWITTER_COMPLAINTS_TOKENIZED_JSONL = os.path.join( DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.jsonl" ) +TWITTER_COMPLAINTS_TOKENIZED_ARROW = os.path.join( + DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.arrow" +) TWITTER_COMPLAINTS_TOKENIZED_PARQUET = os.path.join( PARQUET_DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.parquet" ) diff --git a/tests/artifacts/testdata/twitter_complaints_input_output.arrow b/tests/artifacts/testdata/twitter_complaints_input_output.arrow new file mode 100644 index 0000000000000000000000000000000000000000..602798d34df874ea5e4821f830eb04d37812a68e GIT binary patch literal 13858 zcmeI3U5p&rRmXcjve_6DoXsW>2yka-H{OMrw%7J98!M7)Jf88iJzvbs+UvY{Z&%&! zE_ZcRxo*|8+fR$c10o3JAw&d4D7>J30F;LaIe&_)vf5*nEbd%~%o)axHTaePO8MPFe$R3&?P%lUl?0Ev(Hc)Y z>G=L3vM$>Z&X4kLDdYd(qO9#ybyvEd-F3Y8@SW+{dUo3s~e z2Fuf(^ZNeRKUUqSy867H?{_|4-grR-N&BiOV^^!s>*wG4M0vBj>p79%?sdu&cY7{l zdPXEqKnjDn41EXs(9bj)*PyRJM~*fczXg34@=#n3{SZ3yVQd5X8|d+$!yceM z^bgQ93jYH1HuM8%86DPo^p;jO=KaFv$-rb*ttKk(g;w!zFjAc`v$?vFw7tt&Xmh=y zv(T*Ttgp6uI_|4cQm9UnNuQ}O&P`@)KBZb6cPB7@7#Ni&%GX$8G~=#v%y|^4`D#)% zuQQ``u3nyf?c~YX)oWX=m4(eFgS6zu(N4Y9eEEAM<_P5F1z zWZR){yoCj|7ORcK#!TB&b6Kb()zjNy+*REq2~?3ri4K&G17;k??!>0X1Xi_+Tp^h( zY)3{#I>#%G)Txsv2TXdkHxf~c$eZ&CgBYuZ@RRD)WHZVlx*|7Cb@nASeLhR_9y*%K zm(Z0r@cSl;f~0H9Wevh^4_zTt)i!EiW@O^oS!G5h$dc5agX=u(gubwz8;!C>?o3fw z(NwaXMP5SNVqIb1|#Wcf(hFgNGC*iEIQZWacG1w*LGsF*hpedaVI%(X)x zi0J(y%S_B1Xbx}%P2lLnAd^@Gjc#C3w+RwXPv=(6Hsjb{lBVtsfLWQF_Y=#rp^sF% zIs|JAJT+2zI56cX>zg=W_$DBR(Omb@ z=}0~Qyg=OBG-)r%w0FtGao9207K<6F6+kvp^F1-ufd*BH;_ncYICPh0l!$iog7;XJ ztYoA{fIyM;|U7C*YvKKEWG`bdt$;EBBoRX0m!M174~C; z-kBYS0nk7*$1uzc;I>#;?krUF+AVMw3IlaUYv;7qJ!w-kwK7uPhQ5tdy!A+r)?gJ4 z`aKmiw>9|gYt=J4u=oaFBjhlaw^}L)1MwLwEmr+F8Gz33~gAfr606#V=3c&iXxVSr2$-mS>rHhu1FR?Pao)nQl_i9h) z$E~~o{EVNma+paI?K_B?ZWkk8k-=9Nt}JY*C`s^yyyRR|)0?}xSu6Ic*!oaR0AHcS zF7E5poCYOC(o7e5Pqh-)%n)u_4)Ig;|{`@ zE6bl=C{DK*ya9^xdtgyr3ALC2OUT9g;hA8PfA6g?S=B>bXb@u&Sm3j*tb0rDZ8?w5 zT1x*emV`xPJS=mlI?+hP85Dlsp+ZjbWT=T4TKQQr3Dqy2#rJ5Q}!f;p~e@6-ROUq+k$loEaDnXFDNA>@!ib+4Md;w=s9!h2++8 z9B11IQgrAXH9FBMGbv+~=A93X!C8y*ys%QbcehHGfu9_zWn2-um{Q9+>l$@RE$TtY z+Y-TCFxb@M@?$l0RlHQqQs9~DN;V>yw@4eKm$<)ly|A8~j?ESWHq8GLr- z0T8K+5y-QHJqmCo(y&zW1JOsk9R?xrfO`@Ph4)n5{FXZN>PeBCjY2{M3IaO}kvmr# zlY&_FnhMOeOtu}Gp`;q2I8O=aE_w+kEm>C?@!X*U?l3Ih{|)iX%}bI{qzO0@Q!4%6 zSQ4OMQh{=JYz~P9%xoupA%*U$b}U{_^>ywa(4?qQB&AB`I0oobA89qfD0` zbrFH+oA5N)F~er7%$>TDExb}CA8@dhX&);AZ(eMOWijE5blHJdyja&^(Uy540?iD7 zAZMEL7Q?(9=C#LlV3SBM=!xxU`NLdIk#Mvayd%TWOhye#lAprpZgX}V6#8q1vyf3*c(6y5D3*gDnBH>3Q zl8WZesp*y;EDYK@%do6qG0Cnag*O*<^aczi;&iN93Tw9~S)p?_2x^9SB#4#NmXuk9 zKqldg3Y+P}n6YA~)_^5s!E+KQcypO3h}$KUjHcAZ8d#R~RDsPUL6U^D%6wsZ(X=yj z+pMiFsA-#$wvsaiK7m+JD8an$279^+a5`drV2fwRx+_<)pVY}#a$}d{U|k4C@@Dbq zO?45WppiU5eZ)wP$rx@CNWIJGB-5AIQE&xy0*GYc&ait0+HDtMtZLp*t>Z%?k#wm7=%t^$@55he2}6xTI=w>#bPA%iRf9O+k!}N zo+(|$4Zs9P-sa5$Fui_xhWsUUmMFI4RK#r-7m?>B@sxM9mqccf?IH}8xtxfrjay_I z7gC*T>acBenSl$pDVh#t2Zh(pC3^?m@>*uP$VW#f>J1MW4!b*cKo@iznuzd=w?fxF zux)d6Q$`;?VTC{*7>(T4rm5ClS|r;L;!U2Ax0COedmnZ$w&`gCTvBv<>Y7POu~}Qs zsZnC5^JiLABzwzBXCi$E=rY{6+%A?R-T4KVA<4doWsJC2W}#vU0`nq8P2=xam)QwZVo{vBFeqVOwPg*OMTo-0FOdKR1xa3Wi9MA@0OBGr%X;nA zwyxY667I!c$Z&1oN~$HkS!YePm=OfY;_>Rrb3@6ui_L9cd2?7c%Z}Yuj|3R8TjgCy z0_@3SFE(N;ot2)A9Zem6osT3VY3nZ89J^e}vT&3@Ow5tV#qkjl0=JEKaM{^|^2LK> z$&jP~2}U~C*ww1fZ;k4+S>uW7o_sr$#wRG(`26G=pPyXg)01m_dUB1=POkCE$u&MV zxmJHGewcGYLR#xGXaf2r!_ z+p4r*F1yCBRQ<13y{h`Ds(-HPr>p+?s(+#CKUMX|tNz8Rf2rz!t?GZh>SwBcw(6Tz zf1>J7R{hIWU(3JlYrWKceSL3V->>_<_wV++9(QkFk6T}_&kyPkdak~{x8FNXeSgyH zd!IY#ydHONFKQ7C3(lq>(eJC!d1w{73cU$M&@JerI zA43!&8c#Ffvyg&jAgOW8K`YQTs15a@9C{1-J?IahKY{)n`b+3r(BDDdfqnq}H}pTy z5f=MV=#xn@9R4Z+Ycm0FHUlN z=!-iK+<9=&_&ZPj)DIl=;w0C1X20_QTuN&}{|bE$EFA+w8hVDn@C|5=!0;aQCV}CP zp=kocx1n1EhWDTw1ctwWo*^)N6MCJ%@Xydi0>ghow=t?8Lti8?{5$kT0>k&AHwX;h zf|dyk{|5O4hVMh4AuzlPog*;(KJ*a+gN42dJx6f(Ge{8}z6QMqwFnS@57l*@>FH^; z#aH^XMB+&*7e|iB8?JJ9R#Ju?A?+II!523s;ObGTe~huap9CL7edI7E*jlg;nzVj0 zHrfHWc%(`!lQ8_qRADAzw%)hgw?fSr6Z@d{(JAB{#_7E;^o-4T0Q?WG1x@nymHS;T znt+K%tD-ak!w_e!)jO&!xRVTb|LeSCTK-b zp6}0>`<;^1PAFYaH++`g*vKz=j=375{CemN*QKUR>7+qfrLiP`Xj0pg_LJY={QvKn WL`Kr&|2^3Id!{#$*PQ!p*#7}KJyhxd literal 0 HcmV?d00001 diff --git a/tests/artifacts/testdata/twitter_complaints_small.arrow b/tests/artifacts/testdata/twitter_complaints_small.arrow new file mode 100644 index 0000000000000000000000000000000000000000..b5bba53e2318470857bdae0ad43d7ee85445350b GIT binary patch literal 3930 zcmeHKy>A>v6rUu{2jM^#ickuQ7demw*=Hw2D01U-j3EZ>M2;mZ6`9+8cQ<%%XEihH z^LDABp`bycrlzC-|@Gv=dX{6}WTwPSsz9b)|j+eGYmES^!Of?m!ndofcge z>p67LsFck^mItM7LMEG#du>8a&J4U4K1y_#6X=HAmzC_vgE|Gnhn1N6v6N9KI+*)@ zLVi;BNvy5CBet|aXG>c!RqmNq7r>vvtXMl;`=YXgFeA1}6MS1E=0Ji{=H;(yIQGdN z{CJ_}Tk6%?ldV##IyF`LWctNg0NO5MMZ79{e+P+;rQd6jm+JnoZnmBj-%ESB2GVH& zRnw~B@74Tg2i6RBhfBa;Sp?87Bm>y;}pMzr!*1>m+7+W(tVpsBV z*2n^0u(2*^?aGj1o3h_AG}LJus4LSn`9xF@c#URmI~xXw(Q>$hSapWc4kt-$ z`@S-3X!--h3Qnai(~xK7;JG>Cf@5bhzXa65^h_ks18W3Z{Ej1X&%HdM(STEl&=r|Y zn6tTGChv6~aFR*1XvAf>fRuPoM?qp8 z-DmWrPG=*B}Mq`Y3mfQ=EawA-Uo<(y?@d z9OB_<#0C&`7Z-OKr=<XO=I>Er!myGk;DKR^bTTL(B;dL;+^}P4Xo3h_c={X zkKGoGDQF!@R?x~oeyT%_tg?jf2w75Qcl9%}65WtvAFL8f3M!Bc^vC1x@GrMV)o^8$ zP3na-Tp49+uEpjWlbR;&Y}GaOM7>b+&Hkou^frAnf3R=Xo4#2WwJng;)k~l^L2rX@ zgEm0hpogI5UPj;}3s)m1VRAg491kW}pEe!~j|Y>IIJRtS`2Gy~6yT=;T!UVL10REy z;lSTOO&yq-nV~J*I*Zgf(oKB+yzD3|{-R{{7|4X8=YLrr0o8H3&-cMS%!AM4zJC^o zzgG|TX?sdu9IKLX6u`%(svHIEe_nX@fjA^1fjoI7o10R|JOSOi)jD90V_YqJLRx9-~fJGF5lFD0{)2Y Apa1{> literal 0 HcmV?d00001 diff --git a/tests/artifacts/testdata/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.arrow b/tests/artifacts/testdata/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.arrow new file mode 100644 index 0000000000000000000000000000000000000000..6afd36ddf6280f6116974582f143341033c33587 GIT binary patch literal 11466 zcmeHNeQX>@6<;@wlcsKPo3;wHrDL2=5X+M-kX`7)%EV&dDot`p7%D2cZKJ*c~Ta5tzL^a*PD+gl%{6z z9=Q4V+swb7HxJy2CeQm6>JrqyfPVqCkV$#oBmaU=iQqV$^0-DMEwPmgV`)OvY=<;rYxo-3GkKiEbHtC29~>s_DL<0QKH)AiG%n z^P2vr{vDZ&)4$!x&#kvd?2f7_lzZL=$W*4usgDrJa7*oEx-WCm^M)b!gp>U*8|Fg( zPQe277nzK+;l*k4({&?EUjJ5!M=R>QEFNgeaA^a&^MUTXEMjepXH({ffso=VaUGhl zHZ@|Uw97x22b(Z~#7Ii3cLOt_HPh(LY0fatEpz?i>0c|(!wRrEY` zHQTyL7t+9-8p?L%caZ;LUBB|$|4LoI@;YCefQ}-VOF2s3asS~}{=4iSCT+V)B^^~> z<#ToY$``M8^52w1omQ0B@xQ9;S6;`bg0e_Y>5qIBy@=*_W>-X;ijxVg@6V7K6m#f4OExgE&deebGqnjyOnC zf549eyTgyi%}~(~$4xD&n$aj08k-M*e{2(^bsL733KTSqz z1q`i36E@CxnQJ0~d1+MeYo>pB*%YGCCgXuA`(YQXcCPs*DftyxENVw!OwwiciVq>y zqRDNE66^l%vhjf+Z*$wWnY%)>GfF~CTQt3K;FnC%KM;gtW-N*drdlaQe!=)*0W(g7 z8Ix43EhJ{7Dr-EBgOQRoB|o)czGl{Scb75go?=a)7RbB$A}FzmnSebNzhdLnCINKS z)OMMkm8N4$9Hm9Xs5jk?Smnz3aa$@CqOrueRSw3Ah!sFJBi5AdYMHpR)7Y9V#8D;L z46D;%G{}qTsr3LB7F8K@V!G0HtO&bgf%yAX`9~t=MgRW<{LVOoDT2b1uxL z@=IfJP^c!zA%w}A5F-G6oJB~C8wrqtfIeT1V;f=)L=MSS=t4SXFvuhcfJSU!qCbK| zCyUeOL3;ptl44Y;WIafjf!@A+lt4Df1JyAb zoT-@sBw5Y$6=kWGePmT+@OJ`Pl4ZABMv+7}tzsO!5>INThGZ~*I`usM%kJTdZaZ?; zNx_zp72Vc#yO#4e0-uinHDZ!Zf1YZbdYImOq30>#sfX#3ZbrNt;NiEZ$6DQ{$#PcQ zj`}d_GpKS$Zo%H%gL)6@qp12GjmNC~{bP>)v%}HP9Cq{w54?lqE2L-Tf{60UR!8?! z(E%Qm-y^?;w8{hR1=6$B%j=1!bV0k*_blhRcgXQ-&yP6%(@R}D)q6A9|7GfXod^2e zHQt;&$C_Qp!8zLL-E-}f7I{K zasC~oPZM9w8}<7^_WRd3`(L}k(Oalb^K}vZTE;xvNV=Ef=TUwKok`PIIc zH@kV&54Hau>VK7Syq@!%A)vQXpW=OjysndG%3muVl28nJ#i6~&wd?q;ydZypeyIO3 z?bu8G7i#~par!jgN@t4uJ6(UD{Ggj>BiG-_TO9v5_m2(q^C%5i$rtRGh@&qrz9=1( zyN>Z+ben5G%k~?WJ3hx3@-xIm;};Q^`gL-eb}6o>_&r7AqWhWlD^5x;7GAsfVd?}g z7VSmnkM0-UjAzfSu3hc*6@= zUr)P^lP+8D?9ugbl5u;U@&lBcLtgRFxcAfVmxs{FA+oYW^0e=ZBO#&3q5&huWj@9V8x~ zrTj|+PM^xv*sgK@CGFouJnm<|##!m5iuV(oPvd+qdEM`i&_3Pg7mz2Xs_m_#zpC+iqFPP4!^(h9J`x-=y~>i^19wv5ns(;jn{6zN9aEJ zHRh?xEoHyrsNc19{V0CNIc^=}rTgu>#bvhVlYf)L;mFO7>U)-+mlekQS?0%`olfo$ zKe*h_c&mRWnZFOw55@B>;;HBPA@=J&)A-Jgc3e+$=-zE3P);q>VHj{3c@>e~0v|JS%)G%uCTl%KPWgU>Y-_8r}CxDC$uU8Hd@b2!h0 zSL0shaBAGk91i-P)wq{AJR0{hhl9S~)W84Dw7=;4P~%?aaBSSm91eP~(fnxK%N#Bj z>t2SRJiaX&|JUm>>O7agtO+%>-0+kkfY)FjttS;O-C1}QnEygy1?xc6qiVg5JU{9s1_ literal 0 HcmV?d00001 diff --git a/tests/data/test_data_preprocessing_utils.py b/tests/data/test_data_preprocessing_utils.py index a4ec5dbf7..6e7dacde8 100644 --- a/tests/data/test_data_preprocessing_utils.py +++ b/tests/data/test_data_preprocessing_utils.py @@ -32,12 +32,15 @@ ) from tests.artifacts.testdata import ( MODEL_NAME, + TWITTER_COMPLAINTS_DATA_ARROW, + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW, TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON, TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, TWITTER_COMPLAINTS_DATA_JSON, TWITTER_COMPLAINTS_DATA_JSONL, TWITTER_COMPLAINTS_DATA_PARQUET, + TWITTER_COMPLAINTS_TOKENIZED_ARROW, TWITTER_COMPLAINTS_TOKENIZED_JSON, TWITTER_COMPLAINTS_TOKENIZED_JSONL, TWITTER_COMPLAINTS_TOKENIZED_PARQUET, @@ -62,6 +65,10 @@ TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL, set(["ID", "Label", "input", "output"]), ), + ( + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW, + set(["ID", "Label", "input", "output", "sequence"]), + ), ( TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, set(["ID", "Label", "input", "output"]), @@ -80,6 +87,20 @@ ] ), ), + ( + TWITTER_COMPLAINTS_TOKENIZED_ARROW, + set( + [ + "Tweet text", + "ID", + "Label", + "text_label", + "output", + "input_ids", + "labels", + ] + ), + ), ( TWITTER_COMPLAINTS_TOKENIZED_PARQUET, set( @@ -98,6 +119,10 @@ TWITTER_COMPLAINTS_DATA_JSONL, set(["Tweet text", "ID", "Label", "text_label", "output"]), ), + ( + TWITTER_COMPLAINTS_DATA_ARROW, + set(["Tweet text", "ID", "Label", "text_label", "output"]), + ), ( TWITTER_COMPLAINTS_DATA_PARQUET, set(["Tweet text", "ID", "Label", "text_label", "output"]), @@ -123,6 +148,11 @@ def test_load_dataset_with_datafile(datafile, column_names): set(["ID", "Label", "input", "output"]), "text_dataset_input_output_masking", ), + ( + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_ARROW, + set(["ID", "Label", "input", "output", "sequence"]), + "text_dataset_input_output_masking", + ), ( TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET, set(["ID", "Label", "input", "output"]), @@ -163,6 +193,11 @@ def test_load_dataset_with_datafile(datafile, column_names): set(["Tweet text", "ID", "Label", "text_label", "output"]), "apply_custom_data_template", ), + ( + TWITTER_COMPLAINTS_DATA_ARROW, + set(["Tweet text", "ID", "Label", "text_label", "output"]), + "apply_custom_data_template", + ), ( TWITTER_COMPLAINTS_DATA_PARQUET, set(["Tweet text", "ID", "Label", "text_label", "output"]), @@ -593,6 +628,12 @@ def test_process_dataargs(data_args): training_data_path=TWITTER_COMPLAINTS_TOKENIZED_JSONL, ) ), + # ARROW pretokenized train datasets + ( + configs.DataArguments( + training_data_path=TWITTER_COMPLAINTS_TOKENIZED_ARROW, + ) + ), # PARQUET pretokenized train datasets ( configs.DataArguments( diff --git a/tuning/utils/utils.py b/tuning/utils/utils.py index 585011ae9..6eef6b2cf 100644 --- a/tuning/utils/utils.py +++ b/tuning/utils/utils.py @@ -31,6 +31,8 @@ def get_loader_for_filepath(file_path: str) -> str: return "text" if ext in (".json", ".jsonl"): return "json" + if ext in (".arrow"): + return "arrow" if ext in (".parquet"): return "parquet" return ext