Skip to content

Commit

Permalink
Merge pull request #753 from snipsco/release/0.19.2
Browse files Browse the repository at this point in the history
Release 0.19.2
  • Loading branch information
adrienball authored Feb 11, 2019
2 parents 343cf2e + 3caa419 commit 04087e1
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 453 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# Changelog
All notable changes to this project will be documented in this file.

## [0.19.2] - 2019-02-11
### Fixed
- Fix an issue regarding the way builtin entities were handled by the `CRFSlotFiller`

## [0.19.1] - 2019-02-04
### Fixed
- Bug causing an unnecessary reloading of shared resources
Expand Down Expand Up @@ -232,6 +236,7 @@ several commands.
- Fix compiling issue with `bindgen` dependency when installing from source
- Fix issue in `CRFSlotFiller` when handling builtin entities

[0.19.2]: https://github.com/snipsco/snips-nlu/compare/0.19.1...0.19.2
[0.19.1]: https://github.com/snipsco/snips-nlu/compare/0.19.0...0.19.1
[0.19.0]: https://github.com/snipsco/snips-nlu/compare/0.18.0...0.19.0
[0.18.0]: https://github.com/snipsco/snips-nlu/compare/0.17.4...0.18.0
Expand Down
2 changes: 1 addition & 1 deletion snips_nlu/__about__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
__email__ = "clement.doumouro@snips.ai, adrien.ball@snips.ai"
__license__ = "Apache License, Version 2.0"

__version__ = "0.19.1"
__version__ = "0.19.2"
__model_version__ = "0.19.0"

__download_url__ = "https://github.com/snipsco/snips-nlu-language-resources/releases/download"
Expand Down
10 changes: 9 additions & 1 deletion snips_nlu/cli/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def progression_handler(progress):
nb_folds=nb_folds,
train_size_ratio=train_size_ratio,
include_slot_metrics=not exclude_slot_metrics,
slot_matching_lambda=_match_trimmed_values
)

from snips_nlu_metrics import compute_cross_val_metrics
Expand Down Expand Up @@ -108,7 +109,8 @@ def train_test_metrics(train_dataset_path, test_dataset_path, output_path,
train_dataset=train_dataset_path,
test_dataset=test_dataset_path,
engine_class=engine_cls,
include_slot_metrics=not exclude_slot_metrics
include_slot_metrics=not exclude_slot_metrics,
slot_matching_lambda=_match_trimmed_values
)

from snips_nlu_metrics import compute_train_test_metrics
Expand All @@ -119,3 +121,9 @@ def train_test_metrics(train_dataset_path, test_dataset_path, output_path,

with Path(output_path).open(mode="w", encoding="utf8") as f:
f.write(json_string(metrics))


def _match_trimmed_values(lhs_slot, rhs_slot):
lhs_value = lhs_slot["text"].strip()
rhs_value = rhs_slot["rawValue"].strip()
return lhs_value == rhs_value
178 changes: 5 additions & 173 deletions snips_nlu/slot_filler/crf_slot_filler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
import shutil
import tempfile
from builtins import range
from copy import copy
from itertools import groupby, product
from pathlib import Path

from future.utils import iteritems
Expand All @@ -20,20 +18,16 @@
from snips_nlu.common.log_utils import DifferedLoggingMessage, log_elapsed_time
from snips_nlu.common.utils import (
check_persisted_path,
check_random_state, fitted_required, json_string,
ranges_overlap)
check_random_state, fitted_required, json_string)
from snips_nlu.constants import (
DATA, END, ENTITY_KIND, LANGUAGE, RES_ENTITY,
RES_MATCH_RANGE, RES_VALUE, START)
DATA, LANGUAGE)
from snips_nlu.data_augmentation import augment_utterances
from snips_nlu.dataset import validate_and_format_dataset
from snips_nlu.entity_parser.builtin_entity_parser import is_builtin_entity
from snips_nlu.exceptions import LoadingError
from snips_nlu.pipeline.configs import CRFSlotFillerConfig
from snips_nlu.preprocessing import tokenize
from snips_nlu.slot_filler.crf_utils import (
OUTSIDE, TAGS, TOKENS, positive_tagging, tag_name_to_slot_name,
tags_to_preslots, tags_to_slots, utterance_to_sample)
OUTSIDE, TAGS, TOKENS, tags_to_slots, utterance_to_sample)
from snips_nlu.slot_filler.feature import TOKEN_NAME
from snips_nlu.slot_filler.feature_factory import CRFFeatureFactory
from snips_nlu.slot_filler.slot_filler import SlotFiller
Expand Down Expand Up @@ -186,18 +180,8 @@ def get_slots(self, text):
features = self.compute_features(tokens)
tags = [_decode_tag(tag) for tag in
self.crf_model.predict_single(features)]
slots = tags_to_slots(text, tokens, tags, self.config.tagging_scheme,
self.slot_name_mapping)

builtin_slots_names = set(slot_name for (slot_name, entity) in
iteritems(self.slot_name_mapping)
if is_builtin_entity(entity))
if not builtin_slots_names:
return slots

# Replace tags corresponding to builtin entities by outside tags
tags = _replace_builtin_tags(tags, builtin_slots_names)
return self._augment_slots(text, tokens, tags, builtin_slots_names)
return tags_to_slots(text, tokens, tags, self.config.tagging_scheme,
self.slot_name_mapping)

def compute_features(self, tokens, drop_out=False):
"""Computes features on the provided tokens
Expand Down Expand Up @@ -280,60 +264,6 @@ def log_weights(self):
log += "\n%s %s: %s" % (feat, _decode_tag(tag), weight)
return log

def _augment_slots(self, text, tokens, tags, builtin_slots_names):
scope = set(self.slot_name_mapping[slot]
for slot in builtin_slots_names)
builtin_entities = [
be for entity_kind in scope for be in
self.builtin_entity_parser.parse(text, scope=[entity_kind],
use_cache=True)]
# We remove builtin entities which conflicts with custom slots
# extracted by the CRF
builtin_entities = _filter_overlapping_builtins(
builtin_entities, tokens, tags, self.config.tagging_scheme)

# We resolve conflicts between builtin entities by keeping the longest
# matches. In case when two builtin entities span the same range, we
# keep both.
builtin_entities = _disambiguate_builtin_entities(builtin_entities)

# We group builtin entities based on their position
grouped_entities = (
list(bes)
for _, bes in groupby(builtin_entities,
key=lambda s: s[RES_MATCH_RANGE][START]))
grouped_entities = sorted(
grouped_entities,
key=lambda entities: entities[0][RES_MATCH_RANGE][START])

features = self.compute_features(tokens)
spans_ranges = [entities[0][RES_MATCH_RANGE]
for entities in grouped_entities]
tokens_indexes = _spans_to_tokens_indexes(spans_ranges, tokens)

# We loop on all possible slots permutations and use the CRF to find
# the best one in terms of probability
slots_permutations = _get_slots_permutations(
grouped_entities, self.slot_name_mapping)
best_updated_tags = tags
best_permutation_score = -1
for slots in slots_permutations:
updated_tags = copy(tags)
for slot_index, slot in enumerate(slots):
indexes = tokens_indexes[slot_index]
sub_tags_sequence = positive_tagging(
self.config.tagging_scheme, slot, len(indexes))
updated_tags[indexes[0]:indexes[-1] + 1] = sub_tags_sequence
score = self._get_sequence_probability(features, updated_tags)
if score > best_permutation_score:
best_updated_tags = updated_tags
best_permutation_score = score
slots = tags_to_slots(text, tokens, best_updated_tags,
self.config.tagging_scheme,
self.slot_name_mapping)

return _reconciliate_builtin_slots(text, slots, builtin_entities)

@check_persisted_path
def persist(self, path):
"""Persists the object at the given path"""
Expand Down Expand Up @@ -404,104 +334,6 @@ def _get_crf_model(crf_args):
return CRF(model_filename=model_filename, **crf_args)


def _replace_builtin_tags(tags, builtin_slot_names):
new_tags = []
for tag in tags:
if tag == OUTSIDE:
new_tags.append(tag)
else:
slot_name = tag_name_to_slot_name(tag)
if slot_name in builtin_slot_names:
new_tags.append(OUTSIDE)
else:
new_tags.append(tag)
return new_tags


def _filter_overlapping_builtins(builtin_entities, tokens, tags,
tagging_scheme):
slots = tags_to_preslots(tokens, tags, tagging_scheme)
ents = []
for ent in builtin_entities:
if any(ranges_overlap(ent[RES_MATCH_RANGE], s[RES_MATCH_RANGE])
for s in slots):
continue
ents.append(ent)
return ents


def _spans_to_tokens_indexes(spans, tokens):
tokens_indexes = []
for span in spans:
indexes = []
for i, token in enumerate(tokens):
if span[END] > token.start and span[START] < token.end:
indexes.append(i)
tokens_indexes.append(indexes)
return tokens_indexes


def _reconciliate_builtin_slots(text, slots, builtin_entities):
for slot in slots:
if not is_builtin_entity(slot[RES_ENTITY]):
continue
for be in builtin_entities:
if be[ENTITY_KIND] != slot[RES_ENTITY]:
continue
be_start = be[RES_MATCH_RANGE][START]
be_end = be[RES_MATCH_RANGE][END]
be_length = be_end - be_start
slot_start = slot[RES_MATCH_RANGE][START]
slot_end = slot[RES_MATCH_RANGE][END]
slot_length = slot_end - slot_start
if be_start <= slot_start and be_end >= slot_end \
and be_length > slot_length:
slot[RES_MATCH_RANGE] = {
START: be_start,
END: be_end
}
slot[RES_VALUE] = text[be_start: be_end]
break
return slots


def _disambiguate_builtin_entities(builtin_entities):
if not builtin_entities:
return []
builtin_entities = sorted(
builtin_entities,
key=lambda be: be[RES_MATCH_RANGE][END] - be[RES_MATCH_RANGE][START],
reverse=True)

disambiguated_entities = [builtin_entities[0]]
for entity in builtin_entities[1:]:
entity_rng = entity[RES_MATCH_RANGE]
conflict = False
for disambiguated_entity in disambiguated_entities:
disambiguated_entity_rng = disambiguated_entity[RES_MATCH_RANGE]
if ranges_overlap(entity_rng, disambiguated_entity_rng):
conflict = True
if entity_rng == disambiguated_entity_rng:
disambiguated_entities.append(entity)
break
if not conflict:
disambiguated_entities.append(entity)

return sorted(disambiguated_entities,
key=lambda be: be[RES_MATCH_RANGE][START])


def _get_slots_permutations(grouped_entities, slot_name_mapping):
# We associate to each group of entities the list of slot names that
# could correspond
possible_slots = [
list(set(slot_name for slot_name, ent in iteritems(slot_name_mapping)
for entity in entities if ent == entity[ENTITY_KIND]))
+ [OUTSIDE]
for entities in grouped_entities]
return product(*possible_slots)


def _encode_tag(tag):
return base64.b64encode(tag.encode("utf8"))

Expand Down
Loading

0 comments on commit 04087e1

Please # to comment.