Question

I have a very large code which outputs several plots, one for each function it runs. I save all these plots in a single output .png file. Depending on the situation, sometimes the code will save plots from all functions and other times just a few of them.

One of these plots includes a colorbar which I want positioned in the same place inside said plot no matter if all the functions were processed or if only a few of them were.

I've tried everything I could think of but I just can't stop the colorbar from moving around in the .png final output file when it contains only a few plots (more precisely: when the plots below the one that contains the colorbar are not generated).

Here's a picture to show what I mean:

enter image description here

The MWE is below. To generate the first file I plotted everything and for the second one I just commented out the last eight ax* blocks.

import matplotlib.pyplot as plt
import numpy as np
import matplotlib.gridspec as gridspec

# Generate random data.
x = np.random.randn(60)
y = np.random.randn(60)
z = [np.random.random() for _ in range(60)]

fig = plt.figure(figsize=(20, 35))  # create the top-level container
gs = gridspec.GridSpec(14, 8)  # create a GridSpec object

ax0 = plt.subplot(gs[0:2, 0:2])
plt.scatter(x, y, s=20)
plt.ylabel('$L$', fontsize=16)
plt.xlabel('$E$', fontsize=16)

ax1 = plt.subplot(gs[0:2, 2:4])
plt.scatter(x, y, s=20)
plt.ylabel('$L$', fontsize=16)
plt.xlabel('$E$', fontsize=16)

ax2 = plt.subplot(gs[0:2, 4:6])
plt.scatter(x, y, s=20)
plt.ylabel('$L$', fontsize=16)
plt.xlabel('$E$', fontsize=16)

ax3 = plt.subplot(gs[0:2, 6:8])
plt.scatter(x, y, s=20)
plt.ylabel('$L$', fontsize=16)
plt.xlabel('$E$', fontsize=16)

ax4 = plt.subplot(gs[2:4, 0:2])
plt.scatter(x, y, s=20)
plt.ylabel('$L$', fontsize=16)
plt.xlabel('$E$', fontsize=16)

ax5 = plt.subplot(gs[2:4, 2:4])
plt.scatter(x, y, s=20)
plt.ylabel('$L$', fontsize=16)
plt.xlabel('$E$', fontsize=16)

ax6 = plt.subplot(gs[2:4, 4:6])
plt.scatter(x, y, s=20)
plt.ylabel('$L$', fontsize=16)
plt.xlabel('$E$', fontsize=16)

ax7 = plt.subplot(gs[2:4, 6:8])
plt.scatter(x, y, s=20)
plt.ylabel('$L$', fontsize=16)
plt.xlabel('$E$', fontsize=16)

ax8 = plt.subplot(gs[4:6, 0:2])
plt.scatter(x, y, s=20)
plt.ylabel('$L$', fontsize=16)
plt.xlabel('$E$', fontsize=16)

ax9 = plt.subplot(gs[4:6, 2:4])
plt.scatter(x, y, s=20)
plt.ylabel('$L$', fontsize=16)
plt.xlabel('$E$', fontsize=16)

ax10 = plt.subplot(gs[4:6, 4:6])
plt.scatter(x, y, s=20)
plt.ylabel('$L$', fontsize=16)
plt.xlabel('$E$', fontsize=16)

ax11 = plt.subplot(gs[4:6, 6:8])
plt.scatter(x, y, s=20)
plt.ylabel('$L$', fontsize=16)
plt.xlabel('$E$', fontsize=16)

ax12 = plt.subplot(gs[6:8, 0:2])
plt.scatter(x, y, s=20)
plt.ylabel('$L$', fontsize=16)
plt.xlabel('$E$', fontsize=16)

ax13 = plt.subplot(gs[6:8, 2:4])
plt.scatter(x, y, s=20)
plt.ylabel('$L$', fontsize=16)
plt.xlabel('$E$', fontsize=16)

ax14 = plt.subplot(gs[6:8, 4:6])
plt.scatter(x, y, s=20)
plt.ylabel('$L$', fontsize=16)
plt.xlabel('$E$', fontsize=16)

ax15 = plt.subplot(gs[6:8, 6:8])
plt.scatter(x, y, s=20)
plt.ylabel('$L$', fontsize=16)
plt.xlabel('$E$', fontsize=16)

ax16 = plt.subplot(gs[8:10, 0:2])
plt.scatter(x, y, s=20)
plt.ylabel('$L$', fontsize=16)
plt.xlabel('$E$', fontsize=16)

ax17 = plt.subplot(gs[8:10, 2:4])
plt.scatter(x, y, s=20)
plt.ylabel('$L$', fontsize=16)
plt.xlabel('$E$', fontsize=16)

ax18 = plt.subplot(gs[8:10, 4:6])
cm = plt.cm.get_cmap('RdYlBu_r')
plt.scatter(x, y, s=20, c=z, cmap=cm, vmin=0, vmax=1)
plt.ylabel('$L$', fontsize=16)
plt.xlabel('$E$', fontsize=16)
# Plot colorbar.
box = ax18.get_position()
cbar_posit = [box.x1 * 0.93, box.y1 * 0.94, 0.04, 0.005]
cbaxes = fig.add_axes(cbar_posit)
cbar = plt.colorbar(cax=cbaxes, ticks=[0, 1], orientation='horizontal')
cbar.ax.tick_params(labelsize=9)

ax19 = plt.subplot(gs[8:10, 6:8])
plt.scatter(x, y, s=20)
plt.ylabel('$L$', fontsize=16)
plt.xlabel('$E$', fontsize=16)

ax20 = plt.subplot(gs[10:12, 0:2])
plt.scatter(x, y, s=20)
plt.ylabel('$L$', fontsize=16)
plt.xlabel('$E$', fontsize=16)

ax21 = plt.subplot(gs[10:12, 2:4])
plt.scatter(x, y, s=20)
plt.ylabel('$L$', fontsize=16)
plt.xlabel('$E$', fontsize=16)

ax22 = plt.subplot(gs[10:12, 4:6])
plt.scatter(x, y, s=20)
plt.ylabel('$L$', fontsize=16)
plt.xlabel('$E$', fontsize=16)

ax23 = plt.subplot(gs[10:12, 6:8])
plt.scatter(x, y, s=20)
plt.ylabel('$L$', fontsize=16)
plt.xlabel('$E$', fontsize=16)

ax24 = plt.subplot(gs[12:14, 0:2])
plt.scatter(x, y, s=20)
plt.ylabel('$L$', fontsize=16)
plt.xlabel('$E$', fontsize=16)

ax25 = plt.subplot(gs[12:14, 2:4])
plt.scatter(x, y, s=20)
plt.ylabel('$L$', fontsize=16)
plt.xlabel('$E$', fontsize=16)

ax26 = plt.subplot(gs[12:14, 4:6])
plt.scatter(x, y, s=20)
plt.ylabel('$L$', fontsize=16)
plt.xlabel('$E$', fontsize=16)

ax27 = plt.subplot(gs[12:14, 6:8])
plt.scatter(x, y, s=20)
plt.ylabel('$L$', fontsize=16)
plt.xlabel('$E$', fontsize=16)

fig.tight_layout()

out_png = 'colorbar.png'
plt.savefig(out_png, dpi=150)
plt.close()

This MWE reproduces the way I generate the output .png file in my actual code, hence the apparently unnecessarily large code.

Since the colorbar is positioned according to the position of the ax18 plot, I'd expect it to always be correctly positioned in the right place inside it but clearly that's not what I'm getting.

I just want that colorbar to be fixed inside the ax18 plot no matter how many other plots I generate around it. Is this possible?


Add

Ok, so I've narrowed down the issue to the fig.tight_layout() call. When this is commented out the colorbar stays perfectly in position no matter how many plots are produced. The downside of this is that the final image looks that much worse, with adjacent plots' axes overlapping.

Is there a way to keep fig.tight_layout() on and still get the colorbar position correctly?

Was it helpful?

Solution

There are a couple of issues, I think.

If I understand you correctly, you want the colorbar to be in a fixed position relative to the axes where it is embedded.

fig.tight_layout() moves the 'regular' axes around. By 'regular' I mean that the colorbar is an axes you created manually and it will not be affected by this function. AFAIK, there is not any way to anchor one axes to another so that if you move the first, the second is also moved to keep its relative position. This takes me to point that you need to set the position of the colorbar after your call to tight_layout.

Then, however, you'll still have problems since you're setting the position of the colorbar in absolute value, while you want a fixed position relative to the axes. You need to specify the location realtive to the axis.

How would you do that:

## ...
## ...
## Your MWE starts up there...

ax18 = plt.subplot(gs[8:10, 4:6])
cm = plt.cm.get_cmap('RdYlBu_r')
sca = plt.scatter(x, y, s=20, c=z, cmap=cm, vmin=0, vmax=1)
plt.ylabel('$L$', fontsize=16)
plt.xlabel('$E$', fontsize=16)
# Plot colorbar.
#box = ax18.get_position()
#cbar_posit = [box.x1 * 0.93, box.y1 * 0.94, 0.04, 0.005]
#cbaxes = fig.add_axes(cbar_posit)
#cbar = plt.colorbar(cax=cbaxes, ticks=[0, 1], orientation='horizontal')
#cbar.ax.tick_params(labelsize=9)

## You keep creating and filling axes...
## ...
## ...

Notice I saved the return value of the scatterplot in sca. We'll need it for the colorbar. Also, you can remove the commented lines, it's just to show that you don't want them there any more.

Then, at the end of you're script you can create de colorbar specifying its position and size within the axes

## ...
## ...
## All the plots are created above...

fig.tight_layout()  # You call fig.tight_layout BEFORE creating the colorbar

import matplotlib
# You input the POSITION AND DIMENSIONS RELATIVE TO THE AXES
x0, y0, width, height = [0.6, 0.9, 0.2, 0.04]

# and transform them after to get the ABSOLUTE POSITION AND DIMENSIONS
Bbox = matplotlib.transforms.Bbox.from_bounds(x0, y0, width, height)
trans = ax18.transAxes + fig.transFigure.inverted()
l, b, w, h = matplotlib.transforms.TransformedBbox(Bbox, trans).bounds

# Now just create the axes and the colorbar
cbaxes = fig.add_axes([l, b, w, h])
cbar = plt.colorbar(sca, cax=cbaxes, ticks=[0, 1], orientation='horizontal')
cbar.ax.tick_params(labelsize=9)

out_png = 'colorbar.png'
plt.savefig(out_png, dpi=150)
plt.close()

I got the code for the coordinates transformation long ago from an answer here in SO and saved it in some script. I tried to credit that answer but I cannot find it any more, sorry.

OTHER TIPS

Ditect use fig.canvas.draw() after fig.tight_layout(). Even if there is a UserWarning, it works.

fig.tight_layout()
fig.canvas.draw()
cax = fig.add_axes([ax.get_position().x0, ax.get_position().y0 - 0.12, ax.get_position().width, 0.02])
cbar = plt.colorbar(self.cs, orientation='horizontal', cax=cax)
Licensed under: CC-BY-SA with attribution
Not affiliated with StackOverflow
scroll top