요약
- 데이터 전처리를 보다 더 효과적으로 수행할 수 있게 코드를 수정한다.
- 파이썬 직렬화 최고! 👍
문제
어떤 데이터들은 에이전트(프로세스)를 띄울때마다 다시 로딩을 한다.
좋은 방법이 없을까? 공유된 캐시가 있어야 재 로딩을 막을 수 있을 텐데…
방법1
wandb의 agent(이렇게 명명하는게 맞나?)들은 그 자체가 새로운 프로세스다.
때문에 전역변수로는 캐시를 만들 수 없다.
갑자기 프로세스 간의 공유변수는 따로 없을까 싶어 multiprocessing 패키지의 공유 변수 기능을 찾아보았다.
아래의 Manager를 사용하면 될것 같다.
(Manager는 proxy process를 이용하고, 부모 프로세스가 종료되면 gc된다.)
일단, main에서 공유변수를 인자로 받아 캐시를 저장하게 한다.
원래, 이러면 안좋지만 파이썬이 모듈을 한번만 로드한다는 특성을 노리고 process-wise global 변수인 Cache를 만들었다.
(원래대로라면 의존성 주입을 사용했을 것이다. 허나 지금은… 그러기 힘들것 같다.)
def main(args, shared=None):
Setting.seed_everything(args.seed)
Cache.init(shared)
그리고 main에는 manager로 생성한 공유 딕셔너리를 주입했다.
sweep_id = wandb.sweep(s_config, entity=args.entity)
args_dict = vars(args)
manager = multiprocessing.Manager()
shared = manager.dict()
sweep_func = partial(main, args, shared)
마지막으로 이미지를 로딩하는 image_vector에서 cache를 가져오면 끝!
def image_vector(path):
"""
Parameters
----------
path : str
이미지가 존재하는 경로를 입력합니다.
----------
"""
img_cache = Cache.load_data(path)
if img_cache:
img = img_cache
else:
img = Image.open(path)
scale = transforms.Resize((32, 32))
img = scale(img)
Cache.save_data(path, img)
tensor = transforms.ToTensor()
img_fe = Variable(tensor(img))
return img_fe
속도는 아래와 같이 빨라졌다.
# 전
129777it [01:37, 1328.17it/s]
52000it [00:26, 1952.13it/s]
# 후
129777it [00:46, 2774.93it/s]
52000it [00:18, 2811.94it/s]
방법2
위의 방법은 생각보다 느려서 아예 이미지 텐서 자체를 덤핑하는 방법을 생각했다.
이유는 모르지만 더 빨라졌다…
공유변수지만 같은 메모리를 참조하고 있는 건지 의심이 들었다.
129777it [00:27, 4720.70it/s]
52000it [00:10, 4760.39it/s]
방법3
이번엔 img_path column의 해시값으로 img_vector를 아예 덤핑해보았다.
몇초 걸리고 아예 로딩이 없다.
사용한 캐시객체는 아래와 같다.
Series객체는 hash 함수가 먹히지 않는다.
따라서 해싱을 하려면 pd.util.hash_pandas_object 를 사용해야 한다.
(다른 방법으로는 hash(tuple(sorted(series))) 하는 방법이 있을 것 같다.)
class Cache:
@staticmethod
def dump(key, data):
"""
Parameters
----------
key : str
덤핑되는 파일 이름
data : any
저장할 object
----------
"""
fpath = f'__pycache__/{key}.pt'
torch.save(data, fpath)
@staticmethod
def load(key):
"""
Parameters
----------
key : str
불러올 데이터의 key
----------
"""
fpath = f'__pycache__/{key}.pt'
if os.path.exists(fpath):
return torch.load(fpath)
return None
@staticmethod
def hash(series):
"""
Parameters
----------
series : pandas.Series
불러올 데이터의 key
----------
"""
return pd.util.hash_pandas_object(series).sum()
위의 Cache를 이용해 해시값을 만들고, 해시값으로 데이터를 찾을 수 있다.
img_path_hash = Cache.hash(img_vector_df['img_path'])
cache = Cache.load(img_path_hash)
if cache is not None: # 여기서 if cache 하면 멍청한 pandas가 이상한 로직 돌린다.
print(f'image vector found: cache [{img_path_hash}]')
img_vector_df['img_vector'] = cache
else:
data_box = []
for idx, path in tqdm(enumerate(sorted(img_vector_df['img_path']))):
data = image_vector(path)
if data.size()[0] == 3:
data_box.append(np.array(data))
else:
data_box.append(np.array(data.expand(3, data.size()[1], data.size()[2])))
img_vector_df['img_vector'] = data_box
Cache.dump(img_path_hash, img_vector_df['img_vector'])
'프로그래밍 > 부스트캠프 AI' 카테고리의 다른 글
[DKT] lgbm에 label을 feature로 넣으면... (0) | 2023.05.12 |
---|---|
DKT - EDA 해보기 (1) | 2023.05.06 |
Day 5 - CNN, RNN (0) | 2023.03.11 |
Day 4 - AI Basic (0) | 2023.03.09 |
Day 3 - pandas (0) | 2023.03.08 |