Subplots in Plotly#

Creating subplots in plotly is not supported using plotly-express. Instead, we have to manually build our figures using plotly graph objects Plotly Express usually hides this step from you. In plotly, when you create a scatter plot, the library is really creating a “scatter object” or variable.

To make a figure using subplots and plotly graph objects, we have to create these ourselves instead of letting plotly-express do it for us.

import pandas as pd

import plotly.graph_objects as go

from plotly.subplots import make_subplots
s_orbitals = pd.read_csv("s_orbitals_1D.csv")
s_orbitals.head()
r 1s 2s 3s
0 0.000000 0.564190 0.199471 0.108578
1 0.517241 0.336349 0.114183 0.061683
2 1.034483 0.200519 0.057408 0.029966
3 1.551724 0.119542 0.020580 0.009313
4 2.068966 0.071266 -0.002445 -0.003390
first_line = go.Scatter(x=s_orbitals["r"], y=s_orbitals["1s"], name="1s")
first_line
Scatter({
    'name': '1s',
    'x': array([ 0.        ,  0.51724138,  1.03448276,  1.55172414,  2.06896552,
                 2.5862069 ,  3.10344828,  3.62068966,  4.13793103,  4.65517241,
                 5.17241379,  5.68965517,  6.20689655,  6.72413793,  7.24137931,
                 7.75862069,  8.27586207,  8.79310345,  9.31034483,  9.82758621,
                10.34482759, 10.86206897, 11.37931034, 11.89655172, 12.4137931 ,
                12.93103448, 13.44827586, 13.96551724, 14.48275862, 15.        ]),
    'y': array([5.64189584e-01, 3.36348881e-01, 2.00518714e-01, 1.19541812e-01,
                7.12663894e-02, 4.24863751e-02, 2.53287993e-02, 1.51000896e-02,
                9.00211277e-03, 5.36672537e-03, 3.19944239e-03, 1.90738876e-03,
                1.13711435e-03, 6.77905355e-04, 4.04142001e-04, 2.40934455e-04,
                1.43636176e-04, 8.56305547e-05, 5.10497571e-05, 3.04339697e-05,
                1.81436027e-05, 1.08165422e-05, 6.44842084e-06, 3.84430907e-06,
                2.29183432e-06, 1.36630653e-06, 8.14541218e-07, 4.85599229e-07,
                2.89496230e-07, 1.72586903e-07])
})
second_line = go.Scatter(x=s_orbitals["r"], y=s_orbitals["2s"], name="2s")
third_line = go.Scatter(x=s_orbitals["r"], y=s_orbitals["3s"], name="3s")

Then, you must create a plotly figure and add these lines to the figure.

fig = make_subplots(rows=1, cols=3, shared_yaxes=True)
fig
fig.add_trace(first_line, row=1, col=1)
fig.add_trace(second_line, row=1, col=2)
fig.add_trace(third_line, row=1, col=3)

We can then specify things about the layout like axis titles similar to what we did with one plot only. When there is more than one axis present, the words axis in your layout dictionary have numbers appended to them.

layout = {
    "xaxis": {
        "title": "r",
    },
    "yaxis": {
        "title": r"$\psi$"
    },
    
    "xaxis2":{
        "title": "r"
    },
    
    "xaxis3":{
        "title": "r"
    }
}

fig.update_layout(layout)