How to create a custom legend (Seaborn)

scat=sns.regplot(
        x='age',
        y='charges',
        data=ages_charges,
        truncate=False,
        scatter_kws={'facecolors':color}
    )
    scat.set(
        title='The Correlation between Age and Charge Amount',
        xlabel='Age',
        ylabel='Amount in Charges (Dollars)'
    )
    

I would like to create a legend that gives me something like:

   Legend
o=18-34
o=35-49
o=50+
--- = (my linear reg. equation + my error)

I’ve not used seaborn much but from what I’m aware it’s mostly just a wrapper on top of matplotlib so if the seaborn docs don’t cover your issue perhaps the matplotlib docs do-


The link above should cover it and there are several examples at the bottom of the page.
1 Like

What have you tried (what’s given you an error or failed)?
You imported matplotlib.pyplot as plt , right?

Would either of these help?
https://jakevdp.github.io/PythonDataScienceHandbook/04.06-customizing-legends.html

https://matplotlib.org/api/_as_gen/matplotlib.pyplot.legend.html

1 Like

Whoops! Haha. Jinx! :rofl:

1 Like

So I had to use 3 existing points to make my legend-

dot=scat.scatter(20,562857.83475)
dot2=scat.scatter(39,252568.341850)
dot3=scat.scatter(57,427626.816500)
scat.legend(
    (dot,dot2,dot3),
    ('18-34','35-49','50+'),
    loc='lower right'
)

Which gave me what I wanted. Now I’m trying to find a way to graph the line equation, but not as a dot.

1 Like

Cool.
So, you don’t want a scatterplot(?) You have to change the type of plot.

No thats not what I meant. I have a regplot, so it automatically graphs the line for me. I just want to add a key for it in my legend.

You want to change the dot to a line in the legend?
What about this:
https://www.nuomiphp.com/eplan/en/242911.html

It overwrites it :confused:

scat=sns.regplot(
        x='age',
        y='charges',
        data=ages_charges,
        truncate=False,
        scatter_kws={'facecolors':color},
        line_kws={
            'alpha':0.8,
            'linewidth':2,
            'label':lin_reg
        } #gives me what I want
    )

    #using existing points to put in the legend
    dot=scat.scatter(20,562857.83475,edgecolors='#4890c1')
    dot2=scat.scatter(39,252568.341850,edgecolors='#4890c1')
    dot3=scat.scatter(57,427626.816500,edgecolors='#4890c1')
   
  scat.legend(
        (dot,dot2,dot3),
        ('18-34','35-49','50+'),
        loc='lower right'
    )  #then when I come down here to add more plots it overwrites the line

Is there a way to predefine the legend in the .regplot() function?

When you say add more plots on the last line what exactly do you mean?

I’m afraid I’m just not that familiar with seaborn so I’ve no idea if this is a quirk related to that or not. I think one of the oddities is that what you really need from the regplot function is a return of the objects it creates (paths, polys, lines etc.), the return of an axis is unhelpful.

An ugly fix if you only have one line plotted on this axis is to get the object reference with the following function (it doesn’t have to be a function but this whole method of working is a bit backwards). There must be a better way to get that reference so consider your options. This is a bit sketchy for a proper project.

def get_line_from_axis(axis):
    # warning! only gets first returned line
    import matplotlib
    for item in axis.get_children():
        if isinstance(item, matplotlib.lines.Line2D):
            return item


rfit_line = get_line_from_axis(ax)
# I think you should be able to simply tag it on.
# Without having access to the same data/objects you do I cannot test.
ax.legend((dot, dot2, dot3, rfit_line), ('18-34','35-49','50+', 'regression_fit'))

Perhaps a more sensible alternative is creating a proxy just for the legend (or add two separate legends). There are some more advanced details in the following-

2 Likes

Two separate legends is the way to go I think.

I used the link to add 2 separate legends thanks a lot for that!

first_legend=plt.legend(
        title='Linear Regression Equation',
        loc='upper right'
    )

ax=plt.gca().add_artist(first_legend)

plt.legend(
  (dot,dot2,dot3),
  ('18-34','35-49','50+'),
  loc='lower right',
  title='Legend'
)  

1 Like

That makes a lot more sense! Whew! :slight_smile:
Can I see the visualization now? :smiley:
(Or, I mean, when you’re done with the project.)

1 Like

visualization
Oh I forgot to add the error! :rofl:

2 Likes

No worries! Work in progress :slight_smile:

1 Like

Good work in progress too! “One line at a time we learn to code”

2 Likes

Ah, is this still the medical project? Nice one.

Those changes and the subsequent fitting seem to be an order of magnitude higher than they should be though. I recall the range being roughly $1,000 to $60,000. On that note the charges seem a little odd (highest charges are for an 18-20 years old?).

Have a little check of the set-up for this one. Something has gone a little awry.

Edit: Is it perhaps grouped by age and summed together? Make sure to divide by the number of datapoints for an average if that’s the case.

By fitting do you mean the line of best fit (and the little area around it?)

There are some outliers in the data.

I did sum them.

ages_charges=insurance_dataframe.groupby('age').charges.sum().reset_index() 

So i need to average instead?

What I mean is the maximum charge of any single individual in this dataset was roughly $60,000 and the min was like $1,000. Your scale goes between $200,000 and $600,000. Either this plot is not what I think it is (in which case make sure your labels, captions etc. make it clear what is being plotted) or something is wrong.

Sum alone will be a problem. What if half the dataset had the same age, their sum charge would be phenomenal. Something like this seems to be happening. There are more folks around 18-20 something in this dataset. It’s messing with the result.

1 Like

Oh! Thanks for calling that out. Now that I think about it $600,000 is too high! I readjusted it to get the mean, so now it is where its supposed to be!

2 Likes