JetBrains Academy: Edit distance in Java - Kamil-Jankowski/Learning-JAVA GitHub Wiki

JetBrains Academy: Edit distance in Java

Edit distance with a scoring matrix:

Earlier, we defined edit distance as the minimal number of insertions, deletions and substitutions required to transform one string into another. But the metric can also be formulated in another way. We may assign some cost to each operation and say that edit distance is a sequence of transformations converting one string into the other having the minimal cost.

For example, we may say that each of the described operations costs 1. In this case, there is no difference between the two formulations. But sometimes it is convenient to assign costs in another way.

Assume we are working on a system for correction of spelling mistakes. Our algorithm is the following: we get a user's request, find the most similar word in a correct word database using edit distance metric, chose the most similar one and use it instead of the initial word.

Suppose we get a request "flaq". In this case, we have at least two words having the edit distance equal to 1 with the initial string: "flaw" and "flat". So, which one should we use? On the one hand, there is no difference. But on the other hand, the word "flaw" is more similar to the word "flaq", because the letters "q" and "w" are closer on a keyboard than "w" and "t" and it's more likely that the user wanted to write "flaw" and not "flat".

To process such cases correctly, one may use a so-called scoring matrix. A scoring matrix is a table m where m[s1][s2] is a cost of a substitution of a symbol s1 by a symbol s2. For example, to solve the previous problem, we can use a matrix that assigns lower costs for symbols that are close on a keyboard and bigger costs for symbols that are far from each other.

So, your task here is to implement a simple system for correction of spelling mistakes. For convenience, we will use a shortened version of the alphabet.

Input: The first line contains a string s — a user's request. The second line contains an integer k — the size of a database. Each of the next k lines contains a string — a correct word. Each string consists of only letters a, s, d, b, n, m.

Output: The first line should contain the edit distance dE(s,t) where t is a word having the minimal edit distance with s among all other words from the database. The second line should contain a word t itself. If there are several words with the minimal edit distance, print the one that occurs first in the database.

Consider the cost of an insertion and a deletion to be equal to 11. To calculate the cost of a substitution, use the following scoring matrix:

  a s d b n m
a 0 1 2 5 6 7
s 1 0 1 5 6 7
d 2 1 0 5 6 7
b 5 6 7 0 1 2
n 5 6 7 1 0 1
m 5 6 7 2 1 0
import java.util.*;

public class Main {

    private static final Map<Character, Integer> LETTERS_TO_INDEXES_MAP = createSortingKeys();
    private static final int[][] SCORING_MATRIX = createScoringMatrix();
    public static String correctWord = "";

    public static void main(String[] args) {
        Scanner scanner = new Scanner(System.in);
        String input = scanner.nextLine();
        int size = Integer.parseInt(scanner.nextLine());
        String[] vocabulary = new String[size];
        for (int i = 0; i < size; i++) {
            vocabulary[i] = scanner.nextLine();
        }

        int editDistance = editDistanceWithScoringMatrix(input, vocabulary);
        System.out.println(String.format("%d\n%s", editDistance, correctWord));
    }


    private static int editDistanceWithScoringMatrix(String word, String[] database) {
        int minDistance = Integer.MAX_VALUE;

        for (String databaseElement : database) {
            if (editDistance(word, databaseElement, SCORING_MATRIX) < minDistance) {
                minDistance = editDistance(word, databaseElement, SCORING_MATRIX);
                correctWord = databaseElement;
            }

        }
        return minDistance;
    }
    
    private static int editDistance(String s, String t, int[][] scoringMatrix) {
        int rows = s.length() + 1;
        int columns = t.length() + 1;
        int[][] distance = new int[rows][columns];

        for (int i = 0; i < s.length() + 1; i++) {      // fill in the rows for column 0
            distance[i][0] = i;
        }

        for (int j = 0; j < t.length() + 1; j++) {      // fill in the columns for row 0
            distance[0][j] = j;
        }

        for (int i = 1; i < s.length() + 1; i++) {
            for (int j = 1; j < t.length() + 1; j++) {
                int insertionCost = distance[i][j - 1] + 1;
                int deletionCost = distance[i - 1][j] + 1;
                char a = s.charAt(i - 1);
                char b = t.charAt(j - 1);
                int substitutionCost = distance[i - 1][j - 1] + matchScoringMatrix(a, b, scoringMatrix);

                distance[i][j] = Math.min(Math.min(insertionCost, deletionCost), substitutionCost);
            }
        }

        return distance[s.length()][t.length()];
    }

    private static int[][] createScoringMatrix() {
        return new int[][] {{0, 1, 2, 5, 6, 7},
                            {1, 0, 1, 5, 6, 7},
                            {2, 1, 0, 5, 6, 7},
                            {5, 6, 7, 0, 1, 2},
                            {5, 6, 7, 1, 0, 1},
                            {5, 6, 7, 2, 1, 0}};
    }

    private static int matchScoringMatrix(char a, char b, int[][] scoring) {
        int indexOfThe1stLetterInScoringMatrix = LETTERS_TO_INDEXES_MAP.get(a);
        int indexOfThe2ndLetterInScoringMatrix = LETTERS_TO_INDEXES_MAP.get(b);

        return scoring[indexOfThe1stLetterInScoringMatrix][indexOfThe2ndLetterInScoringMatrix];
    }

    private static Map<Character, Integer> createSortingKeys() {
        Map<Character, Integer> map = new HashMap<>();
        map.put('a', 0);
        map.put('s', 1);
        map.put('d', 2);
        map.put('b', 3);
        map.put('n', 4);
        map.put('m', 5);
        return map;
    }
}

Efficient memory consumption:

It was mentioned that the algorithm for calculation of edit distance can be implemented more efficiently in term of memory consumption. Write a program that for two strings s and t, calculates dE(s,t) using O(min(∣s∣,∣t∣)) memory.

import java.util.*;

public class Main {
    public static void main(String[] args) {
        Scanner scanner = new Scanner(System.in);
        String word1 = scanner.nextLine();
        String word2 = scanner.nextLine();
        
        System.out.println(editDistanceEfficiently(word1, word2));
    }
    
    public static int match(char a, char b) {
        return (a == b) ? 0 : 1;
    }
    
    public static int editDistanceEfficiently(String s, String t) {
        int[] row = new int[t.length() + 1];

        for (int j = 0; j < t.length() + 1; j++) {      // fill in the row 0
            row[j] = j;
        }

        for (int i = 1; i < s.length() + 1; i++) {
            int temp = row[0];
            for (int j = 1; j < row.length; j++) {
                int insCost = row[j - 1] + 1;
                int delCost = row[j] + 1;
                int subCost = temp + match(s.charAt(i - 1), t.charAt(j - 1));
                int minCost = Math.min(Math.min(insCost, delCost), subCost);
                temp = row[j];
                row[j] = minCost;
            }
            row[0] = i;
        }

        return row[t.length()];
    }
}

Edit distance with insertions and deletions:

Assume we are only allowed to use two operations to transform one string into the other: an insertion and a deletion. Write a program that calculates edit distance according to this limitation.

Input: two strings s and t.

Output: the minimum number of insertions and deletions required to transform s into t.

import java.util.*;

public class Main {
    public static void main(String[] args) {
        Scanner scanner = new Scanner(System.in);
        String word1 = scanner.nextLine();
        String word2 = scanner.nextLine();
        
        System.out.println(insertDeleteOnly(word1, word2));    
    }
    
    private static int match(char a, char b, int insertionCost, int deletionCost, int temp) {
        return (a == b) ? temp : (Math.min(insertionCost, deletionCost)) + 1;
    }
    
    public static int insertDeleteOnly(String s, String t) {
        int[] row = new int[t.length() + 1];

        for (int j = 0; j < t.length() + 1; j++) {      // fill in the row 0
            row[j] = j;
        }

        for (int i = 1; i < s.length() + 1; i++) {
            int temp = row[0];
            row[0] = i;
            for (int j = 1; j < row.length; j++) {
                int insCost = row[j - 1] + 1;
                int delCost = row[j] + 1;
                int subCost = match(s.charAt(i - 1), t.charAt(j - 1), insCost, delCost, temp);
                int minCost = Math.min(Math.min(insCost, delCost), subCost);
                temp = row[j];
                row[j] = minCost;
            }
        }
        return row[t.length()];
    }
}

or

import java.util.*;

public class Main {
    public static void main(String[] args) {
        Scanner scanner = new Scanner(System.in);
        String word1 = scanner.nextLine();
        String word2 = scanner.nextLine();
        
        System.out.println(insertDeleteOnly(word1, word2));    
    }
    
    public static int match(char a, char b) {
        return (a == b) ? 0 : 2;    // returns 2, because it requires two operations: deletion + insertion
    }
    
    public static int insertDeleteOnly(String s, String t) {
        int[] row = new int[t.length() + 1];

        for (int j = 0; j < t.length() + 1; j++) {      // fill in the row 0
            row[j] = j;
        }

        for (int i = 1; i < s.length() + 1; i++) {
            int temp = row[0];
            row[0] = i;
            for (int j = 1; j < row.length; j++) {
                int insCost = row[j - 1] + 1;
                int delCost = row[j] + 1;
                int subCost = temp + match(s.charAt(i - 1), t.charAt(j - 1));
                int minCost = Math.min(Math.min(insCost, delCost), subCost);
                temp = row[j];
                row[j] = minCost;
            }

        }

        return row[t.length()];
    }
}

⚠️ **GitHub.com Fallback** ⚠️