Підручник з обрізки¶

Сучасні методи глибокого навчання покладаються на надпараметризовані моделі, які важко застосувати. Навпаки, відомо, що біологічні нейронні мережі використовують ефективний розріджений зв’язок. Визначення оптимальних методів стиснення моделей за рахунок зменшення кількості параметрів у них є важливим для зменшення споживання пам'яті, заряду акумулятора та обладнання без шкоди точності, розгортання легких моделей на пристрої та гарантування конфіденційності при приватних обчисленнях на пристрої. На фронті досліджень обрізання використовується для дослідження відмінностей у динаміці навчання між надпараметризованими та недопараметризованими мережами, для вивчення ролі щасливих розріджених підмереж та ініціалізацій («лотерейних квитків») як деструктивної техніки пошуку нейронної архітектури, і більше.

utils prune

У цьому підручнику ви дізнаєтесь, як використовувати torch.nn.utils.prune для розрідження нейронних мереж, і як розширити його, щоб реалізувати власну техніку обрізки.

Вимоги¶

Створити модель¶

У цьому посібнику ми використовуємо архітектуру LeNet від LeCun et al., 1998.

Перевірка модуля¶

Давайте перевіримо (необрізаний) рівень conv1 у нашій моделі LeNet. Наразі він буде містити два параметри ваги та упередженості та жодних буферів.

Обрізка модуля¶

Щоб обрізати модуль (у цьому прикладі рівень conv1 нашої архітектури LeNet), спочатку виберіть техніку обрізки серед доступних у torch.nn.utils.prune (або застосуйте свою власну, підкласуючи BasePruningMethod). Потім вкажіть модуль та ім’я параметра, який потрібно обрізати в цьому модулі. Нарешті, використовуючи відповідні аргументи ключових слів, необхідні для обраної техніки обрізки, вкажіть параметри обрізки.

У цьому прикладі ми обрізаємо випадковим чином 30% з'єднань у параметрі з назвою вага у шарі conv1. Модуль передається як перший аргумент функції; name ідентифікує параметр у цьому модулі за допомогою його ідентифікатора рядка; і сума вказує або відсоток з'єднань для обрізки (якщо це плаваюча величина від 0 до 1), або абсолютну кількість з'єднань для обрізки (якщо це невід'ємне ціле число).

Обрізка діє шляхом видалення ваги з параметрів і заміни її новим параметром, який називається weight_orig (тобто додаванням "_orig" до початкового імені параметра). weight_orig зберігає необрізану версію тензора. Упередження не було обрізане, тому воно залишиться цілим.

Маска обрізання, створена вибраною вище технікою обрізки, зберігається як буфер модуля з назвою weight_mask (тобто додаючи "_mask" до початкового імені параметра).

Щоб прямий прохід працював без змін, повинен існувати атрибут вага. Прийоми обрізки, реалізовані в torch.nn.utils.prune, обчислюють обрізану версію ваги (комбінуючи маску з вихідним параметром) і зберігають їх у вазі атрибута. Зауважте, це вже не параметр модуля, це тепер просто атрибут.

Нарешті, обрізка застосовується перед кожним прямим проходом за допомогою forward_pre_hooks PyTorch. Зокрема, коли модуль буде обрізаний, як ми це зробили тут, він отримає forward_pre_hook для кожного пов'язаного з ним параметра, який отримує обрізку. У цьому випадку, оскільки ми досі обрізали лише початковий параметр з назвою вага, буде присутній лише один гачок.

Для повноти ми тепер можемо також обрізати упередження, щоб побачити, як змінюються параметри, буфери, хуки та атрибути модуля. Просто для того, щоб випробувати інший прийом обрізки, тут ми обрізаємо 3 найменші записи в ухилі за нормою L1, як реалізовано у функції l1_unstructured обрізки.

Тепер ми очікуємо, що названі параметри включатимуть і weight_orig (раніше), і bias_orig. Буфери включатимуть weight_mask та bias_mask. Обрізані версії двох тензорів існуватимуть як атрибути модуля, і модуль тепер матиме два forward_pre_hooks .

Повторна обрізка¶

Один і той же параметр у модулі можна обрізати кілька разів, при цьому ефект від різних викликів обрізки дорівнює комбінації різних масок, що застосовуються послідовно. Поєднання нової маски зі старою маскою обробляється методом compute_mask PruningContainer.

Скажімо, наприклад, що ми зараз хочемо продовжити обрізання module.weight, цього разу використовуючи структуровану обрізку вздовж 0-ї осі тензора (0-а вісь відповідає вихідним каналам згорткового шару і має розмірність 6 для conv1), на основі щодо норми L2 каналів. Цього можна досягти за допомогою ln_structured функції з n = 2 і dim = 0 .

Відповідний гачок тепер матиме тип torch.nn.utils.prune.PruningContainer і буде зберігати історію обрізки, застосовану до параметра ваги.

Серіалізація обрізаної моделі¶

Усі відповідні тензори, включаючи буфери маски та оригінальні параметри, що використовуються для обчислення обрізаних тензорів, зберігаються в state_dict моделі і, отже, можуть бути легко серіалізовані та збережені, якщо потрібно.

Видалити повторну параметризацію обрізкиning

Щоб зробити обрізку постійною, видаліть повторну параметризацію з точки зору weight_orig та weight_mask та видаліть forward_pre_hook, ми можемо використовувати функцію видалення з torch.nn.utils.prune. Зверніть увагу, що це не скасовує обрізання, ніби цього ніколи не було. Це просто робить його постійним, натомість, переназначаючи вагу параметрів на параметри моделі у її обрізаній версії.

До зняття повторної параметризації:

Після видалення повторної параметризації:

Обрізання кількох параметрів у моделі¶

Вказавши бажану техніку та параметри обрізки, ми можемо легко обрізати кілька тензорів у мережі, можливо, відповідно до їх типу, як ми побачимо у цьому прикладі.

Глобальна обрізка¶

Поки що ми розглядали лише те, що зазвичай називають „місцевим” обрізанням, тобто практикою обрізання тензорів у моделі по одному, порівнюючи статистичні дані (величина ваги, активація, градієнт тощо) кожного винятково до інших записів у цьому тензорі. Однак загальним і, можливо, більш потужним прийомом є обрізання моделі відразу, шляхом видалення (наприклад) найнижчих 20% з'єднань по всій моделі, замість того, щоб видалити найменших 20% з'єднань у кожному шарі. Це може призвести до різного відсотка обрізки на шар. Давайте подивимося, як це зробити, використовуючи global_unstructured з torch.nn.utils.prune .

Тепер ми можемо перевірити розрідженість, спричинену кожним обрізаним параметром, яка не буде дорівнювати 20% у кожному шарі. Однак глобальна розрідженість становитиме (приблизно) 20%.

Розширення torch.nn.utils.prune за допомогою спеціальних функцій обрізки¶

Щоб реалізувати власну функцію обрізки, ви можете розширити модуль nn.utils.prune, підкласуючи базовий клас BasePruningMethod, як і всі інші методи обрізки. Базовий клас реалізує для вас наступні методи: __call__, apply_mask, apply, prune та видалення. Окрім деяких особливих випадків, вам не доведеться застосовувати ці методи для нової техніки обрізки. Однак вам доведеться реалізувати __init__ (конструктор) і compute_mask (інструкції щодо обчислення маски для даного тензора відповідно до логіки вашої техніки обрізки). Крім того, вам доведеться вказати, який тип обрізки реалізує ця техніка (підтримувані варіанти є загальними, структурованими та неструктурованими). Це потрібно для того, щоб визначити, як поєднувати маски у випадку, коли обрізка застосовується ітеративно. Іншими словами, при обрізанні попередньо обрізаного параметра, як очікується, поточна техніка обрізки впливатиме на необрізану частину параметра. Вказівка ​​PRUNING_TYPE дозволить PruningContainer (який обробляє ітеративне застосування масок для обрізки) правильно визначити зріз параметра для обрізки.

Припустимо, наприклад, що ви хочете застосувати техніку обрізки, яка обрізає кожен другий запис у тензорі (або - якщо тензор раніше був обрізаний - у решті необрізаної частини тензора). Це буде PRUNING_TYPE = 'неструктурований', оскільки діє на окремі з'єднання в шарі, а не на цілі одиниці/канали ('структурований'), або на різні параметри ('глобальний').