Python - Check if all elements of a list are equal

I am trying to solve this problem
Python Challenge: All equal - Python Principles
I need to check if all elements in a list are equal or not.

The solution I thought of is to check if the length of the list gives the same result as counting the number of times the first element appears.

Writing the code like this allows me to print the correct answer

def all_equal(lst):
    first_element = lst[0]
    print(lst.count(first_element) == len(lst))
  
all_equal([1, 1, 1])

But when I change the print in the function to return the boolean value instead or printing it gives me an index error that the index I am using in first_element is out of range

def all_equal(lst):
    first_element = lst[0]
    return lst.count(first_element) == len(lst)
    

all_equal([1, 1, 1])

I can’t figure out why it is giving me an error when returning but not when printing?

Hello!

The list may be empty. Try all_equal([]) and fix the error.