# -*- coding: utf-8 -*-
# vim: tabstop=4 shiftwidth=4 softtabstop=4

# Copyright (c) 2010-2012, GEM Foundation.
#
# OpenQuake is free software: you can redistribute it and/or modify it
# under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# OpenQuake is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with OpenQuake.  If not, see <http://www.gnu.org/licenses/>.

import os
import unittest

from tests.utils import helpers
from openquake import shapes
from openquake import kvs
from openquake.parser import vulnerability
from openquake import xml

TEST_FILE = "examples/vulnerability-model-discrete.xml"
INVALID_TEST_FILE = helpers.get_data_path("invalid/vulnerability.xml")
MISMATCHED_TEST_FILE = "examples/source-model.xml"
NO_OF_CURVES_IN_TEST_FILE = 4


class VulnerabilityModelTestCase(unittest.TestCase):

    def setUp(self):
        self.parser = vulnerability.VulnerabilityModelFile(
                os.path.join(helpers.SCHEMA_DIR, TEST_FILE))

        # delete server side cached data
        kvs.get_client().flushall()

    def test_schema_validation(self):
        self.assertRaises(xml.XMLValidationError,
                          vulnerability.VulnerabilityModelFile,
                          INVALID_TEST_FILE)

        self.assertRaises(xml.XMLMismatchError,
                          vulnerability.VulnerabilityModelFile,
                          os.path.join(helpers.SCHEMA_DIR,
                                       MISMATCHED_TEST_FILE))

    def test_loads_all_the_functions_defined(self):
        self.assertEqual(NO_OF_CURVES_IN_TEST_FILE, len(list(self.parser)))

    def test_loads_the_functions_data(self):
        model = self._load_vulnerability_model()

        self.assertEqual("MMI", model["PK"]["IMT"])
        self.assertEqual("fatalities", model["PK"]["lossCategory"])
        self.assertEqual("PAGER", model["PK"]["vulnerabilitySetID"])
        self.assertEqual("population", model["PK"]["assetCategory"])
        self.assertEqual("LN", model["PK"]["probabilisticDistribution"])

        self.assertEqual([0.00, 0.00, 0.00, 0.00, 0.00, 0.01,
                0.06, 0.18, 0.36, 0.36, 0.36],
                model["PK"]["lossRatio"])

        self.assertEqual([0.30, 0.30, 0.30, 0.30, 0.30, 0.30,
                0.30, 0.30, 0.30, 0.30, 0.30],
                model["PK"]["coefficientsVariation"])

        self.assertEqual([5.00, 5.50, 6.00, 6.50, 7.00, 7.50,
                8.00, 8.50, 9.00, 9.50, 10.00],
                model["PK"]["IML"])

        self.assertEqual([0.00, 0.00, 0.00, 0.00, 0.00, 0.01,
                0.06, 0.18, 0.36, 0.36, 0.36],
                model["IR"]["lossRatio"])

        self.assertEqual([0.30, 0.30, 0.30, 0.30, 0.30, 0.30,
                0.30, 0.30, 0.30, 0.30, 0.30],
                model["IR"]["coefficientsVariation"])

        self.assertEqual([5.00, 5.50, 6.00, 6.50, 7.00, 7.50,
                8.00, 8.50, 9.00, 9.50, 10.00],
                model["IR"]["IML"])

        self.assertEqual("NPAGER", model["AA"]["vulnerabilitySetID"])

        self.assertEqual([6.00, 6.50, 7.00, 7.50, 8.00, 8.50,
                9.00, 9.50, 10.00, 10.50, 11.00],
                model["AA"]["IML"])

        self.assertEqual([0.50, 0.50, 0.50, 0.50, 0.50, 0.50,
                0.50, 0.50, 0.50, 0.50, 0.50],
                model["AA"]["coefficientsVariation"])

    def test_loading_and_storing_model_in_kvs(self):
        path = os.path.join(helpers.SCHEMA_DIR, TEST_FILE)
        vulnerability.load_vulnerability_model(1234, path)
        model = vulnerability.load_vuln_model_from_kvs(1234)

        self.assertEqual(NO_OF_CURVES_IN_TEST_FILE, len(model))

        exp_imls = [5.0, 5.5, 6.0, 6.5, 7.0, 7.5, 8.0, 8.5, 9.0, 9.5, 10.0]
        exp_loss_ratios = [0.00, 0.00, 0.00, 0.00, 0.00, 0.01, 0.06, 0.18,
            0.36, 0.36, 0.36]

        exp_covs = [0.3] * 11

        expected_curve = shapes.VulnerabilityFunction(exp_imls,
            exp_loss_ratios, exp_covs, "LN")

        self.assertEqual(expected_curve, model["PK"])

        expected_curve = shapes.VulnerabilityFunction(exp_imls,
            exp_loss_ratios, exp_covs, "LN")

        self.assertEqual(expected_curve, model["IR"])

    def _load_vulnerability_model(self):
        model = {}

        for vulnerability_function in self.parser:
            model[vulnerability_function["ID"]] = vulnerability_function

        return model
