Newer
Older
# @Author: Felix Kramer <kramer>
# @Date: 2021-05-08T20:34:30+02:00
# @Email: kramer@mpi-cbg.de
# @Project: go-with-the-flow
# @Last modified by: Felix Kramer
# @License: MIT
# standard types
import networkx as nx
import numpy as np
import plotly.graph_objects as go
#generate interactive plots with plotly and return the respective figures
def plot_networkx(input_graph,**kwargs):
options={
'network_id':0,
'color_nodes':['#a845b5'],
node_data=pd.DataFrame()
edge_data=pd.DataFrame()
for k,v in kwargs.items():
if k in options:
options[k]=v
if 'node_data' in kwargs:
for sk,sv in kwargs['node_data'].items():
node_data[sk]=sv.to_numpy()
if 'edge_data' in kwargs:
for sk,sv in kwargs['edge_data'].items():
edge_data[sk]=sv.to_numpy()
add_traces_nodes(fig,options,input_graph,node_data)
add_traces_edges(fig,options,input_graph,edge_data)
fig.update_layout(showlegend=False)
return fig
def plot_networkx_dual(dual_graph,**kwargs):
options={
'network_id':0,
'color_nodes':['#6aa84f','#a845b5'],
'color_edges':['#2BDF94','#c762d4'],
'markersize': [2,2]
for k,v in kwargs.items():
if k in options:
options[k]=v
fig = go.Figure()
node_data=K.get_nodes_data(**kwargs)
edges_data=K.get_edges_data(**kwargs)
add_traces_nodes(fig,options,K.G,node_data)
add_traces_edges(fig,options,K.G,edges_data)
fig.update_layout(showlegend=False)
return fig
# integrate traces into the figure
def add_traces_edges(fig, options, input_graph,extra_data):
idx=options['network_id']
edge_mid_trace=get_edge_mid_trace(input_graph,extra_data,color=options['color_edges'][idx])
edge_invd_traces=get_edge_invd_traces(input_graph,extra_data,color=options['color_edges'][idx])
for eit in edge_invd_traces:
fig.add_trace(eit)
fig.add_trace(edge_mid_trace)
def add_traces_nodes(fig, options, input_graph,extra_data):
idx=options['network_id']
node_trace=get_node_trace( input_graph, extra_data, color = options['color_nodes'][idx], markersize = options['markersize'][idx] )
#auxillary functions generating traces for nodes and edges
def get_edge_mid_trace(input_graph,extra_data, **kwargs):
for k,v in kwargs.items():
if k in options:
options[k]=v
if len(list(pos.values())[0]) != dim:
dim=len(list(pos.values())[0])
E=input_graph.edges()
if 'edge_list' in options:
E=options['edge_list']
middle_node_trace=get_hover_scatter_from_template(dim,options)
for i,xi in enumerate(XYZ):
xi.append((XYZ_0[i]+XYZ_1[i])/2.)
return middle_node_trace
def set_hover_info(trace,XYZ,extra_data):
tags=['x','y','z']
if len(XYZ)<3:
tags=['x','y']
for i,t in enumerate(tags):
trace[t]=XYZ[i]
if len(extra_data.keys())!=0:
data=[ list(extra_data[c]) for c in extra_data.columns]
text=[create_tag(vals,extra_data.columns ) for vals in list(zip(*data))]
trace['text']=text
else:
trace['hoverinfo']='none'
middle_node_trace = go.Scatter3d(
x=[],
y=[],
z=[],
text=[],
mode='markers',
hoverinfo='text',
opacity=0,
marker=dict(**options)
# marker=dict(color=options['color'])
middle_node_trace = go.Scatter(
x=[],
y=[],
text=[],
mode='markers',
hoverinfo='text',
# marker=go.scatter.Marker(
# opacity=0,
# color=options['color']
# )
def get_edge_invd_traces(input_graph,extra_data, **kwargs):
for k,v in kwargs.items():
if k in options:
options[k]=v
# handle exceptions and new containers
colorful=False
if type(options['color'])!=str:
colorful=True
options['colorscale']='plasma'
options['cmin']=np.min(options['color'])
options['cmax']=np.max(options['color'])
if len(list(pos.values())[0]) != dim:
dim=len(list(pos.values())[0])
# if 'edge_list' in options:
# E=options['edge_list']
aux_option['width']=extra_data['weight'][i]
if colorful:
aux_option['color']=[options['color'][i] for j in range(2)]
XYZ_0 = input_graph.nodes[edge[0]]['pos']
XYZ_1 = input_graph.nodes[edge[1]]['pos']
set_edge_info(trace,XYZ_0,XYZ_1)
trace_list.append(trace)
return trace_list
def set_edge_info(trace,XYZ_0,XYZ_1):
tags=['x','y','z']
if len(XYZ_0)<3:
tags=['x','y']
for i,t in enumerate(tags):
trace[t]=[XYZ_0[i], XYZ_1[i], None]
trace=go.Scatter3d(
x=[],
y=[],
z=[],
mode='lines',
# line=dict(color=options['color'], width=options['weight'], cmin=options['cmin'], cmax=options['cmax']),
line=dict(**options),
else:
trace=go.Scatter(
x=[],
y=[],
mode='lines',
# line=dict(color=options['color'], width=options['weight'], cmin=options['cmin'], cmax=options['cmax']),
line=dict(**options),
def get_node_trace(input_graph,extra_data,**kwargs):
options={
'color':'#888',
'dim':3,
'markersize':2
}
for k,v in kwargs.items():
if k in options:
options[k]=v
node_xyz=get_node_coords(input_graph,options)
node_trace = get_node_scatter(node_xyz,extra_data,options)
return node_trace
def get_node_coords(input_graph,options):
pos=nx.get_node_attributes(input_graph,'pos')
if len(list(pos.values())[0])!=options['dim']:
options['dim']=len(list(pos.values())[0])
node_xyz = [[] for i in range(options['dim'])]
N=input_graph.nodes()
if 'node_list' in options:
N=options['edge_list']
for node in N:
xyz_0= pos[node]
for i in range(options['dim']):
node_xyz[i].append(xyz_0[i])
return node_xyz
mode='none'
hover=''
if len(extra_data.keys())!=0:
mode='text'
data=[ list(extra_data[c]) for c in extra_data.columns]
hover=[create_tag(vals,extra_data.columns ) for vals in list(zip(*data))]
if options['dim']==3:
node_trace = go.Scatter3d(
x=node_xyz[0], y=node_xyz[1],z=node_xyz[2],
mode='markers',
else:
node_trace = go.Scatter(
x=node_xyz[0], y=node_xyz[1],
mode='markers',
line_width=2,
color=options['color'])
)
return node_trace