Skip to content

Commit 10fb90a

Browse files
authoredApr 19, 2023
Merge pull request #33 from matlab-deep-learning/japanese-bert-fix
Updating predictMaskedToken.m to match new tokenizer.
2 parents 6715db0 + 09ad027 commit 10fb90a

File tree

3 files changed

+53
-4
lines changed

3 files changed

+53
-4
lines changed
 

‎README.md

+4-2
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,14 @@ Download or [clone](https://www.mathworks.com/help/matlab/matlab_prog/use-source
7171
## Example: Classify Text Data Using BERT
7272
The simplest use of a pretrained BERT model is to use it as a feature extractor. In particular, you can use the BERT model to convert documents to feature vectors which you can then use as inputs to train a deep learning classification network.
7373

74-
The example [`ClassifyTextDataUsingBERT.m`](./ClassifyTextDataUsingBERT.m) shows how to use a pretrained BERT model to classify failure events given a data set of factory reports.
74+
The example [`ClassifyTextDataUsingBERT.m`](./ClassifyTextDataUsingBERT.m) shows how to use a pretrained BERT model to classify failure events given a data set of factory reports. This example requires the `factoryReports.csv` data set from the Text Analytics Toolbox example [Prepare Text Data for Analysis](https://www.mathworks.com/help/textanalytics/ug/prepare-text-data-for-analysis.html).
7575

7676
## Example: Fine-Tune Pretrained BERT Model
7777
To get the most out of a pretrained BERT model, you can retrain and fine tune the BERT parameters weights for your task.
7878

79-
The example [`FineTuneBERT.m`](./FineTuneBERT.m) shows how to fine-tune a pretrained BERT model to classify failure events given a data set of factory reports.
79+
The example [`FineTuneBERT.m`](./FineTuneBERT.m) shows how to fine-tune a pretrained BERT model to classify failure events given a data set of factory reports. This example requires the `factoryReports.csv` data set from the Text Analytics Toolbox example [Prepare Text Data for Analysis](https://www.mathworks.com/help/textanalytics/ug/prepare-text-data-for-analysis.html).
80+
81+
The example [`FineTuneBERTJapanese.m`](./FineTuneBERTJapanese.m) shows the same workflow using a pretrained Japanese-BERT model. This example requires the `factoryReportsJP.csv` data set from the Text Analytics Toolbox example [Analyze Japanese Text Data](https://www.mathworks.com/help/textanalytics/ug/analyze-japanese-text.html), available in R2023a or later.
8082

8183
## Example: Analyze Sentiment with FinBERT
8284
FinBERT is a sentiment analysis model trained on financial text data and fine-tuned for sentiment analysis.

‎predictMaskedToken.m

+2-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
% replaces instances of mdl.Tokenizer.MaskToken in the string text with
77
% the most likely token according to the BERT model mdl.
88

9-
% Copyright 2021 The MathWorks, Inc.
9+
% Copyright 2021-2023 The MathWorks, Inc.
1010
arguments
1111
mdl {mustBeA(mdl,'struct')}
1212
str {mustBeText}
@@ -44,7 +44,7 @@
4444
tokens = fulltok.tokenize(pieces(i));
4545
if ~isempty(tokens)
4646
% "" tokenizes to empty - awkward
47-
x = cat(2,x,fulltok.encode(tokens));
47+
x = cat(2,x,fulltok.encode(tokens{1}));
4848
end
4949
if i<numel(pieces)
5050
x = cat(2,x,maskCode);

‎test/tpredictMaskedToken.m

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
classdef(SharedTestFixtures={
2+
DownloadBERTFixture, DownloadJPBERTFixture}) tpredictMaskedToken < matlab.unittest.TestCase
3+
% tpredictMaskedToken Unit test for predictMaskedToken
4+
5+
% Copyright 2023 The MathWorks, Inc.
6+
7+
properties(TestParameter)
8+
Models = {"tiny","japanese-base-wwm"}
9+
ValidText = iGetValidText;
10+
end
11+
12+
methods(Test)
13+
function verifyOutputDimSizes(test, Models, ValidText)
14+
inSize = size(ValidText);
15+
mdl = bert("Model", Models);
16+
outputText = predictMaskedToken(mdl,ValidText);
17+
test.verifyEqual(size(outputText), inSize);
18+
end
19+
20+
function maskTokenIsRemoved(test, Models)
21+
text = "This has a [MASK] token.";
22+
mdl = bert("Model", Models);
23+
outputText = predictMaskedToken(mdl,text);
24+
test.verifyFalse(contains(outputText, "[MASK]"));
25+
end
26+
27+
function inputWithoutMASKRemainsTheSame(test, Models)
28+
text = "This has a no mask token.";
29+
mdl = bert("Model", Models);
30+
outputText = predictMaskedToken(mdl,text);
31+
test.verifyEqual(text, outputText);
32+
end
33+
end
34+
end
35+
36+
function validText = iGetValidText
37+
manyStrs = ["Accelerating the pace of [MASK] and science";
38+
"The cat [MASK] soundly.";
39+
"The [MASK] set beautifully."];
40+
singleStr = "Artificial intelligence continues to shape the future of industries," + ...
41+
" as innovative applications emerge in fields such as healthcare, transportation," + ...
42+
" entertainment, and finance, driving productivity and enhancing human capabilities.";
43+
validText = struct('StringsAsColumns',manyStrs,...
44+
'StringsAsRows',manyStrs',...
45+
'ManyStrings',repmat(singleStr,3),...
46+
'SingleString',singleStr);
47+
end

0 commit comments

Comments
 (0)