Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Solution

Note

Clicking the button below automatically creates a blood oath with the course. It only works if you actually tried to do the exercise beforehand. Click at your own risk.

Reveal the solution.

The exercise program is completely broken. Here is a corrected version. Look for comments starting by FIX: to get explanations.

Writing safe C is possible, but it’s hard. It requires expert knowledge about the language and even experts can still make mistakes. We hope the following example convinces you of the need for memory-safe programming languages.

#include <stdbool.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

/*
================================= UTILS =================================
*/

// FIX: The broken version of the code uses malloc everywhere but never
//      checks whether it returns something or not. Indeed, malloc could
//      return NULL if something went wrong!
//      See <https://en.cppreference.com/w/c/memory/malloc>.
//
//      So, to avoid null pointer dereferencing, we'll wrap malloc and
//      abort if something goes wrong (in C, abort = crash the program).
//      Usually, that's a good enough thing to do.
//
//      In some cases, however, you really don't want your program to
//      crash (e.g., in a web server where reliability is of utmost
//      importance). In such cases, you must apply another strategy...
void* safe_alloc(size_t size) {
    if (size == 0) {
        // malloc(0) is undefined behavior.
        abort();
    }
    void* pointer = malloc(size);
    if (pointer == NULL) {
        abort();
    }
    return pointer;
}

// FIX: realloc is even more sneaky. If something goes wrong, it returns
//      NULL too, BUT the initial pointer stays valid!
//      See <https://en.cppreference.com/w/c/memory/realloc>.
//
//      Let's take care of this by writing another aborting wrapper.
void* safe_realloc(void* pointer, size_t size) {
    if (size == 0) {
        // realloc(..., 0) is undefined behavior.
        abort();
    }
    void* new_pointer = realloc(pointer, size);
    if (new_pointer == NULL) {
        // Warning: the old ptr is still valid at this point!
        // If you don't abort, you must take it into account.
        abort();
    }
    return new_pointer;
}

// FIX: The previous version of the program uses malloc(capacity * sizeof(Client))
//      everywhere. That can get you in trouble. The reason is: its's an
//      integer overflow hiding in plain sight! What happens if a malicious
//      user wants to allocate a lot of clients? You'll end up with a
//      mismatch between capacity and your actual buffer size!
//
//      Let's make it safer using another aborting wrapper.
void* safe_alloc_array(size_t array_size, size_t item_size) {
    size_t size;
    // See <https://gcc.gnu.org/onlinedocs/gcc/Integer-Overflow-Builtins.html>.
    bool overflow = __builtin_mul_overflow(array_size, item_size, &size);
    if (overflow) {
        abort();
    }
    return safe_alloc(size);
}

// FIX: Same remark as above concerning realloc.
void* safe_realloc_array(void* pointer, size_t array_size, size_t item_size) {
    size_t size;
    bool overflow = __builtin_mul_overflow(array_size, item_size, &size);
    if (overflow) {
        abort();
    }
    return safe_realloc(pointer, size);
}

// FIX: In the broken version of the code, gets only took a pointer to a
//      buffer. That's a no-go. When programming in C, it should
//      immediately trigger a red flag in your mind!
//
//      Let's add a new parameter to get the buffer's size and use it in
//      the loop. We're now safe against buffer overflows.
void safe_gets(char* buf, size_t buf_size) {
    size_t i = 0;
    int c;
    while (
        (c = getchar()) != EOF
        && i < buf_size - 1
        && c != '\n'
        && c != '\r'
    ) {
        // FIX: Let's improve readability to ease the code reviewer's job.
        //      In the same way as unit tests are, they're a good security
        //      practice too. It allows catching bugs before shipping new
        //      code to production.
        i++;
        *buf = c;
        buf++;
    }
    *buf = '\0';
}

// FIX: Same remark as for gets, we must add a new size parameter.
void safe_input(char* question, char* buf, size_t buf_size) {
    printf("%s", question);
    safe_gets(buf, buf_size);
}

typedef struct Date {
    int day;
    int month;
    int year;
} Date;

bool parse_date(char* date_string, Date* date) {
    int day;
    int month;
    int year;
    if (
        sscanf(date_string, "%d/%d/%d", &day, &month, &year) != 3
        || day < 1 || day > 31
        || month < 1 || month > 12
        || year < 1900
    ) {
        return false;
    }
    date->day = day;
    date->month = month;
    date->year = year;
    return true;
}

/*
================================= CLIENT =================================
*/

#define USERNAME_SIZE 64
#define FIRST_NAME_SIZE 64
#define LAST_NAME_SIZE 64
#define EMAIL_SIZE 128
#define CITY_SIZE 64
#define COUNTRY_SIZE 64

// FIX: In the broken version of the code, we were using pointers to
//      store referred_by. However, this doesn't work once we realloc the
//      client array. Indeed, realloc can return a new pointer to a new
//      array location, making all existing pointers dangling!
//
//      This is a subtle bug for which there's no quick solution. Here,
//      we'll use indices instead of pointers for them to stay valid after
//      realloc is called. Instead of NULL pointers, we need a special value to
//      indicate when there's no referrer, hence the NO_CLIENT macro below.
//      There are also multiple changes in the program below for this to work.
//
//      You could also store the referrer's username instead of their ID
//      (assuming they're unique, it's not the case here).
//
//      Another solution could be using a linked list. Indeed, linked lists
//      ensure stable addresses. Other data structures could help, see:
//      - <https://www.dgtlgrove.com/p/in-defense-of-linked-lists>
//      - <https://danielchasehooper.com/posts/segment_array/>
//      - <https://skypjack.github.io/2019-05-06-ecs-baf-part-3/>

// Assuming two's complement, casting -1 to an unsigned type should return an
// integer with all bits set. This will be our "no-referrer marker".
#define NO_CLIENT ((size_t)-1)

typedef struct Client Client;
struct Client {
    char username[USERNAME_SIZE];
    char first_name[FIRST_NAME_SIZE];
    char last_name[LAST_NAME_SIZE];
    char email[EMAIL_SIZE];
    char city[CITY_SIZE];
    char country[COUNTRY_SIZE];
    size_t referrer_id; // FIX: Notice the new type.
    Date birth_date;
};

// FIX: size_t is a more appropriate type for size and capacity.
typedef struct ClientArrayList {
    Client* clients;
    size_t size;
    size_t capacity;
} ClientArrayList;

ClientArrayList create_client_list(size_t capacity) {
    return (ClientArrayList){
        .clients = safe_alloc_array(capacity, sizeof(Client)),
        .size = 0,
        .capacity = capacity,
    };
}

void append_client(ClientArrayList* list, Client* client) {
    if (list->size == list->capacity) {
        // FIX: Multiplying capacity by two could lead to an integer overflow!
        //
        //      The fix here is a kind of a hack: I'm pretty sure
        //      2 * sizeof(Client) won't overflow. So I let safe_realloc_array
        //      handle the overflowing case by aborting. I multiply capacity
        //      afterwards when I'm sure the program didn't abort.
        //
        //      This kind of hacks can sometimes be OK, but you must write a
        //      comment to explain them!
        list->clients = safe_realloc_array(list->clients, list->capacity, 2 * sizeof(Client));
        list->capacity *= 2;
    }

    list->clients[list->size] = *client;
    list->size++;
}

void delete_client(ClientArrayList* list, size_t client_id) {
    // FIX: The previous version of this loop actually overflowed.
    //      Look closely at the loop condition, we changed it...
    list->size--;
    for (size_t i = client_id; i < list->size; i++) {
        // No overflow because we decremented the size just above,
        // so clients[i + 1] is still a valid item at this point.
        list->clients[i] = list->clients[i + 1];
    }

    // FIX: 1/2 returns 0. As said before, we cannot call safe_realloc_array
    //      with a zero size! Hence the new condition.
    if (list->size > 1 && list->size == list->capacity / 2) {
        list->capacity /= 2;
        list->clients = safe_realloc_array(list->clients, list->capacity, sizeof(Client));
    }

    // FIX: Since we're using referrer IDs instead of pointers now, we must make
    //      sure existing IDs are still valid.
    for (size_t i = 0; i < list->size; i++) {
        Client* client = &list->clients[i];
        if (client->referrer_id == client_id) {
            client->referrer_id = NO_CLIENT;
        } else if (client->referrer_id > client_id) {
            client->referrer_id--;
        }
    }
}

size_t search_client_index(ClientArrayList* list, char* username) {
    for (size_t i = 0; i < list->size; i++) {
        if (strcmp(list->clients[i].username, username) == 0) {
            return i;
        }
    }
    return NO_CLIENT;
}

void encode_client(ClientArrayList* list, Client* client, size_t referrer_id) {
    do {
        safe_input("Enter the client username : ", client->username, USERNAME_SIZE);
    } while (search_client_index(list, client->username) != NO_CLIENT);

    safe_input("Enter the client first name : ", client->first_name, FIRST_NAME_SIZE);
    safe_input("Enter the client last name : ", client->last_name, LAST_NAME_SIZE);
    safe_input("Enter the client email : ", client->email, EMAIL_SIZE);
    safe_input("Enter the client city : ", client->city, CITY_SIZE);
    safe_input("Enter the client country : ", client->country, COUNTRY_SIZE);

    char birth_date[64];
    do {
        safe_input(
            "Enter the client birth date (dd/mm/yyyy) : ",
            birth_date,
            sizeof(birth_date)
        );
    } while (!parse_date(birth_date, &client->birth_date));

    client->referrer_id = referrer_id;
}

void display_client(ClientArrayList* list, size_t client_id) {
    // FIX: Look closely: below, sprintf overflows if year > 9999.
    //      We don't actually need sprintf (see the last printf below).
    //      So, let's remove it.
    //char birth_date[11];
    //sprintf(birth_date, "%d/%d/%d", client->birth_date.day, client->birth_date.month, client->birth_date.year);

    Client* client = &list->clients[client_id];
    printf("Username: %s\n", client->username);
    printf("First name: %s\n", client->first_name);
    printf("Last name: %s\n", client->last_name);
    printf("Email: %s\n", client->email);
    printf("City: %s\n", client->city);
    printf("Country: %s\n", client->country);
    if (client->referrer_id != NO_CLIENT) {
        Client* referrer = &list->clients[client->referrer_id];
        printf("Referred by: %s\n", referrer->username);
    }

    // FIX: Split the birth date to avoid using sprintf above.
    printf(
        "Birth date: %d/%d/%d\n",
        client->birth_date.day,
        client->birth_date.month,
        client->birth_date.year
    );
}

void free_client_list(ClientArrayList* list) {
    // FIX: This loop doesn't make any sense.
    //      It's freeing pointers to other items in the same array.
    //      We can just remove it.
    //for (int i = 0; i < list->size; i++) {
    //    free(list->clients[i].referred_by);
    //}

    // FIX: Let's get rid of a dangling pointer by setting it to NULL.
    //      Better safe than sorry.
    free(list->clients);
    list->clients = NULL;
}

/*
================================= COMMANDS =================================
*/

void encode_command(ClientArrayList* list) {
    Client client;
    size_t referrer_id = NO_CLIENT;
    char referrer_username[USERNAME_SIZE] = {0};

    do {
        safe_input(
            "Enter the username of the client that referred this "
            "client (leave empty if none) : ",
            referrer_username,
            USERNAME_SIZE
        );
        referrer_id = search_client_index(list, referrer_username);
    } while (strlen(referrer_username) != 0 && referrer_id == NO_CLIENT);

    encode_client(list, &client, referrer_id);

    append_client(list, &client);
}

void delete_command(ClientArrayList* list) {
    char username[USERNAME_SIZE] = {0};

    safe_input(
        "Enter the username of the client that you want to delete : ",
        username,
        USERNAME_SIZE
    );

    size_t client_id = search_client_index(list, username);

    // FIX: We must check whether the client actually exists.
    if (client_id == NO_CLIENT) {
        printf("There's not client named %s.\n", username);
    } else {
        delete_client(list, client_id);
    }
}

void display_command(ClientArrayList* list) {
    for (size_t i = 0; i < list->size; i++) {
        printf("Client %zu\n", i + 1);
        display_client(list, i);
    }
}

void search_command(ClientArrayList* list) {
    char username[USERNAME_SIZE] = {0};

    safe_input(
        "Enter the username that you want to search for : ",
        username,
        USERNAME_SIZE
    );

    size_t client_id = search_client_index(list, username);

    // FIX: Again, we must check whether the client actually exists.
    if (client_id == NO_CLIENT) {
        printf("There's not client named %s.\n", username);
    } else {
        display_client(list, client_id);
    }
}

/*
============================== MAIN FUNCTION ==============================
*/

#define COMMAND_SIZE 16

int main(void) {
    ClientArrayList list = create_client_list(1);

    bool continue_encoding = true;
    while (continue_encoding) {
        char command[COMMAND_SIZE];

        safe_input(
            "Enter the command that you want "
            "(encode/delete/search/display/quit) : ",
            command,
            COMMAND_SIZE
        );

        if (strcmp(command, "encode") == 0) {
            encode_command(&list);
        } else if (strcmp(command, "delete") == 0) {
            delete_command(&list);
        } else if (strcmp(command, "display") == 0) {
            display_command(&list);
        } else if (strcmp(command, "search") == 0) {
            search_command(&list);
        } else if (strcmp(command, "quit") == 0) {
            continue_encoding = false;
        } else {
            printf("Unknown command\n");
        }
    }

    free_client_list(&list);

    return EXIT_SUCCESS;
}