mongodb.py

from chatterbot.storage import StorageAdapter


class Query(object):

    def __init__(self, query={}):
        self.query = query

    def value(self):
        return self.query.copy()

    def raw(self, data):
        query = self.query.copy()

        query.update(data)

        return Query(query)

    def statement_text_equals(self, statement_text):
        query = self.query.copy()

        query['text'] = statement_text

        return Query(query)

    def statement_text_not_in(self, statements):
        query = self.query.copy()

        if 'text' not in query:
            query['text'] = {}

        if '$nin' not in query['text']:
            query['text']['$nin'] = []

        query['text']['$nin'].extend(statements)

        return Query(query)

    def statement_response_list_contains(self, statement_text):
        query = self.query.copy()

        if 'in_response_to' not in query:
            query['in_response_to'] = {}

        if '$elemMatch' not in query['in_response_to']:
            query['in_response_to']['$elemMatch'] = {}

        query['in_response_to']['$elemMatch']['text'] = statement_text

        return Query(query)

    def statement_response_list_equals(self, response_list):
        query = self.query.copy()

        query['in_response_to'] = response_list

        return Query(query)


class MongoDatabaseAdapter(StorageAdapter):
    """
    The MongoDatabaseAdapter is an interface that allows
    ChatterBot to store statements in a MongoDB database.

    :keyword database: The name of the database you wish to connect to.
    :type database: str

    .. code-block:: python

       database='chatterbot-database'

    :keyword database_uri: The URI of a remote instance of MongoDB.
    :type database_uri: str

    .. code-block:: python

       database_uri='mongodb://example.com:8100/'
    """

    def __init__(self, **kwargs):
        super(MongoDatabaseAdapter, self).__init__(**kwargs)
        from pymongo import MongoClient
        from pymongo.errors import OperationFailure

        self.database_name = self.kwargs.get(
            'database', 'chatterbot-database'
        )
        self.database_uri = self.kwargs.get(
            'database_uri', 'mongodb://localhost:27017/'
        )

        # Use the default host and port
        self.client = MongoClient(self.database_uri)

        # Increase the sort buffer to 42M if possible
        try:
            self.client.admin.command({'setParameter': 1, 'internalQueryExecMaxBlockingSortBytes': 44040192})
        except OperationFailure:
            pass

        # Specify the name of the database
        self.database = self.client[self.database_name]

        # The mongo collection of statement documents
        self.statements = self.database['statements']

        # The mongo collection of conversation documents
        self.conversations = self.database['conversations']

        # Set a requirement for the text attribute to be unique
        self.statements.create_index('text', unique=True)

        self.base_query = Query()

    def get_statement_model(self):
        """
        Return the class for the statement model.
        """
        from chatterbot.conversation import Statement

        # Create a storage-aware statement
        statement = Statement
        statement.storage = self

        return statement

    def get_response_model(self):
        """
        Return the class for the response model.
        """
        from chatterbot.conversation import Response

        # Create a storage-aware response
        response = Response
        response.storage = self

        return response

    def count(self):
        return self.statements.count()

    def find(self, statement_text):
        Statement = self.get_model('statement')
        query = self.base_query.statement_text_equals(statement_text)

        values = self.statements.find_one(query.value())

        if not values:
            return None

        del values['text']

        # Build the objects for the response list
        values['in_response_to'] = self.deserialize_responses(
            values.get('in_response_to', [])
        )

        return Statement(statement_text, **values)

    def deserialize_responses(self, response_list):
        """
        Takes the list of response items and returns
        the list converted to Response objects.
        """
        Statement = self.get_model('statement')
        Response = self.get_model('response')
        proxy_statement = Statement('')

        for response in response_list:
            text = response['text']
            del response['text']

            proxy_statement.add_response(
                Response(text, **response)
            )

        return proxy_statement.in_response_to

    def mongo_to_object(self, statement_data):
        """
        Return Statement object when given data
        returned from Mongo DB.
        """
        Statement = self.get_model('statement')
        statement_text = statement_data['text']
        del statement_data['text']

        statement_data['in_response_to'] = self.deserialize_responses(
            statement_data.get('in_response_to', [])
        )

        return Statement(statement_text, **statement_data)

    def filter(self, **kwargs):
        """
        Returns a list of statements in the database
        that match the parameters specified.
        """
        import pymongo

        query = self.base_query

        order_by = kwargs.pop('order_by', None)

        # Convert Response objects to data
        if 'in_response_to' in kwargs:
            serialized_responses = []
            for response in kwargs['in_response_to']:
                serialized_responses.append({'text': response})

            query = query.statement_response_list_equals(serialized_responses)
            del kwargs['in_response_to']

        if 'in_response_to__contains' in kwargs:
            query = query.statement_response_list_contains(
                kwargs['in_response_to__contains']
            )
            del kwargs['in_response_to__contains']

        query = query.raw(kwargs)

        matches = self.statements.find(query.value())

        if order_by:

            direction = pymongo.ASCENDING

            # Sort so that newer datetimes appear first
            if order_by == 'created_at':
                direction = pymongo.DESCENDING

            matches = matches.sort(order_by, direction)

        results = []

        for match in list(matches):
            results.append(self.mongo_to_object(match))

        return results

    def update(self, statement):
        from pymongo import UpdateOne
        from pymongo.errors import BulkWriteError

        data = statement.serialize()

        operations = []

        update_operation = UpdateOne(
            {'text': statement.text},
            {'$set': data},
            upsert=True
        )
        operations.append(update_operation)

        # Make sure that an entry for each response is saved
        for response_dict in data.get('in_response_to', []):
            response_text = response_dict.get('text')

            # $setOnInsert does nothing if the document is not created
            update_operation = UpdateOne(
                {'text': response_text},
                {'$set': response_dict},
                upsert=True
            )
            operations.append(update_operation)

        try:
            self.statements.bulk_write(operations, ordered=False)
        except BulkWriteError as bwe:
            # Log the details of a bulk write error
            self.logger.error(str(bwe.details))

        return statement

    def create_conversation(self):
        """
        Create a new conversation.
        """
        conversation_id = self.conversations.insert_one({}).inserted_id
        return conversation_id

    def get_latest_response(self, conversation_id):
        """
        Returns the latest response in a conversation if it exists.
        Returns None if a matching conversation cannot be found.
        """
        from pymongo import DESCENDING

        statements = list(self.statements.find({
            'conversations.id': conversation_id
        }).sort('conversations.created_at', DESCENDING))

        if not statements:
            return None

        return self.mongo_to_object(statements[-2])

    def add_to_conversation(self, conversation_id, statement, response):
        """
        Add the statement and response to the conversation.
        """
        from datetime import datetime, timedelta
        self.statements.update_one(
            {
                'text': statement.text
            },
            {
                '$push': {
                    'conversations': {
                        'id': conversation_id,
                        'created_at': datetime.utcnow()
                    }
                }
            }
        )
        self.statements.update_one(
            {
                'text': response.text
            },
            {
                '$push': {
                    'conversations': {
                        'id': conversation_id,
                        # Force the response to be at least one millisecond after the input statement
                        'created_at': datetime.utcnow() + timedelta(milliseconds=1)
                    }
                }
            }
        )

    def get_random(self):
        """
        Returns a random statement from the database
        """
        from random import randint

        count = self.count()

        if count < 1:
            raise self.EmptyDatabaseException()

        random_integer = randint(0, count - 1)

        statements = self.statements.find().limit(1).skip(random_integer)

        return self.mongo_to_object(list(statements)[0])

    def remove(self, statement_text):
        """
        Removes the statement that matches the input text.
        Removes any responses from statements if the response text matches the
        input text.
        """
        for statement in self.filter(in_response_to__contains=statement_text):
            statement.remove_response(statement_text)
            self.update(statement)

        self.statements.delete_one({'text': statement_text})

    def get_response_statements(self):
        """
        Return only statements that are in response to another statement.
        A statement must exist which lists the closest matching statement in the
        in_response_to field. Otherwise, the logic adapter may find a closest
        matching statement that does not have a known response.
        """
        response_query = self.statements.aggregate([{'$group': {'_id': '$in_response_to.text'}}])

        responses = []
        for r in response_query:
            try:
                responses.extend(r['_id'])
            except TypeError:
                pass

        _statement_query = {
            'text': {
                '$in': responses
            }
        }

        _statement_query.update(self.base_query.value())
        statement_query = self.statements.find(_statement_query)
        statement_objects = []
        for statement in list(statement_query):
            statement_objects.append(self.mongo_to_object(statement))
        return statement_objects

    def drop(self):
        """
        Remove the database.
        """
        self.client.drop_database(self.database_name)

  

posted @ 2018-03-05 17:30  Daniel_Lu  阅读(138)  评论(0编辑  收藏  举报