問題

discovery2016-qual.contest.atcoder.jp

解法

入力文字列 s の長さを N とします。もし s の中に N-K 個以上 a がある場合は, N-K 個の a を最終的な文字列にするのが最強です。

別の場合は, a のみを残す, という戦略は出来ないので, a をなるべく手前に残してその後の辞書順もなるべく最小にする, というような戦略をしたいです。

これをやるために, SA を考えます。

s[i] != a であるような i について, その手前に a をいくつおけるか, そしてその後の文字列がどうなるかを調べて, 辞書順最小になるものを選びます。

#define _CRT_SECURE_NO_WARNINGS

#include <stdio.h>
#include <algorithm>
#include <string.h>
#include<iostream>

using namespace std;

#define FOR(i,a,b) for(int i=a;i<b;i++)
typedef long long ll;
typedef vector<int> vi;
typedef vector<ll> vll;
typedef pair<int, int> pii;

#define REP(i,b) FOR(i,0,b)

const int Nmax = 300300;
int bucket[Nmax];

template <class T>
void CreateBeginBucket(T* data, int size, int maxVal){
    REP(i, maxVal + 1) bucket[i] = 0;
    REP(i, size) bucket[data[i]]++;
    int sum = 0;
    REP(i, maxVal + 1){ bucket[i] += sum; swap(bucket[i], sum); }
}

template <class T>
void CreateEndBucket(T* data, int size, int maxVal){
    REP(i, maxVal + 1) bucket[i] = 0;
    REP(i, size) bucket[data[i]]++;
    int sum = 0;
    REP(i, maxVal + 1){ sum += bucket[i]; bucket[i] = sum; }
}

template<class T>
void InducedSort(T* data, int size, int* SA, int maxVal, bool* isL){
    CreateBeginBucket(data, size, maxVal);
    REP(i, size) if (SA[i] > 0 && isL[SA[i] - 1]) SA[bucket[data[SA[i] - 1]]++] = SA[i] - 1;
}

template<class T>
void InvertInducedSort(T* data, int size, int* SA, int maxVal, bool* isL){
    CreateEndBucket(data, size, maxVal);
    for (int i = size - 1; i >= 0; --i) if (SA[i] > 0 && !isL[SA[i] - 1]) SA[--bucket[data[SA[i] - 1]]] = SA[i] - 1;
}

template <class T>
void DBGOUT(T* sa, int size){
    REP(i, size) printf("%d ", int(sa[i]));
    printf("
");
}

template<class T>
void SA_IS(T* data, int size, int* SA, int maxVal, bool* isL){
    REP(i, size) SA[i] = -1;
#define isLMS(x) (x>0 && isL[x-1] && !isL[x])
    isL[size - 1] = false;
    for (int i = size - 2; i >= 0; i--) isL[i] = data[i] > data[i + 1] || (data[i] == data[i + 1] && isL[i + 1]);
    CreateEndBucket(data, size, maxVal);
    FOR(i, 1, size) if (isLMS(i)) SA[--bucket[data[i]]] = i;
    InducedSort(data, size, SA, maxVal, isL);
    InvertInducedSort(data, size, SA, maxVal, isL);

    int c = 0;
    REP(i, size) if (isLMS(SA[i])) SA[c++] = SA[i];
    FOR(i, c, size) SA[i] = -1;

    int idx = -1;
    int prev = -1;
    REP(i, c){
        bool diff = false;
        REP(d, size){
            if (prev == -1 || data[SA[i] + d] != data[prev + d] || isL[SA[i] + d] != isL[prev + d]){
                diff = true;
                break;
            }
            else if (d > 0 && isLMS(SA[i] + d)) break;
        }
        if (diff){ idx++; prev = SA[i]; }
        SA[c + SA[i] / 2] = idx;
    }
    int j = size;
    for (int i = size - 1; i >= c; i--) if (SA[i] >= 0) SA[--j] = SA[i];

    int* nxdata = SA + size - c;
    int* nxsa = SA;
    if (c == idx + 1) REP(i, c) nxsa[nxdata[i]] = i;
    else SA_IS(nxdata, c, nxsa, idx, isL + size);

    j = c;
    for (int i = size - 1; i >= 1; i--) if (isLMS(i)) nxdata[--j] = i;
    REP(i, c) nxsa[i] = nxdata[nxsa[i]];
    FOR(i, c, size) SA[i] = -1;
    CreateEndBucket(data, size, maxVal);
    for (int i = c - 1; i >= 0; i--) swap(nxsa[i], SA[--bucket[data[nxsa[i]]]]);
    InducedSort(data, size, SA, maxVal, isL);
    InvertInducedSort(data, size, SA, maxVal, isL);
}

// SA_IS
// input: 対象となる文字列
// size : 文字列の長さ
// SA   : 返される suffix array
bool isLPool[Nmax * 2];
void SA_IS(unsigned char* input, int size, int* SA){
    int mv = 0;
    REP(i, size) if (mv < input[i]) mv = input[i];
    SA_IS(input, size, SA, mv, isLPool);
}

// CreateLCP
// data: 対象となる文字列
// size: 文字列の長さ
// SA  : dataの suffix array の情報
// lcp という配列に情報が保存される
int lcp[Nmax];
int invertSA[Nmax];
void CreateLCP(unsigned char* data, int size, int* SA){
    lcp[0] = -1;
    REP(i, size) invertSA[SA[i]] = i;
    int prev = 0;
    REP(i, size){
        if (invertSA[i] > 0){
            while (data[i + prev] == data[SA[invertSA[i] - 1] + prev]){
                ++prev;
                if (i + prev >= size || SA[invertSA[i] - 1] + prev >= size)
                    break;
            }
            lcp[invertSA[i]] = prev;
        }
        prev = max(prev - 1, 0);
    }
}

int st[21][Nmax];
void InitSparseTable(int n){
    int h = 1;
    while ((1 << h) < n) h++;
    REP(i, n) st[0][i] = lcp[i];
    FOR(j, 1, h + 1){
        REP(i, n - (1<<j) + 1){
            st[j][i] = min(st[j - 1][i], st[j - 1][i + (1 << (j - 1))]);
        }
    }
}

inline int TopBit(int t){
    for (int i = 20; i >= 0; i--){
        if ((1 << i)&t) return i;
    }
    return -1;
}

int GetLCP(int f, int s){
    if (f > s) swap(f, s);
    int diff = TopBit(s-f);
    return min(st[diff][f], st[diff][s - (1 << diff)]);
}

unsigned char str[Nmax];
int indices[Nmax];

int compare(int f, int s, int l){
    int fi = invertSA[f];
    int si = invertSA[s];
    if (GetLCP(fi + 1, si + 1) >= l)
        return 0;
    else
        return 2 * (fi > si) - 1;
}

int SA[Nmax];
int cnt[Nmax];
int inv[Nmax];

int main() {
    string s;
    scanf("%s", str);
    s = string((char*)str);
    int N = s.size();
    SA_IS(str, N+1, SA);
    CreateLCP(str, N+1, SA);
    int K;
    cin >> K;
    int n = 0;
    for (int i = 0; i < N; i++) n += (s[i]==a);
    string ans;
    if (K+n >= N) {
        for (int i = 0; i < N-K; i++) ans += a;
    } else {
        ll head = K;
        pair<ll, int> best = make_pair(100*100*100, -1);
        for (int i = 0; i < N; i++) {
            if (s[i] == a) continue;
            best = min(best, make_pair((head+i)*(-Nmax)+invertSA[i], i));
            head--;
            if (head < 0) break;
        }
        head = K;
        int len = best.second;
        for (int i = 0; i < len; i++) {
            ans += a;
            if (s[i] != a) {
                head--;
            }
        }
        for (int i = 0; i < head; i++) ans += a;
        ans += s.substr(len);
    }
    cout << ans << endl;
    return 0;
}