From b6b6291016545d48f084cc5b5e2b6c93a3e6b651 Mon Sep 17 00:00:00 2001 From: FrostNov4 <97688822+FrostNov4@users.noreply.github.com> Date: Sun, 15 Oct 2023 16:57:28 +1000 Subject: [PATCH 01/20] initialised --- LICENSE | 875 ++++++++++++++++++++++++++++++++---------- README.md | 196 +++++++++- evaluate.py | 40 ++ predict.py | 117 ++++++ requirements.txt | 5 + train.py | 237 ++++++++++++ unet/__init__.py | 1 + unet/unet_model.py | 48 +++ unet/unet_parts.py | 77 ++++ utils/__init__.py | 0 utils/data_loading.py | 117 ++++++ utils/dice_score.py | 28 ++ utils/utils.py | 13 + 13 files changed, 1542 insertions(+), 212 deletions(-) create mode 100644 evaluate.py create mode 100755 predict.py create mode 100644 requirements.txt create mode 100644 train.py create mode 100644 unet/__init__.py create mode 100644 unet/unet_model.py create mode 100644 unet/unet_parts.py create mode 100644 utils/__init__.py create mode 100644 utils/data_loading.py create mode 100644 utils/dice_score.py create mode 100644 utils/utils.py diff --git a/LICENSE b/LICENSE index 261eeb9e9..94a9ed024 100644 --- a/LICENSE +++ b/LICENSE @@ -1,201 +1,674 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. + GNU GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU General Public License is a free, copyleft license for +software and other kinds of works. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +the GNU General Public License is intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. We, the Free Software Foundation, use the +GNU General Public License for most of our software; it applies also to +any other work released this way by its authors. You can apply it to +your programs, too. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + To protect your rights, we need to prevent others from denying you +these rights or asking you to surrender the rights. Therefore, you have +certain responsibilities if you distribute copies of the software, or if +you modify it: responsibilities to respect the freedom of others. + + For example, if you distribute copies of such a program, whether +gratis or for a fee, you must pass on to the recipients the same +freedoms that you received. You must make sure that they, too, receive +or can get the source code. And you must show them these terms so they +know their rights. + + Developers that use the GNU GPL protect your rights with two steps: +(1) assert copyright on the software, and (2) offer you this License +giving you legal permission to copy, distribute and/or modify it. + + For the developers' and authors' protection, the GPL clearly explains +that there is no warranty for this free software. For both users' and +authors' sake, the GPL requires that modified versions be marked as +changed, so that their problems will not be attributed erroneously to +authors of previous versions. + + Some devices are designed to deny users access to install or run +modified versions of the software inside them, although the manufacturer +can do so. This is fundamentally incompatible with the aim of +protecting users' freedom to change the software. The systematic +pattern of such abuse occurs in the area of products for individuals to +use, which is precisely where it is most unacceptable. Therefore, we +have designed this version of the GPL to prohibit the practice for those +products. If such problems arise substantially in other domains, we +stand ready to extend this provision to those domains in future versions +of the GPL, as needed to protect the freedom of users. + + Finally, every program is threatened constantly by software patents. +States should not allow patents to restrict development and use of +software on general-purpose computers, but in those that do, we wish to +avoid the special danger that patents applied to a free program could +make it effectively proprietary. To prevent this, the GPL assures that +patents cannot be used to render the program non-free. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Use with the GNU Affero General Public License. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU Affero General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the special requirements of the GNU Affero General Public License, +section 13, concerning interaction through a network will apply to the +combination as such. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU General Public License from time to time. Such new versions will +be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If the program does terminal interaction, make it output a short +notice like this when it starts in an interactive mode: + + Copyright (C) + This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. + This is free software, and you are welcome to redistribute it + under certain conditions; type `show c' for details. + +The hypothetical commands `show w' and `show c' should show the appropriate +parts of the General Public License. Of course, your program's commands +might be different; for a GUI interface, you would use an "about box". + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU GPL, see +. + + The GNU General Public License does not permit incorporating your program +into proprietary programs. If your program is a subroutine library, you +may consider it more useful to permit linking proprietary applications with +the library. If this is what you want to do, use the GNU Lesser General +Public License instead of this License. But first, please read +. diff --git a/README.md b/README.md index 4a064f841..cb4d4c986 100644 --- a/README.md +++ b/README.md @@ -1,15 +1,189 @@ -# Pattern Analysis -Pattern Analysis of various datasets by COMP3710 students at the University of Queensland. +# U-Net: Semantic segmentation with PyTorch + + + + -We create pattern recognition and image processing library for Tensorflow (TF), PyTorch or JAX. +![input and output for a random image in the test dataset](https://i.imgur.com/GD8FcB7.png) -This library is created and maintained by The University of Queensland [COMP3710](https://my.uq.edu.au/programs-courses/course.html?course_code=comp3710) students. -The library includes the following implemented in Tensorflow: -* fractals -* recognition problems +Customized implementation of the [U-Net](https://arxiv.org/abs/1505.04597) in PyTorch for Kaggle's [Carvana Image Masking Challenge](https://www.kaggle.com/c/carvana-image-masking-challenge) from high definition images. -In the recognition folder, you will find many recognition problems solved including: -* OASIS brain segmentation -* Classification -etc. +- [Quick start](#quick-start) + - [Without Docker](#without-docker) + - [With Docker](#with-docker) +- [Description](#description) +- [Usage](#usage) + - [Docker](#docker) + - [Training](#training) + - [Prediction](#prediction) +- [Weights & Biases](#weights--biases) +- [Pretrained model](#pretrained-model) +- [Data](#data) + +## Quick start + +### Without Docker + +1. [Install CUDA](https://developer.nvidia.com/cuda-downloads) + +2. [Install PyTorch 1.13 or later](https://pytorch.org/get-started/locally/) + +3. Install dependencies +```bash +pip install -r requirements.txt +``` + +4. Download the data and run training: +```bash +bash scripts/download_data.sh +python train.py --amp +``` + +### With Docker + +1. [Install Docker 19.03 or later:](https://docs.docker.com/get-docker/) +```bash +curl https://get.docker.com | sh && sudo systemctl --now enable docker +``` +2. [Install the NVIDIA container toolkit:](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) +```bash +distribution=$(. /etc/os-release;echo $ID$VERSION_ID) \ + && curl -s -L https://nvidia.github.io/nvidia-docker/gpgkey | sudo apt-key add - \ + && curl -s -L https://nvidia.github.io/nvidia-docker/$distribution/nvidia-docker.list | sudo tee /etc/apt/sources.list.d/nvidia-docker.list +sudo apt-get update +sudo apt-get install -y nvidia-docker2 +sudo systemctl restart docker +``` +3. [Download and run the image:](https://hub.docker.com/repository/docker/milesial/unet) +```bash +sudo docker run --rm --shm-size=8g --ulimit memlock=-1 --gpus all -it milesial/unet +``` + +4. Download the data and run training: +```bash +bash scripts/download_data.sh +python train.py --amp +``` + +## Description +This model was trained from scratch with 5k images and scored a [Dice coefficient](https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient) of 0.988423 on over 100k test images. + +It can be easily used for multiclass segmentation, portrait segmentation, medical segmentation, ... + + +## Usage +**Note : Use Python 3.6 or newer** + +### Docker + +A docker image containing the code and the dependencies is available on [DockerHub](https://hub.docker.com/repository/docker/milesial/unet). +You can download and jump in the container with ([docker >=19.03](https://docs.docker.com/get-docker/)): + +```console +docker run -it --rm --shm-size=8g --ulimit memlock=-1 --gpus all milesial/unet +``` + + +### Training + +```console +> python train.py -h +usage: train.py [-h] [--epochs E] [--batch-size B] [--learning-rate LR] + [--load LOAD] [--scale SCALE] [--validation VAL] [--amp] + +Train the UNet on images and target masks + +optional arguments: + -h, --help show this help message and exit + --epochs E, -e E Number of epochs + --batch-size B, -b B Batch size + --learning-rate LR, -l LR + Learning rate + --load LOAD, -f LOAD Load model from a .pth file + --scale SCALE, -s SCALE + Downscaling factor of the images + --validation VAL, -v VAL + Percent of the data that is used as validation (0-100) + --amp Use mixed precision +``` + +By default, the `scale` is 0.5, so if you wish to obtain better results (but use more memory), set it to 1. + +Automatic mixed precision is also available with the `--amp` flag. [Mixed precision](https://arxiv.org/abs/1710.03740) allows the model to use less memory and to be faster on recent GPUs by using FP16 arithmetic. Enabling AMP is recommended. + + +### Prediction + +After training your model and saving it to `MODEL.pth`, you can easily test the output masks on your images via the CLI. + +To predict a single image and save it: + +`python predict.py -i image.jpg -o output.jpg` + +To predict a multiple images and show them without saving them: + +`python predict.py -i image1.jpg image2.jpg --viz --no-save` + +```console +> python predict.py -h +usage: predict.py [-h] [--model FILE] --input INPUT [INPUT ...] + [--output INPUT [INPUT ...]] [--viz] [--no-save] + [--mask-threshold MASK_THRESHOLD] [--scale SCALE] + +Predict masks from input images + +optional arguments: + -h, --help show this help message and exit + --model FILE, -m FILE + Specify the file in which the model is stored + --input INPUT [INPUT ...], -i INPUT [INPUT ...] + Filenames of input images + --output INPUT [INPUT ...], -o INPUT [INPUT ...] + Filenames of output images + --viz, -v Visualize the images as they are processed + --no-save, -n Do not save the output masks + --mask-threshold MASK_THRESHOLD, -t MASK_THRESHOLD + Minimum probability value to consider a mask pixel white + --scale SCALE, -s SCALE + Scale factor for the input images +``` +You can specify which model file to use with `--model MODEL.pth`. + +## Weights & Biases + +The training progress can be visualized in real-time using [Weights & Biases](https://wandb.ai/). Loss curves, validation curves, weights and gradient histograms, as well as predicted masks are logged to the platform. + +When launching a training, a link will be printed in the console. Click on it to go to your dashboard. If you have an existing W&B account, you can link it + by setting the `WANDB_API_KEY` environment variable. If not, it will create an anonymous run which is automatically deleted after 7 days. + + +## Pretrained model +A [pretrained model](https://github.com/milesial/Pytorch-UNet/releases/tag/v3.0) is available for the Carvana dataset. It can also be loaded from torch.hub: + +```python +net = torch.hub.load('milesial/Pytorch-UNet', 'unet_carvana', pretrained=True, scale=0.5) +``` +Available scales are 0.5 and 1.0. + +## Data +The Carvana data is available on the [Kaggle website](https://www.kaggle.com/c/carvana-image-masking-challenge/data). + +You can also download it using the helper script: + +``` +bash scripts/download_data.sh +``` + +The input images and target masks should be in the `data/imgs` and `data/masks` folders respectively (note that the `imgs` and `masks` folder should not contain any sub-folder or any other files, due to the greedy data-loader). For Carvana, images are RGB and masks are black and white. + +You can use your own dataset as long as you make sure it is loaded properly in `utils/data_loading.py`. + + +--- + +Original paper by Olaf Ronneberger, Philipp Fischer, Thomas Brox: + +[U-Net: Convolutional Networks for Biomedical Image Segmentation](https://arxiv.org/abs/1505.04597) + +![network architecture](https://i.imgur.com/jeDVpqF.png) diff --git a/evaluate.py b/evaluate.py new file mode 100644 index 000000000..9a4e3ba2b --- /dev/null +++ b/evaluate.py @@ -0,0 +1,40 @@ +import torch +import torch.nn.functional as F +from tqdm import tqdm + +from utils.dice_score import multiclass_dice_coeff, dice_coeff + + +@torch.inference_mode() +def evaluate(net, dataloader, device, amp): + net.eval() + num_val_batches = len(dataloader) + dice_score = 0 + + # iterate over the validation set + with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp): + for batch in tqdm(dataloader, total=num_val_batches, desc='Validation round', unit='batch', leave=False): + image, mask_true = batch['image'], batch['mask'] + + # move images and labels to correct device and type + image = image.to(device=device, dtype=torch.float32, memory_format=torch.channels_last) + mask_true = mask_true.to(device=device, dtype=torch.long) + + # predict the mask + mask_pred = net(image) + + if net.n_classes == 1: + assert mask_true.min() >= 0 and mask_true.max() <= 1, 'True mask indices should be in [0, 1]' + mask_pred = (F.sigmoid(mask_pred) > 0.5).float() + # compute the Dice score + dice_score += dice_coeff(mask_pred, mask_true, reduce_batch_first=False) + else: + assert mask_true.min() >= 0 and mask_true.max() < net.n_classes, 'True mask indices should be in [0, n_classes[' + # convert to one-hot format + mask_true = F.one_hot(mask_true, net.n_classes).permute(0, 3, 1, 2).float() + mask_pred = F.one_hot(mask_pred.argmax(dim=1), net.n_classes).permute(0, 3, 1, 2).float() + # compute the Dice score, ignoring background + dice_score += multiclass_dice_coeff(mask_pred[:, 1:], mask_true[:, 1:], reduce_batch_first=False) + + net.train() + return dice_score / max(num_val_batches, 1) diff --git a/predict.py b/predict.py new file mode 100755 index 000000000..b74c4608d --- /dev/null +++ b/predict.py @@ -0,0 +1,117 @@ +import argparse +import logging +import os + +import numpy as np +import torch +import torch.nn.functional as F +from PIL import Image +from torchvision import transforms + +from utils.data_loading import BasicDataset +from unet import UNet +from utils.utils import plot_img_and_mask + +def predict_img(net, + full_img, + device, + scale_factor=1, + out_threshold=0.5): + net.eval() + img = torch.from_numpy(BasicDataset.preprocess(None, full_img, scale_factor, is_mask=False)) + img = img.unsqueeze(0) + img = img.to(device=device, dtype=torch.float32) + + with torch.no_grad(): + output = net(img).cpu() + output = F.interpolate(output, (full_img.size[1], full_img.size[0]), mode='bilinear') + if net.n_classes > 1: + mask = output.argmax(dim=1) + else: + mask = torch.sigmoid(output) > out_threshold + + return mask[0].long().squeeze().numpy() + + +def get_args(): + parser = argparse.ArgumentParser(description='Predict masks from input images') + parser.add_argument('--model', '-m', default='MODEL.pth', metavar='FILE', + help='Specify the file in which the model is stored') + parser.add_argument('--input', '-i', metavar='INPUT', nargs='+', help='Filenames of input images', required=True) + parser.add_argument('--output', '-o', metavar='OUTPUT', nargs='+', help='Filenames of output images') + parser.add_argument('--viz', '-v', action='store_true', + help='Visualize the images as they are processed') + parser.add_argument('--no-save', '-n', action='store_true', help='Do not save the output masks') + parser.add_argument('--mask-threshold', '-t', type=float, default=0.5, + help='Minimum probability value to consider a mask pixel white') + parser.add_argument('--scale', '-s', type=float, default=0.5, + help='Scale factor for the input images') + parser.add_argument('--bilinear', action='store_true', default=False, help='Use bilinear upsampling') + parser.add_argument('--classes', '-c', type=int, default=2, help='Number of classes') + + return parser.parse_args() + + +def get_output_filenames(args): + def _generate_name(fn): + return f'{os.path.splitext(fn)[0]}_OUT.png' + + return args.output or list(map(_generate_name, args.input)) + + +def mask_to_image(mask: np.ndarray, mask_values): + if isinstance(mask_values[0], list): + out = np.zeros((mask.shape[-2], mask.shape[-1], len(mask_values[0])), dtype=np.uint8) + elif mask_values == [0, 1]: + out = np.zeros((mask.shape[-2], mask.shape[-1]), dtype=bool) + else: + out = np.zeros((mask.shape[-2], mask.shape[-1]), dtype=np.uint8) + + if mask.ndim == 3: + mask = np.argmax(mask, axis=0) + + for i, v in enumerate(mask_values): + out[mask == i] = v + + return Image.fromarray(out) + + +if __name__ == '__main__': + args = get_args() + logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s') + + in_files = args.input + out_files = get_output_filenames(args) + + net = UNet(n_channels=3, n_classes=args.classes, bilinear=args.bilinear) + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + logging.info(f'Loading model {args.model}') + logging.info(f'Using device {device}') + + net.to(device=device) + state_dict = torch.load(args.model, map_location=device) + mask_values = state_dict.pop('mask_values', [0, 1]) + net.load_state_dict(state_dict) + + logging.info('Model loaded!') + + for i, filename in enumerate(in_files): + logging.info(f'Predicting image {filename} ...') + img = Image.open(filename) + + mask = predict_img(net=net, + full_img=img, + scale_factor=args.scale, + out_threshold=args.mask_threshold, + device=device) + + if not args.no_save: + out_filename = out_files[i] + result = mask_to_image(mask, mask_values) + result.save(out_filename) + logging.info(f'Mask saved to {out_filename}') + + if args.viz: + logging.info(f'Visualizing results for image {filename}, close to continue...') + plot_img_and_mask(img, mask) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 000000000..256fb596a --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +matplotlib==3.6.2 +numpy==1.23.5 +Pillow==9.3.0 +tqdm==4.64.1 +wandb==0.13.5 diff --git a/train.py b/train.py new file mode 100644 index 000000000..c7d27e2a3 --- /dev/null +++ b/train.py @@ -0,0 +1,237 @@ +import argparse +import logging +import os +import random +import sys +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.transforms as transforms +import torchvision.transforms.functional as TF +from pathlib import Path +from torch import optim +from torch.utils.data import DataLoader, random_split +from tqdm import tqdm + +import wandb +from evaluate import evaluate +from unet import UNet +from utils.data_loading import BasicDataset, CarvanaDataset +from utils.dice_score import dice_loss + +dir_img = Path('./data/imgs/') +dir_mask = Path('./data/masks/') +dir_checkpoint = Path('./checkpoints/') + + +def train_model( + model, + device, + epochs: int = 5, + batch_size: int = 1, + learning_rate: float = 1e-5, + val_percent: float = 0.1, + save_checkpoint: bool = True, + img_scale: float = 0.5, + amp: bool = False, + weight_decay: float = 1e-8, + momentum: float = 0.999, + gradient_clipping: float = 1.0, +): + # 1. Create dataset + try: + dataset = CarvanaDataset(dir_img, dir_mask, img_scale) + except (AssertionError, RuntimeError, IndexError): + dataset = BasicDataset(dir_img, dir_mask, img_scale) + + # 2. Split into train / validation partitions + n_val = int(len(dataset) * val_percent) + n_train = len(dataset) - n_val + train_set, val_set = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(0)) + + # 3. Create data loaders + loader_args = dict(batch_size=batch_size, num_workers=os.cpu_count(), pin_memory=True) + train_loader = DataLoader(train_set, shuffle=True, **loader_args) + val_loader = DataLoader(val_set, shuffle=False, drop_last=True, **loader_args) + + # (Initialize logging) + experiment = wandb.init(project='U-Net', resume='allow', anonymous='must') + experiment.config.update( + dict(epochs=epochs, batch_size=batch_size, learning_rate=learning_rate, + val_percent=val_percent, save_checkpoint=save_checkpoint, img_scale=img_scale, amp=amp) + ) + + logging.info(f'''Starting training: + Epochs: {epochs} + Batch size: {batch_size} + Learning rate: {learning_rate} + Training size: {n_train} + Validation size: {n_val} + Checkpoints: {save_checkpoint} + Device: {device.type} + Images scaling: {img_scale} + Mixed Precision: {amp} + ''') + + # 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP + optimizer = optim.RMSprop(model.parameters(), + lr=learning_rate, weight_decay=weight_decay, momentum=momentum, foreach=True) + scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=5) # goal: maximize Dice score + grad_scaler = torch.cuda.amp.GradScaler(enabled=amp) + criterion = nn.CrossEntropyLoss() if model.n_classes > 1 else nn.BCEWithLogitsLoss() + global_step = 0 + + # 5. Begin training + for epoch in range(1, epochs + 1): + model.train() + epoch_loss = 0 + with tqdm(total=n_train, desc=f'Epoch {epoch}/{epochs}', unit='img') as pbar: + for batch in train_loader: + images, true_masks = batch['image'], batch['mask'] + + assert images.shape[1] == model.n_channels, \ + f'Network has been defined with {model.n_channels} input channels, ' \ + f'but loaded images have {images.shape[1]} channels. Please check that ' \ + 'the images are loaded correctly.' + + images = images.to(device=device, dtype=torch.float32, memory_format=torch.channels_last) + true_masks = true_masks.to(device=device, dtype=torch.long) + + with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp): + masks_pred = model(images) + if model.n_classes == 1: + loss = criterion(masks_pred.squeeze(1), true_masks.float()) + loss += dice_loss(F.sigmoid(masks_pred.squeeze(1)), true_masks.float(), multiclass=False) + else: + loss = criterion(masks_pred, true_masks) + loss += dice_loss( + F.softmax(masks_pred, dim=1).float(), + F.one_hot(true_masks, model.n_classes).permute(0, 3, 1, 2).float(), + multiclass=True + ) + + optimizer.zero_grad(set_to_none=True) + grad_scaler.scale(loss).backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping) + grad_scaler.step(optimizer) + grad_scaler.update() + + pbar.update(images.shape[0]) + global_step += 1 + epoch_loss += loss.item() + experiment.log({ + 'train loss': loss.item(), + 'step': global_step, + 'epoch': epoch + }) + pbar.set_postfix(**{'loss (batch)': loss.item()}) + + # Evaluation round + division_step = (n_train // (5 * batch_size)) + if division_step > 0: + if global_step % division_step == 0: + histograms = {} + for tag, value in model.named_parameters(): + tag = tag.replace('/', '.') + if not (torch.isinf(value) | torch.isnan(value)).any(): + histograms['Weights/' + tag] = wandb.Histogram(value.data.cpu()) + if not (torch.isinf(value.grad) | torch.isnan(value.grad)).any(): + histograms['Gradients/' + tag] = wandb.Histogram(value.grad.data.cpu()) + + val_score = evaluate(model, val_loader, device, amp) + scheduler.step(val_score) + + logging.info('Validation Dice score: {}'.format(val_score)) + try: + experiment.log({ + 'learning rate': optimizer.param_groups[0]['lr'], + 'validation Dice': val_score, + 'images': wandb.Image(images[0].cpu()), + 'masks': { + 'true': wandb.Image(true_masks[0].float().cpu()), + 'pred': wandb.Image(masks_pred.argmax(dim=1)[0].float().cpu()), + }, + 'step': global_step, + 'epoch': epoch, + **histograms + }) + except: + pass + + if save_checkpoint: + Path(dir_checkpoint).mkdir(parents=True, exist_ok=True) + state_dict = model.state_dict() + state_dict['mask_values'] = dataset.mask_values + torch.save(state_dict, str(dir_checkpoint / 'checkpoint_epoch{}.pth'.format(epoch))) + logging.info(f'Checkpoint {epoch} saved!') + + +def get_args(): + parser = argparse.ArgumentParser(description='Train the UNet on images and target masks') + parser.add_argument('--epochs', '-e', metavar='E', type=int, default=5, help='Number of epochs') + parser.add_argument('--batch-size', '-b', dest='batch_size', metavar='B', type=int, default=1, help='Batch size') + parser.add_argument('--learning-rate', '-l', metavar='LR', type=float, default=1e-5, + help='Learning rate', dest='lr') + parser.add_argument('--load', '-f', type=str, default=False, help='Load model from a .pth file') + parser.add_argument('--scale', '-s', type=float, default=0.5, help='Downscaling factor of the images') + parser.add_argument('--validation', '-v', dest='val', type=float, default=10.0, + help='Percent of the data that is used as validation (0-100)') + parser.add_argument('--amp', action='store_true', default=False, help='Use mixed precision') + parser.add_argument('--bilinear', action='store_true', default=False, help='Use bilinear upsampling') + parser.add_argument('--classes', '-c', type=int, default=2, help='Number of classes') + + return parser.parse_args() + + +if __name__ == '__main__': + args = get_args() + + logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s') + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + logging.info(f'Using device {device}') + + # Change here to adapt to your data + # n_channels=3 for RGB images + # n_classes is the number of probabilities you want to get per pixel + model = UNet(n_channels=3, n_classes=args.classes, bilinear=args.bilinear) + model = model.to(memory_format=torch.channels_last) + + logging.info(f'Network:\n' + f'\t{model.n_channels} input channels\n' + f'\t{model.n_classes} output channels (classes)\n' + f'\t{"Bilinear" if model.bilinear else "Transposed conv"} upscaling') + + if args.load: + state_dict = torch.load(args.load, map_location=device) + del state_dict['mask_values'] + model.load_state_dict(state_dict) + logging.info(f'Model loaded from {args.load}') + + model.to(device=device) + try: + train_model( + model=model, + epochs=args.epochs, + batch_size=args.batch_size, + learning_rate=args.lr, + device=device, + img_scale=args.scale, + val_percent=args.val / 100, + amp=args.amp + ) + except torch.cuda.OutOfMemoryError: + logging.error('Detected OutOfMemoryError! ' + 'Enabling checkpointing to reduce memory usage, but this slows down training. ' + 'Consider enabling AMP (--amp) for fast and memory efficient training') + torch.cuda.empty_cache() + model.use_checkpointing() + train_model( + model=model, + epochs=args.epochs, + batch_size=args.batch_size, + learning_rate=args.lr, + device=device, + img_scale=args.scale, + val_percent=args.val / 100, + amp=args.amp + ) diff --git a/unet/__init__.py b/unet/__init__.py new file mode 100644 index 000000000..2e9b63b67 --- /dev/null +++ b/unet/__init__.py @@ -0,0 +1 @@ +from .unet_model import UNet diff --git a/unet/unet_model.py b/unet/unet_model.py new file mode 100644 index 000000000..caf79f47f --- /dev/null +++ b/unet/unet_model.py @@ -0,0 +1,48 @@ +""" Full assembly of the parts to form the complete network """ + +from .unet_parts import * + + +class UNet(nn.Module): + def __init__(self, n_channels, n_classes, bilinear=False): + super(UNet, self).__init__() + self.n_channels = n_channels + self.n_classes = n_classes + self.bilinear = bilinear + + self.inc = (DoubleConv(n_channels, 64)) + self.down1 = (Down(64, 128)) + self.down2 = (Down(128, 256)) + self.down3 = (Down(256, 512)) + factor = 2 if bilinear else 1 + self.down4 = (Down(512, 1024 // factor)) + self.up1 = (Up(1024, 512 // factor, bilinear)) + self.up2 = (Up(512, 256 // factor, bilinear)) + self.up3 = (Up(256, 128 // factor, bilinear)) + self.up4 = (Up(128, 64, bilinear)) + self.outc = (OutConv(64, n_classes)) + + def forward(self, x): + x1 = self.inc(x) + x2 = self.down1(x1) + x3 = self.down2(x2) + x4 = self.down3(x3) + x5 = self.down4(x4) + x = self.up1(x5, x4) + x = self.up2(x, x3) + x = self.up3(x, x2) + x = self.up4(x, x1) + logits = self.outc(x) + return logits + + def use_checkpointing(self): + self.inc = torch.utils.checkpoint(self.inc) + self.down1 = torch.utils.checkpoint(self.down1) + self.down2 = torch.utils.checkpoint(self.down2) + self.down3 = torch.utils.checkpoint(self.down3) + self.down4 = torch.utils.checkpoint(self.down4) + self.up1 = torch.utils.checkpoint(self.up1) + self.up2 = torch.utils.checkpoint(self.up2) + self.up3 = torch.utils.checkpoint(self.up3) + self.up4 = torch.utils.checkpoint(self.up4) + self.outc = torch.utils.checkpoint(self.outc) \ No newline at end of file diff --git a/unet/unet_parts.py b/unet/unet_parts.py new file mode 100644 index 000000000..986ba251f --- /dev/null +++ b/unet/unet_parts.py @@ -0,0 +1,77 @@ +""" Parts of the U-Net model """ + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class DoubleConv(nn.Module): + """(convolution => [BN] => ReLU) * 2""" + + def __init__(self, in_channels, out_channels, mid_channels=None): + super().__init__() + if not mid_channels: + mid_channels = out_channels + self.double_conv = nn.Sequential( + nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(mid_channels), + nn.ReLU(inplace=True), + nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True) + ) + + def forward(self, x): + return self.double_conv(x) + + +class Down(nn.Module): + """Downscaling with maxpool then double conv""" + + def __init__(self, in_channels, out_channels): + super().__init__() + self.maxpool_conv = nn.Sequential( + nn.MaxPool2d(2), + DoubleConv(in_channels, out_channels) + ) + + def forward(self, x): + return self.maxpool_conv(x) + + +class Up(nn.Module): + """Upscaling then double conv""" + + def __init__(self, in_channels, out_channels, bilinear=True): + super().__init__() + + # if bilinear, use the normal convolutions to reduce the number of channels + if bilinear: + self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) + self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) + else: + self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) + self.conv = DoubleConv(in_channels, out_channels) + + def forward(self, x1, x2): + x1 = self.up(x1) + # input is CHW + diffY = x2.size()[2] - x1.size()[2] + diffX = x2.size()[3] - x1.size()[3] + + x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, + diffY // 2, diffY - diffY // 2]) + # if you have padding issues, see + # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a + # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd + x = torch.cat([x2, x1], dim=1) + return self.conv(x) + + +class OutConv(nn.Module): + def __init__(self, in_channels, out_channels): + super(OutConv, self).__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) + + def forward(self, x): + return self.conv(x) diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/utils/data_loading.py b/utils/data_loading.py new file mode 100644 index 000000000..11296e78b --- /dev/null +++ b/utils/data_loading.py @@ -0,0 +1,117 @@ +import logging +import numpy as np +import torch +from PIL import Image +from functools import lru_cache +from functools import partial +from itertools import repeat +from multiprocessing import Pool +from os import listdir +from os.path import splitext, isfile, join +from pathlib import Path +from torch.utils.data import Dataset +from tqdm import tqdm + + +def load_image(filename): + ext = splitext(filename)[1] + if ext == '.npy': + return Image.fromarray(np.load(filename)) + elif ext in ['.pt', '.pth']: + return Image.fromarray(torch.load(filename).numpy()) + else: + return Image.open(filename) + + +def unique_mask_values(idx, mask_dir, mask_suffix): + mask_file = list(mask_dir.glob(idx + mask_suffix + '.*'))[0] + mask = np.asarray(load_image(mask_file)) + if mask.ndim == 2: + return np.unique(mask) + elif mask.ndim == 3: + mask = mask.reshape(-1, mask.shape[-1]) + return np.unique(mask, axis=0) + else: + raise ValueError(f'Loaded masks should have 2 or 3 dimensions, found {mask.ndim}') + + +class BasicDataset(Dataset): + def __init__(self, images_dir: str, mask_dir: str, scale: float = 1.0, mask_suffix: str = ''): + self.images_dir = Path(images_dir) + self.mask_dir = Path(mask_dir) + assert 0 < scale <= 1, 'Scale must be between 0 and 1' + self.scale = scale + self.mask_suffix = mask_suffix + + self.ids = [splitext(file)[0] for file in listdir(images_dir) if isfile(join(images_dir, file)) and not file.startswith('.')] + if not self.ids: + raise RuntimeError(f'No input file found in {images_dir}, make sure you put your images there') + + logging.info(f'Creating dataset with {len(self.ids)} examples') + logging.info('Scanning mask files to determine unique values') + with Pool() as p: + unique = list(tqdm( + p.imap(partial(unique_mask_values, mask_dir=self.mask_dir, mask_suffix=self.mask_suffix), self.ids), + total=len(self.ids) + )) + + self.mask_values = list(sorted(np.unique(np.concatenate(unique), axis=0).tolist())) + logging.info(f'Unique mask values: {self.mask_values}') + + def __len__(self): + return len(self.ids) + + @staticmethod + def preprocess(mask_values, pil_img, scale, is_mask): + w, h = pil_img.size + newW, newH = int(scale * w), int(scale * h) + assert newW > 0 and newH > 0, 'Scale is too small, resized images would have no pixel' + pil_img = pil_img.resize((newW, newH), resample=Image.NEAREST if is_mask else Image.BICUBIC) + img = np.asarray(pil_img) + + if is_mask: + mask = np.zeros((newH, newW), dtype=np.int64) + for i, v in enumerate(mask_values): + if img.ndim == 2: + mask[img == v] = i + else: + mask[(img == v).all(-1)] = i + + return mask + + else: + if img.ndim == 2: + img = img[np.newaxis, ...] + else: + img = img.transpose((2, 0, 1)) + + if (img > 1).any(): + img = img / 255.0 + + return img + + def __getitem__(self, idx): + name = self.ids[idx] + mask_file = list(self.mask_dir.glob(name + self.mask_suffix + '.*')) + img_file = list(self.images_dir.glob(name + '.*')) + + assert len(img_file) == 1, f'Either no image or multiple images found for the ID {name}: {img_file}' + assert len(mask_file) == 1, f'Either no mask or multiple masks found for the ID {name}: {mask_file}' + mask = load_image(mask_file[0]) + img = load_image(img_file[0]) + + assert img.size == mask.size, \ + f'Image and mask {name} should be the same size, but are {img.size} and {mask.size}' + + img = self.preprocess(self.mask_values, img, self.scale, is_mask=False) + mask = self.preprocess(self.mask_values, mask, self.scale, is_mask=True) + + return { + 'image': torch.as_tensor(img.copy()).float().contiguous(), + 'mask': torch.as_tensor(mask.copy()).long().contiguous() + } + + +class CarvanaDataset(BasicDataset): + def __init__(self, images_dir, mask_dir, scale=1): + super().__init__(images_dir, mask_dir, scale, mask_suffix='_mask') diff --git a/utils/dice_score.py b/utils/dice_score.py new file mode 100644 index 000000000..c89eebee7 --- /dev/null +++ b/utils/dice_score.py @@ -0,0 +1,28 @@ +import torch +from torch import Tensor + + +def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon: float = 1e-6): + # Average of Dice coefficient for all batches, or for a single mask + assert input.size() == target.size() + assert input.dim() == 3 or not reduce_batch_first + + sum_dim = (-1, -2) if input.dim() == 2 or not reduce_batch_first else (-1, -2, -3) + + inter = 2 * (input * target).sum(dim=sum_dim) + sets_sum = input.sum(dim=sum_dim) + target.sum(dim=sum_dim) + sets_sum = torch.where(sets_sum == 0, inter, sets_sum) + + dice = (inter + epsilon) / (sets_sum + epsilon) + return dice.mean() + + +def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon: float = 1e-6): + # Average of Dice coefficient for all classes + return dice_coeff(input.flatten(0, 1), target.flatten(0, 1), reduce_batch_first, epsilon) + + +def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False): + # Dice loss (objective to minimize) between 0 and 1 + fn = multiclass_dice_coeff if multiclass else dice_coeff + return 1 - fn(input, target, reduce_batch_first=True) diff --git a/utils/utils.py b/utils/utils.py new file mode 100644 index 000000000..5e6e04128 --- /dev/null +++ b/utils/utils.py @@ -0,0 +1,13 @@ +import matplotlib.pyplot as plt + + +def plot_img_and_mask(img, mask): + classes = mask.max() + 1 + fig, ax = plt.subplots(1, classes + 1) + ax[0].set_title('Input image') + ax[0].imshow(img) + for i in range(classes): + ax[i + 1].set_title(f'Mask (class {i + 1})') + ax[i + 1].imshow(mask == i) + plt.xticks([]), plt.yticks([]) + plt.show() From 3195294949dfa1967dff6454957d193282ee6c76 Mon Sep 17 00:00:00 2001 From: FrostNov4 <97688822+FrostNov4@users.noreply.github.com> Date: Sun, 15 Oct 2023 17:04:11 +1000 Subject: [PATCH 02/20] update dataloader --- utils/data_loading.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/utils/data_loading.py b/utils/data_loading.py index 11296e78b..186ac00cd 100644 --- a/utils/data_loading.py +++ b/utils/data_loading.py @@ -115,3 +115,8 @@ def __getitem__(self, idx): class CarvanaDataset(BasicDataset): def __init__(self, images_dir, mask_dir, scale=1): super().__init__(images_dir, mask_dir, scale, mask_suffix='_mask') + + +class ISICDataset(BasicDataset): + def __init__(self, images_dir, mask_dir, scale=1): + super().__init__(images_dir, mask_dir, scale, mask_suffix="_segmentation") From cd85663cbd1192d49014ee0891f297df63ad522b Mon Sep 17 00:00:00 2001 From: FrostNov4 <97688822+FrostNov4@users.noreply.github.com> Date: Sun, 15 Oct 2023 17:46:15 +1000 Subject: [PATCH 03/20] Read me update --- README.md | 189 ------------------------------------------------------ 1 file changed, 189 deletions(-) diff --git a/README.md b/README.md index cb4d4c986..e69de29bb 100644 --- a/README.md +++ b/README.md @@ -1,189 +0,0 @@ -# U-Net: Semantic segmentation with PyTorch - - - - - -![input and output for a random image in the test dataset](https://i.imgur.com/GD8FcB7.png) - - -Customized implementation of the [U-Net](https://arxiv.org/abs/1505.04597) in PyTorch for Kaggle's [Carvana Image Masking Challenge](https://www.kaggle.com/c/carvana-image-masking-challenge) from high definition images. - -- [Quick start](#quick-start) - - [Without Docker](#without-docker) - - [With Docker](#with-docker) -- [Description](#description) -- [Usage](#usage) - - [Docker](#docker) - - [Training](#training) - - [Prediction](#prediction) -- [Weights & Biases](#weights--biases) -- [Pretrained model](#pretrained-model) -- [Data](#data) - -## Quick start - -### Without Docker - -1. [Install CUDA](https://developer.nvidia.com/cuda-downloads) - -2. [Install PyTorch 1.13 or later](https://pytorch.org/get-started/locally/) - -3. Install dependencies -```bash -pip install -r requirements.txt -``` - -4. Download the data and run training: -```bash -bash scripts/download_data.sh -python train.py --amp -``` - -### With Docker - -1. [Install Docker 19.03 or later:](https://docs.docker.com/get-docker/) -```bash -curl https://get.docker.com | sh && sudo systemctl --now enable docker -``` -2. [Install the NVIDIA container toolkit:](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) -```bash -distribution=$(. /etc/os-release;echo $ID$VERSION_ID) \ - && curl -s -L https://nvidia.github.io/nvidia-docker/gpgkey | sudo apt-key add - \ - && curl -s -L https://nvidia.github.io/nvidia-docker/$distribution/nvidia-docker.list | sudo tee /etc/apt/sources.list.d/nvidia-docker.list -sudo apt-get update -sudo apt-get install -y nvidia-docker2 -sudo systemctl restart docker -``` -3. [Download and run the image:](https://hub.docker.com/repository/docker/milesial/unet) -```bash -sudo docker run --rm --shm-size=8g --ulimit memlock=-1 --gpus all -it milesial/unet -``` - -4. Download the data and run training: -```bash -bash scripts/download_data.sh -python train.py --amp -``` - -## Description -This model was trained from scratch with 5k images and scored a [Dice coefficient](https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient) of 0.988423 on over 100k test images. - -It can be easily used for multiclass segmentation, portrait segmentation, medical segmentation, ... - - -## Usage -**Note : Use Python 3.6 or newer** - -### Docker - -A docker image containing the code and the dependencies is available on [DockerHub](https://hub.docker.com/repository/docker/milesial/unet). -You can download and jump in the container with ([docker >=19.03](https://docs.docker.com/get-docker/)): - -```console -docker run -it --rm --shm-size=8g --ulimit memlock=-1 --gpus all milesial/unet -``` - - -### Training - -```console -> python train.py -h -usage: train.py [-h] [--epochs E] [--batch-size B] [--learning-rate LR] - [--load LOAD] [--scale SCALE] [--validation VAL] [--amp] - -Train the UNet on images and target masks - -optional arguments: - -h, --help show this help message and exit - --epochs E, -e E Number of epochs - --batch-size B, -b B Batch size - --learning-rate LR, -l LR - Learning rate - --load LOAD, -f LOAD Load model from a .pth file - --scale SCALE, -s SCALE - Downscaling factor of the images - --validation VAL, -v VAL - Percent of the data that is used as validation (0-100) - --amp Use mixed precision -``` - -By default, the `scale` is 0.5, so if you wish to obtain better results (but use more memory), set it to 1. - -Automatic mixed precision is also available with the `--amp` flag. [Mixed precision](https://arxiv.org/abs/1710.03740) allows the model to use less memory and to be faster on recent GPUs by using FP16 arithmetic. Enabling AMP is recommended. - - -### Prediction - -After training your model and saving it to `MODEL.pth`, you can easily test the output masks on your images via the CLI. - -To predict a single image and save it: - -`python predict.py -i image.jpg -o output.jpg` - -To predict a multiple images and show them without saving them: - -`python predict.py -i image1.jpg image2.jpg --viz --no-save` - -```console -> python predict.py -h -usage: predict.py [-h] [--model FILE] --input INPUT [INPUT ...] - [--output INPUT [INPUT ...]] [--viz] [--no-save] - [--mask-threshold MASK_THRESHOLD] [--scale SCALE] - -Predict masks from input images - -optional arguments: - -h, --help show this help message and exit - --model FILE, -m FILE - Specify the file in which the model is stored - --input INPUT [INPUT ...], -i INPUT [INPUT ...] - Filenames of input images - --output INPUT [INPUT ...], -o INPUT [INPUT ...] - Filenames of output images - --viz, -v Visualize the images as they are processed - --no-save, -n Do not save the output masks - --mask-threshold MASK_THRESHOLD, -t MASK_THRESHOLD - Minimum probability value to consider a mask pixel white - --scale SCALE, -s SCALE - Scale factor for the input images -``` -You can specify which model file to use with `--model MODEL.pth`. - -## Weights & Biases - -The training progress can be visualized in real-time using [Weights & Biases](https://wandb.ai/). Loss curves, validation curves, weights and gradient histograms, as well as predicted masks are logged to the platform. - -When launching a training, a link will be printed in the console. Click on it to go to your dashboard. If you have an existing W&B account, you can link it - by setting the `WANDB_API_KEY` environment variable. If not, it will create an anonymous run which is automatically deleted after 7 days. - - -## Pretrained model -A [pretrained model](https://github.com/milesial/Pytorch-UNet/releases/tag/v3.0) is available for the Carvana dataset. It can also be loaded from torch.hub: - -```python -net = torch.hub.load('milesial/Pytorch-UNet', 'unet_carvana', pretrained=True, scale=0.5) -``` -Available scales are 0.5 and 1.0. - -## Data -The Carvana data is available on the [Kaggle website](https://www.kaggle.com/c/carvana-image-masking-challenge/data). - -You can also download it using the helper script: - -``` -bash scripts/download_data.sh -``` - -The input images and target masks should be in the `data/imgs` and `data/masks` folders respectively (note that the `imgs` and `masks` folder should not contain any sub-folder or any other files, due to the greedy data-loader). For Carvana, images are RGB and masks are black and white. - -You can use your own dataset as long as you make sure it is loaded properly in `utils/data_loading.py`. - - ---- - -Original paper by Olaf Ronneberger, Philipp Fischer, Thomas Brox: - -[U-Net: Convolutional Networks for Biomedical Image Segmentation](https://arxiv.org/abs/1505.04597) - -![network architecture](https://i.imgur.com/jeDVpqF.png) From 4ff482d4aa43a35b3ea752f44669e41270725a5a Mon Sep 17 00:00:00 2001 From: FrostNov4 <97688822+FrostNov4@users.noreply.github.com> Date: Tue, 17 Oct 2023 23:16:21 +1000 Subject: [PATCH 04/20] Readme updated --- README.md | 16 ++++++++++++++++ train.py | 1 + 2 files changed, 17 insertions(+) diff --git a/README.md b/README.md index e69de29bb..a2129c473 100644 --- a/README.md +++ b/README.md @@ -0,0 +1,16 @@ +COMP_3710_Report + +This report is focus on the first task (a) + +Segment the ISIC data set with the Improved UNet +with all labels having a minimum Dice similarity coefficient of 0.8 on the test set. + +The structure of Improved UNet is based on the paper. +"Brain Tumor Segmentation and Radiomics Survival Prediction: Contribution to the BRATS 2017 Challenge" +https://arxiv.org/abs/1802.10508v1 + + + + + + diff --git a/train.py b/train.py index c7d27e2a3..5793fc0b8 100644 --- a/train.py +++ b/train.py @@ -234,4 +234,5 @@ def get_args(): img_scale=args.scale, val_percent=args.val / 100, amp=args.amp + ) From 5a15fd175b3c476686293269339a7221603b8ae1 Mon Sep 17 00:00:00 2001 From: FrostNov4 <97688822+FrostNov4@users.noreply.github.com> Date: Mon, 23 Oct 2023 00:19:20 +1000 Subject: [PATCH 05/20] Model updated. Add ContextModule --- unet/unet_model.py | 87 +++++++++++++++++++++++----------------------- 1 file changed, 44 insertions(+), 43 deletions(-) diff --git a/unet/unet_model.py b/unet/unet_model.py index caf79f47f..83bb4a321 100644 --- a/unet/unet_model.py +++ b/unet/unet_model.py @@ -3,46 +3,47 @@ from .unet_parts import * -class UNet(nn.Module): - def __init__(self, n_channels, n_classes, bilinear=False): - super(UNet, self).__init__() - self.n_channels = n_channels - self.n_classes = n_classes - self.bilinear = bilinear - - self.inc = (DoubleConv(n_channels, 64)) - self.down1 = (Down(64, 128)) - self.down2 = (Down(128, 256)) - self.down3 = (Down(256, 512)) - factor = 2 if bilinear else 1 - self.down4 = (Down(512, 1024 // factor)) - self.up1 = (Up(1024, 512 // factor, bilinear)) - self.up2 = (Up(512, 256 // factor, bilinear)) - self.up3 = (Up(256, 128 // factor, bilinear)) - self.up4 = (Up(128, 64, bilinear)) - self.outc = (OutConv(64, n_classes)) - - def forward(self, x): - x1 = self.inc(x) - x2 = self.down1(x1) - x3 = self.down2(x2) - x4 = self.down3(x3) - x5 = self.down4(x4) - x = self.up1(x5, x4) - x = self.up2(x, x3) - x = self.up3(x, x2) - x = self.up4(x, x1) - logits = self.outc(x) - return logits - - def use_checkpointing(self): - self.inc = torch.utils.checkpoint(self.inc) - self.down1 = torch.utils.checkpoint(self.down1) - self.down2 = torch.utils.checkpoint(self.down2) - self.down3 = torch.utils.checkpoint(self.down3) - self.down4 = torch.utils.checkpoint(self.down4) - self.up1 = torch.utils.checkpoint(self.up1) - self.up2 = torch.utils.checkpoint(self.up2) - self.up3 = torch.utils.checkpoint(self.up3) - self.up4 = torch.utils.checkpoint(self.up4) - self.outc = torch.utils.checkpoint(self.outc) \ No newline at end of file + +class ContextModule(nn.Module): + """ + Context Module: Consists of two convolutional layers for feature extraction and a dropout layer + for regularization, aimed at capturing and preserving the context information in the features. + """ + def __init__(self, in_channels, out_channels): + """ + Initialize the Context Module. + + Parameters: + - in_channels (int): Number of input channels. + - out_channels (int): Number of output channels. + """ + super(ContextModule, self).__init__() + # 2 3x3 convolution layer followed by instance normalization and leaky ReLU activation + self.conv1 = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + ) + self.conv2 = nn.Sequential( + nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + ) + # Dropout layer to prevent overfitting + self.dropout = nn.Dropout2d(p=0.3) + + def forward(self, x): + """ + Forward pass through the context module. Input is put through 2 3x3 stride 1 convolutions with a dropout + layer in between + + Parameters: + - x (Tensor): The input tensor. + + Returns: + - Tensor: The output tensor after passing through the context module. + """ + x = self.conv1(x) + x = self.dropout(x) + x = self.conv2(x) + return x \ No newline at end of file From 8f5b5ad9b128a2f47110b3fb6d5e733cd20519f2 Mon Sep 17 00:00:00 2001 From: FrostNov4 <97688822+FrostNov4@users.noreply.github.com> Date: Mon, 23 Oct 2023 00:23:29 +1000 Subject: [PATCH 06/20] Model updated. Add Segmentation Layer --- unet/unet_model.py | 36 +++++++++++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/unet/unet_model.py b/unet/unet_model.py index 83bb4a321..fce578625 100644 --- a/unet/unet_model.py +++ b/unet/unet_model.py @@ -46,4 +46,38 @@ def forward(self, x): x = self.conv1(x) x = self.dropout(x) x = self.conv2(x) - return x \ No newline at end of file + return x + + +class SegmentationLayer(nn.Module): + """ + SegmentationLayer: A convolutional layer specifically utilized to generate a segmentation map. + """ + + def __init__(self, in_channels, out_channels): + """ + Initialize the SegmentationLayer. + + Parameters: + - in_channels (int): Number of input channels. + - out_channels (int): Number of output channels, often equal to the number of classes in segmentation. + """ + super(SegmentationLayer, self).__init__() + # A convolutional layer that produces segmentation map + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + """ + Forward pass through the SegmentationLayer. + + Parameters: + - x (Tensor): The input tensor. + + Returns: + - Tensor: The output tensor after applying the convolution, serving as a segmentation map. + """ + # Applying convolution + x = self.conv(x) + + return x + From b203d1627b2368b53bab8d4809988acdf6df8787 Mon Sep 17 00:00:00 2001 From: FrostNov4 <97688822+FrostNov4@users.noreply.github.com> Date: Mon, 23 Oct 2023 00:29:16 +1000 Subject: [PATCH 07/20] Model updated. Add Up sampling Layer --- unet/unet_model.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/unet/unet_model.py b/unet/unet_model.py index fce578625..631d43c95 100644 --- a/unet/unet_model.py +++ b/unet/unet_model.py @@ -78,6 +78,37 @@ def forward(self, x): """ # Applying convolution x = self.conv(x) + return x + + +class UpscalingLayer(nn.Module): + """ + UpscalingLayer: A layer designed to upscale feature maps by a factor of 2. + """ + + def __init__(self, scale_factor=2, mode='nearest'): + """ + Initialize the UpscalingLayer. + + Parameters: + - scale_factor (int, optional): Factor by which to upscale the input. Default is 2. + - mode (str, optional): Algorithm used for upscaling: 'nearest', 'bilinear', etc. Default is 'nearest'. + """ + super(UpscalingLayer, self).__init__() + # An upsampling layer that increases the spatial dimensions of the feature map + self.upsample = nn.Upsample(scale_factor=scale_factor, mode=mode) + def forward(self, x): + """ + Forward pass through the UpscalingLayer. + + Parameters: + - x (Tensor): The input tensor. + + Returns: + - Tensor: The output tensor after applying the upscaling, having increased spatial dimensions. + """ + # Applying upscaling + x = self.upsample(x) return x From 10ef43910aa9180d7f8124d0ab7e796854837067 Mon Sep 17 00:00:00 2001 From: FrostNov4 <97688822+FrostNov4@users.noreply.github.com> Date: Mon, 23 Oct 2023 00:29:38 +1000 Subject: [PATCH 08/20] Model updated. Add UpscalingLayer --- unet/unet_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unet/unet_model.py b/unet/unet_model.py index 631d43c95..cff643707 100644 --- a/unet/unet_model.py +++ b/unet/unet_model.py @@ -86,6 +86,7 @@ class UpscalingLayer(nn.Module): UpscalingLayer: A layer designed to upscale feature maps by a factor of 2. """ + def __init__(self, scale_factor=2, mode='nearest'): """ Initialize the UpscalingLayer. From 77cfa6c16e387c64e53f14d8255f23067aa43544 Mon Sep 17 00:00:00 2001 From: FrostNov4 <97688822+FrostNov4@users.noreply.github.com> Date: Mon, 23 Oct 2023 00:31:18 +1000 Subject: [PATCH 09/20] Update README.md --- README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index a2129c473..07dab1014 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,8 @@ COMP_3710_Report -This report is focus on the first task (a) +Medical condition, Extended to 27 OCT + +This report is focused on the first task (a) Segment the ISIC data set with the Improved UNet with all labels having a minimum Dice similarity coefficient of 0.8 on the test set. From d853af72d1ee19b15eb6d840bc60baad522763a6 Mon Sep 17 00:00:00 2001 From: FrostNov4 <97688822+FrostNov4@users.noreply.github.com> Date: Mon, 23 Oct 2023 15:38:11 +1000 Subject: [PATCH 10/20] Model updated. Add Localisation module and Upsampling module --- unet/unet_model.py | 79 ++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 77 insertions(+), 2 deletions(-) diff --git a/unet/unet_model.py b/unet/unet_model.py index cff643707..85d26b2c8 100644 --- a/unet/unet_model.py +++ b/unet/unet_model.py @@ -3,12 +3,12 @@ from .unet_parts import * - class ContextModule(nn.Module): """ Context Module: Consists of two convolutional layers for feature extraction and a dropout layer for regularization, aimed at capturing and preserving the context information in the features. """ + def __init__(self, in_channels, out_channels): """ Initialize the Context Module. @@ -86,7 +86,6 @@ class UpscalingLayer(nn.Module): UpscalingLayer: A layer designed to upscale feature maps by a factor of 2. """ - def __init__(self, scale_factor=2, mode='nearest'): """ Initialize the UpscalingLayer. @@ -113,3 +112,79 @@ def forward(self, x): x = self.upsample(x) return x + +class LocalisationModule(nn.Module): + """ + Localisation Module: Focused on up-sampling the received feature map and reducing the + number of feature channels, working towards recovering the spatial resolution of the input data. + """ + + def __init__(self, in_channels, out_channels): + """ + Initialize the Localisation Module. + + Parameters: + - in_channels (int): Number of input channels. + - out_channels (int): Number of output channels. + """ + super(LocalisationModule, self).__init__() + # Using a simple upscale by repeating the feature pixels twice + self.upsample = nn.Upsample(scale_factor=2, mode='nearest') + # 3x3 convolution to process concatenated features + self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + # 1x1 convolution to reduce the number of feature maps + self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + """ + Forward pass through the localisation module. Input is put through 2 3x3 stride 1 convolutions + with leaky ReLU applied in between + + Parameters: + - x (Tensor): The input tensor. + + Returns: + - Tensor: The output tensor after passing through the localisation module. + """ + x = self.conv1(x) + x = F.relu(x) + x = self.conv2(x) + x = F.relu(x) + return x + + +class UpsamplingModule(nn.Module): + """ + Upsampling Module: Handles the up-sampling of feature maps in the decoder part of the UNet, + contributing to incrementing the spatial dimensions of the input feature map. + """ + + def __init__(self, in_channels, out_channels): + """ + Initialize the Upsampling Module. + + Parameters: + - in_channels (int): Number of input channels. + - out_channels (int): Number of output channels. + """ + super(UpsamplingModule, self).__init__() + # Using a simple upscale by repeating the feature pixels twice + self.upsample = nn.Upsample(scale_factor=2, mode='nearest') + # 3x3 convolution that halves the number of feature maps + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + """ + Forward pass through the upsampling module. First the input is upsampled, then undergoes stride 1 + 3x3 convolution followed by leaky ReLU. + + Parameters: + - x (Tensor): The input tensor. + + Returns: + - Tensor: The output tensor after passing through the upsampling module. + """ + x = self.upsample(x) + x = self.conv(x) + x = F.relu(x) + return x From 7f542ba46f79b2a1f457054ba9154ffae16448cf Mon Sep 17 00:00:00 2001 From: FrostNov4 <97688822+FrostNov4@users.noreply.github.com> Date: Mon, 23 Oct 2023 16:30:53 +1000 Subject: [PATCH 11/20] Model updated. Add Localisation module and Upsampling module --- unet/unet_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unet/unet_model.py b/unet/unet_model.py index 85d26b2c8..6ca182770 100644 --- a/unet/unet_model.py +++ b/unet/unet_model.py @@ -187,4 +187,5 @@ def forward(self, x): x = self.upsample(x) x = self.conv(x) x = F.relu(x) + return x From b2781c5c439ae1fe9bcb491f4e3120358a1d9b23 Mon Sep 17 00:00:00 2001 From: FrostNov4 <97688822+FrostNov4@users.noreply.github.com> Date: Thu, 26 Oct 2023 02:38:04 +1000 Subject: [PATCH 12/20] Model updated. Add UpscalingLayer --- unet/unet_model.py | 115 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 115 insertions(+) diff --git a/unet/unet_model.py b/unet/unet_model.py index 6ca182770..0ae879415 100644 --- a/unet/unet_model.py +++ b/unet/unet_model.py @@ -189,3 +189,118 @@ def forward(self, x): x = F.relu(x) return x + + +class UNet_For_Brain(nn.Module): + """ + UNet2D: An Improved U-Net model implmented as the provided Improved U-Net documentation. + """ + + def __init__(self, in_channels, num_classes=2): + """ + Initialize the UNet2D model. + + Parameters: + - in_channels (int): Number of input channels. + - num_classes (int): Number of output classes for segmentation. + """ + super(UNet_For_Brain, self).__init__() + + # Encoder + self.enc1 = nn.Conv2d(in_channels, 16 * 4, kernel_size=3, stride=1, padding=1) + self.context1 = ContextModule(16 * 4, 16 * 4) + self.enc2 = nn.Conv2d(16 * 4, 32 * 4, kernel_size=3, stride=2, padding=1) + self.context2 = ContextModule(32 * 4, 32 * 4) + self.enc3 = nn.Conv2d(32 * 4, 64 * 4, kernel_size=3, stride=2, padding=1) + self.context3 = ContextModule(64 * 4, 64 * 4) + self.enc4 = nn.Conv2d(64 * 4, 128 * 4, kernel_size=3, stride=2, padding=1) + self.context4 = ContextModule(128 * 4, 128 * 4) + + # Bottleneck + self.bottleneck = nn.Conv2d(128 * 4, 256 * 4, kernel_size=3, stride=2, padding=1) + self.bottleneck_context = ContextModule(256 * 4, 256 * 4) + self.up_bottleneck = UpsamplingModule(256 * 4, 128 * 4) + + # Decoder + self.local1 = LocalisationModule(256 * 4, 128 * 4) + self.up1 = UpsamplingModule(128 * 4, 64 * 4) + + self.local2 = LocalisationModule(128 * 4, 64 * 4) + self.up2 = UpsamplingModule(64 * 4, 32 * 4) + + self.seg1 = SegmentationLayer(64 * 4, num_classes) + self.upsample_seg1 = UpscalingLayer() + + self.local3 = LocalisationModule(64 * 4, 32 * 4) + self.up3 = UpsamplingModule(32 * 4, 16 * 4) + + self.seg2 = SegmentationLayer(32 * 4, num_classes) + self.upsample_seg2 = UpscalingLayer() + + self.final_conv = nn.Conv2d(32 * 4, 32 * 4, kernel_size=3, stride=1, padding=1) + + self.seg3 = SegmentationLayer(32 * 4, num_classes) + self.upsample_seg3 = UpscalingLayer() + + def forward(self, x): + """ + Define the forward pass through the UNet2D model. + + Parameters: + - x (Tensor): Input tensor. + + Returns: + - Tensor: The output tensor after passing through the model. + """ + y1 = self.enc1(x) + x1 = self.context1(y1) + x1 = x1 + y1 + + y2 = self.enc2(x1) + x2 = self.context2(y2) + x2 = x2 + y2 + + y3 = self.enc3(x2) + x3 = self.context3(y3) + x3 = x3 + y3 + + y4 = self.enc4(x3) + x4 = self.context4(y4) + x4 = x4 + y4 + + # Bottleneck + bottleneck_conv = self.bottleneck(x4) + + bottleneck = self.bottleneck_context(bottleneck_conv) + bottleneck = bottleneck + bottleneck_conv + + up_bottleneck = self.up_bottleneck(bottleneck) + + # Decoder + x = self.local1(torch.cat((x4, up_bottleneck), dim=1)) + x = self.up1(x) + + x = self.local2(torch.cat((x3, x), dim=1)) + seg1 = self.seg1(x) + x = self.up2(x) + + seg1_upsampled = self.upsample_seg1(seg1) + + x = self.local3(torch.cat((x2, x), dim=1)) + seg2 = self.seg2(x) + x = self.up3(x) + + seg12 = seg1_upsampled + seg2 + seg12_up = self.upsample_seg2(seg12) + + x = self.final_conv(torch.cat((x1, x), dim=1)) + + seg3 = self.seg3(x) + seg123 = seg3 + seg12_up + + out = seg123 + # out = nn.functional.softmax(seg123, dim=1) + # out = torch.sigmoid(seg123) + # print("out shape: ", out.size()) + + return out From ab7a362694d325d473372284cafe38b2ed16f58d Mon Sep 17 00:00:00 2001 From: FrostNov4 <97688822+FrostNov4@users.noreply.github.com> Date: Thu, 26 Oct 2023 03:05:51 +1000 Subject: [PATCH 13/20] Model updated. Add UpscalingLayer --- unet/__init__.py | 1 - unet/unet_model.py | 306 ------------------------------------------ unet/unet_parts.py | 77 ----------- utils/__init__.py | 0 utils/data_loading.py | 122 ----------------- utils/dice_score.py | 28 ---- utils/utils.py | 13 -- 7 files changed, 547 deletions(-) delete mode 100644 unet/__init__.py delete mode 100644 unet/unet_model.py delete mode 100644 unet/unet_parts.py delete mode 100644 utils/__init__.py delete mode 100644 utils/data_loading.py delete mode 100644 utils/dice_score.py delete mode 100644 utils/utils.py diff --git a/unet/__init__.py b/unet/__init__.py deleted file mode 100644 index 2e9b63b67..000000000 --- a/unet/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .unet_model import UNet diff --git a/unet/unet_model.py b/unet/unet_model.py deleted file mode 100644 index 0ae879415..000000000 --- a/unet/unet_model.py +++ /dev/null @@ -1,306 +0,0 @@ -""" Full assembly of the parts to form the complete network """ - -from .unet_parts import * - - -class ContextModule(nn.Module): - """ - Context Module: Consists of two convolutional layers for feature extraction and a dropout layer - for regularization, aimed at capturing and preserving the context information in the features. - """ - - def __init__(self, in_channels, out_channels): - """ - Initialize the Context Module. - - Parameters: - - in_channels (int): Number of input channels. - - out_channels (int): Number of output channels. - """ - super(ContextModule, self).__init__() - # 2 3x3 convolution layer followed by instance normalization and leaky ReLU activation - self.conv1 = nn.Sequential( - nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1), - nn.BatchNorm2d(out_channels), - nn.ReLU(inplace=True), - ) - self.conv2 = nn.Sequential( - nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), - nn.BatchNorm2d(out_channels), - nn.ReLU(inplace=True), - ) - # Dropout layer to prevent overfitting - self.dropout = nn.Dropout2d(p=0.3) - - def forward(self, x): - """ - Forward pass through the context module. Input is put through 2 3x3 stride 1 convolutions with a dropout - layer in between - - Parameters: - - x (Tensor): The input tensor. - - Returns: - - Tensor: The output tensor after passing through the context module. - """ - x = self.conv1(x) - x = self.dropout(x) - x = self.conv2(x) - return x - - -class SegmentationLayer(nn.Module): - """ - SegmentationLayer: A convolutional layer specifically utilized to generate a segmentation map. - """ - - def __init__(self, in_channels, out_channels): - """ - Initialize the SegmentationLayer. - - Parameters: - - in_channels (int): Number of input channels. - - out_channels (int): Number of output channels, often equal to the number of classes in segmentation. - """ - super(SegmentationLayer, self).__init__() - # A convolutional layer that produces segmentation map - self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) - - def forward(self, x): - """ - Forward pass through the SegmentationLayer. - - Parameters: - - x (Tensor): The input tensor. - - Returns: - - Tensor: The output tensor after applying the convolution, serving as a segmentation map. - """ - # Applying convolution - x = self.conv(x) - return x - - -class UpscalingLayer(nn.Module): - """ - UpscalingLayer: A layer designed to upscale feature maps by a factor of 2. - """ - - def __init__(self, scale_factor=2, mode='nearest'): - """ - Initialize the UpscalingLayer. - - Parameters: - - scale_factor (int, optional): Factor by which to upscale the input. Default is 2. - - mode (str, optional): Algorithm used for upscaling: 'nearest', 'bilinear', etc. Default is 'nearest'. - """ - super(UpscalingLayer, self).__init__() - # An upsampling layer that increases the spatial dimensions of the feature map - self.upsample = nn.Upsample(scale_factor=scale_factor, mode=mode) - - def forward(self, x): - """ - Forward pass through the UpscalingLayer. - - Parameters: - - x (Tensor): The input tensor. - - Returns: - - Tensor: The output tensor after applying the upscaling, having increased spatial dimensions. - """ - # Applying upscaling - x = self.upsample(x) - return x - - -class LocalisationModule(nn.Module): - """ - Localisation Module: Focused on up-sampling the received feature map and reducing the - number of feature channels, working towards recovering the spatial resolution of the input data. - """ - - def __init__(self, in_channels, out_channels): - """ - Initialize the Localisation Module. - - Parameters: - - in_channels (int): Number of input channels. - - out_channels (int): Number of output channels. - """ - super(LocalisationModule, self).__init__() - # Using a simple upscale by repeating the feature pixels twice - self.upsample = nn.Upsample(scale_factor=2, mode='nearest') - # 3x3 convolution to process concatenated features - self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) - # 1x1 convolution to reduce the number of feature maps - self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) - - def forward(self, x): - """ - Forward pass through the localisation module. Input is put through 2 3x3 stride 1 convolutions - with leaky ReLU applied in between - - Parameters: - - x (Tensor): The input tensor. - - Returns: - - Tensor: The output tensor after passing through the localisation module. - """ - x = self.conv1(x) - x = F.relu(x) - x = self.conv2(x) - x = F.relu(x) - return x - - -class UpsamplingModule(nn.Module): - """ - Upsampling Module: Handles the up-sampling of feature maps in the decoder part of the UNet, - contributing to incrementing the spatial dimensions of the input feature map. - """ - - def __init__(self, in_channels, out_channels): - """ - Initialize the Upsampling Module. - - Parameters: - - in_channels (int): Number of input channels. - - out_channels (int): Number of output channels. - """ - super(UpsamplingModule, self).__init__() - # Using a simple upscale by repeating the feature pixels twice - self.upsample = nn.Upsample(scale_factor=2, mode='nearest') - # 3x3 convolution that halves the number of feature maps - self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) - - def forward(self, x): - """ - Forward pass through the upsampling module. First the input is upsampled, then undergoes stride 1 - 3x3 convolution followed by leaky ReLU. - - Parameters: - - x (Tensor): The input tensor. - - Returns: - - Tensor: The output tensor after passing through the upsampling module. - """ - x = self.upsample(x) - x = self.conv(x) - x = F.relu(x) - - return x - - -class UNet_For_Brain(nn.Module): - """ - UNet2D: An Improved U-Net model implmented as the provided Improved U-Net documentation. - """ - - def __init__(self, in_channels, num_classes=2): - """ - Initialize the UNet2D model. - - Parameters: - - in_channels (int): Number of input channels. - - num_classes (int): Number of output classes for segmentation. - """ - super(UNet_For_Brain, self).__init__() - - # Encoder - self.enc1 = nn.Conv2d(in_channels, 16 * 4, kernel_size=3, stride=1, padding=1) - self.context1 = ContextModule(16 * 4, 16 * 4) - self.enc2 = nn.Conv2d(16 * 4, 32 * 4, kernel_size=3, stride=2, padding=1) - self.context2 = ContextModule(32 * 4, 32 * 4) - self.enc3 = nn.Conv2d(32 * 4, 64 * 4, kernel_size=3, stride=2, padding=1) - self.context3 = ContextModule(64 * 4, 64 * 4) - self.enc4 = nn.Conv2d(64 * 4, 128 * 4, kernel_size=3, stride=2, padding=1) - self.context4 = ContextModule(128 * 4, 128 * 4) - - # Bottleneck - self.bottleneck = nn.Conv2d(128 * 4, 256 * 4, kernel_size=3, stride=2, padding=1) - self.bottleneck_context = ContextModule(256 * 4, 256 * 4) - self.up_bottleneck = UpsamplingModule(256 * 4, 128 * 4) - - # Decoder - self.local1 = LocalisationModule(256 * 4, 128 * 4) - self.up1 = UpsamplingModule(128 * 4, 64 * 4) - - self.local2 = LocalisationModule(128 * 4, 64 * 4) - self.up2 = UpsamplingModule(64 * 4, 32 * 4) - - self.seg1 = SegmentationLayer(64 * 4, num_classes) - self.upsample_seg1 = UpscalingLayer() - - self.local3 = LocalisationModule(64 * 4, 32 * 4) - self.up3 = UpsamplingModule(32 * 4, 16 * 4) - - self.seg2 = SegmentationLayer(32 * 4, num_classes) - self.upsample_seg2 = UpscalingLayer() - - self.final_conv = nn.Conv2d(32 * 4, 32 * 4, kernel_size=3, stride=1, padding=1) - - self.seg3 = SegmentationLayer(32 * 4, num_classes) - self.upsample_seg3 = UpscalingLayer() - - def forward(self, x): - """ - Define the forward pass through the UNet2D model. - - Parameters: - - x (Tensor): Input tensor. - - Returns: - - Tensor: The output tensor after passing through the model. - """ - y1 = self.enc1(x) - x1 = self.context1(y1) - x1 = x1 + y1 - - y2 = self.enc2(x1) - x2 = self.context2(y2) - x2 = x2 + y2 - - y3 = self.enc3(x2) - x3 = self.context3(y3) - x3 = x3 + y3 - - y4 = self.enc4(x3) - x4 = self.context4(y4) - x4 = x4 + y4 - - # Bottleneck - bottleneck_conv = self.bottleneck(x4) - - bottleneck = self.bottleneck_context(bottleneck_conv) - bottleneck = bottleneck + bottleneck_conv - - up_bottleneck = self.up_bottleneck(bottleneck) - - # Decoder - x = self.local1(torch.cat((x4, up_bottleneck), dim=1)) - x = self.up1(x) - - x = self.local2(torch.cat((x3, x), dim=1)) - seg1 = self.seg1(x) - x = self.up2(x) - - seg1_upsampled = self.upsample_seg1(seg1) - - x = self.local3(torch.cat((x2, x), dim=1)) - seg2 = self.seg2(x) - x = self.up3(x) - - seg12 = seg1_upsampled + seg2 - seg12_up = self.upsample_seg2(seg12) - - x = self.final_conv(torch.cat((x1, x), dim=1)) - - seg3 = self.seg3(x) - seg123 = seg3 + seg12_up - - out = seg123 - # out = nn.functional.softmax(seg123, dim=1) - # out = torch.sigmoid(seg123) - # print("out shape: ", out.size()) - - return out diff --git a/unet/unet_parts.py b/unet/unet_parts.py deleted file mode 100644 index 986ba251f..000000000 --- a/unet/unet_parts.py +++ /dev/null @@ -1,77 +0,0 @@ -""" Parts of the U-Net model """ - -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class DoubleConv(nn.Module): - """(convolution => [BN] => ReLU) * 2""" - - def __init__(self, in_channels, out_channels, mid_channels=None): - super().__init__() - if not mid_channels: - mid_channels = out_channels - self.double_conv = nn.Sequential( - nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), - nn.BatchNorm2d(mid_channels), - nn.ReLU(inplace=True), - nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), - nn.BatchNorm2d(out_channels), - nn.ReLU(inplace=True) - ) - - def forward(self, x): - return self.double_conv(x) - - -class Down(nn.Module): - """Downscaling with maxpool then double conv""" - - def __init__(self, in_channels, out_channels): - super().__init__() - self.maxpool_conv = nn.Sequential( - nn.MaxPool2d(2), - DoubleConv(in_channels, out_channels) - ) - - def forward(self, x): - return self.maxpool_conv(x) - - -class Up(nn.Module): - """Upscaling then double conv""" - - def __init__(self, in_channels, out_channels, bilinear=True): - super().__init__() - - # if bilinear, use the normal convolutions to reduce the number of channels - if bilinear: - self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) - self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) - else: - self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) - self.conv = DoubleConv(in_channels, out_channels) - - def forward(self, x1, x2): - x1 = self.up(x1) - # input is CHW - diffY = x2.size()[2] - x1.size()[2] - diffX = x2.size()[3] - x1.size()[3] - - x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, - diffY // 2, diffY - diffY // 2]) - # if you have padding issues, see - # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a - # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd - x = torch.cat([x2, x1], dim=1) - return self.conv(x) - - -class OutConv(nn.Module): - def __init__(self, in_channels, out_channels): - super(OutConv, self).__init__() - self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) - - def forward(self, x): - return self.conv(x) diff --git a/utils/__init__.py b/utils/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/utils/data_loading.py b/utils/data_loading.py deleted file mode 100644 index 186ac00cd..000000000 --- a/utils/data_loading.py +++ /dev/null @@ -1,122 +0,0 @@ -import logging -import numpy as np -import torch -from PIL import Image -from functools import lru_cache -from functools import partial -from itertools import repeat -from multiprocessing import Pool -from os import listdir -from os.path import splitext, isfile, join -from pathlib import Path -from torch.utils.data import Dataset -from tqdm import tqdm - - -def load_image(filename): - ext = splitext(filename)[1] - if ext == '.npy': - return Image.fromarray(np.load(filename)) - elif ext in ['.pt', '.pth']: - return Image.fromarray(torch.load(filename).numpy()) - else: - return Image.open(filename) - - -def unique_mask_values(idx, mask_dir, mask_suffix): - mask_file = list(mask_dir.glob(idx + mask_suffix + '.*'))[0] - mask = np.asarray(load_image(mask_file)) - if mask.ndim == 2: - return np.unique(mask) - elif mask.ndim == 3: - mask = mask.reshape(-1, mask.shape[-1]) - return np.unique(mask, axis=0) - else: - raise ValueError(f'Loaded masks should have 2 or 3 dimensions, found {mask.ndim}') - - -class BasicDataset(Dataset): - def __init__(self, images_dir: str, mask_dir: str, scale: float = 1.0, mask_suffix: str = ''): - self.images_dir = Path(images_dir) - self.mask_dir = Path(mask_dir) - assert 0 < scale <= 1, 'Scale must be between 0 and 1' - self.scale = scale - self.mask_suffix = mask_suffix - - self.ids = [splitext(file)[0] for file in listdir(images_dir) if isfile(join(images_dir, file)) and not file.startswith('.')] - if not self.ids: - raise RuntimeError(f'No input file found in {images_dir}, make sure you put your images there') - - logging.info(f'Creating dataset with {len(self.ids)} examples') - logging.info('Scanning mask files to determine unique values') - with Pool() as p: - unique = list(tqdm( - p.imap(partial(unique_mask_values, mask_dir=self.mask_dir, mask_suffix=self.mask_suffix), self.ids), - total=len(self.ids) - )) - - self.mask_values = list(sorted(np.unique(np.concatenate(unique), axis=0).tolist())) - logging.info(f'Unique mask values: {self.mask_values}') - - def __len__(self): - return len(self.ids) - - @staticmethod - def preprocess(mask_values, pil_img, scale, is_mask): - w, h = pil_img.size - newW, newH = int(scale * w), int(scale * h) - assert newW > 0 and newH > 0, 'Scale is too small, resized images would have no pixel' - pil_img = pil_img.resize((newW, newH), resample=Image.NEAREST if is_mask else Image.BICUBIC) - img = np.asarray(pil_img) - - if is_mask: - mask = np.zeros((newH, newW), dtype=np.int64) - for i, v in enumerate(mask_values): - if img.ndim == 2: - mask[img == v] = i - else: - mask[(img == v).all(-1)] = i - - return mask - - else: - if img.ndim == 2: - img = img[np.newaxis, ...] - else: - img = img.transpose((2, 0, 1)) - - if (img > 1).any(): - img = img / 255.0 - - return img - - def __getitem__(self, idx): - name = self.ids[idx] - mask_file = list(self.mask_dir.glob(name + self.mask_suffix + '.*')) - img_file = list(self.images_dir.glob(name + '.*')) - - assert len(img_file) == 1, f'Either no image or multiple images found for the ID {name}: {img_file}' - assert len(mask_file) == 1, f'Either no mask or multiple masks found for the ID {name}: {mask_file}' - mask = load_image(mask_file[0]) - img = load_image(img_file[0]) - - assert img.size == mask.size, \ - f'Image and mask {name} should be the same size, but are {img.size} and {mask.size}' - - img = self.preprocess(self.mask_values, img, self.scale, is_mask=False) - mask = self.preprocess(self.mask_values, mask, self.scale, is_mask=True) - - return { - 'image': torch.as_tensor(img.copy()).float().contiguous(), - 'mask': torch.as_tensor(mask.copy()).long().contiguous() - } - - -class CarvanaDataset(BasicDataset): - def __init__(self, images_dir, mask_dir, scale=1): - super().__init__(images_dir, mask_dir, scale, mask_suffix='_mask') - - -class ISICDataset(BasicDataset): - def __init__(self, images_dir, mask_dir, scale=1): - super().__init__(images_dir, mask_dir, scale, mask_suffix="_segmentation") diff --git a/utils/dice_score.py b/utils/dice_score.py deleted file mode 100644 index c89eebee7..000000000 --- a/utils/dice_score.py +++ /dev/null @@ -1,28 +0,0 @@ -import torch -from torch import Tensor - - -def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon: float = 1e-6): - # Average of Dice coefficient for all batches, or for a single mask - assert input.size() == target.size() - assert input.dim() == 3 or not reduce_batch_first - - sum_dim = (-1, -2) if input.dim() == 2 or not reduce_batch_first else (-1, -2, -3) - - inter = 2 * (input * target).sum(dim=sum_dim) - sets_sum = input.sum(dim=sum_dim) + target.sum(dim=sum_dim) - sets_sum = torch.where(sets_sum == 0, inter, sets_sum) - - dice = (inter + epsilon) / (sets_sum + epsilon) - return dice.mean() - - -def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon: float = 1e-6): - # Average of Dice coefficient for all classes - return dice_coeff(input.flatten(0, 1), target.flatten(0, 1), reduce_batch_first, epsilon) - - -def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False): - # Dice loss (objective to minimize) between 0 and 1 - fn = multiclass_dice_coeff if multiclass else dice_coeff - return 1 - fn(input, target, reduce_batch_first=True) diff --git a/utils/utils.py b/utils/utils.py deleted file mode 100644 index 5e6e04128..000000000 --- a/utils/utils.py +++ /dev/null @@ -1,13 +0,0 @@ -import matplotlib.pyplot as plt - - -def plot_img_and_mask(img, mask): - classes = mask.max() + 1 - fig, ax = plt.subplots(1, classes + 1) - ax[0].set_title('Input image') - ax[0].imshow(img) - for i in range(classes): - ax[i + 1].set_title(f'Mask (class {i + 1})') - ax[i + 1].imshow(mask == i) - plt.xticks([]), plt.yticks([]) - plt.show() From 8166c1aa5bc5c175afe5e617af13c2782189ae71 Mon Sep 17 00:00:00 2001 From: FrostNov4 <97688822+FrostNov4@users.noreply.github.com> Date: Thu, 26 Oct 2023 04:02:11 +1000 Subject: [PATCH 14/20] File location update --- .idea/.gitignore | 8 + .idea/PatternAnalysis-2023.iml | 14 + .idea/inspectionProfiles/Project_Default.xml | 41 +++ .../inspectionProfiles/profiles_settings.xml | 6 + .idea/misc.xml | 4 + .idea/modules.xml | 8 + .idea/vcs.xml | 6 + __init__.py | 3 + data_loading.py | 121 +++++++ dice_score.py | 31 ++ unet_model.py | 313 ++++++++++++++++++ utils.py | 18 + 12 files changed, 573 insertions(+) create mode 100644 .idea/.gitignore create mode 100644 .idea/PatternAnalysis-2023.iml create mode 100644 .idea/inspectionProfiles/Project_Default.xml create mode 100644 .idea/inspectionProfiles/profiles_settings.xml create mode 100644 .idea/misc.xml create mode 100644 .idea/modules.xml create mode 100644 .idea/vcs.xml create mode 100644 __init__.py create mode 100644 data_loading.py create mode 100644 dice_score.py create mode 100644 unet_model.py create mode 100644 utils.py diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 000000000..13566b81b --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/PatternAnalysis-2023.iml b/.idea/PatternAnalysis-2023.iml new file mode 100644 index 000000000..136187df6 --- /dev/null +++ b/.idea/PatternAnalysis-2023.iml @@ -0,0 +1,14 @@ + + + + + + + + + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 000000000..c2745e30b --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,41 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 000000000..105ce2da2 --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 000000000..a971a2c93 --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 000000000..b55eb140b --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 000000000..35eb1ddfb --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/__init__.py b/__init__.py new file mode 100644 index 000000000..2cf6a655c --- /dev/null +++ b/__init__.py @@ -0,0 +1,3 @@ +from .unet_model import UNet + + diff --git a/data_loading.py b/data_loading.py new file mode 100644 index 000000000..0ea6027c5 --- /dev/null +++ b/data_loading.py @@ -0,0 +1,121 @@ +import logging +import numpy as np +import torch +from PIL import Image +from functools import lru_cache +from functools import partial +from itertools import repeat +from multiprocessing import Pool +from os import listdir +from os.path import splitext, isfile, join +from pathlib import Path +from torch.utils.data import Dataset +from tqdm import tqdm +# Updated +def load_image(filename): + ext = splitext(filename)[1] + if ext == '.npy': + return Image.fromarray(np.load(filename)) + elif ext in ['.pt', '.pth']: + return Image.fromarray(torch.load(filename).numpy()) + else: + return Image.open(filename) + + +def unique_mask_values(idx, mask_dir, mask_suffix): + mask_file = list(mask_dir.glob(idx + mask_suffix + '.*'))[0] + mask = np.asarray(load_image(mask_file)) + if mask.ndim == 2: + return np.unique(mask) + elif mask.ndim == 3: + mask = mask.reshape(-1, mask.shape[-1]) + return np.unique(mask, axis=0) + else: + raise ValueError(f'Loaded masks should have 2 or 3 dimensions, found {mask.ndim}') + + +class BasicDataset(Dataset): + def __init__(self, images_dir: str, mask_dir: str, scale: float = 1.0, mask_suffix: str = ''): + self.images_dir = Path(images_dir) + self.mask_dir = Path(mask_dir) + assert 0 < scale <= 1, 'Scale must be between 0 and 1' + self.scale = scale + self.mask_suffix = mask_suffix + + self.ids = [splitext(file)[0] for file in listdir(images_dir) if isfile(join(images_dir, file)) and not file.startswith('.')] + if not self.ids: + raise RuntimeError(f'No input file found in {images_dir}, make sure you put your images there') + + logging.info(f'Creating dataset with {len(self.ids)} examples') + logging.info('Scanning mask files to determine unique values') + with Pool() as p: + unique = list(tqdm( + p.imap(partial(unique_mask_values, mask_dir=self.mask_dir, mask_suffix=self.mask_suffix), self.ids), + total=len(self.ids) + )) + + self.mask_values = list(sorted(np.unique(np.concatenate(unique), axis=0).tolist())) + logging.info(f'Unique mask values: {self.mask_values}') + + def __len__(self): + return len(self.ids) + + @staticmethod + def preprocess(mask_values, pil_img, scale, is_mask): + w, h = pil_img.size + newW, newH = int(scale * w), int(scale * h) + assert newW > 0 and newH > 0, 'Scale is too small, resized images would have no pixel' + pil_img = pil_img.resize((newW, newH), resample=Image.NEAREST if is_mask else Image.BICUBIC) + img = np.asarray(pil_img) + + if is_mask: + mask = np.zeros((newH, newW), dtype=np.int64) + for i, v in enumerate(mask_values): + if img.ndim == 2: + mask[img == v] = i + else: + mask[(img == v).all(-1)] = i + + return mask + + else: + if img.ndim == 2: + img = img[np.newaxis, ...] + else: + img = img.transpose((2, 0, 1)) + + if (img > 1).any(): + img = img / 255.0 + + return img + + def __getitem__(self, idx): + name = self.ids[idx] + mask_file = list(self.mask_dir.glob(name + self.mask_suffix + '.*')) + img_file = list(self.images_dir.glob(name + '.*')) + + assert len(img_file) == 1, f'Either no image or multiple images found for the ID {name}: {img_file}' + assert len(mask_file) == 1, f'Either no mask or multiple masks found for the ID {name}: {mask_file}' + mask = load_image(mask_file[0]) + img = load_image(img_file[0]) + + assert img.size == mask.size, \ + f'Image and mask {name} should be the same size, but are {img.size} and {mask.size}' + + img = self.preprocess(self.mask_values, img, self.scale, is_mask=False) + mask = self.preprocess(self.mask_values, mask, self.scale, is_mask=True) + + return { + 'image': torch.as_tensor(img.copy()).float().contiguous(), + 'mask': torch.as_tensor(mask.copy()).long().contiguous() + } + + +class CarvanaDataset(BasicDataset): + def __init__(self, images_dir, mask_dir, scale=1): + super().__init__(images_dir, mask_dir, scale, mask_suffix='_mask') + + +class ISICDataset(BasicDataset): + def __init__(self, images_dir, mask_dir, scale=1): + super().__init__(images_dir, mask_dir, scale, mask_suffix="_segmentation") diff --git a/dice_score.py b/dice_score.py new file mode 100644 index 000000000..84a219789 --- /dev/null +++ b/dice_score.py @@ -0,0 +1,31 @@ +import torch +from torch import Tensor + + +def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon: float = 1e-6): + # Average of Dice coefficient for all batches, or for a single mask + assert input.size() == target.size() + assert input.dim() == 3 or not reduce_batch_first + + + sum_dim = (-1, -2) if input.dim() == 2 or not reduce_batch_first else (-1, -2, -3) + + inter = 2 * (input * target).sum(dim=sum_dim) + sets_sum = input.sum(dim=sum_dim) + target.sum(dim=sum_dim) + sets_sum = torch.where(sets_sum == 0, inter, sets_sum) + + dice = (inter + epsilon) / (sets_sum + epsilon) + + return dice.mean() + +def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon: float = 1e-6): + # Average of Dice coefficient for all classes + return dice_coeff(input.flatten(0, 1), target.flatten(0, 1), reduce_batch_first, epsilon) + + +def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False): + # Dice loss (objective to minimize) between 0 and 1 + fn = multiclass_dice_coeff if multiclass else dice_coeff + return 1 - fn(input, target, reduce_batch_first=True) + +# Updated \ No newline at end of file diff --git a/unet_model.py b/unet_model.py new file mode 100644 index 000000000..3e7bf5a7a --- /dev/null +++ b/unet_model.py @@ -0,0 +1,313 @@ +""" Full assembly of the parts to form the complete network """ +from __future__ import print_function, division +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.data +import torch + + + +class ContextModule(nn.Module): + """ + Context Module: Consists of two convolutional layers for feature extraction and a dropout layer + for regularization, aimed at capturing and preserving the context information in the features. + """ + + def __init__(self, in_channels, out_channels): + """ + Initialize the Context Module. + + Parameters: + - in_channels (int): Number of input channels. + - out_channels (int): Number of output channels. + """ + + super(ContextModule, self).__init__() + # 2 3x3 convolution layer followed by instance normalization and leaky ReLU activation + self.conv1 = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + ) + self.conv2 = nn.Sequential( + nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + ) + # Dropout layer to prevent overfitting + self.dropout = nn.Dropout2d(p=0.3) + + def forward(self, x): + """ + Forward pass through the context module. Input is put through 2 3x3 stride 1 convolutions with a dropout + layer in between + + Parameters: + - x (Tensor): The input tensor. + + Returns: + - Tensor: The output tensor after passing through the context module. + """ + x = self.conv1(x) + x = self.dropout(x) + x = self.conv2(x) + return x + + +class SegmentationLayer(nn.Module): + """ + SegmentationLayer: A convolutional layer specifically utilized to generate a segmentation map. + """ + + def __init__(self, in_channels, out_channels): + """ + Initialize the SegmentationLayer. + + Parameters: + - in_channels (int): Number of input channels. + - out_channels (int): Number of output channels, often equal to the number of classes in segmentation. + """ + super(SegmentationLayer, self).__init__() + # A convolutional layer that produces segmentation map + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + """ + Forward pass through the SegmentationLayer. + + Parameters: + - x (Tensor): The input tensor. + + Returns: + - Tensor: The output tensor after applying the convolution, serving as a segmentation map. + """ + # Applying convolution + x = self.conv(x) + return x + + +class UpscalingLayer(nn.Module): + """ + UpscalingLayer: A layer designed to upscale feature maps by a factor of 2. + """ + + def __init__(self, scale_factor=2, mode='nearest'): + """ + Initialize the UpscalingLayer. + + Parameters: + - scale_factor (int, optional): Factor by which to upscale the input. Default is 2. + - mode (str, optional): Algorithm used for upscaling: 'nearest', 'bilinear', etc. Default is 'nearest'. + """ + super(UpscalingLayer, self).__init__() + # An upsampling layer that increases the spatial dimensions of the feature map + self.upsample = nn.Upsample(scale_factor=scale_factor, mode=mode) + + def forward(self, x): + """ + Forward pass through the UpscalingLayer. + + Parameters: + - x (Tensor): The input tensor. + + Returns: + - Tensor: The output tensor after applying the upscaling, having increased spatial dimensions. + """ + # Applying upscaling + x = self.upsample(x) + return x + + +class LocalisationModule(nn.Module): + """ + Localisation Module: Focused on up-sampling the received feature map and reducing the + number of feature channels, working towards recovering the spatial resolution of the input data. + """ + + def __init__(self, in_channels, out_channels): + """ + Initialize the Localisation Module. + + Parameters: + - in_channels (int): Number of input channels. + - out_channels (int): Number of output channels. + """ + super(LocalisationModule, self).__init__() + # Using a simple upscale by repeating the feature pixels twice + self.upsample = nn.Upsample(scale_factor=2, mode='nearest') + # 3x3 convolution to process concatenated features + self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + # 1x1 convolution to reduce the number of feature maps + self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + """ + Forward pass through the localisation module. Input is put through 2 3x3 stride 1 convolutions + with leaky ReLU applied in between + + Parameters: + - x (Tensor): The input tensor. + + Returns: + - Tensor: The output tensor after passing through the localisation module. + """ + x = self.conv1(x) + x = F.relu(x) + x = self.conv2(x) + x = F.relu(x) + return x + + +class UpsamplingModule(nn.Module): + """ + Upsampling Module: Handles the up-sampling of feature maps in the decoder part of the UNet, + contributing to incrementing the spatial dimensions of the input feature map. + """ + + def __init__(self, in_channels, out_channels): + """ + Initialize the Upsampling Module. + + Parameters: + - in_channels (int): Number of input channels. + - out_channels (int): Number of output channels. + """ + super(UpsamplingModule, self).__init__() + # Using a simple upscale by repeating the feature pixels twice + self.upsample = nn.Upsample(scale_factor=2, mode='nearest') + # 3x3 convolution that halves the number of feature maps + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + """ + Forward pass through the upsampling module. First the input is upsampled, then undergoes stride 1 + 3x3 convolution followed by leaky ReLU. + + Parameters: + - x (Tensor): The input tensor. + + Returns: + - Tensor: The output tensor after passing through the upsampling module. + """ + x = self.upsample(x) + x = self.conv(x) + x = F.relu(x) + + return x + + +class UNet_For_Brain(nn.Module): + """ + UNet2D: An Improved U-Net model implmented as the provided Improved U-Net paper + """ + + def __init__(self, in_channels, num_classes=2): + """ + Initialize the UNet2D model. + + Parameters: + - in_channels (int): Number of input channels. + - num_classes (int): Number of output classes for segmentation. + """ + super(UNet_For_Brain, self).__init__() + + # Encoder + self.enc1 = nn.Conv2d(in_channels, 16 * 4, kernel_size=3, stride=1, padding=1) + self.context1 = ContextModule(16 * 4, 16 * 4) + self.enc2 = nn.Conv2d(16 * 4, 32 * 4, kernel_size=3, stride=2, padding=1) + self.context2 = ContextModule(32 * 4, 32 * 4) + self.enc3 = nn.Conv2d(32 * 4, 64 * 4, kernel_size=3, stride=2, padding=1) + self.context3 = ContextModule(64 * 4, 64 * 4) + self.enc4 = nn.Conv2d(64 * 4, 128 * 4, kernel_size=3, stride=2, padding=1) + self.context4 = ContextModule(128 * 4, 128 * 4) + + # Bottleneck + self.bottleneck = nn.Conv2d(128 * 4, 256 * 4, kernel_size=3, stride=2, padding=1) + self.bottleneck_context = ContextModule(256 * 4, 256 * 4) + self.up_bottleneck = UpsamplingModule(256 * 4, 128 * 4) + + # Decoder + self.local1 = LocalisationModule(256 * 4, 128 * 4) + self.up1 = UpsamplingModule(128 * 4, 64 * 4) + + self.local2 = LocalisationModule(128 * 4, 64 * 4) + self.up2 = UpsamplingModule(64 * 4, 32 * 4) + + self.seg1 = SegmentationLayer(64 * 4, num_classes) + self.upsample_seg1 = UpscalingLayer() + + self.local3 = LocalisationModule(64 * 4, 32 * 4) + self.up3 = UpsamplingModule(32 * 4, 16 * 4) + + self.seg2 = SegmentationLayer(32 * 4, num_classes) + self.upsample_seg2 = UpscalingLayer() + + self.final_conv = nn.Conv2d(32 * 4, 32 * 4, kernel_size=3, stride=1, padding=1) + + self.seg3 = SegmentationLayer(32 * 4, num_classes) + self.upsample_seg3 = UpscalingLayer() + + def forward(self, x): + """ + Define the forward pass through the UNet2D model. + + Parameters: + - x (Tensor): Input tensor. + + Returns: + - Tensor: The output tensor after passing through the model. + """ + y1 = self.enc1(x) + x1 = self.context1(y1) + x1 = x1 + y1 + + y2 = self.enc2(x1) + x2 = self.context2(y2) + x2 = x2 + y2 + + y3 = self.enc3(x2) + x3 = self.context3(y3) + x3 = x3 + y3 + + y4 = self.enc4(x3) + x4 = self.context4(y4) + x4 = x4 + y4 + + # Bottleneck + bottleneck_conv = self.bottleneck(x4) + + bottleneck = self.bottleneck_context(bottleneck_conv) + bottleneck = bottleneck + bottleneck_conv + + up_bottleneck = self.up_bottleneck(bottleneck) + + # Decoder + x = self.local1(torch.cat((x4, up_bottleneck), dim=1)) + x = self.up1(x) + + x = self.local2(torch.cat((x3, x), dim=1)) + seg1 = self.seg1(x) + x = self.up2(x) + + seg1_upsampled = self.upsample_seg1(seg1) + + x = self.local3(torch.cat((x2, x), dim=1)) + seg2 = self.seg2(x) + x = self.up3(x) + + seg12 = seg1_upsampled + seg2 + seg12_up = self.upsample_seg2(seg12) + + x = self.final_conv(torch.cat((x1, x), dim=1)) + + seg3 = self.seg3(x) + seg123 = seg3 + seg12_up + + out = seg123 + # out = nn.functional.softmax(seg123, dim=1) + # out = torch.sigmoid(seg123) + # print("out shape: ", out.size()) + + return out + +# Updated \ No newline at end of file diff --git a/utils.py b/utils.py new file mode 100644 index 000000000..9f3b54559 --- /dev/null +++ b/utils.py @@ -0,0 +1,18 @@ +import matplotlib.pyplot as plt + + +def plot_img_and_mask(img, mask): + classes = mask.max() + 1 + fig, ax = plt.subplots(1, classes + 1) + ax[0].set_title('Input image') + ax[0].imshow(img) + for i in range(classes): + ax[i + 1].set_title(f'Mask (class {i + 1})') + ax[i + 1].imshow(mask == i) + plt.xticks([]), plt.yticks([]) + + + plt.show() + + +# Updated \ No newline at end of file From f6901bb85951e7f1c48f42944c5d2fdd620a73c1 Mon Sep 17 00:00:00 2001 From: FrostNov4 <97688822+FrostNov4@users.noreply.github.com> Date: Thu, 26 Oct 2023 04:32:19 +1000 Subject: [PATCH 15/20] File location update --- Data_Loader.py | 116 ++++++ Losses.py | 61 ++++ Metrics.py | 55 +++ unet_model.py => Modules.py | 0 Ploting.py | 207 +++++++++++ __init__.py | 2 +- data_loading.py | 121 ------- train.py | 685 +++++++++++++++++++++++------------- 8 files changed, 890 insertions(+), 357 deletions(-) create mode 100644 Data_Loader.py create mode 100644 Losses.py create mode 100644 Metrics.py rename unet_model.py => Modules.py (100%) create mode 100644 Ploting.py delete mode 100644 data_loading.py diff --git a/Data_Loader.py b/Data_Loader.py new file mode 100644 index 000000000..42df1a874 --- /dev/null +++ b/Data_Loader.py @@ -0,0 +1,116 @@ +from __future__ import print_function, division +import os +from PIL import Image +import torch +import torch.utils.data +import torchvision +from skimage import io +from torch.utils.data import Dataset +import random +import numpy as np + + +class Images_Dataset(Dataset): + """Class for getting data as a Dict + Args: + images_dir = path of input images + labels_dir = path of labeled images + transformI = Input Images transformation (default: None) + transformM = Input Labels transformation (default: None) + Output: + sample : Dict of images and labels""" + + def __init__(self, images_dir, labels_dir, transformI=None, transformM=None): + + self.labels_dir = labels_dir + self.images_dir = images_dir + self.transformI = transformI + self.transformM = transformM + + def __len__(self): + return len(self.images_dir) + + def __getitem__(self, idx): + + for i in range(len(self.images_dir)): + image = io.imread(self.images_dir[i]) + label = io.imread(self.labels_dir[i]) + if self.transformI: + image = self.transformI(image) + if self.transformM: + label = self.transformM(label) + sample = {'images': image, 'labels': label} + + return sample + + +class Images_Dataset_folder(torch.utils.data.Dataset): + """Class for getting individual transformations and data + Args: + images_dir = path of input images + labels_dir = path of labeled images + transformI = Input Images transformation (default: None) + transformM = Input Labels transformation (default: None) + Output: + tx = Transformed images + lx = Transformed labels""" + + def __init__(self, images_dir, labels_dir, transformI=None, transformM=None): + self.images = sorted(os.listdir(images_dir)) + self.labels = sorted(os.listdir(labels_dir)) + self.images_dir = images_dir + self.labels_dir = labels_dir + self.transformI = transformI + self.transformM = transformM + + if self.transformI: + self.tx = self.transformI + else: + self.tx = torchvision.transforms.Compose([ + torchvision.transforms.Resize((128, 128)), + # torchvision.transforms.CenterCrop(96), + # torchvision.transforms.RandomRotation((-10,10)), + # torchvision.transforms.RandomHorizontalFlip(), + # torchvision.transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), + torchvision.transforms.ToTensor(), + # torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + ]) + + if self.transformM: + self.lx = self.transformM + else: + self.lx = torchvision.transforms.Compose([ + torchvision.transforms.Resize((128, 128)), + # torchvision.transforms.CenterCrop(96), + # torchvision.transforms.RandomRotation((-10,10)), + # torchvision.transforms.Grayscale(), + torchvision.transforms.ToTensor(), + # torchvision.transforms.Lambda(lambda x: torch.cat([x, 1 - x], dim=0)) + ]) + + def __len__(self): + + return len(self.images) + + def __getitem__(self, i): + i1 = Image.open(self.images_dir + self.images[i]).convert("RGB") + l1 = Image.open(self.labels_dir + self.labels[i]).convert("L") + # np_i1 = np.array(i1) + # np_l1 = np.array(l1) + # print(np_i1.shape) + # print(np_l1.shape) + + seed = np.random.randint(0, 2 ** 32) # make a seed with numpy generator + + # apply this seed to img tranfsorms + random.seed(seed) + torch.manual_seed(seed) + img = self.tx(i1) + + # apply this seed to target/label tranfsorms + random.seed(seed) + torch.manual_seed(seed) + label = self.lx(l1) + + return img, label + diff --git a/Losses.py b/Losses.py new file mode 100644 index 000000000..65813cf5c --- /dev/null +++ b/Losses.py @@ -0,0 +1,61 @@ +from __future__ import print_function, division +import torch.nn.functional as F + +# Implemented from https://github.com/bigmb/Unet-Segmentation-Pytorch-Nest-of-Unets +def dice_loss(prediction, target): + """Calculating the dice loss + Args: + prediction = predicted image + target = Targeted image + Output: + dice_loss""" + + smooth = 1.0 + + i_flat = prediction.view(-1) + t_flat = target.view(-1) + + intersection = (i_flat * t_flat).sum() + + return 1 - ((2. * intersection + smooth) / (i_flat.sum() + t_flat.sum() + smooth)) + + +def calc_loss(prediction, target, bce_weight=0.5): + """Calculating the loss and metrics + Args: + prediction = predicted image + target = Targeted image + metrics = Metrics printed + bce_weight = 0.5 (default) + Output: + loss : dice loss of the epoch """ + bce = F.binary_cross_entropy_with_logits(prediction, target) + prediction = F.sigmoid(prediction) + dice = dice_loss(prediction, target) + + loss = bce * bce_weight + dice * (1 - bce_weight) + + return loss + + +def threshold_predictions_v(predictions, thr=150): + thresholded_preds = predictions[:] + # hist = cv2.calcHist([predictions], [0], None, [2], [0, 2]) + # plt.plot(hist) + # plt.xlim([0, 2]) + # plt.show() + low_values_indices = thresholded_preds < thr + thresholded_preds[low_values_indices] = 0 + low_values_indices = thresholded_preds >= thr + thresholded_preds[low_values_indices] = 255 + return thresholded_preds + + +def threshold_predictions_p(predictions, thr=0.01): + thresholded_preds = predictions[:] + #hist = cv2.calcHist([predictions], [0], None, [256], [0, 256]) + low_values_indices = thresholded_preds < thr + thresholded_preds[low_values_indices] = 0 + low_values_indices = thresholded_preds >= thr + thresholded_preds[low_values_indices] = 1 + return thresholded_preds \ No newline at end of file diff --git a/Metrics.py b/Metrics.py new file mode 100644 index 000000000..233be8455 --- /dev/null +++ b/Metrics.py @@ -0,0 +1,55 @@ +# From https://github.com/bigmb/Unet-Segmentation-Pytorch-Nest-of-Unets + + +import numpy as np +from scipy import spatial + + +def dice_coeff(im1, im2, empty_score=1.0): + """Calculates the dice coefficient for the images""" + + im1 = np.asarray(im1).astype(np.bool) + im2 = np.asarray(im2).astype(np.bool) + + if im1.shape != im2.shape: + raise ValueError("Shape mismatch: im1 and im2 must have the same shape.") + + im1 = im1 > 0.5 + im2 = im2 > 0.5 + + im_sum = im1.sum() + im2.sum() + if im_sum == 0: + return empty_score + + # Compute Dice coefficient + intersection = np.logical_and(im1, im2) + #print(im_sum) + + return 2. * intersection.sum() / im_sum + + +def numeric_score(prediction, groundtruth): + """Computes scores: + FP = False Positives + FN = False Negatives + TP = True Positives + TN = True Negatives + return: FP, FN, TP, TN""" + + FP = np.float(np.sum((prediction == 1) & (groundtruth == 0))) + FN = np.float(np.sum((prediction == 0) & (groundtruth == 1))) + TP = np.float(np.sum((prediction == 1) & (groundtruth == 1))) + TN = np.float(np.sum((prediction == 0) & (groundtruth == 0))) + + return FP, FN, TP, TN + + +def accuracy_score(prediction, groundtruth): + """Getting the accuracy of the model""" + + FP, FN, TP, TN = numeric_score(prediction, groundtruth) + N = FP + FN + TP + TN + # accuracy = np.divide(TP + TN, N) + print("2*TP: ", 2*TP, "FP+2*TP+FN: ", FP+2*TP+FN) + accuracy = np.divide(2*TP, FP+2*TP+FN) + return accuracy #* 100.0 \ No newline at end of file diff --git a/unet_model.py b/Modules.py similarity index 100% rename from unet_model.py rename to Modules.py diff --git a/Ploting.py b/Ploting.py new file mode 100644 index 000000000..76aac1399 --- /dev/null +++ b/Ploting.py @@ -0,0 +1,207 @@ +import matplotlib.pyplot as plt +from matplotlib.lines import Line2D +import numpy as np +from visdom import Visdom + +def draw_loss(Loss_list,epoch,pic_name): + plt.cla() + x1 = [i for i in range(8)] + y1 = Loss_list + plt.title('Train loss vs. epoches', fontsize=1) + plt.plot(x1, y1, '.-') + plt.xlabel('epoches', fontsize=1) + plt.ylabel('Train loss', fontsize=1) + plt.grid() + plt.savefig(pic_name) + plt.show() + +def show_images(images, labels): + """Show image with label + Args: + images = input images + labels = input labels + Output: + plt = concatenated image and label """ + + plt.imshow(images.permute(1, 2, 0)) + plt.imshow(labels, alpha=0.7, cmap='gray') + plt.figure() + + +def show_training_dataset(training_dataset): + """Showing the images in training set for dict images and labels + Args: + training_dataset = dictionary of images and labels + Output: + figure = 3 images shown""" + + if training_dataset: + print(len(training_dataset)) + + for i in range(len(training_dataset)): + sample = training_dataset[i] + + print(i, sample['images'].shape, sample['labels'].shape) + + ax = plt.subplot(1, 4, i + 1) + plt.tight_layout() + ax.set_title('Sample #{}'.format(i)) + ax.axis('off') + show_images(sample['images'],sample['labels']) + + if i == 3: + plt.show() + break + +class VisdomLinePlotter(object): + + """Plots to Visdom""" + + def __init__(self, env_name='main'): + self.viz = Visdom() + self.env = env_name + self.plots = {} + + def plot(self, var_name, split_name, title_name, x, y): + if var_name not in self.plots: + self.plots[var_name] = self.viz.line(X=np.array([x,x]), Y=np.array([y,y]), env=self.env, opts=dict( + legend=[split_name], + title=title_name, + xlabel='Epochs', + ylabel=var_name + )) + else: + self.viz.line(X=np.array([x]), Y=np.array([y]), env=self.env, win=self.plots[var_name], name=split_name, update = 'append') + + +def input_images(x, y, i, n_iter, k=1): + """ + + :param x: takes input image + :param y: take input label + :param i: the epoch number + :param n_iter: + :param k: for keeping it in loop + :return: Returns a image and label + """ + if k == 1: + x1 = x + y1 = y + + x2 = x1.to('cpu') + y2 = y1.to('cpu') + x2 = x2.detach().numpy() + y2 = y2.detach().numpy() + + x3 = x2[1, 1, :, :] + y3 = y2[1, 0, :, :] + + fig = plt.figure() + + ax1 = fig.add_subplot(1, 2, 1) + ax1.imshow(x3) + ax1.axis('off') + ax1.set_xticklabels([]) + ax1.set_yticklabels([]) + ax1 = fig.add_subplot(1, 2, 2) + ax1.imshow(y3) + ax1.axis('off') + ax1.set_xticklabels([]) + ax1.set_yticklabels([]) + plt.savefig( + './model/pred/L_' + str(n_iter-1) + '_epoch_' + + str(i)) + + +def plot_kernels(tensor, n_iter, num_cols=5, cmap="gray"): + """Plotting the kernals and layers + Args: + Tensor :Input layer, + n_iter : number of interation, + num_cols : number of columbs required for figure + Output: + Gives the figure of the size decided with output layers activation map + + Default : Last layer will be taken into consideration + """ + if not len(tensor.shape) == 4: + raise Exception("assumes a 4D tensor") + + fig = plt.figure() + i = 0 + t = tensor.data.numpy() + b = 0 + a = 1 + + for t1 in t: + for t2 in t1: + i += 1 + + ax1 = fig.add_subplot(5, num_cols, i) + ax1.imshow(t2, cmap=cmap) + ax1.axis('off') + ax1.set_xticklabels([]) + ax1.set_yticklabels([]) + + if i == 1: + a = 1 + if a == 10: + break + a += 1 + if i % a == 0: + a = 0 + b += 1 + if b == 20: + break + + plt.savefig( + './model/pred/Kernal_' + str(n_iter - 1) + '_epoch_' + + str(i)) + + +class LayerActivations(): + """Getting the hooks on each layer""" + + features = None + + def __init__(self, layer): + self.hook = layer.register_forward_hook(self.hook_fn) + + def hook_fn(self, module, input, output): + self.features = output.cpu() + + def remove(self): + self.hook.remove() + + +#to get gradient flow +#From Pytorch-forums +def plot_grad_flow(named_parameters,n_iter): + + '''Plots the gradients flowing through different layers in the net during training. + Can be used for checking for possible gradient vanishing / exploding problems. + + Usage: Plug this function in Trainer class after loss.backwards() as + "plot_grad_flow(self.model.named_parameters())" to visualize the gradient flow''' + ave_grads = [] + max_grads = [] + layers = [] + for n, p in named_parameters: + if (p.requires_grad) and ("bias" not in n): + layers.append(n) + ave_grads.append(p.grad.abs().mean()) + max_grads.append(p.grad.abs().max()) + plt.bar(np.arange(len(max_grads)), max_grads, alpha=0.1, lw=1, color="c") + plt.bar(np.arange(len(max_grads)), ave_grads, alpha=0.1, lw=1, color="b") + plt.hlines(0, 0, len(ave_grads) + 1, lw=2, color="k") + plt.xticks(range(0, len(ave_grads), 1), layers, rotation="vertical") + plt.xlim(left=0, right=len(ave_grads)) + plt.ylim(bottom=-0.001, top=0.02) # zoom in on the lower gradient regions + plt.xlabel("Layers") + plt.ylabel("average gradient") + plt.title("Gradient flow") + plt.grid(True) + plt.legend([Line2D([0], [0], color="c", lw=4), + Line2D([0], [0], color="b", lw=4), + Line2D([0], [0], color="k", lw=4)], ['max-gradient', 'mean-gradient', 'zero-gradient']) + #plt.savefig('./model/pred/Grad_Flow_' + str(n_iter - 1)) diff --git a/__init__.py b/__init__.py index 2cf6a655c..279eecc85 100644 --- a/__init__.py +++ b/__init__.py @@ -1,3 +1,3 @@ -from .unet_model import UNet +from .Modules import UNet diff --git a/data_loading.py b/data_loading.py deleted file mode 100644 index 0ea6027c5..000000000 --- a/data_loading.py +++ /dev/null @@ -1,121 +0,0 @@ -import logging -import numpy as np -import torch -from PIL import Image -from functools import lru_cache -from functools import partial -from itertools import repeat -from multiprocessing import Pool -from os import listdir -from os.path import splitext, isfile, join -from pathlib import Path -from torch.utils.data import Dataset -from tqdm import tqdm -# Updated -def load_image(filename): - ext = splitext(filename)[1] - if ext == '.npy': - return Image.fromarray(np.load(filename)) - elif ext in ['.pt', '.pth']: - return Image.fromarray(torch.load(filename).numpy()) - else: - return Image.open(filename) - - -def unique_mask_values(idx, mask_dir, mask_suffix): - mask_file = list(mask_dir.glob(idx + mask_suffix + '.*'))[0] - mask = np.asarray(load_image(mask_file)) - if mask.ndim == 2: - return np.unique(mask) - elif mask.ndim == 3: - mask = mask.reshape(-1, mask.shape[-1]) - return np.unique(mask, axis=0) - else: - raise ValueError(f'Loaded masks should have 2 or 3 dimensions, found {mask.ndim}') - - -class BasicDataset(Dataset): - def __init__(self, images_dir: str, mask_dir: str, scale: float = 1.0, mask_suffix: str = ''): - self.images_dir = Path(images_dir) - self.mask_dir = Path(mask_dir) - assert 0 < scale <= 1, 'Scale must be between 0 and 1' - self.scale = scale - self.mask_suffix = mask_suffix - - self.ids = [splitext(file)[0] for file in listdir(images_dir) if isfile(join(images_dir, file)) and not file.startswith('.')] - if not self.ids: - raise RuntimeError(f'No input file found in {images_dir}, make sure you put your images there') - - logging.info(f'Creating dataset with {len(self.ids)} examples') - logging.info('Scanning mask files to determine unique values') - with Pool() as p: - unique = list(tqdm( - p.imap(partial(unique_mask_values, mask_dir=self.mask_dir, mask_suffix=self.mask_suffix), self.ids), - total=len(self.ids) - )) - - self.mask_values = list(sorted(np.unique(np.concatenate(unique), axis=0).tolist())) - logging.info(f'Unique mask values: {self.mask_values}') - - def __len__(self): - return len(self.ids) - - @staticmethod - def preprocess(mask_values, pil_img, scale, is_mask): - w, h = pil_img.size - newW, newH = int(scale * w), int(scale * h) - assert newW > 0 and newH > 0, 'Scale is too small, resized images would have no pixel' - pil_img = pil_img.resize((newW, newH), resample=Image.NEAREST if is_mask else Image.BICUBIC) - img = np.asarray(pil_img) - - if is_mask: - mask = np.zeros((newH, newW), dtype=np.int64) - for i, v in enumerate(mask_values): - if img.ndim == 2: - mask[img == v] = i - else: - mask[(img == v).all(-1)] = i - - return mask - - else: - if img.ndim == 2: - img = img[np.newaxis, ...] - else: - img = img.transpose((2, 0, 1)) - - if (img > 1).any(): - img = img / 255.0 - - return img - - def __getitem__(self, idx): - name = self.ids[idx] - mask_file = list(self.mask_dir.glob(name + self.mask_suffix + '.*')) - img_file = list(self.images_dir.glob(name + '.*')) - - assert len(img_file) == 1, f'Either no image or multiple images found for the ID {name}: {img_file}' - assert len(mask_file) == 1, f'Either no mask or multiple masks found for the ID {name}: {mask_file}' - mask = load_image(mask_file[0]) - img = load_image(img_file[0]) - - assert img.size == mask.size, \ - f'Image and mask {name} should be the same size, but are {img.size} and {mask.size}' - - img = self.preprocess(self.mask_values, img, self.scale, is_mask=False) - mask = self.preprocess(self.mask_values, mask, self.scale, is_mask=True) - - return { - 'image': torch.as_tensor(img.copy()).float().contiguous(), - 'mask': torch.as_tensor(mask.copy()).long().contiguous() - } - - -class CarvanaDataset(BasicDataset): - def __init__(self, images_dir, mask_dir, scale=1): - super().__init__(images_dir, mask_dir, scale, mask_suffix='_mask') - - -class ISICDataset(BasicDataset): - def __init__(self, images_dir, mask_dir, scale=1): - super().__init__(images_dir, mask_dir, scale, mask_suffix="_segmentation") diff --git a/train.py b/train.py index 5793fc0b8..19b0f3ff2 100644 --- a/train.py +++ b/train.py @@ -1,238 +1,453 @@ -import argparse -import logging +from __future__ import print_function, division import os -import random -import sys +import numpy as np +from PIL import Image +import glob +# import SimpleITK as sitk +from torch import optim +import torch.utils.data import torch -import torch.nn as nn import torch.nn.functional as F -import torchvision.transforms as transforms -import torchvision.transforms.functional as TF -from pathlib import Path -from torch import optim -from torch.utils.data import DataLoader, random_split -from tqdm import tqdm - -import wandb -from evaluate import evaluate -from unet import UNet -from utils.data_loading import BasicDataset, CarvanaDataset -from utils.dice_score import dice_loss - -dir_img = Path('./data/imgs/') -dir_mask = Path('./data/masks/') -dir_checkpoint = Path('./checkpoints/') - - -def train_model( - model, - device, - epochs: int = 5, - batch_size: int = 1, - learning_rate: float = 1e-5, - val_percent: float = 0.1, - save_checkpoint: bool = True, - img_scale: float = 0.5, - amp: bool = False, - weight_decay: float = 1e-8, - momentum: float = 0.999, - gradient_clipping: float = 1.0, -): - # 1. Create dataset - try: - dataset = CarvanaDataset(dir_img, dir_mask, img_scale) - except (AssertionError, RuntimeError, IndexError): - dataset = BasicDataset(dir_img, dir_mask, img_scale) - - # 2. Split into train / validation partitions - n_val = int(len(dataset) * val_percent) - n_train = len(dataset) - n_val - train_set, val_set = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(0)) - - # 3. Create data loaders - loader_args = dict(batch_size=batch_size, num_workers=os.cpu_count(), pin_memory=True) - train_loader = DataLoader(train_set, shuffle=True, **loader_args) - val_loader = DataLoader(val_set, shuffle=False, drop_last=True, **loader_args) - - # (Initialize logging) - experiment = wandb.init(project='U-Net', resume='allow', anonymous='must') - experiment.config.update( - dict(epochs=epochs, batch_size=batch_size, learning_rate=learning_rate, - val_percent=val_percent, save_checkpoint=save_checkpoint, img_scale=img_scale, amp=amp) - ) - - logging.info(f'''Starting training: - Epochs: {epochs} - Batch size: {batch_size} - Learning rate: {learning_rate} - Training size: {n_train} - Validation size: {n_val} - Checkpoints: {save_checkpoint} - Device: {device.type} - Images scaling: {img_scale} - Mixed Precision: {amp} - ''') - - # 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP - optimizer = optim.RMSprop(model.parameters(), - lr=learning_rate, weight_decay=weight_decay, momentum=momentum, foreach=True) - scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=5) # goal: maximize Dice score - grad_scaler = torch.cuda.amp.GradScaler(enabled=amp) - criterion = nn.CrossEntropyLoss() if model.n_classes > 1 else nn.BCEWithLogitsLoss() - global_step = 0 - - # 5. Begin training - for epoch in range(1, epochs + 1): - model.train() - epoch_loss = 0 - with tqdm(total=n_train, desc=f'Epoch {epoch}/{epochs}', unit='img') as pbar: - for batch in train_loader: - images, true_masks = batch['image'], batch['mask'] - - assert images.shape[1] == model.n_channels, \ - f'Network has been defined with {model.n_channels} input channels, ' \ - f'but loaded images have {images.shape[1]} channels. Please check that ' \ - 'the images are loaded correctly.' - - images = images.to(device=device, dtype=torch.float32, memory_format=torch.channels_last) - true_masks = true_masks.to(device=device, dtype=torch.long) - - with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp): - masks_pred = model(images) - if model.n_classes == 1: - loss = criterion(masks_pred.squeeze(1), true_masks.float()) - loss += dice_loss(F.sigmoid(masks_pred.squeeze(1)), true_masks.float(), multiclass=False) - else: - loss = criterion(masks_pred, true_masks) - loss += dice_loss( - F.softmax(masks_pred, dim=1).float(), - F.one_hot(true_masks, model.n_classes).permute(0, 3, 1, 2).float(), - multiclass=True - ) - - optimizer.zero_grad(set_to_none=True) - grad_scaler.scale(loss).backward() - torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping) - grad_scaler.step(optimizer) - grad_scaler.update() - - pbar.update(images.shape[0]) - global_step += 1 - epoch_loss += loss.item() - experiment.log({ - 'train loss': loss.item(), - 'step': global_step, - 'epoch': epoch - }) - pbar.set_postfix(**{'loss (batch)': loss.item()}) - - # Evaluation round - division_step = (n_train // (5 * batch_size)) - if division_step > 0: - if global_step % division_step == 0: - histograms = {} - for tag, value in model.named_parameters(): - tag = tag.replace('/', '.') - if not (torch.isinf(value) | torch.isnan(value)).any(): - histograms['Weights/' + tag] = wandb.Histogram(value.data.cpu()) - if not (torch.isinf(value.grad) | torch.isnan(value.grad)).any(): - histograms['Gradients/' + tag] = wandb.Histogram(value.grad.data.cpu()) - - val_score = evaluate(model, val_loader, device, amp) - scheduler.step(val_score) - - logging.info('Validation Dice score: {}'.format(val_score)) - try: - experiment.log({ - 'learning rate': optimizer.param_groups[0]['lr'], - 'validation Dice': val_score, - 'images': wandb.Image(images[0].cpu()), - 'masks': { - 'true': wandb.Image(true_masks[0].float().cpu()), - 'pred': wandb.Image(masks_pred.argmax(dim=1)[0].float().cpu()), - }, - 'step': global_step, - 'epoch': epoch, - **histograms - }) - except: - pass - - if save_checkpoint: - Path(dir_checkpoint).mkdir(parents=True, exist_ok=True) - state_dict = model.state_dict() - state_dict['mask_values'] = dataset.mask_values - torch.save(state_dict, str(dir_checkpoint / 'checkpoint_epoch{}.pth'.format(epoch))) - logging.info(f'Checkpoint {epoch} saved!') - - -def get_args(): - parser = argparse.ArgumentParser(description='Train the UNet on images and target masks') - parser.add_argument('--epochs', '-e', metavar='E', type=int, default=5, help='Number of epochs') - parser.add_argument('--batch-size', '-b', dest='batch_size', metavar='B', type=int, default=1, help='Batch size') - parser.add_argument('--learning-rate', '-l', metavar='LR', type=float, default=1e-5, - help='Learning rate', dest='lr') - parser.add_argument('--load', '-f', type=str, default=False, help='Load model from a .pth file') - parser.add_argument('--scale', '-s', type=float, default=0.5, help='Downscaling factor of the images') - parser.add_argument('--validation', '-v', dest='val', type=float, default=10.0, - help='Percent of the data that is used as validation (0-100)') - parser.add_argument('--amp', action='store_true', default=False, help='Use mixed precision') - parser.add_argument('--bilinear', action='store_true', default=False, help='Use bilinear upsampling') - parser.add_argument('--classes', '-c', type=int, default=2, help='Number of classes') - - return parser.parse_args() - - -if __name__ == '__main__': - args = get_args() - - logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s') - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - logging.info(f'Using device {device}') - - # Change here to adapt to your data - # n_channels=3 for RGB images - # n_classes is the number of probabilities you want to get per pixel - model = UNet(n_channels=3, n_classes=args.classes, bilinear=args.bilinear) - model = model.to(memory_format=torch.channels_last) - - logging.info(f'Network:\n' - f'\t{model.n_channels} input channels\n' - f'\t{model.n_classes} output channels (classes)\n' - f'\t{"Bilinear" if model.bilinear else "Transposed conv"} upscaling') - - if args.load: - state_dict = torch.load(args.load, map_location=device) - del state_dict['mask_values'] - model.load_state_dict(state_dict) - logging.info(f'Model loaded from {args.load}') - - model.to(device=device) - try: - train_model( - model=model, - epochs=args.epochs, - batch_size=args.batch_size, - learning_rate=args.lr, - device=device, - img_scale=args.scale, - val_percent=args.val / 100, - amp=args.amp - ) - except torch.cuda.OutOfMemoryError: - logging.error('Detected OutOfMemoryError! ' - 'Enabling checkpointing to reduce memory usage, but this slows down training. ' - 'Consider enabling AMP (--amp) for fast and memory efficient training') - torch.cuda.empty_cache() - model.use_checkpointing() - train_model( - model=model, - epochs=args.epochs, - batch_size=args.batch_size, - learning_rate=args.lr, - device=device, - img_scale=args.scale, - val_percent=args.val / 100, - amp=args.amp - - ) + +import torch.nn +import torchvision +import matplotlib.pyplot as plt +import natsort +from torch.utils.data.sampler import SubsetRandomSampler +from Data_Loader import Images_Dataset, Images_Dataset_folder +import torchsummary +# from torch.utils.tensorboard import SummaryWriter +# from tensorboardX import SummaryWriter + +import shutil +import random +from Modules import UNet_For_Brain +from Losses import calc_loss, dice_loss, threshold_predictions_v, threshold_predictions_p +from Ploting import plot_kernels, LayerActivations, input_images, plot_grad_flow, draw_loss +from Metrics import dice_coeff, accuracy_score +import time + +# from ploting import VisdomLinePlotter +# from visdom import Visdom + + +####################################################### +# Checking if GPU is used +####################################################### + +train_on_gpu = torch.cuda.is_available() + +if not train_on_gpu: + print('CUDA is not available. Training on CPU') +else: + print('CUDA is available. Training on GPU') + +# os.environ['CUDA_VISIBLE_DEVICES'] = '1,2,3,4' +device = torch.device("cuda:0" if train_on_gpu else "cpu") + +####################################################### +# Setting the basic paramters of the model +####################################################### + +batch_size = 16 +print('batch_size = ' + str(batch_size)) + +valid_size = 0.15 + +epoch = 400 +print('epoch = ' + str(epoch)) + +random_seed = random.randint(1, 100) +print('random_seed = ' + str(random_seed)) + +shuffle = True +valid_loss_min = np.Inf +num_workers = 24 +lossT = [] +lossL = [] +lossL.append(np.inf) +lossT.append(np.inf) +epoch_valid = epoch - 2 +n_iter = 1 +i_valid = 0 + +pin_memory = False +if train_on_gpu: + pin_memory = True + +# plotter = VisdomLinePlotter(env_name='Tutorial Plots') + +####################################################### +# Setting up the model +####################################################### + +model_Inputs = [UNet_For_Brain] + + +def model_unet(model_input, in_channel=1, out_channel=1): + model_test = model_input(in_channel, out_channel) + return model_test + + +# passsing this string so that if it's AttU_Net or R2ATTU_Net it doesn't throw an error at torchSummary + + +model_test = model_unet(model_Inputs[-1], 3, 1) + +model_test.to(device) + +####################################################### +# Getting the Summary of Model +####################################################### + +torchsummary.summary(model_test, input_size=(3, 128, 128)) + +####################################################### +# Passing the Dataset of Images and Labels +####################################################### + + +#ISIC2018 data +t_data = './ISIC2018/ISIC2018_Task1-2_Training_Input_x2/' +l_data = './ISIC2018/ISIC2018_Task1_Training_GroundTruth_x2/' +test_image = './keras_png_slices_data/keras_png_slices_test/' +test_label = './keras_png_slices_data/keras_png_slices_seg_test/' +test_folderP = './ISIC2018/ISIC2018_Task1-2_Training_Input_x2/*' +test_folderL = './ISIC2018/ISIC2018_Task1_Training_GroundTruth_x2/*' +valid_image = './ISIC2018/ISIC2018_Task1-2_Training_Input_x2/' +valid_lable = './ISIC2018/ISIC2018_Task1_Training_GroundTruth_x2/' + +Training_Data = Images_Dataset_folder(t_data, l_data) + +Validing_Data = Images_Dataset_folder(valid_image, valid_lable) + +####################################################### +# Giving a transformation for input data +####################################################### + +data_transform = torchvision.transforms.Compose([ + torchvision.transforms.Resize((128, 128)), + # torchvision.transforms.CenterCrop(96), + torchvision.transforms.ToTensor(), + # torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) +]) + +####################################################### +# Trainging Validation Split +####################################################### + +num_train = len(Training_Data) +indices_train = list(range(num_train)) +# split = int(np.floor(valid_size * num_train)) + +num_valid = len(Validing_Data) +indices_valid = list(range(num_valid)) + +if shuffle: + np.random.seed(random_seed) + np.random.shuffle(indices_train) + np.random.shuffle(indices_valid) + +# train_idx, valid_idx = indices[split:], indices[:split] +train_idx, valid_idx = indices_train, indices_valid +train_sampler = SubsetRandomSampler(train_idx) +valid_sampler = SubsetRandomSampler(valid_idx) + +train_loader = torch.utils.data.DataLoader(Training_Data, batch_size=batch_size, sampler=train_sampler, + num_workers=num_workers, pin_memory=pin_memory, ) + +valid_loader = torch.utils.data.DataLoader(Validing_Data, batch_size=batch_size, sampler=valid_sampler, + num_workers=num_workers, pin_memory=pin_memory, ) + +####################################################### +# Using Adam as Optimizer +####################################################### +# Hyper Parameters +initial_lr = 5e-4 +lr_decay = 0.985 +l2_weight_decay = 1e-5 +# initial_lr = 0.001 +# opt = torch.optim.Adam(model_test.parameters(), lr=initial_lr) # try SGD +# # opt = optim.SGD(model_test.parameters(), lr = initial_lr, momentum=0.99) + +# MAX_STEP = int(1e10) +# eta_min = 1e-2 +# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, MAX_STEP, eta_min=1e-5) + +# Optimisation & Loss Settings +opt = optim.Adam(model_test.parameters(), lr=initial_lr, weight_decay=l2_weight_decay) +scheduler = optim.lr_scheduler.ExponentialLR(optimizer=opt, gamma=lr_decay) + +New_folder = './model' + +if os.path.exists(New_folder) and os.path.isdir(New_folder): + shutil.rmtree(New_folder) + +try: + os.mkdir(New_folder) +except OSError: + print("Creation of the main directory '%s' failed " % New_folder) +else: + print("Successfully created the main directory '%s' " % New_folder) + +####################################################### +# Setting the folder of saving the predictions +####################################################### + +read_pred = './model/pred' + +####################################################### +# Checking if prediction folder exixts +####################################################### + +if os.path.exists(read_pred) and os.path.isdir(read_pred): + shutil.rmtree(read_pred) + +try: + os.mkdir(read_pred) +except OSError: + print("Creation of the prediction directory '%s' failed of dice loss" % read_pred) +else: + print("Successfully created the prediction directory '%s' of dice loss" % read_pred) + +####################################################### +# checking if the model exists and if true then delete +####################################################### + +read_model_path = './model/Unet_D_' + str(epoch) + '_' + str(batch_size) + +if os.path.exists(read_model_path) and os.path.isdir(read_model_path): + shutil.rmtree(read_model_path) + print('Model folder there, so deleted for newer one') + +try: + os.mkdir(read_model_path) +except OSError: + print("Creation of the model directory '%s' failed" % read_model_path) +else: + print("Successfully created the model directory '%s' " % read_model_path) + +####################################################### +# Training loop +####################################################### + +for i in range(epoch): + + train_loss = 0.0 + valid_loss = 0.0 + train_loss_list = [] + valid_loss_list = [] + since = time.time() + scheduler.step() + # lr = scheduler.get_lr() + + ####################################################### + # Training Data + ####################################################### + + model_test.train() + k = 1 + + for x, y in train_loader: + # print("x: ", x.shape) + # print("y: ", y.shape) + x, y = x.to(device), y.to(device) + + opt.zero_grad() + + y_pred = model_test(x) + lossT = calc_loss(y_pred, y) # Dice_loss Used + + train_loss += lossT.item() * x.size(0) + lossT.backward() + opt.step() + x_size = lossT.item() * x.size(0) + k = 2 + + train_loss_list.append(train_loss) + + ####################################################### + # Validation Step + ####################################################### + + model_test.eval() + torch.no_grad() # to increase the validation process uses less memory + + for x1, y1 in valid_loader: + x1, y1 = x1.to(device), y1.to(device) + + y_pred1 = model_test(x1) + lossL = calc_loss(y_pred1, y1) # Dice_loss Used + + valid_loss += lossL.item() * x1.size(0) + x_size1 = lossL.item() * x1.size(0) + + valid_loss_list.append(valid_loss) + + if (i + 1) % 1 == 0: + print('Epoch: {}/{} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(i + 1, epoch, train_loss, + valid_loss)) + ####################################################### + # Early Stopping + ####################################################### + + if valid_loss <= valid_loss_min and epoch_valid >= i: # and i_valid <= 2: + + print('Validation loss decreased ({:.6f} --> {:.6f}). Saving model '.format(valid_loss_min, valid_loss)) + torch.save(model_test.state_dict(), './model/Unet_D_' + + str(epoch) + '_' + str(batch_size) + '/Unet_epoch_' + str(epoch) + + '_batchsize_' + str(batch_size) + '.pth') + + if round(valid_loss, 4) == round(valid_loss_min, 4): + print(i_valid) + i_valid = i_valid + 1 + valid_loss_min = valid_loss + +if torch.cuda.is_available(): + torch.cuda.empty_cache() + +####################################################### +# Loading the model +####################################################### + +model_test.load_state_dict(torch.load('./model/Unet_D_' + + str(epoch) + '_' + str(batch_size) + '/Unet_epoch_' + str(epoch) + + '_batchsize_' + str(batch_size) + '.pth')) + +model_test.eval() + +####################################################### +# opening the test folder and creating a folder for generated images +####################################################### + +read_test_folder = glob.glob(test_folderP) +x_sort_test = natsort.natsorted(read_test_folder) # To sort + +read_test_folder112 = './model/gen_images' + +if os.path.exists(read_test_folder112) and os.path.isdir(read_test_folder112): + shutil.rmtree(read_test_folder112) + +try: + os.mkdir(read_test_folder112) +except OSError: + print("Creation of the testing directory %s failed" % read_test_folder112) +else: + print("Successfully created the testing directory %s " % read_test_folder112) + +# For Prediction Threshold + +read_test_folder_P_Thres = './model/pred_threshold' + +if os.path.exists(read_test_folder_P_Thres) and os.path.isdir(read_test_folder_P_Thres): + shutil.rmtree(read_test_folder_P_Thres) + +try: + os.mkdir(read_test_folder_P_Thres) +except OSError: + print("Creation of the testing directory %s failed" % read_test_folder_P_Thres) +else: + print("Successfully created the testing directory %s " % read_test_folder_P_Thres) + +# For Label Threshold + +read_test_folder_L_Thres = './model/label_threshold' + +if os.path.exists(read_test_folder_L_Thres) and os.path.isdir(read_test_folder_L_Thres): + shutil.rmtree(read_test_folder_L_Thres) + +try: + os.mkdir(read_test_folder_L_Thres) +except OSError: + print("Creation of the testing directory %s failed" % read_test_folder_L_Thres) +else: + print("Successfully created the testing directory %s " % read_test_folder_L_Thres) + +####################################################### +# saving the images in the files +####################################################### + +img_test_no = 0 + +for i in range(len(read_test_folder)): + im = Image.open(x_sort_test[i]).convert("RGB") + + im1 = im + im_n = np.array(im1) + im_n_flat = im_n.reshape(-1, 1) + + for j in range(im_n_flat.shape[0]): + if im_n_flat[j] != 0: + im_n_flat[j] = 255 + + s = data_transform(im) + pred = model_test(s.unsqueeze(0).cuda()).cpu() + pred = F.sigmoid(pred) + pred = pred.detach().numpy() + + # pred = threshold_predictions_p(pred) #Value kept 0.01 as max is 1 and noise is very small. + + if i % 24 == 0: + img_test_no = img_test_no + 1 + + x1 = plt.imsave('./model/gen_images/im_epoch_' + str(epoch) + 'int_' + str(i) + + '_img_no_' + str(img_test_no) + '.png', pred[0][0], cmap='gray') + +#################################################### +# Calculating the Dice Score +#################################################### + +data_transform = torchvision.transforms.Compose([ + torchvision.transforms.Resize((128, 128)), + # torchvision.transforms.CenterCrop(96), + torchvision.transforms.Grayscale(), + # torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) +]) + +read_test_folderP = glob.glob('./model/gen_images/*') +x_sort_testP = natsort.natsorted(read_test_folderP) + +read_test_folderL = glob.glob(test_folderL) +x_sort_testL = natsort.natsorted(read_test_folderL) # To sort + +dice_score123 = 0.0 +x_count = 0 +x_dice = 0 + +for i in range(len(read_test_folderP)): + + x = Image.open(x_sort_testP[i]).convert("L") + s = data_transform(x) + s = np.array(s) + s = threshold_predictions_v(s) + # print(s) + # print("------------------") + + # save the images + x1 = plt.imsave('./model/pred_threshold/im_epoch_' + str(epoch) + 'int_' + str(i) + + '_img_no_' + str(img_test_no) + '.png', s) + + y = Image.open(x_sort_testL[i]).convert("L") + s2 = data_transform(y) + s3 = np.array(s2) + s2 = threshold_predictions_v(s2) + # print(s3) + + # save the Images + y1 = plt.imsave('./model/label_threshold/im_epoch_' + str(epoch) + 'int_' + str(i) + + '_img_no_' + str(img_test_no) + '.png', s3) + + total = dice_coeff(s, s3) + print(total) + + if total <= 0.8: + x_count += 1 + if total > 0.8: + x_dice = x_dice + total + dice_score123 = dice_score123 + total + +# print('Dice Score : ' + str(dice_score123/len(read_test_folderP))) +print(x_count) +print(x_dice) +print('Dice Score : ' + str(float(x_dice / (len(read_test_folderP) - x_count)))) + From 4afb3db9e61b166a29a5df029855dd98cdb95af6 Mon Sep 17 00:00:00 2001 From: FrostNov4 <97688822+FrostNov4@users.noreply.github.com> Date: Thu, 26 Oct 2023 04:36:11 +1000 Subject: [PATCH 16/20] Update training function including Losses, Matrices and Ploting --- Losses.py | 3 +-- Metrics.py | 1 + Ploting.py | 1 + train.py | 2 ++ 4 files changed, 5 insertions(+), 2 deletions(-) diff --git a/Losses.py b/Losses.py index 65813cf5c..afbc7c3f3 100644 --- a/Losses.py +++ b/Losses.py @@ -1,7 +1,6 @@ +# From https://github.com/bigmb/Unet-Segmentation-Pytorch-Nest-of-Unets from __future__ import print_function, division import torch.nn.functional as F - -# Implemented from https://github.com/bigmb/Unet-Segmentation-Pytorch-Nest-of-Unets def dice_loss(prediction, target): """Calculating the dice loss Args: diff --git a/Metrics.py b/Metrics.py index 233be8455..a0f7c5d43 100644 --- a/Metrics.py +++ b/Metrics.py @@ -1,6 +1,7 @@ # From https://github.com/bigmb/Unet-Segmentation-Pytorch-Nest-of-Unets + import numpy as np from scipy import spatial diff --git a/Ploting.py b/Ploting.py index 76aac1399..0fded5035 100644 --- a/Ploting.py +++ b/Ploting.py @@ -3,6 +3,7 @@ import numpy as np from visdom import Visdom + def draw_loss(Loss_list,epoch,pic_name): plt.cla() x1 = [i for i in range(8)] diff --git a/train.py b/train.py index 19b0f3ff2..e121b06c0 100644 --- a/train.py +++ b/train.py @@ -1,3 +1,4 @@ +# From https://github.com/bigmb/Unet-Segmentation-Pytorch-Nest-of-Unets from __future__ import print_function, division import os import numpy as np @@ -19,6 +20,7 @@ # from torch.utils.tensorboard import SummaryWriter # from tensorboardX import SummaryWriter + import shutil import random from Modules import UNet_For_Brain From a08b86b799a0f921c2ee3f74b9a2bea41385a043 Mon Sep 17 00:00:00 2001 From: FrostNov4 <97688822+FrostNov4@users.noreply.github.com> Date: Thu, 26 Oct 2023 22:04:17 +1000 Subject: [PATCH 17/20] Update README,train --- README.md | 40 ++++++++++++++++++++++++++++------------ train.py | 4 +--- 2 files changed, 29 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 07dab1014..5a62afec2 100644 --- a/README.md +++ b/README.md @@ -1,15 +1,31 @@ -COMP_3710_Report - -Medical condition, Extended to 27 OCT - -This report is focused on the first task (a) - -Segment the ISIC data set with the Improved UNet -with all labels having a minimum Dice similarity coefficient of 0.8 on the test set. - -The structure of Improved UNet is based on the paper. -"Brain Tumor Segmentation and Radiomics Survival Prediction: Contribution to the BRATS 2017 Challenge" -https://arxiv.org/abs/1802.10508v1 +## COMP_3710_Report + +## Medical condition, Extended to 27 OCT + +## This report is focused on the first task (a) + Segment the ISIC data set with the Improved UNet + with all labels having a minimum Dice similarity coefficient of 0.8 on the test set. + +## The structure of Improved UNet is based on the paper. + "Brain Tumor Segmentation and Radiomics Survival Prediction: Contribution to the BRATS 2017 Challenge" + https://arxiv.org/abs/1802.10508v1 + + U-Net is a convolutional neural network architecture primarily used for biomedical image segmentation. + Its U-shaped structure consists of a contracting path, which captures context, and an expansive path, + which enables precise localization. Through skip connections, features from the contracting path are concatenated + with the expansive path, enhancing localization capabilities. + +## Dataset + In this report, the ISIC 2018 dataset will be used. + The ISIC 2018 dataset is a publicly available dataset for skin lesion image segmentation, + provided by the International Skin Imaging Collaboration (ISIC). Given that the real-world + images in the dataset come in different sizes, they are uniformly resized to a 128x128 dimension. + These images use RGB with 3 color channels for input. The label data, which indicates where the lesions are, + is treated in the same way as the real data. However, these labels are input as grayscale images with a single channel, + making them simpler and more focused on the lesion's location and shape. + +## Model + diff --git a/train.py b/train.py index e121b06c0..6d51bfe13 100644 --- a/train.py +++ b/train.py @@ -91,8 +91,6 @@ def model_unet(model_input, in_channel=1, out_channel=1): return model_test -# passsing this string so that if it's AttU_Net or R2ATTU_Net it doesn't throw an error at torchSummary - model_test = model_unet(model_Inputs[-1], 3, 1) @@ -109,7 +107,7 @@ def model_unet(model_input, in_channel=1, out_channel=1): ####################################################### -#ISIC2018 data +# ISIC2018 data t_data = './ISIC2018/ISIC2018_Task1-2_Training_Input_x2/' l_data = './ISIC2018/ISIC2018_Task1_Training_GroundTruth_x2/' test_image = './keras_png_slices_data/keras_png_slices_test/' From 1fd0ff7c59f7db0030d7bd7164884fa34a52ec3d Mon Sep 17 00:00:00 2001 From: FrostNov4 <97688822+FrostNov4@users.noreply.github.com> Date: Thu, 26 Oct 2023 22:12:00 +1000 Subject: [PATCH 18/20] Update, remove files from Main --- Data_Loader.py | 116 ------------ Losses.py | 60 ------- Metrics.py | 56 ------ Modules.py | 313 -------------------------------- Ploting.py | 208 ---------------------- __init__.py | 3 - dice_score.py | 31 ---- evaluate.py | 40 ----- predict.py | 117 ------------ requirements.txt | 5 - train.py | 453 ----------------------------------------------- utils.py | 18 -- 12 files changed, 1420 deletions(-) delete mode 100644 Data_Loader.py delete mode 100644 Losses.py delete mode 100644 Metrics.py delete mode 100644 Modules.py delete mode 100644 Ploting.py delete mode 100644 __init__.py delete mode 100644 dice_score.py delete mode 100644 evaluate.py delete mode 100755 predict.py delete mode 100644 requirements.txt delete mode 100644 train.py delete mode 100644 utils.py diff --git a/Data_Loader.py b/Data_Loader.py deleted file mode 100644 index 42df1a874..000000000 --- a/Data_Loader.py +++ /dev/null @@ -1,116 +0,0 @@ -from __future__ import print_function, division -import os -from PIL import Image -import torch -import torch.utils.data -import torchvision -from skimage import io -from torch.utils.data import Dataset -import random -import numpy as np - - -class Images_Dataset(Dataset): - """Class for getting data as a Dict - Args: - images_dir = path of input images - labels_dir = path of labeled images - transformI = Input Images transformation (default: None) - transformM = Input Labels transformation (default: None) - Output: - sample : Dict of images and labels""" - - def __init__(self, images_dir, labels_dir, transformI=None, transformM=None): - - self.labels_dir = labels_dir - self.images_dir = images_dir - self.transformI = transformI - self.transformM = transformM - - def __len__(self): - return len(self.images_dir) - - def __getitem__(self, idx): - - for i in range(len(self.images_dir)): - image = io.imread(self.images_dir[i]) - label = io.imread(self.labels_dir[i]) - if self.transformI: - image = self.transformI(image) - if self.transformM: - label = self.transformM(label) - sample = {'images': image, 'labels': label} - - return sample - - -class Images_Dataset_folder(torch.utils.data.Dataset): - """Class for getting individual transformations and data - Args: - images_dir = path of input images - labels_dir = path of labeled images - transformI = Input Images transformation (default: None) - transformM = Input Labels transformation (default: None) - Output: - tx = Transformed images - lx = Transformed labels""" - - def __init__(self, images_dir, labels_dir, transformI=None, transformM=None): - self.images = sorted(os.listdir(images_dir)) - self.labels = sorted(os.listdir(labels_dir)) - self.images_dir = images_dir - self.labels_dir = labels_dir - self.transformI = transformI - self.transformM = transformM - - if self.transformI: - self.tx = self.transformI - else: - self.tx = torchvision.transforms.Compose([ - torchvision.transforms.Resize((128, 128)), - # torchvision.transforms.CenterCrop(96), - # torchvision.transforms.RandomRotation((-10,10)), - # torchvision.transforms.RandomHorizontalFlip(), - # torchvision.transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), - torchvision.transforms.ToTensor(), - # torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) - ]) - - if self.transformM: - self.lx = self.transformM - else: - self.lx = torchvision.transforms.Compose([ - torchvision.transforms.Resize((128, 128)), - # torchvision.transforms.CenterCrop(96), - # torchvision.transforms.RandomRotation((-10,10)), - # torchvision.transforms.Grayscale(), - torchvision.transforms.ToTensor(), - # torchvision.transforms.Lambda(lambda x: torch.cat([x, 1 - x], dim=0)) - ]) - - def __len__(self): - - return len(self.images) - - def __getitem__(self, i): - i1 = Image.open(self.images_dir + self.images[i]).convert("RGB") - l1 = Image.open(self.labels_dir + self.labels[i]).convert("L") - # np_i1 = np.array(i1) - # np_l1 = np.array(l1) - # print(np_i1.shape) - # print(np_l1.shape) - - seed = np.random.randint(0, 2 ** 32) # make a seed with numpy generator - - # apply this seed to img tranfsorms - random.seed(seed) - torch.manual_seed(seed) - img = self.tx(i1) - - # apply this seed to target/label tranfsorms - random.seed(seed) - torch.manual_seed(seed) - label = self.lx(l1) - - return img, label - diff --git a/Losses.py b/Losses.py deleted file mode 100644 index afbc7c3f3..000000000 --- a/Losses.py +++ /dev/null @@ -1,60 +0,0 @@ -# From https://github.com/bigmb/Unet-Segmentation-Pytorch-Nest-of-Unets -from __future__ import print_function, division -import torch.nn.functional as F -def dice_loss(prediction, target): - """Calculating the dice loss - Args: - prediction = predicted image - target = Targeted image - Output: - dice_loss""" - - smooth = 1.0 - - i_flat = prediction.view(-1) - t_flat = target.view(-1) - - intersection = (i_flat * t_flat).sum() - - return 1 - ((2. * intersection + smooth) / (i_flat.sum() + t_flat.sum() + smooth)) - - -def calc_loss(prediction, target, bce_weight=0.5): - """Calculating the loss and metrics - Args: - prediction = predicted image - target = Targeted image - metrics = Metrics printed - bce_weight = 0.5 (default) - Output: - loss : dice loss of the epoch """ - bce = F.binary_cross_entropy_with_logits(prediction, target) - prediction = F.sigmoid(prediction) - dice = dice_loss(prediction, target) - - loss = bce * bce_weight + dice * (1 - bce_weight) - - return loss - - -def threshold_predictions_v(predictions, thr=150): - thresholded_preds = predictions[:] - # hist = cv2.calcHist([predictions], [0], None, [2], [0, 2]) - # plt.plot(hist) - # plt.xlim([0, 2]) - # plt.show() - low_values_indices = thresholded_preds < thr - thresholded_preds[low_values_indices] = 0 - low_values_indices = thresholded_preds >= thr - thresholded_preds[low_values_indices] = 255 - return thresholded_preds - - -def threshold_predictions_p(predictions, thr=0.01): - thresholded_preds = predictions[:] - #hist = cv2.calcHist([predictions], [0], None, [256], [0, 256]) - low_values_indices = thresholded_preds < thr - thresholded_preds[low_values_indices] = 0 - low_values_indices = thresholded_preds >= thr - thresholded_preds[low_values_indices] = 1 - return thresholded_preds \ No newline at end of file diff --git a/Metrics.py b/Metrics.py deleted file mode 100644 index a0f7c5d43..000000000 --- a/Metrics.py +++ /dev/null @@ -1,56 +0,0 @@ -# From https://github.com/bigmb/Unet-Segmentation-Pytorch-Nest-of-Unets - - - -import numpy as np -from scipy import spatial - - -def dice_coeff(im1, im2, empty_score=1.0): - """Calculates the dice coefficient for the images""" - - im1 = np.asarray(im1).astype(np.bool) - im2 = np.asarray(im2).astype(np.bool) - - if im1.shape != im2.shape: - raise ValueError("Shape mismatch: im1 and im2 must have the same shape.") - - im1 = im1 > 0.5 - im2 = im2 > 0.5 - - im_sum = im1.sum() + im2.sum() - if im_sum == 0: - return empty_score - - # Compute Dice coefficient - intersection = np.logical_and(im1, im2) - #print(im_sum) - - return 2. * intersection.sum() / im_sum - - -def numeric_score(prediction, groundtruth): - """Computes scores: - FP = False Positives - FN = False Negatives - TP = True Positives - TN = True Negatives - return: FP, FN, TP, TN""" - - FP = np.float(np.sum((prediction == 1) & (groundtruth == 0))) - FN = np.float(np.sum((prediction == 0) & (groundtruth == 1))) - TP = np.float(np.sum((prediction == 1) & (groundtruth == 1))) - TN = np.float(np.sum((prediction == 0) & (groundtruth == 0))) - - return FP, FN, TP, TN - - -def accuracy_score(prediction, groundtruth): - """Getting the accuracy of the model""" - - FP, FN, TP, TN = numeric_score(prediction, groundtruth) - N = FP + FN + TP + TN - # accuracy = np.divide(TP + TN, N) - print("2*TP: ", 2*TP, "FP+2*TP+FN: ", FP+2*TP+FN) - accuracy = np.divide(2*TP, FP+2*TP+FN) - return accuracy #* 100.0 \ No newline at end of file diff --git a/Modules.py b/Modules.py deleted file mode 100644 index 3e7bf5a7a..000000000 --- a/Modules.py +++ /dev/null @@ -1,313 +0,0 @@ -""" Full assembly of the parts to form the complete network """ -from __future__ import print_function, division -import torch.nn as nn -import torch.nn.functional as F -import torch.utils.data -import torch - - - -class ContextModule(nn.Module): - """ - Context Module: Consists of two convolutional layers for feature extraction and a dropout layer - for regularization, aimed at capturing and preserving the context information in the features. - """ - - def __init__(self, in_channels, out_channels): - """ - Initialize the Context Module. - - Parameters: - - in_channels (int): Number of input channels. - - out_channels (int): Number of output channels. - """ - - super(ContextModule, self).__init__() - # 2 3x3 convolution layer followed by instance normalization and leaky ReLU activation - self.conv1 = nn.Sequential( - nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1), - nn.BatchNorm2d(out_channels), - nn.ReLU(inplace=True), - ) - self.conv2 = nn.Sequential( - nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), - nn.BatchNorm2d(out_channels), - nn.ReLU(inplace=True), - ) - # Dropout layer to prevent overfitting - self.dropout = nn.Dropout2d(p=0.3) - - def forward(self, x): - """ - Forward pass through the context module. Input is put through 2 3x3 stride 1 convolutions with a dropout - layer in between - - Parameters: - - x (Tensor): The input tensor. - - Returns: - - Tensor: The output tensor after passing through the context module. - """ - x = self.conv1(x) - x = self.dropout(x) - x = self.conv2(x) - return x - - -class SegmentationLayer(nn.Module): - """ - SegmentationLayer: A convolutional layer specifically utilized to generate a segmentation map. - """ - - def __init__(self, in_channels, out_channels): - """ - Initialize the SegmentationLayer. - - Parameters: - - in_channels (int): Number of input channels. - - out_channels (int): Number of output channels, often equal to the number of classes in segmentation. - """ - super(SegmentationLayer, self).__init__() - # A convolutional layer that produces segmentation map - self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) - - def forward(self, x): - """ - Forward pass through the SegmentationLayer. - - Parameters: - - x (Tensor): The input tensor. - - Returns: - - Tensor: The output tensor after applying the convolution, serving as a segmentation map. - """ - # Applying convolution - x = self.conv(x) - return x - - -class UpscalingLayer(nn.Module): - """ - UpscalingLayer: A layer designed to upscale feature maps by a factor of 2. - """ - - def __init__(self, scale_factor=2, mode='nearest'): - """ - Initialize the UpscalingLayer. - - Parameters: - - scale_factor (int, optional): Factor by which to upscale the input. Default is 2. - - mode (str, optional): Algorithm used for upscaling: 'nearest', 'bilinear', etc. Default is 'nearest'. - """ - super(UpscalingLayer, self).__init__() - # An upsampling layer that increases the spatial dimensions of the feature map - self.upsample = nn.Upsample(scale_factor=scale_factor, mode=mode) - - def forward(self, x): - """ - Forward pass through the UpscalingLayer. - - Parameters: - - x (Tensor): The input tensor. - - Returns: - - Tensor: The output tensor after applying the upscaling, having increased spatial dimensions. - """ - # Applying upscaling - x = self.upsample(x) - return x - - -class LocalisationModule(nn.Module): - """ - Localisation Module: Focused on up-sampling the received feature map and reducing the - number of feature channels, working towards recovering the spatial resolution of the input data. - """ - - def __init__(self, in_channels, out_channels): - """ - Initialize the Localisation Module. - - Parameters: - - in_channels (int): Number of input channels. - - out_channels (int): Number of output channels. - """ - super(LocalisationModule, self).__init__() - # Using a simple upscale by repeating the feature pixels twice - self.upsample = nn.Upsample(scale_factor=2, mode='nearest') - # 3x3 convolution to process concatenated features - self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) - # 1x1 convolution to reduce the number of feature maps - self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) - - def forward(self, x): - """ - Forward pass through the localisation module. Input is put through 2 3x3 stride 1 convolutions - with leaky ReLU applied in between - - Parameters: - - x (Tensor): The input tensor. - - Returns: - - Tensor: The output tensor after passing through the localisation module. - """ - x = self.conv1(x) - x = F.relu(x) - x = self.conv2(x) - x = F.relu(x) - return x - - -class UpsamplingModule(nn.Module): - """ - Upsampling Module: Handles the up-sampling of feature maps in the decoder part of the UNet, - contributing to incrementing the spatial dimensions of the input feature map. - """ - - def __init__(self, in_channels, out_channels): - """ - Initialize the Upsampling Module. - - Parameters: - - in_channels (int): Number of input channels. - - out_channels (int): Number of output channels. - """ - super(UpsamplingModule, self).__init__() - # Using a simple upscale by repeating the feature pixels twice - self.upsample = nn.Upsample(scale_factor=2, mode='nearest') - # 3x3 convolution that halves the number of feature maps - self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) - - def forward(self, x): - """ - Forward pass through the upsampling module. First the input is upsampled, then undergoes stride 1 - 3x3 convolution followed by leaky ReLU. - - Parameters: - - x (Tensor): The input tensor. - - Returns: - - Tensor: The output tensor after passing through the upsampling module. - """ - x = self.upsample(x) - x = self.conv(x) - x = F.relu(x) - - return x - - -class UNet_For_Brain(nn.Module): - """ - UNet2D: An Improved U-Net model implmented as the provided Improved U-Net paper - """ - - def __init__(self, in_channels, num_classes=2): - """ - Initialize the UNet2D model. - - Parameters: - - in_channels (int): Number of input channels. - - num_classes (int): Number of output classes for segmentation. - """ - super(UNet_For_Brain, self).__init__() - - # Encoder - self.enc1 = nn.Conv2d(in_channels, 16 * 4, kernel_size=3, stride=1, padding=1) - self.context1 = ContextModule(16 * 4, 16 * 4) - self.enc2 = nn.Conv2d(16 * 4, 32 * 4, kernel_size=3, stride=2, padding=1) - self.context2 = ContextModule(32 * 4, 32 * 4) - self.enc3 = nn.Conv2d(32 * 4, 64 * 4, kernel_size=3, stride=2, padding=1) - self.context3 = ContextModule(64 * 4, 64 * 4) - self.enc4 = nn.Conv2d(64 * 4, 128 * 4, kernel_size=3, stride=2, padding=1) - self.context4 = ContextModule(128 * 4, 128 * 4) - - # Bottleneck - self.bottleneck = nn.Conv2d(128 * 4, 256 * 4, kernel_size=3, stride=2, padding=1) - self.bottleneck_context = ContextModule(256 * 4, 256 * 4) - self.up_bottleneck = UpsamplingModule(256 * 4, 128 * 4) - - # Decoder - self.local1 = LocalisationModule(256 * 4, 128 * 4) - self.up1 = UpsamplingModule(128 * 4, 64 * 4) - - self.local2 = LocalisationModule(128 * 4, 64 * 4) - self.up2 = UpsamplingModule(64 * 4, 32 * 4) - - self.seg1 = SegmentationLayer(64 * 4, num_classes) - self.upsample_seg1 = UpscalingLayer() - - self.local3 = LocalisationModule(64 * 4, 32 * 4) - self.up3 = UpsamplingModule(32 * 4, 16 * 4) - - self.seg2 = SegmentationLayer(32 * 4, num_classes) - self.upsample_seg2 = UpscalingLayer() - - self.final_conv = nn.Conv2d(32 * 4, 32 * 4, kernel_size=3, stride=1, padding=1) - - self.seg3 = SegmentationLayer(32 * 4, num_classes) - self.upsample_seg3 = UpscalingLayer() - - def forward(self, x): - """ - Define the forward pass through the UNet2D model. - - Parameters: - - x (Tensor): Input tensor. - - Returns: - - Tensor: The output tensor after passing through the model. - """ - y1 = self.enc1(x) - x1 = self.context1(y1) - x1 = x1 + y1 - - y2 = self.enc2(x1) - x2 = self.context2(y2) - x2 = x2 + y2 - - y3 = self.enc3(x2) - x3 = self.context3(y3) - x3 = x3 + y3 - - y4 = self.enc4(x3) - x4 = self.context4(y4) - x4 = x4 + y4 - - # Bottleneck - bottleneck_conv = self.bottleneck(x4) - - bottleneck = self.bottleneck_context(bottleneck_conv) - bottleneck = bottleneck + bottleneck_conv - - up_bottleneck = self.up_bottleneck(bottleneck) - - # Decoder - x = self.local1(torch.cat((x4, up_bottleneck), dim=1)) - x = self.up1(x) - - x = self.local2(torch.cat((x3, x), dim=1)) - seg1 = self.seg1(x) - x = self.up2(x) - - seg1_upsampled = self.upsample_seg1(seg1) - - x = self.local3(torch.cat((x2, x), dim=1)) - seg2 = self.seg2(x) - x = self.up3(x) - - seg12 = seg1_upsampled + seg2 - seg12_up = self.upsample_seg2(seg12) - - x = self.final_conv(torch.cat((x1, x), dim=1)) - - seg3 = self.seg3(x) - seg123 = seg3 + seg12_up - - out = seg123 - # out = nn.functional.softmax(seg123, dim=1) - # out = torch.sigmoid(seg123) - # print("out shape: ", out.size()) - - return out - -# Updated \ No newline at end of file diff --git a/Ploting.py b/Ploting.py deleted file mode 100644 index 0fded5035..000000000 --- a/Ploting.py +++ /dev/null @@ -1,208 +0,0 @@ -import matplotlib.pyplot as plt -from matplotlib.lines import Line2D -import numpy as np -from visdom import Visdom - - -def draw_loss(Loss_list,epoch,pic_name): - plt.cla() - x1 = [i for i in range(8)] - y1 = Loss_list - plt.title('Train loss vs. epoches', fontsize=1) - plt.plot(x1, y1, '.-') - plt.xlabel('epoches', fontsize=1) - plt.ylabel('Train loss', fontsize=1) - plt.grid() - plt.savefig(pic_name) - plt.show() - -def show_images(images, labels): - """Show image with label - Args: - images = input images - labels = input labels - Output: - plt = concatenated image and label """ - - plt.imshow(images.permute(1, 2, 0)) - plt.imshow(labels, alpha=0.7, cmap='gray') - plt.figure() - - -def show_training_dataset(training_dataset): - """Showing the images in training set for dict images and labels - Args: - training_dataset = dictionary of images and labels - Output: - figure = 3 images shown""" - - if training_dataset: - print(len(training_dataset)) - - for i in range(len(training_dataset)): - sample = training_dataset[i] - - print(i, sample['images'].shape, sample['labels'].shape) - - ax = plt.subplot(1, 4, i + 1) - plt.tight_layout() - ax.set_title('Sample #{}'.format(i)) - ax.axis('off') - show_images(sample['images'],sample['labels']) - - if i == 3: - plt.show() - break - -class VisdomLinePlotter(object): - - """Plots to Visdom""" - - def __init__(self, env_name='main'): - self.viz = Visdom() - self.env = env_name - self.plots = {} - - def plot(self, var_name, split_name, title_name, x, y): - if var_name not in self.plots: - self.plots[var_name] = self.viz.line(X=np.array([x,x]), Y=np.array([y,y]), env=self.env, opts=dict( - legend=[split_name], - title=title_name, - xlabel='Epochs', - ylabel=var_name - )) - else: - self.viz.line(X=np.array([x]), Y=np.array([y]), env=self.env, win=self.plots[var_name], name=split_name, update = 'append') - - -def input_images(x, y, i, n_iter, k=1): - """ - - :param x: takes input image - :param y: take input label - :param i: the epoch number - :param n_iter: - :param k: for keeping it in loop - :return: Returns a image and label - """ - if k == 1: - x1 = x - y1 = y - - x2 = x1.to('cpu') - y2 = y1.to('cpu') - x2 = x2.detach().numpy() - y2 = y2.detach().numpy() - - x3 = x2[1, 1, :, :] - y3 = y2[1, 0, :, :] - - fig = plt.figure() - - ax1 = fig.add_subplot(1, 2, 1) - ax1.imshow(x3) - ax1.axis('off') - ax1.set_xticklabels([]) - ax1.set_yticklabels([]) - ax1 = fig.add_subplot(1, 2, 2) - ax1.imshow(y3) - ax1.axis('off') - ax1.set_xticklabels([]) - ax1.set_yticklabels([]) - plt.savefig( - './model/pred/L_' + str(n_iter-1) + '_epoch_' - + str(i)) - - -def plot_kernels(tensor, n_iter, num_cols=5, cmap="gray"): - """Plotting the kernals and layers - Args: - Tensor :Input layer, - n_iter : number of interation, - num_cols : number of columbs required for figure - Output: - Gives the figure of the size decided with output layers activation map - - Default : Last layer will be taken into consideration - """ - if not len(tensor.shape) == 4: - raise Exception("assumes a 4D tensor") - - fig = plt.figure() - i = 0 - t = tensor.data.numpy() - b = 0 - a = 1 - - for t1 in t: - for t2 in t1: - i += 1 - - ax1 = fig.add_subplot(5, num_cols, i) - ax1.imshow(t2, cmap=cmap) - ax1.axis('off') - ax1.set_xticklabels([]) - ax1.set_yticklabels([]) - - if i == 1: - a = 1 - if a == 10: - break - a += 1 - if i % a == 0: - a = 0 - b += 1 - if b == 20: - break - - plt.savefig( - './model/pred/Kernal_' + str(n_iter - 1) + '_epoch_' - + str(i)) - - -class LayerActivations(): - """Getting the hooks on each layer""" - - features = None - - def __init__(self, layer): - self.hook = layer.register_forward_hook(self.hook_fn) - - def hook_fn(self, module, input, output): - self.features = output.cpu() - - def remove(self): - self.hook.remove() - - -#to get gradient flow -#From Pytorch-forums -def plot_grad_flow(named_parameters,n_iter): - - '''Plots the gradients flowing through different layers in the net during training. - Can be used for checking for possible gradient vanishing / exploding problems. - - Usage: Plug this function in Trainer class after loss.backwards() as - "plot_grad_flow(self.model.named_parameters())" to visualize the gradient flow''' - ave_grads = [] - max_grads = [] - layers = [] - for n, p in named_parameters: - if (p.requires_grad) and ("bias" not in n): - layers.append(n) - ave_grads.append(p.grad.abs().mean()) - max_grads.append(p.grad.abs().max()) - plt.bar(np.arange(len(max_grads)), max_grads, alpha=0.1, lw=1, color="c") - plt.bar(np.arange(len(max_grads)), ave_grads, alpha=0.1, lw=1, color="b") - plt.hlines(0, 0, len(ave_grads) + 1, lw=2, color="k") - plt.xticks(range(0, len(ave_grads), 1), layers, rotation="vertical") - plt.xlim(left=0, right=len(ave_grads)) - plt.ylim(bottom=-0.001, top=0.02) # zoom in on the lower gradient regions - plt.xlabel("Layers") - plt.ylabel("average gradient") - plt.title("Gradient flow") - plt.grid(True) - plt.legend([Line2D([0], [0], color="c", lw=4), - Line2D([0], [0], color="b", lw=4), - Line2D([0], [0], color="k", lw=4)], ['max-gradient', 'mean-gradient', 'zero-gradient']) - #plt.savefig('./model/pred/Grad_Flow_' + str(n_iter - 1)) diff --git a/__init__.py b/__init__.py deleted file mode 100644 index 279eecc85..000000000 --- a/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .Modules import UNet - - diff --git a/dice_score.py b/dice_score.py deleted file mode 100644 index 84a219789..000000000 --- a/dice_score.py +++ /dev/null @@ -1,31 +0,0 @@ -import torch -from torch import Tensor - - -def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon: float = 1e-6): - # Average of Dice coefficient for all batches, or for a single mask - assert input.size() == target.size() - assert input.dim() == 3 or not reduce_batch_first - - - sum_dim = (-1, -2) if input.dim() == 2 or not reduce_batch_first else (-1, -2, -3) - - inter = 2 * (input * target).sum(dim=sum_dim) - sets_sum = input.sum(dim=sum_dim) + target.sum(dim=sum_dim) - sets_sum = torch.where(sets_sum == 0, inter, sets_sum) - - dice = (inter + epsilon) / (sets_sum + epsilon) - - return dice.mean() - -def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon: float = 1e-6): - # Average of Dice coefficient for all classes - return dice_coeff(input.flatten(0, 1), target.flatten(0, 1), reduce_batch_first, epsilon) - - -def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False): - # Dice loss (objective to minimize) between 0 and 1 - fn = multiclass_dice_coeff if multiclass else dice_coeff - return 1 - fn(input, target, reduce_batch_first=True) - -# Updated \ No newline at end of file diff --git a/evaluate.py b/evaluate.py deleted file mode 100644 index 9a4e3ba2b..000000000 --- a/evaluate.py +++ /dev/null @@ -1,40 +0,0 @@ -import torch -import torch.nn.functional as F -from tqdm import tqdm - -from utils.dice_score import multiclass_dice_coeff, dice_coeff - - -@torch.inference_mode() -def evaluate(net, dataloader, device, amp): - net.eval() - num_val_batches = len(dataloader) - dice_score = 0 - - # iterate over the validation set - with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp): - for batch in tqdm(dataloader, total=num_val_batches, desc='Validation round', unit='batch', leave=False): - image, mask_true = batch['image'], batch['mask'] - - # move images and labels to correct device and type - image = image.to(device=device, dtype=torch.float32, memory_format=torch.channels_last) - mask_true = mask_true.to(device=device, dtype=torch.long) - - # predict the mask - mask_pred = net(image) - - if net.n_classes == 1: - assert mask_true.min() >= 0 and mask_true.max() <= 1, 'True mask indices should be in [0, 1]' - mask_pred = (F.sigmoid(mask_pred) > 0.5).float() - # compute the Dice score - dice_score += dice_coeff(mask_pred, mask_true, reduce_batch_first=False) - else: - assert mask_true.min() >= 0 and mask_true.max() < net.n_classes, 'True mask indices should be in [0, n_classes[' - # convert to one-hot format - mask_true = F.one_hot(mask_true, net.n_classes).permute(0, 3, 1, 2).float() - mask_pred = F.one_hot(mask_pred.argmax(dim=1), net.n_classes).permute(0, 3, 1, 2).float() - # compute the Dice score, ignoring background - dice_score += multiclass_dice_coeff(mask_pred[:, 1:], mask_true[:, 1:], reduce_batch_first=False) - - net.train() - return dice_score / max(num_val_batches, 1) diff --git a/predict.py b/predict.py deleted file mode 100755 index b74c4608d..000000000 --- a/predict.py +++ /dev/null @@ -1,117 +0,0 @@ -import argparse -import logging -import os - -import numpy as np -import torch -import torch.nn.functional as F -from PIL import Image -from torchvision import transforms - -from utils.data_loading import BasicDataset -from unet import UNet -from utils.utils import plot_img_and_mask - -def predict_img(net, - full_img, - device, - scale_factor=1, - out_threshold=0.5): - net.eval() - img = torch.from_numpy(BasicDataset.preprocess(None, full_img, scale_factor, is_mask=False)) - img = img.unsqueeze(0) - img = img.to(device=device, dtype=torch.float32) - - with torch.no_grad(): - output = net(img).cpu() - output = F.interpolate(output, (full_img.size[1], full_img.size[0]), mode='bilinear') - if net.n_classes > 1: - mask = output.argmax(dim=1) - else: - mask = torch.sigmoid(output) > out_threshold - - return mask[0].long().squeeze().numpy() - - -def get_args(): - parser = argparse.ArgumentParser(description='Predict masks from input images') - parser.add_argument('--model', '-m', default='MODEL.pth', metavar='FILE', - help='Specify the file in which the model is stored') - parser.add_argument('--input', '-i', metavar='INPUT', nargs='+', help='Filenames of input images', required=True) - parser.add_argument('--output', '-o', metavar='OUTPUT', nargs='+', help='Filenames of output images') - parser.add_argument('--viz', '-v', action='store_true', - help='Visualize the images as they are processed') - parser.add_argument('--no-save', '-n', action='store_true', help='Do not save the output masks') - parser.add_argument('--mask-threshold', '-t', type=float, default=0.5, - help='Minimum probability value to consider a mask pixel white') - parser.add_argument('--scale', '-s', type=float, default=0.5, - help='Scale factor for the input images') - parser.add_argument('--bilinear', action='store_true', default=False, help='Use bilinear upsampling') - parser.add_argument('--classes', '-c', type=int, default=2, help='Number of classes') - - return parser.parse_args() - - -def get_output_filenames(args): - def _generate_name(fn): - return f'{os.path.splitext(fn)[0]}_OUT.png' - - return args.output or list(map(_generate_name, args.input)) - - -def mask_to_image(mask: np.ndarray, mask_values): - if isinstance(mask_values[0], list): - out = np.zeros((mask.shape[-2], mask.shape[-1], len(mask_values[0])), dtype=np.uint8) - elif mask_values == [0, 1]: - out = np.zeros((mask.shape[-2], mask.shape[-1]), dtype=bool) - else: - out = np.zeros((mask.shape[-2], mask.shape[-1]), dtype=np.uint8) - - if mask.ndim == 3: - mask = np.argmax(mask, axis=0) - - for i, v in enumerate(mask_values): - out[mask == i] = v - - return Image.fromarray(out) - - -if __name__ == '__main__': - args = get_args() - logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s') - - in_files = args.input - out_files = get_output_filenames(args) - - net = UNet(n_channels=3, n_classes=args.classes, bilinear=args.bilinear) - - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - logging.info(f'Loading model {args.model}') - logging.info(f'Using device {device}') - - net.to(device=device) - state_dict = torch.load(args.model, map_location=device) - mask_values = state_dict.pop('mask_values', [0, 1]) - net.load_state_dict(state_dict) - - logging.info('Model loaded!') - - for i, filename in enumerate(in_files): - logging.info(f'Predicting image {filename} ...') - img = Image.open(filename) - - mask = predict_img(net=net, - full_img=img, - scale_factor=args.scale, - out_threshold=args.mask_threshold, - device=device) - - if not args.no_save: - out_filename = out_files[i] - result = mask_to_image(mask, mask_values) - result.save(out_filename) - logging.info(f'Mask saved to {out_filename}') - - if args.viz: - logging.info(f'Visualizing results for image {filename}, close to continue...') - plot_img_and_mask(img, mask) diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 256fb596a..000000000 --- a/requirements.txt +++ /dev/null @@ -1,5 +0,0 @@ -matplotlib==3.6.2 -numpy==1.23.5 -Pillow==9.3.0 -tqdm==4.64.1 -wandb==0.13.5 diff --git a/train.py b/train.py deleted file mode 100644 index 6d51bfe13..000000000 --- a/train.py +++ /dev/null @@ -1,453 +0,0 @@ -# From https://github.com/bigmb/Unet-Segmentation-Pytorch-Nest-of-Unets -from __future__ import print_function, division -import os -import numpy as np -from PIL import Image -import glob -# import SimpleITK as sitk -from torch import optim -import torch.utils.data -import torch -import torch.nn.functional as F - -import torch.nn -import torchvision -import matplotlib.pyplot as plt -import natsort -from torch.utils.data.sampler import SubsetRandomSampler -from Data_Loader import Images_Dataset, Images_Dataset_folder -import torchsummary -# from torch.utils.tensorboard import SummaryWriter -# from tensorboardX import SummaryWriter - - -import shutil -import random -from Modules import UNet_For_Brain -from Losses import calc_loss, dice_loss, threshold_predictions_v, threshold_predictions_p -from Ploting import plot_kernels, LayerActivations, input_images, plot_grad_flow, draw_loss -from Metrics import dice_coeff, accuracy_score -import time - -# from ploting import VisdomLinePlotter -# from visdom import Visdom - - -####################################################### -# Checking if GPU is used -####################################################### - -train_on_gpu = torch.cuda.is_available() - -if not train_on_gpu: - print('CUDA is not available. Training on CPU') -else: - print('CUDA is available. Training on GPU') - -# os.environ['CUDA_VISIBLE_DEVICES'] = '1,2,3,4' -device = torch.device("cuda:0" if train_on_gpu else "cpu") - -####################################################### -# Setting the basic paramters of the model -####################################################### - -batch_size = 16 -print('batch_size = ' + str(batch_size)) - -valid_size = 0.15 - -epoch = 400 -print('epoch = ' + str(epoch)) - -random_seed = random.randint(1, 100) -print('random_seed = ' + str(random_seed)) - -shuffle = True -valid_loss_min = np.Inf -num_workers = 24 -lossT = [] -lossL = [] -lossL.append(np.inf) -lossT.append(np.inf) -epoch_valid = epoch - 2 -n_iter = 1 -i_valid = 0 - -pin_memory = False -if train_on_gpu: - pin_memory = True - -# plotter = VisdomLinePlotter(env_name='Tutorial Plots') - -####################################################### -# Setting up the model -####################################################### - -model_Inputs = [UNet_For_Brain] - - -def model_unet(model_input, in_channel=1, out_channel=1): - model_test = model_input(in_channel, out_channel) - return model_test - - - -model_test = model_unet(model_Inputs[-1], 3, 1) - -model_test.to(device) - -####################################################### -# Getting the Summary of Model -####################################################### - -torchsummary.summary(model_test, input_size=(3, 128, 128)) - -####################################################### -# Passing the Dataset of Images and Labels -####################################################### - - -# ISIC2018 data -t_data = './ISIC2018/ISIC2018_Task1-2_Training_Input_x2/' -l_data = './ISIC2018/ISIC2018_Task1_Training_GroundTruth_x2/' -test_image = './keras_png_slices_data/keras_png_slices_test/' -test_label = './keras_png_slices_data/keras_png_slices_seg_test/' -test_folderP = './ISIC2018/ISIC2018_Task1-2_Training_Input_x2/*' -test_folderL = './ISIC2018/ISIC2018_Task1_Training_GroundTruth_x2/*' -valid_image = './ISIC2018/ISIC2018_Task1-2_Training_Input_x2/' -valid_lable = './ISIC2018/ISIC2018_Task1_Training_GroundTruth_x2/' - -Training_Data = Images_Dataset_folder(t_data, l_data) - -Validing_Data = Images_Dataset_folder(valid_image, valid_lable) - -####################################################### -# Giving a transformation for input data -####################################################### - -data_transform = torchvision.transforms.Compose([ - torchvision.transforms.Resize((128, 128)), - # torchvision.transforms.CenterCrop(96), - torchvision.transforms.ToTensor(), - # torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -]) - -####################################################### -# Trainging Validation Split -####################################################### - -num_train = len(Training_Data) -indices_train = list(range(num_train)) -# split = int(np.floor(valid_size * num_train)) - -num_valid = len(Validing_Data) -indices_valid = list(range(num_valid)) - -if shuffle: - np.random.seed(random_seed) - np.random.shuffle(indices_train) - np.random.shuffle(indices_valid) - -# train_idx, valid_idx = indices[split:], indices[:split] -train_idx, valid_idx = indices_train, indices_valid -train_sampler = SubsetRandomSampler(train_idx) -valid_sampler = SubsetRandomSampler(valid_idx) - -train_loader = torch.utils.data.DataLoader(Training_Data, batch_size=batch_size, sampler=train_sampler, - num_workers=num_workers, pin_memory=pin_memory, ) - -valid_loader = torch.utils.data.DataLoader(Validing_Data, batch_size=batch_size, sampler=valid_sampler, - num_workers=num_workers, pin_memory=pin_memory, ) - -####################################################### -# Using Adam as Optimizer -####################################################### -# Hyper Parameters -initial_lr = 5e-4 -lr_decay = 0.985 -l2_weight_decay = 1e-5 -# initial_lr = 0.001 -# opt = torch.optim.Adam(model_test.parameters(), lr=initial_lr) # try SGD -# # opt = optim.SGD(model_test.parameters(), lr = initial_lr, momentum=0.99) - -# MAX_STEP = int(1e10) -# eta_min = 1e-2 -# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, MAX_STEP, eta_min=1e-5) - -# Optimisation & Loss Settings -opt = optim.Adam(model_test.parameters(), lr=initial_lr, weight_decay=l2_weight_decay) -scheduler = optim.lr_scheduler.ExponentialLR(optimizer=opt, gamma=lr_decay) - -New_folder = './model' - -if os.path.exists(New_folder) and os.path.isdir(New_folder): - shutil.rmtree(New_folder) - -try: - os.mkdir(New_folder) -except OSError: - print("Creation of the main directory '%s' failed " % New_folder) -else: - print("Successfully created the main directory '%s' " % New_folder) - -####################################################### -# Setting the folder of saving the predictions -####################################################### - -read_pred = './model/pred' - -####################################################### -# Checking if prediction folder exixts -####################################################### - -if os.path.exists(read_pred) and os.path.isdir(read_pred): - shutil.rmtree(read_pred) - -try: - os.mkdir(read_pred) -except OSError: - print("Creation of the prediction directory '%s' failed of dice loss" % read_pred) -else: - print("Successfully created the prediction directory '%s' of dice loss" % read_pred) - -####################################################### -# checking if the model exists and if true then delete -####################################################### - -read_model_path = './model/Unet_D_' + str(epoch) + '_' + str(batch_size) - -if os.path.exists(read_model_path) and os.path.isdir(read_model_path): - shutil.rmtree(read_model_path) - print('Model folder there, so deleted for newer one') - -try: - os.mkdir(read_model_path) -except OSError: - print("Creation of the model directory '%s' failed" % read_model_path) -else: - print("Successfully created the model directory '%s' " % read_model_path) - -####################################################### -# Training loop -####################################################### - -for i in range(epoch): - - train_loss = 0.0 - valid_loss = 0.0 - train_loss_list = [] - valid_loss_list = [] - since = time.time() - scheduler.step() - # lr = scheduler.get_lr() - - ####################################################### - # Training Data - ####################################################### - - model_test.train() - k = 1 - - for x, y in train_loader: - # print("x: ", x.shape) - # print("y: ", y.shape) - x, y = x.to(device), y.to(device) - - opt.zero_grad() - - y_pred = model_test(x) - lossT = calc_loss(y_pred, y) # Dice_loss Used - - train_loss += lossT.item() * x.size(0) - lossT.backward() - opt.step() - x_size = lossT.item() * x.size(0) - k = 2 - - train_loss_list.append(train_loss) - - ####################################################### - # Validation Step - ####################################################### - - model_test.eval() - torch.no_grad() # to increase the validation process uses less memory - - for x1, y1 in valid_loader: - x1, y1 = x1.to(device), y1.to(device) - - y_pred1 = model_test(x1) - lossL = calc_loss(y_pred1, y1) # Dice_loss Used - - valid_loss += lossL.item() * x1.size(0) - x_size1 = lossL.item() * x1.size(0) - - valid_loss_list.append(valid_loss) - - if (i + 1) % 1 == 0: - print('Epoch: {}/{} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(i + 1, epoch, train_loss, - valid_loss)) - ####################################################### - # Early Stopping - ####################################################### - - if valid_loss <= valid_loss_min and epoch_valid >= i: # and i_valid <= 2: - - print('Validation loss decreased ({:.6f} --> {:.6f}). Saving model '.format(valid_loss_min, valid_loss)) - torch.save(model_test.state_dict(), './model/Unet_D_' + - str(epoch) + '_' + str(batch_size) + '/Unet_epoch_' + str(epoch) - + '_batchsize_' + str(batch_size) + '.pth') - - if round(valid_loss, 4) == round(valid_loss_min, 4): - print(i_valid) - i_valid = i_valid + 1 - valid_loss_min = valid_loss - -if torch.cuda.is_available(): - torch.cuda.empty_cache() - -####################################################### -# Loading the model -####################################################### - -model_test.load_state_dict(torch.load('./model/Unet_D_' + - str(epoch) + '_' + str(batch_size) + '/Unet_epoch_' + str(epoch) - + '_batchsize_' + str(batch_size) + '.pth')) - -model_test.eval() - -####################################################### -# opening the test folder and creating a folder for generated images -####################################################### - -read_test_folder = glob.glob(test_folderP) -x_sort_test = natsort.natsorted(read_test_folder) # To sort - -read_test_folder112 = './model/gen_images' - -if os.path.exists(read_test_folder112) and os.path.isdir(read_test_folder112): - shutil.rmtree(read_test_folder112) - -try: - os.mkdir(read_test_folder112) -except OSError: - print("Creation of the testing directory %s failed" % read_test_folder112) -else: - print("Successfully created the testing directory %s " % read_test_folder112) - -# For Prediction Threshold - -read_test_folder_P_Thres = './model/pred_threshold' - -if os.path.exists(read_test_folder_P_Thres) and os.path.isdir(read_test_folder_P_Thres): - shutil.rmtree(read_test_folder_P_Thres) - -try: - os.mkdir(read_test_folder_P_Thres) -except OSError: - print("Creation of the testing directory %s failed" % read_test_folder_P_Thres) -else: - print("Successfully created the testing directory %s " % read_test_folder_P_Thres) - -# For Label Threshold - -read_test_folder_L_Thres = './model/label_threshold' - -if os.path.exists(read_test_folder_L_Thres) and os.path.isdir(read_test_folder_L_Thres): - shutil.rmtree(read_test_folder_L_Thres) - -try: - os.mkdir(read_test_folder_L_Thres) -except OSError: - print("Creation of the testing directory %s failed" % read_test_folder_L_Thres) -else: - print("Successfully created the testing directory %s " % read_test_folder_L_Thres) - -####################################################### -# saving the images in the files -####################################################### - -img_test_no = 0 - -for i in range(len(read_test_folder)): - im = Image.open(x_sort_test[i]).convert("RGB") - - im1 = im - im_n = np.array(im1) - im_n_flat = im_n.reshape(-1, 1) - - for j in range(im_n_flat.shape[0]): - if im_n_flat[j] != 0: - im_n_flat[j] = 255 - - s = data_transform(im) - pred = model_test(s.unsqueeze(0).cuda()).cpu() - pred = F.sigmoid(pred) - pred = pred.detach().numpy() - - # pred = threshold_predictions_p(pred) #Value kept 0.01 as max is 1 and noise is very small. - - if i % 24 == 0: - img_test_no = img_test_no + 1 - - x1 = plt.imsave('./model/gen_images/im_epoch_' + str(epoch) + 'int_' + str(i) - + '_img_no_' + str(img_test_no) + '.png', pred[0][0], cmap='gray') - -#################################################### -# Calculating the Dice Score -#################################################### - -data_transform = torchvision.transforms.Compose([ - torchvision.transforms.Resize((128, 128)), - # torchvision.transforms.CenterCrop(96), - torchvision.transforms.Grayscale(), - # torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -]) - -read_test_folderP = glob.glob('./model/gen_images/*') -x_sort_testP = natsort.natsorted(read_test_folderP) - -read_test_folderL = glob.glob(test_folderL) -x_sort_testL = natsort.natsorted(read_test_folderL) # To sort - -dice_score123 = 0.0 -x_count = 0 -x_dice = 0 - -for i in range(len(read_test_folderP)): - - x = Image.open(x_sort_testP[i]).convert("L") - s = data_transform(x) - s = np.array(s) - s = threshold_predictions_v(s) - # print(s) - # print("------------------") - - # save the images - x1 = plt.imsave('./model/pred_threshold/im_epoch_' + str(epoch) + 'int_' + str(i) - + '_img_no_' + str(img_test_no) + '.png', s) - - y = Image.open(x_sort_testL[i]).convert("L") - s2 = data_transform(y) - s3 = np.array(s2) - s2 = threshold_predictions_v(s2) - # print(s3) - - # save the Images - y1 = plt.imsave('./model/label_threshold/im_epoch_' + str(epoch) + 'int_' + str(i) - + '_img_no_' + str(img_test_no) + '.png', s3) - - total = dice_coeff(s, s3) - print(total) - - if total <= 0.8: - x_count += 1 - if total > 0.8: - x_dice = x_dice + total - dice_score123 = dice_score123 + total - -# print('Dice Score : ' + str(dice_score123/len(read_test_folderP))) -print(x_count) -print(x_dice) -print('Dice Score : ' + str(float(x_dice / (len(read_test_folderP) - x_count)))) - diff --git a/utils.py b/utils.py deleted file mode 100644 index 9f3b54559..000000000 --- a/utils.py +++ /dev/null @@ -1,18 +0,0 @@ -import matplotlib.pyplot as plt - - -def plot_img_and_mask(img, mask): - classes = mask.max() + 1 - fig, ax = plt.subplots(1, classes + 1) - ax[0].set_title('Input image') - ax[0].imshow(img) - for i in range(classes): - ax[i + 1].set_title(f'Mask (class {i + 1})') - ax[i + 1].imshow(mask == i) - plt.xticks([]), plt.yticks([]) - - - plt.show() - - -# Updated \ No newline at end of file From f5d66e17d1115a56d36dfb00c1bde59e037a42f6 Mon Sep 17 00:00:00 2001 From: FrostNov4 <97688822+FrostNov4@users.noreply.github.com> Date: Fri, 27 Oct 2023 13:51:25 +1000 Subject: [PATCH 19/20] Update README.md --- README.md | 40 +++++++++++++++------------------------- 1 file changed, 15 insertions(+), 25 deletions(-) diff --git a/README.md b/README.md index 5a62afec2..5c09f0a41 100644 --- a/README.md +++ b/README.md @@ -1,30 +1,20 @@ -## COMP_3710_Report -## Medical condition, Extended to 27 OCT +# Pattern Analysis +Pattern Analysis of various datasets by COMP3710 students at the University of Queensland. +We create pattern recognition and image processing library for Tensorflow (TF), PyTorch or JAX. + +This library is created and maintained by The University of Queensland [COMP3710](https://my.uq.edu.au/programs-courses/course.html?course_code=comp3710) students. + +The library includes the following implemented in Tensorflow: +* fractals +* recognition problems -## This report is focused on the first task (a) - Segment the ISIC data set with the Improved UNet - with all labels having a minimum Dice similarity coefficient of 0.8 on the test set. - -## The structure of Improved UNet is based on the paper. - "Brain Tumor Segmentation and Radiomics Survival Prediction: Contribution to the BRATS 2017 Challenge" - https://arxiv.org/abs/1802.10508v1 - - U-Net is a convolutional neural network architecture primarily used for biomedical image segmentation. - Its U-shaped structure consists of a contracting path, which captures context, and an expansive path, - which enables precise localization. Through skip connections, features from the contracting path are concatenated - with the expansive path, enhancing localization capabilities. - -## Dataset - In this report, the ISIC 2018 dataset will be used. - The ISIC 2018 dataset is a publicly available dataset for skin lesion image segmentation, - provided by the International Skin Imaging Collaboration (ISIC). Given that the real-world - images in the dataset come in different sizes, they are uniformly resized to a 128x128 dimension. - These images use RGB with 3 color channels for input. The label data, which indicates where the lesions are, - is treated in the same way as the real data. However, these labels are input as grayscale images with a single channel, - making them simpler and more focused on the lesion's location and shape. - -## Model +In the recognition folder, you will find many recognition problems solved including: +* OASIS brain segmentation +* Classification +etc. + +## Medical condition, Extended to 27 OCT From fbd3f830627d5644d6226c7e23cf753d74107187 Mon Sep 17 00:00:00 2001 From: FrostNov4 <97688822+FrostNov4@users.noreply.github.com> Date: Fri, 27 Oct 2023 13:51:50 +1000 Subject: [PATCH 20/20] Update README.md --- README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/README.md b/README.md index 5c09f0a41..8ef13980c 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,6 @@ In the recognition folder, you will find many recognition problems solved includ * Classification etc. -## Medical condition, Extended to 27 OCT