diff --git a/CHANGELOG.md b/CHANGELOG.md index dc53bf654..eb2f2c78c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,10 @@ # Changelog All notable changes to this project will be documented in this file. +## [0.19.2] - 2019-02-11 +### Fixed +- Fix an issue regarding the way builtin entities were handled by the `CRFSlotFiller` + ## [0.19.1] - 2019-02-04 ### Fixed - Bug causing an unnecessary reloading of shared resources @@ -232,6 +236,7 @@ several commands. - Fix compiling issue with `bindgen` dependency when installing from source - Fix issue in `CRFSlotFiller` when handling builtin entities +[0.19.2]: https://github.com/snipsco/snips-nlu/compare/0.19.1...0.19.2 [0.19.1]: https://github.com/snipsco/snips-nlu/compare/0.19.0...0.19.1 [0.19.0]: https://github.com/snipsco/snips-nlu/compare/0.18.0...0.19.0 [0.18.0]: https://github.com/snipsco/snips-nlu/compare/0.17.4...0.18.0 diff --git a/snips_nlu/__about__.py b/snips_nlu/__about__.py index b8955c299..c9d76cadc 100644 --- a/snips_nlu/__about__.py +++ b/snips_nlu/__about__.py @@ -11,7 +11,7 @@ __email__ = "clement.doumouro@snips.ai, adrien.ball@snips.ai" __license__ = "Apache License, Version 2.0" -__version__ = "0.19.1" +__version__ = "0.19.2" __model_version__ = "0.19.0" __download_url__ = "https://github.com/snipsco/snips-nlu-language-resources/releases/download" diff --git a/snips_nlu/cli/metrics.py b/snips_nlu/cli/metrics.py index b8ad55fce..aa181f42a 100644 --- a/snips_nlu/cli/metrics.py +++ b/snips_nlu/cli/metrics.py @@ -66,6 +66,7 @@ def progression_handler(progress): nb_folds=nb_folds, train_size_ratio=train_size_ratio, include_slot_metrics=not exclude_slot_metrics, + slot_matching_lambda=_match_trimmed_values ) from snips_nlu_metrics import compute_cross_val_metrics @@ -108,7 +109,8 @@ def train_test_metrics(train_dataset_path, test_dataset_path, output_path, train_dataset=train_dataset_path, test_dataset=test_dataset_path, engine_class=engine_cls, - include_slot_metrics=not exclude_slot_metrics + include_slot_metrics=not exclude_slot_metrics, + slot_matching_lambda=_match_trimmed_values ) from snips_nlu_metrics import compute_train_test_metrics @@ -119,3 +121,9 @@ def train_test_metrics(train_dataset_path, test_dataset_path, output_path, with Path(output_path).open(mode="w", encoding="utf8") as f: f.write(json_string(metrics)) + + +def _match_trimmed_values(lhs_slot, rhs_slot): + lhs_value = lhs_slot["text"].strip() + rhs_value = rhs_slot["rawValue"].strip() + return lhs_value == rhs_value diff --git a/snips_nlu/slot_filler/crf_slot_filler.py b/snips_nlu/slot_filler/crf_slot_filler.py index 6c50814ac..ff9b0a4c6 100644 --- a/snips_nlu/slot_filler/crf_slot_filler.py +++ b/snips_nlu/slot_filler/crf_slot_filler.py @@ -7,8 +7,6 @@ import shutil import tempfile from builtins import range -from copy import copy -from itertools import groupby, product from pathlib import Path from future.utils import iteritems @@ -20,20 +18,16 @@ from snips_nlu.common.log_utils import DifferedLoggingMessage, log_elapsed_time from snips_nlu.common.utils import ( check_persisted_path, - check_random_state, fitted_required, json_string, - ranges_overlap) + check_random_state, fitted_required, json_string) from snips_nlu.constants import ( - DATA, END, ENTITY_KIND, LANGUAGE, RES_ENTITY, - RES_MATCH_RANGE, RES_VALUE, START) + DATA, LANGUAGE) from snips_nlu.data_augmentation import augment_utterances from snips_nlu.dataset import validate_and_format_dataset -from snips_nlu.entity_parser.builtin_entity_parser import is_builtin_entity from snips_nlu.exceptions import LoadingError from snips_nlu.pipeline.configs import CRFSlotFillerConfig from snips_nlu.preprocessing import tokenize from snips_nlu.slot_filler.crf_utils import ( - OUTSIDE, TAGS, TOKENS, positive_tagging, tag_name_to_slot_name, - tags_to_preslots, tags_to_slots, utterance_to_sample) + OUTSIDE, TAGS, TOKENS, tags_to_slots, utterance_to_sample) from snips_nlu.slot_filler.feature import TOKEN_NAME from snips_nlu.slot_filler.feature_factory import CRFFeatureFactory from snips_nlu.slot_filler.slot_filler import SlotFiller @@ -186,18 +180,8 @@ def get_slots(self, text): features = self.compute_features(tokens) tags = [_decode_tag(tag) for tag in self.crf_model.predict_single(features)] - slots = tags_to_slots(text, tokens, tags, self.config.tagging_scheme, - self.slot_name_mapping) - - builtin_slots_names = set(slot_name for (slot_name, entity) in - iteritems(self.slot_name_mapping) - if is_builtin_entity(entity)) - if not builtin_slots_names: - return slots - - # Replace tags corresponding to builtin entities by outside tags - tags = _replace_builtin_tags(tags, builtin_slots_names) - return self._augment_slots(text, tokens, tags, builtin_slots_names) + return tags_to_slots(text, tokens, tags, self.config.tagging_scheme, + self.slot_name_mapping) def compute_features(self, tokens, drop_out=False): """Computes features on the provided tokens @@ -280,60 +264,6 @@ def log_weights(self): log += "\n%s %s: %s" % (feat, _decode_tag(tag), weight) return log - def _augment_slots(self, text, tokens, tags, builtin_slots_names): - scope = set(self.slot_name_mapping[slot] - for slot in builtin_slots_names) - builtin_entities = [ - be for entity_kind in scope for be in - self.builtin_entity_parser.parse(text, scope=[entity_kind], - use_cache=True)] - # We remove builtin entities which conflicts with custom slots - # extracted by the CRF - builtin_entities = _filter_overlapping_builtins( - builtin_entities, tokens, tags, self.config.tagging_scheme) - - # We resolve conflicts between builtin entities by keeping the longest - # matches. In case when two builtin entities span the same range, we - # keep both. - builtin_entities = _disambiguate_builtin_entities(builtin_entities) - - # We group builtin entities based on their position - grouped_entities = ( - list(bes) - for _, bes in groupby(builtin_entities, - key=lambda s: s[RES_MATCH_RANGE][START])) - grouped_entities = sorted( - grouped_entities, - key=lambda entities: entities[0][RES_MATCH_RANGE][START]) - - features = self.compute_features(tokens) - spans_ranges = [entities[0][RES_MATCH_RANGE] - for entities in grouped_entities] - tokens_indexes = _spans_to_tokens_indexes(spans_ranges, tokens) - - # We loop on all possible slots permutations and use the CRF to find - # the best one in terms of probability - slots_permutations = _get_slots_permutations( - grouped_entities, self.slot_name_mapping) - best_updated_tags = tags - best_permutation_score = -1 - for slots in slots_permutations: - updated_tags = copy(tags) - for slot_index, slot in enumerate(slots): - indexes = tokens_indexes[slot_index] - sub_tags_sequence = positive_tagging( - self.config.tagging_scheme, slot, len(indexes)) - updated_tags[indexes[0]:indexes[-1] + 1] = sub_tags_sequence - score = self._get_sequence_probability(features, updated_tags) - if score > best_permutation_score: - best_updated_tags = updated_tags - best_permutation_score = score - slots = tags_to_slots(text, tokens, best_updated_tags, - self.config.tagging_scheme, - self.slot_name_mapping) - - return _reconciliate_builtin_slots(text, slots, builtin_entities) - @check_persisted_path def persist(self, path): """Persists the object at the given path""" @@ -404,104 +334,6 @@ def _get_crf_model(crf_args): return CRF(model_filename=model_filename, **crf_args) -def _replace_builtin_tags(tags, builtin_slot_names): - new_tags = [] - for tag in tags: - if tag == OUTSIDE: - new_tags.append(tag) - else: - slot_name = tag_name_to_slot_name(tag) - if slot_name in builtin_slot_names: - new_tags.append(OUTSIDE) - else: - new_tags.append(tag) - return new_tags - - -def _filter_overlapping_builtins(builtin_entities, tokens, tags, - tagging_scheme): - slots = tags_to_preslots(tokens, tags, tagging_scheme) - ents = [] - for ent in builtin_entities: - if any(ranges_overlap(ent[RES_MATCH_RANGE], s[RES_MATCH_RANGE]) - for s in slots): - continue - ents.append(ent) - return ents - - -def _spans_to_tokens_indexes(spans, tokens): - tokens_indexes = [] - for span in spans: - indexes = [] - for i, token in enumerate(tokens): - if span[END] > token.start and span[START] < token.end: - indexes.append(i) - tokens_indexes.append(indexes) - return tokens_indexes - - -def _reconciliate_builtin_slots(text, slots, builtin_entities): - for slot in slots: - if not is_builtin_entity(slot[RES_ENTITY]): - continue - for be in builtin_entities: - if be[ENTITY_KIND] != slot[RES_ENTITY]: - continue - be_start = be[RES_MATCH_RANGE][START] - be_end = be[RES_MATCH_RANGE][END] - be_length = be_end - be_start - slot_start = slot[RES_MATCH_RANGE][START] - slot_end = slot[RES_MATCH_RANGE][END] - slot_length = slot_end - slot_start - if be_start <= slot_start and be_end >= slot_end \ - and be_length > slot_length: - slot[RES_MATCH_RANGE] = { - START: be_start, - END: be_end - } - slot[RES_VALUE] = text[be_start: be_end] - break - return slots - - -def _disambiguate_builtin_entities(builtin_entities): - if not builtin_entities: - return [] - builtin_entities = sorted( - builtin_entities, - key=lambda be: be[RES_MATCH_RANGE][END] - be[RES_MATCH_RANGE][START], - reverse=True) - - disambiguated_entities = [builtin_entities[0]] - for entity in builtin_entities[1:]: - entity_rng = entity[RES_MATCH_RANGE] - conflict = False - for disambiguated_entity in disambiguated_entities: - disambiguated_entity_rng = disambiguated_entity[RES_MATCH_RANGE] - if ranges_overlap(entity_rng, disambiguated_entity_rng): - conflict = True - if entity_rng == disambiguated_entity_rng: - disambiguated_entities.append(entity) - break - if not conflict: - disambiguated_entities.append(entity) - - return sorted(disambiguated_entities, - key=lambda be: be[RES_MATCH_RANGE][START]) - - -def _get_slots_permutations(grouped_entities, slot_name_mapping): - # We associate to each group of entities the list of slot names that - # could correspond - possible_slots = [ - list(set(slot_name for slot_name, ent in iteritems(slot_name_mapping) - for entity in entities if ent == entity[ENTITY_KIND])) - + [OUTSIDE] - for entities in grouped_entities] - return product(*possible_slots) - - def _encode_tag(tag): return base64.b64encode(tag.encode("utf8")) diff --git a/snips_nlu/tests/test_crf_slot_filler.py b/snips_nlu/tests/test_crf_slot_filler.py index 17c0d0391..3b03137f2 100644 --- a/snips_nlu/tests/test_crf_slot_filler.py +++ b/snips_nlu/tests/test_crf_slot_filler.py @@ -9,21 +9,15 @@ from sklearn_crfsuite import CRF from snips_nlu.constants import ( - DATA, END, ENTITY, ENTITY_KIND, LANGUAGE_EN, RES_MATCH_RANGE, SLOT_NAME, - SNIPS_DATETIME, START, TEXT, VALUE) + DATA, END, ENTITY, LANGUAGE_EN, SLOT_NAME, START, TEXT) from snips_nlu.dataset import Dataset -from snips_nlu.entity_parser import BuiltinEntityParser, \ - CustomEntityParserUsage +from snips_nlu.entity_parser import CustomEntityParserUsage from snips_nlu.exceptions import NotTrained from snips_nlu.pipeline.configs import CRFSlotFillerConfig -from snips_nlu.preprocessing import Token, tokenize +from snips_nlu.preprocessing import tokenize from snips_nlu.result import unresolved_slot -from snips_nlu.slot_filler.crf_slot_filler import ( - CRFSlotFiller, _disambiguate_builtin_entities, _ensure_safe, - _filter_overlapping_builtins, _get_slots_permutations, - _spans_to_tokens_indexes) -from snips_nlu.slot_filler.crf_utils import ( - BEGINNING_PREFIX, INSIDE_PREFIX, TaggingScheme) +from snips_nlu.slot_filler.crf_slot_filler import CRFSlotFiller, _ensure_safe +from snips_nlu.slot_filler.crf_utils import TaggingScheme from snips_nlu.slot_filler.feature_factory import ( IsDigitFactory, NgramFactory, ShapeNgramFactory) from snips_nlu.tests.utils import FixtureTest, TEST_PATH @@ -76,21 +70,57 @@ def test_should_get_builtin_slots(self): slot_filler.fit(dataset, intent) # When - slots = slot_filler.get_slots("Give me the weather at 9p.m. in Paris") + slots = slot_filler.get_slots("Give me the weather at 9pm in Paris") # Then expected_slots = [ - unresolved_slot(match_range={START: 20, END: 28}, - value='at 9p.m.', + unresolved_slot(match_range={START: 20, END: 26}, + value='at 9pm', entity='snips/datetime', slot_name='datetime'), - unresolved_slot(match_range={START: 32, END: 37}, + unresolved_slot(match_range={START: 30, END: 35}, value='Paris', entity='weather_location', slot_name='location') ] self.assertListEqual(expected_slots, slots) + def test_should_get_sub_builtin_slots(self): + # Given + dataset_stream = io.StringIO(""" +--- +type: intent +name: PlanBreak +utterances: +- 'I want to leave from [start:snips/datetime](tomorrow) until + [end:snips/datetime](next thursday)' +- find me something from [start](9am) to [end](12pm) +- I need a break from [start](2pm) until [end](4pm) +- Can you suggest something from [start](april 4th) until [end](april 6th) ? +- Book me a trip from [start](this friday) to [end](next tuesday)""") + dataset = Dataset.from_yaml_files("en", [dataset_stream]).json + config = CRFSlotFillerConfig(random_seed=42) + intent = "PlanBreak" + slot_filler = CRFSlotFiller(config, + **self.get_shared_data(dataset)) + slot_filler.fit(dataset, intent) + + # When + slots = slot_filler.get_slots("Find me a plan from 5pm to 6pm") + + # Then + expected_slots = [ + unresolved_slot(match_range={START: 20, END: 23}, + value="5pm", + entity="snips/datetime", + slot_name="start"), + unresolved_slot(match_range={START: 27, END: 30}, + value="6pm", + entity="snips/datetime", + slot_name="end") + ] + self.assertListEqual(expected_slots, slots) + def test_should_not_use_crf_when_dataset_with_no_slots(self): # Given dataset = { @@ -698,269 +728,6 @@ def test_should_compute_features(self): ] self.assertListEqual(expected_features, features_with_drop_out) - def test_spans_to_tokens_indexes(self): - # Given - spans = [ - {START: 0, END: 1}, - {START: 2, END: 6}, - {START: 5, END: 6}, - {START: 9, END: 15} - ] - tokens = [ - Token(value="abc", start=0, end=3), - Token(value="def", start=4, end=7), - Token(value="ghi", start=10, end=13) - ] - - # When - indexes = _spans_to_tokens_indexes(spans, tokens) - - # Then - expected_indexes = [[0], [0, 1], [1], [2]] - self.assertListEqual(indexes, expected_indexes) - - def test_augment_slots(self): - # Given - language = LANGUAGE_EN - text = "Find me a flight before 10pm and after 8pm" - tokens = tokenize(text, language) - missing_slots = {"start_date", "end_date"} - - tags = ['O' for _ in tokens] - - def mocked_sequence_probability(_, tags_): - tags_1 = ['O', - 'O', - 'O', - 'O', - '%sstart_date' % BEGINNING_PREFIX, - '%sstart_date' % INSIDE_PREFIX, - 'O', - '%send_date' % BEGINNING_PREFIX, - '%send_date' % INSIDE_PREFIX] - - tags_2 = ['O', - 'O', - 'O', - 'O', - '%send_date' % BEGINNING_PREFIX, - '%send_date' % INSIDE_PREFIX, - 'O', - '%sstart_date' % BEGINNING_PREFIX, - '%sstart_date' % INSIDE_PREFIX] - - tags_3 = ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O'] - - tags_4 = ['O', - 'O', - 'O', - 'O', - 'O', - 'O', - 'O', - '%sstart_date' % BEGINNING_PREFIX, - '%sstart_date' % INSIDE_PREFIX] - - tags_5 = ['O', - 'O', - 'O', - 'O', - 'O', - 'O', - 'O', - '%send_date' % BEGINNING_PREFIX, - '%send_date' % INSIDE_PREFIX] - - tags_6 = ['O', - 'O', - 'O', - 'O', - '%sstart_date' % BEGINNING_PREFIX, - '%sstart_date' % INSIDE_PREFIX, - 'O', - 'O', - 'O'] - - tags_7 = ['O', - 'O', - 'O', - 'O', - '%send_date' % BEGINNING_PREFIX, - '%send_date' % INSIDE_PREFIX, - 'O', - 'O', - 'O'] - - tags_8 = ['O', - 'O', - 'O', - 'O', - '%sstart_date' % BEGINNING_PREFIX, - '%sstart_date' % INSIDE_PREFIX, - 'O', - '%sstart_date' % BEGINNING_PREFIX, - '%sstart_date' % INSIDE_PREFIX] - - tags_9 = ['O', - 'O', - 'O', - 'O', - '%send_date' % BEGINNING_PREFIX, - '%send_date' % INSIDE_PREFIX, - 'O', - '%send_date' % BEGINNING_PREFIX, - '%send_date' % INSIDE_PREFIX] - - if tags_ == tags_1: - return 0.6 - elif tags_ == tags_2: - return 0.8 - elif tags_ == tags_3: - return 0.2 - elif tags_ == tags_4: - return 0.2 - elif tags_ == tags_5: - return 0.99 - elif tags_ == tags_6: - return 0.0 - elif tags_ == tags_7: - return 0.0 - elif tags_ == tags_8: - return 0.5 - elif tags_ == tags_9: - return 0.5 - else: - raise ValueError("Unexpected tag sequence: %s" % tags_) - - slot_filler_config = CRFSlotFillerConfig(random_seed=42) - slot_filler = CRFSlotFiller( - config=slot_filler_config, - builtin_entity_parser=BuiltinEntityParser.build(language="en")) - slot_filler.language = LANGUAGE_EN - slot_filler.intent = "intent1" - slot_filler.slot_name_mapping = { - "start_date": "snips/datetime", - "end_date": "snips/datetime", - } - - # pylint:disable=protected-access - slot_filler._get_sequence_probability = MagicMock( - side_effect=mocked_sequence_probability) - # pylint:enable=protected-access - - slot_filler.compute_features = MagicMock(return_value=None) - - # When - # pylint: disable=protected-access - augmented_slots = slot_filler._augment_slots(text, tokens, tags, - missing_slots) - # pylint: enable=protected-access - - # Then - expected_slots = [ - unresolved_slot(value='after 8pm', - match_range={START: 33, END: 42}, - entity='snips/datetime', slot_name='end_date') - ] - self.assertListEqual(augmented_slots, expected_slots) - - def test_filter_overlapping_builtins(self): - # Given - language = LANGUAGE_EN - text = "Find me a flight before 10pm and after 8pm" - tokens = tokenize(text, language) - tags = ['O' for _ in range(5)] + ['B-flight'] + ['O' for _ in range(3)] - tagging_scheme = TaggingScheme.BIO - builtin_entities = [ - { - RES_MATCH_RANGE: {START: 17, END: 28}, - VALUE: "before 10pm", - ENTITY_KIND: SNIPS_DATETIME - }, - { - RES_MATCH_RANGE: {START: 33, END: 42}, - VALUE: "after 8pm", - ENTITY_KIND: SNIPS_DATETIME - } - ] - - # When - entities = _filter_overlapping_builtins(builtin_entities, tokens, tags, - tagging_scheme) - - # Then - expected_entities = [ - { - RES_MATCH_RANGE: {START: 33, END: 42}, - VALUE: "after 8pm", - ENTITY_KIND: SNIPS_DATETIME - } - ] - self.assertEqual(entities, expected_entities) - - def test_should_disambiguate_builtin_entities(self): - # Given - builtin_entities = [ - {RES_MATCH_RANGE: {START: 7, END: 10}}, - {RES_MATCH_RANGE: {START: 9, END: 15}}, - {RES_MATCH_RANGE: {START: 10, END: 17}}, - {RES_MATCH_RANGE: {START: 12, END: 19}}, - {RES_MATCH_RANGE: {START: 9, END: 15}}, - {RES_MATCH_RANGE: {START: 0, END: 5}}, - {RES_MATCH_RANGE: {START: 0, END: 5}}, - {RES_MATCH_RANGE: {START: 0, END: 8}}, - {RES_MATCH_RANGE: {START: 2, END: 5}}, - {RES_MATCH_RANGE: {START: 0, END: 8}}, - ] - - # When - disambiguated_entities = _disambiguate_builtin_entities( - builtin_entities) - - # Then - expected_entities = [ - {RES_MATCH_RANGE: {START: 0, END: 8}}, - {RES_MATCH_RANGE: {START: 0, END: 8}}, - {RES_MATCH_RANGE: {START: 10, END: 17}}, - ] - - self.assertListEqual(expected_entities, disambiguated_entities) - - def test_generate_slots_permutations(self): - # Given - slot_name_mapping = { - "start_date": "snips/datetime", - "end_date": "snips/datetime", - "temperature": "snips/temperature" - } - grouped_entities = [ - [ - {ENTITY_KIND: "snips/datetime"}, - {ENTITY_KIND: "snips/temperature"} - ], - [ - {ENTITY_KIND: "snips/temperature"} - ] - ] - - # When - slots_permutations = set( - "||".join(perm) for perm in - _get_slots_permutations(grouped_entities, slot_name_mapping)) - - # Then - expected_permutations = { - "start_date||temperature", - "end_date||temperature", - "temperature||temperature", - "O||temperature", - "start_date||O", - "end_date||O", - "temperature||O", - "O||O", - } - self.assertSetEqual(expected_permutations, slots_permutations) - def test_should_fit_and_parse_empty_intent(self): # Given dataset = {