Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/cells.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ This page documents all custom recurrent cells provided in the `torchrecurrent.c
torchrecurrent.OriginalLSTMCell
torchrecurrent.PeepholeLSTMCell
torchrecurrent.RANCell
torchrecurrent.ResLSTMCell
torchrecurrent.SCRNCell
torchrecurrent.SGUCell
torchrecurrent.SGRNCell
Expand Down
1 change: 1 addition & 0 deletions docs/api/layers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ This page documents all custom recurrent layers provided in the `torchrecurrent`
torchrecurrent.OriginalLSTM
torchrecurrent.PeepholeLSTM
torchrecurrent.RAN
torchrecurrent.ResLSTM
torchrecurrent.SCRN
torchrecurrent.SGU
torchrecurrent.SGRN
Expand Down
8 changes: 8 additions & 0 deletions docs/generated/torchrecurrent.ResLSTM.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
torchrecurrent.ResLSTM
======================

.. currentmodule:: torchrecurrent

.. autoclass:: ResLSTM

.. automethod:: __init__
8 changes: 8 additions & 0 deletions docs/generated/torchrecurrent.ResLSTMCell.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
torchrecurrent.ResLSTMCell
==========================

.. currentmodule:: torchrecurrent

.. autoclass:: ResLSTMCell

.. automethod:: __init__
3 changes: 3 additions & 0 deletions docs/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ references and official implementations where available.
* - :doc:`RAN <generated/torchrecurrent.RAN>`
- `arXiv 2017 <https://arxiv.org/abs/1705.07393>`__
- `kentonl/ran <https://github.com/kentonl/ran>`__
* - :doc:`ResLSTM <generated/torchrecurrent.ResLSTM>`
- `arXiv 2017 <https://arxiv.org/abs/1701.03360>`__
- –
* - :doc:`SCRN <generated/torchrecurrent.SCRN>`
- `ICLR 2015 <https://arxiv.org/abs/1412.7753>`__
- `facebookarchive/SCRNNs <https://github.com/facebookarchive/SCRNNs>`__
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "torchrecurrent"
version = "0.2.1"
version = "0.2.2"
description = "A package for recurrent neural networks in PyTorch"
readme = "README.md"
authors = [
Expand Down
12 changes: 12 additions & 0 deletions tests/test_cells.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
OriginalLSTMCell,
PeepholeLSTMCell,
RANCell,
ResLSTMCell,
SCRNCell,
SGUCell,
SGRNCell,
Expand Down Expand Up @@ -68,6 +69,7 @@
(PeepholeLSTMCell, 5, 10, True),
(OriginalLSTMCell, 3, 5, True),
(RANCell, 4, 9, True),
(ResLSTMCell, 4, 9, True),
(SCRNCell, 3, 5, True),
(SGUCell, 3, 5, False),
(SGRNCell, 3, 5, False),
Expand Down Expand Up @@ -120,6 +122,16 @@ def test_cell_output_and_state_shapes(Cell, in_size, hid_size, double):
assert h3.shape == (B, hid_size)


def test_reslstm_cell_parameter_shapes():
cell = ResLSTMCell(4, 9)

assert cell.weight_ih.shape == (36, 4)
assert cell.weight_hh.shape == (36, 9)
assert cell.weight_proj.shape == (9, 9)
assert cell.weight_res.shape == (9, 4)
assert cell.weight_ph.shape == (27,)


@pytest.mark.parametrize("Cell, in_size, hid_size, _", CELL_CASES)
def test_cell_gradients(Cell, in_size, hid_size, _):
"""A quick smoke test: outputs should be differentiable wrt parameters."""
Expand Down
3 changes: 3 additions & 0 deletions tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
OriginalLSTM,
PeepholeLSTM,
RAN,
ResLSTM,
SCRN,
SGU,
SGRN,
Expand Down Expand Up @@ -57,6 +58,7 @@
OriginalLSTM,
PeepholeLSTM,
RAN,
ResLSTM,
SCRN,
SGU,
SGRN,
Expand Down Expand Up @@ -90,6 +92,7 @@
(OriginalLSTM, True),
(PeepholeLSTM, True),
(RAN, True),
(ResLSTM, True),
(SCRN, True),
(SGU, False),
(SGRN, False),
Expand Down
4 changes: 4 additions & 0 deletions torchrecurrent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
OriginalLSTMCell,
PeepholeLSTMCell,
RANCell,
ResLSTMCell,
coRNNCell,
SCRNCell,
SGUCell,
Expand Down Expand Up @@ -64,6 +65,7 @@
OriginalLSTM,
PeepholeLSTM,
RAN,
ResLSTM,
coRNN,
SCRN,
SGU,
Expand Down Expand Up @@ -127,6 +129,8 @@
"PeepholeLSTMCell",
"RAN",
"RANCell",
"ResLSTM",
"ResLSTMCell",
"SCRN",
"SCRNCell",
"SGU",
Expand Down
3 changes: 3 additions & 0 deletions torchrecurrent/cells/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .originallstm_cell import OriginalLSTM, OriginalLSTMCell
from .peepholelstm_cell import PeepholeLSTM, PeepholeLSTMCell
from .ran_cell import RAN, RANCell
from .reslstm_cell import ResLSTM, ResLSTMCell
from .scrn_cell import SCRN, SCRNCell
from .sgu_cell import DSGU, DSGUCell, SGU, SGUCell
from .sgrn_cell import SGRN, SGRNCell
Expand Down Expand Up @@ -86,6 +87,8 @@
"PeepholeLSTMCell",
"RAN",
"RANCell",
"ResLSTM",
"ResLSTMCell",
"SCRN",
"SCRNCell",
"DSGU",
Expand Down
Loading
Loading