Spaced repetition algorithm implementation in Kotlin

Have you ever thought about creating your own flashcards app?

Well, I have. Although there are plenty of applications like that I developed a simple one for my personal purposes.

Today I’d like to share the implementation of the SM-2 spaced repetition algorithm in the Kotlin language.

Introduction

I did my best to keep this implementation as simple as possible but yet easy to understand and read.

I’d like to keep it this way because I’m not a big fan of the algorithm implementations which contains hundreds of lines of code and plenty of fancy variables such as k, p, V, etc.

The flashcard data model

The single flashcard will be represented by the simple data model. Take a look at this class:

package spacedrepetition.card

import java.time.LocalDateTime

data class Card(
        val frontSide: String,
        val backSide: String,
        val nextRepetition: LocalDateTime = LocalDateTime.now(),
        val repetitions: Int = 0,
        val easinessFactor: Float = 2.5.toFloat(),
        val interval: Int = 1
) {
    fun withUpdatedRepetitionProperties(
            newRepetitions: Int,
            newEasinessFactor: Float,
            newNextRepetitionDate: LocalDateTime,
            newInterval: Int
    ) = copy(repetitions = newRepetitions,
            easinessFactor = newEasinessFactor,
            nextRepetition = newNextRepetitionDate,
            interval = newInterval
    )
}

Implementation

As I previously mentioned, the implementation is using Kotlin.

I’m not going to explain the code in detail here, if you’re looking for a good reference, please take a look here.

package spacedrepetition

import spacedrepetition.card.Card
import java.time.Duration
import java.time.Instant
import java.time.LocalDateTime
import java.time.ZoneId
import java.util.logging.Logger
import kotlin.math.roundToInt

class SpacedRepetition {

    fun calculateRepetition(card: Card, quality: Int): Card {
        validateQualityFactorInput(quality)

        val easiness = calculateEasinessFactor(card.easinessFactor, quality)
        val repetitions = calculateRepetitions(quality, card.repetitions)
        val interval = calculateInterval(repetitions, card.interval, easiness)

        val cardAfterRepetition = card.withUpdatedRepetitionProperties(
                newRepetitions = repetitions,
                newEasinessFactor = easiness,
                newNextRepetitionDate = calculateNextPracticeDate(interval),
                newInterval = interval
        )
        log.info(cardAfterRepetition.toString())
        return cardAfterRepetition
    }

    private fun validateQualityFactorInput(quality: Int) {
        log.info("Input quality: $quality")
        if (quality < 0 || quality > 5) {
            throw IllegalArgumentException("Provided quality value is invalid ($quality)")
        }
    }

    private fun calculateEasinessFactor(easiness: Float, quality: Int) =
            Math.max(1.3, easiness + 0.1 - (5.0 - quality) * (0.08 + (5.0 - quality) * 0.02)).toFloat()


    private fun calculateRepetitions(quality: Int, cardRepetitions: Int) = if (quality < 3) {
        0
    } else {
        cardRepetitions + 1
    }

    private fun calculateInterval(repetitions: Int, cardInterval: Int, easiness: Float) = when {
        repetitions <= 1 -> 1
        repetitions == 2 -> 6
        else -> (cardInterval * easiness).roundToInt()
    }

    private fun calculateNextPracticeDate(interval: Int): LocalDateTime {
        val now = System.currentTimeMillis()
        val nextPracticeDate = now + dayInMs * interval
        return LocalDateTime.ofInstant(Instant.ofEpochMilli(nextPracticeDate), ZoneId.systemDefault())
    }

    private companion object {
        private val dayInMs = Duration.ofDays(1).toMillis()
        private val log: Logger = Logger.getLogger(SpacedRepetition::class.java.name)
    }

}

Tests

To make sure it works, a proper test cases have been provided:

package spacedrepetition

import org.junit.jupiter.api.Assertions
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertTrue
import org.junit.jupiter.api.Test
import spacedrepetition.card.Card
import java.time.LocalDateTime

internal class SpacedRepetitionTest {

    private val spacedRepetition = SpacedRepetition()

    @Test
    fun `should send single answer with quality 5 and check returned card`() {
        val card = getCard()

        val response = spacedRepetition.calculateRepetition(card, 5)

        assertEquals(1, response.repetitions)
        assertEquals(1, response.interval)
        assertEquals(2.6.toFloat(), response.easinessFactor)
        assertTrue(response.nextRepetition.isAfter(LocalDateTime.now()))
    }

    @Test
    fun `should send single answer with quality 4 and check returned card`() {
        val card = getCard()

        val response = spacedRepetition.calculateRepetition(card, 4)

        assertEquals(1, response.repetitions)
        assertEquals(1, response.interval)
        assertEquals(2.5.toFloat(), response.easinessFactor)
        assertTrue(response.nextRepetition.isAfter(LocalDateTime.now()))
    }

    @Test
    fun `should send single answer with quality 3 and check returned card`() {
        val card = getCard()

        val response = spacedRepetition.calculateRepetition(card, 3)

        assertEquals(1, response.repetitions)
        assertEquals(1, response.interval)
        assertEquals(2.36.toFloat(), response.easinessFactor)
        assertTrue(response.nextRepetition.isAfter(LocalDateTime.now()))
    }

    @Test
    fun `should send single answer with quality 2 and check returned card`() {
        val card = getCard()

        val response = spacedRepetition.calculateRepetition(card, 2)

        assertEquals(0, response.repetitions)
        assertEquals(1, response.interval)
        assertEquals(2.18.toFloat(), response.easinessFactor)
        assertTrue(response.nextRepetition.isAfter(LocalDateTime.now()))
    }

    @Test
    fun `should send single answer with quality 1 and check returned card`() {
        val card = getCard()

        val response = spacedRepetition.calculateRepetition(card, 1)

        assertEquals(0, response.repetitions)
        assertEquals(1, response.interval)
        assertEquals(1.96.toFloat(), response.easinessFactor)
        assertTrue(response.nextRepetition.isAfter(LocalDateTime.now()))
    }

    @Test
    fun `should simulate successful cards repetition session and expect no more cards to repeat today`() {
        val deck = getDeckWithSixCards()

        val cardsAfterRepetition = deck.map { flashCard ->
            val repetition1 = spacedRepetition.calculateRepetition(flashCard, 1)
            val repetition2 = spacedRepetition.calculateRepetition(repetition1, 2)
            val repetition3 = spacedRepetition.calculateRepetition(repetition2, 3)
            val repetition4 = spacedRepetition.calculateRepetition(repetition3, 4)
            spacedRepetition.calculateRepetition(repetition4, 5)
        }

        val cardToRepeatToday = cardsAfterRepetition.filter { it.nextRepetition.isBefore(LocalDateTime.now()) }
        assertTrue(cardToRepeatToday.isEmpty())
    }

    @Test
    fun `should throw an exception if the user's quality of repetition response is invalid`() {
        val flashCard = getCard(repetitionDate = LocalDateTime.now().minusDays(1))

        Assertions.assertThrows(IllegalArgumentException::class.java) {
            spacedRepetition.calculateRepetition(flashCard, 6)
        }
        Assertions.assertThrows(IllegalArgumentException::class.java) {
            spacedRepetition.calculateRepetition(flashCard, -1)
        }
        Assertions.assertThrows(IllegalArgumentException::class.java) {
            spacedRepetition.calculateRepetition(flashCard, Int.MAX_VALUE)
        }
        Assertions.assertThrows(IllegalArgumentException::class.java) {
            spacedRepetition.calculateRepetition(flashCard, Int.MIN_VALUE)
        }
    }

    private fun getCard(front: String = "🍎",
                        back: String = "Apple",
                        repetitionDate: LocalDateTime = LocalDateTime.now()) = Card(
            frontSide = front,
            backSide = back,
            nextRepetition = repetitionDate
    )

    private fun getDeckWithSixCards() = listOf(
            getCard(),
            getCard(repetitionDate = LocalDateTime.now().minusDays(1)),
            getCard(repetitionDate = LocalDateTime.now().minusDays(2)),
            getCard(repetitionDate = LocalDateTime.now().plusDays(2)),
            getCard(repetitionDate = LocalDateTime.now().plusDays(1)),
            getCard(repetitionDate = LocalDateTime.now().minusDays(3))
    )
}

Conclusion

Please note that the original SM-2 algorithm has been published for the first time in 1988. There’s a newer, improved version of that algorithm (SM-18) already in use. But it’s not open source, so the most common applications (like Anki) use the variations of the SM-2 mostly.

Anyway, I hope you find it useful and you’re now ready to use the algorithm for your own spaced repetition-based application.

Good luck 👋

Resources

Leave a Reply

Your email address will not be published.