기계는 거짓말하지 않는다

Python DataFrame 데이터 분리 후 csv 생성 본문

Python

Python DataFrame 데이터 분리 후 csv 생성

KillinTime 2021. 10. 8. 15:50

특정 데이터 값을 기준으로 나눠 각각 csv로 저장

아래는 하나의 data csv 파일에서 AI 학습용 train, test, valid dataset으로 분리한다.

비율만큼 분리는 sklearn의 train_test_split을 이용했다.

def trim_string(x, trim_word_count=100):
    x = x.split(maxsplit=trim_word_count)
    x = ' '.join(x[:trim_word_count])
    return x

def create_split_csv(raw_data_path=".", dest_path=".", label_numbers=[0, 1, 2],
    train_csv_name="train.csv", valid_csv_name="valid.csv", test_csv_name="test.csv", 
    skiprows=1, encoding="utf-8", test_size=0.25, valid_size=0.25, random_seed=1):

    # 원본 csv data
    df_raw = pd.read_csv(raw_data_path, skiprows=skiprows, encoding=encoding)

    # 지정된 문자 제거
    # df_raw["text"] = df_raw["text"].str.replace(pat=r'[\'\,\.\?\!]', repl=r'', regex=True)

    # 빈 텍스트 행 제거
    df_raw.drop(df_raw[df_raw.text.str.len() < 1].index, inplace=True)

    # 빈 값을 가진 행 제거
    df_raw = df_raw.dropna()

    # 지정된 단어 개수 만큼 단어 앞뒤로 공백 trim
    df_raw['text'] = df_raw['text'].apply(trim_string)

    df_split_train = pd.DataFrame()
    df_split_valid = pd.DataFrame()
    df_split_test = pd.DataFrame()

    for ln in label_numbers:
        # label에 따라 분리
        df_label = df_raw[df_raw['label'] == ln]
        if len(df_label) == 0:
            continue

        # 원본 data를 지정된 비율만큼 train, test dataset으로 나눔
        df_full_train, df_test = train_test_split(df_label, test_size=test_size, random_state=random_seed, shuffle=True)
        # 위에서 나눠진 train dataset에서 지정된 비율만큼 valid dataset으로 나눔
        df_train, df_valid = train_test_split(df_full_train, test_size=valid_size, random_state=random_seed, shuffle=True)

        # label 값 별로 모아서 이어줌
        df_split_train = pd.concat([df_split_train, df_train], ignore_index=True, sort=False)
        df_split_valid = pd.concat([df_split_valid, df_valid], ignore_index=True, sort=False)
        df_split_test = pd.concat([df_split_test, df_test], ignore_index=True, sort=False)

    # 나눠진 dataset들을 csv로 저장
    df_split_train.to_csv(dest_path + "/" + train_csv_name, index=False, encoding="utf-8")
    df_split_valid.to_csv(dest_path + "/" + valid_csv_name, index=False, encoding="utf-8")
    df_split_test.to_csv(dest_path + "/" + test_csv_name, index=False, encoding="utf-8")

 

'Python' 카테고리의 다른 글

Python XML ElementTree Read  (0) 2021.10.15
Python Priority Queue  (0) 2021.10.13
Python GUI PyQt5, QtDesigner  (0) 2021.10.07
Python DataFrame 특정 columns 추출  (0) 2021.10.05
Python Ignoring invalid distribution 오류  (0) 2021.09.26
Comments