-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathSpamModelTester.php
135 lines (114 loc) · 4.27 KB
/
SpamModelTester.php
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
<?php
namespace App\Module\ML\Application\Model;
use App\Core\Application\Path\AppPathResolver;
use App\Module\ML\Application\Model\VO\TestInput;
use App\Module\ML\Application\Model\VO\TestResult;
use App\Module\ML\Application\Utils\WordsUtils;
use App\Module\ML\Domain\Constant;
use Ramsey\Collection\Collection;
use Ramsey\Collection\CollectionInterface;
use Ramsey\Collection\Map\TypedMap;
use Ramsey\Collection\Sort;
use Rubix\ML\CrossValidation\KFold;
use Rubix\ML\CrossValidation\Metrics\FBeta;
use Rubix\ML\Datasets\Labeled;
use Rubix\ML\Extractors\CSV;
use Symfony\Component\Console\Helper\ProgressBar;
use Symfony\Component\Console\Style\SymfonyStyle;
/**
* @see https://docs.rubixml.com/latest/cross-validation.html#validators
*/
class SpamModelTester
{
private static ?SymfonyStyle $io = null;
public function __construct(
private readonly AppPathResolver $appPathResolver,
) {
}
/**
* @return CollectionInterface<TestResult>
*/
public function test(
string $testingDatasetFilename,
int $foldsNumber = Constant::DEFAULT_FOLDS_NUMBER
): CollectionInterface {
$parametersCollection = $this->generateParameters();
$results = new Collection(TestResult::class);
$dataset = Labeled::fromIterator(new CSV(
$this->appPathResolver->getDatasetPath($testingDatasetFilename),
header: true,
));
$uniqueWordsNum = WordsUtils::countUniqueWords($dataset->samples(), Constant::DEFAULT_MIN_WORDS_COUNT);
$progressBar = $this->startProgressBar($parametersCollection->count());
foreach ($parametersCollection as $testParams) {
/** @var TestInput $testParams */
$estimator = LearnerFactory::createLearner(
uniqueWordsNum: $uniqueWordsNum,
minDocumentCount: $testParams->minDocumentCount,
maxDocumentRatio: $testParams->maxDocumentRatio,
treeEstimators: $testParams->treeEstimators,
treeRatio: $testParams->treeRatio,
);
$validator = new KFold($foldsNumber);
$score = $validator->test($estimator, $dataset, new FBeta());
$results->add(new TestResult(
score: $score,
parameters: $testParams,
));
$progressBar?->advance();
}
$progressBar?->finish();
return $results->sort('score', Sort::Descending);
}
public static function setIo(?SymfonyStyle $io): void
{
self::$io = $io;
}
/**
* @return CollectionInterface<TestInput>
*/
private function generateParameters(): CollectionInterface
{
$collection = new Collection(TestInput::class);
$range = $this->createRange();
foreach ($range->get('minDocumentCount') as $minDocumentCount) {
foreach ($range->get('maxDocumentRatio') as $maxDocumentRatio) {
foreach ($range->get('treeEstimators') as $treeEstimators) {
foreach ($range->get('treeRatio') as $treeRatio) {
$collection->add(new TestInput(
minDocumentCount: $minDocumentCount,
maxDocumentRatio: $maxDocumentRatio,
treeEstimators: $treeEstimators,
treeRatio: $treeRatio,
));
}
}
}
}
return $collection;
}
/**
* @return TypedMap<string, iterable<int|float>>
*/
private function createRange(): TypedMap
{
/* @phpstan-ignore-next-line */
return new TypedMap('string', 'array', [
'minDocumentCount' => range(2, 4),
'maxDocumentRatio' => range(0.3, 0.6, 0.1),
'treeEstimators' => range(100, 200, 100),
'treeRatio' => range(0.1, 0.3, 0.1),
]);
}
private function startProgressBar(int $max): ?ProgressBar
{
ProgressBar::setFormatDefinition(
'spam_model_tester',
'%current%/%max% of Property sets [%bar%] %percent:3s%%'
);
$progressBar = self::$io?->createProgressBar($max);
$progressBar?->setFormat('spam_model_tester');
$progressBar?->start();
return $progressBar;
}
}