#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Jun 10 14:04:32 2020

@author: rzelinsky
"""
def TDR_comp(direct):
    from netCDF4 import Dataset
    import numpy as np
    import plotly.graph_objects as go
    import os
    from datetime import datetime
    from datetime import timedelta
    from plotly.subplots import make_subplots
    import plotly.figure_factory as ff

    for d in direct:
        top_dir = d
        print(top_dir)
        entries = []
        for file in os.listdir(top_dir):
            if file.endswith(".nc"):
                print(file)
                entries.append(file)
        entries = sorted(entries)
        flight  = top_dir[-11:-1]

        if os.path.isdir('/home/rzelinsky/public_html/data/'+flight):
            save_dir = '/home/rzelinsky/public_html/data/'+flight+'/'
        else:
            os.system('mkdir /home/rzelinsky/public_html/data/'+flight)
            save_dir = '/home/rzelinsky/public_html/data/'+flight+'/'
        
        lat = np.empty((250,len(entries)))
        lon = np.empty((250,len(entries)))
        level = np.empty((37,len(entries)))
        ref = np.empty((250,250,37,len(entries)))
        u = np.empty((250,250,37,len(entries)))
        v = np.empty((250,250,37,len(entries)))
        w = np.empty((250,250,37,len(entries)))
        wnd_sp = np.empty((250,250,37,len(entries)))
        leg_start = []
        leg_end = []
        i = 0
        for filename in entries:
            nc_single = Dataset(top_dir+'/'+filename, 'r')
            lat[:,i] = nc_single.variables['lats'][:]
            lon[:,i] = nc_single.variables['lons'][:]
            level[:,i] = nc_single.variables['level'][:]
            ref[:,:,:,i] = np.squeeze(nc_single.variables['REFLECTIVITY'][:])
            u[:,:,:,i] = np.squeeze(nc_single.variables['U'][:])
            v[:,:,:,i] = np.squeeze(nc_single.variables['V'][:])
            w[:,:,:,i] = np.squeeze(nc_single.variables['W'][:])
            wnd_sp[:,:,:,i] = np.squeeze(nc_single.variables['WIND_SPEED'][:])
            i = i+1
            time_units = nc_single.variables['time'].UNITS
            [w1,w2,ymd,hms,n1] = time_units.split()
            [uy,umo,ud] = ymd.split('-')
            [uh,um,us] = hms.split(':')
            unit_time = datetime(int(uy),int(umo),int(ud),int(uh),int(um),int(us))
            start_time = nc_single.getncattr('START_TIME')
            end_time = nc_single.getncattr('END_TIME')
            start = unit_time+timedelta(seconds=int(start_time))
            leg_start.append(start.strftime('%Y/%m/%d %H:%M:%S'))
            end = unit_time+timedelta(seconds=int(end_time))
            leg_end.append(end.strftime('%Y/%m/%d %H:%M:%S'))
            
        storm_name = nc_single.getncattr('STMNAME')

        
        ref_new = np.transpose(np.copy(ref),(1,0,2,3))
        ref_new[np.where(ref_new < -999)] = np.nan
        u_new = np.transpose(np.copy(u),(1,0,2,3))
        u_new[np.where(u_new < -999)] = np.nan
        v_new = np.transpose(np.copy(v),(1,0,2,3))
        v_new[np.where(v_new < -999)] = np.nan
        w_new = np.transpose(np.copy(w),(1,0,2,3))
        w_new[np.where(w_new < -999)] = np.nan
        wnd_new = np.transpose(np.copy(wnd_sp),(1,0,2,3))
        wnd_new[np.where(wnd_new < -999)] = np.nan
        
        ref_comp = np.nanmean(ref_new,3)
        u_comp = np.nanmean(u_new,3)*1.94384 #convert to kts
        v_comp = np.nanmean(v_new,3)*1.94384 #convert to kts
        w_comp = np.nanmean(w_new,3)*1.94384 #convert to kts
        wind = np.nanmean(wnd_new,3)*1.94384 #convert to kts
        
        
        color_map = [
         [0.0, 'rgb(255, 255, 255)'],
         [0.1, 'rgb(0, 0, 255)'],
         [0.2, 'rgb(0, 255, 255)'],
         [0.3, 'rgb(0, 255, 128)'],
         [0.4, 'rgb(0, 255, 0)'],
         [0.5, 'rgb(128, 255, 0)'],
         [0.6, 'rgb(255, 255, 0)'],
         [0.7, 'rgb(255, 128, 0)'],
         [0.8, 'rgb(255, 0, 0)'],
         [0.9, 'rgb(255, 0, 255)'],
         [1.0, 'rgb(127, 0, 255)'],]
        
        black = [[0, 'rgb(0,0,0)'],
                 [1, 'rgb(0,0,0)']]
        
        ind = np.where(np.isnan(ref_comp[:,:,:]) == False)
        camera1 = dict(
            up=dict(x=0, y=0., z=1),
            eye=dict(x=1.25, y=-1.25, z=.25))
        
        fig = make_subplots(rows=2,cols=2,specs=[[{"type":"scene","rowspan":2},{"type":"xy","rowspan":2}],
                                         [None,None]],
                    column_widths=[0.6, 0.4],row_heights = [0.4, 0.6], vertical_spacing=0.05,
                    subplot_titles = [None,'0.5 km Reflectivity',None, None])

        X, Y, Z = np.meshgrid(np.arange(0,250,1),np.arange(0,250,1),level[:,0])
        X_e = X[ind[0].min():ind[0].max():1,ind[1].min():ind[1].max():1,:]
        Y_e = Y[ind[0].min():ind[0].max():1,ind[1].min():ind[1].max():1,:]
        Z_e = Z[ind[0].min():ind[0].max():1,ind[1].min():ind[1].max():1,:]
        u_e = u_comp[ind[0].min():ind[0].max():1,ind[1].min():ind[1].max():1,:]
        v_e = v_comp[ind[0].min():ind[0].max():1,ind[1].min():ind[1].max():1,:]
        w_e = w_comp[ind[0].min():ind[0].max():1,ind[1].min():ind[1].max():1,:]
        
        fig.append_trace(go.Volume(
            x=X[ind[0].min():ind[0].max():1,ind[1].min():ind[1].max():1,:].flatten(),
            y=Y[ind[0].min():ind[0].max():1,ind[1].min():ind[1].max():1,:].flatten(),
            z=Z[ind[0].min():ind[0].max():1,ind[1].min():ind[1].max():1,:].flatten(),
            value=ref_comp[ind[0].min():ind[0].max():1,ind[1].min():ind[1].max():1,:].flatten(),
            opacityscale = 'max',
            surface_count = 10,
            cmin = 0, cmax = 70,
            isomin = 10,
            colorscale = color_map,
            #colorbar = dict(x = 0.54, len = .8, title = 'Reflectivity (dBZ)'),
            showscale = False,
            scene = 'scene1',),1,1)
        fig.append_trace(go.Cone(x = X_e[::20,::20,::4].flatten(),
                            y = Y_e[::20,::20,::4].flatten(),
                            z = Z_e[::20,::20,::4].flatten(),
                            u = u_e[::20,::20,::4].flatten(),
                            v = v_e[::20,::20,::4].flatten(),
                            w = w_e[::20,::20,::4].flatten(),
                          sizeref = 0.1,
                          colorscale = black,
                          opacity = 0.3,
                          showscale = False,
                          anchor = 'tail',
                            ),1,1)
        fig.update_layout(scene_camera = camera1,
            title = flight+' ('+storm_name+')'+"<br>"+'Valid: '+leg_start[0]+' to '+leg_end[-1],
                 width=1400,
                 height=600,
                 scene1=dict(
                            #zaxis=dict(range=[-0.1, nb_pos], autorange=False),
                            xaxis_title='East-West',
                            yaxis_title='North-South',
                            zaxis_title='Height (km)',
                            xaxis = dict(range=[0,249], autorange = False), 
                            yaxis = dict(range=[0,249], autorange = False), 
                            zaxis = dict(range=[level.min(), level.max()], autorange = False),
                             aspectratio=dict(x=1, y=1, z=0.5),
                            ),)
        
        fig.append_trace(go.Heatmap(
            x=X[ind[0].min():ind[0].max():1,ind[1].min():ind[1].max():1,1].flatten(),
            y=Y[ind[0].min():ind[0].max():1,ind[1].min():ind[1].max():1,1].flatten(),
            z=ref_comp[ind[0].min():ind[0].max():1,ind[1].min():ind[1].max():1,1].flatten(),
            colorscale = color_map,
            zsmooth='best', 
            connectgaps=False,
            zmin = 0,
            zmax = 70,
            colorbar = dict(title = 'Reflectivity (dBZ)'),),1,2)
        
        x_t = X[::10,::10,1]
        y_t = Y[::10,::10,1]
        u_t = u_comp[::10,::10,1]
        v_t = v_comp[::10,::10,1]
        
        f = ff.create_quiver(x_t,y_t,u_t,v_t,  scale = 0.25, scaleratio = 1,
                             marker_color='black',hoverinfo='none',showlegend = False)
        
        fig.add_trace(f.data[0])
        fig.update_layout(hovermode="closest")
        fig.update_xaxes(title_text="East-West", range=[0, 250], row=1, col=2)
        fig.update_yaxes(title_text="North-South", range=[0, 250], row=1, col=2)
        
        fig.write_html(save_dir+flight+'_'+storm_name+'_reflectivity.html')
        
        

        fig = make_subplots(rows=2,cols=2,specs=[[{"type":"scene","rowspan":2},{"type":"xy","rowspan":2}],
                                         [None,None]],
                    column_widths=[0.6, 0.4],row_heights = [0.4, 0.6], vertical_spacing=0.05,
                    subplot_titles = [None,'0.5 km Wind Speed',None, None])
        
        fig.append_trace(go.Volume(
            x=X[ind[0].min():ind[0].max():1,ind[1].min():ind[1].max():1,:].flatten(),
            y=Y[ind[0].min():ind[0].max():1,ind[1].min():ind[1].max():1,:].flatten(),
            z=Z[ind[0].min():ind[0].max():1,ind[1].min():ind[1].max():1,:].flatten(),
            value=wind[ind[0].min():ind[0].max():1,ind[1].min():ind[1].max():1,:].flatten(),
            opacityscale = 'max',
            surface_count = 10,
            cmin = 0, cmax = 140,
            isomin = 10,
            colorscale = 'Jet',
            #colorbar = dict(x = 0.54, len = .8, title = 'Reflectivity (dBZ)'),
            showscale = False,
            scene = 'scene1',),1,1)
        fig.append_trace(go.Cone(x = X_e[::20,::20,::4].flatten(),
                            y = Y_e[::20,::20,::4].flatten(),
                            z = Z_e[::20,::20,::4].flatten(),
                            u = u_e[::20,::20,::4].flatten(),
                            v = v_e[::20,::20,::4].flatten(),
                            w = w_e[::20,::20,::4].flatten(),
                          sizeref = 0.1,
                          colorscale = black,
                          opacity = 0.3,
                          showscale = False,
                          anchor = 'tail',
                            ),1,1)
        
        fig.update_layout(scene_camera = camera1,
            title = flight+' ('+storm_name+')'+"<br>"+'Valid: '+leg_start[0]+' to '+leg_end[-1],
                 width=1400,
                 height=600,
                 scene1=dict(
                            #zaxis=dict(range=[-0.1, nb_pos], autorange=False),
                            xaxis_title='East-West',
                            yaxis_title='North-South',
                            zaxis_title='Height (km)',
                            xaxis = dict(range=[0,249], autorange = False), 
                            yaxis = dict(range=[0,249], autorange = False), 
                            zaxis = dict(range=[level.min(), level.max()], autorange = False),
                             aspectratio=dict(x=1, y=1, z=0.5),
                            ),)
        
        fig.append_trace(go.Heatmap(
            x=X[ind[0].min():ind[0].max():1,ind[1].min():ind[1].max():1,1].flatten(),
            y=Y[ind[0].min():ind[0].max():1,ind[1].min():ind[1].max():1,1].flatten(),
            z=wind[ind[0].min():ind[0].max():1,ind[1].min():ind[1].max():1,1].flatten(),
            colorscale = 'Jet',
            zsmooth='best', 
            connectgaps=False,
            zmin = 0,
            zmax = 140,
            colorbar = dict(title = 'Wind Speed (kts)'),),1,2)

        f = ff.create_quiver(x_t,y_t,u_t,v_t, scale = 0.25, scaleratio = 1,
                             marker_color='black',hoverinfo='none',showlegend = False)
        
        fig.add_trace(f.data[0])
        fig.update_layout(hovermode="closest")
        fig.update_xaxes(title_text="East-West", range=[0, 250], row=1, col=2)
        fig.update_yaxes(title_text="North-South", range=[0, 250], row=1, col=2)
        
        fig.write_html(save_dir+flight+'_'+storm_name+'_wind.html')

#### Horizontal slice render reflectivity only
        camera1 = dict(
        up=dict(x=0, y=0., z=1),
        eye=dict(x=1.25, y=-1.25, z=1.25))

        X, Y, Z = np.meshgrid(np.arange(0,250,1),np.arange(0,250,1),level[:,0])
        ind = np.where(np.isnan(ref_comp[:,:,:]) == False)
        r, c = X[ind[0].min():ind[0].max():1,ind[1].min():ind[1].max():1,0].shape
        
        nb_frames = 25
        nb_start = nb_frames-1
        nb_pos = nb_start/10
        cmi = 0
        cma = 70
        
        X_e = X[ind[0].min():ind[0].max():1,ind[1].min():ind[1].max():1,:]
        Y_e = Y[ind[0].min():ind[0].max():1,ind[1].min():ind[1].max():1,:]
        Z_e = Z[ind[0].min():ind[0].max():1,ind[1].min():ind[1].max():1,:]
        u_e = u_comp[ind[0].min():ind[0].max():1,ind[1].min():ind[1].max():1,:]
        v_e = v_comp[ind[0].min():ind[0].max():1,ind[1].min():ind[1].max():1,:]
        w_e = w_comp[ind[0].min():ind[0].max():1,ind[1].min():ind[1].max():1,:]
        
        frames=[go.Frame(data=[go.Surface(
                                    x = X_e[:,:,k],
                                    y = Y_e[:,:,k],
                                    z = Z_e[:,:,k],
                                    hovertext=np.round(ref_comp[ind[0].min():ind[0].max():1,ind[1].min():ind[1].max():1,k],decimals = 0),
                                    surfacecolor=ref_comp[ind[0].min():ind[0].max():1,ind[1].min():ind[1].max():1,k],
                                    cmin=cmi, cmax=cma,
                                    opacityscale = [[0, 0.0], [0.1, 0.8], [1, 0.8]],),
                                go.Cone(x = X_e[::10,::10,k].flatten(),
                                   y = Y_e[::10,::10,k].flatten(),
                                   z = Z_e[::10,::10,k].flatten(),
                                   u = u_e[::10,::10,k].flatten(),
                                   v = v_e[::10,::10,k].flatten(),
                                   w = w_e[::10,::10,k].flatten(),
                                   sizeref = 0.1,
                                   colorscale = black,
                                   opacity = 0.3,
                                   showscale = False,
                                   anchor = 'tail',)],
                         name=str(level[k]),
                         )
                for k in range(len(level))]
        
        fig = go.Figure(frames = frames)
        
        fig.add_trace(go.Surface(
            x = X_e[:,:,0],
            y = Y_e[:,:,0],
            z = Z_e[:,:,0],
            surfacecolor=ref_comp[ind[0].min():ind[0].max():1,ind[1].min():ind[1].max():1,0],
            colorscale=color_map,
            cmin=cmi, cmax=cma,
            opacityscale = [[0, 0.0], [0.1, 0.8], [1, 0.8]],
            colorbar = dict(title = 'Reflectivity (dBZ)')))
        
        fig.add_trace(go.Cone(x = X_e[::10,::10,0].flatten(),
                                   y = Y_e[::10,::10,0].flatten(),
                                   z = Z_e[::10,::10,0].flatten(),
                                   u = u_e[::10,::10,0].flatten(),
                                   v = v_e[::10,::10,0].flatten(),
                                   w = w_e[::10,::10,0].flatten(),
                                   sizeref = 0.1,
                                   colorscale = black,
                                   opacity = 0.3,
                                   showscale = False,
                                   anchor = 'tail',))
        def frame_args(duration):
            return {
                    "frame": {"duration": duration},
                    "mode": "immediate",
                    "fromcurrent": True,
                    "transition": {"duration": duration, "easing": "linear"},
                }
        
        sliders = [
                    {
                        "pad": {"b": 10, "t": 60},
                        "len": 0.9,
                        "x": 0.1,
                        "y": 0,
                        "steps": [
                            {
                                "args": [[f.name], frame_args(0)],
                                "label": str(level[k,0]),
                                "method": "animate",
                            }
                            for k, f in enumerate(fig.frames)
                        ],
                    }
                ]
        
        # Layout
        fig.update_layout(scene_camera = camera1,
                 title = flight+' ('+storm_name+')'+"<br>"+'Valid: '+leg_start[0]+' to '+leg_end[-1],
                 width=800,
                 height=800,
                 scene=dict(xaxis_title='East-West',
                            yaxis_title='North-South',
                            zaxis_title='Height (km)',
                            xaxis = dict(range=[0,249], autorange = False), 
                            yaxis = dict(range=[0,249], autorange = False), 
                            zaxis = dict(range=[level.min(), level.max()], autorange = False),
                            aspectratio=dict(x=1, y=1, z=0.5),
                            ),
                    
                 updatemenus = [
                    {
                        "buttons": [
                            {
                                "args": [None, frame_args(50)],
                                "label": "&#9654;", # play symbol
                                "method": "animate",
                            },
                            {
                                "args": [[None], frame_args(0)],
                                "label": "&#9724;", # pause symbol
                                "method": "animate",
                            },
                        ],
                        "direction": "left",
                        "pad": {"r": 10, "t": 70},
                        "type": "buttons",
                        "x": 0.1,
                        "y": 0,
                    }
                 ],
                 sliders=sliders
        )
        
        #fig.show()
        fig.write_html(save_dir+flight+'_'+storm_name+'_hor_slice_ref.html',auto_play=False)
        
#### Horizontal slice render wind only
        camera1 = dict(
        up=dict(x=0, y=0., z=1),
        eye=dict(x=1.25, y=-1.25, z=1.25))

        X, Y, Z = np.meshgrid(np.arange(0,250,1),np.arange(0,250,1),level[:,0])
        ind = np.where(np.isnan(ref_comp[:,:,:]) == False)
        r, c = X[ind[0].min():ind[0].max():1,ind[1].min():ind[1].max():1,0].shape
        
        nb_frames = 25
        nb_start = nb_frames-1
        nb_pos = nb_start/10
        cmi = 0
        cma = 70
        
        X_e = X[ind[0].min():ind[0].max():1,ind[1].min():ind[1].max():1,:]
        Y_e = Y[ind[0].min():ind[0].max():1,ind[1].min():ind[1].max():1,:]
        Z_e = Z[ind[0].min():ind[0].max():1,ind[1].min():ind[1].max():1,:]
        u_e = u_comp[ind[0].min():ind[0].max():1,ind[1].min():ind[1].max():1,:]
        v_e = v_comp[ind[0].min():ind[0].max():1,ind[1].min():ind[1].max():1,:]
        w_e = w_comp[ind[0].min():ind[0].max():1,ind[1].min():ind[1].max():1,:]
        
        frames=[go.Frame(data=[go.Surface(
                                    x = X_e[:,:,k],
                                    y = Y_e[:,:,k],
                                    z = Z_e[:,:,k],
                                    hovertext=np.round(wind[ind[0].min():ind[0].max():1,ind[1].min():ind[1].max():1,k],decimals = 0),
                                    surfacecolor=wind[ind[0].min():ind[0].max():1,ind[1].min():ind[1].max():1,k],
                                    cmin=0, cmax=140,
                                    opacityscale = [[0, 0.0], [0.1, 0.8], [1, 0.8]],),
                                go.Cone(x = X_e[::10,::10,k].flatten(),
                                   y = Y_e[::10,::10,k].flatten(),
                                   z = Z_e[::10,::10,k].flatten(),
                                   u = u_e[::10,::10,k].flatten(),
                                   v = v_e[::10,::10,k].flatten(),
                                   w = w_e[::10,::10,k].flatten(),
                                   sizeref = 0.1,
                                   colorscale = black,
                                   opacity = 0.3,
                                   showscale = False,
                                   anchor = 'tail',)],
                         name=str(level[k]),
                         )
                for k in range(len(level))]
        
        fig = go.Figure(frames = frames)
        
        fig.add_trace(go.Surface(
            x = X_e[:,:,0],
            y = Y_e[:,:,0],
            z = Z_e[:,:,0],
            surfacecolor=wind[ind[0].min():ind[0].max():1,ind[1].min():ind[1].max():1,0],
            colorscale='Jet',
            cmin=0, cmax=140,
            opacityscale = [[0, 0.0], [0.1, 0.8], [1, 0.8]],
            colorbar = dict(title = 'Wind Speed (kts)')))
        
        fig.add_trace(go.Cone(x = X_e[::10,::10,0].flatten(),
                                   y = Y_e[::10,::10,0].flatten(),
                                   z = Z_e[::10,::10,0].flatten(),
                                   u = u_e[::10,::10,0].flatten(),
                                   v = v_e[::10,::10,0].flatten(),
                                   w = w_e[::10,::10,0].flatten(),
                                   sizeref = 0.1,
                                   colorscale = black,
                                   opacity = 0.3,
                                   showscale = False,
                                   anchor = 'tail',))
        def frame_args(duration):
            return {
                    "frame": {"duration": duration},
                    "mode": "immediate",
                    "fromcurrent": True,
                    "transition": {"duration": duration, "easing": "linear"},
                }
        
        sliders = [
                    {
                        "pad": {"b": 10, "t": 60},
                        "len": 0.9,
                        "x": 0.1,
                        "y": 0,
                        "steps": [
                            {
                                "args": [[f.name], frame_args(0)],
                                "label": str(level[k,0]),
                                "method": "animate",
                            }
                            for k, f in enumerate(fig.frames)
                        ],
                    }
                ]
        
        # Layout
        fig.update_layout(scene_camera = camera1,
                 title = flight+' ('+storm_name+')'+"<br>"+'Valid: '+leg_start[0]+' to '+leg_end[-1],
                 width=800,
                 height=800,
                 scene=dict(xaxis_title='East-West',
                            yaxis_title='North-South',
                            zaxis_title='Height (km)',
                            xaxis = dict(range=[0,249], autorange = False), 
                            yaxis = dict(range=[0,249], autorange = False), 
                            zaxis = dict(range=[level.min(), level.max()], autorange = False),
                            aspectratio=dict(x=1, y=1, z=0.5),
                            ),
                    
                 updatemenus = [
                    {
                        "buttons": [
                            {
                                "args": [None, frame_args(50)],
                                "label": "&#9654;", # play symbol
                                "method": "animate",
                            },
                            {
                                "args": [[None], frame_args(0)],
                                "label": "&#9724;", # pause symbol
                                "method": "animate",
                            },
                        ],
                        "direction": "left",
                        "pad": {"r": 10, "t": 70},
                        "type": "buttons",
                        "x": 0.1,
                        "y": 0,
                    }
                 ],
                 sliders=sliders
        )
        
        #fig.show()
        fig.write_html(save_dir+flight+'_'+storm_name+'_hor_slice_wind.html',auto_play=False)
                



#### Horizontal slice render
        camera1 = dict(
        up=dict(x=0, y=0., z=1),
        eye=dict(x=1.25, y=-1.25, z=1.25))
        
        nb_frames = 25
        nb_start = nb_frames-1
        nb_pos = nb_start/10
        cmi = 0
        cma = 70


        fig = make_subplots(rows=2,cols=2,specs=[[{"type":"scene","rowspan":2},{"type":"scene","rowspan":2}],
                                         [None,None]],
                    column_widths=[0.5, 0.5],row_heights = [0.5, 0.5], vertical_spacing=0.05,
                    subplot_titles = ['Reflectivity','Wind Speed',None, None])

        frames=[go.Frame(data=[go.Surface(
                                    x = X_e[:,:,k],
                                    y = Y_e[:,:,k],
                                    z = Z_e[:,:,k],
                                    hovertext=np.round(ref_comp[ind[0].min():ind[0].max():1,ind[1].min():ind[1].max():1,k],decimals = 0),
                                    surfacecolor=ref_comp[ind[0].min():ind[0].max():1,ind[1].min():ind[1].max():1,k],
                                    cmin=cmi, cmax=cma,
                                    opacityscale = [[0, 0.0], [0.1, 1], [1, 1]],
                                   scene = 'scene1'),
                                go.Cone(x = X_e[::10,::10,k].flatten(),
                                   y = Y_e[::10,::10,k].flatten(),
                                   z = Z_e[::10,::10,k].flatten(),
                                   u = u_e[::10,::10,k].flatten(),
                                   v = v_e[::10,::10,k].flatten(),
                                   w = w_e[::10,::10,k].flatten(),
                                   sizeref = 0.1,
                                   colorscale = black,
                                   opacity = 0.3,
                                   showscale = False,
                                   anchor = 'tail',
                                   scene = 'scene1'),
                                go.Surface(
                                    x = X_e[:,:,k],
                                    y = Y_e[:,:,k],
                                    z = Z_e[:,:,k],
                                    hovertext=np.round(ref_comp[ind[0].min():ind[0].max():1,ind[1].min():ind[1].max():1,k],decimals = 0),
                                    surfacecolor=wind[ind[0].min():ind[0].max():1,ind[1].min():ind[1].max():1,k],
                                    cmin=0, cmax=140,
                                    opacityscale = [[0, 0.0], [0.1, 1], [1, 1]],
                                   scene = 'scene2'),
                                go.Cone(x = X_e[::10,::10,k].flatten(),
                                   y = Y_e[::10,::10,k].flatten(),
                                   z = Z_e[::10,::10,k].flatten(),
                                   u = u_e[::10,::10,k].flatten(),
                                   v = v_e[::10,::10,k].flatten(),
                                   w = w_e[::10,::10,k].flatten(),
                                   sizeref = 0.1,
                                   colorscale = black,
                                   opacity = 0.3,
                                   showscale = False,
                                   anchor = 'tail',
                                   scene = 'scene2'),],
                         name=str(level[k,0]),
                         traces = [0,1,2,3]
                         )
                for k in range(len(level[:,0]))]
        fig.frames=frames
        
        fig.append_trace(go.Surface(
            x = X_e[:,:,0],
            y = Y_e[:,:,0],
            z = Z_e[:,:,0],
            surfacecolor=ref_comp[ind[0].min():ind[0].max():1,ind[1].min():ind[1].max():1,0],
            colorscale=color_map,
            cmin=cmi, cmax=cma,
            opacityscale = [[0, 0.0], [0.1, 1], [1, 1]],
            colorbar = dict(title = 'Reflectivity (dBZ)',x=0.46),
                                   scene = 'scene1'),1,1)
        
        fig.append_trace(go.Cone(x = X_e[::10,::10,0].flatten(),
                                   y = Y_e[::10,::10,0].flatten(),
                                   z = Z_e[::10,::10,0].flatten(),
                                   u = u_e[::10,::10,0].flatten(),
                                   v = v_e[::10,::10,0].flatten(),
                                   w = w_e[::10,::10,0].flatten(),
                                   sizeref = 0.1,
                                   colorscale = black,
                                   opacity = 0.3,
                                   showscale = False,
                                   anchor = 'tail',
                                   scene = 'scene1'),row=1,col=1)
        
        fig.append_trace(go.Surface(
            x = X_e[:,:,0],
            y = Y_e[:,:,0],
            z = Z_e[:,:,0],
            surfacecolor=wind[ind[0].min():ind[0].max():1,ind[1].min():ind[1].max():1,0],
            colorscale='Jet',
            cmin=0, cmax=140,
            opacityscale = [[0, 0.0], [0.1, 1], [1, 1]],
            colorbar = dict(title = 'Wind Speed (kts)'),
                                   scene = 'scene2'),1,2)
        
        fig.append_trace(go.Cone(x = X_e[::10,::10,0].flatten(),
                                   y = Y_e[::10,::10,0].flatten(),
                                   z = Z_e[::10,::10,0].flatten(),
                                   u = u_e[::10,::10,0].flatten(),
                                   v = v_e[::10,::10,0].flatten(),
                                   w = w_e[::10,::10,0].flatten(),
                                   sizeref = 0.1,
                                   colorscale = black,
                                   opacity = 0.3,
                                   showscale = False,
                                   anchor = 'tail',
                                   scene = 'scene2'),row=1,col=2)
        def cam_change(layout, camera):
            fig.layout.scene2.camera = camera1
        
        fig.layout.scene1.on_change(cam_change, 'camera')
        def frame_args(duration):
            return {
                    "frame": {"duration": duration},
                    "mode": "immediate",
                    "fromcurrent": True,
                    "transition": {"duration": duration, "easing": "linear"},
                }
        
        sliders = [
                    {
                        "pad": {"b": 10, "t": 60},
                        "len": 0.9,
                        "x": 0.1,
                        "y": 0,
                        "steps": [
                            {
                                "args": [[f.name], frame_args(0)],
                                "label": str(level[k,0]),
                                "method": "animate",
                            }
                            for k, f in enumerate(fig.frames)
                        ],
                    }
                ]
        
        # Layout
        fig.update_layout(scene_camera = camera1,
                 title = flight+' ('+storm_name+')'+"<br>"+'Valid: '+leg_start[0]+' to '+leg_end[-1],
                 width=1400,
                 height=600,
                 scene1=dict(xaxis_title='East-West',
                            yaxis_title='North-South',
                            zaxis_title='Height (km)',
                            xaxis = dict(range=[0,249], autorange = False), 
                            yaxis = dict(range=[0,249], autorange = False), 
                            zaxis = dict(range=[level.min(), level.max()], autorange = False),
                            aspectratio=dict(x=1, y=1, z=0.5),
                            ),
                 scene2=dict(xaxis_title='East-West',
                            yaxis_title='North-South',
                            zaxis_title='Height (km)',
                            xaxis = dict(range=[0,249], autorange = False), 
                            yaxis = dict(range=[0,249], autorange = False), 
                            zaxis = dict(range=[level.min(), level.max()], autorange = False),
                            aspectratio=dict(x=1, y=1, z=0.5),
                            ),
                    
                 updatemenus = [
                    {
                        "buttons": [
                            {
                                "args": [None, frame_args(50)],
                                "label": "&#9654;", # play symbol
                                "method": "animate",
                            },
                            {
                                "args": [[None], frame_args(0)],
                                "label": "&#9724;", # pause symbol
                                "method": "animate",
                            },
                        ],
                        "direction": "left",
                        "pad": {"r": 10, "t": 70},
                        "type": "buttons",
                        "x": 0.1,
                        "y": 0,
                    }
                 ],
                 sliders=sliders
        )
        
        #fig.show()
        fig.write_html(save_dir+flight+'_'+storm_name+'_hor_slice.html',auto_play=False)
        
### Vertical Slice Reflectivity only
        camera1 = dict(
            up=dict(x=0, y=0., z=1),
            eye=dict(x=1.25, y=-1.25, z=1.25))
        
        X, Y, Z = np.meshgrid(np.arange(0,250,1),np.arange(0,250,1),level[:,0])
        ind = np.where(np.isnan(ref_comp[:,:,:]) == False)
        r, c = X[ind[0].min():ind[0].max():1,ind[1].min():ind[1].max():1,0].shape
        
        nb_frames = 25
        nb_start = nb_frames-1
        nb_pos = nb_start/10
        cmi = 0
        cma = 70
        
        X_e = X[ind[0].min():ind[0].max():1,ind[1].min():ind[1].max():1,:]
        Y_e = Y[ind[0].min():ind[0].max():1,ind[1].min():ind[1].max():1,:]
        Z_e = Z[ind[0].min():ind[0].max():1,ind[1].min():ind[1].max():1,:]
        u_e = u_comp[ind[0].min():ind[0].max():1,ind[1].min():ind[1].max():1,:]
        v_e = v_comp[ind[0].min():ind[0].max():1,ind[1].min():ind[1].max():1,:]
        w_e = w_comp[ind[0].min():ind[0].max():1,ind[1].min():ind[1].max():1,:]
        ref_e = ref_comp[ind[0].min():ind[0].max():1,ind[1].min():ind[1].max():1,:]
        wind_e = wind[ind[0].min():ind[0].max():1,ind[1].min():ind[1].max():1,:]
        
        fig = go.Figure()
        km = np.arange(0,250,1)
        km = km[ind[1].min():ind[1].max():1]
        frames=[go.Frame(data=[go.Surface(
                                    x = X_e[:,k,:],
                                    y = Y_e[:,k,:],
                                    z = Z_e[:,k,:],
                                    hovertext=np.round(ref_e[:,k,:],decimals = 0),
                                    surfacecolor=ref_e[:,k,:],
                                    cmin=cmi, cmax=cma,
                                    opacityscale = [[0, 0.0], [0.1, 1], [1, 1]],
                                   scene = 'scene1'),
                                go.Cone(x = X_e[::10,k,::4].flatten(),
                                   y = Y_e[::10,k,::4].flatten(),
                                   z = Z_e[::10,k,::4].flatten(),
                                   u = u_e[::10,k,::4].flatten(),
                                   v = v_e[::10,k,::4].flatten(),
                                   w = w_e[::10,k,::4].flatten(),
                                   sizeref = 0.1,
                                   colorscale = black,
                                   opacity = 0.3,
                                   showscale = False,
                                   anchor = 'tail',
                                   scene = 'scene1')],
                         name=str(km[k]),
                         )
                for k in range(len(km))]
        fig.frames=frames
        
        fig.add_trace(go.Surface(
            x = X_e[:,0,:],
            y = Y_e[:,0,:],
            z = Z_e[:,0,:],
            surfacecolor=ref_e[:,0,:],
            colorscale=color_map,
            cmin=cmi, cmax=cma,
            opacityscale = [[0, 0.0], [0.1, 1], [1, 1]],
            colorbar = dict(title = 'Reflectivity (dBZ)'),
                                   scene = 'scene1'))
        
        fig.add_trace(go.Cone(x = X_e[::10,0,::4].flatten(),
                                   y = Y_e[::10,0,::4].flatten(),
                                   z = Z_e[::10,0,::4].flatten(),
                                   u = u_e[::10,0,::4].flatten(),
                                   v = v_e[::10,0,::4].flatten(),
                                   w = w_e[::10,0,::4].flatten(),
                                   sizeref = 0.1,
                                   colorscale = black,
                                   opacity = 0.3,
                                   showscale = False,
                                   anchor = 'tail',
                                   scene = 'scene1'))
        
        
        def frame_args(duration):
            return {
                    "frame": {"duration": duration},
                    "mode": "immediate",
                    "fromcurrent": True,
                    "transition": {"duration": duration, "easing": "linear"},
                }
        
        sliders = [
                    {
                        "pad": {"b": 10, "t": 60},
                        "len": 0.9,
                        "x": 0.1,
                        "y": 0,
                        "steps": [
                            {
                                "args": [[f.name], frame_args(0)],
                                "label": str(km[k]),
                                "method": "animate",
                            }
                            for k, f in enumerate(fig.frames)
                        ],
                    }
                ]
        
        # Layout
        fig.update_layout(scene_camera = camera1,
                 title = flight+' ('+storm_name+')'+"<br>"+'Valid: '+leg_start[0]+' to '+leg_end[-1],
                 width=800,
                 height=800,
                 scene1=dict(xaxis_title='East-West',
                            yaxis_title='North-South',
                            zaxis_title='Height (km)',
                            xaxis = dict(range=[0,249], autorange = False), 
                            yaxis = dict(range=[0,249], autorange = False), 
                            zaxis = dict(range=[level.min(), level.max()], autorange = False),
                            aspectratio=dict(x=1, y=1, z=0.5),
                            ),
                    
                 updatemenus = [
                    {
                        "buttons": [
                            {
                                "args": [None, frame_args(50)],
                                "label": "&#9654;", # play symbol
                                "method": "animate",
                            },
                            {
                                "args": [[None], frame_args(0)],
                                "label": "&#9724;", # pause symbol
                                "method": "animate",
                            },
                        ],
                        "direction": "left",
                        "pad": {"r": 10, "t": 70},
                        "type": "buttons",
                        "x": 0.1,
                        "y": 0,
                    }
                 ],
                 sliders=sliders
        )
        
        fig.write_html(save_dir+flight+'_'+storm_name+'_vert_slice_ref.html',auto_play=False)
        
#### Vertical slice wind only
        fig = go.Figure()
        km = np.arange(0,250,1)
        km = km[ind[1].min():ind[1].max():1]
        frames=[go.Frame(data=[go.Surface(
                                    x = X_e[:,k,:],
                                    y = Y_e[:,k,:],
                                    z = Z_e[:,k,:],
                                    hovertext=np.round(wind_e[:,k,:],decimals = 0),
                                    surfacecolor=wind_e[:,k,:],
                                    cmin=0, cmax=140,
                                    opacityscale = [[0, 0.0], [0.1, 1], [1, 1]],
                                   scene = 'scene1'),
                                go.Cone(x = X_e[::10,k,::4].flatten(),
                                   y = Y_e[::10,k,::4].flatten(),
                                   z = Z_e[::10,k,::4].flatten(),
                                   u = u_e[::10,k,::4].flatten(),
                                   v = v_e[::10,k,::4].flatten(),
                                   w = w_e[::10,k,::4].flatten(),
                                   sizeref = 0.1,
                                   colorscale = black,
                                   opacity = 0.3,
                                   showscale = False,
                                   anchor = 'tail',
                                   scene = 'scene1')],
                         name=str(km[k]),
                         )
                for k in range(len(km))]
        fig.frames=frames
        
        fig.add_trace(go.Surface(
            x = X_e[:,0,:],
            y = Y_e[:,0,:],
            z = Z_e[:,0,:],
            surfacecolor=ref_e[:,0,:],
            colorscale='Jet',
            cmin=0, cmax=140,
            opacityscale = [[0, 0.0], [0.1, 1], [1, 1]],
            colorbar = dict(title = 'Wind Speed (kts)'),
                                   scene = 'scene1'))
        
        fig.add_trace(go.Cone(x = X_e[::10,0,::4].flatten(),
                                   y = Y_e[::10,0,::4].flatten(),
                                   z = Z_e[::10,0,::4].flatten(),
                                   u = u_e[::10,0,::4].flatten(),
                                   v = v_e[::10,0,::4].flatten(),
                                   w = w_e[::10,0,::4].flatten(),
                                   sizeref = 0.1,
                                   colorscale = black,
                                   opacity = 0.3,
                                   showscale = False,
                                   anchor = 'tail',
                                   scene = 'scene1'))
        
        
        def frame_args(duration):
            return {
                    "frame": {"duration": duration},
                    "mode": "immediate",
                    "fromcurrent": True,
                    "transition": {"duration": duration, "easing": "linear"},
                }
        
        sliders = [
                    {
                        "pad": {"b": 10, "t": 60},
                        "len": 0.9,
                        "x": 0.1,
                        "y": 0,
                        "steps": [
                            {
                                "args": [[f.name], frame_args(0)],
                                "label": str(km[k]),
                                "method": "animate",
                            }
                            for k, f in enumerate(fig.frames)
                        ],
                    }
                ]
        
        # Layout
        fig.update_layout(scene_camera = camera1,
                 title = flight+' ('+storm_name+')'+"<br>"+'Valid: '+leg_start[0]+' to '+leg_end[-1],
                 width=800,
                 height=800,
                 scene1=dict(xaxis_title='East-West',
                            yaxis_title='North-South',
                            zaxis_title='Height (km)',
                            xaxis = dict(range=[0,249], autorange = False), 
                            yaxis = dict(range=[0,249], autorange = False), 
                            zaxis = dict(range=[level.min(), level.max()], autorange = False),
                            aspectratio=dict(x=1, y=1, z=0.5),
                            ),
                    
                 updatemenus = [
                    {
                        "buttons": [
                            {
                                "args": [None, frame_args(50)],
                                "label": "&#9654;", # play symbol
                                "method": "animate",
                            },
                            {
                                "args": [[None], frame_args(0)],
                                "label": "&#9724;", # pause symbol
                                "method": "animate",
                            },
                        ],
                        "direction": "left",
                        "pad": {"r": 10, "t": 70},
                        "type": "buttons",
                        "x": 0.1,
                        "y": 0,
                    }
                 ],
                 sliders=sliders
        )
        

        fig.write_html(save_dir+flight+'_'+storm_name+'_vert_slice_wind.html',auto_play=False)
        
        

