Test inheritance with python (and Django)
2021-03-26
2021-03-26
Let's image that we have this models:
models.py
class Athlete(models.Model):
name = models.CharField(max_length=255)
slug = models.SlugField(max_length=300, unique=True)
age = models.PositiveIntegerField()
class Meta:
abstract = True
class BasketballPlayer(Athlete):
points_scored = models.PositiveIntegerField()
assists = models.PositiveIntegerField()
rebounds = models.PositiveIntegerField()
def __str__(self):
return self.name
class SoccerPlayer(Athlete):
goals_scored = models.PositiveIntegerField()
assists = models.PositiveIntegerField()
yellow_cards = models.PositiveIntegerField()
def __str__(self):
return self.name
We also provide these 2 GET endpoints:
And we want to test that every endpoint returns a 200 code and also be sure that serialized data looks like we expect.
tests.py
class TestBasketballPlayerAPI(TestCase):
def setUp(self) -> None:
self.basketball_player = create_test_basketball_player()
def test__response_ok(self):
url = reverse_lazy(
"basketball_player", kwargs={"slug": self.basketball_player.slug}
)
response = self.client.get(url)
assert response.status_code == 200
assert response.data == BasketballPlayerSerializer(self.basketball_player).data
class TestSoccerPlayerAPI(TestCase):
def setUp(self) -> None:
self.soccer_player = create_test_soccer_player()
def test__response_ok(self):
url = reverse_lazy("soccer_player", kwargs={"slug": self.soccer_player.slug})
response = self.client.get(url)
assert response.status_code == 200
assert response.data == SoccerPlayerSerializer(self.soccer_player).data
Both test classes TestBasketballPlayerAPI
and TestSoccerPlayerAPI
are very similar. Both check that a 200 is returned and the serialized data.
Let's try to refactor this. We could make a base class AthleteTest
for the tests and inherit from it.
class AthleteTest(TestCase):
view_name = ""
serializer_class = None
def setUp(self) -> None:
self.test_athlete = None
@mark.django_db
def test__response_ok(self):
url = reverse_lazy(
self.view_name, kwargs={"slug": self.test_athlete.slug}
)
response = self.client.get(url)
assert response.status_code == 200
assert response.data == self.serializer_class(self.test_athlete).data
class TestBasketballPlayerAPI(AthleteTest):
view_name = "basketball_player"
serializer_class = BasketballPlayerSerializer
def setUp(self) -> None:
self.test_athlete = create_test_basketball_player()
class TestSoccerPlayerAPI(AthleteTest):
view_name = "soccer_player"
serializer_class = SoccerPlayerSerializer
def setUp(self) -> None:
self.test_athlete = create_test_soccer_player()
The code is now much shorter and easier to extend with new test cases. But if we try to run pytest it will fail.
Pytest will collect 3 test cases and fail because AthleteTest
is not actually a test case and it is only intended to be inherit from it.
In order to avoid this problem we can use the option __test__
. After adding it the test will look like this:
class AthleteTest(TestCase):
__test__ = False
view_name = ""
serializer_class = None
def setUp(self) -> None:
self.test_athlete = None
@mark.django_db
def test__response_ok(self):
url = reverse_lazy(
self.view_name, kwargs={"slug": self.test_athlete.slug}
)
response = self.client.get(url)
assert response.status_code == 200
assert response.data == self.serializer_class(self.test_athlete).data
class TestBasketballPlayerAPI(AthleteTest):
__test__ = True
view_name = "basketball_player"
serializer_class = BasketballPlayerSerializer
def setUp(self) -> None:
self.test_athlete = create_test_basketball_player()
class TestSoccerPlayerAPI(AthleteTest):
__test__ = True
view_name = "soccer_player"
serializer_class = SoccerPlayerSerializer
def setUp(self) -> None:
self.test_athlete = create_test_soccer_player()
Now pytest will only collect 2 test and pass as expected.
Checkout the complete source code here.
Note: this method only works with pytest and NOT with the Django test runner.