726. Number of Atoms

class Solution {
    public String countOfAtoms(String formula) {
        StringBuilder sb = new StringBuilder();
        Map<String, Integer> memo= helper(formula, 0).getKey();

        List<String> atoms= new ArrayList<>(memo.keySet());
        Collections.sort(atoms);
        
        for(String str: atoms){
            sb.append(str+ (memo.get(str)==1? "" : memo.get(str)));
        }
        return sb.toString();
    }
    
    private Pair<Map<String, Integer>, Integer> helper(String formula, int start){
        Map<String, Integer> memo = new HashMap<>();
        int index = start;
        while(index<formula.length()){
            if(formula.charAt(index)>='A' && formula.charAt(index)<='Z'){
                index = countKey(formula, memo, index);
            }else if(formula.charAt(index)=='('){
                Pair<Map<String, Integer>, Integer> pair = helper(formula, index+1);
                Map<String, Integer> subMemo = pair.getKey();
                // add subMemo to memo
                for(String key : subMemo.keySet()){
                    memo.put(key, memo.getOrDefault(key, 0)+subMemo.get(key));
                }
                index = pair.getValue();
            }else if(formula.charAt(index) == ')'){
                // get the count after ')'
                int count=1;
                int subStart = ++index;
                while(index<formula.length() && formula.charAt(index)>='0' && formula.charAt(index)<='9'){
                    index++;
                }
                if(index>subStart){
                    count=Integer.parseInt(formula.substring(subStart, index));
                }
                for(String key : memo.keySet()){
                    memo.put(key, memo.get(key)*count);
                }

                return new Pair(memo, index);
            }
        }
        return new Pair(memo, index);
    }
    
    private int countKey(String formular, Map<String, Integer> memo, int start){
        // start will always be the upper case letter
        int end = start+1;
        while(end<formular.length() && formular.charAt(end)>='a' && formular.charAt(end)<='z'){
            end++;
        }
        String key = formular.substring(start, end);
        
        // now get the integer after the key
        int count = 1;
        int subStart = end;
        while(end <formular.length() && formular.charAt(end)>='0' && formular.charAt(end)<='9'){
            end++;
        }
        if (end>subStart){
            count= Integer.parseInt(formular.substring(subStart, end));
        }
        memo.put(key, memo.getOrDefault(key, 0)+count);
        return end;
    }
}

Last updated