#!/usr/bin/env python

import xml.etree.ElementTree as ET
import tables
import numpy as np
import matplotlib.pyplot as plt

# Read in UHS hazard results data
xmltree = ET.parse('uhs.xml')
root=xmltree.getroot()
print root
poe=[]
path=[]
periods=[]
for uhs in root:
  period_text=uhs.find("{http://openquake.org/xmlns/nrml/0.3}uhsPeriods").text
  period_text=period_text.split()
  for period in period_text:
    periods.append(float(period))
  tspan=float(uhs.find("{http://openquake.org/xmlns/nrml/0.3}timeSpan").text)
  print tspan, period_text
  for child in uhs:
    if child.tag=="{http://openquake.org/xmlns/nrml/0.3}uhsResult":
      print child.attrib
      poe.append(float(child.attrib['poE']))
      path.append(child.attrib['path'])
intervals=np.floor(tspan * 1/np.array(poe))
periods=np.array(periods)

# Determine the mean and standard deviation for each frequency
muhs=np.zeros(shape=(len(periods),len(intervals)))
suhs=np.zeros(shape=(len(periods),len(intervals)))
for i,intpath in enumerate(path):
  print i,intpath
  h5file=tables.openFile(intpath,mode="r")
  arr=np.array(h5file.getNode('/',name='lon:-97.362-lat:35.645'))
  muhs[:,i]=np.mean(arr,0)
  suhs[:,i]=np.std(arr,0)
np.savetxt('mUHS.csv',muhs,delimiter=',')
np.savetxt('sigmaUHS.csv',suhs,delimiter=',')
plt.figure()
plt.semilogx(periods,np.mean(arr,0))
plt.semilogx(periods,np.mean(arr,0)+np.std(arr,0))
plt.semilogx(periods,np.mean(arr,0)-np.std(arr,0))
ax=plt.gca()
ax.xaxis.grid(True,'minor')
ax.xaxis.grid(True,'major',linewidth=0.75)
ax.yaxis.grid(True,'major',linewidth=0.75)
  
# Generate Plots
poe.reverse()
intr=np.floor(tspan * 1/np.array(poe))
plt.figure()
plt.semilogx(periods,muhs)
ax=plt.gca()
ax.xaxis.grid(True,'minor')
ax.xaxis.grid(True,'major',linewidth=0.75)
ax.yaxis.grid(True,'major',linewidth=0.75)
#xlim=plt.xlim()
#plt.xlim([xlim[1],xlim[0]])
plt.xlabel("Period (s)")
plt.ylabel("SA (g)")
plt.legend(intervals)
plt.savefig("uhs_mean.pdf",bbox='tight')

# Generate Plots
plt.figure()
plt.semilogx(periods,suhs)
#xlim=plt.xlim()
#plt.xlim([xlim[1],xlim[0]])
plt.xlabel("Period (s)")
plt.ylabel("SA (g)")
plt.legend(intervals)
plt.savefig("uhs_std.pdf",bbox='tight')
plt.show()