Rasa特征提取之CountVectorsFeaturizer

Rasa特征提取之CountVectorsFeaturizer

rasa\nlu\featurizers\sparse_featurizer\count_vectors_featurizer.py

train过程

_get_featurized_attribute 根据attribute训练特征

process预测过程

源码

class CountVectorsFeaturizer(SparseFeaturizer):
    """基于 sklearn 的 `CountVectorizer` 创建一系列令牌计数功能.
    所有仅由数字组成的标记(例如 123 和 99 但不包括 ab12d)将由单个特征表示。
    """

    @classmethod
    def required_components(cls) -> List[Type[Component]]:
        return [Tokenizer]

    defaults = {
        # 是否使用共享词汇
        "use_shared_vocab": False,
        # 参数取自 sklearn 的 CountVectorizer
        #     是否使用单词或字符 n-gram
        #     'char_wb' 在单词边界内创建字符 n-gram
        #     单词边缘的 n-gram 用空格填充。
        "analyzer": "word",  # 使用 'char' 或 'char_wb' 作为字符
        # 在预处理步骤中删除重音
        "strip_accents": None,  # {'ascii', 'unicode', None}
        # 停用词列表
        "stop_words": None,  # string {'english'}, list, or None (default)
        # 要添加到词汇表中的单词的最小文档频率
        # float - 该参数表示文档的比例
        # integer - 绝对计数
        "min_df": 1,  # 在 [0.0, 1.0] 或 int 范围内浮动
        # 要添加到词汇表中的单词的最大文档频率
        # float - 该参数代表文档的比例
        # integer - 绝对数
        "max_df": 1.0,  # 在 [0.0, 1.0] 或 int 范围内浮动
        # 设置要提取的 ngram 范围
        "min_ngram": 1,  # int
        "max_ngram": 1,  # int
        # 限制词汇量
        "max_features": None,  # int or None
        # 如果将所有字符转换为小写
        "lowercase": True,  # bool
        # 处理词汇外 (OOV) 词
        # 如果小写为 True 将转换为小写
        "OOV_token": None,  # string or None
        "OOV_words": [],  # string or list of strings
        # 指示 featurizer 是否应该使用单词的词条 for
        #计数(如果可用)或不计数
        "use_lemma": True,
        # 为微调保留额外的词汇量
        "additional_vocabulary_size": {TEXT: None, RESPONSE: None, ACTION_TEXT: None},
    }

    @classmethod
    def required_packages(cls) -> List[Text]:
        return ["sklearn"]

    def _load_count_vect_params(self) -> None:

        # 在文本和 Message 的所有其他属性之间使用共享词汇
        self.use_shared_vocab = self.component_config["use_shared_vocab"]

        # 设置分析器
        self.analyzer = self.component_config["analyzer"]

        # 在预处理步骤中删除重音
        self.strip_accents = self.component_config["strip_accents"]

        # 停用词列表
        self.stop_words = self.component_config["stop_words"]

        # 文档中要添加到词汇表中的最小单词出现次数
        self.min_df = self.component_config["min_df"]

        # 单词出现的最大数量(如果是浮动的分数)
        # 在文档中添加到词汇表中
        self.max_df = self.component_config["max_df"]

        # set ngram range
        self.min_ngram = self.component_config["min_ngram"]
        self.max_ngram = self.component_config["max_ngram"]

        # limit vocabulary size
        self.max_features = self.component_config["max_features"]

        # if convert all characters to lowercase
        self.lowercase = self.component_config["lowercase"]

        # use the lemma of the words or not
        self.use_lemma = self.component_config["use_lemma"]

    def _load_vocabulary_params(self) -> None:
        self.OOV_token = self.component_config["OOV_token"]

        self.OOV_words = self.component_config["OOV_words"]
        if self.OOV_words and not self.OOV_token:
            logger.error(
                "The list OOV_words={} was given, but "
                "OOV_token was not. OOV words are ignored."
                "".format(self.OOV_words)
            )
            self.OOV_words = []

        if self.lowercase and self.OOV_token:
            # convert to lowercase
            self.OOV_token = self.OOV_token.lower()
            if self.OOV_words:
                self.OOV_words = [w.lower() for w in self.OOV_words]

        # 要保留的额外词汇量
        self.additional_vocabulary_size = self.component_config[
            "额外的词汇量"
        ]

    def _check_attribute_vocabulary(self, attribute: Text) -> bool:
        """检查训练词汇是否存在于属性的计数向量化器中."""
        try:
            return hasattr(self.vectorizers[attribute], "vocabulary_")
        except (AttributeError, TypeError):
            return False

    def _get_attribute_vocabulary(self, attribute: Text) -> Optional[Dict[Text, int]]:
        """从属性计数向量器中获取训练词汇"""

        try:
            return self.vectorizers[attribute].vocabulary_
        except (AttributeError, TypeError):
            return None

    def _get_attribute_vocabulary_tokens(self, attribute: Text) -> Optional[List[Text]]:
        """获取一个属性的词汇表的所有键"""

        attribute_vocabulary = self._get_attribute_vocabulary(attribute)
        try:
            return list(attribute_vocabulary.keys())
        except TypeError:
            return None

    def _check_analyzer(self) -> None:
        if self.analyzer != "word":
            if self.OOV_token is not None:
                logger.warning(
                    "Analyzer is set to character, "
                    "provided OOV word token will be ignored."
                )
            if self.stop_words is not None:
                logger.warning(
                    "Analyzer is set to character, "
                    "provided stop words will be ignored."
                )
            if self.max_ngram == 1:
                logger.warning(
                    "Analyzer is set to character, "
                    "but max n-gram is set to 1. "
                    "It means that the vocabulary will "
                    "contain single letters only."
                )

    @staticmethod
    def _attributes_for(analyzer: Text) -> List[Text]:
        """创建应该特征化的属性列表."""

        # intents should be featurized only by word level count vectorizer
        return (
            MESSAGE_ATTRIBUTES if analyzer == "word" else DENSE_FEATURIZABLE_ATTRIBUTES
        )

    def __init__(
        self,
        component_config: Optional[Dict[Text, Any]] = None,
        vectorizers: Optional[Dict[Text, "CountVectorizer"]] = None,
        finetune_mode: bool = False,
    ) -> None:
        """Construct a new count vectorizer using the sklearn framework."""
        super().__init__(component_config)

        # parameters for sklearn's CountVectorizer
        self._load_count_vect_params()

        # handling Out-Of-Vocabulary (OOV) words
        self._load_vocabulary_params()

        # warn that some of config parameters might be ignored
        self._check_analyzer()

        # set which attributes to featurize
        self._attributes = self._attributes_for(self.analyzer)

        # declare class instance for CountVectorizer
        self.vectorizers = vectorizers

        self.finetune_mode = finetune_mode

    def _get_message_tokens_by_attribute(
        self, message: "Message", attribute: Text
    ) -> List[Text]:
        """Get text tokens of an attribute of a message 获取消息属性的文本标记"""
        if message.get(TOKENS_NAMES[attribute]):
            return [
                t.lemma if self.use_lemma else t.text
                for t in message.get(TOKENS_NAMES[attribute])
            ]
        else:
            return []

    def _process_tokens(self, tokens: List[Text], attribute: Text = TEXT) -> List[Text]:
        """对文本应用处理和清理步骤"""

        if attribute in [INTENT, ACTION_NAME, INTENT_RESPONSE_KEY]:
            # Don't do any processing for intent attribute. Treat them as whole labels
            return tokens

        # replace all digits with NUMBER token 用 NUMBER 令牌替换所有数字
        tokens = [re.sub(r"\b[0-9]+\b", "__NUMBER__", text) for text in tokens]

        # convert to lowercase if necessary 必要时转换为小写
        if self.lowercase:
            tokens = [text.lower() for text in tokens]

        return tokens

    def _replace_with_oov_token(
        self, tokens: List[Text], attribute: Text
    ) -> List[Text]:
        """Replace OOV words with OOV token"""

        if self.OOV_token and self.analyzer == "word":
            vocabulary_exists = self._check_attribute_vocabulary(attribute)
            if vocabulary_exists and self.OOV_token in self._get_attribute_vocabulary(
                attribute
            ):
                # CountVectorizer 被训练,预测过程
                tokens = [
                    t
                    if t in self._get_attribute_vocabulary_tokens(attribute)
                    else self.OOV_token
                    for t in tokens
                ]
            elif self.OOV_words:
                # CountVectorizer 没有经过训练,process过程
                tokens = [self.OOV_token if t in self.OOV_words else t for t in tokens]

        return tokens

    def _get_processed_message_tokens_by_attribute(
        self, message: Message, attribute: Text = TEXT
    ) -> List[Text]:
        """获取消息属性的处理文本"""

        if message.get(attribute) is None:
            # 返回空列表,因为 sklearn countvectorizer 在训练和预测时不喜欢 None 对象
            return []

        tokens = self._get_message_tokens_by_attribute(message, attribute) # ['今天', '上海', '的', '天气', '怎么样']
        tokens = self._process_tokens(tokens, attribute)
        tokens = self._replace_with_oov_token(tokens, attribute)

        return tokens

    # noinspection PyPep8Naming
    def _check_OOV_present(self, all_tokens: List[List[Text]], attribute: Text) -> None:
        """Check if an OOV word is present"""
        if not self.OOV_token or self.OOV_words or not all_tokens:
            return

        for tokens in all_tokens:
            for text in tokens:
                if self.OOV_token in text or (
                    self.lowercase and self.OOV_token in text.lower()
                ):
                    return

        if any(text for tokens in all_tokens for text in tokens):
            training_data_type = "NLU" if attribute == TEXT else "ResponseSelector"

            # if there is some text in tokens, warn if there is no oov token
            rasa.shared.utils.io.raise_warning(
                f"The out of vocabulary token '{self.OOV_token}' was configured, but "
                f"could not be found in any one of the {training_data_type} "
                f"training examples. All unseen words will be ignored during prediction.",
                docs=DOCS_URL_COMPONENTS + "#countvectorsfeaturizer",
            )

    def _get_all_attributes_processed_tokens(
        self, training_data: TrainingData
    ) -> Dict[Text, List[List[Text]]]:
        """获取训练数据中所有示例属性的处理文本"""

        processed_attribute_tokens = {}
        for attribute in self._attributes: # ['text', 'intent', 'response', 'action_name', 'action_text', 'intent_response_key']
            all_tokens = [
                self._get_processed_message_tokens_by_attribute(example, attribute)
                for example in training_data.training_examples
            ]
            if attribute in DENSE_FEATURIZABLE_ATTRIBUTES:
                # check for oov tokens only in text based attributes
                self._check_OOV_present(all_tokens, attribute)
            processed_attribute_tokens[attribute] = all_tokens

        return processed_attribute_tokens

    @staticmethod
    def _convert_attribute_tokens_to_texts(
        attribute_tokens: Dict[Text, List[List[Text]]]
    ) -> Dict[Text, List[Text]]:
        attribute_texts = {}

        for attribute in attribute_tokens.keys():
            list_of_tokens = attribute_tokens[attribute]
            attribute_texts[attribute] = [" ".join(tokens) for tokens in list_of_tokens] # 空格拼接为字符串

        return attribute_texts

    @staticmethod
    def _get_starting_empty_index(vocabulary: Dict[Text, int]) -> int:
        for key in vocabulary.keys():
            if key.startswith(BUFFER_SLOTS_PREFIX):
                return int(key.split(BUFFER_SLOTS_PREFIX)[1])
        return len(vocabulary)

    def _update_vectorizer_vocabulary(
        self, attribute: Text, new_vocabulary: Set[Text]
    ) -> None:
        """使用新的不可见单词更新矢量器的现有词汇表。这些看不见的字应该只占据空的缓冲槽。
        """
        existing_vocabulary: Dict[Text, int] = self.vectorizers[attribute].vocabulary
        if len(new_vocabulary) > len(existing_vocabulary):
            rasa.shared.utils.io.raise_warning(
                f"New data contains vocabulary of size {len(new_vocabulary)} for attribute {attribute} "
                f"which is larger than the maximum vocabulary size({len(existing_vocabulary)}) "
                f"of the original model. Some tokens will have to be dropped "
                f"in order to continue training. It is advised to re-train the "
                f"model from scratch on the complete data."
            )
        self._merge_new_vocabulary_tokens(existing_vocabulary, new_vocabulary)
        self._set_vocabulary(attribute, existing_vocabulary)

    def _merge_new_vocabulary_tokens(
        self, existing_vocabulary: Dict[Text, int], vocabulary: Set[Text]
    ) -> None:
        available_empty_index = self._get_starting_empty_index(existing_vocabulary)
        for token in vocabulary:
            if token not in existing_vocabulary:
                existing_vocabulary[token] = available_empty_index
                del existing_vocabulary[f"{BUFFER_SLOTS_PREFIX}{available_empty_index}"]
                available_empty_index += 1
                if available_empty_index == len(existing_vocabulary):
                    # We have exhausted all available vocabulary slots.
                    # Drop the remaining vocabulary.
                    return

    def _get_additional_vocabulary_size(
        self, attribute: Text, existing_vocabulary_size: int
    ) -> int:
        """获取要为增量训练保存的额外词汇表大小。

        如果为“self.additional_vocabulary_size`”不是“None”,我们将返回,因为用户应该指定大小。如果不是,则我们采取默认的额外词汇量,即当前词汇量的1/2。
        """
    # 目前不支持 INENTS、ACTION_NAME 和 INTENT_RESPONSE_KEY 的词汇扩展,因为增量训练不支持创建/删除新/现有标签(intents, actions等)
        if attribute not in DENSE_FEATURIZABLE_ATTRIBUTES:
            return 0

        configured_additional_size = self.additional_vocabulary_size.get(attribute)
        if configured_additional_size is not None:
            return configured_additional_size

        # If the user hasn't defined additional vocabulary size,
        # then we increase it by 1000 minimum. If the current
        # vocabulary size is greater than 2000, we take half of
        # that number as additional vocabulary size.
        return max(MIN_ADDITIONAL_CVF_VOCABULARY, int(existing_vocabulary_size * 0.5))

    def _add_buffer_to_vocabulary(self, attribute: Text) -> None:
        """ 为词汇添加额外的标记以进行增量训练。
        这些额外的tokens充当缓冲槽,当作为增量训练的一部分接收到更多数据时,它们会按顺序用完。 这些标记中的每一个都以前缀“buf_”开头,后跟额外的插槽索引。 例如 - buf_1、buf_2、buf_3...等等。
        
        """
        original_vocabulary = self.vectorizers[attribute].vocabulary_
        current_vocabulary_size = len(original_vocabulary)
        for index in range(
            current_vocabulary_size,
            current_vocabulary_size
            + self._get_additional_vocabulary_size(attribute, current_vocabulary_size),
        ):
            original_vocabulary[f"{BUFFER_SLOTS_PREFIX}{index}"] = index
        self._set_vocabulary(attribute, original_vocabulary)

    def _set_vocabulary(
        self, attribute: Text, original_vocabulary: Dict[Text, int]
    ) -> None:
        """Sets the vocabulary of the vectorizer of attribute."""
        self.vectorizers[attribute].vocabulary_ = original_vocabulary
        self.vectorizers[attribute]._validate_vocabulary()

    @staticmethod
    def _construct_vocabulary_from_texts(
        vectorizer: CountVectorizer, texts: List[Text]
    ) -> Set:
        """Applies vectorizer's preprocessor on texts to get the vocabulary from texts"""
        analyzer = vectorizer.build_analyzer()
        vocabulary_words = set()
        for example in texts:
            example_vocabulary: List[Text] = analyzer(example)
            vocabulary_words.update(example_vocabulary)
        return vocabulary_words

    @staticmethod
    def _attribute_texts_is_non_empty(attribute_texts: List[Text]) -> bool:
        return any(attribute_texts)

    def _train_with_shared_vocab(self, attribute_texts: Dict[Text, List[Text]]) -> None:
        """构建矢量化器并使用共享词汇对其进行训练。"""
        combined_cleaned_texts = []
        for attribute in self._attributes:
            combined_cleaned_texts += attribute_texts[attribute]

        # 为了训练共享词汇,我们使用 TEXT 作为构建组合词汇的属性。
        if not self.finetune_mode:
            self.vectorizers = self._create_shared_vocab_vectorizers(
                {
                    "strip_accents": self.strip_accents,
                    "lowercase": self.lowercase,
                    "stop_words": self.stop_words,
                    "min_ngram": self.min_ngram,
                    "max_ngram": self.max_ngram,
                    "max_df": self.max_df,
                    "min_df": self.min_df,
                    "max_features": self.max_features,
                    "analyzer": self.analyzer,
                }
            )
            self._fit_vectorizer_from_scratch(TEXT, combined_cleaned_texts)
        else:
            self._fit_loaded_vectorizer(TEXT, combined_cleaned_texts)
        self._log_vocabulary_stats(TEXT)

    def _train_with_independent_vocab(
        self, attribute_texts: Dict[Text, List[Text]]
    ) -> None:
        """Constructs the vectorizers and train them with an independent vocab."""
        if not self.finetune_mode:
            self.vectorizers = self._create_independent_vocab_vectorizers(
                {
                    "strip_accents": self.strip_accents,
                    "lowercase": self.lowercase,
                    "stop_words": self.stop_words,
                    "min_ngram": self.min_ngram,
                    "max_ngram": self.max_ngram,
                    "max_df": self.max_df,
                    "min_df": self.min_df,
                    "max_features": self.max_features,
                    "analyzer": self.analyzer,
                }
            )
        for attribute in self._attributes:
            if self._attribute_texts_is_non_empty(attribute_texts[attribute]):
                if not self.finetune_mode:
                    self._fit_vectorizer_from_scratch(
                        attribute, attribute_texts[attribute]
                    )
                else:
                    self._fit_loaded_vectorizer(attribute, attribute_texts[attribute])

                self._log_vocabulary_stats(attribute)
            else:
                logger.debug(
                    f"No text provided for {attribute} attribute in any messages of "
                    f"training data. Skipping training a CountVectorizer for it."
                )

    def _log_vocabulary_stats(self, attribute: Text) -> None:
        """Logs number of vocabulary slots filled out of the total number of available slots.
        记录在可用槽总数中填充的词汇槽数。
        Args:
            attribute: Message attribute for which vocabulary stats are logged.
        """
        if attribute in DENSE_FEATURIZABLE_ATTRIBUTES:
            attribute_vocabulary = self.vectorizers[attribute].vocabulary_
            first_empty_index = self._get_starting_empty_index(attribute_vocabulary)
            logger.info(
                f"{first_empty_index} vocabulary slots "
                f"consumed out of {len(attribute_vocabulary)} "
                f"slots configured for {attribute} attribute."
            )

    def _fit_loaded_vectorizer(
        self, attribute: Text, attribute_texts: List[Text]
    ) -> None:
        """使训练文本适合先前训练过的计数向量化器。

       我们不使用`.fit()` 方法,因为新的看不见的词应该占据词汇表的缓冲区。
        """
        # 通过预处理器获取词汇词
        new_vocabulary = self._construct_vocabulary_from_texts(
            self.vectorizers[attribute], attribute_texts
        )
        # 用新词汇更新向量化器的词汇
        self._update_vectorizer_vocabulary(attribute, new_vocabulary)

    def _fit_vectorizer_from_scratch(
        self, attribute: Text, attribute_texts: List[Text]
    ) -> None:
        """使训练文本适合未经训练的计数向量化器。  """
        try:
            self.vectorizers[attribute].fit(attribute_texts)
        except ValueError:
            logger.warning(
                f"Unable to train CountVectorizer for message "
                f"attribute {attribute} since the call to sklearn's "
                f"`.fit()` method failed. Leaving an untrained "
                f"CountVectorizer for it."
            )
        # 为额外的词汇标记添加缓冲区
        # 在增量训练期间出现
        self._add_buffer_to_vocabulary(attribute)

    def _create_features(
        self, attribute: Text, all_tokens: List[List[Text]]
    ) -> Tuple[
        List[Optional[scipy.sparse.spmatrix]], List[Optional[scipy.sparse.spmatrix]]
    ]:
        if not self.vectorizers.get(attribute):
            return [None], [None]

        sequence_features = []
        sentence_features = []

        for i, tokens in enumerate(all_tokens):
            if not tokens:
                # nothing to featurize
                sequence_features.append(None)
                sentence_features.append(None)
                continue

            # vectorizer.transform 返回大小的稀疏矩阵
            # [n_samples, n_features]
            # 如果应该返回序列,则将输入设置为标记列表
            # 否则将所有标记连接到单个字符串并将其作为列表传递
            if not tokens:
                # attribute is not set (e.g. response not present)
                sequence_features.append(None)
                sentence_features.append(None)
                continue

            seq_vec = self.vectorizers[attribute].transform(tokens)
            seq_vec.sort_indices()

            sequence_features.append(seq_vec.tocoo())

            if attribute in DENSE_FEATURIZABLE_ATTRIBUTES:
                tokens_text = [" ".join(tokens)]
                sentence_vec = self.vectorizers[attribute].transform(tokens_text)
                sentence_vec.sort_indices()

                sentence_features.append(sentence_vec.tocoo())
            else:
                sentence_features.append(None)

        return sequence_features, sentence_features

    def _get_featurized_attribute(
        self, attribute: Text, all_tokens: List[List[Text]]
    ) -> Tuple[
        List[Optional[scipy.sparse.spmatrix]], List[Optional[scipy.sparse.spmatrix]]
    ]:
        """为完整数据返回特定属性的特征"""

        if self._check_attribute_vocabulary(attribute):
            # 训练计数向量化器
            return self._create_features(attribute, all_tokens)
        else:
            return [], []

    def _set_attribute_features(
        self,
        attribute: Text,
        sequence_features: List[scipy.sparse.spmatrix],
        sentence_features: List[scipy.sparse.spmatrix],
        examples: List[Message],
    ) -> None:
        """将属性的计算特征设置为相应的消息对象"""
        for i, message in enumerate(examples):
            # create bag for each example
            if sequence_features[i] is not None:
                final_sequence_features = Features(
                    sequence_features[i],
                    FEATURE_TYPE_SEQUENCE,
                    attribute,
                    self.component_config[FEATURIZER_CLASS_ALIAS],
                )
                message.add_features(final_sequence_features)

            if sentence_features[i] is not None:
                final_sentence_features = Features(
                    sentence_features[i],
                    FEATURE_TYPE_SENTENCE,
                    attribute,
                    self.component_config[FEATURIZER_CLASS_ALIAS],
                )
                message.add_features(final_sentence_features)

    def train(
        self,
        training_data: TrainingData,
        cfg: Optional[RasaNLUModelConfig] = None,
        **kwargs: Any,
    ) -> None:
        """训练特征
        从配置中获取参数并使用 sklearn 框架构建一个新的计数向量化器。
        """

        spacy_nlp = kwargs.get("spacy_nlp")
        if spacy_nlp is not None:
            # create spacy lemma_ for OOV_words
            self.OOV_words = [
                t.lemma_ if self.use_lemma else t.text
                for w in self.OOV_words
                for t in spacy_nlp(w)
            ]

        #   处理句子并收集所有属性的数据
        processed_attribute_tokens = self._get_all_attributes_processed_tokens(
            training_data
        )
        """
        {'text': [['查询', '明天', '的', '天气'],['查询', '今天', '的', '天气'],['北京', '的', '天气']],
	    'intent': [['search_weather'],['greet'],['search_weather']],
	    'response': [[],[],[]],
	    'action_name': [['action_listen'],['utter_greet'],['weather_form']],
	    'action_text': [[],[],[]],
	    'intent_response_key': [[],[],[]]}
        """

        #  训练所有属性
        attribute_texts = self._convert_attribute_tokens_to_texts(
            processed_attribute_tokens
        )
        """
        {'text': [ '今天 上海 的 天气 怎么样', '北京 明天 的 天气 怎么样', '杭州 的 天气'],
        'intent': [ 'search_weather', 'search_weather', 'search_weather'],
        'response': ['', '', ''],
        'action_name': ['action_deactivate_loop', 'utter_greet', 'weather_form'],
        'action_text': ['', '', ''],
        'intent_response_key': ['', '', '']}
        """
        if self.use_shared_vocab:
            self._train_with_shared_vocab(attribute_texts)
        else:
            self._train_with_independent_vocab(attribute_texts) # 构造向量化器并使用独立的词汇对其进行训练

        # transform for all attributes  转换所有属性
        for attribute in self._attributes:
            sequence_features, sentence_features = self._get_featurized_attribute(
                attribute, processed_attribute_tokens[attribute]
            )

            if sequence_features and sentence_features:
                self._set_attribute_features(
                    attribute,
                    sequence_features,
                    sentence_features,
                    training_data.training_examples,
                )

    def process(self, message: Message, **kwargs: Any) -> None:
        """处理传入消息并计算和设置特征"""

        if self.vectorizers is None:
            logger.error(
                "There is no trained CountVectorizer: "
                "component is either not trained or "
                "didn't receive enough training data"
            )
            return
        for attribute in self._attributes:

            message_tokens = self._get_processed_message_tokens_by_attribute(
                message, attribute
            )

            # features shape (1, seq, dim)
            sequence_features, sentence_features = self._create_features(
                attribute, [message_tokens]
            )

            self._set_attribute_features(
                attribute, sequence_features, sentence_features, [message]
            )

    def _collect_vectorizer_vocabularies(self) -> Dict[Text, Optional[Dict[Text, int]]]:
        """获取所有属性的词汇"""

        attribute_vocabularies = {}
        for attribute in self._attributes:
            attribute_vocabularies[attribute] = self._get_attribute_vocabulary(
                attribute
            )
        return attribute_vocabularies

    @staticmethod
    def _is_any_model_trained(attribute_vocabularies) -> bool:
        """检查是否有任何模型得到训练"""

        return any(value is not None for value in attribute_vocabularies.values())

    def persist(self, file_name: Text, model_dir: Text) -> Optional[Dict[Text, Any]]:
        """将此模型保存到传递的目录中。"""

        file_name = file_name + ".pkl"

        if self.vectorizers:
            # vectorizer 实例不是 None,一些模型可以被训练
            attribute_vocabularies = self._collect_vectorizer_vocabularies()
            if self._is_any_model_trained(attribute_vocabularies):
                # 肯定需要坚持一些词汇
                featurizer_file = os.path.join(model_dir, file_name)

                if self.use_shared_vocab:
                    # 只保留一个属性的词汇。 可以加载和分发到所有属性。
                    vocab = attribute_vocabularies[TEXT]
                else:
                    vocab = attribute_vocabularies

                io_utils.json_pickle(featurizer_file, vocab)

        return {"file": file_name}

    @classmethod
    def _create_shared_vocab_vectorizers(
        cls, parameters: Dict[Text, Any], vocabulary: Optional[Any] = None
    ) -> Dict[Text, CountVectorizer]:
        """Create vectorizers for all attributes with shared vocabulary"""

        shared_vectorizer = CountVectorizer(
            token_pattern=r"(?u)\b\w+\b" if parameters["analyzer"] == "word" else None,
            strip_accents=parameters["strip_accents"],
            lowercase=parameters["lowercase"],
            stop_words=parameters["stop_words"],
            ngram_range=(parameters["min_ngram"], parameters["max_ngram"]),
            max_df=parameters["max_df"],
            min_df=parameters["min_df"],
            max_features=parameters["max_features"],
            analyzer=parameters["analyzer"],
            vocabulary=vocabulary,
        )

        attribute_vectorizers = {}

        for attribute in cls._attributes_for(parameters["analyzer"]):
            attribute_vectorizers[attribute] = shared_vectorizer

        return attribute_vectorizers

    @classmethod
    def _create_independent_vocab_vectorizers(
        cls, parameters: Dict[Text, Any], vocabulary: Optional[Any] = None
    ) -> Dict[Text, CountVectorizer]:
        """为具有独立词汇表的所有属性创建向量化器"""

        attribute_vectorizers = {}

        for attribute in cls._attributes_for(parameters["analyzer"]):

            attribute_vocabulary = vocabulary[attribute] if vocabulary else None

            attribute_vectorizer = CountVectorizer(
                token_pattern=r"(?u)\b\w+\b"
                if parameters["analyzer"] == "word"
                else None,
                strip_accents=parameters["strip_accents"],
                lowercase=parameters["lowercase"],
                stop_words=parameters["stop_words"],
                ngram_range=(parameters["min_ngram"], parameters["max_ngram"]),
                max_df=parameters["max_df"],
                min_df=parameters["min_df"],
                max_features=parameters["max_features"],
                analyzer=parameters["analyzer"],
                vocabulary=attribute_vocabulary,
            )
            attribute_vectorizers[attribute] = attribute_vectorizer

        return attribute_vectorizers

    @classmethod
    def load(
        cls,
        meta: Dict[Text, Any],
        model_dir: Optional[Text] = None,
        model_metadata: Optional[Metadata] = None,
        cached_component: Optional["CountVectorsFeaturizer"] = None,
        should_finetune: bool = False,
        **kwargs: Any,
    ) -> "CountVectorsFeaturizer":
        file_name = meta.get("file")
        featurizer_file = os.path.join(model_dir, file_name)

        if not os.path.exists(featurizer_file):
            return cls(meta)

        vocabulary = io_utils.json_unpickle(featurizer_file)

        share_vocabulary = meta["use_shared_vocab"]

        if share_vocabulary:
            vectorizers = cls._create_shared_vocab_vectorizers(
                meta, vocabulary=vocabulary
            )
        else:
            vectorizers = cls._create_independent_vocab_vectorizers(
                meta, vocabulary=vocabulary
            )

        ftr = cls(meta, vectorizers, should_finetune)

        # 确保正确加载词汇表
        for attribute in vectorizers:
            ftr.vectorizers[attribute]._validate_vocabulary()

        return ftr
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

发呆的比目鱼

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值