# Copyrigh Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0.

from awscrt import io, http, auth
from awscrt import mqtt
from awsiot import mqtt_connection_builder
from awsiot import iotshadow
from uuid import uuid4
import datetime
import time
import json
import logging


class AwsDeviceShadow:
    def __init__(self, thing_name, config_data_cls):
        self.thing_name   = thing_name
        self._config_data = config_data_cls
        self.client_id    = thing_name
        self.mqtt_connection = None
        self.shadow_client = None
        self.send_enable = False
        self._logging = logging.getLogger('cloud-agent')

        conf = self._config_data.get_shadow_conf()
        self._pkcs11_lib = io.Pkcs11Lib(
            file=conf[self._config_data.IOT_PKCS11_PATH],
            behavior=io.Pkcs11Lib.InitializeFinalizeBehavior.STRICT)

    def build_pkcs11_mqtt_connection(self):
        conf = self._config_data.get_shadow_conf()

        self._logging.info("Connecting to {} with client ID '{}'...".format(
            conf[self._config_data.IOT_SHADOW_ENDPOINT], self.client_id))
        try:
            self.mqtt_connection = mqtt_connection_builder.mtls_with_pkcs11(
                pkcs11_lib=self._pkcs11_lib,
                user_pin="",
                #user_pin=conf[self._config_data.IOT_PIN],
                slot_id=None,
                token_label=None,
                private_key_label=conf[self._config_data.IOT_KEY_LABEL],
                cert_filepath=conf[self._config_data.IOT_CERT_FILE],
                endpoint=conf[self._config_data.IOT_SHADOW_ENDPOINT],
                port=conf[self._config_data.IOT_PORT],
                ca_filepath=conf[self._config_data.IOT_CA_FILE],
                on_connection_interrupted=self.on_connection_interrupted,
                on_connection_resumed=self.on_connection_resumed,
                client_id=self.client_id,
                clean_session=False,
                keep_alive_secs=30)
            self.shadow_client = iotshadow.IotShadowClient(self.mqtt_connection)
            return True
        except Exception as e:
            self._logging.error("AWS mqtt_connection() has Exception: " \
                + str(e.args) + " Please check cert files.")
            return False

    def create_subscribe_update_shadow_accepted_callback(self, on_update_shadow_accepted):
        try:
            update_accepted_subscribed_future, _ = self.shadow_client.subscribe_to_update_shadow_accepted(
                request=iotshadow.UpdateShadowSubscriptionRequest(thing_name=self.thing_name),
                qos=mqtt.QoS.AT_LEAST_ONCE,
                callback=on_update_shadow_accepted)
            update_accepted_subscribed_future.result()
        except:
            self._logging.warning("create_subscribe_update_shadow_accepted_callback() is failed.")

    def create_subscribe_shadow_delta_updated_callback(self, on_shadow_delta_updated):
        try:
            delta_updated_subscribed_future, _ = self.shadow_client.subscribe_to_shadow_delta_updated_events(
                request=iotshadow.ShadowDeltaUpdatedSubscriptionRequest(thing_name=self.thing_name),
                qos=mqtt.QoS.AT_LEAST_ONCE,
                callback=on_shadow_delta_updated)
            delta_updated_subscribed_future.result()
        except:
            self._logging.warning("create_subscribe_shadow_delta_updated_callback() is failed.")

    def create_subscribe_get_shadow_accepted_callback(self, on_get_shadow_accepted):
        try:
            get_accepted_subscribed_future, _ = self.shadow_client.subscribe_to_get_shadow_accepted(
                request=iotshadow.GetShadowSubscriptionRequest(thing_name=self.thing_name),
                qos=mqtt.QoS.AT_LEAST_ONCE,
                callback=on_get_shadow_accepted)
            get_accepted_subscribed_future.result()
        except:
            self._logging.warning("create_subscribe_get_shadow_accepted_callback() is failed.")

    def on_publish_update_shadow(self, future):
        #type: (Future) -> None
        try:
            future.result()
            self._logging.info("Update request published.")
        except Exception as e:
            self._logging.info("Failed to publish update request.")

    def publish_get_shadow(self):
        if self.send_enable == False:
            self._logging.warning("network is False, update_publish_shadow() is not send")
            return False
        token = str(uuid4())
        request = iotshadow.GetShadowRequest(
            thing_name=self.thing_name,
            client_token=token,
        )
        get_future = self.shadow_client.publish_get_shadow(request, mqtt.QoS.AT_LEAST_ONCE)

    def update_publish_shadow(self, send_dict):
        if self.send_enable == False:
            self._logging.warning("network is False, update_publish_shadow() is not send")
            return None
        token = str(uuid4())
        request = iotshadow.UpdateShadowRequest(
        thing_name=self.thing_name,
        state=iotshadow.ShadowState(
            reported=send_dict,
            #desired=send_dict,
            ),
            client_token=token,
        )
        self._logging.info("Send data: {}".format(send_dict))

        update_future = self.shadow_client.publish_update_shadow(request, mqtt.QoS.AT_LEAST_ONCE)
        update_future.add_done_callback(self.on_publish_update_shadow)
        return token

    def connect(self):
        if self.mqtt_connection == None:
            resp = self.build_pkcs11_mqtt_connection()
            if resp == False:
                return False
        connect_future = self.mqtt_connection.connect()
        try:
            connect_future.result()
        except:
            self.mqtt_connection = None
            return False
        self.send_enable = True
        return True

    def disconnect(self):
        if self.mqtt_connection == None:
             return 

        disconnect_future = self.mqtt_connection.disconnect()
        try:
            disconnect_future.result()
        except:
            pass
        self.send_enable = False
        self.mqtt_connection = None

    def on_connection_interrupted(self, connection, error, **kwargs):
        self.send_enable = False
        self._logging.info("Connection interrupted. error: {}".format(error))

    def on_connection_resumed(self, connection, return_code, session_present, **kwargs):
        self.send_enable = True
        self._logging.info("Connection resumed. return_code: {} session_present: {}".format(return_code, session_present))

    def is_send_enable(self):
        return self.send_enable


#
# EOF
#
