Using Functions to Display Multiple Similar Subplots in Matplotlib

The StackOverflow 2019 data set can be found at: https://insights.stackoverflow.com/survey

Okay so I am decently new at python and really new to the Matplotlib package. Here is the issue that I am having:

While working with data from a table of StackOverflow survey results. I noticed that that there was a relationship between age and salary for programmers who responded to the survey (I know, not that shocking). The code to grab this info looked like this:

avg_salary_per_age = df.groupby('Age')['Salary_USD'].mean().reset_index()

Pretty simple. Just grouping by the ‘Age’ column and showing the average salary for each. The subplot is simple as well:

plt.figure(figsize=(6, 6))
ax1 = plt.subplot(1)

ax1.set_xticks(range(20,70,5))
plt.plot(avg_salary_per_age ['Age'], avg_salary_per_age ['Salary_USD'], 'o')
plt.xlabel('Age (years)')
plt.ylabel('Avg. Salary (USD)')
plt.title('Programmer Salary v. Age')

Once I had that subplot, I wanted to see if this relationship would hold true for individual countries. It is simple enough conceptually. Just do the same process but group by country and age, then redo the process for the top 9 countries to get a 3x3 subplot figure. However, I am sure there is a way to write a function for this since it is just me iterating over the same steps 9 times.

This is where my lack of skill is glaring… So I tried to write one:

def plotter(countries, df):

    dfs = []
    countries = []

    for country in countries:
        countries.append(country)
        #filtering based on country.
        var = df.loc[df['Country'] == country] 
        grouped_var = var.groupby('Age')['Salary_USD'].mean().reset_index()
        dfs.append(grouped_var)
        
    for df in range(len(dfs)):
        plt.figure(figsize=(12, 10))
        
        ax1 = plt.subplot(2,2,1)

        ax1.set_xticks(range(20,80,5))
        plt.plot(dfs[df]['Age'], dfs[df]['Salary_USD'], 'o')
        plt.xlabel('Age (years)')
        plt.ylabel('Avg. Salary (USD)')
        plt.title('Programmer Salary v. Age')
        plt.show()

It obviously doesn’t work and I apologize for it being a complete mess. I guess the problem I am having is that I am not sure how to treat dataframes. Can they even be stored in lists? What about displaying them with a loop? Could someone give me hint on how to structure the function? Also, I apologize if I did a terrible job of explaining the problem. I am not quite sure how I would display a dataframe in question body.

How exactly is it not working?

Or, more accurately, what is your program doing that you don’t expect or what it is not doing that you did?

And for the sake of us all, please post the URL to this exercise in your next reply. Thanks.

I don’t know if this is relating to anything in the LE here, I think it might be an off-platform question. :slight_smile:

Not sure, though - I don’t know whether there’s material on matplotlib in the Python 3 Pro courses. :slight_smile:

1 Like

Pretty sure it is in the Data Science and/or Computer Science tracks.

1 Like